Improved testing for ArgMinMax

* ArgMinMax output was fixed to S32, this patch makes the changes required
  to allow other output types like U64/S64

* Made changes to the ArgMinMax fixture and tests to allow specifying output data type.

* Made changes to the reference reduction_operation to allow specifying the output type

* Added tests case to output S64 for the CL backend.

* Added missing test cases in the neon backend.

* Partially resolves MLCE-1089

Change-Id: I6f1cbc7093669d12c2a3aff6974cf19d83b2ecda
Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10003
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/tests/validation/CL/ArgMinMax.cpp b/tests/validation/CL/ArgMinMax.cpp
index 9bfd9d9..8566972 100644
--- a/tests/validation/CL/ArgMinMax.cpp
+++ b/tests/validation/CL/ArgMinMax.cpp
@@ -63,6 +63,10 @@
     TensorShape{ 15U, 2U },
 });
 
+const auto OpsDataset   = framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX });
+const auto AxisDataset  = framework::dataset::make("Axis", { 0, 1, 2, 3 });
+const auto QInfoDataset = framework::dataset::make("QuantizationInfo", { QuantizationInfo(5.f / 255.f, 20) });
+
 const auto ArgMinMaxLargeDataset = framework::dataset::make("Shape",
 { TensorShape{ 517U, 123U, 13U, 2U } });
 } // namespace
@@ -95,57 +99,78 @@
 // clang-format on
 // *INDENT-ON*
 
-template <typename T>
-using CLArgMinMaxValidationFixture = ArgMinMaxValidationFixture<CLTensor, CLAccessor, CLArgMinMaxLayer, T>;
+template <typename T1, typename T2>
+using CLArgMinMaxValidationFixture = ArgMinMaxValidationFixture<CLTensor, CLAccessor, CLArgMinMaxLayer, T1, T2>;
+
+using CLArgMinMaxValidationFixture_S32_S32 = CLArgMinMaxValidationFixture<int32_t, int32_t>;
+using CLArgMinMaxValidationFixture_F16_S32 = CLArgMinMaxValidationFixture<half, int32_t>;
+using CLArgMinMaxValidationFixture_F32_S32 = CLArgMinMaxValidationFixture<float, int32_t>;
+using CLArgMinMaxValidationFixture_F32_S64 = CLArgMinMaxValidationFixture<float, int64_t>;
 
 TEST_SUITE(S32)
 FIXTURE_DATA_TEST_CASE(RunSmallAxis0,
-                       CLArgMinMaxValidationFixture<int32_t>,
+                       CLArgMinMaxValidationFixture_S32_S32,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(ArgMinMaxSmallDatasetAxis0, framework::dataset::make("DataType", DataType::S32)), framework::dataset::make("Axis", { 0 })),
-                               framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })))
+                       combine(combine(combine(combine(ArgMinMaxSmallDatasetAxis0,
+                                                       framework::dataset::make("DataTypeIn", DataType::S32)),
+                                               framework::dataset::make("DataTypeOut", DataType::S32)),
+                                       framework::dataset::make("Axis", { 0 })),
+                               OpsDataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
 
 FIXTURE_DATA_TEST_CASE(RunSmall,
-                       CLArgMinMaxValidationFixture<int32_t>,
+                       CLArgMinMaxValidationFixture_S32_S32,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(ArgMinMaxSmallDataset, framework::dataset::make("DataType", DataType::S32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
-                               framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })))
+                       combine(combine(combine(combine(ArgMinMaxSmallDataset,
+                                                       framework::dataset::make("DataTypeIn", DataType::S32)),
+                                               framework::dataset::make("DataTypeOut", DataType::S32)),
+                                       AxisDataset),
+                               OpsDataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
 FIXTURE_DATA_TEST_CASE(RunLarge,
-                       CLArgMinMaxValidationFixture<int32_t>,
+                       CLArgMinMaxValidationFixture_S32_S32,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(ArgMinMaxLargeDataset, framework::dataset::make("DataType", DataType::S32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
-                               framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })))
+                       combine(combine(combine(combine(ArgMinMaxLargeDataset,
+                                                       framework::dataset::make("DataTypeIn", DataType::S32)),
+                                               framework::dataset::make("DataTypeOut", DataType::S32)),
+                                       AxisDataset),
+                               OpsDataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
+
 TEST_SUITE_END() // S32
 
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
 FIXTURE_DATA_TEST_CASE(RunSmall,
-                       CLArgMinMaxValidationFixture<half>,
+                       CLArgMinMaxValidationFixture_F16_S32,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(ArgMinMaxSmallDataset, framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
-                               framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })))
+                       combine(combine(combine(combine(ArgMinMaxSmallDataset,
+                                                       framework::dataset::make("DataTypeIn", DataType::F16)),
+                                               framework::dataset::make("DataTypeOut", DataType::S32)),
+                                       AxisDataset),
+                               OpsDataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
 
 FIXTURE_DATA_TEST_CASE(RunLarge,
-                       CLArgMinMaxValidationFixture<half>,
+                       CLArgMinMaxValidationFixture_F16_S32,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(ArgMinMaxLargeDataset, framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
-                               framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })))
+                       combine(combine(combine(combine(ArgMinMaxLargeDataset,
+                                                       framework::dataset::make("DataTypeIn", DataType::F16)),
+                                               framework::dataset::make("DataTypeOut", DataType::S32)),
+                                       AxisDataset),
+                               OpsDataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -154,49 +179,77 @@
 
 TEST_SUITE(FP32)
 FIXTURE_DATA_TEST_CASE(RunSmall,
-                       CLArgMinMaxValidationFixture<float>,
+                       CLArgMinMaxValidationFixture_F32_S32,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(ArgMinMaxSmallDataset, framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
-                               framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })))
+                       combine(combine(combine(combine(ArgMinMaxSmallDataset,
+                                                       framework::dataset::make("DataTypeIn", DataType::F32)),
+                                               framework::dataset::make("DataTypeOut", DataType::S32)),
+                                       AxisDataset),
+                               OpsDataset))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall_F32_S64,
+                       CLArgMinMaxValidationFixture_F32_S64,
+                       framework::DatasetMode::PRECOMMIT,
+                       combine(combine(combine(combine(ArgMinMaxSmallDataset,
+                                                       framework::dataset::make("DataTypeIn", DataType::F32)),
+                                               framework::dataset::make("DataTypeOut", DataType::S64)),
+                                       AxisDataset),
+                               OpsDataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
 
 FIXTURE_DATA_TEST_CASE(RunLarge,
-                       CLArgMinMaxValidationFixture<float>,
+                       CLArgMinMaxValidationFixture_F32_S32,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(ArgMinMaxLargeDataset, framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
-                               framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })))
+                       combine(combine(combine(combine(ArgMinMaxLargeDataset,
+                                                       framework::dataset::make("DataTypeIn", DataType::F32)),
+                                               framework::dataset::make("DataTypeOut", DataType::S32)),
+                                       AxisDataset),
+                               OpsDataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
+
 TEST_SUITE_END() // FP32
 TEST_SUITE_END() // Float
 
-template <typename T>
-using CLArgMinMaxQuantizedValidationFixture = ArgMinMaxValidationQuantizedFixture<CLTensor, CLAccessor, CLArgMinMaxLayer, T>;
+template <typename T1, typename T2>
+using CLArgMinMaxQuantizedValidationFixture = ArgMinMaxValidationQuantizedFixture<CLTensor, CLAccessor, CLArgMinMaxLayer, T1, T2>;
+
+using CLArgMinMaxQuantizedValidationFixture_U8_S32 = CLArgMinMaxQuantizedValidationFixture<uint8_t, int32_t>;
+using CLArgMinMaxQuantizedValidationFixture_S8_S32 = CLArgMinMaxQuantizedValidationFixture<int8_t, int32_t>;
 
 TEST_SUITE(Quantized)
 TEST_SUITE(QASYMM8)
 FIXTURE_DATA_TEST_CASE(RunSmall,
-                       CLArgMinMaxQuantizedValidationFixture<uint8_t>,
+                       CLArgMinMaxQuantizedValidationFixture_U8_S32,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(combine(ArgMinMaxSmallDataset, framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
-                                       framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })),
-                               framework::dataset::make("QuantizationInfo", { QuantizationInfo(5.f / 255.f, 20) })))
+                       combine(combine(combine(combine(combine(ArgMinMaxSmallDataset,
+                                                               framework::dataset::make("DataTypeIn", DataType::QASYMM8)),
+                                                       framework::dataset::make("DataTypeOut", DataType::S32)),
+                                               AxisDataset),
+                                       OpsDataset),
+                               QInfoDataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
-
 FIXTURE_DATA_TEST_CASE(RunLarge,
-                       CLArgMinMaxQuantizedValidationFixture<uint8_t>,
+                       CLArgMinMaxQuantizedValidationFixture_U8_S32,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(combine(ArgMinMaxLargeDataset, framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
-                                       framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })),
-                               framework::dataset::make("QuantizationInfo", { QuantizationInfo(5.f / 255.f, 20) })))
+                       combine(combine(combine(combine(combine(ArgMinMaxLargeDataset,
+                                                               framework::dataset::make("DataTypeIn", DataType::QASYMM8)),
+                                                       framework::dataset::make("DataTypeOut", DataType::S32)),
+                                               AxisDataset),
+                                       OpsDataset),
+                               QInfoDataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -205,28 +258,32 @@
 
 TEST_SUITE(QASYMM8_SIGNED)
 FIXTURE_DATA_TEST_CASE(RunSmall,
-                       CLArgMinMaxQuantizedValidationFixture<int8_t>,
+                       CLArgMinMaxQuantizedValidationFixture_S8_S32,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(combine(ArgMinMaxSmallDataset, framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
-                                       framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })),
-                               framework::dataset::make("QuantizationInfo", { QuantizationInfo(5.f / 255.f, 20) })))
+                       combine(combine(combine(combine(combine(ArgMinMaxSmallDataset,
+                                                               framework::dataset::make("DataTypeIn", DataType::QASYMM8_SIGNED)),
+                                                       framework::dataset::make("DataTypeOut", DataType::S32)),
+                                               AxisDataset),
+                                       OpsDataset),
+                               QInfoDataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
-
 FIXTURE_DATA_TEST_CASE(RunLarge,
-                       CLArgMinMaxQuantizedValidationFixture<int8_t>,
+                       CLArgMinMaxQuantizedValidationFixture_S8_S32,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(combine(ArgMinMaxLargeDataset, framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
-                                       framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })),
-                               framework::dataset::make("QuantizationInfo", { QuantizationInfo(5.f / 255.f, 20) })))
+                       combine(combine(combine(combine(combine(ArgMinMaxLargeDataset,
+                                                               framework::dataset::make("DataTypeIn", DataType::QASYMM8_SIGNED)),
+                                                       framework::dataset::make("DataTypeOut", DataType::S32)),
+                                               AxisDataset),
+                                       OpsDataset),
+                               QInfoDataset))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
 TEST_SUITE_END() // QASYMM8_SIGNED
-
 TEST_SUITE_END() // Quantized
 TEST_SUITE_END() // ArgMinMax
 TEST_SUITE_END() // CL
diff --git a/tests/validation/NEON/ArgMinMax.cpp b/tests/validation/NEON/ArgMinMax.cpp
index 0a40710..2e21a7d 100644
--- a/tests/validation/NEON/ArgMinMax.cpp
+++ b/tests/validation/NEON/ArgMinMax.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -43,6 +43,27 @@
 {
 namespace validation
 {
+namespace
+{
+const auto OpsDataset   = framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX });
+const auto AxisDataset  = framework::dataset::make("Axis", { 0, 1, 2, 3 });
+const auto QInfoDataset = framework::dataset::make("QuantizationInfo", { QuantizationInfo(5.f / 255.f, 20) });
+
+const auto ArgMinMaxSmallDatasetAxis0 = framework::dataset::make("Shape",
+{
+    TensorShape{ 1U, 5U },
+    TensorShape{ 2U, 3U },
+    TensorShape{ 1U },
+    TensorShape{ 3U },
+    TensorShape{ 2U },
+    TensorShape{ 5U },
+    TensorShape{ 17U },
+    TensorShape{ 15U, 2U },
+});
+using ArgMinMaxSmallDataset = datasets::Small4DShapes;
+using ArgMinMaxLargeDataset = datasets::Large4DShapes;
+}
+
 TEST_SUITE(NEON)
 TEST_SUITE(ArgMinMax)
 
@@ -70,23 +91,46 @@
 // clang-format on
 // *INDENT-ON*
 
-template <typename T>
-using NEArgMinMaxValidationFixture = ArgMinMaxValidationFixture<Tensor, Accessor, NEArgMinMaxLayer, T>;
+template <typename T1, typename T2>
+using NEArgMinMaxValidationFixture = ArgMinMaxValidationFixture<Tensor, Accessor, NEArgMinMaxLayer, T1, T2>;
 
+using NEArgMinMaxValidationFixture_S32_S32 = NEArgMinMaxValidationFixture<int32_t, int32_t>;
+using NEArgMinMaxValidationFixture_F16_S32 = NEArgMinMaxValidationFixture<half, int32_t>;
+using NEArgMinMaxValidationFixture_F32_S32 = NEArgMinMaxValidationFixture<float, int32_t>;
 TEST_SUITE(S32)
-FIXTURE_DATA_TEST_CASE(RunSmall,
-                       NEArgMinMaxValidationFixture<int32_t>,
+FIXTURE_DATA_TEST_CASE(RunSmallAxis0,
+                       NEArgMinMaxValidationFixture_S32_S32,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::S32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })))
+                       combine(combine(combine(combine(ArgMinMaxSmallDatasetAxis0,
+                                                       framework::dataset::make("DataTypeIn", DataType::S32)),
+                                               framework::dataset::make("DataTypeOut", DataType::S32)),
+                                       framework::dataset::make("Axis", { 0 })),
+                               OpsDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference);
 }
 
+FIXTURE_DATA_TEST_CASE(RunSmall,
+                       NEArgMinMaxValidationFixture_S32_S32,
+                       framework::DatasetMode::PRECOMMIT,
+                       combine(combine(combine(combine(ArgMinMaxSmallDataset(),
+                                                       framework::dataset::make("DataTypeIn", DataType::S32)),
+                                               framework::dataset::make("DataTypeOut", DataType::S32)),
+                                       AxisDataset),
+                               OpsDataset))
+{
+    // Validate output
+    validate(Accessor(_target), _reference);
+}
 FIXTURE_DATA_TEST_CASE(RunLarge,
-                       NEArgMinMaxValidationFixture<int32_t>,
+                       NEArgMinMaxValidationFixture_S32_S32,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::S32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })))
+                       combine(combine(combine(combine(ArgMinMaxLargeDataset(),
+                                                       framework::dataset::make("DataTypeIn", DataType::S32)),
+                                               framework::dataset::make("DataTypeOut", DataType::S32)),
+                                       AxisDataset),
+                               OpsDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference);
@@ -97,18 +141,26 @@
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 TEST_SUITE(FP16)
 FIXTURE_DATA_TEST_CASE(RunSmall,
-                       NEArgMinMaxValidationFixture<half>,
+                       NEArgMinMaxValidationFixture_F16_S32,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })))
+                       combine(combine(combine(combine(ArgMinMaxSmallDataset(),
+                                                       framework::dataset::make("DataTypeIn", DataType::F16)),
+                                               framework::dataset::make("DataTypeOut", DataType::S32)),
+                                       AxisDataset),
+                               OpsDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference);
 }
 
 FIXTURE_DATA_TEST_CASE(RunLarge,
-                       NEArgMinMaxValidationFixture<half>,
+                       NEArgMinMaxValidationFixture_F16_S32,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })))
+                       combine(combine(combine(combine(ArgMinMaxLargeDataset(),
+                                                       framework::dataset::make("DataTypeIn", DataType::F16)),
+                                               framework::dataset::make("DataTypeOut", DataType::S32)),
+                                       AxisDataset),
+                               OpsDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference);
@@ -118,18 +170,26 @@
 
 TEST_SUITE(FP32)
 FIXTURE_DATA_TEST_CASE(RunSmall,
-                       NEArgMinMaxValidationFixture<float>,
+                       NEArgMinMaxValidationFixture_F32_S32,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })))
+                       combine(combine(combine(combine(ArgMinMaxSmallDataset(),
+                                                       framework::dataset::make("DataTypeIn", DataType::F32)),
+                                               framework::dataset::make("DataTypeOut", DataType::S32)),
+                                       AxisDataset),
+                               OpsDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference);
 }
 
 FIXTURE_DATA_TEST_CASE(RunLarge,
-                       NEArgMinMaxValidationFixture<float>,
+                       NEArgMinMaxValidationFixture_F32_S32,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })))
+                       combine(combine(combine(combine(ArgMinMaxLargeDataset(),
+                                                       framework::dataset::make("DataTypeIn", DataType::F32)),
+                                               framework::dataset::make("DataTypeOut", DataType::S32)),
+                                       AxisDataset),
+                               OpsDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference);
@@ -137,27 +197,35 @@
 TEST_SUITE_END() // FP32
 TEST_SUITE_END() // Float
 
-template <typename T>
-using NEArgMinMaxQuantizedValidationFixture = ArgMinMaxValidationQuantizedFixture<Tensor, Accessor, NEArgMinMaxLayer, T>;
+template <typename T1, typename T2>
+using NEArgMinMaxQuantizedValidationFixture = ArgMinMaxValidationQuantizedFixture<Tensor, Accessor, NEArgMinMaxLayer, T1, T2>;
+
+using NEArgMinMaxQuantizedValidationFixture_U8_S32 = NEArgMinMaxQuantizedValidationFixture<uint8_t, int32_t>;
+using NEArgMinMaxQuantizedValidationFixture_S8_S32 = NEArgMinMaxQuantizedValidationFixture<int8_t, int32_t>;
 
 TEST_SUITE(QASYMM8)
 FIXTURE_DATA_TEST_CASE(RunSmall,
-                       NEArgMinMaxQuantizedValidationFixture<uint8_t>,
+                       NEArgMinMaxQuantizedValidationFixture_U8_S32,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
-                                       framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })),
-                               framework::dataset::make("QuantizationInfo", { QuantizationInfo(5.f / 255.f, 20) })))
+                       combine(combine(combine(combine(combine(ArgMinMaxSmallDataset(),
+                                                               framework::dataset::make("DataTypeIn", DataType::QASYMM8)),
+                                                       framework::dataset::make("DataTypeOut", DataType::S32)),
+                                               AxisDataset),
+                                       OpsDataset),
+                               QInfoDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference);
 }
-
 FIXTURE_DATA_TEST_CASE(RunLarge,
-                       NEArgMinMaxQuantizedValidationFixture<uint8_t>,
+                       NEArgMinMaxQuantizedValidationFixture_U8_S32,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
-                                       framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })),
-                               framework::dataset::make("QuantizationInfo", { QuantizationInfo(5.f / 255.f, 20) })))
+                       combine(combine(combine(combine(combine(ArgMinMaxLargeDataset(),
+                                                               framework::dataset::make("DataTypeIn", DataType::QASYMM8)),
+                                                       framework::dataset::make("DataTypeOut", DataType::S32)),
+                                               AxisDataset),
+                                       OpsDataset),
+                               QInfoDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference);
@@ -166,22 +234,27 @@
 
 TEST_SUITE(QASYMM8_SIGNED)
 FIXTURE_DATA_TEST_CASE(RunSmall,
-                       NEArgMinMaxQuantizedValidationFixture<int8_t>,
+                       NEArgMinMaxQuantizedValidationFixture_S8_S32,
                        framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
-                                       framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })),
-                               framework::dataset::make("QuantizationInfo", { QuantizationInfo(5.f / 127.f, 20) })))
+                       combine(combine(combine(combine(combine(ArgMinMaxSmallDataset(),
+                                                               framework::dataset::make("DataTypeIn", DataType::QASYMM8_SIGNED)),
+                                                       framework::dataset::make("DataTypeOut", DataType::S32)),
+                                               AxisDataset),
+                                       OpsDataset),
+                               QInfoDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference);
 }
-
 FIXTURE_DATA_TEST_CASE(RunLarge,
-                       NEArgMinMaxQuantizedValidationFixture<int8_t>,
+                       NEArgMinMaxQuantizedValidationFixture_S8_S32,
                        framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("Axis", { 0, 1, 2, 3 })),
-                                       framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MIN, ReductionOperation::ARG_IDX_MAX })),
-                               framework::dataset::make("QuantizationInfo", { QuantizationInfo(5.f / 127.f, 20) })))
+                       combine(combine(combine(combine(combine(ArgMinMaxLargeDataset(),
+                                                               framework::dataset::make("DataTypeIn", DataType::QASYMM8_SIGNED)),
+                                                       framework::dataset::make("DataTypeOut", DataType::S32)),
+                                               AxisDataset),
+                                       OpsDataset),
+                               QInfoDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference);
diff --git a/tests/validation/fixtures/ArgMinMaxFixture.h b/tests/validation/fixtures/ArgMinMaxFixture.h
index 9a600b8..7a82356 100644
--- a/tests/validation/fixtures/ArgMinMaxFixture.h
+++ b/tests/validation/fixtures/ArgMinMaxFixture.h
@@ -42,14 +42,14 @@
 {
 namespace validation
 {
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
 class ArgMinMaxValidationBaseFixture : public framework::Fixture
 {
 public:
-    void setup(TensorShape shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info)
+    void setup(TensorShape shape, DataType input_type, DataType output_type, int axis, ReductionOperation op, QuantizationInfo q_info)
     {
-        _target    = compute_target(shape, data_type, axis, op, q_info);
-        _reference = compute_reference(shape, data_type, axis, op, q_info);
+        _target    = compute_target(shape, input_type, output_type, axis, op, q_info);
+        _reference = compute_reference(shape, input_type, output_type, axis, op, q_info);
     }
 
 protected:
@@ -97,11 +97,11 @@
         }
     }
 
-    TensorType compute_target(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info)
+    TensorType compute_target(TensorShape &src_shape, DataType input_type, DataType output_type, int axis, ReductionOperation op, QuantizationInfo q_info)
     {
         // Create tensors
-        TensorType src = create_tensor<TensorType>(src_shape, data_type, 1, q_info);
-        TensorType dst;
+        TensorType src = create_tensor<TensorType>(src_shape, input_type, 1, q_info);
+        TensorType dst = create_tensor<TensorType>(compute_output_shape(src_shape, axis), output_type, 1, q_info);
 
         // Create and configure function
         FunctionType arg_min_max_layer;
@@ -126,39 +126,43 @@
         return dst;
     }
 
-    SimpleTensor<int32_t> compute_reference(TensorShape &src_shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo q_info)
+    TensorShape compute_output_shape(const TensorShape &src_shape, int axis)
+    {
+        return arm_compute::misc::shape_calculator::compute_reduced_shape(src_shape, axis, false);
+    }
+
+    SimpleTensor<T2> compute_reference(TensorShape &src_shape, DataType input_type, DataType output_type, int axis, ReductionOperation op, QuantizationInfo q_info)
     {
         // Create reference
-        SimpleTensor<T> src{ src_shape, data_type, 1, q_info };
+        SimpleTensor<T1> src{ src_shape, input_type, 1, q_info };
 
         // Fill reference
         fill(src);
 
-        TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(src_shape, axis, false);
-        return reference::reduction_operation<T, int32_t>(src, output_shape, axis, op);
+        return reference::reduction_operation<T1, T2>(src, compute_output_shape(src_shape, axis), axis, op, output_type);
     }
 
-    TensorType            _target{};
-    SimpleTensor<int32_t> _reference{};
+    TensorType       _target{};
+    SimpleTensor<T2> _reference{};
 };
 
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class ArgMinMaxValidationQuantizedFixture : public ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
+class ArgMinMaxValidationQuantizedFixture : public ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T1, T2>
 {
 public:
-    void setup(const TensorShape &shape, DataType data_type, int axis, ReductionOperation op, QuantizationInfo quantization_info)
+    void setup(const TensorShape &shape, DataType input_type, DataType output_type, int axis, ReductionOperation op, QuantizationInfo quantization_info)
     {
-        ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, quantization_info);
+        ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, input_type, output_type, axis, op, quantization_info);
     }
 };
 
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class ArgMinMaxValidationFixture : public ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
+class ArgMinMaxValidationFixture : public ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T1, T2>
 {
 public:
-    void setup(const TensorShape &shape, DataType data_type, int axis, ReductionOperation op)
+    void setup(const TensorShape &shape, DataType input_type, DataType output_type, int axis, ReductionOperation op)
     {
-        ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, QuantizationInfo());
+        ArgMinMaxValidationBaseFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, input_type, output_type, axis, op, QuantizationInfo());
     }
 };
 } // namespace validation
diff --git a/tests/validation/fixtures/ReduceMeanFixture.h b/tests/validation/fixtures/ReduceMeanFixture.h
index 39fdea5..5363d6b 100644
--- a/tests/validation/fixtures/ReduceMeanFixture.h
+++ b/tests/validation/fixtures/ReduceMeanFixture.h
@@ -127,9 +127,9 @@
 
 #ifdef ARM_COMPUTE_OPENCL_ENABLED
             is_opencl = std::is_same<CLTensor, TensorType>::value; // Round down to zero on opencl to match kernel
-#endif /* ARM_COMPUTE_OPENCL_ENABLED */
-            out = reference::reduction_operation<T, T>(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM, quantization_info_output, is_opencl ? RoundingPolicy::TO_ZERO : RoundingPolicy::TO_NEAREST_UP);
-
+#endif                                                             /* ARM_COMPUTE_OPENCL_ENABLED */
+            out = reference::reduction_operation<T, T>(i == 0 ? src : out, output_shape, axis[i], ReductionOperation::MEAN_SUM, data_type, quantization_info_output,
+                                                       is_opencl ? RoundingPolicy::TO_ZERO : RoundingPolicy::TO_NEAREST_UP);
         }
 
         if(!keep_dims)
diff --git a/tests/validation/fixtures/ReductionOperationFixture.h b/tests/validation/fixtures/ReductionOperationFixture.h
index 36e6309..b44f299 100644
--- a/tests/validation/fixtures/ReductionOperationFixture.h
+++ b/tests/validation/fixtures/ReductionOperationFixture.h
@@ -134,7 +134,7 @@
         // Fill reference
         fill(src);
 
-        return reference::reduction_operation<T, T>(src, dst_shape, axis, op, quantization_info);
+        return reference::reduction_operation<T, T>(src, dst_shape, axis, op, data_type, quantization_info);
     }
 
     TensorType      _target{};
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp
index e2890af..c189bc2 100644
--- a/tests/validation/reference/ReductionOperation.cpp
+++ b/tests/validation/reference/ReductionOperation.cpp
@@ -181,12 +181,12 @@
 } // namespace
 
 template <typename T, typename OT>
-SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, RoundingPolicy policy)
+SimpleTensor<OT> compute_reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
+                                             DataType output_type, RoundingPolicy policy)
 {
     // Create reference
-    const bool         is_arg_min_max   = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
-    DataType           output_data_type = is_arg_min_max ? DataType::S32 : src.data_type();
-    SimpleTensor<OT>   dst{ dst_shape, output_data_type, 1, src.quantization_info() };
+    const bool         is_arg_min_max = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
+    SimpleTensor<OT>   dst{ dst_shape, output_type, 1, src.quantization_info() };
     const unsigned int src_width    = src.shape().x();
     const unsigned int src_height   = src.shape().y();
     const unsigned int src_depth    = src.shape().z();
@@ -275,74 +275,89 @@
 }
 
 template <typename T, typename OT>
-SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info_output, RoundingPolicy policy)
+SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
+                                     DataType output_type, QuantizationInfo quantization_info_output, RoundingPolicy policy)
 {
     ARM_COMPUTE_UNUSED(quantization_info_output);
-    return compute_reduction_operation<T, OT>(src, dst_shape, axis, op, policy);
+    return compute_reduction_operation<T, OT>(src, dst_shape, axis, op, output_type, policy);
 }
 
 template <>
-SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info_output, RoundingPolicy policy)
+SimpleTensor<uint8_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
+                                          DataType output_type, QuantizationInfo quantization_info_output, RoundingPolicy policy)
 {
     if(src.data_type() == DataType::QASYMM8)
     {
         // If the operation is MEAN_SUM, we can directly use the uint8 implementation without taking into account scale and offset
         if(op == ReductionOperation::MEAN_SUM && src.quantization_info() == quantization_info_output)
         {
-            return compute_reduction_operation<uint8_t, uint8_t>(src, dst_shape, axis, op, policy);
+            return compute_reduction_operation<uint8_t, uint8_t>(src, dst_shape, axis, op, output_type, policy);
         }
         else
         {
             SimpleTensor<float> src_f = convert_from_asymmetric(src);
-            SimpleTensor<float> dst_f = reference::reduction_operation<float, float>(src_f, dst_shape, axis, op);
+            SimpleTensor<float> dst_f = reference::reduction_operation<float, float>(src_f, dst_shape, axis, op, output_type);
             return convert_to_asymmetric<uint8_t>(dst_f, quantization_info_output);
         }
     }
     else
     {
-        return compute_reduction_operation<uint8_t, uint8_t>(src, dst_shape, axis, op, policy);
+        return compute_reduction_operation<uint8_t, uint8_t>(src, dst_shape, axis, op, output_type, policy);
     }
 }
 
 template <>
-SimpleTensor<int8_t> reduction_operation(const SimpleTensor<int8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info_output, RoundingPolicy policy)
+SimpleTensor<int8_t> reduction_operation(const SimpleTensor<int8_t> &src, const TensorShape &dst_shape, unsigned int axis,
+                                         ReductionOperation op, DataType output_type, QuantizationInfo quantization_info_output, RoundingPolicy policy)
 {
     if(src.data_type() == DataType::QASYMM8_SIGNED)
     {
         // If the operation is MEAN_SUM, we can directly use the int8 implementation without taking into account scale and offset
         if(op == ReductionOperation::MEAN_SUM && src.quantization_info() == quantization_info_output)
         {
-            return compute_reduction_operation<int8_t, int8_t>(src, dst_shape, axis, op, policy);
+            return compute_reduction_operation<int8_t, int8_t>(src, dst_shape, axis, op, output_type, policy);
         }
         else
         {
             SimpleTensor<float> src_f = convert_from_asymmetric(src);
-            SimpleTensor<float> dst_f = reference::reduction_operation<float, float>(src_f, dst_shape, axis, op);
+            SimpleTensor<float> dst_f = reference::reduction_operation<float, float>(src_f, dst_shape, axis, op, output_type);
             return convert_to_asymmetric<int8_t>(dst_f, quantization_info_output);
         }
     }
     else
     {
-        return compute_reduction_operation<int8_t, int8_t>(src, dst_shape, axis, op, policy);
+        return compute_reduction_operation<int8_t, int8_t>(src, dst_shape, axis, op, output_type, policy);
     }
 }
 
 template SimpleTensor<float> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
-                                                 QuantizationInfo quantization_info_output = QuantizationInfo(), RoundingPolicy policy = RoundingPolicy::TO_ZERO);
+                                                 DataType output_type = DataType::S32, QuantizationInfo quantization_info_output = QuantizationInfo(),
+                                                 RoundingPolicy policy = RoundingPolicy::TO_ZERO);
+
 template SimpleTensor<half> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
+                                                DataType         output_type              = DataType::S32,
                                                 QuantizationInfo quantization_info_output = QuantizationInfo(), RoundingPolicy policy = RoundingPolicy::TO_ZERO);
 
 template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
-                                                   QuantizationInfo quantization_info_output = QuantizationInfo(), RoundingPolicy policy = RoundingPolicy::TO_ZERO);
-template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<int32_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
-                                                   QuantizationInfo quantization_info_output = QuantizationInfo(), RoundingPolicy policy = RoundingPolicy::TO_ZERO);
-template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
-                                                   QuantizationInfo quantization_info_output = QuantizationInfo(), RoundingPolicy policy = RoundingPolicy::TO_ZERO);
-template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
-                                                   QuantizationInfo quantization_info_output = QuantizationInfo(), RoundingPolicy policy = RoundingPolicy::TO_ZERO);
-template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<int8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
+                                                   DataType         output_type              = DataType::S32,
                                                    QuantizationInfo quantization_info_output = QuantizationInfo(), RoundingPolicy policy = RoundingPolicy::TO_ZERO);
 
+template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<int32_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
+                                                   DataType         output_type              = DataType::S32,
+                                                   QuantizationInfo quantization_info_output = QuantizationInfo(), RoundingPolicy policy = RoundingPolicy::TO_ZERO);
+template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<half> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
+                                                   DataType         output_type              = DataType::S32,
+                                                   QuantizationInfo quantization_info_output = QuantizationInfo(), RoundingPolicy policy = RoundingPolicy::TO_ZERO);
+template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<uint8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
+                                                   DataType         output_type              = DataType::S32,
+                                                   QuantizationInfo quantization_info_output = QuantizationInfo(), RoundingPolicy policy = RoundingPolicy::TO_ZERO);
+template SimpleTensor<int32_t> reduction_operation(const SimpleTensor<int8_t> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
+                                                   DataType         output_type              = DataType::S32,
+                                                   QuantizationInfo quantization_info_output = QuantizationInfo(), RoundingPolicy policy = RoundingPolicy::TO_ZERO);
+
+template SimpleTensor<int64_t> reduction_operation(const SimpleTensor<float> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
+                                                   DataType output_type = DataType::S32, QuantizationInfo quantization_info_output = QuantizationInfo(),
+                                                   RoundingPolicy policy = RoundingPolicy::TO_ZERO);
 } // namespace reference
 } // namespace validation
 } // namespace test
diff --git a/tests/validation/reference/ReductionOperation.h b/tests/validation/reference/ReductionOperation.h
index dd97778..fb2e7a7 100644
--- a/tests/validation/reference/ReductionOperation.h
+++ b/tests/validation/reference/ReductionOperation.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2020, 2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -24,8 +24,8 @@
 #ifndef ARM_COMPUTE_TEST_REDUCTION_OPERATION_H
 #define ARM_COMPUTE_TEST_REDUCTION_OPERATION_H
 
-#include "tests/SimpleTensor.h"
 #include "arm_compute/core/Rounding.h"
+#include "tests/SimpleTensor.h"
 #include "tests/validation/Helpers.h"
 
 namespace arm_compute
@@ -37,7 +37,7 @@
 namespace reference
 {
 template <typename T, typename OT>
-SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op,
+SimpleTensor<OT> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op, DataType output_type = DataType::S32,
                                      QuantizationInfo quantization_info_output = QuantizationInfo(), RoundingPolicy policy = RoundingPolicy::TO_ZERO);
 } // namespace reference
 } // namespace validation