Optimize CpuSoftmaxKernel for axis != 0 and neon kernels

Resolves: COMPMID-6501
Signed-off-by: Omar Al Khatib <omar.alkhatib@arm.com>
Change-Id: I0abd3cbb5f861301f407c443988fb7efaa205b5d
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11056
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/kernels/softmax/generic/neon/impl.h b/src/cpu/kernels/softmax/generic/neon/impl.h
index 60380cd..e417271 100644
--- a/src/cpu/kernels/softmax/generic/neon/impl.h
+++ b/src/cpu/kernels/softmax/generic/neon/impl.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -62,8 +62,9 @@
 // The template implementation for float data types is stored in the header file because
 // we need all fp16 instantiated code to live in fp16.cpp files.
 template <typename T, bool IS_LOG>
-void neon_softmax_float(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window)
+void neon_softmax_x_float(const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
 {
+    ARM_COMPUTE_UNUSED(axis);
     ARM_COMPUTE_UNUSED(tmp);
 
     const int input_width = in->info()->valid_region().shape.x();
@@ -228,9 +229,199 @@
         },
         in_it, out_it);
 }
+template <typename T, bool IS_LOG>
+void neon_softmax_non_x_float(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
+{
+    ARM_COMPUTE_UNUSED(tmp);
+
+    Iterator in_it(in, window);
+    Iterator out_it(out, window);
+
+    /** SIMD vector tag type. */
+    using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+    const auto         beta_vec        = wrapper::vdup_n(static_cast<T>(beta), ExactTagType{});
+    constexpr int      vec_size        = 16 / sizeof(T);
+    const ITensorInfo *in_info         = in->info();
+    const ITensorInfo *out_info        = out->info();
+    const int          x_width         = in_info->valid_region().shape.x();
+    const unsigned int in_axis_stride  = in_info->strides_in_bytes()[axis];
+    const unsigned int out_axis_stride = out_info->strides_in_bytes()[axis];
+    const int          axis_width      = in_info->dimension(axis);
+
+    execute_window_loop(
+        window,
+        [&](const Coordinates &winCoords)
+        {
+            const bool vector_exceeds_bounds = (winCoords[0] + vec_size) > x_width;
+
+            /* Get pointers */
+            const uint8_t *in_ptr  = in_it.ptr();
+            uint8_t       *out_ptr = out_it.ptr();
+
+            // Init max value
+            auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
+
+            /* Compute Max */
+            {
+                if (!vector_exceeds_bounds)
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const auto current_value =
+                            wrapper::vloadq(reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr));
+                        vec_max = wrapper::vmax(vec_max, current_value);
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const T *const base_ptr_in = reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr);
+                        int            j           = 0;
+                        for (; j < (x_width - winCoords[0]); ++j)
+                        {
+                            const auto current_value = *(base_ptr_in + j);
+                            vec_max[j]               = std::max(vec_max[j], current_value);
+                        }
+                    }
+                }
+            } // compute max
+
+            auto vec_sum_transformed = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+
+            auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+            /* Init sum to zero */
+            auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+
+            /* Compute exponentials and sum */
+            {
+                if (!vector_exceeds_bounds)
+                {
+                    const auto vec_one = wrapper::vdup_n(static_cast<T>(1), ExactTagType{});
+                    /* Loop over row and compute exponentials and sum */
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        vec_elements = wrapper::vloadq(reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr));
+                        vec_elements = wrapper::vsub(vec_elements, vec_max);
+                        if (IS_LOG)
+                        {
+                            vec_elements = wrapper::vmul(vec_elements, beta_vec);
+                            vec_sum      = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
+                        }
+                        else
+                        {
+                            vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, beta_vec));
+                            vec_sum      = wrapper::vadd(vec_sum, vec_elements);
+                        }
+
+                        wrapper::vstore(reinterpret_cast<T *>((i * out_axis_stride) + out_ptr), vec_elements);
+                    }
+
+                    if (!IS_LOG)
+                    {
+                        vec_sum_transformed = wrapper::vdiv(vec_one, vec_sum);
+                    }
+                    else
+                    {
+                        vec_sum_transformed = wrapper::vlog(vec_sum);
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        const T *const base_ptr_in  = reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr);
+                        T *const       base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+                        int            j            = 0;
+                        for (; j < (x_width - winCoords[0]); ++j)
+                        {
+                            vec_elements[j] = *(base_ptr_in + j);
+                            vec_elements[j] -= vec_max[j];
+                            if (IS_LOG)
+                            {
+                                vec_elements[j] *= beta;
+                                vec_sum[j] += std::exp(vec_elements[j]);
+                            }
+                            else
+                            {
+                                vec_elements[j] = std::exp(vec_elements[j] * beta);
+                                vec_sum[j] += vec_elements[j];
+                            }
+                            *(base_ptr_out + j) = vec_elements[j];
+                        }
+                    }
+                    int j = 0;
+                    for (; j < (x_width - winCoords[0]); ++j)
+                    {
+                        if (!IS_LOG)
+                        {
+                            vec_sum_transformed[j] = 1 / vec_sum[j];
+                        }
+                        else
+                        {
+                            vec_sum_transformed[j] = std::log(vec_sum[j]);
+                        }
+                    }
+                }
+            } // Compute exponentials and sum
+
+            /* Normalize exponentials */
+            {
+                if (!vector_exceeds_bounds)
+                {
+                    /* Loop over row and compute softmax */
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+                        auto     vec_in       = wrapper::vloadq(base_ptr_out);
+                        if (IS_LOG)
+                        {
+                            wrapper::vstore(base_ptr_out, wrapper::vsub(vec_in, vec_sum_transformed));
+                        }
+                        else
+                        {
+                            wrapper::vstore(base_ptr_out, wrapper::vmul(vec_in, vec_sum_transformed));
+                        }
+                    }
+                }
+                else
+                {
+                    int i = 0;
+                    for (; i < axis_width; ++i)
+                    {
+                        T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+                        int      j            = 0;
+                        for (; j < (x_width - winCoords[0]); ++j)
+                        {
+                            if (IS_LOG)
+                            {
+                                *(base_ptr_out + j) -= vec_sum_transformed[j];
+                            }
+                            else
+                            {
+                                *(base_ptr_out + j) *= vec_sum_transformed[j];
+                            }
+                        }
+                    }
+                }
+            } // Normalize exponentials
+        },
+        in_it, out_it);
+}
+template <typename T, bool IS_LOG>
+void neon_softmax_x_quantized(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
 
 template <typename T, bool IS_LOG>
-void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+void neon_softmax_non_x_quantized(
+    const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
 } // namespace cpu
 } // namespace arm_compute