COMPMID-3098 Fuse Relu and Bounded Relu with FullyConnected NEON

Change-Id: Id28062445590d6c06b35f7d7434eb38393ae94a7
Signed-off-by: SiCongLi <sicong.li@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2875
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index 3417c72..a6ebcac 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -27,9 +27,6 @@
 #include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/ITensor.h"
 #include "arm_compute/core/KernelDescriptors.h"
-#include "arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h"
-#include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h"
-#include "arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Validate.h"
@@ -42,11 +39,12 @@
 {
 using namespace arm_compute::misc::shape_calculator;
 
-NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(memory_manager), _asm_glue(memory_manager), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(),
-      _offset_contribution_kernel(), _offset_contribution_output_stage_kernel(), _activation_func(), _convert_to_signed_asymm(), _convert_from_signed_asymm(), _vector_sum_col(), _vector_sum_row(), _tmp_a(),
-      _tmp_b(), _mm_result_s32(), _signed_a(), _signed_output(), _original_b(nullptr), _a_offset(0), _b_offset(0), _run_vector_matrix_multiplication(false), _assembly_path(false),
-      _fused_assembly_path(false), _reshape_b_only_on_first_run(false), _is_prepared(false), _fuse_output_stage(false), _run_activation(false), _flip_signedness(false)
+NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+    : _memory_group(memory_manager), _weights_manager(weights_manager), _asm_glue(memory_manager, weights_manager), _mm_kernel(), _mtx_a_reshape_kernel(), _mtx_b_reshape_kernel(),
+      _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(), _offset_contribution_kernel(), _offset_contribution_output_stage_kernel(), _activation_func(), _convert_to_signed_asymm(),
+      _convert_from_signed_asymm(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _mm_result_s32(), _signed_a(), _signed_output(), _original_b(nullptr), _a_offset(0), _b_offset(0),
+      _run_vector_matrix_multiplication(false), _assembly_path(false), _fused_assembly_path(false), _reshape_b_only_on_first_run(false), _is_prepared(false), _fuse_output_stage(false),
+      _run_activation(false), _flip_signedness(false)
 {
 }
 
@@ -60,10 +58,6 @@
     const ITensor *matrix_b = b;
     GEMMInfo       info     = gemm_info;
 
-    // Clear state
-    _mtx_a_reshape_kernel = nullptr;
-    _mtx_b_reshape_kernel = nullptr;
-
     // Set internal variables
     _a_offset                         = a->info()->quantization_info().uniform().offset;
     _b_offset                         = b->info()->quantization_info().uniform().offset;
@@ -158,18 +152,10 @@
         }
 
         // Configure interleave kernel
-        {
-            auto k = arm_compute::support::cpp14::make_unique<NEGEMMInterleave4x4Kernel>();
-            k->configure(a_to_use, &_tmp_a);
-            _mtx_a_reshape_kernel = std::move(k);
-        }
+        _mtx_a_reshape_kernel.configure(a_to_use, &_tmp_a);
 
         // Configure transpose kernel
-        {
-            auto k = arm_compute::support::cpp14::make_unique<NEGEMMTranspose1xWKernel>();
-            k->configure(b, &_tmp_b);
-            _mtx_b_reshape_kernel = std::move(k);
-        }
+        _mtx_b_reshape_kernel.configure(b, &_tmp_b);
     }
 
     if(!_fused_assembly_path)
@@ -209,9 +195,7 @@
             // Configure matrix multiply kernel
             if(!_assembly_path)
             {
-                auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpMatrixMultiplyKernel>();
-                k->configure(matrix_a, matrix_b, &_mm_result_s32);
-                _mm_kernel = std::move(k);
+                _mm_kernel.configure(matrix_a, matrix_b, &_mm_result_s32);
             }
 
             _offset_contribution_output_stage_kernel.configure(&_mm_result_s32,
@@ -231,21 +215,19 @@
             // Configure matrix multiply kernel
             if(!_assembly_path)
             {
-                auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpMatrixMultiplyKernel>();
-                k->configure(matrix_a, matrix_b, output);
-                _mm_kernel = std::move(k);
+                _mm_kernel.configure(matrix_a, matrix_b, output);
             }
             // Configure offset contribution kernel
             _offset_contribution_kernel.configure(output, _a_offset == 0 ? nullptr : &_vector_sum_col, _b_offset == 0 ? nullptr : &_vector_sum_row, a_to_use->info()->dimension(0), _a_offset, _b_offset);
         }
-    }
 
-    // Configure activation
-    const ActivationLayerInfo &activation = gemm_info.activation_info();
-    _run_activation                       = activation.enabled() && (!_assembly_path || (_assembly_path && !NEGEMMAssemblyDispatch::is_activation_supported(activation)));
-    if(_run_activation)
-    {
-        _activation_func.configure(output, nullptr, activation);
+        // Configure activation
+        const ActivationLayerInfo &activation = gemm_info.activation_info();
+        _run_activation                       = activation.enabled() && (!_assembly_path || (_assembly_path && !NEGEMMAssemblyDispatch::is_activation_supported(activation)));
+        if(_run_activation)
+        {
+            _activation_func.configure(output, nullptr, activation);
+        }
     }
 
     // Allocate tensors
@@ -495,16 +477,6 @@
         NEScheduler::get().schedule(&_convert_to_signed_asymm, Window::DimY);
     }
 
-    // Reshape inputs
-    if(_mtx_a_reshape_kernel)
-    {
-        NEScheduler::get().schedule(_mtx_a_reshape_kernel.get(), Window::DimY);
-    }
-    if(_mtx_b_reshape_kernel && !_reshape_b_only_on_first_run)
-    {
-        NEScheduler::get().schedule(_mtx_b_reshape_kernel.get(), Window::DimY);
-    }
-
     // Run GEMM
     if(_asm_glue.is_configured())
     {
@@ -512,7 +484,18 @@
     }
     else
     {
-        NEScheduler::get().schedule(_mm_kernel.get(), Window::DimY);
+        if(!_run_vector_matrix_multiplication)
+        {
+            // Run interleave kernel
+            NEScheduler::get().schedule(&_mtx_a_reshape_kernel, Window::DimY);
+
+            if(!_reshape_b_only_on_first_run)
+            {
+                // Run transpose kernel
+                NEScheduler::get().schedule(&_mtx_b_reshape_kernel, Window::DimY);
+            }
+        }
+        NEScheduler::get().schedule(&_mm_kernel, Window::DimY);
     }
 
     if(!_fused_assembly_path)
@@ -547,8 +530,8 @@
         NEScheduler::get().schedule(&_convert_from_signed_asymm, Window::DimY);
     }
 
-    // Run fused activation
-    if(_run_activation)
+    // Run fused activation unless already run in the fused assembly
+    if(_run_activation && !_fused_assembly_path)
     {
         _activation_func.run();
     }
@@ -558,23 +541,36 @@
 {
     if(!_is_prepared)
     {
+        const bool original_b_managed_by_weights_manager = _weights_manager && _weights_manager->are_weights_managed(_original_b);
         // Run assembly reshape
-        if(_asm_glue.is_configured() && _reshape_b_only_on_first_run)
+        if(_asm_glue.is_configured())
         {
-            ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+            if(!original_b_managed_by_weights_manager)
+            {
+                ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+            }
 
             _asm_glue.prepare();
-            _original_b->mark_as_unused();
+            if(!original_b_managed_by_weights_manager)
+            {
+                _original_b->mark_as_unused();
+            }
         }
         // Run non-assembly reshape
-        else if(_mtx_b_reshape_kernel && _reshape_b_only_on_first_run)
+        else if(_reshape_b_only_on_first_run && !_run_vector_matrix_multiplication && !_asm_glue.is_configured())
         {
-            ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+            if(!original_b_managed_by_weights_manager)
+            {
+                ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+            }
 
             // Run reshape kernel and mark original weights tensor as unused
             _tmp_b.allocator()->allocate();
-            NEScheduler::get().schedule(_mtx_b_reshape_kernel.get(), Window::DimY);
-            _original_b->mark_as_unused();
+            NEScheduler::get().schedule(&_mtx_b_reshape_kernel, Window::DimY);
+            if(!original_b_managed_by_weights_manager)
+            {
+                _original_b->mark_as_unused();
+            }
         }
 
         // Run matrix B reduction kernel only if _a_offset is not equal to 0