COMPMID-1979: Fuse Activation Function in CLGEMM - part 4

Fused activation function in CLGEMM

Change-Id: I644fdf09349325c0b3a2cd5fef2a3ea2c974149d
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1640
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/tests/datasets/LargeGEMMDataset.h b/tests/datasets/LargeGEMMDataset.h
index 0876ae1..0ca0b04 100644
--- a/tests/datasets/LargeGEMMDataset.h
+++ b/tests/datasets/LargeGEMMDataset.h
@@ -55,13 +55,13 @@
 public:
     LargeGEMMOutput3DDataset()
     {
-        add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U, 143U, 3U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f);
-        add_config(TensorShape(681U, 1025U), TensorShape(213U, 681U), TensorShape(213U, 205U, 5U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f);
-        add_config(TensorShape(364U, 3025U), TensorShape(96U, 364U), TensorShape(96U, 605U, 5U), TensorShape(96U, 605U, 5U), 1.0f, 0.0f);
-        add_config(TensorShape(1201U, 729U), TensorShape(128U, 1201U), TensorShape(128U, 243U, 3U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f);
-        add_config(TensorShape(2305U, 169U), TensorShape(384U, 2305U), TensorShape(384U, 13U, 13U), TensorShape(384U, 13U, 13U), 1.0f, 0.0f);
-        add_config(TensorShape(1729U, 170U), TensorShape(192U, 1729U), TensorShape(192U, 85U, 2U), TensorShape(192U, 85U, 2U), 1.0f, 0.0f);
-        add_config(TensorShape(1729U, 170U), TensorShape(128U, 1729U), TensorShape(128U, 17U, 10U), TensorShape(128U, 17U, 10U), 1.0f, 0.0f);
+        add_config(TensorShape(923U, 429U), TensorShape(871U, 923U), TensorShape(871U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f);
+        add_config(TensorShape(681U, 1025U), TensorShape(213U, 681U), TensorShape(213U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f);
+        add_config(TensorShape(364U, 3025U), TensorShape(96U, 364U), TensorShape(96U), TensorShape(96U, 605U, 5U), 1.0f, 0.0f);
+        add_config(TensorShape(1201U, 729U), TensorShape(128U, 1201U), TensorShape(128U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f);
+        add_config(TensorShape(2305U, 169U), TensorShape(384U, 2305U), TensorShape(384U), TensorShape(384U, 13U, 13U), 1.0f, 0.0f);
+        add_config(TensorShape(1729U, 170U), TensorShape(192U, 1729U), TensorShape(192U), TensorShape(192U, 85U, 2U), 1.0f, 0.0f);
+        add_config(TensorShape(1729U, 170U), TensorShape(128U, 1729U), TensorShape(128U), TensorShape(128U, 17U, 10U), 1.0f, 0.0f);
     }
 };
 
@@ -70,13 +70,13 @@
 public:
     LargeGEMMInputOutput3DDataset()
     {
-        add_config(TensorShape(923U, 143U, 3U), TensorShape(871U, 923U), TensorShape(871U, 143U, 3U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f);
-        add_config(TensorShape(681U, 205U, 5U), TensorShape(213U, 681U), TensorShape(213U, 205U, 5U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f);
-        add_config(TensorShape(364U, 605U, 5U), TensorShape(96U, 364U), TensorShape(96U, 605U, 5U), TensorShape(96U, 605U, 5U), 0.2f, 1.2f);
-        add_config(TensorShape(1201U, 243U, 3U), TensorShape(128U, 1201U), TensorShape(128U, 243U, 3U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f);
-        add_config(TensorShape(2305U, 13U, 13U), TensorShape(384U, 2305U), TensorShape(384U, 13U, 13U), TensorShape(384U, 13U, 13U), 0.4f, 0.7f);
-        add_config(TensorShape(1729U, 85U, 2U, 2U), TensorShape(192U, 1729U), TensorShape(192U, 85U, 2U, 2U), TensorShape(192U, 85U, 2U, 2U), 1.0f, 0.0f);
-        add_config(TensorShape(1729U, 17U, 10U, 3U), TensorShape(128U, 1729U), TensorShape(128U, 17U, 10U, 3U), TensorShape(128U, 17U, 10U, 3U), 1.0f, 0.3f);
+        add_config(TensorShape(923U, 143U, 3U), TensorShape(871U, 923U), TensorShape(871U), TensorShape(871U, 143U, 3U), 1.0f, 0.0f);
+        add_config(TensorShape(681U, 205U, 5U), TensorShape(213U, 681U), TensorShape(213U), TensorShape(213U, 205U, 5U), 1.0f, 0.0f);
+        add_config(TensorShape(364U, 605U, 5U), TensorShape(96U, 364U), TensorShape(96U), TensorShape(96U, 605U, 5U), 0.2f, 1.2f);
+        add_config(TensorShape(1201U, 243U, 3U), TensorShape(128U, 1201U), TensorShape(128U), TensorShape(128U, 243U, 3U), 1.0f, 0.0f);
+        add_config(TensorShape(2305U, 13U, 13U), TensorShape(384U, 2305U), TensorShape(384U), TensorShape(384U, 13U, 13U), 0.4f, 0.7f);
+        add_config(TensorShape(1729U, 85U, 2U, 2U), TensorShape(192U, 1729U), TensorShape(192U), TensorShape(192U, 85U, 2U, 2U), 1.0f, 0.0f);
+        add_config(TensorShape(1729U, 17U, 10U, 3U), TensorShape(128U, 1729U), TensorShape(128U), TensorShape(128U, 17U, 10U, 3U), 1.0f, 0.3f);
     }
 };
 } // namespace datasets
diff --git a/tests/datasets/SmallGEMMDataset.h b/tests/datasets/SmallGEMMDataset.h
index ae3c3ed..45d1a1e 100644
--- a/tests/datasets/SmallGEMMDataset.h
+++ b/tests/datasets/SmallGEMMDataset.h
@@ -55,12 +55,12 @@
 public:
     SmallGEMMOutput3DDataset()
     {
-        add_config(TensorShape(21U, 14U), TensorShape(34U, 21U), TensorShape(34U, 7U, 2U), TensorShape(34U, 7U, 2U), 1.0f, 0.0f);
-        add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U, 1U, 1U), TensorShape(23U, 1U, 1U), 1.0f, 0.0f);
-        add_config(TensorShape(38U, 12U), TensorShape(21U, 38U), TensorShape(21U, 4U, 3U), TensorShape(21U, 4U, 3U), 0.2f, 1.2f);
-        add_config(TensorShape(32U, 1U), TensorShape(17U, 32U), TensorShape(17U, 1U, 1U), TensorShape(17U, 1U, 1U), 0.4f, 0.7f);
-        add_config(TensorShape(16U, 16U), TensorShape(8U, 16U), TensorShape(8U, 8U, 2U), TensorShape(8U, 8U, 2U), 1.0f, 0.0f);
-        add_config(TensorShape(16U, 16U, 5U), TensorShape(8U, 16U, 5U), TensorShape(8U, 8U, 2U, 5U), TensorShape(8U, 8U, 2U, 5U), 1.0f, 0.0f);
+        add_config(TensorShape(21U, 14U), TensorShape(34U, 21U), TensorShape(34U), TensorShape(34U, 7U, 2U), 1.0f, 0.0f);
+        add_config(TensorShape(31U, 1U), TensorShape(23U, 31U), TensorShape(23U), TensorShape(23U, 1U, 1U), 1.0f, 0.0f);
+        add_config(TensorShape(38U, 12U), TensorShape(21U, 38U), TensorShape(21U), TensorShape(21U, 4U, 3U), 0.2f, 1.2f);
+        add_config(TensorShape(32U, 1U), TensorShape(17U, 32U), TensorShape(17U), TensorShape(17U, 1U, 1U), 0.4f, 0.7f);
+        add_config(TensorShape(16U, 16U), TensorShape(8U, 16U), TensorShape(8U), TensorShape(8U, 8U, 2U), 1.0f, 0.0f);
+        add_config(TensorShape(16U, 16U, 5U), TensorShape(8U, 16U, 5U), TensorShape(8U), TensorShape(8U, 8U, 2U, 5U), 1.0f, 0.0f);
     }
 };
 
@@ -69,12 +69,12 @@
 public:
     SmallGEMMInputOutput3DDataset()
     {
-        add_config(TensorShape(21U, 14U, 13U), TensorShape(34U, 21U), TensorShape(34U, 14U, 13U), TensorShape(34U, 14U, 13U), 1.0f, 0.0f);
-        add_config(TensorShape(31U, 1U, 3U), TensorShape(23U, 31U), TensorShape(23U, 1U, 3U), TensorShape(23U, 1U, 3U), 1.0f, 0.0f);
-        add_config(TensorShape(38U, 12U, 2U), TensorShape(21U, 38U), TensorShape(21U, 12U, 2U), TensorShape(21U, 12U, 2U), 0.2f, 1.2f);
-        add_config(TensorShape(32U, 1U, 4U, 3U), TensorShape(17U, 32U), TensorShape(17U, 1U, 4U, 3U), TensorShape(17U, 1U, 4U, 3U), 0.4f, 0.7f);
-        add_config(TensorShape(16U, 16U, 3U, 2U), TensorShape(8U, 16U), TensorShape(8U, 16U, 3U, 2U), TensorShape(8U, 16U, 3U, 2U), 1.0f, 0.0f);
-        add_config(TensorShape(16U, 16U, 5U, 3U), TensorShape(8U, 16U), TensorShape(8U, 16U, 5U, 3U), TensorShape(8U, 16U, 5U, 3U), 1.0f, 0.3f);
+        add_config(TensorShape(21U, 14U, 13U), TensorShape(34U, 21U), TensorShape(34U), TensorShape(34U, 14U, 13U), 1.0f, 0.0f);
+        add_config(TensorShape(31U, 1U, 3U), TensorShape(23U, 31U), TensorShape(23U), TensorShape(23U, 1U, 3U), 1.0f, 0.0f);
+        add_config(TensorShape(38U, 12U, 2U), TensorShape(21U, 38U), TensorShape(21U), TensorShape(21U, 12U, 2U), 0.2f, 1.2f);
+        add_config(TensorShape(32U, 1U, 4U, 3U), TensorShape(17U, 32U), TensorShape(17U), TensorShape(17U, 1U, 4U, 3U), 0.4f, 0.7f);
+        add_config(TensorShape(16U, 16U, 3U, 2U), TensorShape(8U, 16U), TensorShape(8U), TensorShape(8U, 16U, 3U, 2U), 1.0f, 0.0f);
+        add_config(TensorShape(16U, 16U, 5U, 3U), TensorShape(8U, 16U), TensorShape(8U), TensorShape(8U, 16U, 5U, 3U), 1.0f, 0.3f);
     }
 };
 } // namespace datasets
diff --git a/tests/validation/CL/GEMMMatrixMultiply.cpp b/tests/validation/CL/GEMMMatrixMultiply.cpp
index 21fd712..8f7c0aa 100644
--- a/tests/validation/CL/GEMMMatrixMultiply.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiply.cpp
@@ -67,7 +67,7 @@
 constexpr float         tolerance_num_f16 = 0.02f;
 
 /** Alpha values to test - Precommit */
-const auto alpha_values = framework::dataset::make("alpha", {0.0f, 1.0f, -0.75f} );
+const auto alpha_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
 
 /** Beta values to test - Precommit */
 const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} );
diff --git a/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp b/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp
index cae94b2..5d21cf4 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyInterleavedTransposed.cpp
@@ -77,7 +77,7 @@
 constexpr float         tolerance_num_f16 = 0.02f;
 
 /** Alpha values to test - Precommit */
-const auto alpha_values = framework::dataset::make("alpha", {0.0f, 1.0f, -0.75f} );
+const auto alpha_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
 
 /** Beta values to test - Precommit */
 const auto beta_values = framework::dataset::make("beta", {-0.75f, 0.0f} );
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index b36bb99..a04a901 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -44,7 +44,7 @@
 {
 namespace validation
 {
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_ouput_as_3d = false>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_output_as_3d = false>
 class GEMMValidationFixture : public framework::Fixture
 {
 public:
@@ -87,7 +87,13 @@
         // The GEMMinfo includes the values of the depth in case of reinterpreted 3d output.
         // If the output shape has the same number of dimensions of the input the method called is a 2D matrix multiplication (depth_output_reinterpreted_as_3D = 0),
         // in the other case we have to use the reinterpreted version of GEMM (depth_output_reinterpreted_as_3D = depth of the 3D output).
-        gemm.configure(&a, &b, (disable_c) ? nullptr : &c, &dst, alpha, beta, GEMMInfo(false, false, false, (reinterpret_ouput_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d));
+        gemm.configure(&a,
+                       &b,
+                       (disable_c) ? nullptr : &c,
+                       &dst,
+                       alpha, beta,
+                       GEMMInfo(false, false, false, (reinterpret_output_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d, false, GEMMLowpOutputStageInfo(), false, (reinterpret_input_as_3d
+                                || reinterpret_output_as_3d)));
         ARM_COMPUTE_EXPECT(a.info()->is_resizable(), framework::LogLevel::ERRORS);
         ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS);
         ARM_COMPUTE_EXPECT(c.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -122,6 +128,7 @@
                                       DataType data_type)
     {
         TensorShape shape_a_to_use = shape_a;
+
         if(reinterpret_input_as_3d)
         {
             // Collapse the second and third dimension if the input is 3D
@@ -131,22 +138,29 @@
         // Create reference
         SimpleTensor<T> a{ shape_a_to_use, data_type, 1 };
         SimpleTensor<T> b{ shape_b, data_type, 1 };
-        SimpleTensor<T> c{ shape_c, data_type, 1 };
+        SimpleTensor<T> c{ output_shape, data_type, 1 };
 
         // Fill reference
         fill(a, 0);
         fill(b, 1);
-        if(!disable_c)
+        fill(c, 2);
+
+        if(reinterpret_input_as_3d || reinterpret_output_as_3d)
         {
-            fill(c, 2);
-            return reference::gemm<T>(a, b, c, alpha, beta);
+            const int n          = shape_b[0];
+            const int m          = reinterpret_output_as_3d ? output_shape[1] * output_shape[2] : output_shape[1];
+            const int batch_size = reinterpret_output_as_3d ? output_shape[3] : output_shape[2];
+
+            // In case of broadcast, we need simply copy the first into the following "M" ones
+            for(int i = 1; i < m * batch_size; i++)
+            {
+                memcpy(c.data() + i * n, c.data(), n * sizeof(T));
+            }
         }
-        else
-        {
-            // Setting beta to 0 will effectively disable C for the
-            // computation of the reference: alpha * A * B + 0 * C
-            return reference::gemm<T>(a, b, c, alpha, 0.f);
-        }
+
+        // Setting beta to 0 will effectively disable C for the
+        // computation of the reference: alpha * A * B + 0 * C
+        return reference::gemm<T>(a, b, c, alpha, disable_c ? 0.f : beta);
     }
 
     TensorType      _target{};