COMPMID-1647 NENormalizationLayer IN_MAP_2D support for NHWC for FP32/FP16

Change-Id: Id74cc7ba8e5cabee6acd3798d4779f88b1f00a9b
diff --git a/src/core/NEON/kernels/NENormalizationLayerKernel.cpp b/src/core/NEON/kernels/NENormalizationLayerKernel.cpp
index 27af121..e5f6e4f 100644
--- a/src/core/NEON/kernels/NENormalizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NENormalizationLayerKernel.cpp
@@ -29,6 +29,7 @@
 #include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/NEON/NEFixedPoint.h"
 #include "arm_compute/core/NEON/NEMath.h"
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
@@ -44,8 +45,6 @@
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
 
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_layout() == DataLayout::NHWC && norm_info.type() == NormType::IN_MAP_2D,
-                                    "Only Cross-map and 1D In-map normalization is supported for NHWC layout");
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, input_squared);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, input_squared);
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(norm_info.norm_size() % 2), "Normalization size should be odd");
@@ -55,6 +54,7 @@
     {
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
     }
 
     return Status{};
@@ -143,16 +143,26 @@
                 {
                     if(norm_info.type() == NormType::IN_MAP_2D)
                     {
-                        _func = &NENormalizationLayerKernel::normalize_float<DataType::F32, 0, true>;
+                        _func = &NENormalizationLayerKernel::normalize_float<float, 4, 0, true>;
                     }
                     else
                     {
-                        _func = &NENormalizationLayerKernel::normalize_float<DataType::F32, 0, false>;
+                        _func = &NENormalizationLayerKernel::normalize_float<float, 4, 0, false>;
                     }
                     break;
                 }
+                case 1:
+                    if(norm_info.type() == NormType::IN_MAP_2D)
+                    {
+                        _func = &NENormalizationLayerKernel::normalize_float<float, 4, 1, true>;
+                    }
+                    else
+                    {
+                        _func = &NENormalizationLayerKernel::normalize_float<float, 4, 1, false>;
+                    }
+                    break;
                 case 2:
-                    _func = &NENormalizationLayerKernel::normalize_float<DataType::F32, 2, false>;
+                    _func = &NENormalizationLayerKernel::normalize_float<float, 4, 2, false>;
                     break;
                 default:
                     break;
@@ -168,16 +178,26 @@
                 {
                     if(norm_info.type() == NormType::IN_MAP_2D)
                     {
-                        _func = &NENormalizationLayerKernel::normalize_float<DataType::F16, 0, true>;
+                        _func = &NENormalizationLayerKernel::normalize_float<float16_t, 8, 0, true>;
                     }
                     else
                     {
-                        _func = &NENormalizationLayerKernel::normalize_float<DataType::F16, 0, false>;
+                        _func = &NENormalizationLayerKernel::normalize_float<float16_t, 8, 0, false>;
                     }
                     break;
                 }
+                case 1:
+                    if(norm_info.type() == NormType::IN_MAP_2D)
+                    {
+                        _func = &NENormalizationLayerKernel::normalize_float<float16_t, 8, 1, true>;
+                    }
+                    else
+                    {
+                        _func = &NENormalizationLayerKernel::normalize_float<float16_t, 8, 1, false>;
+                    }
+                    break;
                 case 2:
-                    _func = &NENormalizationLayerKernel::normalize_float<DataType::F16, 2, false>;
+                    _func = &NENormalizationLayerKernel::normalize_float<float16_t, 8, 2, false>;
                     break;
                 default:
                     break;
@@ -195,14 +215,17 @@
     INEKernel::configure(win_config.second);
 }
 
-template <DataType dt, unsigned int dim, bool do_2D_norm>
+template <typename T, unsigned int S, unsigned int dim, bool do_2D_norm>
 void NENormalizationLayerKernel::normalize_float(const Window &window)
 {
+    /** NEON vector tag type. */
+    using ExactTagType = typename wrapper::traits::neon_vector<T, S>::tag_type;
+
     Iterator input(_input, window);
     Iterator input_squared(_input_squared, window);
     Iterator output(_output, window);
 
-    const int dim_y                = 1;
+    const int dim_y                = _input->info()->data_layout() == DataLayout::NCHW ? 1 : 2;
     const int radius               = _norm_info.norm_size() / 2;
     const int input_squared_stride = _input_squared->info()->strides_in_bytes()[dim];
     // We account padding across X only and we iterate over rows
@@ -210,83 +233,39 @@
     const int max_right  = _input->info()->dimension(dim) - 1;
     const int max_bottom = _input->info()->dimension(dim_y) - 1;
 
-    if(dt == DataType::F32)
-    {
-        const float32x4_t coeff_vec = vdupq_n_f32(_norm_info.scale_coeff());
-        const float32x4_t beta_vec  = vdupq_n_f32(_norm_info.beta());
-        const float32x4_t kappa_vec = vdupq_n_f32(_norm_info.kappa());
+    const auto coeff_vec = wrapper::vdup_n(static_cast<T>(_norm_info.scale_coeff()), ExactTagType{});
+    const auto beta_vec  = wrapper::vdup_n(static_cast<T>(_norm_info.beta()), ExactTagType{});
+    const auto kappa_vec = wrapper::vdup_n(static_cast<T>(_norm_info.kappa()), ExactTagType{});
 
-        execute_window_loop(window, [&](const Coordinates & id)
+    execute_window_loop(window, [&](const Coordinates & id)
+    {
+        // Get range to normalize
+        const int current_row   = do_2D_norm ? id[dim_y] : 0;
+        const int current_slice = id[dim];
+        const int first_row     = do_2D_norm ? std::max(current_row - radius, 0) : 0;
+        const int last_row      = do_2D_norm ? std::min(current_row + radius, max_bottom) : 0;
+        const int first_slice   = std::max(current_slice - radius, min_left);
+        const int last_slice    = std::min(current_slice + radius, max_right);
+
+        // Accumulate 2D In-Map values
+        auto accu = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
+        for(int j = first_row; j <= last_row; j++)
         {
-            // Get range to normalize
-            const int current_row   = do_2D_norm ? id[dim_y] : 0;
-            const int current_slice = id[dim];
-            const int first_row     = do_2D_norm ? std::max(current_row - radius, 0) : 0;
-            const int last_row      = do_2D_norm ? std::min(current_row + radius, max_bottom) : 0;
-            const int first_slice   = std::max(current_slice - radius, min_left);
-            const int last_slice    = std::min(current_slice + radius, max_right);
-
-            // Accumulate 2D In-Map values
-            float32x4_t accu = vdupq_n_f32(0.f);
-            for(int j = first_row; j <= last_row; j++)
+            // Compute row displacement
+            const int            row               = (j - current_row) * _input_squared->info()->strides_in_bytes()[dim_y];
+            const uint8_t *const input_squared_ptr = input_squared.ptr() + row - (current_slice * input_squared_stride);
+            for(int i = first_slice; i <= last_slice; ++i)
             {
-                // Compute row displacement
-                const int            row               = (j - current_row) * _input_squared->info()->strides_in_bytes()[dim_y];
-                const uint8_t *const input_squared_ptr = input_squared.ptr() + row - (current_slice * input_squared_stride);
-                for(int i = first_slice; i <= last_slice; ++i)
-                {
-                    accu = vaddq_f32(accu, vld1q_f32(reinterpret_cast<const float *>(input_squared_ptr + i * input_squared_stride)));
-                }
+                accu = wrapper::vadd(accu, wrapper::vloadq(reinterpret_cast<const T *>(input_squared_ptr + i * input_squared_stride)));
             }
+        }
 
-            // Normalize
-            const float32x4_t normalized       = vpowq_f32(vmlaq_f32(kappa_vec, coeff_vec, accu), beta_vec);
-            const float32x4_t normalized_pixel = vmulq_f32(vld1q_f32(reinterpret_cast<const float *>(input.ptr())), vinvq_f32(normalized));
-            vst1q_f32(reinterpret_cast<float *>(output.ptr()), normalized_pixel);
-        },
-        input, input_squared, output);
-    }
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-    else if(dt == DataType::F16)
-    {
-        const float16x8_t coeff_vec    = vdupq_n_f16(_norm_info.scale_coeff());
-        const float16x8_t beta_vec_f16 = vdupq_n_f16(_norm_info.beta());
-        const float16x8_t kappa_vec    = vdupq_n_f16(_norm_info.kappa());
-
-        execute_window_loop(window, [&](const Coordinates & id)
-        {
-            // Get range to normalize
-            const int current_row   = do_2D_norm ? id[dim_y] : 0;
-            const int current_slice = id[dim];
-            const int first_row     = do_2D_norm ? std::max(current_row - radius, 0) : 0;
-            const int last_row      = do_2D_norm ? std::min(current_row + radius, max_bottom) : 0;
-            const int first_slice   = std::max(current_slice - radius, min_left);
-            const int last_slice    = std::min(current_slice + radius, max_right);
-
-            // Accumulate 2D In-Map values
-            float16x8_t accu = vdupq_n_f16(0.f);
-            for(int j = first_row; j <= last_row; j++)
-            {
-                // Compute row displacement
-                const int            row               = (j - current_row) * _input_squared->info()->strides_in_bytes()[dim_y];
-                const uint8_t *const input_squared_ptr = input_squared.ptr() + row - (current_slice * input_squared_stride);
-                for(int i = first_slice; i <= last_slice; ++i)
-                {
-                    accu = vaddq_f16(accu, vld1q_f16(reinterpret_cast<const float16_t *>(input_squared_ptr + i * input_squared_stride)));
-                }
-            }
-
-            const float16x8_t norm_f16         = vpowq_f16(vaddq_f16(kappa_vec, vmulq_f16(coeff_vec, accu)), beta_vec_f16);
-            const float16x8_t normalized_pixel = vmulq_f16(vld1q_f16(reinterpret_cast<const float16_t *>(input.ptr())), vinvq_f16(norm_f16));
-            vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), normalized_pixel);
-        },
-        input, input_squared, output);
-    }
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-    else
-    {
-        ARM_COMPUTE_ERROR("Not supported");
-    }
+        // Normalize
+        const auto normalized       = wrapper::vpow(wrapper::vmla(kappa_vec, coeff_vec, accu), beta_vec);
+        const auto normalized_pixel = wrapper::vmul(wrapper::vloadq(reinterpret_cast<const T *>(input.ptr())), wrapper::vinv(normalized));
+        wrapper::vstore(reinterpret_cast<T *>(output.ptr()), normalized_pixel);
+    },
+    input, input_squared, output);
 }
 
 Status NENormalizationLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *input_squared, const ITensorInfo *output, const NormalizationLayerInfo norm_info)