COMPMID-2571: Add mixed-precision support in CLGEMMReshaped for FP16

Change-Id: I5ba90d4de4594ed784c7230aa6b10503be67c001
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1991
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
index 99f5ffe..b885bfe 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
@@ -60,10 +60,20 @@
 template <typename T>
 using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
 
+// Fixture for CLGEMMMatrixMultiplyReshaped mixed precision
+template <typename T>
+using CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture =
+    GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
+
 // Fixture for CLGEMMMatrixMultiplyReshaped3D
 template <typename T>
 using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
 
+// Fixture for CLGEMMMatrixMultiplyReshaped3D mixed precision
+template <typename T>
+using CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture =
+    GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
+
 namespace
 {
 // *INDENT-OFF*
@@ -71,15 +81,12 @@
 RelativeTolerance<float> rel_tolerance_f32(0.001f);
 constexpr float          abs_tolerance_f32(0.0001f);
 
+RelativeTolerance<float> rel_tolerance_f16_mixed_precision(0.001f);
+constexpr float          abs_tolerance_f16_mixed_precision(0.01f);
+
 RelativeTolerance<float> rel_tolerance_f16(0.001f);
 constexpr float          abs_tolerance_f16(0.01f);
 
-/** Alpha values to test - Precommit */
-const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
-
-/** Beta values to test - Precommit */
-const auto beta_values = framework::dataset::make("beta", {-0.35f, 0.0f} );
-
 /** M values to test */
 const auto m_values = framework::dataset::make("M", 37);
 
@@ -105,6 +112,12 @@
     ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f),
 });
 
+/** Alpha values to test - Precommit */
+const auto a_values_precommit = framework::dataset::make("alpha", {-0.75f} );
+
+/** Beta values to test - Precommit */
+const auto beta_values_precommit = framework::dataset::make("beta", {-0.35f} );
+
 /** M0 values to test - Precommit */
 const auto m0_values_precommit = framework::dataset::make("M0", { 4 });
 
@@ -120,6 +133,12 @@
 /** H0 values to test - Precommit */
 const auto h0_values_precommit = framework::dataset::make("H0", 1, 3);
 
+/** Alpha values to test - Nightly */
+const auto a_values_nightly = framework::dataset::make("alpha", {1.0f} );
+
+/** Beta values to test - Nightly */
+const auto beta_values_nightly = framework::dataset::make("beta", {1.0f} );
+
 /** M0 values to test - Nightly */
 const auto m0_values_nightly = framework::dataset::make("M0", { 2, 3, 4, 8 });
 
@@ -167,8 +186,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F32)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_precommit),
+                                                                   beta_values_precommit),
                                                                    broadcast_bias_values),
                                                                    lhs_transpose_values),
                                                                    act_values))
@@ -191,8 +210,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F32)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_nightly),
+                                                                   beta_values_nightly),
                                                                    broadcast_bias_values),
                                                                    lhs_transpose_values),
                                                                    act_values))
@@ -216,8 +235,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F32)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_precommit),
+                                                                   beta_values_precommit),
                                                                    lhs_transpose_values),
                                                                    act_values))
 {
@@ -240,8 +259,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F32)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_nightly),
+                                                                   beta_values_nightly),
                                                                    lhs_transpose_values),
                                                                    act_values))
 {
@@ -266,8 +285,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F16)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_precommit),
+                                                                   beta_values_precommit),
                                                                    broadcast_bias_values),
                                                                    lhs_transpose_values),
                                                                    act_values))
@@ -290,8 +309,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F16)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_nightly),
+                                                                   beta_values_nightly),
                                                                    broadcast_bias_values),
                                                                    lhs_transpose_values),
                                                                    act_values))
@@ -315,8 +334,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F16)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_precommit),
+                                                                   beta_values_precommit),
                                                                    lhs_transpose_values),
                                                                    act_values))
 {
@@ -339,8 +358,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F16)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_nightly),
+                                                                   beta_values_nightly),
                                                                    lhs_transpose_values),
                                                                    act_values))
 {
@@ -348,6 +367,105 @@
     validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
 }
 TEST_SUITE_END() // FP16
+
+TEST_SUITE(MixedPrecision)
+
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_values,
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_precommit),
+                                                                   n0_values_precommit),
+                                                                   k0_values_precommit),
+                                                                   v0_values_precommit),
+                                                                   h0_values_precommit),
+                                                                   i_values_lhs),
+                                                                   i_values_rhs),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values_precommit),
+                                                                   beta_values_precommit),
+                                                                   broadcast_bias_values),
+                                                                   lhs_transpose_values),
+                                                                   act_values))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::NIGHTLY,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_values,
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_nightly),
+                                                                   n0_values_nightly),
+                                                                   k0_values_nightly),
+                                                                   v0_values_nightly),
+                                                                   h0_values_nightly),
+                                                                   i_values_lhs),
+                                                                   i_values_rhs),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values_nightly),
+                                                                   beta_values_nightly),
+                                                                   broadcast_bias_values),
+                                                                   lhs_transpose_values),
+                                                                   act_values))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_w_values,
+                                                                   m_h_values),
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_precommit),
+                                                                   n0_values_precommit),
+                                                                   k0_values_precommit),
+                                                                   v0_values_precommit),
+                                                                   h0_values_precommit),
+                                                                   i_values_lhs),
+                                                                   i_values_rhs),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values_precommit),
+                                                                   beta_values_precommit),
+                                                                   lhs_transpose_values),
+                                                                   act_values))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::NIGHTLY,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_w_values,
+                                                                   m_h_values),
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_nightly),
+                                                                   n0_values_nightly),
+                                                                   k0_values_nightly),
+                                                                   v0_values_nightly),
+                                                                   h0_values_nightly),
+                                                                   i_values_lhs),
+                                                                   i_values_rhs),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values_nightly),
+                                                                   beta_values_nightly),
+                                                                   lhs_transpose_values),
+                                                                   act_values))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
+}
+TEST_SUITE_END() // MixedPrecision
 TEST_SUITE_END() // Float
 TEST_SUITE_END() // GEMMMatrixMultiplyReshaped
 TEST_SUITE_END() // CL