COMPMID-425 Port CLBatchnormalization to support QS8/QS16

Change-Id: I46c93305f377666ea0915ff789b7dfdfff596087
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/78862
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
diff --git a/src/core/CL/cl_kernels/batchnormalization_layer.cl b/src/core/CL/cl_kernels/batchnormalization_layer.cl
index 13e6702..cb4d0c8 100644
--- a/src/core/CL/cl_kernels/batchnormalization_layer.cl
+++ b/src/core/CL/cl_kernels/batchnormalization_layer.cl
@@ -21,11 +21,31 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
+
 #include "helpers.h"
 
+#if defined(FIXED_POINT_POSITION)
+#include "fixed_point.h"
+
+#define ADD_OP(a, b) ADD_SAT_OP_EXPAND((a), (b), DATA_TYPE, VEC_SIZE)
+#define SUB_OP(a, b) SUB_SAT_OP_EXPAND((a), (b), DATA_TYPE, VEC_SIZE)
+#define MUL_OP(a, b) MUL_SAT_OP_EXPAND((a), (b), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
+#define INVSQRT_OP(a) INVSQRT_OP_EXPAND((a), DATA_TYPE, VEC_SIZE, FIXED_POINT_POSITION)
+#define SQCVT_SAT(a) SQCVT_SAT_OP_EXPAND((a), DATA_TYPE, FIXED_POINT_POSITION)
+
+#else /* FIXED_POINT_POSITION */
+
+#define ADD_OP(a, b) ((a) + (b))
+#define SUB_OP(a, b) ((a) - (b))
+#define MUL_OP(a, b) ((a) * (b))
+#define INVSQRT_OP(a) rsqrt((a))
+#define SQCVT_SAT(a) (a)
+
+#endif /* FIXED_POINT_POSITION */
+
 /** Apply batch normalization.
  *
- * @param[in]  input_ptr                            Pointer to the first source tensor. Supported data types: F32
+ * @param[in]  input_ptr                            Pointer to the first source tensor. Supported data types: QS8/QS16/F32
  * @param[in]  input_stride_x                       Stride of the first source tensor in X dimension (in bytes)
  * @param[in]  input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  input_stride_y                       Stride of the first source tensor in Y dimension (in bytes)
@@ -33,7 +53,7 @@
  * @param[in]  input_stride_z                       Stride of the first source tensor in Z dimension (in bytes)
  * @param[in]  input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  input_offset_first_element_in_bytes  The offset of the first element in the first source tensor
- * @param[out] output_ptr                           Pointer to the destination tensor. Supported data types: F32
+ * @param[out] output_ptr                           Pointer to the destination tensor. Supported data types: same as @p input_ptr
  * @param[in]  output_stride_x                      Stride of the destination tensor in X dimension (in bytes)
  * @param[in]  output_step_x                        output_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  output_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
@@ -41,19 +61,19 @@
  * @param[in]  output_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
  * @param[in]  output_step_z                        output_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  output_offset_first_element_in_bytes The offset of the first element in the destination tensor
- * @param[in]  mean_ptr                             Pointer to the mean source tensor. Supported data types: F32
+ * @param[in]  mean_ptr                             Pointer to the mean source tensor. Supported data types: same as @p input_ptr
  * @param[in]  mean_stride_x                        Stride of the mean source tensor in X dimension (in bytes)
  * @param[in]  mean_step_x                          mean_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  mean_offset_first_element_in_bytes   The offset of the first element in the mean source tensor
- * @param[in]  var_ptr                              Pointer to the var tensor. Supported data types: F32
+ * @param[in]  var_ptr                              Pointer to the var tensor. Supported data types: same as @p input_ptr
  * @param[in]  var_stride_x                         Stride of the var tensor in X dimension (in bytes)
  * @param[in]  var_step_x                           var_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  var_offset_first_element_in_bytes    The offset of the first element in the var source tensor
- * @param[in]  beta_ptr                             Pointer to the beta source tensor. Supported data types: F32
+ * @param[in]  beta_ptr                             Pointer to the beta source tensor. Supported data types: same as @p input_ptr
  * @param[in]  beta_stride_x                        Stride of the beta source tensor in X dimension (in bytes)
  * @param[in]  beta_step_x                          beta_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  beta_offset_first_element_in_bytes   The offset of the first element in the beta source tensor
- * @param[in]  gamma_ptr                            Pointer to the gamma source tensor. Supported data types: F32
+ * @param[in]  gamma_ptr                            Pointer to the gamma source tensor. Supported data types: same as @p input_ptr
  * @param[in]  gamma_stride_x                       Stride of the gamma source tensor in X dimension (in bytes)
  * @param[in]  gamma_step_x                         gamma_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  gamma_offset_first_element_in_bytes  The offset of the first element in the gamma source tensor
@@ -74,26 +94,33 @@
     Vector   beta  = CONVERT_TO_VECTOR_STRUCT(beta);
     Vector   gamma = CONVERT_TO_VECTOR_STRUCT(gamma);
 
-    float4 _in         = 0;
-    float4 denominator = 0;
-    float4 numerator   = 0;
-    float4 x_bar       = 0;
-    float4 gamma_vec   = 0;
-    float4 beta_vec    = 0;
+    VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+    _in = 0;
+    VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+    denominator = 0;
+    VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+    numerator = 0;
+    VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+    x_bar = 0;
+    VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+    gamma_vec = 0;
+    VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
+    beta_vec = 0;
 
     const int current_slice = get_global_id(2);
 
-    _in         = vload4(0, (__global float *)in.ptr);
-    denominator = *((__global float *)(var.ptr + current_slice * var.stride_x));
-    denominator = rsqrt(denominator + epsilon);
+    _in         = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)in.ptr);
+    denominator = *((__global DATA_TYPE *)(var.ptr + current_slice * var.stride_x));
+    denominator = INVSQRT_OP(ADD_OP(denominator, SQCVT_SAT(epsilon)));
 
     // Calculate x bar and store results
-    numerator = *((__global float *)(mean.ptr + current_slice * mean.stride_x));
-    numerator = _in - numerator;
-    x_bar     = numerator * denominator;
+    numerator = *((__global DATA_TYPE *)(mean.ptr + current_slice * mean.stride_x));
+    numerator = SUB_OP(_in, numerator);
+    x_bar     = MUL_OP(numerator, denominator);
 
-    gamma_vec = *((__global float *)(gamma.ptr + current_slice * beta.stride_x));
-    beta_vec  = *((__global float *)(beta.ptr + current_slice * beta.stride_x));
+    gamma_vec = *((__global DATA_TYPE *)(gamma.ptr + current_slice * beta.stride_x));
+    beta_vec  = *((__global DATA_TYPE *)(beta.ptr + current_slice * beta.stride_x));
 
-    vstore4(gamma_vec * x_bar + beta_vec, 0, (__global float *)out.ptr);
+    VSTORE(VEC_SIZE)
+    (ADD_OP(MUL_OP(gamma_vec, x_bar), beta_vec), 0, (__global DATA_TYPE *)out.ptr);
 }
diff --git a/src/core/CL/cl_kernels/fixed_point.h b/src/core/CL/cl_kernels/fixed_point.h
index bb534f5..4de7fc5 100644
--- a/src/core/CL/cl_kernels/fixed_point.h
+++ b/src/core/CL/cl_kernels/fixed_point.h
@@ -471,4 +471,16 @@
 CONVERTQ_UP_IMPL(qs8x16, float16)
 CONVERTQ_UP_IMPL(qs16x16, float16)
 
+#define SQCVT_SAT_IMPL(type)                                                                    \
+    inline type sqcvt_##type##_sat(float a, int fixed_point_position)                           \
+    {                                                                                           \
+        return CONVERT_SAT((a * (1 << fixed_point_position) + ((a < 0) ? -0.5f : 0.5f)), type); \
+    }
+
+SQCVT_SAT_IMPL(qs8)
+SQCVT_SAT_IMPL(qs16)
+
+#define SQCVT_SAT_OP_EXPAND_STR(a, type, position) sqcvt_##type##_sat((a), (position))
+#define SQCVT_SAT_OP_EXPAND(a, type, position) SQCVT_SAT_OP_EXPAND_STR((a), type, position)
+
 #endif // ARM_COMPUTE_FIXED_POINT_H
diff --git a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
index 85d8ab7..02bf35a 100644
--- a/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
+++ b/src/core/CL/kernels/CLBatchNormalizationLayerKernel.cpp
@@ -26,12 +26,15 @@
 #include "arm_compute/core/CL/CLHelpers.h"
 #include "arm_compute/core/CL/CLKernelLibrary.h"
 #include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/FixedPoint.h"
 #include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/Window.h"
 
+#include "support/ToolchainSupport.h"
+
 using namespace arm_compute;
 
 CLBatchNormalizationLayerKernel::CLBatchNormalizationLayerKernel()
@@ -42,7 +45,7 @@
 void CLBatchNormalizationLayerKernel::configure(const ICLTensor *input, ICLTensor *output, const ICLTensor *mean, const ICLTensor *var, const ICLTensor *beta, const ICLTensor *gamma,
                                                 float epsilon)
 {
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_NULLPTR(output);
 
     // Output tensor auto initialization if not yet initialized
@@ -54,10 +57,6 @@
     ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(mean, var, beta, gamma);
     ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) != mean->info()->dimension(0));
 
-    // Set build options
-    std::set<std::string> build_opts;
-    build_opts.emplace(("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type())));
-
     _input   = input;
     _output  = output;
     _mean    = mean;
@@ -66,17 +65,25 @@
     _gamma   = gamma;
     _epsilon = epsilon;
 
+    const unsigned int num_elems_processed_per_iteration = 16 / input->info()->element_size();
+
+    // Set build options
+    std::set<std::string> build_opts;
+    build_opts.emplace(("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type())));
+    build_opts.emplace(("-DVEC_SIZE=" + support::cpp11::to_string(num_elems_processed_per_iteration)));
+    if(is_data_type_fixed_point(input->info()->data_type()))
+    {
+        build_opts.emplace("-DFIXED_POINT_POSITION=" + support::cpp11::to_string(input->info()->fixed_point_position()));
+    }
+
     // Create kernel
-    std::string kernel_name = "batchnormalization_layer";
-    _kernel                 = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts));
+    _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("batchnormalization_layer", build_opts));
 
     // Set kernel static arguments
     unsigned int idx = 2 * num_arguments_per_3D_tensor() + 4 * num_arguments_per_1D_tensor(); // Skip the input and output parameters
     _kernel.setArg<cl_float>(idx++, _epsilon);
 
     // Configure kernel window
-    const unsigned int num_elems_processed_per_iteration = 4;
-
     Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
 
     AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration);