Optimize CpuSoftmaxKernel for axis=0

Implement a single kernel instead of having two consecutive ones. In the previous setup, one kernel was calculating the maximum value in the axis, and this maximum was being subtracted from each data while calculating the softmax, i.e.

softmax(x_i) = exp(x_i - max) / sum_i( exp(x_i - max) )

This patch integrates these two stages into a single kernel for Neon™ for all data types. This will save some memory because we don't need to hold the max values in a separate auxiliary tensor.

It also introduces some other optimizations that will ease memory pressure when the data type is float/half, by using the dst tensor as temporary storage for already exponentiated inputs.

It removes the references to SVE and SVE2 implementations, and most of the associated files; but, it leaves the implementations as these may be used in the future.

Resolves: COMPMID-6500

Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
Change-Id: Icff9976d1214c4c6cbe15a62ca60b8a77d3784cc
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10688
Reviewed-by: SiCong Li <sicong.li@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/kernels/softmax/generic/neon/impl.h b/src/cpu/kernels/softmax/generic/neon/impl.h
index 4d9b789..60380cd 100644
--- a/src/cpu/kernels/softmax/generic/neon/impl.h
+++ b/src/cpu/kernels/softmax/generic/neon/impl.h
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef SRC_CORE_NEON_KERNELS_SOFTMAX_IMPL_H
-#define SRC_CORE_NEON_KERNELS_SOFTMAX_IMPL_H
+#ifndef ACL_SRC_CPU_KERNELS_SOFTMAX_GENERIC_NEON_IMPL_H
+#define ACL_SRC_CPU_KERNELS_SOFTMAX_GENERIC_NEON_IMPL_H
 
 #include "arm_compute/core/Helpers.h"
 
@@ -33,105 +33,100 @@
 {
 namespace cpu
 {
-template <typename T>
-void neon_logits_1d_max(const ITensor *in, ITensor *out, const Window &window)
+
+#ifdef __aarch64__
+namespace
 {
-    /** SIMD vector tag type. */
-    using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
-
-    constexpr int window_step_x  = 16 / sizeof(T);
-    const auto    window_start_x = static_cast<int>(window.x().start());
-    const auto    window_end_x   = static_cast<int>(window.x().end());
-
-    Window win{window};
-    win.set(Window::DimX, Window::Dimension(0, 1, 1));
-    Iterator input(in, win);
-    Iterator output(out, win);
-
-    const int sum_stages = log2(window_step_x / 2);
-    execute_window_loop(
-        win,
-        [&](const Coordinates &)
-        {
-            // Get pointers
-            const auto in_ptr  = reinterpret_cast<const T *>(input.ptr());
-            const auto out_ptr = reinterpret_cast<T *>(output.ptr());
-
-            // Init max value
-            auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
-            int  x       = window_start_x;
-
-            for (; x <= (window_end_x - window_step_x); x += window_step_x)
-            {
-                const auto current_value = wrapper::vloadq(in_ptr + x);
-                vec_max                  = wrapper::vmax(vec_max, current_value);
-            }
-            auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
-
-            for (int i = 0; i < sum_stages; ++i)
-            {
-                carry_max = wrapper::vpmax(carry_max, carry_max);
-            }
-            T max_val = wrapper::vgetlane(carry_max, 0);
-
-            // Compute left-over elements
-            for (; x < window_end_x; ++x)
-            {
-                max_val = *(in_ptr + x) > max_val ? *(in_ptr + x) : max_val;
-            }
-
-            *out_ptr = max_val;
-        },
-        input, output);
+// These helper functions are added because vaddv does not exist for fp16,
+// and, therefore, is not part of the wrapper::vaddv interface.
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+inline float16_t wrapper_vaddv(const float16x8_t &a, int sum_stages)
+{
+    auto sum_res = wrapper::vpadd(wrapper::vgethigh(a), wrapper::vgetlow(a));
+    for (int i = 0; i < sum_stages; ++i)
+    {
+        sum_res = wrapper::vpadd(sum_res, sum_res);
+    }
+    return wrapper::vgetlane(sum_res, 0);
 }
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 
-template <typename T>
-void neon_softmax_logits_1d_quantized(const ITensor *in,
-                                      const ITensor *max,
-                                      void *const    tmp,
-                                      ITensor       *out,
-                                      float          beta,
-                                      bool           is_log,
-                                      const Window  &window);
-
-template <typename T>
-void neon_softmax_logits_1d_float(const ITensor *in,
-                                  const ITensor *max,
-                                  void *const    tmp,
-                                  ITensor       *out,
-                                  const float    beta,
-                                  bool           is_log,
-                                  const Window  &window)
+inline float wrapper_vaddv(const float32x4_t &a, int sum_stages)
 {
-    const int start_x     = in->info()->valid_region().anchor.x();
+    ARM_COMPUTE_UNUSED(sum_stages);
+    return wrapper::vaddv(a);
+}
+} // namespace
+#endif // __aarch64__
+
+// The template implementation for float data types is stored in the header file because
+// we need all fp16 instantiated code to live in fp16.cpp files.
+template <typename T, bool IS_LOG>
+void neon_softmax_float(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window)
+{
+    ARM_COMPUTE_UNUSED(tmp);
+
     const int input_width = in->info()->valid_region().shape.x();
 
     Iterator in_it(in, window);
-    Iterator max_it(max, window);
     Iterator out_it(out, window);
 
     /** SIMD vector tag type. */
     using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
 
-    constexpr int vec_size   = 16 / sizeof(T);
-    const int     sum_stages = log2(vec_size / 2);
+    constexpr int vec_size = 16 / sizeof(T);
+
+    const int sum_stages = log2(vec_size >> 1);
+
+    const auto beta_vec = wrapper::vdup_n(static_cast<T>(beta), ExactTagType{});
 
     execute_window_loop(
         window,
         [&](const Coordinates &)
         {
             /* Get pointers */
-            const auto in_ptr  = reinterpret_cast<const T *>(in_it.ptr()) + start_x;
-            const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
-            const auto tmp_ptr = reinterpret_cast<T *>(tmp);
+            const T *in_ptr  = reinterpret_cast<const T *>(in_it.ptr());
+            T       *out_ptr = reinterpret_cast<T *>(out_it.ptr());
 
-            T sum{};
-            T sum_inversed{};
+            T max_val;
+
+            /* Compute Max */
+            {
+                // Init max value
+                auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
+                int  x       = 0;
+
+                for (; x <= (input_width - vec_size); x += vec_size)
+                {
+                    const auto current_value = wrapper::vloadq(in_ptr + x);
+                    vec_max                  = wrapper::vmax(vec_max, current_value);
+                }
+
+#ifdef __aarch64__
+                max_val = wrapper::vmaxv(vec_max);
+#else  // __aarch64__
+                auto carry_max = wrapper::vpmax(wrapper::vgethigh(vec_max), wrapper::vgetlow(vec_max));
+
+                for (int i = 0; i < sum_stages; ++i)
+                {
+                    carry_max = wrapper::vpmax(carry_max, carry_max);
+                }
+
+                max_val      = wrapper::vgetlane(carry_max, 0);
+#endif // __aarch64__
+
+                // Compute left-over elements
+                for (; x < input_width; ++x)
+                {
+                    max_val = std::max(*(in_ptr + x), max_val);
+                }
+            } // compute max
+
+            T sum_transformed{};
 
             /* Compute exponentials and sum */
             {
                 /* Get max value */
-                const auto max_val = *reinterpret_cast<const T *>(max_it.ptr());
                 const auto vec_max = wrapper::vdup_n(max_val, ExactTagType{});
 
                 /* Init sum to zero */
@@ -143,35 +138,38 @@
                 {
                     auto vec_elements = wrapper::vloadq(in_ptr + x);
                     vec_elements      = wrapper::vsub(vec_elements, vec_max);
-                    if (is_log)
+                    if (IS_LOG)
                     {
-                        vec_elements =
-                            wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{}));
-                        vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
+                        vec_elements = wrapper::vmul(vec_elements, beta_vec);
+                        vec_sum      = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
                     }
                     else
                     {
-                        vec_elements = wrapper::vexpq(
-                            wrapper::vmul(vec_elements, wrapper::vdup_n(static_cast<T>(beta), ExactTagType{})));
-                        vec_sum = wrapper::vadd(vec_sum, vec_elements);
+                        vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, beta_vec));
+                        vec_sum      = wrapper::vadd(vec_sum, vec_elements);
                     }
-                    wrapper::vstore(tmp_ptr + x, vec_elements);
+                    wrapper::vstore(out_ptr + x, vec_elements);
                 }
 
                 /* Reduce sum */
+                T sum{};
+#ifdef __aarch64__
+                sum = wrapper_vaddv(vec_sum, sum_stages);
+#else  // __aarch64__
                 auto sum_res = wrapper::vpadd(wrapper::vgethigh(vec_sum), wrapper::vgetlow(vec_sum));
                 for (int i = 0; i < sum_stages; ++i)
                 {
                     sum_res = wrapper::vpadd(sum_res, sum_res);
                 }
                 sum = wrapper::vgetlane(sum_res, 0);
+#endif // __aarch64__
 
                 /* Run remaining elements */
                 for (; x < input_width; ++x)
                 {
                     T element{};
 
-                    if (is_log)
+                    if (IS_LOG)
                     {
                         element = (in_ptr[x] - max_val) * beta;
                         sum += std::exp(element);
@@ -181,55 +179,59 @@
                         element = std::exp((in_ptr[x] - max_val) * beta);
                         sum += element;
                     }
-                    tmp_ptr[x] = element;
+
+                    out_ptr[x] = element;
                 }
 
-                if (!is_log)
+                if (!IS_LOG)
                 {
-                    sum_inversed = T(1) / sum;
+                    sum_transformed = T(1) / sum;
                 }
                 else
                 {
-                    sum = static_cast<T>(std::log(sum));
+                    sum_transformed = static_cast<T>(std::log(sum));
                 }
-            }
+            } // Compute exponentials and sum
 
             /* Normalize exponentials */
             {
+                const auto sum_vec = wrapper::vdup_n(static_cast<T>(sum_transformed), ExactTagType{});
+
                 /* Loop over row and compute softmax */
                 int x = 0;
                 for (; x <= (input_width - vec_size); x += vec_size)
                 {
-                    auto vec_in           = wrapper::vloadq(tmp_ptr + x);
-                    auto normalized_value = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
-                    if (is_log)
+                    const auto vec_in = wrapper::vloadq(out_ptr + x);
+                    if (IS_LOG)
                     {
-                        normalized_value = wrapper::vsub(vec_in, wrapper::vdup_n(static_cast<T>(sum), ExactTagType{}));
+                        wrapper::vstore(out_ptr + x, wrapper::vsub(vec_in, sum_vec));
                     }
                     else
                     {
-                        normalized_value =
-                            wrapper::vmul(vec_in, wrapper::vdup_n(static_cast<T>(sum_inversed), ExactTagType{}));
+                        wrapper::vstore(out_ptr + x, wrapper::vmul(vec_in, sum_vec));
                     }
-                    wrapper::vstore(out_ptr + x, normalized_value);
                 }
+
                 /* Run remaining elements */
                 for (; x < input_width; ++x)
                 {
-                    if (is_log)
+                    if (IS_LOG)
                     {
-                        out_ptr[x] = tmp_ptr[x] - sum;
+                        out_ptr[x] = out_ptr[x] - sum_transformed;
                     }
                     else
                     {
-                        out_ptr[x] = tmp_ptr[x] * sum_inversed;
+                        out_ptr[x] = out_ptr[x] * sum_transformed;
                     }
                 }
-            }
+            } // Normalize exponentials
         },
-        in_it, max_it, out_it);
+        in_it, out_it);
 }
+
+template <typename T, bool IS_LOG>
+void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
 } // namespace cpu
 } // namespace arm_compute
 
-#endif /* SRC_CORE_NEON_KERNELS_SOFTMAX_IMPL_H */
+#endif // ACL_SRC_CPU_KERNELS_SOFTMAX_GENERIC_NEON_IMPL_H