COMPMID-421: Added FP16 support to Softmax.

Change-Id: If48178689e7cdadf1858556438c7292128be5b92
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/80436
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Moritz Pflanzer <moritz.pflanzer@arm.com>
diff --git a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
index fe62d7b..79fcba1 100644
--- a/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
+++ b/src/core/NEON/kernels/NESoftmaxLayerKernel.cpp
@@ -106,6 +106,41 @@
     }
     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
 }
+
+#ifdef ARM_COMPUTE_ENABLE_FP16
+void logits_1d_max_f16(const ITensor *in, ITensor *out, const Window &window)
+{
+    Window in_slice = window.first_slice_window_1D();
+
+    Window window_max(window);
+    window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
+    Window max_slice = window_max.first_slice_window_1D();
+
+    do
+    {
+        Iterator input(in, in_slice);
+        Iterator output(out, max_slice);
+
+        float16x8_t vec_max = vdupq_n_f16(std::numeric_limits<float16_t>::lowest());
+
+        execute_window_loop(in_slice, [&](const Coordinates & id)
+        {
+            const auto        in_ptr        = reinterpret_cast<const float16_t *>(input.ptr());
+            const float16x8_t current_value = vld1q_f16(in_ptr);
+            vec_max                         = vmaxq_f16(vec_max, current_value);
+        },
+        input);
+
+        float16x4_t carry_max = vpmax_f16(vget_high_f16(vec_max), vget_low_f16(vec_max));
+        carry_max             = vpmax_f16(carry_max, carry_max);
+        carry_max             = vpmax_f16(carry_max, carry_max);
+
+        *(reinterpret_cast<float16_t *>(output.ptr())) = vget_lane_f16(carry_max, 0);
+    }
+    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
+}
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
 void logits_1d_max_f32(const ITensor *in, ITensor *out, const Window &window)
 {
     Window in_slice = window.first_slice_window_1D();
@@ -150,7 +185,7 @@
 
 void NELogits1DMaxKernel::configure(const ITensor *input, ITensor *output)
 {
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_NULLPTR(output);
 
     // Softmax across the x dimension
@@ -178,6 +213,11 @@
         case DataType::F32:
             _func = &logits_1d_max_f32;
             break;
+        case DataType::F16:
+#ifdef ARM_COMPUTE_ENABLE_FP16
+            _func = &logits_1d_max_f16;
+            break;
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
         default:
             ARM_COMPUTE_ERROR("Unsupported data type.");
     }
@@ -333,6 +373,69 @@
     }
     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
 }
+
+#ifdef ARM_COMPUTE_ENABLE_FP16
+void logits_1d_shift_exp_sum_f16(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
+{
+    Window window_max(window);
+    window_max.set(Window::DimX, Window::Dimension(0, 0, 0));
+
+    Window max_slice = window_max.first_slice_window_1D();
+    Window in_slice  = window.first_slice_window_1D();
+
+    constexpr int step        = 8;
+    const int     long_steps  = in->info()->valid_region().shape.x() / step;
+    const int     small_steps = in->info()->valid_region().shape.x() % step;
+
+    do
+    {
+        Iterator input(in, in_slice);
+        Iterator exp(out, in_slice);
+        Iterator _max(max, max_slice);
+        Iterator _sum(sum, max_slice);
+
+        // Get pointers
+        auto in_ptr  = reinterpret_cast<const float16_t *>(input.ptr());
+        auto exp_ptr = reinterpret_cast<float16_t *>(exp.ptr());
+
+        // Init sum to zero
+        float16x8_t vec_sum_value = vdupq_n_f16(0);
+
+        // Get max value
+        const auto        max_ptr = reinterpret_cast<const float16_t *>(_max.ptr());
+        const float16x8_t vec_max = vdupq_n_f16(*max_ptr);
+
+        // Run neon loop
+        for(int i = 0; i < long_steps; ++i)
+        {
+            float16x8_t vec_elements = vld1q_f16(in_ptr);
+            vec_elements             = vsubq_f16(vec_elements, vec_max);
+            vec_elements             = vexpq_f16(vec_elements);
+
+            vst1q_f16(exp_ptr, vec_elements);
+            vec_sum_value = vaddq_f16(vec_sum_value, vec_elements);
+
+            in_ptr += step;
+            exp_ptr += step;
+        }
+        // Reduce sum
+        const float16x4_t sum_red        = vadd_f16(vget_low_f16(vec_sum_value), vget_high_f16(vec_sum_value));
+        const float16x4_t carry_addition = vpadd_f16(sum_red, sum_red);
+        float16_t         sum            = vget_lane_f16(carry_addition, 0) + vget_lane_f16(carry_addition, 1);
+
+        // Run remaining elements
+        for(int i = 0; i < small_steps; ++i)
+        {
+            const float16_t element = std::exp(static_cast<float>(in_ptr[i] - *max_ptr));
+            exp_ptr[i]              = element;
+            sum += element;
+        }
+        *(reinterpret_cast<float16_t *>(_sum.ptr())) = sum;
+    }
+    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(max_slice));
+}
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
 void logits_1d_shift_exp_sum_f32(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window)
 {
     Window window_max(window);
@@ -403,7 +506,7 @@
 
 void NELogits1DShiftExpSumKernel::configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum)
 {
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_NULLPTR(max, sum, output);
 
     // Output auto initialization if not yet initialized
@@ -428,8 +531,14 @@
         case DataType::F32:
             _func = &logits_1d_shift_exp_sum_f32;
             break;
+        case DataType::F16:
+#ifdef ARM_COMPUTE_ENABLE_FP16
+            _func = &logits_1d_shift_exp_sum_f16;
+            break;
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
         default:
             ARM_COMPUTE_ERROR("Unsupported data type.");
+            break;
     }
 
     _input  = input;
@@ -527,6 +636,39 @@
     }
     while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
 }
+#ifdef ARM_COMPUTE_ENABLE_FP16
+void logits_1d_norm_f16(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
+{
+    Window window_sum(window);
+    window_sum.set(Window::DimX, Window::Dimension(0, 0, 0));
+    Window sum_slice = window_sum.first_slice_window_1D();
+    Window in_slice  = window.first_slice_window_1D();
+
+    do
+    {
+        Iterator input(in, in_slice);
+        Iterator _sum(sum, sum_slice);
+        Iterator output(out, in_slice);
+
+        const float16_t   sum_value        = *reinterpret_cast<const qint16_t *>(_sum.ptr());
+        const float16x8_t vec_sum_inversed = vdupq_n_f16(1.0f / sum_value);
+
+        execute_window_loop(in_slice, [&](const Coordinates & id)
+        {
+            const auto in_ptr  = reinterpret_cast<const float16_t *>(input.ptr());
+            const auto out_ptr = reinterpret_cast<float16_t *>(output.ptr());
+
+            const float16x8_t vec_in           = vld1q_f16(in_ptr);
+            const float16x8_t normalized_value = vmulq_f16(vec_in, vec_sum_inversed);
+
+            vst1q_f16(out_ptr, normalized_value);
+        },
+        input, output);
+    }
+    while(window.slide_window_slice_1D(in_slice) && window.slide_window_slice_1D(sum_slice));
+}
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
+
 void logits_1d_norm_f32(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window)
 {
     Window window_sum(window);
@@ -566,7 +708,7 @@
 
 void NELogits1DNormKernel::configure(const ITensor *input, const ITensor *sum, ITensor *output)
 {
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
     ARM_COMPUTE_ERROR_ON_NULLPTR(sum, output);
 
     // Output auto initialization if not yet initialized
@@ -594,8 +736,14 @@
         case DataType::F32:
             _func = &logits_1d_norm_f32;
             break;
+        case DataType::F16:
+#ifdef ARM_COMPUTE_ENABLE_FP16
+            _func = &logits_1d_norm_f16;
+            break;
+#endif /* ARM_COMPUTE_ENABLE_FP16 */
         default:
             ARM_COMPUTE_ERROR("Unsupported data type.");
+            break;
     }
 
     Window win = calculate_max_window(*input->info(), Steps(num_elems_processed_per_iteration));