Optimize CpuSoftmaxKernel for axis != 0 and neon kernels

Resolves: COMPMID-6501
Signed-off-by: Omar Al Khatib <omar.alkhatib@arm.com>
Change-Id: I0abd3cbb5f861301f407c443988fb7efaa205b5d
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11056
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox
index bc7d2cb..2d46737 100644
--- a/docs/user_guide/release_version_and_change_log.dox
+++ b/docs/user_guide/release_version_and_change_log.dox
@@ -44,6 +44,8 @@
 v24.04 Public major release
  - Optimize start-up time of @ref NEConvolutionLayer for some input configurations where GeMM is selected as the convolution algorithm
  - Optimize @ref NEConvolutionLayer for input tensor size > 1e7 bytes and weight tensor height > 7
+ - Performance optimizations:
+  - Optimize @ref NESoftmaxLayer for axis != 0 by natively supporting higher axes up to axis 3.
 
 v24.02.1 Public patch release
  - Fix performance regression in fixed-format kernels
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp
index 68bc397..54ff858 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.cpp
+++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp
@@ -81,7 +81,7 @@
 };
 
 Status validate_arguments_softmax(
-    const ITensorInfo &src, const ITensorInfo &dst, float beta, const ITensorInfo &tmp, bool is_log)
+    const ITensorInfo &src, const ITensorInfo &dst, float beta, int axis, const ITensorInfo &tmp, bool is_log)
 {
     ARM_COMPUTE_UNUSED(beta);
     // Check input
@@ -89,6 +89,8 @@
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
                                                          DataType::F16, DataType::F32);
 
+    ARM_COMPUTE_RETURN_ERROR_ON(axis < 0 || axis > 3);
+
     const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src.data_type());
 
     // Check output if configured
@@ -124,10 +126,13 @@
     return available_kernels;
 }
 
-void CpuSoftmaxKernel::configure(const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, ITensorInfo *tmp)
+void CpuSoftmaxKernel::configure(
+    const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, int axis, ITensorInfo *tmp)
 {
+    _axis = axis;
+
     ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp);
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_softmax(*src, *dst, beta, *tmp, is_log));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log));
 
     // Configure kernel window
     const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src->data_type());
@@ -154,25 +159,40 @@
     _run_method = uk->ukernel;
     _name       = kernel_name.append("/").append(uk->name);
 
-    Window win = calculate_max_window(*dst, Steps());
+    Window win;
 
-    /// TODO: Check dimensions > 0 for holes only. For this, we need
-    /// a utility function checking if there are holes after some dimension.
-    if (!has_holes(*dst, dst->num_dimensions() - 1))
+    int vec_size = 16 / dst->element_size();
+
+    if (_axis == 0)
     {
-        win = win.collapse(win, Window::DimY);
+        win = calculate_max_window(*dst, Steps());
+
+        /// TODO:Check dimensions > 0 for holes only. For this, we need
+        /// a utility function checking if there are holes after some dimension.
+        if (!has_holes(*dst, dst->num_dimensions() - 1))
+        {
+            win = win.collapse(win, Window::DimY);
+        }
+    }
+    else if (_axis > 0 && _axis <= 3)
+    {
+        win = calculate_max_window(*dst, Steps(vec_size));
+    }
+    else
+    {
+        ARM_COMPUTE_ERROR("Invalid axis");
     }
 
-    win.set(Window::DimX, Window::Dimension(0, 1, 1)); // First dimension is the reduction axis
+    win.set(_axis, Window::Dimension(0, 1, 1));
 
     ICpuKernel<CpuSoftmaxKernel>::configure(win);
 }
 
 Status CpuSoftmaxKernel::validate(
-    const ITensorInfo *src, const ITensorInfo *dst, float beta, bool is_log, const ITensorInfo *tmp)
+    const ITensorInfo *src, const ITensorInfo *dst, float beta, int axis, bool is_log, const ITensorInfo *tmp)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp);
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_softmax(*src, *dst, beta, *tmp, is_log));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log));
 
     return Status{};
 }
@@ -188,19 +208,25 @@
 
     if (is_data_type_quantized_asymmetric(src->info()->data_type()))
     {
-        auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
-
-        const unsigned int num_elems_processed_per_iteration = src->info()->valid_region().shape.x();
+        auto         tmp = tensors.get_tensor(TensorType::ACL_DST_1);
+        unsigned int num_elems_processed_per_iteration;
+        if (_axis == 0)
+        {
+            num_elems_processed_per_iteration = src->info()->valid_region().shape[_axis];
+        }
+        else
+        {
+            //16 QASYMM8/QASYMM8_SIGNED elements can fit into the 16-byte vectors.
+            num_elems_processed_per_iteration = 16;
+        }
         const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration;
 
-        ARM_COMPUTE_ERROR_ON(tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
-
         void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread);
-        _run_method(src, tmp_for_thread, dst, _beta, window);
+        _run_method(src, tmp_for_thread, dst, _beta, _axis, window);
     }
     else
     {
-        _run_method(src, nullptr, dst, _beta, window);
+        _run_method(src, nullptr, dst, _beta, _axis, window);
     }
 }
 
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.h b/src/cpu/kernels/CpuSoftmaxKernel.h
index 3db1f3d..043ad97 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.h
+++ b/src/cpu/kernels/CpuSoftmaxKernel.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -38,7 +38,7 @@
 {
 private:
     using SoftmaxKernelPtr =
-        std::add_pointer<void(const ITensor *, void *const, ITensor *, float, const Window &)>::type;
+        std::add_pointer<void(const ITensor *, void *const, ITensor *, float, int, const Window &)>::type;
 
 public:
     CpuSoftmaxKernel() = default;
@@ -49,11 +49,12 @@
      * @param[in]  src    Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
      * @param[out] dst    Destination tensor info. Data types supported: same as @p input.
      * @param[in]  beta   A scaling factor for the exponent.
-     * @param[in]  is_log True if the operation is log-softmax
+     * @param[in]  is_log True if the operation is log-softmax.
+     * @param[in]  axis   The axis along which to perform the softmax operation.
      *
      * @param      tmp    Auxiliary tensor info. Must be type F32 and same shape as the input.
      */
-    void configure(const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, ITensorInfo *tmp);
+    void configure(const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, int axis, ITensorInfo *tmp);
     /** Static function to check if given info will lead to a valid configuration
      *
      * Similar to CpuSoftmaxKernel::configure()
@@ -61,7 +62,7 @@
      * @return a status
      */
     static Status
-    validate(const ITensorInfo *src, const ITensorInfo *dst, float beta, bool is_log, const ITensorInfo *tmp);
+    validate(const ITensorInfo *src, const ITensorInfo *dst, float beta, int axis, bool is_log, const ITensorInfo *tmp);
 
     // Inherited methods overridden:
     void        run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
@@ -80,6 +81,7 @@
     float            _beta{1.0f};
     SoftmaxKernelPtr _run_method{nullptr};
     std::string      _name{};
+    int              _axis{};
 };
 } // namespace kernels
 } // namespace cpu
diff --git a/src/cpu/kernels/softmax/generic/neon/fp16.cpp b/src/cpu/kernels/softmax/generic/neon/fp16.cpp
index db8f881..da62d2d 100644
--- a/src/cpu/kernels/softmax/generic/neon/fp16.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/fp16.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -33,15 +33,23 @@
 {
 
 template <bool IS_LOG>
-void neon_fp16_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+void neon_fp16_softmax(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
 {
-    return neon_softmax_float<float16_t, IS_LOG>(in, tmp, out, beta, window);
+    if (axis == 0)
+    {
+        return neon_softmax_x_float<float16_t, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
+    else
+    {
+        return neon_softmax_non_x_float<float16_t, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
 }
 
-template void
-neon_fp16_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
-template void
-neon_fp16_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void neon_fp16_softmax<true>(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_fp16_softmax<false>(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
 
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/fp32.cpp b/src/cpu/kernels/softmax/generic/neon/fp32.cpp
index c281d1b..0701620 100644
--- a/src/cpu/kernels/softmax/generic/neon/fp32.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/fp32.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -31,15 +31,23 @@
 {
 
 template <bool IS_LOG>
-void neon_fp32_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+void neon_fp32_softmax(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
 {
-    return neon_softmax_float<float, IS_LOG>(in, tmp, out, beta, window);
+    if (axis == 0)
+    {
+        return neon_softmax_x_float<float, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
+    else
+    {
+        return neon_softmax_non_x_float<float, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
 }
 
-template void
-neon_fp32_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
-template void
-neon_fp32_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void neon_fp32_softmax<true>(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_fp32_softmax<false>(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
 
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/impl.cpp b/src/cpu/kernels/softmax/generic/neon/impl.cpp
index 487f6ae..31baf8a 100644
--- a/src/cpu/kernels/softmax/generic/neon/impl.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/impl.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -30,8 +30,11 @@
 namespace cpu
 {
 template <typename T, bool IS_LOG>
-void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window)
+void neon_softmax_x_quantized(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
 {
+    ARM_COMPUTE_UNUSED(axis);
+
     static_assert(std::is_same<T, qasymm8_t>::value || std::is_same<T, qasymm8_signed_t>::value,
                   "quantized type should be either qasymm8_t or qasymm8_signed_t.");
 
@@ -248,16 +251,346 @@
         in_it, out_it);
 }
 
-template void neon_softmax_quantized<qasymm8_signed_t, true>(
-    const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+template <typename T, bool IS_LOG>
+void neon_softmax_non_x_quantized(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
+{
+    static_assert(std::is_same<T, qasymm8_t>::value || std::is_same<T, qasymm8_signed_t>::value,
+                  "quantized type should be either qasymm8_t or qasymm8_signed_t.");
 
-template void neon_softmax_quantized<qasymm8_signed_t, false>(
-    const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+    const float       scale_beta     = -beta * in->info()->quantization_info().uniform().scale;
+    const float32x4_t scale_beta_vec = vdupq_n_f32(scale_beta);
 
-template void neon_softmax_quantized<qasymm8_t, true>(
-    const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+    Iterator in_it(in, window);
+    Iterator out_it(out, window);
 
-template void neon_softmax_quantized<qasymm8_t, false>(
-    const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+    /** SIMD vector tag type. */
+    using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+    constexpr int      vec_size        = 16;
+    const ITensorInfo *in_info         = in->info();
+    const ITensorInfo *out_info        = out->info();
+    const int          x_width         = in_info->valid_region().shape.x();
+    const int          in_axis_stride  = in_info->strides_in_bytes()[axis];
+    const int          out_axis_stride = out_info->strides_in_bytes()[axis];
+    const int          tmp_axis_stride = in_axis_stride;
+    const int          axis_width      = in_info->dimension(axis);
+    const int          end_actual      = std::min(window[0].end(), x_width);
+
+    execute_window_loop(
+        window,
+        [&](const Coordinates &winCoords)
+        {
+            const bool vector_exceeds_bounds = ((winCoords[0] + vec_size) > end_actual);
+
+            int num_remaining         = (end_actual - winCoords[0]);
+            int num_remaining_full    = num_remaining / 4;
+            int num_remaining_partial = num_remaining % 4;
+
+            /* Get pointers */
+            const uint8_t *in_ptr  = in_it.ptr();
+            uint8_t       *out_ptr = out_it.ptr();
+            uint8_t       *tmp_ptr = reinterpret_cast<uint8_t *>(tmp);
+
+            auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
+
+            /* Compute Max */
+            {
+                if (!vector_exceeds_bounds)
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const auto current_value =
+                            wrapper::vloadq((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr));
+                        vec_max = wrapper::vmax(vec_max, current_value);
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const T *const base_ptr_in = ((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr));
+                        int            j           = 0;
+                        for (; j < num_remaining; ++j)
+                        {
+                            const T current_value = *(base_ptr_in + j);
+                            vec_max[j]            = std::max(vec_max[j], current_value);
+                        }
+                    }
+                }
+            } // Compute Max
+
+            float32x4x4_t vec_sum_transformed = {
+                vdupq_n_f32(0.f),
+                vdupq_n_f32(0.f),
+                vdupq_n_f32(0.f),
+                vdupq_n_f32(0.f),
+            };
+
+            /* Compute exponentials and sum */
+            {
+                /* Init sum to zero */
+                float32x4x4_t vec_sum = vec_sum_transformed;
+
+                auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+
+                float32x4x4_t vec_elements_flt;
+
+                if (!vector_exceeds_bounds)
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        vec_elements     = wrapper::vloadq((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr));
+                        vec_elements     = wrapper::vqsub(vec_max, vec_elements);
+                        vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
+
+                        if (IS_LOG)
+                        {
+                            vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
+                            vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
+                            vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
+                            vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
+                            vec_sum.val[0]          = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
+                            vec_sum.val[1]          = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
+                            vec_sum.val[2]          = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
+                            vec_sum.val[3]          = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
+                        }
+                        else
+                        {
+                            vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
+                            vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
+                            vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
+                            vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
+                            vec_sum.val[0]          = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
+                            vec_sum.val[1]          = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
+                            vec_sum.val[2]          = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
+                            vec_sum.val[3]          = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
+                        }
+                        vst4q_f32((i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr), vec_elements_flt);
+                    }
+
+                    auto vec_256 = wrapper::vdup_n(static_cast<float32_t>(256.f), ExactTagType{});
+                    if (!IS_LOG)
+                    {
+                        vec_sum_transformed.val[0] = wrapper::vdiv(vec_256, vec_sum.val[0]);
+                        vec_sum_transformed.val[1] = wrapper::vdiv(vec_256, vec_sum.val[1]);
+                        vec_sum_transformed.val[2] = wrapper::vdiv(vec_256, vec_sum.val[2]);
+                        vec_sum_transformed.val[3] = wrapper::vdiv(vec_256, vec_sum.val[3]);
+                    }
+                    else
+                    {
+                        vec_sum_transformed.val[0] = wrapper::vlog(vec_sum.val[0]);
+                        vec_sum_transformed.val[1] = wrapper::vlog(vec_sum.val[1]);
+                        vec_sum_transformed.val[2] = wrapper::vlog(vec_sum.val[2]);
+                        vec_sum_transformed.val[3] = wrapper::vlog(vec_sum.val[3]);
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const T *const base_ptr_in  = (i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr);
+                        auto           vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+                        //vec_els is functionally redundant but is needed as a workaround for a toolchain bug.
+                        std::vector<T> vec_els(16);
+
+                        for (int k = 0; k < num_remaining_full; ++k)
+                        {
+                            for (int j = 0; j < 4; ++j)
+                            {
+                                vec_els[k * 4 + j] = *(base_ptr_in + (4 * k + j));
+                            }
+                        }
+                        for (int j = 0; j < num_remaining_partial; ++j)
+                        {
+                            vec_els[num_remaining_full * 4 + j] = *(base_ptr_in + (4 * num_remaining_full + j));
+                        }
+                        for (int q = 0; q < 16; q++)
+                        {
+                            vec_elements[q] = vec_els[q];
+                        }
+                        vec_elements                   = wrapper::vqsub(vec_max, vec_elements);
+                        float32x4x4_t vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
+
+                        if (IS_LOG)
+                        {
+                            vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
+                            vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
+                            vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
+                            vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
+                            vec_sum.val[0]          = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
+                            vec_sum.val[1]          = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
+                            vec_sum.val[2]          = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
+                            vec_sum.val[3]          = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
+                        }
+                        else
+                        {
+                            vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
+                            vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
+                            vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
+                            vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
+                            vec_sum.val[0]          = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
+                            vec_sum.val[1]          = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
+                            vec_sum.val[2]          = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
+                            vec_sum.val[3]          = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
+                        }
+
+                        float *const base_ptr_tmp = (i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr);
+                        for (int k = 0; k < num_remaining_full; ++k)
+                        {
+                            for (int j = 0; j < 4; ++j)
+                            {
+                                *(base_ptr_tmp + (4 * k + j)) = vec_elements_flt.val[k][j];
+                            }
+                        }
+
+                        for (int j = 0; j < num_remaining_partial; ++j)
+                        {
+                            *(base_ptr_tmp + (4 * num_remaining_full + j)) =
+                                vec_elements_flt.val[num_remaining_full][j];
+                        }
+                    }
+
+                    auto vec_256 = wrapper::vdup_n(static_cast<float32_t>(256), ExactTagType{});
+                    if (!IS_LOG)
+                    {
+                        vec_sum_transformed.val[0] = wrapper::vdiv(vec_256, vec_sum.val[0]);
+                        vec_sum_transformed.val[1] = wrapper::vdiv(vec_256, vec_sum.val[1]);
+                        vec_sum_transformed.val[2] = wrapper::vdiv(vec_256, vec_sum.val[2]);
+                        vec_sum_transformed.val[3] = wrapper::vdiv(vec_256, vec_sum.val[3]);
+                    }
+                    else
+                    {
+                        vec_sum_transformed.val[0] = wrapper::vlog(vec_sum.val[0]);
+                        vec_sum_transformed.val[1] = wrapper::vlog(vec_sum.val[1]);
+                        vec_sum_transformed.val[2] = wrapper::vlog(vec_sum.val[2]);
+                        vec_sum_transformed.val[3] = wrapper::vlog(vec_sum.val[3]);
+                    }
+                }
+            } // Compute exponentials and sum
+
+            /* Normalize exponentials */
+            {
+                constexpr bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value;
+                if (!vector_exceeds_bounds)
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        using int_vec_type   = wrapper::traits::neon_vector_t<T, 16>;
+                        float32x4x4_t vec_in = vld4q_f32((i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr));
+
+                        int_vec_type normalized_value{};
+
+                        if (IS_LOG)
+                        {
+                            const float32x4x4_t sub = {
+                                vsubq_f32(vec_in.val[0], vec_sum_transformed.val[0]),
+                                vsubq_f32(vec_in.val[1], vec_sum_transformed.val[1]),
+                                vsubq_f32(vec_in.val[2], vec_sum_transformed.val[2]),
+                                vsubq_f32(vec_in.val[3], vec_sum_transformed.val[3]),
+                            };
+                            normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub);
+                        }
+                        else
+                        {
+                            float32x4x4_t mul = {
+                                vmulq_f32(vec_in.val[0], vec_sum_transformed.val[0]),
+                                vmulq_f32(vec_in.val[1], vec_sum_transformed.val[1]),
+                                vmulq_f32(vec_in.val[2], vec_sum_transformed.val[2]),
+                                vmulq_f32(vec_in.val[3], vec_sum_transformed.val[3]),
+                            };
+
+                            if (is_qasymm8_signed)
+                            {
+                                const auto offset_vec = wrapper::vdup_n(128.f, wrapper::traits::vector_128_tag{});
+                                mul.val[0]            = wrapper::vsub(mul.val[0], offset_vec);
+                                mul.val[1]            = wrapper::vsub(mul.val[1], offset_vec);
+                                mul.val[2]            = wrapper::vsub(mul.val[2], offset_vec);
+                                mul.val[3]            = wrapper::vsub(mul.val[3], offset_vec);
+                            }
+
+                            normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(mul);
+                        }
+                        wrapper::vstore((i * out_axis_stride) + reinterpret_cast<T *>(out_ptr), normalized_value);
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        T *const     base_ptr_out = (i * out_axis_stride) + reinterpret_cast<T *>(out_ptr);
+                        float *const base_ptr_tmp = (i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr);
+                        if (IS_LOG)
+                        {
+                            for (int k = 0; k < num_remaining_full; ++k)
+                            {
+                                for (int j = 0; j < 4; ++j)
+                                {
+                                    *(base_ptr_out + (4 * k + j)) = utils::cast::saturate_cast<T>(
+                                        (*(base_ptr_tmp + (4 * k + j)) - vec_sum_transformed.val[k][j]));
+                                }
+                            }
+                            for (int j = 0; j < num_remaining_partial; ++j)
+                            {
+                                *(base_ptr_out + (4 * num_remaining_full + j)) =
+                                    utils::cast::saturate_cast<T>(*(base_ptr_tmp + (4 * num_remaining_full + j)) -
+                                                                  vec_sum_transformed.val[num_remaining_full][j]);
+                            }
+                        }
+                        else
+                        {
+                            for (int k = 0; k < num_remaining_full; ++k)
+                            {
+                                for (int j = 0; j < 4; ++j)
+                                {
+                                    *(base_ptr_out + (4 * k + j)) = utils::cast::saturate_cast<T>(
+                                        *(base_ptr_tmp + (4 * k + j)) * vec_sum_transformed.val[k][j] -
+                                        (is_qasymm8_signed ? 128.f : 0));
+                                }
+                            }
+                            for (int j = 0; j < num_remaining_partial; ++j)
+                            {
+                                *(base_ptr_out + (4 * num_remaining_full + j)) =
+                                    utils::cast::saturate_cast<T>(*(base_ptr_tmp + (4 * num_remaining_full + j)) *
+                                                                      vec_sum_transformed.val[num_remaining_full][j] -
+                                                                  (is_qasymm8_signed ? 128.f : 0));
+                            }
+                        }
+                    }
+                }
+            } // Normalize exponentials
+        },
+        in_it, out_it);
+}
+
+template void neon_softmax_x_quantized<qasymm8_signed_t, true>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_x_quantized<qasymm8_signed_t, false>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_x_quantized<qasymm8_t, true>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_x_quantized<qasymm8_t, false>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_non_x_quantized<qasymm8_signed_t, true>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_non_x_quantized<qasymm8_signed_t, false>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_non_x_quantized<qasymm8_t, true>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_non_x_quantized<qasymm8_t, false>(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/impl.h b/src/cpu/kernels/softmax/generic/neon/impl.h
index 60380cd..e417271 100644
--- a/src/cpu/kernels/softmax/generic/neon/impl.h
+++ b/src/cpu/kernels/softmax/generic/neon/impl.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -62,8 +62,9 @@
 // The template implementation for float data types is stored in the header file because
 // we need all fp16 instantiated code to live in fp16.cpp files.
 template <typename T, bool IS_LOG>
-void neon_softmax_float(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window)
+void neon_softmax_x_float(const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
 {
+    ARM_COMPUTE_UNUSED(axis);
     ARM_COMPUTE_UNUSED(tmp);
 
     const int input_width = in->info()->valid_region().shape.x();
@@ -228,9 +229,199 @@
         },
         in_it, out_it);
 }
+template <typename T, bool IS_LOG>
+void neon_softmax_non_x_float(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
+{
+    ARM_COMPUTE_UNUSED(tmp);
+
+    Iterator in_it(in, window);
+    Iterator out_it(out, window);
+
+    /** SIMD vector tag type. */
+    using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+    const auto         beta_vec        = wrapper::vdup_n(static_cast<T>(beta), ExactTagType{});
+    constexpr int      vec_size        = 16 / sizeof(T);
+    const ITensorInfo *in_info         = in->info();
+    const ITensorInfo *out_info        = out->info();
+    const int          x_width         = in_info->valid_region().shape.x();
+    const unsigned int in_axis_stride  = in_info->strides_in_bytes()[axis];
+    const unsigned int out_axis_stride = out_info->strides_in_bytes()[axis];
+    const int          axis_width      = in_info->dimension(axis);
+
+    execute_window_loop(
+        window,
+        [&](const Coordinates &winCoords)
+        {
+            const bool vector_exceeds_bounds = (winCoords[0] + vec_size) > x_width;
+
+            /* Get pointers */
+            const uint8_t *in_ptr  = in_it.ptr();
+            uint8_t       *out_ptr = out_it.ptr();
+
+            // Init max value
+            auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
+
+            /* Compute Max */
+            {
+                if (!vector_exceeds_bounds)
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const auto current_value =
+                            wrapper::vloadq(reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr));
+                        vec_max = wrapper::vmax(vec_max, current_value);
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const T *const base_ptr_in = reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr);
+                        int            j           = 0;
+                        for (; j < (x_width - winCoords[0]); ++j)
+                        {
+                            const auto current_value = *(base_ptr_in + j);
+                            vec_max[j]               = std::max(vec_max[j], current_value);
+                        }
+                    }
+                }
+            } // compute max
+
+            auto vec_sum_transformed = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+
+            auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+            /* Init sum to zero */
+            auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+
+            /* Compute exponentials and sum */
+            {
+                if (!vector_exceeds_bounds)
+                {
+                    const auto vec_one = wrapper::vdup_n(static_cast<T>(1), ExactTagType{});
+                    /* Loop over row and compute exponentials and sum */
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        vec_elements = wrapper::vloadq(reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr));
+                        vec_elements = wrapper::vsub(vec_elements, vec_max);
+                        if (IS_LOG)
+                        {
+                            vec_elements = wrapper::vmul(vec_elements, beta_vec);
+                            vec_sum      = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
+                        }
+                        else
+                        {
+                            vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, beta_vec));
+                            vec_sum      = wrapper::vadd(vec_sum, vec_elements);
+                        }
+
+                        wrapper::vstore(reinterpret_cast<T *>((i * out_axis_stride) + out_ptr), vec_elements);
+                    }
+
+                    if (!IS_LOG)
+                    {
+                        vec_sum_transformed = wrapper::vdiv(vec_one, vec_sum);
+                    }
+                    else
+                    {
+                        vec_sum_transformed = wrapper::vlog(vec_sum);
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const T *const base_ptr_in  = reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr);
+                        T *const       base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+                        int            j            = 0;
+                        for (; j < (x_width - winCoords[0]); ++j)
+                        {
+                            vec_elements[j] = *(base_ptr_in + j);
+                            vec_elements[j] -= vec_max[j];
+                            if (IS_LOG)
+                            {
+                                vec_elements[j] *= beta;
+                                vec_sum[j] += std::exp(vec_elements[j]);
+                            }
+                            else
+                            {
+                                vec_elements[j] = std::exp(vec_elements[j] * beta);
+                                vec_sum[j] += vec_elements[j];
+                            }
+                            *(base_ptr_out + j) = vec_elements[j];
+                        }
+                    }
+                    int j = 0;
+                    for (; j < (x_width - winCoords[0]); ++j)
+                    {
+                        if (!IS_LOG)
+                        {
+                            vec_sum_transformed[j] = 1 / vec_sum[j];
+                        }
+                        else
+                        {
+                            vec_sum_transformed[j] = std::log(vec_sum[j]);
+                        }
+                    }
+                }
+            } // Compute exponentials and sum
+
+            /* Normalize exponentials */
+            {
+                if (!vector_exceeds_bounds)
+                {
+                    /* Loop over row and compute softmax */
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+                        auto     vec_in       = wrapper::vloadq(base_ptr_out);
+                        if (IS_LOG)
+                        {
+                            wrapper::vstore(base_ptr_out, wrapper::vsub(vec_in, vec_sum_transformed));
+                        }
+                        else
+                        {
+                            wrapper::vstore(base_ptr_out, wrapper::vmul(vec_in, vec_sum_transformed));
+                        }
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+                        int      j            = 0;
+                        for (; j < (x_width - winCoords[0]); ++j)
+                        {
+                            if (IS_LOG)
+                            {
+                                *(base_ptr_out + j) -= vec_sum_transformed[j];
+                            }
+                            else
+                            {
+                                *(base_ptr_out + j) *= vec_sum_transformed[j];
+                            }
+                        }
+                    }
+                }
+            } // Normalize exponentials
+        },
+        in_it, out_it);
+}
+template <typename T, bool IS_LOG>
+void neon_softmax_x_quantized(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
 
 template <typename T, bool IS_LOG>
-void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+void neon_softmax_non_x_quantized(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
 } // namespace cpu
 } // namespace arm_compute
 
diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
index 9589ebc..d39240b 100644
--- a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -30,15 +30,23 @@
 namespace cpu
 {
 template <bool IS_LOG>
-void neon_qasymm8_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+void neon_qasymm8_softmax(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
 {
-    return neon_softmax_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, window);
+    if (axis == 0)
+    {
+        return neon_softmax_x_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
+    else
+    {
+        return neon_softmax_non_x_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
 }
 
-template void
-neon_qasymm8_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
-template void
-neon_qasymm8_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void neon_qasymm8_softmax<true>(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_qasymm8_softmax<false>(
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
 
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
index 0bf6b28..26fd5db 100644
--- a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -31,15 +31,22 @@
 {
 template <bool IS_LOG>
 void neon_qasymm8_signed_softmax(
-    const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
 {
-    return neon_softmax_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, window);
+    if (axis == 0)
+    {
+        return neon_softmax_x_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
+    else
+    {
+        return neon_softmax_non_x_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, axis, window);
+    }
 }
 
 template void neon_qasymm8_signed_softmax<true>(
-    const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
 template void neon_qasymm8_signed_softmax<false>(
-    const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+    const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
 
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/list.h b/src/cpu/kernels/softmax/list.h
index c143f66..f9295eb 100644
--- a/src/cpu/kernels/softmax/list.h
+++ b/src/cpu/kernels/softmax/list.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -30,7 +30,7 @@
 {
 #define DECLARE_SOFTMAX_KERNEL(func_name) \
     template <bool IS_LOG>                \
-    void func_name(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+    void func_name(const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
 
 DECLARE_SOFTMAX_KERNEL(neon_fp32_softmax);
 DECLARE_SOFTMAX_KERNEL(neon_fp16_softmax);
diff --git a/src/cpu/operators/CpuSoftmax.cpp b/src/cpu/operators/CpuSoftmax.cpp
index ae14381..fecee7d 100644
--- a/src/cpu/operators/CpuSoftmax.cpp
+++ b/src/cpu/operators/CpuSoftmax.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021, 2023 Arm Limited.
+ * Copyright (c) 2021, 2023-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -41,15 +41,7 @@
 {
 namespace cpu
 {
-CpuSoftmaxGeneric::CpuSoftmaxGeneric()
-    : _permute_input(),
-      _permute_output(),
-      _softmax_kernel(),
-      _tmp(),
-      _input_permuted(),
-      _output_permuted(),
-      _needs_permute(false),
-      _aux_mem(InternalTensorIdx::COUNT)
+CpuSoftmaxGeneric::CpuSoftmaxGeneric() : _softmax_kernel(), _tmp(), _aux_mem(InternalTensorIdx::COUNT)
 {
 }
 
@@ -63,17 +55,9 @@
     const unsigned int actual_axis =
         static_cast<unsigned int>(wrap_around(axis, static_cast<int32_t>(src->num_dimensions())));
 
-    _needs_permute = actual_axis > 0;
+    _axis = actual_axis;
 
-    if (_needs_permute)
-    {
-        _permute_input.configure(src, &_input_permuted,
-                                 softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
-    }
-
-    // We want to deal with a 2D input. Either it is the permuted version of the original input (4D case)
-    // or it is the original input case (2D case)
-    const ITensorInfo *tmp_input = (_needs_permute ? &_input_permuted : src);
+    const ITensorInfo *tmp_input = src;
 
     TensorInfo tensor_info_tmp;
     if (is_data_type_quantized_asymmetric(src->data_type()))
@@ -88,20 +72,10 @@
 
     // Configure kernels
     auto sm = std::make_unique<kernels::CpuSoftmaxKernel>();
-    if (_needs_permute)
-    {
-        // The normalization kernel stores the result in a permuted output tensor
-        sm->configure(tmp_input, &_output_permuted, beta, is_log, &_tmp);
 
-        // Re-permute the permuted output into the requested (4D) output
-        _permute_output.configure(&_output_permuted, dst,
-                                  softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
-    }
-    else
-    {
-        // Softmax 2D case
-        sm->configure(tmp_input, dst, beta, is_log, &_tmp);
-    }
+    // Softmax 2D case
+    sm->configure(tmp_input, dst, beta, is_log, actual_axis, &_tmp);
+
     _softmax_kernel = std::move(sm);
 
     if (_tmp.total_size() > 0)
@@ -109,11 +83,6 @@
         _aux_mem[InternalTensorIdx::TMP] =
             MemoryInfo(offset_int_vec(InternalTensorIdx::TMP), MemoryLifetime::Temporary, _tmp.total_size());
     }
-
-    _aux_mem[InternalTensorIdx::PERMUTED_SRC] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_SRC),
-                                                           MemoryLifetime::Temporary, _input_permuted.total_size());
-    _aux_mem[InternalTensorIdx::PERMUTED_DST] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_DST),
-                                                           MemoryLifetime::Temporary, _output_permuted.total_size());
 }
 
 Status
@@ -133,25 +102,11 @@
     {
         tensor_info_tmp = src->clone()->set_data_type(DataType::F32).set_is_resizable(true);
     }
-
     const unsigned int actual_axis =
         static_cast<unsigned int>(wrap_around(axis, static_cast<int32_t>(src->num_dimensions())));
 
-    const bool needs_permute = actual_axis > 0;
-
-    if (needs_permute)
-    {
-        const PermutationVector permutation_vector =
-            softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis);
-        const TensorShape permuted_shape =
-            misc::shape_calculator::compute_permutation_output_shape(*src, permutation_vector);
-        TensorInfo input_permuted(src->clone()->set_tensor_shape(permuted_shape));
-        ARM_COMPUTE_RETURN_ON_ERROR(CpuPermute::validate(src, &input_permuted, permutation_vector));
-        TensorInfo output_permuted(dst->clone()->set_tensor_shape(permuted_shape));
-        ARM_COMPUTE_RETURN_ON_ERROR(CpuPermute::validate(&output_permuted, dst, permutation_vector));
-    }
-
-    ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuSoftmaxKernel::validate(src, dst, beta, is_log, &tensor_info_tmp));
+    ARM_COMPUTE_RETURN_ON_ERROR(
+        kernels::CpuSoftmaxKernel::validate(src, dst, beta, actual_axis, is_log, &tensor_info_tmp));
 
     return Status{};
 }
@@ -165,34 +120,17 @@
 
     CpuAuxTensorHandler tmp(offset_int_vec(InternalTensorIdx::TMP), _tmp, tensors, true);
 
-    CpuAuxTensorHandler input_permuted(offset_int_vec(InternalTensorIdx::PERMUTED_SRC), _input_permuted, tensors, true);
-    CpuAuxTensorHandler output_permuted(offset_int_vec(InternalTensorIdx::PERMUTED_DST), _output_permuted, tensors,
-                                        true);
-
     ITensorPack softmax_pack;
 
-    if (_needs_permute)
-    {
-        ITensorPack permute_in_pack = {{TensorType::ACL_SRC, src}, {TensorType::ACL_DST, input_permuted.get()}};
-        _permute_input.run(permute_in_pack);
+    softmax_pack = {{TensorType::ACL_SRC_0, src}, {TensorType::ACL_DST_0, dst}, {TensorType::ACL_DST_1, tmp.get()}};
 
-        softmax_pack = {{TensorType::ACL_SRC_0, input_permuted.get()},
-                        {TensorType::ACL_DST_0, output_permuted.get()},
-                        {TensorType::ACL_DST_1, tmp.get()}};
+    if (_axis == 0)
+    {
+        NEScheduler::get().schedule_op(_softmax_kernel.get(), Window::DimY, _softmax_kernel->window(), softmax_pack);
     }
     else
     {
-        softmax_pack = {{TensorType::ACL_SRC_0, src}, {TensorType::ACL_DST_0, dst}, {TensorType::ACL_DST_1, tmp.get()}};
-    }
-
-    NEScheduler::get().schedule_op(_softmax_kernel.get(), Window::DimY, _softmax_kernel->window(), softmax_pack);
-
-    if (_needs_permute)
-    {
-        ITensorPack permute_out_pack;
-        permute_out_pack.add_tensor(TensorType::ACL_SRC, output_permuted.get());
-        permute_out_pack.add_tensor(TensorType::ACL_DST, dst);
-        _permute_output.run(permute_out_pack);
+        NEScheduler::get().schedule_op(_softmax_kernel.get(), Window::DimX, _softmax_kernel->window(), softmax_pack);
     }
 }
 
diff --git a/src/cpu/operators/CpuSoftmax.h b/src/cpu/operators/CpuSoftmax.h
index 47020e9..6ba3476 100644
--- a/src/cpu/operators/CpuSoftmax.h
+++ b/src/cpu/operators/CpuSoftmax.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -89,16 +89,13 @@
         COUNT
     };
 
-    CpuPermute                  _permute_input;
-    CpuPermute                  _permute_output;
     std::unique_ptr<ICPPKernel> _softmax_kernel;
 
     TensorInfo _tmp;
-    TensorInfo _input_permuted;
-    TensorInfo _output_permuted;
 
-    bool                             _needs_permute;
     experimental::MemoryRequirements _aux_mem{};
+
+    unsigned int _axis = 0;
 };
 
 } // namespace cpu