COMPMID-1581: Collapse windows

Change-Id: Iec56c9a96d9736a63f13b65efa33311950f20661
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/148572
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: bsgcomp <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
index 7188983..b3edc52 100644
--- a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
+++ b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
@@ -45,7 +45,7 @@
 #endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
 #endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
 
-#if defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER)
+#if defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS)
 
 #if CONV_STRIDE_X > 3
 #error "Stride X not supported"
@@ -129,18 +129,25 @@
 {
     Image    src     = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
     Image    dst     = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
-    Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
+    Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
+
+    // Extract channel and linearized batch indices
+    const int channel = get_global_id(2) % DST_CHANNELS;
+    const int batch   = get_global_id(2) / DST_CHANNELS;
+
 #if defined(HAS_BIAS)
     Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
 
-    int bias_value = *((__global int *)(vector_offset(&biases, get_global_id(2))));
+    int bias_value = *((__global int *)(vector_offset(&biases, channel));
 #endif //defined(HAS_BIAS)
 
-    src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+    // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+    src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
+    __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
 
-    uchar3 w0 = vload3(0, weights.ptr + 0 * weights_stride_y);
-    uchar3 w1 = vload3(0, weights.ptr + 1 * weights_stride_y);
-    uchar3 w2 = vload3(0, weights.ptr + 2 * weights_stride_y);
+    uchar3 w0 = vload3(0, weights_addr + 0 * weights_stride_y);
+    uchar3 w1 = vload3(0, weights_addr + 1 * weights_stride_y);
+    uchar3 w2 = vload3(0, weights_addr + 2 * weights_stride_y);
 
     int8 values0 = 0;
     int8 sum0    = 0;
@@ -337,18 +344,25 @@
 {
     Image    src     = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
     Image    dst     = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
-    Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
-#if defined(HAS_BIAS)
-    Vector   biases  = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
+    Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
 
-    const int bias_value = *((__global int *)(vector_offset(&biases, get_global_id(2))));
+    // Extract channel and linearized batch indices
+    const int channel = get_global_id(2) % DST_CHANNELS;
+    const int batch   = get_global_id(2) / DST_CHANNELS;
+
+#if defined(HAS_BIAS)
+    Vector    biases  = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
+
+    const int bias_value = *((__global int *)(vector_offset(&biases, channel)));
 #endif //defined(HAS_BIAS)
 
-    src.ptr -= (get_global_id(2) - get_global_id(2) / DEPTH_MULTIPLIER) * src_step_z;
+    // Load relevant input and weights data (Accounts depth multiplier when indexing input, OFM = IFM * DEPTH_MULTIPLIER)
+    src.ptr -= batch * (DST_CHANNELS / DEPTH_MULTIPLIER) * (DEPTH_MULTIPLIER - 1) * src_step_z + (channel - (channel / DEPTH_MULTIPLIER)) * src_step_z;
+    __global uchar *weights_addr = weights.ptr + get_global_id(0) * weights_step_x + get_global_id(1) * weights_step_y + channel * weights_step_z;
 
-    uchar3 w0 = vload3(0, weights.ptr + 0 * weights_stride_y);
-    uchar3 w1 = vload3(0, weights.ptr + 1 * weights_stride_y);
-    uchar3 w2 = vload3(0, weights.ptr + 2 * weights_stride_y);
+    uchar3 w0 = vload3(0, weights_addr + 0 * weights_stride_y);
+    uchar3 w1 = vload3(0, weights_addr + 1 * weights_stride_y);
+    uchar3 w2 = vload3(0, weights_addr + 2 * weights_stride_y);
 
     uchar8 left0, middle0, right0;
     uchar8 left1, middle1, right1;
@@ -501,7 +515,7 @@
 
 #endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
 
-#endif /* defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER) */
+#endif /* defined(CONV_STRIDE_Y) && defined(CONV_STRIDE_X) && defined(DEPTH_MULTIPLIER) && defined(DST_CHANNELS) */
 
 #if defined(VEC_SIZE) && defined(SRC_DIM_1) && defined(SRC_DIM_2) && defined(CONV_PAD_TOP) && defined(CONV_PAD_LEFT)