Optimize CpuSoftmaxKernel for axis != 0 and neon kernels

Resolves: COMPMID-6501
Signed-off-by: Omar Al Khatib <omar.alkhatib@arm.com>
Change-Id: I0abd3cbb5f861301f407c443988fb7efaa205b5d
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11056
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp
index 68bc397..54ff858 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.cpp
+++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp
@@ -81,7 +81,7 @@
 };
 
 Status validate_arguments_softmax(
-    const ITensorInfo &src, const ITensorInfo &dst, float beta, const ITensorInfo &tmp, bool is_log)
+    const ITensorInfo &src, const ITensorInfo &dst, float beta, int axis, const ITensorInfo &tmp, bool is_log)
 {
     ARM_COMPUTE_UNUSED(beta);
     // Check input
@@ -89,6 +89,8 @@
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
                                                          DataType::F16, DataType::F32);
 
+    ARM_COMPUTE_RETURN_ERROR_ON(axis < 0 || axis > 3);
+
     const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src.data_type());
 
     // Check output if configured
@@ -124,10 +126,13 @@
     return available_kernels;
 }
 
-void CpuSoftmaxKernel::configure(const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, ITensorInfo *tmp)
+void CpuSoftmaxKernel::configure(
+    const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, int axis, ITensorInfo *tmp)
 {
+    _axis = axis;
+
     ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp);
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_softmax(*src, *dst, beta, *tmp, is_log));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log));
 
     // Configure kernel window
     const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src->data_type());
@@ -154,25 +159,40 @@
     _run_method = uk->ukernel;
     _name       = kernel_name.append("/").append(uk->name);
 
-    Window win = calculate_max_window(*dst, Steps());
+    Window win;
 
-    /// TODO: Check dimensions > 0 for holes only. For this, we need
-    /// a utility function checking if there are holes after some dimension.
-    if (!has_holes(*dst, dst->num_dimensions() - 1))
+    int vec_size = 16 / dst->element_size();
+
+    if (_axis == 0)
     {
-        win = win.collapse(win, Window::DimY);
+        win = calculate_max_window(*dst, Steps());
+
+        /// TODO:Check dimensions > 0 for holes only. For this, we need
+        /// a utility function checking if there are holes after some dimension.
+        if (!has_holes(*dst, dst->num_dimensions() - 1))
+        {
+            win = win.collapse(win, Window::DimY);
+        }
+    }
+    else if (_axis > 0 && _axis <= 3)
+    {
+        win = calculate_max_window(*dst, Steps(vec_size));
+    }
+    else
+    {
+        ARM_COMPUTE_ERROR("Invalid axis");
     }
 
-    win.set(Window::DimX, Window::Dimension(0, 1, 1)); // First dimension is the reduction axis
+    win.set(_axis, Window::Dimension(0, 1, 1));
 
     ICpuKernel<CpuSoftmaxKernel>::configure(win);
 }
 
 Status CpuSoftmaxKernel::validate(
-    const ITensorInfo *src, const ITensorInfo *dst, float beta, bool is_log, const ITensorInfo *tmp)
+    const ITensorInfo *src, const ITensorInfo *dst, float beta, int axis, bool is_log, const ITensorInfo *tmp)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp);
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_softmax(*src, *dst, beta, *tmp, is_log));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log));
 
     return Status{};
 }
@@ -188,19 +208,25 @@
 
     if (is_data_type_quantized_asymmetric(src->info()->data_type()))
     {
-        auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
-
-        const unsigned int num_elems_processed_per_iteration = src->info()->valid_region().shape.x();
+        auto         tmp = tensors.get_tensor(TensorType::ACL_DST_1);
+        unsigned int num_elems_processed_per_iteration;
+        if (_axis == 0)
+        {
+            num_elems_processed_per_iteration = src->info()->valid_region().shape[_axis];
+        }
+        else
+        {
+            //16 QASYMM8/QASYMM8_SIGNED elements can fit into the 16-byte vectors.
+            num_elems_processed_per_iteration = 16;
+        }
         const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration;
 
-        ARM_COMPUTE_ERROR_ON(tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
-
         void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread);
-        _run_method(src, tmp_for_thread, dst, _beta, window);
+        _run_method(src, tmp_for_thread, dst, _beta, _axis, window);
     }
     else
     {
-        _run_method(src, nullptr, dst, _beta, window);
+        _run_method(src, nullptr, dst, _beta, _axis, window);
     }
 }