COMPMID-2685: [CL] Use Weights manager

Change-Id: Ia1818e6ecd9386e96378e64f14d02592fe3cdf0f
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1997
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index 594c8ee..831f108 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -27,6 +27,7 @@
 #include "arm_compute/core/Size2D.h"
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/Cast.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
@@ -35,8 +36,10 @@
 #include <memory>
 #include <tuple>
 
-using namespace arm_compute;
+namespace arm_compute
+{
 using namespace arm_compute::misc::shape_calculator;
+using namespace arm_compute::utils::cast;
 
 CLConvolutionLayerReshapeWeights::CLConvolutionLayerReshapeWeights()
     : _weights_reshape_kernel()
@@ -90,9 +93,10 @@
     CLScheduler::get().enqueue(_weights_reshape_kernel);
 }
 
-CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(),
-      _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _skip_im2col(false), _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _is_prepared(false)
+CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+    : _memory_group(memory_manager), _weights_manager(weights_manager), _reshape_weights(), _reshape_weights_managed(), _im2col_kernel(), _mm_gemm(memory_manager, weights_manager),
+      _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _skip_im2col(false),
+      _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _is_prepared(false)
 {
 }
 
@@ -238,6 +242,7 @@
     const ICLTensor *biases_to_use = biases;
     bool             append_bias   = false;
 
+    ICLTensor *weights_to_use = &_weights_reshaped;
     if(num_groups != 1 && biases != nullptr)
     {
         // num_groups != 1 can only be for NCHW
@@ -245,11 +250,27 @@
         biases_to_use = nullptr;
         append_bias   = true;
 
-        _reshape_weights.configure(weights, biases, &_weights_reshaped, num_groups);
+        if(_weights_manager && _weights_manager->are_weights_managed(weights))
+        {
+            _reshape_weights_managed.configure(weights, biases, num_groups);
+            weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_reshape_weights_managed));
+        }
+        else
+        {
+            _reshape_weights.configure(weights, biases, &_weights_reshaped, num_groups);
+        }
     }
     else
     {
-        _reshape_weights.configure(weights, nullptr, &_weights_reshaped, num_groups);
+        if(_weights_manager && _weights_manager->are_weights_managed(weights))
+        {
+            _reshape_weights_managed.configure(weights, nullptr, num_groups);
+            weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_reshape_weights_managed));
+        }
+        else
+        {
+            _reshape_weights.configure(weights, nullptr, &_weights_reshaped, num_groups);
+        }
     }
 
     // Create tensor to store im2col reshaped inputs
@@ -340,7 +361,7 @@
     // In case of NHWC, we need to run GEMM3D (gemm_3d_depth != 0) in order to avoid reshaping the output matrix
     const unsigned int gemm_3d_depth = (data_layout == DataLayout::NHWC) ? conv_h : 0;
 
-    configure_mm(gemm_input_to_use, &_weights_reshaped, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, act_info);
+    configure_mm(gemm_input_to_use, weights_to_use, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, act_info);
 
     if(!_skip_im2col)
     {
@@ -601,10 +622,18 @@
 {
     if(!_is_prepared)
     {
-        // Run weights reshaping and mark original weights tensor as unused
-        _weights_reshaped.allocator()->allocate();
-        _reshape_weights.run();
-        _original_weights->mark_as_unused();
+        ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+        if(_weights_manager && _weights_manager->are_weights_managed(_original_weights))
+        {
+            _weights_manager->run(_original_weights, &_reshape_weights_managed);
+        }
+        else
+        {
+            // Run weights reshaping and mark original weights tensor as unused
+            _weights_reshaped.allocator()->allocate();
+            _reshape_weights.run();
+            _original_weights->mark_as_unused();
+        }
 
         // Prepare GEMM
         _is_quantized ? _mm_gemmlowp.prepare() : _mm_gemm.prepare();
@@ -617,3 +646,4 @@
         _is_prepared = true;
     }
 }
+} // namespace arm_compute