Make GpuWorkloadContext own all tensor info objects

* The tensor info objects created by calling create_tensor_info
  is now solely owned by the context object. The user only receives
  pointers to those objects.
  - Internally pointers to tensor info objects are used in various
    places. It's safer for dynamic fusion to manage these objects
    directly rather than relying on the users.
  - The validation test is updated to use the modified API.
* Make various changes in dynamic fusion API to make it more
  friendly (e.g. making some of the objects moveable).

Partially resolves: COMPMID-6707
Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Change-Id: Ifee70e53c05f8e7b72bf9ef123701ff291c5ee80
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10990
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Jakub Sujak <jakub.sujak@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h b/arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h
index 3deaff7..6b92f12 100644
--- a/arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h
+++ b/arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef ARM_COMPUTE_DYNAMIC_FUSION_RUNTIME_GPU_CL_CLWORKLOADRUNTIME
-#define ARM_COMPUTE_DYNAMIC_FUSION_RUNTIME_GPU_CL_CLWORKLOADRUNTIME
+#ifndef ACL_ARM_COMPUTE_DYNAMIC_FUSION_RUNTIME_GPU_CL_CLWORKLOADRUNTIME_H
+#define ACL_ARM_COMPUTE_DYNAMIC_FUSION_RUNTIME_GPU_CL_CLWORKLOADRUNTIME_H
 
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/dynamic_fusion/sketch/MemoryDescriptor.h"
@@ -46,8 +46,18 @@
 class ClWorkloadRuntime
 {
 public:
+    /** Default constructor. */
     ClWorkloadRuntime();
+
+    /** Destructor */
     ~ClWorkloadRuntime();
+
+    /** Move constructor */
+    ClWorkloadRuntime(ClWorkloadRuntime &&);
+
+    /** Move assignment */
+    ClWorkloadRuntime &operator=(ClWorkloadRuntime &&);
+
     /** Configure @ref ClWorkloadRuntime
      * @note A runtime cannot be re-configured
      *
@@ -78,4 +88,4 @@
 } // namespace dynamic_fusion
 } // namespace experimental
 } // namespace arm_compute
-#endif /* ARM_COMPUTE_DYNAMIC_FUSION_RUNTIME_GPU_CL_CLWORKLOADRUNTIME */
+#endif // ACL_ARM_COMPUTE_DYNAMIC_FUSION_RUNTIME_GPU_CL_CLWORKLOADRUNTIME_H
diff --git a/arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadContext.h b/arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadContext.h
index 38b350c..76e4255 100644
--- a/arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadContext.h
+++ b/arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadContext.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef ARM_COMPUTE_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADCONTEXT
-#define ARM_COMPUTE_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADCONTEXT
+#ifndef ACL_ARM_COMPUTE_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADCONTEXT_H
+#define ACL_ARM_COMPUTE_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADCONTEXT_H
 
 #include "arm_compute/core/GPUTarget.h"
 #include "arm_compute/core/TensorInfo.h"
@@ -85,11 +85,14 @@
      * @return TensorInfo Newly created tensor info
      */
     template <typename... TArgs>
-    TensorInfo create_tensor_info(TArgs &&...args)
+    ITensorInfo *create_tensor_info(TArgs &&...args)
     {
-        auto tensor_info = TensorInfo(std::forward<TArgs>(args)...);
-        register_user_tensor(tensor_info);
-        return tensor_info;
+        auto  tensor_info     = std::make_unique<TensorInfo>(std::forward<TArgs>(args)...);
+        auto *tensor_info_ptr = tensor_info.get();
+
+        register_user_tensor(std::move(tensor_info));
+
+        return tensor_info_ptr;
     }
 
     /** Get the internal implementation */
@@ -101,9 +104,11 @@
 private:
     /** Set a new ID to the tensor info and register its memory descriptor to the context.
      *
-     * @param[in,out] tensor_info @ref ITensorInfo to be registered.
+     * The ownership of the tensor info object will be transfered to this context object.
+     *
+     * @param[in] tensor_info @ref TensorInfo to be registered.
      */
-    void register_user_tensor(ITensorInfo &tensor_info);
+    void register_user_tensor(std::unique_ptr<TensorInfo> &&tensor_info);
 
     /** Internal implementation */
     std::unique_ptr<Impl> _impl;
@@ -113,4 +118,4 @@
 } // namespace experimental
 } // namespace arm_compute
 
-#endif /* ARM_COMPUTE_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADCONTEXT */
+#endif // ACL_ARM_COMPUTE_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADCONTEXT_H
diff --git a/arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h b/arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h
index 75c2b1f..1c738bd 100644
--- a/arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h
+++ b/arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef ARM_COMPUTE_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSKETCH
-#define ARM_COMPUTE_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSKETCH
+#ifndef ACL_ARM_COMPUTE_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSKETCH_H
+#define ACL_ARM_COMPUTE_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSKETCH_H
 
 #include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadContext.h"
 
@@ -53,15 +53,28 @@
      * @param[in] context Gpu context for the creation of a workload
      */
     explicit GpuWorkloadSketch(GpuWorkloadContext *context);
+
     /** Destructor */
     ~GpuWorkloadSketch();
+
+    /** Move constructor */
+    GpuWorkloadSketch(GpuWorkloadSketch &&);
+
+    /** Move assignment */
+    GpuWorkloadSketch &operator=(GpuWorkloadSketch &&);
+
     /** Get the implementation */
     Implementation &implementation();
+
     /** Get the implementation */
     const Implementation &implementation() const;
+
     /** Get the gpu workload context of this sketch */
     const GpuWorkloadContext *gpu_context() const;
 
+    /** Get the gpu workload context of this sketch */
+    GpuWorkloadContext *gpu_context();
+
 private:
     std::unique_ptr<Implementation> _impl; /**< Internal opaque implementation*/
 };
@@ -69,4 +82,4 @@
 } // namespace dynamic_fusion
 } // namespace experimental
 } // namespace arm_compute
-#endif /* ARM_COMPUTE_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSKETCH */
+#endif // ACL_ARM_COMPUTE_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSKETCH_H
diff --git a/src/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.cpp b/src/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.cpp
index ba39ff4..3500a0e 100644
--- a/src/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.cpp
+++ b/src/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -289,6 +289,10 @@
 
 ClWorkloadRuntime::~ClWorkloadRuntime() = default;
 
+ClWorkloadRuntime::ClWorkloadRuntime(ClWorkloadRuntime &&) = default;
+
+ClWorkloadRuntime &ClWorkloadRuntime::operator=(ClWorkloadRuntime &&) = default;
+
 Status ClWorkloadRuntime::configure(const GpuWorkloadSketch &sketch)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(_impl->_is_configured, "ClWorkloadRuntime cannot be re-configured");
diff --git a/src/dynamic_fusion/sketch/gpu/GpuWorkloadContext.cpp b/src/dynamic_fusion/sketch/gpu/GpuWorkloadContext.cpp
index 36cad79..fab18aa 100644
--- a/src/dynamic_fusion/sketch/gpu/GpuWorkloadContext.cpp
+++ b/src/dynamic_fusion/sketch/gpu/GpuWorkloadContext.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -60,9 +60,9 @@
     return _impl->cl_compile_context();
 }
 
-void GpuWorkloadContext::register_user_tensor(ITensorInfo &tensor_info)
+void GpuWorkloadContext::register_user_tensor(std::unique_ptr<TensorInfo> &&tensor_info)
 {
-    _impl->register_user_tensor(tensor_info);
+    _impl->register_user_tensor(std::move(tensor_info));
 }
 
 GpuWorkloadContext::Impl &GpuWorkloadContext::implementation()
@@ -99,17 +99,17 @@
     return _mem_map;
 }
 
-void GpuWorkloadContext::Impl::register_user_tensor(ITensorInfo &tensor_info)
+void GpuWorkloadContext::Impl::register_user_tensor(std::unique_ptr<TensorInfo> &&tensor_info)
 {
-    ARM_COMPUTE_ERROR_ON(tensor_info.has_valid_id());
+    ARM_COMPUTE_ERROR_ON(tensor_info->has_valid_id());
 
     const auto tensor_id = next_tensor_id();
 
-    tensor_info.set_id(tensor_id);
+    tensor_info->set_id(tensor_id);
     _mem_map[tensor_id] = MemoryDescriptor{MemoryType::User};
     // Save a *copy* of the user tensor info in workload context for future reference
     // Note that this means if the user modifies the @p tensor_info, the change will not be reflected in the context
-    _managed_tensor_info.emplace(tensor_info.id(), std::make_unique<TensorInfo>(tensor_info));
+    _managed_tensor_info.emplace(tensor_info->id(), std::move(tensor_info));
 }
 
 ITensorInfo *GpuWorkloadContext::Impl::create_virtual_tensor()
diff --git a/src/dynamic_fusion/sketch/gpu/GpuWorkloadContextImpl.h b/src/dynamic_fusion/sketch/gpu/GpuWorkloadContextImpl.h
index 7d96990..b3571a6 100644
--- a/src/dynamic_fusion/sketch/gpu/GpuWorkloadContextImpl.h
+++ b/src/dynamic_fusion/sketch/gpu/GpuWorkloadContextImpl.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -64,9 +64,11 @@
 
     /** Set a new ID and register the user tensor info.
      *
-     * @param[in, out] tensor_info The tensor info to be registered.
+     * The ownership of the tensor info object will be transfered to this context object.
+     *
+     * @param[in] tensor_info The tensor info to be registered.
      */
-    void register_user_tensor(ITensorInfo &tensor_info);
+    void register_user_tensor(std::unique_ptr<TensorInfo> &&tensor_info);
 
     /** Create a virtual (see @ref MemoryType) tensor info and save it
      *
diff --git a/src/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.cpp b/src/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.cpp
index 973f7c7..357cb48 100644
--- a/src/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.cpp
+++ b/src/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -31,22 +31,34 @@
 {
 namespace dynamic_fusion
 {
+
 GpuWorkloadSketch::GpuWorkloadSketch(Context *context) : _impl{std::make_unique<Implementation>(context)}
 {
 }
+
 GpuWorkloadSketch::~GpuWorkloadSketch()
 {
 }
 
+GpuWorkloadSketch::GpuWorkloadSketch(GpuWorkloadSketch &&) = default;
+
+GpuWorkloadSketch &GpuWorkloadSketch::operator=(GpuWorkloadSketch &&) = default;
+
 const GpuWorkloadSketch::Context *GpuWorkloadSketch::gpu_context() const
 {
     return _impl->context();
 }
 
+GpuWorkloadSketch::Context *GpuWorkloadSketch::gpu_context()
+{
+    return _impl->context();
+}
+
 GpuWorkloadSketch::Implementation &GpuWorkloadSketch::implementation()
 {
     return *_impl;
 }
+
 const GpuWorkloadSketch::Implementation &GpuWorkloadSketch::implementation() const
 {
     return *_impl;
diff --git a/src/dynamic_fusion/sketch/gpu/GpuWorkloadSketchImpl.h b/src/dynamic_fusion/sketch/gpu/GpuWorkloadSketchImpl.h
index fea4fe9..04e294e 100644
--- a/src/dynamic_fusion/sketch/gpu/GpuWorkloadSketchImpl.h
+++ b/src/dynamic_fusion/sketch/gpu/GpuWorkloadSketchImpl.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSKETCHIMPL
-#define SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSKETCHIMPL
+#ifndef ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSKETCHIMPL_H
+#define ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSKETCHIMPL_H
 
 #include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
 #include "arm_compute/dynamic_fusion/sketch/MemoryDescriptor.h"
@@ -63,6 +63,11 @@
     {
         return _context;
     }
+    /** Get workload context */
+    Context *context()
+    {
+        return _context;
+    }
     /** Get component graph */
     const GpuKernelComponentGraph &component_graph() const
     {
@@ -126,4 +131,4 @@
 } // namespace dynamic_fusion
 } // namespace experimental
 } // namespace arm_compute
-#endif /* SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSKETCHIMPL */
+#endif // ACL_SRC_DYNAMIC_FUSION_SKETCH_GPU_GPUWORKLOADSKETCHIMPL_H
diff --git a/tests/validation/dynamic_fusion/gpu/Integration.cpp b/tests/validation/dynamic_fusion/gpu/Integration.cpp
index 89cca5c..bb9c008 100644
--- a/tests/validation/dynamic_fusion/gpu/Integration.cpp
+++ b/tests/validation/dynamic_fusion/gpu/Integration.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -37,11 +37,10 @@
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuDepthwiseConv2d.h"
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuMul.h"
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h"
-
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuSigmoid.h"
+
 #include "tests/CL/CLAccessor.h"
 #include "tests/framework/Macros.h"
-#include "tests/validation/Validation.h"
 #include "tests/validation/dynamic_fusion/Utils.h"
 #include "tests/validation/reference/ActivationLayer.h"
 #include "tests/validation/reference/ConvolutionLayer.h"
@@ -50,6 +49,7 @@
 #include "tests/validation/reference/ElementwiseOperations.h"
 #include "tests/validation/reference/Permute.h"
 #include "tests/validation/reference/PixelWiseMultiplication.h"
+#include "tests/validation/Validation.h"
 
 using namespace arm_compute::experimental::dynamic_fusion;
 using namespace arm_compute::test::validation::utils;
@@ -79,18 +79,18 @@
 
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &context };
+    auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+    GpuWorkloadSketch sketch{&context};
 
     // Fuse conv2d
     Conv2dAttributes conv2d_attr{};
-    TensorInfo       input_info  = context.create_tensor_info(t_input_shape, 1, data_type, data_layout);
-    TensorInfo       weight_info = context.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout));
+    ITensorInfo     *input_info  = context.create_tensor_info(t_input_shape, 1, data_type, data_layout);
+    ITensorInfo     *weight_info = context.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout));
 
-    ITensorInfo *conv_out_info = GpuConv2d::create_op(sketch, &input_info, &weight_info, nullptr, conv2d_attr);
+    ITensorInfo *conv_out_info = GpuConv2d::create_op(sketch, input_info, weight_info, nullptr, conv2d_attr);
 
-    TensorInfo dst_info = context.create_tensor_info();
-    GpuOutput::create_op(sketch, conv_out_info, &dst_info);
+    ITensorInfo *dst_info = context.create_tensor_info();
+    GpuOutput::create_op(sketch, conv_out_info, dst_info);
 
     // Configure runtime
     ClWorkloadRuntime runtime;
@@ -98,7 +98,7 @@
 
     // (Important) Allocate auxiliary tensor memory if there are any
     // Instead of using ACL allocated memory, the user can choose to import memory into the tensors
-    for(auto &data : runtime.get_auxiliary_tensors())
+    for (auto &data : runtime.get_auxiliary_tensors())
     {
         CLTensor     *tensor      = std::get<0>(data);
         TensorInfo    info        = std::get<1>(data);
@@ -115,9 +115,9 @@
     CLTensor t_dst{};
 
     // Initialize user tensors
-    t_input.allocator()->init(input_info);
-    t_weight.allocator()->init(weight_info);
-    t_dst.allocator()->init(dst_info);
+    t_input.allocator()->init(*input_info);
+    t_weight.allocator()->init(*weight_info);
+    t_dst.allocator()->init(*dst_info);
 
     // Allocate and fill user tensors
     // Instead of using ACL allocator, the user can choose to import memory into the tensors
@@ -128,12 +128,12 @@
     fill<float>(CLAccessor(t_weight), 1, library.get());
 
     // Run runtime
-    runtime.run({ &t_input, &t_weight, &t_dst });
+    runtime.run({&t_input, &t_weight, &t_dst});
 
     // Create reference
-    SimpleTensor<float> ref_t_input{ t_input_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC };
-    SimpleTensor<float> ref_t_weight{ t_weight_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC };
-    SimpleTensor<float> ref_t_bias_placeholder{ t_dst_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC };
+    SimpleTensor<float> ref_t_input{t_input_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC};
+    SimpleTensor<float> ref_t_weight{t_weight_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC};
+    SimpleTensor<float> ref_t_bias_placeholder{t_dst_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC};
 
     // Fill reference
     fill<float>(ref_t_input, 0, library.get());
@@ -145,12 +145,15 @@
     auto t_dst_shape_nchw            = t_dst_shape;
     permute(t_dst_shape_nchw, PermutationVector(1U, 2U, 0U));
 
-    PadStrideInfo legacy_pad_stride(conv2d_attr.stride().x(), conv2d_attr.stride().y(), conv2d_attr.pad().left, conv2d_attr.pad().right, conv2d_attr.pad().top, conv2d_attr.pad().bottom,
+    PadStrideInfo legacy_pad_stride(conv2d_attr.stride().x(), conv2d_attr.stride().y(), conv2d_attr.pad().left,
+                                    conv2d_attr.pad().right, conv2d_attr.pad().top, conv2d_attr.pad().bottom,
                                     DimensionRoundingType{});
-    auto       ref_t_dst_nchw = reference::convolution_layer(ref_t_input_nchw, ref_t_weight_nchw, ref_t_bias_placeholder_nchw, t_dst_shape_nchw, legacy_pad_stride, conv2d_attr.dilation());
-    const auto ref_t_dst      = reference::permute(ref_t_dst_nchw, PermutationVector(2U, 0U, 1U));
+    auto ref_t_dst_nchw = reference::convolution_layer(ref_t_input_nchw, ref_t_weight_nchw, ref_t_bias_placeholder_nchw,
+                                                       t_dst_shape_nchw, legacy_pad_stride, conv2d_attr.dilation());
+    const auto ref_t_dst = reference::permute(ref_t_dst_nchw, PermutationVector(2U, 0U, 1U));
 
-    RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
+    RelativeTolerance<float> tolerance_f32(
+        0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
     validate(CLAccessor(t_dst), ref_t_dst_nchw, tolerance_f32);
 }
 #endif // ACL_INTERNAL_TEST_CKW_IN_DF
@@ -167,20 +170,20 @@
 
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &context };
+    auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+    GpuWorkloadSketch sketch{&context};
 
-    TensorInfo in_0_info = context.create_tensor_info(t_input_shape, 1, data_type);
-    TensorInfo in_1_info = context.create_tensor_info(t_input_shape, 1, data_type);
-    TensorInfo in_2_info = context.create_tensor_info(t_input_shape, 1, data_type);
+    ITensorInfo *in_0_info = context.create_tensor_info(t_input_shape, 1, data_type);
+    ITensorInfo *in_1_info = context.create_tensor_info(t_input_shape, 1, data_type);
+    ITensorInfo *in_2_info = context.create_tensor_info(t_input_shape, 1, data_type);
 
-    TensorInfo out_0_info = context.create_tensor_info();
-    TensorInfo out_1_info = context.create_tensor_info();
+    ITensorInfo *out_0_info = context.create_tensor_info();
+    ITensorInfo *out_1_info = context.create_tensor_info();
 
-    ITensorInfo *ans_0_info = GpuAdd::create_op(sketch, &in_0_info, &in_1_info);
-    GpuOutput::create_op(sketch, ans_0_info, &out_0_info);
-    ITensorInfo *ans_1_info = GpuAdd::create_op(sketch, ans_0_info, &in_2_info);
-    GpuOutput::create_op(sketch, ans_1_info, &out_1_info);
+    ITensorInfo *ans_0_info = GpuAdd::create_op(sketch, in_0_info, in_1_info);
+    GpuOutput::create_op(sketch, ans_0_info, out_0_info);
+    ITensorInfo *ans_1_info = GpuAdd::create_op(sketch, ans_0_info, in_2_info);
+    GpuOutput::create_op(sketch, ans_1_info, out_1_info);
 
     // Configure runtime
     ClWorkloadRuntime runtime;
@@ -188,7 +191,7 @@
 
     // (Important) Allocate auxiliary tensor memory if there are any
     // Instead of using ACL allocated memory, the user can choose to import memory into the tensors
-    for(auto &data : runtime.get_auxiliary_tensors())
+    for (auto &data : runtime.get_auxiliary_tensors())
     {
         CLTensor     *tensor      = std::get<0>(data);
         TensorInfo    info        = std::get<1>(data);
@@ -208,12 +211,12 @@
     CLTensor t_out_1{};
 
     // Initialize user tensors
-    t_in_0.allocator()->init(in_0_info);
-    t_in_1.allocator()->init(in_1_info);
-    t_in_2.allocator()->init(in_2_info);
+    t_in_0.allocator()->init(*in_0_info);
+    t_in_1.allocator()->init(*in_1_info);
+    t_in_2.allocator()->init(*in_2_info);
 
-    t_out_0.allocator()->init(out_0_info);
-    t_out_1.allocator()->init(out_1_info);
+    t_out_0.allocator()->init(*out_0_info);
+    t_out_1.allocator()->init(*out_1_info);
 
     // Allocate and fill user tensors
     // Instead of using ACL allocator, the user can choose to import memory into the tensors
@@ -229,15 +232,15 @@
     fill<float>(CLAccessor(t_in_2), 2, library.get());
 
     // Run runtime
-    runtime.run({ &t_in_0, &t_in_1, &t_in_2, &t_out_0, &t_out_1 });
+    runtime.run({&t_in_0, &t_in_1, &t_in_2, &t_out_0, &t_out_1});
 
     // Create reference
-    SimpleTensor<float> ref_t_in_0{ t_input_shape, data_type, 1, QuantizationInfo() };
-    SimpleTensor<float> ref_t_in_1{ t_input_shape, data_type, 1, QuantizationInfo() };
-    SimpleTensor<float> ref_t_in_2{ t_input_shape, data_type, 1, QuantizationInfo() };
+    SimpleTensor<float> ref_t_in_0{t_input_shape, data_type, 1, QuantizationInfo()};
+    SimpleTensor<float> ref_t_in_1{t_input_shape, data_type, 1, QuantizationInfo()};
+    SimpleTensor<float> ref_t_in_2{t_input_shape, data_type, 1, QuantizationInfo()};
 
-    SimpleTensor<float> ref_t_out_0{ t_input_shape, data_type, 1, QuantizationInfo() };
-    SimpleTensor<float> ref_t_out_1{ t_input_shape, data_type, 1, QuantizationInfo() };
+    SimpleTensor<float> ref_t_out_0{t_input_shape, data_type, 1, QuantizationInfo()};
+    SimpleTensor<float> ref_t_out_1{t_input_shape, data_type, 1, QuantizationInfo()};
 
     // Fill reference
     fill<float>(ref_t_in_0, 0, library.get());
@@ -245,9 +248,11 @@
     fill<float>(ref_t_in_2, 2, library.get());
 
     reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_in_0, ref_t_in_1, ref_t_out_0, ConvertPolicy::WRAP);
-    reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_out_0, ref_t_in_2, ref_t_out_1, ConvertPolicy::WRAP);
+    reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_out_0, ref_t_in_2, ref_t_out_1,
+                                    ConvertPolicy::WRAP);
 
-    RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
+    RelativeTolerance<float> tolerance_f32(
+        0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
     validate(CLAccessor(t_out_0), ref_t_out_0, tolerance_f32);
     validate(CLAccessor(t_out_1), ref_t_out_1, tolerance_f32);
 }
@@ -264,15 +269,15 @@
 
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &context };
+    auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+    GpuWorkloadSketch sketch{&context};
 
-    TensorInfo in_0_info = context.create_tensor_info(t_input_shape, 1, data_type);
-    TensorInfo in_1_info = context.create_tensor_info(t_input_shape, 1, data_type);
-    TensorInfo in_2_info = context.create_tensor_info(t_input_shape, 1, data_type);
+    ITensorInfo *in_0_info = context.create_tensor_info(t_input_shape, 1, data_type);
+    ITensorInfo *in_1_info = context.create_tensor_info(t_input_shape, 1, data_type);
+    ITensorInfo *in_2_info = context.create_tensor_info(t_input_shape, 1, data_type);
 
-    TensorInfo out_0_info = context.create_tensor_info();
-    TensorInfo out_1_info = context.create_tensor_info();
+    ITensorInfo *out_0_info = context.create_tensor_info();
+    ITensorInfo *out_1_info = context.create_tensor_info();
 
     CastAttributes cast_0_attr;
     cast_0_attr.data_type(DataType::S32).convert_policy(ConvertPolicy::SATURATE);
@@ -280,12 +285,12 @@
     CastAttributes cast_1_attr;
     cast_1_attr.data_type(DataType::F32).convert_policy(ConvertPolicy::SATURATE);
 
-    ITensorInfo *ans_0_info = GpuAdd::create_op(sketch, &in_0_info, &in_1_info);
-    GpuOutput::create_op(sketch, ans_0_info, &out_0_info);
-    ITensorInfo *ans_1_info = GpuAdd::create_op(sketch, ans_0_info, &in_2_info);
+    ITensorInfo *ans_0_info = GpuAdd::create_op(sketch, in_0_info, in_1_info);
+    GpuOutput::create_op(sketch, ans_0_info, out_0_info);
+    ITensorInfo *ans_1_info = GpuAdd::create_op(sketch, ans_0_info, in_2_info);
     ITensorInfo *ans_2_info = GpuCast::create_op(sketch, ans_1_info, cast_0_attr);
     ITensorInfo *ans_3_info = GpuCast::create_op(sketch, ans_2_info, cast_1_attr);
-    GpuOutput::create_op(sketch, ans_3_info, &out_1_info);
+    GpuOutput::create_op(sketch, ans_3_info, out_1_info);
 
     // Configure runtime
     ClWorkloadRuntime runtime;
@@ -293,7 +298,7 @@
 
     // (Important) Allocate auxiliary tensor memory if there are any
     // Instead of using ACL allocated memory, the user can choose to import memory into the tensors
-    for(auto &data : runtime.get_auxiliary_tensors())
+    for (auto &data : runtime.get_auxiliary_tensors())
     {
         CLTensor     *tensor      = std::get<0>(data);
         TensorInfo    info        = std::get<1>(data);
@@ -313,12 +318,12 @@
     CLTensor t_out_1{};
 
     // Initialize user tensors
-    t_in_0.allocator()->init(in_0_info);
-    t_in_1.allocator()->init(in_1_info);
-    t_in_2.allocator()->init(in_2_info);
+    t_in_0.allocator()->init(*in_0_info);
+    t_in_1.allocator()->init(*in_1_info);
+    t_in_2.allocator()->init(*in_2_info);
 
-    t_out_0.allocator()->init(out_0_info);
-    t_out_1.allocator()->init(out_1_info);
+    t_out_0.allocator()->init(*out_0_info);
+    t_out_1.allocator()->init(*out_1_info);
 
     // Allocate and fill user tensors
     // Instead of using ACL allocator, the user can choose to import memory into the tensors
@@ -334,15 +339,15 @@
     fill<float>(CLAccessor(t_in_2), 2, library.get());
 
     // Run runtime
-    runtime.run({ &t_in_0, &t_in_1, &t_in_2, &t_out_0, &t_out_1 });
+    runtime.run({&t_in_0, &t_in_1, &t_in_2, &t_out_0, &t_out_1});
 
     // Create reference
-    SimpleTensor<float> ref_t_in_0{ t_input_shape, data_type, 1, QuantizationInfo() };
-    SimpleTensor<float> ref_t_in_1{ t_input_shape, data_type, 1, QuantizationInfo() };
-    SimpleTensor<float> ref_t_in_2{ t_input_shape, data_type, 1, QuantizationInfo() };
+    SimpleTensor<float> ref_t_in_0{t_input_shape, data_type, 1, QuantizationInfo()};
+    SimpleTensor<float> ref_t_in_1{t_input_shape, data_type, 1, QuantizationInfo()};
+    SimpleTensor<float> ref_t_in_2{t_input_shape, data_type, 1, QuantizationInfo()};
 
-    SimpleTensor<float> ref_t_out_0{ t_input_shape, data_type, 1, QuantizationInfo() };
-    SimpleTensor<float> ref_t_ans_1{ t_input_shape, data_type, 1, QuantizationInfo() };
+    SimpleTensor<float> ref_t_out_0{t_input_shape, data_type, 1, QuantizationInfo()};
+    SimpleTensor<float> ref_t_ans_1{t_input_shape, data_type, 1, QuantizationInfo()};
 
     // Fill reference
     fill<float>(ref_t_in_0, 0, library.get());
@@ -350,9 +355,12 @@
     fill<float>(ref_t_in_2, 2, library.get());
 
     reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_in_0, ref_t_in_1, ref_t_out_0, ConvertPolicy::WRAP);
-    reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_out_0, ref_t_in_2, ref_t_ans_1, ConvertPolicy::WRAP);
-    const auto ref_t_ans_2 = reference::depth_convert<float, int32_t>(ref_t_ans_1, DataType::S32, ConvertPolicy::SATURATE, 0);
-    const auto ref_t_out_1 = reference::depth_convert<int32_t, float>(ref_t_ans_2, DataType::F32, ConvertPolicy::SATURATE, 0);
+    reference::arithmetic_operation(ArithmeticOperation::ADD, ref_t_out_0, ref_t_in_2, ref_t_ans_1,
+                                    ConvertPolicy::WRAP);
+    const auto ref_t_ans_2 =
+        reference::depth_convert<float, int32_t>(ref_t_ans_1, DataType::S32, ConvertPolicy::SATURATE, 0);
+    const auto ref_t_out_1 =
+        reference::depth_convert<int32_t, float>(ref_t_ans_2, DataType::F32, ConvertPolicy::SATURATE, 0);
 
     RelativeTolerance<float> tolerance_add_f32(0.001f);
     AbsoluteTolerance<float> tolerance_cast_f32(1.0f);
@@ -436,20 +444,22 @@
     Conv2dAttributes conv2d_attr;
     auto             tensor1_info = context.create_tensor_info(conv2d_wei_shape, 1, DataType::F32, DataLayout::NHWC);
     auto             tensor2_info = context.create_tensor_info(conv2d_bia_shape, 1, DataType::F32, DataLayout::NHWC);
-    ARM_COMPUTE_EXPECT(GpuConv2d::validate_op(sketch0, &tensor0_info, &tensor1_info, &tensor2_info, conv2d_attr), framework::LogLevel::ERRORS);
-    auto ans_info = GpuConv2d::create_op(sketch0, &tensor0_info, &tensor1_info, &tensor2_info, conv2d_attr);
+    ARM_COMPUTE_EXPECT(GpuConv2d::validate_op(sketch0, tensor0_info, tensor1_info, tensor2_info, conv2d_attr),
+                       framework::LogLevel::ERRORS);
+    auto ans_info = GpuConv2d::create_op(sketch0, tensor0_info, tensor1_info, tensor2_info, conv2d_attr);
 
     ARM_COMPUTE_EXPECT(GpuSigmoid::validate_op(sketch0, ans_info), framework::LogLevel::ERRORS);
     ans_info = GpuSigmoid::create_op(sketch0, ans_info);
 
     DepthwiseConv2dAttributes dwc_attr;
-    auto                      tensor3_info = context.create_tensor_info(dwc_wei_shape, 1, DataType::F32, DataLayout::NHWC);
-    auto                      tensor4_info = context.create_tensor_info(dwc_bia_shape, 1, DataType::F32, DataLayout::NHWC);
-    ARM_COMPUTE_EXPECT(!GpuDepthwiseConv2d::validate_op(sketch0, ans_info, &tensor3_info, &tensor4_info, dwc_attr), framework::LogLevel::ERRORS);
+    auto tensor3_info = context.create_tensor_info(dwc_wei_shape, 1, DataType::F32, DataLayout::NHWC);
+    auto tensor4_info = context.create_tensor_info(dwc_bia_shape, 1, DataType::F32, DataLayout::NHWC);
+    ARM_COMPUTE_EXPECT(!GpuDepthwiseConv2d::validate_op(sketch0, ans_info, tensor3_info, tensor4_info, dwc_attr),
+                       framework::LogLevel::ERRORS);
 
     auto tensor5_info = context.create_tensor_info();
-    ARM_COMPUTE_EXPECT(GpuOutput::validate_op(sketch0, ans_info, &tensor5_info), framework::LogLevel::ERRORS);
-    GpuOutput::create_op(sketch0, ans_info, &tensor5_info);
+    ARM_COMPUTE_EXPECT(GpuOutput::validate_op(sketch0, ans_info, tensor5_info), framework::LogLevel::ERRORS);
+    GpuOutput::create_op(sketch0, ans_info, tensor5_info);
 
     // Create the first workload runtime.
     ClWorkloadRuntime runtime0;
@@ -458,15 +468,16 @@
     // Create the second sketch: dwc + sigmoid + output.
     GpuWorkloadSketch sketch1(&context);
 
-    ARM_COMPUTE_EXPECT(GpuDepthwiseConv2d::validate_op(sketch1, &tensor5_info, &tensor3_info, &tensor4_info, dwc_attr), framework::LogLevel::ERRORS);
-    ans_info = GpuDepthwiseConv2d::create_op(sketch1, &tensor5_info, &tensor3_info, &tensor4_info, dwc_attr);
+    ARM_COMPUTE_EXPECT(GpuDepthwiseConv2d::validate_op(sketch1, tensor5_info, tensor3_info, tensor4_info, dwc_attr),
+                       framework::LogLevel::ERRORS);
+    ans_info = GpuDepthwiseConv2d::create_op(sketch1, tensor5_info, tensor3_info, tensor4_info, dwc_attr);
 
-    ARM_COMPUTE_EXPECT(GpuMul::validate_op(sketch1, ans_info, &tensor2_info), framework::LogLevel::ERRORS);
-    ans_info = GpuMul::create_op(sketch1, ans_info, &tensor2_info);
+    ARM_COMPUTE_EXPECT(GpuMul::validate_op(sketch1, ans_info, tensor2_info), framework::LogLevel::ERRORS);
+    ans_info = GpuMul::create_op(sketch1, ans_info, tensor2_info);
 
     auto tensor6_info = context.create_tensor_info();
-    ARM_COMPUTE_EXPECT(GpuOutput::validate_op(sketch1, ans_info, &tensor6_info), framework::LogLevel::ERRORS);
-    GpuOutput::create_op(sketch1, ans_info, &tensor6_info);
+    ARM_COMPUTE_EXPECT(GpuOutput::validate_op(sketch1, ans_info, tensor6_info), framework::LogLevel::ERRORS);
+    GpuOutput::create_op(sketch1, ans_info, tensor6_info);
 
     // Create the second workload runtime.
     ClWorkloadRuntime runtime1;
@@ -481,13 +492,13 @@
     CLTensor tensor5;
     CLTensor tensor6;
 
-    tensor0.allocator()->init(tensor0_info);
-    tensor1.allocator()->init(tensor1_info);
-    tensor2.allocator()->init(tensor2_info);
-    tensor3.allocator()->init(tensor3_info);
-    tensor4.allocator()->init(tensor4_info);
-    tensor5.allocator()->init(tensor5_info);
-    tensor6.allocator()->init(tensor6_info);
+    tensor0.allocator()->init(*tensor0_info);
+    tensor1.allocator()->init(*tensor1_info);
+    tensor2.allocator()->init(*tensor2_info);
+    tensor3.allocator()->init(*tensor3_info);
+    tensor4.allocator()->init(*tensor4_info);
+    tensor5.allocator()->init(*tensor5_info);
+    tensor6.allocator()->init(*tensor6_info);
 
     tensor0.allocator()->allocate();
     tensor1.allocator()->allocate();
@@ -498,7 +509,7 @@
     tensor6.allocator()->allocate();
 
     // Allocate the auxiliary tensors.
-    for(auto &data : runtime0.get_auxiliary_tensors())
+    for (auto &data : runtime0.get_auxiliary_tensors())
     {
         auto  tensor      = std::get<0>(data);
         auto &tensor_info = std::get<1>(data);
@@ -508,7 +519,7 @@
         tensor->allocator()->allocate();
     }
 
-    for(auto &data : runtime1.get_auxiliary_tensors())
+    for (auto &data : runtime1.get_auxiliary_tensors())
     {
         auto  tensor      = std::get<0>(data);
         auto &tensor_info = std::get<1>(data);
@@ -526,8 +537,8 @@
     fill<float>(CLAccessor(tensor4), 4, library.get());
 
     // Run each runtime.
-    runtime0.run({ &tensor0, &tensor1, &tensor2, &tensor5 });
-    runtime1.run({ &tensor5, &tensor3, &tensor4, &tensor2, &tensor6 });
+    runtime0.run({&tensor0, &tensor1, &tensor2, &tensor5});
+    runtime1.run({&tensor5, &tensor3, &tensor4, &tensor2, &tensor6});
 
     // Compute the reference result.
     SimpleTensor<float> ref_conv2d_src(conv2d_src_shape, DataType::F32, 1, QuantizationInfo(), DataLayout::NHWC);
@@ -549,18 +560,22 @@
     const auto ref_conv2d_src_nchw = reference::permute(ref_conv2d_src, nhwc_to_nchw);
     const auto ref_conv2d_wei_nchw = reference::permute(ref_conv2d_wei, nhwc_to_nchw);
     const auto ref_conv2d_bia_nchw = reference::permute(ref_conv2d_bia, nhwc_to_nchw);
-    const auto ref_conv2d_dst_nchw = reference::convolution_layer(ref_conv2d_src_nchw, ref_conv2d_wei_nchw, ref_conv2d_bia_nchw, conv2d_dst_shape_nchw, PadStrideInfo());
+    const auto ref_conv2d_dst_nchw = reference::convolution_layer(
+        ref_conv2d_src_nchw, ref_conv2d_wei_nchw, ref_conv2d_bia_nchw, conv2d_dst_shape_nchw, PadStrideInfo());
 
-    const auto ref_sigmoid_dst_nchw = reference::activation_layer(ref_conv2d_dst_nchw, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
+    const auto ref_sigmoid_dst_nchw = reference::activation_layer(
+        ref_conv2d_dst_nchw, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC));
 
     auto dwc_dst_shape_nchw = dwc_dst_shape;
     permute(dwc_dst_shape_nchw, nhwc_to_nchw);
     const auto ref_dwc_wei_nchw = reference::permute(ref_dwc_wei, nhwc_to_nchw);
     const auto ref_dwc_bia_nchw = reference::permute(ref_dwc_bia, nhwc_to_nchw);
-    const auto ref_dwc_dst_nchw = reference::depthwise_convolution(ref_sigmoid_dst_nchw, ref_dwc_wei_nchw, ref_dwc_bia_nchw, dwc_dst_shape_nchw, PadStrideInfo(), 1);
+    const auto ref_dwc_dst_nchw = reference::depthwise_convolution(
+        ref_sigmoid_dst_nchw, ref_dwc_wei_nchw, ref_dwc_bia_nchw, dwc_dst_shape_nchw, PadStrideInfo(), 1);
 
-    const auto ref_mul_dst_nchw = reference::pixel_wise_multiplication<float, float, float>(ref_dwc_dst_nchw, ref_conv2d_bia_nchw, 1.0, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_UP,
-                                                                                            DataType::F32);
+    const auto ref_mul_dst_nchw = reference::pixel_wise_multiplication<float, float, float>(
+        ref_dwc_dst_nchw, ref_conv2d_bia_nchw, 1.0, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_UP,
+        DataType::F32);
 
     constexpr RelativeTolerance<float> tolerance(0.001f);
     validate(CLAccessor(tensor6), ref_mul_dst_nchw, tolerance);
@@ -587,34 +602,35 @@
 
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &context };
+    auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+    GpuWorkloadSketch sketch{&context};
 
     // Create tensor infos
-    TensorInfo   input_info  = context.create_tensor_info(t_input_shape, 1, data_type, data_layout);
-    TensorInfo   weight_info = context.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout));
+    ITensorInfo *input_info  = context.create_tensor_info(t_input_shape, 1, data_type, data_layout);
+    ITensorInfo *weight_info = context.create_tensor_info(TensorInfo(t_weight_shape, 1, data_type, data_layout));
     ITensorInfo *dst_info;
 
     // Fuse conv2d into the workload
     {
         // Validate operator
-        const Status success = GpuConv2d::validate_op(sketch, &input_info, &weight_info, nullptr, conv2d_attr);
+        const Status success = GpuConv2d::validate_op(sketch, input_info, weight_info, nullptr, conv2d_attr);
         ARM_COMPUTE_EXPECT(bool(success), framework::LogLevel::ERRORS);
 
-        dst_info = GpuConv2d::create_op(sketch, &input_info, &weight_info, nullptr, conv2d_attr);
+        dst_info = GpuConv2d::create_op(sketch, input_info, weight_info, nullptr, conv2d_attr);
     }
 
     // Create tensor infos
-    TensorInfo weight_info_2 = context.create_tensor_info(t_weight_info);
+    ITensorInfo *weight_info_2 = context.create_tensor_info(t_weight_info);
 
     // Fuse conv2d into the workload
     {
         // Validate operator, should fail
-        const Status success            = GpuConv2d::validate_op(sketch, dst_info, &weight_info_2, nullptr, conv2d_attr);
-        const auto   expected_error_str = "Operator fusion test failed. This operator cannot be fused into the workload";
+        const Status success          = GpuConv2d::validate_op(sketch, dst_info, weight_info_2, nullptr, conv2d_attr);
+        const auto expected_error_str = "Operator fusion test failed. This operator cannot be fused into the workload";
 
         ARM_COMPUTE_EXPECT(!bool(success), framework::LogLevel::ERRORS);
-        ARM_COMPUTE_EXPECT((success.error_description().find(expected_error_str) != std::string::npos), framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT((success.error_description().find(expected_error_str) != std::string::npos),
+                           framework::LogLevel::ERRORS);
     }
 }
 TEST_SUITE_END() // Invalid_Fusion_Should_Fail
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Add.cpp b/tests/validation/dynamic_fusion/gpu/cl/Add.cpp
index 09a8f3f..a358d47 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Add.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Add.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,14 +29,13 @@
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuAdd.h"
 
 #include "tests/CL/CLAccessor.h"
-#include "tests/framework/Fixture.h"
-#include "tests/framework/Macros.h"
-#include "tests/framework/datasets/Datasets.h"
-#include "tests/validation/Validation.h"
-
 #include "tests/datasets/DynamicFusionDataset.h"
 #include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Fixture.h"
+#include "tests/framework/Macros.h"
 #include "tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h"
+#include "tests/validation/Validation.h"
 
 namespace arm_compute
 {
@@ -97,32 +96,36 @@
     auto          lhs_info         = context.create_tensor_info(input1_info);
     auto          rhs_info         = context.create_tensor_info(input2_info);
 
-    bool res = bool(GpuAdd::validate_op(sketch, &lhs_info, &rhs_info));
+    bool res = bool(GpuAdd::validate_op(sketch, lhs_info, rhs_info));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
 }
 // clang-format on
 // *INDENT-ON*
 
-constexpr AbsoluteTolerance<float> tolerance_f(0.0001f);    /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 and DataType::F16 */
-constexpr float                    tolerance_num = 0.0001f; /**< Tolerance number */
+constexpr AbsoluteTolerance<float> tolerance_f(
+    0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 and DataType::F16 */
+constexpr float tolerance_num = 0.0001f; /**< Tolerance number */
 
 template <typename T>
-using DynamicFusionCLAddFixture = DynamicFusionGpuElementwiseBinaryOneOpValidationFixture<CLTensor, CLAccessor, GpuAdd, T>;
+using DynamicFusionCLAddFixture =
+    DynamicFusionGpuElementwiseBinaryOneOpValidationFixture<CLTensor, CLAccessor, GpuAdd, T>;
 
 template <typename T>
-using DynamicFusionCLAddBroadcastFixture = DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture<CLTensor, CLAccessor, GpuAdd, T>;
+using DynamicFusionCLAddBroadcastFixture =
+    DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture<CLTensor, CLAccessor, GpuAdd, T>;
 
 template <typename T>
-using DynamicFusionCLAddTwoOpsFixture = DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture<CLTensor, CLAccessor, GpuAdd, T>;
+using DynamicFusionCLAddTwoOpsFixture =
+    DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture<CLTensor, CLAccessor, GpuAdd, T>;
 
 TEST_SUITE(FP32)
 FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
                        DynamicFusionCLAddFixture<float>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
                                                datasets::SmallShapes()),
-                                       framework::dataset::make("DataType", { DataType::F32 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F32})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f);
@@ -130,10 +133,10 @@
 FIXTURE_DATA_TEST_CASE(RunLargeOneOp,
                        DynamicFusionCLAddFixture<float>,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
                                                datasets::LargeShapes()),
-                                       framework::dataset::make("DataType", { DataType::F32 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F32})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f);
@@ -141,10 +144,10 @@
 FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp,
                        DynamicFusionCLAddBroadcastFixture<float>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
                                                datasets::TemporaryLimitedSmallShapesBroadcast()),
-                                       framework::dataset::make("DataType", { DataType::F32 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F32})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f);
@@ -153,22 +156,23 @@
 FIXTURE_DATA_TEST_CASE(RunLargeBroadcastOneOp,
                        DynamicFusionCLAddBroadcastFixture<float>,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
                                                datasets::TemporaryLimitedLargeShapesBroadcast()),
-                                       framework::dataset::make("DataType", { DataType::F32 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F32})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f);
 }
-FIXTURE_DATA_TEST_CASE(RunSmallTwoOps,
-                       DynamicFusionCLAddTwoOpsFixture<float>,
-                       framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
-                                                       datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes()),
-                                               framework::dataset::make("DataType", { DataType::F32 })),
-                                       framework::dataset::make("InPlace", { false })),
-                               framework::dataset::make("FuseTwoOps", { true })))
+FIXTURE_DATA_TEST_CASE(
+    RunSmallTwoOps,
+    DynamicFusionCLAddTwoOpsFixture<float>,
+    framework::DatasetMode::PRECOMMIT,
+    combine(combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
+                                    datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes()),
+                            framework::dataset::make("DataType", {DataType::F32})),
+                    framework::dataset::make("InPlace", {false})),
+            framework::dataset::make("FuseTwoOps", {true})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f);
@@ -179,10 +183,10 @@
 FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
                        DynamicFusionCLAddFixture<half>,
                        framework::DatasetMode::ALL,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
                                                datasets::SmallShapes()),
-                                       framework::dataset::make("DataType", { DataType::F16 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F16})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f, tolerance_num);
@@ -191,10 +195,10 @@
 FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp,
                        DynamicFusionCLAddBroadcastFixture<half>,
                        framework::DatasetMode::ALL,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
                                                datasets::TemporaryLimitedSmallShapesBroadcast()),
-                                       framework::dataset::make("DataType", { DataType::F16 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F16})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f, tolerance_num);
@@ -206,10 +210,10 @@
 FIXTURE_DATA_TEST_CASE(RunSmall,
                        DynamicFusionCLAddFixture<int32_t>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
                                                datasets::SmallShapes()),
-                                       framework::dataset::make("DataType", { DataType::S32 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::S32})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -220,10 +224,10 @@
 FIXTURE_DATA_TEST_CASE(RunSmall,
                        DynamicFusionCLAddFixture<int16_t>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
                                                datasets::SmallShapes()),
-                                       framework::dataset::make("DataType", { DataType::S16 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::S16})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -231,10 +235,10 @@
 FIXTURE_DATA_TEST_CASE(RunLarge,
                        DynamicFusionCLAddFixture<int16_t>,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
                                                datasets::LargeShapes()),
-                                       framework::dataset::make("DataType", { DataType::S16 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::S16})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -245,10 +249,10 @@
 FIXTURE_DATA_TEST_CASE(RunSmall,
                        DynamicFusionCLAddFixture<uint8_t>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::ADD }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::ADD}),
                                                datasets::SmallShapes()),
-                                       framework::dataset::make("DataType", { DataType::U8 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::U8})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp b/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp
index 285c0d6..cef8b87 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Clamp.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,10 +29,10 @@
 #include "tests/CL/CLAccessor.h"
 #include "tests/datasets/ShapeDatasets.h"
 #include "tests/framework/Asserts.h"
-#include "tests/framework/Macros.h"
 #include "tests/framework/datasets/Datasets.h"
-#include "tests/validation/Validation.h"
+#include "tests/framework/Macros.h"
 #include "tests/validation/fixtures/dynamic_fusion/operators/ClampFixture.h"
+#include "tests/validation/Validation.h"
 
 namespace arm_compute
 {
@@ -73,13 +73,13 @@
     GpuWorkloadSketch sketch{ &context };
 
     // Fuse Clamp
-    const TensorInfo src_info = context.create_tensor_info(input_info);
+    const ITensorInfo* src_info = context.create_tensor_info(input_info);
 
     ClampAttributes attributes {};
     attributes.min_val(min_val)
               .max_val(max_val);
 
-    const bool res = static_cast<bool>(GpuClamp::validate_op(sketch, &src_info, attributes));
+    const bool res = static_cast<bool>(GpuClamp::validate_op(sketch, src_info, attributes));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
 }
 // clang-format on
@@ -94,8 +94,9 @@
                        DynamicFusionClampOpFixture<half>,
                        framework::DatasetMode::ALL,
                        combine(combine(combine(datasets::SmallShapes(),
-                                               framework::dataset::make("ClampAttributes", { ClampAttributes().min_val(0.1f).max_val(0.6f) })),
-                                       framework::dataset::make("Fuse", { false })),
+                                               framework::dataset::make(
+                                                   "ClampAttributes", {ClampAttributes().min_val(0.1f).max_val(0.6f)})),
+                                       framework::dataset::make("Fuse", {false})),
                                framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
@@ -106,8 +107,9 @@
                        DynamicFusionClampOpFixture<half>,
                        framework::DatasetMode::ALL,
                        combine(combine(combine(datasets::Small5dShapes(),
-                                               framework::dataset::make("ClampAttributes", { ClampAttributes().min_val(0.1f).max_val(0.6f) })),
-                                       framework::dataset::make("Fuse", { false })),
+                                               framework::dataset::make(
+                                                   "ClampAttributes", {ClampAttributes().min_val(0.1f).max_val(0.6f)})),
+                                       framework::dataset::make("Fuse", {false})),
                                framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
@@ -119,8 +121,9 @@
                        DynamicFusionClampOpFixture<half>,
                        framework::DatasetMode::ALL,
                        combine(combine(combine(datasets::SmallShapes(),
-                                               framework::dataset::make("ClampAttributes", { ClampAttributes().min_val(0.2f).max_val(0.4f) })),
-                                       framework::dataset::make("Fuse", { true })),
+                                               framework::dataset::make(
+                                                   "ClampAttributes", {ClampAttributes().min_val(0.2f).max_val(0.4f)})),
+                                       framework::dataset::make("Fuse", {true})),
                                framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
@@ -134,8 +137,9 @@
                        DynamicFusionClampOpFixture<float>,
                        framework::DatasetMode::ALL,
                        combine(combine(combine(datasets::SmallShapes(),
-                                               framework::dataset::make("ClampAttributes", { ClampAttributes().min_val(0.3f).max_val(0.7f) })),
-                                       framework::dataset::make("Fuse", { false })),
+                                               framework::dataset::make(
+                                                   "ClampAttributes", {ClampAttributes().min_val(0.3f).max_val(0.7f)})),
+                                       framework::dataset::make("Fuse", {false})),
                                framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
@@ -146,8 +150,9 @@
                        DynamicFusionClampOpFixture<float>,
                        framework::DatasetMode::ALL,
                        combine(combine(combine(datasets::Small5dShapes(),
-                                               framework::dataset::make("ClampAttributes", { ClampAttributes().min_val(0.3f).max_val(0.7f) })),
-                                       framework::dataset::make("Fuse", { false })),
+                                               framework::dataset::make(
+                                                   "ClampAttributes", {ClampAttributes().min_val(0.3f).max_val(0.7f)})),
+                                       framework::dataset::make("Fuse", {false})),
                                framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
@@ -159,8 +164,9 @@
                        DynamicFusionClampOpFixture<float>,
                        framework::DatasetMode::ALL,
                        combine(combine(combine(datasets::SmallShapes(),
-                                               framework::dataset::make("ClampAttributes", { ClampAttributes().min_val(0.1f).max_val(0.9f) })),
-                                       framework::dataset::make("Fuse", { true })),
+                                               framework::dataset::make(
+                                                   "ClampAttributes", {ClampAttributes().min_val(0.1f).max_val(0.9f)})),
+                                       framework::dataset::make("Fuse", {true})),
                                framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
diff --git a/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp
index aec1306..40e1ea8 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/DepthwiseConv2d.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,11 +28,11 @@
 #include "tests/datasets/DepthwiseConvolutionLayerDataset.h"
 #include "tests/datasets/DilatedDepthwiseConvolutionLayerDataset.h"
 #include "tests/framework/Asserts.h"
+#include "tests/framework/datasets/Datasets.h"
 #include "tests/framework/Fixture.h"
 #include "tests/framework/Macros.h"
-#include "tests/framework/datasets/Datasets.h"
-#include "tests/validation/Validation.h"
 #include "tests/validation/fixtures/dynamic_fusion/gpu/cl/DepthwiseConv2dFixture.h"
+#include "tests/validation/Validation.h"
 
 namespace arm_compute
 {
@@ -40,16 +40,18 @@
 {
 namespace validation
 {
-const auto depth_multipliers       = framework::dataset::make("DepthMultiplier", { 1U, 4U });
-const auto large_depth_multipliers = framework::dataset::make("DepthMultiplier", { 1, 2, 5, 8 });
+const auto depth_multipliers       = framework::dataset::make("DepthMultiplier", {1U, 4U});
+const auto large_depth_multipliers = framework::dataset::make("DepthMultiplier", {1, 2, 5, 8});
 
 TEST_SUITE(CL)
 TEST_SUITE(DYNAMIC_FUSION)
 TEST_SUITE(DEPTHWISE_CONV2D)
 
-RelativeTolerance<float>            tolerance_f32(0.01f);                 /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
-RelativeTolerance<half_float::half> tolerance_f16(half_float::half(0.1)); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
-constexpr float                     tolerance_num = 0.02f;                /**< Tolerance number */
+RelativeTolerance<float> tolerance_f32(
+    0.01f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
+RelativeTolerance<half_float::half> tolerance_f16(half_float::half(
+    0.1)); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
+constexpr float                     tolerance_num = 0.02f; /**< Tolerance number */
 
 // *INDENT-OFF*
 // clang-format off
@@ -245,9 +247,9 @@
     GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx };
     GpuWorkloadSketch sketch{ &context };
 
-    const TensorInfo sketch_input_info   = context.create_tensor_info(input_info);
-    const TensorInfo sketch_weights_info = context.create_tensor_info(weights_info);
-    const TensorInfo sketch_biases_info  = context.create_tensor_info(biases_info);
+    const ITensorInfo* sketch_input_info   = context.create_tensor_info(input_info);
+    const ITensorInfo* sketch_weights_info = context.create_tensor_info(weights_info);
+    const ITensorInfo* sketch_biases_info  = context.create_tensor_info(biases_info);
 
     DepthwiseConv2dAttributes attributes {};
     attributes.pad(padding)
@@ -255,7 +257,7 @@
               .dilation(dilation)
               .depth_multiplier(depth_multiplier);
 
-    const Status status = GpuDepthwiseConv2d::validate_op(sketch, &sketch_input_info, &sketch_weights_info, &sketch_biases_info, attributes);
+    const Status status = GpuDepthwiseConv2d::validate_op(sketch, sketch_input_info, sketch_weights_info, sketch_biases_info, attributes);
     const bool res = bool(status);
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
 }
@@ -263,40 +265,50 @@
 // *INDENT-ON*
 
 template <typename T>
-using DynamicFusionGpuDepthwiseConv2dFixture = DynamicFusionGpuDepthwiseConv2dValidationFixture<CLTensor, CLAccessor, GpuDepthwiseConv2d, T>;
+using DynamicFusionGpuDepthwiseConv2dFixture =
+    DynamicFusionGpuDepthwiseConv2dValidationFixture<CLTensor, CLAccessor, GpuDepthwiseConv2d, T>;
 
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
 TEST_SUITE(W3x3)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture<half>, framework::DatasetMode::ALL,
-                       combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(),
-                                               depth_multipliers),
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuDepthwiseConv2dFixture<half>,
+                       framework::DatasetMode::ALL,
+                       combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(), depth_multipliers),
                                        framework::dataset::make("DataType", DataType::F16)),
                                framework::dataset::make("DataLayout", DataLayout::NHWC)))
 {
     validate(CLAccessor(_target), _reference, tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(),
-                                                                                                                        large_depth_multipliers),
-                                                                                                                        framework::dataset::make("DataType", DataType::F16)),
-                                                                                                                        framework::dataset::make("DataLayout", DataLayout::NHWC)))
+FIXTURE_DATA_TEST_CASE(RunLarge,
+                       DynamicFusionGpuDepthwiseConv2dFixture<half>,
+                       framework::DatasetMode::NIGHTLY,
+                       combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(),
+                                               large_depth_multipliers),
+                                       framework::dataset::make("DataType", DataType::F16)),
+                               framework::dataset::make("DataLayout", DataLayout::NHWC)))
 {
     validate(CLAccessor(_target), _reference, tolerance_f16);
 }
 #ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test as dilation not supported yet in DepthwiseConv2d CKW kernel
 TEST_SUITE(Dilation)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(),
-                                                                                                                    depth_multipliers),
-                                                                                                                    framework::dataset::make("DataType", DataType::F16)),
-                                                                                                                    framework::dataset::make("DataLayout", { DataLayout::NHWC })))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuDepthwiseConv2dFixture<half>,
+                       framework::DatasetMode::ALL,
+                       combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(),
+                                               depth_multipliers),
+                                       framework::dataset::make("DataType", DataType::F16)),
+                               framework::dataset::make("DataLayout", {DataLayout::NHWC})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture<half>, framework::DatasetMode::NIGHTLY,
+FIXTURE_DATA_TEST_CASE(RunLarge,
+                       DynamicFusionGpuDepthwiseConv2dFixture<half>,
+                       framework::DatasetMode::NIGHTLY,
                        combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(),
                                                large_depth_multipliers),
                                        framework::dataset::make("DataType", DataType::F16)),
-                               framework::dataset::make("DataLayout", { DataLayout::NHWC })))
+                               framework::dataset::make("DataLayout", {DataLayout::NHWC})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f16);
 }
@@ -305,34 +317,44 @@
 TEST_SUITE_END() // W3x3
 
 TEST_SUITE(Generic)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
-                                                                                                                    depth_multipliers),
-                                                                                                                    framework::dataset::make("DataType", DataType::F16)),
-                                                                                                                    framework::dataset::make("DataLayout", { DataLayout::NHWC })))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuDepthwiseConv2dFixture<half>,
+                       framework::DatasetMode::ALL,
+                       combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(), depth_multipliers),
+                                       framework::dataset::make("DataType", DataType::F16)),
+                               framework::dataset::make("DataLayout", {DataLayout::NHWC})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset(),
-                                                                                                                        large_depth_multipliers),
-                                                                                                                        framework::dataset::make("DataType", DataType::F16)),
-                                                                                                                        framework::dataset::make("DataLayout", { DataLayout::NHWC })))
+FIXTURE_DATA_TEST_CASE(RunLarge,
+                       DynamicFusionGpuDepthwiseConv2dFixture<half>,
+                       framework::DatasetMode::NIGHTLY,
+                       combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset(),
+                                               large_depth_multipliers),
+                                       framework::dataset::make("DataType", DataType::F16)),
+                               framework::dataset::make("DataLayout", {DataLayout::NHWC})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
 }
 #ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test as dilation not supported yet in DepthwiseConv2d CKW kernel
 TEST_SUITE(Dilation)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
-                                                                                                                    depth_multipliers),
-                                                                                                                    framework::dataset::make("DataType", DataType::F16)),
-                                                                                                                    framework::dataset::make("DataLayout", { DataLayout::NHWC })))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuDepthwiseConv2dFixture<half>,
+                       framework::DatasetMode::ALL,
+                       combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
+                                               depth_multipliers),
+                                       framework::dataset::make("DataType", DataType::F16)),
+                               framework::dataset::make("DataLayout", {DataLayout::NHWC})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture<half>, framework::DatasetMode::NIGHTLY,
+FIXTURE_DATA_TEST_CASE(RunLarge,
+                       DynamicFusionGpuDepthwiseConv2dFixture<half>,
+                       framework::DatasetMode::NIGHTLY,
                        combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset(),
                                                large_depth_multipliers),
                                        framework::dataset::make("DataType", DataType::F16)),
-                               framework::dataset::make("DataLayout", { DataLayout::NHWC })))
+                               framework::dataset::make("DataLayout", {DataLayout::NHWC})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
 }
@@ -343,15 +365,18 @@
 
 TEST_SUITE(FP32)
 TEST_SUITE(W3x3)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture<float>, framework::DatasetMode::ALL,
-                       combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(),
-                                               depth_multipliers),
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuDepthwiseConv2dFixture<float>,
+                       framework::DatasetMode::ALL,
+                       combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(), depth_multipliers),
                                        framework::dataset::make("DataType", DataType::F32)),
                                framework::dataset::make("DataLayout", DataLayout::NHWC)))
 {
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture<float>, framework::DatasetMode::NIGHTLY,
+FIXTURE_DATA_TEST_CASE(RunLarge,
+                       DynamicFusionGpuDepthwiseConv2dFixture<float>,
+                       framework::DatasetMode::NIGHTLY,
                        combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset3x3(),
                                                large_depth_multipliers),
                                        framework::dataset::make("DataType", DataType::F32)),
@@ -363,7 +388,9 @@
 #ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test as dilation not supported yet in DepthwiseConv2d CKW kernel
 TEST_SUITE(Dilation)
 
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture<float>, framework::DatasetMode::ALL,
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuDepthwiseConv2dFixture<float>,
+                       framework::DatasetMode::ALL,
                        combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset3x3(),
                                                depth_multipliers),
                                        framework::dataset::make("DataType", DataType::F32)),
@@ -371,7 +398,9 @@
 {
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture<float>, framework::DatasetMode::NIGHTLY,
+FIXTURE_DATA_TEST_CASE(RunLarge,
+                       DynamicFusionGpuDepthwiseConv2dFixture<float>,
+                       framework::DatasetMode::NIGHTLY,
                        combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(),
                                                large_depth_multipliers),
                                        framework::dataset::make("DataType", DataType::F32)),
@@ -384,47 +413,57 @@
 TEST_SUITE_END() // W3x3
 
 TEST_SUITE(Generic)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture<float>, framework::DatasetMode::ALL,
-                       combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
-                                               depth_multipliers),
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuDepthwiseConv2dFixture<float>,
+                       framework::DatasetMode::ALL,
+                       combine(combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(), depth_multipliers),
                                        framework::dataset::make("DataType", DataType::F32)),
-                               framework::dataset::make("DataLayout", { DataLayout::NHWC })))
+                               framework::dataset::make("DataLayout", {DataLayout::NHWC})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture<float>, framework::DatasetMode::NIGHTLY,
+FIXTURE_DATA_TEST_CASE(RunLarge,
+                       DynamicFusionGpuDepthwiseConv2dFixture<float>,
+                       framework::DatasetMode::NIGHTLY,
                        combine(combine(combine(datasets::LargeDepthwiseConvolutionLayerDataset(),
                                                large_depth_multipliers),
                                        framework::dataset::make("DataType", DataType::F32)),
-                               framework::dataset::make("DataLayout", { DataLayout::NHWC })))
+                               framework::dataset::make("DataLayout", {DataLayout::NHWC})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLargeKernelSize, DynamicFusionGpuDepthwiseConv2dFixture<float>, framework::DatasetMode::ALL,
+FIXTURE_DATA_TEST_CASE(RunLargeKernelSize,
+                       DynamicFusionGpuDepthwiseConv2dFixture<float>,
+                       framework::DatasetMode::ALL,
                        combine(combine(combine(datasets::LargeKernelSizeDepthwiseConvolutionLayerNHWCDataset(),
-                                               framework::dataset::make("DepthMultiplier", { 1 })),
+                                               framework::dataset::make("DepthMultiplier", {1})),
                                        framework::dataset::make("DataType", DataType::F32)),
-                               framework::dataset::make("DataLayout", { DataLayout::NHWC })))
+                               framework::dataset::make("DataLayout", {DataLayout::NHWC})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
 
 #ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test as dilation not supported yet in DepthwiseConv2d CKW kernel
 TEST_SUITE(Dilation)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuDepthwiseConv2dFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
-                                                                                                                     depth_multipliers),
-                                                                                                                     framework::dataset::make("DataType", DataType::F32)),
-                                                                                                                     framework::dataset::make("DataLayout", { DataLayout::NHWC })))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuDepthwiseConv2dFixture<float>,
+                       framework::DatasetMode::ALL,
+                       combine(combine(combine(datasets::SmallDepthwiseDilatedConvolutionLayerDataset(),
+                                               depth_multipliers),
+                                       framework::dataset::make("DataType", DataType::F32)),
+                               framework::dataset::make("DataLayout", {DataLayout::NHWC})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuDepthwiseConv2dFixture<float>, framework::DatasetMode::NIGHTLY,
+FIXTURE_DATA_TEST_CASE(RunLarge,
+                       DynamicFusionGpuDepthwiseConv2dFixture<float>,
+                       framework::DatasetMode::NIGHTLY,
                        combine(combine(combine(datasets::LargeDepthwiseDilatedConvolutionLayerDataset3x3(),
                                                large_depth_multipliers),
                                        framework::dataset::make("DataType", DataType::F32)),
-                               framework::dataset::make("DataLayout", { DataLayout::NHWC })))
+                               framework::dataset::make("DataLayout", {DataLayout::NHWC})))
 {
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
diff --git a/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp
index bae8cbf..dae5500 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/DirectConv2d.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -24,14 +24,13 @@
 
 #include "tests/AssetsLibrary.h"
 #include "tests/CL/CLAccessor.h"
+#include "tests/datasets/SmallConvolutionLayerDataset.h"
+#include "tests/framework/datasets/Datasets.h"
 #include "tests/framework/Fixture.h"
 #include "tests/framework/Macros.h"
-#include "tests/framework/datasets/Datasets.h"
-#include "tests/validation/Validation.h"
-#include "tests/validation/reference/ConvolutionLayer.h"
-
-#include "tests/datasets/SmallConvolutionLayerDataset.h"
 #include "tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h"
+#include "tests/validation/reference/ConvolutionLayer.h"
+#include "tests/validation/Validation.h"
 
 namespace arm_compute
 {
@@ -43,10 +42,12 @@
 {
 /** Tolerances from tests/validation/CL/DirectConvolutionLayer.cpp
  */
-RelativeTolerance<float>            tolerance_f32(0.05f);                 /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
-RelativeTolerance<half_float::half> tolerance_f16(half_float::half(0.2)); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
-constexpr float                     abs_tolerance_f32(0.0001f);           /**< Absolute tolerance for FP32 tests*/
-constexpr float                     tolerance_num = 0.07f;                /**< Tolerance number */
+RelativeTolerance<float> tolerance_f32(
+    0.05f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
+RelativeTolerance<half_float::half> tolerance_f16(half_float::half(
+    0.2)); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
+constexpr float                     abs_tolerance_f32(0.0001f); /**< Absolute tolerance for FP32 tests*/
+constexpr float                     tolerance_num = 0.07f;      /**< Tolerance number */
 } // namespace
 
 TEST_SUITE(CL)
@@ -69,8 +70,13 @@
 template <typename T>
 using DynamicFusionGpuConv2dFixture = DynamicFusionGpuConv2dValidationFixture<CLTensor, CLAccessor, GpuConv2d, T>;
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuConv2dFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
-                                                                                                                    framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), framework::dataset::make("QuantizationInfo", QuantizationInfo())))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuConv2dFixture<float>,
+                       framework::DatasetMode::ALL,
+                       combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
+                                               framework::dataset::make("DataType", DataType::F32)),
+                                       framework::dataset::make("DataLayout", {DataLayout::NHWC})),
+                               framework::dataset::make("QuantizationInfo", QuantizationInfo())))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -78,8 +84,13 @@
 TEST_SUITE_END() // FP32
 
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuConv2dFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
-                                                                                                                   framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), framework::dataset::make("QuantizationInfo", QuantizationInfo())))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuConv2dFixture<half>,
+                       framework::DatasetMode::ALL,
+                       combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
+                                               framework::dataset::make("DataType", DataType::F16)),
+                                       framework::dataset::make("DataLayout", {DataLayout::NHWC})),
+                               framework::dataset::make("QuantizationInfo", QuantizationInfo())))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
@@ -156,10 +167,10 @@
     auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
     GpuWorkloadSketch sketch{ &context };
 
-    const TensorInfo sketch_input_info   = context.create_tensor_info(input_info);
-    const TensorInfo sketch_weights_info = context.create_tensor_info(weights_info);
-    const TensorInfo sketch_biases_info  = context.create_tensor_info(biases_info);
-    bool is_valid = bool(GpuConv2d::validate_op(sketch, &sketch_input_info, &sketch_weights_info, &sketch_biases_info, conv2d_attrs));
+    const ITensorInfo* sketch_input_info   = context.create_tensor_info(input_info);
+    const ITensorInfo* sketch_weights_info = context.create_tensor_info(weights_info);
+    const ITensorInfo* sketch_biases_info  = context.create_tensor_info(biases_info);
+    bool is_valid = bool(GpuConv2d::validate_op(sketch, sketch_input_info, sketch_weights_info, sketch_biases_info, conv2d_attrs));
     ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
 }
 template <typename T>
diff --git a/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp b/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp
index 38c3a0c..d714a2f 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/MatMul.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -24,16 +24,15 @@
 #ifdef ACL_INTERNAL_TEST_CKW_IN_DF
 #include "tests/AssetsLibrary.h"
 #include "tests/CL/CLAccessor.h"
-#include "tests/framework/Fixture.h"
-#include "tests/framework/Macros.h"
-#include "tests/framework/datasets/Datasets.h"
 #include "tests/datasets/LargeMatMulDataset.h"
 #include "tests/datasets/SmallMatMulDataset.h"
-#include "tests/validation/Validation.h"
-#include "tests/validation/reference/Permute.h"
-#include "tests/validation/reference/GEMM.h"
-
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Fixture.h"
+#include "tests/framework/Macros.h"
 #include "tests/validation/fixtures/dynamic_fusion/gpu/cl/MatMulKernelFixture.h"
+#include "tests/validation/reference/GEMM.h"
+#include "tests/validation/reference/Permute.h"
+#include "tests/validation/Validation.h"
 
 #include <tuple>
 
@@ -45,35 +44,37 @@
 {
 namespace
 {
-    RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
-constexpr float          abs_tolerance_f32(
+RelativeTolerance<float> tolerance_f32(
+    0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
+constexpr float abs_tolerance_f32(
     0.0001f); /**< Absolute tolerance value for comparing reference's output against implementation's output for floating point data types in case using relative tolerance fails because of small values */
 constexpr float abs_tolerance_f16(
-    0.001f);                                                   /**< Absolute tolerance value for comparing reference's output against implementation's output for fp16  data types in case using relative tolerance fails because of small values */
-    RelativeTolerance<half_float::half> tolerance_f16(half(0.02)); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
-}
+    0.001f); /**< Absolute tolerance value for comparing reference's output against implementation's output for fp16  data types in case using relative tolerance fails because of small values */
+RelativeTolerance<half_float::half> tolerance_f16(half(
+    0.02)); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
+} // namespace
 
 /** M0 values to test --precommit*/
-const auto m0_values_precommit = framework::dataset::make("M0", { 1, 3 });
+const auto m0_values_precommit = framework::dataset::make("M0", {1, 3});
 
 /** N0 values to test --precommit*/
-const auto n0_values_precommit = framework::dataset::make("N0", { 1, 2, 4 });
+const auto n0_values_precommit = framework::dataset::make("N0", {1, 2, 4});
 
 /** K0 values to test --precommit*/
-const auto k0_values_precommit = framework::dataset::make("K0", { 1, 2, 3 });
+const auto k0_values_precommit = framework::dataset::make("K0", {1, 2, 3});
 
 /** M0 values to test --nightly*/
-const auto m0_values_nightly_lhs_nt = framework::dataset::make("M0", { 1, 2, 3, 4, 5, 6, 7, 8 });
-const auto m0_values_nightly_lhs_t  = framework::dataset::make("M0", { 1, 2, 3, 4, 8 });
+const auto m0_values_nightly_lhs_nt = framework::dataset::make("M0", {1, 2, 3, 4, 5, 6, 7, 8});
+const auto m0_values_nightly_lhs_t  = framework::dataset::make("M0", {1, 2, 3, 4, 8});
 
 /** N0 values to test --nightly*/
-const auto n0_values_nightly_rhs_nt = framework::dataset::make("N0", { 1, 2, 3, 4, 8, 16 });
-const auto n0_values_nightly_rhs_t  = framework::dataset::make("N0", { 1, 2, 3, 4, 8 });
+const auto n0_values_nightly_rhs_nt = framework::dataset::make("N0", {1, 2, 3, 4, 8, 16});
+const auto n0_values_nightly_rhs_t  = framework::dataset::make("N0", {1, 2, 3, 4, 8});
 
 /** K0 values to test --nightly*/
-const auto k0_values_nightly_lhs_nt_rhs_nt = framework::dataset::make("K0", { 1, 2, 3, 4, 8, 16 });
-const auto k0_values_nightly_rhs_t         = framework::dataset::make("K0", { 1, 2, 3, 4, 8 });
-const auto k0_values_nightly_lhs_t_rhs_nt  = framework::dataset::make("K0", { 1, 2, 3, 4, 5, 6, 7, 8 });
+const auto k0_values_nightly_lhs_nt_rhs_nt = framework::dataset::make("K0", {1, 2, 3, 4, 8, 16});
+const auto k0_values_nightly_rhs_t         = framework::dataset::make("K0", {1, 2, 3, 4, 8});
+const auto k0_values_nightly_lhs_t_rhs_nt  = framework::dataset::make("K0", {1, 2, 3, 4, 5, 6, 7, 8});
 
 TEST_SUITE(CL)
 TEST_SUITE(DYNAMIC_FUSION)
@@ -85,45 +86,43 @@
 {
     using MatMulConfigurationPair = std::pair<MatMulKernelInfo, bool>;
 
-    const std::vector<MatMulConfigurationPair> supported_block_sizes =
-    {
+    const std::vector<MatMulConfigurationPair> supported_block_sizes = {
         // MatMulKernelInfo(adj_lhs, adj_rhs, M0, N0, K0, export_rhs_to_cl_image = false)
 
         // Lhs not-transposed, Rhs transposed
-        { MatMulKernelInfo(false, true, 0, 1, 1), false },  // M0 should be > 0
-        { MatMulKernelInfo(false, true, 3, 11, 1), false }, // N0 not in {1, 2, 3, 4, 8, 16}
-        { MatMulKernelInfo(false, true, 3, 7, 1), false },  // N0 not in {1, 2, 3, 4, 8, 16}
-        { MatMulKernelInfo(false, true, 3, 3, 12), false }, // K0 not in {1, 2, 3, 4, 8, 16}
-        { MatMulKernelInfo(false, true, 3, 3, 6), false },  // K0 not in {1, 2, 3, 4, 8, 16}
-        { MatMulKernelInfo(false, true, 5, 1, 2), true },
-        { MatMulKernelInfo(false, true, 3, 3, 3), true },
-        { MatMulKernelInfo(false, true, 2, 4, 8), true },
+        {MatMulKernelInfo(false, true, 0, 1, 1), false},  // M0 should be > 0
+        {MatMulKernelInfo(false, true, 3, 11, 1), false}, // N0 not in {1, 2, 3, 4, 8, 16}
+        {MatMulKernelInfo(false, true, 3, 7, 1), false},  // N0 not in {1, 2, 3, 4, 8, 16}
+        {MatMulKernelInfo(false, true, 3, 3, 12), false}, // K0 not in {1, 2, 3, 4, 8, 16}
+        {MatMulKernelInfo(false, true, 3, 3, 6), false},  // K0 not in {1, 2, 3, 4, 8, 16}
+        {MatMulKernelInfo(false, true, 5, 1, 2), true},   {MatMulKernelInfo(false, true, 3, 3, 3), true},
+        {MatMulKernelInfo(false, true, 2, 4, 8), true},
 
     };
 
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &context };
+    auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+    GpuWorkloadSketch sketch{&context};
 
     // Set big enough shapes so that block sizes are not truncated. Also, set all dimensions equal
     // so that it doesn't fail for different NT/T configurations. We aim to test the block sizes here,
     // not the shapes themselves.
-    const TensorInfo lhs_info    = context.create_tensor_info(TensorInfo(TensorShape(100U, 100U), 1, DataType::F32));
-    const TensorInfo rhs_info    = context.create_tensor_info(TensorInfo(TensorShape(100U, 100U), 1, DataType::F32));
+    const ITensorInfo *lhs_info = context.create_tensor_info(TensorInfo(TensorShape(100U, 100U), 1, DataType::F32));
+    const ITensorInfo *rhs_info = context.create_tensor_info(TensorInfo(TensorShape(100U, 100U), 1, DataType::F32));
 
-    for(auto &pair : supported_block_sizes)
+    for (auto &pair : supported_block_sizes)
     {
-        MatMulAttributes matmul_attr {};
+        MatMulAttributes matmul_attr{};
         matmul_attr.adj_lhs(pair.first.adj_lhs);
         matmul_attr.adj_rhs(pair.first.adj_rhs);
 
-        GpuMatMulSettings matmul_settings {};
+        GpuMatMulSettings matmul_settings{};
         matmul_settings.m0(pair.first.m0);
         matmul_settings.n0(pair.first.n0);
         matmul_settings.k0(pair.first.k0);
 
-        Status status = GpuMatMul::validate_op(sketch, &lhs_info, &rhs_info, matmul_attr, matmul_settings);
+        Status status = GpuMatMul::validate_op(sketch, lhs_info, rhs_info, matmul_attr, matmul_settings);
         ARM_COMPUTE_EXPECT(bool(status) == pair.second, framework::LogLevel::ERRORS);
     }
 }
@@ -132,117 +131,110 @@
 {
     // Create a sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &context };
+    auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+    GpuWorkloadSketch sketch{&context};
 
     // Configurations are assumed to be Nt/Nt, but will be transposed inside the test to test other configurations
-    using ShapeConfigurationTuple = std::tuple<TensorShape, TensorShape, bool>;
-    const std::vector<ShapeConfigurationTuple> shape_configurations =
-    {
-        { TensorShape(5U, 1U), TensorShape(3U, 5U), true },
-        { TensorShape(10U, 12U), TensorShape(3U, 10U), true },
-        { TensorShape(8U, 4U), TensorShape(2U, 8U), true },
-        { TensorShape(8U, 4U), TensorShape(2U, 5U), false }, // Mismatch in the K dimension
-        { TensorShape(5U, 0U), TensorShape(2U, 5U), false }, // Invalid dimension
-        { TensorShape(5U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), true },
-        { TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false }, // no batch broadcasting
-        { TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false }, // mismatch in batch dimension
+    using ShapeConfigurationTuple                                   = std::tuple<TensorShape, TensorShape, bool>;
+    const std::vector<ShapeConfigurationTuple> shape_configurations = {
+        {TensorShape(5U, 1U), TensorShape(3U, 5U), true},
+        {TensorShape(10U, 12U), TensorShape(3U, 10U), true},
+        {TensorShape(8U, 4U), TensorShape(2U, 8U), true},
+        {TensorShape(8U, 4U), TensorShape(2U, 5U), false}, // Mismatch in the K dimension
+        {TensorShape(5U, 0U), TensorShape(2U, 5U), false}, // Invalid dimension
+        {TensorShape(5U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), true},
+        {TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false}, // no batch broadcasting
+        {TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U),
+         false}, // mismatch in batch dimension
     };
 
-    for(auto &tuple : shape_configurations)
+    for (auto &tuple : shape_configurations)
     {
         const bool expected = std::get<2>(tuple);
 
-        for(bool adj_lhs :
-            {
-                false
-            })
+        for (bool adj_lhs : {false})
         {
-            for(bool adj_rhs :
-                {
-                    true
-                })
+            for (bool adj_rhs : {true})
             {
                 TensorShape lhs_shape = std::get<0>(tuple);
                 TensorShape rhs_shape = std::get<1>(tuple);
 
-                if(adj_lhs)
+                if (adj_lhs)
                 {
                     permute(lhs_shape, PermutationVector(1U, 0U));
                 }
 
-                if(adj_rhs)
+                if (adj_rhs)
                 {
                     permute(rhs_shape, PermutationVector(1U, 0U));
                 }
 
-                const TensorInfo lhs_info    = context.create_tensor_info(TensorInfo(lhs_shape, 1, DataType::F32));
-                const TensorInfo rhs_info    = context.create_tensor_info(TensorInfo(rhs_shape, 1, DataType::F32));
+                const ITensorInfo *lhs_info = context.create_tensor_info(TensorInfo(lhs_shape, 1, DataType::F32));
+                const ITensorInfo *rhs_info = context.create_tensor_info(TensorInfo(rhs_shape, 1, DataType::F32));
 
-                MatMulAttributes matmul_attr {};
+                MatMulAttributes matmul_attr{};
                 matmul_attr.adj_lhs(adj_lhs);
                 matmul_attr.adj_rhs(adj_rhs);
 
-                GpuMatMulSettings matmul_settings {};
+                GpuMatMulSettings matmul_settings{};
                 matmul_settings.m0(1);
                 matmul_settings.n0(1);
                 matmul_settings.k0(1);
 
-                Status status = GpuMatMul::validate_op(sketch, &lhs_info, &rhs_info, matmul_attr, matmul_settings);
+                Status status = GpuMatMul::validate_op(sketch, lhs_info, rhs_info, matmul_attr, matmul_settings);
                 ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
             }
         }
     }
 }
 
-
 TEST_CASE(ValidateDataTypes, framework::DatasetMode::ALL)
 {
     // Configurations are assumed to be Nt/Nt, but will be transposed inside the test to test other configurations
     using DataTypeConfigurationTuple = std::tuple<DataType, DataType, DataType, bool>;
-    const std::vector<DataTypeConfigurationTuple> data_type_configurations =
-    {
-        { DataType::F32, DataType::F32, DataType::F32, true },
-        { DataType::F16, DataType::F16, DataType::F16, true },
-        { DataType::F16, DataType::F32, DataType::F32, false },                                              // no mixed precision
-        { DataType::F64, DataType::F64, DataType::F64, false },                                              // no double precision
-        { DataType::QASYMM8, DataType::QASYMM8, DataType::QASYMM8, false },                                  // no quantized types
-        { DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, false },             // no quantized types
-        { DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, false }, // no quantized types
-        { DataType::QASYMM16, DataType::QASYMM16, DataType::QASYMM16, false },                               // no quantized types
-        { DataType::QSYMM16, DataType::QSYMM16, DataType::QSYMM16, false },                                  // no quantized types
-        { DataType::QSYMM8, DataType::QSYMM8, DataType::QSYMM8, false },                                     // no quantized types
-        { DataType::S64, DataType::S64, DataType::S64, false },                                              // no integral types
-        { DataType::S32, DataType::S32, DataType::S32, false },                                              // no integral types
-        { DataType::S16, DataType::S16, DataType::S16, false },                                              // no integral types
-        { DataType::S8, DataType::S8, DataType::S8, false },                                                 // no integral types
-        { DataType::U64, DataType::U64, DataType::U64, false },                                              // no integral types
-        { DataType::U32, DataType::U32, DataType::U32, false },                                              // no integral types
-        { DataType::U16, DataType::U16, DataType::U16, false },                                              // no integral types
-        { DataType::U8, DataType::U8, DataType::U8, false },                                                 // no integral types
+    const std::vector<DataTypeConfigurationTuple> data_type_configurations = {
+        {DataType::F32, DataType::F32, DataType::F32, true},
+        {DataType::F16, DataType::F16, DataType::F16, true},
+        {DataType::F16, DataType::F32, DataType::F32, false},                                  // no mixed precision
+        {DataType::F64, DataType::F64, DataType::F64, false},                                  // no double precision
+        {DataType::QASYMM8, DataType::QASYMM8, DataType::QASYMM8, false},                      // no quantized types
+        {DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, false}, // no quantized types
+        {DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL,
+         false},                                                             // no quantized types
+        {DataType::QASYMM16, DataType::QASYMM16, DataType::QASYMM16, false}, // no quantized types
+        {DataType::QSYMM16, DataType::QSYMM16, DataType::QSYMM16, false},    // no quantized types
+        {DataType::QSYMM8, DataType::QSYMM8, DataType::QSYMM8, false},       // no quantized types
+        {DataType::S64, DataType::S64, DataType::S64, false},                // no integral types
+        {DataType::S32, DataType::S32, DataType::S32, false},                // no integral types
+        {DataType::S16, DataType::S16, DataType::S16, false},                // no integral types
+        {DataType::S8, DataType::S8, DataType::S8, false},                   // no integral types
+        {DataType::U64, DataType::U64, DataType::U64, false},                // no integral types
+        {DataType::U32, DataType::U32, DataType::U32, false},                // no integral types
+        {DataType::U16, DataType::U16, DataType::U16, false},                // no integral types
+        {DataType::U8, DataType::U8, DataType::U8, false},                   // no integral types
     };
     // Create a sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &context };
+    auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+    GpuWorkloadSketch sketch{&context};
 
     const TensorShape shape = TensorShape(10U, 10U);
-    MatMulAttributes  matmul_attr {};
+    MatMulAttributes  matmul_attr{};
     matmul_attr.adj_lhs(false);
     matmul_attr.adj_rhs(false);
-    GpuMatMulSettings matmul_settings {};
+    GpuMatMulSettings matmul_settings{};
     matmul_settings.m0(1);
     matmul_settings.n0(1);
     matmul_settings.k0(1);
 
-    for(auto &tuple : data_type_configurations)
+    for (auto &tuple : data_type_configurations)
     {
         const bool expected = std::get<3>(tuple);
 
-        const TensorInfo lhs_info    = context.create_tensor_info(TensorInfo(shape, 1, std::get<0>(tuple)));
-        const TensorInfo rhs_info    = context.create_tensor_info(TensorInfo(shape, 1, std::get<1>(tuple)));
+        const ITensorInfo *lhs_info = context.create_tensor_info(TensorInfo(shape, 1, std::get<0>(tuple)));
+        const ITensorInfo *rhs_info = context.create_tensor_info(TensorInfo(shape, 1, std::get<1>(tuple)));
 
-        Status status = GpuMatMul::validate_op(sketch, &lhs_info, &rhs_info, matmul_attr, matmul_settings);
+        Status status = GpuMatMul::validate_op(sketch, lhs_info, rhs_info, matmul_attr, matmul_settings);
         ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
     }
 }
@@ -250,59 +242,75 @@
 TEST_SUITE_END() // Validate
 
 template <typename T>
-using DynamicFusionGpuMatmulFixture = DynamicFusionGpuMatMulValidationFixture<CLTensor, CLAccessor,GpuMatMul, T>;
+using DynamicFusionGpuMatmulFixture = DynamicFusionGpuMatMulValidationFixture<CLTensor, CLAccessor, GpuMatMul, T>;
 
 TEST_SUITE(Float)
 TEST_SUITE(FP32)
 
-FIXTURE_DATA_TEST_CASE(RunTiny, DynamicFusionGpuMatmulFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::TinyMatMulDataset(),
-                                                                                                                   framework::dataset::make("TransposeA", { false })),
-                                                                                                                   framework::dataset::make("TransposeB", { true })),
-                                                                                                                   m0_values_precommit),
-                                                                                                                   n0_values_precommit),
-                                                                                                                   k0_values_precommit),
-                                                                                                           framework::dataset::make("ExportRhsToCLImage", { false })),
-                                                                                                   framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(
+    RunTiny,
+    DynamicFusionGpuMatmulFixture<float>,
+    framework::DatasetMode::ALL,
+    combine(combine(combine(combine(combine(combine(combine(datasets::TinyMatMulDataset(),
+                                                            framework::dataset::make("TransposeA", {false})),
+                                                    framework::dataset::make("TransposeB", {true})),
+                                            m0_values_precommit),
+                                    n0_values_precommit),
+                            k0_values_precommit),
+                    framework::dataset::make("ExportRhsToCLImage", {false})),
+            framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
 }
 
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuMatmulFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(),
-                                                                                                                   framework::dataset::make("TransposeA", { false })),
-                                                                                                                   framework::dataset::make("TransposeB", { true })),
-                                                                                                                   m0_values_precommit),
-                                                                                                                   n0_values_precommit),
-                                                                                                                   k0_values_precommit),
-                                                                                                           framework::dataset::make("ExportRhsToCLImage", { false })),
-                                                                                                   framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(
+    RunSmall,
+    DynamicFusionGpuMatmulFixture<float>,
+    framework::DatasetMode::ALL,
+    combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(),
+                                                            framework::dataset::make("TransposeA", {false})),
+                                                    framework::dataset::make("TransposeB", {true})),
+                                            m0_values_precommit),
+                                    n0_values_precommit),
+                            k0_values_precommit),
+                    framework::dataset::make("ExportRhsToCLImage", {false})),
+            framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLargeRhsTransposed, DynamicFusionGpuMatmulFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(),
-                                                                                                                     framework::dataset::make("TransposeA", { false })),
-                                                                                                                     framework::dataset::make("TransposeB", { true })),
-                                                                                                                     m0_values_nightly_lhs_nt),
-                                                                                                                     n0_values_nightly_rhs_t),
-                                                                                                                     k0_values_nightly_rhs_t),
-                                                                                                                     framework::dataset::make("ExportRhsToCLImage", { false })),
-                                                                                                                     framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(
+    RunLargeRhsTransposed,
+    DynamicFusionGpuMatmulFixture<float>,
+    framework::DatasetMode::NIGHTLY,
+    combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(),
+                                                            framework::dataset::make("TransposeA", {false})),
+                                                    framework::dataset::make("TransposeB", {true})),
+                                            m0_values_nightly_lhs_nt),
+                                    n0_values_nightly_rhs_t),
+                            k0_values_nightly_rhs_t),
+                    framework::dataset::make("ExportRhsToCLImage", {false})),
+            framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
 }
 
 // Running High Dimensional test is enough for FP32, because we're stressing the number of dimensions, not data type or M0/N0/K0
-FIXTURE_DATA_TEST_CASE(RunHighDimensional, DynamicFusionGpuMatmulFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::HighDimensionalMatMulDataset(),
-                                                                                                                      framework::dataset::make("TransposeA", { false })),
-                                                                                                                      framework::dataset::make("TransposeB", { true })),
-                                                                                                                      framework::dataset::make("M0", { 2 })),
-                                                                                                                      framework::dataset::make("N0", { 2 })),
-                                                                                                                      framework::dataset::make("K0", { 2 })),
-                                                                                                                      framework::dataset::make("ExportRhsToCLImage", { false })),
-                                                                                                              framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(
+    RunHighDimensional,
+    DynamicFusionGpuMatmulFixture<float>,
+    framework::DatasetMode::ALL,
+    combine(combine(combine(combine(combine(combine(combine(datasets::HighDimensionalMatMulDataset(),
+                                                            framework::dataset::make("TransposeA", {false})),
+                                                    framework::dataset::make("TransposeB", {true})),
+                                            framework::dataset::make("M0", {2})),
+                                    framework::dataset::make("N0", {2})),
+                            framework::dataset::make("K0", {2})),
+                    framework::dataset::make("ExportRhsToCLImage", {false})),
+            framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
@@ -311,28 +319,35 @@
 
 TEST_SUITE(FP16)
 
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuMatmulFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(),
-                                                                                                                   framework::dataset::make("TransposeA", { false })),
-                                                                                                                   framework::dataset::make("TransposeB", { true })),
-                                                                                                                   m0_values_precommit),
-                                                                                                                   n0_values_precommit),
-                                                                                                                   k0_values_precommit),
-                                                                                                           framework::dataset::make("ExportRhsToCLImage", { false })),
-                                                                                                   framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(
+    RunSmall,
+    DynamicFusionGpuMatmulFixture<half>,
+    framework::DatasetMode::ALL,
+    combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(),
+                                                            framework::dataset::make("TransposeA", {false})),
+                                                    framework::dataset::make("TransposeB", {true})),
+                                            m0_values_precommit),
+                                    n0_values_precommit),
+                            k0_values_precommit),
+                    framework::dataset::make("ExportRhsToCLImage", {false})),
+            framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16);
 }
 
-
-FIXTURE_DATA_TEST_CASE(RunLargeRhsTransposed, DynamicFusionGpuMatmulFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(),
-                                                                                                                     framework::dataset::make("TransposeA", { false })),
-                                                                                                                     framework::dataset::make("TransposeB", { true })),
-                                                                                                                     m0_values_nightly_lhs_nt),
-                                                                                                                     n0_values_nightly_rhs_t),
-                                                                                                                     k0_values_nightly_rhs_t),
-                                                                                                                     framework::dataset::make("ExportRhsToCLImage", { false })),
-                                                                                                                     framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(
+    RunLargeRhsTransposed,
+    DynamicFusionGpuMatmulFixture<half>,
+    framework::DatasetMode::NIGHTLY,
+    combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulDataset(),
+                                                            framework::dataset::make("TransposeA", {false})),
+                                                    framework::dataset::make("TransposeB", {true})),
+                                            m0_values_nightly_lhs_nt),
+                                    n0_values_nightly_rhs_t),
+                            k0_values_nightly_rhs_t),
+                    framework::dataset::make("ExportRhsToCLImage", {false})),
+            framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16);
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp b/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp
index b69479f..c11bffe 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Mul.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,14 +29,13 @@
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuMul.h"
 
 #include "tests/CL/CLAccessor.h"
-#include "tests/framework/Fixture.h"
-#include "tests/framework/Macros.h"
-#include "tests/framework/datasets/Datasets.h"
-#include "tests/validation/Validation.h"
-
 #include "tests/datasets/DynamicFusionDataset.h"
 #include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Fixture.h"
+#include "tests/framework/Macros.h"
 #include "tests/validation/fixtures/dynamic_fusion/operators/MulFixture.h"
+#include "tests/validation/Validation.h"
 
 namespace arm_compute
 {
@@ -58,8 +57,10 @@
  */
 namespace
 {
-constexpr AbsoluteTolerance<float> tolerance_f16(0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
-constexpr AbsoluteTolerance<float> tolerance_f32(0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
+constexpr AbsoluteTolerance<float> tolerance_f16(
+    0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F16 */
+constexpr AbsoluteTolerance<float> tolerance_f32(
+    0.0001f); /**< Tolerance value for comparing reference's output against implementation's output for DataType::F32 */
 } // namespace
 TEST_SUITE(CL)
 TEST_SUITE(DYNAMIC_FUSION)
@@ -112,7 +113,7 @@
     auto          lhs_info         = context.create_tensor_info(input1_info);
     auto          rhs_info         = context.create_tensor_info(input2_info);
 
-    bool res = bool(GpuMul::validate_op(sketch, &lhs_info, &rhs_info));
+    bool res = bool(GpuMul::validate_op(sketch, lhs_info, rhs_info));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
 }
 // clang-format on
@@ -129,9 +130,8 @@
 FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
                        DynamicFusionCLMulFixture<half>,
                        framework::DatasetMode::ALL,
-                       combine(combine(datasets::SmallShapes(),
-                                       framework::dataset::make("DataType", { DataType::F16 })),
-                               framework::dataset::make("InPlace", { false })))
+                       combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", {DataType::F16})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
@@ -141,8 +141,8 @@
                        DynamicFusionCLMulBroadcastFixture<half>,
                        framework::DatasetMode::PRECOMMIT,
                        combine(combine(datasets::TemporaryLimitedSmallShapesBroadcast(),
-                                       framework::dataset::make("DataType", { DataType::F16 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F16})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
@@ -152,8 +152,8 @@
                        DynamicFusionCLMulBroadcastFixture<half>,
                        framework::DatasetMode::NIGHTLY,
                        combine(combine(datasets::TemporaryLimitedLargeShapesBroadcast(),
-                                       framework::dataset::make("DataType", { DataType::F16 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F16})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
@@ -164,9 +164,8 @@
 FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
                        DynamicFusionCLMulFixture<float>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(datasets::SmallShapes(),
-                                       framework::dataset::make("DataType", { DataType::F32 })),
-                               framework::dataset::make("InPlace", { false })))
+                       combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", {DataType::F32})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -175,9 +174,8 @@
 FIXTURE_DATA_TEST_CASE(RunLargeOneOp,
                        DynamicFusionCLMulFixture<float>,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(datasets::LargeShapes(),
-                                       framework::dataset::make("DataType", { DataType::F32 })),
-                               framework::dataset::make("InPlace", { false })))
+                       combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", {DataType::F32})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -187,8 +185,8 @@
                        DynamicFusionCLMulBroadcastFixture<float>,
                        framework::DatasetMode::PRECOMMIT,
                        combine(combine(datasets::TemporaryLimitedSmallShapesBroadcast(),
-                                       framework::dataset::make("DataType", { DataType::F32 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F32})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -198,8 +196,8 @@
                        DynamicFusionCLMulBroadcastFixture<float>,
                        framework::DatasetMode::NIGHTLY,
                        combine(combine(datasets::TemporaryLimitedLargeShapesBroadcast(),
-                                       framework::dataset::make("DataType", { DataType::F32 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F32})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -209,9 +207,9 @@
                        DynamicFusionCLMulTwoOpsFixture<float>,
                        framework::DatasetMode::PRECOMMIT,
                        combine(combine(combine(datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes(),
-                                               framework::dataset::make("DataType", { DataType::F32 })),
-                                       framework::dataset::make("InPlace", { false })),
-                               framework::dataset::make("FuseTwoOps", { true })))
+                                               framework::dataset::make("DataType", {DataType::F32})),
+                                       framework::dataset::make("InPlace", {false})),
+                               framework::dataset::make("FuseTwoOps", {true})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp b/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp
index 411e31b..f894ce3 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Pool2d.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -25,13 +25,13 @@
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuPool2d.h"
 
 #include "tests/CL/CLAccessor.h"
-#include "tests/datasets/ShapeDatasets.h"
 #include "tests/datasets/dynamic_fusion/PoolingLayerDataset.h"
+#include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/datasets/Datasets.h"
 #include "tests/framework/Fixture.h"
 #include "tests/framework/Macros.h"
-#include "tests/framework/datasets/Datasets.h"
-#include "tests/validation/Validation.h"
 #include "tests/validation/fixtures/dynamic_fusion/gpu/cl/Pool2dFixture.h"
+#include "tests/validation/Validation.h"
 
 namespace arm_compute
 {
@@ -43,15 +43,19 @@
 TEST_SUITE(DYNAMIC_FUSION)
 TEST_SUITE(POOL2D)
 
-constexpr AbsoluteTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for 32-bit floating-point type */
-constexpr AbsoluteTolerance<float> tolerance_f16(0.01f);  /**< Tolerance value for comparing reference's output against implementation's output for 16-bit floating-point type */
+constexpr AbsoluteTolerance<float> tolerance_f32(
+    0.001f); /**< Tolerance value for comparing reference's output against implementation's output for 32-bit floating-point type */
+constexpr AbsoluteTolerance<float> tolerance_f16(
+    0.01f); /**< Tolerance value for comparing reference's output against implementation's output for 16-bit floating-point type */
 
-const auto PoolingLayerDatasetFP = combine(combine(combine(combine(framework::dataset::make("PoolingType", { PoolingType::MAX, PoolingType::AVG }), framework::dataset::make("PoolingSize", { Size2D(2, 2), Size2D(3, 3) })),
-                                                           framework::dataset::make("Pad", { Padding2D() })),
-                                                   framework::dataset::make("Stride", { Size2D(1, 1), Size2D(2, 1), Size2D(5, 7) })),
-                                           framework::dataset::make("ExcludePadding", { true }));
+const auto PoolingLayerDatasetFP =
+    combine(combine(combine(combine(framework::dataset::make("PoolingType", {PoolingType::MAX, PoolingType::AVG}),
+                                    framework::dataset::make("PoolingSize", {Size2D(2, 2), Size2D(3, 3)})),
+                            framework::dataset::make("Pad", {Padding2D()})),
+                    framework::dataset::make("Stride", {Size2D(1, 1), Size2D(2, 1), Size2D(5, 7)})),
+            framework::dataset::make("ExcludePadding", {true}));
 
-const auto pool_fp_mixed_precision_dataset = framework::dataset::make("FpMixedPrecision", { true, false });
+const auto pool_fp_mixed_precision_dataset = framework::dataset::make("FpMixedPrecision", {true, false});
 
 template <typename T>
 using DynamicFusionGpuPool2dFixture = DynamicFusionGpuPool2dValidationFixture<CLTensor, CLAccessor, GpuPool2d, T>;
@@ -60,7 +64,8 @@
 using DFSpecialGpuPool2dFixture = DynamicFusionGpuPool2dSpecialValidationFixture<CLTensor, CLAccessor, GpuPool2d, T>;
 
 template <typename T>
-using DFPoolMixedPrecisionFixture = DynamicFusionGpuPool2dMixedPrecisionValidationFixture<CLTensor, CLAccessor, GpuPool2d, T>;
+using DFPoolMixedPrecisionFixture =
+    DynamicFusionGpuPool2dMixedPrecisionValidationFixture<CLTensor, CLAccessor, GpuPool2d, T>;
 // *INDENT-OFF*
 // clang-format off
 
@@ -91,7 +96,7 @@
 
     // Validate Pool2d Configuration
     auto                   src_info    = context.create_tensor_info(input_info);
-    bool                   res         = bool(GpuPool2d::validate_op(sketch, &src_info, pool2d_attr, settings));
+    bool                   res         = bool(GpuPool2d::validate_op(sketch, src_info, pool2d_attr, settings));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
 }
 
@@ -100,53 +105,68 @@
 
 TEST_SUITE(Float)
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuPool2dFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallNoneUnitShapes(), PoolingLayerDatasetFP),
-                                                                                                                  framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuPool2dFixture<float>,
+                       framework::DatasetMode::PRECOMMIT,
+                       combine(combine(datasets::SmallNoneUnitShapes(), PoolingLayerDatasetFP),
+                               framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuPool2dFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), PoolingLayerDatasetFP),
-                                                                                                                framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunLarge,
+                       DynamicFusionGpuPool2dFixture<float>,
+                       framework::DatasetMode::NIGHTLY,
+                       combine(combine(datasets::LargeShapes(), PoolingLayerDatasetFP),
+                               framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunSpecial, DFSpecialGpuPool2dFixture<float>, framework::DatasetMode::ALL, combine(datasets::PoolingLayerDatasetSpecialDynamicFusion(),
-                                                                                                          framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunSpecial,
+                       DFSpecialGpuPool2dFixture<float>,
+                       framework::DatasetMode::ALL,
+                       combine(datasets::PoolingLayerDatasetSpecialDynamicFusion(),
+                               framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
 
 TEST_SUITE(GlobalPooling)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuPool2dFixture<float>, framework::DatasetMode::ALL,
-                       combine(combine(combine(combine(combine(combine(
-                                                                   framework::dataset::make("InputShape", { TensorShape(27U, 13U, 2U),
-                                                                                                            TensorShape(27U, 13U, 2U, 4U)
-                                                                                                          }),
-                                                                   framework::dataset::make("PoolingType", { PoolingType::AVG, PoolingType::MAX })),
-                                                               framework::dataset::make("PoolingSize", { Size2D(27, 13) })),
-                                                       framework::dataset::make("Pad", { Padding2D() })),
-                                               framework::dataset::make("Stride", { Size2D(1, 1) })),
-                                       framework::dataset::make("ExcludePadding", true)),
-                               framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(
+    RunSmall,
+    DynamicFusionGpuPool2dFixture<float>,
+    framework::DatasetMode::ALL,
+    combine(combine(combine(combine(combine(combine(framework::dataset::make("InputShape",
+                                                                             {TensorShape(27U, 13U, 2U),
+                                                                              TensorShape(27U, 13U, 2U, 4U)}),
+                                                    framework::dataset::make("PoolingType",
+                                                                             {PoolingType::AVG, PoolingType::MAX})),
+                                            framework::dataset::make("PoolingSize", {Size2D(27, 13)})),
+                                    framework::dataset::make("Pad", {Padding2D()})),
+                            framework::dataset::make("Stride", {Size2D(1, 1)})),
+                    framework::dataset::make("ExcludePadding", true)),
+            framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuPool2dFixture<float>, framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(combine(combine(combine(
-                                                                   framework::dataset::make("InputShape", { TensorShape(79U, 37U, 11U),
-                                                                                                            TensorShape(79U, 37U, 11U, 4U)
-                                                                                                          }),
-                                                                   framework::dataset::make("PoolingType", { PoolingType::AVG, PoolingType::MAX })),
-                                                               framework::dataset::make("PoolingSize", { Size2D(79, 37) })),
-                                                       framework::dataset::make("Pad", { Padding2D() })),
-                                               framework::dataset::make("Stride", { Size2D(1, 1) })),
-                                       framework::dataset::make("ExcludePadding", true)),
-                               framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(
+    RunLarge,
+    DynamicFusionGpuPool2dFixture<float>,
+    framework::DatasetMode::NIGHTLY,
+    combine(combine(combine(combine(combine(combine(framework::dataset::make("InputShape",
+                                                                             {TensorShape(79U, 37U, 11U),
+                                                                              TensorShape(79U, 37U, 11U, 4U)}),
+                                                    framework::dataset::make("PoolingType",
+                                                                             {PoolingType::AVG, PoolingType::MAX})),
+                                            framework::dataset::make("PoolingSize", {Size2D(79, 37)})),
+                                    framework::dataset::make("Pad", {Padding2D()})),
+                            framework::dataset::make("Stride", {Size2D(1, 1)})),
+                    framework::dataset::make("ExcludePadding", true)),
+            framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
@@ -155,49 +175,61 @@
 TEST_SUITE_END() // FP32
 
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, DFPoolMixedPrecisionFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallNoneUnitShapes(), PoolingLayerDatasetFP),
-                                                                                                                       framework::dataset::make("DataType", DataType::F16)),
-                                                                                                               pool_fp_mixed_precision_dataset))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DFPoolMixedPrecisionFixture<half>,
+                       framework::DatasetMode::PRECOMMIT,
+                       combine(combine(combine(datasets::SmallNoneUnitShapes(), PoolingLayerDatasetFP),
+                                       framework::dataset::make("DataType", DataType::F16)),
+                               pool_fp_mixed_precision_dataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, DFPoolMixedPrecisionFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), PoolingLayerDatasetFP),
-                                                                                                                     framework::dataset::make("DataType", DataType::F16)),
-                                                                                                             pool_fp_mixed_precision_dataset))
+FIXTURE_DATA_TEST_CASE(RunLarge,
+                       DFPoolMixedPrecisionFixture<half>,
+                       framework::DatasetMode::NIGHTLY,
+                       combine(combine(combine(datasets::LargeShapes(), PoolingLayerDatasetFP),
+                                       framework::dataset::make("DataType", DataType::F16)),
+                               pool_fp_mixed_precision_dataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
 }
 
 TEST_SUITE(GlobalPooling)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuPool2dFixture<half>, framework::DatasetMode::ALL,
-                       combine(combine(combine(combine(combine(combine(
-                                                                   framework::dataset::make("InputShape", { TensorShape(27U, 13U, 2U),
-                                                                                                            TensorShape(27U, 13U, 2U, 4U)
-                                                                                                          }),
-                                                                   framework::dataset::make("PoolingType", { PoolingType::AVG, PoolingType::MAX })),
-                                                               framework::dataset::make("PoolingSize", { Size2D(27, 13) })),
-                                                       framework::dataset::make("Pad", { Padding2D() })),
-                                               framework::dataset::make("Stride", { Size2D(1, 1) })),
-                                       framework::dataset::make("ExcludePadding", true)),
-                               framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(
+    RunSmall,
+    DynamicFusionGpuPool2dFixture<half>,
+    framework::DatasetMode::ALL,
+    combine(combine(combine(combine(combine(combine(framework::dataset::make("InputShape",
+                                                                             {TensorShape(27U, 13U, 2U),
+                                                                              TensorShape(27U, 13U, 2U, 4U)}),
+                                                    framework::dataset::make("PoolingType",
+                                                                             {PoolingType::AVG, PoolingType::MAX})),
+                                            framework::dataset::make("PoolingSize", {Size2D(27, 13)})),
+                                    framework::dataset::make("Pad", {Padding2D()})),
+                            framework::dataset::make("Stride", {Size2D(1, 1)})),
+                    framework::dataset::make("ExcludePadding", true)),
+            framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
 }
 
-FIXTURE_DATA_TEST_CASE(RunLarge, DynamicFusionGpuPool2dFixture<half>, framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(combine(combine(combine(
-                                                                   framework::dataset::make("InputShape", { TensorShape(79U, 37U, 11U),
-                                                                                                            TensorShape(79U, 37U, 11U, 4U)
-                                                                                                          }),
-                                                                   framework::dataset::make("PoolingType", { PoolingType::AVG, PoolingType::MAX })),
-                                                               framework::dataset::make("PoolingSize", { Size2D(79, 37) })),
-                                                       framework::dataset::make("Pad", { Padding2D() })),
-                                               framework::dataset::make("Stride", { Size2D(1, 1) })),
-                                       framework::dataset::make("ExcludePadding", true)),
-                               framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(
+    RunLarge,
+    DynamicFusionGpuPool2dFixture<half>,
+    framework::DatasetMode::NIGHTLY,
+    combine(combine(combine(combine(combine(combine(framework::dataset::make("InputShape",
+                                                                             {TensorShape(79U, 37U, 11U),
+                                                                              TensorShape(79U, 37U, 11U, 4U)}),
+                                                    framework::dataset::make("PoolingType",
+                                                                             {PoolingType::AVG, PoolingType::MAX})),
+                                            framework::dataset::make("PoolingSize", {Size2D(79, 37)})),
+                                    framework::dataset::make("Pad", {Padding2D()})),
+                            framework::dataset::make("Stride", {Size2D(1, 1)})),
+                    framework::dataset::make("ExcludePadding", true)),
+            framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
@@ -209,7 +241,7 @@
 TEST_SUITE_END() // POOL2D
 TEST_SUITE_END() // DYNAMIC_FUSION
 TEST_SUITE_END() // CL
-}
-}
-}
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
 #endif // ACL_INTERNAL_TEST_CKW_IN_DF
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp b/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp
index 4d038b2..43617fe 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Reshape.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -24,10 +24,10 @@
 #ifndef ACL_INTERNAL_TEST_CKW_IN_DF // Do not include this test if ACL_INTERNAL_TEST_CKW_IN_DF and the op has not been ported to ckw
 #include "tests/CL/CLAccessor.h"
 #include "tests/datasets/ReshapeLayerDataset.h"
-#include "tests/framework/Macros.h"
 #include "tests/framework/datasets/Datasets.h"
-#include "tests/validation/Validation.h"
+#include "tests/framework/Macros.h"
 #include "tests/validation/fixtures/dynamic_fusion/operators/ReshapeFixture.h"
+#include "tests/validation/Validation.h"
 
 namespace arm_compute
 {
@@ -39,41 +39,52 @@
 TEST_SUITE(DYNAMIC_FUSION)
 TEST_SUITE(RESHAPE)
 
-DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(framework::dataset::make("InputInfo",
-{
-    TensorInfo(TensorShape(9U, 5U, 7U, 3U), 1, DataType::F32), TensorInfo(TensorShape(8U, 4U, 6U, 4U), 1, DataType::F32), TensorInfo(TensorShape(8U, 4U, 6U, 4U), 1, DataType::F32) /*mismatching dimensions*/,
-}),
-framework::dataset::make("OutputShape",
-{
-    TensorShape(9U, 5U, 21U),
-    TensorShape(8U, 24U, 4U),
-    TensorShape(192U, 192U),
-})),
-framework::dataset::make("Expected", { true, true, false })),
-input_info, output_shape, expected)
+DATA_TEST_CASE(Validate,
+               framework::DatasetMode::ALL,
+               zip(zip(framework::dataset::make(
+                           "InputInfo",
+                           {
+                               TensorInfo(TensorShape(9U, 5U, 7U, 3U), 1, DataType::F32),
+                               TensorInfo(TensorShape(8U, 4U, 6U, 4U), 1, DataType::F32),
+                               TensorInfo(TensorShape(8U, 4U, 6U, 4U), 1, DataType::F32) /*mismatching dimensions*/,
+                           }),
+                       framework::dataset::make("OutputShape",
+                                                {
+                                                    TensorShape(9U, 5U, 21U),
+                                                    TensorShape(8U, 24U, 4U),
+                                                    TensorShape(192U, 192U),
+                                                })),
+                   framework::dataset::make("Expected", {true, true, false})),
+               input_info,
+               output_shape,
+               expected)
 {
     // Create a new workload sketch
     auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch sketch{ &context };
+    auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+    GpuWorkloadSketch sketch{&context};
 
     // Create sketch tensors
     TensorShape input_shape = input_info.tensor_shape();
     ARM_COMPUTE_UNUSED(input_shape);
-    TensorInfo src_info = context.create_tensor_info(input_info);
+    ITensorInfo *src_info = context.create_tensor_info(input_info);
 
     ReshapeAttributes attributes;
     attributes.shape(output_shape);
-    Status status = GpuReshape::validate_op(sketch, &src_info, attributes);
+    Status status = GpuReshape::validate_op(sketch, src_info, attributes);
     ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
 }
 
 template <typename T>
-using DynamicFusionGpuReshapeLayerFixture = DynamicFusionGpuReshapeLayerValidationFixture<CLTensor, CLAccessor, GpuReshape, T>;
+using DynamicFusionGpuReshapeLayerFixture =
+    DynamicFusionGpuReshapeLayerValidationFixture<CLTensor, CLAccessor, GpuReshape, T>;
 
 TEST_SUITE(F32)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuReshapeLayerFixture<float>, framework::DatasetMode::ALL, combine(datasets::SmallReshapeLayerDataset(), framework::dataset::make("DataType",
-                                                                                                                  DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuReshapeLayerFixture<float>,
+                       framework::DatasetMode::ALL,
+                       combine(datasets::SmallReshapeLayerDataset(),
+                               framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -81,8 +92,11 @@
 TEST_SUITE_END() // F32
 
 TEST_SUITE(F16)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuReshapeLayerFixture<half>, framework::DatasetMode::ALL, combine(datasets::SmallReshapeLayerDataset(), framework::dataset::make("DataType",
-                                                                                                                 DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuReshapeLayerFixture<half>,
+                       framework::DatasetMode::ALL,
+                       combine(datasets::SmallReshapeLayerDataset(),
+                               framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -90,8 +104,11 @@
 TEST_SUITE_END() // F16
 
 TEST_SUITE(U8)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuReshapeLayerFixture<uint8_t>, framework::DatasetMode::ALL, combine(datasets::SmallReshapeLayerDataset(), framework::dataset::make("DataType",
-                                                                                                                    DataType::U8)))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuReshapeLayerFixture<uint8_t>,
+                       framework::DatasetMode::ALL,
+                       combine(datasets::SmallReshapeLayerDataset(),
+                               framework::dataset::make("DataType", DataType::U8)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -99,8 +116,11 @@
 TEST_SUITE_END() // U8
 
 TEST_SUITE(S8)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuReshapeLayerFixture<int8_t>, framework::DatasetMode::ALL, combine(datasets::SmallReshapeLayerDataset(), framework::dataset::make("DataType",
-                                                                                                                   DataType::S8)))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuReshapeLayerFixture<int8_t>,
+                       framework::DatasetMode::ALL,
+                       combine(datasets::SmallReshapeLayerDataset(),
+                               framework::dataset::make("DataType", DataType::S8)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -108,8 +128,11 @@
 TEST_SUITE_END() // S8
 
 TEST_SUITE(S16)
-FIXTURE_DATA_TEST_CASE(RunSmall, DynamicFusionGpuReshapeLayerFixture<int16_t>, framework::DatasetMode::ALL, combine(datasets::SmallReshapeLayerDataset(), framework::dataset::make("DataType",
-                                                                                                                    DataType::S16)))
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       DynamicFusionGpuReshapeLayerFixture<int16_t>,
+                       framework::DatasetMode::ALL,
+                       combine(datasets::SmallReshapeLayerDataset(),
+                               framework::dataset::make("DataType", DataType::S16)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp b/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp
index 5f99cd6..10915ac 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Resize.cpp
@@ -1,5 +1,5 @@
 /*
-* Copyright (c) 2022-2023 Arm Limited.
+* Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,8 +29,8 @@
 #include "tests/framework/Asserts.h"
 #include "tests/framework/Fixture.h"
 #include "tests/framework/Macros.h"
-#include "tests/validation/Validation.h"
 #include "tests/validation/fixtures/dynamic_fusion/operators/ResizeFixture.h"
+#include "tests/validation/Validation.h"
 
 using namespace arm_compute::experimental::dynamic_fusion;
 namespace arm_compute
@@ -41,10 +41,10 @@
 {
 namespace
 {
-using datasets::ScaleShapesBaseDataSet;
+using datasets::ScaleAlignCornersSamplingPolicySet;
 using datasets::ScaleInterpolationPolicySet;
 using datasets::ScaleSamplingPolicySet;
-using datasets::ScaleAlignCornersSamplingPolicySet;
+using datasets::ScaleShapesBaseDataSet;
 
 /** We consider vector size in byte 16 since the maximum size of
  * a vector used by @ref CLScaleKernel is currently 16-byte (float4).
@@ -59,9 +59,9 @@
 
 /** Quantization information data set */
 const auto QuantizationInfoSet = framework::dataset::make("QuantizationInfo",
-{
-    QuantizationInfo(0.5f, -1),
-});
+                                                          {
+                                                              QuantizationInfo(0.5f, -1),
+                                                          });
 
 /** Tolerance */
 constexpr AbsoluteTolerance<uint8_t> tolerance_q8(1);
@@ -83,22 +83,20 @@
 
 TEST_SUITE(Validate)
 
-const auto default_input_shape  = TensorShape{ 2, 3, 3, 2 };
-const auto default_output_shape = TensorShape{ 4, 6, 3, 2 };
+const auto default_input_shape  = TensorShape{2, 3, 3, 2};
+const auto default_output_shape = TensorShape{4, 6, 3, 2};
 
 constexpr auto default_data_type   = DataType::U8;
 constexpr auto default_data_layout = DataLayout::NHWC;
 
 TEST_CASE(NullPtr, framework::DatasetMode::ALL)
 {
-    const TensorInfo input_info  = TensorInfo{ default_input_shape, 1, default_data_type, default_data_layout };
-    const TensorInfo output_info = TensorInfo{ default_output_shape, 1, default_data_type, default_data_layout };
+    const TensorInfo input_info  = TensorInfo{default_input_shape, 1, default_data_type, default_data_layout};
+    const TensorInfo output_info = TensorInfo{default_output_shape, 1, default_data_type, default_data_layout};
 
     CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch  sketch{ &context };
-
-    const TensorInfo sketch_input_info = context.create_tensor_info(input_info);
+    GpuWorkloadContext context        = GpuWorkloadContext{&cl_compile_ctx};
+    GpuWorkloadSketch  sketch{&context};
 
     // nullptr is given as input
     Status status = GpuResize::validate_op(sketch, nullptr, ResizeAttributes());
@@ -107,44 +105,43 @@
 
 TEST_CASE(SupportDataType, framework::DatasetMode::ALL)
 {
-    const std::map<DataType, bool> supported_data_types =
-    {
-        { DataType::U8, true },
-        { DataType::S8, false },
-        { DataType::QSYMM8, false },
-        { DataType::QASYMM8, true },
-        { DataType::QASYMM8_SIGNED, true },
-        { DataType::QSYMM8_PER_CHANNEL, false },
-        { DataType::U16, false },
-        { DataType::S16, true },
-        { DataType::QSYMM16, false },
-        { DataType::QASYMM16, false },
-        { DataType::U32, false },
-        { DataType::S32, false },
-        { DataType::U64, false },
-        { DataType::S64, false },
-        { DataType::BFLOAT16, false },
-        { DataType::F16, true },
-        { DataType::F32, true },
-        { DataType::F64, false },
-        { DataType::SIZET, false },
+    const std::map<DataType, bool> supported_data_types = {
+        {DataType::U8, true},
+        {DataType::S8, false},
+        {DataType::QSYMM8, false},
+        {DataType::QASYMM8, true},
+        {DataType::QASYMM8_SIGNED, true},
+        {DataType::QSYMM8_PER_CHANNEL, false},
+        {DataType::U16, false},
+        {DataType::S16, true},
+        {DataType::QSYMM16, false},
+        {DataType::QASYMM16, false},
+        {DataType::U32, false},
+        {DataType::S32, false},
+        {DataType::U64, false},
+        {DataType::S64, false},
+        {DataType::BFLOAT16, false},
+        {DataType::F16, true},
+        {DataType::F32, true},
+        {DataType::F64, false},
+        {DataType::SIZET, false},
     };
 
-    for(auto &kv : supported_data_types)
+    for (auto &kv : supported_data_types)
     {
-        const TensorInfo input_info = TensorInfo{ default_input_shape, 1, kv.first, default_data_layout };
+        const TensorInfo input_info = TensorInfo{default_input_shape, 1, kv.first, default_data_layout};
 
         CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-        GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
-        GpuWorkloadSketch  sketch{ &context };
+        GpuWorkloadContext context        = GpuWorkloadContext{&cl_compile_ctx};
+        GpuWorkloadSketch  sketch{&context};
 
-        const TensorInfo sketch_input_info = context.create_tensor_info(input_info);
+        const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info);
 
         ResizeAttributes attributes;
         attributes.output_width(default_output_shape[0]); // shape is not important unless it's empty
         attributes.output_height(default_output_shape[1]);
 
-        Status status = GpuResize::validate_op(sketch, &sketch_input_info, attributes);
+        Status status = GpuResize::validate_op(sketch, sketch_input_info, attributes);
         ARM_COMPUTE_EXPECT(bool(status) == kv.second, framework::LogLevel::ERRORS);
     }
 }
@@ -153,16 +150,16 @@
 {
     constexpr DataType non_default_data_type = DataType::F32;
 
-    const TensorInfo input_info  = TensorInfo{ default_input_shape, 1, default_data_type, default_data_layout };
-    const TensorInfo output_info = TensorInfo{ default_output_shape, 1, non_default_data_type, default_data_layout };
+    const TensorInfo input_info  = TensorInfo{default_input_shape, 1, default_data_type, default_data_layout};
+    const TensorInfo output_info = TensorInfo{default_output_shape, 1, non_default_data_type, default_data_layout};
 
     CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch  sketch{ &context };
+    GpuWorkloadContext context        = GpuWorkloadContext{&cl_compile_ctx};
+    GpuWorkloadSketch  sketch{&context};
 
-    const TensorInfo sketch_input_info = context.create_tensor_info(input_info);
+    const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info);
 
-    Status status = GpuResize::validate_op(sketch, &sketch_input_info, ResizeAttributes());
+    Status status = GpuResize::validate_op(sketch, sketch_input_info, ResizeAttributes());
     ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS);
 }
 
@@ -173,59 +170,57 @@
     constexpr bool                align_corners        = true;
     constexpr SamplingPolicy      sampling_policy      = SamplingPolicy::CENTER;
 
-    const TensorInfo input_info  = TensorInfo{ default_input_shape, 1, default_data_type, default_data_layout };
-    const TensorInfo output_info = TensorInfo{ default_output_shape, 1, default_data_type, default_data_layout };
+    const TensorInfo input_info  = TensorInfo{default_input_shape, 1, default_data_type, default_data_layout};
+    const TensorInfo output_info = TensorInfo{default_output_shape, 1, default_data_type, default_data_layout};
 
     CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch  sketch{ &context };
+    GpuWorkloadContext context        = GpuWorkloadContext{&cl_compile_ctx};
+    GpuWorkloadSketch  sketch{&context};
 
-    const TensorInfo sketch_input_info = context.create_tensor_info(input_info);
+    const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info);
 
     ResizeAttributes attributes{};
-    attributes.interpolation_policy(interpolation_policy)
-    .sampling_policy(sampling_policy)
-    .align_corners(align_corners);
+    attributes.interpolation_policy(interpolation_policy).sampling_policy(sampling_policy).align_corners(align_corners);
 
-    Status status = GpuResize::validate_op(sketch, &sketch_input_info, attributes);
+    Status status = GpuResize::validate_op(sketch, sketch_input_info, attributes);
     ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS);
 }
 
 TEST_CASE(UnsupportedInterpolationPolicy, framework::DatasetMode::ALL)
 {
-    const TensorInfo input_info           = TensorInfo{ TensorShape(28U, 33U, 2U), 1, DataType::F32, default_data_layout };
-    const TensorInfo output_info          = TensorInfo{ TensorShape(26U, 21U, 2U), 1, DataType::F32, default_data_layout };
+    const TensorInfo input_info  = TensorInfo{TensorShape(28U, 33U, 2U), 1, DataType::F32, default_data_layout};
+    const TensorInfo output_info = TensorInfo{TensorShape(26U, 21U, 2U), 1, DataType::F32, default_data_layout};
     constexpr auto   interpolation_policy = InterpolationPolicy::AREA;
 
     CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch  sketch{ &context };
+    GpuWorkloadContext context        = GpuWorkloadContext{&cl_compile_ctx};
+    GpuWorkloadSketch  sketch{&context};
 
-    const TensorInfo sketch_input_info = context.create_tensor_info(input_info);
+    const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info);
 
     ResizeAttributes attributes{};
     attributes.interpolation_policy(interpolation_policy);
 
-    Status status = GpuResize::validate_op(sketch, &sketch_input_info, attributes);
+    Status status = GpuResize::validate_op(sketch, sketch_input_info, attributes);
     ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS);
 }
 
 TEST_CASE(UnsupportedLayout, framework::DatasetMode::ALL)
 {
-    const TensorInfo input_info           = TensorInfo{ default_input_shape, 1, default_data_type, DataLayout::NCHW };
-    const TensorInfo output_info          = TensorInfo{ default_output_shape, 1, default_data_type, DataLayout::NCHW };
+    const TensorInfo input_info           = TensorInfo{default_input_shape, 1, default_data_type, DataLayout::NCHW};
+    const TensorInfo output_info          = TensorInfo{default_output_shape, 1, default_data_type, DataLayout::NCHW};
     constexpr auto   interpolation_policy = InterpolationPolicy::BILINEAR;
 
     CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-    GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
-    GpuWorkloadSketch  sketch{ &context };
+    GpuWorkloadContext context        = GpuWorkloadContext{&cl_compile_ctx};
+    GpuWorkloadSketch  sketch{&context};
 
-    const TensorInfo sketch_input_info = context.create_tensor_info(input_info);
+    const ITensorInfo *sketch_input_info = context.create_tensor_info(input_info);
 
     ResizeAttributes attributes{};
     attributes.interpolation_policy(interpolation_policy);
 
-    Status status = GpuResize::validate_op(sketch, &sketch_input_info, attributes);
+    Status status = GpuResize::validate_op(sketch, sketch_input_info, attributes);
     ARM_COMPUTE_EXPECT(bool(status) == false, framework::LogLevel::ERRORS);
 }
 
@@ -237,43 +232,60 @@
 TEST_SUITE(Float)
 TEST_SUITE(FP32)
 
-const auto f32_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector<float>())), framework::dataset::make("DataType", DataType::F32));
+const auto f32_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector<float>())),
+                               framework::dataset::make("DataType", DataType::F32));
 
-FIXTURE_DATA_TEST_CASE(Run, DynamicFusionResizeFixture<float>, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_shape, ScaleSamplingPolicySet))
+FIXTURE_DATA_TEST_CASE(Run,
+                       DynamicFusionResizeFixture<float>,
+                       framework::DatasetMode::ALL,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_shape, ScaleSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32, tolerance_f32_absolute);
 }
 
-FIXTURE_DATA_TEST_CASE(RunAlignCorners, DynamicFusionResizeFixture<float>, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_shape, ScaleAlignCornersSamplingPolicySet))
+FIXTURE_DATA_TEST_CASE(RunAlignCorners,
+                       DynamicFusionResizeFixture<float>,
+                       framework::DatasetMode::ALL,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_shape, ScaleAlignCornersSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32, tolerance_f32_absolute);
 }
-const auto f32_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector<float>())), framework::dataset::make("DataType", DataType::F32));
-FIXTURE_DATA_TEST_CASE(RunNightly, DynamicFusionResizeFixture<float>, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_nightly_shape, ScaleSamplingPolicySet))
+const auto f32_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector<float>())),
+                                       framework::dataset::make("DataType", DataType::F32));
+FIXTURE_DATA_TEST_CASE(RunNightly,
+                       DynamicFusionResizeFixture<float>,
+                       framework::DatasetMode::NIGHTLY,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_nightly_shape, ScaleSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32, tolerance_f32_absolute);
 }
-FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeFixture<float>, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_nightly_shape,
-                       ScaleAlignCornersSamplingPolicySet))
+FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners,
+                       DynamicFusionResizeFixture<float>,
+                       framework::DatasetMode::NIGHTLY,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(f32_nightly_shape, ScaleAlignCornersSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_f32, tolerance_num_f32, tolerance_f32_absolute);
@@ -281,41 +293,58 @@
 TEST_SUITE_END() // FP32
 
 TEST_SUITE(FP16)
-const auto f16_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector<half>())), framework::dataset::make("DataType", DataType::F16));
-FIXTURE_DATA_TEST_CASE(Run, DynamicFusionResizeFixture<half>, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_shape, ScaleSamplingPolicySet))
+const auto f16_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector<half>())),
+                               framework::dataset::make("DataType", DataType::F16));
+FIXTURE_DATA_TEST_CASE(Run,
+                       DynamicFusionResizeFixture<half>,
+                       framework::DatasetMode::ALL,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_shape, ScaleSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_f16, 0.0f, abs_tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(RunAlignCorners, DynamicFusionResizeFixture<half>, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_shape, ScaleAlignCornersSamplingPolicySet))
+FIXTURE_DATA_TEST_CASE(RunAlignCorners,
+                       DynamicFusionResizeFixture<half>,
+                       framework::DatasetMode::ALL,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_shape, ScaleAlignCornersSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_f16, 0.0f, abs_tolerance_f16);
 }
-const auto f16_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector<half>())), framework::dataset::make("DataType", DataType::F16));
-FIXTURE_DATA_TEST_CASE(RunNightly, DynamicFusionResizeFixture<half>, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_nightly_shape, ScaleSamplingPolicySet))
+const auto f16_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector<half>())),
+                                       framework::dataset::make("DataType", DataType::F16));
+FIXTURE_DATA_TEST_CASE(RunNightly,
+                       DynamicFusionResizeFixture<half>,
+                       framework::DatasetMode::NIGHTLY,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_nightly_shape, ScaleSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_f16, 0.0f, abs_tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeFixture<half>, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_nightly_shape,
-                       ScaleAlignCornersSamplingPolicySet))
+FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners,
+                       DynamicFusionResizeFixture<half>,
+                       framework::DatasetMode::NIGHTLY,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(f16_nightly_shape, ScaleAlignCornersSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_f16, 0.0f, abs_tolerance_f16);
@@ -325,41 +354,58 @@
 
 TEST_SUITE(Integer)
 TEST_SUITE(U8)
-const auto u8_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector<uint8_t>())), framework::dataset::make("DataType", DataType::U8));
-FIXTURE_DATA_TEST_CASE(Run, DynamicFusionResizeFixture<uint8_t>, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_shape, ScaleSamplingPolicySet))
+const auto u8_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector<uint8_t>())),
+                              framework::dataset::make("DataType", DataType::U8));
+FIXTURE_DATA_TEST_CASE(Run,
+                       DynamicFusionResizeFixture<uint8_t>,
+                       framework::DatasetMode::ALL,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_shape, ScaleSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_q8);
 }
-FIXTURE_DATA_TEST_CASE(RunAlignCorners, DynamicFusionResizeFixture<uint8_t>, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_shape, ScaleAlignCornersSamplingPolicySet))
+FIXTURE_DATA_TEST_CASE(RunAlignCorners,
+                       DynamicFusionResizeFixture<uint8_t>,
+                       framework::DatasetMode::ALL,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_shape, ScaleAlignCornersSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_q8);
 }
-const auto u8_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector<uint8_t>())), framework::dataset::make("DataType", DataType::U8));
-FIXTURE_DATA_TEST_CASE(RunNightly, DynamicFusionResizeFixture<uint8_t>, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_nightly_shape, ScaleSamplingPolicySet))
+const auto u8_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector<uint8_t>())),
+                                      framework::dataset::make("DataType", DataType::U8));
+FIXTURE_DATA_TEST_CASE(RunNightly,
+                       DynamicFusionResizeFixture<uint8_t>,
+                       framework::DatasetMode::NIGHTLY,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_nightly_shape, ScaleSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_q8);
 }
-FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeFixture<uint8_t>, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_nightly_shape,
-                       ScaleAlignCornersSamplingPolicySet))
+FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners,
+                       DynamicFusionResizeFixture<uint8_t>,
+                       framework::DatasetMode::NIGHTLY,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(u8_nightly_shape, ScaleAlignCornersSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_q8);
@@ -367,41 +413,58 @@
 TEST_SUITE_END() // U8
 
 TEST_SUITE(S16)
-const auto s16_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector<int16_t>())), framework::dataset::make("DataType", DataType::S16));
-FIXTURE_DATA_TEST_CASE(Run, DynamicFusionResizeFixture<int16_t>, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_shape, ScaleSamplingPolicySet))
+const auto s16_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector<int16_t>())),
+                               framework::dataset::make("DataType", DataType::S16));
+FIXTURE_DATA_TEST_CASE(Run,
+                       DynamicFusionResizeFixture<int16_t>,
+                       framework::DatasetMode::ALL,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_shape, ScaleSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_s16);
 }
-FIXTURE_DATA_TEST_CASE(RunAlignCorners, DynamicFusionResizeFixture<int16_t>, framework::DatasetMode::ALL, ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_shape, ScaleAlignCornersSamplingPolicySet))
+FIXTURE_DATA_TEST_CASE(RunAlignCorners,
+                       DynamicFusionResizeFixture<int16_t>,
+                       framework::DatasetMode::ALL,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_shape, ScaleAlignCornersSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_s16);
 }
-const auto s16_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector<int16_t>())), framework::dataset::make("DataType", DataType::S16));
-FIXTURE_DATA_TEST_CASE(RunNightly, DynamicFusionResizeFixture<int16_t>, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_nightly_shape, ScaleSamplingPolicySet))
+const auto s16_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector<int16_t>())),
+                                       framework::dataset::make("DataType", DataType::S16));
+FIXTURE_DATA_TEST_CASE(RunNightly,
+                       DynamicFusionResizeFixture<int16_t>,
+                       framework::DatasetMode::NIGHTLY,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_nightly_shape, ScaleSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_s16);
 }
-FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeFixture<int16_t>, framework::DatasetMode::NIGHTLY, ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_nightly_shape,
-                       ScaleAlignCornersSamplingPolicySet))
+FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners,
+                       DynamicFusionResizeFixture<int16_t>,
+                       framework::DatasetMode::NIGHTLY,
+                       ASSEMBLE_DATASET_DYNAMIC_FUSION(s16_nightly_shape, ScaleAlignCornersSamplingPolicySet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_s16);
@@ -410,50 +473,70 @@
 TEST_SUITE_END() // Integer
 
 template <typename T>
-using DynamicFusionResizeQuantizedFixture = DynamicFusionResizeQuantizedValidationFixture<CLTensor, CLAccessor, GpuResize, T>;
+using DynamicFusionResizeQuantizedFixture =
+    DynamicFusionResizeQuantizedValidationFixture<CLTensor, CLAccessor, GpuResize, T>;
 TEST_SUITE(Quantized)
 TEST_SUITE(QASYMM8)
-const auto qasymm8_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector<uint8_t>())), framework::dataset::make("DataType", DataType::QASYMM8));
-FIXTURE_DATA_TEST_CASE(Run, DynamicFusionResizeQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_shape, ScaleSamplingPolicySet,
-                       QuantizationInfoSet))
+const auto qasymm8_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector<uint8_t>())),
+                                   framework::dataset::make("DataType", DataType::QASYMM8));
+FIXTURE_DATA_TEST_CASE(Run,
+                       DynamicFusionResizeQuantizedFixture<uint8_t>,
+                       framework::DatasetMode::ALL,
+                       ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_shape,
+                                                                 ScaleSamplingPolicySet,
+                                                                 QuantizationInfoSet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_q8);
 }
-FIXTURE_DATA_TEST_CASE(RunAlignCorners, DynamicFusionResizeQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_shape,
-                       ScaleAlignCornersSamplingPolicySet,
-                       QuantizationInfoSet))
+FIXTURE_DATA_TEST_CASE(RunAlignCorners,
+                       DynamicFusionResizeQuantizedFixture<uint8_t>,
+                       framework::DatasetMode::ALL,
+                       ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_shape,
+                                                                 ScaleAlignCornersSamplingPolicySet,
+                                                                 QuantizationInfoSet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_q8);
 }
-const auto qasymm8_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector<uint8_t>())), framework::dataset::make("DataType", DataType::QASYMM8));
-FIXTURE_DATA_TEST_CASE(RunNightly, DynamicFusionResizeQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_nightly_shape,
-                       ScaleSamplingPolicySet,
-                       QuantizationInfoSet))
+const auto qasymm8_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector<uint8_t>())),
+                                           framework::dataset::make("DataType", DataType::QASYMM8));
+FIXTURE_DATA_TEST_CASE(RunNightly,
+                       DynamicFusionResizeQuantizedFixture<uint8_t>,
+                       framework::DatasetMode::NIGHTLY,
+                       ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_nightly_shape,
+                                                                 ScaleSamplingPolicySet,
+                                                                 QuantizationInfoSet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_q8);
 }
-FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_nightly_shape,
-                       ScaleAlignCornersSamplingPolicySet,
-                       QuantizationInfoSet))
+FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners,
+                       DynamicFusionResizeQuantizedFixture<uint8_t>,
+                       framework::DatasetMode::NIGHTLY,
+                       ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_nightly_shape,
+                                                                 ScaleAlignCornersSamplingPolicySet,
+                                                                 QuantizationInfoSet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_q8);
@@ -461,47 +544,66 @@
 TEST_SUITE_END() // QASYMM8
 
 TEST_SUITE(QASYMM8_SIGNED)
-const auto qasymm8_signed_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector<int8_t>())), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED));
-FIXTURE_DATA_TEST_CASE(Run, DynamicFusionResizeQuantizedFixture<int8_t>, framework::DatasetMode::ALL, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_shape, ScaleSamplingPolicySet,
-                       QuantizationInfoSet))
+const auto qasymm8_signed_shape = combine((SCALE_PRECOMMIT_SHAPE_DATASET(num_elements_per_vector<int8_t>())),
+                                          framework::dataset::make("DataType", DataType::QASYMM8_SIGNED));
+FIXTURE_DATA_TEST_CASE(Run,
+                       DynamicFusionResizeQuantizedFixture<int8_t>,
+                       framework::DatasetMode::ALL,
+                       ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_shape,
+                                                                 ScaleSamplingPolicySet,
+                                                                 QuantizationInfoSet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_qs8);
 }
-FIXTURE_DATA_TEST_CASE(RunAlignCorners, DynamicFusionResizeQuantizedFixture<int8_t>, framework::DatasetMode::ALL, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_shape,
-                       ScaleAlignCornersSamplingPolicySet,
-                       QuantizationInfoSet))
+FIXTURE_DATA_TEST_CASE(RunAlignCorners,
+                       DynamicFusionResizeQuantizedFixture<int8_t>,
+                       framework::DatasetMode::ALL,
+                       ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_shape,
+                                                                 ScaleAlignCornersSamplingPolicySet,
+                                                                 QuantizationInfoSet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_qs8);
 }
-const auto qasymm8_signed_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector<int8_t>())), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED));
-FIXTURE_DATA_TEST_CASE(RunNightly, DynamicFusionResizeQuantizedFixture<int8_t>, framework::DatasetMode::NIGHTLY, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_nightly_shape,
-                       ScaleSamplingPolicySet,
-                       QuantizationInfoSet))
+const auto qasymm8_signed_nightly_shape = combine((SCALE_NIGHTLY_SHAPE_DATASET(num_elements_per_vector<int8_t>())),
+                                                  framework::dataset::make("DataType", DataType::QASYMM8_SIGNED));
+FIXTURE_DATA_TEST_CASE(RunNightly,
+                       DynamicFusionResizeQuantizedFixture<int8_t>,
+                       framework::DatasetMode::NIGHTLY,
+                       ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_nightly_shape,
+                                                                 ScaleSamplingPolicySet,
+                                                                 QuantizationInfoSet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_qs8);
 }
-FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners, DynamicFusionResizeQuantizedFixture<int8_t>, framework::DatasetMode::NIGHTLY, ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_nightly_shape,
-                       ScaleAlignCornersSamplingPolicySet,
-                       QuantizationInfoSet))
+FIXTURE_DATA_TEST_CASE(RunNightlyAlignCorners,
+                       DynamicFusionResizeQuantizedFixture<int8_t>,
+                       framework::DatasetMode::NIGHTLY,
+                       ASSEMBLE_QUANTIZED_DATASET_DYNAMIC_FUSION(qasymm8_signed_nightly_shape,
+                                                                 ScaleAlignCornersSamplingPolicySet,
+                                                                 QuantizationInfoSet))
 {
     //Create valid region
     TensorInfo        src_info(_shape, 1, _data_type);
-    const ValidRegion valid_region = calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
+    const ValidRegion valid_region =
+        calculate_valid_region_scale(src_info, _reference.shape(), _interpolation_policy, _sampling_policy, false);
 
     // Validate output
     validate(CLAccessor(_target), _reference, valid_region, tolerance_qs8);
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp b/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp
index e995511..0134a7c 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Sigmoid.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,10 +29,10 @@
 #include "tests/CL/CLAccessor.h"
 #include "tests/datasets/ShapeDatasets.h"
 #include "tests/framework/Asserts.h"
-#include "tests/framework/Macros.h"
 #include "tests/framework/datasets/Datasets.h"
-#include "tests/validation/Validation.h"
+#include "tests/framework/Macros.h"
 #include "tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h"
+#include "tests/validation/Validation.h"
 
 namespace arm_compute
 {
@@ -65,9 +65,9 @@
     GpuWorkloadSketch sketch{ &context };
 
     // Fuse sigmoid
-    const TensorInfo src_info = context.create_tensor_info(input_info);
+    const ITensorInfo *src_info = context.create_tensor_info(input_info);
 
-    const bool res = static_cast<bool>(GpuSigmoid::validate_op(sketch, &src_info));
+    const bool res = static_cast<bool>(GpuSigmoid::validate_op(sketch, src_info));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
 }
 // clang-format on
@@ -81,8 +81,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
                        DynamicFusionSigmoidOpFixture<half>,
                        framework::DatasetMode::ALL,
-                       combine(combine(datasets::SmallShapes(),
-                                       framework::dataset::make("Fuse", { false })),
+                       combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {false})),
                                framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
@@ -92,8 +91,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp,
                        DynamicFusionSigmoidOpFixture<half>,
                        framework::DatasetMode::ALL,
-                       combine(combine(datasets::Small5dShapes(),
-                                       framework::dataset::make("Fuse", { false })),
+                       combine(combine(datasets::Small5dShapes(), framework::dataset::make("Fuse", {false})),
                                framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
@@ -104,8 +102,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmallTwoOps,
                        DynamicFusionSigmoidOpFixture<half>,
                        framework::DatasetMode::ALL,
-                       combine(combine(datasets::SmallShapes(),
-                                       framework::dataset::make("Fuse", { true })),
+                       combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {true})),
                                framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
@@ -118,8 +115,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
                        DynamicFusionSigmoidOpFixture<float>,
                        framework::DatasetMode::ALL,
-                       combine(combine(datasets::SmallShapes(),
-                                       framework::dataset::make("Fuse", { false })),
+                       combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {false})),
                                framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
@@ -129,8 +125,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp,
                        DynamicFusionSigmoidOpFixture<float>,
                        framework::DatasetMode::ALL,
-                       combine(combine(datasets::Small5dShapes(),
-                                       framework::dataset::make("Fuse", { false })),
+                       combine(combine(datasets::Small5dShapes(), framework::dataset::make("Fuse", {false})),
                                framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
@@ -141,8 +136,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmallTwoOps,
                        DynamicFusionSigmoidOpFixture<float>,
                        framework::DatasetMode::ALL,
-                       combine(combine(datasets::SmallShapes(),
-                                       framework::dataset::make("Fuse", { true })),
+                       combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {true})),
                                framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp b/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp
index 340f5dc..b7cb6ba 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Softmax.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,11 +28,11 @@
 #include "tests/CL/CLAccessor.h"
 #include "tests/datasets/ShapeDatasets.h"
 #include "tests/framework/Asserts.h"
+#include "tests/framework/datasets/Datasets.h"
 #include "tests/framework/Fixture.h"
 #include "tests/framework/Macros.h"
-#include "tests/framework/datasets/Datasets.h"
-#include "tests/validation/Validation.h"
 #include "tests/validation/fixtures/dynamic_fusion/operators/SoftmaxFixture.h"
+#include "tests/validation/Validation.h"
 
 using namespace arm_compute::experimental::dynamic_fusion;
 
@@ -110,9 +110,9 @@
 
     SoftmaxAttributes softmax_attr{};
     softmax_attr.axis(axis).beta(beta).is_log_softmax(false);
-    TensorInfo src_info  = context.create_tensor_info(input_info);
-    TensorInfo dst_info = context.create_tensor_info(output_info);
-    const bool res = static_cast<bool>(GpuSoftmax::validate_op(sketch, &src_info, &dst_info, softmax_attr));
+    ITensorInfo* src_info  = context.create_tensor_info(input_info);
+    ITensorInfo* dst_info = context.create_tensor_info(output_info);
+    const bool res = static_cast<bool>(GpuSoftmax::validate_op(sketch, src_info, dst_info, softmax_attr));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
 }
 
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp b/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp
index 022c9b4..ef9f75b 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Sub.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,14 +29,13 @@
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuSub.h"
 
 #include "tests/CL/CLAccessor.h"
-#include "tests/framework/Fixture.h"
-#include "tests/framework/Macros.h"
-#include "tests/framework/datasets/Datasets.h"
-#include "tests/validation/Validation.h"
-
 #include "tests/datasets/DynamicFusionDataset.h"
 #include "tests/datasets/ShapeDatasets.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/framework/Fixture.h"
+#include "tests/framework/Macros.h"
 #include "tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h"
+#include "tests/validation/Validation.h"
 
 namespace arm_compute
 {
@@ -99,29 +98,32 @@
     auto          lhs_info         = context.create_tensor_info(input1_info);
     auto          rhs_info         = context.create_tensor_info(input2_info);
 
-    bool res = bool(GpuSub::validate_op(sketch, &lhs_info, &rhs_info));
+    bool res = bool(GpuSub::validate_op(sketch, lhs_info, rhs_info));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
 }
 // clang-format on
 // *INDENT-ON*
 
 template <typename T>
-using DynamicFusionCLSubFixture = DynamicFusionGpuElementwiseBinaryOneOpValidationFixture<CLTensor, CLAccessor, GpuSub, T>;
+using DynamicFusionCLSubFixture =
+    DynamicFusionGpuElementwiseBinaryOneOpValidationFixture<CLTensor, CLAccessor, GpuSub, T>;
 
 template <typename T>
-using DynamicFusionCLSubBroadcastFixture = DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture<CLTensor, CLAccessor, GpuSub, T>;
+using DynamicFusionCLSubBroadcastFixture =
+    DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture<CLTensor, CLAccessor, GpuSub, T>;
 
 template <typename T>
-using DynamicFusionCLSubTwoOpsFixture = DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture<CLTensor, CLAccessor, GpuSub, T>;
+using DynamicFusionCLSubTwoOpsFixture =
+    DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture<CLTensor, CLAccessor, GpuSub, T>;
 
 TEST_SUITE(FP32)
 FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
                        DynamicFusionCLSubFixture<float>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
                                                datasets::SmallShapes()),
-                                       framework::dataset::make("DataType", { DataType::F32 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F32})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -129,10 +131,10 @@
 FIXTURE_DATA_TEST_CASE(RunLargeOneOp,
                        DynamicFusionCLSubFixture<float>,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
                                                datasets::LargeShapes()),
-                                       framework::dataset::make("DataType", { DataType::F32 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F32})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -140,10 +142,10 @@
 FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp,
                        DynamicFusionCLSubBroadcastFixture<float>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
                                                datasets::TemporaryLimitedSmallShapesBroadcast()),
-                                       framework::dataset::make("DataType", { DataType::F32 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F32})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -152,22 +154,23 @@
 FIXTURE_DATA_TEST_CASE(RunLargeBroadcastOneOp,
                        DynamicFusionCLSubBroadcastFixture<float>,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
                                                datasets::TemporaryLimitedLargeShapesBroadcast()),
-                                       framework::dataset::make("DataType", { DataType::F32 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F32})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
-FIXTURE_DATA_TEST_CASE(RunSmallTwoOps,
-                       DynamicFusionCLSubTwoOpsFixture<float>,
-                       framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }),
-                                                       datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes()),
-                                               framework::dataset::make("DataType", { DataType::F32 })),
-                                       framework::dataset::make("InPlace", { false })),
-                               framework::dataset::make("FuseTwoOps", { true })))
+FIXTURE_DATA_TEST_CASE(
+    RunSmallTwoOps,
+    DynamicFusionCLSubTwoOpsFixture<float>,
+    framework::DatasetMode::PRECOMMIT,
+    combine(combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
+                                    datasets::DynamicFusionElementwiseBinaryTwoOpsSmallShapes()),
+                            framework::dataset::make("DataType", {DataType::F32})),
+                    framework::dataset::make("InPlace", {false})),
+            framework::dataset::make("FuseTwoOps", {true})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -178,10 +181,10 @@
 FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
                        DynamicFusionCLSubFixture<half>,
                        framework::DatasetMode::ALL,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
                                                datasets::SmallShapes()),
-                                       framework::dataset::make("DataType", { DataType::F16 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F16})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -190,10 +193,10 @@
 FIXTURE_DATA_TEST_CASE(RunSmallBroadcastOneOp,
                        DynamicFusionCLSubBroadcastFixture<half>,
                        framework::DatasetMode::ALL,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
                                                datasets::TemporaryLimitedSmallShapesBroadcast()),
-                                       framework::dataset::make("DataType", { DataType::F16 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::F16})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -205,10 +208,10 @@
 FIXTURE_DATA_TEST_CASE(RunSmall,
                        DynamicFusionCLSubFixture<int32_t>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
                                                datasets::SmallShapes()),
-                                       framework::dataset::make("DataType", { DataType::S32 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::S32})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -219,10 +222,10 @@
 FIXTURE_DATA_TEST_CASE(RunSmall,
                        DynamicFusionCLSubFixture<int16_t>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
                                                datasets::SmallShapes()),
-                                       framework::dataset::make("DataType", { DataType::S16 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::S16})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -230,10 +233,10 @@
 FIXTURE_DATA_TEST_CASE(RunLarge,
                        DynamicFusionCLSubFixture<int16_t>,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
                                                datasets::LargeShapes()),
-                                       framework::dataset::make("DataType", { DataType::S16 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::S16})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -244,10 +247,10 @@
 FIXTURE_DATA_TEST_CASE(RunSmall,
                        DynamicFusionCLSubFixture<uint8_t>,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(framework::dataset::make("ElementwiseOp", { ArithmeticOperation::SUB }),
+                       combine(combine(combine(framework::dataset::make("ElementwiseOp", {ArithmeticOperation::SUB}),
                                                datasets::SmallShapes()),
-                                       framework::dataset::make("DataType", { DataType::U8 })),
-                               framework::dataset::make("InPlace", { false })))
+                                       framework::dataset::make("DataType", {DataType::U8})),
+                               framework::dataset::make("InPlace", {false})))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
diff --git a/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp b/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp
index 12f3677..2560f3a 100644
--- a/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp
+++ b/tests/validation/dynamic_fusion/gpu/cl/Tanh.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,10 +29,10 @@
 #include "tests/CL/CLAccessor.h"
 #include "tests/datasets/ShapeDatasets.h"
 #include "tests/framework/Asserts.h"
-#include "tests/framework/Macros.h"
 #include "tests/framework/datasets/Datasets.h"
-#include "tests/validation/Validation.h"
+#include "tests/framework/Macros.h"
 #include "tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h"
+#include "tests/validation/Validation.h"
 
 namespace arm_compute
 {
@@ -65,9 +65,9 @@
     GpuWorkloadSketch sketch{ &context };
 
     // Fuse tanh
-    const TensorInfo src_info = context.create_tensor_info(input_info);
+    const ITensorInfo* src_info = context.create_tensor_info(input_info);
 
-    const bool res = static_cast<bool>(GpuTanh::validate_op(sketch, &src_info));
+    const bool res = static_cast<bool>(GpuTanh::validate_op(sketch, src_info));
     ARM_COMPUTE_EXPECT(res == expected, framework::LogLevel::ERRORS);
 }
 // clang-format on
@@ -81,8 +81,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
                        DynamicFusionTanhOpFixture<half>,
                        framework::DatasetMode::ALL,
-                       combine(combine(datasets::SmallShapes(),
-                                       framework::dataset::make("Fuse", { false })),
+                       combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {false})),
                                framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
@@ -92,8 +91,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp,
                        DynamicFusionTanhOpFixture<half>,
                        framework::DatasetMode::ALL,
-                       combine(combine(datasets::Small5dShapes(),
-                                       framework::dataset::make("Fuse", { false })),
+                       combine(combine(datasets::Small5dShapes(), framework::dataset::make("Fuse", {false})),
                                framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
@@ -104,8 +102,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmallTwoOps,
                        DynamicFusionTanhOpFixture<half>,
                        framework::DatasetMode::ALL,
-                       combine(combine(datasets::SmallShapes(),
-                                       framework::dataset::make("Fuse", { true })),
+                       combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {true})),
                                framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
@@ -118,8 +115,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmallOneOp,
                        DynamicFusionTanhOpFixture<float>,
                        framework::DatasetMode::ALL,
-                       combine(combine(datasets::SmallShapes(),
-                                       framework::dataset::make("Fuse", { false })),
+                       combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {false})),
                                framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
@@ -129,8 +125,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmall5dOneOp,
                        DynamicFusionTanhOpFixture<float>,
                        framework::DatasetMode::ALL,
-                       combine(combine(datasets::Small5dShapes(),
-                                       framework::dataset::make("Fuse", { false })),
+                       combine(combine(datasets::Small5dShapes(), framework::dataset::make("Fuse", {false})),
                                framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
@@ -141,8 +136,7 @@
 FIXTURE_DATA_TEST_CASE(RunSmallTwoOps,
                        DynamicFusionTanhOpFixture<float>,
                        framework::DatasetMode::ALL,
-                       combine(combine(datasets::SmallShapes(),
-                                       framework::dataset::make("Fuse", { true })),
+                       combine(combine(datasets::SmallShapes(), framework::dataset::make("Fuse", {true})),
                                framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/DepthwiseConv2dFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/DepthwiseConv2dFixture.h
index 6498a06..ca4de11 100644
--- a/tests/validation/fixtures/dynamic_fusion/gpu/cl/DepthwiseConv2dFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/DepthwiseConv2dFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,14 +21,13 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DEPTHWISECONV2DFIXTURE
-#define TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DEPTHWISECONV2DFIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DEPTHWISECONV2DFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DEPTHWISECONV2DFIXTURE_H
 
 #include "arm_compute/core/CL/CLKernelLibrary.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
-
 #include "arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h"
 #include "arm_compute/dynamic_fusion/sketch/attributes/DepthwiseConv2dAttributes.h"
 #include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
@@ -36,13 +35,11 @@
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h"
 
 #include "tests/CL/CLAccessor.h"
-
 #include "tests/framework/Asserts.h"
 #include "tests/framework/Fixture.h"
 #include "tests/framework/Macros.h"
-
-#include "tests/validation/Validation.h"
 #include "tests/validation/reference/DepthwiseConvolutionLayer.h"
+#include "tests/validation/Validation.h"
 
 using namespace arm_compute::experimental::dynamic_fusion;
 
@@ -56,22 +53,30 @@
 class DynamicFusionGpuDepthwiseConv2dValidationGenericFixture : public framework::Fixture
 {
 public:
-    using TBias = typename std::conditional < std::is_same<typename std::decay<T>::type, uint8_t>::value
-                  || std::is_same<typename std::decay<T>::type, int8_t>::value,
-                  int32_t, T >::type; // If T: uint8_t or int8_t then TBias: int32_t, otherwise TBias: T
+    using TBias = typename std::conditional<std::is_same<typename std::decay<T>::type, uint8_t>::value ||
+                                                std::is_same<typename std::decay<T>::type, int8_t>::value,
+                                            int32_t,
+                                            T>::type; // If T: uint8_t or int8_t then TBias: int32_t, otherwise TBias: T
 
-    void setup(TensorShape input_shape, Size2D kernel_size, const PadStrideInfo &pad_stride, const Size2D &dilation,
-               const unsigned int depth_multiplier, const DataType data_type, const DataLayout data_layout)
+    void setup(TensorShape          input_shape,
+               Size2D               kernel_size,
+               const PadStrideInfo &pad_stride,
+               const Size2D        &dilation,
+               const unsigned int   depth_multiplier,
+               const DataType       data_type,
+               const DataLayout     data_layout)
     {
-        ARM_COMPUTE_ERROR_ON(data_layout != DataLayout::NHWC); // Dynamic fusion depthwise conv2d only supports NHWC layout
+        ARM_COMPUTE_ERROR_ON(data_layout !=
+                             DataLayout::NHWC); // Dynamic fusion depthwise conv2d only supports NHWC layout
 
         DepthwiseConv2dAttributes dwc_conv2d_attr;
-        const Padding2D           padding_2d(pad_stride.pad_left(), pad_stride.pad_right(), pad_stride.pad_top(), pad_stride.pad_bottom());
+        const Padding2D           padding_2d(pad_stride.pad_left(), pad_stride.pad_right(), pad_stride.pad_top(),
+                                             pad_stride.pad_bottom());
         dwc_conv2d_attr.pad(padding_2d)
-        .stride(Size2D(pad_stride.stride().first, pad_stride.stride().second))
-        .dilation(dilation)
-        .depth_multiplier(depth_multiplier)
-        .dimension_rounding_type(pad_stride.round());
+            .stride(Size2D(pad_stride.stride().first, pad_stride.stride().second))
+            .dilation(dilation)
+            .depth_multiplier(depth_multiplier)
+            .dimension_rounding_type(pad_stride.round());
 
         // Calculate Output and Weight Shapes
         TensorShape weights_shape = TensorShape(kernel_size.width, kernel_size.height);
@@ -79,8 +84,9 @@
         const TensorInfo in_info(input_shape, 1, data_type);
         const TensorInfo we_info(weights_shape, 1, data_type);
 
-        const ConvolutionInfo info{ pad_stride, depth_multiplier, ActivationLayerInfo(), dilation };
-        const TensorShape     output_shape = misc::shape_calculator::compute_depthwise_convolution_shape(in_info, we_info, info);
+        const ConvolutionInfo info{pad_stride, depth_multiplier, ActivationLayerInfo(), dilation};
+        const TensorShape     output_shape =
+            misc::shape_calculator::compute_depthwise_convolution_shape(in_info, we_info, info);
 
         weights_shape.set(2, output_shape.z());
         const TensorShape bias_shape = TensorShape(weights_shape[2]);
@@ -95,11 +101,11 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        switch(tensor.data_type())
+        switch (tensor.data_type())
         {
             case DataType::F16:
             {
-                arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
+                arm_compute::utils::uniform_real_distribution_16bit<half> distribution{-1.0f, 1.0f};
                 library->fill(tensor, distribution, i);
                 break;
             }
@@ -115,7 +121,10 @@
     }
 
     // Given input is in nchw format
-    TensorType compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape &bias_shape, const DepthwiseConv2dAttributes dwc_conv2d_attr)
+    TensorType compute_target(TensorShape                     input_shape,
+                              TensorShape                     weights_shape,
+                              const TensorShape              &bias_shape,
+                              const DepthwiseConv2dAttributes dwc_conv2d_attr)
     {
         ARM_COMPUTE_ERROR_ON(_data_layout != DataLayout::NHWC);
 
@@ -125,24 +134,24 @@
 
         // Create a new workload sketch
         auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-        auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-        GpuWorkloadSketch sketch{ &context };
+        auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+        GpuWorkloadSketch sketch{&context};
 
         // Create sketch tensors
-        TensorInfo input_info  = context.create_tensor_info(TensorInfo(input_shape, 1, _data_type, _data_layout));
-        TensorInfo weight_info = context.create_tensor_info(TensorInfo(weights_shape, 1, _data_type, _data_layout));
-        TensorInfo bias_info   = context.create_tensor_info(TensorInfo(bias_shape, 1, _data_type, _data_layout));
-        TensorInfo dst_info    = context.create_tensor_info();
+        ITensorInfo *input_info  = context.create_tensor_info(TensorInfo(input_shape, 1, _data_type, _data_layout));
+        ITensorInfo *weight_info = context.create_tensor_info(TensorInfo(weights_shape, 1, _data_type, _data_layout));
+        ITensorInfo *bias_info   = context.create_tensor_info(TensorInfo(bias_shape, 1, _data_type, _data_layout));
+        ITensorInfo *dst_info    = context.create_tensor_info();
 
-        ITensorInfo *ans_info = FunctionType::create_op(sketch, &input_info, &weight_info, &bias_info, dwc_conv2d_attr);
-        GpuOutput::create_op(sketch, ans_info, &dst_info);
+        ITensorInfo *ans_info = FunctionType::create_op(sketch, input_info, weight_info, bias_info, dwc_conv2d_attr);
+        GpuOutput::create_op(sketch, ans_info, dst_info);
 
         // Configure runtime
         ClWorkloadRuntime runtime;
         runtime.configure(sketch);
 
         // (Important) Allocate auxiliary tensor memory if there are any
-        for(auto &data : runtime.get_auxiliary_tensors())
+        for (auto &data : runtime.get_auxiliary_tensors())
         {
             CLTensor     *tensor      = std::get<0>(data);
             TensorInfo    info        = std::get<1>(data);
@@ -158,10 +167,10 @@
         TensorType t_dst{};
 
         // Initialize user tensors
-        t_input.allocator()->init(input_info);
-        t_weight.allocator()->init(weight_info);
-        t_bias.allocator()->init(bias_info);
-        t_dst.allocator()->init(dst_info);
+        t_input.allocator()->init(*input_info);
+        t_weight.allocator()->init(*weight_info);
+        t_bias.allocator()->init(*bias_info);
+        t_dst.allocator()->init(*dst_info);
 
         // Allocate and fill user tensors
         t_input.allocator()->allocate();
@@ -174,17 +183,20 @@
         fill(AccessorType(t_bias), 2);
 
         // Run runtime
-        runtime.run({ &t_input, &t_weight, &t_bias, &t_dst });
+        runtime.run({&t_input, &t_weight, &t_bias, &t_dst});
         return t_dst;
     }
 
-    SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape,
-                                      const TensorShape &output_shape, DepthwiseConv2dAttributes dwc_conv2d_attr)
+    SimpleTensor<T> compute_reference(const TensorShape        &input_shape,
+                                      const TensorShape        &weights_shape,
+                                      const TensorShape        &bias_shape,
+                                      const TensorShape        &output_shape,
+                                      DepthwiseConv2dAttributes dwc_conv2d_attr)
     {
         // Create reference
-        SimpleTensor<T>     src{ input_shape, _data_type, 1 };
-        SimpleTensor<T>     weight{ weights_shape, _data_type, 1 };
-        SimpleTensor<TBias> bias{ bias_shape, _data_type, 1 };
+        SimpleTensor<T>     src{input_shape, _data_type, 1};
+        SimpleTensor<T>     weight{weights_shape, _data_type, 1};
+        SimpleTensor<TBias> bias{bias_shape, _data_type, 1};
 
         fill(src, 0);
         fill(weight, 1);
@@ -195,10 +207,13 @@
         auto bias_nchw         = bias;
         auto output_shape_nchw = output_shape;
 
-        PadStrideInfo legacy_pad_stride(dwc_conv2d_attr.stride().x(), dwc_conv2d_attr.stride().y(), dwc_conv2d_attr.pad().left, dwc_conv2d_attr.pad().right, dwc_conv2d_attr.pad().top,
-                                        dwc_conv2d_attr.pad().bottom,
+        PadStrideInfo legacy_pad_stride(dwc_conv2d_attr.stride().x(), dwc_conv2d_attr.stride().y(),
+                                        dwc_conv2d_attr.pad().left, dwc_conv2d_attr.pad().right,
+                                        dwc_conv2d_attr.pad().top, dwc_conv2d_attr.pad().bottom,
                                         DimensionRoundingType{});
-        auto dst_nchw = reference::depthwise_convolution(src_nchw, weights_nchw, bias_nchw, output_shape_nchw, legacy_pad_stride, dwc_conv2d_attr.depth_multiplier(), dwc_conv2d_attr.dilation());
+        auto          dst_nchw =
+            reference::depthwise_convolution(src_nchw, weights_nchw, bias_nchw, output_shape_nchw, legacy_pad_stride,
+                                             dwc_conv2d_attr.depth_multiplier(), dwc_conv2d_attr.dilation());
         return dst_nchw;
     }
 
@@ -209,16 +224,23 @@
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionGpuDepthwiseConv2dValidationFixture : public DynamicFusionGpuDepthwiseConv2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionGpuDepthwiseConv2dValidationFixture
+    : public DynamicFusionGpuDepthwiseConv2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
-    void setup(TensorShape input_shape, Size2D kernel_size, const PadStrideInfo &info, const Size2D &dilation, const unsigned int depth_multiplier, DataType data_type, DataLayout data_layout)
+    void setup(TensorShape          input_shape,
+               Size2D               kernel_size,
+               const PadStrideInfo &info,
+               const Size2D        &dilation,
+               const unsigned int   depth_multiplier,
+               DataType             data_type,
+               DataLayout           data_layout)
     {
-        DynamicFusionGpuDepthwiseConv2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, kernel_size, info, dilation,
-                                                                                                                  depth_multiplier, data_type, data_layout);
+        DynamicFusionGpuDepthwiseConv2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            input_shape, kernel_size, info, dilation, depth_multiplier, data_type, data_layout);
     }
 };
 } // namespace validation
 } // namespace test
 } // namespace arm_compute
-#endif /* TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DEPTHWISECONV2DFIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DEPTHWISECONV2DFIXTURE_H
diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h
index e30a564..1f4e223 100644
--- a/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,13 +21,12 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DIRECTCONV2DFIXTURE
-#define TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DIRECTCONV2DFIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DIRECTCONV2DFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DIRECTCONV2DFIXTURE_H
 
 #include "arm_compute/core/CL/CLKernelLibrary.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
-
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h"
 #include "arm_compute/dynamic_fusion/sketch/attributes/Conv2dAttributes.h"
@@ -38,9 +37,9 @@
 #include "tests/CL/CLAccessor.h"
 #include "tests/framework/Fixture.h"
 #include "tests/framework/Macros.h"
-#include "tests/validation/Validation.h"
 #include "tests/validation/reference/ConvolutionLayer.h"
 #include "tests/validation/reference/Permute.h"
+#include "tests/validation/Validation.h"
 
 using namespace arm_compute::experimental::dynamic_fusion;
 
@@ -55,11 +54,11 @@
 template <typename U>
 void fill(U &&tensor, int i)
 {
-    switch(tensor.data_type())
+    switch (tensor.data_type())
     {
         case DataType::F16:
         {
-            arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
+            arm_compute::utils::uniform_real_distribution_16bit<half> distribution{-1.0f, 1.0f};
             library->fill(tensor, distribution, i);
             break;
         }
@@ -84,12 +83,21 @@
 class DynamicFusionGpuConv2dValidationGenericFixture : public framework::Fixture
 {
 public:
-    using TBias = typename std::conditional < std::is_same<typename std::decay<T>::type, uint8_t>::value
-                  || std::is_same<typename std::decay<T>::type, int8_t>::value,
-                  int32_t, T >::type; // If T: uint8_t or int8_t then TBias: int32_t, otherwise TBias: T
+    using TBias = typename std::conditional<std::is_same<typename std::decay<T>::type, uint8_t>::value ||
+                                                std::is_same<typename std::decay<T>::type, int8_t>::value,
+                                            int32_t,
+                                            T>::type; // If T: uint8_t or int8_t then TBias: int32_t, otherwise TBias: T
 
-    void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, const PadStrideInfo &info, const Size2D &dilation, DataType data_type,
-               DataLayout data_layout, QuantizationInfo quantization_info, QuantizationInfo weight_quantization_info)
+    void setup(TensorShape          input_shape,
+               TensorShape          weights_shape,
+               TensorShape          bias_shape,
+               TensorShape          output_shape,
+               const PadStrideInfo &info,
+               const Size2D        &dilation,
+               DataType             data_type,
+               DataLayout           data_layout,
+               QuantizationInfo     quantization_info,
+               QuantizationInfo     weight_quantization_info)
     {
         ARM_COMPUTE_ERROR_ON(data_layout != DataLayout::NHWC); // Dynamic fusion conv2d only supports NHWC layout
         const Conv2dAttributes conv2d_attr = convert_pad_stride_info_to_conv_attr(info, dilation);
@@ -100,12 +108,15 @@
         _weight_quantization_info          = weight_quantization_info;
         _bias_data_type                    = _is_quantized ? DataType::S32 : data_type;
         _target                            = compute_target(input_shape, weights_shape, bias_shape, conv2d_attr);
-        _reference                         = compute_reference(input_shape, weights_shape, bias_shape, output_shape, conv2d_attr);
+        _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, conv2d_attr);
     }
 
 protected:
     // Given input is in nchw format
-    TensorType compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape &bias_shape, Conv2dAttributes conv2d_attr)
+    TensorType compute_target(TensorShape        input_shape,
+                              TensorShape        weights_shape,
+                              const TensorShape &bias_shape,
+                              Conv2dAttributes   conv2d_attr)
     {
         ARM_COMPUTE_ERROR_ON(_data_layout != DataLayout::NHWC);
         permute(input_shape, PermutationVector(2U, 0U, 1U));
@@ -114,23 +125,23 @@
 
         // Create a new workload sketch
         auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-        auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-        GpuWorkloadSketch sketch{ &context };
+        auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+        GpuWorkloadSketch sketch{&context};
 
         // Create sketch tensors
-        TensorInfo input_info  = context.create_tensor_info(TensorInfo(input_shape, 1, _data_type, _data_layout));
-        TensorInfo weight_info = context.create_tensor_info(TensorInfo(weights_shape, 1, _data_type, _data_layout));
-        TensorInfo bias_info   = context.create_tensor_info(TensorInfo(bias_shape, 1, _data_type, _data_layout));
-        TensorInfo dst_info    = context.create_tensor_info();
+        ITensorInfo *input_info  = context.create_tensor_info(TensorInfo(input_shape, 1, _data_type, _data_layout));
+        ITensorInfo *weight_info = context.create_tensor_info(TensorInfo(weights_shape, 1, _data_type, _data_layout));
+        ITensorInfo *bias_info   = context.create_tensor_info(TensorInfo(bias_shape, 1, _data_type, _data_layout));
+        ITensorInfo *dst_info    = context.create_tensor_info();
 
-        ITensorInfo *ans_info = FunctionType::create_op(sketch, &input_info, &weight_info, &bias_info, conv2d_attr);
-        GpuOutput::create_op(sketch, ans_info, &dst_info);
+        ITensorInfo *ans_info = FunctionType::create_op(sketch, input_info, weight_info, bias_info, conv2d_attr);
+        GpuOutput::create_op(sketch, ans_info, dst_info);
 
         // Configure runtime
         ClWorkloadRuntime runtime;
         runtime.configure(sketch);
         // (Important) Allocate auxiliary tensor memory if there are any
-        for(auto &data : runtime.get_auxiliary_tensors())
+        for (auto &data : runtime.get_auxiliary_tensors())
         {
             CLTensor     *tensor      = std::get<0>(data);
             TensorInfo    info        = std::get<1>(data);
@@ -145,10 +156,10 @@
         TensorType t_dst{};
 
         // Initialize user tensors
-        t_input.allocator()->init(input_info);
-        t_weight.allocator()->init(weight_info);
-        t_bias.allocator()->init(bias_info);
-        t_dst.allocator()->init(dst_info);
+        t_input.allocator()->init(*input_info);
+        t_weight.allocator()->init(*weight_info);
+        t_bias.allocator()->init(*bias_info);
+        t_dst.allocator()->init(*dst_info);
 
         // Allocate and fill user tensors
         t_input.allocator()->allocate();
@@ -161,17 +172,20 @@
         fill(AccessorType(t_bias), 2);
 
         // Run runtime
-        runtime.run({ &t_input, &t_weight, &t_bias, &t_dst });
+        runtime.run({&t_input, &t_weight, &t_bias, &t_dst});
         return t_dst;
     }
 
-    SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape,
-                                      const TensorShape &output_shape, Conv2dAttributes conv2d_attr)
+    SimpleTensor<T> compute_reference(const TensorShape &input_shape,
+                                      const TensorShape &weights_shape,
+                                      const TensorShape &bias_shape,
+                                      const TensorShape &output_shape,
+                                      Conv2dAttributes   conv2d_attr)
     {
         // Create reference
-        SimpleTensor<T>     src{ input_shape, _data_type, 1, _quantization_info };
-        SimpleTensor<T>     weight{ weights_shape, _data_type, 1, _weight_quantization_info };
-        SimpleTensor<TBias> bias{ bias_shape, _data_type, 1, _quantization_info };
+        SimpleTensor<T>     src{input_shape, _data_type, 1, _quantization_info};
+        SimpleTensor<T>     weight{weights_shape, _data_type, 1, _weight_quantization_info};
+        SimpleTensor<TBias> bias{bias_shape, _data_type, 1, _quantization_info};
 
         fill(src, 0);
         fill(weight, 1);
@@ -182,9 +196,11 @@
         auto bias_nchw         = bias;
         auto output_shape_nchw = output_shape;
 
-        PadStrideInfo legacy_pad_stride(conv2d_attr.stride().x(), conv2d_attr.stride().y(), conv2d_attr.pad().left, conv2d_attr.pad().right, conv2d_attr.pad().top, conv2d_attr.pad().bottom,
+        PadStrideInfo legacy_pad_stride(conv2d_attr.stride().x(), conv2d_attr.stride().y(), conv2d_attr.pad().left,
+                                        conv2d_attr.pad().right, conv2d_attr.pad().top, conv2d_attr.pad().bottom,
                                         DimensionRoundingType{});
-        auto dst_nchw = reference::convolution_layer(src_nchw, weights_nchw, bias_nchw, output_shape_nchw, legacy_pad_stride, conv2d_attr.dilation());
+        auto          dst_nchw = reference::convolution_layer(src_nchw, weights_nchw, bias_nchw, output_shape_nchw,
+                                                              legacy_pad_stride, conv2d_attr.dilation());
         return dst_nchw;
     }
 
@@ -199,14 +215,23 @@
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionGpuConv2dValidationFixture : public DynamicFusionGpuConv2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionGpuConv2dValidationFixture
+    : public DynamicFusionGpuConv2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
-    void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape output_shape, TensorShape bias_shape,
-               const PadStrideInfo &info, const Size2D &dialation, DataType data_type, DataLayout data_layout, QuantizationInfo quantization_info)
+    void setup(TensorShape          input_shape,
+               TensorShape          weights_shape,
+               TensorShape          output_shape,
+               TensorShape          bias_shape,
+               const PadStrideInfo &info,
+               const Size2D        &dialation,
+               DataType             data_type,
+               DataLayout           data_layout,
+               QuantizationInfo     quantization_info)
     {
-        DynamicFusionGpuConv2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, output_shape, bias_shape, info, dialation,
-                                                                                                         data_type, data_layout, quantization_info, quantization_info);
+        DynamicFusionGpuConv2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            input_shape, weights_shape, output_shape, bias_shape, info, dialation, data_type, data_layout,
+            quantization_info, quantization_info);
     }
 };
 
@@ -218,10 +243,19 @@
 class DynamicFusionDirectConv2dValidationGenericFixture : public framework::Fixture
 {
 public:
-    using TBias = typename std::conditional < std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int32_t, T >::type;
+    using TBias =
+        typename std::conditional<std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int32_t, T>::type;
 
-    void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels,
-               DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout)
+    void setup(TensorShape      input_shape,
+               int              stride_x,
+               int              stride_y,
+               int              pad_x,
+               int              pad_y,
+               unsigned int     kernel_size,
+               unsigned int     num_kernels,
+               DataType         data_type,
+               QuantizationInfo quantization_info,
+               DataLayout       data_layout)
     {
         ARM_COMPUTE_ERROR_ON(data_layout != DataLayout::NHWC); // Dynamic fusion conv2d only supports NHWC layout
 
@@ -230,20 +264,30 @@
         const PadStrideInfo info(stride_x, stride_y, pad_x, pad_y, DimensionRoundingType::FLOOR);
         const DataType      bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type;
 
-        const Conv2dAttributes conv2d_attr = convert_pad_stride_info_to_conv_attr(info, { 1U, 1U } /* dilation */);
+        const Conv2dAttributes conv2d_attr = convert_pad_stride_info_to_conv_attr(info, {1U, 1U} /* dilation */);
 
         TensorInfo input_info   = TensorInfo(input_shape, 1, data_type);
         TensorInfo weights_info = TensorInfo(weights_shape, 1, data_type);
 
-        const TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_info, weights_info, info);
+        const TensorShape output_shape =
+            misc::shape_calculator::compute_deep_convolution_shape(input_info, weights_info, info);
 
-        _target    = compute_target(input_shape, weights_shape, bias_shape, output_shape, conv2d_attr, data_type, bias_data_type, quantization_info, data_layout);
-        _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, quantization_info);
+        _target    = compute_target(input_shape, weights_shape, bias_shape, output_shape, conv2d_attr, data_type,
+                                    bias_data_type, quantization_info, data_layout);
+        _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type,
+                                       bias_data_type, quantization_info);
     }
 
 protected:
-    TensorType compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape &bias_shape, TensorShape output_shape, const Conv2dAttributes &conv2d_attr,
-                              DataType data_type, DataType bias_data_type, QuantizationInfo quantization_info, const DataLayout &data_layout)
+    TensorType compute_target(TensorShape             input_shape,
+                              TensorShape             weights_shape,
+                              const TensorShape      &bias_shape,
+                              TensorShape             output_shape,
+                              const Conv2dAttributes &conv2d_attr,
+                              DataType                data_type,
+                              DataType                bias_data_type,
+                              QuantizationInfo        quantization_info,
+                              const DataLayout       &data_layout)
     {
         ARM_COMPUTE_ERROR_ON(data_layout != DataLayout::NHWC);
         ARM_COMPUTE_UNUSED(quantization_info);
@@ -253,8 +297,8 @@
         permute(output_shape, PermutationVector(2U, 0U, 1U));
 
         auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-        auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-        GpuWorkloadSketch sketch{ &context };
+        auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+        GpuWorkloadSketch sketch{&context};
 
         // Create sketch tensors
         auto input_info  = context.create_tensor_info(TensorInfo(input_shape, 1, data_type, data_layout));
@@ -262,14 +306,14 @@
         auto bias_info   = context.create_tensor_info(TensorInfo(bias_shape, 1, bias_data_type, data_layout));
         auto dst_info    = context.create_tensor_info();
 
-        ITensorInfo *ans_info = FunctionType::create_op(sketch, &input_info, &weight_info, &bias_info, conv2d_attr);
-        GpuOutput::create_op(sketch, ans_info, &dst_info);
+        ITensorInfo *ans_info = FunctionType::create_op(sketch, input_info, weight_info, bias_info, conv2d_attr);
+        GpuOutput::create_op(sketch, ans_info, dst_info);
 
         // Configure runtime
         ClWorkloadRuntime runtime;
         runtime.configure(sketch);
 
-        for(auto &data : runtime.get_auxiliary_tensors())
+        for (auto &data : runtime.get_auxiliary_tensors())
         {
             CLTensor     *tensor      = std::get<0>(data);
             TensorInfo    info        = std::get<1>(data);
@@ -284,10 +328,10 @@
         TensorType t_dst{};
 
         // Initialize user tensors
-        t_input.allocator()->init(input_info);
-        t_weight.allocator()->init(weight_info);
-        t_bias.allocator()->init(bias_info);
-        t_dst.allocator()->init(dst_info);
+        t_input.allocator()->init(*input_info);
+        t_weight.allocator()->init(*weight_info);
+        t_bias.allocator()->init(*bias_info);
+        t_dst.allocator()->init(*dst_info);
 
         ARM_COMPUTE_ASSERT(t_input.info()->is_resizable());
         ARM_COMPUTE_ASSERT(t_weight.info()->is_resizable());
@@ -310,17 +354,23 @@
         fill(AccessorType(t_bias), 2);
 
         // Run runtime
-        runtime.run({ &t_input, &t_weight, &t_bias, &t_dst });
+        runtime.run({&t_input, &t_weight, &t_bias, &t_dst});
         return t_dst;
     }
 
-    SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
-                                      DataType data_type, DataType bias_data_type, QuantizationInfo quantization_info)
+    SimpleTensor<T> compute_reference(const TensorShape   &input_shape,
+                                      const TensorShape   &weights_shape,
+                                      const TensorShape   &bias_shape,
+                                      const TensorShape   &output_shape,
+                                      const PadStrideInfo &info,
+                                      DataType             data_type,
+                                      DataType             bias_data_type,
+                                      QuantizationInfo     quantization_info)
     {
         // Create reference
-        SimpleTensor<T>     src{ input_shape, data_type, 1, quantization_info };
-        SimpleTensor<T>     weights{ weights_shape, data_type, 1, quantization_info };
-        SimpleTensor<TBias> bias{ bias_shape, bias_data_type, 1, quantization_info };
+        SimpleTensor<T>     src{input_shape, data_type, 1, quantization_info};
+        SimpleTensor<T>     weights{weights_shape, data_type, 1, quantization_info};
+        SimpleTensor<TBias> bias{bias_shape, bias_data_type, 1, quantization_info};
 
         // Fill reference
         fill(src, 0);
@@ -335,19 +385,27 @@
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionDirectConv2dValidationFixture : public DynamicFusionDirectConv2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionDirectConv2dValidationFixture
+    : public DynamicFusionDirectConv2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
-    void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type,
-               DataLayout data_layout)
+    void setup(TensorShape  input_shape,
+               int          stride_x,
+               int          stride_y,
+               int          pad_x,
+               int          pad_y,
+               unsigned int kernel_size,
+               unsigned int num_kernels,
+               DataType     data_type,
+               DataLayout   data_layout)
     {
-        DynamicFusionDirectConv2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type,
-                                                                                                            QuantizationInfo(),
-                                                                                                            data_layout);
+        DynamicFusionDirectConv2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, QuantizationInfo(),
+            data_layout);
     }
 };
 
 } // namespace validation
 } // namespace test
 } // namespace arm_compute
-#endif /* TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DIRECTCONV2DFIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_DIRECTCONV2DFIXTURE_H
diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h
index 567322f..69bd0ef 100644
--- a/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_ELEMENTWISEBINARYFIXTURE
-#define TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_ELEMENTWISEBINARYFIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_ELEMENTWISEBINARYFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_ELEMENTWISEBINARYFIXTURE_H
 
 #include "arm_compute/core/CL/CLKernelLibrary.h"
 #include "arm_compute/core/TensorInfo.h"
@@ -47,9 +47,15 @@
 class DynamicFusionGpuElementwiseBinaryValidationGenericFixture : public framework::Fixture
 {
 public:
-    void setup(ArithmeticOperation ref_op, const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2, DataType data_type, bool is_inplace, bool fuse_two_ops = false)
+    void setup(ArithmeticOperation ref_op,
+               const TensorShape  &shape0,
+               const TensorShape  &shape1,
+               const TensorShape  &shape2,
+               DataType            data_type,
+               bool                is_inplace,
+               bool                fuse_two_ops = false)
     {
-        _ref_op         = ref_op;
+        _ref_op     = ref_op;
         _is_inplace = is_inplace;
         _data_type  = data_type;
         _fuse       = fuse_two_ops;
@@ -63,12 +69,12 @@
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        if(is_data_type_float(tensor.data_type()))
+        if (is_data_type_float(tensor.data_type()))
         {
-            switch(_ref_op)
+            switch (_ref_op)
             {
                 case ArithmeticOperation::DIV:
-                    library->fill_tensor_uniform_ranged(tensor, i, { std::pair<float, float>(-0.001f, 0.001f) });
+                    library->fill_tensor_uniform_ranged(tensor, i, {std::pair<float, float>(-0.001f, 0.001f)});
                     break;
                 case ArithmeticOperation::POWER:
                     library->fill_tensor_uniform(tensor, i, 0.0f, 5.0f);
@@ -77,12 +83,12 @@
                     library->fill_tensor_uniform(tensor, i);
             }
         }
-        else if(tensor.data_type() == DataType::S32)
+        else if (tensor.data_type() == DataType::S32)
         {
-            switch(_ref_op)
+            switch (_ref_op)
             {
                 case ArithmeticOperation::DIV:
-                    library->fill_tensor_uniform_ranged(tensor, i, { std::pair<int32_t, int32_t>(-1U, 1U) });
+                    library->fill_tensor_uniform_ranged(tensor, i, {std::pair<int32_t, int32_t>(-1U, 1U)});
                     break;
                 default:
                     library->fill_tensor_uniform(tensor, i);
@@ -98,27 +104,27 @@
     {
         // Create a new workload sketch
         auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-        auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-        GpuWorkloadSketch sketch{ &context };
+        auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+        GpuWorkloadSketch sketch{&context};
 
         // Fuse first element wise binary Op
-        TensorInfo lhs_info = context.create_tensor_info(TensorInfo(shape0, 1, _data_type));
-        TensorInfo rhs_info = context.create_tensor_info(TensorInfo(shape1, 1, _data_type));
-        TensorInfo dst_info = context.create_tensor_info();
+        ITensorInfo *lhs_info = context.create_tensor_info(TensorInfo(shape0, 1, _data_type));
+        ITensorInfo *rhs_info = context.create_tensor_info(TensorInfo(shape1, 1, _data_type));
+        ITensorInfo *dst_info = context.create_tensor_info();
 
-        TensorInfo rhs_info_fuse;
+        ITensorInfo *rhs_info_fuse = nullptr;
 
-        ITensorInfo *ans_info = FunctionType::create_op(sketch, &lhs_info, &rhs_info);
+        ITensorInfo *ans_info = FunctionType::create_op(sketch, lhs_info, rhs_info);
 
-        if(_fuse)
+        if (_fuse)
         {
             rhs_info_fuse          = context.create_tensor_info(TensorInfo(shape2, 1, _data_type));
-            ITensorInfo *ans2_info = FunctionType::create_op(sketch, ans_info, &rhs_info_fuse);
-            GpuOutput::create_op(sketch, ans2_info, &dst_info);
+            ITensorInfo *ans2_info = FunctionType::create_op(sketch, ans_info, rhs_info_fuse);
+            GpuOutput::create_op(sketch, ans2_info, dst_info);
         }
         else
         {
-            GpuOutput::create_op(sketch, ans_info, &dst_info);
+            GpuOutput::create_op(sketch, ans_info, dst_info);
         }
 
         // Configure runtime
@@ -126,7 +132,7 @@
         runtime.configure(sketch);
 
         // (Important) Allocate auxiliary tensor memory if there are any
-        for(auto &data : runtime.get_auxiliary_tensors())
+        for (auto &data : runtime.get_auxiliary_tensors())
         {
             CLTensor     *tensor      = std::get<0>(data);
             TensorInfo    info        = std::get<1>(data);
@@ -142,12 +148,12 @@
         TensorType t_dst{};
 
         // Initialize user tensors
-        t_lhs.allocator()->init(lhs_info);
-        t_rhs.allocator()->init(rhs_info);
-        t_dst.allocator()->init(dst_info);
-        if(_fuse)
+        t_lhs.allocator()->init(*lhs_info);
+        t_rhs.allocator()->init(*rhs_info);
+        t_dst.allocator()->init(*dst_info);
+        if (_fuse)
         {
-            t_rhs_fuse.allocator()->init(rhs_info_fuse);
+            t_rhs_fuse.allocator()->init(*rhs_info_fuse);
         }
 
         // Allocate and fill user tensors
@@ -155,26 +161,26 @@
         t_lhs.allocator()->allocate();
         t_rhs.allocator()->allocate();
         t_dst.allocator()->allocate();
-        if(_fuse)
+        if (_fuse)
         {
             t_rhs_fuse.allocator()->allocate();
         }
 
         fill(AccessorType(t_lhs), 0);
         fill(AccessorType(t_rhs), 1);
-        if(_fuse)
+        if (_fuse)
         {
             fill(AccessorType(t_rhs_fuse), 2);
         }
 
         // Run runtime
-        if(_fuse)
+        if (_fuse)
         {
-            runtime.run({ &t_lhs, &t_rhs, &t_rhs_fuse, &t_dst });
+            runtime.run({&t_lhs, &t_rhs, &t_rhs_fuse, &t_dst});
         }
         else
         {
-            runtime.run({ &t_lhs, &t_rhs, &t_dst });
+            runtime.run({&t_lhs, &t_rhs, &t_dst});
         }
 
         return t_dst;
@@ -186,18 +192,18 @@
         const TensorShape out_shape_fuse = TensorShape::broadcast_shape(out_shape, shape1);
 
         // Create reference
-        SimpleTensor<T> ref_lhs{ shape0, _data_type, 1, QuantizationInfo() };
-        SimpleTensor<T> ref_rhs{ shape1, _data_type, 1, QuantizationInfo() };
-        SimpleTensor<T> ref_rhs_fuse{ shape2, _data_type, 1, QuantizationInfo() };
-        SimpleTensor<T> ref_dst{ out_shape, _data_type, 1, QuantizationInfo() };
-        SimpleTensor<T> ref_dst_fuse{ out_shape_fuse, _data_type, 1, QuantizationInfo() };
+        SimpleTensor<T> ref_lhs{shape0, _data_type, 1, QuantizationInfo()};
+        SimpleTensor<T> ref_rhs{shape1, _data_type, 1, QuantizationInfo()};
+        SimpleTensor<T> ref_rhs_fuse{shape2, _data_type, 1, QuantizationInfo()};
+        SimpleTensor<T> ref_dst{out_shape, _data_type, 1, QuantizationInfo()};
+        SimpleTensor<T> ref_dst_fuse{out_shape_fuse, _data_type, 1, QuantizationInfo()};
 
         // Fill reference
         fill(ref_lhs, 0);
         fill(ref_rhs, 1);
 
         reference::arithmetic_operation<T>(_ref_op, ref_lhs, ref_rhs, ref_dst, ConvertPolicy::WRAP);
-        if(_fuse)
+        if (_fuse)
         {
             fill(ref_rhs_fuse, 2);
             reference::arithmetic_operation<T>(_ref_op, ref_dst, ref_rhs_fuse, ref_dst_fuse, ConvertPolicy::WRAP);
@@ -206,46 +212,62 @@
         return *ret;
     }
 
-    ArithmeticOperation _ref_op{ ArithmeticOperation::ADD };
+    ArithmeticOperation _ref_op{ArithmeticOperation::ADD};
     TensorType          _target{};
     SimpleTensor<T>     _reference{};
     DataType            _data_type{};
     DataLayout          _data_layout{};
-    bool                _is_inplace{ false };
-    bool                _fuse{ false };
+    bool                _is_inplace{false};
+    bool                _fuse{false};
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionGpuElementwiseBinaryOneOpValidationFixture : public DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionGpuElementwiseBinaryOneOpValidationFixture
+    : public DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
     void setup(ArithmeticOperation ref_op, const TensorShape &shape0, DataType data_type, bool is_inplace)
     {
-        DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ref_op, shape0, shape0, TensorShape(), data_type, is_inplace);
+        DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            ref_op, shape0, shape0, TensorShape(), data_type, is_inplace);
     }
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture : public DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionGpuElementwiseBinaryBroadcastOneOpValidationFixture
+    : public DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
-    void setup(ArithmeticOperation ref_op, const TensorShape &shape0, const TensorShape &shape1, DataType data_type, bool is_inplace)
+    void setup(ArithmeticOperation ref_op,
+               const TensorShape  &shape0,
+               const TensorShape  &shape1,
+               DataType            data_type,
+               bool                is_inplace)
     {
-        DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ref_op, shape0, shape1, TensorShape(), data_type, is_inplace);
+        DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            ref_op, shape0, shape1, TensorShape(), data_type, is_inplace);
     }
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture : public DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionGpuElementwiseBinaryTwoOpsValidationFixture
+    : public DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
-    void setup(ArithmeticOperation ref_op, const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2, DataType data_type, bool is_inplace, bool fuse_two_ops)
+    void setup(ArithmeticOperation ref_op,
+               const TensorShape  &shape0,
+               const TensorShape  &shape1,
+               const TensorShape  &shape2,
+               DataType            data_type,
+               bool                is_inplace,
+               bool                fuse_two_ops)
     {
-        DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(ref_op, shape0, shape1, shape2, data_type, is_inplace, fuse_two_ops);
+        DynamicFusionGpuElementwiseBinaryValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            ref_op, shape0, shape1, shape2, data_type, is_inplace, fuse_two_ops);
     }
 };
 
 } // namespace validation
 } // namespace test
 } // namespace arm_compute
-#endif /* TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_ELEMENTWISEBINARYFIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_GPU_CL_ELEMENTWISEBINARYFIXTURE_H
diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/MatMulKernelFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/MatMulKernelFixture.h
index c6ac4b9..65a3363 100644
--- a/tests/validation/fixtures/dynamic_fusion/gpu/cl/MatMulKernelFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/MatMulKernelFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,7 +28,6 @@
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
-
 #include "arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h"
 #include "arm_compute/dynamic_fusion/sketch/attributes/MatMulAttributes.h"
 #include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
@@ -39,10 +38,10 @@
 #include "tests/framework/Fixture.h"
 #include "tests/framework/Macros.h"
 #include "tests/validation/Helpers.h"
-#include "tests/validation/Validation.h"
 #include "tests/validation/reference/GEMM.h"
 #include "tests/validation/reference/Permute.h"
 #include "tests/validation/reference/ReshapeLayer.h"
+#include "tests/validation/Validation.h"
 
 using namespace arm_compute::experimental::dynamic_fusion;
 
@@ -57,11 +56,11 @@
 template <typename U>
 void fill(U &&tensor, int i)
 {
-    switch(tensor.data_type())
+    switch (tensor.data_type())
     {
         case DataType::F16:
         {
-            arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
+            arm_compute::utils::uniform_real_distribution_16bit<half> distribution{-1.0f, 1.0f};
             library->fill(tensor, distribution, i);
             break;
         }
@@ -80,67 +79,83 @@
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
 class DynamicFusionGpuMatMulValidationGenericFixture : public framework::Fixture
 {
-
 public:
-    void setup(TensorShape lhs_shape, TensorShape rhs_shape, TensorShape output_shape, bool transpose_a, bool transpose_b,
-    int M0, int N0, int K0, bool export_rhs_to_cl_image, DataType data_type)
+    void setup(TensorShape lhs_shape,
+               TensorShape rhs_shape,
+               TensorShape output_shape,
+               bool        transpose_a,
+               bool        transpose_b,
+               int         M0,
+               int         N0,
+               int         K0,
+               bool        export_rhs_to_cl_image,
+               DataType    data_type)
     {
         //For brevity, the input shapes are assumed to be not-transposed for both a and b matrices.
-        if(transpose_a)
+        if (transpose_a)
         {
             permute(lhs_shape, PermutationVector(1U, 0U));
         }
-        if(transpose_b)
+        if (transpose_b)
         {
             permute(rhs_shape, PermutationVector(1U, 0U));
         }
 
         // Skip configurations unsupported by the device.
         _device_supports_export_to_cl_image = image2d_from_buffer_supported(CLKernelLibrary::get().get_device());
-        if(!_device_supports_export_to_cl_image && export_rhs_to_cl_image)
+        if (!_device_supports_export_to_cl_image && export_rhs_to_cl_image)
         {
             ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
             framework::ARM_COMPUTE_PRINT_INFO();
             return; // Note: Also need to skip the validate in corresponding FIXTURE_DATA_TEST_CASEs.
         }
 
-        _target    = compute_target(lhs_shape, rhs_shape, transpose_a, transpose_b, M0, N0, K0, export_rhs_to_cl_image, data_type);
+        _target    = compute_target(lhs_shape, rhs_shape, transpose_a, transpose_b, M0, N0, K0, export_rhs_to_cl_image,
+                                    data_type);
         _reference = compute_reference(lhs_shape, rhs_shape, output_shape, transpose_a, transpose_b, data_type);
     }
 
 protected:
-    TensorType compute_target(TensorShape &shape_a, TensorShape &shape_b, bool transpose_a, bool transpose_b, int M0, int N0, int K0, bool export_rhs_to_cl_image, DataType data_type)
+    TensorType compute_target(TensorShape &shape_a,
+                              TensorShape &shape_b,
+                              bool         transpose_a,
+                              bool         transpose_b,
+                              int          M0,
+                              int          N0,
+                              int          K0,
+                              bool         export_rhs_to_cl_image,
+                              DataType     data_type)
     {
         ARM_COMPUTE_UNUSED(export_rhs_to_cl_image);
         CLScheduler::get().default_reinit();
 
         // Create a new workload sketch
         auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-        auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-        GpuWorkloadSketch sketch{ &context };
+        auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+        GpuWorkloadSketch sketch{&context};
 
         // Create sketch tensors
-        TensorInfo lhs_info    = context.create_tensor_info(TensorInfo(shape_a, 1, data_type));
-        TensorInfo rhs_info    = context.create_tensor_info(TensorInfo(shape_b, 1, data_type));
-        TensorInfo dst_info    = context.create_tensor_info();
+        ITensorInfo *lhs_info = context.create_tensor_info(TensorInfo(shape_a, 1, data_type));
+        ITensorInfo *rhs_info = context.create_tensor_info(TensorInfo(shape_b, 1, data_type));
+        ITensorInfo *dst_info = context.create_tensor_info();
 
-        MatMulAttributes matmul_attr {};
+        MatMulAttributes matmul_attr{};
         matmul_attr.adj_lhs(transpose_a);
         matmul_attr.adj_rhs(transpose_b);
 
-        GpuMatMulSettings matmul_settings {};
+        GpuMatMulSettings matmul_settings{};
         matmul_settings.m0(M0);
         matmul_settings.n0(N0);
         matmul_settings.k0(K0);
 
-        ITensorInfo *ans_info = FunctionType::create_op(sketch, &lhs_info, &rhs_info, matmul_attr, matmul_settings);
-        GpuOutput::create_op(sketch, ans_info, &dst_info);
+        ITensorInfo *ans_info = FunctionType::create_op(sketch, lhs_info, rhs_info, matmul_attr, matmul_settings);
+        GpuOutput::create_op(sketch, ans_info, dst_info);
 
         // Configure runtime
         ClWorkloadRuntime runtime;
         runtime.configure(sketch);
 
-        for(auto &data : runtime.get_auxiliary_tensors())
+        for (auto &data : runtime.get_auxiliary_tensors())
         {
             CLTensor     *tensor      = std::get<0>(data);
             TensorInfo    info        = std::get<1>(data);
@@ -155,9 +170,9 @@
         TensorType t_dst{};
 
         // Initialize user tensors
-        t_lhs.allocator()->init(lhs_info);
-        t_rhs.allocator()->init(rhs_info);
-        t_dst.allocator()->init(dst_info);
+        t_lhs.allocator()->init(*lhs_info);
+        t_rhs.allocator()->init(*rhs_info);
+        t_dst.allocator()->init(*dst_info);
 
         ARM_COMPUTE_ASSERT(t_lhs.info()->is_resizable());
         ARM_COMPUTE_ASSERT(t_rhs.info()->is_resizable());
@@ -176,12 +191,17 @@
         fill(AccessorType(t_rhs), 1);
 
         // Run runtime
-        runtime.run({ &t_lhs, &t_rhs, &t_dst });
+        runtime.run({&t_lhs, &t_rhs, &t_dst});
 
         return t_dst;
     }
 
-    SimpleTensor<T> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, bool pretranspose_a, bool pretranspose_b, DataType data_type)
+    SimpleTensor<T> compute_reference(const TensorShape &shape_a,
+                                      const TensorShape &shape_b,
+                                      const TensorShape &output_shape,
+                                      bool               pretranspose_a,
+                                      bool               pretranspose_b,
+                                      DataType           data_type)
     {
         // We collapse dimensions > 3 onto dimension 3, i.e. 5D+ tensors will look like 4D
         // This is necessary unless we choose to extend gemm reference for 5D+ tensors
@@ -190,9 +210,9 @@
         TensorShape shape_b_collapsed      = shape_b.collapsed_from(Window::DimZ);
 
         // Create reference
-        SimpleTensor<T> a{ shape_a_collapsed, data_type, 1 };
-        SimpleTensor<T> b{ shape_b_collapsed, data_type, 1 };
-        SimpleTensor<T> c{ output_shape_collapsed, data_type, 1 };
+        SimpleTensor<T> a{shape_a_collapsed, data_type, 1};
+        SimpleTensor<T> b{shape_b_collapsed, data_type, 1};
+        SimpleTensor<T> c{output_shape_collapsed, data_type, 1};
 
         // Fill reference
         fill(a, 0);
@@ -213,27 +233,27 @@
         b_transposed_shape.set(1, b.shape().x());
 
         // Define transposed tensors
-        SimpleTensor<T> a_transposed{ a_transposed_shape, data_type };
-        SimpleTensor<T> b_transposed{ b_transposed_shape, data_type };
+        SimpleTensor<T> a_transposed{a_transposed_shape, data_type};
+        SimpleTensor<T> b_transposed{b_transposed_shape, data_type};
 
         //pretranspose a if necessary
-        if(pretranspose_a)
+        if (pretranspose_a)
         {
             a_transposed = reference::permute<T>(a, PermutationVector(1U, 0U));
         }
 
         // pretranspose b if necessary
-        if(pretranspose_b)
+        if (pretranspose_b)
         {
             b_transposed = reference::permute<T>(b, PermutationVector(1U, 0U));
         }
 
         // Use transposed tensors if boolean enabled else use original tensors
-        SimpleTensor<T> result = reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, 1.0f, 0.f);
-
+        SimpleTensor<T> result =
+            reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, 1.0f, 0.f);
 
         // We reshape the gemm output back if the tensor is high dimensional
-        if(output_shape_collapsed != output_shape)
+        if (output_shape_collapsed != output_shape)
         {
             // std::cout << "called reshape: \n";
             result = reference::reshape_layer(result, output_shape);
@@ -244,20 +264,30 @@
 
     CLTensor        _target{};
     SimpleTensor<T> _reference{};
-    bool            _device_supports_export_to_cl_image{ false };
-    bool            _device_supports_mmul{ false };
+    bool            _device_supports_export_to_cl_image{false};
+    bool            _device_supports_mmul{false};
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionGpuMatMulValidationFixture : public DynamicFusionGpuMatMulValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionGpuMatMulValidationFixture
+    : public DynamicFusionGpuMatMulValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
 {
-    public:
-    void setup(TensorShape lhs_shape, TensorShape rhs_shape, TensorShape output_shape, bool transpose_a, bool transpose_b,
-    int M0, int N0, int K0, bool export_rhs_to_cl_image, DataType data_type)
+public:
+    void setup(TensorShape lhs_shape,
+               TensorShape rhs_shape,
+               TensorShape output_shape,
+               bool        transpose_a,
+               bool        transpose_b,
+               int         M0,
+               int         N0,
+               int         K0,
+               bool        export_rhs_to_cl_image,
+               DataType    data_type)
     {
         ARM_COMPUTE_UNUSED(export_rhs_to_cl_image);
-        DynamicFusionGpuMatMulValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(lhs_shape, rhs_shape, output_shape, transpose_a, transpose_b, M0,
-        N0, K0, false /* export_rhs_to_cl_image bias */, data_type);
+        DynamicFusionGpuMatMulValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            lhs_shape, rhs_shape, output_shape, transpose_a, transpose_b, M0, N0, K0,
+            false /* export_rhs_to_cl_image bias */, data_type);
     }
 };
 
diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/Pool2dFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/Pool2dFixture.h
index 34f2647..dd3519b 100644
--- a/tests/validation/fixtures/dynamic_fusion/gpu/cl/Pool2dFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/Pool2dFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -28,14 +28,13 @@
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
-
 #include "arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h"
 #include "arm_compute/dynamic_fusion/sketch/attributes/Pool2dAttributes.h"
 #include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
-#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuPool2d.h"
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h"
-#include "src/dynamic_fusion/utils/Utils.h"
+#include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuPool2d.h"
 
+#include "src/dynamic_fusion/utils/Utils.h"
 #include "tests/CL/CLAccessor.h"
 #include "tests/framework/Fixture.h"
 #include "tests/validation/reference/PoolingLayer.h"
@@ -54,19 +53,20 @@
 public:
     void setup(TensorShape input_shape, const Pool2dAttributes &pool_attr, DataType data_type, bool mixed_precision)
     {
-        _target    = compute_target(input_shape, pool_attr, data_type, mixed_precision);
-        _reference = compute_reference(input_shape, convert_pool_attr_to_pool_info(pool_attr, mixed_precision), data_type);
+        _target = compute_target(input_shape, pool_attr, data_type, mixed_precision);
+        _reference =
+            compute_reference(input_shape, convert_pool_attr_to_pool_info(pool_attr, mixed_precision), data_type);
     }
 
 protected:
     template <typename U>
     void fill(U &&tensor, int i)
     {
-        switch(tensor.data_type())
+        switch (tensor.data_type())
         {
             case DataType::F16:
             {
-                arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
+                arm_compute::utils::uniform_real_distribution_16bit<half> distribution{-1.0f, 1.0f};
                 library->fill(tensor, distribution, i);
                 break;
             }
@@ -82,7 +82,10 @@
     }
 
     // Given input is in nchw format
-    TensorType compute_target(TensorShape input_shape, const Pool2dAttributes &pool_attr, const DataType data_type, bool mixed_precision)
+    TensorType compute_target(TensorShape             input_shape,
+                              const Pool2dAttributes &pool_attr,
+                              const DataType          data_type,
+                              bool                    mixed_precision)
     {
         CLScheduler::get().default_reinit();
 
@@ -91,8 +94,8 @@
 
         // Create a new workload sketch
         auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-        auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-        GpuWorkloadSketch sketch{ &context };
+        auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+        GpuWorkloadSketch sketch{&context};
 
         // Create sketch tensors
         auto input_info = context.create_tensor_info(TensorInfo(input_shape, 1, data_type, DataLayout::NHWC));
@@ -101,14 +104,14 @@
         // Create Pool2dSettings
         GpuPool2dSettings pool_settings = GpuPool2dSettings().mixed_precision(mixed_precision);
 
-        ITensorInfo *ans_info = FunctionType::create_op(sketch, &input_info, pool_attr, pool_settings);
-        GpuOutput::create_op(sketch, ans_info, &dst_info);
+        ITensorInfo *ans_info = FunctionType::create_op(sketch, input_info, pool_attr, pool_settings);
+        GpuOutput::create_op(sketch, ans_info, dst_info);
 
         // Configure runtime
         ClWorkloadRuntime runtime;
         runtime.configure(sketch);
         // (Important) Allocate auxiliary tensor memory if there are any
-        for(auto &data : runtime.get_auxiliary_tensors())
+        for (auto &data : runtime.get_auxiliary_tensors())
         {
             CLTensor     *tensor      = std::get<0>(data);
             TensorInfo    info        = std::get<1>(data);
@@ -121,8 +124,8 @@
         TensorType t_dst{};
 
         // Initialize user tensors
-        t_input.allocator()->init(input_info);
-        t_dst.allocator()->init(dst_info);
+        t_input.allocator()->init(*input_info);
+        t_dst.allocator()->init(*dst_info);
 
         // Allocate and fill user tensors
         t_input.allocator()->allocate();
@@ -131,7 +134,7 @@
         fill(AccessorType(t_input), 0);
 
         // Run runtime
-        runtime.run({ &t_input, &t_dst });
+        runtime.run({&t_input, &t_dst});
         return t_dst;
     }
 
@@ -149,36 +152,57 @@
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionGpuPool2dValidationFixture : public DynamicFusionGpuPool2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionGpuPool2dValidationFixture
+    : public DynamicFusionGpuPool2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
-    void setup(TensorShape input_shape, PoolingType pool_type, Size2D pool_size, Padding2D pad, Size2D stride, bool exclude_padding, DataType data_type)
+    void setup(TensorShape input_shape,
+               PoolingType pool_type,
+               Size2D      pool_size,
+               Padding2D   pad,
+               Size2D      stride,
+               bool        exclude_padding,
+               DataType    data_type)
     {
-        DynamicFusionGpuPool2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape,
-                                                                                                         Pool2dAttributes().pool_type(pool_type).pool_size(pool_size).pad(pad).stride(stride).exclude_padding(exclude_padding),
-                                                                                                         data_type, false);
+        DynamicFusionGpuPool2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            input_shape,
+            Pool2dAttributes().pool_type(pool_type).pool_size(pool_size).pad(pad).stride(stride).exclude_padding(
+                exclude_padding),
+            data_type, false);
     }
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionGpuPool2dMixedPrecisionValidationFixture : public DynamicFusionGpuPool2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionGpuPool2dMixedPrecisionValidationFixture
+    : public DynamicFusionGpuPool2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
-    void setup(TensorShape input_shape, PoolingType pool_type, Size2D pool_size, Padding2D pad, Size2D stride, bool exclude_padding, DataType data_type, bool mixed_precision)
+    void setup(TensorShape input_shape,
+               PoolingType pool_type,
+               Size2D      pool_size,
+               Padding2D   pad,
+               Size2D      stride,
+               bool        exclude_padding,
+               DataType    data_type,
+               bool        mixed_precision)
     {
-        DynamicFusionGpuPool2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape,
-                                                                                                         Pool2dAttributes().pool_type(pool_type).pool_size(pool_size).pad(pad).stride(stride).exclude_padding(exclude_padding),
-                                                                                                         data_type, mixed_precision);
+        DynamicFusionGpuPool2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            input_shape,
+            Pool2dAttributes().pool_type(pool_type).pool_size(pool_size).pad(pad).stride(stride).exclude_padding(
+                exclude_padding),
+            data_type, mixed_precision);
     }
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionGpuPool2dSpecialValidationFixture : public DynamicFusionGpuPool2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionGpuPool2dSpecialValidationFixture
+    : public DynamicFusionGpuPool2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
     void setup(TensorShape input_shape, Pool2dAttributes pool_attr, DataType data_type)
     {
-        DynamicFusionGpuPool2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, pool_attr, data_type, false);
+        DynamicFusionGpuPool2dValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            input_shape, pool_attr, data_type, false);
     }
 };
 
diff --git a/tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h b/tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h
index 18c3b6b..2f0b133 100644
--- a/tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -22,8 +22,8 @@
  * SOFTWARE.
  */
 
-#ifndef TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_ACTIVATIONFIXTURE
-#define TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_ACTIVATIONFIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_ACTIVATIONFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_ACTIVATIONFIXTURE_H
 
 #include "arm_compute/core/CL/CLKernelLibrary.h"
 #include "arm_compute/core/TensorInfo.h"
@@ -49,11 +49,11 @@
 public:
     void setup(TensorShape shape, bool fuse, DataType data_type, ActivationLayerInfo act_info, TArgs... args)
     {
-        _fuse       = fuse;
-        _data_type  = data_type;
-        _function   = act_info.activation();
-        _target     = compute_target(shape, args...);
-        _reference  = compute_reference(shape, act_info);
+        _fuse      = fuse;
+        _data_type = data_type;
+        _function  = act_info.activation();
+        _target    = compute_target(shape, args...);
+        _reference = compute_reference(shape, act_info);
     }
 
 protected:
@@ -73,17 +73,19 @@
         // To ensure all the inserted values are within the given range after subtracing/adding delta
         auto insert_values = [&boundary_values, &min, &max](const std::initializer_list<T> &new_values)
         {
-            for(auto &v : new_values)
+            for (auto &v : new_values)
             {
-                if(v >= min && v <= max)
+                if (v >= min && v <= max)
                 {
                     boundary_values.emplace_back(v);
                 }
             }
         };
 
-        insert_values({ min, static_cast<T>(min + delta), static_cast<T>(lower_quarter), static_cast<T>(center_value - delta) });                               // lower partition
-        insert_values({ static_cast<T>(center_value), static_cast<T>(center_value + delta), static_cast<T>(upper_quarter), static_cast<T>(max - delta), max }); // upper partition
+        insert_values({min, static_cast<T>(min + delta), static_cast<T>(lower_quarter),
+                       static_cast<T>(center_value - delta)}); // lower partition
+        insert_values({static_cast<T>(center_value), static_cast<T>(center_value + delta),
+                       static_cast<T>(upper_quarter), static_cast<T>(max - delta), max}); // upper partition
 
         return boundary_values;
     }
@@ -91,8 +93,8 @@
     template <typename U>
     void fill(U &&tensor)
     {
-        float min_bound = 0;
-        float max_bound = 0;
+        float min_bound                = 0;
+        float max_bound                = 0;
         std::tie(min_bound, max_bound) = get_activation_layer_test_bounds<T>(_function, _data_type);
         library->fill_static_values(tensor, get_boundary_values(static_cast<T>(min_bound), static_cast<T>(max_bound)));
     }
@@ -101,22 +103,22 @@
     {
         // Create a new workload sketch
         CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-        GpuWorkloadContext context{ &cl_compile_ctx };
-        GpuWorkloadSketch  sketch{ &context };
+        GpuWorkloadContext context{&cl_compile_ctx};
+        GpuWorkloadSketch  sketch{&context};
 
         // Create sketch tensors
-        TensorInfo src_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type));
-        TensorInfo dst_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type));
+        ITensorInfo *src_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type));
+        ITensorInfo *dst_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type));
 
-        ITensorInfo *ans_0_info = FunctionType::create_op(sketch, &src_info, args...);
-        if(_fuse)
+        ITensorInfo *ans_0_info = FunctionType::create_op(sketch, src_info, args...);
+        if (_fuse)
         {
             ITensorInfo *ans_1_info = FunctionType::create_op(sketch, ans_0_info, args...);
-            GpuOutput::create_op(sketch, ans_1_info, &dst_info);
+            GpuOutput::create_op(sketch, ans_1_info, dst_info);
         }
         else
         {
-            GpuOutput::create_op(sketch, ans_0_info, &dst_info);
+            GpuOutput::create_op(sketch, ans_0_info, dst_info);
         }
 
         // Configure runtime
@@ -128,8 +130,8 @@
         TensorType t_dst{};
 
         // Initialize user tensors
-        t_src.allocator()->init(src_info);
-        t_dst.allocator()->init(dst_info);
+        t_src.allocator()->init(*src_info);
+        t_dst.allocator()->init(*dst_info);
 
         // Allocate and fill user tensors
         t_src.allocator()->allocate();
@@ -138,7 +140,7 @@
         fill(AccessorType(t_src));
 
         // Run runtime
-        runtime.run({ &t_src, &t_dst });
+        runtime.run({&t_src, &t_dst});
 
         return t_dst;
     }
@@ -146,14 +148,14 @@
     SimpleTensor<T> compute_reference(const TensorShape &shape, ActivationLayerInfo act_info)
     {
         // Create reference
-        SimpleTensor<T> src{ shape, _data_type, 1 };
+        SimpleTensor<T> src{shape, _data_type, 1};
 
         // Fill reference
         fill(src);
 
         auto tmp = reference::activation_layer<T>(src, act_info);
 
-        if(_fuse)
+        if (_fuse)
         {
             auto dst = reference::activation_layer<T>(tmp, act_info);
             return dst;
@@ -166,31 +168,35 @@
 
 protected:
     ActivationLayerInfo::ActivationFunction _function{};
-    bool                                    _fuse{ false };
+    bool                                    _fuse{false};
     DataType                                _data_type{};
     TensorType                              _target{};
     SimpleTensor<T>                         _reference{};
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionSigmoidValidationFixture : public DynamicFusionActivationValidationFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionSigmoidValidationFixture
+    : public DynamicFusionActivationValidationFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
     void setup(TensorShape shape, bool fuse, DataType data_type)
     {
-        ActivationLayerInfo act_info{ ActivationLayerInfo::ActivationFunction::LOGISTIC };
-        DynamicFusionActivationValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, fuse, data_type, act_info);
+        ActivationLayerInfo act_info{ActivationLayerInfo::ActivationFunction::LOGISTIC};
+        DynamicFusionActivationValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, fuse,
+                                                                                                   data_type, act_info);
     }
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionTanhValidationFixture : public DynamicFusionActivationValidationFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionTanhValidationFixture
+    : public DynamicFusionActivationValidationFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
     void setup(TensorShape shape, bool fuse, DataType data_type)
     {
-        ActivationLayerInfo act_info{ ActivationLayerInfo::ActivationFunction::TANH };
-        DynamicFusionActivationValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, fuse, data_type, act_info);
+        ActivationLayerInfo act_info{ActivationLayerInfo::ActivationFunction::TANH};
+        DynamicFusionActivationValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, fuse,
+                                                                                                   data_type, act_info);
     }
 };
 
@@ -198,4 +204,4 @@
 } // namespace test
 } // namespace arm_compute
 
-#endif /* TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_ACTIVATIONFIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_ACTIVATIONFIXTURE_H
diff --git a/tests/validation/fixtures/dynamic_fusion/operators/CastFixture.h b/tests/validation/fixtures/dynamic_fusion/operators/CastFixture.h
index d8e250c..edf0dff 100644
--- a/tests/validation/fixtures/dynamic_fusion/operators/CastFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/operators/CastFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_CASTFIXTURE
-#define TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_CASTFIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_CASTFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_CASTFIXTURE_H
 
 #include "arm_compute/core/CL/CLKernelLibrary.h"
 #include "arm_compute/core/TensorInfo.h"
@@ -58,14 +58,14 @@
     void fill(U &&tensor, int i, DataType dt_in, DataType dt_out)
     {
         // Restricting range to avoid inf values
-        if(dt_out == DataType::F16)
+        if (dt_out == DataType::F16)
         {
             constexpr int signed_min   = -32000;
             constexpr int signed_max   = 32000;
             constexpr int unsigned_min = 0;
             constexpr int unsigned_max = 65000;
 
-            switch(dt_in)
+            switch (dt_in)
             {
                 case DataType::U8:
                 case DataType::QASYMM8:
@@ -78,22 +78,26 @@
                 }
                 case DataType::U16:
                 {
-                    library->fill_tensor_uniform(tensor, i, static_cast<uint16_t>(unsigned_min), static_cast<uint16_t>(unsigned_max));
+                    library->fill_tensor_uniform(tensor, i, static_cast<uint16_t>(unsigned_min),
+                                                 static_cast<uint16_t>(unsigned_max));
                     break;
                 }
                 case DataType::S16:
                 {
-                    library->fill_tensor_uniform(tensor, i, static_cast<int16_t>(signed_min), static_cast<int16_t>(signed_max));
+                    library->fill_tensor_uniform(tensor, i, static_cast<int16_t>(signed_min),
+                                                 static_cast<int16_t>(signed_max));
                     break;
                 }
                 case DataType::U32:
                 {
-                    library->fill_tensor_uniform(tensor, i, static_cast<uint32_t>(unsigned_min), static_cast<uint32_t>(unsigned_max));
+                    library->fill_tensor_uniform(tensor, i, static_cast<uint32_t>(unsigned_min),
+                                                 static_cast<uint32_t>(unsigned_max));
                     break;
                 }
                 case DataType::S32:
                 {
-                    library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(signed_min), static_cast<int32_t>(signed_max));
+                    library->fill_tensor_uniform(tensor, i, static_cast<int32_t>(signed_min),
+                                                 static_cast<int32_t>(signed_max));
                     break;
                 }
                 default:
@@ -107,29 +111,31 @@
     }
 
     // Given input is in nchw format
-    TensorType compute_target(const TensorShape &shape, const DataType dt_in, const DataType dt_out, const ConvertPolicy policy)
+    TensorType
+    compute_target(const TensorShape &shape, const DataType dt_in, const DataType dt_out, const ConvertPolicy policy)
     {
         // Create a new workload sketch
         auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-        auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-        GpuWorkloadSketch sketch{ &context };
+        auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+        GpuWorkloadSketch sketch{&context};
 
         // Create sketch tensors
-        TensorInfo src_info = context.create_tensor_info(TensorInfo(shape, 1, dt_in, DataLayout::NCHW)); // layout is not important
-        TensorInfo dst_info = context.create_tensor_info();
+        ITensorInfo *src_info =
+            context.create_tensor_info(TensorInfo(shape, 1, dt_in, DataLayout::NCHW)); // layout is not important
+        ITensorInfo *dst_info = context.create_tensor_info();
 
         CastAttributes attributes;
         attributes.convert_policy(policy).data_type(dt_out);
 
-        ITensorInfo *ans_info = FunctionType::create_op(sketch, &src_info, attributes);
-        GpuOutput::create_op(sketch, ans_info, &dst_info);
+        ITensorInfo *ans_info = FunctionType::create_op(sketch, src_info, attributes);
+        GpuOutput::create_op(sketch, ans_info, dst_info);
 
         // Configure runtime
         ClWorkloadRuntime runtime;
         runtime.configure(sketch);
 
         // (Important) Allocate auxiliary tensor memory if there are any
-        for(auto &data : runtime.get_auxiliary_tensors())
+        for (auto &data : runtime.get_auxiliary_tensors())
         {
             CLTensor     *tensor      = std::get<0>(data);
             TensorInfo    info        = std::get<1>(data);
@@ -143,8 +149,8 @@
         TensorType t_dst{};
 
         // Initialize user tensors
-        t_src.allocator()->init(src_info);
-        t_dst.allocator()->init(dst_info);
+        t_src.allocator()->init(*src_info);
+        t_dst.allocator()->init(*dst_info);
 
         // Allocate and fill user tensors
         t_src.allocator()->allocate();
@@ -153,14 +159,15 @@
         fill(AccessorType(t_src), 0, dt_in, dt_out);
 
         // Run runtime
-        runtime.run({ &t_src, &t_dst });
+        runtime.run({&t_src, &t_dst});
         return t_dst;
     }
 
-    SimpleTensor<T2> compute_reference(const TensorShape &shape, const DataType dt_in, const DataType dt_out, const ConvertPolicy policy)
+    SimpleTensor<T2>
+    compute_reference(const TensorShape &shape, const DataType dt_in, const DataType dt_out, const ConvertPolicy policy)
     {
         // Create reference
-        SimpleTensor<T1> src{ shape, dt_in, 1 };
+        SimpleTensor<T1> src{shape, dt_in, 1};
 
         // Fill reference
         fill(src, 0, dt_in, dt_out);
@@ -174,4 +181,4 @@
 } // namespace validation
 } // namespace test
 } // namespace arm_compute
-#endif /* TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_CASTFIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_CASTFIXTURE_H
diff --git a/tests/validation/fixtures/dynamic_fusion/operators/ClampFixture.h b/tests/validation/fixtures/dynamic_fusion/operators/ClampFixture.h
index 3c325d7..e8f6f83 100644
--- a/tests/validation/fixtures/dynamic_fusion/operators/ClampFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/operators/ClampFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_CLAMPFIXTURE
-#define TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_CLAMPFIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_CLAMPFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_CLAMPFIXTURE_H
 
 #include "arm_compute/core/CL/CLKernelLibrary.h"
 #include "arm_compute/core/TensorInfo.h"
@@ -107,18 +107,18 @@
         GpuWorkloadSketch  sketch{ &context };
 
         // Create sketch tensors
-        TensorInfo src_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type));
-        TensorInfo dst_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type));
+        ITensorInfo* src_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type));
+        ITensorInfo* dst_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type));
 
-        ITensorInfo *ans_0_info = FunctionType::create_op(sketch, &src_info, attributes);
+        ITensorInfo *ans_0_info = FunctionType::create_op(sketch, src_info, attributes);
         if(_fuse)
         {
             ITensorInfo *ans_1_info = FunctionType::create_op(sketch, ans_0_info, attributes);
-            GpuOutput::create_op(sketch, ans_1_info, &dst_info);
+            GpuOutput::create_op(sketch, ans_1_info, dst_info);
         }
         else
         {
-            GpuOutput::create_op(sketch, ans_0_info, &dst_info);
+            GpuOutput::create_op(sketch, ans_0_info, dst_info);
         }
 
         // Configure runtime
@@ -130,8 +130,8 @@
         TensorType t_dst{};
 
         // Initialize user tensors
-        t_src.allocator()->init(src_info);
-        t_dst.allocator()->init(dst_info);
+        t_src.allocator()->init(*src_info);
+        t_dst.allocator()->init(*dst_info);
 
         // Allocate and fill user tensors
         t_src.allocator()->allocate();
@@ -168,4 +168,4 @@
 } // namespace validation
 } // namespace test
 } // namespace arm_compute
-#endif /* TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_CLAMPFIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_CLAMPFIXTURE_H
diff --git a/tests/validation/fixtures/dynamic_fusion/operators/MulFixture.h b/tests/validation/fixtures/dynamic_fusion/operators/MulFixture.h
index 02dc996..f02aa5e 100644
--- a/tests/validation/fixtures/dynamic_fusion/operators/MulFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/operators/MulFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_MULFIXTURE
-#define TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_MULFIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_MULFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_MULFIXTURE_H
 
 #include "arm_compute/core/CL/CLKernelLibrary.h"
 #include "arm_compute/core/TensorInfo.h"
@@ -31,9 +31,9 @@
 #include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h"
 
-#include "tests/Globals.h"
 #include "tests/framework/Fixture.h"
 #include "tests/framework/Macros.h"
+#include "tests/Globals.h"
 #include "tests/validation/reference/PixelWiseMultiplication.h"
 
 using namespace arm_compute::experimental::dynamic_fusion;
@@ -52,180 +52,188 @@
 class DynamicFusionMulValidationFixture : public framework::Fixture
 {
 public:
-   void setup(const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2, DataType data_type, bool is_inplace, bool fuse_two_ops = false)
-   {
-       _data_type  = data_type;
-       _is_inplace = is_inplace;
-       _fuse       = fuse_two_ops;
-       ARM_COMPUTE_ERROR_ON_MSG(_fuse && shape2.total_size() == 0, "No shape2 provided for fusion of two ops.");
-       ARM_COMPUTE_ERROR_ON_MSG(_fuse && _is_inplace, "In place for fusing case not supported yet.");
-       _target    = compute_target(shape0, shape1, shape2);
-       _reference = compute_reference(shape0, shape1, shape2);
-   }
+    void setup(const TensorShape &shape0,
+               const TensorShape &shape1,
+               const TensorShape &shape2,
+               DataType           data_type,
+               bool               is_inplace,
+               bool               fuse_two_ops = false)
+    {
+        _data_type  = data_type;
+        _is_inplace = is_inplace;
+        _fuse       = fuse_two_ops;
+        ARM_COMPUTE_ERROR_ON_MSG(_fuse && shape2.total_size() == 0, "No shape2 provided for fusion of two ops.");
+        ARM_COMPUTE_ERROR_ON_MSG(_fuse && _is_inplace, "In place for fusing case not supported yet.");
+        _target    = compute_target(shape0, shape1, shape2);
+        _reference = compute_reference(shape0, shape1, shape2);
+    }
 
 protected:
-   template <typename U>
-   void fill(U &&tensor, int i)
-   {
-       library->fill_tensor_uniform(tensor, i);
-   }
+    template <typename U>
+    void fill(U &&tensor, int i)
+    {
+        library->fill_tensor_uniform(tensor, i);
+    }
 
-   TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2)
-   {
-       // Create a new workload sketch
-       auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-       auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-       GpuWorkloadSketch sketch{ &context };
+    TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2)
+    {
+        // Create a new workload sketch
+        auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
+        auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+        GpuWorkloadSketch sketch{&context};
 
-       // Fuse first multiplication op
-       TensorInfo lhs_info = context.create_tensor_info(TensorInfo(shape0, 1, _data_type));
-       TensorInfo rhs_info = context.create_tensor_info(TensorInfo(shape1, 1, _data_type));
-       TensorInfo dst_info = context.create_tensor_info();
+        // Fuse first multiplication op
+        ITensorInfo *lhs_info = context.create_tensor_info(TensorInfo(shape0, 1, _data_type));
+        ITensorInfo *rhs_info = context.create_tensor_info(TensorInfo(shape1, 1, _data_type));
+        ITensorInfo *dst_info = context.create_tensor_info();
 
-       TensorInfo rhs_info_fuse;
+        ITensorInfo *rhs_info_fuse = nullptr;
 
-       ITensorInfo *ans_info = FunctionType::create_op(sketch, &lhs_info, &rhs_info);
+        ITensorInfo *ans_info = FunctionType::create_op(sketch, lhs_info, rhs_info);
 
-       if(_fuse)
-       {
-           rhs_info_fuse          = context.create_tensor_info(TensorInfo(shape2, 1, _data_type));
-           ITensorInfo *ans2_info = FunctionType::create_op(sketch, ans_info, &rhs_info_fuse);
-           GpuOutput::create_op(sketch, ans2_info, &dst_info);
-       }
-       else
-       {
-           GpuOutput::create_op(sketch, ans_info, &dst_info);
-       }
+        if (_fuse)
+        {
+            rhs_info_fuse          = context.create_tensor_info(TensorInfo(shape2, 1, _data_type));
+            ITensorInfo *ans2_info = FunctionType::create_op(sketch, ans_info, rhs_info_fuse);
+            GpuOutput::create_op(sketch, ans2_info, dst_info);
+        }
+        else
+        {
+            GpuOutput::create_op(sketch, ans_info, dst_info);
+        }
 
-       // Configure runtime
-       ClWorkloadRuntime runtime;
-       runtime.configure(sketch);
+        // Configure runtime
+        ClWorkloadRuntime runtime;
+        runtime.configure(sketch);
 
-       // (Important) Allocate auxiliary tensor memory if there are any
-       for(auto &data : runtime.get_auxiliary_tensors())
-       {
-           CLTensor     *tensor      = std::get<0>(data);
-           TensorInfo    info        = std::get<1>(data);
-           AuxMemoryInfo aux_mem_req = std::get<2>(data);
-           tensor->allocator()->init(info, aux_mem_req.alignment);
-           tensor->allocator()->allocate(); // Use ACL allocated memory
-       }
+        // (Important) Allocate auxiliary tensor memory if there are any
+        for (auto &data : runtime.get_auxiliary_tensors())
+        {
+            CLTensor     *tensor      = std::get<0>(data);
+            TensorInfo    info        = std::get<1>(data);
+            AuxMemoryInfo aux_mem_req = std::get<2>(data);
+            tensor->allocator()->init(info, aux_mem_req.alignment);
+            tensor->allocator()->allocate(); // Use ACL allocated memory
+        }
 
-       // Construct user tensors
-       TensorType t_lhs{};
-       TensorType t_rhs{};
-       TensorType t_rhs_fuse{};
-       TensorType t_dst{};
+        // Construct user tensors
+        TensorType t_lhs{};
+        TensorType t_rhs{};
+        TensorType t_rhs_fuse{};
+        TensorType t_dst{};
 
-       // Initialize user tensors
-       t_lhs.allocator()->init(lhs_info);
-       t_rhs.allocator()->init(rhs_info);
-       t_dst.allocator()->init(dst_info);
-       if(_fuse)
-       {
-           t_rhs_fuse.allocator()->init(rhs_info_fuse);
-       }
+        // Initialize user tensors
+        t_lhs.allocator()->init(*lhs_info);
+        t_rhs.allocator()->init(*rhs_info);
+        t_dst.allocator()->init(*dst_info);
+        if (_fuse)
+        {
+            t_rhs_fuse.allocator()->init(*rhs_info_fuse);
+        }
 
-       // Allocate and fill user tensors
-       // Instead of using ACL allocator, the user can choose to import memory into the tensors
-       t_lhs.allocator()->allocate();
-       t_rhs.allocator()->allocate();
-       t_dst.allocator()->allocate();
-       if(_fuse)
-       {
-           t_rhs_fuse.allocator()->allocate();
-       }
+        // Allocate and fill user tensors
+        // Instead of using ACL allocator, the user can choose to import memory into the tensors
+        t_lhs.allocator()->allocate();
+        t_rhs.allocator()->allocate();
+        t_dst.allocator()->allocate();
+        if (_fuse)
+        {
+            t_rhs_fuse.allocator()->allocate();
+        }
 
-       fill(AccessorType(t_lhs), 0);
-       fill(AccessorType(t_rhs), 1);
-       if(_fuse)
-       {
-           fill(AccessorType(t_rhs_fuse), 2);
-       }
+        fill(AccessorType(t_lhs), 0);
+        fill(AccessorType(t_rhs), 1);
+        if (_fuse)
+        {
+            fill(AccessorType(t_rhs_fuse), 2);
+        }
 
-       // Run runtime
-       if(_fuse)
-       {
-           runtime.run({ &t_lhs, &t_rhs, &t_rhs_fuse, &t_dst });
-       }
-       else
-       {
-           runtime.run({ &t_lhs, &t_rhs, &t_dst });
-       }
+        // Run runtime
+        if (_fuse)
+        {
+            runtime.run({&t_lhs, &t_rhs, &t_rhs_fuse, &t_dst});
+        }
+        else
+        {
+            runtime.run({&t_lhs, &t_rhs, &t_dst});
+        }
 
-       return t_dst;
-   }
+        return t_dst;
+    }
 
-   SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2)
-   {
-       // Create reference
-       SimpleTensor<T> ref_lhs{ shape0, _data_type, 1, QuantizationInfo() };
-       SimpleTensor<T> ref_rhs{ shape1, _data_type, 1, QuantizationInfo() };
-       SimpleTensor<T> ref_rhs_fuse{ shape2, _data_type, 1, QuantizationInfo() };
+    SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2)
+    {
+        // Create reference
+        SimpleTensor<T> ref_lhs{shape0, _data_type, 1, QuantizationInfo()};
+        SimpleTensor<T> ref_rhs{shape1, _data_type, 1, QuantizationInfo()};
+        SimpleTensor<T> ref_rhs_fuse{shape2, _data_type, 1, QuantizationInfo()};
 
-       // Fill reference
-       fill(ref_lhs, 0);
-       fill(ref_rhs, 1);
-       SimpleTensor<T> ref_dst = reference::pixel_wise_multiplication<T, T, T>(ref_lhs,
-                                                                               ref_rhs,
-                                                                               1.f,
-                                                                               ConvertPolicy::SATURATE,
-                                                                               RoundingPolicy::TO_NEAREST_UP,
-                                                                               _data_type,
-                                                                               QuantizationInfo());
-       if(_fuse)
-       {
-           fill(ref_rhs_fuse, 2);
-           SimpleTensor<T> ref_dst_fuse = reference::pixel_wise_multiplication<T, T, T>(ref_dst,
-                                                                                        ref_rhs_fuse,
-                                                                                        1.f,
-                                                                                        ConvertPolicy::SATURATE,
-                                                                                        RoundingPolicy::TO_NEAREST_UP,
-                                                                                        _data_type,
-                                                                                        QuantizationInfo());
-           return ref_dst_fuse;
-       }
-       return ref_dst;
-   }
+        // Fill reference
+        fill(ref_lhs, 0);
+        fill(ref_rhs, 1);
+        SimpleTensor<T> ref_dst = reference::pixel_wise_multiplication<T, T, T>(
+            ref_lhs, ref_rhs, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_UP, _data_type,
+            QuantizationInfo());
+        if (_fuse)
+        {
+            fill(ref_rhs_fuse, 2);
+            SimpleTensor<T> ref_dst_fuse = reference::pixel_wise_multiplication<T, T, T>(
+                ref_dst, ref_rhs_fuse, 1.f, ConvertPolicy::SATURATE, RoundingPolicy::TO_NEAREST_UP, _data_type,
+                QuantizationInfo());
+            return ref_dst_fuse;
+        }
+        return ref_dst;
+    }
 
-   TensorType      _target{};
-   SimpleTensor<T> _reference{};
-   DataType        _data_type{};
-   bool            _is_inplace{ false };
-   bool            _fuse{ false };
+    TensorType      _target{};
+    SimpleTensor<T> _reference{};
+    DataType        _data_type{};
+    bool            _is_inplace{false};
+    bool            _fuse{false};
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionMulOneOpValidationFixture : public DynamicFusionMulValidationFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionMulOneOpValidationFixture
+    : public DynamicFusionMulValidationFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
-   void setup(const TensorShape &shape0, DataType data_type, bool is_inplace)
-   {
-       DynamicFusionMulValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape0, TensorShape(), data_type, is_inplace);
-   }
+    void setup(const TensorShape &shape0, DataType data_type, bool is_inplace)
+    {
+        DynamicFusionMulValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            shape0, shape0, TensorShape(), data_type, is_inplace);
+    }
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionMulBroadcastValidationFixture : public DynamicFusionMulValidationFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionMulBroadcastValidationFixture
+    : public DynamicFusionMulValidationFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
-   void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, bool is_inplace)
-   {
-       DynamicFusionMulValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape1, TensorShape(), data_type, is_inplace);
-   }
+    void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, bool is_inplace)
+    {
+        DynamicFusionMulValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            shape0, shape1, TensorShape(), data_type, is_inplace);
+    }
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionMulTwoOpsValidationFixture : public DynamicFusionMulValidationFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionMulTwoOpsValidationFixture
+    : public DynamicFusionMulValidationFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
-   void setup(const TensorShape &shape0, const TensorShape &shape1, const TensorShape &shape2, DataType data_type, bool is_inplace, bool fuse_two_ops)
-   {
-       DynamicFusionMulValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape1, shape2, data_type, is_inplace, fuse_two_ops);
-   }
+    void setup(const TensorShape &shape0,
+               const TensorShape &shape1,
+               const TensorShape &shape2,
+               DataType           data_type,
+               bool               is_inplace,
+               bool               fuse_two_ops)
+    {
+        DynamicFusionMulValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            shape0, shape1, shape2, data_type, is_inplace, fuse_two_ops);
+    }
 };
 
 } // namespace validation
 } // namespace test
 } // namespace arm_compute
-#endif /* TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_MULFIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_MULFIXTURE_H
diff --git a/tests/validation/fixtures/dynamic_fusion/operators/ReshapeFixture.h b/tests/validation/fixtures/dynamic_fusion/operators/ReshapeFixture.h
index abfc645..bde3360 100644
--- a/tests/validation/fixtures/dynamic_fusion/operators/ReshapeFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/operators/ReshapeFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_RESHAPEFIXTURE
-#define TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_RESHAPEFIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_RESHAPEFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_RESHAPEFIXTURE_H
 
 #include "arm_compute/core/TensorShape.h"
 #include "arm_compute/core/Types.h"
@@ -33,9 +33,9 @@
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h"
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuReshape.h"
 
-#include "tests/Globals.h"
 #include "tests/framework/Asserts.h"
 #include "tests/framework/Fixture.h"
+#include "tests/Globals.h"
 #include "tests/validation/reference/ReshapeLayer.h"
 
 using namespace arm_compute::experimental::dynamic_fusion;
@@ -70,24 +70,24 @@
 
         // Create a new workload sketch
         auto              cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-        auto              context        = GpuWorkloadContext{ &cl_compile_ctx };
-        GpuWorkloadSketch sketch{ &context };
+        auto              context        = GpuWorkloadContext{&cl_compile_ctx};
+        GpuWorkloadSketch sketch{&context};
 
         // Create sketch tensors
-        TensorInfo        src_info = context.create_tensor_info(TensorInfo(input_shape, 1, data_type));
-        TensorInfo        dst_info = context.create_tensor_info(TensorInfo(output_shape, 1, data_type));
+        ITensorInfo      *src_info = context.create_tensor_info(TensorInfo(input_shape, 1, data_type));
+        ITensorInfo      *dst_info = context.create_tensor_info(TensorInfo(output_shape, 1, data_type));
         ReshapeAttributes attributes;
         attributes.shape(output_shape);
 
-        ITensorInfo *ans_info = FunctionType::create_op(sketch, &src_info, attributes);
-        GpuOutput::create_op(sketch, ans_info, &dst_info);
+        ITensorInfo *ans_info = FunctionType::create_op(sketch, src_info, attributes);
+        GpuOutput::create_op(sketch, ans_info, dst_info);
 
         // Configure runtime
         ClWorkloadRuntime runtime;
         runtime.configure(sketch);
 
         // (Important) Allocate auxiliary tensor memory if there are any
-        for(auto &data : runtime.get_auxiliary_tensors())
+        for (auto &data : runtime.get_auxiliary_tensors())
         {
             CLTensor     *tensor      = std::get<0>(data);
             TensorInfo    info        = std::get<1>(data);
@@ -100,8 +100,8 @@
         TensorType t_src{};
         TensorType t_dst{};
         // Initialize user tensors
-        t_src.allocator()->init(src_info);
-        t_dst.allocator()->init(dst_info);
+        t_src.allocator()->init(*src_info);
+        t_dst.allocator()->init(*dst_info);
 
         // Allocate and fill user tensors
         t_src.allocator()->allocate();
@@ -110,15 +110,16 @@
         fill(AccessorType(t_src), 0);
 
         // Run runtime
-        runtime.run({ &t_src, &t_dst });
+        runtime.run({&t_src, &t_dst});
 
         return t_dst;
     }
 
-    SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type)
+    SimpleTensor<T>
+    compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type)
     {
         // Create reference
-        SimpleTensor<T> src{ input_shape, data_type };
+        SimpleTensor<T> src{input_shape, data_type};
 
         // Fill reference
         fill(src, 0);
@@ -133,4 +134,4 @@
 } // namespace validation
 } // namespace test
 } // namespace arm_compute
-#endif /* TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_RESHAPEFIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_RESHAPEFIXTURE_H
diff --git a/tests/validation/fixtures/dynamic_fusion/operators/ResizeFixture.h b/tests/validation/fixtures/dynamic_fusion/operators/ResizeFixture.h
index c44f037..711767b 100644
--- a/tests/validation/fixtures/dynamic_fusion/operators/ResizeFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/operators/ResizeFixture.h
@@ -1,5 +1,5 @@
 /*
-* Copyright (c) 2022-2023 Arm Limited.
+* Copyright (c) 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_RESIZEFIXTURE
-#define TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_RESIZEFIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_RESIZEFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_RESIZEFIXTURE_H
 
 #include "arm_compute/core/CL/CLKernelLibrary.h"
 #include "arm_compute/core/TensorInfo.h"
@@ -33,12 +33,12 @@
 #include "arm_compute/dynamic_fusion/sketch/gpu/operators/GpuOutput.h"
 
 #include "tests/CL/CLAccessor.h"
-#include "tests/SimpleTensor.h"
 #include "tests/framework/Fixture.h"
 #include "tests/framework/Macros.h"
-#include "tests/validation/Validation.h"
+#include "tests/SimpleTensor.h"
 #include "tests/validation/reference/Permute.h"
 #include "tests/validation/reference/Scale.h"
+#include "tests/validation/Validation.h"
 
 using namespace arm_compute::experimental::dynamic_fusion;
 
@@ -52,9 +52,14 @@
 class DynamicFusionResizeGenericValidationFixture : public framework::Fixture
 {
 public:
-    void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout,
-               InterpolationPolicy interpolation_policy, SamplingPolicy sampling_policy,
-               bool align_corners, QuantizationInfo output_quantization_info)
+    void setup(TensorShape         shape,
+               DataType            data_type,
+               QuantizationInfo    quantization_info,
+               DataLayout          data_layout,
+               InterpolationPolicy interpolation_policy,
+               SamplingPolicy      sampling_policy,
+               bool                align_corners,
+               QuantizationInfo    output_quantization_info)
     {
         _shape                    = shape;
         _interpolation_policy     = interpolation_policy;
@@ -79,13 +84,13 @@
 protected:
     void generate_scale(const TensorShape &shape)
     {
-        static constexpr float _min_scale{ 0.25f };
-        static constexpr float _max_scale{ 3.f };
+        static constexpr float _min_scale{0.25f};
+        static constexpr float _max_scale{3.f};
 
-        constexpr float max_width{ 8192.0f };
-        constexpr float max_height{ 6384.0f };
-        constexpr float min_width{ 1.f };
-        constexpr float min_height{ 1.f };
+        constexpr float max_width{8192.0f};
+        constexpr float max_height{6384.0f};
+        constexpr float min_width{1.f};
+        constexpr float min_height{1.f};
 
         std::mt19937                          generator(library->seed());
         std::uniform_real_distribution<float> distribution_float(_min_scale, _max_scale);
@@ -93,7 +98,8 @@
         auto generate = [&](size_t input_size, float min_output, float max_output) -> int
         {
             const float generated_scale = distribution_float(generator);
-            const int   output_size     = static_cast<int>(utility::clamp(static_cast<float>(input_size) * generated_scale, min_output, max_output));
+            const int   output_size     = static_cast<int>(
+                utility::clamp(static_cast<float>(input_size) * generated_scale, min_output, max_output));
             return output_size;
         };
 
@@ -108,17 +114,17 @@
     template <typename U>
     void fill(U &&tensor)
     {
-        if(tensor.data_type() == DataType::F32)
+        if (tensor.data_type() == DataType::F32)
         {
             std::uniform_real_distribution<float> distribution(-5.0f, 5.0f);
             library->fill(tensor, distribution, 0);
         }
-        else if(tensor.data_type() == DataType::F16)
+        else if (tensor.data_type() == DataType::F16)
         {
-            arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -5.0f, 5.0f };
+            arm_compute::utils::uniform_real_distribution_16bit<half> distribution{-5.0f, 5.0f};
             library->fill(tensor, distribution, 0);
         }
-        else if(is_data_type_quantized(tensor.data_type()))
+        else if (is_data_type_quantized(tensor.data_type()))
         {
             std::uniform_int_distribution<> distribution(0, 100);
             library->fill(tensor, distribution, 0);
@@ -136,26 +142,30 @@
 
         // Create a new workload sketch
         CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-        GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
-        GpuWorkloadSketch  sketch{ &context };
+        GpuWorkloadContext context        = GpuWorkloadContext{&cl_compile_ctx};
+        GpuWorkloadSketch  sketch{&context};
 
         // Create sketch tensors
-        TensorInfo src_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type, _data_layout));
-        src_info.set_quantization_info(_input_quantization_info);
-        TensorInfo dst_info = context.create_tensor_info();
+        ITensorInfo *src_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type, _data_layout));
+        src_info->set_quantization_info(_input_quantization_info);
+        ITensorInfo *dst_info = context.create_tensor_info();
 
         ResizeAttributes attributes;
-        attributes.align_corners(_align_corners).sampling_policy(_sampling_policy).interpolation_policy(_interpolation_policy).output_width(_output_width).output_height(_output_height);
+        attributes.align_corners(_align_corners)
+            .sampling_policy(_sampling_policy)
+            .interpolation_policy(_interpolation_policy)
+            .output_width(_output_width)
+            .output_height(_output_height);
 
-        ITensorInfo *scale_result_info = FunctionType::create_op(sketch, &src_info, attributes);
-        GpuOutput::create_op(sketch, scale_result_info, &dst_info);
+        ITensorInfo *scale_result_info = FunctionType::create_op(sketch, src_info, attributes);
+        GpuOutput::create_op(sketch, scale_result_info, dst_info);
 
         // Configure runtime
         ClWorkloadRuntime runtime;
         runtime.configure(sketch);
 
         // (Important) Allocate auxiliary tensor memory if there are any
-        for(auto &data : runtime.get_auxiliary_tensors())
+        for (auto &data : runtime.get_auxiliary_tensors())
         {
             CLTensor     *tensor      = std::get<0>(data);
             TensorInfo    info        = std::get<1>(data);
@@ -169,8 +179,8 @@
         TensorType t_dst{};
 
         // Initialize user tensors
-        t_src.allocator()->init(src_info);
-        t_dst.allocator()->init(dst_info);
+        t_src.allocator()->init(*src_info);
+        t_dst.allocator()->init(*dst_info);
 
         // Allocate and fill user tensors
         t_src.allocator()->allocate();
@@ -179,7 +189,7 @@
         fill(AccessorType(t_src));
 
         // Run runtime
-        runtime.run({ &t_src, &t_dst });
+        runtime.run({&t_src, &t_dst});
 
         return t_dst;
     }
@@ -187,7 +197,7 @@
     SimpleTensor<T> compute_reference(const TensorShape &shape)
     {
         // Create reference
-        SimpleTensor<T> src{ shape, _data_type, 1, _input_quantization_info };
+        SimpleTensor<T> src{shape, _data_type, 1, _input_quantization_info};
 
         // Reference code is NCHW, so the input shapes are NCHW
         const int idx_width  = get_data_layout_dimension_index(DataLayout::NCHW, DataLayoutDimension::WIDTH);
@@ -199,9 +209,9 @@
         // Fill reference
         fill(src);
 
-        return reference::scale<T>(src, scale_x, scale_y, _interpolation_policy,
-                                   BorderMode::REPLICATE, static_cast<T>(0), _sampling_policy, /* ceil_policy_scale */ false,
-                                   _align_corners, _output_quantization_info);
+        return reference::scale<T>(src, scale_x, scale_y, _interpolation_policy, BorderMode::REPLICATE,
+                                   static_cast<T>(0), _sampling_policy, /* ceil_policy_scale */ false, _align_corners,
+                                   _output_quantization_info);
     }
 
     TensorType          _target{};
@@ -213,43 +223,45 @@
     DataLayout          _data_layout{};
     QuantizationInfo    _input_quantization_info{};
     QuantizationInfo    _output_quantization_info{};
-    bool                _align_corners{ false };
-    int                 _output_width{ 0 };
-    int                 _output_height{ 0 };
+    bool                _align_corners{false};
+    int                 _output_width{0};
+    int                 _output_height{0};
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionResizeValidationFixture : public DynamicFusionResizeGenericValidationFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionResizeValidationFixture
+    : public DynamicFusionResizeGenericValidationFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
-    void setup(TensorShape shape, DataType data_type, DataLayout data_layout, InterpolationPolicy policy, SamplingPolicy sampling_policy, bool align_corners)
+    void setup(TensorShape         shape,
+               DataType            data_type,
+               DataLayout          data_layout,
+               InterpolationPolicy policy,
+               SamplingPolicy      sampling_policy,
+               bool                align_corners)
     {
-        DynamicFusionResizeGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape,
-                                                                                                      data_type,
-                                                                                                      QuantizationInfo(),
-                                                                                                      data_layout,
-                                                                                                      policy,
-                                                                                                      sampling_policy,
-                                                                                                      align_corners,
-                                                                                                      QuantizationInfo());
+        DynamicFusionResizeGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            shape, data_type, QuantizationInfo(), data_layout, policy, sampling_policy, align_corners,
+            QuantizationInfo());
     }
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false>
-class DynamicFusionResizeQuantizedValidationFixture : public DynamicFusionResizeGenericValidationFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionResizeQuantizedValidationFixture
+    : public DynamicFusionResizeGenericValidationFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
-    void setup(TensorShape shape, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout, InterpolationPolicy policy, SamplingPolicy sampling_policy,
-               bool align_corners)
+    void setup(TensorShape         shape,
+               DataType            data_type,
+               QuantizationInfo    quantization_info,
+               DataLayout          data_layout,
+               InterpolationPolicy policy,
+               SamplingPolicy      sampling_policy,
+               bool                align_corners)
     {
-        DynamicFusionResizeGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape,
-                                                                                                      data_type,
-                                                                                                      quantization_info,
-                                                                                                      data_layout,
-                                                                                                      policy,
-                                                                                                      sampling_policy,
-                                                                                                      align_corners,
-                                                                                                      quantization_info);
+        DynamicFusionResizeGenericValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            shape, data_type, quantization_info, data_layout, policy, sampling_policy, align_corners,
+            quantization_info);
     }
 };
 
@@ -257,4 +269,4 @@
 } // namespace test
 } // namespace arm_compute
 
-#endif /* TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_RESIZEFIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_RESIZEFIXTURE_H
diff --git a/tests/validation/fixtures/dynamic_fusion/operators/SoftmaxFixture.h b/tests/validation/fixtures/dynamic_fusion/operators/SoftmaxFixture.h
index 1ed133d..175d4ff 100644
--- a/tests/validation/fixtures/dynamic_fusion/operators/SoftmaxFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/operators/SoftmaxFixture.h
@@ -1,5 +1,5 @@
 /*
-* Copyright (c) 2023 Arm Limited.
+* Copyright (c) 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,18 +21,18 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_SOFTMAXFIXTURE
-#define TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_SOFTMAXFIXTURE
+#ifndef ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_SOFTMAXFIXTURE_H
+#define ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_SOFTMAXFIXTURE_H
 
 #include "arm_compute/dynamic_fusion/runtime/gpu/cl/ClWorkloadRuntime.h"
 #include "arm_compute/dynamic_fusion/sketch/attributes/SoftmaxAttributes.h"
 #include "arm_compute/dynamic_fusion/sketch/gpu/GpuWorkloadSketch.h"
 
-#include "tests/SimpleTensor.h"
 #include "tests/framework/Fixture.h"
 #include "tests/framework/Macros.h"
-#include "tests/validation/Validation.h"
+#include "tests/SimpleTensor.h"
 #include "tests/validation/reference/SoftmaxLayer.h"
+#include "tests/validation/Validation.h"
 
 using namespace arm_compute::experimental::dynamic_fusion;
 
@@ -56,17 +56,17 @@
     template <typename U>
     void fill(U &&tensor)
     {
-        if(tensor.data_type() == DataType::F32)
+        if (tensor.data_type() == DataType::F32)
         {
             std::uniform_real_distribution<float> distribution(-10.0f, 10.0f);
             library->fill(tensor, distribution, 0);
         }
-        else if(tensor.data_type() == DataType::F16)
+        else if (tensor.data_type() == DataType::F16)
         {
-            arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -10.0f, 10.0f };
+            arm_compute::utils::uniform_real_distribution_16bit<half> distribution{-10.0f, 10.0f};
             library->fill(tensor, distribution, 0);
         }
-        else if(!is_data_type_quantized(tensor.data_type()))
+        else if (!is_data_type_quantized(tensor.data_type()))
         {
             std::uniform_int_distribution<> distribution(0, 100);
             library->fill(tensor, distribution, 0);
@@ -81,14 +81,14 @@
     {
         // Create a new workload sketch
         CLCompileContext   cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
-        GpuWorkloadContext context        = GpuWorkloadContext{ &cl_compile_ctx };
-        GpuWorkloadSketch  sketch{ &context };
+        GpuWorkloadContext context        = GpuWorkloadContext{&cl_compile_ctx};
+        GpuWorkloadSketch  sketch{&context};
 
         SoftmaxAttributes softmax_attr{};
         softmax_attr.axis(axis).beta(beta).is_log_softmax(is_log);
-        TensorInfo src_info = context.create_tensor_info(shape, 1, data_type);
-        TensorInfo dst_info = context.create_tensor_info(shape, 1, data_type);
-        FunctionType::create_op(sketch, &src_info, &dst_info, softmax_attr);
+        ITensorInfo *src_info = context.create_tensor_info(shape, 1, data_type);
+        ITensorInfo *dst_info = context.create_tensor_info(shape, 1, data_type);
+        FunctionType::create_op(sketch, src_info, dst_info, softmax_attr);
 
         // Configure runtime
         ClWorkloadRuntime runtime;
@@ -96,7 +96,7 @@
 
         // (Important) Allocate auxiliary tensor memory if there are any
         // Instead of using ACL allocated memory, the user can choose to import memory into the tensors
-        for(auto &data : runtime.get_auxiliary_tensors())
+        for (auto &data : runtime.get_auxiliary_tensors())
         {
             CLTensor     *tensor      = std::get<0>(data);
             TensorInfo    info        = std::get<1>(data);
@@ -109,8 +109,8 @@
         TensorType dst{};
 
         // Initialize user tensors
-        src.allocator()->init(src_info);
-        dst.allocator()->init(dst_info);
+        src.allocator()->init(*src_info);
+        dst.allocator()->init(*dst_info);
 
         // Allocate and fill user tensors
         src.allocator()->allocate();
@@ -118,15 +118,16 @@
         fill(AccessorType(src));
 
         // Run runtime
-        runtime.run({ &src, &dst });
+        runtime.run({&src, &dst});
 
         return dst;
     }
 
-    SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, float beta, int32_t axis, bool is_log)
+    SimpleTensor<T>
+    compute_reference(const TensorShape &shape, DataType data_type, float beta, int32_t axis, bool is_log)
     {
         // Create reference
-        SimpleTensor<T> src{ shape, data_type, 1 };
+        SimpleTensor<T> src{shape, data_type, 1};
 
         // Fill reference
         fill(src);
@@ -139,16 +140,14 @@
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DynamicFusionSoftmaxValidationFixture : public DynamicFusionSoftmaxValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DynamicFusionSoftmaxValidationFixture
+    : public DynamicFusionSoftmaxValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
 {
 public:
     void setup(TensorShape shape, DataType data_type, float beta, size_t axis, bool is_log)
     {
-        DynamicFusionSoftmaxValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape,
-                                                                                                       data_type,
-                                                                                                       beta,
-                                                                                                       axis,
-                                                                                                       is_log);
+        DynamicFusionSoftmaxValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(
+            shape, data_type, beta, axis, is_log);
     }
 };
 
@@ -156,4 +155,4 @@
 } // namespace test
 } // namespace arm_compute
 
-#endif /* TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_SOFTMAXFIXTURE */
+#endif // ACL_TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_SOFTMAXFIXTURE_H