Rework gemm_reshape_rhs_(nt,t) with new macros

Resolves COMPMID-4891

Change-Id: Ifdf2a0eaed23347a1b4465ea8d58c11b72083952
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6741
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
diff --git a/src/core/CL/ICLKernel.cpp b/src/core/CL/ICLKernel.cpp
index eb750cb..9bbc710 100644
--- a/src/core/CL/ICLKernel.cpp
+++ b/src/core/CL/ICLKernel.cpp
@@ -116,6 +116,31 @@
     ARM_COMPUTE_UNUSED(idx_start);
 }
 
+void ICLKernel::add_3d_tensor_nhw_argument(unsigned int &idx, const ICLTensor *tensor)
+{
+    ARM_COMPUTE_ERROR_ON(tensor == nullptr);
+
+    const ITensorInfo *info = tensor->info();
+    ARM_COMPUTE_ERROR_ON(info == nullptr);
+    const Strides &strides = info->strides_in_bytes();
+
+    // Tensor poniter
+    _kernel.setArg(idx++, tensor->cl_buffer());
+
+    // Add stride_y, stride_z
+    _kernel.setArg<cl_uint>(idx++, strides[1]);
+    _kernel.setArg<cl_uint>(idx++, strides[2]);
+
+    // Tensor dimensions
+    _kernel.setArg<cl_uint>(idx++, info->dimension(0));
+    _kernel.setArg<cl_uint>(idx++, info->dimension(1));
+    _kernel.setArg<cl_uint>(idx++, info->dimension(2));
+
+    // Offset of first element
+    unsigned int offset_first_element = info->offset_first_element_in_bytes();
+    _kernel.setArg<cl_uint>(idx++, offset_first_element);
+}
+
 void ICLKernel::add_4d_tensor_nhwc_argument(unsigned int &idx, const ICLTensor *tensor)
 {
     ARM_COMPUTE_ERROR_ON(tensor == nullptr);
diff --git a/src/core/CL/ICLKernel.h b/src/core/CL/ICLKernel.h
index a7c979e..bc138e7 100644
--- a/src/core/CL/ICLKernel.h
+++ b/src/core/CL/ICLKernel.h
@@ -226,6 +226,23 @@
         add_tensor_argument<4>(idx, tensor, window);
     }
 
+    /** Add the passed NHW 3D tensor's parameters to the object's kernel's arguments by passing strides, dimensions and the offset to the first valid element in bytes.
+     *
+     * @param[in,out] idx    Index at which to start adding the tensor's arguments. Will be incremented by the number of kernel arguments set.
+     * @param[in]     tensor Tensor to set as an argument of the object's kernel.
+     */
+    void add_3d_tensor_nhw_argument(unsigned int &idx, const ICLTensor *tensor);
+
+    /** Returns the number of arguments enqueued per NHW 3D Tensor object.
+     *
+     * @return The number of arguments enqueued per NHW 3D Tensor object.
+     */
+    constexpr static unsigned int num_arguments_per_3d_tensor_nhw()
+    {
+        constexpr unsigned int no_args_per_3d_tensor_nhw = 7u;
+        return no_args_per_3d_tensor_nhw;
+    }
+
     /** Add the passed NHWC 4D tensor's parameters to the object's kernel's arguments by passing strides, dimensions and the offset to the first valid element in bytes.
      *
      * @param[in,out] idx    Index at which to start adding the tensor's arguments. Will be incremented by the number of kernel arguments set.
diff --git a/src/core/CL/cl_kernels/common/gemm_utils.cl b/src/core/CL/cl_kernels/common/gemm_utils.cl
index 89c00b5..2e49614 100644
--- a/src/core/CL/cl_kernels/common/gemm_utils.cl
+++ b/src/core/CL/cl_kernels/common/gemm_utils.cl
@@ -21,6 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
+#include "helpers.h"
+#include "tile_helpers.h"
 #include "gemm_helpers.h"
 #include "repeat.h"
 
@@ -390,12 +392,11 @@
 }
 #endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH) && defined(SRC_HEIGHT) && defined(PARTIAL_LOAD_M0) && defined(PARTIAL_LOAD_K0)
 
-#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
+#if defined(RESHAPE_RHS_NT)
 /** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
  *  the output matrix unrolling the values.
  *
  * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
- * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
  * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
  * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
  * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
@@ -404,25 +405,25 @@
  *                                      K0: 1,2,3,4,8,16
  *                                      H0: greater than 0
  *
- * @param[in]  src_ptr                           Pointer to the source RHS tensor. Supported data types: All
- * @param[in]  src_stride_x                      Stride of the source RHS tensor in X dimension (in bytes)
- * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src_stride_y                      Stride of the source RHS tensor in Y dimension (in bytes)
- * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  src_stride_z                      Stride of the source RHS tensor in Z dimension (in bytes)
- * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
- * @param[out] dst_ptr                           Pointer to the destination matrix Supported data types: same as @p src_ptr
- * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
- * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
- * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src_ptr                           Pointer to the source tensor. Supported data types: All
+ * @param[in] src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_w                             The size of the width dimension of the source tensor
+ * @param[in] src_h                             The size of the height dimension of the source tensor
+ * @param[in] src_n                             The size of the depth dimension of the source tensor
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[in] dst_ptr                           Pointer to the destination tensor. Supported data types: All
+ * @param[in] dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_w                             The size of the width dimension of the destination tensor
+ * @param[in] dst_h                             The size of the height dimension of the destination tensor
+ * @param[in] dst_n                             The size of the depth dimension of the destination tensor
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * @param[in] H0                                The number of blocks to place on the same row. It must be greater than 0.
  */
-__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
-                                         TENSOR3D_DECLARATION(dst))
+__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_T(src, BUFFER),
+                                         TENSOR3D_T(dst, BUFFER),
+                                         const int H0)
 {
     // Block size
 #define BLOCK_SIZE ((K0) * (N0))
@@ -441,114 +442,55 @@
 #define OUTPUT_STEP_X (N0)
 #endif // defined(INTERLEAVE)
 
-    // Compute source and destination addresses
-    uint x = get_global_id(0);
-    uint y = get_global_id(1);
-    uint z = get_global_id(2);
+    const int x = GET_SPATIAL_IDX(0, 1, 0);
+    const int y = GET_SPATIAL_IDX(1, 1, 0);
+    const int z = GET_SPATIAL_IDX(2, 1, 0);
 
-    // ------------------ Compute input/output addresses ---------------------------
+    const int xi = x * N0;
+    const int yi = y * K0;
 
-    // Compute the input address
-    __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
+    const int xo = y * BLOCK_SIZE * H0 + (x % H0) * OUTPUT_OFFSET_X;
+    const int yo = (x / H0);
 
-    // Compute the output address
-    __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
-                                     x / (uint)H0)
-                                 * (uint)dst_stride_y)
-                                 + z * (uint)dst_stride_z;
+    src_offset_first_element_in_bytes += yi * src_stride_y + z * src_stride_z;
+    dst_offset_first_element_in_bytes += yo * dst_stride_y + z * dst_stride_z;
 
-    // ---------------------------Load input values --------------------------------
+    TILE(DATA_TYPE, K0, N0, in);
 
-    REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); ////uint a0=0, a1=0, a2=0...a(M0-1)=0;
+    // Initialize the tile to zero
+    for(int i = 0; i < K0; ++i)
+    {
+        in[i].v = 0;
+    }
 
-    // Load values from the RHS matrix
-    a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
-#if K0 > 1
-    if(y * (uint)K0 + 1 < SRC_HEIGHT)
+    // Load input tile
+    for(int i = 0; i < K0; ++i)
     {
-        a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
+        if(yi + i < src_h)
+        {
+            in[i].v = V_LOAD(DATA_TYPE, N0, BUFFER, src, xi, i, src_stride_y);
+        }
     }
-#endif // K0 > 1
-#if K0 > 2
-    if(y * (uint)K0 + 2 < SRC_HEIGHT)
-    {
-        a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
-    }
-#endif // K0 > 2
-#if K0 > 3
-    if(y * (uint)K0 + 3 < SRC_HEIGHT)
-    {
-        a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
-    }
-#endif // K0 > 3
-#if K0 > 4
-    if(y * (uint)K0 + 4 < SRC_HEIGHT)
-    {
-        a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
-    }
-    if(y * (uint)K0 + 5 < SRC_HEIGHT)
-    {
-        a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
-    }
-    if(y * (uint)K0 + 6 < SRC_HEIGHT)
-    {
-        a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
-    }
-    if(y * (uint)K0 + 7 < SRC_HEIGHT)
-    {
-        a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
-    }
-#endif // K0 > 4
-#if K0 > 8
-    if(y * (uint)K0 + 8 < SRC_HEIGHT)
-    {
-        a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
-    }
-    if(y * (uint)K0 + 9 < SRC_HEIGHT)
-    {
-        a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
-    }
-    if(y * (uint)K0 + 10 < SRC_HEIGHT)
-    {
-        aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
-    }
-    if(y * (uint)K0 + 11 < SRC_HEIGHT)
-    {
-        aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
-    }
-    if(y * (uint)K0 + 12 < SRC_HEIGHT)
-    {
-        aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
-    }
-    if(y * (uint)K0 + 13 < SRC_HEIGHT)
-    {
-        aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
-    }
-    if(y * (uint)K0 + 14 < SRC_HEIGHT)
-    {
-        aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
-    }
-    if(y * (uint)K0 + 15 < SRC_HEIGHT)
-    {
-        aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
-    }
-#endif // K0 > 8
 
-    // ---------------------------Store output values ------------------------------
-    REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
-    STORE_BLOCK(K0, N0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
+    TILE(uint, K0, 1, dst_indirect_y);
+    for(int i = 0; i < K0; ++i)
+    {
+        dst_indirect_y[i].v = i;
+    }
+
+    T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, K0, N0, 0, BUFFER, dst, xo, (OUTPUT_STEP_X * sizeof(DATA_TYPE)), false, in, dst_indirect_y);
 
 #undef BLOCK_SIZE
 #undef OUTPUT_OFFSET_X
 #undef OUTPUT_STEP_X
 }
+#endif // defined(RESHAPE_RHS_NT)
 
-#if defined(TRANSPOSE)
+#if defined(RESHAPE_RHS_T)
 /** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
  *  the output matrix unrolling the values.
  *
  * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
- * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
  * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
  * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
  * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
@@ -558,25 +500,25 @@
  *                                      K0: 2,3,4,8,16
  *                                      H0: greater than 0
  *
- * @param[in]  src_ptr                           Pointer to the source RHS tensor. Supported data types: All
- * @param[in]  src_stride_x                      Stride of the source RHS tensor in X dimension (in bytes)
- * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  src_stride_y                      Stride of the source RHS tensor in Y dimension (in bytes)
- * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  src_stride_z                      Stride of the source RHS tensor in Z dimension (in bytes)
- * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
- * @param[out] dst_ptr                           Pointer to the destination matrix Supported data types: same as @p src_ptr
- * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
- * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
- * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
- * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
- * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
- * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in] src_ptr                           Pointer to the source tensor. Supported data types: All
+ * @param[in] src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
+ * @param[in] src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_w                             The size of the width dimension of the source tensor
+ * @param[in] src_h                             The size of the height dimension of the source tensor
+ * @param[in] src_n                             The size of the depth dimension of the source tensor
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[in] dst_ptr                           Pointer to the destination tensor. Supported data types: All
+ * @param[in] dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_w                             The size of the width dimension of the destination tensor
+ * @param[in] dst_h                             The size of the height dimension of the destination tensor
+ * @param[in] dst_n                             The size of the depth dimension of the destination tensor
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * @param[in] H0                                The number of blocks to place on the same row. It must be greater than 0.
  */
-__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
-                                        TENSOR3D_DECLARATION(dst))
+__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_T(src, BUFFER),
+                                        TENSOR3D_T(dst, BUFFER),
+                                        const int H0)
 {
     // Block size
 #define BLOCK_SIZE ((K0) * (N0))
@@ -595,280 +537,57 @@
 #define OUTPUT_STEP_X (K0)
 #endif // defined(INTERLEAVE)
 
-    // Compute source and destination addresses
-    uint x = get_global_id(0);
-    uint y = get_global_id(1);
-    uint z = get_global_id(2);
+    const int x = GET_SPATIAL_IDX(0, 1, 0);
+    const int y = GET_SPATIAL_IDX(1, 1, 0);
+    const int z = GET_SPATIAL_IDX(2, 1, 0);
 
-    // ------------------ Compute input/output addresses ---------------------------
+    const int xi = x * N0;
+    const int yi = y * K0;
 
-    // Compute the input address
-    __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
+    const int xo = y * BLOCK_SIZE * H0 + (x % H0) * OUTPUT_OFFSET_X;
+    const int yo = (x / H0);
 
-    // Compute the output address
-    __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
-                                 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
+    src_offset_first_element_in_bytes += yi * src_stride_y + z * src_stride_z;
+    dst_offset_first_element_in_bytes += yo * dst_stride_y + z * dst_stride_z;
 
-    // ---------------------------Load input values --------------------------------
-    REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); //VEC_DATA_TYPE(DATA_TYPE, N0)    a0=0, a1=0, ... a(K0-1)=0;
+    TILE(DATA_TYPE, K0, N0, in);
+    TILE(DATA_TYPE, N0, K0, in_tr);
 
-    // Load values from the RHS matrix
-    a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
-    if(y * (uint)K0 + 1 < SRC_HEIGHT)
+    // Initialize the tile to zero
+    for(int i = 0; i < K0; ++i)
     {
-        a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
+        in[i].v = 0;
     }
-#if K0 > 2
-    if(y * (uint)K0 + 2 < SRC_HEIGHT)
-    {
-        a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
-    }
-#endif // K0 > 2
-#if K0 > 3
-    if(y * (uint)K0 + 3 < SRC_HEIGHT)
-    {
-        a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
-    }
-#endif // K0 > 3
-#if K0 > 4
-    if(y * (uint)K0 + 4 < SRC_HEIGHT)
-    {
-        a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
-    }
-    if(y * (uint)K0 + 5 < SRC_HEIGHT)
-    {
-        a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
-    }
-    if(y * (uint)K0 + 6 < SRC_HEIGHT)
-    {
-        a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
-    }
-    if(y * (uint)K0 + 7 < SRC_HEIGHT)
-    {
-        a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
-    }
-#endif // K0 > 4
-#if K0 > 8
-    if(y * (uint)K0 + 8 < SRC_HEIGHT)
-    {
-        a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
-    }
-    if(y * (uint)K0 + 9 < SRC_HEIGHT)
-    {
-        a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
-    }
-    if(y * (uint)K0 + 10 < SRC_HEIGHT)
-    {
-        aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
-    }
-    if(y * (uint)K0 + 11 < SRC_HEIGHT)
-    {
-        aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
-    }
-    if(y * (uint)K0 + 12 < SRC_HEIGHT)
-    {
-        aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
-    }
-    if(y * (uint)K0 + 13 < SRC_HEIGHT)
-    {
-        aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
-    }
-    if(y * (uint)K0 + 14 < SRC_HEIGHT)
-    {
-        aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
-    }
-    if(y * (uint)K0 + 15 < SRC_HEIGHT)
-    {
-        aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
-    }
-#endif // K0 > 8
 
-    // ---------------------------Transpose the block ------------------------------
-    REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0)    res0=0, res1=0, res2=0,... res(N0-1)=0;
+    // Load input tile
+    for(int i = 0; i < K0; ++i)
+    {
+        if(yi + i < src_h)
+        {
+            in[i].v = V_LOAD(DATA_TYPE, N0, BUFFER, src, xi, i, src_stride_y);
+        }
+    }
 
-#if K0 == 2
-    // This part computes the following transpositions:
-    // 2x2 -> 2x2
-    // 2x4 -> 4x2
-    // 2x8 -> 8x2
-    // 2x16 -> 16x2
-    res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0);
-    res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1);
-#if N0 > 2
-    res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2);
-#endif // N0 > 2
-#if N0 > 3
-    res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3);
-#endif // N0 > 3
-#if N0 > 4
-    res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4);
-    res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5);
-    res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6);
-    res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7);
-#endif // N0 > 4
-#if N0 > 8
-    res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8);
-    res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9);
-    resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA);
-    resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB);
-    resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC);
-    resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD);
-    resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE);
-    resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF);
-#endif // N0 > 8
+    // Transpose input tile
+    for(int k0 = 0; k0 < K0; ++k0)
+    {
+        for(int n0 = 0; n0 < N0; ++n0)
+        {
+            in_tr[n0].s[k0] = in[k0].s[n0];
+        }
+    }
 
-#elif K0 == 3 // K0 == 2
-    // This part computes the following transpositions:
-    // 3x2 -> 2x3
-    // 3x4 -> 4x3
-    // 3x8 -> 8x3
-    // 3x16 -> 16x3
-    res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
-    res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
-#if N0 > 2
-    res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
-#endif // N0 > 2
-#if N0 > 3
-    res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
-#endif // N0 > 3
-#if N0 > 4
-    res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
-    res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
-    res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
-    res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
-#endif // N0 > 4
-#if N0 > 8
-    res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
-    res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
-    resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
-    resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
-    resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
-    resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
-    resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
-    resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
-#endif // N0 > 8
+    TILE(uint, N0, 1, dst_indirect_y);
+    for(int i = 0; i < N0; ++i)
+    {
+        dst_indirect_y[i].v = i;
+    }
 
-#elif K0 == 4 // K0 == 4
-    // This part computes the following transpositions:
-    // 4x2 -> 2x4
-    // 4x4 -> 4x4
-    // 4x8 -> 8x4
-    // 4x16 -> 16x4
-    res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
-    res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
-#if N0 > 2
-    res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
-#endif // N0 > 2
-#if N0 > 3
-    res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
-#endif // N0 > 3
-#if N0 > 4
-    res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
-    res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
-    res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
-    res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
-#endif // N0 > 4
-#if N0 > 8
-    res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
-    res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
-    resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
-    resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
-    resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
-    resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
-    resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
-    resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
-#endif // N0 > 8
-
-#elif K0 == 8 // K0 == 8
-    // This part computes the following transpositions:
-    // 8x2 -> 2x8
-    // 8x4 -> 4x8
-    // 8x8 -> 8x8
-    // 8x16 -> 16x8
-    res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
-    res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
-#if N0 > 2
-    res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
-#endif // N0 > 2
-#if N0 > 3
-    res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
-#endif // N0 > 3
-#if N0 > 4
-    res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
-    res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
-    res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
-    res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
-#endif // N0 > 4
-#if N0 > 8
-    res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
-    res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
-    resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
-    resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
-    resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
-    resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
-    resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
-    resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
-#endif // N0 > 8
-
-#elif K0 == 16 // K0 == 16
-
-    // This part computes the following transpositions:
-    // 16x2 -> 2x16
-    // 16x4 -> 4x16
-    // 16x8 -> 8x16
-    // 16x16 -> 16x16
-    res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
-                                          a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
-    res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
-                                          a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
-#if N0 > 2
-    res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
-                                          a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
-#endif // N0 > 2
-#if N0 > 3
-    res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
-                                          a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
-#endif // N0 > 3
-#if N0 > 4
-    res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
-                                          a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
-    res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
-                                          a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
-    res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
-                                          a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
-    res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
-                                          a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
-#endif // N0 > 4
-#if N0 > 8
-    res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
-                                          a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
-    res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
-                                          a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
-    resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
-                                          a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
-    resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
-                                          a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
-    resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
-                                          a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
-    resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
-                                          a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
-    resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
-                                          a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
-    resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
-                                          a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
-#endif // N0 > 8
-
-#else // N0 == 16
-#error "Not supported N0 value"
-#endif // N0 > 2
-
-    // ---------------------------Store the output values ------------------------------
-    REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
-    STORE_BLOCK(N0, K0, DATA_TYPE, res, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
+    T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, N0, K0, 0, BUFFER, dst, xo, (OUTPUT_STEP_X * sizeof(DATA_TYPE)), false, in_tr, dst_indirect_y);
 
 #undef BLOCK_SIZE
 #undef OUTPUT_OFFSET_X
 #undef OUTPUT_STEP_X
 }
-#endif // defined(TRANSPOSE)
-#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
+
+#endif // defined(RESHAPE_RHS_T)
\ No newline at end of file
diff --git a/src/core/CL/cl_kernels/tile_helpers.h b/src/core/CL/cl_kernels/tile_helpers.h
index cc20616..30c37da 100644
--- a/src/core/CL/cl_kernels/tile_helpers.h
+++ b/src/core/CL/cl_kernels/tile_helpers.h
@@ -130,6 +130,28 @@
 #define TENSOR4D_T_STR(name, type) TENSOR4D_T_##type(name)
 #define TENSOR4D_T(name, type) TENSOR4D_T_STR(name, type)
 
+#define TENSOR3D_T_IMAGE(name)          \
+    __read_only image2d_t name##_img, \
+    __global uchar *name##_ptr,       \
+    uint        name##_stride_y, \
+    uint        name##_stride_z, \
+    uint        name##_w,   \
+    uint        name##_h,   \
+    uint        name##_n,   \
+    uint        name##_offset_first_element_in_bytes
+
+#define TENSOR3D_T_BUFFER(name)    \
+    __global uchar *name##_ptr,  \
+    uint        name##_stride_y, \
+    uint        name##_stride_z, \
+    uint        name##_w,   \
+    uint        name##_h,   \
+    uint        name##_n,   \
+    uint        name##_offset_first_element_in_bytes
+
+#define TENSOR3D_T_STR(name, type) TENSOR3D_T_##type(name)
+#define TENSOR3D_T(name, type) TENSOR3D_T_STR(name, type)
+
 #if !defined(UNROLL_WITH_PRAGMA)
 #define UNROLL_INCR(idx, step, macro) idx += (step); (macro)
 
diff --git a/src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.cpp b/src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.cpp
index 778b9b9..b3a0388 100644
--- a/src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.cpp
+++ b/src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.cpp
@@ -123,10 +123,9 @@
     CLBuildOptions build_opts;
     build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0));
     build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0));
-    build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0));
-    build_opts.add_option_if(rhs_info.transpose, "-DTRANSPOSE");
     build_opts.add_option_if(rhs_info.interleave, "-DINTERLEAVE");
-    build_opts.add_option("-DSRC_HEIGHT=" + support::cpp11::to_string(src->dimension(1)));
+    build_opts.add_option_if(rhs_info.transpose, "-DRESHAPE_RHS_T");
+    build_opts.add_option_if(!rhs_info.transpose, "-DRESHAPE_RHS_NT");
     build_opts.add_option("-DDATA_TYPE=" + get_cl_unsigned_type_from_element_size(src->element_size()));
 
     std::string kernel_name("gemm_reshape_rhs_matrix_");
@@ -139,6 +138,9 @@
     auto win_config = validate_and_configure_window(src, dst, rhs_info);
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
     ICLKernel::configure_internal(win_config.second);
+
+    unsigned int idx = 2 * num_arguments_per_3d_tensor_nhw();
+    _kernel.setArg<cl_int>(idx++, rhs_info.h0);
 }
 
 Status ClGemmReshapeRhsMatrixKernel::validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMRHSMatrixInfo &rhs_info)
@@ -164,8 +166,8 @@
     do
     {
         unsigned int idx = 0;
-        add_3D_tensor_argument(idx, src, slice);
-        add_3D_tensor_argument(idx, dst, slice);
+        add_3d_tensor_nhw_argument(idx, src);
+        add_3d_tensor_nhw_argument(idx, dst);
         enqueue(queue, *this, slice, lws_hint());
     }
     while(window.slide_window_slice_3D(slice));