[ONCPUML-951] Variable weight support for Convolution.

API changes for NEGEMMConvolutionLayer and CpuGemmConv2d

Built with:

    scons neon=1 opencl=0 os=linux arch=armv8.2-a multi_isa=1 \
        build=native -j32 Werror=false validation_tests=1 build_dir=opt \
        standalone=1 asserts=1 experimental_fixed_format_kernels=1 .

Tested with:

    ./build/opt/tests/arm_compute_validation

Hardware where the test executable was run:

Neoverse N1

Test coverage:

* NEGEMMConvolutionLayer, CpuGemmConv2d
* NHWC (the only one supported by the fixed-format kernels)
* F16, F32
* Shapes: RunSmall

Change-Id: I4fd3e495a7cbf61210ea02d37440ba9652934e99
Signed-off-by: Francesco Petrogalli <francesco.petrogalli@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7632
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 787ea95..5694a3d 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -160,6 +160,13 @@
     void                             prepare(ITensorPack &tensors) override;
     bool                             is_configured() const override;
     experimental::MemoryRequirements workspace() const override;
+    bool                             isVarWeightsKernel() const override
+    {
+        if(!_gemm_kernel_asm)
+            return false;
+        const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format;
+        return wf != arm_gemm::WeightFormat::UNSPECIFIED && wf != arm_gemm::WeightFormat::ANY;
+    }
 
 private:
     enum AuxTensorIdx
@@ -420,6 +427,8 @@
         // Pretranspose B if required
         if(_gemm_kernel_asm->B_pretranspose_required())
         {
+            // Fixed format kernels need no pretranspose.
+            ARM_COMPUTE_ERROR_ON(arm_gemm::is_fixed_format(_gemm_kernel_asm->get_config().weight_format));
             const int  ldb            = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
             const auto in1_ptr        = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
             const int  multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
@@ -483,7 +492,24 @@
     // Check if B is pre-tranposed and de-reference if not
     if(!_gemm_kernel_asm->B_is_pretransposed())
     {
-        ldb            = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+        ldb                             = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+        const arm_gemm::WeightFormat wf = _gemm_kernel_asm->get_config().weight_format;
+        if(is_fixed_format(wf))
+        {
+            // The 4D tensor of dimension O'HWI' created for the
+            // OHWIo<interleave_by>i<block_by> format is in reality seen
+            // as a 2D tensor at arm_gemm level, where the rows are
+            // O'/<interleave_by> and the columns are <interleave_by> *
+            // H * W * I'.
+            ITensorInfo      *tensor_info   = b->info();
+            const DataLayout  data_layout   = tensor_info->data_layout();
+            const TensorShape tensor_shape  = tensor_info->tensor_shape();
+            const int         H             = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
+            const int         W             = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
+            const int         Ip            = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)];
+            const int         interleave_by = arm_gemm::interleave_by(wf);
+            ldb                             = (interleave_by * H * W * Ip);
+        }
         multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
         in1_ptr        = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
     }
@@ -576,7 +602,9 @@
     const CPUInfo &ci          = NEScheduler::get().cpu_info();
     unsigned int   num_threads = NEScheduler::get().num_threads();
 
-    arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode);
+    arm_gemm::GemmConfig cfg;
+    cfg.weight_format = info.weight_format;
+    arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
 
     // Create arm_gemm fallback
     auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
@@ -594,7 +622,9 @@
     const CPUInfo     &ci          = NEScheduler::get().cpu_info();
     const unsigned int num_threads = NEScheduler::get().num_threads();
 
-    arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode);
+    arm_gemm::GemmConfig cfg;
+    cfg.weight_format = info.weight_format;
+    arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
 
     // Create arm_gemm fallback
     auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
@@ -635,7 +665,8 @@
 {
 }
 
-Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
+Status CpuGemmAssemblyDispatch::has_opt_impl(arm_gemm::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
+                                             const AsmGemmInfo &info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
     ARM_COMPUTE_UNUSED(c);
@@ -643,12 +674,14 @@
     Params               p           = extract_parameters(a, b, d, info);
     const CPUInfo       &ci          = NEScheduler::get().cpu_info();
     unsigned int         num_threads = NEScheduler::get().num_threads();
+    arm_gemm::GemmConfig cfg;
+    cfg.weight_format = info.weight_format;
 
-    arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode);
+    arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg);
     switch(a->data_type())
     {
         case DataType::F32:
-            ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(args, {})),
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(expected_weight_format, args, {})),
                                             "We could not find an optimized kernel for F32 input");
             break;
 #ifdef __aarch64__
@@ -656,12 +689,12 @@
         case DataType::QASYMM8:
             if(d->data_type() == DataType::S32)
             {
-                ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(args, {})),
+                ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
                                                 "We could not find an optimized kernel for U8/QASYMM8 input and S32 output");
             }
             else
             {
-                ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(args, {})),
+                ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})),
                                                 "We could not find an optimized kernel for U8 input and U8 output");
             }
             break;
@@ -669,12 +702,12 @@
         case DataType::QASYMM8_SIGNED:
             if(d->data_type() == DataType::S32)
             {
-                ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(args, {})),
+                ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
                                                 "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
             }
             else
             {
-                ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(args, {})),
+                ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(expected_weight_format, args, {})),
                                                 "We could not find an optimized kernel for S8 input and S32 output");
             }
             break;
@@ -689,7 +722,7 @@
 #endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
         case DataType::F16:
-            ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(args, {})),
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(expected_weight_format, args, {})),
                                             "We could not find an optimized kernel for BFLOAT16 input and F32 output");
             break;
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
@@ -729,7 +762,17 @@
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input");
-    return CpuGemmAssemblyDispatch::has_opt_impl(a, b, c, d, info);
+    arm_gemm::WeightFormat expected_weight_format;
+    const Status           ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
+    if((bool)ret && expected_weight_format != arm_gemm::WeightFormat::ANY)
+    {
+        // Correctness check: if the format expected by the kernel is
+        // not "any", make sure that the one found matches the format
+        // intended by the caller.
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG((expected_weight_format != info.weight_format),
+                                        "The format expected by the kernel does not correspond with the one requested by the user.");
+    }
+    return ret;
 }
 
 bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation)
@@ -801,7 +844,7 @@
 
 bool CpuGemmAssemblyDispatch::is_configured() const
 {
-    return _arm_gemm != nullptr && _arm_gemm->is_configured();
+    return _arm_gemm && _arm_gemm->is_configured();
 }
 
 void CpuGemmAssemblyDispatch::run(ITensorPack &tensors)