Reimplement erf function

* The current implementation has signfinicant inaccuracy
  and the issue cascades to GELU.
* Use the implementation from ArmĀ® Optimized Routines.
  The maximum error is 1.93 ULP.

Resolves: COMPMID-6554
Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Change-Id: If80131e164b7a078e34dd8e05b1506698f31d17a
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10395
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: TeresaARM <teresa.charlinreyes@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl
index f875917..a5aba0b 100644
--- a/src/core/NEON/NEMath.inl
+++ b/src/core/NEON/NEMath.inl
@@ -21,6 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
+
+#include "src/core/utils/Math.h"
 #include "support/ToolchainSupport.h"
 
 #include <cmath>
@@ -224,35 +226,62 @@
 #ifdef __aarch64__
 inline float32x4_t verfq_f32(float32x4_t x)
 {
-    static const float       erffdata[4] = {0.278393f, 0.230389f, 0.000972f, 0.078108f};
-    static const float32x4_t coeffdata   = vld1q_f32(erffdata);
-    static const float32x4_t onev{vdupq_n_f32(1.0f)};
+    const float32x4_t max_value = vdupq_n_f32(3.9375);       // 4 - 8/128
+    const float32x4_t shift     = vdupq_n_f32(65536);        // 2^16
+    const float32x4_t third     = vdupq_n_f32(0.3333333333); // 1/3
+    const float32x4_t one       = vdupq_n_f32(1.f);
+    const uint32x4_t  max_index = vdupq_n_u32(512);
+    const uint32x4_t  sign_mask = vdupq_n_u32(0x7fffffff);
 
-    uint32x4_t selector = vcltzq_f32(x);
+    const float32x4_t x_abs = vabsq_f32(x);
 
-    float32x4_t absx  = vabsq_f32(x);
-    float32x4_t absx2 = vmulq_f32(x, x);
-    float32x4_t absx3 = vmulq_f32(absx2, absx);
-    float32x4_t absx4 = vmulq_f32(absx2, absx2);
+    // erf(x) for x in [0, 3.9375] is approxiated as follows:
+    //
+    //   erf(x) = erf(r) + scale(r) * d * (1 - r * d - 1/3 * d^2)
+    //
+    // where:
+    //   r = floor(x * 128) / 128
+    //   d = x - r
+    //
+    // erf(r) and scale(r) are stored in a 513-entry lookup table.
+    // The LUT covers the range from 0 to 4 with the step of 1/128.
+    //
+    // Special cases:
+    //   erf(x) =  1 for x >  3.9375
+    //   erf(x) = -1 for x < -3.9375
 
-    float32x4_t denom = onev;
-    denom             = vfmaq_laneq_f32(denom, absx, coeffdata, 0);
-    denom             = vfmaq_laneq_f32(denom, absx2, coeffdata, 1);
-    denom             = vfmaq_laneq_f32(denom, absx3, coeffdata, 2);
-    denom             = vfmaq_laneq_f32(denom, absx4, coeffdata, 3);
+    // Find the LUT indices by rounding the input value to the step of 1/128.
+    //
+    // `shift` is used to push out the 16 LSBs of the input value. Only 7 bits in the fraction part
+    // of the input value is preserved.
+    const float32x4_t z = x_abs + shift;
+    const float32x4_t r = z - shift;
 
-    denom = vmulq_f32(denom, denom);
-    denom = vmulq_f32(denom, denom);
+    uint32x4_t index = vreinterpretq_u32_f32(z) - vreinterpretq_u32_f32(shift);
+    index            = vminq_u32(index, max_index);
 
-    float32x4_t fract = onev;
-    fract             = vdivq_f32(fract, denom);
+    // Lookup erf(r) and scale(r).
+    const float64_t entry_0 = *reinterpret_cast<const float64_t *>(&erf_f32_lut[index[0]]);
+    const float64_t entry_1 = *reinterpret_cast<const float64_t *>(&erf_f32_lut[index[1]]);
+    const float64_t entry_2 = *reinterpret_cast<const float64_t *>(&erf_f32_lut[index[2]]);
+    const float64_t entry_3 = *reinterpret_cast<const float64_t *>(&erf_f32_lut[index[3]]);
 
-    float32x4_t result = onev;
-    result             = vsubq_f32(result, fract);
+    const float32x4_t entry_01 = vreinterpretq_f32_f64(float64x2_t{entry_0, entry_1});
+    const float32x4_t entry_23 = vreinterpretq_f32_f64(float64x2_t{entry_2, entry_3});
 
-    float32x4_t inverse = vnegq_f32(result);
+    const float32x4_t erf_r   = vuzp1q_f32(entry_01, entry_23);
+    const float32x4_t scale_r = vuzp2q_f32(entry_01, entry_23);
 
-    result = vbslq_f32(selector, inverse, result);
+    // Approximate erf(x) = erf(r) + scale(r) * d * (1 - r * d - 1/3 * d^2).
+    const float32x4_t d  = x_abs - r;
+    const float32x4_t d2 = d * d;
+
+    const float32x4_t t0    = vfmaq_f32(r, third, d); // t0 = r + 1/3 * d.
+    const float32x4_t t1    = vfmsq_f32(d, d2, t0);   // t1 = d - d2 * t0 = d * (1 - r * d - 1/3 * d^2).
+    const float32x4_t erf_x = vfmaq_f32(erf_r, scale_r, t1);
+
+    const float32x4_t clamped = vbslq_f32(x_abs > max_value, one, erf_x);
+    const float32x4_t result  = vbslq_f32(sign_mask, clamped, x);
 
     return result;
 }