Fix for tanh at small argument values

x - x^3/3 is more accurate approximation for |x| < 0.005
than (exp2x - 1)/(exp2x + 1).

Resolves: COMPMID-4098

Signed-off-by: Aleksandr Nikolaev <aleksandr.nikolaev@arm.com>
Change-Id: If6f9d7ce4d8d00d36d2dada7ab8f8d9f5b58f5c0
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/321354
Tested-by: bsgcomp <bsgcomp@arm.com>
Comments-Addressed: bsgcomp <bsgcomp@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5563
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/NEMath.inl b/src/core/NEON/NEMath.inl
index 29df543..5ac62ba 100644
--- a/src/core/NEON/NEMath.inl
+++ b/src/core/NEON/NEMath.inl
@@ -190,12 +190,15 @@
     static const float32x4_t CONST_2        = vdupq_n_f32(2.f);
     static const float32x4_t CONST_MIN_TANH = vdupq_n_f32(-10.f);
     static const float32x4_t CONST_MAX_TANH = vdupq_n_f32(10.f);
+    static const float32x4_t CONST_THR      = vdupq_n_f32(5.e-3);
+    static const float32x4_t CONST_1_3      = vdupq_n_f32(0.3333333f);
 
     float32x4_t x     = vminq_f32(vmaxq_f32(val, CONST_MIN_TANH), CONST_MAX_TANH);
-    float32x4_t exp2x = vexpq_f32(vmulq_f32(CONST_2, x));
-    float32x4_t num   = vsubq_f32(exp2x, CONST_1);
-    float32x4_t den   = vaddq_f32(exp2x, CONST_1);
-    float32x4_t tanh  = vmulq_f32(num, vinvq_f32(den));
+    // x * (1 - x^2/3) if |x| < 5.e-3 or (exp2x - 1) / (exp2x + 1) otherwise
+    float32x4_t exp2x = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vexpq_f32(vmulq_f32(CONST_2, x)), vmulq_f32(x, x));
+    float32x4_t num   = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vsubq_f32(exp2x, CONST_1), vmulq_f32(CONST_1_3, exp2x));
+    float32x4_t den   = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vaddq_f32(exp2x, CONST_1), vsubq_f32(CONST_1, num));
+    float32x4_t tanh  = vbslq_f32(vcgtq_f32(vabsq_f32(x), CONST_THR), vmulq_f32(num, vinvq_f32(den)), vmulq_f32(x, den));
     return tanh;
 }