COMPMID-1337 Implementing Winograd Convolution Layer 1x3 and 3x1 kernels on OpenCL NHWC

Change-Id: Ia07e0dfcbcd07366c4bcb956e298369fb12a0369
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/138759
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
diff --git a/src/core/CL/cl_kernels/winograd_filter_transform.cl b/src/core/CL/cl_kernels/winograd_filter_transform.cl
index 5f528d4..e53da9b 100644
--- a/src/core/CL/cl_kernels/winograd_filter_transform.cl
+++ b/src/core/CL/cl_kernels/winograd_filter_transform.cl
@@ -285,9 +285,11 @@
 #endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
 }
 
-/** This OpenCL kernel performs Winograd filter transform 3x3 when the data layout is NHWC and the output tile is 4x4
+/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NHWC and the output tile is 4x4/4x1/1x4
  *
  * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
+ * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
  *
  * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F32
  * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
@@ -317,32 +319,26 @@
     const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
 
     // Load the values from the input tensor
-    float w00 = *((__global float *)(src_addr + 0 * src_stride_z + 0 * src_stride_y));
-    float w01 = *((__global float *)(src_addr + 0 * src_stride_z + 1 * src_stride_y));
-    float w02 = *((__global float *)(src_addr + 0 * src_stride_z + 2 * src_stride_y));
-    float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
-    float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
-    float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
-    float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
-    float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
-    float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
-
-    // Transform the 3x3 tile in a 6x6 tile
-    float out00, out01, out02, out03, out04, out05;
-    float out10, out11, out12, out13, out14, out15;
-    float out20, out21, out22, out23, out24, out25;
-    float out30, out31, out32, out33, out34, out35;
-    float out40, out41, out42, out43, out44, out45;
-    float out50, out51, out52, out53, out54, out55;
-
-    out00 = out01 = out02 = out03 = out04 = out05 = 0.f;
-    out10 = out11 = out12 = out13 = out14 = out15 = 0.f;
-    out20 = out21 = out22 = out23 = out24 = out25 = 0.f;
-    out30 = out31 = out32 = out33 = out34 = out35 = 0.f;
-    out40 = out41 = out42 = out43 = out44 = out45 = 0.f;
-    out50 = out51 = out52 = out53 = out54 = out55 = 0.f;
+#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+    float w00 = *((__global float *)(src_addr + 0 * src_stride_z));
+    float w01 = *((__global float *)(src_addr + 1 * src_stride_z));
+    float w02 = *((__global float *)(src_addr + 2 * src_stride_z));
+#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+    float  w00 = *((__global float *)(src_addr + 0 * src_stride_z + 0 * src_stride_y));
+    float  w01 = *((__global float *)(src_addr + 0 * src_stride_z + 1 * src_stride_y));
+    float  w02 = *((__global float *)(src_addr + 0 * src_stride_z + 2 * src_stride_y));
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+    float  w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
+    float  w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
+    float  w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
+    float  w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
+    float  w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
+    float  w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
 
     // Row 0
+    float out00, out01, out02, out03, out04, out05;
     out00 = (w00) / 16.f;
     out01 = (-w00 - w01 - w02) / 24.f;
     out02 = (-w00 + w01 - w02) / 24.f;
@@ -350,7 +346,9 @@
     out04 = (w00 - 2.f * w01 + 4.f * w02) / 96.f;
     out05 = (w02) / 4.f;
 
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
     // Row 1
+    float out10, out11, out12, out13, out14, out15;
     out10 = (-w00 - w10 - w20) / 24.f;
     out11 = (w00 + w10 + w20 + w01 + w11 + w21 + w02 + w12 + w22) / 36.f;
     out12 = (w00 + w10 + w20 - w01 - w11 - w21 + w02 + w12 + w22) / 36.f;
@@ -359,6 +357,7 @@
     out15 = (-w02 - w12 - w22) / 6.f;
 
     // Row 2
+    float out20, out21, out22, out23, out24, out25;
     out20 = (-w00 + w10 - w20) / 24.f;
     out21 = (w00 - w10 + w20 + w01 - w11 + w21 + w02 - w12 + w22) / 36.f;
     out22 = (w00 - w10 + w20 - w01 + w11 - w21 + w02 - w12 + w22) / 36.f;
@@ -367,6 +366,7 @@
     out25 = (-w02 + w12 - w22) / 6.f;
 
     // Row 3
+    float out30, out31, out32, out33, out34, out35;
     out30 = (w00 + 2.f * w10 + 4.f * w20) / 96.f;
     out31 = (-w00 - 2.f * w10 - 4.f * w20 - w01 - 2.f * w11 - 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
     out32 = (-w00 - 2.f * w10 - 4.f * w20 + w01 + 2.f * w11 + 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
@@ -375,6 +375,7 @@
     out35 = (w02 + 2.f * w12 + 4.f * w22) / 24.f;
 
     // Row 4
+    float out40, out41, out42, out43, out44, out45;
     out40 = (w00 - 2.f * w10 + 4.f * w20) / 96.f;
     out41 = (-w00 + 2.f * w10 - 4.f * w20 - w01 + 2.f * w11 - 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
     out42 = (-w00 + 2.f * w10 - 4.f * w20 + w01 - 2.f * w11 + 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
@@ -383,26 +384,31 @@
     out45 = (w02 - 2.f * w12 + 4.f * w22) / 24.f;
 
     // Row 5
+    float out50, out51, out52, out53, out54, out55;
     out50 = (w20) / 4.f;
     out51 = (-w20 - w21 - w22) / 6.f;
     out52 = (-w20 + w21 - w22) / 6.f;
     out53 = (w20 + 2.f * w21 + 4.f * w22) / 24.f;
     out54 = (w20 - 2.f * w21 + 4.f * w22) / 24.f;
     out55 = (w22);
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
 
     int x0 = get_global_id(2); // idx filter
     int y0 = get_global_id(0); // idx channel
 
     // Get output address
-    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
+    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(float) + y0 * dst_stride_y;
 
     // Store the values across the channels
-    *(__global float *)(dst_addr + 0 * dst_stride_z)  = out00;
-    *(__global float *)(dst_addr + 1 * dst_stride_z)  = out01;
-    *(__global float *)(dst_addr + 2 * dst_stride_z)  = out02;
-    *(__global float *)(dst_addr + 3 * dst_stride_z)  = out03;
-    *(__global float *)(dst_addr + 4 * dst_stride_z)  = out04;
-    *(__global float *)(dst_addr + 5 * dst_stride_z)  = out05;
+    // 36 channels for 3x3 kernels
+    // 6  channels for 3x1 or 1x3 kernels
+    *(__global float *)(dst_addr + 0 * dst_stride_z) = out00;
+    *(__global float *)(dst_addr + 1 * dst_stride_z) = out01;
+    *(__global float *)(dst_addr + 2 * dst_stride_z) = out02;
+    *(__global float *)(dst_addr + 3 * dst_stride_z) = out03;
+    *(__global float *)(dst_addr + 4 * dst_stride_z) = out04;
+    *(__global float *)(dst_addr + 5 * dst_stride_z) = out05;
+#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
     *(__global float *)(dst_addr + 6 * dst_stride_z)  = out10;
     *(__global float *)(dst_addr + 7 * dst_stride_z)  = out11;
     *(__global float *)(dst_addr + 8 * dst_stride_z)  = out12;
@@ -433,7 +439,108 @@
     *(__global float *)(dst_addr + 33 * dst_stride_z) = out53;
     *(__global float *)(dst_addr + 34 * dst_stride_z) = out54;
     *(__global float *)(dst_addr + 35 * dst_stride_z) = out55;
+#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
 }
+
+#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NHWC and the output tile is 4x1
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
+ *
+ * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F32
+ * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
+ * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
+ * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
+ * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in]  src_stride_w                      Stride of the source tensor in W dimension (in bytes)
+ * @param[in]  src_step_w                        src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
+ * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
+ * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_4x1_3x1_nhwc(
+    TENSOR4D_DECLARATION(src),
+    TENSOR3D_DECLARATION(dst))
+{
+    winograd_filter_transform_4x4_3x3_nhwc(src_ptr,
+                                           src_stride_x,
+                                           src_step_x,
+                                           src_stride_y,
+                                           src_step_y,
+                                           src_stride_z,
+                                           src_step_z,
+                                           src_stride_w,
+                                           src_step_w,
+                                           src_offset_first_element_in_bytes,
+                                           dst_ptr,
+                                           dst_stride_x,
+                                           dst_step_x,
+                                           dst_stride_y,
+                                           dst_step_y,
+                                           dst_stride_z,
+                                           dst_step_z,
+                                           dst_offset_first_element_in_bytes);
+}
+#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
+
+#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
+/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NHWC and the output tile is 1x4
+ *
+ * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
+ * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
+ *
+ * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F32
+ * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
+ * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
+ * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
+ * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in]  src_stride_w                      Stride of the source tensor in W dimension (in bytes)
+ * @param[in]  src_step_w                        src_stride_w * number of elements along W processed per workitem(in bytes)
+ * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
+ * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
+ * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_filter_transform_1x4_1x3_nhwc(
+    TENSOR4D_DECLARATION(src),
+    TENSOR3D_DECLARATION(dst))
+{
+    winograd_filter_transform_4x4_3x3_nhwc(src_ptr,
+                                           src_stride_x,
+                                           src_step_x,
+                                           src_stride_y,
+                                           src_step_y,
+                                           src_stride_z,
+                                           src_step_z,
+                                           src_stride_w,
+                                           src_step_w,
+                                           src_offset_first_element_in_bytes,
+                                           dst_ptr,
+                                           dst_stride_x,
+                                           dst_step_x,
+                                           dst_stride_y,
+                                           dst_step_y,
+                                           dst_stride_z,
+                                           dst_step_z,
+                                           dst_offset_first_element_in_bytes);
+}
+#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
 /** This OpenCL kernel performs Winograd filter transform 5x5/5x1 or 1x5 when the data layout is NCHW and the output tile is 4x4/4x1 or 1x4
  *
  * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
@@ -1264,4 +1371,4 @@
                                            dst_step_z,
                                            dst_offset_first_element_in_bytes);
 }
-#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
\ No newline at end of file
+#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)