COMPMID-2927: Add support for mixed precision in
CLInstanceNormalizationLayer

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I91482e2e4b723606aef76afef09a8277813e5d1b
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2668
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
diff --git a/src/core/CL/cl_kernels/instance_normalization.cl b/src/core/CL/cl_kernels/instance_normalization.cl
index de7d57c..043012b 100644
--- a/src/core/CL/cl_kernels/instance_normalization.cl
+++ b/src/core/CL/cl_kernels/instance_normalization.cl
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -23,7 +23,7 @@
  */
 #include "helpers.h"
 
-#if defined(VEC_SIZE) && defined(DATA_TYPE) && defined(GAMMA) && defined(BETA) && defined(EPSILON) && defined(DIM_X) && defined(DIM_Y) && defined(DIM_Z)
+#if defined(VEC_SIZE) && defined(DATA_TYPE) && defined(INTERNAL_DATA_TYPE) && defined(GAMMA) && defined(BETA) && defined(EPSILON) && defined(DIM_X) && defined(DIM_Y) && defined(DIM_Z)
 /** This function normalizes the input 2D tensor across the first dimension with respect to mean and standard deviation of the same dimension.
  *
  * @attention Vector size should be given as a preprocessor argument using -DVEC_SIZE=size. e.g. -DVEC_SIZE=16
@@ -63,8 +63,8 @@
     Tensor4D out = CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(output, 0);
 #endif /* IN_PLACE */
 
-    float sum    = 0.f;
-    float sum_sq = 0.f;
+    INTERNAL_DATA_TYPE sum    = 0.f;
+    INTERNAL_DATA_TYPE sum_sq = 0.f;
 
 #if defined(NHWC)
 
@@ -76,7 +76,7 @@
     {
         for(int i_h = 0; i_h < DIM_Z; ++i_h)
         {
-            float data = (float) * ((__global DATA_TYPE *)tensor4D_offset(&in, ch, i_w, i_h, batch));
+            INTERNAL_DATA_TYPE data = (INTERNAL_DATA_TYPE) * ((__global DATA_TYPE *)tensor4D_offset(&in, ch, i_w, i_h, batch));
             sum += data;
             sum_sq += data * data;
         }
@@ -87,9 +87,9 @@
     const int batch          = get_global_id(2) / DIM_Z; // Current batch
     const int elements_plane = DIM_X * DIM_Y;
 
-    VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+    VEC_DATA_TYPE(INTERNAL_DATA_TYPE, VEC_SIZE)
     part_sum = 0.f;
-    VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+    VEC_DATA_TYPE(INTERNAL_DATA_TYPE, VEC_SIZE)
     part_sum_sq = 0.f;
     // Calculate partial sum
     for(int y = 0; y < DIM_Y; ++y)
@@ -98,15 +98,15 @@
         for(; x <= (DIM_X - VEC_SIZE); x += VEC_SIZE)
         {
             // Load data
-            VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
-            data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)tensor4D_offset(&in, x, y, ch, batch));
+            VEC_DATA_TYPE(INTERNAL_DATA_TYPE, VEC_SIZE)
+            data = CONVERT(VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)tensor4D_offset(&in, x, y, ch, batch)), VEC_DATA_TYPE(INTERNAL_DATA_TYPE, VEC_SIZE));
             part_sum += data;
             part_sum_sq += data * data;
         }
         // Left-overs loop
         for(; x < DIM_X; ++x)
         {
-            DATA_TYPE data = *((__global DATA_TYPE *)tensor4D_offset(&in, x, y, ch, batch));
+            INTERNAL_DATA_TYPE data = (INTERNAL_DATA_TYPE)(*((__global DATA_TYPE *)tensor4D_offset(&in, x, y, ch, batch)));
             part_sum.s0 += data;
             part_sum_sq.s0 += data * data;
         }
@@ -127,16 +127,14 @@
     part_sum.s0 += part_sum.s1;
     part_sum_sq.s0 += part_sum_sq.s1;
 
-    sum    = (float)part_sum.s0;
-    sum_sq = (float)part_sum_sq.s0;
+    sum    = (INTERNAL_DATA_TYPE)part_sum.s0;
+    sum_sq = (INTERNAL_DATA_TYPE)part_sum_sq.s0;
 
 #endif // defined(NHWC)
 
-    const float     mean_float   = (sum / elements_plane);
-    const DATA_TYPE mean         = (DATA_TYPE)mean_float;
-    const float     var_float    = (sum_sq / elements_plane) - (mean_float * mean_float);
-    const float     multip_float = GAMMA / sqrt(var_float + EPSILON);
-    const DATA_TYPE multip       = (DATA_TYPE)multip_float;
+    const INTERNAL_DATA_TYPE mean   = (sum / elements_plane);
+    const INTERNAL_DATA_TYPE var    = (sum_sq / elements_plane) - (mean * mean);
+    const INTERNAL_DATA_TYPE multip = GAMMA / sqrt(var + EPSILON);
 
 #if defined(NHWC)
 
@@ -150,7 +148,7 @@
 #else  /* !IN_PLACE */
             __global DATA_TYPE *output_address = (__global DATA_TYPE *)tensor4D_offset(&out, ch, i_w, i_h, batch);
 #endif /* IN_PLACE */
-            *(output_address) = (*(input_address) - mean) * multip + (DATA_TYPE)BETA;
+            *(output_address) = (*(input_address) - mean) * multip + (INTERNAL_DATA_TYPE)BETA;
         }
     }
 
@@ -167,13 +165,13 @@
             __global DATA_TYPE *output_address = (__global DATA_TYPE *)tensor4D_offset(&out, x, y, ch, batch);
 #endif /* IN_PLACE */
 
-            VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
-            data = VLOAD(VEC_SIZE)(0, input_address);
+            VEC_DATA_TYPE(INTERNAL_DATA_TYPE, VEC_SIZE)
+            data = CONVERT(VLOAD(VEC_SIZE)(0, input_address), VEC_DATA_TYPE(INTERNAL_DATA_TYPE, VEC_SIZE));
 
-            VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
-            res = (data - mean) * multip + (DATA_TYPE)BETA;
+            VEC_DATA_TYPE(INTERNAL_DATA_TYPE, VEC_SIZE)
+            res = (data - mean) * multip + (INTERNAL_DATA_TYPE)BETA;
             VSTORE(VEC_SIZE)
-            (res, 0, output_address);
+            (CONVERT(res, VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)), 0, output_address);
         }
         // Left-overs loop
         for(; x < DIM_X; ++x)
@@ -184,9 +182,9 @@
 #else  /* !IN_PLACE */
             __global DATA_TYPE *output_address = (__global DATA_TYPE *)tensor4D_offset(&out, x, y, ch, batch);
 #endif /* IN_PLACE */
-            *(output_address)                  = (*(input_address) - mean) * multip + (DATA_TYPE)BETA;
+            *(output_address)                  = (*(input_address) - mean) * multip + (INTERNAL_DATA_TYPE)BETA;
         }
     }
 #endif // defined(NHWC)
 }
-#endif /* defined(VEC_SIZE) && defined(DATA_TYPE) && defined(GAMMA) && defined(BETA) && defined(EPSILON) && defined(DIM_X) && defined(DIM_Y) && defined(DIM_Z) */
+#endif /* defined(VEC_SIZE) && defined(DATA_TYPE) && defined(INTERNAL_DATA_TYPE) && defined(GAMMA) && defined(BETA) && defined(EPSILON) && defined(DIM_X) && defined(DIM_Y) && defined(DIM_Z) */
diff --git a/src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp
index 0f20857..5c2a3d9 100644
--- a/src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLInstanceNormalizationLayerKernel.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -38,12 +38,9 @@
 {
 namespace
 {
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, float gamma, float beta, float epsilon)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const InstanceNormalizationLayerKernelInfo &info)
 {
-    ARM_COMPUTE_UNUSED(gamma);
-    ARM_COMPUTE_UNUSED(beta);
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(epsilon == 0.f, "Epsilon must be different than 0");
-
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.epsilon == 0.f, "Epsilon must be different than 0");
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(input, DataType::F16, DataType::F32);
 
     if(output != nullptr && output->total_size() != 0)
@@ -74,33 +71,31 @@
 } // namespace
 
 CLInstanceNormalizationLayerKernel::CLInstanceNormalizationLayerKernel()
-    : _input(nullptr), _output(nullptr), _gamma(1), _beta(0), _epsilon(1e-12), _run_in_place(false)
+    : _input(nullptr), _output(nullptr), _run_in_place(false)
 {
 }
 
-void CLInstanceNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *output, float gamma, float beta, float epsilon)
+void CLInstanceNormalizationLayerKernel::configure(ICLTensor *input, ICLTensor *output, const InstanceNormalizationLayerKernelInfo &info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input);
 
-    _input   = input;
-    _output  = output == nullptr ? input : output;
-    _gamma   = gamma;
-    _beta    = beta;
-    _epsilon = epsilon;
+    _input  = input;
+    _output = output == nullptr ? input : output;
 
     _run_in_place = (output == nullptr) || (output == input);
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), gamma, beta, epsilon));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(_input->info(), _output->info(), info));
     const unsigned int num_elems_processed_per_iteration = 16 / input->info()->element_size();
 
     CLBuildOptions build_opts;
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
+    build_opts.add_option("-DINTERNAL_DATA_TYPE=" + (info.use_mixed_precision ? "float" : get_cl_type_from_data_type(input->info()->data_type())));
     build_opts.add_option("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration));
     build_opts.add_option("-DDIM_X=" + support::cpp11::to_string(input->info()->dimension(0)));
     build_opts.add_option("-DDIM_Y=" + support::cpp11::to_string(input->info()->dimension(1)));
     build_opts.add_option("-DDIM_Z=" + support::cpp11::to_string(input->info()->dimension(2)));
-    build_opts.add_option("-DGAMMA=" + float_to_string_with_full_precision(gamma));
-    build_opts.add_option("-DBETA=" + float_to_string_with_full_precision(beta));
-    build_opts.add_option("-DEPSILON=" + float_to_string_with_full_precision(epsilon));
+    build_opts.add_option("-DGAMMA=" + float_to_string_with_full_precision(info.gamma));
+    build_opts.add_option("-DBETA=" + float_to_string_with_full_precision(info.beta));
+    build_opts.add_option("-DEPSILON=" + float_to_string_with_full_precision(info.epsilon));
     build_opts.add_option_if(_run_in_place, "-DIN_PLACE");
     build_opts.add_option_if(_input->info()->data_layout() == DataLayout::NHWC, "-DNHWC");
 
@@ -113,9 +108,9 @@
     ICLKernel::configure_internal(std::get<1>(win_config));
 }
 
-Status CLInstanceNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, float gamma, float beta, float epsilon)
+Status CLInstanceNormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const InstanceNormalizationLayerKernelInfo &info)
 {
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, gamma, beta, epsilon));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, info));
     ARM_COMPUTE_RETURN_ON_ERROR(std::get<0>(validate_and_configure_window(input->clone().get(), (output == nullptr ? input->clone().get() : output->clone().get()))));
     return Status{};
 }