COMPMID-1534: Fix NEActivationLayer for FP16

Simulates Logistic, Tanh and SoftRelu in FP32

Change-Id: I9950f7636b8ff2f3e054937e5ef414e45dfe06f5
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/145357
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/NEON/kernels/NEActivationLayerKernel.cpp b/src/core/NEON/kernels/NEActivationLayerKernel.cpp
index 1dad531..2163f7b 100644
--- a/src/core/NEON/kernels/NEActivationLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEActivationLayerKernel.cpp
@@ -138,6 +138,7 @@
         { ActivationFunction::RELU, &NEActivationLayerKernel::activation<ActivationFunction::RELU, float16_t> },
         { ActivationFunction::BOUNDED_RELU, &NEActivationLayerKernel::activation<ActivationFunction::BOUNDED_RELU, float16_t> },
         { ActivationFunction::LU_BOUNDED_RELU, &NEActivationLayerKernel::activation<ActivationFunction::LU_BOUNDED_RELU, float16_t> },
+        { ActivationFunction::LEAKY_RELU, &NEActivationLayerKernel::activation<ActivationFunction::LEAKY_RELU, float16_t> },
         { ActivationFunction::SOFT_RELU, &NEActivationLayerKernel::activation<ActivationFunction::SOFT_RELU, float16_t> },
         { ActivationFunction::SQRT, &NEActivationLayerKernel::activation<ActivationFunction::SQRT, float16_t> },
         { ActivationFunction::SQUARE, &NEActivationLayerKernel::activation<ActivationFunction::SQUARE, float16_t> },
@@ -182,11 +183,14 @@
     Iterator input(_input, window);
     Iterator output(_output, window);
 
-    static const float16x8_t CONST_0 = vdupq_n_f16(0.f);
-    static const float16x8_t CONST_1 = vdupq_n_f16(1.f);
+    static const float16x8_t CONST_0   = vdupq_n_f16(0.f);
+    static const float16x4_t CONST_1_H = vdup_n_f16(1.f);
 
-    const float16x8_t a = vdupq_n_f16(_act_info.a());
-    const float16x8_t b = vdupq_n_f16(_act_info.b());
+    static const float32x4_t CONST_1_F32 = vdupq_n_f32(1.f);
+
+    const float16x8_t a   = vdupq_n_f16(_act_info.a());
+    const float16x4_t a_h = vdup_n_f16(_act_info.a());
+    const float16x8_t b   = vdupq_n_f16(_act_info.b());
 
     execute_window_loop(window, [&](const Coordinates &)
     {
@@ -235,14 +239,29 @@
                 };
                 break;
             case ActivationFunction::LOGISTIC:
+            {
+                // TODO (COMPMID-1535) : Revisit FP16 approximations
+                const float16x4x2_t in0 =
+                {
+                    vinv_f16(vadd_f16(CONST_1_H, vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vneg_f16(vget_low_f16(in.val[0]))))))),
+                    vinv_f16(vadd_f16(CONST_1_H, vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vneg_f16(vget_high_f16(in.val[0]))))))),
+                };
+
+                const float16x4x2_t in1 =
+                {
+                    vinv_f16(vadd_f16(CONST_1_H, vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vneg_f16(vget_low_f16(in.val[1]))))))),
+                    vinv_f16(vadd_f16(CONST_1_H, vcvt_f16_f32(vexpq_f32(vcvt_f32_f16(vneg_f16(vget_high_f16(in.val[1]))))))),
+                };
+
                 tmp =
                 {
                     {
-                        vinvq_f16(vaddq_f16(CONST_1, vexpq_f16(vnegq_f16(in.val[0])))),
-                        vinvq_f16(vaddq_f16(CONST_1, vexpq_f16(vnegq_f16(in.val[1])))),
+                        vcombine_f16(in0.val[0], in0.val[1]),
+                        vcombine_f16(in1.val[0], in1.val[1]),
                     }
                 };
-                break;
+            }
+            break;
             case ActivationFunction::RELU:
                 tmp =
                 {
@@ -262,14 +281,29 @@
                 };
                 break;
             case ActivationFunction::SOFT_RELU:
+            {
+                // TODO (COMPMID-1535) : Revisit FP16 approximations
+                const float16x4x2_t in0 =
+                {
+                    vcvt_f16_f32(vlogq_f32(vaddq_f32(CONST_1_F32, vexpq_f32(vcvt_f32_f16(vget_low_f16(in.val[0])))))),
+                    vcvt_f16_f32(vlogq_f32(vaddq_f32(CONST_1_F32, vexpq_f32(vcvt_f32_f16(vget_high_f16(in.val[0])))))),
+                };
+
+                const float16x4x2_t in1 =
+                {
+                    vcvt_f16_f32(vlogq_f32(vaddq_f32(CONST_1_F32, vexpq_f32(vcvt_f32_f16(vget_low_f16(in.val[1])))))),
+                    vcvt_f16_f32(vlogq_f32(vaddq_f32(CONST_1_F32, vexpq_f32(vcvt_f32_f16(vget_high_f16(in.val[1])))))),
+                };
+
                 tmp =
                 {
                     {
-                        vlogq_f16(vaddq_f16(CONST_1, vexpq_f16(in.val[0]))),
-                        vlogq_f16(vaddq_f16(CONST_1, vexpq_f16(in.val[1]))),
+                        vcombine_f16(in0.val[0], in0.val[1]),
+                        vcombine_f16(in1.val[0], in1.val[1]),
                     }
                 };
-                break;
+            }
+            break;
             case ActivationFunction::SQRT:
                 tmp =
                 {
@@ -289,14 +323,34 @@
                 };
                 break;
             case ActivationFunction::TANH:
+            {
+                // TODO (COMPMID-1535) : Revisit FP16 approximations
+                const float16x8x2_t mul =
+                {
+                    vmulq_f16(b, in.val[0]),
+                    vmulq_f16(b, in.val[1])
+                };
+                const float16x4x2_t in0 =
+                {
+                    vmul_f16(a_h, vcvt_f16_f32(vtanhq_f32(vcvt_f32_f16(vget_low_f16(mul.val[0]))))),
+                    vmul_f16(a_h, vcvt_f16_f32(vtanhq_f32(vcvt_f32_f16(vget_high_f16(mul.val[0]))))),
+                };
+
+                const float16x4x2_t in1 =
+                {
+                    vmul_f16(a_h, vcvt_f16_f32(vtanhq_f32(vcvt_f32_f16(vget_low_f16(mul.val[1]))))),
+                    vmul_f16(a_h, vcvt_f16_f32(vtanhq_f32(vcvt_f32_f16(vget_high_f16(mul.val[1]))))),
+                };
+
                 tmp =
                 {
                     {
-                        vmulq_f16(a, vtanhq_f16(vmulq_f16(b, in.val[0]))),
-                        vmulq_f16(a, vtanhq_f16(vmulq_f16(b, in.val[1]))),
+                        vcombine_f16(in0.val[0], in0.val[1]),
+                        vcombine_f16(in1.val[0], in1.val[1]),
                     }
                 };
-                break;
+            }
+            break;
             default:
                 ARM_COMPUTE_ERROR("Not implemented");
                 break;
diff --git a/tests/validation/Helpers.h b/tests/validation/Helpers.h
index 814d1f5..e5ba148 100644
--- a/tests/validation/Helpers.h
+++ b/tests/validation/Helpers.h
@@ -70,15 +70,16 @@
 
             switch(activation)
             {
+                case ActivationLayerInfo::ActivationFunction::TANH:
                 case ActivationLayerInfo::ActivationFunction::SQUARE:
                 case ActivationLayerInfo::ActivationFunction::LOGISTIC:
                 case ActivationLayerInfo::ActivationFunction::SOFT_RELU:
                     // Reduce range as exponent overflows
-                    bounds = std::make_pair(-10._h, 10._h);
+                    bounds = std::make_pair(-2._h, 2._h);
                     break;
                 case ActivationLayerInfo::ActivationFunction::SQRT:
                     // Reduce range as sqrt should take a non-negative number
-                    bounds = std::make_pair(0._h, 255._h);
+                    bounds = std::make_pair(0._h, 128._h);
                     break;
                 default:
                     bounds = std::make_pair(-255._h, 255._h);
diff --git a/tests/validation/NEON/ActivationLayer.cpp b/tests/validation/NEON/ActivationLayer.cpp
index dee264c..5b16b06 100644
--- a/tests/validation/NEON/ActivationLayer.cpp
+++ b/tests/validation/NEON/ActivationLayer.cpp
@@ -43,14 +43,42 @@
 {
 namespace
 {
-/** Define tolerance of the activation layer.
+/** Define relative tolerance of the activation layer.
  *
  * @param[in] data_type  The data type used.
  * @param[in] activation The activation function used.
  *
- * @return Tolerance depending on the activation function.
+ * @return Relative tolerance depending on the activation function.
  */
-AbsoluteTolerance<float> tolerance(DataType data_type, ActivationLayerInfo::ActivationFunction activation)
+RelativeTolerance<float> relative_tolerance(DataType data_type, ActivationLayerInfo::ActivationFunction activation)
+{
+    switch(activation)
+    {
+        case ActivationLayerInfo::ActivationFunction::LOGISTIC:
+        case ActivationLayerInfo::ActivationFunction::SOFT_RELU:
+        case ActivationLayerInfo::ActivationFunction::SQRT:
+        case ActivationLayerInfo::ActivationFunction::TANH:
+            switch(data_type)
+            {
+                case DataType::F16:
+                    return RelativeTolerance<float>(0.1f);
+                default:
+                    return RelativeTolerance<float>(0.05f);
+            }
+            break;
+        default:
+            return RelativeTolerance<float>(0.f);
+    }
+}
+
+/** Define absolute tolerance of the activation layer.
+ *
+ * @param[in] data_type  The data type used.
+ * @param[in] activation The activation function used.
+ *
+ * @return Absolute tolerance depending on the activation function.
+ */
+AbsoluteTolerance<float> absolute_tolerance(DataType data_type, ActivationLayerInfo::ActivationFunction activation)
 {
     switch(activation)
     {
@@ -163,14 +191,14 @@
                                                                                                                     DataType::F16)))
 {
     // Validate output
-    validate(Accessor(_target), _reference, tolerance(_data_type, _function));
+    validate(Accessor(_target), _reference, relative_tolerance(_data_type, _function), 0.f, absolute_tolerance(_data_type, _function));
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, NEActivationLayerFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ActivationDataset),
                                                                                                           framework::dataset::make("DataType",
                                                                                                                   DataType::F16)))
 {
     // Validate output
-    validate(Accessor(_target), _reference, tolerance(_data_type, _function));
+    validate(Accessor(_target), _reference, relative_tolerance(_data_type, _function), 0.f, absolute_tolerance(_data_type, _function));
 }
 TEST_SUITE_END()
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
@@ -181,13 +209,13 @@
 
 {
     // Validate output
-    validate(Accessor(_target), _reference, tolerance(_data_type, _function));
+    validate(Accessor(_target), _reference, relative_tolerance(_data_type, _function), 0.f, absolute_tolerance(_data_type, _function));
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, NEActivationLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ActivationDataset),
                                                                                                            framework::dataset::make("DataType", DataType::F32)))
 {
     // Validate output
-    validate(Accessor(_target), _reference, tolerance(_data_type, _function));
+    validate(Accessor(_target), _reference, relative_tolerance(_data_type, _function), 0.f, absolute_tolerance(_data_type, _function));
 }
 TEST_SUITE_END()
 TEST_SUITE_END()
@@ -211,7 +239,7 @@
                                                                                                                         framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.1f, 128.0f) })))
 {
     // Validate output
-    validate(Accessor(_target), _reference, tolerance(_data_type, _function));
+    validate(Accessor(_target), _reference, relative_tolerance(_data_type, _function), 0.f, absolute_tolerance(_data_type, _function));
 }
 FIXTURE_DATA_TEST_CASE(RunLarge, NEActivationLayerQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), QuantizedActivationDataset),
                                                                                                                       framework::dataset::make("DataType",
@@ -219,7 +247,7 @@
                                                                                                                       framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.1f, 128.0f) })))
 {
     // Validate output
-    validate(Accessor(_target), _reference, tolerance(_data_type, _function));
+    validate(Accessor(_target), _reference, relative_tolerance(_data_type, _function), 0.f, absolute_tolerance(_data_type, _function));
 }
 TEST_SUITE_END()
 TEST_SUITE_END()