COMPMID-2927: Add support for mixed precision in
CLInstanceNormalizationLayer

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I91482e2e4b723606aef76afef09a8277813e5d1b
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2668
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
diff --git a/src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp
index 0f20857..5c2a3d9 100644
--- a/src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -38,12 +38,9 @@
 {
 namespace
 {
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, float gamma, float beta, float epsilon)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const InstanceNormalizationLayerKernelInfo &info)
 {
-    ARM_COMPUTE_UNUSED(gamma);
-    ARM_COMPUTE_UNUSED(beta);
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(epsilon == 0.f, "Epsilon must be different than 0");
-
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.epsilon == 0.f, "Epsilon must be different than 0");
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(input, DataType::F16, DataType::F32);
 
     if(output != nullptr && output->total_size() != 0)
@@ -74,33 +71,31 @@
 } // namespace
 
 CLInstanceNormalizationLayerKernel::CLInstanceNormalizationLayerKernel()
-    : _input(nullptr), _output(nullptr), _gamma(1), _beta(0), _epsilon(1e-12), _run_in_place(false)
+    : _input(nullptr), _output(nullptr), _run_in_place(false)
 {
 }
 
-void CLInstanceNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *output, float gamma, float beta, float epsilon)
+void CLInstanceNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *output, const InstanceNormalizationLayerKernelInfo &info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input);
 
-    _input   = input;
-    _output  = output == nullptr ? input : output;
-    _gamma   = gamma;
-    _beta    = beta;
-    _epsilon = epsilon;
+    _input  = input;
+    _output = output == nullptr ? input : output;
 
     _run_in_place = (output == nullptr) || (output == input);
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), gamma, beta, epsilon));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), info));
     const unsigned int num_elems_processed_per_iteration = 16 / input->info()->element_size();
 
     CLBuildOptions build_opts;
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
+    build_opts.add_option("-DINTERNAL_DATA_TYPE=" + (info.use_mixed_precision ? "float" : get_cl_type_from_data_type(input->info()->data_type())));
     build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration));
     build_opts.add_option("-DDIM_X=" + support::cpp11::to_string(input->info()->dimension(0)));
     build_opts.add_option("-DDIM_Y=" + support::cpp11::to_string(input->info()->dimension(1)));
     build_opts.add_option("-DDIM_Z=" + support::cpp11::to_string(input->info()->dimension(2)));
-    build_opts.add_option("-DGAMMA=" + float_to_string_with_full_precision(gamma));
-    build_opts.add_option("-DBETA=" + float_to_string_with_full_precision(beta));
-    build_opts.add_option("-DEPSILON=" + float_to_string_with_full_precision(epsilon));
+    build_opts.add_option("-DGAMMA=" + float_to_string_with_full_precision(info.gamma));
+    build_opts.add_option("-DBETA=" + float_to_string_with_full_precision(info.beta));
+    build_opts.add_option("-DEPSILON=" + float_to_string_with_full_precision(info.epsilon));
     build_opts.add_option_if(_run_in_place, "-DIN_PLACE");
     build_opts.add_option_if(_input->info()->data_layout() == DataLayout::NHWC, "-DNHWC");
 
@@ -113,9 +108,9 @@
     ICLKernel::configure_internal(std::get<1>(win_config));
 }
 
-Status CLInstanceNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, float gamma, float beta, float epsilon)
+Status CLInstanceNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const InstanceNormalizationLayerKernelInfo &info)
 {
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, gamma, beta, epsilon));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, info));
     ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), (output == nullptr ? input->clone().get() : output->clone().get()))));
     return Status{};
 }