COMPMID-2243 ArgMinMaxLayer: support new datatypes

Change-Id: I846e833e0c94090cbbdcd6aee6061cea8295f4f9
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1131
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
index aa20d1f..5f0a4dd 100644
--- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
@@ -41,7 +41,8 @@
 {
 namespace
 {
-uint32x4x4_t calculate_index(uint32_t idx, float32x4_t a, float32x4_t b, uint32x4x4_t c, ReductionOperation op, int axis)
+template <typename T>
+uint32x4x4_t calculate_index(uint32_t idx, T a, T b, uint32x4x4_t c, ReductionOperation op, int axis)
 {
     uint32x4_t mask{ 0 };
     if(op == ReductionOperation::ARG_IDX_MIN)
@@ -107,8 +108,8 @@
 
     return res;
 }
-
-uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, float32x4_t vec_res_value, ReductionOperation op)
+template <typename T>
+uint32_t calculate_vector_index(uint32x4x4_t vec_res_idx, T vec_res_value, ReductionOperation op)
 {
     uint32x4_t res_idx_mask{ 0 };
     uint32x4_t mask_ones = vdupq_n_u32(0xFFFFFFFF);
@@ -124,7 +125,7 @@
     {
         auto pmax    = wrapper::vpmax(wrapper::vgethigh(vec_res_value), wrapper::vgetlow(vec_res_value));
         pmax         = wrapper::vpmax(pmax, pmax);
-        auto mask    = vceqq_f32(vec_res_value, wrapper::vcombine(pmax, pmax));
+        auto mask    = wrapper::vceq(vec_res_value, wrapper::vcombine(pmax, pmax));
         res_idx_mask = wrapper::vand(vec_res_idx.val[0], mask);
     }
 
@@ -394,14 +395,14 @@
                 case ReductionOperation::ARG_IDX_MIN:
                 {
                     auto temp_vec_res_value = wrapper::vmin(vec_elements, vec_res_value);
-                    vec_res_idx             = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
+                    vec_res_idx             = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
                     vec_res_value           = temp_vec_res_value;
                     break;
                 }
                 case ReductionOperation::ARG_IDX_MAX:
                 {
                     auto temp_vec_res_value = wrapper::vmax(vec_elements, vec_res_value);
-                    vec_res_idx             = calculate_index(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
+                    vec_res_idx             = calculate_index<decltype(vec_res_value)>(id.x(), temp_vec_res_value, vec_res_value, vec_res_idx, op, 0);
                     vec_res_value           = temp_vec_res_value;
                     break;
                 }
@@ -446,7 +447,7 @@
             case ReductionOperation::ARG_IDX_MIN:
             case ReductionOperation::ARG_IDX_MAX:
             {
-                auto res                                      = calculate_vector_index(vec_res_idx, vec_res_value, op);
+                auto res                                      = calculate_vector_index<decltype(vec_res_value)>(vec_res_idx, vec_res_value, op);
                 *(reinterpret_cast<uint32_t *>(output.ptr())) = res;
                 break;
             }
@@ -943,6 +944,8 @@
 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
                 case DataType::F32:
                     return Reducer<RedOpX<float, 4>>::reduceX(window, input, output, RedOpX<float, 4>(), op);
+                case DataType::S32:
+                    return Reducer<RedOpX<int32_t, 4>>::reduceX(window, input, output, RedOpX<int32_t, 4>(), op);
                 default:
                     ARM_COMPUTE_ERROR("Not supported");
             }
@@ -957,6 +960,8 @@
 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
                 case DataType::F32:
                     return Reducer<RedOpYZW<float, 4>>::reduceY(window, input, output, RedOpYZW<float, 4>(), op);
+                case DataType::S32:
+                    return Reducer<RedOpYZW<int32_t, 4>>::reduceY(window, input, output, RedOpYZW<int32_t, 4>(), op);
                 default:
                     ARM_COMPUTE_ERROR("Not supported");
             }
@@ -971,6 +976,8 @@
 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
                 case DataType::F32:
                     return Reducer<RedOpYZW<float, 4>>::reduceZ(window, input, output, RedOpYZW<float, 4>(), op);
+                case DataType::S32:
+                    return Reducer<RedOpYZW<int32_t, 4>>::reduceZ(window, input, output, RedOpYZW<int32_t, 4>(), op);
                 default:
                     ARM_COMPUTE_ERROR("Not supported");
             }
@@ -985,6 +992,8 @@
 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
                 case DataType::F32:
                     return Reducer<RedOpYZW<float, 4>>::reduceW(window, input, output, RedOpYZW<float, 4>(), op);
+                case DataType::S32:
+                    return Reducer<RedOpYZW<int32_t, 4>>::reduceW(window, input, output, RedOpYZW<int32_t, 4>(), op);
                 default:
                     ARM_COMPUTE_ERROR("Not supported");
             }
@@ -1002,7 +1011,7 @@
 
     if(input->num_channels() == 1)
     {
-        ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+        ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
     }
     else
     {