Add Gemm MMUL Reshaped Only Rhs Support for FP32/FP16

This patch introduces a GEMM routine that is optimized for Arm(R) Mali(TM)-G715 and Arm(R) Mali(TM)-G615

Resolves: COMPMID-5216
Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
Change-Id: I2e5d7806f5904347185bb3e250f73d73d6669dba
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7914
Reviewed-by: SiCong Li <sicong.li@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp
new file mode 100644
index 0000000..7808be8
--- /dev/null
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp
@@ -0,0 +1,231 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h"
+#include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h"
+#include "tests/CL/CLAccessor.h"
+#include "tests/CL/Helper.h"
+#include "tests/framework/Macros.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/validation/Validation.h"
+#include "tests/validation/fixtures/GEMMFixture.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+using namespace arm_compute::opencl::kernels;
+
+// Create function for ClGemmReshapeRhsMatrixKernel
+using CLGEMMReshapeRHSMatrix = CLSynthetizeOperator<ClGemmReshapeRhsMatrixKernel>;
+
+// Create function for ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel
+using CLGEMMMatrixMultiplyReshapedOnlyRhsMMUL = CLSynthetizeOperator<ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel>;
+
+// Fixture for CLGEMMMatrixMultiplyReshapedOnlyRhsMMUL
+template <typename T>
+using CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture = GEMMMatrixMultiplyReshapedOnlyRhsMMULValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshapedOnlyRhsMMUL>;
+
+namespace
+{
+// *INDENT-OFF*
+// clang-format off
+RelativeTolerance<float> rel_tolerance_f32(0.001f);
+constexpr float          abs_tolerance_f32(0.0001f);
+RelativeTolerance<half_float::half> rel_tolerance_f16(half_float::half(0.001f));
+constexpr float          abs_tolerance_f16(0.3f);
+
+/** Alpha values to test - Precommit */
+const auto a_values = framework::dataset::make("alpha", {1.0f, 0.75f} );
+
+/** Beta values to test - Precommit */
+const auto beta_values = framework::dataset::make("beta", {0.0f, -0.75f} );
+
+/** M values to test */
+const auto m_values = framework::dataset::make("M", {49});
+
+/** N values to test */
+const auto n_values = framework::dataset::make("N", {257});
+
+/** K values to test */
+/** The test case requires this to be multiple of 4*/
+const auto k_values = framework::dataset::make("K", {192});
+
+/** Batch size values to test */
+const auto b_values = framework::dataset::make("batch_size", {1, 2});
+
+/** Activation values to test */
+const auto act_values = framework::dataset::make("Activation",
+{
+    ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
+});
+
+/** M0 values to test - Precommit */
+const auto m0_values_precommit = framework::dataset::make("M0", { 1, 2, 4 });
+
+/** N0 values to test - Precommit */
+const auto n0_values_precommit = framework::dataset::make("N0", { 4, 8 });
+
+/** K0 values to test - Precommit */
+const auto k0_values_precommit = framework::dataset::make("K0", { 1 });
+
+/** Broadcast bias from vector to matrix */
+const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } );
+
+} // namespace
+
+TEST_SUITE(CL)
+TEST_SUITE(GEMMMatrixMultiplyReshapedOnlyRhsMMUL)
+TEST_SUITE(Float)
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture<float>, framework::DatasetMode::ALL,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_values,
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_precommit),
+                                                                   n0_values_precommit),
+                                                                   k0_values_precommit),
+                                                                   framework::dataset::make("ExportToCLImage", false)),
+                                                                   framework::dataset::make("DataType", DataType::F32)),
+                                                                   a_values),
+                                                                   beta_values),
+                                                                   broadcast_bias_values),
+                                                                   act_values))
+{
+    // Validate output
+    if(validate_result)
+    {
+        validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+    }
+    else
+    {
+        ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
+        framework::ARM_COMPUTE_PRINT_INFO();
+    }
+}
+
+TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture<half>, framework::DatasetMode::ALL,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_values,
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_precommit),
+                                                                   n0_values_precommit),
+                                                                   k0_values_precommit),
+                                                                   framework::dataset::make("ExportToCLImage", false)),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values),
+                                                                   beta_values),
+                                                                   broadcast_bias_values),
+                                                                   act_values))
+{
+    // Validate output
+    if(validate_result)
+    {
+        validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+    }
+    else
+    {
+        ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
+        framework::ARM_COMPUTE_PRINT_INFO();
+    }
+}
+TEST_SUITE_END() // FP16
+
+TEST_SUITE(ExportToCLImage)
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture<float>, framework::DatasetMode::ALL,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_values,
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_precommit),
+                                                                   n0_values_precommit),
+                                                                   k0_values_precommit),
+                                                                   framework::dataset::make("ExportToCLImage", true)),
+                                                                   framework::dataset::make("DataType", DataType::F32)),
+                                                                   a_values),
+                                                                   beta_values),
+                                                                   broadcast_bias_values),
+                                                                   act_values))
+{
+    // Validate output
+    if(validate_result)
+    {
+        validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+    }
+    else
+    {
+        ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
+        framework::ARM_COMPUTE_PRINT_INFO();
+    }
+}
+
+TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture<half>, framework::DatasetMode::ALL,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_values,
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_precommit),
+                                                                   n0_values_precommit),
+                                                                   k0_values_precommit),
+                                                                   framework::dataset::make("ExportToCLImage", true)),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values),
+                                                                   beta_values),
+                                                                   broadcast_bias_values),
+                                                                   act_values))
+{
+    // Validate output
+    if(validate_result)
+    {
+        validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+    }
+    else
+    {
+        ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
+        framework::ARM_COMPUTE_PRINT_INFO();
+    }
+}
+TEST_SUITE_END() // FP16
+TEST_SUITE_END() // ExportToCLImage
+TEST_SUITE_END() // Float
+TEST_SUITE_END() // GEMMMatrixMultiplyReshapedOnlyRhsMMUL
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index 884b13d..55bbbda 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -163,18 +163,18 @@
             const int m          = reinterpret_output_as_3d ? output_shape[1] * output_shape[2] : output_shape[1];
             const int batch_size = reinterpret_output_as_3d ? output_shape[3] : output_shape[2];
 
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(c.data() + i * n, c.data(), n * sizeof(T));
             }
         }
-        
+
         /* Note: Assuming the usual batch matmul dimensions A = (B x M x K), B = (B x K x N), if pretranspose_A is set to true, then A is assumed to be (B x K x M),
            therefore, A must be pre-transposed before passing it to the fixture. And, we transpose A again in the fixture to make it (B x M x K)
            in order to be able to call reference implementation that works with (B x M x K) input.
            Similarly, if pretranspose_B is set to true, then B is assumed to be (B x N x K), B must be pre-transposed before passing it to the fixture. */
-           
+
         // Define transposed shapes
         TensorShape a_transposed_shape(a.shape().y(), a.shape().x());
         TensorShape b_transposed_shape(b.shape().y(), b.shape().x());
@@ -315,7 +315,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -438,7 +438,7 @@
         fill(rhs, 1);
         fill(bias, 2);
 
-        // In case of broadcast, we need simply copy the first into the following "M" ones
+        // In case of broadcast, we need to simply copy the first into the following "M" ones
         for(int i = 1; i < m * batch_size; i++)
         {
             memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -593,7 +593,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -748,7 +748,7 @@
         fill(rhs, 1);
         fill(bias, 2);
 
-        // In case of broadcast, we need simply copy the first into the following "M" ones
+        // In case of broadcast, we need to simply copy the first into the following "M" ones
         for(int i = 1; i < m * batch_size; i++)
         {
             memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -923,7 +923,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -1169,7 +1169,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -1361,7 +1361,7 @@
         fill(rhs, 1);
         fill(bias, 2);
 
-        // In case of broadcast, we need simply copy the first into the following "M" ones
+        // In case of broadcast, we need to simply copy the first into the following "M" ones
         for(int i = 1; i < m * batch_size; i++)
         {
             memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -1533,7 +1533,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -1759,7 +1759,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -1941,7 +1941,7 @@
         fill(rhs, 1);
         fill(bias, 2);
 
-        // In case of broadcast, we need simply copy the first into the following "M" ones
+        // In case of broadcast, we need to simply copy the first into the following "M" ones
         for(int i = 1; i < m * batch_size; i++)
         {
             memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -2078,7 +2078,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -2274,7 +2274,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -2421,7 +2421,7 @@
         fill(rhs, 1);
         fill(bias, 2);
 
-        // In case of broadcast, we need simply copy the first into the following "M" ones
+        // In case of broadcast, we need to simply copy the first into the following "M" ones
         for(int i = 1; i < m * batch_size; i++)
         {
             memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -2434,6 +2434,171 @@
     SimpleTensor<T> _reference{};
 };
 
+template <typename TensorType, typename AccessorType, typename T, typename ReshapeRHSOperatorType, typename GEMMOperatorType>
+class GEMMMatrixMultiplyReshapedOnlyRhsMMULValidationFixture : public framework::Fixture
+{
+public:
+    template <typename...>
+    void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, bool export_to_cl_image, DataType data_type, float alpha,
+               float beta, bool broadcast_bias,
+               const ActivationLayerInfo &act_info)
+    {
+        GEMMLHSMatrixInfo lhs_info;
+        lhs_info.m0 = m0;
+        lhs_info.k0 = k0;
+
+        GEMMRHSMatrixInfo rhs_info;
+        rhs_info.n0                 = n0;
+        rhs_info.k0                 = k0;
+        rhs_info.interleave         = true;
+        rhs_info.transpose          = false;
+        rhs_info.h0                 = 4;
+        rhs_info.export_to_cl_image = export_to_cl_image;
+
+        // Set the tensor shapes for LHS and RHS matrices
+        const TensorShape lhs_shape(k, m, batch_size);
+        const TensorShape rhs_shape(n, k, batch_size);
+        const TensorShape bias_shape(n,
+                                     broadcast_bias ? 1 : m,
+                                     broadcast_bias ? 1 : batch_size);
+
+        _target    = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, broadcast_bias, act_info);
+        _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, beta, broadcast_bias, act_info);
+    }
+
+protected:
+    template <typename U>
+    void fill(U &&tensor, int i)
+    {
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
+        library->fill(tensor, distribution, i);
+
+        // Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0)
+        DistributionType distribution_inf{ T(std::numeric_limits<float>::infinity()), T(std::numeric_limits<float>::infinity()) };
+        library->fill_borders_with_garbage(tensor, distribution_inf, i);
+    }
+
+    TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
+                              DataType data_type, float alpha, float beta, bool broadcast_bias, const ActivationLayerInfo &act_info)
+    {
+        // Create tensors
+        TensorType lhs  = create_tensor<TensorType>(lhs_shape, data_type, 1);
+        TensorType rhs  = create_tensor<TensorType>(rhs_shape, data_type, 1);
+        TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1);
+        TensorType rhs_reshaped;
+        TensorType dst;
+
+        const unsigned int M = lhs_shape[1];
+        const unsigned int N = rhs_shape[0];
+        const unsigned int K = lhs_shape[0];
+        GEMMKernelInfo     kernel_info;
+        kernel_info.m                       = M;
+        kernel_info.n                       = N;
+        kernel_info.k                       = K;
+        kernel_info.depth_output_gemm3d     = 0;
+        kernel_info.reinterpret_input_as_3d = false;
+        kernel_info.broadcast_bias          = broadcast_bias;
+        kernel_info.activation_info         = act_info;
+
+        // Create and configure function
+        ReshapeRHSOperatorType reshape_rhs;
+        GEMMOperatorType       gemm;
+
+        validate_result = bool(reshape_rhs.validate(rhs.info(), rhs_reshaped.info(), rhs_info));
+        if(!validate_result)
+        {
+            return nullptr;
+        }
+
+        reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info);
+
+        validate_result = bool(gemm.validate(lhs.info(), rhs_reshaped.info(), bias.info(), dst.info(), alpha, beta, lhs_info, rhs_info, kernel_info));
+        if(!validate_result)
+        {
+            return nullptr;
+        }
+
+        gemm.configure(lhs.info(), rhs_reshaped.info(), bias.info(), dst.info(), alpha, beta, lhs_info, rhs_info, kernel_info);
+
+        ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
+        ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
+        ARM_COMPUTE_ASSERT(bias.info()->is_resizable());
+
+        // Allocate tensors
+        lhs.allocator()->allocate();
+        rhs.allocator()->allocate();
+        rhs_reshaped.allocator()->allocate();
+        bias.allocator()->allocate();
+        dst.allocator()->allocate();
+
+        ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
+        ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
+        ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable());
+        ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
+        ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+
+        // Fill tensors
+        fill(AccessorType(lhs), 0);
+        fill(AccessorType(rhs), 1);
+        fill(AccessorType(bias), 2);
+
+        // Compute GEMM
+        ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } };
+        reshape_rhs.run(reshape_rhs_pack);
+        ITensorPack gemm_pack({ { ACL_SRC_0, &lhs },
+            { ACL_SRC_1, &rhs_reshaped },
+            { ACL_SRC_2, &bias },
+            { ACL_DST, &dst }
+        });
+        gemm.run(gemm_pack);
+
+        return dst;
+    }
+
+    SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type, float alpha, float beta, bool broadcast_bias,
+                                      const ActivationLayerInfo &act_info)
+    {
+        if(!validate_result)
+            return SimpleTensor<T>();
+
+        TensorShape dst_shape = lhs_shape;
+        dst_shape[0]          = rhs_shape[0];
+        dst_shape[1]          = lhs_shape[1];
+
+        // Create reference
+        SimpleTensor<T> lhs{ lhs_shape, data_type, 1 };
+        SimpleTensor<T> rhs{ rhs_shape, data_type, 1 };
+        SimpleTensor<T> bias{ dst_shape, data_type, 1 };
+
+        const int n          = rhs_shape[0];
+        const int m          = lhs_shape[1];
+        const int batch_size = lhs_shape[2];
+
+        // Fill reference
+        fill(lhs, 0);
+        fill(rhs, 1);
+        fill(bias, 2);
+
+        if(broadcast_bias)
+        {
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
+            for(int i = 1; i < m * batch_size; i++)
+            {
+                memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
+            }
+        }
+
+        return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+    }
+
+    bool            validate_result = true;
+    TensorType      _target{};
+    SimpleTensor<T> _reference{};
+};
+
 } // namespace validation
 } // namespace test
 } // namespace arm_compute