COMPMID-410 Port BatchNormalization to use fixed point 16

Change-Id: I7d3e9ff70c717ef5e6de2bcfbfd277f39006702f
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/78956
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
diff --git a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
index d0aec69..d1adfa7 100644
--- a/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEBatchNormalizationLayerKernel.cpp
@@ -47,7 +47,7 @@
     // Only compute denominator and NEON vectors once per feature map.
     int slice = -1;
 
-    int        fixed_point_position = in->info()->fixed_point_position();
+    const int  fixed_point_position = in->info()->fixed_point_position();
     const auto input_mean           = reinterpret_cast<const qint8_t *>(mean->ptr_to_element(Coordinates(0, 0)));
     const auto input_var            = reinterpret_cast<const qint8_t *>(var->ptr_to_element(Coordinates(0, 0)));
     const auto input_gamma          = reinterpret_cast<const qint8_t *>(gamma->ptr_to_element(Coordinates(0, 0)));
@@ -82,6 +82,50 @@
     input, output);
 }
 
+void batch_normalization_q16(const ITensor *in, ITensor *out, const ITensor *mean, const ITensor *var, const ITensor *beta, const ITensor *gamma, float epsilon, const Window &window)
+{
+    Iterator input(in, window);
+    Iterator output(out, window);
+
+    // Hold information about the current feature map we are iterating.
+    // Only compute denominator and NEON vectors once per feature map.
+    int slice = -1;
+
+    const int  fixed_point_position = in->info()->fixed_point_position();
+    const auto input_mean           = reinterpret_cast<const qint16_t *>(mean->ptr_to_element(Coordinates(0, 0)));
+    const auto input_var            = reinterpret_cast<const qint16_t *>(var->ptr_to_element(Coordinates(0, 0)));
+    const auto input_gamma          = reinterpret_cast<const qint16_t *>(gamma->ptr_to_element(Coordinates(0, 0)));
+    const auto input_beta           = reinterpret_cast<const qint16_t *>(beta->ptr_to_element(Coordinates(0, 0)));
+
+    qint16x8_t       mean_vec    = vdupq_n_qs16(0);
+    qint16x8_t       var_vec     = vdupq_n_qs16(0);
+    qint16x8_t       gamma_vec   = vdupq_n_qs16(0);
+    qint16x8_t       beta_vec    = vdupq_n_qs16(0);
+    qint16x8_t       denominator = vdupq_n_qs16(0);
+    const qint16x8_t epsilon_vec = vdupq_n_qs16(sqcvt_qs16_f32(epsilon, fixed_point_position));
+    execute_window_loop(window, [&](const Coordinates & id)
+    {
+        if(slice != id.z())
+        {
+            // Conctruct vectors
+            mean_vec  = vdupq_n_qs16(*(input_mean + id.z()));
+            var_vec   = vdupq_n_qs16(*(input_var + id.z()));
+            gamma_vec = vdupq_n_qs16(*(input_gamma + id.z()));
+            beta_vec  = vdupq_n_qs16(*(input_beta + id.z()));
+
+            // Calculate denominator
+            denominator = vqinvsqrtq_qs16(vqaddq_qs16(var_vec, epsilon_vec), fixed_point_position);
+            slice       = id.z();
+        }
+
+        // Calculate x bar and store results
+        const qint16x8_t numerator = vqsubq_qs16(vld1q_qs16(reinterpret_cast<const qint16_t *>(input.ptr())), mean_vec);
+        const qint16x8_t x_bar     = vqmulq_qs16(numerator, denominator, fixed_point_position);
+        vst1q_qs16(reinterpret_cast<qint16_t *>(output.ptr()), vqmlaq_qs16(beta_vec, x_bar, gamma_vec, fixed_point_position));
+    },
+    input, output);
+}
+
 void batch_normalization_fp32(const ITensor *in, ITensor *out, const ITensor *mean, const ITensor *var, const ITensor *beta, const ITensor *gamma, float epsilon, const Window &window)
 {
     Iterator input(in, window);
@@ -127,7 +171,7 @@
 
 void NEBatchNormalizationLayerKernel::configure(const ITensor *input, ITensor *output, const ITensor *mean, const ITensor *var, const ITensor *beta, const ITensor *gamma, float epsilon)
 {
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_NULLPTR(output);
 
     // Output tensor auto initialization if not yet initialized
@@ -155,6 +199,10 @@
             _func                             = &batch_normalization_q8;
             num_elems_processed_per_iteration = 16;
             break;
+        case DataType::QS16:
+            _func                             = &batch_normalization_q16;
+            num_elems_processed_per_iteration = 8;
+            break;
         case DataType::F32:
             _func                             = &batch_normalization_fp32;
             num_elems_processed_per_iteration = 4;