COMPMID-2161 [NEON] Create IWeightManager class

Change-Id: I1a9a46da2f98e896b825099151b56d1d8271dd31
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1915
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/graph/GraphContext.cpp b/src/graph/GraphContext.cpp
index 037b40b..c959d5e 100644
--- a/src/graph/GraphContext.cpp
+++ b/src/graph/GraphContext.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -32,13 +32,14 @@
 namespace graph
 {
 GraphContext::GraphContext()
-    : _config(), _memory_managers()
+    : _config(), _memory_managers(), _weights_managers()
 {
 }
 
 GraphContext::~GraphContext()
 {
     _memory_managers.clear();
+    _weights_managers.clear();
     release_default_graph_context(*this);
 }
 
@@ -74,6 +75,30 @@
     return _memory_managers;
 }
 
+bool GraphContext::insert_weights_management_ctx(WeightsManagerContext &&weights_managers)
+{
+    Target target = weights_managers.target;
+
+    if(target != Target::NEON || _weights_managers.find(target) != std::end(_weights_managers))
+    {
+        return false;
+    }
+
+    _weights_managers[target] = std::move(weights_managers);
+
+    return true;
+}
+
+WeightsManagerContext *GraphContext::weights_management_ctx(Target target)
+{
+    return (_weights_managers.find(target) != std::end(_weights_managers)) ? &_weights_managers[target] : nullptr;
+}
+
+std::map<Target, WeightsManagerContext> &GraphContext::weights_managers()
+{
+    return _weights_managers;
+}
+
 void GraphContext::finalize()
 {
     const size_t num_pools = 1;
diff --git a/src/graph/backends/CL/CLDeviceBackend.cpp b/src/graph/backends/CL/CLDeviceBackend.cpp
index 9971e4f..9b7c879 100644
--- a/src/graph/backends/CL/CLDeviceBackend.cpp
+++ b/src/graph/backends/CL/CLDeviceBackend.cpp
@@ -204,6 +204,11 @@
 
     return mm;
 }
+
+std::shared_ptr<arm_compute::IWeightsManager> CLDeviceBackend::create_weights_manager()
+{
+    return nullptr;
+}
 } // namespace backends
 } // namespace graph
 } // namespace arm_compute
diff --git a/src/graph/backends/GLES/GCDeviceBackend.cpp b/src/graph/backends/GLES/GCDeviceBackend.cpp
index 058f779..83e2436 100644
--- a/src/graph/backends/GLES/GCDeviceBackend.cpp
+++ b/src/graph/backends/GLES/GCDeviceBackend.cpp
@@ -154,6 +154,11 @@
 
     return mm;
 }
+
+std::shared_ptr<arm_compute::IWeightsManager> GCDeviceBackend::create_weights_manager()
+{
+    return nullptr;
+}
 } // namespace backends
 } // namespace graph
 } // namespace arm_compute
diff --git a/src/graph/backends/NEON/NEDeviceBackend.cpp b/src/graph/backends/NEON/NEDeviceBackend.cpp
index f94cd97..017b4f0 100644
--- a/src/graph/backends/NEON/NEDeviceBackend.cpp
+++ b/src/graph/backends/NEON/NEDeviceBackend.cpp
@@ -37,6 +37,7 @@
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/runtime/Allocator.h"
 #include "arm_compute/runtime/BlobLifetimeManager.h"
+#include "arm_compute/runtime/IWeightsManager.h"
 #include "arm_compute/runtime/MemoryGroup.h"
 #include "arm_compute/runtime/MemoryManagerOnDemand.h"
 #include "arm_compute/runtime/OffsetLifetimeManager.h"
@@ -90,6 +91,16 @@
 
         ctx.insert_memory_management_ctx(std::move(mm_ctx));
     }
+
+    // Create function level weights manager
+    if(ctx.weights_management_ctx(Target::NEON) == nullptr)
+    {
+        WeightsManagerContext wm_ctx;
+        wm_ctx.target = Target::NEON;
+        wm_ctx.wm     = create_weights_manager();
+
+        ctx.insert_weights_management_ctx(std::move(wm_ctx));
+    }
 }
 
 bool NEDeviceBackend::is_backend_supported()
@@ -159,6 +170,12 @@
 
     return mm;
 }
+
+std::shared_ptr<arm_compute::IWeightsManager> NEDeviceBackend::create_weights_manager()
+{
+    auto weights_mgr = std::make_shared<IWeightsManager>();
+    return weights_mgr;
+}
 } // namespace backends
 } // namespace graph
 } // namespace arm_compute
diff --git a/src/graph/backends/NEON/NEFunctionFactory.cpp b/src/graph/backends/NEON/NEFunctionFactory.cpp
index 852de54..45e9727 100644
--- a/src/graph/backends/NEON/NEFunctionFactory.cpp
+++ b/src/graph/backends/NEON/NEFunctionFactory.cpp
@@ -115,6 +115,7 @@
     std::shared_ptr<IMemoryManager> mm = get_memory_manager(ctx, Target::NEON);
     std::unique_ptr<IFunction>      func;
     std::string                     func_name;
+
     if(conv_algorithm == ConvolutionMethod::Direct)
     {
         std::tie(func, func_name) = create_named_memory_managed_function<NEDirectConvolutionLayer>(
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
index c5da649..0452a23 100644
--- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
@@ -76,7 +76,7 @@
     return CLTransposeKernel::validate(input, output);
 }
 
-CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
+CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
     : _memory_group(memory_manager), _convert_weights(), _flatten_layer(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(),
       _accumulate_biases_kernel(), _flatten_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _are_weights_converted(true), _are_weights_reshaped(true),
       _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _is_prepared(false), _original_weights(nullptr)
diff --git a/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp b/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp
index a208545..4ccda88 100644
--- a/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp
+++ b/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp
@@ -38,7 +38,7 @@
     _kernel = std::move(k);
 }
 
-GCFullyConnectedLayer::GCFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
+GCFullyConnectedLayer::GCFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
     : _memory_group(std::move(memory_manager)), _im2col_kernel(), _reshape_weights_kernel(), _mm_kernel(), _accumulate_biases_kernel(), _im2col_output(), _reshape_weights_output(),
       _original_weights(nullptr), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false)
 {
diff --git a/src/runtime/IWeightsManager.cpp b/src/runtime/IWeightsManager.cpp
new file mode 100644
index 0000000..6dfb925
--- /dev/null
+++ b/src/runtime/IWeightsManager.cpp
@@ -0,0 +1,128 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/IWeightsManager.h"
+
+namespace arm_compute
+{
+IWeightsManager::IWeightsManager()
+    : _managed_weights(), _managed_weights_parents()
+{
+}
+
+void IWeightsManager::manage(const ITensor *weights, ITransformWeights *parent)
+{
+    if(!are_weights_managed(weights))
+    {
+        _managed_weights[weights];
+    }
+
+    // In case the weights are an output of a previous reshape function
+    // store the parent's link
+    if(parent != nullptr)
+    {
+        if(_managed_weights_parents.find(weights) == _managed_weights_parents.end())
+        {
+            _managed_weights_parents[weights] = parent;
+        }
+    }
+}
+
+ITensor *IWeightsManager::run(const ITensor *weights, ITransformWeights *weights_transform)
+{
+    ARM_COMPUTE_ERROR_ON_MSG(!are_weights_managed(weights), "Cannot run function. Weights are not managed");
+
+    // Find if I have the same weights with weights transform. If I do, don't run the reshape
+    auto     item = _managed_weights.find(weights);
+    bool     perform_run{ true };
+    ITensor *weights_tensor{ nullptr };
+
+    // Check if I already have the requested transform and I have run the reshape function
+    for(auto it : item->second)
+    {
+        if(it->is_reshape_run() && (it->uid() == weights_transform->uid()))
+        {
+            weights_tensor = it->get_weights();
+            perform_run    = false;
+            break;
+        }
+    }
+
+    if(perform_run)
+    {
+        weights_transform->run();
+        weights_tensor = weights_transform->get_weights();
+    }
+
+    // Check if we can release memory from parent
+    auto parent_item = _managed_weights_parents.find(weights);
+    if(parent_item != _managed_weights_parents.end())
+    {
+        int32_t refcount = parent_item->second->decrease_refcount();
+        if(refcount == 0)
+        {
+            parent_item->second->release();
+        }
+    }
+
+    return weights_tensor;
+}
+
+bool IWeightsManager::are_weights_managed(const ITensor *weights)
+{
+    return (_managed_weights.find(weights) != _managed_weights.end());
+}
+
+ITensor *IWeightsManager::acquire(const ITensor *weights, ITransformWeights *weights_transform)
+{
+    ARM_COMPUTE_ERROR_ON_MSG(!are_weights_managed(weights), "Cannot acquire weights. Weights are not managed");
+
+    ITensor *transformed_weights{ nullptr };
+    auto     item = _managed_weights.find(weights);
+
+    // Check if I already have the requested transform. If I do,
+    // increase the refcount of the transformed weights object and
+    // reuse the tensor
+    for(auto it : item->second)
+    {
+        if(it->uid() == weights_transform->uid())
+        {
+            transformed_weights = it->get_weights();
+            it->increase_refcount();
+            break;
+        }
+    }
+
+    if(transformed_weights == nullptr)
+    {
+        transformed_weights = weights_transform->get_weights();
+        weights_transform->increase_refcount();
+        item->second.emplace_back(weights_transform);
+    }
+
+    // Manage the weights and store link to the parent node
+    manage(transformed_weights, weights_transform);
+
+    return transformed_weights;
+}
+} // namespace arm_compute
diff --git a/src/runtime/NEON/functions/NEConvertFullyConnectedWeights.cpp b/src/runtime/NEON/functions/NEConvertFullyConnectedWeights.cpp
index b5b159a..f65c035 100644
--- a/src/runtime/NEON/functions/NEConvertFullyConnectedWeights.cpp
+++ b/src/runtime/NEON/functions/NEConvertFullyConnectedWeights.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -23,8 +23,8 @@
  */
 #include "arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h"
 
-using namespace arm_compute;
-
+namespace arm_compute
+{
 NEConvertFullyConnectedWeights::NEConvertFullyConnectedWeights()
     : _kernel()
 {
@@ -46,3 +46,4 @@
 {
     NEScheduler::get().schedule(&_kernel, Window::DimZ);
 }
+} // namespace arm_compute
\ No newline at end of file
diff --git a/src/runtime/NEON/functions/NEDeconvolutionLayer.cpp b/src/runtime/NEON/functions/NEDeconvolutionLayer.cpp
index bbb91b4..0411b41 100644
--- a/src/runtime/NEON/functions/NEDeconvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDeconvolutionLayer.cpp
@@ -91,10 +91,10 @@
         ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->dimension(Window::DimZ) != output_shape.z(), "Output's depth is invalid.");
     }
 
-    unsigned int        deconv_pad_x = 0;
-    unsigned int        deconv_pad_y = 0;
-    const unsigned int  stride_x = info.stride().first;
-    const unsigned int  stride_y = info.stride().second;
+    unsigned int        deconv_pad_x    = 0;
+    unsigned int        deconv_pad_y    = 0;
+    const unsigned int  stride_x        = info.stride().first;
+    const unsigned int  stride_y        = info.stride().second;
     const TensorShape   scale_out_shape = compute_deconvolution_upsampled_shape(*input, *weights, stride_x, stride_y, out_dims, deconv_pad_x, deconv_pad_y);
     TensorInfo          scale_out_info(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(scale_out_shape));
     const PadStrideInfo conv_info(1, 1, 0, 0, 0, 0, DimensionRoundingType::CEIL);
@@ -127,8 +127,8 @@
     const unsigned int pad_right  = info.pad_right();
     const unsigned int pad_top    = info.pad_top();
     const unsigned int pad_bottom = info.pad_bottom();
-    const unsigned int stride_x = info.stride().first;
-    const unsigned int stride_y = info.stride().second;
+    const unsigned int stride_x   = info.stride().first;
+    const unsigned int stride_y   = info.stride().second;
 
     const unsigned int width_idx  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
     const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
@@ -166,14 +166,14 @@
         unsigned int deconv_pad_right = pad_left > pad_right ? pad_left - pad_right : 0;
         deconv_pad_x -= deconv_pad_left + deconv_pad_right;
         ARM_COMPUTE_ERROR_ON((deconv_pad_x % 2) != 0);
-        deconv_pad_left  += deconv_pad_x / 2;
+        deconv_pad_left += deconv_pad_x / 2;
         deconv_pad_right += deconv_pad_x / 2;
 
         unsigned int deconv_pad_top    = pad_bottom > pad_top ? pad_bottom - pad_top : 0;
         unsigned int deconv_pad_bottom = pad_top > pad_bottom ? pad_top - pad_bottom : 0;
         deconv_pad_y -= deconv_pad_top + deconv_pad_bottom;
         ARM_COMPUTE_ERROR_ON((deconv_pad_y % 2) != 0);
-        deconv_pad_top    += deconv_pad_y / 2;
+        deconv_pad_top += deconv_pad_y / 2;
         deconv_pad_bottom += deconv_pad_y / 2;
 
         TensorInfo scale_out_info(scale_out_shape, 1, _permuted_input.info()->data_type(), _permuted_input.info()->quantization_info());
@@ -212,14 +212,14 @@
         unsigned int deconv_pad_right = pad_left > pad_right ? pad_left - pad_right : 0;
         deconv_pad_x -= deconv_pad_left + deconv_pad_right;
         ARM_COMPUTE_ERROR_ON((deconv_pad_x % 2) != 0);
-        deconv_pad_left  += deconv_pad_x / 2;
+        deconv_pad_left += deconv_pad_x / 2;
         deconv_pad_right += deconv_pad_x / 2;
 
         unsigned int deconv_pad_top    = pad_bottom > pad_top ? pad_bottom - pad_top : 0;
         unsigned int deconv_pad_bottom = pad_top > pad_bottom ? pad_top - pad_bottom : 0;
         deconv_pad_y -= deconv_pad_top + deconv_pad_bottom;
         ARM_COMPUTE_ERROR_ON((deconv_pad_y % 2) != 0);
-        deconv_pad_top    += deconv_pad_y / 2;
+        deconv_pad_top += deconv_pad_y / 2;
         deconv_pad_bottom += deconv_pad_y / 2;
 
         TensorInfo scale_out_info(scale_out_shape, 1, input->info()->data_type(), input->info()->quantization_info());
diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
index 12a5a1d..7adc3bc 100644
--- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
+++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
@@ -74,10 +74,11 @@
     return NETransposeKernel::validate(input, output);
 }
 
-NEFullyConnectedLayer::NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(std::move(memory_manager)), _flatten_kernel(), _convert_weights(), _reshape_weights_function(), _mm_gemm(), _mm_gemmlowp(), _gemmlowp_output_stage(), _accumulate_biases_kernel(),
-      _flatten_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _original_weights(nullptr), _are_weights_converted(true), _are_weights_reshaped(false),
-      _is_fc_after_conv(false), _accumulate_biases(false), _is_quantized(false), _is_prepared(false)
+NEFullyConnectedLayer::NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+    : _memory_group(std::move(memory_manager)), _weights_manager(weights_manager), _flatten_kernel(), _convert_weights(), _convert_weights_managed(), _reshape_weights_function(),
+      _reshape_weights_managed_function(), _mm_gemm(nullptr, weights_manager), _mm_gemmlowp(), _gemmlowp_output_stage(), _accumulate_biases_kernel(), _flatten_output(), _gemmlowp_output(),
+      _converted_weights_output(), _reshape_weights_output(), _original_weights(nullptr), _are_weights_converted(true), _are_weights_reshaped(false), _is_fc_after_conv(false), _accumulate_biases(false),
+      _is_quantized(false), _is_prepared(false)
 {
 }
 
@@ -155,6 +156,11 @@
     _is_quantized          = is_data_type_quantized_asymmetric(input->info()->data_type());
     _original_weights      = weights;
 
+    if(_weights_manager)
+    {
+        _weights_manager->manage(weights);
+    }
+
     // Configure gemmlowp output
     if(_is_quantized)
     {
@@ -194,21 +200,39 @@
     // Reshape weights if needed
     if(!_are_weights_reshaped)
     {
-        // Reshape the weights
-        _reshape_weights_function.configure(weights, &_reshape_weights_output);
-        weights_to_use = &_reshape_weights_output;
+        if(_weights_manager && _weights_manager->are_weights_managed(weights))
+        {
+            _reshape_weights_managed_function.configure(weights);
+            weights_to_use = _weights_manager->acquire(weights, &_reshape_weights_managed_function);
+        }
+        else
+        {
+            // Reshape the weights
+            _reshape_weights_function.configure(weights, &_reshape_weights_output);
+            weights_to_use = &_reshape_weights_output;
+        }
     }
 
     // Convert weights if needed
     if(_is_fc_after_conv && (input->info()->data_layout() != fc_info.weights_trained_layout))
     {
-        // Convert weights
-        _convert_weights.configure(weights_to_use,
-                                   &_converted_weights_output,
-                                   input->info()->tensor_shape(),
-                                   fc_info.weights_trained_layout);
+        if(_weights_manager && _weights_manager->are_weights_managed(weights_to_use))
+        {
+            _convert_weights_managed.configure(weights_to_use,
+                                               input->info()->tensor_shape(),
+                                               fc_info.weights_trained_layout);
+            weights_to_use = _weights_manager->acquire(weights, &_convert_weights_managed);
+        }
+        else
+        {
+            // Convert weights
+            _convert_weights.configure(weights_to_use,
+                                       &_converted_weights_output,
+                                       input->info()->tensor_shape(),
+                                       fc_info.weights_trained_layout);
 
-        weights_to_use         = &_converted_weights_output;
+            weights_to_use = &_converted_weights_output;
+        }
         _are_weights_converted = false;
     }
 
@@ -381,7 +405,10 @@
 {
     if(!_is_prepared)
     {
-        ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+        if(!_weights_manager)
+        {
+            ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+        }
 
         auto release_unused = [](Tensor * w)
         {
@@ -397,20 +424,38 @@
         // Reshape of the weights (happens only once)
         if(!_are_weights_reshaped)
         {
-            // Run reshape weights kernel and mark weights as unused
-            _reshape_weights_output.allocator()->allocate();
-            _reshape_weights_function.run();
-
-            cur_weights->mark_as_unused();
-            cur_weights           = &_reshape_weights_output;
+            if(_weights_manager && _weights_manager->are_weights_managed(_original_weights))
+            {
+                cur_weights->mark_as_unused();
+                cur_weights = _weights_manager->run(cur_weights, &_reshape_weights_managed_function);
+            }
+            else
+            {
+                // Reshape of the weights (happens only once)
+                if(!_are_weights_reshaped)
+                {
+                    // Run reshape weights kernel and mark weights as unused
+                    _reshape_weights_output.allocator()->allocate();
+                    _reshape_weights_function.run();
+                }
+                cur_weights->mark_as_unused();
+                cur_weights = &_reshape_weights_output;
+            }
             _are_weights_reshaped = true;
         }
 
         // Convert weights if needed (happens only once)
         if(!_are_weights_converted)
         {
-            _converted_weights_output.allocator()->allocate();
-            _convert_weights.run();
+            if(_weights_manager && _weights_manager->are_weights_managed(cur_weights))
+            {
+                _weights_manager->run(cur_weights, &_convert_weights_managed);
+            }
+            else
+            {
+                _converted_weights_output.allocator()->allocate();
+                _convert_weights.run();
+            }
 
             cur_weights->mark_as_unused();
             _are_weights_converted = true;
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index 37d0e09..df92b79 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -42,9 +42,9 @@
 
 namespace arm_compute
 {
-NEGEMM::NEGEMM(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(memory_manager), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(memory_manager), _ma_kernel(), _tmp_a(), _tmp_b(), _original_b(nullptr),
-      _run_vector_matrix_multiplication(false), _run_addition(false), _reshape_b_only_on_first_run(false), _is_prepared(false)
+NEGEMM::NEGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+    : _memory_group(memory_manager), _weights_manager(weights_manager), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(memory_manager, weights_manager), _ma_kernel(), _tmp_a(),
+      _tmp_b(), _original_b(nullptr), _run_vector_matrix_multiplication(false), _run_addition(false), _reshape_b_only_on_first_run(false), _is_prepared(false)
 {
 }
 
@@ -276,13 +276,19 @@
     {
         if(_asm_glue.is_configured())
         {
-            ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+            if(!_weights_manager || !_weights_manager->are_weights_managed(_original_b))
+            {
+                ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+            }
 
             _asm_glue.prepare();
         }
         else if(_reshape_b_only_on_first_run && !_run_vector_matrix_multiplication && !_asm_glue.is_configured())
         {
-            ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+            if(!_weights_manager || !_weights_manager->are_weights_managed(_original_b))
+            {
+                ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+            }
 
             _tmp_b.allocator()->allocate();
             NEScheduler::get().schedule(&_transpose_kernel, Window::DimY);
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 2a4498b..956ded5 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -38,7 +38,8 @@
 std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescription &gemm_kernel_info,
                                                      const ITensor *a, const ITensor *b, ITensor *d,
                                                      float alpha, float beta, const GEMMInfo &gemm_info,
-                                                     std::shared_ptr<IMemoryManager> memory_manager)
+                                                     std::shared_ptr<IMemoryManager> memory_manager,
+                                                     IWeightsManager                *weights_manager)
 
 {
     // Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
@@ -50,7 +51,7 @@
             {
                 return nullptr;
             }
-            auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+            auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager, weights_manager);
             function->configure(a, b, d, alpha, beta, gemm_info);
             return std::move(function);
         }
@@ -73,25 +74,95 @@
     }
 }
 
+template <typename TypeInput, typename TypeOutput>
+class FallbackTransform : public ITransformWeights
+{
+public:
+    void run() override
+    {
+        _output.allocator()->allocate();
+        ARM_COMPUTE_ERROR_ON(_output.buffer() == nullptr);
+        _gemm_kernel_asm->pretranspose_B_array(_output.buffer(), _in1_ptr, _ldb, _multi_stride_b);
+        _reshape_run = true;
+    }
+
+    void release() override
+    {
+        _output.allocator()->free();
+    }
+
+    ITensor *get_weights() override
+    {
+        return &_output;
+    }
+
+    uint32_t uid() override
+    {
+        uint32_t id = (_B_pretranspose_size | 0x80000000);
+        return id;
+    }
+
+    void configure(size_t B_pretranspose_size, unsigned int alignment)
+    {
+        _output.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
+        _B_pretranspose_size = B_pretranspose_size;
+    }
+
+    void set_pretranspose(ITensor *tensor)
+    {
+        if(!_reshape_run)
+        {
+            _gemm_kernel_asm->set_pretransposed_B_data(tensor->buffer());
+        }
+    }
+
+    void set_args(const int ldb, const TypeInput *in1_ptr, const int multi_stride_b, std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> gemm_kernel_asm)
+    {
+        _ldb             = ldb;
+        _in1_ptr         = in1_ptr;
+        _multi_stride_b  = multi_stride_b;
+        _gemm_kernel_asm = gemm_kernel_asm;
+    }
+
+private:
+    Tensor           _output{};
+    int              _ldb{};
+    const TypeInput *_in1_ptr{};
+    int              _multi_stride_b{};
+    size_t           _B_pretranspose_size{};
+    std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
+};
+
 /** Fallback in case ACL doesn't have a function */
 template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing>
 class Fallback : public NEGEMMAssemblyDispatch::IFallback
 {
 public:
+    /** Destructor */
+    ~Fallback()
+    {
+        // Release memory if we have allocated the memory ourselves
+        if(_pretranspose && !(_weights_manager && _weights_manager->are_weights_managed(_b)))
+        {
+            delete _pretranspose;
+        }
+    }
+
     /** Initialise the functions's input and output.
      *
-     * @param[in]  a            Input tensor containing the Matrix A.
-     * @param[in]  b            Input tensor containing the Matrix B.
-     * @param[in]  c            Input tensor containing the Matrix C.
-     * @param[out] d            Output tensor to store the result of matrix multiplication.
-     * @param[in]  args         Matrix multiplication information.
-     * @param[in]  gemm_info    GEMM meta-data
-     * @param[in]  memory_group Memory group to be used by the function.
-     * @param[in]  os           Output stage meta-data.
+     * @param[in]  a               Input tensor containing the Matrix A.
+     * @param[in]  b               Input tensor containing the Matrix B.
+     * @param[in]  c               Input tensor containing the Matrix C.
+     * @param[out] d               Output tensor to store the result of matrix multiplication.
+     * @param[in]  args            Matrix multiplication information.
+     * @param[in]  gemm_info       GEMM meta-data
+     * @param[in]  memory_group    Memory group to be used by the function.
+     * @param[in]  weights_manager Weights manager to be used by the function.
+     * @param[in]  os              Output stage meta-data.
      */
     void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
                    arm_gemm::GemmArgs<TypeOutput> args, const GEMMInfo &gemm_info,
-                   MemoryGroup &memory_group, const OutputStage &os = {});
+                   MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {});
 
     // Inherited methods overridden:
     void run() override;
@@ -108,7 +179,7 @@
     void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment);
 
     /** Assembly Gemm kernel */
-    std::unique_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
+    std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
     /** Optimised NEON kernel */
     std::unique_ptr<INEKernel> _optimised_kernel{ nullptr };
     /** Input A */
@@ -130,20 +201,25 @@
     /** GEMM workspace */
     Tensor _workspace{};
     /** Pre-transpose tensor */
-    Tensor _pretranspose{};
+    ITensor *_pretranspose{ nullptr };
     /** Prepared flag */
     bool _is_prepared{ false };
     /** GEMM meta-data */
     GEMMInfo _gemm_info{};
+    /** Weights manager */
+    IWeightsManager *_weights_manager{ nullptr };
+    /** Weights transform object */
+    FallbackTransform<TypeInput, TypeOutput> _weights_transform{};
 };
 
 template <typename TypeInput, typename TypeOutput, class OutputStage>
 void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
                                                              arm_gemm::GemmArgs<TypeOutput> args, const GEMMInfo &gemm_info,
-                                                             MemoryGroup &memory_group, const OutputStage &os)
+                                                             MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os)
 {
     arm_gemm::GemmConfig              gemm_cfg;
     const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput, OutputStage>(args, os);
+    _weights_manager                                   = weights_manager;
     if(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED)
     {
         gemm_cfg.filter = gemm_kernel_info.name;
@@ -190,7 +266,16 @@
         // Forcing 128-byte alignment (required by 32-bit kernels)
         const unsigned int alignment           = 128;
         const size_t       B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
-        _pretranspose.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
+        if(weights_manager && _weights_manager->are_weights_managed(b))
+        {
+            _weights_transform.configure(B_pretranspose_size, alignment);
+            _pretranspose = _weights_manager->acquire(b, &_weights_transform);
+        }
+        else
+        {
+            _pretranspose = new Tensor();
+            static_cast<Tensor *>(_pretranspose)->allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
+        }
     }
 }
 
@@ -208,14 +293,28 @@
         // Pretranspose B if required
         if(_gemm_kernel_asm->B_pretranspose_required())
         {
-            _pretranspose.allocator()->allocate();
-            ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
             const int  ldb            = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
             const auto in1_ptr        = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
             const int  multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
 
-            _gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b);
-            _b->mark_as_unused();
+            if(_weights_manager && _weights_manager->are_weights_managed(_b))
+            {
+                _weights_transform.set_args(ldb, in1_ptr, multi_stride_b, _gemm_kernel_asm);
+                _weights_manager->run(_b, &_weights_transform);
+
+                // If we didn't run the reshape function, set the pretransposed buffer
+                if(!_weights_transform.is_reshape_run())
+                {
+                    _weights_transform.set_pretranspose(_pretranspose);
+                }
+            }
+            else
+            {
+                static_cast<Tensor *>(_pretranspose)->allocator()->allocate();
+                ARM_COMPUTE_ERROR_ON(_pretranspose->buffer() == nullptr);
+                _gemm_kernel_asm->pretranspose_B_array(_pretranspose->buffer(), in1_ptr, ldb, multi_stride_b);
+                _b->mark_as_unused();
+            }
         }
 
         _is_prepared = true;
@@ -294,7 +393,7 @@
 template <typename TypeInput, typename TypeOutput>
 void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
                                  const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info,
-                                 std::shared_ptr<IMemoryManager> memory_manager)
+                                 std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
 {
     INEGEMMWrapperKernel::Params p           = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
     const CPUInfo               &ci          = NEScheduler::get().cpu_info();
@@ -304,14 +403,14 @@
 
     // Try to create an ACL function:
     const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args);
-    acl_function                                       = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager));
+    acl_function                                       = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager);
 
     // If we still don't have an ACL function:
     if(acl_function == nullptr)
     {
         //Fallback onto arm_gemm function if ACL doesn't support this method.
         auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
-        fallback->configure(a, b, c, d, args, gemm_info, memory_group);
+        fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager);
         arm_gemm = std::move(fallback);
     }
 }
@@ -319,7 +418,7 @@
 template <typename TypeInput, typename TypeOutput>
 void create_function_or_arm_gemm_quant(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
                                        const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info,
-                                       std::shared_ptr<IMemoryManager> memory_manager)
+                                       std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
 {
     INEGEMMWrapperKernel::Params p           = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
     const CPUInfo               &ci          = NEScheduler::get().cpu_info();
@@ -339,22 +438,22 @@
 
     // Try to create an ACL function:
     const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args, gemm_requant_info);
-    acl_function                                       = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager));
+    acl_function                                       = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager);
 
     // If we still don't have an ACL function:
     if(acl_function == nullptr)
     {
         // Fallback onto arm_gemm function if ACL doesn't support this method.
         auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::ARequantizeLayer32>>();
-        fallback->configure(a, b, c, d, args, gemm_info, memory_group, gemm_requant_info);
+        fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info);
         arm_gemm = std::move(fallback);
     }
 }
 
 } //namespace
 
-NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
-    : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager)
+NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+    : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager), _weights_manager(weights_manager)
 {
 }
 
@@ -390,27 +489,27 @@
     switch(a->info()->data_type())
     {
         case DataType::F32:
-            create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+            create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
             break;
 #ifdef __aarch64__
         case DataType::U8:
         case DataType::QASYMM8:
             if(d->info()->data_type() == DataType::S32)
             {
-                create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+                create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
             }
             else
             {
-                create_function_or_arm_gemm_quant<uint8_t, uint8_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+                create_function_or_arm_gemm_quant<uint8_t, uint8_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
             }
             break;
         case DataType::S8:
-            create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+            create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
             break;
 #endif /* __aarch64__ */
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
         case DataType::F16:
-            create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+            create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
             break;
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
         default:
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index e94c893..a39e4c5 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -50,7 +50,6 @@
     ARM_COMPUTE_ERROR_THROW_ON(NEConvolutionLayerReshapeWeights::validate(weights->info(),
                                                                           (biases != nullptr) ? biases->info() : nullptr,
                                                                           output->info()));
-
     const bool     append_biases = (biases != nullptr) && !is_data_type_quantized_asymmetric(weights->info()->data_type());
     const ITensor *biases_to_use = (append_biases) ? biases : nullptr;
 
@@ -89,10 +88,10 @@
     NEScheduler::get().schedule(&_weights_reshape_kernel, 3);
 }
 
-NEGEMMConvolutionLayer::NEGEMMConvolutionLayer(const std::shared_ptr<IMemoryManager> &memory_manager)
-    : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(), _add_bias_kernel(),
-      _reshape_layer(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _append_bias(false), _skip_im2col(false),
-      _skip_col2im(false), _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
+NEGEMMConvolutionLayer::NEGEMMConvolutionLayer(const std::shared_ptr<IMemoryManager> &memory_manager, IWeightsManager *weights_manager)
+    : _memory_group(memory_manager), _weights_manager(weights_manager), _reshape_weights(), _reshape_weights_managed(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager),
+      _col2im_kernel(), _activationlayer_function(), _add_bias_kernel(), _reshape_layer(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(),
+      _data_layout(DataLayout::NCHW), _append_bias(false), _skip_im2col(false), _skip_col2im(false), _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
 {
 }
 
@@ -309,7 +308,18 @@
 
     // _weights_reshaped will be auto configured in the kernel.
     // Just append biases and do not transpose 1xW as it will be reshaped in NEGEMM
-    _reshape_weights.configure(weights, biases_to_use, &_weights_reshaped);
+    const ITensor *weights_to_use = weights;
+
+    if(_weights_manager && _weights_manager->are_weights_managed(weights))
+    {
+        _reshape_weights_managed.configure(weights, biases_to_use);
+        weights_to_use = _weights_manager->acquire(weights, &_reshape_weights_managed);
+    }
+    else
+    {
+        _reshape_weights.configure(weights, biases_to_use, &_weights_reshaped);
+        weights_to_use = &_weights_reshaped;
+    }
 
     // Create tensor to store im2col reshaped inputs
     if(!_skip_im2col)
@@ -351,7 +361,7 @@
     // Configure GEMM
     // In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix
     const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0;
-    configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, gemm_3d_depth);
+    configure_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, gemm_3d_depth);
 
     if(!_skip_im2col)
     {
@@ -493,7 +503,7 @@
     ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases_to_use, nullptr));
     weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, (append_bias && !skip_im2col)), 1, data_type);
     weights_reshaped_info.set_quantization_info(weights->quantization_info());
-    weights_to_use        = &weights_reshaped_info;
+    weights_to_use = &weights_reshaped_info;
 
     if(!skip_im2col)
     {
@@ -603,10 +613,17 @@
     {
         ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
 
-        // Run weights reshaping and mark original weights tensor as unused
-        _weights_reshaped.allocator()->allocate();
-        _reshape_weights.run();
-        _original_weights->mark_as_unused();
+        if(_weights_manager && _weights_manager->are_weights_managed(_original_weights))
+        {
+            _weights_manager->run(_original_weights, &_reshape_weights_managed);
+        }
+        else
+        {
+            // Run weights reshaping and mark original weights tensor as unused
+            _weights_reshaped.allocator()->allocate();
+            _reshape_weights.run();
+            _original_weights->mark_as_unused();
+        }
 
         // Prepare GEMM
         _is_quantized ? _mm_gemmlowp.prepare() : _mm_gemm.prepare();
diff --git a/src/runtime/NEON/functions/NERNNLayer.cpp b/src/runtime/NEON/functions/NERNNLayer.cpp
index 9ca7ded..67f4064 100644
--- a/src/runtime/NEON/functions/NERNNLayer.cpp
+++ b/src/runtime/NEON/functions/NERNNLayer.cpp
@@ -34,8 +34,8 @@
 namespace arm_compute
 {
 NERNNLayer::NERNNLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(std::move(memory_manager)), _gemm_state_f(), _add_kernel(), _activation_kernel(), _fully_connected_kernel(), _copy_kernel(), _fully_connected_out(), _gemm_output(), _add_output(),
-      _is_prepared(false)
+    : _memory_group(std::move(memory_manager)), _gemm_state_f(), _add_kernel(), _activation_kernel(), _fully_connected(memory_manager), _copy_kernel(), _fully_connected_out(), _gemm_output(),
+      _add_output(), _is_prepared(false)
 {
 }
 
@@ -81,7 +81,7 @@
 
     // Manage intermediate buffers and configure
     _memory_group.manage(&_fully_connected_out);
-    _fully_connected_kernel.configure(input, weights, bias, &_fully_connected_out);
+    _fully_connected.configure(input, weights, bias, &_fully_connected_out);
 
     _memory_group.manage(&_gemm_output);
     _gemm_state_f.configure(hidden_state, recurrent_weights, nullptr, &_gemm_output, 1.f, 0.f);
@@ -106,7 +106,7 @@
 
     MemoryGroupResourceScope scope_mg(_memory_group);
 
-    _fully_connected_kernel.run();
+    _fully_connected.run();
 
     _gemm_state_f.run();
 
@@ -121,7 +121,7 @@
 {
     if(!_is_prepared)
     {
-        _fully_connected_kernel.prepare();
+        _fully_connected.prepare();
         _gemm_state_f.prepare();
 
         _is_prepared = true;
diff --git a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
index ac809fa..41d7d1f 100644
--- a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
+++ b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
@@ -180,7 +180,7 @@
     }
 };
 
-NEGEMMInterleavedWrapper::NEGEMMInterleavedWrapper(std::shared_ptr<IMemoryManager> memory_manager)
+NEGEMMInterleavedWrapper::NEGEMMInterleavedWrapper(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
     : _memory_group(std::move(memory_manager))
 {
 }