COMPMID-2265 add support for Log Softmax to NEON

Kernel (NEON/reference), validation tests, function and fixture
are updated to add support for Log Softmax

Change-Id: I641dbf1552f4128c691af8875949ebf88da71ee8
Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2075
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
index 4144a18..1003ebd 100644
--- a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
+++ b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
@@ -333,6 +333,19 @@
     return res;
 }
 
+float32x4x4_t vsub_n(float32x4x4_t a, float val)
+{
+    auto          scalar_vector = vdup_n<float32x4x4_t>(val);
+    float32x4x4_t res           = { {
+            vsubq_f32(a.val[0], scalar_vector.val[0]),
+            vsubq_f32(a.val[1], scalar_vector.val[1]),
+            vsubq_f32(a.val[2], scalar_vector.val[2]),
+            vsubq_f32(a.val[3], scalar_vector.val[3])
+        }
+    };
+    return res;
+}
+
 namespace
 {
 Status validate_arguments_logits_1d_max(const ITensorInfo &input, const ITensorInfo &output)
@@ -590,6 +603,7 @@
     return reduce_add_impl < elem_type_t<V>, N, 0, N - 1 >::reduce(add_fn, vec);
 }
 
+template <bool is_log>
 void logits_1d_softmax_qasymm8(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, const Window &window)
 {
     const int start_x     = in.info()->valid_region().anchor.x();
@@ -608,7 +622,8 @@
         const auto out_ptr = reinterpret_cast<qasymm8_t *>(out_it.ptr()) + start_x;
         const auto tmp_ptr = reinterpret_cast<float *>(tmp);
 
-        float sum_inversed;
+        float sum{};
+        float sum_inversed{};
 
         /* Compute exponentials and sum */
         {
@@ -622,33 +637,55 @@
             /* Loop over row and compute exponentials and sum */
             int           i        = 0;
             constexpr int vec_size = vec_size_of(vec_max);
+
             for(; i <= (input_width - vec_size); i += vec_size)
             {
                 auto vec_elements = vld<vec_16_byte_t<qasymm8_t>>(in_ptr + i);
                 vec_elements      = vsubq_u8(vec_max, vec_elements);
 
                 auto vec_elements_flt = vcvt<float32x4x4_t>(vec_elements);
-                vec_elements_flt      = vexp(vmul_n(vec_elements_flt, scale_beta));
 
-                vec_sum = vadd(vec_sum, vec_elements_flt);
-
+                if(is_log)
+                {
+                    vec_elements_flt = vmul_n(vec_elements_flt, scale_beta);
+                    vec_sum          = vadd(vec_sum, vexp(vec_elements_flt));
+                }
+                else
+                {
+                    vec_elements_flt = vexp(vmul_n(vec_elements_flt, scale_beta));
+                    vec_sum          = vadd(vec_sum, vec_elements_flt);
+                }
                 vst4q_f32(tmp_ptr + i, vec_elements_flt);
             }
+
             /* Reduce sum */
             const auto sum_16_byte = vaddq_f32(vaddq_f32(vec_sum.val[0], vec_sum.val[1]),
                                                vaddq_f32(vec_sum.val[2], vec_sum.val[3]));
             const auto sum_8_byte = vadd_f32(vget_low(sum_16_byte), vget_high(sum_16_byte));
-            float      sum        = reduce_add(std::plus<float>(), sum_8_byte);
+            sum                   = reduce_add(std::plus<float>(), sum_8_byte);
 
             /* Run remaining elements */
             for(; i < input_width; ++i)
             {
-                const float element = std::exp((max_val - in_ptr[i]) * scale_beta);
-                sum += element;
+                float element{};
+                if(is_log)
+                {
+                    element = (max_val - in_ptr[i]) * scale_beta;
+                    sum += std::exp(element);
+                }
+                else
+                {
+                    element = std::exp((max_val - in_ptr[i]) * scale_beta);
+                    sum += element;
+                }
+
                 tmp_ptr[i] = element;
             }
 
-            sum_inversed = 256.f / sum;
+            if(!is_log)
+            {
+                sum_inversed = 256.f / sum;
+            }
         }
 
         /* Normalize exponentials */
@@ -657,24 +694,40 @@
             int i = 0;
             {
                 constexpr int vec_size = 16;
+
                 for(; i <= (input_width - vec_size); i += vec_size)
                 {
-                    float32x4x4_t vec_in           = vld4q_f32(tmp_ptr + i);
-                    auto          normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vmul_n(vec_in, sum_inversed));
+                    float32x4x4_t            vec_in = vld4q_f32(tmp_ptr + i);
+                    vec_16_byte_t<qasymm8_t> normalized_value{};
+                    if(is_log)
+                    {
+                        normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vsub_n(vec_in, sum));
+                    }
+                    else
+                    {
+                        normalized_value = vcvt<vec_16_byte_t<qasymm8_t>>(vmul_n(vec_in, sum_inversed));
+                    }
                     vst(out_ptr + i, normalized_value);
                 }
             }
             /* Run remaining elements */
             for(; i < input_width; ++i)
             {
-                out_ptr[i] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[i] * sum_inversed);
+                if(is_log)
+                {
+                    out_ptr[i] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[i] - sum);
+                }
+                else
+                {
+                    out_ptr[i] = utils::cast::saturate_cast<qasymm8_t>(tmp_ptr[i] * sum_inversed);
+                }
             }
         }
     },
     in_it, max_it, out_it);
 }
 
-template <typename T>
+template <typename T, bool is_log = false>
 void logits_1d_softmax_float(const ITensor &in, const ITensor &max, void *const tmp,
                              ITensor &out, const float beta, const Window &window)
 {
@@ -692,7 +745,8 @@
         const auto out_ptr = reinterpret_cast<T *>(out_it.ptr()) + start_x;
         const auto tmp_ptr = reinterpret_cast<T *>(tmp);
 
-        T sum_inversed;
+        T sum{};
+        T sum_inversed{};
 
         /* Compute exponentials and sum */
         {
@@ -706,46 +760,87 @@
             /* Loop over row and compute exponentials and sum */
             int           i        = 0;
             constexpr int vec_size = vec_size_of(vec_sum);
+
             for(; i <= (input_width - vec_size); i += vec_size)
             {
                 auto vec_elements = vld<vec_16_byte_t<T>>(in_ptr + i);
                 vec_elements      = vsub(vec_elements, vec_max);
-                vec_elements      = vexp(vmul_n(vec_elements, static_cast<T>(beta)));
-                vec_sum           = vadd(vec_sum, vec_elements);
+                if(is_log)
+                {
+                    vec_elements = vmul_n(vec_elements, static_cast<T>(beta));
+                    vec_sum      = vadd(vec_sum, vexp(vec_elements));
+                }
+                else
+                {
+                    vec_elements = vexp(vmul_n(vec_elements, static_cast<T>(beta)));
+                    vec_sum      = vadd(vec_sum, vec_elements);
+                }
                 vst(tmp_ptr + i, vec_elements);
             }
+
             /* Reduce sum */
             const auto sum_8_byte = vadd(vget_high(vec_sum), vget_low(vec_sum));
-            T sum                 = reduce_add([](T a, T b) -> T { return a + b; }, sum_8_byte);
+            sum                   = reduce_add([](T a, T b) -> T { return a + b; }, sum_8_byte);
 
             /* Run remaining elements */
+
             for(; i < input_width; ++i)
             {
-                T element = std::exp((in_ptr[i] - max_val) * beta);
-                sum += element;
+                T element{};
+
+                if(is_log)
+                {
+                    element = (in_ptr[i] - max_val) * beta;
+                    sum += std::exp(element);
+                }
+                else
+                {
+                    element = std::exp((in_ptr[i] - max_val) * beta);
+                    sum += element;
+                }
                 tmp_ptr[i] = element;
             }
 
-            sum_inversed = T(1) / sum;
+            if(!is_log)
+            {
+                sum_inversed = T(1) / sum;
+            }
         }
 
         /* Normalize exponentials */
         {
             /* Loop over row and compute softmax */
             int i = 0;
+
             {
                 constexpr int vec_size = vec_size_of(vec_16_byte_t<T> {});
+
                 for(; i <= (input_width - vec_size); i += vec_size)
                 {
-                    auto             vec_in           = vld<vec_16_byte_t<T>>(tmp_ptr + i);
-                    vec_16_byte_t<T> normalized_value = vmul_n(vec_in, sum_inversed);
+                    auto             vec_in = vld<vec_16_byte_t<T>>(tmp_ptr + i);
+                    vec_16_byte_t<T> normalized_value{};
+                    if(is_log)
+                    {
+                        normalized_value = vsub(vec_in, vdup_n<vec_16_byte_t<T>>(sum));
+                    }
+                    else
+                    {
+                        normalized_value = vmul_n(vec_in, sum_inversed);
+                    }
                     vst(out_ptr + i, normalized_value);
                 }
             }
             /* Run remaining elements */
             for(; i < input_width; ++i)
             {
-                out_ptr[i] = tmp_ptr[i] * sum_inversed;
+                if(is_log)
+                {
+                    out_ptr[i] = tmp_ptr[i] - sum;
+                }
+                else
+                {
+                    out_ptr[i] = tmp_ptr[i] * sum_inversed;
+                }
             }
         }
     },
@@ -753,12 +848,14 @@
 }
 } // namespace
 
-NELogits1DSoftmaxKernel::NELogits1DSoftmaxKernel()
+template <bool IS_LOG>
+NELogits1DSoftmaxKernel<IS_LOG>::NELogits1DSoftmaxKernel()
     : _func(nullptr), _input(nullptr), _max(nullptr), _output(nullptr), _beta(1.0f), _tmp(nullptr)
 {
 }
 
-void NELogits1DSoftmaxKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
+template <bool IS_LOG>
+void NELogits1DSoftmaxKernel<IS_LOG>::configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
     ARM_COMPUTE_ERROR_ON_NULLPTR(input->info(), max->info(), output->info(), tmp->info());
@@ -771,15 +868,15 @@
     switch(input->info()->data_type())
     {
         case DataType::QASYMM8:
-            _func = &logits_1d_softmax_qasymm8;
+            _func = &logits_1d_softmax_qasymm8<IS_LOG>;
             break;
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
         case DataType::F16:
-            _func = &logits_1d_softmax_float<float16_t>;
+            _func = &logits_1d_softmax_float<float16_t, IS_LOG>;
             break;
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
         case DataType::F32:
-            _func = &logits_1d_softmax_float<float>;
+            _func = &logits_1d_softmax_float<float, IS_LOG>;
             break;
         default:
             ARM_COMPUTE_ERROR("Unsupported data type.");
@@ -795,8 +892,9 @@
     INEKernel::configure(win_config.second);
 }
 
-Status NELogits1DSoftmaxKernel::validate(const ITensorInfo *input, const ITensorInfo *max,
-                                         const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
+template <bool IS_LOG>
+Status NELogits1DSoftmaxKernel<IS_LOG>::validate(const ITensorInfo *input, const ITensorInfo *max,
+                                                 const ITensorInfo *output, const float beta, const ITensorInfo *tmp)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, max, output, tmp);
 
@@ -806,7 +904,8 @@
     return Status{};
 }
 
-void NELogits1DSoftmaxKernel::run(const Window &window, const ThreadInfo &info)
+template <bool IS_LOG>
+void NELogits1DSoftmaxKernel<IS_LOG>::run(const Window &window, const ThreadInfo &info)
 {
     ARM_COMPUTE_UNUSED(info);
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
@@ -822,4 +921,7 @@
     (*_func)(*_input, *_max, tmp_for_thread, *_output, _beta, window);
 }
 
+template class NELogits1DSoftmaxKernel<true>;
+template class NELogits1DSoftmaxKernel<false>;
+
 } // namespace arm_compute