COMPMID-459 Collapse CL Im2col's higher dimensions

Change-Id: I0ccc39cbcf6926e6810faf3fe264c4af7adc3f7b
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/83070
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/CL/ICLKernel.cpp b/src/core/CL/ICLKernel.cpp
index 5bd7142..bace631 100644
--- a/src/core/CL/ICLKernel.cpp
+++ b/src/core/CL/ICLKernel.cpp
@@ -122,6 +122,11 @@
     add_tensor_argument<3>(idx, tensor, window);
 }
 
+void ICLKernel::add_4D_tensor_argument(unsigned int &idx, const ICLTensor *tensor, const Window &window)
+{
+    add_tensor_argument<4>(idx, tensor, window);
+}
+
 unsigned int ICLKernel::num_arguments_per_1D_tensor() const
 {
     return num_arguments_per_tensor<1>();
@@ -137,6 +142,11 @@
     return num_arguments_per_tensor<3>();
 }
 
+unsigned int ICLKernel::num_arguments_per_4D_tensor() const
+{
+    return num_arguments_per_tensor<4>();
+}
+
 void ICLKernel::set_target(cl::Device &device)
 {
     _target = get_target_from_device(device);
diff --git a/src/core/CL/cl_kernels/convolution_layer.cl b/src/core/CL/cl_kernels/convolution_layer.cl
index 0dd331f..cd886b1 100644
--- a/src/core/CL/cl_kernels/convolution_layer.cl
+++ b/src/core/CL/cl_kernels/convolution_layer.cl
@@ -120,11 +120,15 @@
  */
 __kernel void im2col_generic(
     TENSOR3D_DECLARATION(src),
-    IMAGE_DECLARATION(dst))
+    IMAGE_DECLARATION(dst),
+    uint filter_depth,
+    uint src_stride_w,
+    uint dst_stride_w)
 {
-    const int xc = get_global_id(0); // x coordinate in the convolved tensor
-    const int yc = get_global_id(1); // y coordinate in the convolved tensor
-    const int ch = get_global_id(2); // input feature map
+    const int xc    = get_global_id(0);                // x coordinate in the convolved tensor
+    const int yc    = get_global_id(1);                // y coordinate in the convolved tensor
+    const int ch    = get_global_id(2) % filter_depth; // input feature map
+    const int batch = get_global_id(2) / filter_depth; // the batch
 
     // Calculate input indeces
     const int xi = xc * STRIDE_X - PAD_X;
@@ -134,8 +138,8 @@
     const int xo = ch * KERNEL_WIDTH * KERNEL_HEIGHT;
     const int yo = xc + yc * CONVOLVED_WIDTH; // Index of the convolution
 
-    __global uchar *input_ptr      = src_ptr + src_offset_first_element_in_bytes + ch * src_stride_z;
-    __global DATA_TYPE *output_ptr = ((__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + yo * dst_stride_y)) + xo;
+    __global uchar *input_ptr      = src_ptr + src_offset_first_element_in_bytes + ch * src_stride_z + batch * src_stride_w;
+    __global DATA_TYPE *output_ptr = ((__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + yo * dst_stride_y + batch * dst_stride_w)) + xo;
 
     // Linearize convolution elements
     for(int y = yi, y_e = yi + KERNEL_HEIGHT; y < y_e; ++y)
@@ -158,7 +162,7 @@
     }
 
 #ifdef HAS_BIAS
-    if(get_global_id(2) == (KERNEL_DEPTH - 1))
+    if(ch == (KERNEL_DEPTH - 1))
     {
 #ifdef FIXED_POINT_POSITION
         *output_ptr = (DATA_TYPE)(1 << FIXED_POINT_POSITION);
@@ -191,11 +195,15 @@
  */
 __kernel void im2col_kernel3x3_padx0_pady0(
     TENSOR3D_DECLARATION(src),
-    IMAGE_DECLARATION(dst))
+    IMAGE_DECLARATION(dst),
+    uint filter_depth,
+    uint src_stride_w,
+    uint dst_stride_w)
 {
-    const int xc = get_global_id(0); // x coordinate in the convolved tensor
-    const int yc = get_global_id(1); // y coordinate in the convolved tensor
-    const int ch = get_global_id(2); // input feature map
+    const int xc    = get_global_id(0);                // x coordinate in the convolved tensor
+    const int yc    = get_global_id(1);                // y coordinate in the convolved tensor
+    const int ch    = get_global_id(2) % filter_depth; // input feature map
+    const int batch = get_global_id(2) / filter_depth; // the batch
 
     // Calculate input indeces
     const int xi = xc * STRIDE_X;
@@ -206,8 +214,9 @@
     const int yo = xc + yc * CONVOLVED_WIDTH; // Index of the convolution
 
     // Get input and output address
-    __global uchar *input_ptr      = src_ptr + src_offset_first_element_in_bytes + xi * src_stride_x + yi * src_stride_y + ch * src_stride_z;
-    __global DATA_TYPE *output_ptr = ((__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + yo * dst_stride_y)) + xo;
+    __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + xi * src_stride_x + yi * src_stride_y + ch * src_stride_z + batch * src_stride_w;
+
+    __global DATA_TYPE *output_ptr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + yo * dst_stride_y + batch * dst_stride_w) + xo;
 
     VEC_DATA_TYPE(DATA_TYPE, 3)
     row0 = vload3(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
@@ -220,7 +229,7 @@
     *(output_ptr + 8) = row2.s2;
 
 #ifdef HAS_BIAS
-    if(get_global_id(2) == (KERNEL_DEPTH - 1))
+    if(ch == (KERNEL_DEPTH - 1))
     {
 #ifdef FIXED_POINT_POSITION
         *(output_ptr + 9) = (DATA_TYPE)(1 << FIXED_POINT_POSITION);
diff --git a/src/core/CL/cl_kernels/helpers.h b/src/core/CL/cl_kernels/helpers.h
index 59b81d7..68af64e 100644
--- a/src/core/CL/cl_kernels/helpers.h
+++ b/src/core/CL/cl_kernels/helpers.h
@@ -72,6 +72,18 @@
     uint        name##_step_z,   \
     uint        name##_offset_first_element_in_bytes
 
+#define TENSOR4D_DECLARATION(name)   \
+    __global uchar *name##_ptr,      \
+    uint        name##_stride_x, \
+    uint        name##_step_x,   \
+    uint        name##_stride_y, \
+    uint        name##_step_y,   \
+    uint        name##_stride_z, \
+    uint        name##_step_z,   \
+    uint        name##_stride_w, \
+    uint        name##_step_w,   \
+    uint        name##_offset_first_element_in_bytes
+
 #define CONVERT_TO_VECTOR_STRUCT(name) \
     update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x)
 
@@ -84,6 +96,9 @@
 #define CONVERT_TO_IMAGE_STRUCT_NO_STEP(name) \
     update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0)
 
+#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
+    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
+
 #define CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(name) \
     update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, name##_step_z)
 
@@ -97,6 +112,13 @@
 #define CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(name) \
     update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0)
 
+#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size)                                                                                                 \
+    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
+                                 name##_stride_z, name##_step_z, name##_stride_w, name##_step_z, mod_size)
+
+#define CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(name, mod_size) \
+    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0, name##_stride_w, 0, mod_size)
+
 /** Structure to hold Vector information */
 typedef struct Vector
 {
@@ -124,6 +146,17 @@
     int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
 } Tensor3D;
 
+/** Structure to hold 4D tensor information */
+typedef struct Tensor4D
+{
+    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
+    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
+    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
+    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
+    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
+    int             stride_w;                      /**< Stride of the image in W dimension (in bytes) */
+} Tensor4D;
+
 /** Wrap vector information into an Vector structure, and make the pointer point at this workitem's data.
  *
  * @param[in] ptr                           Pointer to the starting postion of the buffer
@@ -222,6 +255,24 @@
     return tensor;
 }
 
+Tensor4D inline update_tensor4D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z, uint stride_w,
+                                             uint step_w,
+                                             uint mod_size)
+{
+    Tensor4D tensor =
+    {
+        .ptr                           = ptr,
+        .offset_first_element_in_bytes = offset_first_element_in_bytes,
+        .stride_x                      = stride_x,
+        .stride_y                      = stride_y,
+        .stride_z                      = stride_z,
+        .stride_w                      = stride_w
+    };
+
+    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + (get_global_id(2) % mod_size) * step_z + (get_global_id(2) / mod_size) * step_w;
+    return tensor;
+}
+
 /** Get the pointer position of a Vector
  *
  * @param[in] vec Pointer to the starting position of the buffer
@@ -255,4 +306,17 @@
     return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z;
 }
 
+/** Get the pointer position of a Tensor4D
+ *
+ * @param[in] tensor Pointer to the starting position of the buffer
+ * @param[in] x      Relative X position
+ * @param[in] y      Relative Y position
+ * @param[in] z      Relative Z position
+ * @param[in] w      Relative W position
+ */
+__global inline const uchar *tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w)
+{
+    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + w * tensor->stride_w;
+}
+
 #endif // _HELPER_H
diff --git a/src/core/CL/kernels/CLIm2ColKernel.cpp b/src/core/CL/kernels/CLIm2ColKernel.cpp
index b72aff2..5147ea0 100644
--- a/src/core/CL/kernels/CLIm2ColKernel.cpp
+++ b/src/core/CL/kernels/CLIm2ColKernel.cpp
@@ -103,7 +103,6 @@
         {
             _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("im2col_generic", build_opts));
         }
-
         _run_func = &CLIm2ColKernel::run_generic;
     }
     else
@@ -117,6 +116,12 @@
     Window win = calculate_max_window(*input->info(), Steps());
     // The CLIm2ColKernel doesn't need padding so update_window_and_padding() can be skipped
     output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
+    if(!run_img2col_reduced)
+    {
+        // set the Z dimension's step same size as the whole dimension so that one can't split across the Z dimension
+        win.set_dimension_step(Window::DimZ, win[Window::DimZ].end() - win[Window::DimZ].start());
+    }
+
     ICLKernel::configure(win);
 }
 
@@ -132,14 +137,17 @@
     ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(ICLKernel::window(), window);
 
     // Get initial windows
-    Window slice     = window.first_slice_window_3D();
-    Window slice_in  = window.first_slice_window_3D();
-    Window slice_out = window.first_slice_window_3D();
+    Window window_collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
+    // Change the Z dimension's step back to 1
+    window_collapsed.set_dimension_step(Window::DimZ, 1);
+
+    Window slice     = window_collapsed.first_slice_window_3D();
+    Window slice_in  = window_collapsed.first_slice_window_3D();
+    Window slice_out = window_collapsed.first_slice_window_3D();
 
     // Setup slice
     slice.set(Window::DimX, Window::Dimension(0, static_cast<int>(_convolved_dims.first), 1));
     slice.set(Window::DimY, Window::Dimension(0, static_cast<int>(_convolved_dims.second), 1));
-    slice.set(Window::DimZ, Window::Dimension(0, static_cast<int>(_input->info()->dimension(2)), 1));
 
     // Setup input slice
     // The first three dimensions of the input are increased by the inner loops
@@ -157,13 +165,15 @@
 
     do
     {
-        // Set inputs
         unsigned int idx = 0;
         add_3D_tensor_argument(idx, _input, slice_in);
         add_2D_tensor_argument(idx, _output, slice_out);
+        _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input->info()->dimension(2)));
+        _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_input->info()->strides_in_bytes()[3]));
+        _kernel.setArg<cl_uint>(idx++, static_cast<unsigned int>(_output->info()->strides_in_bytes()[3]));
         enqueue(queue, *this, slice, _lws_hint);
     }
-    while(window.slide_window_slice_3D(slice) && window.slide_window_slice_3D(slice_out) && window.slide_window_slice_3D(slice_in));
+    while(window_collapsed.slide_window_slice_3D(slice) && window_collapsed.slide_window_slice_3D(slice_out) && window_collapsed.slide_window_slice_3D(slice_in));
 }
 
 void CLIm2ColKernel::run_reduced(const Window &window, cl::CommandQueue &queue)