COMPMID-2378: Sanitize GEMM configuration for NEON

Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1418
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index 55bcc45..2f36397 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -58,17 +58,19 @@
     _run_vector_matrix_multiplication = a->info()->dimension(1) < 2;
     _original_b                       = b;
 
-    bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, _reshape_b_only_on_first_run));
+    bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, gemm_info));
 
     if(run_optimised)
     {
         if(MEMInfo::get_policy() == MemoryPolicy::MINIMIZE)
         {
-            _asm_glue.configure(a, b, d, alpha, beta, false);
+            GEMMInfo gemm_info_ntb = gemm_info;
+            gemm_info_ntb.set_pretranpose_B(false);
+            _asm_glue.configure(a, b, d, alpha, beta, gemm_info_ntb);
         }
         else
         {
-            _asm_glue.configure(a, b, d, alpha, beta, _reshape_b_only_on_first_run);
+            _asm_glue.configure(a, b, d, alpha, beta, gemm_info);
         }
         ARM_COMPUTE_ERROR_ON(!_asm_glue.is_configured());
     }
@@ -176,7 +178,7 @@
     }
 
     // Check if we need to run the optimized assembly kernel
-    const bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a, b, output, alpha, beta, true));
+    const bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a, b, output, alpha, beta, gemm_info));
 
     if(!run_optimised)
     {
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 55e067f..2de7d2b 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -36,21 +36,22 @@
 namespace
 {
 std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescription &gemm_kernel_info,
-                                                     const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+                                                     const ITensor *a, const ITensor *b, ITensor *d,
+                                                     float alpha, float beta, const GEMMInfo &gemm_info,
                                                      std::shared_ptr<IMemoryManager> memory_manager)
 
 {
-    //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
+    // Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
     switch(gemm_kernel_info.method)
     {
         case arm_gemm::GemmMethod::GEMM_INTERLEAVED:
         {
-            if(!pretranspose_hint)
+            if(!gemm_info.pretranpose_B())
             {
                 return nullptr;
             }
             auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
-            function->configure(a, b, d, alpha, beta, pretranspose_hint);
+            function->configure(a, b, d, alpha, beta, gemm_info);
             return std::move(function);
         }
 #if defined(__aarch64__)
@@ -59,7 +60,7 @@
             if(gemm_kernel_info.name.find("sgemm_native_16x4") != std::string::npos)
             {
                 auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
-                kernel->configure(a, b, d, alpha, beta);
+                kernel->configure(a, b, d, alpha, beta, gemm_info);
                 auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
                 function->configure(std::move(kernel));
                 return std::move(function);
@@ -83,9 +84,11 @@
      * @param[in]  b            Input tensor containing the Matrix B.
      * @param[out] d            Output tensor to store the result of matrix multiplication.
      * @param[in]  args         Matrix multiplication information.
+     * @param[in]  gemm_info    GEMM meta-data
      * @param[in]  memory_group Memory group to be used by the function.
      */
-    void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group);
+    void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args,
+                   const GEMMInfo &gemm_info, MemoryGroup &memory_group);
 
     // Inherited methods overridden:
     void run() override;
@@ -123,10 +126,13 @@
     Tensor _pretranspose{};
     /** Prepared flag */
     bool _is_prepared{ false };
+    /** GEMM meta-data */
+    GEMMInfo _gemm_info{};
 };
 
 template <typename TypeInput, typename TypeOutput>
-void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group)
+void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args,
+                                                const GEMMInfo &gemm_info, MemoryGroup &memory_group)
 {
     arm_gemm::GemmConfig              gemm_cfg;
     const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args);
@@ -168,6 +174,7 @@
     _a                = a;
     _b                = b;
     _d                = d;
+    _gemm_info        = gemm_info;
     // Check for pre-transposed support
     if(_gemm_kernel_asm->B_pretranspose_required())
     {
@@ -222,17 +229,17 @@
     int       ldb = 0;
     const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
 
-    // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
-    // the relevant multiple of the row stride.
-    const bool is_nhwc           = _a->info()->data_layout() == DataLayout::NHWC;
-    const int  stride_in_bytes_a = is_nhwc ? _a->info()->strides_in_bytes().y() * _d->info()->dimension(1) : _a->info()->strides_in_bytes().z();
+    const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d() != 0 ? 3 : 2;
+    const size_t a_multi_idx = a_batch_idx + 1;
+    const size_t d_batch_idx = _gemm_info.depth_output_gemm3d() != 0 ? 3 : 2;
+    const size_t d_multi_idx = d_batch_idx + 1;
 
-    const int batch_stride_a = stride_in_bytes_a / sizeof(TypeInput);
-    const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput);
+    const int batch_stride_a = _a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput);
+    const int batch_stride_d = _d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput);
 
-    const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
+    const int multi_stride_a = _a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput);
     int       multi_stride_b = 0;
-    const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput);
+    const int multi_stride_d = _d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput);
 
     const auto       in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes());
     const TypeInput *in1_ptr = nullptr;
@@ -270,24 +277,27 @@
 }
 
 template <typename TypeInput, typename TypeOutput>
-void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b,
-                                 ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr<IMemoryManager> memory_manager)
+void create_function_or_arm_gemm(std::unique_ptr<IFunction>                         &acl_function,
+                                 std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm,
+                                 MemoryGroup &memory_group, const ITensor *a, const ITensor *b,
+                                 ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info,
+                                 std::shared_ptr<IMemoryManager> memory_manager)
 {
-    INEGEMMWrapperKernel::Params p           = INEGEMMWrapperKernel::extract_parameters(a, b, d);
+    INEGEMMWrapperKernel::Params p           = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
     const CPUInfo               &ci          = NEScheduler::get().cpu_info();
     unsigned int                 num_threads = NEScheduler::get().num_threads();
 
-    arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+    arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, gemm_info.pretranpose_B());
 
     //Try to create an ACL function:
-    acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, std::move(memory_manager));
+    acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, gemm_info, std::move(memory_manager));
 
     //If we still don't have an ACL function:
     if(acl_function == nullptr)
     {
         //Fallback onto arm_gemm function if ACL doesn't support this method.
         auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
-        fallback->configure(a, b, d, args, memory_group);
+        fallback->configure(a, b, d, args, gemm_info, memory_group);
         arm_gemm = std::move(fallback);
     }
 }
@@ -299,11 +309,11 @@
 {
 }
 
-Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint)
+Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info)
 {
     ARM_COMPUTE_UNUSED(alpha);
     ARM_COMPUTE_UNUSED(beta);
-    ARM_COMPUTE_UNUSED(pretranspose_hint);
+    ARM_COMPUTE_UNUSED(gemm_info);
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
 #ifndef __aarch64__
@@ -319,14 +329,14 @@
     return Status{};
 }
 
-void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(a);
     ARM_COMPUTE_ERROR_ON_NULLPTR(b);
     ARM_COMPUTE_ERROR_ON_NULLPTR(d);
 
     //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
-    if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, pretranspose_hint))
+    if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, gemm_info))
     {
         return;
     }
@@ -334,20 +344,20 @@
     switch(a->info()->data_type())
     {
         case DataType::F32:
-            create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+            create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager);
             break;
 #ifdef __aarch64__
         case DataType::U8:
         case DataType::QASYMM8:
-            create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+            create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager);
             break;
         case DataType::S8:
-            create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+            create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager);
             break;
 #endif /* __aarch64__ */
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
         case DataType::F16:
-            create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+            create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager);
             break;
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
         default:
diff --git a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
index ede89bf..5b70c87 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
@@ -59,7 +59,7 @@
         case DataType::QASYMM8:
         case DataType::U8:
         {
-            _asm_glue.configure(a, b, output, 1.f, 0.f, true);
+            _asm_glue.configure(a, b, output, 1.f, 0.f, GEMMInfo(false, false, true));
             run_optimised = _asm_glue.is_configured();
             break;
         }
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index d8773e3..f10f114 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -87,7 +87,7 @@
         case DataType::U8:
         case DataType::S8:
         {
-            _asm_glue.configure(a, b, _fuse_output_stage ? &_mm_result_s32 : output, 1.f, 0.f, _reshape_b_only_on_first_run);
+            _asm_glue.configure(a, b, _fuse_output_stage ? &_mm_result_s32 : output, 1.f, 0.f, gemm_info);
             _dot_product_path = _asm_glue.is_configured();
             break;
         }
@@ -224,9 +224,8 @@
     TensorInfo tmp_b_info{};
     TensorInfo mm_result_s32_info{};
 
-    int32_t    a_offset                    = a->quantization_info().uniform().offset;
-    int32_t    b_offset                    = b->quantization_info().uniform().offset;
-    const bool reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
+    int32_t a_offset = a->quantization_info().uniform().offset;
+    int32_t b_offset = b->quantization_info().uniform().offset;
 
     bool fuse_output_stage = gemm_info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE;
     if(fuse_output_stage)
@@ -235,7 +234,7 @@
     }
 
     // Check if we need to run the optimized assembly kernel
-    const bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, fuse_output_stage ? &mm_result_s32_info : output, 1.f, 0.f, reshape_b_only_on_first_run));
+    const bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, fuse_output_stage ? &mm_result_s32_info : output, 1.f, 0.f, gemm_info));
 
     if(run_optimised)
     {
diff --git a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
index 20aa149..ac809fa 100644
--- a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
+++ b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
@@ -339,19 +339,19 @@
     }
 }
 
-void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, bool pretranspose_b)
+void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info)
 {
-    _params         = INEGEMMWrapperKernel::extract_parameters(a, b, c);
+    _params         = INEGEMMWrapperKernel::extract_parameters(a, b, c, gemm_info);
     _a              = a;
     _b              = b;
     _c              = c;
-    _pretranspose_b = pretranspose_b;
+    _pretranspose_b = gemm_info.pretranpose_B();
 
     const DataType     input_type  = a->info()->data_type();
     const CPUInfo     &ci          = NEScheduler::get().cpu_info();
     const unsigned int num_threads = NEScheduler::get().num_threads();
 
-    const arm_gemm::KernelDescription gemm_kernel_info = get_gemm_info(input_type, ci, num_threads, _params, alpha, beta, pretranspose_b);
+    const arm_gemm::KernelDescription gemm_kernel_info = get_gemm_info(input_type, ci, num_threads, _params, alpha, beta, _pretranspose_b);
     ARM_COMPUTE_ERROR_ON(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMM_INTERLEAVED);
 
     // Forcing 128-byte alignment (required by 32-bit kernels)
@@ -411,8 +411,8 @@
     _memory_group.manage(&_transformed_a);
     _memory_group.manage(&_tmp_c);
 
-    _transform_a     = strategy->instantiate_transformA(_a, &_transformed_a, _block_walker, _params);
-    _matrix_multiply = strategy->instantiate_matrix_multiply(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, alpha, beta, pretranspose_b, num_threads);
+    _transform_a     = strategy->instantiate_transformA(_a, &_transformed_a, _block_walker, _params, gemm_info);
+    _matrix_multiply = strategy->instantiate_matrix_multiply(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, alpha, beta, gemm_info, num_threads);
     ARM_COMPUTE_ERROR_ON(_transform_a == nullptr);
     ARM_COMPUTE_ERROR_ON(_matrix_multiply == nullptr);