Adding GELU activation

OpenCL implementation uses built in erf.

NEON implementation requires new vectorized erf.
Uses the following approximation:
erf(x) = 1 - 1 / (1 + a1x + a2x^2 + a3x^3 + a4x^4)^4
a1 = 0.278393, a2 = 0.230389, a3 = 0.000972, a4 = 0.078108

From https://en.wikipedia.org/wiki/Error_function#Numerical_approximations

Signed-off-by: Murray Kornelsen <murray.kornelsen@mail.mcgill.ca>
Change-Id: I2d3964b2c26a4334166b17135f9104bc6324fad2
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7921
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Pablo Marquez Tello <pablo.tello@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/activation_float_helpers.h b/src/core/CL/cl_kernels/activation_float_helpers.h
index 91d7197..3f93c8d 100644
--- a/src/core/CL/cl_kernels/activation_float_helpers.h
+++ b/src/core/CL/cl_kernels/activation_float_helpers.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2020, 2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -69,6 +69,9 @@
 // Linear Activation
 #define linear_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (MLA((DATA_TYPE)B_VAL, (DATA_TYPE)A_VAL, x))
 
+// GELU Activation
+#define gelu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (x * (DATA_TYPE)0.5 * ((DATA_TYPE)1.0 + erf(x / (DATA_TYPE)1.41421356237)))
+
 // Identity Activation
 #define identity_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (x)
 
diff --git a/src/core/NEON/NEMath.h b/src/core/NEON/NEMath.h
index 8118c47..9e81c38 100644
--- a/src/core/NEON/NEMath.h
+++ b/src/core/NEON/NEMath.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2016-2021 Arm Limited.
+ * Copyright (c) 2016-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -94,6 +94,14 @@
  */
 float32x4_t vexpq_f32(float32x4_t x);
 
+/** Calculate error function
+ *
+ * @param[in] x Input vector in F32 format.
+ *
+ * @return The calculated erf.
+ */
+float32x4_t verfq_f32(float32x4_t x);
+
 /** Calculate logarithm
  *
  * @param[in] x Input vector value in F32 format.
@@ -308,6 +316,14 @@
  */
 float16x8_t vexpq_f16(float16x8_t x);
 
+/** Calculate error function
+ *
+ * @param[in] x Input vector in F16 format.
+ *
+ * @return The calculated erf.
+ */
+float16x8_t verfq_f16(float16x8_t x);
+
 /** Calculate n power of a number.
  *
  * pow(x,n) = e^(n*log(x))
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl
index 05cf301..1b0b894 100644
--- a/src/core/NEON/NEMath.inl
+++ b/src/core/NEON/NEMath.inl
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2016-2021 Arm Limited.
+ * Copyright (c) 2016-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -166,6 +166,43 @@
     return poly;
 }
 
+#ifdef __aarch64__
+inline float32x4_t verfq_f32(float32x4_t x)
+{
+    static const float       erffdata[4] = { 0.278393f, 0.230389f, 0.000972f, 0.078108f };
+    static const float32x4_t coeffdata   = vld1q_f32(erffdata);
+    static const float32x4_t onev{ vdupq_n_f32(1.0f) };
+
+    uint32x4_t selector = vcltzq_f32(x);
+
+    float32x4_t absx  = vabsq_f32(x);
+    float32x4_t absx2 = vmulq_f32(x, x);
+    float32x4_t absx3 = vmulq_f32(absx2, absx);
+    float32x4_t absx4 = vmulq_f32(absx2, absx2);
+
+    float32x4_t denom = onev;
+    denom             = vfmaq_laneq_f32(denom, absx, coeffdata, 0);
+    denom             = vfmaq_laneq_f32(denom, absx2, coeffdata, 1);
+    denom             = vfmaq_laneq_f32(denom, absx3, coeffdata, 2);
+    denom             = vfmaq_laneq_f32(denom, absx4, coeffdata, 3);
+
+    denom = vmulq_f32(denom, denom);
+    denom = vmulq_f32(denom, denom);
+
+    float32x4_t fract = onev;
+    fract             = vdivq_f32(fract, denom);
+
+    float32x4_t result = onev;
+    result             = vsubq_f32(result, fract);
+
+    float32x4_t inverse = vnegq_f32(result);
+
+    result = vbslq_f32(selector, inverse, result);
+
+    return result;
+}
+#endif // #ifdef __aarch64__
+
 inline float32x4_t vlogq_f32(float32x4_t x)
 {
     static const int32x4_t   CONST_127 = vdupq_n_s32(127);           // 127
@@ -517,6 +554,17 @@
     return res;
 }
 
+#ifdef __aarch64__
+inline float16x8_t verfq_f16(float16x8_t x)
+{
+    const float32x4_t x_high = vcvt_f32_f16(vget_high_f16(x));
+    const float32x4_t x_low  = vcvt_f32_f16(vget_low_f16(x));
+
+    const float16x8_t res = vcombine_f16(vcvt_f16_f32(verfq_f32(x_low)), vcvt_f16_f32(verfq_f32(x_high)));
+    return res;
+}
+#endif // #ifdef __aarch64__
+
 inline float16x8_t vlogq_f16(float16x8_t x)
 {
     const float32x4_t x_high = vcvt_f32_f16(vget_high_f16(x));
diff --git a/src/core/NEON/wrapper/intrinsics/erf.h b/src/core/NEON/wrapper/intrinsics/erf.h
new file mode 100644
index 0000000..e220764
--- /dev/null
+++ b/src/core/NEON/wrapper/intrinsics/erf.h
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ARM_COMPUTE_WRAPPER_ERF_H
+#define ARM_COMPUTE_WRAPPER_ERF_H
+
+#include "src/core/NEON/NEMath.h"
+#include <arm_neon.h>
+
+namespace arm_compute
+{
+namespace wrapper
+{
+#define VERF_IMPL(vtype, prefix, postfix) \
+    inline vtype verf(const vtype &a)     \
+    {                                     \
+        return prefix##_##postfix(a);     \
+    }
+
+VERF_IMPL(float32x4_t, verfq, f32)
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+VERF_IMPL(float16x8_t, verfq, f16)
+#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+
+#undef VERF_IMPL
+
+} // namespace wrapper
+} // namespace arm_compute
+
+#endif /* ARM_COMPUTE_WRAPPER_ERF_H */
diff --git a/src/core/NEON/wrapper/intrinsics/intrinsics.h b/src/core/NEON/wrapper/intrinsics/intrinsics.h
index 871d9cc..0256e0a 100644
--- a/src/core/NEON/wrapper/intrinsics/intrinsics.h
+++ b/src/core/NEON/wrapper/intrinsics/intrinsics.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -39,6 +39,7 @@
 #include "src/core/NEON/wrapper/intrinsics/div.h"
 #include "src/core/NEON/wrapper/intrinsics/dup_n.h"
 #include "src/core/NEON/wrapper/intrinsics/eor.h"
+#include "src/core/NEON/wrapper/intrinsics/erf.h"
 #include "src/core/NEON/wrapper/intrinsics/exp.h"
 #include "src/core/NEON/wrapper/intrinsics/ext.h"
 #include "src/core/NEON/wrapper/intrinsics/gethigh.h"
diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp
index 904362e..48eb8b9 100644
--- a/src/core/Utils.cpp
+++ b/src/core/Utils.cpp
@@ -177,7 +177,8 @@
         { ActivationLayerInfo::ActivationFunction::SQUARE, "SQUARE" },
         { ActivationLayerInfo::ActivationFunction::TANH, "TANH" },
         { ActivationLayerInfo::ActivationFunction::IDENTITY, "IDENTITY" },
-        { ActivationLayerInfo::ActivationFunction::HARD_SWISH, "HARD_SWISH" }
+        { ActivationLayerInfo::ActivationFunction::HARD_SWISH, "HARD_SWISH" },
+        { ActivationLayerInfo::ActivationFunction::GELU, "GELU" }
 
     };
 
diff --git a/src/cpu/kernels/CpuActivationKernel.cpp b/src/cpu/kernels/CpuActivationKernel.cpp
index ee9db99..61efcb2 100644
--- a/src/cpu/kernels/CpuActivationKernel.cpp
+++ b/src/cpu/kernels/CpuActivationKernel.cpp
@@ -46,7 +46,8 @@
 static const std::vector<CpuActivationKernel::ActivationKernel> available_kernels =
 {
 #ifdef __aarch64__
-    { // Neon LUT implementantion takes precedence
+    {
+        // Neon LUT implementantion takes precedence
         "neon_q8_activation_lut",
         [](const ActivationDataTypeISASelectorData & data) { return ActivationLayerInfo::is_lut_supported(data.f, data.dt); },
         REGISTER_Q8_NEON(arm_compute::cpu::neon_q8_activation_lut)
@@ -54,27 +55,27 @@
 #endif // __aarch64__
     {
         "sve2_qu8_activation",
-        [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8 && data.isa.sve2; },
+        [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8 && data.isa.sve2 && data.f != ActivationLayerInfo::ActivationFunction::GELU; },
         REGISTER_QASYMM8_SVE2(arm_compute::cpu::sve2_qasymm8_activation)
     },
     {
         "sve2_qs8_activation",
-        [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve2; },
+        [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve2 && data.f != ActivationLayerInfo::ActivationFunction::GELU; },
         REGISTER_QASYMM8_SIGNED_SVE2(arm_compute::cpu::sve2_qasymm8_signed_activation)
     },
     {
         "sve2_qs16_activation",
-        [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QSYMM16 && data.isa.sve2; },
+        [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QSYMM16 && data.isa.sve2 && data.f != ActivationLayerInfo::ActivationFunction::GELU; },
         REGISTER_QSYMM16_SVE2(arm_compute::cpu::sve2_qsymm16_activation)
     },
     {
         "sve_fp16_activation",
-        [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16; },
+        [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16 && data.f != ActivationLayerInfo::ActivationFunction::GELU; },
         REGISTER_FP16_SVE(arm_compute::cpu::sve_fp16_activation)
     },
     {
         "sve_fp32_activation",
-        [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F32 && data.isa.sve; },
+        [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F32 && data.isa.sve && data.f != ActivationLayerInfo::ActivationFunction::GELU; },
         REGISTER_FP32_SVE(arm_compute::cpu::sve_fp32_activation)
     },
     {
@@ -105,7 +106,7 @@
 };
 
 /* Supported activation in the 8-bit integer domain */
-static const std::array<ActivationLayerInfo::ActivationFunction, 7> qasymm8_activations =
+static const std::array<ActivationLayerInfo::ActivationFunction, 8> qasymm8_activations =
 {
     ActivationLayerInfo::ActivationFunction::RELU,
     ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU,
@@ -114,6 +115,7 @@
     ActivationLayerInfo::ActivationFunction::TANH,
     ActivationLayerInfo::ActivationFunction::HARD_SWISH,
     ActivationLayerInfo::ActivationFunction::LEAKY_RELU,
+    ActivationLayerInfo::ActivationFunction::GELU,
 };
 /* Supported activation in the 16-bit integer domain */
 static const std::array<ActivationLayerInfo::ActivationFunction, 4> qsymm16_activations =
@@ -193,7 +195,7 @@
 #ifdef __aarch64__
     if(ActivationLayerInfo::is_lut_supported(activation_info.activation(), src->data_type()))
     {
-        activation_info.init_lut(src->data_type(), src->quantization_info().uniform(), (dst)?dst->quantization_info().uniform():src->quantization_info().uniform());
+        activation_info.init_lut(src->data_type(), src->quantization_info().uniform(), (dst) ? dst->quantization_info().uniform() : src->quantization_info().uniform());
     }
 #endif // __aarch64__
     _act_info = activation_info;
diff --git a/src/cpu/kernels/activation/generic/neon/impl.h b/src/cpu/kernels/activation/generic/neon/impl.h
index 2dd239e..35abcb5 100644
--- a/src/cpu/kernels/activation/generic/neon/impl.h
+++ b/src/cpu/kernels/activation/generic/neon/impl.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2020-2021 Arm Limited.
+ * Copyright (c) 2020-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -77,7 +77,9 @@
     const auto      const_0           = wrapper::vdup_n(static_cast<T>(0.f), ExactTagType{});
     const auto      const_6           = wrapper::vdup_n(static_cast<T>(6.f), ExactTagType{});
     const auto      const_3           = wrapper::vdup_n(static_cast<T>(3.f), ExactTagType{});
+    const auto      const_inv_2       = wrapper::vdup_n(static_cast<T>(0.5f), ExactTagType{});
     const auto      const_inv_6       = wrapper::vdup_n(static_cast<T>(0.166666667f), ExactTagType{});
+    const auto      const_inv_sqrt_2  = wrapper::vdup_n(static_cast<T>(0.70710678118f), ExactTagType{});
     constexpr float soft_relu_thresh  = 12.f;
     const auto      vsoft_relu_thresh = wrapper::vdup_n(static_cast<T>(soft_relu_thresh), ExactTagType{});
     const auto      va                = wrapper::vdup_n(static_cast<T>(act_info.a()), ExactTagType{});
@@ -146,6 +148,9 @@
                 case ActivationLayerInfo::ActivationFunction::HARD_SWISH:
                     tmp = wrapper::vmul(vin, wrapper::vmul(const_inv_6, wrapper::vmin(const_6, wrapper::vmax(const_0, wrapper::vadd(vin, const_3)))));
                     break;
+                case ActivationLayerInfo::ActivationFunction::GELU:
+                    tmp = wrapper::vmul(vin, wrapper::vmul(const_inv_2, wrapper::vadd(const_1, wrapper::verf(wrapper::vmul(vin, const_inv_sqrt_2)))));
+                    break;
                 default:
                     ARM_COMPUTE_ERROR("Unsupported activation function");
             }
@@ -200,6 +205,9 @@
                 case ActivationLayerInfo::ActivationFunction::HARD_SWISH:
                     tmp = in * ((std::min(std::max((in + 3), 0.0f), 6.0f)) * 0.166666667f);
                     break;
+                case ActivationLayerInfo::ActivationFunction::GELU:
+                    tmp = in * static_cast<T>(0.5f * (1.0f + erff(static_cast<float>(in) / 1.41421356237f)));
+                    break;
                 default:
                     ARM_COMPUTE_ERROR("Unsupported activation function");
             }
diff --git a/src/cpu/kernels/activation/generic/neon/qasymm8.cpp b/src/cpu/kernels/activation/generic/neon/qasymm8.cpp
index 67d9e0a..05a0b50 100644
--- a/src/cpu/kernels/activation/generic/neon/qasymm8.cpp
+++ b/src/cpu/kernels/activation/generic/neon/qasymm8.cpp
@@ -58,9 +58,13 @@
     const qasymm8_t               b        = quantize_qasymm8(act_info.b(), qi_in);
     const qasymm8_t               const_0  = quantize_qasymm8(0.f, qi_in);
     const qasymm8x16_t            vconst_0 = vdupq_n_u8(const_0);
+    const auto                    vconst_1 = vdupq_n_f32(1.f);
+
 #ifndef __aarch64__
-    const auto vconst_1     = vdupq_n_f32(1.f);
     const auto vconst_0_f32 = vdupq_n_f32(0);
+#else  // #ifndef __aarch64__
+    const auto const_inv_2      = vdupq_n_f32(0.5f);
+    const auto const_inv_sqrt_2 = vdupq_n_f32(0.70710678118f);
 #endif // __aarch64__
     const float32x4_t va_f32 = vdupq_n_f32(act_info.a());
     const float32x4_t vb_f32 = vdupq_n_f32(act_info.b());
@@ -193,6 +197,23 @@
 
                 tmp = vquantize(tmp_dep, qi_out);
             }
+#else  // #ifndef __aarch64__
+            else if (act == ActivationLayerInfo::ActivationFunction::GELU)
+            {
+                const auto vin_deq = vdequantize(vin, qi_in);
+                // Perform activation
+                const float32x4x4_t tmp_dep =
+                {
+                    {
+                        wrapper::vmul(vin_deq.val[0], wrapper::vmul(const_inv_2, wrapper::vadd(vconst_1, wrapper::verf(wrapper::vmul(vin_deq.val[0], const_inv_sqrt_2))))),
+                        wrapper::vmul(vin_deq.val[1], wrapper::vmul(const_inv_2, wrapper::vadd(vconst_1, wrapper::verf(wrapper::vmul(vin_deq.val[1], const_inv_sqrt_2))))),
+                        wrapper::vmul(vin_deq.val[2], wrapper::vmul(const_inv_2, wrapper::vadd(vconst_1, wrapper::verf(wrapper::vmul(vin_deq.val[2], const_inv_sqrt_2))))),
+                        wrapper::vmul(vin_deq.val[3], wrapper::vmul(const_inv_2, wrapper::vadd(vconst_1, wrapper::verf(wrapper::vmul(vin_deq.val[3], const_inv_sqrt_2))))),
+                    }
+                };
+                // Re-quantize to new output space
+                tmp = vquantize(tmp_dep, qi_out);
+            }
 #endif // __aarch64__
             else
             {
@@ -248,6 +269,12 @@
                 tmp_f       = tmp_f > 0 ? tmp_f : tmp_f * a_f32;
                 tmp         = quantize_qasymm8(tmp_f, qi_out);
             }
+            else if(act == ActivationLayerInfo::ActivationFunction::GELU)
+            {
+                float tmp_f = dequantize_qasymm8(in, qi_in);
+                tmp         = tmp_f * 0.5f * (1.0f + std::erff(in / 1.41421356237f));
+                tmp         = quantize_qasymm8(tmp_f, qi_out);
+            }
 #endif // __aarch64__
             else
             {