COMPMID-2699: Add support for QASYMM16 in NEQuantizationLayer

Change-Id: Icb968e37551a9048040e9aaff5329e874c53a2ee
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2016
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp b/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp
index 0aa34cd..6a9c4ae 100644
--- a/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEQuantizationLayerKernel.cpp
@@ -34,9 +34,10 @@
 #include "arm_compute/core/CPP/Validate.h"
 
 #include <arm_neon.h>
+#include <map>
 
-using namespace arm_compute;
-
+namespace arm_compute
+{
 namespace
 {
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
@@ -45,7 +46,7 @@
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON(output->tensor_shape().total_size() == 0);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM16);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
 
     return Status{};
@@ -71,7 +72,7 @@
 } // namespace
 
 NEQuantizationLayerKernel::NEQuantizationLayerKernel()
-    : _input(nullptr), _output(nullptr)
+    : _input(nullptr), _output(nullptr), _func(nullptr)
 {
 }
 
@@ -83,6 +84,33 @@
     _input  = input;
     _output = output;
 
+    static std::map<DataType, QuantizationFunctionExecutorPtr> quant_map_f32 =
+    {
+        { DataType::QASYMM8, &NEQuantizationLayerKernel::run_quantize_qasymm8<float> },
+        { DataType::QASYMM16, &NEQuantizationLayerKernel::run_quantize_qasymm16<float> },
+    };
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+    static std::map<DataType, QuantizationFunctionExecutorPtr> quant_map_f16 =
+    {
+        { DataType::QASYMM8, &NEQuantizationLayerKernel::run_quantize_qasymm8<float16_t> },
+        { DataType::QASYMM16, &NEQuantizationLayerKernel::run_quantize_qasymm16<float16_t> },
+    };
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC*/
+
+    switch(input->info()->data_type())
+    {
+        case DataType::F32:
+            _func = quant_map_f32[output->info()->data_type()];
+            break;
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+        case DataType::F16:
+            _func = quant_map_f16[output->info()->data_type()];
+            break;
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+        default:
+            ARM_COMPUTE_ERROR("Unsupported input data type.");
+    }
+
     // Configure kernel window
     Window win_config = calculate_max_window(*input->info(), Steps());
 
@@ -96,18 +124,17 @@
 Status NEQuantizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
 {
     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output));
-
     return Status{};
 }
 
 template <typename T>
-void NEQuantizationLayerKernel::quantize(const Window &window, const QuantizationInfo &qinfo)
+void NEQuantizationLayerKernel::run_quantize_qasymm8(const Window &window)
 {
     constexpr auto window_step    = 16;
     const auto     window_start_x = static_cast<int>(window.x().start());
     const auto     window_end_x   = static_cast<int>(window.x().end());
 
-    const UniformQuantizationInfo uqinfo = qinfo.uniform();
+    const UniformQuantizationInfo uqinfo = _output->info()->quantization_info().uniform();
 #ifdef __aarch64__
     constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN;
 #else  //__aarch64__
@@ -139,25 +166,54 @@
     input, output);
 }
 
+template <typename T>
+void NEQuantizationLayerKernel::run_quantize_qasymm16(const Window &window)
+{
+    constexpr auto window_step    = 16;
+    const auto     window_start_x = static_cast<int>(window.x().start());
+    const auto     window_end_x   = static_cast<int>(window.x().end());
+
+    const UniformQuantizationInfo uqinfo = _output->info()->quantization_info().uniform();
+#ifdef __aarch64__
+    constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_NEAREST_EVEN;
+#else  //__aarch64__
+    constexpr RoundingPolicy rounding_policy = RoundingPolicy::TO_ZERO;
+#endif //__aarch64__
+
+    // Collapse window and reset first dimension to handle tail calculations manually
+    Window win_collapsed = window.collapse_if_possible(window, Window::DimZ);
+    win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+    Iterator input(_input, win_collapsed);
+    Iterator output(_output, win_collapsed);
+    execute_window_loop(win_collapsed, [&](const Coordinates &)
+    {
+        auto input_ptr  = reinterpret_cast<const T *>(input.ptr());
+        auto output_ptr = reinterpret_cast<uint16_t *>(output.ptr());
+
+        int x = window_start_x;
+        for(; x <= (window_end_x - window_step); x += window_step)
+        {
+            uint16x8x2_t tmp = vquantize_qasymm16(load_value(&input_ptr[x]), uqinfo);
+            vst1q_u16(&output_ptr[x], tmp.val[0]);
+            vst1q_u16(&output_ptr[x + 8], tmp.val[1]);
+        }
+        // Compute left-over elements
+        for(; x < window_end_x; ++x)
+        {
+            output_ptr[x] = quantize_qasymm16(input_ptr[x], uqinfo, rounding_policy);
+        }
+    },
+    input, output);
+}
+
 void NEQuantizationLayerKernel::run(const Window &window, const ThreadInfo &info)
 {
     ARM_COMPUTE_UNUSED(info);
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
+    ARM_COMPUTE_ERROR_ON(_func == nullptr);
 
-    const QuantizationInfo &qinfo = _output->info()->quantization_info();
-
-    switch(_input->info()->data_type())
-    {
-        case DataType::F32:
-            NEQuantizationLayerKernel::quantize<float>(window, qinfo);
-            break;
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-        case DataType::F16:
-            NEQuantizationLayerKernel::quantize<float16_t>(window, qinfo);
-            break;
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-        default:
-            ARM_COMPUTE_ERROR("Unsupported data type.");
-    }
+    (this->*_func)(window);
 }
+} // namespace arm_compute