COMPMID-1710: Fixes in StrideSlice calculations.

Change-Id: I66eb922f1ff15142de278bf4439a61c979f98ba7
Reviewed-on: https://review.mlplatform.org/382
Reviewed-by: Matthew Bentham <matthew.bentham@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
diff --git a/src/core/CL/cl_kernels/slice_ops.cl b/src/core/CL/cl_kernels/slice_ops.cl
index bc3df47..97decee 100644
--- a/src/core/CL/cl_kernels/slice_ops.cl
+++ b/src/core/CL/cl_kernels/slice_ops.cl
@@ -64,7 +64,9 @@
     int offset = 0;
 
     // Offset X
-#if defined(START_0) && defined(STRIDE_0) && defined(VEC_SIZE) && defined(LAST_ACCESSED_X)
+#if defined(SHRINK_0)
+    input.ptr += (int)START_0 * input_stride_x;
+#elif defined(START_0) && defined(STRIDE_0) && defined(VEC_SIZE) && defined(LAST_ACCESSED_X)
     // Check if access on width gets out of bounds
     // If it does shift access vector to access elements within bounds
     const int xi = (int)(get_global_id(0) * VEC_SIZE);
@@ -77,20 +79,46 @@
 #endif // defined(START_0) && defined(STRIDE_0)
 
     // Offset Y
-#if defined(START_1) && defined(STRIDE_1)
+#if defined(SHRINK_1)
+    input.ptr += (int)START_1 * input_stride_y;
+#elif defined(START_1) && defined(STRIDE_1)
+#if defined(SHRINK_0)
+    offset = (int)START_1 + (int)get_global_id(0) * (int)STRIDE_1;
+#else  // defined(SHRINK_0)
     offset = (int)START_1 + (int)get_global_id(1) * (int)STRIDE_1;
+#endif // defined(SHRINK_0)
     input.ptr += offset * input_stride_y;
 #endif // defined(START_1) && defined(STRIDE_1)
 
     // Offset Z
-#if defined(START_2) && defined(STRIDE_2)
+#if defined(SHRINK_2)
+    input.ptr += (int)START_2 * input_stride_z;
+#elif defined(START_2) && defined(STRIDE_2)
+
+#if defined(SHRINK_1) && defined(SHRINK_0)
+    offset = (int)START_2 + (int)get_global_id(0) * (int)STRIDE_2;
+#elif defined(SHRINK_1) || defined(SHRINK_0)
+    offset = (int)START_2 + (int)get_global_id(1) * (int)STRIDE_2;
+#else  // defined(SHRINK_1) && defined(SHRINK_0)
     offset = (int)START_2 + ((int)get_global_id(2) % (int)DST_DEPTH) * (int)STRIDE_2;
+#endif // defined(SHRINK_1) && defined(SHRINK_0)
+
     input.ptr += offset * input_stride_z;
 #endif // defined(START_2) && defined(STRIDE_2)
 
     // Offset depth
-#if defined(START_3) && defined(STRIDE_3)
+#if defined(SHRINK_3)
+    input.ptr += (int)START_3 * input_stride_w;
+#elif defined(START_3) && defined(STRIDE_3)
+#if defined(SHRINK_2) && defined(SHRINK_1) && defined(SHRINK_0)
+    offset = (int)START_3 + (int)get_global_id(0) * (int)STRIDE_3;
+#elif !defined(SHRINK_2) && !defined(SHRINK_1) && !defined(SHRINK_0)
     offset = (int)START_3 + ((int)get_global_id(2) / (int)DST_DEPTH) * (int)STRIDE_3;
+#elif(defined(SHRINK_0) && defined(SHRINK_1)) || (defined(SHRINK_1) && defined(SHRINK_2)) || (defined(SHRINK_0) && defined(SHRINK_2))
+    offset = (int)START_3 + (int)get_global_id(1) * (int)STRIDE_3;
+#else  // defined(SHRINK_2) && defined(SHRINK_1) && defined(SHRINK_0)
+    offset = (int)START_3 + ((int)get_global_id(2) % (int)DST_DEPTH) * (int)STRIDE_3;
+#endif // defined(SHRINK_2) && defined(SHRINK_1) && defined(SHRINK_0)
     input.ptr += offset * input_stride_w;
 #endif // defined(START_3) && defined(STRIDE_3)
 
diff --git a/src/core/CL/kernels/CLStridedSliceKernel.cpp b/src/core/CL/kernels/CLStridedSliceKernel.cpp
index 3828a48..c40f3c9 100644
--- a/src/core/CL/kernels/CLStridedSliceKernel.cpp
+++ b/src/core/CL/kernels/CLStridedSliceKernel.cpp
@@ -32,6 +32,7 @@
 #include "arm_compute/core/Window.h"
 
 #include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/helpers/bit_ops.h"
 #include "arm_compute/core/utils/helpers/tensor_transform.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 
@@ -114,9 +115,11 @@
 
     const TensorShape &input_shape = input->info()->tensor_shape();
 
-    const Coordinates final_strides = arm_compute::helpers::tensor_transform::strided_slice_strides(input_shape, strides);
-    const Coordinates starts_abs    = arm_compute::helpers::tensor_transform::strided_slice_absolute_start_coords(input_shape, starts, final_strides, begin_mask);
-    const Coordinates ends_abs      = arm_compute::helpers::tensor_transform::strided_slice_absolute_end_coords(input_shape, starts_abs, ends, final_strides, end_mask, shrink_axis_mask);
+    Coordinates starts_abs, ends_abs, final_strides;
+    std::tie(starts_abs, ends_abs, final_strides) = arm_compute::helpers::tensor_transform::calculate_strided_slice_coords(
+                                                        input_shape,
+                                                        starts, ends, strides,
+                                                        begin_mask, end_mask, shrink_axis_mask);
 
     // Configure kernel window
     auto win_config = validate_and_configure_window(input->info(), output->info(), starts, ends, strides, begin_mask, end_mask, shrink_axis_mask);
@@ -125,7 +128,8 @@
     // Enable multiple elements processing along x if stride_x is 1 and output width greater than the access vector size
     const int  vec_size_x     = 16 / input->info()->element_size();
     const int  output_width_x = output->info()->tensor_shape().x();
-    const bool multi_access_x = (final_strides.x() == 1) && (output_width_x / vec_size_x > 0);
+    const bool is_shrink_on_x = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, 0);
+    const bool multi_access_x = !is_shrink_on_x && (final_strides.x() == 1) && (output_width_x / vec_size_x > 0);
 
     // Update window if needed
     if(multi_access_x)
@@ -141,8 +145,10 @@
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
     for(unsigned int i = 0; i < input_shape.num_dimensions(); ++i)
     {
+        const bool is_shrink = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, i);
         build_opts.add_option("-DSTART_" + support::cpp11::to_string(i) + "=" + support::cpp11::to_string(starts_abs[i]));
         build_opts.add_option("-DSTRIDE_" + support::cpp11::to_string(i) + "=" + support::cpp11::to_string(final_strides[i]));
+        build_opts.add_option_if(is_shrink, "-DSHRINK_" + support::cpp11::to_string(i));
     }
     build_opts.add_option_if(multi_access_x, "-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(output_width_x - vec_size_x, 0)));
     build_opts.add_option_if(multi_access_x, "-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x));