COMPMID-1645  NEL2Normalization for FP32/FP16 & NHWC

Change-Id: I29e35024e29781a6b943b568abec9c73649215e6
diff --git a/src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp b/src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp
index ed03783..cda041d 100644
--- a/src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEL2NormalizeLayerKernel.cpp
@@ -32,15 +32,20 @@
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/Window.h"
 
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
 #include <arm_neon.h>
 #include <cmath>
 
-using namespace arm_compute;
-
+namespace arm_compute
+{
 namespace
 {
+template <typename T, int S>
 void l2_normalize_X(const ITensor *in, const ITensor *sum, ITensor *out, float epsilon, const Window &window)
 {
+    /** NEON vector tag type. */
+    using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
     Window window_sum(window);
     window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
 
@@ -53,30 +58,97 @@
         Iterator sum_it(sum, sum_slice);
         Iterator output_it(out, in_slice);
 
-        const float       sum_value           = *reinterpret_cast<const float *>(sum_it.ptr());
-        const float32x4_t vec_normalize_value = vdupq_n_f32(1.f / std::sqrt(std::max(sum_value, epsilon)));
+        const auto sum_value           = *reinterpret_cast<const T *>(sum_it.ptr());
+        const auto vec_normalize_value = wrapper::vdup_n(static_cast<T>(1.f / std::sqrt(std::max(sum_value, static_cast<T>(epsilon)))), ExactTagType{});
 
         execute_window_loop(in_slice, [&](const Coordinates & id)
         {
-            const auto in_ptr  = reinterpret_cast<const float *>(input_it.ptr());
-            const auto out_ptr = reinterpret_cast<float *>(output_it.ptr());
+            const auto in_ptr  = reinterpret_cast<const T *>(input_it.ptr());
+            const auto out_ptr = reinterpret_cast<T *>(output_it.ptr());
 
-            vst1q_f32(out_ptr, vmulq_f32(vld1q_f32(in_ptr), vec_normalize_value));
+            wrapper::vstore(out_ptr, wrapper::vmul(wrapper::vloadq(in_ptr), vec_normalize_value));
         },
         input_it, output_it);
     }
     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
 }
 
+template <typename T, int S>
+void l2_normalize_Y(const ITensor *in, const ITensor *sum, ITensor *out, float epsilon, const Window &window)
+{
+    /** NEON vector tag type. */
+    using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
+    Window window_sum(window);
+    window_sum.set(Window::DimY, Window::Dimension(0, 0, 0));
+
+    Window in_slice  = window.first_slice_window_2D();
+    Window sum_slice = window_sum.first_slice_window_2D();
+
+    do
+    {
+        Iterator input_it(in, in_slice);
+        Iterator sum_it(sum, sum_slice);
+        Iterator output_it(out, in_slice);
+
+        auto eps = wrapper::vdup_n(static_cast<T>(epsilon), ExactTagType{});
+
+        execute_window_loop(in_slice, [&](const Coordinates & id)
+        {
+            const auto in_ptr  = reinterpret_cast<const T *>(input_it.ptr());
+            const auto sum_ptr = reinterpret_cast<const T *>(sum_it.ptr());
+            const auto out_ptr = reinterpret_cast<T *>(output_it.ptr());
+
+            const auto vec_normalize_value = wrapper::vinvsqrt(wrapper::vmax(wrapper::vloadq(sum_ptr), eps));
+            wrapper::vstore(out_ptr, wrapper::vmul(wrapper::vloadq(in_ptr), vec_normalize_value));
+        },
+        input_it, sum_it, output_it);
+    }
+    while(window.slide_window_slice_2D(in_slice) && window.slide_window_slice_2D(sum_slice));
+}
+
+template <typename T, int S>
+void l2_normalize_Z(const ITensor *in, const ITensor *sum, ITensor *out, float epsilon, const Window &window)
+{
+    /** NEON vector tag type. */
+    using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
+    Window window_sum(window);
+    window_sum.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+    Window in_slice  = window.first_slice_window_3D();
+    Window sum_slice = window_sum.first_slice_window_3D();
+
+    do
+    {
+        Iterator input_it(in, in_slice);
+        Iterator sum_it(sum, sum_slice);
+        Iterator output_it(out, in_slice);
+
+        auto eps = wrapper::vdup_n(static_cast<T>(epsilon), ExactTagType{});
+
+        execute_window_loop(in_slice, [&](const Coordinates & id)
+        {
+            const auto in_ptr  = reinterpret_cast<const T *>(input_it.ptr());
+            const auto sum_ptr = reinterpret_cast<const T *>(sum_it.ptr());
+            const auto out_ptr = reinterpret_cast<T *>(output_it.ptr());
+
+            const auto vec_normalize_value = wrapper::vinvsqrt(wrapper::vmax(wrapper::vloadq(sum_ptr), eps));
+            wrapper::vstore(out_ptr, wrapper::vmul(wrapper::vloadq(in_ptr), vec_normalize_value));
+        },
+        input_it, sum_it, output_it);
+    }
+    while(window.slide_window_slice_3D(in_slice) && window.slide_window_slice_3D(sum_slice));
+}
+
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output, unsigned int axis, float epsilon)
 {
     ARM_COMPUTE_UNUSED(epsilon);
 
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, sum, output);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
-    ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() != DataLayout::NCHW);
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 0, "Unsupported normalization axis, Supported axis is 0");
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis > 2, "Axis greater than 2 is not supported");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(axis >= TensorShape::num_max_dimensions, "Normalization axis greater than max number of dimensions");
 
     // Reduce shape on axis
@@ -89,7 +161,7 @@
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(input->tensor_shape(), output->tensor_shape());
-        ARM_COMPUTE_RETURN_ERROR_ON(output->data_layout() != DataLayout::NCHW);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
     }
 
     return Status{};
@@ -158,9 +230,52 @@
     switch(_axis)
     {
         case 0:
-            l2_normalize_X(_input, _sum, _output, _epsilon, window);
+            switch(_input->info()->data_type())
+            {
+                case DataType::F32:
+                    l2_normalize_X<float, 4>(_input, _sum, _output, _epsilon, window);
+                    break;
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+                case DataType::F16:
+                    l2_normalize_X<float16_t, 8>(_input, _sum, _output, _epsilon, window);
+                    break;
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+                default:
+                    ARM_COMPUTE_ERROR("Not implemented");
+            }
+            break;
+        case 1:
+            switch(_input->info()->data_type())
+            {
+                case DataType::F32:
+                    l2_normalize_Y<float, 4>(_input, _sum, _output, _epsilon, window);
+                    break;
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+                case DataType::F16:
+                    l2_normalize_Y<float16_t, 8>(_input, _sum, _output, _epsilon, window);
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+                    break;
+                default:
+                    ARM_COMPUTE_ERROR("Not implemented");
+            }
+            break;
+        case 2:
+            switch(_input->info()->data_type())
+            {
+                case DataType::F32:
+                    l2_normalize_Z<float, 4>(_input, _sum, _output, _epsilon, window);
+                    break;
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+                case DataType::F16:
+                    l2_normalize_Z<float16_t, 8>(_input, _sum, _output, _epsilon, window);
+                    break;
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+                default:
+                    ARM_COMPUTE_ERROR("Not implemented");
+            }
             break;
         default:
             ARM_COMPUTE_ERROR("Unsupported normalization axis");
     }
 }
+} // namespace arm_compute