COMPMID-459 Collapse CL Im2col's higher dimensions

Change-Id: I0ccc39cbcf6926e6810faf3fe264c4af7adc3f7b
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/83070
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/CL/cl_kernels/convolution_layer.cl b/src/core/CL/cl_kernels/convolution_layer.cl
index 0dd331f..cd886b1 100644
--- a/src/core/CL/cl_kernels/convolution_layer.cl
+++ b/src/core/CL/cl_kernels/convolution_layer.cl
@@ -120,11 +120,15 @@
  */
 __kernel void im2col_generic(
     TENSOR3D_DECLARATION(src),
-    IMAGE_DECLARATION(dst))
+    IMAGE_DECLARATION(dst),
+    uint filter_depth,
+    uint src_stride_w,
+    uint dst_stride_w)
 {
-    const int xc = get_global_id(0); // x coordinate in the convolved tensor
-    const int yc = get_global_id(1); // y coordinate in the convolved tensor
-    const int ch = get_global_id(2); // input feature map
+    const int xc    = get_global_id(0);                // x coordinate in the convolved tensor
+    const int yc    = get_global_id(1);                // y coordinate in the convolved tensor
+    const int ch    = get_global_id(2) % filter_depth; // input feature map
+    const int batch = get_global_id(2) / filter_depth; // the batch
 
     // Calculate input indeces
     const int xi = xc * STRIDE_X - PAD_X;
@@ -134,8 +138,8 @@
     const int xo = ch * KERNEL_WIDTH * KERNEL_HEIGHT;
     const int yo = xc + yc * CONVOLVED_WIDTH; // Index of the convolution
 
-    __global uchar *input_ptr      = src_ptr + src_offset_first_element_in_bytes + ch * src_stride_z;
-    __global DATA_TYPE *output_ptr = ((__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + yo * dst_stride_y)) + xo;
+    __global uchar *input_ptr      = src_ptr + src_offset_first_element_in_bytes + ch * src_stride_z + batch * src_stride_w;
+    __global DATA_TYPE *output_ptr = ((__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + yo * dst_stride_y + batch * dst_stride_w)) + xo;
 
     // Linearize convolution elements
     for(int y = yi, y_e = yi + KERNEL_HEIGHT; y < y_e; ++y)
@@ -158,7 +162,7 @@
     }
 
 #ifdef HAS_BIAS
-    if(get_global_id(2) == (KERNEL_DEPTH - 1))
+    if(ch == (KERNEL_DEPTH - 1))
     {
 #ifdef FIXED_POINT_POSITION
         *output_ptr = (DATA_TYPE)(1 << FIXED_POINT_POSITION);
@@ -191,11 +195,15 @@
  */
 __kernel void im2col_kernel3x3_padx0_pady0(
     TENSOR3D_DECLARATION(src),
-    IMAGE_DECLARATION(dst))
+    IMAGE_DECLARATION(dst),
+    uint filter_depth,
+    uint src_stride_w,
+    uint dst_stride_w)
 {
-    const int xc = get_global_id(0); // x coordinate in the convolved tensor
-    const int yc = get_global_id(1); // y coordinate in the convolved tensor
-    const int ch = get_global_id(2); // input feature map
+    const int xc    = get_global_id(0);                // x coordinate in the convolved tensor
+    const int yc    = get_global_id(1);                // y coordinate in the convolved tensor
+    const int ch    = get_global_id(2) % filter_depth; // input feature map
+    const int batch = get_global_id(2) / filter_depth; // the batch
 
     // Calculate input indeces
     const int xi = xc * STRIDE_X;
@@ -206,8 +214,9 @@
     const int yo = xc + yc * CONVOLVED_WIDTH; // Index of the convolution
 
     // Get input and output address
-    __global uchar *input_ptr      = src_ptr + src_offset_first_element_in_bytes + xi * src_stride_x + yi * src_stride_y + ch * src_stride_z;
-    __global DATA_TYPE *output_ptr = ((__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + yo * dst_stride_y)) + xo;
+    __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + xi * src_stride_x + yi * src_stride_y + ch * src_stride_z + batch * src_stride_w;
+
+    __global DATA_TYPE *output_ptr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + yo * dst_stride_y + batch * dst_stride_w) + xo;
 
     VEC_DATA_TYPE(DATA_TYPE, 3)
     row0 = vload3(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
@@ -220,7 +229,7 @@
     *(output_ptr + 8) = row2.s2;
 
 #ifdef HAS_BIAS
-    if(get_global_id(2) == (KERNEL_DEPTH - 1))
+    if(ch == (KERNEL_DEPTH - 1))
     {
 #ifdef FIXED_POINT_POSITION
         *(output_ptr + 9) = (DATA_TYPE)(1 << FIXED_POINT_POSITION);
diff --git a/src/core/CL/cl_kernels/helpers.h b/src/core/CL/cl_kernels/helpers.h
index 59b81d7..68af64e 100644
--- a/src/core/CL/cl_kernels/helpers.h
+++ b/src/core/CL/cl_kernels/helpers.h
@@ -72,6 +72,18 @@
     uint        name##_step_z,   \
     uint        name##_offset_first_element_in_bytes
 
+#define TENSOR4D_DECLARATION(name)   \
+    __global uchar *name##_ptr,      \
+    uint        name##_stride_x, \
+    uint        name##_step_x,   \
+    uint        name##_stride_y, \
+    uint        name##_step_y,   \
+    uint        name##_stride_z, \
+    uint        name##_step_z,   \
+    uint        name##_stride_w, \
+    uint        name##_step_w,   \
+    uint        name##_offset_first_element_in_bytes
+
 #define CONVERT_TO_VECTOR_STRUCT(name) \
     update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x)
 
@@ -84,6 +96,9 @@
 #define CONVERT_TO_IMAGE_STRUCT_NO_STEP(name) \
     update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0)
 
+#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
+    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
+
 #define CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(name) \
     update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, name##_step_z)
 
@@ -97,6 +112,13 @@
 #define CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(name) \
     update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0)
 
+#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size)                                                                                                 \
+    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
+                                 name##_stride_z, name##_step_z, name##_stride_w, name##_step_z, mod_size)
+
+#define CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(name, mod_size) \
+    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0, name##_stride_w, 0, mod_size)
+
 /** Structure to hold Vector information */
 typedef struct Vector
 {
@@ -124,6 +146,17 @@
     int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
 } Tensor3D;
 
+/** Structure to hold 4D tensor information */
+typedef struct Tensor4D
+{
+    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
+    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
+    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
+    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
+    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
+    int             stride_w;                      /**< Stride of the image in W dimension (in bytes) */
+} Tensor4D;
+
 /** Wrap vector information into an Vector structure, and make the pointer point at this workitem's data.
  *
  * @param[in] ptr                           Pointer to the starting postion of the buffer
@@ -222,6 +255,24 @@
     return tensor;
 }
 
+Tensor4D inline update_tensor4D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z, uint stride_w,
+                                             uint step_w,
+                                             uint mod_size)
+{
+    Tensor4D tensor =
+    {
+        .ptr                           = ptr,
+        .offset_first_element_in_bytes = offset_first_element_in_bytes,
+        .stride_x                      = stride_x,
+        .stride_y                      = stride_y,
+        .stride_z                      = stride_z,
+        .stride_w                      = stride_w
+    };
+
+    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + (get_global_id(2) % mod_size) * step_z + (get_global_id(2) / mod_size) * step_w;
+    return tensor;
+}
+
 /** Get the pointer position of a Vector
  *
  * @param[in] vec Pointer to the starting position of the buffer
@@ -255,4 +306,17 @@
     return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z;
 }
 
+/** Get the pointer position of a Tensor4D
+ *
+ * @param[in] tensor Pointer to the starting position of the buffer
+ * @param[in] x      Relative X position
+ * @param[in] y      Relative Y position
+ * @param[in] z      Relative Z position
+ * @param[in] w      Relative W position
+ */
+__global inline const uchar *tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w)
+{
+    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + w * tensor->stride_w;
+}
+
 #endif // _HELPER_H