COMPMID-1698: Implementing CLGEMMLowpMatrixMultiplyReshapedKernel

Change-Id: Ia4db21b394a0b9235393202ce3c00b11cceb94ea
Reviewed-on: https://review.mlplatform.org/568
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 4b72878..2a01db7 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -31,43 +31,25 @@
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
+#include "arm_compute/runtime/CL/gemm_reshaped/CLGEMMReshapedConfiguration.h"
 
 namespace arm_compute
 {
 using namespace arm_compute::misc::shape_calculator;
+using namespace arm_compute::cl_gemm;
 
 namespace
 {
-inline bool is_interleaved_transposed(int m, int n, int k, bool reshape_b_only_on_first_run, GPUTarget gpu_target)
+inline bool is_gemm_reshaped(unsigned int m, bool reshape_b_only_on_first_run, GPUTarget gpu_target)
 {
-    bool flag = true;
-
-    if(gpu_target_is_in(gpu_target,
-                        GPUTarget::G71, GPUTarget::G72,
-                        GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT))
-    {
-        // COMPMID-852
-        if(k > 256 && m > 4 && reshape_b_only_on_first_run)
-        {
-            flag = ((0.72f + n * 0.10766f) < (n * 0.1284f));
-        }
-        else
-        {
-            flag = false;
-        }
-    }
-    else
-    {
-        flag = m > 1;
-    }
-
-    return flag;
+    return (get_arch_from_target(gpu_target) != GPUTarget::MIDGARD) && (m > 1) && (reshape_b_only_on_first_run);
 }
 } // namespace
 
 CLGEMMLowpMatrixMultiplyCore::CLGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager)
     : _memory_group(std::move(memory_manager)),
       _mm_kernel(),
+      _mm_reshaped_kernel(),
       _mtx_a_reshape_kernel(),
       _mtx_b_reshape_kernel(),
       _mtx_a_reduction_kernel(),
@@ -82,7 +64,7 @@
       _original_b(nullptr),
       _a_offset(0),
       _b_offset(0),
-      _is_interleaved_transposed(true),
+      _is_gemm_reshaped(true),
       _reshape_b_only_on_first_run(false),
       _is_prepared(false),
       _fuse_output_stage(false)
@@ -115,29 +97,17 @@
     // Arguments used by GEMMReshapeInfo
     // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
     // in order to know how the matrices have been reshaped
-    bool          reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
-    const bool    unroll_block              = dot8_supported(CLKernelLibrary::get().get_device());
-    const int     m                         = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
-    const int     n                         = b->info()->dimension(0);
-    const int     k                         = a->info()->dimension(0);
-    const int     depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
-    constexpr int mult_transpose1xW_width   = 1;
-    constexpr int mult_interleave4x4_height = 1;
-    rhs_info.n0                             = 16 / b->info()->element_size();
-    rhs_info.k0                             = 1;
-    rhs_info.h0                             = mult_transpose1xW_width;
-    rhs_info.interleave                     = false;
-    rhs_info.transpose                      = false;
-    lhs_info.m0                             = 4;
-    lhs_info.k0                             = 4;
-    lhs_info.v0                             = mult_interleave4x4_height;
-    lhs_info.interleave                     = true;
-    lhs_info.transpose                      = !unroll_block;
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
+    const unsigned int n                       = b->info()->dimension(0);
+    const unsigned int k                       = a->info()->dimension(0);
+    const unsigned int batch_size              = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
+    const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
 
     // Check if we need to reshape the matrix A and matrix B
-    _is_interleaved_transposed = is_interleaved_transposed(m, n, k, _reshape_b_only_on_first_run, gpu_target);
+    _is_gemm_reshaped = is_gemm_reshaped(m, _reshape_b_only_on_first_run, gpu_target);
 
-    if(_is_interleaved_transposed)
+    if(_is_gemm_reshaped)
     {
         // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
         reinterpret_input_as_3d = false;
@@ -151,6 +121,9 @@
             _memory_group.manage(&_tmp_b);
         }
 
+        // Pick up the GEMM configuration
+        std::tie(lhs_info, rhs_info) = CLGEMMReshapedConfigurationFactory::create()->configure(m, n, k, batch_size, DataType::QASYMM8);
+
         // Configure interleave kernel
         _mtx_a_reshape_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
 
@@ -190,10 +163,16 @@
 
         _memory_group.manage(&_mm_result_s32);
 
-        // Configure matrix multiply kernel
-        _mm_kernel.configure(matrix_a, matrix_b, &_mm_result_s32, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k,
-                                                                                                              mult_transpose1xW_width, mult_interleave4x4_height,
-                                                                                                              depth_output_gemm3d, reinterpret_input_as_3d));
+        if(_is_gemm_reshaped)
+        {
+            // Configure and tune matrix multiply kernel
+            _mm_reshaped_kernel.configure(matrix_a, matrix_b, &_mm_result_s32, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+        }
+        else
+        {
+            // Configure matrix multiply kernel
+            _mm_kernel.configure(matrix_a, matrix_b, &_mm_result_s32, false, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+        }
 
         // Configure offset contribution kernel
         _offset_contribution_output_stage_kernel.configure(&_mm_result_s32, _a_offset == 0 ? nullptr : &_vector_sum_col, _b_offset == 0 ? nullptr : &_vector_sum_row, c, output, a->info()->dimension(0),
@@ -203,17 +182,23 @@
     }
     else
     {
-        // Configure matrix multiply kernel
-        _mm_kernel.configure(matrix_a, matrix_b, output, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k,
-                                                                                                     mult_transpose1xW_width, mult_interleave4x4_height,
-                                                                                                     depth_output_gemm3d, reinterpret_input_as_3d));
+        if(_is_gemm_reshaped)
+        {
+            // Configure and tune matrix multiply kernel
+            _mm_reshaped_kernel.configure(matrix_a, matrix_b, output, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+        }
+        else
+        {
+            // Configure matrix multiply kernel
+            _mm_kernel.configure(matrix_a, matrix_b, output, false, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+        }
 
         // Configure offset contribution kernel
         _offset_contribution_kernel.configure(output, _a_offset == 0 ? nullptr : &_vector_sum_col, _b_offset == 0 ? nullptr : &_vector_sum_row, c, a->info()->dimension(0), _a_offset, _b_offset);
     }
 
     // Allocate tensors
-    if(_is_interleaved_transposed)
+    if(_is_gemm_reshaped)
     {
         _tmp_a.allocator()->allocate();
         if(!_reshape_b_only_on_first_run)
@@ -251,26 +236,14 @@
     GEMMRHSMatrixInfo rhs_info;
     GEMMLHSMatrixInfo lhs_info;
 
-    bool          reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
-    const bool    unroll_block              = dot8_supported(CLKernelLibrary::get().get_device());
-    const int     m                         = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
-    const int     n                         = b->dimension(0);
-    const int     k                         = a->dimension(0);
-    constexpr int mult_transpose1xW_width   = 1;
-    constexpr int mult_interleave4x4_height = 1;
-    const int     depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
-    rhs_info.n0                             = 16 / b->element_size();
-    rhs_info.k0                             = 1;
-    rhs_info.h0                             = mult_transpose1xW_width;
-    rhs_info.interleave                     = false;
-    rhs_info.transpose                      = false;
-    lhs_info.m0                             = 4;
-    lhs_info.k0                             = 4;
-    lhs_info.v0                             = mult_interleave4x4_height;
-    lhs_info.interleave                     = true;
-    lhs_info.transpose                      = !unroll_block;
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+    const unsigned int n                       = b->dimension(0);
+    const unsigned int k                       = a->dimension(0);
+    const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
+    const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
 
-    bool reshape_matrices = is_interleaved_transposed(m, n, k, gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target());
+    bool reshape_matrices = is_gemm_reshaped(m, gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target());
 
     // if reshape_matrices is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
     if(reshape_matrices)
@@ -278,13 +251,16 @@
         reinterpret_input_as_3d = false;
     }
 
-    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, reinterpret_input_as_3d);
+    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d);
 
     if(reshape_matrices)
     {
         matrix_a_info = &tmp_a_info;
         matrix_b_info = &tmp_b_info;
 
+        // Pick up the GEMM configuration
+        std::tie(lhs_info, rhs_info) = CLGEMMReshapedConfigurationFactory::create()->configure(m, n, k, batch_size, DataType::QASYMM8);
+
         // Validate interleave kernel
         auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
@@ -319,12 +295,22 @@
     {
         TensorInfo mm_result_s32_info{};
 
-        // Output tensor auto inizialitation if not yet initialized
-        auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, reshape_matrices, reshape_info)).set_data_type(DataType::S32));
+        if(reshape_matrices)
+        {
+            // Output tensor auto inizialitation if not yet initialized
+            auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, reshape_info)).set_data_type(DataType::S32));
 
-        // Validate matrix multiply
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, reshape_matrices, reshape_info));
+            // Validate matrix multiply
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyReshapedKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, lhs_info, rhs_info, reshape_info));
+        }
+        else
+        {
+            // Output tensor auto inizialitation if not yet initialized
+            auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, false, reshape_info)).set_data_type(DataType::S32));
 
+            // Validate matrix multiply
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, false, reshape_info));
+        }
         // Validate offset contribution kernel
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOffsetContributionOutputStageKernel::validate(&mm_result_s32_info,
                                                                                             a_offset == 0 ? nullptr : &info_vector_sum_col,
@@ -336,9 +322,16 @@
     }
     else
     {
-        // Validate matrix multiply
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, reshape_matrices, reshape_info));
-
+        if(reshape_matrices)
+        {
+            // Validate matrix multiply
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyReshapedKernel::validate(matrix_a_info, matrix_b_info, output, lhs_info, rhs_info, reshape_info));
+        }
+        else
+        {
+            // Validate matrix multiply
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, false, reshape_info));
+        }
         // Validate offset contribution kernel
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOffsetContributionKernel::validate(output,
                                                                                  a_offset == 0 ? nullptr : &info_vector_sum_col,
@@ -356,7 +349,7 @@
 
     _memory_group.acquire();
 
-    if(_is_interleaved_transposed)
+    if(_is_gemm_reshaped)
     {
         // Run reshape matrix A
         CLScheduler::get().enqueue(_mtx_a_reshape_kernel, false);
@@ -375,7 +368,14 @@
     }
 
     // Run matrix multiply
-    CLScheduler::get().enqueue(_mm_kernel, false);
+    if(_is_gemm_reshaped)
+    {
+        CLScheduler::get().enqueue(_mm_reshaped_kernel, false);
+    }
+    else
+    {
+        CLScheduler::get().enqueue(_mm_kernel, false);
+    }
 
     // Run matrix A reduction kernel only if _b_offset is not equal to 0
     if(_b_offset != 0)
@@ -401,7 +401,7 @@
 {
     if(!_is_prepared)
     {
-        if(_is_interleaved_transposed && _reshape_b_only_on_first_run)
+        if(_is_gemm_reshaped && _reshape_b_only_on_first_run)
         {
             ARM_COMPUTE_ERROR_ON(!_original_b->is_used());