COMPMID-1708: Improve GEMM test coverage.
Added test cases to exercise the code path where the reshaping of B is performed on the fly.
Change-Id: Ifa4348e1054dc0019be3927f482adf64b18fd554
diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp
index ded5ec6..943fc9f 100644
--- a/tests/validation/NEON/GEMM.cpp
+++ b/tests/validation/NEON/GEMM.cpp
@@ -123,16 +123,23 @@
template <typename T>
using NEGEMMFixture = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T>;
+template <typename T>
+using NEGEMMFixtureDisabledC = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T, true>;
+
TEST_SUITE(Float)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMDataset(), framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
+ framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMDataset(), framework::dataset::make("DataType",
- DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
+
+ framework::dataset::make("DataType", DataType::F16)))
{
// Validate output
validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
@@ -141,16 +148,33 @@
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMDataset(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
+
+ framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMDataset(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
+
+ framework::dataset::make("DataType", DataType::F32)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_f);
}
+TEST_SUITE(DisabledC)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixtureDisabledC<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
+ framework::dataset::make("ReshapeWeights", { true, false })),
+
+ framework::dataset::make("DataType", DataType::F32)))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_f);
+}
+TEST_SUITE_END()
+
TEST_SUITE_END()
TEST_SUITE_END()