IVGCVSW-798 Add Softmax NEON support for QASYMM8

Change-Id: I4f2cca52caf210fdb7d6bb7e9436ac51cb5088b4
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/112398
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
index b13fb0e..13d87a0 100644
--- a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
+++ b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -33,285 +33,433 @@
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/Window.h"
+#include "arm_compute/core/utils/misc/utility.h"
 
 #include <algorithm>
 #include <arm_neon.h>
 #include <cfloat>
+#include <functional>
 
-using namespace arm_compute;
+namespace arm_compute
+{
+template <typename T, int N>
+struct vec_n_type;
+
+#define DECLARE_NEON_VEC_TYPE(T, N, V) \
+    template <>                        \
+    struct vec_n_type<T, N>            \
+    {                                  \
+        using type = V;                \
+    };
+
+DECLARE_NEON_VEC_TYPE(uint8_t, 16, uint8x16_t)
+DECLARE_NEON_VEC_TYPE(uint8_t, 8, uint8x8_t)
+
+DECLARE_NEON_VEC_TYPE(int8_t, 16, int8x16_t)
+DECLARE_NEON_VEC_TYPE(int8_t, 8, int8x8_t)
+
+DECLARE_NEON_VEC_TYPE(uint16_t, 8, uint16x8_t)
+DECLARE_NEON_VEC_TYPE(uint16_t, 4, uint16x4_t)
+
+DECLARE_NEON_VEC_TYPE(int16_t, 8, int16x8_t)
+DECLARE_NEON_VEC_TYPE(int16_t, 4, int16x4_t)
+
+DECLARE_NEON_VEC_TYPE(int32_t, 4, int32x4_t)
+DECLARE_NEON_VEC_TYPE(int32_t, 2, int32x2_t)
+
+DECLARE_NEON_VEC_TYPE(uint32_t, 4, uint32x4_t)
+DECLARE_NEON_VEC_TYPE(uint32_t, 2, uint32x2_t)
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+DECLARE_NEON_VEC_TYPE(float16_t, 8, float16x8_t)
+DECLARE_NEON_VEC_TYPE(float16_t, 4, float16x4_t)
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+
+DECLARE_NEON_VEC_TYPE(float, 4, float32x4_t)
+DECLARE_NEON_VEC_TYPE(float, 2, float32x2_t)
+
+template <typename T, int N>
+using vec_n_t = typename vec_n_type<T, N>::type;
+
+template <typename T, int N>
+using vec_n_byte_t = vec_n_t < T, N / sizeof(T) >;
+
+template <typename T>
+using vec_16_byte_t = vec_n_byte_t<T, 16>;
+
+template <typename T>
+using vec_8_byte_t = vec_n_byte_t<T, 8>;
+
+template <typename T>
+using const_ptr_t = const T *;
+
+template <typename T>
+using ptr_t = T *;
+
+#define FORWARD_DECLARE_VGET_LANE_FOR_TYPE(TYPE) \
+    template <int lane>                          \
+    TYPE vget_lane(vec_8_byte_t<TYPE> vec);      \
+    template <int lane>                          \
+    TYPE vget_lane(vec_16_byte_t<TYPE> vec);
+
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint8_t)
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int8_t)
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint16_t)
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int16_t)
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(uint32_t)
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(int32_t)
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float16_t)
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+FORWARD_DECLARE_VGET_LANE_FOR_TYPE(float)
+template <int lane>
+float vget_lane(float32x4x4_t vec);
+
+template <typename V>
+using elem_type_t = decltype(vget_lane<0>(std::declval<V>()));
+
+template <typename V>
+constexpr size_t vec_size_of(const V &vec)
+{
+    return sizeof(vec) / sizeof(elem_type_t<V>);
+}
+
+template <typename V>
+V vdup_n(elem_type_t<V> val);
+template <typename V>
+V vld(const_ptr_t<elem_type_t<V>> ptr);
+
+#define DECLARE_NEON_FUNCTIONS_FOR_TYPE(TYPE, TAG)                                \
+    template <>                                                                   \
+    inline vec_8_byte_t<TYPE> vdup_n<vec_8_byte_t<TYPE>>(TYPE val)                \
+    {                                                                             \
+        return vdup_n_##TAG(val);                                                 \
+    }                                                                             \
+    template <>                                                                   \
+    inline vec_16_byte_t<TYPE> vdup_n<vec_16_byte_t<TYPE>>(TYPE val)              \
+    {                                                                             \
+        return vdupq_n_##TAG(val);                                                \
+    }                                                                             \
+    template <>                                                                   \
+    inline vec_8_byte_t<TYPE> vld<vec_8_byte_t<TYPE>>(const_ptr_t<TYPE> ptr)      \
+    {                                                                             \
+        return vld1_##TAG(ptr);                                                   \
+    }                                                                             \
+    template <>                                                                   \
+    inline vec_16_byte_t<TYPE> vld<vec_16_byte_t<TYPE>>(const_ptr_t<TYPE> ptr)    \
+    {                                                                             \
+        return vld1q_##TAG(ptr);                                                  \
+    }                                                                             \
+    inline void vst(ptr_t<TYPE> ptr, vec_8_byte_t<TYPE> vec)                      \
+    {                                                                             \
+        vst1_##TAG(ptr, vec);                                                     \
+    }                                                                             \
+    inline void vst(ptr_t<TYPE> ptr, vec_16_byte_t<TYPE> vec)                     \
+    {                                                                             \
+        vst1q_##TAG(ptr, vec);                                                    \
+    }                                                                             \
+    inline vec_16_byte_t<TYPE> vmax(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
+    {                                                                             \
+        return vmaxq_##TAG(a, b);                                                 \
+    }                                                                             \
+    inline vec_8_byte_t<TYPE> vpmax(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b)   \
+    {                                                                             \
+        return vpmax_##TAG(a, b);                                                 \
+    }                                                                             \
+    inline vec_8_byte_t<TYPE> vget_low(vec_16_byte_t<TYPE> vec)                   \
+    {                                                                             \
+        return vget_low_##TAG(vec);                                               \
+    }                                                                             \
+    inline vec_8_byte_t<TYPE> vget_high(vec_16_byte_t<TYPE> vec)                  \
+    {                                                                             \
+        return vget_high_##TAG(vec);                                              \
+    }                                                                             \
+    template <int lane>                                                           \
+    inline TYPE vget_lane(vec_8_byte_t<TYPE> vec)                                 \
+    {                                                                             \
+        static_assert(lane >= 0, "lane is out of bounds");                        \
+        static_assert(lane < vec_size_of(vec), "lane is out of bounds");          \
+        return vget_lane_##TAG(vec, lane);                                        \
+    }                                                                             \
+    template <int lane>                                                           \
+    inline TYPE vget_lane(vec_16_byte_t<TYPE> vec)                                \
+    {                                                                             \
+        static_assert(lane >= 0, "lane is out of bounds");                        \
+        static_assert(lane < vec_size_of(vec), "lane is out of bounds");          \
+        return vgetq_lane_##TAG(vec, lane);                                       \
+    }
+
+template <typename T>
+T sqadd(T a, T b);
+template <typename T>
+T sqsub(T a, T b);
+template <typename T>
+T sqmul(T a, T b, int fixed_point_position);
+
+#define DECLARE_NEON_FUNCTIONS_FOR_FIXED_POINT(TYPET, TYPEU, TAGT, TAGU)                                        \
+    inline vec_8_byte_t<TYPET> vqsub(vec_8_byte_t<TYPET> a, vec_8_byte_t<TYPET> b)                              \
+    {                                                                                                           \
+        return vqsub_##TAGT(a, b);                                                                              \
+    }                                                                                                           \
+    inline vec_8_byte_t<TYPEU> vqadd(vec_8_byte_t<TYPEU> a, vec_8_byte_t<TYPEU> b)                              \
+    {                                                                                                           \
+        return vqadd_##TAGU(a, b);                                                                              \
+    }                                                                                                           \
+    inline vec_16_byte_t<TYPEU> vqadd(vec_16_byte_t<TYPEU> a, vec_16_byte_t<TYPEU> b)                           \
+    {                                                                                                           \
+        return vqaddq_##TAGU(a, b);                                                                             \
+    }                                                                                                           \
+    inline vec_8_byte_t<TYPET> vqexp(vec_8_byte_t<TYPET> vec, int fixed_point_position)                         \
+    {                                                                                                           \
+        return vqexp_q##TAGT(vec, fixed_point_position);                                                        \
+    }                                                                                                           \
+    inline auto vmovl(vec_8_byte_t<TYPET> vec)->decltype(vmovl_##TAGT(vec))                                     \
+    {                                                                                                           \
+        return vmovl_##TAGT(vec);                                                                               \
+    }                                                                                                           \
+    inline vec_16_byte_t<TYPET> vqrecip(vec_16_byte_t<TYPET> vec, int fixed_point_position)                     \
+    {                                                                                                           \
+        return vqrecipq_q##TAGT(vec, fixed_point_position);                                                     \
+    }                                                                                                           \
+    inline vec_16_byte_t<TYPET> vqmul(vec_16_byte_t<TYPET> a, vec_16_byte_t<TYPET> b, int fixed_point_position) \
+    {                                                                                                           \
+        return vqmulq_q##TAGT(a, b, fixed_point_position);                                                      \
+    }                                                                                                           \
+    template <>                                                                                                 \
+    inline TYPEU sqadd<TYPEU>(TYPEU a, TYPEU b)                                                                 \
+    {                                                                                                           \
+        return sqadd_q##TAGU(a, b);                                                                             \
+    }                                                                                                           \
+    inline TYPET sqexp(TYPET val, int fixed_point_position)                                                     \
+    {                                                                                                           \
+        return sqexp_q##TAGT(val, fixed_point_position);                                                        \
+    }                                                                                                           \
+    template <>                                                                                                 \
+    inline TYPET sqsub<TYPET>(TYPET a, TYPET b)                                                                 \
+    {                                                                                                           \
+        return sqsub_q##TAGT(a, b);                                                                             \
+    }                                                                                                           \
+    template <>                                                                                                 \
+    inline TYPET sqmul<TYPET>(TYPET a, TYPET b, int fixed_point_position)                                       \
+    {                                                                                                           \
+        return sqmul_q##TAGT(a, b, fixed_point_position);                                                       \
+    }
+
+#define DECLARE_NEON_FUNCTIONS_FOR_FLOAT(TYPE, TAG)                               \
+    inline vec_8_byte_t<TYPE> vadd(vec_8_byte_t<TYPE> a, vec_8_byte_t<TYPE> b)    \
+    {                                                                             \
+        return vadd_##TAG(a, b);                                                  \
+    }                                                                             \
+    inline vec_16_byte_t<TYPE> vadd(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
+    {                                                                             \
+        return vaddq_##TAG(a, b);                                                 \
+    }                                                                             \
+    inline vec_16_byte_t<TYPE> vsub(vec_16_byte_t<TYPE> a, vec_16_byte_t<TYPE> b) \
+    {                                                                             \
+        return vsubq_##TAG(a, b);                                                 \
+    }                                                                             \
+    inline vec_16_byte_t<TYPE> vexp(vec_16_byte_t<TYPE> vec)                      \
+    {                                                                             \
+        return vexpq_##TAG(vec);                                                  \
+    }                                                                             \
+    inline vec_16_byte_t<TYPE> vmul_n(vec_16_byte_t<TYPE> vec, TYPE val)          \
+    {                                                                             \
+        return vmulq_n_##TAG(vec, val);                                           \
+    }
+
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint8_t, u8)
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(int8_t, s8)
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint16_t, u16)
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(int16_t, s16)
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(uint32_t, u32)
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(int32_t, s32)
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(float16_t, f16)
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+DECLARE_NEON_FUNCTIONS_FOR_TYPE(float, f32)
+
+DECLARE_NEON_FUNCTIONS_FOR_FIXED_POINT(int8_t, int16_t, s8, s16)
+DECLARE_NEON_FUNCTIONS_FOR_FIXED_POINT(int16_t, int32_t, s16, s32)
+
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float16_t, f16)
+#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+DECLARE_NEON_FUNCTIONS_FOR_FLOAT(float, f32)
+
+template <typename VO, typename VI>
+VO vcvt(VI vec);
+
+template <>
+float32x4x4_t vcvt<float32x4x4_t>(uint8x16_t vec)
+{
+    const auto    low  = vmovl_u8(vget_low(vec));
+    const auto    high = vmovl_u8(vget_high(vec));
+    float32x4x4_t res  = { {
+            vcvtq_f32_u32(vmovl_u16(vget_low(low))),
+            vcvtq_f32_u32(vmovl_u16(vget_high(low))),
+            vcvtq_f32_u32(vmovl_u16(vget_low(high))),
+            vcvtq_f32_u32(vmovl_u16(vget_high(high)))
+        }
+    };
+    return res;
+}
+
+template <>
+uint8x16_t vcvt<uint8x16_t>(float32x4x4_t vec)
+{
+    uint16x8x2_t resU16 = { {
+            vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[0])),
+            vqmovn_u32(vcvtq_u32_f32(vec.val[1]))),
+            vcombine_u16(vqmovn_u32(vcvtq_u32_f32(vec.val[2])),
+            vqmovn_u32(vcvtq_u32_f32(vec.val[3])))
+        }
+    };
+
+    uint8x16_t res = vcombine_u8(vqmovn_u16(resU16.val[0]), vqmovn_u16(resU16.val[1]));
+    return res;
+}
+
+float32x4x4_t vexp(float32x4x4_t vec)
+{
+    float32x4x4_t res = { {
+            vexpq_f32(vec.val[0]),
+            vexpq_f32(vec.val[1]),
+            vexpq_f32(vec.val[2]),
+            vexpq_f32(vec.val[3])
+        }
+    };
+    return res;
+}
+
+template <>
+float32x4x4_t vdup_n<float32x4x4_t>(float val)
+{
+    float32x4x4_t res = { {
+            vdupq_n_f32(val),
+            vdupq_n_f32(val),
+            vdupq_n_f32(val),
+            vdupq_n_f32(val)
+        }
+    };
+    return res;
+}
+
+float32x4x4_t vmul_n(float32x4x4_t vec, float val)
+{
+    float32x4x4_t res = { {
+            vmulq_n_f32(vec.val[0], val),
+            vmulq_n_f32(vec.val[1], val),
+            vmulq_n_f32(vec.val[2], val),
+            vmulq_n_f32(vec.val[3], val)
+        }
+    };
+    return res;
+}
+
+float32x4x4_t vadd(float32x4x4_t a, float32x4x4_t b)
+{
+    float32x4x4_t res = { {
+            vaddq_f32(a.val[0], b.val[0]),
+            vaddq_f32(a.val[1], b.val[1]),
+            vaddq_f32(a.val[2], b.val[2]),
+            vaddq_f32(a.val[3], b.val[3])
+        }
+    };
+    return res;
+}
 
 namespace
 {
-Status validate_arguments_logits_1d_max(const ITensorInfo *input, const ITensorInfo *output)
+Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
 {
-    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
-
-    // Checks performed when output is configured
-    if(output->total_size() != 0)
-    {
-        // Softmax across the x dimension
-        TensorShape output_shape{ input->tensor_shape() };
-        output_shape.set(0, 1);
-
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
-    }
-
-    return Status{};
-}
-
-std::pair<Status, Window> validate_and_configure_window_logits_1d_max(ITensorInfo *input, ITensorInfo *output)
-{
-    // Configure kernel window
-    constexpr unsigned int num_elems_written_per_row = 1;
-    const int              input_width               = input->valid_region().shape.x();
-
-    unsigned int           num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
-    Window                 win                               = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
-    AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
-    bool                   window_changed = false;
-
-    if(output->total_size() != 0)
-    {
-        AccessWindowHorizontal output_access(output, 0, num_elems_written_per_row, 1.f / input_width);
-        window_changed = update_window_and_padding(win, input_access, output_access);
-        output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
-    }
-    else
-    {
-        window_changed = update_window_and_padding(win, input_access);
-    }
-
-    Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
-    return std::make_pair(err, win);
-}
-
-Status validate_arguments_logits_1d_shift_exp_sum(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum, float beta)
-{
-    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, max, sum, output);
-    ARM_COMPUTE_RETURN_ERROR_ON((beta != 1.0f) && is_data_type_fixed_point(input->data_type()));
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
-
-    // Checks performed when output is configured
-    if(output->total_size() != 0)
-    {
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
-    }
-
-    // Checks performed when sum is configured
-    if(sum->total_size() != 0)
-    {
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, max, sum);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(max, sum);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, max, sum);
-    }
-
-    return Status{};
-}
-
-std::pair<Status, Window> validate_and_configure_window_logits_1d_shift_exp_sum(ITensorInfo *input, ITensorInfo *max, ITensorInfo *output, ITensorInfo *sum)
-{
-    unsigned int num_elems_processed_per_iteration = input->valid_region().shape.x();
-
-    // Configure kernel window
-    Window                 win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
-    AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
-    AccessWindowHorizontal max_access(max, 0, 1);
-    AccessWindowHorizontal sum_access(sum, 0, 1);
-    bool                   window_changed = false;
-
-    if(output->total_size() != 0)
-    {
-        AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
-        window_changed = update_window_and_padding(win, input_access, max_access, output_access, sum_access);
-        output_access.set_valid_region(win, input->valid_region());
-    }
-    else
-    {
-        window_changed = update_window_and_padding(win, input_access, max_access, sum_access);
-    }
-
-    sum_access.set_valid_region(win, ValidRegion(Coordinates(), sum->tensor_shape()));
-
-    Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
-    return std::make_pair(err, win);
-}
-
-Status validate_arguments_logits_1d_norm(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
-{
-    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, sum, output);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::S32, DataType::F16, DataType::F32);
-    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, sum);
-    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, sum);
-
-    // Checks performed when output is configured
-    if(output->total_size() != 0)
-    {
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
-    }
-
-    return Status{};
-}
-
-std::pair<Status, Window> validate_and_configure_window_logits_1d_norm(ITensorInfo *input, ITensorInfo *sum, ITensorInfo *output)
-{
-    // Configure kernel window
-    unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->data_type());
-    Window       win                               = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
-
-    AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
-    AccessWindowStatic     sum_access(sum, 0, 0, 1, sum->dimension(1));
-    bool                   window_changed = false;
-
-    if(output->total_size() != 0)
-    {
-        AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
-
-        window_changed = update_window_and_padding(win, input_access, sum_access, output_access);
-
-        output_access.set_valid_region(win, input->valid_region());
-    }
-    else
-    {
-        window_changed = update_window_and_padding(win, input_access, sum_access);
-    }
-    Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
-    return std::make_pair(err, win);
-}
-
-void logits_1d_max_qs8(const ITensor *in, ITensor *out, const Window &window)
-{
-    Window in_slice = window.first_slice_window_1D();
-
-    Window window_max(window);
-    window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
-    Window max_slice = window_max.first_slice_window_1D();
-
-    do
-    {
-        Iterator input(in, in_slice);
-        Iterator output(out, max_slice);
-
-        qint8x16_t vec_max = vdupq_n_s8(std::numeric_limits<qint8_t>::lowest());
-
-        execute_window_loop(in_slice, [&](const Coordinates & id)
-        {
-            const auto       in_ptr        = reinterpret_cast<const qint8_t *>(input.ptr());
-            const qint8x16_t current_value = vld1q_qs8(in_ptr);
-            vec_max                        = vmaxq_qs8(vec_max, current_value);
-        },
-        input);
-
-        qint8x8_t carry_max = vpmax_qs8(vget_high_s8(vec_max), vget_low_s8(vec_max));
-        carry_max           = vpmax_qs8(carry_max, carry_max);
-        carry_max           = vpmax_qs8(carry_max, carry_max);
-        carry_max           = vpmax_qs8(carry_max, carry_max);
-
-        *(reinterpret_cast<qint8_t *>(output.ptr())) = vget_lane_s8(carry_max, 0);
-    }
-    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
-}
-void logits_1d_max_qs16(const ITensor *in, ITensor *out, const Window &window)
-{
-    Window in_slice = window.first_slice_window_1D();
-
-    Window window_max(window);
-    window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
-    Window max_slice = window_max.first_slice_window_1D();
-
-    do
-    {
-        Iterator input(in, in_slice);
-        Iterator output(out, max_slice);
-
-        qint16x8_t vec_max = vdupq_n_qs16(std::numeric_limits<qint16_t>::lowest());
-
-        execute_window_loop(in_slice, [&](const Coordinates & id)
-        {
-            const auto       in_ptr        = reinterpret_cast<const qint16_t *>(input.ptr());
-            const qint16x8_t current_value = vld1q_qs16(in_ptr);
-            vec_max                        = vmaxq_qs16(vec_max, current_value);
-        },
-        input);
-
-        qint16x4_t carry_max = vpmax_qs16(vget_high_qs16(vec_max), vget_low_qs16(vec_max));
-        carry_max            = vpmax_qs16(carry_max, carry_max);
-        carry_max            = vpmax_qs16(carry_max, carry_max);
-
-        *(reinterpret_cast<qint16_t *>(output.ptr())) = vget_lane_s16(carry_max, 0);
-    }
-    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
-}
-
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-void logits_1d_max_f16(const ITensor *in, ITensor *out, const Window &window)
-{
-    Window in_slice = window.first_slice_window_1D();
-
-    Window window_max(window);
-    window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
-    Window max_slice = window_max.first_slice_window_1D();
-
-    do
-    {
-        Iterator input(in, in_slice);
-        Iterator output(out, max_slice);
-
-        float16x8_t vec_max = vdupq_n_f16(std::numeric_limits<float16_t>::lowest());
-
-        execute_window_loop(in_slice, [&](const Coordinates & id)
-        {
-            const auto        in_ptr        = reinterpret_cast<const float16_t *>(input.ptr());
-            const float16x8_t current_value = vld1q_f16(in_ptr);
-            vec_max                         = vmaxq_f16(vec_max, current_value);
-        },
-        input);
-
-        float16x4_t carry_max = vpmax_f16(vget_high_f16(vec_max), vget_low_f16(vec_max));
-        carry_max             = vpmax_f16(carry_max, carry_max);
-        carry_max             = vpmax_f16(carry_max, carry_max);
-
-        *(reinterpret_cast<float16_t *>(output.ptr())) = vget_lane_f16(carry_max, 0);
-    }
-    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
-}
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+#else  /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F32);
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
 
-void logits_1d_max_f32(const ITensor *in, ITensor *out, const Window &window)
-{
-    Window in_slice = window.first_slice_window_1D();
-
-    Window window_max(window);
-    window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
-    Window max_slice = window_max.first_slice_window_1D();
-
-    do
+    // Validate in case of configured output
+    if(output.total_size() != 0)
     {
-        Iterator input(in, in_slice);
-        Iterator output(out, max_slice);
-
-        float32x4_t vec_max = vdupq_n_f32(-FLT_MAX);
-
-        execute_window_loop(in_slice, [&](const Coordinates & id)
-        {
-            const auto        in_ptr        = reinterpret_cast<const float *>(input.ptr());
-            const float32x4_t current_value = vld1q_f32(in_ptr);
-            vec_max                         = vmaxq_f32(vec_max, current_value);
-        },
-        input);
-
-        float32x2_t carry_max = vpmax_f32(vget_high_f32(vec_max), vget_low_f32(vec_max));
-        carry_max             = vpmax_f32(carry_max, carry_max);
-
-        *(reinterpret_cast<float *>(output.ptr())) = vget_lane_f32(carry_max, 0);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &output);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &output);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output.tensor_shape(), TensorShape(input.tensor_shape()).set(0, 1));
     }
-    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
+
+    return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window_logits_1d_max(ITensorInfo &input, ITensorInfo &output)
+{
+    // Softmax across the x dimension
+    const TensorShape output_shape = TensorShape(input.tensor_shape()).set(0, 1);
+    // Output auto initialization if not yet initialized
+    auto_init_if_empty(output, output_shape, 1, input.data_type(), input.fixed_point_position(), input.quantization_info());
+
+    // Configure kernel window
+    const int input_width                       = input.valid_region().shape.x();
+    const int num_elems_processed_per_iteration = 16U / data_size_from_type(input.data_type());
+    const int num_elems_read_per_iteration      = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
+
+    const ValidRegion out_valid_region(ValidRegion(input.valid_region()).set(0, 0, 1));
+    output.set_valid_region(out_valid_region);
+
+    Window win = calculate_max_window(output);
+
+    AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), num_elems_read_per_iteration);
+    AccessWindowHorizontal output_access(&output, 0, 1);
+
+    const bool window_changed = update_window_and_padding(win, input_access, output_access);
+
+    const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+    return std::make_pair(err, win);
+}
+
+template <typename V>
+auto reduce_max(V vec) -> elem_type_t<V>
+{
+    constexpr int N = vec_size_of(vec);
+
+    auto carry_max = vpmax(vget_high(vec), vget_low(vec));
+
+    for(int k = N / 2; k > 1; k /= 2)
+    {
+        carry_max = vpmax(carry_max, carry_max);
+    }
+
+    return vget_lane<0>(carry_max);
+}
+
+template <typename T>
+void logits_1d_max(const ITensor &in, ITensor &out, const Window &window)
+{
+    const auto   start_x     = in.info()->valid_region().anchor.x();
+    const size_t input_width = in.info()->valid_region().shape.x();
+
+    Iterator input(&in, window);
+    Iterator output(&out, window);
+
+    execute_window_loop(window, [&](const Coordinates &)
+    {
+        // Get pointers
+        const auto in_ptr  = reinterpret_cast<const T *>(input.ptr()) + start_x;
+        const auto out_ptr = reinterpret_cast<T *>(output.ptr());
+
+        // Init max value
+        auto vec_max = vdup_n<vec_16_byte_t<T>>(std::numeric_limits<T>::lowest());
+
+        // Loop over input row
+        for(const T *it = in_ptr; it < (in_ptr + input_width); it += vec_size_of(vec_max))
+        {
+            const auto current_value = vld<vec_16_byte_t<T>>(it);
+            vec_max                  = vmax(vec_max, current_value);
+        }
+
+        const T max_val = reduce_max(vec_max);
+        *out_ptr        = max_val;
+    },
+    input, output);
 }
 } // namespace
 
@@ -328,54 +476,54 @@
 void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
-
-    // Softmax across the x dimension
-    TensorShape output_shape{ input->info()->tensor_shape() };
-    output_shape.set(0, 1);
-
-    // Output auto initialization if not yet initialized
-    auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
-
+    ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), output->info());
     // Perform validation step
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(input->info(), output->info()));
-
-    const int    input_width                       = input->info()->valid_region().shape.x();
-    unsigned int num_elems_processed_per_iteration = 16 / data_size_from_type(input->info()->data_type());
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_max(*input->info(), *output->info()));
+    // Configure kernel window
+    auto win_config = validate_and_configure_window_logits_1d_max(*input->info(), *output->info());
+    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
 
     switch(input->info()->data_type())
     {
+        case DataType::QASYMM8:
+            _func = &logits_1d_max<qasymm8_t>;
+            break;
         case DataType::QS8:
-            _func = &logits_1d_max_qs8;
+            _func = &logits_1d_max<qint8_t>;
             break;
         case DataType::QS16:
-            _func = &logits_1d_max_qs16;
+            _func = &logits_1d_max<qint16_t>;
             break;
-        case DataType::F32:
-            _func = &logits_1d_max_f32;
-            break;
-        case DataType::F16:
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-            _func = &logits_1d_max_f16;
+        case DataType::F16:
+            _func = &logits_1d_max<float16_t>;
             break;
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+        case DataType::F32:
+            _func = &logits_1d_max<float>;
+            break;
         default:
             ARM_COMPUTE_ERROR("Unsupported data type.");
     }
 
-    _input       = input;
-    _output      = output;
-    _border_size = BorderSize(0, num_elems_processed_per_iteration - (input_width % num_elems_processed_per_iteration), 0, 0);
+    _input  = input;
+    _output = output;
 
-    // Configure kernel window
-    auto win_config = validate_and_configure_window_logits_1d_max(input->info(), output->info());
-    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+    const int input_width                       = input->info()->valid_region().shape.x();
+    const int num_elems_processed_per_iteration = 16U / data_size_from_type(input->info()->data_type());
+    const int num_elems_read_per_iteration      = ceil_to_multiple(input_width, num_elems_processed_per_iteration);
+
+    _border_size = BorderSize(0, num_elems_read_per_iteration - input_width, 0, 0);
+
     INEKernel::configure(win_config.second);
 }
 
 Status NELogits1DMaxKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
 {
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(input, output));
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_max(input->clone().get(), output->clone().get()).first);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
+
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_max(*input, *output));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_max(*input->clone(), *output->clone()).first);
 
     return Status{};
 }
@@ -387,297 +535,393 @@
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
     ARM_COMPUTE_ERROR_ON(_func == nullptr);
 
-    (*_func)(_input, _output, window);
+    (*_func)(*_input, *_output, window);
 }
 
 namespace
 {
-void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
+Status validate_arguments_logits_softmax(const ITensorInfo &input, const ITensorInfo &max,
+                                         const ITensorInfo &output, const float beta, const ITensorInfo &tmp)
 {
-    ARM_COMPUTE_UNUSED(beta);
-
-    Window window_max(window);
-    window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
-
-    Window max_slice = window_max.first_slice_window_1D();
-    Window in_slice  = window.first_slice_window_1D();
-
-    constexpr int step                 = 8;
-    const int     long_steps           = in->info()->valid_region().shape.x() / step;
-    const int     small_steps          = in->info()->valid_region().shape.x() % step;
-    const int     fixed_point_position = in->info()->fixed_point_position();
-
-    do
-    {
-        Iterator input(in, in_slice);
-        Iterator exp(out, in_slice);
-        Iterator _max(max, max_slice);
-        Iterator _sum(sum, max_slice);
-
-        // Get pointers
-        auto in_ptr  = reinterpret_cast<const qint8_t *>(input.ptr());
-        auto exp_ptr = reinterpret_cast<qint8_t *>(exp.ptr());
-
-        // Init sum to zero
-        qint16x8_t vec_sum_value = vdupq_n_qs16(0);
-
-        // Get max value
-        const auto      max_ptr = reinterpret_cast<const qint8_t *>(_max.ptr());
-        const qint8x8_t vec_max = vdup_n_qs8(*max_ptr);
-
-        // Run neon loop
-        for(int i = 0; i < long_steps; ++i)
-        {
-            qint8x8_t vec_elements = vld1_qs8(in_ptr);
-            vec_elements           = vqsub_qs8(vec_elements, vec_max);
-            vec_elements           = vqexp_qs8(vec_elements, fixed_point_position);
-
-            vst1_qs8(exp_ptr, vec_elements);
-            vec_sum_value = vqaddq_qs16(vec_sum_value, vmovl_s8(vec_elements));
-
-            in_ptr += step;
-            exp_ptr += step;
-        }
-        // Reduce sum
-        const qint16x4_t sum_red = vqadd_qs16(vget_low_s16(vec_sum_value), vget_high_s16(vec_sum_value));
-        const qint16_t   sum0    = sqadd_qs16(vget_lane_s16(sum_red, 0), vget_lane_s16(sum_red, 1));
-        const qint16_t   sum1    = sqadd_qs16(vget_lane_s16(sum_red, 2), vget_lane_s16(sum_red, 3));
-        qint16_t         sum     = sqadd_qs16(sum0, sum1);
-
-        // Run remaining elements
-        for(int i = 0; i < small_steps; ++i)
-        {
-            qint8_t element = sqexp_qs8(sqsub_qs8(in_ptr[i], *max_ptr), fixed_point_position);
-            exp_ptr[i]      = element;
-            sum             = sqadd_qs16(sum, element);
-        }
-
-        *(reinterpret_cast<qint8_t *>(_sum.ptr())) = sqmovn_qs16(sum);
-    }
-    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
-}
-void logits_1d_shift_exp_sum_qs16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
-{
-    ARM_COMPUTE_UNUSED(beta);
-
-    Window window_max(window);
-    window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
-
-    Window max_slice = window_max.first_slice_window_1D();
-    Window in_slice  = window.first_slice_window_1D();
-
-    constexpr int step                 = 4;
-    const int     long_steps           = in->info()->valid_region().shape.x() / step;
-    const int     small_steps          = in->info()->valid_region().shape.x() % step;
-    const int     fixed_point_position = in->info()->fixed_point_position();
-
-    do
-    {
-        Iterator input(in, in_slice);
-        Iterator exp(out, in_slice);
-        Iterator _max(max, max_slice);
-        Iterator _sum(sum, max_slice);
-
-        // Get pointers
-        auto in_ptr  = reinterpret_cast<const qint16_t *>(input.ptr());
-        auto exp_ptr = reinterpret_cast<qint16_t *>(exp.ptr());
-
-        // Init sum to zero
-        qint32x4_t vec_sum_value = vdupq_n_qs32(0);
-
-        // Get max value
-        const auto       max_ptr = reinterpret_cast<const qint16_t *>(_max.ptr());
-        const qint16x4_t vec_max = vdup_n_qs16(*max_ptr);
-
-        // Run neon loop
-        for(int i = 0; i < long_steps; ++i)
-        {
-            qint16x4_t vec_elements = vld1_qs16(in_ptr);
-            vec_elements            = vqsub_qs16(vec_elements, vec_max);
-            vec_elements            = vqexp_qs16(vec_elements, fixed_point_position);
-
-            vst1_qs16(exp_ptr, vec_elements);
-            vec_sum_value = vqaddq_qs32(vec_sum_value, vmovl_s16(vec_elements));
-
-            in_ptr += step;
-            exp_ptr += step;
-        }
-        // Reduce sum
-        qint32x2_t carry_addition = vqadd_qs32(vget_high_s32(vec_sum_value), vget_low_s32(vec_sum_value));
-        qint32_t   sum            = vget_lane_s32(carry_addition, 0) + vget_lane_s32(carry_addition, 1);
-
-        // Run remaining elements
-        for(int i = 0; i < small_steps; ++i)
-        {
-            qint16_t element = sqexp_qs16(sqsub_qs16(in_ptr[i], *max_ptr), fixed_point_position);
-            exp_ptr[i]       = element;
-            sum              = sqadd_qs32(sum, element);
-        }
-
-        *(reinterpret_cast<qint16_t *>(_sum.ptr())) = sqmovn_qs32(sum);
-    }
-    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
-}
-
+    // Check input
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
-{
-    Window window_max(window);
-    window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
-
-    Window max_slice = window_max.first_slice_window_1D();
-    Window in_slice  = window.first_slice_window_1D();
-
-    constexpr int step        = 8;
-    const int     long_steps  = in->info()->valid_region().shape.x() / step;
-    const int     small_steps = in->info()->valid_region().shape.x() % step;
-
-    do
-    {
-        Iterator input(in, in_slice);
-        Iterator exp(out, in_slice);
-        Iterator _max(max, max_slice);
-        Iterator _sum(sum, max_slice);
-
-        // Get pointers
-        auto in_ptr  = reinterpret_cast<const float16_t *>(input.ptr());
-        auto exp_ptr = reinterpret_cast<float16_t *>(exp.ptr());
-
-        // Init sum to zero
-        float16x8_t vec_sum_value = vdupq_n_f16(0);
-
-        // Get max value
-        const auto        max_ptr = reinterpret_cast<const float16_t *>(_max.ptr());
-        const float16x8_t vec_max = vdupq_n_f16(*max_ptr);
-
-        // Run neon loop
-        for(int i = 0; i < long_steps; ++i)
-        {
-            float16x8_t vec_elements = vld1q_f16(in_ptr);
-            vec_elements             = vsubq_f16(vec_elements, vec_max);
-            vec_elements             = vmulq_n_f16(vec_elements, beta);
-            vec_elements             = vexpq_f16(vec_elements);
-
-            vst1q_f16(exp_ptr, vec_elements);
-            vec_sum_value = vaddq_f16(vec_sum_value, vec_elements);
-
-            in_ptr += step;
-            exp_ptr += step;
-        }
-        // Reduce sum
-        const float16x4_t sum_red        = vadd_f16(vget_low_f16(vec_sum_value), vget_high_f16(vec_sum_value));
-        const float16x4_t carry_addition = vpadd_f16(sum_red, sum_red);
-        float16_t         sum            = vget_lane_f16(carry_addition, 0) + vget_lane_f16(carry_addition, 1);
-
-        // Run remaining elements
-        for(int i = 0; i < small_steps; ++i)
-        {
-            const float16_t element = std::exp(static_cast<float>(in_ptr[i] - *max_ptr) * beta);
-            exp_ptr[i]              = element;
-            sum += element;
-        }
-        *(reinterpret_cast<float16_t *>(_sum.ptr())) = sum;
-    }
-    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
-}
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
+#else  /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input, 1, DataType::QASYMM8, DataType::QS8, DataType::QS16, DataType::F32);
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
 
-void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
-{
-    Window window_max(window);
-    window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
+    const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
 
-    Window max_slice = window_max.first_slice_window_1D();
-    Window in_slice  = window.first_slice_window_1D();
+    // Check max
+    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &max);
+    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(TensorShape(input.tensor_shape()).set(0, 1), max.tensor_shape());
+    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &max);
+    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(&input, &max);
 
-    constexpr int step        = 4;
-    const int     long_steps  = in->info()->valid_region().shape.x() / step;
-    const int     small_steps = in->info()->valid_region().shape.x() % step;
-
-    do
+    // Check output if configured
+    if(output.total_size() != 0)
     {
-        Iterator input(in, in_slice);
-        Iterator exp(out, in_slice);
-        Iterator _max(max, max_slice);
-        Iterator _sum(sum, max_slice);
-
-        // Get pointers
-        auto in_ptr  = reinterpret_cast<const float *>(input.ptr());
-        auto exp_ptr = reinterpret_cast<float *>(exp.ptr());
-
-        // Init sum to zero
-        float32x4_t vec_sum_value = vdupq_n_f32(0.0f);
-
-        // Get max value
-        const auto        max_ptr = reinterpret_cast<const float *>(_max.ptr());
-        const float32x4_t vec_max = vdupq_n_f32(*max_ptr);
-
-        // Run neon loop
-        for(int i = 0; i < long_steps; ++i)
-        {
-            float32x4_t vec_elements = vld1q_f32(in_ptr);
-            vec_elements             = vsubq_f32(vec_elements, vec_max);
-            vec_elements             = vmulq_n_f32(vec_elements, beta);
-            vec_elements             = vexpq_f32(vec_elements);
-
-            vst1q_f32(exp_ptr, vec_elements);
-            vec_sum_value = vaddq_f32(vec_elements, vec_sum_value);
-
-            in_ptr += step;
-            exp_ptr += step;
-        }
-
-        // Reduce sum
-        float32x2_t carry_addition = vpadd_f32(vget_high_f32(vec_sum_value), vget_low_f32(vec_sum_value));
-        carry_addition             = vpadd_f32(carry_addition, carry_addition);
-        float sum                  = vget_lane_f32(carry_addition, 0);
-
-        // Run remaining elements
-        for(int i = 0; i < small_steps; ++i)
-        {
-            float element = std::exp((in_ptr[i] - *max_ptr) * beta);
-            exp_ptr[i]    = element;
-            sum += element;
-        }
-
-        *(reinterpret_cast<float *>(_sum.ptr())) = sum;
+        const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input, &output);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &output);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &output);
+        ARM_COMPUTE_RETURN_ERROR_ON(output.quantization_info() != output_quantization);
     }
-    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
-}
-} //namespace
 
-NELogits1DShiftExpSumKernel::NELogits1DShiftExpSumKernel()
-    : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr), _beta(1.0f)
-{
+    // Check beta
+    ARM_COMPUTE_RETURN_ERROR_ON((beta != 1.0f) && is_data_type_fixed_point(input.data_type()));
+
+    // Check tmp if configured
+    if(tmp.total_size() != 0)
+    {
+        const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
+        ARM_COMPUTE_RETURN_ERROR_ON(tmp.data_type() != tmp_data_type);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(&input, &tmp);
+        // We could potentially reduce tmp memory if we could predict or make an assumption
+        // on the maximum number of threads that will run in parallel.
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&input, &tmp);
+    }
+
+    return Status{};
 }
 
-void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum, float beta)
+std::pair<Status, Window> validate_and_configure_window_logits_softmax(ITensorInfo &input, ITensorInfo &max,
+                                                                       ITensorInfo &output, ITensorInfo &tmp)
 {
-    ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, sum, output);
+    const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(input.data_type());
 
     // Output auto initialization if not yet initialized
-    auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
-    auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
+    const QuantizationInfo output_quantization = is_quantized_asymmetric ? QuantizationInfo(1.f / 256.f, 0) : output.quantization_info();
+    auto_init_if_empty(output, TensorInfo(input).set_quantization_info(output_quantization).reset_padding());
 
+    // Tmp auto initialization if not yet initialized
+    const DataType tmp_data_type = is_quantized_asymmetric ? DataType::F32 : input.data_type();
+    auto_init_if_empty(tmp, TensorInfo(input).set_data_type(tmp_data_type).reset_padding());
+
+    const int input_width = input.valid_region().shape.x();
+
+    Window win = calculate_max_window(max);
+
+    AccessWindowHorizontal input_access(&input, input.valid_region().anchor.x(), input_width);
+    AccessWindowHorizontal max_access(&input, 0, 1);
+    AccessWindowHorizontal output_access(&output, input.valid_region().anchor.x(), input_width);
+    AccessWindowHorizontal tmp_access(&tmp, input.valid_region().anchor.x(), input_width);
+
+    const bool window_changed = update_window_and_padding(win, input_access, max_access, output_access, tmp_access);
+
+    output.set_valid_region(input.valid_region());
+
+    const Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
+    return std::make_pair(err, win);
+}
+
+template <typename T, int N, int S, int E>
+struct reduce_add_impl
+{
+    template <typename F>
+    static T reduce(F add_fn, vec_n_t<T, N> vec)
+    {
+        constexpr int H            = (S + E + 1) / 2;
+        const auto    reduced_high = reduce_add_impl < T, N, S, H - 1 >::reduce(add_fn, vec);
+        const auto    reduced_low  = reduce_add_impl<T, N, H, E>::reduce(add_fn, vec);
+        return add_fn(reduced_high, reduced_low);
+    }
+};
+template <typename T, int N, int I>
+struct reduce_add_impl<T, N, I, I>
+{
+    template <typename F>
+    static T reduce(F /*add_fn*/, vec_n_t<T, N> vec)
+    {
+        return vget_lane<I>(vec);
+    }
+};
+template <typename V, typename F>
+elem_type_t<V> reduce_add(F add_fn, V vec)
+{
+    constexpr int N = vec_size_of(vec);
+    return reduce_add_impl < elem_type_t<V>, N, 0, N - 1 >::reduce(add_fn, vec);
+}
+
+void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
+{
+    const int start_x     = in.info()->valid_region().anchor.x();
+    const int input_width = in.info()->valid_region().shape.x();
+
+    const float scale_beta = -beta * in.info()->quantization_info().scale;
+
+    Iterator in_it(&in, window);
+    Iterator max_it(&max, window);
+    Iterator out_it(&out, window);
+
+    execute_window_loop(window, [&](const Coordinates &)
+    {
+        /* Get pointers */
+        const auto in_ptr  = reinterpret_cast<const qasymm8_t *>(in_it.ptr()) + start_x;
+        const auto out_ptr = reinterpret_cast<qasymm8_t *>(out_it.ptr()) + start_x;
+        const auto tmp_ptr = reinterpret_cast<float *>(tmp);
+
+        float sum_inversed;
+
+        /* Compute exponentials and sum */
+        {
+            /* Get max value */
+            const auto max_val = *reinterpret_cast<const qasymm8_t *>(max_it.ptr());
+            const auto vec_max = vdup_n<vec_16_byte_t<qasymm8_t>>(max_val);
+
+            /* Init sum to zero */
+            auto vec_sum = vdup_n<float32x4x4_t>(0.f);
+
+            /* Loop over row and compute exponentials and sum */
+            int           i        = 0;
+            constexpr int vec_size = vec_size_of(vec_max);
+            for(; i <= (input_width - vec_size); i += vec_size)
+            {
+                auto vec_elements = vld<vec_16_byte_t<qasymm8_t>>(in_ptr + i);
+                vec_elements      = vsubq_u8(vec_max, vec_elements);
+
+                auto vec_elements_flt = vcvt<float32x4x4_t>(vec_elements);
+                vec_elements_flt      = vexp(vmul_n(vec_elements_flt, scale_beta));
+
+                vec_sum = vadd(vec_sum, vec_elements_flt);
+
+                vst4q_f32(tmp_ptr + i, vec_elements_flt);
+            }
+            /* Reduce sum */
+            const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]),
+                                               vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
+            const auto sum_8_byte = vadd_f32(vget_low(sum_16_byte), vget_high(sum_16_byte));
+            float      sum        = reduce_add(std::plus<float>(), sum_8_byte);
+
+            /* Run remaining elements */
+            for(; i < input_width; ++i)
+            {
+                const float element = std::exp((max_val - in_ptr[i]) * scale_beta);
+                sum += element;
+                tmp_ptr[i] = element;
+            }
+
+            sum_inversed = 256.f / sum;
+        }
+
+        /* Normalize exponentials */
+        {
+            /* Loop over row and compute softmax */
+            int i = 0;
+            {
+                constexpr int vec_size = 16;
+                for(; i <= (input_width - vec_size); i += vec_size)
+                {
+                    float32x4x4_t vec_in           = vld4q_f32(tmp_ptr + i);
+                    auto          normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vmul_n(vec_in, sum_inversed));
+                    vst(out_ptr + i, normalized_value);
+                }
+            }
+            /* Run remaining elements */
+            for(; i < input_width; ++i)
+            {
+                out_ptr[i] = utility::saturate_cast<qasymm8_t>(tmp_ptr[i] * sum_inversed);
+            }
+        }
+    },
+    in_it, max_it, out_it);
+}
+
+template <typename T, typename U>
+void logits_1d_softmax_fixed_point(const ITensor &in, const ITensor &max, void *const tmp,
+                                   ITensor &out, const float /*beta*/, const Window &window)
+{
+    const int start_x     = in.info()->valid_region().anchor.x();
+    const int input_width = in.info()->valid_region().shape.x();
+
+    const int fixed_point_position = in.info()->fixed_point_position();
+
+    Iterator in_it(&in, window);
+    Iterator max_it(&max, window);
+    Iterator out_it(&out, window);
+
+    execute_window_loop(window, [&](const Coordinates &)
+    {
+        /* Get pointers */
+        const auto in_ptr  = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
+        const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
+        const auto tmp_ptr = reinterpret_cast<T *>(tmp);
+
+        vec_16_byte_t<T> vec_sum_inversed;
+
+        /* Compute exponentials and sum */
+        {
+            /* Get max value */
+            const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
+            const auto vec_max = vdup_n<vec_8_byte_t<T>>(max_val);
+
+            /* Init sum to zero */
+            auto vec_sum = vdup_n<vec_16_byte_t<U>>(0);
+
+            /* Loop over row and compute exponentials and sum */
+            int           i        = 0;
+            constexpr int vec_size = vec_size_of(vec_sum);
+            for(; i <= (input_width - vec_size); i += vec_size)
+            {
+                auto vec_elements = vld<vec_8_byte_t<T>>(in_ptr + i);
+                vec_elements      = vqsub(vec_elements, vec_max);
+                vec_elements      = vqexp(vec_elements, fixed_point_position);
+                vec_sum           = vqadd(vec_sum, vmovl(vec_elements));
+                vst(tmp_ptr + i, vec_elements);
+            }
+            /* Reduce sum */
+            const vec_8_byte_t<U> sum_8_byte = vqadd(vget_high(vec_sum), vget_low(vec_sum));
+            U                     sum        = reduce_add(sqadd<U>, sum_8_byte);
+
+            /* Run remaining elements */
+            for(; i < input_width; ++i)
+            {
+                T element  = sqexp(sqsub(in_ptr[i], max_val), fixed_point_position);
+                sum        = sqadd<U>(sum, element);
+                tmp_ptr[i] = element;
+            }
+
+            const auto qsum  = utility::saturate_cast<T>(sum);
+            vec_sum_inversed = vqrecip(vdup_n<vec_16_byte_t<T>>(qsum), fixed_point_position);
+        }
+
+        /* Normalize exponentials */
+        {
+            /* Loop over row and compute softmax */
+            int           i        = 0;
+            constexpr int vec_size = vec_size_of(vec_sum_inversed);
+            for(; i <= (input_width - vec_size); i += vec_size)
+            {
+                const auto             vec_in           = vld<vec_16_byte_t<T>>(tmp_ptr + i);
+                const vec_16_byte_t<T> normalized_value = vqmul(vec_in, vec_sum_inversed, fixed_point_position);
+                vst(out_ptr + i, normalized_value);
+            }
+
+            const T sum_inversed = vget_lane<0>(vec_sum_inversed);
+
+            /* Run remaining elements */
+            for(; i < input_width; ++i)
+            {
+                out_ptr[i] = sqmul(tmp_ptr[i], sum_inversed, fixed_point_position);
+            }
+        }
+    },
+    in_it, max_it, out_it);
+}
+
+template <typename T>
+void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp,
+                             ITensor &out, const float beta, const Window &window)
+{
+    const int start_x     = in.info()->valid_region().anchor.x();
+    const int input_width = in.info()->valid_region().shape.x();
+
+    Iterator in_it(&in, window);
+    Iterator max_it(&max, window);
+    Iterator out_it(&out, window);
+
+    execute_window_loop(window, [&](const Coordinates &)
+    {
+        /* Get pointers */
+        const auto in_ptr  = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
+        const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
+        const auto tmp_ptr = reinterpret_cast<T *>(tmp);
+
+        T sum_inversed;
+
+        /* Compute exponentials and sum */
+        {
+            /* Get max value */
+            const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
+            const auto vec_max = vdup_n<vec_16_byte_t<T>>(max_val);
+
+            /* Init sum to zero */
+            auto vec_sum = vdup_n<vec_16_byte_t<T>>(0);
+
+            /* Loop over row and compute exponentials and sum */
+            int           i        = 0;
+            constexpr int vec_size = vec_size_of(vec_sum);
+            for(; i <= (input_width - vec_size); i += vec_size)
+            {
+                auto vec_elements = vld<vec_16_byte_t<T>>(in_ptr + i);
+                vec_elements      = vsub(vec_elements, vec_max);
+                vec_elements      = vexp(vmul_n(vec_elements, beta));
+                vec_sum           = vadd(vec_sum, vec_elements);
+                vst(tmp_ptr + i, vec_elements);
+            }
+            /* Reduce sum */
+            const auto sum_8_byte = vadd(vget_high(vec_sum), vget_low(vec_sum));
+            T sum                 = reduce_add([](T a, T b) -> T { return a + b; }, sum_8_byte);
+
+            /* Run remaining elements */
+            for(; i < input_width; ++i)
+            {
+                T element = std::exp((in_ptr[i] - max_val) * beta);
+                sum += element;
+                tmp_ptr[i] = element;
+            }
+
+            sum_inversed = T(1) / sum;
+        }
+
+        /* Normalize exponentials */
+        {
+            /* Loop over row and compute softmax */
+            int i = 0;
+            {
+                constexpr int vec_size = vec_size_of(vec_16_byte_t<T> {});
+                for(; i <= (input_width - vec_size); i += vec_size)
+                {
+                    auto             vec_in           = vld<vec_16_byte_t<T>>(tmp_ptr + i);
+                    vec_16_byte_t<T> normalized_value = vmul_n(vec_in, sum_inversed);
+                    vst(out_ptr + i, normalized_value);
+                }
+            }
+            /* Run remaining elements */
+            for(; i < input_width; ++i)
+            {
+                out_ptr[i] = tmp_ptr[i] * sum_inversed;
+            }
+        }
+    },
+    in_it, max_it, out_it);
+}
+} // namespace
+
+NELogits1DSoftmaxKernel::NELogits1DSoftmaxKernel()
+    : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _beta(1.0f), _tmp(nullptr)
+{
+}
+
+void NELogits1DSoftmaxKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
+{
+    ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), max->info(), output->info(), tmp->info());
     // Perform validation step
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_shift_exp_sum(input->info(), max->info(), output->info(), sum->info(), beta));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_softmax(*input->info(), *max->info(), *output->info(), beta, *tmp->info()));
+    // Configure kernel window
+    auto win_config = validate_and_configure_window_logits_softmax(*input->info(), *max->info(), *output->info(), *tmp->info());
+    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
 
     switch(input->info()->data_type())
     {
+        case DataType::QASYMM8:
+            _func = &logits_1d_softmax_qasymm8;
+            break;
         case DataType::QS8:
-            _func = &logits_1d_shift_exp_sum_qs8;
+            _func = &logits_1d_softmax_fixed_point<qint8_t, qint16_t>;
             break;
         case DataType::QS16:
-            _func = &logits_1d_shift_exp_sum_qs16;
+            _func = &logits_1d_softmax_fixed_point<qint16_t, qint32_t>;
             break;
-        case DataType::F32:
-            _func = &logits_1d_shift_exp_sum_f32;
-            break;
-        case DataType::F16:
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-            _func = &logits_1d_shift_exp_sum_f16;
+        case DataType::F16:
+            _func = &logits_1d_softmax_float<float16_t>;
             break;
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
+        case DataType::F32:
+            _func = &logits_1d_softmax_float<float>;
+            break;
         default:
             ARM_COMPUTE_ERROR("Unsupported data type.");
             break;
@@ -686,224 +930,37 @@
     _input  = input;
     _max    = max;
     _output = output;
-    _sum    = sum;
     _beta   = beta;
+    _tmp    = tmp;
 
-    // Configure kernel window
-    auto win_config = validate_and_configure_window_logits_1d_shift_exp_sum(input->info(), max->info(), output->info(), sum->info());
-    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
     INEKernel::configure(win_config.second);
 }
 
-Status NELogits1DShiftExpSumKernel::validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum, float beta)
+Status NELogits1DSoftmaxKernel::validate(const ITensorInfo *input, const ITensorInfo *max,
+                                         const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
 {
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_shift_exp_sum(input, max, output, sum, beta));
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_shift_exp_sum(input->clone().get(), max->clone().get(), output->clone().get(), sum->clone().get()).first);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
+
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_softmax(*input, *max, *output, beta, *tmp));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_softmax(*input->clone(), *max->clone(), *output->clone(), *tmp->clone()).first);
 
     return Status{};
 }
 
-void NELogits1DShiftExpSumKernel::run(const Window &window, const ThreadInfo &info)
+void NELogits1DSoftmaxKernel::run(const Window &window, const ThreadInfo &info)
 {
     ARM_COMPUTE_UNUSED(info);
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
-    ARM_COMPUTE_ERROR_ON(_func == nullptr);
 
-    (*_func)(_input, _max, _output, _sum, window, _beta);
+    const unsigned int num_elems_processed_per_iteration = _input->info()->valid_region().shape.x();
+    const unsigned int tmp_size_for_thread               = _tmp->info()->element_size() * num_elems_processed_per_iteration;
+
+    ARM_COMPUTE_ERROR_ON(_tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
+
+    void *tmp_for_thread = _tmp->buffer() + (info.thread_id * tmp_size_for_thread);
+
+    (*_func)(*_input, *_max, tmp_for_thread, *_output, _beta, window);
 }
 
-namespace
-{
-void logits_1d_norm_qs8(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
-{
-    Window window_sum(window);
-    window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
-    Window sum_slice = window_sum.first_slice_window_1D();
-    Window in_slice  = window.first_slice_window_1D();
-
-    const int fixed_point_position = in->info()->fixed_point_position();
-
-    do
-    {
-        Iterator input(in, in_slice);
-        Iterator _sum(sum, sum_slice);
-        Iterator output(out, in_slice);
-
-        const int8_t     sum_value        = *reinterpret_cast<const qint8_t *>(_sum.ptr());
-        const qint8x16_t vec_sum_inversed = vqrecipq_qs8(vdupq_n_qs8(sum_value), fixed_point_position);
-
-        execute_window_loop(in_slice, [&](const Coordinates & id)
-        {
-            const auto in_ptr  = reinterpret_cast<const qint8_t *>(input.ptr());
-            const auto out_ptr = reinterpret_cast<qint8_t *>(output.ptr());
-
-            const qint8x16_t vec_in           = vld1q_qs8(in_ptr);
-            const qint8x16_t normalized_value = vqmulq_qs8(vec_in, vec_sum_inversed, fixed_point_position);
-
-            vst1q_qs8(out_ptr, normalized_value);
-        },
-        input, output);
-    }
-    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
-}
-void logits_1d_norm_qs16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
-{
-    Window window_sum(window);
-    window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
-    Window sum_slice = window_sum.first_slice_window_1D();
-    Window in_slice  = window.first_slice_window_1D();
-
-    const int fixed_point_position = in->info()->fixed_point_position();
-
-    do
-    {
-        Iterator input(in, in_slice);
-        Iterator _sum(sum, sum_slice);
-        Iterator output(out, in_slice);
-
-        const int16_t    sum_value        = *reinterpret_cast<const qint16_t *>(_sum.ptr());
-        const qint16x8_t vec_sum_inversed = vqrecipq_qs16(vdupq_n_qs16(sum_value), fixed_point_position);
-
-        execute_window_loop(in_slice, [&](const Coordinates & id)
-        {
-            const auto in_ptr  = reinterpret_cast<const qint16_t *>(input.ptr());
-            const auto out_ptr = reinterpret_cast<qint16_t *>(output.ptr());
-
-            const qint16x8_t vec_in           = vld1q_qs16(in_ptr);
-            const qint16x8_t normalized_value = vqmulq_qs16(vec_in, vec_sum_inversed, fixed_point_position);
-
-            vst1q_qs16(out_ptr, normalized_value);
-        },
-        input, output);
-    }
-    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
-}
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-void logits_1d_norm_f16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
-{
-    Window window_sum(window);
-    window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
-    Window sum_slice = window_sum.first_slice_window_1D();
-    Window in_slice  = window.first_slice_window_1D();
-
-    do
-    {
-        Iterator input(in, in_slice);
-        Iterator _sum(sum, sum_slice);
-        Iterator output(out, in_slice);
-
-        const float16_t   sum_value        = *reinterpret_cast<const qint16_t *>(_sum.ptr());
-        const float16x8_t vec_sum_inversed = vdupq_n_f16(1.0f / sum_value);
-
-        execute_window_loop(in_slice, [&](const Coordinates & id)
-        {
-            const auto in_ptr  = reinterpret_cast<const float16_t *>(input.ptr());
-            const auto out_ptr = reinterpret_cast<float16_t *>(output.ptr());
-
-            const float16x8_t vec_in           = vld1q_f16(in_ptr);
-            const float16x8_t normalized_value = vmulq_f16(vec_in, vec_sum_inversed);
-
-            vst1q_f16(out_ptr, normalized_value);
-        },
-        input, output);
-    }
-    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
-}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-
-void logits_1d_norm_f32(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
-{
-    Window window_sum(window);
-    window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
-    Window sum_slice = window_sum.first_slice_window_1D();
-    Window in_slice  = window.first_slice_window_1D();
-
-    do
-    {
-        Iterator input(in, in_slice);
-        Iterator _sum(sum, sum_slice);
-        Iterator output(out, in_slice);
-
-        const float       sum_value        = *reinterpret_cast<const float *>(_sum.ptr());
-        const float32x4_t vec_sum_inversed = vdupq_n_f32(1.0f / sum_value);
-
-        execute_window_loop(in_slice, [&](const Coordinates & id)
-        {
-            const auto in_ptr  = reinterpret_cast<const float *>(input.ptr());
-            const auto out_ptr = reinterpret_cast<float *>(output.ptr());
-
-            const float32x4_t vec_in           = vld1q_f32(in_ptr);
-            const float32x4_t normalized_value = vmulq_f32(vec_in, vec_sum_inversed);
-
-            vst1q_f32(out_ptr, normalized_value);
-        },
-        input, output);
-    }
-    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
-}
-} // namespace
-
-NELogits1DNormKernel::NELogits1DNormKernel()
-    : _func(nullptr), _input(nullptr), _sum(nullptr), _output(nullptr)
-{
-}
-
-void NELogits1DNormKernel::configure(const ITensor *input, const ITensor *sum, ITensor *output)
-{
-    ARM_COMPUTE_ERROR_ON_NULLPTR(input, sum, output);
-
-    // Output auto initialization if not yet initialized
-    auto_init_if_empty(*output->info(), input->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
-
-    // Perform validation step
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_logits_1d_norm(input->info(), sum->info(), output->info()));
-
-    _input  = input;
-    _sum    = sum;
-    _output = output;
-
-    switch(input->info()->data_type())
-    {
-        case DataType::QS8:
-            _func = &logits_1d_norm_qs8;
-            break;
-        case DataType::QS16:
-            _func = &logits_1d_norm_qs16;
-            break;
-        case DataType::F32:
-            _func = &logits_1d_norm_f32;
-            break;
-        case DataType::F16:
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-            _func = &logits_1d_norm_f16;
-            break;
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-        default:
-            ARM_COMPUTE_ERROR("Unsupported data type.");
-            break;
-    }
-
-    // Configure kernel window
-    auto win_config = validate_and_configure_window_logits_1d_norm(input->info(), sum->info(), output->info());
-    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
-    INEKernel::configure(win_config.second);
-}
-
-Status NELogits1DNormKernel::validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output)
-{
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_logits_1d_norm(input, sum, output));
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_logits_1d_norm(input->clone().get(), sum->clone().get(), output->clone().get()).first);
-
-    return Status{};
-}
-
-void NELogits1DNormKernel::run(const Window &window, const ThreadInfo &info)
-{
-    ARM_COMPUTE_UNUSED(info);
-    ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
-    ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
-    ARM_COMPUTE_ERROR_ON(_func == nullptr);
-
-    (*_func)(_input, _sum, _output, window);
-}
+} // namespace arm_compute