[ONCPUML-951] Variable weight support for Convolution.

API changes for NEGEMMConvolutionLayer and CpuGemmConv2d

Built with:

    scons neon=1 opencl=0 os=linux arch=armv8.2-a multi_isa=1 \
        build=native -j32 Werror=false validation_tests=1 build_dir=opt \
        standalone=1 asserts=1 experimental_fixed_format_kernels=1 .

Tested with:

    ./build/opt/tests/arm_compute_validation

Hardware where the test executable was run:

Neoverse N1

Test coverage:

* NEGEMMConvolutionLayer, CpuGemmConv2d
* NHWC (the only one supported by the fixed-format kernels)
* F16, F32
* Shapes: RunSmall

Change-Id: I4fd3e495a7cbf61210ea02d37440ba9652934e99
Signed-off-by: Francesco Petrogalli <francesco.petrogalli@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7632
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
index 50fc5bd..58e4861 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
@@ -33,21 +33,21 @@
 
 #include "kernels/a32_sgemm_8x6.hpp"
 
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #include "kernels/a64_ffhybrid_bf16fp32_mmla_6x16.hpp"
 #include "kernels/a64_ffinterleaved_bf16fp32_dot_8x12.hpp"
 #include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #include "kernels/a64_hybrid_bf16fp32_dot_6x16.hpp"
 #include "kernels/a64_hybrid_bf16fp32_mmla_6x16.hpp"
 #include "kernels/a64_interleaved_bf16fp32_dot_8x12.hpp"
 #include "kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp"
 #include "kernels/a64_sgemm_8x12.hpp"
 
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #include "kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp"
 #include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #include "kernels/sve_hybrid_bf16fp32_dot_6x4VL.hpp"
 #include "kernels/sve_hybrid_bf16fp32_mmla_6x4VL.hpp"
 #include "kernels/sve_interleaved_bf16fp32_dot_8x3VL.hpp"
@@ -89,7 +89,7 @@
     [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>(args); }
 ),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 GemmImplementation<bfloat16, float>::with_estimate(
     GemmMethod::GEMM_INTERLEAVED,
     "sve_ffinterleaved_bf16fp32_mmla_8x3VL",
@@ -106,7 +106,7 @@
     [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
     [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>(args); }
 ),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #endif // ARM_COMPUTE_ENABLE_SVE
 GemmImplementation<bfloat16, float>::with_estimate(
     GemmMethod::GEMM_HYBRID,
@@ -136,7 +136,7 @@
     [](const GemmArgs &args) { return GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
 ),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 GemmImplementation<bfloat16, float>::with_estimate(
     GemmMethod::GEMM_INTERLEAVED,
     "a64_ffinterleaved_bf16fp32_mmla_8x12",
@@ -161,7 +161,7 @@
     [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
     [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
 ),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 GemmImplementation<bfloat16, float>::with_estimate(
     GemmMethod::GEMM_INTERLEAVED,
     "a64_sgemm_8x12",
@@ -197,7 +197,7 @@
 
 /* Explicitly instantiate the external functions for these types. */
 template UniqueGemmCommon<bfloat16, float> gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<bfloat16, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
 template KernelDescription get_gemm_method<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
 template std::vector<KernelDescription> get_compatible_kernels<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
 
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index 2796b0d..d749dce 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -34,17 +34,17 @@
 #include "gemm_interleaved.hpp"
 
 #include "kernels/a32_sgemm_8x6.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #include "kernels/a64_ffhybrid_fp16_mla_6x32.hpp"
 #include "kernels/a64_ffinterleaved_fp16_mla_8x24.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #include "kernels/a64_hgemm_8x24.hpp"
 #include "kernels/a64_hybrid_fp16_mla_6x32.hpp"
 #include "kernels/a64_sgemm_8x12.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #include "kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp"
 #include "kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #include "kernels/sve_hybrid_fp16_mla_6x4VL.hpp"
 #include "kernels/sve_interleaved_fp16_mla_8x3VL.hpp"
 
@@ -66,7 +66,7 @@
     [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); }
 ),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 GemmImplementation<__fp16, __fp16>::with_estimate(
     GemmMethod::GEMM_INTERLEAVED,
     "sve_ffinterleaved_fp16_mla_8x3VL",
@@ -83,7 +83,7 @@
     [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
     [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>(args); }
 ),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #endif // ARM_COMPUTE_ENABLE_SVE
 #if defined(__aarch64__)
 GemmImplementation<__fp16, __fp16>::with_estimate(
@@ -100,7 +100,7 @@
     [](const GemmArgs &args) { return GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>(args); }
 ),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 GemmImplementation<__fp16, __fp16>::with_estimate(
     GemmMethod::GEMM_INTERLEAVED,
     "a64_ffinterleaved_fp16_mla_8x24",
@@ -117,7 +117,7 @@
     [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
     [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>(args); }
 ),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 {
     GemmMethod::GEMM_INTERLEAVED,
     "a64_sgemm_8x12",
@@ -152,7 +152,7 @@
 
 /* Explicitly instantiate the external functions for these types. */
 template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<__fp16, __fp16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
 template KernelDescription get_gemm_method<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
 template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
 
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 4f7e191..0fc9e8b 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -31,12 +31,12 @@
 #include "gemv_pretransposed.hpp"
 
 #include "kernels/a32_sgemm_8x6.hpp"
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #include "kernels/a64_ffhybrid_fp32_mla_6x16.hpp"
 #include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp"
 #include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
 #include "kernels/a64_ffinterleaved_fp32_mla_8x12.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #include "kernels/a64_hybrid_fp32bf16fp32_mmla_4x24.hpp"
 #include "kernels/a64_hybrid_fp32bf16fp32_mmla_6x16.hpp"
 #include "kernels/a64_hybrid_fp32_mla_4x24.hpp"
@@ -48,12 +48,12 @@
 #include "kernels/a64_smallK_hybrid_fp32_mla_6x4.hpp"
 #include "kernels/a64_smallK_hybrid_fp32_mla_8x4.hpp"
 
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #include "kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp"
 #include "kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp"
 #include "kernels/sve_ffinterleaved_fp32_mla_8x3VL.hpp"
 #include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp"
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #include "kernels/sve_hybrid_fp32bf16fp32_mmla_4x6VL.hpp"
 #include "kernels/sve_hybrid_fp32bf16fp32_mmla_6x4VL.hpp"
 #include "kernels/sve_hybrid_fp32_mla_6x4VL.hpp"
@@ -165,7 +165,7 @@
     [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>::estimate_cycles<float>(args); },
     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>(args); }
 ),
- #ifdef ENABLE_FIXED_FORMAT_KERNELS
+ #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #ifdef ARM_COMPUTE_ENABLE_BF16
 GemmImplementation<float, float>::with_estimate(
     GemmMethod::GEMM_INTERLEAVED,
@@ -200,7 +200,7 @@
     [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>::estimate_cycles<float>(args); },
     [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>(args); }
 ),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #endif // ARM_COMPUTE_ENABLE_SVE
 // Cortex-A35 specific kernel - use for any problem on A35, and never in any other cases.
 {
@@ -253,7 +253,7 @@
     [](const GemmArgs &args) { return GemmInterleaved<cls_a64_sgemm_8x12, float, float>::estimate_cycles<float>(args); },
     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, float, float>(args); }
 ),
-#ifdef ENABLE_FIXED_FORMAT_KERNELS
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #ifdef ARM_COMPUTE_ENABLE_BF16
 // "fast mode" (BF16) kernels
 GemmImplementation<float, float>::with_estimate(
@@ -289,7 +289,7 @@
     [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>::estimate_cycles<float>(args); },
     [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>(args); }
 ),
-#endif // ENABLE_FIXED_FORMAT_KERNELS
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
 #endif // __aarch64__
 
 #ifdef __arm__
@@ -318,7 +318,7 @@
 
 /* Explicitly instantiate the external functions for these types. */
 template UniqueGemmCommon<float, float> gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<float, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
 template KernelDescription get_gemm_method<float, float, Nothing>(const GemmArgs &args, const Nothing &);
 template std::vector<KernelDescription> get_compatible_kernels<float, float, Nothing> (const GemmArgs &args, const Nothing &);
 
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
index c41b0a5..90e2f07 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
@@ -450,7 +450,7 @@
         }
 
         /* Make sure we've been set up correctly. */
-        assert(_B_transposed);
+        assert(FixedFormat || _B_transposed);
         static_assert(std::is_same<To, Tloi>::value, "gemm_native: Operand types must be the same.");
 //        static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same.");
 
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
index 75fb1cb..19c8fca 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -306,9 +306,12 @@
 }
 
 template<typename Top, typename Tret, class OutputStage>
-bool has_opt_gemm(const GemmArgs &args, const OutputStage &os) {
+bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os) {
     const GemmImplementation<Top, Tret, OutputStage> *impl;
-    return find_implementation<Top, Tret, OutputStage>(args, os, impl);
+    const bool success =  find_implementation<Top, Tret, OutputStage>(args, os, impl);
+    if (success)
+      wf = UniqueGemmCommon<Top, Tret>(impl->do_instantiate(args, os))->get_config().weight_format;
+    return success;
 }
 
 template<typename Top, typename Tret, class OutputStage>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
index 3915861..18d8fc9 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
@@ -56,7 +56,7 @@
 
 /* Explicitly instantiate the external functions for these types. */
 template UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<int16_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<int16_t, int32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
 template std::vector<KernelDescription> get_compatible_kernels<int16_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &);
 
 } // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index 0c68e4d..2450748 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -159,7 +159,7 @@
 
 /* Explicitly instantiate the external functions for these types. */
 template UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<int8_t, int32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<int8_t, int32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
 template std::vector<KernelDescription> get_compatible_kernels<int8_t, int32_t, Nothing> (const GemmArgs &args, const Nothing &);
 
 } // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
index 6b813c7..1d7b9c5 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
@@ -230,7 +230,7 @@
 }
 
 template UniqueGemmCommon<int8_t, int8_t> gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
-template bool has_opt_gemm<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template bool has_opt_gemm<int8_t, int8_t, Requantize32>(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os);
 template std::vector<KernelDescription> get_compatible_kernels<int8_t, int8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
 
 } // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
index 95139c2..be7a4ee 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
@@ -197,7 +197,7 @@
 }
 
 template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
-template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
+template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os);
 template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
 
 } // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
index 20cee55..fc836f9 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
@@ -56,7 +56,7 @@
 
 /* Explicitly instantiate the external functions for these types. */
 template UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<uint16_t, uint32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
 template std::vector<KernelDescription> get_compatible_kernels<uint16_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
 
 } // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index a2d2cc8..03e9cd6 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -157,7 +157,7 @@
 
 /* Explicitly instantiate the external functions for these types. */
 template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
-template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
+template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
 template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint32_t, Nothing> (const GemmArgs &args, const Nothing &);
 
 } // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/utils.hpp b/src/core/NEON/kernels/arm_gemm/utils.hpp
index 18e124b..d7b5398 100644
--- a/src/core/NEON/kernels/arm_gemm/utils.hpp
+++ b/src/core/NEON/kernels/arm_gemm/utils.hpp
@@ -24,7 +24,7 @@
 
 #pragma once
 
-#include "arm_gemm.hpp"
+#include "src/cpu/kernels/assembly/arm_gemm.hpp"
 
 #include <cstddef>
 #include <limits>