COMPMID-2376: (Nightly) NEON ArgMinMax QASYMM8

The maximum index value returned from the reference is computed as a uint8_t
Therefore the maximum value it returns is 255. In order to fix this bug for
tensor shapes with dimension along the reduced axis above 255, I separated the
computation of the arg min max reference so that it cleanly computes the
result as a uint32_t

Change-Id: I96a710177609d97c53ed12f20651d9737b3eb703
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1318
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp
index 571b991..fe128cc 100644
--- a/tests/validation/reference/ReductionOperation.cpp
+++ b/tests/validation/reference/ReductionOperation.cpp
@@ -71,18 +71,6 @@
 
             switch(op)
             {
-                case ReductionOperation::ARG_IDX_MIN:
-                    if(*(ptr + stride * static_cast<uint32_t>(int_res)) > elem)
-                    {
-                        int_res = static_cast<uint32_t>(i);
-                    }
-                    break;
-                case ReductionOperation::ARG_IDX_MAX:
-                    if(*(ptr + stride * static_cast<uint32_t>(int_res)) < elem)
-                    {
-                        int_res = static_cast<uint32_t>(i);
-                    }
-                    break;
                 case ReductionOperation::MIN:
                     if(static_cast<T>(int_res) > elem)
                     {
@@ -122,18 +110,6 @@
             auto elem = *(ptr + stride * i);
             switch(op)
             {
-                case ReductionOperation::ARG_IDX_MIN:
-                    if(*(ptr + stride * static_cast<uint32_t>(res)) > elem)
-                    {
-                        res = static_cast<uint32_t>(i);
-                    }
-                    break;
-                case ReductionOperation::ARG_IDX_MAX:
-                    if(*(ptr + stride * static_cast<uint32_t>(res)) < elem)
-                    {
-                        res = static_cast<uint32_t>(i);
-                    }
-                    break;
                 case ReductionOperation::MIN:
                     if(res > elem)
                     {
@@ -167,6 +143,35 @@
     }
     return res;
 }
+
+template <typename T, typename OT>
+OT reduce_operation_arg_min_max(const T *ptr, int reduce_elements, ReductionOperation op, int stride)
+{
+    uint32_t res = 0;
+    for(int i = 0; i < reduce_elements; ++i)
+    {
+        auto elem = *(ptr + stride * i);
+        switch(op)
+        {
+            case ReductionOperation::ARG_IDX_MIN:
+                if(*(ptr + stride * res) > elem)
+                {
+                    res = static_cast<uint32_t>(i);
+                }
+                break;
+            case ReductionOperation::ARG_IDX_MAX:
+                if(*(ptr + stride * res) < elem)
+                {
+                    res = static_cast<uint32_t>(i);
+                }
+                break;
+            default:
+                ARM_COMPUTE_ERROR("Operation not supported");
+        }
+    }
+    return static_cast<OT>(res);
+}
+
 } // namespace
 
 template <typename T, typename OT>
@@ -190,7 +195,9 @@
             for(unsigned int du = 0; du < upper_dims; ++du)
             {
                 const T *src_row_ptr = src.data() + du * reduce_elems;
-                dst[du]              = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, 1);
+                dst[du]              = is_arg_min_max ?
+                                       reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, 1) :
+                                       reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, 1);
             }
         }
         break;
@@ -204,7 +211,9 @@
                     const int in_offset   = du * src_height * src_width + x;
                     const int out_offset  = du * src_width + x;
                     const T *src_row_ptr = src.data() + in_offset;
-                    dst[out_offset]       = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width);
+                    dst[out_offset]       = is_arg_min_max ?
+                                            reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, src_width) :
+                                            reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width);
                 }
             }
         }
@@ -221,7 +230,9 @@
                         const int in_offset   = du * src_depth * src_height * src_width + y * src_width + x;
                         const int out_offset  = du * src_width * src_height + y * src_width + x;
                         const T *src_row_ptr = src.data() + in_offset;
-                        dst[out_offset]       = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_height * src_width);
+                        dst[out_offset]       = is_arg_min_max ?
+                                                reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height) :
+                                                reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height);
                     }
                 }
             }
@@ -241,7 +252,9 @@
                             const int in_offset   = du * src_batch * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
                             const int out_offset  = du * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x;
                             const T *src_row_ptr = src.data() + in_offset;
-                            dst[out_offset]       = reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
+                            dst[out_offset]       = is_arg_min_max ?
+                                                    reduce_operation_arg_min_max<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth) :
+                                                    reduce_operation<T, OT>(src_row_ptr, reduce_elems, op, src_width * src_height * src_depth);
                         }
                     }
                 }