COMPMID-1710: Fixes in StrideSlice calculations.

Change-Id: I66eb922f1ff15142de278bf4439a61c979f98ba7
Reviewed-on: https://review.mlplatform.org/382
Reviewed-by: Matthew Bentham <matthew.bentham@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Marquez <pablo.tello@arm.com>
diff --git a/src/core/CL/cl_kernels/slice_ops.cl b/src/core/CL/cl_kernels/slice_ops.cl
index bc3df47..97decee 100644
--- a/src/core/CL/cl_kernels/slice_ops.cl
+++ b/src/core/CL/cl_kernels/slice_ops.cl
@@ -64,7 +64,9 @@
     int offset = 0;
 
     // Offset X
-#if defined(START_0) && defined(STRIDE_0) && defined(VEC_SIZE) && defined(LAST_ACCESSED_X)
+#if defined(SHRINK_0)
+    input.ptr += (int)START_0 * input_stride_x;
+#elif defined(START_0) && defined(STRIDE_0) && defined(VEC_SIZE) && defined(LAST_ACCESSED_X)
     // Check if access on width gets out of bounds
     // If it does shift access vector to access elements within bounds
     const int xi = (int)(get_global_id(0) * VEC_SIZE);
@@ -77,20 +79,46 @@
 #endif // defined(START_0) && defined(STRIDE_0)
 
     // Offset Y
-#if defined(START_1) && defined(STRIDE_1)
+#if defined(SHRINK_1)
+    input.ptr += (int)START_1 * input_stride_y;
+#elif defined(START_1) && defined(STRIDE_1)
+#if defined(SHRINK_0)
+    offset = (int)START_1 + (int)get_global_id(0) * (int)STRIDE_1;
+#else  // defined(SHRINK_0)
     offset = (int)START_1 + (int)get_global_id(1) * (int)STRIDE_1;
+#endif // defined(SHRINK_0)
     input.ptr += offset * input_stride_y;
 #endif // defined(START_1) && defined(STRIDE_1)
 
     // Offset Z
-#if defined(START_2) && defined(STRIDE_2)
+#if defined(SHRINK_2)
+    input.ptr += (int)START_2 * input_stride_z;
+#elif defined(START_2) && defined(STRIDE_2)
+
+#if defined(SHRINK_1) && defined(SHRINK_0)
+    offset = (int)START_2 + (int)get_global_id(0) * (int)STRIDE_2;
+#elif defined(SHRINK_1) || defined(SHRINK_0)
+    offset = (int)START_2 + (int)get_global_id(1) * (int)STRIDE_2;
+#else  // defined(SHRINK_1) && defined(SHRINK_0)
     offset = (int)START_2 + ((int)get_global_id(2) % (int)DST_DEPTH) * (int)STRIDE_2;
+#endif // defined(SHRINK_1) && defined(SHRINK_0)
+
     input.ptr += offset * input_stride_z;
 #endif // defined(START_2) && defined(STRIDE_2)
 
     // Offset depth
-#if defined(START_3) && defined(STRIDE_3)
+#if defined(SHRINK_3)
+    input.ptr += (int)START_3 * input_stride_w;
+#elif defined(START_3) && defined(STRIDE_3)
+#if defined(SHRINK_2) && defined(SHRINK_1) && defined(SHRINK_0)
+    offset = (int)START_3 + (int)get_global_id(0) * (int)STRIDE_3;
+#elif !defined(SHRINK_2) && !defined(SHRINK_1) && !defined(SHRINK_0)
     offset = (int)START_3 + ((int)get_global_id(2) / (int)DST_DEPTH) * (int)STRIDE_3;
+#elif(defined(SHRINK_0) && defined(SHRINK_1)) || (defined(SHRINK_1) && defined(SHRINK_2)) || (defined(SHRINK_0) && defined(SHRINK_2))
+    offset = (int)START_3 + (int)get_global_id(1) * (int)STRIDE_3;
+#else  // defined(SHRINK_2) && defined(SHRINK_1) && defined(SHRINK_0)
+    offset = (int)START_3 + ((int)get_global_id(2) % (int)DST_DEPTH) * (int)STRIDE_3;
+#endif // defined(SHRINK_2) && defined(SHRINK_1) && defined(SHRINK_0)
     input.ptr += offset * input_stride_w;
 #endif // defined(START_3) && defined(STRIDE_3)
 
diff --git a/src/core/CL/kernels/CLStridedSliceKernel.cpp b/src/core/CL/kernels/CLStridedSliceKernel.cpp
index 3828a48..c40f3c9 100644
--- a/src/core/CL/kernels/CLStridedSliceKernel.cpp
+++ b/src/core/CL/kernels/CLStridedSliceKernel.cpp
@@ -32,6 +32,7 @@
 #include "arm_compute/core/Window.h"
 
 #include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/helpers/bit_ops.h"
 #include "arm_compute/core/utils/helpers/tensor_transform.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 
@@ -114,9 +115,11 @@
 
     const TensorShape &input_shape = input->info()->tensor_shape();
 
-    const Coordinates final_strides = arm_compute::helpers::tensor_transform::strided_slice_strides(input_shape, strides);
-    const Coordinates starts_abs    = arm_compute::helpers::tensor_transform::strided_slice_absolute_start_coords(input_shape, starts, final_strides, begin_mask);
-    const Coordinates ends_abs      = arm_compute::helpers::tensor_transform::strided_slice_absolute_end_coords(input_shape, starts_abs, ends, final_strides, end_mask, shrink_axis_mask);
+    Coordinates starts_abs, ends_abs, final_strides;
+    std::tie(starts_abs, ends_abs, final_strides) = arm_compute::helpers::tensor_transform::calculate_strided_slice_coords(
+                                                        input_shape,
+                                                        starts, ends, strides,
+                                                        begin_mask, end_mask, shrink_axis_mask);
 
     // Configure kernel window
     auto win_config = validate_and_configure_window(input->info(), output->info(), starts, ends, strides, begin_mask, end_mask, shrink_axis_mask);
@@ -125,7 +128,8 @@
     // Enable multiple elements processing along x if stride_x is 1 and output width greater than the access vector size
     const int  vec_size_x     = 16 / input->info()->element_size();
     const int  output_width_x = output->info()->tensor_shape().x();
-    const bool multi_access_x = (final_strides.x() == 1) && (output_width_x / vec_size_x > 0);
+    const bool is_shrink_on_x = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, 0);
+    const bool multi_access_x = !is_shrink_on_x && (final_strides.x() == 1) && (output_width_x / vec_size_x > 0);
 
     // Update window if needed
     if(multi_access_x)
@@ -141,8 +145,10 @@
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
     for(unsigned int i = 0; i < input_shape.num_dimensions(); ++i)
     {
+        const bool is_shrink = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, i);
         build_opts.add_option("-DSTART_" + support::cpp11::to_string(i) + "=" + support::cpp11::to_string(starts_abs[i]));
         build_opts.add_option("-DSTRIDE_" + support::cpp11::to_string(i) + "=" + support::cpp11::to_string(final_strides[i]));
+        build_opts.add_option_if(is_shrink, "-DSHRINK_" + support::cpp11::to_string(i));
     }
     build_opts.add_option_if(multi_access_x, "-DLAST_ACCESSED_X=" + support::cpp11::to_string(std::max<int>(output_width_x - vec_size_x, 0)));
     build_opts.add_option_if(multi_access_x, "-DVEC_SIZE=" + support::cpp11::to_string(vec_size_x));
diff --git a/src/core/utils/helpers/tensor_transform.cpp b/src/core/utils/helpers/tensor_transform.cpp
index a4bce5d..08803c7 100644
--- a/src/core/utils/helpers/tensor_transform.cpp
+++ b/src/core/utils/helpers/tensor_transform.cpp
@@ -23,13 +23,143 @@
  */
 #include "arm_compute/core/utils/helpers/tensor_transform.h"
 
+#include "arm_compute/core/utils/helpers/bit_ops.h"
+
 namespace arm_compute
 {
 namespace helpers
 {
 namespace tensor_transform
 {
-Coordinates slice_absolute_end_coords(TensorShape input_shape, Coordinates ends)
+int calculate_stride_on_index(int index, Coordinates strides)
+{
+    return index >= static_cast<int>(strides.num_dimensions()) ? 1 : strides[index];
+}
+
+int calculate_start_on_index(TensorShape input_shape, int index, Coordinates starts, Coordinates strides, int32_t begin_mask)
+{
+    // Early exit
+    if(index >= static_cast<int>(starts.num_dimensions()))
+    {
+        return 0;
+    }
+
+    // Get stride
+    const int stride = calculate_stride_on_index(index, strides);
+
+    // Calculate start
+    int start = starts[index];
+
+    // Reset in case of begin mask present
+    if(arm_compute::helpers::bit_ops::is_bit_set(begin_mask, index))
+    {
+        start = stride > 0 ? std::numeric_limits<int>::lowest() : std::numeric_limits<int>::max();
+    }
+
+    // Account negative start points
+    const int dim_size = input_shape[index];
+    if(start < 0)
+    {
+        start += dim_size;
+    }
+
+    // Final clamp
+    start = utility::clamp(start, 0, dim_size - 1);
+
+    return start;
+}
+
+int calculate_end_on_index(TensorShape input_shape, int index, int start_on_index,
+                           Coordinates ends, Coordinates strides,
+                           int32_t end_mask, int32_t shrink_axis_mask)
+{
+    // Early exit
+    if(index >= static_cast<int>(ends.num_dimensions()))
+    {
+        return input_shape[index];
+    }
+
+    const int  stride      = calculate_stride_on_index(index, strides);
+    const bool shrink_axis = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, index);
+
+    // Calculate start
+    int stop = ends[index];
+
+    // Shrink dimension
+    if(shrink_axis)
+    {
+        stop = start_on_index + 1;
+    }
+
+    // Reset in case of begin mask present
+    if(arm_compute::helpers::bit_ops::is_bit_set(end_mask, index) && !shrink_axis)
+    {
+        stop = (stride > 0) ? std::numeric_limits<int>::max() : std::numeric_limits<int>::lowest();
+    }
+
+    // Account negative end points
+    const int dim_size = input_shape[index];
+    if(stop < 0)
+    {
+        stop += dim_size;
+    }
+
+    // Final clamp
+    stop = (stride > 0) ? utility::clamp(stop, 0, dim_size) : utility::clamp(stop, -1, dim_size - 1);
+
+    return stop;
+}
+
+std::tuple<Coordinates, Coordinates, Coordinates> calculate_strided_slice_coords(TensorShape input_shape,
+                                                                                 Coordinates starts, Coordinates ends, Coordinates strides,
+                                                                                 int32_t begin_mask, int32_t end_mask, int32_t shrink_axis_mask)
+{
+    Coordinates starts_abs, ends_abs, final_strides;
+    for(unsigned int i = 0; i < input_shape.num_dimensions(); ++i)
+    {
+        const int start_i = calculate_start_on_index(input_shape, i, starts, strides, begin_mask);
+        starts_abs.set(i, start_i);
+        ends_abs.set(i, calculate_end_on_index(input_shape, i, start_i, ends, strides, end_mask, shrink_axis_mask));
+        final_strides.set(i, calculate_stride_on_index(i, strides));
+    }
+
+    return std::make_tuple(starts_abs, ends_abs, final_strides);
+}
+
+TensorShape compute_strided_slice_output_shape(TensorShape input_shape, Coordinates starts, Coordinates ends, Coordinates strides,
+                                               int32_t begin_mask, int32_t end_mask, int32_t shrink_axis_mask, bool return_unshrinked)
+{
+    unsigned int index = 0;
+
+    TensorShape output_shape;
+    for(unsigned int i = 0; i < input_shape.num_dimensions(); ++i)
+    {
+        const int stride = calculate_stride_on_index(index, strides);
+        const int start  = calculate_start_on_index(input_shape, i, starts, strides, begin_mask);
+        const int end    = calculate_end_on_index(input_shape, i, start, ends, strides, end_mask, shrink_axis_mask);
+        const int range  = end - start;
+
+        const bool is_shrink = arm_compute::helpers::bit_ops::is_bit_set(shrink_axis_mask, i);
+        if(return_unshrinked || !is_shrink)
+        {
+            if((range == 0) ||               // Zero range
+               (range < 0 && stride >= 0) || // Negative range with positive stride
+               (range > 0 && stride <= 0))   // Positive range with negative stride
+            {
+                output_shape.set(index, 0);
+                return output_shape;
+            }
+            else
+            {
+                int dim = range / stride + (range % stride != 0 ? 1 : 0);
+                output_shape.set(index++, dim);
+            }
+        }
+    }
+    return output_shape;
+}
+
+int32_t construct_slice_end_mask(Coordinates ends)
 {
     // Create end mask
     int32_t end_mask = 0;
@@ -40,126 +170,8 @@
             end_mask |= 1 << i;
         }
     }
-    // Get unit strides
-    const BiStrides unit_strides = strided_slice_strides(input_shape, BiStrides());
 
-    return strided_slice_absolute_end_coords(input_shape, Coordinates(), ends, unit_strides, end_mask);
-}
-
-TensorShape compute_slice_output_shape(TensorShape input_shape, Coordinates starts, Coordinates ends_abs)
-{
-    // Get unit strides
-    const BiStrides unit_strides = strided_slice_strides(input_shape, BiStrides());
-    return compute_strided_slice_output_shape(input_shape, starts, ends_abs, unit_strides);
-}
-
-Coordinates strided_slice_absolute_start_coords(TensorShape input_shape, Coordinates starts, Coordinates strides, int32_t begin_mask)
-{
-    Coordinates starts_abs;
-    for(unsigned int i = 0; i < starts.num_dimensions(); ++i)
-    {
-        // Get start index
-        int start_i = starts[i];
-
-        // Reset in case of begin mask present
-        if((begin_mask & 1 << i) != 0)
-        {
-            start_i = strides[i] > 0 ? std::numeric_limits<int>::lowest() : std::numeric_limits<int>::max();
-        }
-
-        // Account negative start points
-        const int dim_size = input_shape[i];
-        if(start_i < 0)
-        {
-            start_i += dim_size;
-        }
-
-        // Final clamp
-        start_i = utility::clamp(start_i, 0, dim_size - 1);
-        starts_abs.set(i, start_i);
-    }
-
-    // Fill remaining
-    for(unsigned int i = starts_abs.num_dimensions(); i < input_shape.num_dimensions(); ++i)
-    {
-        starts_abs.set(i, 0);
-    }
-
-    return starts_abs;
-}
-
-Coordinates strided_slice_absolute_end_coords(TensorShape input_shape, Coordinates starts_abs, Coordinates ends, Coordinates strides,
-                                              int32_t end_mask, int32_t shrink_axis_mask)
-{
-    Coordinates ends_abs;
-    for(unsigned int i = 0; i < ends.num_dimensions(); ++i)
-    {
-        // Get end index
-        int stop_i = ends[i];
-
-        // Shrink dimension
-        if((shrink_axis_mask & (1 << i)) != 0)
-        {
-            stop_i = starts_abs[i] + 1;
-        }
-
-        // Reset in case of begin mask present
-        if((end_mask & 1 << i) != 0)
-        {
-            stop_i = (strides[i] > 0) ? std::numeric_limits<int>::max() : std::numeric_limits<int>::lowest();
-        }
-
-        // Account negative end points
-        const int dim_size = input_shape[i];
-        if(stop_i < 0)
-        {
-            stop_i += dim_size;
-        }
-
-        // Final clamp
-        stop_i = (strides[i] > 0) ? utility::clamp(stop_i, 0, dim_size) : utility::clamp(stop_i, -1, dim_size - 1);
-        ends_abs.set(i, stop_i);
-    }
-
-    // Fill remaining ends
-    for(unsigned int i = ends_abs.num_dimensions(); i < input_shape.num_dimensions(); ++i)
-    {
-        ends_abs.set(i, input_shape[i]);
-    }
-
-    return ends_abs;
-}
-
-Coordinates strided_slice_strides(TensorShape input_shape, Coordinates strides)
-{
-    for(unsigned int i = strides.num_dimensions(); i < input_shape.num_dimensions(); ++i)
-    {
-        strides.set(i, 1);
-    }
-    return strides;
-}
-
-TensorShape compute_strided_slice_output_shape(TensorShape input_shape, Coordinates starts_abs, Coordinates ends_abs, Coordinates final_strides)
-{
-    TensorShape output_shape = input_shape;
-    for(unsigned int i = 0; i < input_shape.num_dimensions(); ++i)
-    {
-        const int stride_i = final_strides[i];
-        const int range    = ends_abs[i] - starts_abs[i];
-        if((range == 0) ||                 // Zero range
-           (range < 0 && stride_i >= 0) || // Negative range with positive stride
-           (range > 0 && stride_i <= 0))   // Positive range with negative stride
-        {
-            output_shape.set(i, 0);
-            return output_shape;
-        }
-        else
-        {
-            int dim = range / stride_i + (range % stride_i != 0 ? 1 : 0);
-            output_shape.set(i, dim);
-        }
-    }
-    return output_shape;
+    return end_mask;
 }
 } // namespace tensor_transform
 } // namespace helpers
diff --git a/src/graph/nodes/SliceLayerNode.cpp b/src/graph/nodes/SliceLayerNode.cpp
index 3a29e4c..bfc009d 100644
--- a/src/graph/nodes/SliceLayerNode.cpp
+++ b/src/graph/nodes/SliceLayerNode.cpp
@@ -24,7 +24,7 @@
 #include "arm_compute/graph/nodes/SliceLayerNode.h"
 
 #include "arm_compute/core/Utils.h"
-#include "arm_compute/core/utils/helpers/tensor_transform.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/graph/Graph.h"
 #include "arm_compute/graph/INodeVisitor.h"
 
@@ -52,16 +52,12 @@
 TensorDescriptor SliceLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor,
                                                            const Coordinates &starts, const Coordinates &ends)
 {
-    // Get absolute end coordinates
-    const Coordinates ends_abs = arm_compute::helpers::tensor_transform::slice_absolute_end_coords(input_descriptor.shape, ends);
+    using namespace arm_compute::helpers::tensor_transform;
 
-    TensorDescriptor output_descriptor = input_descriptor;
-    for(unsigned int i = 0; i < starts.num_dimensions(); ++i)
-    {
-        output_descriptor.shape.set(i, ends_abs[i] - starts[i]);
-    }
+    TensorDescriptor output_desc = input_descriptor;
+    output_desc.shape            = arm_compute::misc::shape_calculator::compute_slice_shape(input_descriptor.shape, starts, ends);
 
-    return output_descriptor;
+    return output_desc;
 }
 
 bool SliceLayerNode::forward_descriptors()
diff --git a/src/runtime/CL/functions/CLSlice.cpp b/src/runtime/CL/functions/CLSlice.cpp
index bef7eca..f630853 100644
--- a/src/runtime/CL/functions/CLSlice.cpp
+++ b/src/runtime/CL/functions/CLSlice.cpp
@@ -36,10 +36,10 @@
     ARM_COMPUTE_ERROR_ON_NULLPTR(input);
 
     // Get absolute end coordinates
-    const Coordinates ends_abs = arm_compute::helpers::tensor_transform::slice_absolute_end_coords(input->info()->tensor_shape(), ends);
+    const int32_t slice_end_mask = arm_compute::helpers::tensor_transform::construct_slice_end_mask(ends);
 
     auto k = arm_compute::support::cpp14::make_unique<CLStridedSliceKernel>();
-    k->configure(input, output, starts, ends_abs, BiStrides(), 0, 0, 0);
+    k->configure(input, output, starts, ends, BiStrides(), 0, slice_end_mask, 0);
     _kernel = std::move(k);
 }
 
@@ -54,8 +54,8 @@
     }));
 
     // Get absolute end coordinates
-    const Coordinates ends_abs = arm_compute::helpers::tensor_transform::slice_absolute_end_coords(input->tensor_shape(), ends);
+    const int32_t slice_end_mask = arm_compute::helpers::tensor_transform::construct_slice_end_mask(ends);
 
-    return CLStridedSliceKernel::validate(input, output, starts, ends_abs, BiStrides(), 0, 0, 0);
+    return CLStridedSliceKernel::validate(input, output, starts, ends, BiStrides(), 0, slice_end_mask, 0);
 }
 } // namespace arm_compute