COMPMID-719: NEPermuteKernel refactoring

Change-Id: I91b43d9706ac3244ce43684967ace0b022d35bad
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/114988
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/CPP/kernels/CPPPermuteKernel.cpp b/src/core/CPP/kernels/CPPPermuteKernel.cpp
index 80b0aba..c7bae87 100644
--- a/src/core/CPP/kernels/CPPPermuteKernel.cpp
+++ b/src/core/CPP/kernels/CPPPermuteKernel.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017, 2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,6 +29,7 @@
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 
 #include <cstddef>
 #include <cstdint>
@@ -37,13 +38,6 @@
 
 namespace
 {
-TensorShape get_output_shape(const ITensorInfo *input, const PermutationVector &perm)
-{
-    TensorShape output_shape = input->tensor_shape();
-    permute(output_shape, perm);
-    return output_shape;
-}
-
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const PermutationVector &perm)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QS8, DataType::QASYMM8,
@@ -57,7 +51,7 @@
                 || (perm[0] != 1 && perm[1] != 2 && perm[2] != 0))),
         "Only [2, 0, 1],[1, 2, 0] and [3, 2, 0, 1] permutation is supported");
 
-    const TensorShape output_shape = get_output_shape(input, perm);
+    const TensorShape output_shape = misc::shape_calculator::compute_permutation_output_shape(*input, perm);
 
     // Validate configured output
     if(output->total_size() != 0)
@@ -69,59 +63,71 @@
 
     return Status{};
 }
+
+template <typename T>
+inline void permute_strides(Dimensions<T> &dimensions, const PermutationVector &perm)
+{
+    const auto old_dim = utility::make_array<Dimensions<T>::num_max_dimensions>(dimensions.begin(), dimensions.end());
+    for(unsigned int i = 0; i < perm.num_dimensions(); ++i)
+    {
+        dimensions[perm[i]] = old_dim[i];
+    }
+}
+
 } // namespace
 
+
+
+
 template <typename T>
 void CPPPermuteKernel::run_permute(const Window &window)
 {
-    const int output_stride_x = _output->info()->strides_in_bytes().x();
-    const int output_stride_y = _output->info()->strides_in_bytes().y();
-    const int output_stride_z = _output->info()->strides_in_bytes().z();
-    const int output_stride_w = _output->info()->strides_in_bytes()[3];
 
-    Window window_out(window);
-    window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
-    window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
-    window_out.set(Window::DimZ, Window::Dimension(0, 0, 0));
-    window_out.set(3, Window::Dimension(0, 0, 0));
-
+    Strides strides = _output->info()->strides_in_bytes();
+    Strides perm_strides = strides;
+    permute_strides(perm_strides,_perm);
+    const int output_stride_w = strides[3];
+    Window                  window_out(window);
+    const Window::Dimension zero_window = Window::Dimension(0, 0, 0);
+    for(size_t d = 0; d <= _perm.num_dimensions(); ++d)
+    {
+        window_out.set(d, zero_window);
+    }
     // Create iterators
     Iterator in(_input, window);
     Iterator out(_output, window_out);
-
-    // Run [2, 0, 1] permute
-    if(_perm[0] == 2 && _perm[1] == 0 && _perm[2] == 1)
+    ARM_COMPUTE_ERROR_ON(_perm.num_dimensions() > _input->info()->num_dimensions());
+    if(_input->info()->num_dimensions() <= 3)
     {
         execute_window_loop(window, [&](const Coordinates & id)
         {
-            const int idx                             = id[3] * output_stride_w + id.y() * output_stride_z + id.x() * output_stride_y + id.z() * output_stride_x;
+            const int idx                             = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2];
             *(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr()));
         },
         in, out);
     }
-    // Run [1, 2, 0] permute
-    else if(_perm[0] == 1 && _perm[1] == 2 && _perm[2] == 0)
+    else if(_input->info()->num_dimensions() >= 4)
     {
-        execute_window_loop(window, [&](const Coordinates & id)
+        if(_perm.num_dimensions() < _input->info()->num_dimensions())
         {
-            const int idx                             = id[3] * output_stride_w + id.x() * output_stride_z + id.z() * output_stride_y + id.y() * output_stride_x;
-            *(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr()));
-        },
-        in, out);
-    }
-    // Run [3, 2, 0, 1] permute
-    else if(_perm[0] == 3 && _perm[1] == 2 && _perm[2] == 0 && _perm[3] == 1)
-    {
-        execute_window_loop(window, [&](const Coordinates & id)
+            // special case: perm.size = 3 and tensor size > 3, _perm[3] would be invalid so we handle this with id[3] * output_stride_w instead of id[_perm[3]]
+            ARM_COMPUTE_ERROR_ON(_perm.num_dimensions() < 3);
+            execute_window_loop(window, [&](const Coordinates & id)
+            {
+                const int idx                             = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2] + id[3] * output_stride_w;
+                *(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr()));
+            },
+            in, out);
+        }
+        else
         {
-            const int idx                             = id[3] * output_stride_x + id[2] * output_stride_y + id[0] * output_stride_z + id[1] * output_stride_w;
-            *(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr()));
-        },
-        in, out);
-    }
-    else
-    {
-        ARM_COMPUTE_ERROR("Not supported.");
+            execute_window_loop(window, [&](const Coordinates & id)
+            {
+                const int idx                             = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2] + id[3] * perm_strides[3];
+                *(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr()));
+            },
+            in, out);
+        }
     }
 }
 
@@ -133,7 +139,7 @@
 void CPPPermuteKernel::configure(const ITensor *input, ITensor *output, const PermutationVector &perm)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
-    const TensorShape output_shape = get_output_shape(input->info(), perm);
+    const TensorShape output_shape = misc::shape_calculator::compute_permutation_output_shape(*input->info(), perm);
     // Output auto inizialitation if not yet initialized
     auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape));