COMPMID-421: Added FP16 suppot to NENormalizationLayer and NEPixelWiseMultiplication.

Change-Id: If174f8071502fc5cc94b27cd44a9b1d5e451a9e2
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79553
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/arm_compute/core/NEON/NEMath.h b/arm_compute/core/NEON/NEMath.h
index bb8a330..8dd9d60 100644
--- a/arm_compute/core/NEON/NEMath.h
+++ b/arm_compute/core/NEON/NEMath.h
@@ -91,6 +91,26 @@
  * @return The calculated power.
  */
 float32x4_t vpowq_f32(float32x4_t val, float32x4_t n);
+
+#ifdef ARM_COMPUTE_ENABLE_FP16
+/** Calculate exponential
+ *
+ * @param[in] x Input vector value in F16 format.
+ *
+ * @return The calculated exponent.
+ */
+float16x8_t vexpq_f16(float16x8_t x);
+/** Calculate n power of a number.
+ *
+ * pow(x,n) = e^(n*log(x))
+ *
+ * @param[in] val Input vector value in F16 format.
+ * @param[in] n   Powers to raise the input to.
+ *
+ * @return The calculated power.
+ */
+float16x8_t vpowq_f16(float16x8_t val, float16x8_t n);
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
 }
 #include "arm_compute/core/NEON/NEMath.inl"
 #endif /* __ARM_COMPUTE_NEMATH_H__ */
diff --git a/arm_compute/core/NEON/NEMath.inl b/arm_compute/core/NEON/NEMath.inl
index 9a49493..c73c545 100644
--- a/arm_compute/core/NEON/NEMath.inl
+++ b/arm_compute/core/NEON/NEMath.inl
@@ -141,4 +141,100 @@
 {
     return vexpq_f32(vmulq_f32(n, vlogq_f32(val)));
 }
-}
\ No newline at end of file
+
+#ifdef ARM_COMPUTE_ENABLE_FP16
+/* Exponent polynomial coefficients */
+const std::array<float16x8_t, 8> exp_tab_f16 =
+{
+    {
+        vdupq_n_f16(1.f),
+        vdupq_n_f16(0.0416598916054f),
+        vdupq_n_f16(0.500000596046f),
+        vdupq_n_f16(0.0014122662833f),
+        vdupq_n_f16(1.00000011921f),
+        vdupq_n_f16(0.00833693705499f),
+        vdupq_n_f16(0.166665703058f),
+        vdupq_n_f16(0.000195780929062f),
+    }
+};
+
+/* Logarithm polynomial coefficients */
+const std::array<float16x8_t, 8> log_tab_f16 =
+{
+    {
+        vdupq_n_f16(-2.29561495781f),
+        vdupq_n_f16(-2.47071170807f),
+        vdupq_n_f16(-5.68692588806f),
+        vdupq_n_f16(-0.165253549814f),
+        vdupq_n_f16(5.17591238022f),
+        vdupq_n_f16(0.844007015228f),
+        vdupq_n_f16(4.58445882797f),
+        vdupq_n_f16(0.0141278216615f),
+    }
+};
+
+inline float16x8_t vinvq_f16(float16x8_t x)
+{
+    float16x8_t recip = vrecpeq_f16(x);
+    recip             = vmulq_f16(vrecpsq_f16(x, recip), recip);
+    recip             = vmulq_f16(vrecpsq_f16(x, recip), recip);
+    return recip;
+}
+
+inline float16x8_t vtaylor_polyq_f16(float16x8_t x, const std::array<float16x8_t, 8> &coeffs)
+{
+    const float16x8_t A   = vaddq_f16(coeffs[0], vmulq_f16(coeffs[4], x));
+    const float16x8_t B   = vaddq_f16(coeffs[2], vmulq_f16(coeffs[6], x));
+    const float16x8_t C   = vaddq_f16(coeffs[1], vmulq_f16(coeffs[5], x));
+    const float16x8_t D   = vaddq_f16(coeffs[3], vmulq_f16(coeffs[7], x));
+    const float16x8_t x2  = vmulq_f16(x, x);
+    const float16x8_t x4  = vmulq_f16(x2, x2);
+    const float16x8_t res = vaddq_f16(vaddq_f16(A, vmulq_f16(B, x2)), vmulq_f16(vaddq_f16(C, vmulq_f16(D, x2)), x4));
+    return res;
+}
+
+inline float16x8_t vexpq_f16(float16x8_t x)
+{
+    static const float16x8_t CONST_LN2          = vdupq_n_f16(0.6931471805f); // ln(2)
+    static const float16x8_t CONST_INV_LN2      = vdupq_n_f16(1.4426950408f); // 1/ln(2)
+    static const float16x8_t CONST_0            = vdupq_n_f16(0.f);
+    static const int16x8_t   CONST_NEGATIVE_126 = vdupq_n_s16(-126);
+
+    // Perform range reduction [-log(2),log(2)]
+    const int16x8_t   m   = vcvtq_s16_f16(vmulq_f16(x, CONST_INV_LN2));
+    const float16x8_t val = vsubq_f16(x, vmulq_f16(vcvtq_f16_s16(m), CONST_LN2));
+
+    // Polynomial Approximation
+    float16x8_t poly = vtaylor_polyq_f16(val, exp_tab_f16);
+
+    // Reconstruct
+    poly = vreinterpretq_f16_s16(vqaddq_s16(vreinterpretq_s16_f16(poly), vqshlq_n_s16(m, 9)));
+    poly = vbslq_f16(vcltq_s16(m, CONST_NEGATIVE_126), CONST_0, poly);
+
+    return poly;
+}
+
+inline float16x8_t vlogq_f16(float16x8_t x)
+{
+    static const int16x8_t   CONST_127 = vdupq_n_s16(127);           // 127
+    static const float16x8_t CONST_LN2 = vdupq_n_f16(0.6931471805f); // ln(2)
+
+    // Extract exponent
+    const int16x8_t   m   = vsubq_s16(vreinterpretq_s16_u16(vshrq_n_u16(vreinterpretq_u16_f16(x), 9)), CONST_127);
+    const float16x8_t val = vreinterpretq_f16_s16(vsubq_s16(vreinterpretq_s16_f16(x), vshlq_n_s16(m, 9)));
+
+    // Polynomial Approximation
+    float16x8_t poly = vtaylor_polyq_f16(val, log_tab_f16);
+
+    // Reconstruct
+    poly = vaddq_f16(poly, vmulq_f16(vcvtq_f16_s16(m), CONST_LN2));
+
+    return poly;
+}
+
+inline float16x8_t vpowq_f16(float16x8_t val, float16x8_t n)
+{
+    return vexpq_f16(vmulq_f16(n, vlogq_f16(val)));
+}
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+}
diff --git a/arm_compute/core/NEON/kernels/NENormalizationLayerKernel.h b/arm_compute/core/NEON/kernels/NENormalizationLayerKernel.h
index d4e36d5..b1bc594 100644
--- a/arm_compute/core/NEON/kernels/NENormalizationLayerKernel.h
+++ b/arm_compute/core/NEON/kernels/NENormalizationLayerKernel.h
@@ -73,8 +73,8 @@
      *
      * @param[in] window Region on which to execute the kernel.
      */
-    template <unsigned int dim, bool do_2D_norm>
-    void normalize(const Window &window);
+    template <DataType dt, unsigned int dim, bool do_2D_norm>
+    void normalize_float(const Window &window);
 
     /** Function to perform normalization for fixed-point values depending on
      * the given template dimension. The second template parameter specifies
diff --git a/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h b/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h
index 7e402cd..433a20e 100644
--- a/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h
+++ b/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h
@@ -52,9 +52,9 @@
      * @note For @p scale equal to 1/255 only round to nearest even (implemented as round half up) is supported.
      *       For all other scale values only round to zero (implemented as round towards minus infinity) is supported.
      *
-     * @param[in]  input1          An input tensor. Data types supported: U8/QS8/S16/F32.
-     * @param[in]  input2          An input tensor. Data types supported: U8/QS8/S16/F32.
-     * @param[out] output          The output tensor. Data types supported: U8 (Only if both inputs are U8) /S16/F32.
+     * @param[in]  input1          An input tensor. Data types supported: U8/QS8/S16/F16/F32.
+     * @param[in]  input2          An input tensor. Data types supported: U8/QS8/S16/F16/F32.
+     * @param[out] output          The output tensor. Data types supported: U8 (Only if both inputs are U8) /S16/F16/F32.
      * @param[in]  scale           Scale to apply after multiplication.
      *                             Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15.
      * @param[in]  overflow_policy Overflow policy.
diff --git a/arm_compute/runtime/NEON/functions/NENormalizationLayer.h b/arm_compute/runtime/NEON/functions/NENormalizationLayer.h
index 3202867..4cfea22 100644
--- a/arm_compute/runtime/NEON/functions/NENormalizationLayer.h
+++ b/arm_compute/runtime/NEON/functions/NENormalizationLayer.h
@@ -52,7 +52,7 @@
     /** Set the input and output tensors.
      *
      * @param[in]  input     Source tensor. 3 lower dims represent a single input with dimensions [width, height, IFM],
-     *                       and an optional 4th dimension for batch of inputs. Data type supported: QS8/F32
+     *                       and an optional 4th dimension for batch of inputs. Data type supported: QS8/F16/F32
      * @param[out] output    Destination with the same dimensions, data type and number of channels of  @p input
      * @param[in]  norm_info Normalization layer information like the normalization type, normalization size and other parameters.
      */
diff --git a/scripts/check_clang-tidy.py b/scripts/check_clang-tidy.py
index 6ab1747..e80b460 100755
--- a/scripts/check_clang-tidy.py
+++ b/scripts/check_clang-tidy.py
@@ -20,6 +20,7 @@
                     ("cl2.hpp" in line and "cast from pointer to smaller type 'cl_context_properties' (aka 'int') loses information" in line) or
                     ("arm_fp16.h" in line) or
                     ("memory" in line and "cast from pointer to smaller type 'uintptr_t' (aka 'unsigned int') loses information" in line) or
+                    ("NEMath.inl" in line and "statement expression not allowed at file scope" in line) or
                     "3rdparty" in line):
                     continue
 
diff --git a/src/core/NEON/kernels/NENormalizationLayerKernel.cpp b/src/core/NEON/kernels/NENormalizationLayerKernel.cpp
index 0183e54..76ace91 100644
--- a/src/core/NEON/kernels/NENormalizationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NENormalizationLayerKernel.cpp
@@ -46,12 +46,10 @@
 
 void NENormalizationLayerKernel::configure(const ITensor *input, const ITensor *input_squared, ITensor *output, NormalizationLayerInfo norm_info)
 {
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32, DataType::QS8);
     ARM_COMPUTE_ERROR_ON_NULLPTR(output);
-
     // Output tensor auto initialization if not yet initialized
     auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
-
     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, input_squared, output);
     ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, input_squared, output);
     ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, input_squared, output);
@@ -68,27 +66,79 @@
     _norm_info     = norm_info;
     _border_size   = BorderSize(0, border_width);
 
-    const bool is_dt_f32 = _input->info()->data_type() == DataType::F32;
+    unsigned int num_elems_processed_per_iteration = 16 / input->info()->element_size();
 
-    switch(norm_info.type())
+    switch(_input->info()->data_type())
     {
-        case NormType::IN_MAP_1D:
-            _func = (is_dt_f32) ? &NENormalizationLayerKernel::normalize<0, false> : &NENormalizationLayerKernel::normalize_fixed_point<0, false>;
+        case DataType::F32:
+        {
+            num_elems_processed_per_iteration = 4;
+            switch(norm_info.type())
+            {
+                case NormType::IN_MAP_1D:
+                    _func = &NENormalizationLayerKernel::normalize_float<DataType::F32, 0, false>;
+                    break;
+                case NormType::IN_MAP_2D:
+                    // Normalize over X and Y
+                    _func = &NENormalizationLayerKernel::normalize_float<DataType::F32, 0, true>;
+                    break;
+                case NormType::CROSS_MAP:
+                    _func = &NENormalizationLayerKernel::normalize_float<DataType::F32, 2, false>;
+                    break;
+                default:
+                    ARM_COMPUTE_ERROR("Not supported");
+                    break;
+            }
             break;
-        case NormType::IN_MAP_2D:
-            // Normalize over X and Y
-            _func = (is_dt_f32) ? &NENormalizationLayerKernel::normalize<0, true> : &NENormalizationLayerKernel::normalize_fixed_point<0, true>;
+        }
+        case DataType::F16:
+        {
+            num_elems_processed_per_iteration = 8;
+            switch(norm_info.type())
+            {
+                case NormType::IN_MAP_1D:
+                    _func = &NENormalizationLayerKernel::normalize_float<DataType::F16, 0, false>;
+                    break;
+                case NormType::IN_MAP_2D:
+                    // Normalize over X and Y
+                    _func = &NENormalizationLayerKernel::normalize_float<DataType::F16, 0, true>;
+                    break;
+                case NormType::CROSS_MAP:
+                    _func = &NENormalizationLayerKernel::normalize_float<DataType::F16, 2, false>;
+                    break;
+                default:
+                    ARM_COMPUTE_ERROR("Not supported");
+                    break;
+            }
             break;
-        case NormType::CROSS_MAP:
-            _func = (is_dt_f32) ? &NENormalizationLayerKernel::normalize<2, false> : &NENormalizationLayerKernel::normalize_fixed_point<2, false>;
+        }
+        case DataType::QS8:
+        {
+            num_elems_processed_per_iteration = 16;
+            switch(norm_info.type())
+            {
+                case NormType::IN_MAP_1D:
+                    _func = &NENormalizationLayerKernel::normalize_fixed_point<0, false>;
+                    break;
+                case NormType::IN_MAP_2D:
+                    // Normalize over X and Y
+                    _func = &NENormalizationLayerKernel::normalize_fixed_point<0, true>;
+                    break;
+                case NormType::CROSS_MAP:
+                    _func = &NENormalizationLayerKernel::normalize_fixed_point<2, false>;
+                    break;
+                default:
+                    ARM_COMPUTE_ERROR("Not supported");
+                    break;
+            }
             break;
+        }
         default:
             ARM_COMPUTE_ERROR("NOT SUPPORTED!");
     }
 
-    const unsigned int num_elems_processed_per_iteration = (is_dt_f32) ? 4 : 16;
-    const unsigned int num_elems_read_per_iteration      = num_elems_processed_per_iteration + 2 * (norm_info.norm_size() / 2);
-    const unsigned int num_rows                          = (norm_info.type() == NormType::IN_MAP_2D) ? norm_info.norm_size() : 1;
+    const unsigned int num_elems_read_per_iteration = num_elems_processed_per_iteration + 2 * (norm_info.norm_size() / 2);
+    const unsigned int num_rows                     = (norm_info.type() == NormType::IN_MAP_2D) ? norm_info.norm_size() : 1;
 
     // Configure window
     Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
@@ -104,8 +154,8 @@
     INEKernel::configure(win);
 }
 
-template <unsigned int dim, bool do_2D_norm>
-void NENormalizationLayerKernel::normalize(const Window &window)
+template <DataType dt, unsigned int dim, bool do_2D_norm>
+void NENormalizationLayerKernel::normalize_float(const Window &window)
 {
     Iterator input(_input, window);
     Iterator input_squared(_input_squared, window);
@@ -121,39 +171,83 @@
     const int min_top    = 0;
     const int max_bottom = _input->info()->dimension(dim_y) - 1;
 
-    const float32x4_t coeff_vec = vdupq_n_f32(_norm_info.scale_coeff());
-    const float32x4_t beta_vec  = vdupq_n_f32(_norm_info.beta());
-    const float32x4_t kappa_vec = vdupq_n_f32(_norm_info.kappa());
-
-    execute_window_loop(window, [&](const Coordinates & id)
+    if(dt == DataType::F32)
     {
-        // Get range to normalize
-        const int current_row   = do_2D_norm ? id[dim_y] : 0;
-        const int current_slice = id[dim];
-        const int first_row     = do_2D_norm ? std::max(current_row - radius, min_top) : 0;
-        const int last_row      = do_2D_norm ? std::min(current_row + radius, max_bottom) : 0;
-        const int first_slice   = std::max(current_slice - radius, min_left);
-        const int last_slice    = std::min(current_slice + radius, max_right);
+        const float32x4_t coeff_vec = vdupq_n_f32(_norm_info.scale_coeff());
+        const float32x4_t beta_vec  = vdupq_n_f32(_norm_info.beta());
+        const float32x4_t kappa_vec = vdupq_n_f32(_norm_info.kappa());
 
-        // Accumulate 2D In-Map values
-        float32x4_t accu = vdupq_n_f32(0.f);
-        for(int j = first_row; j <= last_row; j++)
+        execute_window_loop(window, [&](const Coordinates & id)
         {
-            // Compute row displacement
-            const int            row               = (j - current_row) * _input_squared->info()->strides_in_bytes()[dim_y];
-            const uint8_t *const input_squared_ptr = input_squared.ptr() + row - (current_slice * input_squared_stride);
-            for(int i = first_slice; i <= last_slice; ++i)
-            {
-                accu = vaddq_f32(accu, vld1q_f32(reinterpret_cast<const float *>(input_squared_ptr + i * input_squared_stride)));
-            }
-        }
+            // Get range to normalize
+            const int current_row   = do_2D_norm ? id[dim_y] : 0;
+            const int current_slice = id[dim];
+            const int first_row     = do_2D_norm ? std::max(current_row - radius, min_top) : 0;
+            const int last_row      = do_2D_norm ? std::min(current_row + radius, max_bottom) : 0;
+            const int first_slice   = std::max(current_slice - radius, min_left);
+            const int last_slice    = std::min(current_slice + radius, max_right);
 
-        // Normalize
-        const float32x4_t normalized       = vpowq_f32(vmlaq_f32(kappa_vec, coeff_vec, accu), beta_vec);
-        const float32x4_t normalized_pixel = vmulq_f32(vld1q_f32(reinterpret_cast<const float *>(input.ptr())), vinvq_f32(normalized));
-        vst1q_f32(reinterpret_cast<float *>(output.ptr()), normalized_pixel);
-    },
-    input, input_squared, output);
+            // Accumulate 2D In-Map values
+            float32x4_t accu = vdupq_n_f32(0.f);
+            for(int j = first_row; j <= last_row; j++)
+            {
+                // Compute row displacement
+                const int            row               = (j - current_row) * _input_squared->info()->strides_in_bytes()[dim_y];
+                const uint8_t *const input_squared_ptr = input_squared.ptr() + row - (current_slice * input_squared_stride);
+                for(int i = first_slice; i <= last_slice; ++i)
+                {
+                    accu = vaddq_f32(accu, vld1q_f32(reinterpret_cast<const float *>(input_squared_ptr + i * input_squared_stride)));
+                }
+            }
+
+            // Normalize
+            const float32x4_t normalized       = vpowq_f32(vmlaq_f32(kappa_vec, coeff_vec, accu), beta_vec);
+            const float32x4_t normalized_pixel = vmulq_f32(vld1q_f32(reinterpret_cast<const float *>(input.ptr())), vinvq_f32(normalized));
+            vst1q_f32(reinterpret_cast<float *>(output.ptr()), normalized_pixel);
+        },
+        input, input_squared, output);
+    }
+#ifdef ARM_COMPUTE_ENABLE_FP16
+    else if(dt == DataType::F16)
+    {
+        const float16x8_t coeff_vec    = vdupq_n_f16(_norm_info.scale_coeff());
+        const float16x8_t beta_vec_f16 = vdupq_n_f16(_norm_info.beta());
+        const float16x8_t kappa_vec    = vdupq_n_f16(_norm_info.kappa());
+
+        execute_window_loop(window, [&](const Coordinates & id)
+        {
+            // Get range to normalize
+            const int current_row   = do_2D_norm ? id[dim_y] : 0;
+            const int current_slice = id[dim];
+            const int first_row     = do_2D_norm ? std::max(current_row - radius, min_top) : 0;
+            const int last_row      = do_2D_norm ? std::min(current_row + radius, max_bottom) : 0;
+            const int first_slice   = std::max(current_slice - radius, min_left);
+            const int last_slice    = std::min(current_slice + radius, max_right);
+
+            // Accumulate 2D In-Map values
+            float16x8_t accu = vdupq_n_f16(0.f);
+            for(int j = first_row; j <= last_row; j++)
+            {
+                // Compute row displacement
+                const int            row               = (j - current_row) * _input_squared->info()->strides_in_bytes()[dim_y];
+                const uint8_t *const input_squared_ptr = input_squared.ptr() + row - (current_slice * input_squared_stride);
+                for(int i = first_slice; i <= last_slice; ++i)
+                {
+                    accu = vaddq_f16(accu, vld1q_f16(reinterpret_cast<const float16_t *>(input_squared_ptr + i * input_squared_stride)));
+                }
+            }
+
+            const float16x8_t norm_f16         = vpowq_f16(vaddq_f16(kappa_vec, vmulq_f16(coeff_vec, accu)), beta_vec_f16);
+            const float16x8_t normalized_pixel = vmulq_f16(vld1q_f16(reinterpret_cast<const float16_t *>(input.ptr())), vinvq_f16(norm_f16));
+            vst1q_f16(reinterpret_cast<float16_t *>(output.ptr()), normalized_pixel);
+        },
+        input, input_squared, output);
+    }
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+    else
+    {
+        ARM_COMPUTE_ERROR("Not supported");
+    }
 }
 
 template <unsigned int dim, bool do_2D_norm>
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
index c3f61ac..83d6d82 100644
--- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
+++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
@@ -38,6 +38,10 @@
 #include <cstdint>
 #include <cstdlib>
 
+#if ARM_COMPUTE_ENABLE_FP16
+#include <arm_fp16.h> // needed for float16_t
+#endif                /* ARM_COMPUTE_ENABLE_FP16 */
+
 using namespace arm_compute;
 
 namespace arm_compute
@@ -249,6 +253,32 @@
 }
 
 template <bool is_scale255, bool is_sat>
+void mul_F16_F16_F16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale)
+{
+    ARM_COMPUTE_UNUSED(input1_ptr);
+    ARM_COMPUTE_UNUSED(input2_ptr);
+    ARM_COMPUTE_UNUSED(output_ptr);
+#ifdef ARM_COMPUTE_ENABLE_FP16
+    const auto          input1    = static_cast<const float16_t *__restrict>(input1_ptr);
+    const auto          input2    = static_cast<const float16_t *__restrict>(input2_ptr);
+    const auto          output    = static_cast<float16_t *__restrict>(output_ptr);
+    const float16x8x2_t ta1       = vld2q_f16(input1);
+    const float16x8x2_t ta2       = vld2q_f16(input2);
+    const float16x8_t   scale_vec = vdupq_n_f16(scale);
+    const float16x8x2_t result =
+    {
+        {
+            vmulq_f16(vmulq_f16(ta1.val[0], ta2.val[0]), scale_vec),
+            vmulq_f16(vmulq_f16(ta1.val[1], ta2.val[1]), scale_vec),
+        }
+    };
+    vst2q_f16(output, result);
+#else  /* ARM_COMPUTE_ENABLE_FP16 */
+    ARM_COMPUTE_ERROR("Not supported. Recompile the library with arch=arm64-v8.2-a.");
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+}
+
+template <bool is_scale255, bool is_sat>
 void mul_U8_U8_S16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, int n)
 {
     const auto input1 = static_cast<const uint8_t *__restrict>(input1_ptr);
@@ -347,6 +377,10 @@
         {
             set_format_if_unknown(*output->info(), Format::F32);
         }
+        else if(input1->info()->data_type() == DataType::F16 || input2->info()->data_type() == DataType::F16)
+        {
+            set_format_if_unknown(*output->info(), Format::F16);
+        }
         else if(input1->info()->data_type() == DataType::QS8 && input2->info()->data_type() == DataType::QS8)
         {
             set_data_type_if_unknown(*output->info(), DataType::QS8);
@@ -355,9 +389,9 @@
     }
 
     ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::S16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
                              "Output can only be U8 if both inputs are U8");
     if(input1->info()->data_type() == DataType::QS8)
@@ -479,6 +513,11 @@
             _func_q_int = is_sat ? &mul_QS8_QS8_QS8_n<false, true> : &mul_QS8_QS8_QS8_n<false, false>;
         }
     }
+    else if(DataType::F16 == dt_input1 && DataType::F16 == dt_input2 && DataType::F16 == dt_output)
+    {
+        _func_float = &mul_F16_F16_F16_n<false, false>;
+        _func_int   = nullptr;
+    }
     else if(DataType::F32 == dt_input1 && DataType::F32 == dt_input2 && DataType::F32 == dt_output)
     {
         _func_float = &mul_F32_F32_F32_n<false, false>;
diff --git a/tests/TensorLibrary.h b/tests/TensorLibrary.h
index 3fb593c..b0a0556 100644
--- a/tests/TensorLibrary.h
+++ b/tests/TensorLibrary.h
@@ -501,11 +501,6 @@
         }
 #if ARM_COMPUTE_ENABLE_FP16
         case DataType::F16:
-        {
-            std::uniform_real_distribution<float> distribution_f16(-1000.f, 1000.f);
-            fill(tensor, distribution_f16, seed_offset);
-            break;
-        }
 #endif /* ARM_COMPUTE_ENABLE_FP16 */
         case DataType::F32:
         {
diff --git a/tests/dataset/NormalizationTypeDataset.h b/tests/dataset/NormalizationTypeDataset.h
index 9edadbf..756772e 100644
--- a/tests/dataset/NormalizationTypeDataset.h
+++ b/tests/dataset/NormalizationTypeDataset.h
@@ -73,7 +73,7 @@
     }
 
 private:
-    std::array<NormType, 3> _types{ { NormType::IN_MAP_1D, NormType::IN_MAP_2D, NormType::CROSS_MAP } };
+    const std::array<NormType, 3> _types{ { NormType::IN_MAP_1D, NormType::IN_MAP_2D, NormType::CROSS_MAP } };
 };
 } // namespace test
 } // namespace arm_compute
diff --git a/tests/validation/NEON/NormalizationLayer.cpp b/tests/validation/NEON/NormalizationLayer.cpp
index a8ba7da..60c2646 100644
--- a/tests/validation/NEON/NormalizationLayer.cpp
+++ b/tests/validation/NEON/NormalizationLayer.cpp
@@ -52,6 +52,8 @@
     {
         case DataType::QS8:
             return 2.0f;
+        case DataType::F16:
+            return 0.001f;
         case DataType::F32:
             return 1e-05;
         default:
@@ -108,6 +110,29 @@
 BOOST_AUTO_TEST_SUITE(NEON)
 BOOST_AUTO_TEST_SUITE(NormalizationLayer)
 
+#ifdef ARM_COMPUTE_ENABLE_FP16
+BOOST_AUTO_TEST_SUITE(Float16)
+BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
+BOOST_DATA_TEST_CASE(RunSmall,
+                     SmallShapes() * DataType::F16 *NormalizationTypes() * boost::unit_test::data::xrange(3, 9, 2) * boost::unit_test::data::make({ 0.5f, 1.0f, 2.0f }),
+                     shape, dt, norm_type, norm_size, beta)
+{
+    // Provide normalization layer information
+    NormalizationLayerInfo norm_info(norm_type, norm_size, 5, beta);
+
+    // Compute function
+    Tensor dst = compute_normalization_layer(shape, dt, norm_info);
+
+    // Compute reference
+    RawTensor ref_dst = Reference::compute_reference_normalization_layer(shape, dt, norm_info);
+
+    // Validate output
+    validate(NEAccessor(dst), ref_dst, normalization_layer_tolerance(DataType::F16));
+}
+
+BOOST_AUTO_TEST_SUITE_END()
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
 BOOST_AUTO_TEST_SUITE(Float)
 BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
 BOOST_DATA_TEST_CASE(RunSmall,
diff --git a/tests/validation/NEON/PixelWiseMultiplication.cpp b/tests/validation/NEON/PixelWiseMultiplication.cpp
index 5641705..26ea38a 100644
--- a/tests/validation/NEON/PixelWiseMultiplication.cpp
+++ b/tests/validation/NEON/PixelWiseMultiplication.cpp
@@ -122,7 +122,6 @@
 BOOST_AUTO_TEST_SUITE(PixelWiseMultiplication)
 
 BOOST_AUTO_TEST_SUITE(U8)
-
 BOOST_AUTO_TEST_SUITE(Scale255)
 BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit") * boost::unit_test::label("nightly"))
 BOOST_DATA_TEST_CASE(Configuration, (SmallShapes() + LargeShapes()) * (1.f / 255.f) * ConvertPolicies()
@@ -314,6 +313,27 @@
 BOOST_AUTO_TEST_SUITE_END()
 BOOST_AUTO_TEST_SUITE_END()
 
+#ifdef ARM_COMPUTE_ENABLE_FP16
+BOOST_AUTO_TEST_SUITE(F16)
+BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
+
+BOOST_DATA_TEST_CASE(RunSmall, SmallShapes() * (1.f / 255.f) * ConvertPolicies() * RoundingPolicy::TO_NEAREST_UP,
+                     shape, scale, convert_policy, rounding_policy)
+{
+    // Compute function
+    Tensor dst = compute_pixel_wise_multiplication(shape, DataType::F16, DataType::F16, DataType::F16, scale, convert_policy, rounding_policy);
+
+    // Compute reference
+    RawTensor ref_dst = Reference::compute_reference_pixel_wise_multiplication(shape, DataType::F16, DataType::F16, DataType::F16, scale, convert_policy, rounding_policy);
+
+    // Validate output
+    // Allow tolerance value of 1.f to counteract imprecision due to 32-bit float conversion
+    validate(NEAccessor(dst), ref_dst, 1.f, 0.f, std::numeric_limits<int16_t>::max());
+}
+
+BOOST_AUTO_TEST_SUITE_END()
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
 BOOST_AUTO_TEST_SUITE(Float)
 BOOST_AUTO_TEST_SUITE(Scale255)
 BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit") * boost::unit_test::label("nightly"))