COMPMID-1451: Fix CL/NEPermuteKernel PermuteVection check
COMPMID-1690: Add tests for NEPermute with PermutationVector dimension > 3

Change-Id: I4bfc6ff88cd46863c2e39975b5663c624db1a63d
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/155316
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: bsgcomp <bsgcomp@arm.com>
diff --git a/src/core/CL/kernels/CLPermuteKernel.cpp b/src/core/CL/kernels/CLPermuteKernel.cpp
index c6f0f4b..a9a2c5c 100644
--- a/src/core/CL/kernels/CLPermuteKernel.cpp
+++ b/src/core/CL/kernels/CLPermuteKernel.cpp
@@ -93,17 +93,17 @@
     build_opts.emplace("-DDEPTH_IN=" + support::cpp11::to_string(input->info()->dimension(2)));
 
     // Run [2, 0, 1] permute
-    if(_perm[0] == 2 && _perm[1] == 0 && _perm[2] == 1)
+    if(_perm == PermutationVector{ 2U, 0U, 1U })
     {
         _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("permute_201", build_opts));
     }
     // Run [1, 2, 0] permute
-    else if(_perm[0] == 1 && _perm[1] == 2 && _perm[2] == 0)
+    else if(_perm == PermutationVector{ 1U, 2U, 0U })
     {
         _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("permute_120", build_opts));
     }
     // Run [3, 2, 0, 1] permute
-    else if(_perm[0] == 3 && _perm[1] == 2 && _perm[2] == 0 && _perm[3] == 1)
+    else if(_perm == PermutationVector{ 3U, 2U, 0U, 1U })
     {
         _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("permute_3201", build_opts));
     }
diff --git a/src/core/NEON/kernels/NEPermuteKernel.cpp b/src/core/NEON/kernels/NEPermuteKernel.cpp
index 8d3fd88..29e6d50 100644
--- a/src/core/NEON/kernels/NEPermuteKernel.cpp
+++ b/src/core/NEON/kernels/NEPermuteKernel.cpp
@@ -50,7 +50,8 @@
                                                          DataType::U16, DataType::S16,
                                                          DataType::U32, DataType::S32,
                                                          DataType::F16, DataType::F32);
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG((perm.num_dimensions() == 3 && !(perm[0] == 2 && perm[1] == 0 && perm[2] == 1) && !(perm[0] == 1 && perm[1] == 2 && perm[2] == 0)),
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG((perm != PermutationVector{ 2U, 0U, 1U })
+                                    && (perm != PermutationVector{ 1U, 2U, 0U }),
                                     "Only [2, 0, 1] and [1, 2, 0] permutation is supported");
 
     const TensorShape output_shape = misc::shape_calculator::compute_permutation_output_shape(*input, perm);
@@ -89,7 +90,7 @@
     Iterator out(_output, window_out);
 
     // CHW -> HWC
-    if((_perm.num_dimensions() == 3) && (_perm[0] == 2) && (_perm[1] == 0) && (_perm[2] == 1))
+    if(_perm == PermutationVector{ 2U, 0U, 1U })
     {
         const int in_row_stride     = _input->info()->strides_in_bytes().y() / sizeof(T);
         const int in_channel_stride = _input->info()->strides_in_bytes().z() / sizeof(T);
@@ -116,7 +117,7 @@
         in, out);
     }
     // HWC -> CHW
-    else if((_perm.num_dimensions() == 3) && (_perm[0] == 1) && (_perm[1] == 2) && (_perm[2] == 0))
+    else if(_perm == PermutationVector{ 1U, 2U, 0U })
     {
         const int in_col_stride   = _input->info()->strides_in_bytes().y() / sizeof(T);
         const int in_row_stride   = _input->info()->strides_in_bytes().z() / sizeof(T);
diff --git a/tests/validation/NEON/Permute.cpp b/tests/validation/NEON/Permute.cpp
index 872a16b..8c172dd 100644
--- a/tests/validation/NEON/Permute.cpp
+++ b/tests/validation/NEON/Permute.cpp
@@ -60,6 +60,8 @@
                                                                                         TensorInfo(TensorShape(1U, 7U), 1, DataType::U8),              // invalid input size
                                                                                         TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16),     // valid
                                                                                         TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32),  // valid
+                                                                                        TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16),     // permutation not supported
+                                                                                        TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32),  // permutation not supported
                                                                                     }),
                                                 framework::dataset::make("OutputInfo", { 
                                                                                         TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),     
@@ -68,6 +70,8 @@
                                                                                         TensorInfo(TensorShape(5U, 7U), 1, DataType::U8),
                                                                                         TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16), 
                                                                                         TensorInfo(TensorShape(13U, 37U, 27U, 2U), 1, DataType::F32),  
+                                                                                        TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16), 
+                                                                                        TensorInfo(TensorShape(13U, 37U, 27U, 2U), 1, DataType::F32),  
                                                                                     })),
                                                 framework::dataset::make("PermutationVector", { 
                                                                                                 PermutationVector(2U, 1U, 0U),
@@ -76,8 +80,10 @@
                                                                                                 PermutationVector(2U, 0U, 1U),
                                                                                                 PermutationVector(2U, 0U, 1U), 
                                                                                                 PermutationVector(1U, 2U, 0U),
+                                                                                                PermutationVector(3U, 2U, 0U, 1U),
+                                                                                                PermutationVector(2U, 3U, 1U, 0U)
                                                                                     })),
-                                                framework::dataset::make("Expected", { false, false, false, false, true, true })),
+                                                framework::dataset::make("Expected", { false, false, false, false, true, true, false, false })),
                                             input_info, output_info, perm_vect, expected)
 {
     ARM_COMPUTE_EXPECT(bool(NEPermute::validate(&input_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), perm_vect)) == expected, framework::LogLevel::ERRORS);