COMPMID-1698: Implementing CLGEMMLowpMatrixMultiplyReshapedKernel

Change-Id: Ia4db21b394a0b9235393202ce3c00b11cceb94ea
Reviewed-on: https://review.mlplatform.org/568
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 4b72878..2a01db7 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -31,43 +31,25 @@
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
+#include "arm_compute/runtime/CL/gemm_reshaped/CLGEMMReshapedConfiguration.h"
 
 namespace arm_compute
 {
 using namespace arm_compute::misc::shape_calculator;
+using namespace arm_compute::cl_gemm;
 
 namespace
 {
-inline bool is_interleaved_transposed(int m, int n, int k, bool reshape_b_only_on_first_run, GPUTarget gpu_target)
+inline bool is_gemm_reshaped(unsigned int m, bool reshape_b_only_on_first_run, GPUTarget gpu_target)
 {
-    bool flag = true;
-
-    if(gpu_target_is_in(gpu_target,
-                        GPUTarget::G71, GPUTarget::G72,
-                        GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT))
-    {
-        // COMPMID-852
-        if(k > 256 && m > 4 && reshape_b_only_on_first_run)
-        {
-            flag = ((0.72f + n * 0.10766f) < (n * 0.1284f));
-        }
-        else
-        {
-            flag = false;
-        }
-    }
-    else
-    {
-        flag = m > 1;
-    }
-
-    return flag;
+    return (get_arch_from_target(gpu_target) != GPUTarget::MIDGARD) && (m > 1) && (reshape_b_only_on_first_run);
 }
 } // namespace
 
 CLGEMMLowpMatrixMultiplyCore::CLGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager)
     : _memory_group(std::move(memory_manager)),
       _mm_kernel(),
+      _mm_reshaped_kernel(),
       _mtx_a_reshape_kernel(),
       _mtx_b_reshape_kernel(),
       _mtx_a_reduction_kernel(),
@@ -82,7 +64,7 @@
       _original_b(nullptr),
       _a_offset(0),
       _b_offset(0),
-      _is_interleaved_transposed(true),
+      _is_gemm_reshaped(true),
       _reshape_b_only_on_first_run(false),
       _is_prepared(false),
       _fuse_output_stage(false)
@@ -115,29 +97,17 @@
     // Arguments used by GEMMReshapeInfo
     // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
     // in order to know how the matrices have been reshaped
-    bool          reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
-    const bool    unroll_block              = dot8_supported(CLKernelLibrary::get().get_device());
-    const int     m                         = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
-    const int     n                         = b->info()->dimension(0);
-    const int     k                         = a->info()->dimension(0);
-    const int     depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
-    constexpr int mult_transpose1xW_width   = 1;
-    constexpr int mult_interleave4x4_height = 1;
-    rhs_info.n0                             = 16 / b->info()->element_size();
-    rhs_info.k0                             = 1;
-    rhs_info.h0                             = mult_transpose1xW_width;
-    rhs_info.interleave                     = false;
-    rhs_info.transpose                      = false;
-    lhs_info.m0                             = 4;
-    lhs_info.k0                             = 4;
-    lhs_info.v0                             = mult_interleave4x4_height;
-    lhs_info.interleave                     = true;
-    lhs_info.transpose                      = !unroll_block;
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
+    const unsigned int n                       = b->info()->dimension(0);
+    const unsigned int k                       = a->info()->dimension(0);
+    const unsigned int batch_size              = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
+    const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
 
     // Check if we need to reshape the matrix A and matrix B
-    _is_interleaved_transposed = is_interleaved_transposed(m, n, k, _reshape_b_only_on_first_run, gpu_target);
+    _is_gemm_reshaped = is_gemm_reshaped(m, _reshape_b_only_on_first_run, gpu_target);
 
-    if(_is_interleaved_transposed)
+    if(_is_gemm_reshaped)
     {
         // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
         reinterpret_input_as_3d = false;
@@ -151,6 +121,9 @@
             _memory_group.manage(&_tmp_b);
         }
 
+        // Pick up the GEMM configuration
+        std::tie(lhs_info, rhs_info) = CLGEMMReshapedConfigurationFactory::create()->configure(m, n, k, batch_size, DataType::QASYMM8);
+
         // Configure interleave kernel
         _mtx_a_reshape_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
 
@@ -190,10 +163,16 @@
 
         _memory_group.manage(&_mm_result_s32);
 
-        // Configure matrix multiply kernel
-        _mm_kernel.configure(matrix_a, matrix_b, &_mm_result_s32, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k,
-                                                                                                              mult_transpose1xW_width, mult_interleave4x4_height,
-                                                                                                              depth_output_gemm3d, reinterpret_input_as_3d));
+        if(_is_gemm_reshaped)
+        {
+            // Configure and tune matrix multiply kernel
+            _mm_reshaped_kernel.configure(matrix_a, matrix_b, &_mm_result_s32, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+        }
+        else
+        {
+            // Configure matrix multiply kernel
+            _mm_kernel.configure(matrix_a, matrix_b, &_mm_result_s32, false, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+        }
 
         // Configure offset contribution kernel
         _offset_contribution_output_stage_kernel.configure(&_mm_result_s32, _a_offset == 0 ? nullptr : &_vector_sum_col, _b_offset == 0 ? nullptr : &_vector_sum_row, c, output, a->info()->dimension(0),
@@ -203,17 +182,23 @@
     }
     else
     {
-        // Configure matrix multiply kernel
-        _mm_kernel.configure(matrix_a, matrix_b, output, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k,
-                                                                                                     mult_transpose1xW_width, mult_interleave4x4_height,
-                                                                                                     depth_output_gemm3d, reinterpret_input_as_3d));
+        if(_is_gemm_reshaped)
+        {
+            // Configure and tune matrix multiply kernel
+            _mm_reshaped_kernel.configure(matrix_a, matrix_b, output, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+        }
+        else
+        {
+            // Configure matrix multiply kernel
+            _mm_kernel.configure(matrix_a, matrix_b, output, false, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+        }
 
         // Configure offset contribution kernel
         _offset_contribution_kernel.configure(output, _a_offset == 0 ? nullptr : &_vector_sum_col, _b_offset == 0 ? nullptr : &_vector_sum_row, c, a->info()->dimension(0), _a_offset, _b_offset);
     }
 
     // Allocate tensors
-    if(_is_interleaved_transposed)
+    if(_is_gemm_reshaped)
     {
         _tmp_a.allocator()->allocate();
         if(!_reshape_b_only_on_first_run)
@@ -251,26 +236,14 @@
     GEMMRHSMatrixInfo rhs_info;
     GEMMLHSMatrixInfo lhs_info;
 
-    bool          reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
-    const bool    unroll_block              = dot8_supported(CLKernelLibrary::get().get_device());
-    const int     m                         = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
-    const int     n                         = b->dimension(0);
-    const int     k                         = a->dimension(0);
-    constexpr int mult_transpose1xW_width   = 1;
-    constexpr int mult_interleave4x4_height = 1;
-    const int     depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
-    rhs_info.n0                             = 16 / b->element_size();
-    rhs_info.k0                             = 1;
-    rhs_info.h0                             = mult_transpose1xW_width;
-    rhs_info.interleave                     = false;
-    rhs_info.transpose                      = false;
-    lhs_info.m0                             = 4;
-    lhs_info.k0                             = 4;
-    lhs_info.v0                             = mult_interleave4x4_height;
-    lhs_info.interleave                     = true;
-    lhs_info.transpose                      = !unroll_block;
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+    const unsigned int n                       = b->dimension(0);
+    const unsigned int k                       = a->dimension(0);
+    const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
+    const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
 
-    bool reshape_matrices = is_interleaved_transposed(m, n, k, gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target());
+    bool reshape_matrices = is_gemm_reshaped(m, gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target());
 
     // if reshape_matrices is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
     if(reshape_matrices)
@@ -278,13 +251,16 @@
         reinterpret_input_as_3d = false;
     }
 
-    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, reinterpret_input_as_3d);
+    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d);
 
     if(reshape_matrices)
     {
         matrix_a_info = &tmp_a_info;
         matrix_b_info = &tmp_b_info;
 
+        // Pick up the GEMM configuration
+        std::tie(lhs_info, rhs_info) = CLGEMMReshapedConfigurationFactory::create()->configure(m, n, k, batch_size, DataType::QASYMM8);
+
         // Validate interleave kernel
         auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
@@ -319,12 +295,22 @@
     {
         TensorInfo mm_result_s32_info{};
 
-        // Output tensor auto inizialitation if not yet initialized
-        auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, reshape_matrices, reshape_info)).set_data_type(DataType::S32));
+        if(reshape_matrices)
+        {
+            // Output tensor auto inizialitation if not yet initialized
+            auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, reshape_info)).set_data_type(DataType::S32));
 
-        // Validate matrix multiply
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, reshape_matrices, reshape_info));
+            // Validate matrix multiply
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyReshapedKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, lhs_info, rhs_info, reshape_info));
+        }
+        else
+        {
+            // Output tensor auto inizialitation if not yet initialized
+            auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, false, reshape_info)).set_data_type(DataType::S32));
 
+            // Validate matrix multiply
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, false, reshape_info));
+        }
         // Validate offset contribution kernel
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOffsetContributionOutputStageKernel::validate(&mm_result_s32_info,
                                                                                             a_offset == 0 ? nullptr : &info_vector_sum_col,
@@ -336,9 +322,16 @@
     }
     else
     {
-        // Validate matrix multiply
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, reshape_matrices, reshape_info));
-
+        if(reshape_matrices)
+        {
+            // Validate matrix multiply
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyReshapedKernel::validate(matrix_a_info, matrix_b_info, output, lhs_info, rhs_info, reshape_info));
+        }
+        else
+        {
+            // Validate matrix multiply
+            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, output, false, reshape_info));
+        }
         // Validate offset contribution kernel
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOffsetContributionKernel::validate(output,
                                                                                  a_offset == 0 ? nullptr : &info_vector_sum_col,
@@ -356,7 +349,7 @@
 
     _memory_group.acquire();
 
-    if(_is_interleaved_transposed)
+    if(_is_gemm_reshaped)
     {
         // Run reshape matrix A
         CLScheduler::get().enqueue(_mtx_a_reshape_kernel, false);
@@ -375,7 +368,14 @@
     }
 
     // Run matrix multiply
-    CLScheduler::get().enqueue(_mm_kernel, false);
+    if(_is_gemm_reshaped)
+    {
+        CLScheduler::get().enqueue(_mm_reshaped_kernel, false);
+    }
+    else
+    {
+        CLScheduler::get().enqueue(_mm_kernel, false);
+    }
 
     // Run matrix A reduction kernel only if _b_offset is not equal to 0
     if(_b_offset != 0)
@@ -401,7 +401,7 @@
 {
     if(!_is_prepared)
     {
-        if(_is_interleaved_transposed && _reshape_b_only_on_first_run)
+        if(_is_gemm_reshaped && _reshape_b_only_on_first_run)
         {
             ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
 
diff --git a/src/runtime/CL/gemm_reshaped/CLGEMMReshapedConfigurationBifrost.cpp b/src/runtime/CL/gemm_reshaped/CLGEMMReshapedConfigurationBifrost.cpp
index 079a52e..cd97849 100644
--- a/src/runtime/CL/gemm_reshaped/CLGEMMReshapedConfigurationBifrost.cpp
+++ b/src/runtime/CL/gemm_reshaped/CLGEMMReshapedConfigurationBifrost.cpp
@@ -32,18 +32,62 @@
 {
 namespace cl_gemm
 {
+namespace
+{
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_gemm_reshaped(unsigned int m, unsigned int n, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int v0, unsigned int h0,
+                                                                        bool lhs_interleave, bool rhs_interleave)
+{
+    GEMMLHSMatrixInfo lhs_info;
+    GEMMRHSMatrixInfo rhs_info;
+
+    // Configure GEMMLHSMatrixInfo
+    lhs_info.m0         = m0;
+    lhs_info.k0         = k0;
+    lhs_info.v0         = ((m / (lhs_info.m0 * v0)) == 0) ? 1 : v0;
+    lhs_info.interleave = lhs_interleave;
+    lhs_info.transpose  = false;
+
+    // Configure GEMMRHSMatrixInfo
+    rhs_info.n0         = n0;
+    rhs_info.k0         = lhs_info.k0;
+    rhs_info.h0         = ((n / (rhs_info.n0 * h0)) == 0) ? 1 : h0;
+    rhs_info.interleave = rhs_interleave;
+    rhs_info.transpose  = true;
+
+    return std::make_pair(lhs_info, rhs_info);
+}
+
+} // namespace
+
 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
 {
-    ARM_COMPUTE_ERROR_ON(data_type != DataType::F32);
+    ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8);
     ARM_COMPUTE_UNUSED(data_type);
 
     const GPUTarget gpu_target = CLScheduler::get().target();
+
+    using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+
+    // Configurations for Mali-G76
+    static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_reshaped_configs_G76 =
+    {
+        { DataType::F32, &CLGEMMReshapedConfigurationBifrost::configure_G76_f32 },
+        { DataType::QASYMM8, &CLGEMMReshapedConfigurationBifrost::configure_G76_u8 }
+    };
+
+    // Configurations for Mali-G7x
+    static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_reshaped_configs_G7x =
+    {
+        { DataType::F32, &CLGEMMReshapedConfigurationBifrost::configure_G7x_f32 },
+        { DataType::QASYMM8, &CLGEMMReshapedConfigurationBifrost::configure_G7x_u8 }
+    };
+
     switch(gpu_target)
     {
         case GPUTarget::G76:
-            return configure_G76_f32(m, n, k, b);
+            return (this->*gemm_reshaped_configs_G76[data_type])(m, n, k, b);
         default:
-            return configure_G7x_f32(m, n, k, b);
+            return (this->*gemm_reshaped_configs_G7x[data_type])(m, n, k, b);
     }
 }
 
@@ -52,43 +96,43 @@
     ARM_COMPUTE_UNUSED(k);
     ARM_COMPUTE_UNUSED(b);
 
-    GEMMLHSMatrixInfo lhs_info;
-    GEMMRHSMatrixInfo rhs_info;
-
     if(n <= 4)
     {
-        // Configure GEMMLHSMatrixInfo
-        lhs_info.m0         = 4;
-        lhs_info.k0         = 8;
-        lhs_info.v0         = lhs_info.m0 * 16 < m ? 2 : 16;
-        lhs_info.interleave = true;
-        lhs_info.transpose  = false;
-
-        // Configure GEMMRHSMatrixInfo
-        rhs_info.n0         = 2;
-        rhs_info.k0         = lhs_info.k0;
-        rhs_info.h0         = rhs_info.n0 * 16 < n ? 2 : 16;
-        rhs_info.interleave = false;
-        rhs_info.transpose  = true;
+        return configure_gemm_reshaped(m, n, 4, 2, 8, 16, 16, true, false);
     }
     else
     {
-        // Configure GEMMLHSMatrixInfo
-        lhs_info.m0         = 5;
-        lhs_info.k0         = 4;
-        lhs_info.v0         = lhs_info.m0 * 2 < m ? 1 : 2;
-        lhs_info.interleave = false;
-        lhs_info.transpose  = false;
-
-        // Configure GEMMRHSMatrixInfo
-        rhs_info.n0         = 4;
-        rhs_info.k0         = lhs_info.k0;
-        rhs_info.h0         = rhs_info.n0 * 16 < n ? 2 : 16;
-        rhs_info.interleave = true;
-        rhs_info.transpose  = true;
+        return configure_gemm_reshaped(m, n, 5, 4, 4, 2, 16, false, true);
     }
+}
 
-    return std::make_pair(lhs_info, rhs_info);
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+    ARM_COMPUTE_UNUSED(k);
+    ARM_COMPUTE_UNUSED(b);
+
+    if(dot8_supported(CLKernelLibrary::get().get_device()))
+    {
+        if(n <= 4)
+        {
+            return configure_gemm_reshaped(m, n, 4, 2, 16, 2, 2, true, false);
+        }
+        else
+        {
+            return configure_gemm_reshaped(m, n, 4, 4, 16, 2, 2, true, false);
+        }
+    }
+    else
+    {
+        if(n <= 4)
+        {
+            return configure_gemm_reshaped(m, n, 4, 2, 8, 2, 2, true, false);
+        }
+        else
+        {
+            return configure_gemm_reshaped(m, n, 6, 4, 4, 2, 2, true, true);
+        }
+    }
 }
 
 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedConfigurationBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
@@ -96,43 +140,29 @@
     ARM_COMPUTE_UNUSED(k);
     ARM_COMPUTE_UNUSED(b);
 
-    GEMMLHSMatrixInfo lhs_info;
-    GEMMRHSMatrixInfo rhs_info;
-
     if(n <= 4)
     {
-        // Configure GEMMLHSMatrixInfo
-        lhs_info.m0         = 4;
-        lhs_info.k0         = 8;
-        lhs_info.v0         = lhs_info.m0 * 16 < m ? 2 : 16;
-        lhs_info.interleave = true;
-        lhs_info.transpose  = false;
-
-        // Configure GEMMRHSMatrixInfo
-        rhs_info.n0         = 2;
-        rhs_info.k0         = lhs_info.k0;
-        rhs_info.h0         = rhs_info.n0 * 16 < n ? 2 : 16;
-        rhs_info.interleave = false;
-        rhs_info.transpose  = true;
+        return configure_gemm_reshaped(m, n, 4, 2, 8, 16, 16, true, false);
     }
     else
     {
-        // Configure GEMMLHSMatrixInfo
-        lhs_info.m0         = 4;
-        lhs_info.k0         = 2;
-        lhs_info.v0         = lhs_info.m0 * 8 < m ? 2 : 8;
-        lhs_info.interleave = false;
-        lhs_info.transpose  = false;
-
-        // Configure GEMMRHSMatrixInfo
-        rhs_info.n0         = 4;
-        rhs_info.k0         = lhs_info.k0;
-        rhs_info.h0         = rhs_info.n0 * 16 < n ? 2 : 16;
-        rhs_info.interleave = false;
-        rhs_info.transpose  = true;
+        return configure_gemm_reshaped(m, n, 4, 4, 2, 8, 16, false, false);
     }
+}
 
-    return std::make_pair(lhs_info, rhs_info);
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+    ARM_COMPUTE_UNUSED(k);
+    ARM_COMPUTE_UNUSED(b);
+
+    if(n <= 4)
+    {
+        return configure_gemm_reshaped(m, n, 4, 2, 16, 4, 1, false, false);
+    }
+    else
+    {
+        return configure_gemm_reshaped(m, n, 4, 4, 16, 2, 2, false, true);
+    }
 }
 } // namespace cl_gemm
 } // namespace arm_compute
\ No newline at end of file