Add multi-sketch support for dynamic fusion

* Tensors are owned by workload context instead of workload sketch
  so that they can be used by multiple sketches.
* Add an integration test for multi-sketch case.

Resolves: COMPMID-6148
Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Change-Id: I37d0de5ac103fb2a85020aa1c26e49eb304f47b7
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9706
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/tests/validation/dynamic_fusion/gpu/Integration.cpp b/tests/validation/dynamic_fusion/gpu/Integration.cpp
index 6a283f8..3a91577 100644
--- a/tests/validation/dynamic_fusion/gpu/Integration.cpp
+++ b/tests/validation/dynamic_fusion/gpu/Integration.cpp
@@ -23,24 +23,33 @@
  */
 
 #include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/QuantizationInfo.h"
 #include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Types.h"
 #include "arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h"
 #include "arm_compute/dynamic_fusion/sketch/attributes/CastAttributes.h"
 #include "arm_compute/dynamic_fusion/sketch/attributes/Conv2dAttributes.h"
+#include "arm_compute/dynamic_fusion/sketch/attributes/DepthwiseConv2dAttributes.h"
 #include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuAdd.h"
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuCast.h"
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuConv2d.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuDepthwiseConv2d.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuMul.h"
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h"
 
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuSigmoid.h"
 #include "tests/CL/CLAccessor.h"
 #include "tests/framework/Macros.h"
 #include "tests/validation/Validation.h"
 #include "tests/validation/dynamic_fusion/Utils.h"
+#include "tests/validation/reference/ActivationLayer.h"
 #include "tests/validation/reference/ConvolutionLayer.h"
 #include "tests/validation/reference/DepthConvertLayer.h"
+#include "tests/validation/reference/DepthwiseConvolutionLayer.h"
 #include "tests/validation/reference/ElementwiseOperations.h"
 #include "tests/validation/reference/Permute.h"
+#include "tests/validation/reference/PixelWiseMultiplication.h"
 
 using namespace arm_compute::experimental::dynamic_fusion;
 using namespace arm_compute::test::validation::utils;
@@ -69,17 +78,17 @@
 
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &gpu_ctx };
+    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch sketch{ &context };
 
     // Fuse conv2d
     Conv2dAttributes conv2d_attr{};
-    TensorInfo       input_info  = sketch.create_tensor_info(t_input_shape, 1, data_type, data_layout);
-    TensorInfo       weight_info = sketch.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout));
+    TensorInfo       input_info  = context.create_tensor_info(t_input_shape, 1, data_type, data_layout);
+    TensorInfo       weight_info = context.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout));
 
     ITensorInfo *conv_out_info = GpuConv2d::create_op(sketch, &input_info, &weight_info, nullptr, conv2d_attr);
 
-    TensorInfo dst_info = sketch.create_tensor_info();
+    TensorInfo dst_info = context.create_tensor_info();
     GpuOutput::create_op(sketch, conv_out_info, &dst_info);
 
     // Configure runtime
@@ -156,15 +165,15 @@
 
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &gpu_ctx };
+    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch sketch{ &context };
 
-    TensorInfo in_0_info = sketch.create_tensor_info(t_input_shape, 1, data_type);
-    TensorInfo in_1_info = sketch.create_tensor_info(t_input_shape, 1, data_type);
-    TensorInfo in_2_info = sketch.create_tensor_info(t_input_shape, 1, data_type);
+    TensorInfo in_0_info = context.create_tensor_info(t_input_shape, 1, data_type);
+    TensorInfo in_1_info = context.create_tensor_info(t_input_shape, 1, data_type);
+    TensorInfo in_2_info = context.create_tensor_info(t_input_shape, 1, data_type);
 
-    TensorInfo out_0_info = sketch.create_tensor_info();
-    TensorInfo out_1_info = sketch.create_tensor_info();
+    TensorInfo out_0_info = context.create_tensor_info();
+    TensorInfo out_1_info = context.create_tensor_info();
 
     ITensorInfo *ans_0_info = GpuAdd::create_op(sketch, &in_0_info, &in_1_info);
     GpuOutput::create_op(sketch, ans_0_info, &out_0_info);
@@ -253,15 +262,15 @@
 
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &gpu_ctx };
+    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch sketch{ &context };
 
-    TensorInfo in_0_info = sketch.create_tensor_info(t_input_shape, 1, data_type);
-    TensorInfo in_1_info = sketch.create_tensor_info(t_input_shape, 1, data_type);
-    TensorInfo in_2_info = sketch.create_tensor_info(t_input_shape, 1, data_type);
+    TensorInfo in_0_info = context.create_tensor_info(t_input_shape, 1, data_type);
+    TensorInfo in_1_info = context.create_tensor_info(t_input_shape, 1, data_type);
+    TensorInfo in_2_info = context.create_tensor_info(t_input_shape, 1, data_type);
 
-    TensorInfo out_0_info = sketch.create_tensor_info();
-    TensorInfo out_1_info = sketch.create_tensor_info();
+    TensorInfo out_0_info = context.create_tensor_info();
+    TensorInfo out_1_info = context.create_tensor_info();
 
     CastAttributes cast_0_attr;
     cast_0_attr.data_type(DataType::S32).convert_policy(ConvertPolicy::SATURATE);
@@ -348,6 +357,211 @@
     validate(CLAccessor(t_out_0), ref_t_out_0, tolerance_add_f32);
     validate(CLAccessor(t_out_1), ref_t_out_1, tolerance_cast_f32);
 }
+
+TEST_CASE(Conv2d_Sigmoid_DepthwiseConv2d_Mul, framework::DatasetMode::ALL)
+{
+    //   (tensor0)
+    //       |
+    // ======|============================================== Sketch 0
+    //       |     (tensor1)     +---- (tensor2)
+    //       |         |         |         |
+    // +-- input -- weights -- biases --+  |
+    // |                                |  |
+    // |            Conv2d              |  |
+    // |                                |  |
+    // +----------- output -------------+  |
+    //                |                    |
+    //          +-- input ---+             |
+    //          |            |             |
+    //          |  Sigmoid   |             |
+    //          |            |             |
+    //          +-- output --+             |
+    //                |                    |
+    //          +-- input ---+             |
+    //          |            |             |
+    //          |   Output   |             |
+    //          |            |             |
+    //          +-- output --+             |
+    //                |                    |
+    //            (tensor5)                |
+    //                |                    |
+    //       +--------+                    |
+    // ======|=============================|================ Sketch 1
+    //       |     (tensor3) (tensor4)     |
+    //       |         |         |         |
+    // +-- input -- weights -- biases --+  |
+    // |                                |  |
+    // |        DepthwiseConv2d         |  |
+    // |                                |  |
+    // +----------- output -------------+  |
+    //                |                    |
+    //             +--+   +----------------+
+    //             |      |
+    //        +-- lhs -- rhs --+
+    //        |                |
+    //        |    Multiply    |
+    //        |                |
+    //        +---- output ----+
+    //                |
+    //          +-- input ---+
+    //          |            |
+    //          |   Output   |
+    //          |            |
+    //          +-- output --+
+    //                |
+    //            (tensor6)
+
+    TensorShape conv2d_src_shape(10, 20, 30);
+    TensorShape conv2d_wei_shape(10, 3, 3, 5);
+    TensorShape conv2d_bia_shape(5);
+    TensorShape conv2d_dst_shape(5, 18, 28);
+    TensorShape dwc_wei_shape(5, 3, 3);
+    TensorShape dwc_bia_shape(5);
+    TensorShape dwc_dst_shape(5, 16, 26);
+
+    // Initialize the context.
+    CLScheduler::get().default_reinit();
+
+    auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+    GpuWorkloadContext context(&cl_compile_ctx);
+
+    auto tensor0_info = context.create_tensor_info(conv2d_src_shape, 1, DataType::F32, DataLayout::NHWC);
+
+    // Create the first sketch: conv2d + cast + output.
+    GpuWorkloadSketch sketch0(&context);
+
+    Conv2dAttributes conv2d_attr;
+    auto tensor1_info = context.create_tensor_info(conv2d_wei_shape, 1, DataType::F32, DataLayout::NHWC);
+    auto tensor2_info = context.create_tensor_info(conv2d_bia_shape, 1, DataType::F32, DataLayout::NHWC);
+    ARM_COMPUTE_EXPECT(GpuConv2d::validate_op(sketch0, &tensor0_info, &tensor1_info, &tensor2_info, conv2d_attr), framework::LogLevel::ERRORS);
+    auto ans_info = GpuConv2d::create_op(sketch0, &tensor0_info, &tensor1_info, &tensor2_info, conv2d_attr);
+
+    ARM_COMPUTE_EXPECT(GpuSigmoid::validate_op(sketch0, ans_info), framework::LogLevel::ERRORS);
+    ans_info = GpuSigmoid::create_op(sketch0, ans_info);
+
+    DepthwiseConv2dAttributes dwc_attr;
+    auto tensor3_info = context.create_tensor_info(dwc_wei_shape, 1, DataType::F32, DataLayout::NHWC);
+    auto tensor4_info = context.create_tensor_info(dwc_bia_shape, 1, DataType::F32, DataLayout::NHWC);
+    ARM_COMPUTE_EXPECT(!GpuDepthwiseConv2d::validate_op(sketch0, ans_info, &tensor3_info, &tensor4_info, dwc_attr), framework::LogLevel::ERRORS);
+
+    auto tensor5_info = context.create_tensor_info();
+    ARM_COMPUTE_EXPECT(GpuOutput::validate_op(sketch0, ans_info, &tensor5_info), framework::LogLevel::ERRORS);
+    GpuOutput::create_op(sketch0, ans_info, &tensor5_info);
+
+    // Create the first workload runtime.
+    ClWorkloadRuntime runtime0;
+    runtime0.configure(sketch0);
+
+    // Create the second sketch: dwc + sigmoid + output.
+    GpuWorkloadSketch sketch1(&context);
+
+    ARM_COMPUTE_EXPECT(GpuDepthwiseConv2d::validate_op(sketch1, &tensor5_info, &tensor3_info, &tensor4_info, dwc_attr), framework::LogLevel::ERRORS);
+    ans_info = GpuDepthwiseConv2d::create_op(sketch1, &tensor5_info, &tensor3_info, &tensor4_info, dwc_attr);
+
+    ARM_COMPUTE_EXPECT(GpuMul::validate_op(sketch1, ans_info, &tensor2_info), framework::LogLevel::ERRORS);
+    ans_info = GpuMul::create_op(sketch1, ans_info, &tensor2_info);
+
+    auto tensor6_info = context.create_tensor_info();
+    ARM_COMPUTE_EXPECT(GpuOutput::validate_op(sketch1, ans_info, &tensor6_info), framework::LogLevel::ERRORS);
+    GpuOutput::create_op(sketch1, ans_info, &tensor6_info);
+
+    // Create the second workload runtime.
+    ClWorkloadRuntime runtime1;
+    runtime1.configure(sketch1);
+
+    // Create the user tensors.
+    CLTensor tensor0;
+    CLTensor tensor1;
+    CLTensor tensor2;
+    CLTensor tensor3;
+    CLTensor tensor4;
+    CLTensor tensor5;
+    CLTensor tensor6;
+
+    tensor0.allocator()->init(tensor0_info);
+    tensor1.allocator()->init(tensor1_info);
+    tensor2.allocator()->init(tensor2_info);
+    tensor3.allocator()->init(tensor3_info);
+    tensor4.allocator()->init(tensor4_info);
+    tensor5.allocator()->init(tensor5_info);
+    tensor6.allocator()->init(tensor6_info);
+
+    tensor0.allocator()->allocate();
+    tensor1.allocator()->allocate();
+    tensor2.allocator()->allocate();
+    tensor3.allocator()->allocate();
+    tensor4.allocator()->allocate();
+    tensor5.allocator()->allocate();
+    tensor6.allocator()->allocate();
+
+    // Allocate the auxiliary tensors.
+    for(auto &data : runtime0.get_auxiliary_tensors())
+    {
+        auto tensor = std::get<0>(data);
+        auto &tensor_info = std::get<1>(data);
+        auto mem_req = std::get<2>(data);
+
+        tensor->allocator()->init(tensor_info, mem_req.alignment);
+        tensor->allocator()->allocate();
+    }
+
+    for(auto &data : runtime1.get_auxiliary_tensors())
+    {
+        auto tensor = std::get<0>(data);
+        auto &tensor_info = std::get<1>(data);
+        auto mem_req = std::get<2>(data);
+
+        tensor->allocator()->init(tensor_info, mem_req.alignment);
+        tensor->allocator()->allocate();
+    }
+
+    // Fill the input tensors with random data.
+    fill<float>(CLAccessor(tensor0), 0, library.get());
+    fill<float>(CLAccessor(tensor1), 1, library.get());
+    fill<float>(CLAccessor(tensor2), 2, library.get());
+    fill<float>(CLAccessor(tensor3), 3, library.get());
+    fill<float>(CLAccessor(tensor4), 4, library.get());
+
+    // Run each runtime.
+    runtime0.run({ &tensor0, &tensor1, &tensor2, &tensor5 });
+    runtime1.run({ &tensor5, &tensor3, &tensor4, &tensor2, &tensor6 });
+
+    // Compute the reference result.
+    SimpleTensor<float> ref_conv2d_src(conv2d_src_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC);
+    SimpleTensor<float> ref_conv2d_wei(conv2d_wei_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC);
+    SimpleTensor<float> ref_conv2d_bia(conv2d_bia_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC);
+    SimpleTensor<float> ref_dwc_wei(dwc_wei_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC);
+    SimpleTensor<float> ref_dwc_bia(dwc_bia_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC);
+
+    fill<float>(ref_conv2d_src, 0, library.get());
+    fill<float>(ref_conv2d_wei, 1, library.get());
+    fill<float>(ref_conv2d_bia, 2, library.get());
+    fill<float>(ref_dwc_wei, 3, library.get());
+    fill<float>(ref_dwc_bia, 4, library.get());
+
+    PermutationVector nhwc_to_nchw(1, 2, 0);
+
+    auto conv2d_dst_shape_nchw = conv2d_dst_shape;
+    permute(conv2d_dst_shape_nchw, nhwc_to_nchw);
+    const auto ref_conv2d_src_nchw = reference::permute(ref_conv2d_src, nhwc_to_nchw);
+    const auto ref_conv2d_wei_nchw = reference::permute(ref_conv2d_wei, nhwc_to_nchw);
+    const auto ref_conv2d_bia_nchw = reference::permute(ref_conv2d_bia, nhwc_to_nchw);
+    const auto ref_conv2d_dst_nchw = reference::convolution_layer(ref_conv2d_src_nchw, ref_conv2d_wei_nchw, ref_conv2d_bia_nchw, conv2d_dst_shape_nchw, PadStrideInfo());
+
+    const auto ref_sigmoid_dst_nchw = reference::activation_layer(ref_conv2d_dst_nchw, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+
+    auto dwc_dst_shape_nchw = dwc_dst_shape;
+    permute(dwc_dst_shape_nchw, nhwc_to_nchw);
+    const auto ref_dwc_wei_nchw = reference::permute(ref_dwc_wei, nhwc_to_nchw);
+    const auto ref_dwc_bia_nchw = reference::permute(ref_dwc_bia, nhwc_to_nchw);
+    const auto ref_dwc_dst_nchw = reference::depthwise_convolution(ref_sigmoid_dst_nchw, ref_dwc_wei_nchw, ref_dwc_bia_nchw, dwc_dst_shape_nchw, PadStrideInfo(), 1);
+
+    const auto ref_mul_dst_nchw = reference::pixel_wise_multiplication<float, float, float>(ref_dwc_dst_nchw, ref_conv2d_bia_nchw, 1.0, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_UP, DataType::F32);
+
+    constexpr RelativeTolerance<float> tolerance(0.001f);
+    validate(CLAccessor(tensor6), ref_mul_dst_nchw, tolerance);
+}
+
 TEST_SUITE(Invalid_Fusion_Should_Fail)
 TEST_CASE(Multiple_Complex_Ops_0, framework::DatasetMode::ALL)
 {
@@ -368,12 +582,12 @@
 
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &gpu_ctx };
+    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch sketch{ &context };
 
     // Create tensor infos
-    TensorInfo   input_info  = sketch.create_tensor_info(t_input_shape, 1, data_type, data_layout);
-    TensorInfo   weight_info = sketch.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout));
+    TensorInfo   input_info  = context.create_tensor_info(t_input_shape, 1, data_type, data_layout);
+    TensorInfo   weight_info = context.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout));
     ITensorInfo *dst_info;
 
     // Fuse conv2d into the workload
@@ -386,7 +600,7 @@
     }
 
     // Create tensor infos
-    TensorInfo weight_info_2 = sketch.create_tensor_info(t_weight_info);
+    TensorInfo weight_info_2 = context.create_tensor_info(t_weight_info);
 
     // Fuse conv2d into the workload
     {
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Add.cpp b/tests/validation/dynamic_fusion/gpu/cl/Add.cpp
index 0034b0f..d9a3d95 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Add.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Add.cpp
@@ -87,12 +87,12 @@
 {
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &gpu_ctx };
+    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch sketch{ &context };
 
     // Validate Elementwise Add
-    auto          lhs_info         = sketch.create_tensor_info(input1_info);
-    auto          rhs_info         = sketch.create_tensor_info(input2_info);
+    auto          lhs_info         = context.create_tensor_info(input1_info);
+    auto          rhs_info         = context.create_tensor_info(input2_info);
 
     bool res = bool(GpuAdd::validate_op(sketch, &lhs_info, &rhs_info));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp b/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp
index 177c02c..dc46dd5 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp
@@ -69,11 +69,11 @@
 {
     // Create a new workload sketch
     CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext gpu_ctx{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &gpu_ctx };
+    GpuWorkloadContext context{ &cl_compile_ctx };
+    GpuWorkloadSketch sketch{ &context };
 
     // Fuse Clamp
-    const TensorInfo src_info = sketch.create_tensor_info(input_info);
+    const TensorInfo src_info = context.create_tensor_info(input_info);
 
     ClampAttributes attributes {};
     attributes.min_val(min_val)
diff --git a/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp
index b6331d7..7ab7c8a 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp
@@ -242,12 +242,12 @@
                 input_info, weights_info, biases_info, padding, stride, depth_multiplier, dilation, expected)
 {
     CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &gpu_ctx };
+    GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch sketch{ &context };
 
-    const TensorInfo sketch_input_info   = sketch.create_tensor_info(input_info);
-    const TensorInfo sketch_weights_info = sketch.create_tensor_info(weights_info);
-    const TensorInfo sketch_biases_info  = sketch.create_tensor_info(biases_info);
+    const TensorInfo sketch_input_info   = context.create_tensor_info(input_info);
+    const TensorInfo sketch_weights_info = context.create_tensor_info(weights_info);
+    const TensorInfo sketch_biases_info  = context.create_tensor_info(biases_info);
 
     DepthwiseConv2dAttributes attributes {};
     attributes.pad(padding)
diff --git a/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp
index cccad18..f27a179 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp
@@ -157,12 +157,12 @@
                input_info, weights_info, biases_info, conv2d_attrs, expected)
 {
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &gpu_ctx };
+    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch sketch{ &context };
 
-    const TensorInfo sketch_input_info   = sketch.create_tensor_info(input_info);
-    const TensorInfo sketch_weights_info = sketch.create_tensor_info(weights_info);
-    const TensorInfo sketch_biases_info  = sketch.create_tensor_info(biases_info);
+    const TensorInfo sketch_input_info   = context.create_tensor_info(input_info);
+    const TensorInfo sketch_weights_info = context.create_tensor_info(weights_info);
+    const TensorInfo sketch_biases_info  = context.create_tensor_info(biases_info);
     bool is_valid = bool(GpuConv2d::validate_op(sketch, &sketch_input_info, &sketch_weights_info, &sketch_biases_info, conv2d_attrs));
     ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
 }
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp b/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp
index a9e8f9c..2da2b9e 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp
@@ -102,12 +102,12 @@
 {
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &gpu_ctx };
+    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch sketch{ &context };
 
     // Validate Elementwise Mul
-    auto          lhs_info         = sketch.create_tensor_info(input1_info);
-    auto          rhs_info         = sketch.create_tensor_info(input2_info);
+    auto          lhs_info         = context.create_tensor_info(input1_info);
+    auto          rhs_info         = context.create_tensor_info(input2_info);
 
     bool res = bool(GpuMul::validate_op(sketch, &lhs_info, &rhs_info));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp
index a7772ae..f4478db 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp
@@ -101,15 +101,15 @@
 {
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &gpu_ctx };
+    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch sketch{ &context };
 
     // Declare GpuPool2d settings
     const GpuPool2dSettings &settings = GpuPool2dSettings().mixed_precision(false);
 
     // Validate Pool2d Configuration
-    auto                   src_info    = sketch.create_tensor_info(input_info);
-    auto                   dst_info    = sketch.create_tensor_info(output_info);
+    auto                   src_info    = context.create_tensor_info(input_info);
+    auto                   dst_info    = context.create_tensor_info(output_info);
     bool                   res         = bool(GpuPool2d::validate_op(sketch, &src_info, &dst_info, pool2d_attr, settings));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
 }
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp b/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp
index 6d88be4..bdaa1be 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp
@@ -53,13 +53,13 @@
 {
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &gpu_ctx };
+    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch sketch{ &context };
 
     // Create sketch tensors
     TensorShape input_shape = input_info.tensor_shape();
     ARM_COMPUTE_UNUSED(input_shape);
-    TensorInfo src_info = sketch.create_tensor_info(input_info);
+    TensorInfo src_info = context.create_tensor_info(input_info);
 
     ReshapeAttributes attributes;
     attributes.shape(output_shape);
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp b/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp
index 696be54..5f99cd6 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp
@@ -95,10 +95,10 @@
     const TensorInfo output_info = TensorInfo{ default_output_shape, 1, default_data_type, default_data_layout };
 
     CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch  sketch{ &gpu_ctx };
+    GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch  sketch{ &context };
 
-    const TensorInfo sketch_input_info = sketch.create_tensor_info(input_info);
+    const TensorInfo sketch_input_info = context.create_tensor_info(input_info);
 
     // nullptr is given as input
     Status status = GpuResize::validate_op(sketch, nullptr, ResizeAttributes());
@@ -135,10 +135,10 @@
         const TensorInfo input_info = TensorInfo{ default_input_shape, 1, kv.first, default_data_layout };
 
         CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-        GpuWorkloadContext gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-        GpuWorkloadSketch  sketch{ &gpu_ctx };
+        GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
+        GpuWorkloadSketch  sketch{ &context };
 
-        const TensorInfo sketch_input_info = sketch.create_tensor_info(input_info);
+        const TensorInfo sketch_input_info = context.create_tensor_info(input_info);
 
         ResizeAttributes attributes;
         attributes.output_width(default_output_shape[0]); // shape is not important unless it's empty
@@ -157,10 +157,10 @@
     const TensorInfo output_info = TensorInfo{ default_output_shape, 1, non_default_data_type, default_data_layout };
 
     CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch  sketch{ &gpu_ctx };
+    GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch  sketch{ &context };
 
-    const TensorInfo sketch_input_info = sketch.create_tensor_info(input_info);
+    const TensorInfo sketch_input_info = context.create_tensor_info(input_info);
 
     Status status = GpuResize::validate_op(sketch, &sketch_input_info, ResizeAttributes());
     ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS);
@@ -177,10 +177,10 @@
     const TensorInfo output_info = TensorInfo{ default_output_shape, 1, default_data_type, default_data_layout };
 
     CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch  sketch{ &gpu_ctx };
+    GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch  sketch{ &context };
 
-    const TensorInfo sketch_input_info = sketch.create_tensor_info(input_info);
+    const TensorInfo sketch_input_info = context.create_tensor_info(input_info);
 
     ResizeAttributes attributes{};
     attributes.interpolation_policy(interpolation_policy)
@@ -198,10 +198,10 @@
     constexpr auto   interpolation_policy = InterpolationPolicy::AREA;
 
     CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch  sketch{ &gpu_ctx };
+    GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch  sketch{ &context };
 
-    const TensorInfo sketch_input_info = sketch.create_tensor_info(input_info);
+    const TensorInfo sketch_input_info = context.create_tensor_info(input_info);
 
     ResizeAttributes attributes{};
     attributes.interpolation_policy(interpolation_policy);
@@ -217,10 +217,10 @@
     constexpr auto   interpolation_policy = InterpolationPolicy::BILINEAR;
 
     CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch  sketch{ &gpu_ctx };
+    GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch  sketch{ &context };
 
-    const TensorInfo sketch_input_info = sketch.create_tensor_info(input_info);
+    const TensorInfo sketch_input_info = context.create_tensor_info(input_info);
 
     ResizeAttributes attributes{};
     attributes.interpolation_policy(interpolation_policy);
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp b/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp
index aace23e..5fd1180 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp
@@ -61,11 +61,11 @@
 {
     // Create a new workload sketch
     CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext gpu_ctx{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &gpu_ctx };
+    GpuWorkloadContext context{ &cl_compile_ctx };
+    GpuWorkloadSketch sketch{ &context };
 
     // Fuse sigmoid
-    const TensorInfo src_info = sketch.create_tensor_info(input_info);
+    const TensorInfo src_info = context.create_tensor_info(input_info);
 
     const bool res = static_cast<bool>(GpuSigmoid::validate_op(sketch, &src_info));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp b/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp
index d09454e..e8314d7 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp
@@ -104,13 +104,13 @@
 {
     // Create a new workload sketch
     CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch  sketch{ &gpu_ctx };
+    GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch  sketch{ &context };
 
     SoftmaxAttributes softmax_attr{};
     softmax_attr.axis(axis).beta(beta).is_log_softmax(false);
-    TensorInfo src_info  = sketch.create_tensor_info(input_info);
-    TensorInfo dst_info = sketch.create_tensor_info(output_info);
+    TensorInfo src_info  = context.create_tensor_info(input_info);
+    TensorInfo dst_info = context.create_tensor_info(output_info);
     const bool res = static_cast<bool>(GpuSoftmax::validate_op(sketch, &src_info, &dst_info, softmax_attr));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
 }
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp b/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp
index 977e011..0bb05c2 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp
@@ -89,12 +89,12 @@
 {
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              gpu_ctx        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &gpu_ctx };
+    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
+    GpuWorkloadSketch sketch{ &context };
 
     // Validate Elementwise Sub
-    auto          lhs_info         = sketch.create_tensor_info(input1_info);
-    auto          rhs_info         = sketch.create_tensor_info(input2_info);
+    auto          lhs_info         = context.create_tensor_info(input1_info);
+    auto          rhs_info         = context.create_tensor_info(input2_info);
 
     bool res = bool(GpuSub::validate_op(sketch, &lhs_info, &rhs_info));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp b/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp
index 183cd07..00c92fb 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp
@@ -61,11 +61,11 @@
 {
     // Create a new workload sketch
     CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext gpu_ctx{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &gpu_ctx };
+    GpuWorkloadContext context{ &cl_compile_ctx };
+    GpuWorkloadSketch sketch{ &context };
 
     // Fuse tanh
-    const TensorInfo src_info = sketch.create_tensor_info(input_info);
+    const TensorInfo src_info = context.create_tensor_info(input_info);
 
     const bool res = static_cast<bool>(GpuTanh::validate_op(sketch, &src_info));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);