COMPMID-1708: Improve GEMM test coverage.

Added test cases to exercise the code path where the reshaping of B is performed on the fly.

Change-Id: Ifa4348e1054dc0019be3927f482adf64b18fd554
diff --git a/tests/validation/CL/GEMM.cpp b/tests/validation/CL/GEMM.cpp
index ff2071a..376032c 100644
--- a/tests/validation/CL/GEMM.cpp
+++ b/tests/validation/CL/GEMM.cpp
@@ -108,10 +108,10 @@
 using CLGEMMFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T>;
 
 template <typename T>
-using CLGEMMOutput3DFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T, false, true>;
+using CLGEMMOutput3DFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T, false, false, true>;
 
 template <typename T>
-using CLGEMMInputOutput3DFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T, true, true>;
+using CLGEMMInputOutput3DFixture = GEMMValidationFixture<CLTensor, CLAccessor, CLGEMM, T, false, true, true>;
 
 TEST_SUITE(TRANSPOSE_1XW)
 using CLGEMMTranspose1xW        = CLSynthetizeFunctionWithZeroConstantBorder<CLGEMMTranspose1xWKernel, 4>;
@@ -128,13 +128,16 @@
 
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMDataset(), framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
+                                                                                                         framework::dataset::make("ReshapeWeights", { true, false })),
+                                                                                                 framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMDataset(), framework::dataset::make("DataType",
-                                                                                               DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
+                                                                                                       framework::dataset::make("ReshapeWeights", { true })),
+                                                                                               framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
@@ -142,12 +145,16 @@
 TEST_SUITE_END()
 
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMDataset(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
+                                                                                                          framework::dataset::make("ReshapeWeights", { true, false })),
+                                                                                                  framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMDataset(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
+                                                                                                        framework::dataset::make("ReshapeWeights", { true, false })),
+                                                                                                framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
@@ -158,13 +165,15 @@
 TEST_SUITE(INPUT_OUTPUT_3D)
 TEST_SUITE(Float)
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInputOutput3DFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMInputOutput3DDataset(),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInputOutput3DFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMInputOutput3DDataset(),
+                                                                                                                       framework::dataset::make("ReshapeWeights", { true, false })),
                                                                                                                framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMInputOutput3DDataset(),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMInputOutput3DDataset(),
+                                                                                                                     framework::dataset::make("ReshapeWeights", { true, false })),
                                                                                                              framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
@@ -173,15 +182,16 @@
 TEST_SUITE_END() // FP32
 
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInputOutput3DFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMInputOutput3DDataset(),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMInputOutput3DFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMInputOutput3DDataset(),
+                                                                                                                      framework::dataset::make("ReshapeWeights", { true, false })),
                                                                                                               framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMInputOutput3DDataset(),
-                                                                                                            framework::dataset::make("DataType",
-                                                                                                                    DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMInputOutput3DFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMInputOutput3DDataset(),
+                                                                                                                    framework::dataset::make("ReshapeWeights", { true, false })),
+                                                                                                            framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
@@ -194,13 +204,15 @@
 TEST_SUITE(OUTPUT_3D)
 TEST_SUITE(Float)
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMOutput3DFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMOutput3DDataset(),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMOutput3DFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMOutput3DDataset(),
+                                                                                                                  framework::dataset::make("ReshapeWeights", { true, false })),
                                                                                                           framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMOutput3DDataset(),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMOutput3DDataset(),
+                                                                                                                framework::dataset::make("ReshapeWeights", { true, false })),
                                                                                                         framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
@@ -209,15 +221,16 @@
 TEST_SUITE_END() // FP32
 
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMOutput3DFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMOutput3DDataset(),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMOutput3DFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMOutput3DDataset(),
+                                                                                                                 framework::dataset::make("ReshapeWeights", { true, false })),
                                                                                                          framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMOutput3DDataset(),
-                                                                                                       framework::dataset::make("DataType",
-                                                                                                               DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMOutput3DFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMOutput3DDataset(),
+                                                                                                               framework::dataset::make("ReshapeWeights", { true, false })),
+                                                                                                       framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
diff --git a/tests/validation/GLES_COMPUTE/GEMM.cpp b/tests/validation/GLES_COMPUTE/GEMM.cpp
index 6417143..e8184fe 100644
--- a/tests/validation/GLES_COMPUTE/GEMM.cpp
+++ b/tests/validation/GLES_COMPUTE/GEMM.cpp
@@ -82,12 +82,16 @@
 
 TEST_SUITE(Float)
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, GCGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMDataset(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunSmall, GCGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
+                                                                                                          framework::dataset::make("ReshapeWeights", { true })),
+                                                                                                  framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(GCAccessor(_target), _reference, tolerance_f32);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, GCGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMDataset(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunLarge, GCGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
+                                                                                                        framework::dataset::make("ReshapeWeights", { true })),
+                                                                                                framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(GCAccessor(_target), _reference, tolerance_f32);
diff --git a/tests/validation/NEON/GEMM.cpp b/tests/validation/NEON/GEMM.cpp
index ded5ec6..943fc9f 100644
--- a/tests/validation/NEON/GEMM.cpp
+++ b/tests/validation/NEON/GEMM.cpp
@@ -123,16 +123,23 @@
 template <typename T>
 using NEGEMMFixture = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T>;
 
+template <typename T>
+using NEGEMMFixtureDisabledC = GEMMValidationFixture<Tensor, Accessor, NEGEMM, T, true>;
+
 TEST_SUITE(Float)
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMDataset(), framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
+                                                                                                         framework::dataset::make("ReshapeWeights", { true, false })),
+                                                                                                 framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
     validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMDataset(), framework::dataset::make("DataType",
-                                                                                               DataType::F16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
+                                                                                                       framework::dataset::make("ReshapeWeights", { true, false })),
+
+                                                                                               framework::dataset::make("DataType", DataType::F16)))
 {
     // Validate output
     validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16);
@@ -141,16 +148,33 @@
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
 
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(datasets::SmallGEMMDataset(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
+                                                                                                          framework::dataset::make("ReshapeWeights", { true, false })),
+
+                                                                                                  framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(datasets::LargeGEMMDataset(), framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeGEMMDataset(),
+                                                                                                        framework::dataset::make("ReshapeWeights", { true, false })),
+
+                                                                                                framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_f);
 }
+TEST_SUITE(DisabledC)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMFixtureDisabledC<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallGEMMDataset(),
+                                                                                                                   framework::dataset::make("ReshapeWeights", { true, false })),
+
+                                                                                                           framework::dataset::make("DataType", DataType::F32)))
+{
+    // Validate output
+    validate(Accessor(_target), _reference, tolerance_f);
+}
+TEST_SUITE_END()
+
 TEST_SUITE_END()
 TEST_SUITE_END()
 
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index 74205e2..0083abf 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -42,29 +42,27 @@
 {
 namespace validation
 {
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool reinterpret_input_as_3d = false, bool reinterpret_ouput_as_3d = false>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool disable_c = false, bool reinterpret_input_as_3d = false, bool reinterpret_ouput_as_3d = false>
 class GEMMValidationFixture : public framework::Fixture
 {
 public:
     template <typename...>
-    void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, DataType data_type)
+    void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_c, TensorShape output_shape, float alpha, float beta, bool pretranspose, DataType data_type)
     {
-        _data_type = data_type;
-
-        _target    = compute_target(shape_a, shape_b, shape_c, output_shape, alpha, beta, data_type);
+        _target    = compute_target(shape_a, shape_b, shape_c, output_shape, alpha, beta, pretranspose, data_type);
         _reference = compute_reference(shape_a, shape_b, shape_c, output_shape, alpha, beta, data_type);
     }
 
 protected:
     template <typename U>
-    void fill(U &&tensor, int i)
+    void fill(U &&tensor, int i, float lo = -1.f, float hi = 1.f)
     {
         switch(tensor.data_type())
         {
             case DataType::F16:
             case DataType::F32:
             {
-                std::uniform_real_distribution<> distribution(-1.0f, 1.0f);
+                std::uniform_real_distribution<> distribution(lo, hi);
                 library->fill(tensor, distribution, i);
                 break;
             }
@@ -74,7 +72,7 @@
     }
 
     TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_c, const TensorShape &output_shape, float alpha, float beta,
-                              DataType data_type)
+                              bool pretranspose, DataType data_type)
     {
         // Create tensors
         TensorType a   = create_tensor<TensorType>(shape_a, data_type, 1);
@@ -87,7 +85,7 @@
         // The GEMMinfo includes the values of the depth in case of reinterpreted 3d output.
         // If the output shape has the same number of dimensions of the input the method called is a 2D matrix multiplication (depth_output_reinterpreted_as_3D = 0),
         // in the other case we have to use the reinterpreted version of GEMM (depth_output_reinterpreted_as_3D = depth of the 3D output).
-        gemm.configure(&a, &b, &c, &dst, alpha, beta, GEMMInfo(false, false, false, (reinterpret_ouput_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d));
+        gemm.configure(&a, &b, (disable_c) ? nullptr : &c, &dst, alpha, beta, GEMMInfo(false, false, false, (reinterpret_ouput_as_3d ? output_shape[2] : 0), reinterpret_input_as_3d));
         ARM_COMPUTE_EXPECT(a.info()->is_resizable(), framework::LogLevel::ERRORS);
         ARM_COMPUTE_EXPECT(b.info()->is_resizable(), framework::LogLevel::ERRORS);
         ARM_COMPUTE_EXPECT(c.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -107,7 +105,10 @@
         // Fill tensors
         fill(AccessorType(a), 0);
         fill(AccessorType(b), 1);
-        fill(AccessorType(c), 2);
+        if(!disable_c)
+        {
+            fill(AccessorType(c), 2);
+        }
 
         // Compute GEMM function
         gemm.run();
@@ -133,14 +134,21 @@
         // Fill reference
         fill(a, 0);
         fill(b, 1);
-        fill(c, 2);
-
-        return reference::gemm<T>(a, b, c, alpha, beta);
+        if(!disable_c)
+        {
+            fill(c, 2);
+            return reference::gemm<T>(a, b, c, alpha, beta);
+        }
+        else
+        {
+            // Setting beta to 0 will effectively disable C for the
+            // computation of the reference: alpha * A * B + 0 * C
+            return reference::gemm<T>(a, b, c, alpha, 0.f);
+        }
     }
 
     TensorType      _target{};
     SimpleTensor<T> _reference{};
-    DataType        _data_type{};
 };
 
 } // namespace validation