[ONCPUML-951] Variable weight support for Convolution.

API changes for NEGEMMConvolutionLayer and CpuGemmConv2d

Built with:

    scons neon=1 opencl=0 os=linux arch=armv8.2-a multi_isa=1 \
        build=native -j32 Werror=false validation_tests=1 build_dir=opt \
        standalone=1 asserts=1 experimental_fixed_format_kernels=1 .

Tested with:

    ./build/opt/tests/arm_compute_validation

Hardware where the test executable was run:

Neoverse N1

Test coverage:

* NEGEMMConvolutionLayer, CpuGemmConv2d
* NHWC (the only one supported by the fixed-format kernels)
* F16, F32
* Shapes: RunSmall

Change-Id: I4fd3e495a7cbf61210ea02d37440ba9652934e99
Signed-off-by: Francesco Petrogalli <francesco.petrogalli@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7632
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 247cb1d..48fd7c6 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -47,6 +47,57 @@
     GEMM_HYBRID_QUANTIZED
 };
 
+/** Memory layouts for the weights tensor.
+  *
+  * * UNSPECIFIED is used to select kernels that do not run in
+  *    variable weights mode.
+  *
+  * * ANY is used to query the kernel database to retrieve any of the
+  *   kernels that runs in variable weights mode. Once a kernel is
+  *   found, the specific format expected by the kernel can be
+  *   retrieved by the user for reordering the weights tensor
+  *   accordingly.
+  *
+  * The other values OHWIo{interleave_by}i{block_by} describe the
+  * memory layout of a 4D tensor with layout OHWI that has been
+  * transformed into a 4D tensor with dimensions O'HWI' where:
+  *
+  * O' = first multiple of {interleave_by} s.t. O<=O'
+  * I' = first multiple of {block_by} s.t. I<=I'
+  *
+  * The total size of the dst tensor is O' x H x W x I'
+  *
+  * The access function of the tensor with layout
+  * OHWIo{interleave_by}i{block_by} and size O'HWI' is a 6-parameter
+  * access function, where the 6 parameters are computed as follows:
+  *
+  * x5 = floor(o/{interleave_by}) RANGE [0, O'/{interleave_by} -1] SIZE: O'/{interleave_by}
+  *
+  * x4 = h                        RANGE [0, H-1]                   SIZE: H
+  * x3 = w                        RANGE [0, W-1]                   SIZE: W
+  * x2 = floor(i/{block_by})      RANGE [0, I'/{block_by} -1]      SIZE: I'/{block_by}
+  * x1 = o%{interleave_by}        RANGE [0, {interleave_by} -1]    SIZE: {interleave_by}
+  * x0 = i%{block_by}             RANGE [0, {block_by} -1]         SIZE: {block_by}
+  *                                                          TOTAL SIZE: O' * H * W * I'
+  *
+  *        4D                       6D
+  * -----------------   -----------------------------------
+  * value(o, h, w, i) =   x5 * H * W * I' * {interleave_by}
+  *                     + x4 * W * I' * {interleave_by}
+  *                     + x3 * I' * {interleave_by}
+  *                     + x2 * {interleave_by} * {block_by}
+  *                     + x1 * {block_by}
+  *                     + x0
+  *
+  * Notice that in arm_gemm the 4D tensor of dimension O'HWI' created
+  * for the OHWIo{interleave_by}i{block_by} format is in reality seen
+  * as a 2D tensor, where the number of rows is O'/{interleave_by}
+  * and the number of columns is {interleave_by} * H * W * I'.
+  *
+  * The postfix *_bf16 is for the memory layout needed for the
+  * fast-mode kernels, in which the weights are passed in bfloat16
+  * format.
+  */
 enum class WeightFormat
 {
     UNSPECIFIED    = 0x1,
@@ -87,6 +138,69 @@
     OHWIo64i8      = 0x804000
 };
 
+// OHWIo<interleave_by>i<block_by>
+inline int interleave_by(const WeightFormat wf)
+{
+    return ((int)wf >> 8) & 0xFFF;
+}
+inline int block_by(const WeightFormat wf)
+{
+    return ((int)wf >> 20) & 0xF;
+}
+inline bool is_fixed_format(const WeightFormat wf)
+{
+    return wf != WeightFormat::UNSPECIFIED && wf != WeightFormat::ANY;
+}
+
+inline std::string to_string(WeightFormat wf)
+{
+#define __CASE_WEIGHT_FORMAT(wf) \
+case WeightFormat::wf:       \
+    return #wf;
+    switch(wf)
+    {
+            __CASE_WEIGHT_FORMAT(UNSPECIFIED)
+            __CASE_WEIGHT_FORMAT(ANY)
+            __CASE_WEIGHT_FORMAT(OHWI)
+            __CASE_WEIGHT_FORMAT(OHWIo2)
+            __CASE_WEIGHT_FORMAT(OHWIo4)
+            __CASE_WEIGHT_FORMAT(OHWIo8)
+            __CASE_WEIGHT_FORMAT(OHWIo16)
+            __CASE_WEIGHT_FORMAT(OHWIo32)
+            __CASE_WEIGHT_FORMAT(OHWIo64)
+            __CASE_WEIGHT_FORMAT(OHWIo128)
+            __CASE_WEIGHT_FORMAT(OHWIo4i2)
+            __CASE_WEIGHT_FORMAT(OHWIo4i2_bf16)
+            __CASE_WEIGHT_FORMAT(OHWIo8i2)
+            __CASE_WEIGHT_FORMAT(OHWIo8i2_bf16)
+            __CASE_WEIGHT_FORMAT(OHWIo16i2)
+            __CASE_WEIGHT_FORMAT(OHWIo16i2_bf16)
+            __CASE_WEIGHT_FORMAT(OHWIo32i2)
+            __CASE_WEIGHT_FORMAT(OHWIo32i2_bf16)
+            __CASE_WEIGHT_FORMAT(OHWIo64i2)
+            __CASE_WEIGHT_FORMAT(OHWIo64i2_bf16)
+            __CASE_WEIGHT_FORMAT(OHWIo4i4)
+            __CASE_WEIGHT_FORMAT(OHWIo4i4_bf16)
+            __CASE_WEIGHT_FORMAT(OHWIo8i4)
+            __CASE_WEIGHT_FORMAT(OHWIo8i4_bf16)
+            __CASE_WEIGHT_FORMAT(OHWIo16i4)
+            __CASE_WEIGHT_FORMAT(OHWIo16i4_bf16)
+            __CASE_WEIGHT_FORMAT(OHWIo32i4)
+            __CASE_WEIGHT_FORMAT(OHWIo32i4_bf16)
+            __CASE_WEIGHT_FORMAT(OHWIo64i4)
+            __CASE_WEIGHT_FORMAT(OHWIo64i4_bf16)
+            __CASE_WEIGHT_FORMAT(OHWIo2i8)
+            __CASE_WEIGHT_FORMAT(OHWIo4i8)
+            __CASE_WEIGHT_FORMAT(OHWIo8i8)
+            __CASE_WEIGHT_FORMAT(OHWIo16i8)
+            __CASE_WEIGHT_FORMAT(OHWIo32i8)
+            __CASE_WEIGHT_FORMAT(OHWIo64i8)
+        default:
+            return "invalid value";
+    }
+#undef __CASE_WEIGHT_FORMAT
+}
+
 struct KernelDescription
 {
     GemmMethod  method         = GemmMethod::DEFAULT;
@@ -230,6 +344,6 @@
 std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
 
 template <typename Top, typename Tret, class OutputStage = Nothing>
-bool has_opt_gemm(const GemmArgs &args, const OutputStage & = {});
+bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
 
 } // namespace arm_gemm