[ONCPUML-951] Variable weight support for Convolution.

API changes for NEGEMMConvolutionLayer and CpuGemmConv2d

Built with:

    scons neon=1 opencl=0 os=linux arch=armv8.2-a multi_isa=1 \
        build=native -j32 Werror=false validation_tests=1 build_dir=opt \
        standalone=1 asserts=1 experimental_fixed_format_kernels=1 .

Tested with:

    ./build/opt/tests/arm_compute_validation

Hardware where the test executable was run:

Neoverse N1

Test coverage:

* NEGEMMConvolutionLayer, CpuGemmConv2d
* NHWC (the only one supported by the fixed-format kernels)
* F16, F32
* Shapes: RunSmall

Change-Id: I4fd3e495a7cbf61210ea02d37440ba9652934e99
Signed-off-by: Francesco Petrogalli <francesco.petrogalli@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7632
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/tests/SConscript b/tests/SConscript
index 8cd13ab..b848f27 100644
--- a/tests/SConscript
+++ b/tests/SConscript
@@ -184,6 +184,10 @@
     if 'ndk_above_r21' in env:
         test_env['LINKFLAGS'].append('-static-openmp')
 
+# Testing for fixed format GEMM kernels.
+if env['experimental_fixed_format_kernels'] and test_env['validation_tests']:
+    test_env.Append(CPPDEFINES = ['ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS'])
+
 if test_env['validation_tests']:
     arm_compute_validation_framework = env.StaticLibrary('arm_compute_validation_framework', Glob('validation/reference/*.cpp') + Glob('validation/*.cpp'), LINKFLAGS=test_env['LINKFLAGS'], CXXFLAGS=test_env['CXXFLAGS'], LIBS= [ arm_compute_test_framework, arm_compute_core_a])
     Depends(arm_compute_validation_framework , arm_compute_test_framework)
diff --git a/tests/framework/Asserts.h b/tests/framework/Asserts.h
index 28d3da9..5f46277 100644
--- a/tests/framework/Asserts.h
+++ b/tests/framework/Asserts.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -42,6 +42,11 @@
     return value;
 }
 
+inline std::string make_printable(arm_gemm::WeightFormat wf)
+{
+    return arm_gemm::to_string(wf);
+}
+
 inline unsigned int make_printable(uint8_t value)
 {
     return value;
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp
index 578921b..3b385d4 100644
--- a/tests/validation/NEON/ConvolutionLayer.cpp
+++ b/tests/validation/NEON/ConvolutionLayer.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -504,6 +504,220 @@
 #endif           /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
 TEST_SUITE_END() // WinogradLayer
 
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+TEST_SUITE(VariableWeightUtils)
+
+// UC2_1_* tests: the user requests a specific fixed format, but there is no kernel that supports it.
+
+FIXTURE_DATA_TEST_CASE(UC2_1_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
+                       combine(framework::dataset::make("DataType", { DataType::F32 }),
+                               framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo2 })))
+{
+    ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
+}
+FIXTURE_DATA_TEST_CASE(UC2_1_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
+                       combine(framework::dataset::make("DataType", { DataType::F32 }),
+                               framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo2 })))
+{
+    ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
+}
+
+// UC2_1_* tests: the user requests a specific fixed format, and a
+// kernel that support that fixed format is found.
+
+FIXTURE_DATA_TEST_CASE(UC2_2_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
+                       combine(framework::dataset::make("DataType", { DataType::F32 }),
+                               framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo4 })))
+{
+    ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+    ARM_COMPUTE_EXPECT(_computed_weight_format == arm_gemm::WeightFormat::OHWIo4, framework::LogLevel::ERRORS);
+}
+
+FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
+                       combine(framework::dataset::make("DataType", { DataType::F32 }),
+                               framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::OHWIo4 })))
+{
+    ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+    ARM_COMPUTE_EXPECT(_computed_weight_format == arm_gemm::WeightFormat::OHWIo4, framework::LogLevel::ERRORS);
+}
+
+// UC3_1_* tests: the user queries for ANY fixed format, but there is
+// no kernel that support the use case specified by the user (for
+// example, there is no fixed format kernel for the datatype of the
+// problem).
+
+FIXTURE_DATA_TEST_CASE(UC3_1_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
+                       combine(framework::dataset::make("DataType", { DataType::S32 }),
+                               framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+{
+    ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
+}
+
+FIXTURE_DATA_TEST_CASE(UC3_1_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
+                       combine(framework::dataset::make("DataType", { DataType::S32 }),
+                               framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+{
+    ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS);
+}
+
+// UC3_2_* tests: the user queries for ANY fixed format. The search
+// succeeded and the fixed format found is prompted back for
+// consumption by the user. Note that we just test the
+// _computed_weight_format to be anything but not the formats that are
+// not fixed formats (ANY and UNSPECIFIED). This is because the weight
+// format that the runtime produces depends on the size of the vector
+// units of the hardware where the tests is executed. For example, a
+// format like OHWIo4 for FP32 data returned for 128-bit NEON hardware
+// is replaced by OHWIo8 when running on 256-bit SVE.
+
+FIXTURE_DATA_TEST_CASE(UC3_2_CpuGemmConv2d, HasOptImplFixture<cpu::CpuGemmConv2d>, framework::DatasetMode::ALL,
+                       combine(framework::dataset::make("DataType", { DataType::F32 }),
+                               framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+{
+    ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS);
+    ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::ANY, framework::LogLevel::ERRORS);
+    ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+}
+
+FIXTURE_DATA_TEST_CASE(UC3_2_NEGEMMConvolutionLayer, HasOptImplFixture<NEGEMMConvolutionLayer>, framework::DatasetMode::ALL,
+                       combine(framework::dataset::make("DataType", { DataType::F32 }),
+                               framework::dataset::make("QueryWeightFormat", { arm_gemm::WeightFormat::ANY })))
+{
+    ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::ANY, framework::LogLevel::ERRORS);
+    ARM_COMPUTE_EXPECT(_computed_weight_format != arm_gemm::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS);
+}
+
+namespace
+{
+using TestCaseType          = std::tuple<TensorShape, TensorShape, arm_gemm::WeightFormat>;
+auto prepare_weights_shapes = framework::dataset::make("TensorShape",
+{
+    // OHWIo<interleave_by>i<block_by>
+    //
+    // OHWI --> O'HWI', where:
+    //
+    //   O'= smallest multiple of <interleave_by> such that O<=O'
+    //   I'= smallest multiple of <block_by> such that I<=I'
+    //
+
+    // Change N for OHWIo4
+    TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 4U }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 12U }, arm_gemm::WeightFormat::OHWIo4 }),
+    // // Change N for OHWIo8
+    TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+    TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+    TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+    TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+    TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+    TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+    TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+    TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_gemm::WeightFormat::OHWIo8 }),
+    TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 16U }, arm_gemm::WeightFormat::OHWIo8 }),
+    // // Change N for OHWIo4 when H, W and C are not 1
+    TestCaseType({ { 3U, 4U, 2U, 1U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 3U, 4U, 2U, 2U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 3U, 4U, 2U, 3U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 3U, 4U, 2U, 4U }, { 3, 4, 2, 4 }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 3U, 4U, 2U, 6U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 3U, 4U, 2U, 7U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 3U, 4U, 2U, 8U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 12 }, arm_gemm::WeightFormat::OHWIo4 }),
+
+    // // Fix N and move HWI around, with different data layouts and formats
+    TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 2U, 4U, 3U, 9U }, { 2, 4, 3, 16 }, arm_gemm::WeightFormat::OHWIo8 }),
+    TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 16 }, arm_gemm::WeightFormat::OHWIo8 }),
+    TestCaseType({ { 1024U, 1U, 1U, 1001U }, { 1024, 1, 1, 1008 }, arm_gemm::WeightFormat::OHWIo8 }),
+
+    // // Adding <block_by> on I (=C)
+    TestCaseType({ { 1U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }),
+    TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }),
+    TestCaseType({ { 3U, 4U, 3U, 5U }, { 4, 4, 3, 8 }, arm_gemm::WeightFormat::OHWIo4i2 }),
+
+    // ---------
+    TestCaseType({ { 2, 2, 1, 5 }, { 2, 2, 1, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+    TestCaseType({ { 1, 2, 2, 5 }, { 1, 2, 2, 8 }, arm_gemm::WeightFormat::OHWIo4 }),
+
+});
+} // unnamed namespace
+
+DATA_TEST_CASE(PrepareWeightShape, framework::DatasetMode::ALL,
+               prepare_weights_shapes, shapes)
+{
+    const TensorShape            input_shape    = std::get<0>(shapes);
+    const TensorShape            expected_shape = std::get<1>(shapes);
+    const arm_gemm::WeightFormat wf             = std::get<2>(shapes);
+    const DataType               DT             = DataType::F32;
+    const DataLayout             DL             = DataLayout::NHWC;
+    const auto                   TI             = TensorInfo(input_shape, 1 /*num_channels, deprecated*/, DT, DL);
+    const TensorInfo             computed       = ::arm_compute::test::validation::prepare_weights(TI, wf);
+    const TensorInfo             expected       = TensorInfo(expected_shape, 1 /*num_channels, deprecated*/, DT, DL);
+    ARM_COMPUTE_EXPECT_EQUAL(computed, expected, framework::LogLevel::ERRORS);
+}
+
+TEST_SUITE_END() // VariableWeightUtils
+
+TEST_SUITE(ExperimentalCpuAPIVariableWeightWithFixtures)
+
+template <typename ScalarType>
+using VarWidth = VariableWeightsFixture<cpu::CpuGemmConv2d, Tensor, Accessor, ScalarType>;
+
+FIXTURE_DATA_TEST_CASE(RunSmallFloat, VarWidth<float>, framework::DatasetMode::ALL,
+                       combine(combine(datasets::SmallConvolutionLayerDataset(),
+                                       framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+                               framework::dataset::make("ACL Scalar type", { DataType::F32 })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallHalf, VarWidth<half>, framework::DatasetMode::ALL,
+                       combine(combine(datasets::SmallConvolutionLayerDataset(),
+                                       framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+                               framework::dataset::make("ACL Scalar type", { DataType::F16 })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference, rel_tolerance_f16, 0.f, half(abs_tolerance_f16));
+}
+
+TEST_SUITE_END() // ExperimentalCpuAPIVariableWeightWithFixtures
+
+TEST_SUITE(ExperimentalNEAPIVariableWeightWithFixtures)
+
+template <typename ScalarType>
+using NEGEMMVarWidth = VariableWeightsFixtureNEInterface<NEGEMMConvolutionLayer, Tensor, Accessor, ScalarType>;
+
+FIXTURE_DATA_TEST_CASE(NEGEMMRunSmallFloat, NEGEMMVarWidth<float>, framework::DatasetMode::ALL,
+                       combine(combine(datasets::SmallConvolutionLayerDataset(),
+                                       framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+                               framework::dataset::make("ACL Scalar type", { DataType::F32 })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
+}
+
+FIXTURE_DATA_TEST_CASE(NEGEMMRunSmallHalf, NEGEMMVarWidth<half>, framework::DatasetMode::ALL,
+                       combine(combine(datasets::SmallConvolutionLayerDataset(),
+                                       framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+                               framework::dataset::make("ACL Scalar type", { DataType::F16 })))
+{
+    // Validate output
+    validate(Accessor(_target), _reference, rel_tolerance_f16, 0.f, half(abs_tolerance_f16));
+}
+
+TEST_SUITE_END() // ExperimentalNEAPIVariableWeightWithFixtures
+
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+
 TEST_SUITE(GEMMConvolutionLayer)
 template <typename T>
 using NEGEMMConvolutionLayerFixture = ConvolutionValidationFixture<Tensor, Accessor, NEConvolutionLayer, T>;
@@ -609,9 +823,7 @@
 #if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
 TEST_SUITE(BFLOAT16)
 FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
-                                                                                                                    framework::dataset::make("ReshapeWeights", { true })),
-                                                                                                                    framework::dataset::make("DataType", DataType::BFLOAT16)),
-                                                                                                                    framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+                                                                                                                    framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::BFLOAT16)), framework::dataset::make("DataLayout", { DataLayout::NHWC })),
                                                                                                             ActivationFunctionsDataset))
 {
     // Validate output
@@ -623,10 +835,7 @@
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 TEST_SUITE(FP16)
 FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
-                                                                                                                   framework::dataset::make("ReshapeWeights", { true })),
-                                                                                                                   framework::dataset::make("DataType", DataType::F16)),
-                                                                                                                   framework::dataset::make("DataLayout", { DataLayout::NCHW })),
-                                                                                                           ActivationFunctionsDataset))
+                                                                                                                   framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("DataLayout", { DataLayout::NCHW })), ActivationFunctionsDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
@@ -636,9 +845,7 @@
 
 TEST_SUITE(FP32)
 FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
-                                                                                                                    framework::dataset::make("ReshapeWeights", { true })),
-                                                                                                                    framework::dataset::make("DataType", DataType::F32)),
-                                                                                                                    framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
+                                                                                                                    framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
                                                                                                             ActivationFunctionsDataset))
 {
     // Validate output
@@ -680,11 +887,8 @@
 TEST_SUITE(Quantized)
 TEST_SUITE(QASYMM8)
 FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
-                                                                                                                       framework::dataset::make("ReshapeWeights", { true })),
-                                                                                                                       framework::dataset::make("DataType", DataType::QASYMM8)),
-                                                                                                                       framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
-                                                                                                                       framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
-                                                                                                                       QuantizedActivationFunctionsDataset))
+                                                                                                                       framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
+                                                                                                                       framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), QuantizedActivationFunctionsDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_qasymm8);
@@ -710,11 +914,8 @@
 
 TEST_SUITE(QASYMM8_SIGNED)
 FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
-                                                                                                                      framework::dataset::make("ReshapeWeights", { true })),
-                                                                                                                      framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
-                                                                                                                      framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
-                                                                                                                      framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })),
-                                                                                                                      QuantizedActivationFunctionsDataset))
+                                                                                                                      framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })),
+                                                                                                                      framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), QuantizedActivationFunctionsDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_qasymm8);
@@ -868,10 +1069,7 @@
 TEST_SUITE(Float)
 TEST_SUITE(FP32)
 FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
-                                                                                                                     framework::dataset::make("ReshapeWeights", { true })),
-                                                                                                                     framework::dataset::make("DataType", DataType::F32)),
-                                                                                                                     framework::dataset::make("DataLayout", { DataLayout::NHWC })),
-                                                                                                             ActivationFunctionsDataset))
+                                                                                                                     framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), ActivationFunctionsDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
@@ -895,11 +1093,8 @@
 TEST_SUITE(Quantized)
 TEST_SUITE(QASYMM8)
 FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
-                                                                                                                        framework::dataset::make("ReshapeWeights", { true })),
-                                                                                                                        framework::dataset::make("DataType", DataType::QASYMM8)),
-                                                                                                                        framework::dataset::make("DataLayout", { DataLayout::NHWC })),
-                                                                                                                        framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
-                                                                                                                        QuantizedActivationFunctionsDataset))
+                                                                                                                        framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+                                                                                                                        framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), QuantizedActivationFunctionsDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_qasymm8);
@@ -908,11 +1103,8 @@
 
 TEST_SUITE(QASYMM8_SIGNED)
 FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerQuantizedFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
-                                                                                                                       framework::dataset::make("ReshapeWeights", { true })),
-                                                                                                                       framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
-                                                                                                                       framework::dataset::make("DataLayout", { DataLayout::NHWC })),
-                                                                                                                       framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })),
-                                                                                                                       QuantizedActivationFunctionsDataset))
+                                                                                                                       framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+                                                                                                                       framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), QuantizedActivationFunctionsDataset))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_qasymm8);
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index bffdc59..d3804ee 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -28,6 +28,7 @@
 #include "arm_compute/core/Types.h"
 #include "arm_compute/graph/Utils.h"
 #include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
 #include "src/graph/mutators/MutatorUtils.h"
 #include "tests/AssetsLibrary.h"
 #include "tests/Globals.h"
@@ -121,14 +122,14 @@
         {
             case DataType::QASYMM8:
             {
-                std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
+                std::pair<int, int>                     bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
                 std::uniform_int_distribution<uint32_t> distribution(bounds.first, bounds.second);
                 library->fill(tensor, distribution, i);
                 break;
             }
             case DataType::QASYMM8_SIGNED:
             {
-                std::pair<int, int> bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f);
+                std::pair<int, int>                    bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f);
                 std::uniform_int_distribution<int32_t> distribution(bounds.first, bounds.second);
                 library->fill(tensor, distribution, i);
                 break;
@@ -397,6 +398,296 @@
                                                                                                   quantization_info, QuantizationInfo(weights_scales), act_info);
     }
 };
+
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_gemm::WeightFormat weight_format)
+{
+    const DataLayout data_layout = tensor_info.data_layout();
+    ARM_COMPUTE_EXPECT(data_layout == DataLayout::NHWC, framework::LogLevel::ERRORS);
+    const DataType    data_type    = tensor_info.data_type();
+    const TensorShape tensor_shape = tensor_info.tensor_shape();
+    const int         N            = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES)]; // N=O
+    const int         H            = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
+    const int         W            = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
+    const int         C            = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; // C=I
+
+    const int interleave_by = arm_gemm::interleave_by(weight_format);
+    const int block_by      = arm_gemm::block_by(weight_format);
+    const int Ip            = arm_gemm::roundup<unsigned int>(C, block_by);      // C'=I'
+    const int Op            = arm_gemm::roundup<unsigned int>(N, interleave_by); // O'=N'
+
+    const TensorShape TS(Ip, W, H, Op);
+    return TensorInfo(TS, 1 /*num_channels*/, data_type, data_layout);
+}
+
+template <typename ScalarType, typename AccessorType>
+inline void rearrange_data(const AccessorType src, AccessorType dst, const arm_gemm::WeightFormat weight_format)
+{
+    ARM_COMPUTE_EXPECT(arm_gemm::is_fixed_format(weight_format), framework::LogLevel::ERRORS);
+    // Data Layout: OHWIo<interleave_by>i<block_by>
+    const int         interleave_by    = arm_gemm::interleave_by(weight_format);
+    const int         block_by         = arm_gemm::block_by(weight_format);
+    const TensorShape src_tensor_shape = src.shape();
+    const DataLayout  data_layout      = src.data_layout();
+    ARM_COMPUTE_EXPECT(data_layout == DataLayout::NHWC, framework::LogLevel::ERRORS);
+    const unsigned int O  = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES)]; // N=O
+    const unsigned int H  = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
+    const unsigned int W  = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
+    const unsigned int I  = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; // C=I
+    const unsigned int Ip = arm_gemm::roundup<unsigned int>(I, block_by);                                                 // C'=I'
+    const unsigned int Op = arm_gemm::roundup<unsigned int>(O, interleave_by);                                            // N'=O'
+
+    ARM_COMPUTE_EXPECT_EQUAL(Op * H * W * Ip, (unsigned)dst.num_elements(), framework::LogLevel::ERRORS);
+    ARM_COMPUTE_EXPECT(src.num_elements() <= dst.num_elements(), framework::LogLevel::ERRORS);
+
+    const ScalarType *src_ptr = reinterpret_cast<const ScalarType *>(src.data());
+    ScalarType       *dst_ptr = reinterpret_cast<ScalarType *>(dst.data());
+    for(unsigned i = 0; i < I; ++i)
+        for(unsigned w = 0; w < W; ++w)
+            for(unsigned h = 0; h < H; ++h)
+                for(unsigned o = 0; o < O; ++o)
+                {
+                    ScalarType src_element;
+                    switch(data_layout)
+                    {
+                        case DataLayout::NHWC:
+                        {
+                            src_element = src_ptr[o * H * W * I + h * W * I + w * I + i];
+                        }
+                        break;
+                        default:
+                        {
+                            ARM_COMPUTE_ERROR("Unsupported memory layout.");
+                        }
+                    }
+                    const int x5      = std::floor(((float)o) / interleave_by);
+                    const int x4      = h;
+                    const int x3      = w;
+                    const int x2      = std::floor((float)i / block_by);
+                    const int x1      = o % interleave_by;
+                    const int x0      = i % block_by;
+                    unsigned  dst_idx = x5 * H * W * Ip * interleave_by
+                                        + x4 * W * Ip * interleave_by
+                                        + x3 * Ip * interleave_by
+                                        + x2 * interleave_by * block_by
+                                        + x1 * block_by
+                                        + x0;
+                    dst_ptr[dst_idx] = src_element;
+                }
+}
+
+template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType>
+class VariableWeightsFixtureBaseClass : public framework::Fixture
+{
+public:
+    template <typename...>
+    void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, DataLayout data_layout,
+               const DataType data_type)
+    {
+        conv = std::make_unique<ConvolutionFunction>();
+        // prepare data
+        _data_layout = data_layout;
+        // Fixed format kernels for variable weights can work only with NHWC format.
+        ARM_COMPUTE_EXPECT_EQUAL(_data_layout, DataLayout::NHWC, framework::LogLevel::ERRORS);
+        _data_type = data_type;
+        // run the code
+        compute_target(input_shape, weights_shape, bias_shape, output_shape, info, dilation);
+        compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, dilation);
+    }
+    void teardown()
+    {
+        _target.allocator()->free();
+    }
+
+protected:
+    template <typename U>
+    void fill(U &&tensor, int i)
+    {
+        switch(tensor.data_type())
+        {
+            case DataType::F16:
+            {
+                arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
+                library->fill(tensor, distribution, i);
+                break;
+            }
+            case DataType::F32:
+            {
+                std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
+                library->fill(tensor, distribution, i);
+                break;
+            }
+            default:
+                library->fill_tensor_uniform(tensor, i);
+        }
+    }
+
+private:
+    virtual void configure_and_execute_kernel(TensorInfo src_tensor_info, TensorInfo weight_tensor_info, TensorInfo bias_tensor_info, TensorInfo dst_tensor_info, const WeightsInfo weights_info,
+                                              const PadStrideInfo &conv_info,
+                                              const Size2D        &dilation) = 0;
+
+    void compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape &bias_shape, TensorShape output_shape, const PadStrideInfo &conv_info,
+                        const Size2D &dilation)
+    {
+        // The dataset is always in NCHW format - we need to make C the
+        // innermost dimension because the fixed-format kernel work only
+        // with NHWC layout.
+        permute(input_shape, PermutationVector(2U, 0U, 1U));
+        permute(weights_shape, PermutationVector(2U, 0U, 1U));
+        permute(output_shape, PermutationVector(2U, 0U, 1U));
+        const auto src_tensor_info    = TensorInfo(input_shape, 1, _data_type, _data_layout);
+        const auto weight_tensor_info = TensorInfo(weights_shape, 1, _data_type, _data_layout);
+        const auto bias_tensor_info   = TensorInfo(bias_shape, 1, _data_type, _data_layout);
+        auto       dst_tensor_info    = TensorInfo(output_shape, 1, _data_type, _data_layout);
+
+        const int kernel_height = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT)];
+        const int kernel_width  = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH)];
+        const int num_kernels   = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::BATCHES)];
+
+        const WeightsInfo query_weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, arm_gemm::WeightFormat::ANY);
+        const bool        kernel_found = bool(ConvolutionFunction::has_opt_impl(_computed_weight_format, &src_tensor_info, &weight_tensor_info,
+                                                                                &bias_tensor_info, &dst_tensor_info, conv_info, query_weights_info));
+        // Make surethat the setup founds a fixed-format kernel as requested by the test case.
+        ARM_COMPUTE_EXPECT(kernel_found, framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(arm_gemm::is_fixed_format(_computed_weight_format), framework::LogLevel::ERRORS);
+
+        const WeightsInfo weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, _computed_weight_format);
+        configure_and_execute_kernel(src_tensor_info, weight_tensor_info, bias_tensor_info, dst_tensor_info, weights_info, conv_info,
+                                     dilation);
+    }
+    void compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
+                           const Size2D &dilation)
+    {
+        ARM_COMPUTE_UNUSED(input_shape, weights_shape, bias_shape, output_shape, info,
+                           dilation);
+
+        // Create reference
+        SimpleTensor<ScalarType> src{ input_shape, _data_type };
+        SimpleTensor<ScalarType> weights{ weights_shape, _data_type };
+        SimpleTensor<ScalarType> bias{ bias_shape, _data_type };
+        fill(src, 0);
+        fill(bias, 1);
+        fill(weights, 3);
+        _reference = reference::convolution_layer<ScalarType>(src, weights, bias, output_shape, info, dilation, 1 /*num_groups*/);
+    }
+    DataLayout _data_layout{};
+    DataType   _data_type{};
+
+protected:
+    std::unique_ptr<ConvolutionFunction> conv{};
+    arm_gemm::WeightFormat               _computed_weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
+    TensorClass                          _target{};
+    SimpleTensor<ScalarType>             _reference{};
+};
+
+template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType>
+class VariableWeightsFixture : public VariableWeightsFixtureBaseClass<ConvolutionFunction, TensorClass, AccessorType, ScalarType>
+{
+    void configure_and_execute_kernel(TensorInfo src_tensor_info, TensorInfo weight_tensor_info, TensorInfo bias_tensor_info, TensorInfo dst_tensor_info, const WeightsInfo weights_info,
+                                      const PadStrideInfo &conv_info,
+                                      const Size2D        &dilation)
+    {
+        this->conv->configure(&src_tensor_info, &weight_tensor_info, &bias_tensor_info, &dst_tensor_info, conv_info, weights_info, dilation);
+
+        // Allocate input tensors
+        auto             src                 = create_tensor<TensorClass>(src_tensor_info);
+        auto             weights_original    = create_tensor<TensorClass>(weight_tensor_info);
+        const TensorInfo new_tensor_info     = prepare_weights(weight_tensor_info, this->_computed_weight_format);
+        auto             weights_transformed = create_tensor<TensorClass>(new_tensor_info);
+        auto             bias                = create_tensor<TensorClass>(bias_tensor_info);
+        src.allocator()->allocate();
+        weights_original.allocator()->allocate();
+        weights_transformed.allocator()->allocate();
+        bias.allocator()->allocate();
+        // Allocate destination tensor
+        this->_target = create_tensor<TensorClass>(dst_tensor_info);
+        this->_target.allocator()->allocate();
+
+        // Prepare source and biases that are left unchanged.
+        this->fill(AccessorType(src), 0);
+        this->fill(AccessorType(bias), 1);
+
+        // First run
+        this->fill(AccessorType(weights_original), 2);
+        rearrange_data<ScalarType, AccessorType>(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format);
+        ITensorPack run_pack{ { TensorType::ACL_SRC_0, &src }, { TensorType::ACL_SRC_1, &weights_transformed }, { TensorType::ACL_SRC_2, &bias }, { TensorType::ACL_DST, &(this->_target) } };
+        this->conv->run(run_pack);
+        // Second run, with new weights
+        this->fill(AccessorType(weights_original), 3);
+        rearrange_data<ScalarType, AccessorType>(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format);
+        this->conv->run(run_pack);
+        src.allocator()->free();
+        weights_original.allocator()->free();
+        weights_transformed.allocator()->free();
+        bias.allocator()->free();
+    }
+};
+
+template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType>
+class VariableWeightsFixtureNEInterface : public VariableWeightsFixtureBaseClass<ConvolutionFunction, TensorClass, AccessorType, ScalarType>
+{
+    void configure_and_execute_kernel(TensorInfo src_tensor_info, TensorInfo weight_tensor_info, TensorInfo bias_tensor_info, TensorInfo dst_tensor_info, const WeightsInfo weights_info,
+                                      const PadStrideInfo &conv_info,
+                                      const Size2D        &dilation)
+    {
+        // Allocate input tensors
+        auto             src                 = create_tensor<TensorClass>(src_tensor_info);
+        auto             weights_original    = create_tensor<TensorClass>(weight_tensor_info);
+        const TensorInfo new_tensor_info     = prepare_weights(weight_tensor_info, this->_computed_weight_format);
+        auto             weights_transformed = create_tensor<TensorClass>(new_tensor_info);
+        auto             bias                = create_tensor<TensorClass>(bias_tensor_info);
+        src.allocator()->allocate();
+        weights_original.allocator()->allocate();
+        weights_transformed.allocator()->allocate();
+        bias.allocator()->allocate();
+        // Allocate destination tensor
+        this->_target = create_tensor<TensorClass>(dst_tensor_info);
+        this->_target.allocator()->allocate();
+        this->conv->configure(&src, &weights_transformed, &bias, &(this->_target), conv_info, weights_info, dilation);
+        // Prepare source and biases that are left unchanged.
+        this->fill(AccessorType(src), 0);
+        this->fill(AccessorType(bias), 1);
+
+        // First run
+        this->fill(AccessorType(weights_original), 2);
+        rearrange_data<ScalarType, AccessorType>(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format);
+        this->conv->run();
+        // Second run, with new weights
+        this->fill(AccessorType(weights_original), 3);
+        rearrange_data<ScalarType, AccessorType>(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format);
+        this->conv->run();
+        src.allocator()->free();
+        weights_original.allocator()->free();
+        weights_transformed.allocator()->free();
+        bias.allocator()->free();
+    }
+};
+
+template <typename ConvolutionClass>
+class HasOptImplFixture : public framework::Fixture
+{
+public:
+    template <typename...>
+    void setup(DataType data_type, arm_gemm::WeightFormat query_weight_format)
+    {
+        auto              conv        = std::make_unique<ConvolutionClass>();
+        const auto        src_info    = TensorInfo(TensorShape(1U, 5U, 2U), 1, data_type, DataLayout::NHWC);
+        const auto        weight_info = TensorInfo(TensorShape(1U, 3U, 2U, 3U), 1, data_type, DataLayout::NHWC);
+        const auto        bias_info   = TensorInfo(TensorShape(3U), 1, data_type, DataLayout::NHWC);
+        auto              dst_info    = TensorInfo(TensorShape(1U, 7U, 3U), 1, data_type, DataLayout::NHWC);
+        const auto        conv_info   = PadStrideInfo(1, 1, 0, 0, 2, 2, DimensionRoundingType::FLOOR);
+        const WeightsInfo weights_info(false, 3U, 3U, 1U, false, query_weight_format);
+        _kernel_found = bool(ConvolutionClass::has_opt_impl(_computed_weight_format, &src_info, &weight_info,
+                                                            &bias_info, &dst_info, conv_info, weights_info));
+    }
+
+protected:
+    bool                   _kernel_found{ false };
+    arm_gemm::WeightFormat _computed_weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
+};
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+
 } // namespace validation
 } // namespace test
 } // namespace arm_compute