IVGCVSW-631 Neon support for Softmax beta parameter (F32 only)

Change-Id: Ibf6f038b39f1a4e557f5d04feb08e3d5ef54e223
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/112019
Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
index f102759..a8a0f59 100644
--- a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
+++ b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
@@ -251,8 +251,10 @@
 
 namespace
 {
-void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
+void logits_1d_shift_exp_sum_qs8(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
 {
+    ARM_COMPUTE_UNUSED(beta);
+
     Window window_max(window);
     window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
 
@@ -313,8 +315,10 @@
     }
     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
 }
-void logits_1d_shift_exp_sum_qs16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
+void logits_1d_shift_exp_sum_qs16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
 {
+    ARM_COMPUTE_UNUSED(beta);
+
     Window window_max(window);
     window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
 
@@ -375,7 +379,7 @@
 }
 
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
+void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
 {
     Window window_max(window);
     window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
@@ -410,6 +414,7 @@
         {
             float16x8_t vec_elements = vld1q_f16(in_ptr);
             vec_elements             = vsubq_f16(vec_elements, vec_max);
+            vec_elements             = vmulq_n_f16(vec_elements, beta);
             vec_elements             = vexpq_f16(vec_elements);
 
             vst1q_f16(exp_ptr, vec_elements);
@@ -426,7 +431,7 @@
         // Run remaining elements
         for(int i = 0; i < small_steps; ++i)
         {
-            const float16_t element = std::exp(static_cast<float>(in_ptr[i] - *max_ptr));
+            const float16_t element = std::exp(static_cast<float>(in_ptr[i] - *max_ptr) * beta);
             exp_ptr[i]              = element;
             sum += element;
         }
@@ -436,7 +441,7 @@
 }
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
 
-void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
+void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta)
 {
     Window window_max(window);
     window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
@@ -471,6 +476,7 @@
         {
             float32x4_t vec_elements = vld1q_f32(in_ptr);
             vec_elements             = vsubq_f32(vec_elements, vec_max);
+            vec_elements             = vmulq_n_f32(vec_elements, beta);
             vec_elements             = vexpq_f32(vec_elements);
 
             vst1q_f32(exp_ptr, vec_elements);
@@ -488,7 +494,7 @@
         // Run remaining elements
         for(int i = 0; i < small_steps; ++i)
         {
-            float element = std::exp(in_ptr[i] - *max_ptr);
+            float element = std::exp((in_ptr[i] - *max_ptr) * beta);
             exp_ptr[i]    = element;
             sum += element;
         }
@@ -500,14 +506,15 @@
 } //namespace
 
 NELogits1DShiftExpSumKernel::NELogits1DShiftExpSumKernel()
-    : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr)
+    : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _sum(nullptr), _beta(1.0f)
 {
 }
 
-void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum)
+void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum, float beta)
 {
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_NULLPTR(max, sum, output);
+    ARM_COMPUTE_ERROR_ON((beta != 1.0f) && is_data_type_fixed_point(input->info()->data_type()));
 
     // Output auto initialization if not yet initialized
     auto_init_if_empty(*sum->info(), max->info()->tensor_shape(), 1, input->info()->data_type(), input->info()->fixed_point_position());
@@ -545,6 +552,7 @@
     _max    = max;
     _output = output;
     _sum    = sum;
+    _beta   = beta;
 
     // Configure kernel window
     Window                 win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));
@@ -568,7 +576,7 @@
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
     ARM_COMPUTE_ERROR_ON(_func == nullptr);
 
-    (*_func)(_input, _max, _output, _sum, window);
+    (*_func)(_input, _max, _output, _sum, window, _beta);
 }
 
 namespace