diff --git a/arm_compute/runtime/CL/functions/CLArgMinMaxLayer.h b/arm_compute/runtime/CL/functions/CLArgMinMaxLayer.h
index 2384ebc..28feee0 100644
--- a/arm_compute/runtime/CL/functions/CLArgMinMaxLayer.h
+++ b/arm_compute/runtime/CL/functions/CLArgMinMaxLayer.h
@@ -24,13 +24,16 @@
 #ifndef __ARM_COMPUTE_CLARGMINMAXLAYER_H__
 #define __ARM_COMPUTE_CLARGMINMAXLAYER_H__
 
-#include "arm_compute/core/CL/kernels/CLReductionOperationKernel.h"
 #include "arm_compute/core/Types.h"
-#include "arm_compute/runtime/CL/ICLSimpleFunction.h"
+#include "arm_compute/runtime/IFunction.h"
+#include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/MemoryGroup.h"
 
 namespace arm_compute
 {
+class ITensorInfo;
 class ICLTensor;
+class CLReductionOperation;
 
 /** Function to calculate the index of the minimum or maximum values in a
  *  tensor based on an axis.
@@ -39,17 +42,23 @@
  *       responsibility to check that the results do not overflow in case the
  *       output data type is set to signed 32-bit integer (S32).
  */
-class CLArgMinMaxLayer : public ICLSimpleFunction
+class CLArgMinMaxLayer : public IFunction
 {
 public:
+    /** Default Constructor.
+     *
+     * @param[in] memory_manager (Optional) Memory manager.
+     */
+    CLArgMinMaxLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
     /** Set the input and output tensors.
      *
-     * @param[in]  input  Input source tensor. Data types supported: F16/F32.
+     * @param[in]  input  Input source tensor, this could be written if @ref CLReductionOperation
+     *                    manipulates its border for better performance. Data types supported: F16/F32.
      * @param[in]  axis   Axis to find max/min index.
      * @param[out] output Output source tensor. Data types supported: U32/S32.
      * @param[in]  op     Operation to perform: min or max
      */
-    void configure(const ICLTensor *input, int axis, ICLTensor *output, const ReductionOperation &op);
+    void configure(ICLTensor *input, int axis, ICLTensor *output, const ReductionOperation &op);
     /** Static function to check if given info will lead to a valid configuration of @ref CLArgMinMaxLayer
      *
      * @param[in] input  Input source tensor info. Data types supported: F16/F32.
@@ -60,6 +69,12 @@
      * @return a status
      */
     static Status validate(const ITensorInfo *input, int axis, const ITensorInfo *output, const ReductionOperation &op);
+
+    // Inherited methods overridden:
+    void run() override;
+
+private:
+    std::unique_ptr<CLReductionOperation> _reduction_function;
 };
 } // namespace arm_compute
 #endif /* __ARM_COMPUTE_CLARGMINMAXLAYER_H__ */
diff --git a/arm_compute/runtime/CL/functions/CLReductionOperation.h b/arm_compute/runtime/CL/functions/CLReductionOperation.h
index f71313f..405e117 100644
--- a/arm_compute/runtime/CL/functions/CLReductionOperation.h
+++ b/arm_compute/runtime/CL/functions/CLReductionOperation.h
@@ -26,6 +26,7 @@
 
 #include "arm_compute/core/CL/kernels/CLFillBorderKernel.h"
 #include "arm_compute/core/CL/kernels/CLReductionOperationKernel.h"
+#include "arm_compute/core/CL/kernels/CLReshapeLayerKernel.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/runtime/CL/CLTensor.h"
 #include "arm_compute/runtime/IFunction.h"
@@ -53,35 +54,42 @@
 
     /** Set the input and output tensors.
      *
-     * @param[in]  input  Source tensor. Data types supported: QASYMM8/F16/F32.
-     * @param[out] output Destination tensor. Data types and data layouts supported: Same as @p input.
-     * @param[in]  axis   Axis along which to reduce. Supported reduction axis : 0, 1, 2, 3
-     * @param[in]  op     Reduction operation to perform.
+     * @param[in]  input     Source tensor. Data types supported: QASYMM8/F16/F32.
+     * @param[out] output    Destination tensor. Data types and data layouts supported: Same as @p input.
+     * @param[in]  axis      Axis along which to reduce. Supported reduction axis : 0, 1, 2, 3
+     * @param[in]  op        Reduction operation to perform.
+     * @param[in]  keep_dims (Optional) Whether to keep the reduced dimension after the operation. Defaults to true.
      */
-    void configure(ICLTensor *input, ICLTensor *output, unsigned int axis, ReductionOperation op);
+    void configure(ICLTensor *input, ICLTensor *output, unsigned int axis, ReductionOperation op, bool keep_dims = true);
 
     /** Static function to check if given info will lead to a valid configuration of @ref CLReductionOperation.
      *
-     * @param[in] input  Source tensor info. Data types supported: QASYMM8/F16/F32.
-     * @param[in] output Destination tensor info. Data types and data layouts supported: Same as @p input.
-     * @param[in] axis   Axis along which to reduce. Supported reduction axis : 0, 1, 2, 3
-     * @param[in] op     Reduction operation to perform.
+     * @param[in] input     Source tensor info. Data types supported: QASYMM8/F16/F32.
+     * @param[in] output    Destination tensor info. Data types and data layouts supported: Same as @p input.
+     * @param[in] axis      Axis along which to reduce. Supported reduction axis : 0, 1, 2, 3
+     * @param[in] op        Reduction operation to perform.
+     * @param[in] keep_dims (Optional) Whether to keep the reduced dimension after the operation. Defaults to true.
      *
      * @return a status
      */
-    static Status validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op);
+    static Status validate(const ITensorInfo *input, const ITensorInfo *output, unsigned int axis, ReductionOperation op, bool keep_dims = true);
 
     // Inherited methods overridden:
     void run() override;
 
 private:
+    ICLTensor *configure_intermediate_result_vector(ICLTensor *input, ICLTensor *output);
+
     MemoryGroup                             _memory_group;
     std::vector<CLTensor>                   _results_vector;
     std::vector<CLReductionOperationKernel> _reduction_kernels_vector;
     std::vector<CLFillBorderKernel>         _border_handlers_vector;
+    CLReshapeLayerKernel                    _reshape_kernel;
+    ReductionOperation                      _op;
     unsigned int                            _num_of_stages;
     unsigned int                            _reduction_axis;
     bool                                    _is_serial;
+    bool                                    _is_reshape_required;
 };
 } // namespace arm_compute
 #endif /*__ARM_COMPUTE_CLREDUCTIONOPERATION_H__ */
