COMPMID-2685: [CL] Use Weights manager

Change-Id: Ia1818e6ecd9386e96378e64f14d02592fe3cdf0f
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1997
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h
index b8e5fa6..3691fe9 100644
--- a/arm_compute/runtime/CL/functions/CLGEMM.h
+++ b/arm_compute/runtime/CL/functions/CLGEMM.h
@@ -32,12 +32,62 @@
 #include "arm_compute/runtime/CL/CLTensor.h"
 #include "arm_compute/runtime/IFunction.h"
 #include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/IWeightsManager.h"
 #include "arm_compute/runtime/MemoryGroup.h"
 
 namespace arm_compute
 {
 class ICLTensor;
 
+namespace weights_transformations
+{
+/** Basic function to manage the reshape weights generated from @ref CLGEMMReshapeRHSMatrixKernel */
+class CLGEMMReshapeRHSMatrixKernelManaged : public ITransformWeights
+{
+public:
+    //Inherited method override
+    void run() override
+    {
+        _output.allocator()->allocate();
+        CLScheduler::get().enqueue(_kernel, false);
+        _reshape_run = true;
+    }
+
+    //Inherited method override
+    void release() override
+    {
+        _output.allocator()->free();
+    }
+
+    //Inherited method override
+    ICLTensor *get_weights() override
+    {
+        return &_output;
+    }
+
+    //Inherited method override
+    uint32_t uid() override
+    {
+        return _uid;
+    }
+
+    /** Configures the @ref CLGEMMReshapeRHSMatrixKernel kernel
+     *
+     * @param[in] input Input tensor. Data types supported: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
+     * @param[in] info  RHS matrix information to be used for reshaping.
+     */
+    void configure(const ICLTensor *input, GEMMRHSMatrixInfo info)
+    {
+        _kernel.configure(input, &_output, info);
+    }
+
+private:
+    static constexpr uint32_t    _uid = 0x15;
+    CLTensor                     _output{};
+    CLGEMMReshapeRHSMatrixKernel _kernel{};
+};
+} // namespace weights_transformations
+
 /** Basic function to execute GEMM on OpenCL. This function calls the following OpenCL kernels:
  *
  *  -# @ref CLGEMMReshapeLHSMatrixKernel (only if the RESHAPED_V1 is selected by the heuristic model)
@@ -52,9 +102,10 @@
 public:
     /** Default constructor.
      *
-     * @param[in] memory_manager (Optional) Memory manager.
+     * @param[in] memory_manager  (Optional) Memory manager.
+     * @param[in] weights_manager (Optional) Weights manager.
      */
-    CLGEMM(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
+    CLGEMM(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
     /** Prevent instances of this class from being copied (As this class contains pointers) */
     CLGEMM(const CLGEMM &) = delete;
     /** Default move constructor */
@@ -123,18 +174,20 @@
     static Status validate_reshaped_v2(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
     static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
 
-    MemoryGroup                               _memory_group;
-    CLGEMMMatrixMultiplyKernel                _mm_kernel;
-    CLGEMMReshapeLHSMatrixKernel              _reshape_lhs_kernel;
-    CLGEMMReshapeRHSMatrixKernel              _reshape_rhs_kernel;
-    CLGEMMMatrixMultiplyReshapedKernel        _mm_reshaped_kernel;
-    CLGEMMMatrixMultiplyReshapedOnlyRHSKernel _mm_reshaped_only_rhs_kernel;
-    CLTensor                                  _tmp_a;
-    CLTensor                                  _tmp_b;
-    const ICLTensor                          *_original_b;
-    bool                                      _reshape_b_only_on_first_run;
-    bool                                      _is_prepared;
-    GEMMType                                  _gemm_type;
+    MemoryGroup                                                  _memory_group;
+    IWeightsManager                                             *_weights_manager;
+    CLGEMMMatrixMultiplyKernel                                   _mm_kernel;
+    CLGEMMReshapeLHSMatrixKernel                                 _reshape_lhs_kernel;
+    CLGEMMReshapeRHSMatrixKernel                                 _reshape_rhs_kernel;
+    weights_transformations::CLGEMMReshapeRHSMatrixKernelManaged _reshape_rhs_kernel_managed;
+    CLGEMMMatrixMultiplyReshapedKernel                           _mm_reshaped_kernel;
+    CLGEMMMatrixMultiplyReshapedOnlyRHSKernel                    _mm_reshaped_only_rhs_kernel;
+    CLTensor                                                     _tmp_a;
+    CLTensor                                                     _tmp_b;
+    const ICLTensor                                             *_original_b;
+    bool                                                         _reshape_b_only_on_first_run;
+    bool                                                         _is_prepared;
+    GEMMType                                                     _gemm_type;
 };
 } // namespace arm_compute