COMPMID-1687: Optimize CLGEMMMatrixMultiplyKernel for Mali-G76 - Part1

The current implementation is limited just to FP32

Change-Id: I185ab57e483e879d7c301e9cc3033efc8b41e244
Reviewed-on: https://review.mlplatform.org/389
Reviewed-by: Anthony Barbier <Anthony.barbier@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index baa0cf4..d0db876 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -40,25 +40,32 @@
 
 namespace
 {
-inline bool is_interleaved_transposed(int m, int n, int k, DataType data_type, bool reshape_b_only_on_first_run, GPUTarget gpu_target)
+inline bool is_interleaved_transposed(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run, GPUTarget gpu_target)
 {
     bool flag = true;
 
     if(gpu_target_is_in(gpu_target, GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72, GPUTarget::G76))
     {
-        // COMPMID-852
-        if(k > 256 && m > 4 && is_data_type_float(data_type) && reshape_b_only_on_first_run)
+        if((m > 1) && n < 16)
         {
-            constexpr float alpha = 3.2f;
-            constexpr float fact0 = 1.51f;
-            constexpr float fact1 = 1.66f;
-            constexpr float ops   = 12.0f;
-            const float     scale = k > 1024 ? 1.07f : 1.0f;
-            flag                  = alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops);
+            flag = true;
         }
         else
         {
-            flag = false;
+            // COMPMID-852
+            if(k > 256 && m > 4 && is_data_type_float(data_type) && reshape_b_only_on_first_run)
+            {
+                constexpr float alpha = 3.2f;
+                constexpr float fact0 = 1.51f;
+                constexpr float fact1 = 1.66f;
+                constexpr float ops   = 12.0f;
+                const float     scale = k > 1024 ? 1.07f : 1.0f;
+                flag                  = alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops);
+            }
+            else
+            {
+                flag = false;
+            }
         }
     }
     else
@@ -69,6 +76,43 @@
 
     return flag;
 }
+
+inline void select_gemm_configuration(unsigned int m, unsigned int n, GEMMLHSMatrixInfo &lhs_info, GEMMRHSMatrixInfo &rhs_info)
+{
+    // Heuristic selection for GEMM
+    if(n <= 4)
+    {
+        // Configure GEMMLHSMatrixInfo
+        lhs_info.m0         = 4;
+        lhs_info.k0         = 8;
+        lhs_info.v0         = lhs_info.m0 * 16 < m ? 2 : 16;
+        lhs_info.interleave = true;
+        lhs_info.transpose  = false;
+
+        // Configure GEMMRHSMatrixInfo
+        rhs_info.n0         = 2;
+        rhs_info.k0         = lhs_info.k0;
+        rhs_info.h0         = rhs_info.n0 * 16 < n ? 2 : 16;
+        rhs_info.interleave = false;
+        rhs_info.transpose  = true;
+    }
+    else
+    {
+        // Configure GEMMLHSMatrixInfo
+        lhs_info.m0         = (m * n) / 24 > 2048 ? 6 : 5;
+        lhs_info.k0         = 4;
+        lhs_info.v0         = 32;
+        lhs_info.interleave = false;
+        lhs_info.transpose  = false;
+
+        // Configure GEMMRHSMatrixInfo
+        rhs_info.n0         = 4;
+        rhs_info.k0         = lhs_info.k0;
+        rhs_info.h0         = 32;
+        rhs_info.interleave = true;
+        rhs_info.transpose  = true;
+    }
+}
 } // namespace
 
 CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager)
@@ -77,13 +121,17 @@
       _transpose_kernel(),
       _mm_kernel(),
       _ma_kernel(),
+      _reshape_lhs_kernel(),
+      _reshape_rhs_kernel(),
+      _mm_reshaped_kernel(),
       _tmp_a(),
       _tmp_b(),
       _original_b(nullptr),
       _is_interleaved_transposed(false),
       _run_addition(false),
       _reshape_b_only_on_first_run(false),
-      _is_prepared(false)
+      _is_prepared(false),
+      _is_G76_path(false)
 {
 }
 
@@ -112,13 +160,14 @@
     // Arguments used by GEMMReshapeInfo
     // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
     // in order to know how the matrices have been reshaped
-    bool      reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
-    const int m                         = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
-    const int n                         = b->info()->dimension(0);
-    const int k                         = a->info()->dimension(0);
-    const int depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
-    int       mult_transpose1xW_width   = 1;
-    int       mult_interleave4x4_height = 1;
+    DataType           data_type                 = a->info()->data_type();
+    bool               reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                         = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
+    const unsigned int n                         = b->info()->dimension(0);
+    const unsigned int k                         = a->info()->dimension(0);
+    const int          depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
+    int                mult_transpose1xW_width   = 1;
+    int                mult_interleave4x4_height = 1;
 
     if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
     {
@@ -129,6 +178,10 @@
     // Check if we need to reshape the matrix A and matrix B
     _is_interleaved_transposed = is_interleaved_transposed(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target);
 
+    // Check if we can run the new reshaped GEMM
+    _is_G76_path = (gpu_target == GPUTarget::G76) && _is_interleaved_transposed && (data_type == DataType::F32);
+    ;
+
     // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
     if(_is_interleaved_transposed)
     {
@@ -145,19 +198,40 @@
         }
         // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
 
-        // Configure interleave kernel
-        _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d());
+        if(_is_G76_path)
+        {
+            GEMMLHSMatrixInfo lhs_info;
+            GEMMRHSMatrixInfo rhs_info;
 
-        // Configure transpose kernel
-        _transpose_kernel.configure(b, &_tmp_b, mult_transpose1xW_width);
+            // Pick up the GEMM configuration based on M,N and K
+            select_gemm_configuration(m, n, lhs_info, rhs_info);
+
+            _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
+            _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+
+            // Configure and tune matrix multiply kernel
+            _mm_reshaped_kernel.configure(matrix_a, matrix_b, output, alpha, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1,
+                                                                                                                 depth_output_gemm3d, reinterpret_input_as_3d));
+        }
+        else
+        {
+            // Configure interleave kernel
+            _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d());
+
+            // Configure transpose kernel
+            _transpose_kernel.configure(b, &_tmp_b, mult_transpose1xW_width);
+        }
     }
 
-    // Configure and tune matrix multiply kernel
-    _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k,
-                                                                                                        mult_transpose1xW_width, mult_interleave4x4_height,
-                                                                                                        depth_output_gemm3d, reinterpret_input_as_3d),
-                         gemm_info.fp_mixed_precision());
-    CLScheduler::get().tune_kernel_static(_mm_kernel);
+    if(!_is_G76_path)
+    {
+        // Configure and tune matrix multiply kernel
+        _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k,
+                                                                                                            mult_transpose1xW_width, mult_interleave4x4_height,
+                                                                                                            depth_output_gemm3d, reinterpret_input_as_3d),
+                             gemm_info.fp_mixed_precision());
+        CLScheduler::get().tune_kernel_static(_mm_kernel);
+    }
 
     if(_is_interleaved_transposed)
     {
@@ -197,13 +271,14 @@
     // Arguments used by GEMMReshapeInfo
     // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
     // in order to know how the matrices have been reshaped
-    bool      reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
-    const int m                         = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
-    const int n                         = b->dimension(0);
-    const int k                         = a->dimension(0);
-    int       mult_transpose1xW_width   = 1;
-    int       mult_interleave4x4_height = 1;
-    const int depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
+    DataType           data_type                 = a->data_type();
+    bool               reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                         = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+    const unsigned int n                         = b->dimension(0);
+    const unsigned int k                         = a->dimension(0);
+    int                mult_transpose1xW_width   = 1;
+    int                mult_interleave4x4_height = 1;
+    const int          depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
 
     if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
     {
@@ -214,6 +289,9 @@
     // Check if we need to reshape the matrix A and matrix B
     const bool run_interleave_transpose = is_interleaved_transposed(m, n, k, a->data_type(), reshape_b_only_on_first_run, gpu_target);
 
+    // Check if we can run the new reshaped GEMM
+    const bool is_G76_path = (gpu_target == GPUTarget::G76) && run_interleave_transpose && (data_type == DataType::F32);
+
     // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
     if(run_interleave_transpose)
     {
@@ -227,17 +305,41 @@
         matrix_a_info = &tmp_a_info;
         matrix_b_info = &tmp_b_info;
 
-        // Validate interleave kernel
-        auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())));
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()));
+        if(is_G76_path)
+        {
+            GEMMLHSMatrixInfo lhs_info;
+            GEMMRHSMatrixInfo rhs_info;
 
-        // Validate transpose kernel
-        auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMTranspose1xWKernel::validate(b, &tmp_b_info, mult_transpose1xW_width));
+            // Pick up the GEMM configuration based on M,N and K
+            select_gemm_configuration(m, n, lhs_info, rhs_info);
+
+            auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
+
+            auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
+
+            // Validate matrix multiply
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedKernel::validate(matrix_a_info, matrix_b_info, output, alpha, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1,
+                                                                                     depth_output_gemm3d, reinterpret_input_as_3d)));
+        }
+        else
+        {
+            // Validate interleave kernel
+            auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())));
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()));
+
+            // Validate transpose kernel
+            auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMTranspose1xWKernel::validate(b, &tmp_b_info, mult_transpose1xW_width));
+        }
     }
 
-    // Validate matrix multiply
-    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, alpha, run_interleave_transpose, reshape_info, gpu_target, gemm_info.fp_mixed_precision()));
+    if(!is_G76_path)
+    {
+        // Validate matrix multiply
+        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, alpha, run_interleave_transpose, reshape_info, gpu_target, gemm_info.fp_mixed_precision()));
+    }
 
     if(beta != 0 && c != nullptr)
     {
@@ -257,17 +359,38 @@
     if(_is_interleaved_transposed)
     {
         // Run interleave kernel
-        CLScheduler::get().enqueue(_interleave_kernel, false);
+        if(_is_G76_path)
+        {
+            CLScheduler::get().enqueue(_reshape_lhs_kernel, false);
+        }
+        else
+        {
+            CLScheduler::get().enqueue(_interleave_kernel, false);
+        }
 
         if(!_reshape_b_only_on_first_run)
         {
             // Run transpose kernel
-            CLScheduler::get().enqueue(_transpose_kernel, false);
+            if(_is_G76_path)
+            {
+                CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+            }
+            else
+            {
+                CLScheduler::get().enqueue(_transpose_kernel, false);
+            }
         }
     }
 
     // Run matrix multiply kernel
-    CLScheduler::get().enqueue(_mm_kernel, !_run_addition);
+    if(_is_G76_path)
+    {
+        CLScheduler::get().enqueue(_mm_reshaped_kernel, !_run_addition);
+    }
+    else
+    {
+        CLScheduler::get().enqueue(_mm_kernel, !_run_addition);
+    }
 
     // Run matrix addition kernel
     if(_run_addition)
@@ -286,7 +409,14 @@
         {
             // Run transpose kernel and mark original weights tensor as unused
             _tmp_b.allocator()->allocate();
-            CLScheduler::get().enqueue(_transpose_kernel, false);
+            if(_is_G76_path)
+            {
+                CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+            }
+            else
+            {
+                CLScheduler::get().enqueue(_transpose_kernel, false);
+            }
             _original_b->mark_as_unused();
         }
         CLScheduler::get().queue().finish();