COMPMID-2700: Use NEON wrapper on SoftmaxLayer

Change-Id: Id8901e865c9f355dcf7b2a1a539493099591377e
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2186
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/NEON/NEColorConvertHelper.inl b/arm_compute/core/NEON/NEColorConvertHelper.inl
index 68f4371..62c6eb5 100644
--- a/arm_compute/core/NEON/NEColorConvertHelper.inl
+++ b/arm_compute/core/NEON/NEColorConvertHelper.inl
@@ -24,6 +24,7 @@
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/IMultiImage.h"
+#include "arm_compute/core/NEON/NEMath.h"
 #include "arm_compute/core/Utils.h"
 
 #include <arm_neon.h>
@@ -49,37 +50,6 @@
 constexpr float rgb2u8_green_coef = 0.7152f;
 constexpr float rgb2u8_blue_coef  = 0.0722f;
 
-inline float32x4x4_t convert_uint8x16_to_float32x4x4(const uint8x16_t &in)
-{
-    float32x4x4_t out;
-    const auto    tmp1 = vmovl_u8(vget_low_u8(in));
-    out.val[0]         = vcvtq_f32_u32(vmovl_u16(vget_low_u16(tmp1)));
-    out.val[1]         = vcvtq_f32_u32(vmovl_u16(vget_high_u16(tmp1)));
-    const auto tmp2    = vmovl_u8(vget_high_u8(in));
-    out.val[2]         = vcvtq_f32_u32(vmovl_u16(vget_low_u16(tmp2)));
-    out.val[3]         = vcvtq_f32_u32(vmovl_u16(vget_high_u16(tmp2)));
-    return out;
-}
-
-inline void convert_float32x4x3_to_uint8x8x3(const float32x4x3_t &in1, const float32x4x3_t &in2, uint8x8x3_t &out)
-{
-    out.val[0] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[0])),
-                                         vqmovn_u32(vcvtq_u32_f32(in2.val[0]))));
-    out.val[1] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[1])),
-                                         vqmovn_u32(vcvtq_u32_f32(in2.val[1]))));
-    out.val[2] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[2])),
-                                         vqmovn_u32(vcvtq_u32_f32(in2.val[2]))));
-}
-
-inline void convert_float32x4x4_to_unit8x16(const float32x4x4_t &in, uint8x16_t &out)
-{
-    const auto low = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[0])),
-                                  vqmovn_u32(vcvtq_u32_f32(in.val[1])));
-    const auto high = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[2])),
-                                   vqmovn_u32(vcvtq_u32_f32(in.val[3])));
-    out = vcombine_u8(vqmovn_u16(low), vqmovn_u16(high));
-}
-
 inline float32x4_t rgb_to_greyscale_calculation(const float32x4_t &rcolor, const float32x4_t &gcolor, const float32x4_t &bcolor,
                                                 const float rcoef, const float gcoef, const float bcoef)
 {
@@ -94,9 +64,9 @@
     float32x4x4_t out_float32;
 
     //Conversion from 3(RGB) 4 uint8s to 3(RGB) 4 floats
-    const float32x4x4_t r_float32 = convert_uint8x16_to_float32x4x4(in.val[0]);
-    const float32x4x4_t g_float32 = convert_uint8x16_to_float32x4x4(in.val[1]);
-    const float32x4x4_t b_float32 = convert_uint8x16_to_float32x4x4(in.val[2]);
+    const float32x4x4_t r_float32 = arm_compute::convert_uint8x16_to_float32x4x4(in.val[0]);
+    const float32x4x4_t g_float32 = arm_compute::convert_uint8x16_to_float32x4x4(in.val[1]);
+    const float32x4x4_t b_float32 = arm_compute::convert_uint8x16_to_float32x4x4(in.val[2]);
 
     //New grayscale image = ( (RED_COEFF * R) + (GREEN_COEFF * G) + (BLUE_COEFF * B) )
     //Computation of 1(Greyscale) 4 uint8 using 3(RGB) 4 uint8s float
@@ -113,7 +83,7 @@
                                                       rgb2u8_red_coef, rgb2u8_green_coef, rgb2u8_blue_coef);
 
     //Conversion from 1(Greyscale) 4 floats to 1(Greyscale) 4 uint8s
-    convert_float32x4x4_to_unit8x16(out_float32, out);
+    arm_compute::convert_float32x4x4_to_unit8x16(out_float32, out);
 }
 
 inline void rgb_to_yuv_calculation(const float32x4_t &rvec, const float32x4_t &gvec, const float32x4_t &bvec,
@@ -172,7 +142,7 @@
     rgb2.val[2] = vaddq_f32(yyvec_val, blue);
 
     uint8x8x3_t u8_rgb;
-    convert_float32x4x3_to_uint8x8x3(rgb1, rgb2, u8_rgb);
+    arm_compute::convert_float32x4x3_to_uint8x8x3(rgb1, rgb2, u8_rgb);
 
     if(!alpha)
     {
@@ -225,13 +195,13 @@
 inline void rgb_to_yuv_conversion(uint8x16x3_t &vec_top, uint8x16x3_t &vec_bottom)
 {
     // Convert the uint8x16_t to float32x4x4_t
-    const float32x4x4_t frvec_top = convert_uint8x16_to_float32x4x4(vec_top.val[0]);
-    const float32x4x4_t fgvec_top = convert_uint8x16_to_float32x4x4(vec_top.val[1]);
-    const float32x4x4_t fbvec_top = convert_uint8x16_to_float32x4x4(vec_top.val[2]);
+    const float32x4x4_t frvec_top = arm_compute::convert_uint8x16_to_float32x4x4(vec_top.val[0]);
+    const float32x4x4_t fgvec_top = arm_compute::convert_uint8x16_to_float32x4x4(vec_top.val[1]);
+    const float32x4x4_t fbvec_top = arm_compute::convert_uint8x16_to_float32x4x4(vec_top.val[2]);
 
-    const float32x4x4_t frvec_bottom = convert_uint8x16_to_float32x4x4(vec_bottom.val[0]);
-    const float32x4x4_t fgvec_bottom = convert_uint8x16_to_float32x4x4(vec_bottom.val[1]);
-    const float32x4x4_t fbvec_bottom = convert_uint8x16_to_float32x4x4(vec_bottom.val[2]);
+    const float32x4x4_t frvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(vec_bottom.val[0]);
+    const float32x4x4_t fgvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(vec_bottom.val[1]);
+    const float32x4x4_t fbvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(vec_bottom.val[2]);
 
     float32x4x4_t fyvec_top, fuvec_top, fvvec_top;
     float32x4x4_t fyvec_bottom, fuvec_bottom, fvvec_bottom;
@@ -244,12 +214,12 @@
                                fyvec_bottom.val[i], fuvec_bottom.val[i], fvvec_bottom.val[i]);
     }
 
-    convert_float32x4x4_to_unit8x16(fyvec_top, vec_top.val[0]);
-    convert_float32x4x4_to_unit8x16(fuvec_top, vec_top.val[1]);
-    convert_float32x4x4_to_unit8x16(fvvec_top, vec_top.val[2]);
-    convert_float32x4x4_to_unit8x16(fyvec_bottom, vec_bottom.val[0]);
-    convert_float32x4x4_to_unit8x16(fuvec_bottom, vec_bottom.val[1]);
-    convert_float32x4x4_to_unit8x16(fvvec_bottom, vec_bottom.val[2]);
+    arm_compute::convert_float32x4x4_to_unit8x16(fyvec_top, vec_top.val[0]);
+    arm_compute::convert_float32x4x4_to_unit8x16(fuvec_top, vec_top.val[1]);
+    arm_compute::convert_float32x4x4_to_unit8x16(fvvec_top, vec_top.val[2]);
+    arm_compute::convert_float32x4x4_to_unit8x16(fyvec_bottom, vec_bottom.val[0]);
+    arm_compute::convert_float32x4x4_to_unit8x16(fuvec_bottom, vec_bottom.val[1]);
+    arm_compute::convert_float32x4x4_to_unit8x16(fvvec_bottom, vec_bottom.val[2]);
 }
 
 inline void store_rgb_to_nv12(const uint8x16_t &rvec_top, const uint8x16_t &gvec_top, const uint8x16_t &bvec_top,
@@ -316,9 +286,9 @@
                               unsigned char *const __restrict out_v)
 {
     // Convert the uint8x16_t to float32x4x4_t
-    const float32x4x4_t frvec = convert_uint8x16_to_float32x4x4(rvec);
-    const float32x4x4_t fgvec = convert_uint8x16_to_float32x4x4(gvec);
-    const float32x4x4_t fbvec = convert_uint8x16_to_float32x4x4(bvec);
+    const float32x4x4_t frvec = arm_compute::convert_uint8x16_to_float32x4x4(rvec);
+    const float32x4x4_t fgvec = arm_compute::convert_uint8x16_to_float32x4x4(gvec);
+    const float32x4x4_t fbvec = arm_compute::convert_uint8x16_to_float32x4x4(bvec);
 
     float32x4x4_t fyvec, fuvec, fvvec;
     for(auto i = 0; i < 4; ++i)
@@ -328,9 +298,9 @@
     }
 
     uint8x16_t yvec, uvec, vvec;
-    convert_float32x4x4_to_unit8x16(fyvec, yvec);
-    convert_float32x4x4_to_unit8x16(fuvec, uvec);
-    convert_float32x4x4_to_unit8x16(fvvec, vvec);
+    arm_compute::convert_float32x4x4_to_unit8x16(fyvec, yvec);
+    arm_compute::convert_float32x4x4_to_unit8x16(fuvec, uvec);
+    arm_compute::convert_float32x4x4_to_unit8x16(fvvec, vvec);
 
     vst1q_u8(out_y, yvec);
     vst1q_u8(out_u, uvec);
@@ -461,10 +431,10 @@
         //ta.val[3] = V0 V2 V4 V7 ...
 
         // Convert the uint8x16x4_t to float32x4x4_t
-        const float32x4x4_t yvec  = convert_uint8x16_to_float32x4x4(ta.val[0 + shift]);
-        const float32x4x4_t uvec  = convert_uint8x16_to_float32x4x4(ta.val[1 - shift]);
-        const float32x4x4_t yyvec = convert_uint8x16_to_float32x4x4(ta.val[2 + shift]);
-        const float32x4x4_t vvec  = convert_uint8x16_to_float32x4x4(ta.val[3 - shift]);
+        const float32x4x4_t yvec  = arm_compute::convert_uint8x16_to_float32x4x4(ta.val[0 + shift]);
+        const float32x4x4_t uvec  = arm_compute::convert_uint8x16_to_float32x4x4(ta.val[1 - shift]);
+        const float32x4x4_t yyvec = arm_compute::convert_uint8x16_to_float32x4x4(ta.val[2 + shift]);
+        const float32x4x4_t vvec  = arm_compute::convert_uint8x16_to_float32x4x4(ta.val[3 - shift]);
 
         yuyv_to_rgb_calculation(yvec.val[0], uvec.val[0], yyvec.val[0], vvec.val[0], out.ptr() + 0 * element_size, alpha);
         yuyv_to_rgb_calculation(yvec.val[1], uvec.val[1], yyvec.val[1], vvec.val[1], out.ptr() + 1 * element_size, alpha);
@@ -516,12 +486,12 @@
         //ta_uv.val[1] = V0 V2 V4 V6 ...
 
         // Convert the uint8x16x4_t to float32x4x4_t
-        float32x4x4_t yvec_top     = convert_uint8x16_to_float32x4x4(ta_y_top.val[0]);
-        float32x4x4_t yyvec_top    = convert_uint8x16_to_float32x4x4(ta_y_top.val[1]);
-        float32x4x4_t yvec_bottom  = convert_uint8x16_to_float32x4x4(ta_y_bottom.val[0]);
-        float32x4x4_t yyvec_bottom = convert_uint8x16_to_float32x4x4(ta_y_bottom.val[1]);
-        float32x4x4_t uvec         = convert_uint8x16_to_float32x4x4(ta_uv.val[0 + shift]);
-        float32x4x4_t vvec         = convert_uint8x16_to_float32x4x4(ta_uv.val[1 - shift]);
+        float32x4x4_t yvec_top     = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_top.val[0]);
+        float32x4x4_t yyvec_top    = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_top.val[1]);
+        float32x4x4_t yvec_bottom  = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_bottom.val[0]);
+        float32x4x4_t yyvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_bottom.val[1]);
+        float32x4x4_t uvec         = arm_compute::convert_uint8x16_to_float32x4x4(ta_uv.val[0 + shift]);
+        float32x4x4_t vvec         = arm_compute::convert_uint8x16_to_float32x4x4(ta_uv.val[1 - shift]);
 
         yuyv_to_rgb_calculation(yvec_top.val[0], uvec.val[0], yyvec_top.val[0], vvec.val[0], out.ptr() + 0 * element_size, alpha);
         yuyv_to_rgb_calculation(yvec_top.val[1], uvec.val[1], yyvec_top.val[1], vvec.val[1], out.ptr() + 1 * element_size, alpha);
@@ -579,12 +549,12 @@
         //ta_v.val[0] = V0 V2 V4 V6 ...
 
         // Convert the uint8x16x4_t to float32x4x4_t
-        float32x4x4_t yvec_top     = convert_uint8x16_to_float32x4x4(ta_y_top.val[0]);
-        float32x4x4_t yyvec_top    = convert_uint8x16_to_float32x4x4(ta_y_top.val[1]);
-        float32x4x4_t yvec_bottom  = convert_uint8x16_to_float32x4x4(ta_y_bottom.val[0]);
-        float32x4x4_t yyvec_bottom = convert_uint8x16_to_float32x4x4(ta_y_bottom.val[1]);
-        float32x4x4_t uvec         = convert_uint8x16_to_float32x4x4(ta_u);
-        float32x4x4_t vvec         = convert_uint8x16_to_float32x4x4(ta_v);
+        float32x4x4_t yvec_top     = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_top.val[0]);
+        float32x4x4_t yyvec_top    = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_top.val[1]);
+        float32x4x4_t yvec_bottom  = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_bottom.val[0]);
+        float32x4x4_t yyvec_bottom = arm_compute::convert_uint8x16_to_float32x4x4(ta_y_bottom.val[1]);
+        float32x4x4_t uvec         = arm_compute::convert_uint8x16_to_float32x4x4(ta_u);
+        float32x4x4_t vvec         = arm_compute::convert_uint8x16_to_float32x4x4(ta_v);
 
         yuyv_to_rgb_calculation(yvec_top.val[0], uvec.val[0], yyvec_top.val[0], vvec.val[0], out.ptr() + 0 * element_size, alpha);
         yuyv_to_rgb_calculation(yvec_top.val[1], uvec.val[1], yyvec_top.val[1], vvec.val[1], out.ptr() + 1 * element_size, alpha);
diff --git a/arm_compute/core/NEON/NEMath.h b/arm_compute/core/NEON/NEMath.h
index 8593059..aa30543 100644
--- a/arm_compute/core/NEON/NEMath.h
+++ b/arm_compute/core/NEON/NEMath.h
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef __ARM_COMPUTE_NEMATH_H__
-#define __ARM_COMPUTE_NEMATH_H__
+#ifndef ARM_COMPUTE_NEMATH_H
+#define ARM_COMPUTE_NEMATH_H
 
 #include <arm_neon.h>
 
@@ -157,6 +157,29 @@
  */
 int32_t rounding_divide_by_pow2(int32_t x, int exponent);
 
+/** Converts from uint8x16 to float32x4x4_t
+ *
+ * @param[in] in Vector of uint8 to be converted
+ *
+ * @return Converted vector of float
+ */
+float32x4x4_t convert_uint8x16_to_float32x4x4(const uint8x16_t &in);
+
+/** Converts from two float32x4x3_t to just one uint8x8x3_t
+ *
+ * @param[in]  in1 First input vector of float to be converted
+ * @param[in]  in2 Second input vector of float to be converted
+ * @param[out] out Converted output vector uint8 to store the result
+ */
+void convert_float32x4x3_to_uint8x8x3(const float32x4x3_t &in1, const float32x4x3_t &in2, uint8x8x3_t &out);
+
+/** Converts from two float32x4x4_t to just one uint8x16_t
+ *
+ * @param[in]  in  Vector of float to be converted
+ * @param[out] out Converted vector of uint8 to store the result
+ */
+void convert_float32x4x4_to_unit8x16(const float32x4x4_t &in, uint8x16_t &out);
+
 /** Calculate sine.
  *
  * @param[in] val Input vector value in radians, F32 format.
@@ -256,4 +279,4 @@
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
 } // namespace arm_compute
 #include "arm_compute/core/NEON/NEMath.inl"
-#endif /* __ARM_COMPUTE_NEMATH_H__ */
+#endif /* ARM_COMPUTE_NEMATH_H */
diff --git a/arm_compute/core/NEON/NEMath.inl b/arm_compute/core/NEON/NEMath.inl
index f1c9c20..a3601f6 100644
--- a/arm_compute/core/NEON/NEMath.inl
+++ b/arm_compute/core/NEON/NEMath.inl
@@ -317,6 +317,39 @@
     return (x >> exponent) + ((x & mask) > threshold ? 1 : 0);
 }
 
+inline float32x4x4_t convert_uint8x16_to_float32x4x4(const uint8x16_t &in)
+{
+    float32x4x4_t out;
+
+    const auto tmp1 = vmovl_u8(vget_low_u8(in));
+    out.val[0]      = vcvtq_f32_u32(vmovl_u16(vget_low_u16(tmp1)));
+    out.val[1]      = vcvtq_f32_u32(vmovl_u16(vget_high_u16(tmp1)));
+
+    const auto tmp2 = vmovl_u8(vget_high_u8(in));
+    out.val[2]      = vcvtq_f32_u32(vmovl_u16(vget_low_u16(tmp2)));
+    out.val[3]      = vcvtq_f32_u32(vmovl_u16(vget_high_u16(tmp2)));
+    return out;
+}
+
+inline void convert_float32x4x3_to_uint8x8x3(const float32x4x3_t &in1, const float32x4x3_t &in2, uint8x8x3_t &out)
+{
+    out.val[0] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[0])),
+                                         vqmovn_u32(vcvtq_u32_f32(in2.val[0]))));
+    out.val[1] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[1])),
+                                         vqmovn_u32(vcvtq_u32_f32(in2.val[1]))));
+    out.val[2] = vqmovn_u16(vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in1.val[2])),
+                                         vqmovn_u32(vcvtq_u32_f32(in2.val[2]))));
+}
+
+inline void convert_float32x4x4_to_unit8x16(const float32x4x4_t &in, uint8x16_t &out)
+{
+    const auto low = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[0])),
+                                  vqmovn_u32(vcvtq_u32_f32(in.val[1])));
+    const auto high = vcombine_u16(vqmovn_u32(vcvtq_u32_f32(in.val[2])),
+                                   vqmovn_u32(vcvtq_u32_f32(in.val[3])));
+    out = vcombine_u8(vqmovn_u16(low), vqmovn_u16(high));
+}
+
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 /** Exponent polynomial coefficients */
 /** Logarithm polynomial coefficients */
diff --git a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
index 1003ebd..a3ecce3 100644
--- a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
+++ b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
@@ -30,6 +30,7 @@
 #include "arm_compute/core/ITensor.h"
 #include "arm_compute/core/NEON/NEFixedPoint.h"
 #include "arm_compute/core/NEON/NEMath.h"
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
@@ -43,309 +44,6 @@
 
 namespace arm_compute
 {
-template <typename T, int N>
-struct vec_n_type;
-
-#define DECLARE_NEON_VEC_TYPE(T, N, V) \
-    template <>                        \
-    struct vec_n_type<T, N>            \
-    {                                  \
-        using type = V;                \
-    };
-
-DECLARE_NEON_VEC_TYPE(uint8_t, 16, uint8x16_t)
-DECLARE_NEON_VEC_TYPE(uint8_t, 8, uint8x8_t)
-
-DECLARE_NEON_VEC_TYPE(int8_t, 16, int8x16_t)
-DECLARE_NEON_VEC_TYPE(int8_t, 8, int8x8_t)
-
-DECLARE_NEON_VEC_TYPE(uint16_t, 8, uint16x8_t)
-DECLARE_NEON_VEC_TYPE(uint16_t, 4, uint16x4_t)
-
-DECLARE_NEON_VEC_TYPE(int16_t, 8, int16x8_t)
-DECLARE_NEON_VEC_TYPE(int16_t, 4, int16x4_t)
-
-DECLARE_NEON_VEC_TYPE(int32_t, 4, int32x4_t)
-DECLARE_NEON_VEC_TYPE(int32_t, 2, int32x2_t)
-
-DECLARE_NEON_VEC_TYPE(uint32_t, 4, uint32x4_t)
-DECLARE_NEON_VEC_TYPE(uint32_t, 2, uint32x2_t)
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-DECLARE_NEON_VEC_TYPE(float16_t, 8, float16x8_t)
-DECLARE_NEON_VEC_TYPE(float16_t, 4, float16x4_t)
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-
-DECLARE_NEON_VEC_TYPE(float, 4, float32x4_t)
-DECLARE_NEON_VEC_TYPE(float, 2, float32x2_t)
-
-template <typename T, int N>
-using vec_n_t = typename vec_n_type<T, N>::type;
-
-template <typename T, int N>
-using vec_n_byte_t = vec_n_t < T, N / sizeof(T) >;
-
-template <typename T>
-using vec_16_byte_t = vec_n_byte_t<T, 16>;
-
-template <typename T>
-using vec_8_byte_t = vec_n_byte_t<T, 8>;
-
-template <typename T>
-using const_ptr_t = const T *;
-
-template <typename T>
-using ptr_t = T *;
-
-#define FORWARD_DECLARE_VGET_LANE_FOR_TYPE(TYPE) \
-    template <int lane>                          \
-    TYPE vget_lane(vec_8_byte_t<TYPE> vec);      \
-    template <int lane>                          \
-    TYPE vget_lane(vec_16_byte_t<TYPE> vec);
-
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint8_t)
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int8_t)
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint16_t)
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int16_t)
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint32_t)
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int32_t)
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float16_t)
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float)
-template <int lane>
-float vget_lane(float32x4x4_t vec);
-
-template <typename V>
-using elem_type_t = decltype(vget_lane<0>(std::declval<V>()));
-
-template <typename V>
-constexpr size_t vec_size_of(const V &vec)
-{
-    return sizeof(vec) / sizeof(elem_type_t<V>);
-}
-
-template <typename V>
-V vdup_n(elem_type_t<V> val);
-template <typename V>
-V vld(const_ptr_t<elem_type_t<V>> ptr);
-
-#define DECLARE_NEON_FUNCTIONS_FOR_TYPE(TYPE, TAG)                                \
-    template <>                                                                   \
-    inline vec_8_byte_t<TYPE> vdup_n<vec_8_byte_t<TYPE>>(TYPE val)                \
-    {                                                                             \
-        return vdup_n_##TAG(val);                                                 \
-    }                                                                             \
-    template <>                                                                   \
-    inline vec_16_byte_t<TYPE> vdup_n<vec_16_byte_t<TYPE>>(TYPE val)              \
-    {                                                                             \
-        return vdupq_n_##TAG(val);                                                \
-    }                                                                             \
-    template <>                                                                   \
-    inline vec_8_byte_t<TYPE> vld<vec_8_byte_t<TYPE>>(const_ptr_t<TYPE> ptr)      \
-    {                                                                             \
-        return vld1_##TAG(ptr);                                                   \
-    }                                                                             \
-    template <>                                                                   \
-    inline vec_16_byte_t<TYPE> vld<vec_16_byte_t<TYPE>>(const_ptr_t<TYPE> ptr)    \
-    {                                                                             \
-        return vld1q_##TAG(ptr);                                                  \
-    }                                                                             \
-    inline void vst(ptr_t<TYPE> ptr, vec_8_byte_t<TYPE> vec)                      \
-    {                                                                             \
-        vst1_##TAG(ptr, vec);                                                     \
-    }                                                                             \
-    inline void vst(ptr_t<TYPE> ptr, vec_16_byte_t<TYPE> vec)                     \
-    {                                                                             \
-        vst1q_##TAG(ptr, vec);                                                    \
-    }                                                                             \
-    inline vec_16_byte_t<TYPE> vmax(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
-    {                                                                             \
-        return vmaxq_##TAG(a, b);                                                 \
-    }                                                                             \
-    inline vec_8_byte_t<TYPE> vpmax(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b)   \
-    {                                                                             \
-        return vpmax_##TAG(a, b);                                                 \
-    }                                                                             \
-    inline vec_8_byte_t<TYPE> vget_low(vec_16_byte_t<TYPE> vec)                   \
-    {                                                                             \
-        return vget_low_##TAG(vec);                                               \
-    }                                                                             \
-    inline vec_8_byte_t<TYPE> vget_high(vec_16_byte_t<TYPE> vec)                  \
-    {                                                                             \
-        return vget_high_##TAG(vec);                                              \
-    }                                                                             \
-    template <int lane>                                                           \
-    inline TYPE vget_lane(vec_8_byte_t<TYPE> vec)                                 \
-    {                                                                             \
-        static_assert(lane >= 0, "lane is out of bounds");                        \
-        static_assert(lane < vec_size_of(vec), "lane is out of bounds");          \
-        return vget_lane_##TAG(vec, lane);                                        \
-    }                                                                             \
-    template <int lane>                                                           \
-    inline TYPE vget_lane(vec_16_byte_t<TYPE> vec)                                \
-    {                                                                             \
-        static_assert(lane >= 0, "lane is out of bounds");                        \
-        static_assert(lane < vec_size_of(vec), "lane is out of bounds");          \
-        return vgetq_lane_##TAG(vec, lane);                                       \
-    }
-
-template <typename T>
-T sqadd(T a, T b);
-template <typename T>
-T sqsub(T a, T b);
-template <typename T>
-T sqmul(T a, T b);
-
-#define DECLARE_NEON_FUNCTIONS_FOR_FLOAT(TYPE, TAG)                               \
-    inline vec_8_byte_t<TYPE> vadd(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b)    \
-    {                                                                             \
-        return vadd_##TAG(a, b);                                                  \
-    }                                                                             \
-    inline vec_16_byte_t<TYPE> vadd(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
-    {                                                                             \
-        return vaddq_##TAG(a, b);                                                 \
-    }                                                                             \
-    inline vec_16_byte_t<TYPE> vsub(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
-    {                                                                             \
-        return vsubq_##TAG(a, b);                                                 \
-    }                                                                             \
-    inline vec_16_byte_t<TYPE> vmul_n(vec_16_byte_t<TYPE> vec, TYPE val)          \
-    {                                                                             \
-        return vmulq_n_##TAG(vec, val);                                           \
-    }
-
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint8_t, u8)
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(int8_t, s8)
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint16_t, u16)
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(int16_t, s16)
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint32_t, u32)
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(int32_t, s32)
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(float16_t, f16)
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-DECLARE_NEON_FUNCTIONS_FOR_TYPE(float, f32)
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float16_t, f16)
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float, f32)
-
-template <typename VO, typename VI>
-VO vcvt(VI vec);
-
-template <>
-float32x4x4_t vcvt<float32x4x4_t>(uint8x16_t vec)
-{
-    const auto    low  = vmovl_u8(vget_low(vec));
-    const auto    high = vmovl_u8(vget_high(vec));
-    float32x4x4_t res  = { {
-            vcvtq_f32_u32(vmovl_u16(vget_low(low))),
-            vcvtq_f32_u32(vmovl_u16(vget_high(low))),
-            vcvtq_f32_u32(vmovl_u16(vget_low(high))),
-            vcvtq_f32_u32(vmovl_u16(vget_high(high)))
-        }
-    };
-    return res;
-}
-
-template <>
-uint8x16_t vcvt<uint8x16_t>(float32x4x4_t vec)
-{
-    uint16x8x2_t resU16 = { {
-            vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[0])),
-            vqmovn_u32(vcvtq_u32_f32(vec.val[1]))),
-            vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[2])),
-            vqmovn_u32(vcvtq_u32_f32(vec.val[3])))
-        }
-    };
-
-    uint8x16_t res = vcombine_u8(vqmovn_u16(resU16.val[0]), vqmovn_u16(resU16.val[1]));
-    return res;
-}
-
-float32x4x4_t vexp(float32x4x4_t vec)
-{
-    float32x4x4_t res = { {
-            vexpq_f32(vec.val[0]),
-            vexpq_f32(vec.val[1]),
-            vexpq_f32(vec.val[2]),
-            vexpq_f32(vec.val[3])
-        }
-    };
-    return res;
-}
-
-float32x4_t vexp(const float32x4_t &vec)
-{
-    return vexpq_f32(vec);
-}
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-// TODO (COMPMID-1535) : Revisit FP16 approximations
-float16x8_t vexp(const float16x8_t &vec)
-{
-    float16x4x2_t res =
-    {
-        {
-            vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vget_low_f16(vec)))),
-            vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vget_high_f16(vec))))
-        }
-    };
-    return vcombine_f16(res.val[0], res.val[1]);
-}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-
-template <>
-float32x4x4_t vdup_n<float32x4x4_t>(float val)
-{
-    float32x4x4_t res = { {
-            vdupq_n_f32(val),
-            vdupq_n_f32(val),
-            vdupq_n_f32(val),
-            vdupq_n_f32(val)
-        }
-    };
-    return res;
-}
-
-float32x4x4_t vmul_n(float32x4x4_t vec, float val)
-{
-    float32x4x4_t res = { {
-            vmulq_n_f32(vec.val[0], val),
-            vmulq_n_f32(vec.val[1], val),
-            vmulq_n_f32(vec.val[2], val),
-            vmulq_n_f32(vec.val[3], val)
-        }
-    };
-    return res;
-}
-
-float32x4x4_t vadd(float32x4x4_t a, float32x4x4_t b)
-{
-    float32x4x4_t res = { {
-            vaddq_f32(a.val[0], b.val[0]),
-            vaddq_f32(a.val[1], b.val[1]),
-            vaddq_f32(a.val[2], b.val[2]),
-            vaddq_f32(a.val[3], b.val[3])
-        }
-    };
-    return res;
-}
-
-float32x4x4_t vsub_n(float32x4x4_t a, float val)
-{
-    auto          scalar_vector = vdup_n<float32x4x4_t>(val);
-    float32x4x4_t res           = { {
-            vsubq_f32(a.val[0], scalar_vector.val[0]),
-            vsubq_f32(a.val[1], scalar_vector.val[1]),
-            vsubq_f32(a.val[2], scalar_vector.val[2]),
-            vsubq_f32(a.val[3], scalar_vector.val[3])
-        }
-    };
-    return res;
-}
-
 namespace
 {
 Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
@@ -390,30 +88,20 @@
     return std::make_pair(err, win);
 }
 
-template <typename V>
-auto reduce_max(V vec) -> elem_type_t<V>
-{
-    constexpr int N = vec_size_of(vec);
-
-    auto carry_max = vpmax(vget_high(vec), vget_low(vec));
-
-    for(int k = N / 2; k > 1; k /= 2)
-    {
-        carry_max = vpmax(carry_max, carry_max);
-    }
-
-    return vget_lane<0>(carry_max);
-}
-
 template <typename T>
 void logits_1d_max(const ITensor &in, ITensor &out, const Window &window)
 {
     const auto   start_x     = in.info()->valid_region().anchor.x();
     const size_t input_width = in.info()->valid_region().shape.x();
 
+    /** NEON vector tag type. */
+    using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
     Iterator input(&in, window);
     Iterator output(&out, window);
 
+    constexpr int window_step_x = 16 / sizeof(T);
+    const int     sum_stages    = log2(window_step_x / 2);
     execute_window_loop(window, [&](const Coordinates &)
     {
         // Get pointers
@@ -421,16 +109,22 @@
         const auto out_ptr = reinterpret_cast<T *>(output.ptr());
 
         // Init max value
-        auto vec_max = vdup_n<vec_16_byte_t<T>>(support::cpp11::lowest<T>());
+        auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
 
         // Loop over input row
-        for(const T *it = in_ptr; it < (in_ptr + input_width); it += vec_size_of(vec_max))
+        for(const T *it = in_ptr; it < (in_ptr + input_width); it += window_step_x)
         {
-            const auto current_value = vld<vec_16_byte_t<T>>(it);
-            vec_max                  = vmax(vec_max, current_value);
+            const auto current_value = wrapper::vloadq(it);
+            vec_max                  = wrapper::vmax(vec_max, current_value);
         }
 
-        const T max_val = reduce_max(vec_max);
+        auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
+
+        for(int i = 0; i < sum_stages; ++i)
+        {
+            carry_max = wrapper::vpmax(carry_max, carry_max);
+        }
+        const T max_val = wrapper::vgetlane(carry_max, 0);
         *out_ptr        = max_val;
     },
     input, output);
@@ -575,45 +269,19 @@
     return std::make_pair(err, win);
 }
 
-template <typename T, int N, int S, int E>
-struct reduce_add_impl
-{
-    template <typename F>
-    static T reduce(F add_fn, vec_n_t<T, N> vec)
-    {
-        constexpr int H            = (S + E + 1) / 2;
-        const auto    reduced_high = reduce_add_impl < T, N, S, H - 1 >::reduce(add_fn, vec);
-        const auto    reduced_low  = reduce_add_impl<T, N, H, E>::reduce(add_fn, vec);
-        return add_fn(reduced_high, reduced_low);
-    }
-};
-template <typename T, int N, int I>
-struct reduce_add_impl<T, N, I, I>
-{
-    template <typename F>
-    static T reduce(F /*add_fn*/, vec_n_t<T, N> vec)
-    {
-        return vget_lane<I>(vec);
-    }
-};
-template <typename V, typename F>
-elem_type_t<V> reduce_add(F add_fn, V vec)
-{
-    constexpr int N = vec_size_of(vec);
-    return reduce_add_impl < elem_type_t<V>, N, 0, N - 1 >::reduce(add_fn, vec);
-}
-
 template <bool is_log>
 void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
 {
     const int start_x     = in.info()->valid_region().anchor.x();
     const int input_width = in.info()->valid_region().shape.x();
 
-    const float scale_beta = -beta * in.info()->quantization_info().uniform().scale;
+    const float scale_beta     = -beta * in.info()->quantization_info().uniform().scale;
+    const auto  scale_beta_vec = vdupq_n_f32(scale_beta);
 
-    Iterator in_it(&in, window);
-    Iterator max_it(&max, window);
-    Iterator out_it(&out, window);
+    Iterator      in_it(&in, window);
+    Iterator      max_it(&max, window);
+    Iterator      out_it(&out, window);
+    constexpr int vec_size = 16;
 
     execute_window_loop(window, [&](const Coordinates &)
     {
@@ -629,57 +297,73 @@
         {
             /* Get max value */
             const auto max_val = *reinterpret_cast<const qasymm8_t *>(max_it.ptr());
-            const auto vec_max = vdup_n<vec_16_byte_t<qasymm8_t>>(max_val);
+            const auto vec_max = vdupq_n_u8(max_val);
 
             /* Init sum to zero */
-            auto vec_sum = vdup_n<float32x4x4_t>(0.f);
+            float32x4x4_t vec_sum =
+            {
+                vdupq_n_f32(0.f),
+                vdupq_n_f32(0.f),
+                vdupq_n_f32(0.f),
+                vdupq_n_f32(0.f),
+            };
 
             /* Loop over row and compute exponentials and sum */
-            int           i        = 0;
-            constexpr int vec_size = vec_size_of(vec_max);
-
-            for(; i <= (input_width - vec_size); i += vec_size)
+            int x = 0;
+            for(; x <= (input_width - vec_size); x += vec_size)
             {
-                auto vec_elements = vld<vec_16_byte_t<qasymm8_t>>(in_ptr + i);
-                vec_elements      = vsubq_u8(vec_max, vec_elements);
-
-                auto vec_elements_flt = vcvt<float32x4x4_t>(vec_elements);
+                auto vec_elements     = wrapper::vloadq(in_ptr + x);
+                vec_elements          = vsubq_u8(vec_max, vec_elements);
+                auto vec_elements_flt = convert_uint8x16_to_float32x4x4(vec_elements);
 
                 if(is_log)
                 {
-                    vec_elements_flt = vmul_n(vec_elements_flt, scale_beta);
-                    vec_sum          = vadd(vec_sum, vexp(vec_elements_flt));
+                    vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
+                    vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
+                    vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
+                    vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
+                    vec_sum.val[0]          = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
+                    vec_sum.val[1]          = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
+                    vec_sum.val[2]          = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
+                    vec_sum.val[3]          = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
                 }
                 else
                 {
-                    vec_elements_flt = vexp(vmul_n(vec_elements_flt, scale_beta));
-                    vec_sum          = vadd(vec_sum, vec_elements_flt);
+                    vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
+                    vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
+                    vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
+                    vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
+                    vec_sum.val[0]          = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
+                    vec_sum.val[1]          = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
+                    vec_sum.val[2]          = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
+                    vec_sum.val[3]          = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
                 }
-                vst4q_f32(tmp_ptr + i, vec_elements_flt);
+
+                vst4q_f32(tmp_ptr + x, vec_elements_flt);
             }
 
             /* Reduce sum */
-            const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]),
-                                               vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
-            const auto sum_8_byte = vadd_f32(vget_low(sum_16_byte), vget_high(sum_16_byte));
-            sum                   = reduce_add(std::plus<float>(), sum_8_byte);
+            const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
+            auto       sum_res     = vpadd_f32(vget_high_f32(sum_16_byte), vget_low_f32(sum_16_byte));
+            sum_res                = vpadd_f32(sum_res, sum_res);
+            sum                    = wrapper::vgetlane(sum_res, 0);
 
             /* Run remaining elements */
-            for(; i < input_width; ++i)
+            for(; x < input_width; ++x)
             {
                 float element{};
                 if(is_log)
                 {
-                    element = (max_val - in_ptr[i]) * scale_beta;
+                    element = (max_val - in_ptr[x]) * scale_beta;
                     sum += std::exp(element);
                 }
                 else
                 {
-                    element = std::exp((max_val - in_ptr[i]) * scale_beta);
+                    element = std::exp((max_val - in_ptr[x]) * scale_beta);
                     sum += element;
                 }
 
-                tmp_ptr[i] = element;
+                tmp_ptr[x] = element;
             }
 
             if(!is_log)
@@ -691,35 +375,45 @@
         /* Normalize exponentials */
         {
             /* Loop over row and compute softmax */
-            int i = 0;
+            int x = 0;
+            for(; x <= (input_width - vec_size); x += vec_size)
             {
-                constexpr int vec_size = 16;
-
-                for(; i <= (input_width - vec_size); i += vec_size)
-                {
-                    float32x4x4_t            vec_in = vld4q_f32(tmp_ptr + i);
-                    vec_16_byte_t<qasymm8_t> normalized_value{};
-                    if(is_log)
-                    {
-                        normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vsub_n(vec_in, sum));
-                    }
-                    else
-                    {
-                        normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vmul_n(vec_in, sum_inversed));
-                    }
-                    vst(out_ptr + i, normalized_value);
-                }
-            }
-            /* Run remaining elements */
-            for(; i < input_width; ++i)
-            {
+                float32x4x4_t vec_in = vld4q_f32(tmp_ptr + x);
+                uint8x16_t    normalized_value{};
                 if(is_log)
                 {
-                    out_ptr[i] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[i] - sum);
+                    const float32x4x4_t sub =
+                    {
+                        vsubq_f32(vec_in.val[0], vdupq_n_f32(sum)),
+                        vsubq_f32(vec_in.val[1], vdupq_n_f32(sum)),
+                        vsubq_f32(vec_in.val[2], vdupq_n_f32(sum)),
+                        vsubq_f32(vec_in.val[3], vdupq_n_f32(sum)),
+                    };
+                    convert_float32x4x4_to_unit8x16(sub, normalized_value);
                 }
                 else
                 {
-                    out_ptr[i] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[i] * sum_inversed);
+                    const float32x4x4_t mul =
+                    {
+                        vmulq_f32(vec_in.val[0], vdupq_n_f32(sum_inversed)),
+                        vmulq_f32(vec_in.val[1], vdupq_n_f32(sum_inversed)),
+                        vmulq_f32(vec_in.val[2], vdupq_n_f32(sum_inversed)),
+                        vmulq_f32(vec_in.val[3], vdupq_n_f32(sum_inversed)),
+                    };
+                    convert_float32x4x4_to_unit8x16(mul, normalized_value);
+                }
+                vst1q_u8(out_ptr + x, normalized_value);
+            }
+            /* Run remaining elements */
+            for(; x < input_width; ++x)
+            {
+                if(is_log)
+                {
+                    out_ptr[x] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[x] - sum);
+                }
+                else
+                {
+                    out_ptr[x] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[x] * sum_inversed);
                 }
             }
         }
@@ -738,6 +432,12 @@
     Iterator max_it(&max, window);
     Iterator out_it(&out, window);
 
+    /** NEON vector tag type. */
+    using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+    constexpr int vec_size   = 16 / sizeof(T);
+    const int     sum_stages = log2(vec_size / 2);
+
     execute_window_loop(window, [&](const Coordinates &)
     {
         /* Get pointers */
@@ -752,53 +452,54 @@
         {
             /* Get max value */
             const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
-            const auto vec_max = vdup_n<vec_16_byte_t<T>>(max_val);
+            const auto vec_max = wrapper::vdup_n(max_val, ExactTagType{});
 
             /* Init sum to zero */
-            auto vec_sum = vdup_n<vec_16_byte_t<T>>(0);
+            auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
 
             /* Loop over row and compute exponentials and sum */
-            int           i        = 0;
-            constexpr int vec_size = vec_size_of(vec_sum);
-
-            for(; i <= (input_width - vec_size); i += vec_size)
+            int x = 0;
+            for(; x <= (input_width - vec_size); x += vec_size)
             {
-                auto vec_elements = vld<vec_16_byte_t<T>>(in_ptr + i);
-                vec_elements      = vsub(vec_elements, vec_max);
+                auto vec_elements = wrapper::vloadq(in_ptr + x);
+                vec_elements      = wrapper::vsub(vec_elements, vec_max);
                 if(is_log)
                 {
-                    vec_elements = vmul_n(vec_elements, static_cast<T>(beta));
-                    vec_sum      = vadd(vec_sum, vexp(vec_elements));
+                    vec_elements = wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}));
+                    vec_sum      = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
                 }
                 else
                 {
-                    vec_elements = vexp(vmul_n(vec_elements, static_cast<T>(beta)));
-                    vec_sum      = vadd(vec_sum, vec_elements);
+                    vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{})));
+                    vec_sum      = wrapper::vadd(vec_sum, vec_elements);
                 }
-                vst(tmp_ptr + i, vec_elements);
+                wrapper::vstore(tmp_ptr + x, vec_elements);
             }
 
             /* Reduce sum */
-            const auto sum_8_byte = vadd(vget_high(vec_sum), vget_low(vec_sum));
-            sum                   = reduce_add([](T a, T b) -> T { return a + b; }, sum_8_byte);
+            auto sum_res = wrapper::vpadd(wrapper::vgethigh(vec_sum), wrapper::vgetlow(vec_sum));
+            for(int i = 0; i < sum_stages; ++i)
+            {
+                sum_res = wrapper::vpadd(sum_res, sum_res);
+            }
+            sum = wrapper::vgetlane(sum_res, 0);
 
             /* Run remaining elements */
-
-            for(; i < input_width; ++i)
+            for(; x < input_width; ++x)
             {
                 T element{};
 
                 if(is_log)
                 {
-                    element = (in_ptr[i] - max_val) * beta;
+                    element = (in_ptr[x] - max_val) * beta;
                     sum += std::exp(element);
                 }
                 else
                 {
-                    element = std::exp((in_ptr[i] - max_val) * beta);
+                    element = std::exp((in_ptr[x] - max_val) * beta);
                     sum += element;
                 }
-                tmp_ptr[i] = element;
+                tmp_ptr[x] = element;
             }
 
             if(!is_log)
@@ -810,36 +511,31 @@
         /* Normalize exponentials */
         {
             /* Loop over row and compute softmax */
-            int i = 0;
-
+            int x = 0;
+            for(; x <= (input_width - vec_size); x += vec_size)
             {
-                constexpr int vec_size = vec_size_of(vec_16_byte_t<T> {});
-
-                for(; i <= (input_width - vec_size); i += vec_size)
-                {
-                    auto             vec_in = vld<vec_16_byte_t<T>>(tmp_ptr + i);
-                    vec_16_byte_t<T> normalized_value{};
-                    if(is_log)
-                    {
-                        normalized_value = vsub(vec_in, vdup_n<vec_16_byte_t<T>>(sum));
-                    }
-                    else
-                    {
-                        normalized_value = vmul_n(vec_in, sum_inversed);
-                    }
-                    vst(out_ptr + i, normalized_value);
-                }
-            }
-            /* Run remaining elements */
-            for(; i < input_width; ++i)
-            {
+                auto vec_in           = wrapper::vloadq(tmp_ptr + x);
+                auto normalized_value = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
                 if(is_log)
                 {
-                    out_ptr[i] = tmp_ptr[i] - sum;
+                    normalized_value = wrapper::vsub(vec_in, wrapper::vdup_n(static_cast<T>(sum), ExactTagType{}));
                 }
                 else
                 {
-                    out_ptr[i] = tmp_ptr[i] * sum_inversed;
+                    normalized_value = wrapper::vmul(vec_in, wrapper::vdup_n(static_cast<T>(sum_inversed), ExactTagType{}));
+                }
+                wrapper::vstore(out_ptr + x, normalized_value);
+            }
+            /* Run remaining elements */
+            for(; x < input_width; ++x)
+            {
+                if(is_log)
+                {
+                    out_ptr[x] = tmp_ptr[x] - sum;
+                }
+                else
+                {
+                    out_ptr[x] = tmp_ptr[x] * sum_inversed;
                 }
             }
         }
diff --git a/tests/validation/NEON/SoftmaxLayer.cpp b/tests/validation/NEON/SoftmaxLayer.cpp
index 8f91b51..7f8c622 100644
--- a/tests/validation/NEON/SoftmaxLayer.cpp
+++ b/tests/validation/NEON/SoftmaxLayer.cpp
@@ -162,12 +162,12 @@
     validate(Accessor(_target), _reference, tolerance_f16);
 }
 FIXTURE_DATA_TEST_CASE(RunSmall4D, NESoftmaxLayerFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::Small4DShapes(),
-                                                                                                                   framework::dataset::make("DataType", DataType::F32)),
+                                                                                                                   framework::dataset::make("DataType", DataType::F16)),
                                                                                                                    framework::dataset::make("Beta", { 1.0f, 2.0f })),
                                                                                                            framework::dataset::make("Axis", { 1, 2, 3 })))
 {
     // Validate output
-    validate(Accessor(_target), _reference, tolerance_f32);
+    validate(Accessor(_target), _reference, tolerance_f16);
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::SoftmaxLayerLargeShapes(),
                                                                                                                        framework::dataset::make("DataType", DataType::F16)),