COMPMID-1764 NEON: Implement ArgMax/ArgMin

Change-Id: Ibe23aa90b36ffd8553d1d1c35fada5d300fab829
Reviewed-on: https://review.mlplatform.org/475
Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
diff --git a/tests/validation/fixtures/ArgMinMaxFixture.h b/tests/validation/fixtures/ArgMinMaxFixture.h
index 5f5f85c..e263b25 100644
--- a/tests/validation/fixtures/ArgMinMaxFixture.h
+++ b/tests/validation/fixtures/ArgMinMaxFixture.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -42,28 +42,38 @@
 namespace validation
 {
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class ArgMinMaxValidationFixture : public framework::Fixture
+class ArgMinMaxValidationBaseFixture : public framework::Fixture
 {
 public:
     template <typename...>
-    void setup(TensorShape shape, DataType data_type, int axis, ReductionOperation op)
+    void setup(TensorShape shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info)
     {
-        _target    = compute_target(shape, data_type, axis, op);
-        _reference = compute_reference(shape, data_type, axis, op);
+        _target    = compute_target(shape, data_type, axis, op, q_info);
+        _reference = compute_reference(shape, data_type, axis, op, q_info);
     }
 
 protected:
     template <typename U>
     void fill(U &&tensor)
     {
-        std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
-        library->fill(tensor, distribution, 0);
+        if(!is_data_type_quantized(tensor.data_type()))
+        {
+            std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+            library->fill(tensor, distribution, 0);
+        }
+        else
+        {
+            std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
+            std::uniform_int_distribution<uint8_t> distribution(bounds.first, bounds.second);
+
+            library->fill(tensor, distribution, 0);
+        }
     }
 
-    TensorType compute_target(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op)
+    TensorType compute_target(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info)
     {
         // Create tensors
-        TensorType src = create_tensor<TensorType>(src_shape, data_type, 1);
+        TensorType src = create_tensor<TensorType>(src_shape, data_type, 1, q_info);
         TensorType dst;
 
         // Create and configure function
@@ -89,21 +99,43 @@
         return dst;
     }
 
-    SimpleTensor<T> compute_reference(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op)
+    SimpleTensor<uint32_t> compute_reference(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info)
     {
         // Create reference
-        SimpleTensor<T> src{ src_shape, data_type, 1 };
+        SimpleTensor<T> src{ src_shape, data_type, 1, q_info };
 
         // Fill reference
         fill(src);
 
         TensorShape output_shape = src_shape;
         output_shape.set(axis, 1);
-        return reference::reduction_operation<T>(src, output_shape, axis, op);
+        return reference::reduction_operation<T, uint32_t>(src, output_shape, axis, op);
     }
 
-    TensorType      _target{};
-    SimpleTensor<T> _reference{};
+    TensorType             _target{};
+    SimpleTensor<uint32_t> _reference{};
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class ArgMinMaxValidationQuantizedFixture : public ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+    template <typename...>
+    void setup(const TensorShape &shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo quantization_info)
+    {
+        ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, quantization_info);
+    }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class ArgMinMaxValidationFixture : public ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+    template <typename...>
+    void setup(const TensorShape &shape, DataType data_type, int axis, ReductionOperation op)
+    {
+        ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, QuantizationInfo());
+    }
 };
 } // namespace validation
 } // namespace test