Optimize CpuSoftmaxKernel for axis != 0 and neon kernels

Resolves: COMPMID-6501
Signed-off-by: Omar Al Khatib <omar.alkhatib@arm.com>
Change-Id: I0abd3cbb5f861301f407c443988fb7efaa205b5d
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11056
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/kernels/softmax/generic/neon/fp16.cpp b/src/cpu/kernels/softmax/generic/neon/fp16.cpp
index db8f881..da62d2d 100644
--- a/src/cpu/kernels/softmax/generic/neon/fp16.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/fp16.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -33,15 +33,23 @@
 {
 
 template <bool IS_LOG>
-void neon_fp16_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+void neon_fp16_softmax(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
 {
-    return neon_softmax_float<float16_t, IS_LOG>(in, tmp, out, beta, window);
+    if (axis == 0)
+    {
+        return neon_softmax_x_float<float16_t, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
+    else
+    {
+        return neon_softmax_non_x_float<float16_t, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
 }
 
-template void
-neon_fp16_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
-template void
-neon_fp16_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void neon_fp16_softmax<true>(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_fp16_softmax<false>(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
 
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/fp32.cpp b/src/cpu/kernels/softmax/generic/neon/fp32.cpp
index c281d1b..0701620 100644
--- a/src/cpu/kernels/softmax/generic/neon/fp32.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/fp32.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -31,15 +31,23 @@
 {
 
 template <bool IS_LOG>
-void neon_fp32_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+void neon_fp32_softmax(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
 {
-    return neon_softmax_float<float, IS_LOG>(in, tmp, out, beta, window);
+    if (axis == 0)
+    {
+        return neon_softmax_x_float<float, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
+    else
+    {
+        return neon_softmax_non_x_float<float, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
 }
 
-template void
-neon_fp32_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
-template void
-neon_fp32_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void neon_fp32_softmax<true>(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_fp32_softmax<false>(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
 
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/impl.cpp b/src/cpu/kernels/softmax/generic/neon/impl.cpp
index 487f6ae..31baf8a 100644
--- a/src/cpu/kernels/softmax/generic/neon/impl.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/impl.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -30,8 +30,11 @@
 namespace cpu
 {
 template <typename T, bool IS_LOG>
-void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window)
+void neon_softmax_x_quantized(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
 {
+    ARM_COMPUTE_UNUSED(axis);
+
     static_assert(std::is_same<T, qasymm8_t>::value || std::is_same<T, qasymm8_signed_t>::value,
                   "quantized type should be either qasymm8_t or qasymm8_signed_t.");
 
@@ -248,16 +251,346 @@
         in_it, out_it);
 }
 
-template void neon_softmax_quantized<qasymm8_signed_t, true>(
-    const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+template <typename T, bool IS_LOG>
+void neon_softmax_non_x_quantized(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
+{
+    static_assert(std::is_same<T, qasymm8_t>::value || std::is_same<T, qasymm8_signed_t>::value,
+                  "quantized type should be either qasymm8_t or qasymm8_signed_t.");
 
-template void neon_softmax_quantized<qasymm8_signed_t, false>(
-    const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+    const float       scale_beta     = -beta * in->info()->quantization_info().uniform().scale;
+    const float32x4_t scale_beta_vec = vdupq_n_f32(scale_beta);
 
-template void neon_softmax_quantized<qasymm8_t, true>(
-    const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+    Iterator in_it(in, window);
+    Iterator out_it(out, window);
 
-template void neon_softmax_quantized<qasymm8_t, false>(
-    const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+    /** SIMD vector tag type. */
+    using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+    constexpr int      vec_size        = 16;
+    const ITensorInfo *in_info         = in->info();
+    const ITensorInfo *out_info        = out->info();
+    const int          x_width         = in_info->valid_region().shape.x();
+    const int          in_axis_stride  = in_info->strides_in_bytes()[axis];
+    const int          out_axis_stride = out_info->strides_in_bytes()[axis];
+    const int          tmp_axis_stride = in_axis_stride;
+    const int          axis_width      = in_info->dimension(axis);
+    const int          end_actual      = std::min(window[0].end(), x_width);
+
+    execute_window_loop(
+        window,
+        [&](const Coordinates &winCoords)
+        {
+            const bool vector_exceeds_bounds = ((winCoords[0] + vec_size) > end_actual);
+
+            int num_remaining         = (end_actual - winCoords[0]);
+            int num_remaining_full    = num_remaining / 4;
+            int num_remaining_partial = num_remaining % 4;
+
+            /* Get pointers */
+            const uint8_t *in_ptr  = in_it.ptr();
+            uint8_t       *out_ptr = out_it.ptr();
+            uint8_t       *tmp_ptr = reinterpret_cast<uint8_t *>(tmp);
+
+            auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
+
+            /* Compute Max */
+            {
+                if (!vector_exceeds_bounds)
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const auto current_value =
+                            wrapper::vloadq((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr));
+                        vec_max = wrapper::vmax(vec_max, current_value);
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const T *const base_ptr_in = ((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr));
+                        int            j           = 0;
+                        for (; j < num_remaining; ++j)
+                        {
+                            const T current_value = *(base_ptr_in + j);
+                            vec_max[j]            = std::max(vec_max[j], current_value);
+                        }
+                    }
+                }
+            } // Compute Max
+
+            float32x4x4_t vec_sum_transformed = {
+                vdupq_n_f32(0.f),
+                vdupq_n_f32(0.f),
+                vdupq_n_f32(0.f),
+                vdupq_n_f32(0.f),
+            };
+
+            /* Compute exponentials and sum */
+            {
+                /* Init sum to zero */
+                float32x4x4_t vec_sum = vec_sum_transformed;
+
+                auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+
+                float32x4x4_t vec_elements_flt;
+
+                if (!vector_exceeds_bounds)
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        vec_elements     = wrapper::vloadq((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr));
+                        vec_elements     = wrapper::vqsub(vec_max, vec_elements);
+                        vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
+
+                        if (IS_LOG)
+                        {
+                            vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
+                            vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
+                            vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
+                            vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
+                            vec_sum.val[0]          = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
+                            vec_sum.val[1]          = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
+                            vec_sum.val[2]          = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
+                            vec_sum.val[3]          = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
+                        }
+                        else
+                        {
+                            vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
+                            vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
+                            vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
+                            vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
+                            vec_sum.val[0]          = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
+                            vec_sum.val[1]          = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
+                            vec_sum.val[2]          = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
+                            vec_sum.val[3]          = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
+                        }
+                        vst4q_f32((i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr), vec_elements_flt);
+                    }
+
+                    auto vec_256 = wrapper::vdup_n(static_cast<float32_t>(256.f), ExactTagType{});
+                    if (!IS_LOG)
+                    {
+                        vec_sum_transformed.val[0] = wrapper::vdiv(vec_256, vec_sum.val[0]);
+                        vec_sum_transformed.val[1] = wrapper::vdiv(vec_256, vec_sum.val[1]);
+                        vec_sum_transformed.val[2] = wrapper::vdiv(vec_256, vec_sum.val[2]);
+                        vec_sum_transformed.val[3] = wrapper::vdiv(vec_256, vec_sum.val[3]);
+                    }
+                    else
+                    {
+                        vec_sum_transformed.val[0] = wrapper::vlog(vec_sum.val[0]);
+                        vec_sum_transformed.val[1] = wrapper::vlog(vec_sum.val[1]);
+                        vec_sum_transformed.val[2] = wrapper::vlog(vec_sum.val[2]);
+                        vec_sum_transformed.val[3] = wrapper::vlog(vec_sum.val[3]);
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const T *const base_ptr_in  = (i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr);
+                        auto           vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+                        //vec_els is functionally redundant but is needed as a workaround for a toolchain bug.
+                        std::vector<T> vec_els(16);
+
+                        for (int k = 0; k < num_remaining_full; ++k)
+                        {
+                            for (int j = 0; j < 4; ++j)
+                            {
+                                vec_els[k * 4 + j] = *(base_ptr_in + (4 * k + j));
+                            }
+                        }
+                        for (int j = 0; j < num_remaining_partial; ++j)
+                        {
+                            vec_els[num_remaining_full * 4 + j] = *(base_ptr_in + (4 * num_remaining_full + j));
+                        }
+                        for (int q = 0; q < 16; q++)
+                        {
+                            vec_elements[q] = vec_els[q];
+                        }
+                        vec_elements                   = wrapper::vqsub(vec_max, vec_elements);
+                        float32x4x4_t vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
+
+                        if (IS_LOG)
+                        {
+                            vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
+                            vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
+                            vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
+                            vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
+                            vec_sum.val[0]          = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
+                            vec_sum.val[1]          = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
+                            vec_sum.val[2]          = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
+                            vec_sum.val[3]          = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
+                        }
+                        else
+                        {
+                            vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
+                            vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
+                            vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
+                            vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
+                            vec_sum.val[0]          = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
+                            vec_sum.val[1]          = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
+                            vec_sum.val[2]          = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
+                            vec_sum.val[3]          = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
+                        }
+
+                        float *const base_ptr_tmp = (i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr);
+                        for (int k = 0; k < num_remaining_full; ++k)
+                        {
+                            for (int j = 0; j < 4; ++j)
+                            {
+                                *(base_ptr_tmp + (4 * k + j)) = vec_elements_flt.val[k][j];
+                            }
+                        }
+
+                        for (int j = 0; j < num_remaining_partial; ++j)
+                        {
+                            *(base_ptr_tmp + (4 * num_remaining_full + j)) =
+                                vec_elements_flt.val[num_remaining_full][j];
+                        }
+                    }
+
+                    auto vec_256 = wrapper::vdup_n(static_cast<float32_t>(256), ExactTagType{});
+                    if (!IS_LOG)
+                    {
+                        vec_sum_transformed.val[0] = wrapper::vdiv(vec_256, vec_sum.val[0]);
+                        vec_sum_transformed.val[1] = wrapper::vdiv(vec_256, vec_sum.val[1]);
+                        vec_sum_transformed.val[2] = wrapper::vdiv(vec_256, vec_sum.val[2]);
+                        vec_sum_transformed.val[3] = wrapper::vdiv(vec_256, vec_sum.val[3]);
+                    }
+                    else
+                    {
+                        vec_sum_transformed.val[0] = wrapper::vlog(vec_sum.val[0]);
+                        vec_sum_transformed.val[1] = wrapper::vlog(vec_sum.val[1]);
+                        vec_sum_transformed.val[2] = wrapper::vlog(vec_sum.val[2]);
+                        vec_sum_transformed.val[3] = wrapper::vlog(vec_sum.val[3]);
+                    }
+                }
+            } // Compute exponentials and sum
+
+            /* Normalize exponentials */
+            {
+                constexpr bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value;
+                if (!vector_exceeds_bounds)
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        using int_vec_type   = wrapper::traits::neon_vector_t<T, 16>;
+                        float32x4x4_t vec_in = vld4q_f32((i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr));
+
+                        int_vec_type normalized_value{};
+
+                        if (IS_LOG)
+                        {
+                            const float32x4x4_t sub = {
+                                vsubq_f32(vec_in.val[0], vec_sum_transformed.val[0]),
+                                vsubq_f32(vec_in.val[1], vec_sum_transformed.val[1]),
+                                vsubq_f32(vec_in.val[2], vec_sum_transformed.val[2]),
+                                vsubq_f32(vec_in.val[3], vec_sum_transformed.val[3]),
+                            };
+                            normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub);
+                        }
+                        else
+                        {
+                            float32x4x4_t mul = {
+                                vmulq_f32(vec_in.val[0], vec_sum_transformed.val[0]),
+                                vmulq_f32(vec_in.val[1], vec_sum_transformed.val[1]),
+                                vmulq_f32(vec_in.val[2], vec_sum_transformed.val[2]),
+                                vmulq_f32(vec_in.val[3], vec_sum_transformed.val[3]),
+                            };
+
+                            if (is_qasymm8_signed)
+                            {
+                                const auto offset_vec = wrapper::vdup_n(128.f, wrapper::traits::vector_128_tag{});
+                                mul.val[0]            = wrapper::vsub(mul.val[0], offset_vec);
+                                mul.val[1]            = wrapper::vsub(mul.val[1], offset_vec);
+                                mul.val[2]            = wrapper::vsub(mul.val[2], offset_vec);
+                                mul.val[3]            = wrapper::vsub(mul.val[3], offset_vec);
+                            }
+
+                            normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(mul);
+                        }
+                        wrapper::vstore((i * out_axis_stride) + reinterpret_cast<T *>(out_ptr), normalized_value);
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        T *const     base_ptr_out = (i * out_axis_stride) + reinterpret_cast<T *>(out_ptr);
+                        float *const base_ptr_tmp = (i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr);
+                        if (IS_LOG)
+                        {
+                            for (int k = 0; k < num_remaining_full; ++k)
+                            {
+                                for (int j = 0; j < 4; ++j)
+                                {
+                                    *(base_ptr_out + (4 * k + j)) = utils::cast::saturate_cast<T>(
+                                        (*(base_ptr_tmp + (4 * k + j)) - vec_sum_transformed.val[k][j]));
+                                }
+                            }
+                            for (int j = 0; j < num_remaining_partial; ++j)
+                            {
+                                *(base_ptr_out + (4 * num_remaining_full + j)) =
+                                    utils::cast::saturate_cast<T>(*(base_ptr_tmp + (4 * num_remaining_full + j)) -
+                                                                  vec_sum_transformed.val[num_remaining_full][j]);
+                            }
+                        }
+                        else
+                        {
+                            for (int k = 0; k < num_remaining_full; ++k)
+                            {
+                                for (int j = 0; j < 4; ++j)
+                                {
+                                    *(base_ptr_out + (4 * k + j)) = utils::cast::saturate_cast<T>(
+                                        *(base_ptr_tmp + (4 * k + j)) * vec_sum_transformed.val[k][j] -
+                                        (is_qasymm8_signed ? 128.f : 0));
+                                }
+                            }
+                            for (int j = 0; j < num_remaining_partial; ++j)
+                            {
+                                *(base_ptr_out + (4 * num_remaining_full + j)) =
+                                    utils::cast::saturate_cast<T>(*(base_ptr_tmp + (4 * num_remaining_full + j)) *
+                                                                      vec_sum_transformed.val[num_remaining_full][j] -
+                                                                  (is_qasymm8_signed ? 128.f : 0));
+                            }
+                        }
+                    }
+                }
+            } // Normalize exponentials
+        },
+        in_it, out_it);
+}
+
+template void neon_softmax_x_quantized<qasymm8_signed_t, true>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_x_quantized<qasymm8_signed_t, false>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_x_quantized<qasymm8_t, true>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_x_quantized<qasymm8_t, false>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_non_x_quantized<qasymm8_signed_t, true>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_non_x_quantized<qasymm8_signed_t, false>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_non_x_quantized<qasymm8_t, true>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_non_x_quantized<qasymm8_t, false>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/impl.h b/src/cpu/kernels/softmax/generic/neon/impl.h
index 60380cd..e417271 100644
--- a/src/cpu/kernels/softmax/generic/neon/impl.h
+++ b/src/cpu/kernels/softmax/generic/neon/impl.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -62,8 +62,9 @@
 // The template implementation for float data types is stored in the header file because
 // we need all fp16 instantiated code to live in fp16.cpp files.
 template <typename T, bool IS_LOG>
-void neon_softmax_float(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window)
+void neon_softmax_x_float(const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
 {
+    ARM_COMPUTE_UNUSED(axis);
     ARM_COMPUTE_UNUSED(tmp);
 
     const int input_width = in->info()->valid_region().shape.x();
@@ -228,9 +229,199 @@
         },
         in_it, out_it);
 }
+template <typename T, bool IS_LOG>
+void neon_softmax_non_x_float(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
+{
+    ARM_COMPUTE_UNUSED(tmp);
+
+    Iterator in_it(in, window);
+    Iterator out_it(out, window);
+
+    /** SIMD vector tag type. */
+    using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+    const auto         beta_vec        = wrapper::vdup_n(static_cast<T>(beta), ExactTagType{});
+    constexpr int      vec_size        = 16 / sizeof(T);
+    const ITensorInfo *in_info         = in->info();
+    const ITensorInfo *out_info        = out->info();
+    const int          x_width         = in_info->valid_region().shape.x();
+    const unsigned int in_axis_stride  = in_info->strides_in_bytes()[axis];
+    const unsigned int out_axis_stride = out_info->strides_in_bytes()[axis];
+    const int          axis_width      = in_info->dimension(axis);
+
+    execute_window_loop(
+        window,
+        [&](const Coordinates &winCoords)
+        {
+            const bool vector_exceeds_bounds = (winCoords[0] + vec_size) > x_width;
+
+            /* Get pointers */
+            const uint8_t *in_ptr  = in_it.ptr();
+            uint8_t       *out_ptr = out_it.ptr();
+
+            // Init max value
+            auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
+
+            /* Compute Max */
+            {
+                if (!vector_exceeds_bounds)
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const auto current_value =
+                            wrapper::vloadq(reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr));
+                        vec_max = wrapper::vmax(vec_max, current_value);
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const T *const base_ptr_in = reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr);
+                        int            j           = 0;
+                        for (; j < (x_width - winCoords[0]); ++j)
+                        {
+                            const auto current_value = *(base_ptr_in + j);
+                            vec_max[j]               = std::max(vec_max[j], current_value);
+                        }
+                    }
+                }
+            } // compute max
+
+            auto vec_sum_transformed = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+
+            auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+            /* Init sum to zero */
+            auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+
+            /* Compute exponentials and sum */
+            {
+                if (!vector_exceeds_bounds)
+                {
+                    const auto vec_one = wrapper::vdup_n(static_cast<T>(1), ExactTagType{});
+                    /* Loop over row and compute exponentials and sum */
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        vec_elements = wrapper::vloadq(reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr));
+                        vec_elements = wrapper::vsub(vec_elements, vec_max);
+                        if (IS_LOG)
+                        {
+                            vec_elements = wrapper::vmul(vec_elements, beta_vec);
+                            vec_sum      = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
+                        }
+                        else
+                        {
+                            vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, beta_vec));
+                            vec_sum      = wrapper::vadd(vec_sum, vec_elements);
+                        }
+
+                        wrapper::vstore(reinterpret_cast<T *>((i * out_axis_stride) + out_ptr), vec_elements);
+                    }
+
+                    if (!IS_LOG)
+                    {
+                        vec_sum_transformed = wrapper::vdiv(vec_one, vec_sum);
+                    }
+                    else
+                    {
+                        vec_sum_transformed = wrapper::vlog(vec_sum);
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const T *const base_ptr_in  = reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr);
+                        T *const       base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+                        int            j            = 0;
+                        for (; j < (x_width - winCoords[0]); ++j)
+                        {
+                            vec_elements[j] = *(base_ptr_in + j);
+                            vec_elements[j] -= vec_max[j];
+                            if (IS_LOG)
+                            {
+                                vec_elements[j] *= beta;
+                                vec_sum[j] += std::exp(vec_elements[j]);
+                            }
+                            else
+                            {
+                                vec_elements[j] = std::exp(vec_elements[j] * beta);
+                                vec_sum[j] += vec_elements[j];
+                            }
+                            *(base_ptr_out + j) = vec_elements[j];
+                        }
+                    }
+                    int j = 0;
+                    for (; j < (x_width - winCoords[0]); ++j)
+                    {
+                        if (!IS_LOG)
+                        {
+                            vec_sum_transformed[j] = 1 / vec_sum[j];
+                        }
+                        else
+                        {
+                            vec_sum_transformed[j] = std::log(vec_sum[j]);
+                        }
+                    }
+                }
+            } // Compute exponentials and sum
+
+            /* Normalize exponentials */
+            {
+                if (!vector_exceeds_bounds)
+                {
+                    /* Loop over row and compute softmax */
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+                        auto     vec_in       = wrapper::vloadq(base_ptr_out);
+                        if (IS_LOG)
+                        {
+                            wrapper::vstore(base_ptr_out, wrapper::vsub(vec_in, vec_sum_transformed));
+                        }
+                        else
+                        {
+                            wrapper::vstore(base_ptr_out, wrapper::vmul(vec_in, vec_sum_transformed));
+                        }
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+                        int      j            = 0;
+                        for (; j < (x_width - winCoords[0]); ++j)
+                        {
+                            if (IS_LOG)
+                            {
+                                *(base_ptr_out + j) -= vec_sum_transformed[j];
+                            }
+                            else
+                            {
+                                *(base_ptr_out + j) *= vec_sum_transformed[j];
+                            }
+                        }
+                    }
+                }
+            } // Normalize exponentials
+        },
+        in_it, out_it);
+}
+template <typename T, bool IS_LOG>
+void neon_softmax_x_quantized(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
 
 template <typename T, bool IS_LOG>
-void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+void neon_softmax_non_x_quantized(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
 } // namespace cpu
 } // namespace arm_compute
 
diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
index 9589ebc..d39240b 100644
--- a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -30,15 +30,23 @@
 namespace cpu
 {
 template <bool IS_LOG>
-void neon_qasymm8_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+void neon_qasymm8_softmax(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
 {
-    return neon_softmax_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, window);
+    if (axis == 0)
+    {
+        return neon_softmax_x_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
+    else
+    {
+        return neon_softmax_non_x_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
 }
 
-template void
-neon_qasymm8_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
-template void
-neon_qasymm8_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void neon_qasymm8_softmax<true>(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_qasymm8_softmax<false>(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
 
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
index 0bf6b28..26fd5db 100644
--- a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -31,15 +31,22 @@
 {
 template <bool IS_LOG>
 void neon_qasymm8_signed_softmax(
-    const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
 {
-    return neon_softmax_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, window);
+    if (axis == 0)
+    {
+        return neon_softmax_x_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
+    else
+    {
+        return neon_softmax_non_x_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
 }
 
 template void neon_qasymm8_signed_softmax<true>(
-    const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
 template void neon_qasymm8_signed_softmax<false>(
-    const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
 
 } // namespace cpu
 } // namespace arm_compute