COMPMID-1764 NEON: Implement ArgMax/ArgMin

Change-Id: Ibe23aa90b36ffd8553d1d1c35fada5d300fab829
Reviewed-on: https://review.mlplatform.org/475
Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
diff --git a/tests/validation/reference/L2NormalizeLayer.cpp b/tests/validation/reference/L2NormalizeLayer.cpp
index fcd6226..43885b2 100644
--- a/tests/validation/reference/L2NormalizeLayer.cpp
+++ b/tests/validation/reference/L2NormalizeLayer.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -54,7 +54,7 @@
     SimpleTensor<T> dst{ src.shape(), src.data_type() };
 
     // Reduce across given axis
-    SimpleTensor<T> sum = reduction_operation(src, get_output_shape(src.shape(), axis), axis, ReductionOperation::SUM_SQUARE);
+    SimpleTensor<T> sum = reduction_operation<T, T>(src, get_output_shape(src.shape(), axis), axis, ReductionOperation::SUM_SQUARE);
 
     // Compute reference
     const int upper_dims     = src.shape().total_size_upper(axis + 1);
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp
index 37a9be8..fc12e31 100644
--- a/tests/validation/reference/ReductionOperation.cpp
+++ b/tests/validation/reference/ReductionOperation.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -49,20 +49,20 @@
         uint32_t int_res = 0;
         for(int i = 0; i < reduce_elements; ++i)
         {
-            auto elem = static_cast<uint32_t>(*(ptr + stride * i));
+            auto elem = *(ptr + stride * i);
 
             switch(op)
             {
                 case ReductionOperation::ARG_IDX_MIN:
-                    if(static_cast<uint32_t>(*(ptr + stride * static_cast<uint32_t>(res))) > elem)
+                    if(*(ptr + stride * static_cast<uint32_t>(int_res)) > elem)
                     {
-                        res = static_cast<uint32_t>(i);
+                        int_res = static_cast<uint32_t>(i);
                     }
                     break;
                 case ReductionOperation::ARG_IDX_MAX:
-                    if(static_cast<uint32_t>(*(ptr + stride * static_cast<uint32_t>(res))) < elem)
+                    if(*(ptr + stride * static_cast<uint32_t>(int_res)) < elem)
                     {
-                        res = static_cast<uint32_t>(i);
+                        int_res = static_cast<uint32_t>(i);
                     }
                     break;
                 case ReductionOperation::SUM_SQUARE:
@@ -122,13 +122,13 @@
 }
 } // namespace
 
-template <typename T>
-SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
+template <typename T, typename OT>
+SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
 {
     // Create reference
     const bool         is_arg_min_max   = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
     DataType           output_data_type = is_arg_min_max ? DataType::U32 : src.data_type();
-    SimpleTensor<T>    dst{ dst_shape, output_data_type, 1, src.quantization_info() };
+    SimpleTensor<OT>   dst{ dst_shape, output_data_type, 1, src.quantization_info() };
     const unsigned int src_width    = src.shape().x();
     const unsigned int src_height   = src.shape().y();
     const unsigned int src_depth    = src.shape().z();
@@ -143,14 +143,7 @@
             for(unsigned int du = 0; du < upper_dims; ++du)
             {
                 const T *src_row_ptr = src.data() + du * reduce_elems;
-                if(is_arg_min_max)
-                {
-                    dst[du] = reduce_operation<T, uint32_t>(src_row_ptr, reduce_elems, op, 1);
-                }
-                else
-                {
-                    dst[du] = reduce_operation<T, T>(src_row_ptr, reduce_elems, op, 1);
-                }
+                dst[du]              = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, 1);
             }
         }
         break;
@@ -164,15 +157,7 @@
                     const int in_offset   = du * src_height * src_width + x;
                     const int out_offset  = du * src_width + x;
                     const T *src_row_ptr = src.data() + in_offset;
-
-                    if(is_arg_min_max)
-                    {
-                        dst[out_offset] = reduce_operation<T, uint32_t>(src_row_ptr, reduce_elems, op, src_width);
-                    }
-                    else
-                    {
-                        dst[out_offset] = reduce_operation<T, T>(src_row_ptr, reduce_elems, op, src_width);
-                    }
+                    dst[out_offset]       = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width);
                 }
             }
         }
@@ -189,15 +174,7 @@
                         const int in_offset   = du * src_depth * src_height * src_width + y * src_width + x;
                         const int out_offset  = du * src_width * src_height + y * src_width + x;
                         const T *src_row_ptr = src.data() + in_offset;
-
-                        if(is_arg_min_max)
-                        {
-                            dst[out_offset] = reduce_operation<T, uint32_t>(src_row_ptr, reduce_elems, op, src_height * src_width);
-                        }
-                        else
-                        {
-                            dst[out_offset] = reduce_operation<T, T>(src_row_ptr, reduce_elems, op, src_height * src_width);
-                        }
+                        dst[out_offset]       = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_height * src_width);
                     }
                 }
             }
@@ -217,14 +194,7 @@
                             const int in_offset   = du * src_batch * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
                             const int out_offset  = du * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
                             const T *src_row_ptr = src.data() + in_offset;
-                            if(is_arg_min_max)
-                            {
-                                dst[out_offset] = reduce_operation<T, uint32_t>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
-                            }
-                            else
-                            {
-                                dst[out_offset] = reduce_operation<T, T>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
-                            }
+                            dst[out_offset]       = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
                         }
                     }
                 }
@@ -238,6 +208,9 @@
     return dst;
 }
 
+template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
+template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
+template SimpleTensor<uint32_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
 template SimpleTensor<float> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
 template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
 template SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
diff --git a/tests/validation/reference/ReductionOperation.h b/tests/validation/reference/ReductionOperation.h
index 859b57a..9f7050f 100644
--- a/tests/validation/reference/ReductionOperation.h
+++ b/tests/validation/reference/ReductionOperation.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -35,10 +35,10 @@
 {
 namespace reference
 {
-template <typename T>
-SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
+template <typename T, typename OT>
+SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op);
 } // namespace reference
 } // namespace validation
 } // namespace test
 } // namespace arm_compute
-#endif /* __ARM_COMPUTE_TEST_FLOOR_H__ */
+#endif /* __ARM_COMPUTE_TEST_REDUCTION_OPERATION_H__ */