Optimize CpuSoftmaxKernel for axis=0

Implement a single kernel instead of having two consecutive ones. In the previous setup, one kernel was calculating the maximum value in the axis, and this maximum was being subtracted from each data while calculating the softmax, i.e.

softmax(x_i) = exp(x_i - max) / sum_i( exp(x_i - max) )

This patch integrates these two stages into a single kernel for Neon™ for all data types. This will save some memory because we don't need to hold the max values in a separate auxiliary tensor.

It also introduces some other optimizations that will ease memory pressure when the data type is float/half, by using the dst tensor as temporary storage for already exponentiated inputs.

It removes the references to SVE and SVE2 implementations, and most of the associated files; but, it leaves the implementations as these may be used in the future.

Resolves: COMPMID-6500

Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
Change-Id: Icff9976d1214c4c6cbe15a62ca60b8a77d3784cc
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10688
Reviewed-by: SiCong Li <sicong.li@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/kernels/softmax/generic/neon/fp16.cpp b/src/cpu/kernels/softmax/generic/neon/fp16.cpp
index 2e2adf3..db8f881 100644
--- a/src/cpu/kernels/softmax/generic/neon/fp16.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/fp16.cpp
@@ -31,21 +31,18 @@
 {
 namespace cpu
 {
-void neon_fp16_softmax(const ITensor *in,
-                       const ITensor *max,
-                       void *const    tmp,
-                       ITensor       *out,
-                       const float    beta,
-                       bool           is_log,
-                       const Window  &window)
+
+template <bool IS_LOG>
+void neon_fp16_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
 {
-    return neon_softmax_logits_1d_float<float16_t>(in, max, tmp, out, beta, is_log, window);
+    return neon_softmax_float<float16_t, IS_LOG>(in, tmp, out, beta, window);
 }
 
-void neon_fp16_logits(const ITensor *in, ITensor *out, const Window &window)
-{
-    return neon_logits_1d_max<float16_t>(in, out, window);
-}
+template void
+neon_fp16_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void
+neon_fp16_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+
 } // namespace cpu
 } // namespace arm_compute
 #endif //defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
diff --git a/src/cpu/kernels/softmax/generic/neon/fp32.cpp b/src/cpu/kernels/softmax/generic/neon/fp32.cpp
index 61df40c..c281d1b 100644
--- a/src/cpu/kernels/softmax/generic/neon/fp32.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/fp32.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,20 +29,17 @@
 {
 namespace cpu
 {
-void neon_fp32_softmax(const ITensor *in,
-                       const ITensor *max,
-                       void *const    tmp,
-                       ITensor       *out,
-                       const float    beta,
-                       bool           is_log,
-                       const Window  &window)
+
+template <bool IS_LOG>
+void neon_fp32_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
 {
-    return neon_softmax_logits_1d_float<float>(in, max, tmp, out, beta, is_log, window);
+    return neon_softmax_float<float, IS_LOG>(in, tmp, out, beta, window);
 }
 
-void neon_fp32_logits(const ITensor *in, ITensor *out, const Window &window)
-{
-    return neon_logits_1d_max<float>(in, out, window);
-}
+template void
+neon_fp32_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void
+neon_fp32_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/impl.cpp b/src/cpu/kernels/softmax/generic/neon/impl.cpp
index 5d6e6a4..487f6ae 100644
--- a/src/cpu/kernels/softmax/generic/neon/impl.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/impl.cpp
@@ -29,43 +29,76 @@
 {
 namespace cpu
 {
-template void neon_logits_1d_max<qasymm8_signed_t>(const ITensor *in, ITensor *out, const Window &window);
-template void neon_logits_1d_max<qasymm8_t>(const ITensor *in, ITensor *out, const Window &window);
-
-template <typename T>
-void neon_softmax_logits_1d_quantized(
-    const ITensor *in, const ITensor *max, void *const tmp, ITensor *out, float beta, bool is_log, const Window &window)
+template <typename T, bool IS_LOG>
+void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window)
 {
     static_assert(std::is_same<T, qasymm8_t>::value || std::is_same<T, qasymm8_signed_t>::value,
                   "quantized type should be either qasymm8_t or qasymm8_signed_t.");
 
-    const int start_x     = in->info()->valid_region().anchor.x();
     const int input_width = in->info()->valid_region().shape.x();
 
-    const float scale_beta     = -beta * in->info()->quantization_info().uniform().scale;
-    const auto  scale_beta_vec = vdupq_n_f32(scale_beta);
+    const float       scale_beta     = -beta * in->info()->quantization_info().uniform().scale;
+    const float32x4_t scale_beta_vec = vdupq_n_f32(scale_beta);
 
-    Iterator      in_it(in, window);
-    Iterator      max_it(max, window);
-    Iterator      out_it(out, window);
+    Iterator in_it(in, window);
+    Iterator out_it(out, window);
+
     constexpr int vec_size = 16;
 
+#ifndef __aarch64__
+    const int sum_stages = log2(vec_size >> 1);
+#endif // __aarch64__
+
+    using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
     execute_window_loop(
         window,
         [&](const Coordinates &)
         {
             /* Get pointers */
-            const auto in_ptr  = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
-            const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
-            const auto tmp_ptr = reinterpret_cast<float *>(tmp);
+            const T *in_ptr  = reinterpret_cast<const T *>(in_it.ptr());
+            T       *out_ptr = reinterpret_cast<T *>(out_it.ptr());
+            float   *tmp_ptr = reinterpret_cast<float *>(tmp);
 
-            float sum{};
-            float sum_inversed{};
+            T max_val;
+
+            /* Compute Max */
+            {
+                // Init max value
+                auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
+                int  x       = 0;
+
+                for (; x <= (input_width - vec_size); x += vec_size)
+                {
+                    const auto current_value = wrapper::vloadq(in_ptr + x);
+                    vec_max                  = wrapper::vmax(vec_max, current_value);
+                }
+
+#ifdef __aarch64__
+                max_val = wrapper::vmaxv(vec_max);
+#else  // __aarch64__
+                auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
+
+                for (int i = 0; i < sum_stages; ++i)
+                {
+                    carry_max = wrapper::vpmax(carry_max, carry_max);
+                }
+
+                max_val      = wrapper::vgetlane(carry_max, 0);
+#endif // __aarch64__
+
+                // Compute left-over elements
+                for (; x < input_width; ++x)
+                {
+                    max_val = std::max(*(in_ptr + x), max_val);
+                }
+            } // Compute Max
+
+            float sum_transformed{};
 
             /* Compute exponentials and sum */
             {
                 /* Get max value */
-                const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
                 const auto vec_max = wrapper::vdup_n(max_val, wrapper::traits::vector_128_tag{});
 
                 /* Init sum to zero */
@@ -80,11 +113,11 @@
                 int x = 0;
                 for (; x <= (input_width - vec_size); x += vec_size)
                 {
-                    auto vec_elements     = wrapper::vloadq(in_ptr + x);
-                    vec_elements          = wrapper::vqsub(vec_max, vec_elements);
-                    auto vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
+                    auto vec_elements              = wrapper::vloadq(in_ptr + x);
+                    vec_elements                   = wrapper::vqsub(vec_max, vec_elements);
+                    float32x4x4_t vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
 
-                    if (is_log)
+                    if (IS_LOG)
                     {
                         vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
                         vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
@@ -111,17 +144,24 @@
                 }
 
                 /* Reduce sum */
-                const auto sum_16_byte =
+                const float32x4_t sum_16_byte =
                     vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]), vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
+
+                float sum;
+
+#ifdef __aarch64__
+                sum = wrapper::vaddv(sum_16_byte);
+#else  // __aarch64__
                 auto sum_res = vpadd_f32(vget_high_f32(sum_16_byte), vget_low_f32(sum_16_byte));
                 sum_res      = vpadd_f32(sum_res, sum_res);
                 sum          = wrapper::vgetlane(sum_res, 0);
+#endif // __aarch64__
 
                 /* Run remaining elements */
                 for (; x < input_width; ++x)
                 {
                     float element{};
-                    if (is_log)
+                    if (IS_LOG)
                     {
                         element = (max_val - in_ptr[x]) * scale_beta;
                         sum += std::exp(element);
@@ -135,19 +175,22 @@
                     tmp_ptr[x] = element;
                 }
 
-                if (!is_log)
+                if (!IS_LOG)
                 {
-                    sum_inversed = 256.f / sum;
+                    sum_transformed = 256.f / sum;
                 }
                 else
                 {
-                    sum = std::log(sum);
+                    sum_transformed = std::log(sum);
                 }
-            }
+            } // Compute exponentials and sum
 
             /* Normalize exponentials */
             {
                 constexpr bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value;
+
+                const float32x4_t sum_vec = vdupq_n_f32(sum_transformed);
+
                 /* Loop over row and compute softmax */
                 int x = 0;
                 for (; x <= (input_width - vec_size); x += vec_size)
@@ -155,23 +198,23 @@
                     using int_vec_type   = wrapper::traits::neon_vector_t<T, 16>;
                     float32x4x4_t vec_in = vld4q_f32(tmp_ptr + x);
                     int_vec_type  normalized_value{};
-                    if (is_log)
+                    if (IS_LOG)
                     {
                         const float32x4x4_t sub = {
-                            vsubq_f32(vec_in.val[0], vdupq_n_f32(sum)),
-                            vsubq_f32(vec_in.val[1], vdupq_n_f32(sum)),
-                            vsubq_f32(vec_in.val[2], vdupq_n_f32(sum)),
-                            vsubq_f32(vec_in.val[3], vdupq_n_f32(sum)),
+                            vsubq_f32(vec_in.val[0], sum_vec),
+                            vsubq_f32(vec_in.val[1], sum_vec),
+                            vsubq_f32(vec_in.val[2], sum_vec),
+                            vsubq_f32(vec_in.val[3], sum_vec),
                         };
                         normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub);
                     }
                     else
                     {
                         float32x4x4_t mul = {
-                            vmulq_f32(vec_in.val[0], vdupq_n_f32(sum_inversed)),
-                            vmulq_f32(vec_in.val[1], vdupq_n_f32(sum_inversed)),
-                            vmulq_f32(vec_in.val[2], vdupq_n_f32(sum_inversed)),
-                            vmulq_f32(vec_in.val[3], vdupq_n_f32(sum_inversed)),
+                            vmulq_f32(vec_in.val[0], sum_vec),
+                            vmulq_f32(vec_in.val[1], sum_vec),
+                            vmulq_f32(vec_in.val[2], sum_vec),
+                            vmulq_f32(vec_in.val[3], sum_vec),
                         };
 
                         if (is_qasymm8_signed)
@@ -190,34 +233,31 @@
                 /* Run remaining elements */
                 for (; x < input_width; ++x)
                 {
-                    if (is_log)
+                    if (IS_LOG)
                     {
-                        out_ptr[x] = utils::cast::saturate_cast<T>(tmp_ptr[x] - sum);
+                        out_ptr[x] = utils::cast::saturate_cast<T>(tmp_ptr[x] - sum_transformed);
                     }
                     else
                     {
-                        out_ptr[x] = utils::cast::saturate_cast<T>((tmp_ptr[x] * sum_inversed) -
+                        out_ptr[x] = utils::cast::saturate_cast<T>((tmp_ptr[x] * sum_transformed) -
                                                                    (is_qasymm8_signed ? 128.f : 0));
                     }
                 }
-            }
+            } // Normalize exponentials
         },
-        in_it, max_it, out_it);
+        in_it, out_it);
 }
 
-template void neon_softmax_logits_1d_quantized<qasymm8_signed_t>(const ITensor *in,
-                                                                 const ITensor *max,
-                                                                 void *const    tmp,
-                                                                 ITensor       *out,
-                                                                 float          beta,
-                                                                 bool           is_log,
-                                                                 const Window  &window);
-template void neon_softmax_logits_1d_quantized<qasymm8_t>(const ITensor *in,
-                                                          const ITensor *max,
-                                                          void *const    tmp,
-                                                          ITensor       *out,
-                                                          float          beta,
-                                                          bool           is_log,
-                                                          const Window  &window);
+template void neon_softmax_quantized<qasymm8_signed_t, true>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+
+template void neon_softmax_quantized<qasymm8_signed_t, false>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+
+template void neon_softmax_quantized<qasymm8_t, true>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+
+template void neon_softmax_quantized<qasymm8_t, false>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/impl.h b/src/cpu/kernels/softmax/generic/neon/impl.h
index 4d9b789..60380cd 100644
--- a/src/cpu/kernels/softmax/generic/neon/impl.h
+++ b/src/cpu/kernels/softmax/generic/neon/impl.h
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef SRC_CORE_NEON_KERNELS_SOFTMAX_IMPL_H
-#define SRC_CORE_NEON_KERNELS_SOFTMAX_IMPL_H
+#ifndef ACL_SRC_CPU_KERNELS_SOFTMAX_GENERIC_NEON_IMPL_H
+#define ACL_SRC_CPU_KERNELS_SOFTMAX_GENERIC_NEON_IMPL_H
 
 #include "arm_compute/core/Helpers.h"
 
@@ -33,105 +33,100 @@
 {
 namespace cpu
 {
-template <typename T>
-void neon_logits_1d_max(const ITensor *in, ITensor *out, const Window &window)
+
+#ifdef __aarch64__
+namespace
 {
-    /** SIMD vector tag type. */
-    using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
-
-    constexpr int window_step_x  = 16 / sizeof(T);
-    const auto    window_start_x = static_cast<int>(window.x().start());
-    const auto    window_end_x   = static_cast<int>(window.x().end());
-
-    Window win{window};
-    win.set(Window::DimX, Window::Dimension(0, 1, 1));
-    Iterator input(in, win);
-    Iterator output(out, win);
-
-    const int sum_stages = log2(window_step_x / 2);
-    execute_window_loop(
-        win,
-        [&](const Coordinates &)
-        {
-            // Get pointers
-            const auto in_ptr  = reinterpret_cast<const T *>(input.ptr());
-            const auto out_ptr = reinterpret_cast<T *>(output.ptr());
-
-            // Init max value
-            auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
-            int  x       = window_start_x;
-
-            for (; x <= (window_end_x - window_step_x); x += window_step_x)
-            {
-                const auto current_value = wrapper::vloadq(in_ptr + x);
-                vec_max                  = wrapper::vmax(vec_max, current_value);
-            }
-            auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
-
-            for (int i = 0; i < sum_stages; ++i)
-            {
-                carry_max = wrapper::vpmax(carry_max, carry_max);
-            }
-            T max_val = wrapper::vgetlane(carry_max, 0);
-
-            // Compute left-over elements
-            for (; x < window_end_x; ++x)
-            {
-                max_val = *(in_ptr + x) > max_val ? *(in_ptr + x) : max_val;
-            }
-
-            *out_ptr = max_val;
-        },
-        input, output);
+// These helper functions are added because vaddv does not exist for fp16,
+// and, therefore, is not part of the wrapper::vaddv interface.
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+inline float16_t wrapper_vaddv(const float16x8_t &a, int sum_stages)
+{
+    auto sum_res = wrapper::vpadd(wrapper::vgethigh(a), wrapper::vgetlow(a));
+    for (int i = 0; i < sum_stages; ++i)
+    {
+        sum_res = wrapper::vpadd(sum_res, sum_res);
+    }
+    return wrapper::vgetlane(sum_res, 0);
 }
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 
-template <typename T>
-void neon_softmax_logits_1d_quantized(const ITensor *in,
-                                      const ITensor *max,
-                                      void *const    tmp,
-                                      ITensor       *out,
-                                      float          beta,
-                                      bool           is_log,
-                                      const Window  &window);
-
-template <typename T>
-void neon_softmax_logits_1d_float(const ITensor *in,
-                                  const ITensor *max,
-                                  void *const    tmp,
-                                  ITensor       *out,
-                                  const float    beta,
-                                  bool           is_log,
-                                  const Window  &window)
+inline float wrapper_vaddv(const float32x4_t &a, int sum_stages)
 {
-    const int start_x     = in->info()->valid_region().anchor.x();
+    ARM_COMPUTE_UNUSED(sum_stages);
+    return wrapper::vaddv(a);
+}
+} // namespace
+#endif // __aarch64__
+
+// The template implementation for float data types is stored in the header file because
+// we need all fp16 instantiated code to live in fp16.cpp files.
+template <typename T, bool IS_LOG>
+void neon_softmax_float(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window)
+{
+    ARM_COMPUTE_UNUSED(tmp);
+
     const int input_width = in->info()->valid_region().shape.x();
 
     Iterator in_it(in, window);
-    Iterator max_it(max, window);
     Iterator out_it(out, window);
 
     /** SIMD vector tag type. */
     using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
 
-    constexpr int vec_size   = 16 / sizeof(T);
-    const int     sum_stages = log2(vec_size / 2);
+    constexpr int vec_size = 16 / sizeof(T);
+
+    const int sum_stages = log2(vec_size >> 1);
+
+    const auto beta_vec = wrapper::vdup_n(static_cast<T>(beta), ExactTagType{});
 
     execute_window_loop(
         window,
         [&](const Coordinates &)
         {
             /* Get pointers */
-            const auto in_ptr  = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
-            const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
-            const auto tmp_ptr = reinterpret_cast<T *>(tmp);
+            const T *in_ptr  = reinterpret_cast<const T *>(in_it.ptr());
+            T       *out_ptr = reinterpret_cast<T *>(out_it.ptr());
 
-            T sum{};
-            T sum_inversed{};
+            T max_val;
+
+            /* Compute Max */
+            {
+                // Init max value
+                auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
+                int  x       = 0;
+
+                for (; x <= (input_width - vec_size); x += vec_size)
+                {
+                    const auto current_value = wrapper::vloadq(in_ptr + x);
+                    vec_max                  = wrapper::vmax(vec_max, current_value);
+                }
+
+#ifdef __aarch64__
+                max_val = wrapper::vmaxv(vec_max);
+#else  // __aarch64__
+                auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
+
+                for (int i = 0; i < sum_stages; ++i)
+                {
+                    carry_max = wrapper::vpmax(carry_max, carry_max);
+                }
+
+                max_val      = wrapper::vgetlane(carry_max, 0);
+#endif // __aarch64__
+
+                // Compute left-over elements
+                for (; x < input_width; ++x)
+                {
+                    max_val = std::max(*(in_ptr + x), max_val);
+                }
+            } // compute max
+
+            T sum_transformed{};
 
             /* Compute exponentials and sum */
             {
                 /* Get max value */
-                const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
                 const auto vec_max = wrapper::vdup_n(max_val, ExactTagType{});
 
                 /* Init sum to zero */
@@ -143,35 +138,38 @@
                 {
                     auto vec_elements = wrapper::vloadq(in_ptr + x);
                     vec_elements      = wrapper::vsub(vec_elements, vec_max);
-                    if (is_log)
+                    if (IS_LOG)
                     {
-                        vec_elements =
-                            wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}));
-                        vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
+                        vec_elements = wrapper::vmul(vec_elements, beta_vec);
+                        vec_sum      = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
                     }
                     else
                     {
-                        vec_elements = wrapper::vexpq(
-                            wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{})));
-                        vec_sum = wrapper::vadd(vec_sum, vec_elements);
+                        vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, beta_vec));
+                        vec_sum      = wrapper::vadd(vec_sum, vec_elements);
                     }
-                    wrapper::vstore(tmp_ptr + x, vec_elements);
+                    wrapper::vstore(out_ptr + x, vec_elements);
                 }
 
                 /* Reduce sum */
+                T sum{};
+#ifdef __aarch64__
+                sum = wrapper_vaddv(vec_sum, sum_stages);
+#else  // __aarch64__
                 auto sum_res = wrapper::vpadd(wrapper::vgethigh(vec_sum), wrapper::vgetlow(vec_sum));
                 for (int i = 0; i < sum_stages; ++i)
                 {
                     sum_res = wrapper::vpadd(sum_res, sum_res);
                 }
                 sum = wrapper::vgetlane(sum_res, 0);
+#endif // __aarch64__
 
                 /* Run remaining elements */
                 for (; x < input_width; ++x)
                 {
                     T element{};
 
-                    if (is_log)
+                    if (IS_LOG)
                     {
                         element = (in_ptr[x] - max_val) * beta;
                         sum += std::exp(element);
@@ -181,55 +179,59 @@
                         element = std::exp((in_ptr[x] - max_val) * beta);
                         sum += element;
                     }
-                    tmp_ptr[x] = element;
+
+                    out_ptr[x] = element;
                 }
 
-                if (!is_log)
+                if (!IS_LOG)
                 {
-                    sum_inversed = T(1) / sum;
+                    sum_transformed = T(1) / sum;
                 }
                 else
                 {
-                    sum = static_cast<T>(std::log(sum));
+                    sum_transformed = static_cast<T>(std::log(sum));
                 }
-            }
+            } // Compute exponentials and sum
 
             /* Normalize exponentials */
             {
+                const auto sum_vec = wrapper::vdup_n(static_cast<T>(sum_transformed), ExactTagType{});
+
                 /* Loop over row and compute softmax */
                 int x = 0;
                 for (; x <= (input_width - vec_size); x += vec_size)
                 {
-                    auto vec_in           = wrapper::vloadq(tmp_ptr + x);
-                    auto normalized_value = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
-                    if (is_log)
+                    const auto vec_in = wrapper::vloadq(out_ptr + x);
+                    if (IS_LOG)
                     {
-                        normalized_value = wrapper::vsub(vec_in, wrapper::vdup_n(static_cast<T>(sum), ExactTagType{}));
+                        wrapper::vstore(out_ptr + x, wrapper::vsub(vec_in, sum_vec));
                     }
                     else
                     {
-                        normalized_value =
-                            wrapper::vmul(vec_in, wrapper::vdup_n(static_cast<T>(sum_inversed), ExactTagType{}));
+                        wrapper::vstore(out_ptr + x, wrapper::vmul(vec_in, sum_vec));
                     }
-                    wrapper::vstore(out_ptr + x, normalized_value);
                 }
+
                 /* Run remaining elements */
                 for (; x < input_width; ++x)
                 {
-                    if (is_log)
+                    if (IS_LOG)
                     {
-                        out_ptr[x] = tmp_ptr[x] - sum;
+                        out_ptr[x] = out_ptr[x] - sum_transformed;
                     }
                     else
                     {
-                        out_ptr[x] = tmp_ptr[x] * sum_inversed;
+                        out_ptr[x] = out_ptr[x] * sum_transformed;
                     }
                 }
-            }
+            } // Normalize exponentials
         },
-        in_it, max_it, out_it);
+        in_it, out_it);
 }
+
+template <typename T, bool IS_LOG>
+void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
 } // namespace cpu
 } // namespace arm_compute
 
-#endif /* SRC_CORE_NEON_KERNELS_SOFTMAX_IMPL_H */
+#endif // ACL_SRC_CPU_KERNELS_SOFTMAX_GENERIC_NEON_IMPL_H
diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
index 40713dc..9589ebc 100644
--- a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,20 +29,16 @@
 {
 namespace cpu
 {
-void neon_qasymm8_softmax(const ITensor *in,
-                          const ITensor *max,
-                          void *const    tmp,
-                          ITensor       *out,
-                          const float    beta,
-                          bool           is_log,
-                          const Window  &window)
+template <bool IS_LOG>
+void neon_qasymm8_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
 {
-    return neon_softmax_logits_1d_quantized<qasymm8_t>(in, max, tmp, out, beta, is_log, window);
+    return neon_softmax_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, window);
 }
 
-void neon_qasymm8_logits(const ITensor *in, ITensor *out, const Window &window)
-{
-    return neon_logits_1d_max<qasymm8_t>(in, out, window);
-}
+template void
+neon_qasymm8_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void
+neon_qasymm8_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
index 2c5e284..0bf6b28 100644
--- a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,20 +29,17 @@
 {
 namespace cpu
 {
-void neon_qasymm8_signed_softmax(const ITensor *in,
-                                 const ITensor *max,
-                                 void *const    tmp,
-                                 ITensor       *out,
-                                 const float    beta,
-                                 bool           is_log,
-                                 const Window  &window)
+template <bool IS_LOG>
+void neon_qasymm8_signed_softmax(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
 {
-    return neon_softmax_logits_1d_quantized<qasymm8_signed_t>(in, max, tmp, out, beta, is_log, window);
+    return neon_softmax_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, window);
 }
 
-void neon_qasymm8_singed_logits(const ITensor *in, ITensor *out, const Window &window)
-{
-    return neon_logits_1d_max<qasymm8_signed_t>(in, out, window);
-}
+template void neon_qasymm8_signed_softmax<true>(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void neon_qasymm8_signed_softmax<false>(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/sve/fp16.cpp b/src/cpu/kernels/softmax/generic/sve/fp16.cpp
deleted file mode 100644
index 5e94f72..0000000
--- a/src/cpu/kernels/softmax/generic/sve/fp16.cpp
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Copyright (c) 2021-2023 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
-#include "arm_compute/core/Helpers.h"
-
-#include "src/cpu/CpuTypes.h"
-#include "src/cpu/kernels/softmax/generic/sve/impl.h"
-namespace arm_compute
-{
-namespace cpu
-{
-void sve_fp16_softmax(const ITensor *in,
-                      const ITensor *max,
-                      void *const    tmp,
-                      ITensor       *out,
-                      const float    beta,
-                      bool           is_log,
-                      const Window  &window)
-{
-    return sve_softmax_logits_1d_float<float16_t>(in, max, tmp, out, beta, is_log, window);
-}
-
-void sve_fp16_logits(const ITensor *in, ITensor *out, const Window &window)
-{
-    return sve_logits_1d_max<float16_t>(in, out, window);
-}
-} // namespace cpu
-} // namespace arm_compute
-#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */
diff --git a/src/cpu/kernels/softmax/generic/sve/fp32.cpp b/src/cpu/kernels/softmax/generic/sve/fp32.cpp
deleted file mode 100644
index d692cc2..0000000
--- a/src/cpu/kernels/softmax/generic/sve/fp32.cpp
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Copyright (c) 2021-2022 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "arm_compute/core/Helpers.h"
-
-#include "src/cpu/kernels/softmax/generic/sve/impl.h"
-
-namespace arm_compute
-{
-namespace cpu
-{
-void sve_fp32_softmax(const ITensor *in,
-                      const ITensor *max,
-                      void *const    tmp,
-                      ITensor       *out,
-                      const float    beta,
-                      bool           is_log,
-                      const Window  &window)
-{
-    return sve_softmax_logits_1d_float<float>(in, max, tmp, out, beta, is_log, window);
-}
-
-void sve_fp32_logits(const ITensor *in, ITensor *out, const Window &window)
-{
-    return sve_logits_1d_max<float>(in, out, window);
-}
-} // namespace cpu
-} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/sve/impl.cpp b/src/cpu/kernels/softmax/generic/sve/impl.cpp
index 24f1bb8..0d4b7f4 100644
--- a/src/cpu/kernels/softmax/generic/sve/impl.cpp
+++ b/src/cpu/kernels/softmax/generic/sve/impl.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -30,6 +30,9 @@
 {
 namespace cpu
 {
+/// TODO: (COMPMID-6505) Similar to Neon(TM), this implementation be converted to
+/// a single kernel that performs softmax operation. Leaving the SVE code here for
+/// future references. Implementation for Neon(TM) is introduced in COMPMID-6500
 template <typename ScalarType>
 void sve_logits_1d_max(const ITensor *in, ITensor *out, const Window &window)
 {
@@ -172,25 +175,5 @@
         },
         in_it, max_it, out_it);
 }
-
-template void sve_logits_1d_max<float>(const ITensor *in, ITensor *out, const Window &window);
-template void sve_logits_1d_max<float16_t>(const ITensor *in, ITensor *out, const Window &window);
-template void sve_logits_1d_max<qasymm8_t>(const ITensor *in, ITensor *out, const Window &window);
-template void sve_logits_1d_max<qasymm8_signed_t>(const ITensor *in, ITensor *out, const Window &window);
-
-template void sve_softmax_logits_1d_float<float>(const ITensor *in,
-                                                 const ITensor *max,
-                                                 void *const    tmp,
-                                                 ITensor       *out,
-                                                 const float    beta,
-                                                 bool           is_log,
-                                                 const Window  &window);
-template void sve_softmax_logits_1d_float<float16_t>(const ITensor *in,
-                                                     const ITensor *max,
-                                                     void *const    tmp,
-                                                     ITensor       *out,
-                                                     const float    beta,
-                                                     bool           is_log,
-                                                     const Window  &window);
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/sve/qasymm8.cpp b/src/cpu/kernels/softmax/generic/sve/qasymm8.cpp
deleted file mode 100644
index 85e5ccf..0000000
--- a/src/cpu/kernels/softmax/generic/sve/qasymm8.cpp
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Copyright (c) 2021-2022 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "arm_compute/core/Helpers.h"
-
-#include "src/cpu/kernels/softmax/generic/sve/impl.h"
-
-namespace arm_compute
-{
-namespace cpu
-{
-void sve_qasymm8_logits(const ITensor *in, ITensor *out, const Window &window)
-{
-    return sve_logits_1d_max<qasymm8_t>(in, out, window);
-}
-} // namespace cpu
-} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/sve/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/sve/qasymm8_signed.cpp
deleted file mode 100644
index 4be2e2e..0000000
--- a/src/cpu/kernels/softmax/generic/sve/qasymm8_signed.cpp
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Copyright (c) 2021-2022 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "arm_compute/core/Helpers.h"
-
-#include "src/cpu/kernels/softmax/generic/sve/impl.h"
-
-namespace arm_compute
-{
-namespace cpu
-{
-void sve_qasymm8_signed_logits(const ITensor *in, ITensor *out, const Window &window)
-{
-    return sve_logits_1d_max<qasymm8_signed_t>(in, out, window);
-}
-} // namespace cpu
-} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/sve2/impl.cpp b/src/cpu/kernels/softmax/generic/sve2/impl.cpp
index 98b2f51..a8fb1d4 100644
--- a/src/cpu/kernels/softmax/generic/sve2/impl.cpp
+++ b/src/cpu/kernels/softmax/generic/sve2/impl.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -32,6 +32,9 @@
 {
 namespace cpu
 {
+/// TODO: (COMPMID-6505) Similar to Neon(TM), this implementation be converted to
+/// a single kernel that performs softmax operation. Leaving the SVE2 code here for
+/// future references. Implementation for Neon(TM) is introduced in COMPMID-6500
 template <typename ScalarType>
 void sve2_softmax_logits_1d_quantized(
     const ITensor *in, const ITensor *max, void *const tmp, ITensor *out, float beta, bool is_log, const Window &window)
@@ -205,20 +208,5 @@
         },
         in_it, max_it, out_it);
 }
-
-template void sve2_softmax_logits_1d_quantized<qasymm8_signed_t>(const ITensor *in,
-                                                                 const ITensor *max,
-                                                                 void *const    tmp,
-                                                                 ITensor       *out,
-                                                                 float          beta,
-                                                                 bool           is_log,
-                                                                 const Window  &window);
-template void sve2_softmax_logits_1d_quantized<qasymm8_t>(const ITensor *in,
-                                                          const ITensor *max,
-                                                          void *const    tmp,
-                                                          ITensor       *out,
-                                                          float          beta,
-                                                          bool           is_log,
-                                                          const Window  &window);
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/sve2/qasymm8.cpp b/src/cpu/kernels/softmax/generic/sve2/qasymm8.cpp
deleted file mode 100644
index 9562378..0000000
--- a/src/cpu/kernels/softmax/generic/sve2/qasymm8.cpp
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Copyright (c) 2021-2022 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "arm_compute/core/Helpers.h"
-
-#include "src/cpu/kernels/softmax/generic/sve2/impl.h"
-
-namespace arm_compute
-{
-namespace cpu
-{
-void sve2_qasymm8_softmax(const ITensor *in,
-                          const ITensor *max,
-                          void *const    tmp,
-                          ITensor       *out,
-                          const float    beta,
-                          bool           is_log,
-                          const Window  &window)
-{
-    return sve2_softmax_logits_1d_quantized<qasymm8_t>(in, max, tmp, out, beta, is_log, window);
-}
-} // namespace cpu
-} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/sve2/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/sve2/qasymm8_signed.cpp
deleted file mode 100644
index c20462f..0000000
--- a/src/cpu/kernels/softmax/generic/sve2/qasymm8_signed.cpp
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Copyright (c) 2021-2022 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-
-#include "arm_compute/core/Helpers.h"
-
-#include "src/cpu/kernels/softmax/generic/sve2/impl.h"
-
-namespace arm_compute
-{
-namespace cpu
-{
-void sve2_qasymm8_signed_softmax(const ITensor *in,
-                                 const ITensor *max,
-                                 void *const    tmp,
-                                 ITensor       *out,
-                                 const float    beta,
-                                 bool           is_log,
-                                 const Window  &window)
-{
-    return sve2_softmax_logits_1d_quantized<qasymm8_signed_t>(in, max, tmp, out, beta, is_log, window);
-}
-} // namespace cpu
-} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/list.h b/src/cpu/kernels/softmax/list.h
index 627ce0c..c143f66 100644
--- a/src/cpu/kernels/softmax/list.h
+++ b/src/cpu/kernels/softmax/list.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2022 Arm Limited.
+ * Copyright (c) 2021-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,41 +21,24 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef SRC_CORE_NEON_KERNELS_SOFTMAX_LIST_H
-#define SRC_CORE_NEON_KERNELS_SOFTMAX_LIST_H
+#ifndef ACL_SRC_CPU_KERNELS_SOFTMAX_LIST_H
+#define ACL_SRC_CPU_KERNELS_SOFTMAX_LIST_H
 
 namespace arm_compute
 {
 namespace cpu
 {
-#define DECLARE_SOFTMAX_KERNEL(func_name)                                                                  \
-    void func_name(const ITensor *in, const ITensor *max, void *const tmp, ITensor *out, const float beta, \
-                   bool is_log, const Window &window)
+#define DECLARE_SOFTMAX_KERNEL(func_name) \
+    template <bool IS_LOG>                \
+    void func_name(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
 
 DECLARE_SOFTMAX_KERNEL(neon_fp32_softmax);
 DECLARE_SOFTMAX_KERNEL(neon_fp16_softmax);
 DECLARE_SOFTMAX_KERNEL(neon_qasymm8_softmax);
 DECLARE_SOFTMAX_KERNEL(neon_qasymm8_signed_softmax);
-DECLARE_SOFTMAX_KERNEL(sve_fp32_softmax);
-DECLARE_SOFTMAX_KERNEL(sve_fp16_softmax);
-DECLARE_SOFTMAX_KERNEL(sve2_qasymm8_signed_softmax);
-DECLARE_SOFTMAX_KERNEL(sve2_qasymm8_softmax);
 
 #undef DECLARE_SOFTMAX_KERNEL
-
-#define DECLARE_LOGITS_KERNEL(func_name) void func_name(const ITensor *in, ITensor *out, const Window &window)
-
-DECLARE_LOGITS_KERNEL(neon_fp32_logits);
-DECLARE_LOGITS_KERNEL(neon_fp16_logits);
-DECLARE_LOGITS_KERNEL(neon_qasymm8_logits);
-DECLARE_LOGITS_KERNEL(neon_qasymm8_singed_logits);
-DECLARE_LOGITS_KERNEL(sve_fp32_logits);
-DECLARE_LOGITS_KERNEL(sve_fp16_logits);
-DECLARE_LOGITS_KERNEL(sve_qasymm8_logits);
-DECLARE_LOGITS_KERNEL(sve_qasymm8_signed_logits);
-
-#undef DECLARE_LOGITS_KERNEL
 } // namespace cpu
 } // namespace arm_compute
 
-#endif /* SRC_CORE_NEON_KERNELS_SOFTMAX_LIST_H */
+#endif // ACL_SRC_CPU_KERNELS_SOFTMAX_LIST_H