COMPMID-2707: add keep_dims parameter to Reduction Operation

The added parameter is used to decide whether or not to keep
the target dimension of reduction operation. ArgMinMax operations
will always remove the reduced dimension. Following things
are updated to support the parameter.

- [CL/NEON] functions and reference kernel
- [CL/NEON] ArgMinMax function to use ReductionOperation function
- [CL/NEON] validation test suite for Reduction and ArgMinMax operations
  to validate the added parameter
- ReductionOperationFixture is modified NOT to pre-populate output
  tensor and now relies on underlying kernel/function.
- Adjust CL validation test suite for Reduction operation to remove
  excessive test cases with axis values beyond input tensor's
  dimension.

Change-Id: I3e24d276ed469a4201f323001708f0c525f11c4f
Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2167
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/core/CL/kernels/CLReductionOperationKernel.cpp b/src/core/CL/kernels/CLReductionOperationKernel.cpp
index 8e92b59..a085ab1 100644
--- a/src/core/CL/kernels/CLReductionOperationKernel.cpp
+++ b/src/core/CL/kernels/CLReductionOperationKernel.cpp
@@ -33,6 +33,7 @@
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/Window.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 
 #include "support/ToolchainSupport.h"
 
@@ -80,17 +81,15 @@
 std::tuple<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, unsigned int axis, ReductionOperation op)
 {
     // Output tensor auto initialization if not yet initialized
-    TensorShape output_shape{ input->tensor_shape() };
-    output_shape.set(axis, 1);
-    const bool is_arg_min_max   = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
-    DataType   output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
+    const bool        is_arg_min_max   = (op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::ARG_IDX_MAX);
+    const TensorShape output_shape     = arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis, !is_arg_min_max);
+    const DataType    output_data_type = is_arg_min_max ? DataType::U32 : input->data_type();
     auto_init_if_empty(*output, input->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
 
     const unsigned int num_elems_processed_per_iteration = (is_data_type_quantized(input->data_type()) && (axis == 0)) ? 1 : 16;
     Window             win                               = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
     bool               window_changed                    = false;
-    const bool         is_serial_op                      = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN
-                                                            || op == ReductionOperation::MAX || is_data_type_quantized(input->data_type()));
+    const bool         is_serial_op                      = needs_serialized_reduction(op, input->data_type(), axis);
 
     switch(axis)
     {
@@ -198,8 +197,8 @@
     // Create kernel
     cl::NDRange lws_hint = CLKernelLibrary::get().default_ndrange();
     std::string kernel_axis_name;
-    const bool  is_serial_op = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN || op == ReductionOperation::MIN || op == ReductionOperation::MAX
-                                || is_data_type_quantized(input->info()->data_type()));
+    const bool  is_serial_op = needs_serialized_reduction(_op, _input->info()->data_type(), _reduction_axis);
+
     switch(axis)
     {
         case 0:
@@ -264,8 +263,7 @@
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
 
-    const bool is_serial_op = (_op == ReductionOperation::ARG_IDX_MAX || _op == ReductionOperation::ARG_IDX_MIN || _op == ReductionOperation::MIN || _op == ReductionOperation::MAX
-                               || is_data_type_quantized(_input->info()->data_type()));
+    const bool is_serial_op = needs_serialized_reduction(_op, _input->info()->data_type(), _reduction_axis);
     switch(_reduction_axis)
     {
         case 0:
diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp
index 7e1af0e..fa335d7 100644
--- a/src/core/Utils.cpp
+++ b/src/core/Utils.cpp
@@ -427,6 +427,16 @@
     return std::make_pair(w, h);
 }
 
+bool arm_compute::needs_serialized_reduction(ReductionOperation op, DataType dt, unsigned int axis)
+{
+    const bool is_arg_min_max    = (op == ReductionOperation::ARG_IDX_MAX || op == ReductionOperation::ARG_IDX_MIN);
+    const bool is_min_max        = (op == ReductionOperation::MAX || op == ReductionOperation::MIN);
+    const bool is_quantized_type = is_data_type_quantized(dt);
+    const bool is_first_dim      = (axis == 0);
+
+    return !is_first_dim || is_arg_min_max || is_min_max || is_quantized_type;
+}
+
 #ifdef ARM_COMPUTE_ASSERTS_ENABLED
 void arm_compute::print_consecutive_elements(std::ostream &s, DataType dt, const uint8_t *ptr, unsigned int n, int stream_width, const std::string &element_delim)
 {
diff --git a/src/runtime/CL/functions/CLArgMinMaxLayer.cpp b/src/runtime/CL/functions/CLArgMinMaxLayer.cpp
index a6393c5..fd172d5 100644
--- a/src/runtime/CL/functions/CLArgMinMaxLayer.cpp
+++ b/src/runtime/CL/functions/CLArgMinMaxLayer.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -23,26 +23,33 @@
  */
 
 #include "arm_compute/runtime/CL/functions/CLArgMinMaxLayer.h"
+#include "arm_compute/runtime/CL/functions/CLReductionOperation.h"
 
-#include "arm_compute/core/CL/kernels/CLReductionOperationKernel.h"
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Validate.h"
-#include "arm_compute/runtime/CL/CLScheduler.h"
 
 namespace arm_compute
 {
-void CLArgMinMaxLayer::configure(const ICLTensor *input, int axis, ICLTensor *output, const ReductionOperation &op)
+CLArgMinMaxLayer::CLArgMinMaxLayer(std::shared_ptr<IMemoryManager> memory_manager)
+    : _reduction_function(support::cpp14::make_unique<CLReductionOperation>(std::move(memory_manager)))
 {
-    auto k = arm_compute::support::cpp14::make_unique<CLReductionOperationKernel>();
-    k->configure(input, output, axis, op);
-    _kernel = std::move(k);
+}
+
+void CLArgMinMaxLayer::configure(ICLTensor *input, int axis, ICLTensor *output, const ReductionOperation &op)
+{
+    _reduction_function->configure(input, output, axis, op, false);
 }
 
 Status CLArgMinMaxLayer::validate(const ITensorInfo *input, int axis, const ITensorInfo *output, const ReductionOperation &op)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(op != ReductionOperation::ARG_IDX_MAX && op != ReductionOperation::ARG_IDX_MIN, "Invalid operation");
-    return CLReductionOperationKernel::validate(input, output, axis, op);
+    return CLReductionOperation::validate(input, output, axis, op, false);
+}
+
+void CLArgMinMaxLayer::run()
+{
+    _reduction_function->run();
 }
 } // namespace arm_compute
\ No newline at end of file
diff --git a/src/runtime/CL/functions/CLReductionOperation.cpp b/src/runtime/CL/functions/CLReductionOperation.cpp
index 38f0a75..447c15b 100644
--- a/src/runtime/CL/functions/CLReductionOperation.cpp
+++ b/src/runtime/CL/functions/CLReductionOperation.cpp
@@ -26,15 +26,17 @@
 #include "arm_compute/core/CL/ICLTensor.h"
 #include "arm_compute/core/CL/kernels/CLReductionOperationKernel.h"
 #include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/PixelValue.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
 #include "arm_compute/runtime/Tensor.h"
 #include "support/ToolchainSupport.h"
 
-using namespace arm_compute;
-
+namespace arm_compute
+{
 namespace
 {
 unsigned int calculate_number_of_stages(const ITensorInfo *input, unsigned int axis)
@@ -56,17 +58,52 @@
 } // namespace
 
 CLReductionOperation::CLReductionOperation(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(std::move(memory_manager)), _results_vector(), _reduction_kernels_vector(), _border_handlers_vector(), _num_of_stages(), _reduction_axis(), _is_serial()
+    : _memory_group(std::move(memory_manager)), _results_vector(), _reduction_kernels_vector(), _border_handlers_vector(), _reshape_kernel(), _op(), _num_of_stages(), _reduction_axis(), _is_serial(),
+      _is_reshape_required(false)
 {
 }
 
-Status CLReductionOperation::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
+Status CLReductionOperation::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op, bool keep_dims)
 {
-    const unsigned int num_of_stages = calculate_number_of_stages(input, axis);
-    bool               is_serial     = is_data_type_quantized(input->data_type()) || axis != 0;
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
+
+    const unsigned int num_of_stages       = calculate_number_of_stages(input, axis);
+    const bool         is_serial           = needs_serialized_reduction(op, input->data_type(), axis);
+    const bool         is_arg_min_max      = (op == ReductionOperation::ARG_IDX_MAX) || (op == ReductionOperation::ARG_IDX_MIN);
+    const bool         is_reshape_required = !keep_dims || is_arg_min_max;
+
+    if(is_reshape_required)
+    {
+        const TensorInfo expected_output_shape = output->clone()->set_tensor_shape(arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis, keep_dims));
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&expected_output_shape, output);
+    }
+
+    auto *output_internal = output;
+
+    TensorInfo output_before_reshape;
+    const auto input_shape        = input->tensor_shape();
+    const auto input_data_type    = input->data_type();
+    const auto input_num_channles = input->num_channels();
+    const auto input_qinfo        = input->quantization_info();
+    const auto output_data_type   = is_arg_min_max ? DataType::U32 : output->data_type();
+
+    auto initialize_tensorinfo = [](TensorInfo & ti, TensorShape shape, DataType data_type, int num_channels, QuantizationInfo qinfo)
+    {
+        ti.set_data_type(data_type).set_tensor_shape(shape).set_num_channels(num_channels).set_quantization_info(qinfo);
+    };
+
+    if(is_reshape_required)
+    {
+        auto shape_before_reshape = input_shape;
+        shape_before_reshape.set(axis, 1);
+        initialize_tensorinfo(output_before_reshape, shape_before_reshape, output_data_type, input_num_channles, input_qinfo);
+        output_internal = &output_before_reshape;
+    }
+
     if(is_serial)
     {
-        ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, output, axis, op));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(input, output_internal, axis, op));
     }
     else
     {
@@ -74,14 +111,13 @@
         std::vector<TensorInfo> sums_vector(num_of_stages - 1);
 
         // Create intermediate tensor info
-        TensorShape shape{ input->tensor_shape() };
+        TensorShape shape{ input_shape };
+
+        shape.set(0, ceil(shape.x() / 128.f));
 
         for(unsigned int i = 0; i < num_of_stages - 1; i++)
         {
-            shape.set(0, ceil(shape.x() / 128.f));
-            sums_vector[i].set_data_type(input->data_type());
-            sums_vector[i].set_tensor_shape(shape);
-            sums_vector[i].set_num_channels(input->num_channels());
+            initialize_tensorinfo(sums_vector[i], shape, input_data_type, input_num_channles, input_qinfo);
         }
 
         ReductionOperation first_kernel_op;
@@ -130,17 +166,72 @@
 
         // Validate ReductionOperation on the last stage
         const unsigned int last_stage = num_of_stages - 1;
-        ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(&sums_vector[last_stage - 1], output, axis, last_kernel_op, input->dimension(0)));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLReductionOperationKernel::validate(&sums_vector[last_stage - 1], output_internal, axis, last_kernel_op, input->dimension(0)));
+    }
+
+    if(is_reshape_required)
+    {
+        ARM_COMPUTE_RETURN_ON_ERROR(CLReshapeLayerKernel::validate(output_internal, output));
     }
 
     return Status{};
 }
 
-void CLReductionOperation::configure(ICLTensor *input, ICLTensor *output, unsigned int axis, ReductionOperation op)
+ICLTensor *CLReductionOperation::configure_intermediate_result_vector(ICLTensor *input, ICLTensor *output)
 {
-    _num_of_stages  = calculate_number_of_stages(input->info(), axis);
-    _reduction_axis = axis;
-    _is_serial      = is_data_type_quantized(input->info()->data_type()) || axis != 0;
+    if(!_is_reshape_required && _is_serial)
+    {
+        return output;
+    }
+
+    auto       intermediate_result_vector_size = _is_serial ? 1 : _num_of_stages;
+    const auto is_arg_min_max                  = (_op == ReductionOperation::ARG_IDX_MAX || _op == ReductionOperation::ARG_IDX_MIN);
+
+    if(!_is_reshape_required)
+    {
+        --intermediate_result_vector_size;
+    }
+
+    _results_vector.resize(intermediate_result_vector_size);
+    auto shape = input->info()->tensor_shape();
+
+    shape.set(_reduction_axis, _is_serial ? 1 : ceil(shape.x() / 128.f));
+
+    for(auto &v : _results_vector)
+    {
+        if(&v == &_results_vector.back() && _is_reshape_required)
+        {
+            shape.set(_reduction_axis, 1);
+        }
+        v.allocator()->init(input->info()->clone()->set_tensor_shape(shape));
+    }
+
+    if(is_arg_min_max)
+    {
+        _results_vector.back().info()->set_data_type(DataType::U32).set_is_resizable(true).reset_padding();
+    }
+
+    return _is_reshape_required ? &_results_vector.back() : output;
+}
+
+void CLReductionOperation::configure(ICLTensor *input, ICLTensor *output, unsigned int axis, ReductionOperation op, bool keep_dims)
+{
+    _op                       = op;
+    _num_of_stages            = calculate_number_of_stages(input->info(), axis);
+    _reduction_axis           = axis;
+    _is_serial                = needs_serialized_reduction(op, input->info()->data_type(), axis);
+    const bool is_arg_min_max = (op == ReductionOperation::ARG_IDX_MAX) || (op == ReductionOperation::ARG_IDX_MIN);
+    _is_reshape_required      = !keep_dims || is_arg_min_max;
+
+    auto *output_internal = configure_intermediate_result_vector(input, output);
+
+    // ArgMinMax might not give initialized output tensor, so initialize here.
+    if(_is_reshape_required)
+    {
+        const TensorShape output_shape     = arm_compute::misc::shape_calculator::compute_reduced_shape(input->info()->tensor_shape(), axis, false);
+        const auto        output_data_type = is_arg_min_max ? DataType::U32 : input->info()->data_type();
+        auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape).set_data_type(output_data_type).reset_padding().set_is_resizable(true));
+    }
 
     // Configure reduction operation kernels
     _reduction_kernels_vector.resize(_num_of_stages);
@@ -148,20 +239,16 @@
     // Create temporary tensors
     if(_is_serial)
     {
-        _reduction_kernels_vector[0].configure(input, output, axis, op, 0);
+        if(_is_reshape_required)
+        {
+            _memory_group.manage(&_results_vector.back());
+        }
+
+        _reduction_kernels_vector[0].configure(input, output_internal, axis, op, 0);
     }
     else
     {
         _border_handlers_vector.resize(_num_of_stages);
-        _results_vector.resize(_num_of_stages - 1);
-        TensorShape shape{ input->info()->tensor_shape() };
-        for(unsigned int i = 0; i < _num_of_stages - 1; i++)
-        {
-            shape.set(0, ceil(shape.x() / 128.f));
-            _results_vector[i].allocator()->init(input->info()->clone()->set_tensor_shape(shape));
-        }
-
-        // Apply ReductionOperation only on first kernel
         _memory_group.manage(&_results_vector[0]);
 
         ReductionOperation first_kernel_op;
@@ -262,10 +349,22 @@
         // Apply ReductionOperation on the last stage
         const unsigned int last_stage  = _num_of_stages - 1;
         const unsigned int input_width = input->info()->dimension(0);
-        _reduction_kernels_vector[last_stage].configure(&_results_vector[last_stage - 1], output, axis, last_kernel_op, input_width);
+
+        if(_is_reshape_required)
+        {
+            _memory_group.manage(&_results_vector.back());
+        }
+
+        _reduction_kernels_vector[last_stage].configure(&_results_vector[last_stage - 1], output_internal, axis, last_kernel_op, input_width);
         _border_handlers_vector[last_stage].configure(&_results_vector[last_stage - 1], _reduction_kernels_vector[last_stage].border_size(), BorderMode::CONSTANT, pixelValue);
         _results_vector[last_stage - 1].allocator()->allocate();
     }
+
+    if(_is_reshape_required)
+    {
+        _reshape_kernel.configure(&_results_vector.back(), output);
+        _results_vector.back().allocator()->allocate();
+    }
 }
 
 void CLReductionOperation::run()
@@ -284,4 +383,10 @@
             CLScheduler::get().enqueue(_reduction_kernels_vector[i], false);
         }
     }
+
+    if(_is_reshape_required)
+    {
+        CLScheduler::get().enqueue(_reshape_kernel, false);
+    }
 }
+} // namespace arm_compute
diff --git a/src/runtime/NEON/functions/NEArgMinMaxLayer.cpp b/src/runtime/NEON/functions/NEArgMinMaxLayer.cpp
index 6863bb0..ab2d6f0 100644
--- a/src/runtime/NEON/functions/NEArgMinMaxLayer.cpp
+++ b/src/runtime/NEON/functions/NEArgMinMaxLayer.cpp
@@ -23,47 +23,35 @@
  */
 
 #include "arm_compute/runtime/NEON/functions/NEArgMinMaxLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEReductionOperation.h"
 
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/ITensor.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Validate.h"
-#include "arm_compute/runtime/NEON/NEScheduler.h"
 
 namespace arm_compute
 {
 NEArgMinMaxLayer::NEArgMinMaxLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(std::move(memory_manager)), _reduction_kernel(), _fill_border_kernel(), _run_fill_border(false)
+    : _reduction_function(support::cpp14::make_unique<NEReductionOperation>())
 {
+    ARM_COMPUTE_UNUSED(memory_manager);
 }
 void NEArgMinMaxLayer::configure(ITensor *input, int axis, ITensor *output, const ReductionOperation &op)
 {
-    _reduction_kernel.configure(input, output, axis, op);
-
-    if(axis == 0)
-    {
-        _fill_border_kernel.configure(input, _reduction_kernel.border_size(), BorderMode::REPLICATE);
-        _run_fill_border = true;
-    }
+    _reduction_function->configure(input, output, axis, op, false);
 }
 
 Status NEArgMinMaxLayer::validate(const ITensorInfo *input, int axis, const ITensorInfo *output, const ReductionOperation &op)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(op != ReductionOperation::ARG_IDX_MAX && op != ReductionOperation::ARG_IDX_MIN, "Invalid operation");
-    ARM_COMPUTE_RETURN_ON_ERROR(NEReductionOperationKernel::validate(input, output, axis, op));
-    return Status{};
+    return NEReductionOperation::validate(input, output, axis, op, false);
 }
 
 void NEArgMinMaxLayer::run()
 {
-    MemoryGroupResourceScope scope_mg(_memory_group);
-
-    if(_run_fill_border)
-    {
-        NEScheduler::get().schedule(&_fill_border_kernel, Window::DimY);
-    }
-    NEScheduler::get().schedule(&_reduction_kernel, Window::DimY);
+    _reduction_function->run();
 }
 
 } // namespace arm_compute
\ No newline at end of file
diff --git a/src/runtime/NEON/functions/NEReductionOperation.cpp b/src/runtime/NEON/functions/NEReductionOperation.cpp
index dc6cf59..09cd765 100644
--- a/src/runtime/NEON/functions/NEReductionOperation.cpp
+++ b/src/runtime/NEON/functions/NEReductionOperation.cpp
@@ -24,6 +24,7 @@
 #include "arm_compute/runtime/NEON/functions/NEReductionOperation.h"
 
 #include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/runtime/NEON/NEScheduler.h"
 
 namespace arm_compute
@@ -52,25 +53,78 @@
 }
 } // namespace
 
-NEReductionOperation::NEReductionOperation()
-    : _reduction_kernel(), _fill_border_kernel(), _window_split(0), _reduction_axis()
+NEReductionOperation::NEReductionOperation(std::shared_ptr<IMemoryManager> memory_manager)
+    : _memory_group(memory_manager), _reduction_kernel(), _fill_border_kernel(), _reshape_kernel(), _output_internal(), _window_split(0), _reduction_axis(), _is_reshape_required(false)
 {
 }
 
-Status NEReductionOperation::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op)
+Status NEReductionOperation::validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op, bool keep_dims)
 {
-    ARM_COMPUTE_RETURN_ON_ERROR(NEReductionOperationKernel::validate(input, output, axis, op));
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Reduction axis greater than max number of dimensions");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 3, "Unsupported reduction axis");
+
+    const auto is_reshape_required = !keep_dims;
+
+    auto *output_internal = output;
+
+    TensorInfo info_before_reshape;
+
+    if(is_reshape_required)
+    {
+        const TensorInfo expected_output_shape = output->clone()->set_tensor_shape(arm_compute::misc::shape_calculator::compute_reduced_shape(input->tensor_shape(), axis, keep_dims));
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&expected_output_shape, output);
+
+        auto shape_before_reshape = input->tensor_shape();
+        shape_before_reshape.set(axis, 1);
+
+        const auto input_num_channles = input->num_channels();
+        const auto input_qinfo        = input->quantization_info();
+        const auto is_arg_min_max     = (op == ReductionOperation::ARG_IDX_MAX) || (op == ReductionOperation::ARG_IDX_MIN);
+        const auto output_data_type   = is_arg_min_max ? DataType::U32 : output->data_type();
+
+        info_before_reshape.set_data_type(output_data_type).set_tensor_shape(shape_before_reshape).set_num_channels(input_num_channles).set_quantization_info(input_qinfo);
+
+        output_internal = &info_before_reshape;
+    }
+
+    ARM_COMPUTE_RETURN_ON_ERROR(NEReductionOperationKernel::validate(input, output_internal, axis, op));
+
+    if(is_reshape_required)
+    {
+        ARM_COMPUTE_RETURN_ON_ERROR(NEReshapeLayerKernel::validate(output_internal, output));
+    }
 
     return Status{};
 }
 
-void NEReductionOperation::configure(ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op)
+void NEReductionOperation::configure(ITensor *input, ITensor *output, unsigned int axis, ReductionOperation op, bool keep_dims)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
-    ARM_COMPUTE_ERROR_THROW_ON(NEReductionOperation::validate(input->info(), output->info(), axis, op));
+
+    _is_reshape_required = !keep_dims;
+
+    auto      *output_internal = output;
+    const auto is_arg_min_max  = (op == ReductionOperation::ARG_IDX_MAX) || (op == ReductionOperation::ARG_IDX_MIN);
+
+    if(_is_reshape_required)
+    {
+        const auto output_internal_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->info()->tensor_shape(), axis);
+        const auto output_external_shape = arm_compute::misc::shape_calculator::compute_reduced_shape(input->info()->tensor_shape(), axis, false);
+        const auto output_data_type      = is_arg_min_max ? DataType::U32 : input->info()->data_type();
+        const auto num_channels          = input->info()->num_channels();
+        const auto qinfo                 = input->info()->quantization_info();
+
+        _output_internal.allocator()->init(input->info()->clone()->set_data_type(output_data_type).set_tensor_shape(output_internal_shape).reset_padding().set_is_resizable(true).set_num_channels(
+                                               num_channels).set_quantization_info(qinfo));
+        _memory_group.manage(&_output_internal);
+        output_internal = &_output_internal;
+        auto_init_if_empty(*output->info(), input->info()->clone()->set_data_type(output_data_type).set_tensor_shape(output_external_shape).reset_padding().set_is_resizable(true));
+    }
+
+    ARM_COMPUTE_ERROR_THROW_ON(NEReductionOperation::validate(input->info(), output->info(), axis, op, keep_dims));
 
     // Configure reduction kernel
-    _reduction_kernel.configure(input, output, axis, op);
+    _reduction_kernel.configure(input, output_internal, axis, op);
     _window_split   = reduction_window_split_dimension(axis);
     _reduction_axis = axis;
 
@@ -150,7 +204,13 @@
             default:
                 ARM_COMPUTE_ERROR("Reduction Operation unsupported");
         }
-        _fill_border_kernel.configure(input, fill_border_size, BorderMode::CONSTANT, pixelValue);
+        _fill_border_kernel.configure(input, fill_border_size, (is_arg_min_max ? BorderMode::REPLICATE : BorderMode::CONSTANT), pixelValue);
+    }
+
+    if(_is_reshape_required)
+    {
+        _reshape_kernel.configure(output_internal, output);
+        _output_internal.allocator()->allocate();
     }
 }
 
@@ -161,5 +221,9 @@
         NEScheduler::get().schedule(&_fill_border_kernel, Window::DimY);
     }
     NEScheduler::get().schedule(&_reduction_kernel, _window_split);
+    if(_is_reshape_required)
+    {
+        NEScheduler::get().schedule(&_reshape_kernel, Window::DimY);
+    }
 }
 } // namespace arm_compute