COMPMID-1979: Fuse Activation Function in CLGEMM - part 4

Fused activation function in CLGEMM

Change-Id: I644fdf09349325c0b3a2cd5fef2a3ea2c974149d
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1640
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index b4d94ec..2c17f27 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -1775,7 +1775,8 @@
           _gemmlowp_output_stage(),
           _fp_mixed_precision(false),
           _broadcast_bias(false),
-          _pretranpose_B(true)
+          _pretranpose_B(true),
+          _activation_info()
     {
     }
     /** Constructor
@@ -1791,9 +1792,11 @@
      * @param[in] gemmlowp_output_stage       (Optional) GEMMLowp Output stage info
      * @param[in] fp_mixed_precision          (Optional) Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy.
      * @param[in] broadcast_bias              (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
+     * @param[in] activation_info             (Optional) Activation to apply after the matrix multiplication
      */
     GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
-             GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool broadcast_bias = false) noexcept
+             GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool broadcast_bias = false,
+             const ActivationLayerInfo &activation_info = ActivationLayerInfo()) noexcept
         : _is_a_reshaped(is_a_reshaped),
           _is_b_reshaped(is_b_reshaped),
           _reshape_b_only_on_first_run(reshape_b_only_on_first_run),
@@ -1803,7 +1806,8 @@
           _gemmlowp_output_stage(gemmlowp_output_stage),
           _fp_mixed_precision(fp_mixed_precision),
           _broadcast_bias(broadcast_bias),
-          _pretranpose_B(reshape_b_only_on_first_run)
+          _pretranpose_B(reshape_b_only_on_first_run),
+          _activation_info(activation_info)
     {
     }
     /** Flag which specifies if the matrix A has been reshaped
@@ -1896,6 +1900,14 @@
     {
         _pretranpose_B = flag;
     }
+    /** Activation layer to apply after the matrix multiplication
+     *
+     * @return ActivationLayerInfo object
+     */
+    ActivationLayerInfo activation_info() const
+    {
+        return _activation_info;
+    }
 
 private:
     bool                    _is_a_reshaped;
@@ -1908,6 +1920,7 @@
     bool                    _fp_mixed_precision;
     bool                    _broadcast_bias;
     bool                    _pretranpose_B;
+    ActivationLayerInfo     _activation_info;
 };
 
 /** Winograd information */
diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h
index 8c462fa..e2a92a8 100644
--- a/arm_compute/runtime/CL/functions/CLGEMM.h
+++ b/arm_compute/runtime/CL/functions/CLGEMM.h
@@ -127,7 +127,6 @@
 
     CLMemoryGroup                             _memory_group;
     CLGEMMMatrixMultiplyKernel                _mm_kernel;
-    CLGEMMMatrixAdditionKernel                _ma_kernel;
     CLGEMMReshapeLHSMatrixKernel              _reshape_lhs_kernel;
     CLGEMMReshapeRHSMatrixKernel              _reshape_rhs_kernel;
     CLGEMMMatrixMultiplyReshapedKernel        _mm_reshaped_kernel;
@@ -135,7 +134,6 @@
     CLTensor                                  _tmp_a;
     CLTensor                                  _tmp_b;
     const ICLTensor                          *_original_b;
-    bool                                      _run_addition;
     bool                                      _reshape_b_only_on_first_run;
     bool                                      _is_prepared;
     GEMMType                                  _gemm_type;
diff --git a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
index e9a3f9b..027727c 100644
--- a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
+++ b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
@@ -163,8 +163,10 @@
      *                                       except for input of QASYMM8 type where output should be of S32 type.
      * @param[in]      gemmlowp_output_stage GEMMLowp output stage info
      * @param[in]      gemm_3d_depth         Depth of GEMM 3D
+     * @param[in]      act_info              Activation to apply after the matrix multiplication
      */
-    void configure_mm(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const GEMMLowpOutputStageInfo &gemmlowp_output_stage, int gemm_3d_depth = 1);
+    void configure_mm(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const GEMMLowpOutputStageInfo &gemmlowp_output_stage, int gemm_3d_depth,
+                      const ActivationLayerInfo &act_info);
     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMConvolutionLayer matrix multiply routines
      *
      * @param[in] input                 Input tensor. Data types supported: QASYMM8/F16/F32.
@@ -176,22 +178,21 @@
      * @param[in] gemmlowp_output_stage GEMMLowp output stage info
      * @param[in] gemm_3d_depth         Depth of GEMM 3D
      * @param[in] skip_im2col           Flag which specifies if im2col has to be skipped. i.e. 1x1 convolution with NHWC data layout.
-     * @param[in] run_addition          Flag which specifies if @ref CLGEMMMatrixMatrixMultiplyAddition to be run.
+     * @param[in] act_info              Activation to apply after the matrix multiplication
      *
      * @return a status
      */
     static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const GEMMLowpOutputStageInfo &gemmlowp_output_stage,
-                              int gemm_3d_depth, bool skip_im2col, bool run_addition);
+                              int gemm_3d_depth, bool skip_im2col, const ActivationLayerInfo &act_info);
 
 private:
-    CLMemoryGroup                        _memory_group;
-    CLConvolutionLayerReshapeWeights     _reshape_weights;
-    CLIm2ColKernel                       _im2col_kernel;
-    CLGEMM                               _mm_gemm;
-    CLGEMMLowpMatrixMultiplyCore         _mm_gemmlowp;
-    CLCol2ImKernel                       _col2im_kernel;
-    CLActivationLayer                    _activationlayer_function;
-    CLSaturatedArithmeticOperationKernel _add_bias_kernel;
+    CLMemoryGroup                    _memory_group;
+    CLConvolutionLayerReshapeWeights _reshape_weights;
+    CLIm2ColKernel                   _im2col_kernel;
+    CLGEMM                           _mm_gemm;
+    CLGEMMLowpMatrixMultiplyCore     _mm_gemmlowp;
+    CLCol2ImKernel                   _col2im_kernel;
+    CLActivationLayer                _activationlayer_function;
 
     const ICLTensor *_original_weights;
 
@@ -199,15 +200,11 @@
     CLTensor _weights_reshaped;
     CLTensor _gemm_output;
 
-    DataLayout _data_layout;
-
-    bool _append_bias;
     bool _skip_im2col;
     bool _skip_col2im;
     bool _is_quantized;
-    bool _is_activationlayer_enabled;
+    bool _fuse_activation;
     bool _is_prepared;
-    bool _run_addition;
 };
 } // namespace arm_compute
 #endif /* __ARM_COMPUTE_CLGEMMCONVOLUTIONLAYER_H__ */
diff --git a/examples/cl_cache.cpp b/examples/cl_cache.cpp
index 998c468..7d8a515 100644
--- a/examples/cl_cache.cpp
+++ b/examples/cl_cache.cpp
@@ -28,8 +28,6 @@
 #include "arm_compute/runtime/CL/CLScheduler.h"
 #include "utils/Utils.h"
 
-#include <chrono>
-
 using namespace arm_compute;
 using namespace utils;
 
@@ -46,7 +44,7 @@
     {
         std::cout << "Once the program has run and created the file cache.bin, rerun with --restore_cache." << std::endl;
         CLScheduler::get().default_init();
-        auto start_time = std::chrono::high_resolution_clock::now();
+
         if(argc > 1)
         {
             std::string argv1 = argv[1];
@@ -88,10 +86,6 @@
         permute_nchw.configure(&tensor_nhwc, &tensor_nchw_result, vector_nhwc_to_nchw);
         tensor_nchw_result.allocator()->allocate();
 
-        auto end_time        = std::chrono::high_resolution_clock::now();
-        auto time_elapsed    = end_time - start_time;
-        auto time_elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(time_elapsed).count();
-        std::cout << "Configuration time " << time_elapsed_ms << " ms " << std::endl;
         // Save the opencl kernels to a file
         save_program_cache_to_file();
 
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index c0ccd0f..e78395f 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -48,7 +48,6 @@
 CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager)
     : _memory_group(std::move(memory_manager)),
       _mm_kernel(),
-      _ma_kernel(),
       _reshape_lhs_kernel(),
       _reshape_rhs_kernel(),
       _mm_reshaped_kernel(),
@@ -56,7 +55,6 @@
       _tmp_a(),
       _tmp_b(),
       _original_b(nullptr),
-      _run_addition(false),
       _reshape_b_only_on_first_run(false),
       _is_prepared(false),
       _gemm_type(GEMMType::NATIVE)
@@ -118,10 +116,10 @@
     // Set the target for the kernels
     _mm_kernel.set_target(gpu_target);
 
-    GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d());
+    GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d(), gemm_info.broadcast_bias());
 
     // Configure and tune matrix multiply kernel
-    _mm_kernel.configure(a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision());
+    _mm_kernel.configure(a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
 
     // Tune kernel statically
     CLScheduler::get().tune_kernel_static(_mm_kernel);
@@ -162,7 +160,7 @@
     lhs_info.interleave = true;
     lhs_info.transpose  = true;
 
-    GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false);
+    GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
 
     _memory_group.manage(&_tmp_a);
     if(!_reshape_b_only_on_first_run)
@@ -177,7 +175,7 @@
     _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
 
     // Configure and tune matrix multiply kernel
-    _mm_kernel.configure(&_tmp_a, &_tmp_b, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision());
+    _mm_kernel.configure(&_tmp_a, &_tmp_b, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
 
     CLScheduler::get().tune_kernel_static(_mm_kernel);
 
@@ -200,13 +198,15 @@
     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
     const GPUTarget    gpu_target              = CLScheduler::get().target();
     bool               broadcast_bias          = gemm_info.broadcast_bias();
-    GEMMKernelInfo     kernel_info;
+
+    GEMMKernelInfo kernel_info;
     kernel_info.m                       = m;
     kernel_info.n                       = n;
     kernel_info.k                       = k;
     kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
     kernel_info.reinterpret_input_as_3d = false;
     kernel_info.broadcast_bias          = broadcast_bias;
+    kernel_info.activation_info         = gemm_info.activation_info();
 
     // Set the target for the kernels
     _reshape_lhs_kernel.set_target(gpu_target);
@@ -255,13 +255,15 @@
     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
     const GPUTarget    gpu_target              = CLScheduler::get().target();
     bool               broadcast_bias          = gemm_info.broadcast_bias();
-    GEMMKernelInfo     kernel_info;
+
+    GEMMKernelInfo kernel_info;
     kernel_info.m                       = m;
     kernel_info.n                       = n;
     kernel_info.k                       = k;
     kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
     kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
     kernel_info.broadcast_bias          = broadcast_bias;
+    kernel_info.activation_info         = gemm_info.activation_info();
 
     // Set the target for the kernels
     _mm_kernel.set_target(gpu_target);
@@ -305,21 +307,12 @@
     const unsigned int n                       = b->dimension(0);
     const unsigned int k                       = a->dimension(0);
     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
-    const bool         add_c                   = (beta != 0.f && c != nullptr);
-    const bool         is_beta_one             = std::abs(1.0f - beta) < 0.00001f;
-    const bool         fuse_add                = is_beta_one && (c != nullptr && c->num_dimensions() == 1);
 
-    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d);
+    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, gemm_info.broadcast_bias());
 
     // Validate matrix multiply
-    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(a, b, (add_c && fuse_add) ? c : nullptr, output, alpha, beta,
-                                                                     false, reshape_info, gpu_target, gemm_info.fp_mixed_precision()));
-
-    if(add_c && !fuse_add)
-    {
-        // Validate matrix addition kernel
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta));
-    }
+    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(a, b, c, output, alpha, beta,
+                                                                     false, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
 
     return Status{};
 }
@@ -340,9 +333,6 @@
     int                mult_transpose1xW_width   = 1;
     int                mult_interleave4x4_height = 1;
     const int          depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
-    const bool         add_c                     = (beta != 0.f && c != nullptr);
-    const bool         is_beta_one               = std::abs(1.0f - beta) < 0.00001f;
-    const bool         fuse_add                  = is_beta_one && (c != nullptr && c->num_dimensions() == 1);
 
     if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
     {
@@ -364,7 +354,7 @@
     lhs_info.interleave = true;
     lhs_info.transpose  = true;
 
-    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false);
+    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
 
     // Validate interleave kernel
     auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
@@ -375,14 +365,8 @@
     ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
 
     // Validate matrix multiply
-    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, (add_c && fuse_add) ? c : nullptr, output, alpha, beta,
-                                                                     true, reshape_info, gpu_target, gemm_info.fp_mixed_precision()));
-
-    if(add_c && !fuse_add)
-    {
-        // Validate matrix addition kernel
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta));
-    }
+    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta,
+                                                                     true, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
 
     return Status{};
 }
@@ -405,13 +389,15 @@
     const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
     const bool         broadcast_bias          = gemm_info.broadcast_bias();
-    GEMMKernelInfo     kernel_info;
+
+    GEMMKernelInfo kernel_info;
     kernel_info.m                       = m;
     kernel_info.n                       = n;
     kernel_info.k                       = k;
     kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
     kernel_info.reinterpret_input_as_3d = false;
     kernel_info.broadcast_bias          = broadcast_bias;
+    kernel_info.activation_info         = gemm_info.activation_info();
 
     GEMMLHSMatrixInfo lhs_info;
     GEMMRHSMatrixInfo rhs_info;
@@ -452,13 +438,15 @@
     const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
     const bool         broadcast_bias          = gemm_info.broadcast_bias();
-    GEMMKernelInfo     kernel_info;
+
+    GEMMKernelInfo kernel_info;
     kernel_info.m                       = m;
     kernel_info.n                       = n;
     kernel_info.k                       = k;
     kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
     kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
     kernel_info.broadcast_bias          = broadcast_bias;
+    kernel_info.activation_info         = gemm_info.activation_info();
 
     GEMMLHSMatrixInfo lhs_info;
     GEMMRHSMatrixInfo rhs_info;
@@ -501,9 +489,7 @@
     // Select GEMMType
     _gemm_type = select_gemm_type(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target);
 
-    const bool is_fuse_add_c_supported = (_gemm_type == GEMMType::RESHAPED_V2) || (_gemm_type == GEMMType::RESHAPED_ONLY_RHS);
-    const bool add_c                   = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
-    const bool fuse_add_c              = add_c && is_fuse_add_c_supported;
+    const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
 
     const ICLTensor *c_to_use = fuse_add_c ? c : nullptr;
 
@@ -534,13 +520,6 @@
             ARM_COMPUTE_ERROR("GEMMType not supported");
         }
     }
-
-    // Configure matrix addition kernel
-    if(add_c && !fuse_add_c)
-    {
-        _ma_kernel.configure(c, output, beta);
-        _run_addition = true;
-    }
 }
 
 Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
@@ -555,9 +534,7 @@
     // Select GEMMType
     GEMMType gemm_type = select_gemm_type(m, n, k, a->data_type(), gemm_info.reshape_b_only_on_first_run(), gpu_target);
 
-    const bool is_fuse_add_c_supported = (gemm_type == GEMMType::RESHAPED_V2) || (gemm_type == GEMMType::RESHAPED_ONLY_RHS);
-    const bool add_c                   = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
-    const bool fuse_add_c              = add_c && is_fuse_add_c_supported;
+    const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
 
     const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
 
@@ -589,12 +566,6 @@
         }
     }
 
-    // Validate matrix addition kernel
-    if(add_c && !fuse_add_c)
-    {
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta));
-    }
-
     return Status{};
 }
 
@@ -609,7 +580,7 @@
     {
         case GEMMType::NATIVE:
         {
-            CLScheduler::get().enqueue(_mm_kernel, !_run_addition);
+            CLScheduler::get().enqueue(_mm_kernel, true);
             break;
         }
         case GEMMType::RESHAPED_V1:
@@ -623,7 +594,7 @@
                 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
             }
 
-            CLScheduler::get().enqueue(_mm_kernel, !_run_addition);
+            CLScheduler::get().enqueue(_mm_kernel, true);
             break;
         }
         case GEMMType::RESHAPED_V2:
@@ -637,7 +608,7 @@
                 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
             }
 
-            CLScheduler::get().enqueue(_mm_reshaped_kernel, !_run_addition);
+            CLScheduler::get().enqueue(_mm_reshaped_kernel, true);
             break;
         }
         case GEMMType::RESHAPED_ONLY_RHS:
@@ -648,7 +619,7 @@
                 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
             }
 
-            CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, !_run_addition);
+            CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, true);
             break;
         }
         default:
@@ -656,12 +627,6 @@
             ARM_COMPUTE_ERROR("GEMMType not supported");
         }
     }
-
-    // Run matrix addition kernel
-    if(_run_addition)
-    {
-        CLScheduler::get().enqueue(_ma_kernel);
-    }
 }
 
 void CLGEMM::prepare()
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index 99f045a..be6be04 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -91,22 +91,27 @@
 }
 
 CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(), _add_bias_kernel(),
-      _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _data_layout(DataLayout::NCHW), _append_bias(false), _skip_im2col(false), _skip_col2im(false), _is_quantized(false),
-      _is_activationlayer_enabled(false), _is_prepared(false), _run_addition(true)
+    : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(),
+      _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _skip_im2col(false), _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _is_prepared(false)
 {
 }
 
 void CLGEMMConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const GEMMLowpOutputStageInfo &gemmlowp_output_stage,
-                                          int gemm_3d_depth)
+                                          int gemm_3d_depth, const ActivationLayerInfo &act_info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights);
-    ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), gemmlowp_output_stage, gemm_3d_depth, _skip_im2col,
-                                           _run_addition));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), gemmlowp_output_stage, gemm_3d_depth, _skip_im2col, act_info));
 
-    const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
-                                         gemm_3d_depth, _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
-                                         false, gemmlowp_output_stage);
+    const GEMMInfo &gemm_info = GEMMInfo(false,                 // is_a_reshaped
+                                         false,                 // is_b_reshaped
+                                         true,                  // reshape_b_only_on_first_run
+                                         gemm_3d_depth,         // depth_output_gemm3d
+                                         _skip_im2col,          // reinterpret_input_as_3d
+                                         false,                 // retain_internal_weights
+                                         gemmlowp_output_stage, // gemmlowp_output_stage
+                                         false,                 // fp_mixed_precision
+                                         true,                  // broadcast_bias
+                                         act_info);             // activation_info
 
     if(_is_quantized)
     {
@@ -126,21 +131,26 @@
     }
     else
     {
-        // Bias does not need to be added in GEMM if im2col is being used or the Matrix Addition kernel needs to be run
-        const bool skip_bias_in_gemm = _run_addition || !_skip_im2col;
         // Configure matrix multiply function
-        _mm_gemm.configure(input, weights, (skip_bias_in_gemm) ? nullptr : biases, output, 1.0f, 1.0f, gemm_info);
+        _mm_gemm.configure(input, weights, biases, output, 1.0f, 1.0f, gemm_info);
     }
 }
 
 Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
-                                           const GEMMLowpOutputStageInfo &gemmlowp_output_stage, int gemm_3d_depth, bool skip_im2col, bool run_addition)
+                                           const GEMMLowpOutputStageInfo &gemmlowp_output_stage, int gemm_3d_depth, bool skip_im2col, const ActivationLayerInfo &act_info)
 {
     const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
 
-    const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
-                                         gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */,
-                                         false, gemmlowp_output_stage);
+    const GEMMInfo &gemm_info = GEMMInfo(false,                 // is_a_reshaped
+                                         false,                 // is_b_reshaped
+                                         true,                  // reshape_b_only_on_first_run
+                                         gemm_3d_depth,         // depth_output_gemm3d
+                                         skip_im2col,           // reinterpret_input_as_3d
+                                         false,                 // retain_internal_weights
+                                         gemmlowp_output_stage, // gemmlowp_output_stage
+                                         false,                 // fp_mixed_precision
+                                         true,                  // broadcast_bias
+                                         act_info);             // activation_info
 
     if(is_quantized)
     {
@@ -159,10 +169,8 @@
     }
     else
     {
-        // Bias does not need to be added in GEMM if im2col is being used or the Matrix Addition kernel needs to be run
-        const bool skip_bias_in_gemm = run_addition || !skip_im2col;
         // Perform validation step on Matrix multiply function
-        return CLGEMM::validate(input, weights, (skip_bias_in_gemm) ? nullptr : biases, output, 1.0f, 1.0f, gemm_info);
+        return CLGEMM::validate(input, weights, biases, output, 1.0f, 1.0f, gemm_info);
     }
 }
 
@@ -194,15 +202,14 @@
     const UniformQuantizationInfo wq_info = weights->info()->quantization_info().uniform();
     const UniformQuantizationInfo oq_info = output->info()->quantization_info().uniform();
 
-    _is_prepared                = weights_info.retain_internal_weights();
-    _original_weights           = weights;
-    _is_quantized               = is_data_type_quantized_asymmetric(input->info()->data_type());
-    _data_layout                = data_layout;
-    _skip_im2col                = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
-    _skip_col2im                = data_layout == DataLayout::NHWC;
-    _append_bias                = (biases != nullptr) && (!_is_quantized);
-    _is_activationlayer_enabled = act_info.enabled();
-    _run_addition               = (_skip_im2col) && (_append_bias);
+    _is_prepared      = weights_info.retain_internal_weights();
+    _original_weights = weights;
+    _is_quantized     = is_data_type_quantized_asymmetric(input->info()->data_type());
+    _skip_im2col      = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
+    _skip_col2im      = data_layout == DataLayout::NHWC;
+
+    // Only for quantize there are few cases where we cannot fuse the activation function in GEMM
+    _fuse_activation = true;
 
     // Set the GPU target for im2col and col2im
     _im2col_kernel.set_target(CLScheduler::get().target());
@@ -211,8 +218,6 @@
     const ICLTensor *gemm_input_to_use  = input;
     ICLTensor       *gemm_output_to_use = output;
 
-    const ICLTensor *biases_to_use = (_append_bias && !_skip_im2col) ? biases : nullptr;
-
     // Get parameters from conv_info
     unsigned int stride_x = 0;
     unsigned int stride_y = 0;
@@ -230,9 +235,22 @@
 
     unsigned int mat_weights_cols = weights->info()->dimension(idx_kernels) / num_groups;
 
-    // _weights_reshaped will be auto configured in the kernel.
-    // Just append biases and do not transpose 1xW as it will be reshaped in CLGEMM
-    _reshape_weights.configure(weights, biases_to_use, &_weights_reshaped, num_groups);
+    const ICLTensor *biases_to_use = biases;
+    bool             append_bias   = false;
+
+    if(num_groups != 1 && biases != nullptr)
+    {
+        // num_groups != 1 can only be for NCHW
+        // Since it is missing an utility function to reshape the biases, we append the biases into the weights tensor
+        biases_to_use = nullptr;
+        append_bias   = true;
+
+        _reshape_weights.configure(weights, biases, &_weights_reshaped, num_groups);
+    }
+    else
+    {
+        _reshape_weights.configure(weights, nullptr, &_weights_reshaped, num_groups);
+    }
 
     // Create tensor to store im2col reshaped inputs
     if(!_skip_im2col)
@@ -240,7 +258,7 @@
         _memory_group.manage(&_im2col_output);
 
         // Configure and tune im2col. im2col output shape is auto-initialized
-        _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, dilation, num_groups);
+        _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation, num_groups);
 
         // Set quantization info
         _im2col_output.info()->set_quantization_info(input->info()->quantization_info());
@@ -249,11 +267,6 @@
         // Update GEMM input
         gemm_input_to_use = &_im2col_output;
     }
-    else if(_append_bias)
-    {
-        // Configure add bias kernel
-        _add_bias_kernel.configure(ArithmeticOperation::ADD, output, biases, output, ConvertPolicy::SATURATE);
-    }
 
     // Create GEMM output tensor
     if(!_skip_col2im)
@@ -299,16 +312,20 @@
                                                                                    ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
                                                                                  };
 
-        if(_is_activationlayer_enabled && supported_acts.count(act_info.activation()) != 0)
+        if(act_info.enabled())
         {
-            const int a_const_int = quantize_qasymm8(act_info.a(), output_quant_info);
-            const int b_const_int = quantize_qasymm8(act_info.b(), output_quant_info);
+            if(supported_acts.count(act_info.activation()) != 0)
+            {
+                const int a_const_int = quantize_qasymm8(act_info.a(), output_quant_info);
+                const int b_const_int = quantize_qasymm8(act_info.b(), output_quant_info);
 
-            min_activation = act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU ? output_quant_info.offset : b_const_int;
-            max_activation = act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU ? 255 : a_const_int;
-
-            // If the activation layer is RELU, BOUNDED_RELU or LU_BOUNDED_RELU, we can use the GEMMLowp output stage to perform this operation
-            _is_activationlayer_enabled = false;
+                min_activation = act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU ? output_quant_info.offset : b_const_int;
+                max_activation = act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU ? 255 : a_const_int;
+            }
+            else
+            {
+                _fuse_activation = false;
+            }
         }
 
         // Set the GEMMLowp output stage info
@@ -323,7 +340,7 @@
     // In case of NHWC, we need to run GEMM3D (gemm_3d_depth != 0) in order to avoid reshaping the output matrix
     const unsigned int gemm_3d_depth = (data_layout == DataLayout::NHWC) ? conv_h : 0;
 
-    configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth);
+    configure_mm(gemm_input_to_use, &_weights_reshaped, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, act_info);
 
     if(!_skip_im2col)
     {
@@ -345,7 +362,7 @@
     ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(idx_width) != conv_w) || (output->info()->dimension(idx_height) != conv_h),
                              "Output shape does not match the expected one");
 
-    if(_is_activationlayer_enabled)
+    if(!_fuse_activation)
     {
         _activationlayer_function.configure(output, nullptr, act_info);
     }
@@ -382,12 +399,10 @@
     const ITensorInfo *gemm_output_to_use = output;
     const ITensorInfo *weights_to_use     = weights;
 
-    const bool is_quantized               = is_data_type_quantized_asymmetric(data_type);
-    const bool append_bias                = (biases != nullptr) && (!is_quantized);
-    const bool skip_im2col                = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
-    const bool skip_col2im                = data_layout == DataLayout::NHWC;
-    bool       is_activationlayer_enabled = act_info.enabled();
-    const bool run_addition               = (skip_im2col) && (append_bias);
+    const bool is_quantized    = is_data_type_quantized_asymmetric(data_type);
+    const bool skip_im2col     = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
+    const bool skip_col2im     = data_layout == DataLayout::NHWC;
+    bool       fuse_activation = true;
 
     const UniformQuantizationInfo iq_info = input->quantization_info().uniform();
     const UniformQuantizationInfo wq_info = weights->quantization_info().uniform();
@@ -429,10 +444,26 @@
 
     unsigned int mat_weights_cols = weights->dimension(idx_kernels) / num_groups;
 
-    // Output tensor auto inizialitation if not yet initialized
-    ARM_COMPUTE_RETURN_ON_ERROR(CLConvolutionLayerReshapeWeights::validate(weights, is_quantized ? nullptr : biases, nullptr, num_groups));
-    weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, (append_bias && !skip_im2col), num_groups), 1, data_type);
-    weights_to_use        = &weights_reshaped_info;
+    const ITensorInfo *biases_to_use = biases;
+    bool               append_bias   = false;
+
+    if(num_groups != 1 && biases != nullptr)
+    {
+        // num_groups != 1 can only be for NCHW
+        // Since it is missing an utility function to reshape the biases, we append the biases into the weights tensor
+        biases_to_use = nullptr;
+        append_bias   = true;
+
+        ARM_COMPUTE_RETURN_ON_ERROR(CLConvolutionLayerReshapeWeights::validate(weights, biases, nullptr, num_groups));
+        weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, true, num_groups), 1, data_type);
+    }
+    else
+    {
+        ARM_COMPUTE_RETURN_ON_ERROR(CLConvolutionLayerReshapeWeights::validate(weights, nullptr, nullptr, num_groups));
+        weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, false, num_groups), 1, data_type);
+    }
+
+    weights_to_use = &weights_reshaped_info;
 
     if(!skip_im2col)
     {
@@ -446,11 +477,6 @@
         ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, kernel_dims, conv_info, append_bias, dilation, num_groups));
         gemm_input_to_use = &im2col_reshaped_info;
     }
-    else if(run_addition)
-    {
-        // Validate add bias kernel
-        ARM_COMPUTE_RETURN_ON_ERROR(CLSaturatedArithmeticOperationKernel::validate(ArithmeticOperation::ADD, output, biases, output, ConvertPolicy::SATURATE));
-    }
 
     // Create GEMM output tensor
     if(!skip_col2im)
@@ -490,16 +516,20 @@
                                                                                    ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
                                                                                  };
 
-        if(is_activationlayer_enabled && supported_acts.count(act_info.activation()) != 0)
+        if(act_info.enabled())
         {
-            const int a_const_int = quantize_qasymm8(act_info.a(), output_quant_info);
-            const int b_const_int = quantize_qasymm8(act_info.b(), output_quant_info);
+            if(supported_acts.count(act_info.activation()) != 0)
+            {
+                const int a_const_int = quantize_qasymm8(act_info.a(), output_quant_info);
+                const int b_const_int = quantize_qasymm8(act_info.b(), output_quant_info);
 
-            min_activation = act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU ? output_quant_info.offset : b_const_int;
-            max_activation = act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU ? 255 : a_const_int;
-
-            // If the activation layer is RELU, BOUNDED_RELU or LU_BOUNDED_RELU, we can use the GEMMLowp output stage to perform this operation
-            is_activationlayer_enabled = false;
+                min_activation = act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU ? output_quant_info.offset : b_const_int;
+                max_activation = act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU ? 255 : a_const_int;
+            }
+            else
+            {
+                fuse_activation = false;
+            }
         }
 
         // Set the GEMMLowp output stage info
@@ -513,7 +543,7 @@
     // In case of NHWC, we need to run GEMM3D (gemm_3d_depth != 0) in order to avoid reshaping the output matrix
     const unsigned int gemm_3d_depth = (data_layout == DataLayout::NHWC) ? conv_h : 0;
 
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, skip_im2col, run_addition));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, skip_im2col, act_info));
 
     // Validate Col2Im
     if(!skip_col2im)
@@ -522,7 +552,7 @@
     }
 
     //Validate Activation Layer
-    if(is_activationlayer_enabled)
+    if(!fuse_activation)
     {
         ARM_COMPUTE_RETURN_ON_ERROR(CLActivationLayer::validate(output, nullptr, act_info));
     }
@@ -554,19 +584,14 @@
         _mm_gemm.run();
     }
 
-    if(_run_addition)
-    {
-        CLScheduler::get().enqueue(_add_bias_kernel);
-    }
-
     // Reshape output matrix
     if(!_skip_col2im)
     {
         CLScheduler::get().enqueue(_col2im_kernel, false);
     }
 
-    //Run Activation Layer if enabled
-    if(_is_activationlayer_enabled)
+    //Run Activation Layer if we cannot fuse in GEMM
+    if(!_fuse_activation)
     {
         _activationlayer_function.run();
     }
diff --git a/tests/datasets/LargeGEMMDataset.h b/tests/datasets/LargeGEMMDataset.h
index 0876ae1..0ca0b04 100644
--- a/tests/datasets/LargeGEMMDataset.h
+++ b/tests/datasets/LargeGEMMDataset.h
@@ -55,13 +55,13 @@
 public:
     LargeGEMMOutput3DDataset()
     {
-        add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U, 143U, 3U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f);
-        add_config(TensorShape(681U, 1025U), TensorShape(213U, 681U), TensorShape(213U, 205U, 5U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f);
-        add_config(TensorShape(364U, 3025U), TensorShape(96U, 364U), TensorShape(96U, 605U, 5U), TensorShape(96U, 605U, 5U), 1.0f, 0.0f);
-        add_config(TensorShape(1201U, 729U), TensorShape(128U, 1201U), TensorShape(128U, 243U, 3U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f);
-        add_config(TensorShape(2305U, 169U), TensorShape(384U, 2305U), TensorShape(384U, 13U, 13U), TensorShape(384U, 13U, 13U), 1.0f, 0.0f);
-        add_config(TensorShape(1729U, 170U), TensorShape(192U, 1729U), TensorShape(192U, 85U, 2U), TensorShape(192U, 85U, 2U), 1.0f, 0.0f);
-        add_config(TensorShape(1729U, 170U), TensorShape(128U, 1729U), TensorShape(128U, 17U, 10U), TensorShape(128U, 17U, 10U), 1.0f, 0.0f);
+        add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f);
+        add_config(TensorShape(681U, 1025U), TensorShape(213U, 681U), TensorShape(213U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f);
+        add_config(TensorShape(364U, 3025U), TensorShape(96U, 364U), TensorShape(96U), TensorShape(96U, 605U, 5U), 1.0f, 0.0f);
+        add_config(TensorShape(1201U, 729U), TensorShape(128U, 1201U), TensorShape(128U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f);
+        add_config(TensorShape(2305U, 169U), TensorShape(384U, 2305U), TensorShape(384U), TensorShape(384U, 13U, 13U), 1.0f, 0.0f);
+        add_config(TensorShape(1729U, 170U), TensorShape(192U, 1729U), TensorShape(192U), TensorShape(192U, 85U, 2U), 1.0f, 0.0f);
+        add_config(TensorShape(1729U, 170U), TensorShape(128U, 1729U), TensorShape(128U), TensorShape(128U, 17U, 10U), 1.0f, 0.0f);
     }
 };
 
@@ -70,13 +70,13 @@
 public:
     LargeGEMMInputOutput3DDataset()
     {
-        add_config(TensorShape(923U, 143U, 3U), TensorShape(871U, 923U), TensorShape(871U, 143U, 3U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f);
-        add_config(TensorShape(681U, 205U, 5U), TensorShape(213U, 681U), TensorShape(213U, 205U, 5U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f);
-        add_config(TensorShape(364U, 605U, 5U), TensorShape(96U, 364U), TensorShape(96U, 605U, 5U), TensorShape(96U, 605U, 5U), 0.2f, 1.2f);
-        add_config(TensorShape(1201U, 243U, 3U), TensorShape(128U, 1201U), TensorShape(128U, 243U, 3U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f);
-        add_config(TensorShape(2305U, 13U, 13U), TensorShape(384U, 2305U), TensorShape(384U, 13U, 13U), TensorShape(384U, 13U, 13U), 0.4f, 0.7f);
-        add_config(TensorShape(1729U, 85U, 2U, 2U), TensorShape(192U, 1729U), TensorShape(192U, 85U, 2U, 2U), TensorShape(192U, 85U, 2U, 2U), 1.0f, 0.0f);
-        add_config(TensorShape(1729U, 17U, 10U, 3U), TensorShape(128U, 1729U), TensorShape(128U, 17U, 10U, 3U), TensorShape(128U, 17U, 10U, 3U), 1.0f, 0.3f);
+        add_config(TensorShape(923U, 143U, 3U), TensorShape(871U, 923U), TensorShape(871U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f);
+        add_config(TensorShape(681U, 205U, 5U), TensorShape(213U, 681U), TensorShape(213U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f);
+        add_config(TensorShape(364U, 605U, 5U), TensorShape(96U, 364U), TensorShape(96U), TensorShape(96U, 605U, 5U), 0.2f, 1.2f);
+        add_config(TensorShape(1201U, 243U, 3U), TensorShape(128U, 1201U), TensorShape(128U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f);
+        add_config(TensorShape(2305U, 13U, 13U), TensorShape(384U, 2305U), TensorShape(384U), TensorShape(384U, 13U, 13U), 0.4f, 0.7f);
+        add_config(TensorShape(1729U, 85U, 2U, 2U), TensorShape(192U, 1729U), TensorShape(192U), TensorShape(192U, 85U, 2U, 2U), 1.0f, 0.0f);
+        add_config(TensorShape(1729U, 17U, 10U, 3U), TensorShape(128U, 1729U), TensorShape(128U), TensorShape(128U, 17U, 10U, 3U), 1.0f, 0.3f);
     }
 };
 } // namespace datasets
diff --git a/tests/datasets/SmallGEMMDataset.h b/tests/datasets/SmallGEMMDataset.h
index ae3c3ed..45d1a1e 100644
--- a/tests/datasets/SmallGEMMDataset.h
+++ b/tests/datasets/SmallGEMMDataset.h
@@ -55,12 +55,12 @@
 public:
     SmallGEMMOutput3DDataset()
     {
-        add_config(TensorShape(21U, 14U), TensorShape(34U, 21U), TensorShape(34U, 7U, 2U), TensorShape(34U, 7U, 2U), 1.0f, 0.0f);
-        add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U, 1U, 1U), TensorShape(23U, 1U, 1U), 1.0f, 0.0f);
-        add_config(TensorShape(38U, 12U), TensorShape(21U, 38U), TensorShape(21U, 4U, 3U), TensorShape(21U, 4U, 3U), 0.2f, 1.2f);
-        add_config(TensorShape(32U, 1U), TensorShape(17U, 32U), TensorShape(17U, 1U, 1U), TensorShape(17U, 1U, 1U), 0.4f, 0.7f);
-        add_config(TensorShape(16U, 16U), TensorShape(8U, 16U), TensorShape(8U, 8U, 2U), TensorShape(8U, 8U, 2U), 1.0f, 0.0f);
-        add_config(TensorShape(16U, 16U, 5U), TensorShape(8U, 16U, 5U), TensorShape(8U, 8U, 2U, 5U), TensorShape(8U, 8U, 2U, 5U), 1.0f, 0.0f);
+        add_config(TensorShape(21U, 14U), TensorShape(34U, 21U), TensorShape(34U), TensorShape(34U, 7U, 2U), 1.0f, 0.0f);
+        add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U), TensorShape(23U, 1U, 1U), 1.0f, 0.0f);
+        add_config(TensorShape(38U, 12U), TensorShape(21U, 38U), TensorShape(21U), TensorShape(21U, 4U, 3U), 0.2f, 1.2f);
+        add_config(TensorShape(32U, 1U), TensorShape(17U, 32U), TensorShape(17U), TensorShape(17U, 1U, 1U), 0.4f, 0.7f);
+        add_config(TensorShape(16U, 16U), TensorShape(8U, 16U), TensorShape(8U), TensorShape(8U, 8U, 2U), 1.0f, 0.0f);
+        add_config(TensorShape(16U, 16U, 5U), TensorShape(8U, 16U, 5U), TensorShape(8U), TensorShape(8U, 8U, 2U, 5U), 1.0f, 0.0f);
     }
 };
 
@@ -69,12 +69,12 @@
 public:
     SmallGEMMInputOutput3DDataset()
     {
-        add_config(TensorShape(21U, 14U, 13U), TensorShape(34U, 21U), TensorShape(34U, 14U, 13U), TensorShape(34U, 14U, 13U), 1.0f, 0.0f);
-        add_config(TensorShape(31U, 1U, 3U), TensorShape(23U, 31U), TensorShape(23U, 1U, 3U), TensorShape(23U, 1U, 3U), 1.0f, 0.0f);
-        add_config(TensorShape(38U, 12U, 2U), TensorShape(21U, 38U), TensorShape(21U, 12U, 2U), TensorShape(21U, 12U, 2U), 0.2f, 1.2f);
-        add_config(TensorShape(32U, 1U, 4U, 3U), TensorShape(17U, 32U), TensorShape(17U, 1U, 4U, 3U), TensorShape(17U, 1U, 4U, 3U), 0.4f, 0.7f);
-        add_config(TensorShape(16U, 16U, 3U, 2U), TensorShape(8U, 16U), TensorShape(8U, 16U, 3U, 2U), TensorShape(8U, 16U, 3U, 2U), 1.0f, 0.0f);
-        add_config(TensorShape(16U, 16U, 5U, 3U), TensorShape(8U, 16U), TensorShape(8U, 16U, 5U, 3U), TensorShape(8U, 16U, 5U, 3U), 1.0f, 0.3f);
+        add_config(TensorShape(21U, 14U, 13U), TensorShape(34U, 21U), TensorShape(34U), TensorShape(34U, 14U, 13U), 1.0f, 0.0f);
+        add_config(TensorShape(31U, 1U, 3U), TensorShape(23U, 31U), TensorShape(23U), TensorShape(23U, 1U, 3U), 1.0f, 0.0f);
+        add_config(TensorShape(38U, 12U, 2U), TensorShape(21U, 38U), TensorShape(21U), TensorShape(21U, 12U, 2U), 0.2f, 1.2f);
+        add_config(TensorShape(32U, 1U, 4U, 3U), TensorShape(17U, 32U), TensorShape(17U), TensorShape(17U, 1U, 4U, 3U), 0.4f, 0.7f);
+        add_config(TensorShape(16U, 16U, 3U, 2U), TensorShape(8U, 16U), TensorShape(8U), TensorShape(8U, 16U, 3U, 2U), 1.0f, 0.0f);
+        add_config(TensorShape(16U, 16U, 5U, 3U), TensorShape(8U, 16U), TensorShape(8U), TensorShape(8U, 16U, 5U, 3U), 1.0f, 0.3f);
     }
 };
 } // namespace datasets
diff --git a/tests/validation/CL/GEMMMatrixMultiply.cpp b/tests/validation/CL/GEMMMatrixMultiply.cpp
index 21fd712..8f7c0aa 100644
--- a/tests/validation/CL/GEMMMatrixMultiply.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiply.cpp
@@ -67,7 +67,7 @@
 constexpr float         tolerance_num_f16 = 0.02f;
 
 /** Alpha values to test - Precommit */
-const auto alpha_values = framework::dataset::make("alpha", {0.0f, 1.0f, -0.75f} );
+const auto alpha_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
 
 /** Beta values to test - Precommit */
 const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} );
diff --git a/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp b/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp
index cae94b2..5d21cf4 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp
@@ -77,7 +77,7 @@
 constexpr float         tolerance_num_f16 = 0.02f;
 
 /** Alpha values to test - Precommit */
-const auto alpha_values = framework::dataset::make("alpha", {0.0f, 1.0f, -0.75f} );
+const auto alpha_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
 
 /** Beta values to test - Precommit */
 const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} );
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index b36bb99..a04a901 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -44,7 +44,7 @@
 {
 namespace validation
 {
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_ouput_as_3d = false>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false>
 class GEMMValidationFixture : public framework::Fixture
 {
 public:
@@ -87,7 +87,13 @@
         // The GEMMinfo includes the values of the depth in case of reinterpreted 3d output.
         // If the output shape has the same number of dimensions of the input the method called is a 2D matrix multiplication (depth_output_reinterpreted_as_3D = 0),
         // in the other case we have to use the reinterpreted version of GEMM (depth_output_reinterpreted_as_3D = depth of the 3D output).
-        gemm.configure(&a, &b, (disable_c) ? nullptr : &c, &dst, alpha, beta, GEMMInfo(false, false, false, (reinterpret_ouput_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d));
+        gemm.configure(&a,
+                       &b,
+                       (disable_c) ? nullptr : &c,
+                       &dst,
+                       alpha, beta,
+                       GEMMInfo(false, false, false, (reinterpret_output_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d, false, GEMMLowpOutputStageInfo(), false, (reinterpret_input_as_3d
+                                || reinterpret_output_as_3d)));
         ARM_COMPUTE_EXPECT(a.info()->is_resizable(), framework::LogLevel::ERRORS);
         ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS);
         ARM_COMPUTE_EXPECT(c.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -122,6 +128,7 @@
                                       DataType data_type)
     {
         TensorShape shape_a_to_use = shape_a;
+
         if(reinterpret_input_as_3d)
         {
             // Collapse the second and third dimension if the input is 3D
@@ -131,22 +138,29 @@
         // Create reference
         SimpleTensor<T> a{ shape_a_to_use, data_type, 1 };
         SimpleTensor<T> b{ shape_b, data_type, 1 };
-        SimpleTensor<T> c{ shape_c, data_type, 1 };
+        SimpleTensor<T> c{ output_shape, data_type, 1 };
 
         // Fill reference
         fill(a, 0);
         fill(b, 1);
-        if(!disable_c)
+        fill(c, 2);
+
+        if(reinterpret_input_as_3d || reinterpret_output_as_3d)
         {
-            fill(c, 2);
-            return reference::gemm<T>(a, b, c, alpha, beta);
+            const int n          = shape_b[0];
+            const int m          = reinterpret_output_as_3d ? output_shape[1] * output_shape[2] : output_shape[1];
+            const int batch_size = reinterpret_output_as_3d ? output_shape[3] : output_shape[2];
+
+            // In case of broadcast, we need simply copy the first into the following "M" ones
+            for(int i = 1; i < m * batch_size; i++)
+            {
+                memcpy(c.data() + i * n, c.data(), n * sizeof(T));
+            }
         }
-        else
-        {
-            // Setting beta to 0 will effectively disable C for the
-            // computation of the reference: alpha * A * B + 0 * C
-            return reference::gemm<T>(a, b, c, alpha, 0.f);
-        }
+
+        // Setting beta to 0 will effectively disable C for the
+        // computation of the reference: alpha * A * B + 0 * C
+        return reference::gemm<T>(a, b, c, alpha, disable_c ? 0.f : beta);
     }
 
     TensorType      _target{};