COMPMID-2571: Add mixed-precision support in CLGEMMReshaped for FP16

Change-Id: I5ba90d4de4594ed784c7230aa6b10503be67c001
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1991
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
index b00faed..a390e34 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
@@ -68,6 +68,7 @@
     ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr)
                                     && (!gemm_info.broadcast_bias),
                                     "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision, "Mixed precision not supported");
 
     const unsigned int m = gemm_info.m;
     const unsigned int n = gemm_info.n;
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
index f77ab02..9b9eb12 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
@@ -77,6 +77,7 @@
     ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr)
                                     && (!gemm_info.broadcast_bias),
                                     "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision && (input0->data_type() == DataType::F32), "Mixed precision only supported for F16 data type");
 
     const unsigned int m = gemm_info.m;
     const unsigned int n = gemm_info.n;
@@ -240,9 +241,11 @@
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
     ICLKernel::configure_internal(win_config.second);
 
+    const bool     enable_mixed_precision = gemm_info.fp_mixed_precision;
+    const DataType data_type              = input0->info()->data_type();
+
     // Create build options
     CLBuildOptions build_opts;
-    build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type()));
     build_opts.add_option_if(!(helpers::float_ops::is_one(alpha)), "-DALPHA=" + float_to_string_with_full_precision(alpha));
     build_opts.add_option_if(_input2 != nullptr, "-DBETA=" + float_to_string_with_full_precision(beta));
     build_opts.add_option_if(helpers::float_ops::is_one(beta), "-DUNIT_BETA");
@@ -255,6 +258,12 @@
     build_opts.add_option_if(rhs_info.interleave, "-DRHS_INTERLEAVE");
     build_opts.add_option_if(lhs_info.transpose, "-DLHS_TRANSPOSE");
     build_opts.add_option_if(_use_dummy_work_items, "-DDUMMY_WORK_ITEMS");
+    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation())));
+    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a()));
+    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b()));
+    build_opts.add_option_if(enable_mixed_precision, "-DMIXED_PRECISION");
+    build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
+    build_opts.add_option("-DDATA_TYPE_ACCUMULATOR=" + (enable_mixed_precision ? get_cl_type_from_data_type(DataType::F32) : get_cl_type_from_data_type(data_type)));
     build_opts.add_option("-DM=" + support::cpp11::to_string(gemm_info.m));
     build_opts.add_option("-DN=" + support::cpp11::to_string(gemm_info.n));
     build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0));
@@ -262,9 +271,6 @@
     build_opts.add_option("-DK0=" + support::cpp11::to_string(lhs_info.k0));
     build_opts.add_option("-DV0=" + support::cpp11::to_string(lhs_info.v0));
     build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0));
-    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation())));
-    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a()));
-    build_opts.add_option_if(gemm_info.activation_info.enabled(), "-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b()));
 
     std::string kernel_name("gemm_mm_reshaped_");
     kernel_name += lhs_info.transpose ? "lhs_t_" : "lhs_nt_";
@@ -282,6 +288,7 @@
     _config_id += (gemm_info.activation_info.enabled() ? "fused_activation_" : "");
     _config_id += lower_string(string_from_data_type(input0->info()->data_type()));
     _config_id += "_";
+    _config_id += (enable_mixed_precision ? "mixed_precision_" : "");
     _config_id += support::cpp11::to_string(output->info()->dimension(1));
     _config_id += "_";
     _config_id += support::cpp11::to_string(output->info()->dimension(0));
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
index fff4da6..3d5e148 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
@@ -68,6 +68,7 @@
     ARM_COMPUTE_RETURN_ERROR_ON_MSG((gemm_info.reinterpret_input_as_3d || gemm_info.depth_output_gemm3d != 0) && (input2 != nullptr)
                                     && (!gemm_info.broadcast_bias),
                                     "Bias addition only supported with broadcast mode in case the input or output has to be reinterpreted as 3D");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision, "Mixed precision not supported");
 
     const unsigned int m = gemm_info.m;
     const unsigned int n = gemm_info.n;