COMPMID-2571: Add mixed-precision support in CLGEMMReshaped for FP16

Change-Id: I5ba90d4de4594ed784c7230aa6b10503be67c001
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1991
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
index 99f5ffe..b885bfe 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
@@ -60,10 +60,20 @@
 template <typename T>
 using CLGEMMMatrixMultiplyReshapedFixture = GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
 
+// Fixture for CLGEMMMatrixMultiplyReshaped mixed precision
+template <typename T>
+using CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture =
+    GEMMMatrixMultiplyReshapedValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
+
 // Fixture for CLGEMMMatrixMultiplyReshaped3D
 template <typename T>
 using CLGEMMMatrixMultiplyReshaped3DFixture = GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped>;
 
+// Fixture for CLGEMMMatrixMultiplyReshaped3D mixed precision
+template <typename T>
+using CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture =
+    GEMMMatrixMultiplyReshaped3DValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeLHSMatrix, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshaped, true>;
+
 namespace
 {
 // *INDENT-OFF*
@@ -71,15 +81,12 @@
 RelativeTolerance<float> rel_tolerance_f32(0.001f);
 constexpr float          abs_tolerance_f32(0.0001f);
 
+RelativeTolerance<float> rel_tolerance_f16_mixed_precision(0.001f);
+constexpr float          abs_tolerance_f16_mixed_precision(0.01f);
+
 RelativeTolerance<float> rel_tolerance_f16(0.001f);
 constexpr float          abs_tolerance_f16(0.01f);
 
-/** Alpha values to test - Precommit */
-const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
-
-/** Beta values to test - Precommit */
-const auto beta_values = framework::dataset::make("beta", {-0.35f, 0.0f} );
-
 /** M values to test */
 const auto m_values = framework::dataset::make("M", 37);
 
@@ -105,6 +112,12 @@
     ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 8.f, 2.f),
 });
 
+/** Alpha values to test - Precommit */
+const auto a_values_precommit = framework::dataset::make("alpha", {-0.75f} );
+
+/** Beta values to test - Precommit */
+const auto beta_values_precommit = framework::dataset::make("beta", {-0.35f} );
+
 /** M0 values to test - Precommit */
 const auto m0_values_precommit = framework::dataset::make("M0", { 4 });
 
@@ -120,6 +133,12 @@
 /** H0 values to test - Precommit */
 const auto h0_values_precommit = framework::dataset::make("H0", 1, 3);
 
+/** Alpha values to test - Nightly */
+const auto a_values_nightly = framework::dataset::make("alpha", {1.0f} );
+
+/** Beta values to test - Nightly */
+const auto beta_values_nightly = framework::dataset::make("beta", {1.0f} );
+
 /** M0 values to test - Nightly */
 const auto m0_values_nightly = framework::dataset::make("M0", { 2, 3, 4, 8 });
 
@@ -167,8 +186,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F32)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_precommit),
+                                                                   beta_values_precommit),
                                                                    broadcast_bias_values),
                                                                    lhs_transpose_values),
                                                                    act_values))
@@ -191,8 +210,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F32)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_nightly),
+                                                                   beta_values_nightly),
                                                                    broadcast_bias_values),
                                                                    lhs_transpose_values),
                                                                    act_values))
@@ -216,8 +235,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F32)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_precommit),
+                                                                   beta_values_precommit),
                                                                    lhs_transpose_values),
                                                                    act_values))
 {
@@ -240,8 +259,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F32)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_nightly),
+                                                                   beta_values_nightly),
                                                                    lhs_transpose_values),
                                                                    act_values))
 {
@@ -266,8 +285,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F16)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_precommit),
+                                                                   beta_values_precommit),
                                                                    broadcast_bias_values),
                                                                    lhs_transpose_values),
                                                                    act_values))
@@ -290,8 +309,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F16)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_nightly),
+                                                                   beta_values_nightly),
                                                                    broadcast_bias_values),
                                                                    lhs_transpose_values),
                                                                    act_values))
@@ -315,8 +334,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F16)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_precommit),
+                                                                   beta_values_precommit),
                                                                    lhs_transpose_values),
                                                                    act_values))
 {
@@ -339,8 +358,8 @@
                                                                    i_values_lhs),
                                                                    i_values_rhs),
                                                                    framework::dataset::make("DataType", DataType::F16)),
-                                                                   a_values),
-                                                                   beta_values),
+                                                                   a_values_nightly),
+                                                                   beta_values_nightly),
                                                                    lhs_transpose_values),
                                                                    act_values))
 {
@@ -348,6 +367,105 @@
     validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
 }
 TEST_SUITE_END() // FP16
+
+TEST_SUITE(MixedPrecision)
+
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_values,
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_precommit),
+                                                                   n0_values_precommit),
+                                                                   k0_values_precommit),
+                                                                   v0_values_precommit),
+                                                                   h0_values_precommit),
+                                                                   i_values_lhs),
+                                                                   i_values_rhs),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values_precommit),
+                                                                   beta_values_precommit),
+                                                                   broadcast_bias_values),
+                                                                   lhs_transpose_values),
+                                                                   act_values))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedMixedPrecisionFixture<half>, framework::DatasetMode::NIGHTLY,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_values,
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_nightly),
+                                                                   n0_values_nightly),
+                                                                   k0_values_nightly),
+                                                                   v0_values_nightly),
+                                                                   h0_values_nightly),
+                                                                   i_values_lhs),
+                                                                   i_values_rhs),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values_nightly),
+                                                                   beta_values_nightly),
+                                                                   broadcast_bias_values),
+                                                                   lhs_transpose_values),
+                                                                   act_values))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::ALL,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_w_values,
+                                                                   m_h_values),
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_precommit),
+                                                                   n0_values_precommit),
+                                                                   k0_values_precommit),
+                                                                   v0_values_precommit),
+                                                                   h0_values_precommit),
+                                                                   i_values_lhs),
+                                                                   i_values_rhs),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values_precommit),
+                                                                   beta_values_precommit),
+                                                                   lhs_transpose_values),
+                                                                   act_values))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DMixedPrecisionFixture<half>, framework::DatasetMode::NIGHTLY,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_w_values,
+                                                                   m_h_values),
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_nightly),
+                                                                   n0_values_nightly),
+                                                                   k0_values_nightly),
+                                                                   v0_values_nightly),
+                                                                   h0_values_nightly),
+                                                                   i_values_lhs),
+                                                                   i_values_rhs),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values_nightly),
+                                                                   beta_values_nightly),
+                                                                   lhs_transpose_values),
+                                                                   act_values))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, rel_tolerance_f16_mixed_precision, 0.f, abs_tolerance_f16_mixed_precision);
+}
+TEST_SUITE_END() // MixedPrecision
 TEST_SUITE_END() // Float
 TEST_SUITE_END() // GEMMMatrixMultiplyReshaped
 TEST_SUITE_END() // CL
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index 854cc4a..bf919c9 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -667,7 +667,7 @@
     SimpleTensor<T> _reference{};
 };
 
-template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType>
+template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType, bool fp_mixed_precision = false>
 class GEMMMatrixMultiplyReshapedValidationFixture : public framework::Fixture
 {
 public:
@@ -734,6 +734,7 @@
         kernel_info.reinterpret_input_as_3d = false;
         kernel_info.broadcast_bias          = broadcast_bias;
         kernel_info.activation_info         = act_info;
+        kernel_info.fp_mixed_precision      = fp_mixed_precision;
 
         // The output tensor will be auto-initialized within the function
 
@@ -807,14 +808,21 @@
             }
         }
 
-        return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+        if(fp_mixed_precision)
+        {
+            return reference::activation_layer(reference::gemm_mixed_precision<T>(lhs, rhs, bias, alpha, beta), act_info);
+        }
+        else
+        {
+            return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+        }
     }
 
     TensorType      _target{};
     SimpleTensor<T> _reference{};
 };
 
-template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType>
+template <typename TensorType, typename AccessorType, typename T, typename ReshapeLHSFunctionType, typename ReshapeRHSFunctionType, typename GEMMFunctionType, bool fp_mixed_precision = false>
 class GEMMMatrixMultiplyReshaped3DValidationFixture : public framework::Fixture
 {
 public:
@@ -879,6 +887,7 @@
         kernel_info.reinterpret_input_as_3d = false;
         kernel_info.broadcast_bias          = true;
         kernel_info.activation_info         = act_info;
+        kernel_info.fp_mixed_precision      = fp_mixed_precision;
 
         // The output tensor will be auto-initialized within the function
 
@@ -951,7 +960,14 @@
             memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
         }
 
-        return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+        if(fp_mixed_precision)
+        {
+            return reference::activation_layer(reference::gemm_mixed_precision<T>(lhs, rhs, bias, alpha, beta), act_info);
+        }
+        else
+        {
+            return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+        }
     }
 
     TensorType      _target{};
diff --git a/tests/validation/reference/GEMM.cpp b/tests/validation/reference/GEMM.cpp
index 2feab89..3c72b94 100644
--- a/tests/validation/reference/GEMM.cpp
+++ b/tests/validation/reference/GEMM.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -84,8 +84,61 @@
     return dst;
 }
 
+template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type>
+SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta)
+{
+    // GEMM mixed-precision combines F32 accumulators with F16 multiplications
+    // Create reference
+    SimpleTensor<T> dst{ c.shape(), c.data_type(), 1 };
+
+    // Compute reference
+    const int M = a.shape().y();
+    const int N = b.shape().x();
+    const int K = a.shape().x();
+    const int D = a.shape().z(); // Number of matrices in a batch
+    const int W = a.shape()[3];  // Number of batched-gemm (Winograd case)
+
+    const int a_stride_z = K * M;
+    const int a_stride_w = K * M * D;
+
+    const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0;     // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions
+    const int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions
+
+    const int c_stride_z = N * M;
+    const int c_stride_w = N * M * D;
+
+    for(int w = 0; w < W; ++w)
+    {
+        for(int depth = 0; depth < D; ++depth)
+        {
+            const int base_addr_a = depth * a_stride_z + w * a_stride_w;
+            const int base_addr_b = depth * b_stride_z + w * b_stride_w;
+            const int base_addr_c = depth * c_stride_z + w * c_stride_w;
+
+            for(int row = 0; row < M; ++row)
+            {
+                for(int col = 0; col < N; ++col)
+                {
+                    float acc(0);
+
+                    for(int k = 0; k < K; ++k)
+                    {
+                        acc += static_cast<float>(a[base_addr_a + k + row * K] * b[base_addr_b + col + k * N]);
+                    }
+
+                    // Finalize the result: alpha * A * B + beta * C
+                    dst[base_addr_c + col + row * N] = static_cast<T>(alpha * acc + beta * c[base_addr_c + col + row * N]);
+                }
+            }
+        }
+    }
+
+    return dst;
+}
+
 template SimpleTensor<float> gemm(const SimpleTensor<float> &a, const SimpleTensor<float> &b, const SimpleTensor<float> &c, float alpha, float beta);
 template SimpleTensor<half> gemm(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
+template SimpleTensor<half> gemm_mixed_precision(const SimpleTensor<half> &a, const SimpleTensor<half> &b, const SimpleTensor<half> &c, float alpha, float beta);
 } // namespace reference
 } // namespace validation
 } // namespace test
diff --git a/tests/validation/reference/GEMM.h b/tests/validation/reference/GEMM.h
index 39007c6..9bcd640 100644
--- a/tests/validation/reference/GEMM.h
+++ b/tests/validation/reference/GEMM.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -38,6 +38,9 @@
 template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
 SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta);
 
+template <typename T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
+SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const SimpleTensor<T> &c, float alpha, float beta);
+
 } // namespace reference
 } // namespace validation
 } // namespace test