COMPMID-1979: Fuse Activation Function in CLGEMM - part 1

Implementing a new struct to contains the information for the
OpenCL GEMM kernels

Change-Id: I6c641c312f9c3b025a7c69dd0df3b730d2d2c2cb
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1434
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h
index 79689a2..96f412c 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.h
@@ -26,6 +26,8 @@
 
 #include "arm_compute/core/CL/ICLKernel.h"
 
+#include "arm_compute/core/KernelDescriptors.h"
+
 namespace arm_compute
 {
 class ICLTensor;
@@ -62,7 +64,7 @@
      */
     void configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
                    const GEMMRHSMatrixInfo &rhs_info,
-                   const GEMMReshapeInfo &gemm_info);
+                   const GEMMKernelInfo    &gemm_info);
     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyNativeKernel
      *
      * @param[in] input0    Input tensor info for the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4.
@@ -83,7 +85,7 @@
      */
     static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
                            const GEMMRHSMatrixInfo &rhs_info,
-                           const GEMMReshapeInfo &gemm_info);
+                           const GEMMKernelInfo    &gemm_info);
 
     // Inherited methods overridden:
     void run(const Window &window, cl::CommandQueue &queue) override;
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h
index 68ab94a..47916b3 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h
@@ -26,6 +26,8 @@
 
 #include "arm_compute/core/CL/ICLKernel.h"
 
+#include "arm_compute/core/KernelDescriptors.h"
+
 namespace arm_compute
 {
 class ICLTensor;
@@ -69,7 +71,7 @@
      */
     void configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
                    const GEMMRHSMatrixInfo &rhs_info,
-                   const GEMMReshapeInfo   &gemm_info);
+                   const GEMMKernelInfo    &gemm_info);
     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyReshapedKernel
      *
      * @param[in] input0    Input tensor containing the LHS reshaped matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4
@@ -94,7 +96,7 @@
      */
     static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
                            const GEMMRHSMatrixInfo &rhs_info,
-                           const GEMMReshapeInfo   &gemm_info);
+                           const GEMMKernelInfo    &gemm_info);
 
     // Inherited methods overridden:
     void run(const Window &window, cl::CommandQueue &queue) override;
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h
index e3b3880..3315331 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h
@@ -26,6 +26,8 @@
 
 #include "arm_compute/core/CL/ICLKernel.h"
 
+#include "arm_compute/core/KernelDescriptors.h"
+
 namespace arm_compute
 {
 class ICLTensor;
@@ -65,7 +67,7 @@
      */
     void configure(const ICLTensor *input0, const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
                    const GEMMRHSMatrixInfo &rhs_info,
-                   const GEMMReshapeInfo   &gemm_info);
+                   const GEMMKernelInfo    &gemm_info);
     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyReshapedOnlyRHSKernel
      *
      * @param[in] input0    Input tensor info for the LHS matrix. Data type supported: F32/F16. The number of dimensions for the LHS matrix must be less or equal than 4.
@@ -86,7 +88,7 @@
      */
     static Status validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
                            const GEMMRHSMatrixInfo &rhs_info,
-                           const GEMMReshapeInfo   &gemm_info);
+                           const GEMMKernelInfo    &gemm_info);
 
     // Inherited methods overridden:
     void run(const Window &window, cl::CommandQueue &queue) override;
diff --git a/arm_compute/core/KernelDescriptors.h b/arm_compute/core/KernelDescriptors.h
index 83131f4..fe59365 100644
--- a/arm_compute/core/KernelDescriptors.h
+++ b/arm_compute/core/KernelDescriptors.h
@@ -48,5 +48,16 @@
     unsigned int Nx{ 0 };                 /**< Nx coefficient. */
     bool         is_first_stage{ false }; /**< Flags if the FFT kernels is the first stage of a decomposed FFT. */
 };
+
+/** Descriptor used by the GEMM kernels */
+struct GEMMKernelInfo
+{
+    unsigned int m{ 0 };
+    unsigned int n{ 0 };
+    unsigned int k{ 0 };
+    unsigned int depth_output_gemm3d{ 0 };
+    bool         reinterpret_input_as_3d{ false };
+    bool         broadcast_bias{ false };
+};
 } // namespace arm_compute
 #endif /* __ARM_COMPUTE_CORE_KERNEL_DESCRIPTORS_H__ */
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 7eab17b..0105014 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -26,6 +26,7 @@
 
 #include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/ITensorInfo.h"
+#include "arm_compute/core/KernelDescriptors.h"
 #include "arm_compute/core/Utils.h"
 
 #include "arm_compute/core/utils/helpers/tensor_transform.h"
@@ -851,6 +852,8 @@
 
 /** Calculate the matrix multiplication output shape of two tensors
  *
+ * @note Deprecated. Remove when GEMMReshapeInfo is not used anymore by any other kernels
+ *
  * @param[in] input0    First input tensor info
  * @param[in] input1    Second input tensor info
  * @param[in] gemm_info GEMM reshape info
@@ -888,6 +891,43 @@
 
 /** Calculate the matrix multiplication output shape of two tensors
  *
+ * @param[in] input0    First input tensor info
+ * @param[in] input1    Second input tensor info
+ * @param[in] gemm_info GEMM kernel info used to retrieve the original dimensions of the input matrices
+ *
+ * @return the calculated shape
+ */
+inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo &input1, const GEMMKernelInfo &gemm_info)
+{
+    ARM_COMPUTE_ERROR_ON_MSG(input0.num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
+
+    const bool         reinterpret_input_as_3d  = gemm_info.reinterpret_input_as_3d;
+    const bool         reinterpret_output_as_3d = gemm_info.depth_output_gemm3d != 0;
+    const unsigned int depth_output_gemm3d      = reinterpret_output_as_3d ? gemm_info.depth_output_gemm3d : 1;
+
+    TensorShape output_shape{ input0.tensor_shape() };
+
+    if(!reinterpret_input_as_3d && !reinterpret_output_as_3d)
+    {
+        output_shape.set(0, gemm_info.n);
+        output_shape.set(1, gemm_info.m);
+    }
+    else
+    {
+        // If the output of GEMM has to be reinterpreted as 3D, the number of input0 rows (M) is obtained collapsing the second and third
+        // dimension of the output tensor
+        const unsigned int batch_size = reinterpret_input_as_3d ? input0.tensor_shape()[3] : input0.tensor_shape()[2];
+        output_shape.set(0, gemm_info.n);
+        output_shape.set(1, gemm_info.m / depth_output_gemm3d);
+        output_shape.set(2, reinterpret_output_as_3d ? depth_output_gemm3d : batch_size);
+        output_shape.set(3, reinterpret_output_as_3d ? batch_size : 1);
+    }
+
+    return output_shape;
+}
+
+/** Calculate the matrix multiplication output shape of two tensors
+ *
  * @param[in] input           Input tensor info
  * @param[in] gemm_3d_depth   (Optional)  GEMM 3d depth
  * @param[in] batch_size_on_z (Optional) True if batch size is on z axis