COMPMID-1979: Fuse Activation Function in CLGEMM - part 2

Fuse activation function in:
CLGEMMMatrixMultiplyNativeKernel
CLGEMMMatrixMultiplyReshapedKernel
CLGEMMMatrixMultiplyReshapedOnlyRHSKernel

Change-Id: I033ace2bdc58903594c9f31175e4b23c4b559f6f
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1565
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 854d009..213075d 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -1022,6 +1022,8 @@
  *  - K0 = 2, 3, 4, 8, 16
  *  - H0 >= 1
  *
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
  * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
  *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
@@ -1280,6 +1282,10 @@
 #endif // defined(BROADCAST_BIAS)
 #endif // defined(BETA)
 
+#if defined(ACTIVATION_TYPE)
+    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
+#endif // defined(ACTIVATION_TYPE)
+
     // Store output block
     STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
 
@@ -1397,6 +1403,8 @@
  *  - K0 = 2, 3, 4, 8, 16
  *  - H0 >= 1
  *
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
  * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
  *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
@@ -1656,6 +1664,10 @@
 #endif // defined(BROADCAST_BIAS)
 #endif // defined(BETA)
 
+#if defined(ACTIVATION_TYPE)
+    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
+#endif // defined(ACTIVATION_TYPE)
+
     // Store output block
     STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
 
@@ -1799,6 +1811,8 @@
  *  - V0 >= 1
  *  - H0 >= 1
  *
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
  * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
  *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
@@ -2008,6 +2022,10 @@
 #endif // defined(BROADCAST_BIAS)
 #endif // defined(BETA)
 
+#if defined(ACTIVATION_TYPE)
+    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
+#endif // defined(ACTIVATION_TYPE)
+
     // Store output block
     STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
 
@@ -2115,6 +2133,8 @@
  *  - N0 = 2, 3, 4, 8, 16
  *  - K0 = 2, 3, 4, 8, 16
  *
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
  * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
  *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
@@ -2371,6 +2391,10 @@
 #endif // defined(BROADCAST_BIAS)
 #endif // defined(BETA)
 
+#if defined(ACTIVATION_TYPE)
+    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
+#endif // defined(ACTIVATION_TYPE)
+
     // Store output block
     STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
 
diff --git a/src/core/CL/cl_kernels/gemm_helpers.h b/src/core/CL/cl_kernels/gemm_helpers.h
index 3fd5950..4715fb7 100644
--- a/src/core/CL/cl_kernels/gemm_helpers.h
+++ b/src/core/CL/cl_kernels/gemm_helpers.h
@@ -21,6 +21,7 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
+#include "activation_float_helpers.h"
 #include "helpers.h"
 
 #define LOAD_ROW_1(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
@@ -619,3 +620,73 @@
  * Supported cases N=1,2,3..16, for variables BASENAME[0..N]
  */
 #define ADD_BLOCK_BROADCAST(N, BASENAME, BIAS) ADD_BLOCK_BROADCAST_STR(N, BASENAME, BIAS)
+
+#define ACTIVATION_ROW_1(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    BASENAME##0 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##0, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_2(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_1(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)     \
+    BASENAME##1 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##1, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_3(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_2(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)     \
+    BASENAME##2 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##2, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_4(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_3(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)     \
+    BASENAME##3 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##3, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_5(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_4(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)     \
+    BASENAME##4 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##4, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_6(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_5(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)     \
+    BASENAME##5 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##5, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_7(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_6(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)     \
+    BASENAME##6 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##6, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_8(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_7(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)     \
+    BASENAME##7 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##7, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_9(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_8(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)     \
+    BASENAME##8 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##8, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_10(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_9(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)      \
+    BASENAME##9 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##9, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_11(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_10(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)     \
+    BASENAME##A = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##A, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_12(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_11(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)     \
+    BASENAME##B = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##B, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_13(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_12(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)     \
+    BASENAME##C = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##C, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_14(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_13(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)     \
+    BASENAME##D = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##D, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_15(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_14(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)     \
+    BASENAME##E = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##E, A_VAL, B_VAL);
+
+#define ACTIVATION_ROW_16(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) \
+    ACTIVATION_ROW_15(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)     \
+    BASENAME##F = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, BASENAME##F, A_VAL, B_VAL);
+
+// ACTIVATION_ROW_n apply activation to the variables BASENAME##0... BASENAME##(n-1)
+#define ACTIVATION_BLOCK_STR(N, ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) ACTIVATION_ROW_##N(ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)
+/** Apply activation to the variables BASENAME##0... BASENAME##(n-1)
+ * Supported cases N=1,2,3..16, for variables BASENAME[0..N]
+ */
+#define ACTIVATION_BLOCK(N, ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL) ACTIVATION_BLOCK_STR(N, ACTIVATION_TYPE, DATA_TYPE, BASENAME, A_VAL, B_VAL)
\ No newline at end of file
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
index e5d199d..3c07c1d 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
@@ -262,6 +262,9 @@
     build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0));
     build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0));
     build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0));
+    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation())));
+    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a()));
+    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b()));
 
     std::string kernel_name("gemm_mm_native");
 
@@ -275,6 +278,7 @@
     _config_id += (_broadcast_bias ? "broadcast_bias_" : "");
     _config_id += (_reinterpret_input_as_3d ? "3di_" : "");
     _config_id += (_reinterpret_output_as_3d ? "3do_" : "");
+    _config_id += (gemm_info.activation_info.enabled() ? "fused_activation_" : "");
     _config_id += lower_string(string_from_data_type(input0->info()->data_type()));
     _config_id += "_";
     _config_id += support::cpp11::to_string(output->info()->dimension(1));
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
index 3ad0ffd..fd6fd7c 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
@@ -258,6 +258,9 @@
     build_opts.add_option("-DK0=" + support::cpp11::to_string(lhs_info.k0));
     build_opts.add_option("-DV0=" + support::cpp11::to_string(lhs_info.v0));
     build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0));
+    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation())));
+    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a()));
+    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b()));
 
     std::string kernel_name("gemm_mm_reshaped_");
     kernel_name += lhs_info.transpose ? "lhs_t_" : "lhs_nt_";
@@ -272,6 +275,7 @@
     _config_id += (_add_bias ? "add_bias_" : "");
     _config_id += (_broadcast_bias ? "broadcast_bias_" : "");
     _config_id += (_reinterpret_output_as_3d ? "3do_" : "");
+    _config_id += (gemm_info.activation_info.enabled() ? "fused_activation_" : "");
     _config_id += lower_string(string_from_data_type(input0->info()->data_type()));
     _config_id += "_";
     _config_id += support::cpp11::to_string(output->info()->dimension(1));
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
index 97c7984..5f92cad 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
@@ -267,6 +267,9 @@
     build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0));
     build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0));
     build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0));
+    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation())));
+    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a()));
+    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b()));
 
     std::string kernel_name("gemm_mm_reshaped_only_rhs_");
     kernel_name += rhs_info.transpose ? "t" : "nt";
@@ -281,6 +284,7 @@
     _config_id += (_broadcast_bias ? "broadcast_bias_" : "");
     _config_id += (_reinterpret_input_as_3d ? "3di_" : "");
     _config_id += (_reinterpret_output_as_3d ? "3do_" : "");
+    _config_id += (gemm_info.activation_info.enabled() ? "fused_activation_" : "");
     _config_id += lower_string(string_from_data_type(input0->info()->data_type()));
     _config_id += "_";
     _config_id += support::cpp11::to_string(output->info()->dimension(1));