COMPMID-2685: [CL] Use Weights manager

Change-Id: Ia1818e6ecd9386e96378e64f14d02592fe3cdf0f
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1997
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
index 0452a23..91f722f 100644
--- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
@@ -25,6 +25,7 @@
 
 #include "arm_compute/core/Size2D.h"
 #include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/Cast.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
@@ -32,8 +33,10 @@
 
 #include <algorithm>
 
-using namespace arm_compute;
+namespace arm_compute
+{
 using namespace arm_compute::misc::shape_calculator;
+using namespace arm_compute::utils::cast;
 
 namespace
 {
@@ -77,9 +80,10 @@
 }
 
 CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
-    : _memory_group(memory_manager), _convert_weights(), _flatten_layer(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(),
-      _accumulate_biases_kernel(), _flatten_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _are_weights_converted(true), _are_weights_reshaped(true),
-      _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _is_prepared(false), _original_weights(nullptr)
+    : _memory_group(memory_manager), _weights_manager(weights_manager), _convert_weights(), _convert_weights_managed(), _reshape_weights_managed_function(), _flatten_layer(), _reshape_weights_function(),
+      _mm_gemm(memory_manager, weights_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _accumulate_biases_kernel(), _flatten_output(), _gemmlowp_output(), _converted_weights_output(),
+      _reshape_weights_output(), _are_weights_converted(true), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _is_prepared(false),
+      _original_weights(nullptr)
 {
 }
 void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights)
@@ -157,6 +161,11 @@
     _is_prepared           = fc_info.retain_internal_weights;
     _original_weights      = weights;
 
+    if(_weights_manager)
+    {
+        _weights_manager->manage(weights);
+    }
+
     // Configure gemmlowp output
     if(_is_quantized)
     {
@@ -199,21 +208,39 @@
     // Reshape weights if needed
     if(!_are_weights_reshaped)
     {
-        // Reshape the weights
-        _reshape_weights_kernel.configure(weights, &_reshape_weights_output);
-        weights_to_use = &_reshape_weights_output;
+        if(_weights_manager && _weights_manager->are_weights_managed(weights))
+        {
+            _reshape_weights_managed_function.configure(weights);
+            weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_reshape_weights_managed_function));
+        }
+        else
+        {
+            // Reshape the weights
+            _reshape_weights_function.configure(weights, &_reshape_weights_output);
+            weights_to_use = &_reshape_weights_output;
+        }
     }
 
     // Convert weights if needed
     if(_is_fc_after_conv && (input->info()->data_layout() != fc_info.weights_trained_layout))
     {
-        // Convert weights
-        _convert_weights.configure(weights_to_use,
-                                   &_converted_weights_output,
-                                   input->info()->tensor_shape(),
-                                   fc_info.weights_trained_layout);
+        if(_weights_manager && _weights_manager->are_weights_managed(weights_to_use))
+        {
+            _convert_weights_managed.configure(weights_to_use,
+                                               input->info()->tensor_shape(),
+                                               fc_info.weights_trained_layout);
+            weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_convert_weights_managed));
+        }
+        else
+        {
+            // Convert weights
+            _convert_weights.configure(weights_to_use,
+                                       &_converted_weights_output,
+                                       input->info()->tensor_shape(),
+                                       fc_info.weights_trained_layout);
 
-        weights_to_use         = &_converted_weights_output;
+            weights_to_use = &_converted_weights_output;
+        }
         _are_weights_converted = false;
     }
 
@@ -384,7 +411,10 @@
 {
     if(!_is_prepared)
     {
-        ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+        if(!_weights_manager)
+        {
+            ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+        }
 
         auto release_unused = [](CLTensor * w)
         {
@@ -401,22 +431,36 @@
         // Reshape of the weights if needed (happens only once)
         if(!_are_weights_reshaped)
         {
-            // Run reshape weights kernel and mark weights as unused
-            _reshape_weights_output.allocator()->allocate();
-            _reshape_weights_kernel.run();
+            if(_weights_manager && _weights_manager->are_weights_managed(_original_weights))
+            {
+                cur_weights = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->run(cur_weights, &_reshape_weights_managed_function));
+            }
+            else
+            {
+                // Run reshape weights kernel and mark weights as unused
+                _reshape_weights_output.allocator()->allocate();
+                _reshape_weights_function.run();
 
-            cur_weights->mark_as_unused();
-            cur_weights           = &_reshape_weights_output;
+                cur_weights->mark_as_unused();
+                cur_weights = &_reshape_weights_output;
+            }
             _are_weights_reshaped = true;
         }
 
         // Convert weights if needed (happens only once)
         if(!_are_weights_converted)
         {
-            _converted_weights_output.allocator()->allocate();
-            _convert_weights.run();
+            if(_weights_manager && _weights_manager->are_weights_managed(cur_weights))
+            {
+                _weights_manager->run(cur_weights, &_convert_weights_managed);
+            }
+            else
+            {
+                _converted_weights_output.allocator()->allocate();
+                _convert_weights.run();
+                cur_weights->mark_as_unused();
+            }
 
-            cur_weights->mark_as_unused();
             _are_weights_converted = true;
         }
 
@@ -436,3 +480,4 @@
         _is_prepared = true;
     }
 }
+} // namespace arm_compute
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 762b001..2a027d8 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -36,6 +36,7 @@
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/utils/helpers/float_ops.h"
+#include "arm_compute/core/utils/misc/Cast.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
 #include "arm_compute/runtime/ITensorAllocator.h"
@@ -44,12 +45,15 @@
 {
 using namespace arm_compute::misc::shape_calculator;
 using namespace arm_compute::cl_gemm;
+using namespace arm_compute::utils::cast;
 
-CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager)
+CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
     : _memory_group(std::move(memory_manager)),
+      _weights_manager(weights_manager),
       _mm_kernel(),
       _reshape_lhs_kernel(),
       _reshape_rhs_kernel(),
+      _reshape_rhs_kernel_managed(),
       _mm_reshaped_kernel(),
       _mm_reshaped_only_rhs_kernel(),
       _tmp_a(),
@@ -178,8 +182,12 @@
 
     GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
 
+    const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
+
+    // Manage intermediate buffers
     _memory_group.manage(&_tmp_a);
-    if(!_reshape_b_only_on_first_run)
+
+    if(!_reshape_b_only_on_first_run && use_mm_b)
     {
         _memory_group.manage(&_tmp_b);
     }
@@ -188,16 +196,26 @@
     _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, reinterpret_input_as_3d);
 
     // Configure transpose kernel
-    _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+    ICLTensor *reshaped_rhs = &_tmp_b;
+    if(_weights_manager && _weights_manager->are_weights_managed(b))
+    {
+        _reshape_rhs_kernel_managed.configure(b, rhs_info);
+        reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
+    }
+    else
+    {
+        _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+    }
 
     // Configure and tune matrix multiply kernel
-    _mm_kernel.configure(&_tmp_a, &_tmp_b, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
+    _mm_kernel.configure(&_tmp_a, reshaped_rhs, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
 
     CLScheduler::get().tune_kernel_static(_mm_kernel);
 
     // Allocate intermediate tensors
     _tmp_a.allocator()->allocate();
-    if(!_reshape_b_only_on_first_run)
+
+    if(!_reshape_b_only_on_first_run && use_mm_b)
     {
         _tmp_b.allocator()->allocate();
     }
@@ -228,12 +246,16 @@
     _reshape_lhs_kernel.set_target(gpu_target);
     _mm_kernel.set_target(gpu_target);
 
+    const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
+
     // Manage intermediate buffers
     _memory_group.manage(&_tmp_a);
-    if(!_reshape_b_only_on_first_run)
+
+    if(!_reshape_b_only_on_first_run && use_mm_b)
     {
         _memory_group.manage(&_tmp_b);
     }
+
     // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
 
     GEMMLHSMatrixInfo lhs_info{};
@@ -247,14 +269,25 @@
     std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
 
     _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
-    _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+
+    ICLTensor *reshaped_rhs = &_tmp_b;
+    if(_weights_manager && _weights_manager->are_weights_managed(b))
+    {
+        _reshape_rhs_kernel_managed.configure(b, rhs_info);
+        reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
+    }
+    else
+    {
+        _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+    }
 
     // Configure and tune matrix multiply kernel
-    _mm_reshaped_kernel.configure(&_tmp_a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
+    _mm_reshaped_kernel.configure(&_tmp_a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
 
     // Allocate intermediate tensors
     _tmp_a.allocator()->allocate();
-    if(!_reshape_b_only_on_first_run)
+
+    if(!_reshape_b_only_on_first_run && use_mm_b)
     {
         _tmp_b.allocator()->allocate();
     }
@@ -284,8 +317,10 @@
     // Set the target for the kernels
     _mm_kernel.set_target(gpu_target);
 
+    const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
+
     // Manage intermediate buffers
-    if(!_reshape_b_only_on_first_run)
+    if(!_reshape_b_only_on_first_run && use_mm_b)
     {
         _memory_group.manage(&_tmp_b);
     }
@@ -300,12 +335,21 @@
     // Configure lhs_info and rhs_info
     std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
 
-    _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+    ICLTensor *reshaped_rhs = &_tmp_b;
+    if(_weights_manager && _weights_manager->are_weights_managed(b))
+    {
+        _reshape_rhs_kernel_managed.configure(b, rhs_info);
+        reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
+    }
+    else
+    {
+        _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+    }
 
     // Configure and tune matrix multiply kernel
-    _mm_reshaped_only_rhs_kernel.configure(a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
+    _mm_reshaped_only_rhs_kernel.configure(a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
 
-    if(!_reshape_b_only_on_first_run)
+    if(!_reshape_b_only_on_first_run && use_mm_b)
     {
         _tmp_b.allocator()->allocate();
     }
@@ -607,7 +651,14 @@
             if(!_reshape_b_only_on_first_run)
             {
                 // Run transpose kernel
-                CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+                if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
+                {
+                    _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
+                }
+                else
+                {
+                    CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+                }
             }
 
             CLScheduler::get().enqueue(_mm_kernel, true);
@@ -621,7 +672,14 @@
             if(!_reshape_b_only_on_first_run)
             {
                 // Run transpose kernel
-                CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+                if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
+                {
+                    _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
+                }
+                else
+                {
+                    CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+                }
             }
 
             CLScheduler::get().enqueue(_mm_reshaped_kernel, true);
@@ -632,7 +690,14 @@
             if(!_reshape_b_only_on_first_run)
             {
                 // Run transpose kernel
-                CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+                if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
+                {
+                    _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
+                }
+                else
+                {
+                    CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+                }
             }
 
             CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, true);
@@ -651,10 +716,17 @@
     {
         if(_gemm_type != GEMMType::NATIVE && _reshape_b_only_on_first_run)
         {
-            // Run transpose kernel and mark original weights tensor as unused
-            _tmp_b.allocator()->allocate();
-            CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
-            _original_b->mark_as_unused();
+            if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
+            {
+                _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
+            }
+            else
+            {
+                // Run transpose kernel and mark original weights tensor as unused
+                _tmp_b.allocator()->allocate();
+                CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+                _original_b->mark_as_unused();
+            }
         }
         CLScheduler::get().queue().finish();
         _is_prepared = true;
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index 594c8ee..831f108 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -27,6 +27,7 @@
 #include "arm_compute/core/Size2D.h"
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/Cast.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
@@ -35,8 +36,10 @@
 #include <memory>
 #include <tuple>
 
-using namespace arm_compute;
+namespace arm_compute
+{
 using namespace arm_compute::misc::shape_calculator;
+using namespace arm_compute::utils::cast;
 
 CLConvolutionLayerReshapeWeights::CLConvolutionLayerReshapeWeights()
     : _weights_reshape_kernel()
@@ -90,9 +93,10 @@
     CLScheduler::get().enqueue(_weights_reshape_kernel);
 }
 
-CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(),
-      _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _skip_im2col(false), _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _is_prepared(false)
+CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+    : _memory_group(memory_manager), _weights_manager(weights_manager), _reshape_weights(), _reshape_weights_managed(), _im2col_kernel(), _mm_gemm(memory_manager, weights_manager),
+      _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _skip_im2col(false),
+      _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _is_prepared(false)
 {
 }
 
@@ -238,6 +242,7 @@
     const ICLTensor *biases_to_use = biases;
     bool             append_bias   = false;
 
+    ICLTensor *weights_to_use = &_weights_reshaped;
     if(num_groups != 1 && biases != nullptr)
     {
         // num_groups != 1 can only be for NCHW
@@ -245,11 +250,27 @@
         biases_to_use = nullptr;
         append_bias   = true;
 
-        _reshape_weights.configure(weights, biases, &_weights_reshaped, num_groups);
+        if(_weights_manager && _weights_manager->are_weights_managed(weights))
+        {
+            _reshape_weights_managed.configure(weights, biases, num_groups);
+            weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_reshape_weights_managed));
+        }
+        else
+        {
+            _reshape_weights.configure(weights, biases, &_weights_reshaped, num_groups);
+        }
     }
     else
     {
-        _reshape_weights.configure(weights, nullptr, &_weights_reshaped, num_groups);
+        if(_weights_manager && _weights_manager->are_weights_managed(weights))
+        {
+            _reshape_weights_managed.configure(weights, nullptr, num_groups);
+            weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_reshape_weights_managed));
+        }
+        else
+        {
+            _reshape_weights.configure(weights, nullptr, &_weights_reshaped, num_groups);
+        }
     }
 
     // Create tensor to store im2col reshaped inputs
@@ -340,7 +361,7 @@
     // In case of NHWC, we need to run GEMM3D (gemm_3d_depth != 0) in order to avoid reshaping the output matrix
     const unsigned int gemm_3d_depth = (data_layout == DataLayout::NHWC) ? conv_h : 0;
 
-    configure_mm(gemm_input_to_use, &_weights_reshaped, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, act_info);
+    configure_mm(gemm_input_to_use, weights_to_use, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, act_info);
 
     if(!_skip_im2col)
     {
@@ -601,10 +622,18 @@
 {
     if(!_is_prepared)
     {
-        // Run weights reshaping and mark original weights tensor as unused
-        _weights_reshaped.allocator()->allocate();
-        _reshape_weights.run();
-        _original_weights->mark_as_unused();
+        ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+        if(_weights_manager && _weights_manager->are_weights_managed(_original_weights))
+        {
+            _weights_manager->run(_original_weights, &_reshape_weights_managed);
+        }
+        else
+        {
+            // Run weights reshaping and mark original weights tensor as unused
+            _weights_reshaped.allocator()->allocate();
+            _reshape_weights.run();
+            _original_weights->mark_as_unused();
+        }
 
         // Prepare GEMM
         _is_quantized ? _mm_gemmlowp.prepare() : _mm_gemm.prepare();
@@ -617,3 +646,4 @@
         _is_prepared = true;
     }
 }
+} // namespace arm_compute