COMPMID-2685: [CL] Use Weights manager

Change-Id: Ia1818e6ecd9386e96378e64f14d02592fe3cdf0f
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1997
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h b/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h
index 43abb67..e4e6f07 100644
--- a/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h
+++ b/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h
@@ -25,7 +25,9 @@
 #define __ARM_COMPUTE_CLCONVERTFULLYCONNECTEDWEIGHTS_H__
 
 #include "arm_compute/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.h"
+#include "arm_compute/runtime/CL/CLTensor.h"
 #include "arm_compute/runtime/CL/ICLSimpleFunction.h"
+#include "arm_compute/runtime/ITransformWeights.h"
 
 namespace arm_compute
 {
@@ -54,5 +56,54 @@
      */
     static Status validate(const ITensorInfo *input, const ITensorInfo *output, const TensorShape &original_input_shape, DataLayout data_layout);
 };
+
+namespace weights_transformations
+{
+/** Basic function to run @ref CLConvertFullyConnectedWeightsKernel. */
+class CLConvertFullyConnectedWeightsManaged : public ITransformWeights
+{
+public:
+    //Inherited method override
+    void run() override
+    {
+        _output.allocator()->allocate();
+        _func.run();
+        _reshape_run = true;
+    }
+
+    //Inherited method override
+    void release() override
+    {
+        _output.allocator()->free();
+    }
+
+    //Inherited method override
+    ICLTensor *get_weights() override
+    {
+        return &_output;
+    }
+
+    //Inherited method override
+    uint32_t uid() override
+    {
+        return _uid;
+    }
+    /** Configures the @ref CLConvertFullyConnectedWeights function
+     *
+     * @param[in] input                Source weights tensor info to convert.  Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32.
+     * @param[in] original_input_shape Shape of the original input tensor (the one entering fully connected layer).
+     * @param[in] data_layout          The data layout the weights have been trained in.
+     */
+    void configure(const ICLTensor *input, const TensorShape &original_input_shape, DataLayout data_layout)
+    {
+        _func.configure(input, &_output, original_input_shape, data_layout);
+    }
+
+private:
+    static constexpr uint32_t      _uid = 0x5;
+    CLTensor                       _output{};
+    CLConvertFullyConnectedWeights _func{};
+};
+} // namespace weights_transformations
 } // namespace arm_compute
 #endif /* __ARM_COMPUTE_CLCONVERTFULLYCONNECTEDWEIGHTS_H__ */
diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
index d54304e..9512b22 100644
--- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
+++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
@@ -64,6 +64,54 @@
     static Status validate(const ITensorInfo *input, const ITensorInfo *output);
 };
 
+namespace weights_transformations
+{
+/** Basic function to manage the reshape weights generated from @ref CLFullyConnectedLayerReshapeWeights */
+class CLFullyConnectedLayerReshapeWeightsManaged : public ITransformWeights
+{
+public:
+    //Inherited method override
+    void run() override
+    {
+        _output.allocator()->allocate();
+        _func.run();
+        _reshape_run = true;
+    }
+
+    //Inherited method override
+    void release() override
+    {
+        _output.allocator()->free();
+    }
+
+    //Inherited method override
+    ICLTensor *get_weights() override
+    {
+        return &_output;
+    }
+
+    //Inherited method override
+    uint32_t uid() override
+    {
+        return _uid;
+    }
+
+    /** Configures the @ref CLFullyConnectedLayerReshapeWeights function
+     *
+     * @param[in] input Source tensor. Data type supported: QASYMM8/F16/F32.
+     */
+    void configure(const ICLTensor *input)
+    {
+        _func.configure(input, &_output);
+    }
+
+private:
+    static constexpr uint32_t           _uid = 0x0;
+    CLTensor                            _output{};
+    CLFullyConnectedLayerReshapeWeights _func{};
+};
+} // namespace weights_transformations
+
 /** Basic function to compute a Fully Connected layer on OpenCL. This function calls the following OpenCL kernels:
  *
  *  -# @ref CLIm2ColKernel (called when the input comes from a convolutional layer)
@@ -130,25 +178,28 @@
     void configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights);
     void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights);
 
-    MemoryGroup                                         _memory_group;
-    CLConvertFullyConnectedWeights                      _convert_weights;
-    CLFlattenLayer                                      _flatten_layer;
-    CLFullyConnectedLayerReshapeWeights                 _reshape_weights_kernel;
-    CLGEMM                                              _mm_gemm;
-    CLGEMMLowpMatrixMultiplyCore                        _mm_gemmlowp;
-    CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage;
-    CLGEMMMatrixAccumulateBiasesKernel                  _accumulate_biases_kernel; // TODO(COMPMID-1889): Use CLGEMM to add bias in CLFullyConnectedLayer
-    CLTensor                                            _flatten_output;
-    CLTensor                                            _gemmlowp_output;
-    CLTensor                                            _converted_weights_output;
-    CLTensor                                            _reshape_weights_output;
-    bool                                                _are_weights_converted;
-    bool                                                _are_weights_reshaped;
-    bool                                                _is_fc_after_conv;
-    bool                                                _accumulate_biases;
-    bool                                                _is_quantized;
-    bool                                                _is_prepared;
-    const ICLTensor                                    *_original_weights;
+    MemoryGroup                                                         _memory_group;
+    IWeightsManager                                                    *_weights_manager;
+    CLConvertFullyConnectedWeights                                      _convert_weights;
+    weights_transformations::CLConvertFullyConnectedWeightsManaged      _convert_weights_managed;
+    weights_transformations::CLFullyConnectedLayerReshapeWeightsManaged _reshape_weights_managed_function;
+    CLFlattenLayer                                                      _flatten_layer;
+    CLFullyConnectedLayerReshapeWeights                                 _reshape_weights_function;
+    CLGEMM                                                              _mm_gemm;
+    CLGEMMLowpMatrixMultiplyCore                                        _mm_gemmlowp;
+    CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint                 _gemmlowp_output_stage;
+    CLGEMMMatrixAccumulateBiasesKernel                                  _accumulate_biases_kernel; // TODO(COMPMID-1889): Use CLGEMM to add bias in CLFullyConnectedLayer
+    CLTensor                                                            _flatten_output;
+    CLTensor                                                            _gemmlowp_output;
+    CLTensor                                                            _converted_weights_output;
+    CLTensor                                                            _reshape_weights_output;
+    bool                                                                _are_weights_converted;
+    bool                                                                _are_weights_reshaped;
+    bool                                                                _is_fc_after_conv;
+    bool                                                                _accumulate_biases;
+    bool                                                                _is_quantized;
+    bool                                                                _is_prepared;
+    const ICLTensor                                                    *_original_weights;
 };
 } // namespace arm_compute
 #endif /* __ARM_COMPUTE_CLFULLYCONNECTEDLAYER_H__ */
diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h
index b8e5fa6..3691fe9 100644
--- a/arm_compute/runtime/CL/functions/CLGEMM.h
+++ b/arm_compute/runtime/CL/functions/CLGEMM.h
@@ -32,12 +32,62 @@
 #include "arm_compute/runtime/CL/CLTensor.h"
 #include "arm_compute/runtime/IFunction.h"
 #include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/IWeightsManager.h"
 #include "arm_compute/runtime/MemoryGroup.h"
 
 namespace arm_compute
 {
 class ICLTensor;
 
+namespace weights_transformations
+{
+/** Basic function to manage the reshape weights generated from @ref CLGEMMReshapeRHSMatrixKernel */
+class CLGEMMReshapeRHSMatrixKernelManaged : public ITransformWeights
+{
+public:
+    //Inherited method override
+    void run() override
+    {
+        _output.allocator()->allocate();
+        CLScheduler::get().enqueue(_kernel, false);
+        _reshape_run = true;
+    }
+
+    //Inherited method override
+    void release() override
+    {
+        _output.allocator()->free();
+    }
+
+    //Inherited method override
+    ICLTensor *get_weights() override
+    {
+        return &_output;
+    }
+
+    //Inherited method override
+    uint32_t uid() override
+    {
+        return _uid;
+    }
+
+    /** Configures the @ref CLGEMMReshapeRHSMatrixKernel kernel
+     *
+     * @param[in] input Input tensor. Data types supported: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
+     * @param[in] info  RHS matrix information to be used for reshaping.
+     */
+    void configure(const ICLTensor *input, GEMMRHSMatrixInfo info)
+    {
+        _kernel.configure(input, &_output, info);
+    }
+
+private:
+    static constexpr uint32_t    _uid = 0x15;
+    CLTensor                     _output{};
+    CLGEMMReshapeRHSMatrixKernel _kernel{};
+};
+} // namespace weights_transformations
+
 /** Basic function to execute GEMM on OpenCL. This function calls the following OpenCL kernels:
  *
  *  -# @ref CLGEMMReshapeLHSMatrixKernel (only if the RESHAPED_V1 is selected by the heuristic model)
@@ -52,9 +102,10 @@
 public:
     /** Default constructor.
      *
-     * @param[in] memory_manager (Optional) Memory manager.
+     * @param[in] memory_manager  (Optional) Memory manager.
+     * @param[in] weights_manager (Optional) Weights manager.
      */
-    CLGEMM(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
+    CLGEMM(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
     /** Prevent instances of this class from being copied (As this class contains pointers) */
     CLGEMM(const CLGEMM &) = delete;
     /** Default move constructor */
@@ -123,18 +174,20 @@
     static Status validate_reshaped_v2(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
     static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
 
-    MemoryGroup                               _memory_group;
-    CLGEMMMatrixMultiplyKernel                _mm_kernel;
-    CLGEMMReshapeLHSMatrixKernel              _reshape_lhs_kernel;
-    CLGEMMReshapeRHSMatrixKernel              _reshape_rhs_kernel;
-    CLGEMMMatrixMultiplyReshapedKernel        _mm_reshaped_kernel;
-    CLGEMMMatrixMultiplyReshapedOnlyRHSKernel _mm_reshaped_only_rhs_kernel;
-    CLTensor                                  _tmp_a;
-    CLTensor                                  _tmp_b;
-    const ICLTensor                          *_original_b;
-    bool                                      _reshape_b_only_on_first_run;
-    bool                                      _is_prepared;
-    GEMMType                                  _gemm_type;
+    MemoryGroup                                                  _memory_group;
+    IWeightsManager                                             *_weights_manager;
+    CLGEMMMatrixMultiplyKernel                                   _mm_kernel;
+    CLGEMMReshapeLHSMatrixKernel                                 _reshape_lhs_kernel;
+    CLGEMMReshapeRHSMatrixKernel                                 _reshape_rhs_kernel;
+    weights_transformations::CLGEMMReshapeRHSMatrixKernelManaged _reshape_rhs_kernel_managed;
+    CLGEMMMatrixMultiplyReshapedKernel                           _mm_reshaped_kernel;
+    CLGEMMMatrixMultiplyReshapedOnlyRHSKernel                    _mm_reshaped_only_rhs_kernel;
+    CLTensor                                                     _tmp_a;
+    CLTensor                                                     _tmp_b;
+    const ICLTensor                                             *_original_b;
+    bool                                                         _reshape_b_only_on_first_run;
+    bool                                                         _is_prepared;
+    GEMMType                                                     _gemm_type;
 };
 } // namespace arm_compute
 
diff --git a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
index 0b27c82..017bf78 100644
--- a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
+++ b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
@@ -39,6 +39,8 @@
 #include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h"
 #include "arm_compute/runtime/CL/functions/CLReshapeLayer.h"
 #include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/ITransformWeights.h"
+#include "arm_compute/runtime/IWeightsManager.h"
 #include "arm_compute/runtime/MemoryGroup.h"
 
 #include <memory>
@@ -82,6 +84,59 @@
     CLWeightsReshapeKernel _weights_reshape_kernel;
 };
 
+namespace weights_transformations
+{
+/** Basic function to manage the reshape weights generated from @ref CLConvolutionLayerReshapeWeights */
+class CLConvolutionLayerReshapeWeightsTransform : public ITransformWeights
+{
+public:
+    /** Configures the @ref CLConvolutionLayerReshapeWeights function
+     *
+     * @param[in] input      Input tensor. Data type supported: QASYMM8/F16/F32.
+     * @param[in] biases     Biases tensor. Data type supported: Same as @p input.
+     * @param[in] num_groups Number of groups when performing a grouped convolution.
+     */
+    void configure(const ICLTensor *input, const ICLTensor *biases, unsigned int num_groups)
+    {
+        _bias_bit   = (biases != nullptr) ? 1 : 0;
+        _num_groups = num_groups;
+        _func.configure(input, biases, &_output, num_groups);
+    }
+
+    //Inherited method override
+    void run() override
+    {
+        _output.allocator()->allocate();
+        _func.run();
+        _reshape_run = true;
+    }
+
+    //Inherited method override
+    ICLTensor *get_weights() override
+    {
+        return &_output;
+    }
+
+    //Inherited method override
+    void release() override
+    {
+        _output.allocator()->free();
+    }
+
+    //Inherited method override
+    uint32_t uid() override
+    {
+        return ((0x9) | (_bias_bit << 7) | (_num_groups << 8));
+    }
+
+private:
+    CLTensor                         _output{};
+    CLConvolutionLayerReshapeWeights _func{};
+    int32_t                          _bias_bit{ 0 };
+    unsigned int                     _num_groups{ 0 };
+};
+} // namespace weights_transformations
+
 /** Basic function to compute the convolution layer. This function calls the following OpenCL kernels/functions:
  *
  * -# @ref CLIm2ColKernel
@@ -96,9 +151,10 @@
 public:
     /** Constructor
      *
-     * @param[in] memory_manager (Optional) Memory manager.
+     * @param[in] memory_manager  (Optional) Memory manager.
+     * @param[in] weights_manager (Optional) Weights manager.
      */
-    CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
+    CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
     /** Prevent instances of this class from being copied (As this class contains pointers) */
     CLGEMMConvolutionLayer(const CLGEMMConvolutionLayer &) = delete;
     /** Default move constructor */
@@ -186,13 +242,15 @@
                               int gemm_3d_depth, bool skip_im2col, const ActivationLayerInfo &act_info);
 
 private:
-    MemoryGroup                      _memory_group;
-    CLConvolutionLayerReshapeWeights _reshape_weights;
-    CLIm2ColKernel                   _im2col_kernel;
-    CLGEMM                           _mm_gemm;
-    CLGEMMLowpMatrixMultiplyCore     _mm_gemmlowp;
-    CLCol2ImKernel                   _col2im_kernel;
-    CLActivationLayer                _activationlayer_function;
+    MemoryGroup                                                        _memory_group;
+    IWeightsManager                                                   *_weights_manager;
+    CLConvolutionLayerReshapeWeights                                   _reshape_weights;
+    weights_transformations::CLConvolutionLayerReshapeWeightsTransform _reshape_weights_managed;
+    CLIm2ColKernel                                                     _im2col_kernel;
+    CLGEMM                                                             _mm_gemm;
+    CLGEMMLowpMatrixMultiplyCore                                       _mm_gemmlowp;
+    CLCol2ImKernel                                                     _col2im_kernel;
+    CLActivationLayer                                                  _activationlayer_function;
 
     const ICLTensor *_original_weights;