COMPMID-2707: add keep_dims parameter to Reduction Operation

The added parameter is used to decide whether or not to keep
the target dimension of reduction operation. ArgMinMax operations
will always remove the reduced dimension. Following things
are updated to support the parameter.

- [CL/NEON] functions and reference kernel
- [CL/NEON] ArgMinMax function to use ReductionOperation function
- [CL/NEON] validation test suite for Reduction and ArgMinMax operations
  to validate the added parameter
- ReductionOperationFixture is modified NOT to pre-populate output
  tensor and now relies on underlying kernel/function.
- Adjust CL validation test suite for Reduction operation to remove
  excessive test cases with axis values beyond input tensor's
  dimension.

Change-Id: I3e24d276ed469a4201f323001708f0c525f11c4f
Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2167
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/tests/validation/CL/ArgMinMax.cpp b/tests/validation/CL/ArgMinMax.cpp
index 6de09be..845fdbf 100644
--- a/tests/validation/CL/ArgMinMax.cpp
+++ b/tests/validation/CL/ArgMinMax.cpp
@@ -25,7 +25,9 @@
 #include "arm_compute/runtime/CL/CLTensor.h"
 #include "arm_compute/runtime/CL/CLTensorAllocator.h"
 #include "arm_compute/runtime/CL/functions/CLArgMinMaxLayer.h"
+#include "arm_compute/runtime/CL/functions/CLReductionOperation.h"
 
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "tests/CL/CLAccessor.h"
 #include "tests/datasets/ShapeDatasets.h"
 #include "tests/datasets/SplitDataset.h"
@@ -49,16 +51,18 @@
         framework::dataset::make("InputInfo", { TensorInfo(TensorShape(27U, 3U, 16U, 2U), 1, DataType::F32), // Invalid axis
                                                 TensorInfo(TensorShape(27U, 3U, 16U, 2U), 1, DataType::F32), // Invalid output shape
                                                 TensorInfo(TensorShape(32U, 16U, 16U, 2U), 1, DataType::F32),
-                                                TensorInfo(TensorShape(32U, 16U, 16U, 2U), 1, DataType::F32) // Invalid operation
+                                                TensorInfo(TensorShape(32U, 16U, 16U, 2U), 1, DataType::F32), // Invalid operation
+                                                TensorInfo(TensorShape(32U, 16U, 16U, 2U), 1, DataType::F32) // Not allowed keeping the dimension 
         }),
-        framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(27U, 3U, 1U, 2U), 1, DataType::F32),
-                                                 TensorInfo(TensorShape(27U, 3U, 1U, 2U), 1, DataType::F32),
-                                                 TensorInfo(TensorShape(32U, 16U, 1U, 2U), 1, DataType::U32),
-                                                 TensorInfo(TensorShape(32U, 16U, 1U, 2U), 1, DataType::F32)
+        framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(27U, 3U, 2U), 1, DataType::F32),
+                                                 TensorInfo(TensorShape(27U, 3U, 2U), 1, DataType::F32),
+                                                 TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::U32),
+                                                 TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::F32),
+                                                 TensorInfo(TensorShape(32U, 16U, 1U, 2U), 1, DataType::U32)
         })),
-        framework::dataset::make("Axis", { 4, 0, 2, 0 })),
-        framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MAX, ReductionOperation::ARG_IDX_MAX, ReductionOperation::ARG_IDX_MAX, ReductionOperation::MEAN_SUM })),
-        framework::dataset::make("Expected", { false, false, true, false })),
+        framework::dataset::make("Axis", { 4, 0, 2, 0, 2 })),
+        framework::dataset::make("Operation", { ReductionOperation::ARG_IDX_MAX, ReductionOperation::ARG_IDX_MAX, ReductionOperation::ARG_IDX_MAX, ReductionOperation::MEAN_SUM, ReductionOperation::ARG_IDX_MAX })),
+        framework::dataset::make("Expected", { false, false, true, false, false })),
         input_info, output_info, axis, operation, expected)
 {
     const Status status = CLArgMinMaxLayer::validate(&input_info.clone()->set_is_resizable(false), axis, &output_info.clone()->set_is_resizable(false), operation);
@@ -76,13 +80,13 @@
     CLTensor ref_src = create_tensor<CLTensor>(shape, data_type);
     CLTensor dst;
 
+    constexpr int axis = 1;
+
     // Create and Configure function
     CLArgMinMaxLayer arg_min_max_layer;
-    arg_min_max_layer.configure(&ref_src, 1, &dst, ReductionOperation::ARG_IDX_MAX);
+    arg_min_max_layer.configure(&ref_src, axis, &dst, ReductionOperation::ARG_IDX_MAX);
 
-    // Validate valid region
-    TensorShape output_shape = shape;
-    output_shape.set(1, 1);
+    const auto        output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(shape, axis, false);
     const ValidRegion valid_region = shape_to_valid_region(output_shape);
     validate(dst.info()->valid_region(), valid_region);
 }
diff --git a/tests/validation/CL/ReductionOperation.cpp b/tests/validation/CL/ReductionOperation.cpp
index 9a3cd99..1dec020 100644
--- a/tests/validation/CL/ReductionOperation.cpp
+++ b/tests/validation/CL/ReductionOperation.cpp
@@ -57,6 +57,7 @@
 
 });
 
+const auto KeepDimensions = framework::dataset::make("KeepDims", { true, false });
 } // namespace
 
 TEST_SUITE(CL)
@@ -64,29 +65,34 @@
 
 // *INDENT-OFF*
 // clang-format off
-DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
     framework::dataset::make("InputInfo",          { TensorInfo(TensorShape(128U, 64U), 1, DataType::F32), // Mismatching data type input/output
                                                      TensorInfo(TensorShape(128U, 64U), 3, DataType::F32), // Number of Input channels != 1
                                                      TensorInfo(TensorShape(128U, 64U), 1, DataType::S16), // DataType != QASYMM8/F16/F32
                                                      TensorInfo(TensorShape(128U, 64U), 1, DataType::F32), // Axis >= num_max_dimensions
                                                      TensorInfo(TensorShape(128U, 64U), 1, DataType::QASYMM8), // Axis == 0 and SUM_SQUARE and QASYMM8
-                                                     TensorInfo(TensorShape(128U, 64U), 1, DataType::F32)
+                                                     TensorInfo(TensorShape(128U, 64U), 1, DataType::F32),
+                                                     TensorInfo(TensorShape(128U, 64U), 1, DataType::F32) // Kept Dimension when keep_dims = false
+
                                                    }),
     framework::dataset::make("OutputInfo",         { TensorInfo(TensorShape(1U, 64U), 1, DataType::F16),
                                                      TensorInfo(TensorShape(1U, 64U), 1, DataType::F32),
                                                      TensorInfo(TensorShape(1U, 64U), 1, DataType::S16),
                                                      TensorInfo(TensorShape(1U, 64U), 1, DataType::F32),
                                                      TensorInfo(TensorShape(1U, 64U), 1, DataType::QASYMM8),
+                                                     TensorInfo(TensorShape(1U, 64U), 1, DataType::F32),
                                                      TensorInfo(TensorShape(1U, 64U), 1, DataType::F32)
                                                    })),
-    framework::dataset::make("Axis",               { 0U, 0U, 0U, static_cast<unsigned int>(TensorShape::num_max_dimensions), 1U, 0U })),
-    framework::dataset::make("Expected",           { false, false, false, false, false, true })),
-    input_info, output_info, axis, expected)
+    framework::dataset::make("Axis",               { 0U, 0U, 0U, static_cast<unsigned int>(TensorShape::num_max_dimensions), 1U, 0U, 0U })),
+    framework::dataset::make("KeepDims",           { true, true, true, true, true, true, false })),
+    framework::dataset::make("Expected",           { false, false, false, false, false, true , false })),
+    input_info, output_info, axis, keep_dims, expected)
 {
     bool is_valid = bool(CLReductionOperation::validate(&input_info.clone()->set_is_resizable(false),
                                                         &output_info.clone()->set_is_resizable(true),
                                                         axis,
-                                                        ReductionOperation::SUM_SQUARE));
+                                                        ReductionOperation::SUM_SQUARE,
+                                                        keep_dims));
     ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
 }
 // clang-format on
@@ -97,28 +103,54 @@
 
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLReductionOperationFixture<half>, framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations))
+FIXTURE_DATA_TEST_CASE(RunSmall2D, CLReductionOperationFixture<half>, framework::DatasetMode::PRECOMMIT,
+                       combine(combine(combine(combine(datasets::Small2DShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1 })), ReductionOperations), KeepDimensions))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+FIXTURE_DATA_TEST_CASE(RunSmall3D, CLReductionOperationFixture<half>, framework::DatasetMode::PRECOMMIT,
+                       combine(combine(combine(combine(datasets::Small3DShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1, 2 })), ReductionOperations), KeepDimensions))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, tolerance_f16);
+}
+FIXTURE_DATA_TEST_CASE(RunSmall4D, CLReductionOperationFixture<half>, framework::DatasetMode::PRECOMMIT,
+                       combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations),
+                               KeepDimensions))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16);
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, CLReductionOperationFixture<half>, framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations))
+                       combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations), KeepDimensions))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0, tolerance_f16);
 }
 TEST_SUITE_END() // F16
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLReductionOperationFixture<float>, framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations))
+FIXTURE_DATA_TEST_CASE(RunSmall2D, CLReductionOperationFixture<float>, framework::DatasetMode::PRECOMMIT,
+                       combine(combine(combine(combine(datasets::Small2DShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1 })), ReductionOperations), KeepDimensions))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+FIXTURE_DATA_TEST_CASE(RunSmall3D, CLReductionOperationFixture<float>, framework::DatasetMode::PRECOMMIT,
+                       combine(combine(combine(combine(datasets::Small3DShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2 })), ReductionOperations), KeepDimensions))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, tolerance_f32);
+}
+FIXTURE_DATA_TEST_CASE(RunSmall4D, CLReductionOperationFixture<float>, framework::DatasetMode::PRECOMMIT,
+                       combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations),
+                               KeepDimensions))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, CLReductionOperationFixture<float>, framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations))
+                       combine(combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("Axis", { 0, 1, 2, 3 })), ReductionOperations), KeepDimensions))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0, tolerance_f32);
diff --git a/tests/validation/NEON/ArgMinMax.cpp b/tests/validation/NEON/ArgMinMax.cpp
index 71fb39a..642a69b 100644
--- a/tests/validation/NEON/ArgMinMax.cpp
+++ b/tests/validation/NEON/ArgMinMax.cpp
@@ -24,9 +24,11 @@
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/utils/misc/Traits.h"
 #include "arm_compute/runtime/NEON/functions/NEArgMinMaxLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEReductionOperation.h"
 #include "arm_compute/runtime/Tensor.h"
 #include "arm_compute/runtime/TensorAllocator.h"
 
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "tests/NEON/Accessor.h"
 #include "tests/datasets/ShapeDatasets.h"
 #include "tests/datasets/SplitDataset.h"
@@ -54,7 +56,7 @@
         }),
         framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(27U, 3U, 1U, 2U), 1, DataType::F32),
                                                  TensorInfo(TensorShape(27U, 3U, 1U, 2U), 1, DataType::F32),
-                                                 TensorInfo(TensorShape(32U, 16U, 1U, 2U), 1, DataType::U32),
+                                                 TensorInfo(TensorShape(32U, 16U, 2U), 1, DataType::U32),
                                                  TensorInfo(TensorShape(32U, 16U, 1U, 2U), 1, DataType::F32)
         })),
         framework::dataset::make("Axis", { 4, 0, 2, 0 })),
@@ -74,17 +76,17 @@
                shape, data_type)
 {
     // Create tensors
-    Tensor ref_src = create_tensor<Tensor>(shape, data_type);
-    Tensor dst;
+    Tensor    ref_src = create_tensor<Tensor>(shape, data_type);
+    Tensor    dst;
+    const int axis = 1;
 
     // Create and Configure function
     NEArgMinMaxLayer arg_min_max_layer;
-    arg_min_max_layer.configure(&ref_src, 1, &dst, ReductionOperation::ARG_IDX_MAX);
+    arg_min_max_layer.configure(&ref_src, axis, &dst, ReductionOperation::ARG_IDX_MAX);
 
     // Validate valid region
-    TensorShape output_shape = shape;
-    output_shape.set(1, 1);
-    const ValidRegion valid_region = shape_to_valid_region(output_shape);
+    const auto        expected_output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(shape, axis, false);
+    const ValidRegion valid_region          = shape_to_valid_region(expected_output_shape);
     validate(dst.info()->valid_region(), valid_region);
 }
 
diff --git a/tests/validation/NEON/ReductionOperation.cpp b/tests/validation/NEON/ReductionOperation.cpp
index 5b697a5..3a7f707 100644
--- a/tests/validation/NEON/ReductionOperation.cpp
+++ b/tests/validation/NEON/ReductionOperation.cpp
@@ -66,6 +66,8 @@
 const auto Axises = framework::dataset::make("Axis",
 { 0, 1, 2, 3 });
 
+const auto KeepDims = framework::dataset::make("KeepDims", { true, false });
+
 } // namespace
 
 TEST_SUITE(NEON)
@@ -73,27 +75,31 @@
 
 // *INDENT-OFF*
 // clang-format off
-DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
     framework::dataset::make("InputInfo",          { TensorInfo(TensorShape(128U, 64U), 1, DataType::F32), // Mismatching data type input/output
                                                      TensorInfo(TensorShape(128U, 64U), 2, DataType::F32), // Number of Input channels != 1
                                                      TensorInfo(TensorShape(128U, 64U), 1, DataType::S16), // DataType != F32
                                                      TensorInfo(TensorShape(128U, 64U), 1, DataType::F32), // Axis >= num_max_dimensions
-                                                     TensorInfo(TensorShape(128U, 64U), 1, DataType::F32)
+                                                     TensorInfo(TensorShape(128U, 64U), 1, DataType::F32),
+                                                     TensorInfo(TensorShape(128U, 64U), 1, DataType::F32) // Kept dimension when keep_dims = false
                                                    }),
     framework::dataset::make("OutputInfo",         { TensorInfo(TensorShape(1U, 64U), 1, DataType::F16),
                                                      TensorInfo(TensorShape(1U, 64U), 1, DataType::F32),
                                                      TensorInfo(TensorShape(1U, 64U), 1, DataType::S16),
                                                      TensorInfo(TensorShape(1U, 64U), 1, DataType::F32),
+                                                     TensorInfo(TensorShape(1U, 64U), 1, DataType::F32),
                                                      TensorInfo(TensorShape(1U, 64U), 1, DataType::F32)
                                                    })),
-    framework::dataset::make("Axis",               { 0U, 0U, 0U, static_cast<unsigned int>(TensorShape::num_max_dimensions), 0U })),
-    framework::dataset::make("Expected",           { false, false, false, false, true })),
-    input_info, output_info, axis, expected)
+    framework::dataset::make("Axis",               { 0U, 0U, 0U, static_cast<unsigned int>(TensorShape::num_max_dimensions), 0U, 0U })),
+    framework::dataset::make("KeepDims",           { true, true, true, true, true, false})),
+    framework::dataset::make("Expected",           { false, false, false, false, true, false })),
+    input_info, output_info, axis, keep_dims, expected)
 {
     bool is_valid = bool(NEReductionOperation::validate(&input_info.clone()->set_is_resizable(false),
                                                         &output_info.clone()->set_is_resizable(true),
                                                         axis,
-                                                        ReductionOperation::SUM_SQUARE));
+                                                        ReductionOperation::SUM_SQUARE,
+                                                        keep_dims));
     ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
 }
 // clang-format on
@@ -104,13 +110,13 @@
 
 TEST_SUITE(FP32)
 FIXTURE_DATA_TEST_CASE(RunSmall, NEReductionOperationFixture<float>, framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::F32)), Axises), ReductionOperations))
+                       combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::F32)), Axises), ReductionOperations), KeepDims))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f32);
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, NEReductionOperationFixture<float>, framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::F32)), Axises), ReductionOperations))
+                       combine(combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::F32)), Axises), ReductionOperations), KeepDims))
 {
     // Validate output
     validate(Accessor(_target), _reference, rel_tolerance_f32, 0, tolerance_f32);
@@ -122,17 +128,19 @@
 
 TEST_SUITE(QASYMM8)
 FIXTURE_DATA_TEST_CASE(RunSmall, NEReductionOperationQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
-                       combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), Axises),
-                                       ReductionOperations),
-                               QuantizationInfos))
+                       combine(combine(combine(combine(combine(datasets::Small4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), Axises),
+                                               ReductionOperations),
+                                       QuantizationInfos),
+                               KeepDims))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_qasymm8);
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, NEReductionOperationQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
-                       combine(combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), Axises),
-                                       ReductionOperations),
-                               QuantizationInfos))
+                       combine(combine(combine(combine(combine(datasets::Large4DShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), Axises),
+                                               ReductionOperations),
+                                       QuantizationInfos),
+                               KeepDims))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_qasymm8);
diff --git a/tests/validation/fixtures/ArgMinMaxFixture.h b/tests/validation/fixtures/ArgMinMaxFixture.h
index ed6b51a..f8fe4ff 100644
--- a/tests/validation/fixtures/ArgMinMaxFixture.h
+++ b/tests/validation/fixtures/ArgMinMaxFixture.h
@@ -26,6 +26,7 @@
 
 #include "arm_compute/core/TensorShape.h"
 #include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/runtime/Tensor.h"
 #include "tests/AssetsLibrary.h"
 #include "tests/Globals.h"
@@ -121,8 +122,7 @@
         // Fill reference
         fill(src);
 
-        TensorShape output_shape = src_shape;
-        output_shape.set(axis, 1);
+        TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(src_shape, axis, false);
         return reference::reduction_operation<T, uint32_t>(src, output_shape, axis, op);
     }
 
diff --git a/tests/validation/fixtures/ReductionOperationFixture.h b/tests/validation/fixtures/ReductionOperationFixture.h
index d01f41a..867c08e 100644
--- a/tests/validation/fixtures/ReductionOperationFixture.h
+++ b/tests/validation/fixtures/ReductionOperationFixture.h
@@ -26,6 +26,7 @@
 
 #include "arm_compute/core/TensorShape.h"
 #include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/runtime/Tensor.h"
 #include "tests/AssetsLibrary.h"
 #include "tests/Globals.h"
@@ -45,11 +46,15 @@
 {
 public:
     template <typename...>
-    void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info)
+    void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info, bool keep_dims = false)
     {
-        const TensorShape output_shape = get_output_shape(shape, axis);
-        _target                        = compute_target(shape, output_shape, data_type, axis, op, quantization_info);
-        _reference                     = compute_reference(shape, output_shape, data_type, axis, op, quantization_info);
+        const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX) || (op == ReductionOperation::ARG_IDX_MIN);
+        _keep_dims                = keep_dims && !is_arg_min_max;
+
+        const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(shape, axis, _keep_dims);
+
+        _target    = compute_target(shape, data_type, axis, op, quantization_info);
+        _reference = compute_reference(shape, output_shape, data_type, axis, op, quantization_info);
     }
 
 protected:
@@ -70,15 +75,15 @@
         }
     }
 
-    TensorType compute_target(const TensorShape &src_shape, const TensorShape &dst_shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info)
+    TensorType compute_target(const TensorShape &src_shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info)
     {
         // Create tensors
         TensorType src = create_tensor<TensorType>(src_shape, data_type, 1, quantization_info);
-        TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, quantization_info);
+        TensorType dst;
 
         // Create and configure function
         FunctionType reduction_func;
-        reduction_func.configure(&src, &dst, axis, op);
+        reduction_func.configure(&src, &dst, axis, op, _keep_dims);
 
         ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
         ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -114,12 +119,7 @@
     SimpleTensor<T> _reference{};
 
 private:
-    TensorShape get_output_shape(TensorShape shape, unsigned int axis)
-    {
-        TensorShape output_shape(shape);
-        output_shape.set(axis, 1);
-        return output_shape;
-    }
+    bool _keep_dims{ false };
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
@@ -127,9 +127,9 @@
 {
 public:
     template <typename...>
-    void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info = QuantizationInfo())
+    void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, QuantizationInfo quantization_info = QuantizationInfo(), bool keep_dims = false)
     {
-        ReductionOperationValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, quantization_info);
+        ReductionOperationValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, quantization_info, keep_dims);
     }
 };
 
@@ -138,9 +138,9 @@
 {
 public:
     template <typename...>
-    void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op)
+    void setup(TensorShape shape, DataType data_type, unsigned int axis, ReductionOperation op, bool keep_dims = false)
     {
-        ReductionOperationValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, QuantizationInfo());
+        ReductionOperationValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, axis, op, QuantizationInfo(), keep_dims);
     }
 };
 } // namespace validation