MLBEDSW-3224: Support HardSwish

Change-Id: If49abc31f093f1bd3393bee86f821fd35972086f
Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
diff --git a/ethosu/vela/fp_math.py b/ethosu/vela/fp_math.py
index 5228f03..21022c2 100644
--- a/ethosu/vela/fp_math.py
+++ b/ethosu/vela/fp_math.py
@@ -35,13 +35,14 @@
     return x / (1 << fractional_bits)
 
 
-def saturating_rounding_mul(a, b):
+def saturating_rounding_mul32(a, b):
     assert np.int32(a) == a
     assert np.int32(b) == b
     if a == b and a == np.iinfo(np.int32).min:
         return np.int32(np.iinfo(np.int32).max)
     divider = 1 << 31
     ab = np.int64(a) * np.int64(b)
+
     if ab >= 0:
         nudge = 1 << 30
         return (ab + nudge) // divider
@@ -56,19 +57,81 @@
         return result
 
 
-def shift_left(a, offset):
-    assert np.int32(a) == a
+def saturating_rounding_mul16(a, b):
+    assert np.int16(a) == a
+    assert np.int16(b) == b
+    if a == b and a == np.iinfo(np.int16).min:
+        return np.int16(np.iinfo(np.int16).max)
+    divider = 1 << 15
+    ab = np.int32(a) * np.int32(b)
+
+    if ab >= 0:
+        nudge = 1 << 14
+        return (ab + nudge) // divider
+    else:
+        nudge = 1 - (1 << 14)
+        ab_plus_nudge = ab + nudge
+        result = ab_plus_nudge // divider
+        # Python uses floor, the reference uses truncation
+        # so we need to compensate for that.
+        if result * divider < ab_plus_nudge:
+            result += 1
+        return result
+
+
+# Similar to saturating_rounding_mul16 except rounding to zero instead of to nearest
+# Only supports 16bit
+def saturating_mul16(a, b):
+    assert np.int16(a) == a
+    assert np.int16(b) == b
+    if a == b and a == np.iinfo(np.int16).min:
+        return np.int16(np.iinfo(np.int16).max)
+    ab = np.int32(a) * np.int32(b)
+    divider = 1 << 15
+    if ab >= 0:
+        return ab // divider
+    else:
+        result = ab // divider
+        # Python uses floor, the reference uses truncation
+        # so we need to compensate for that.
+        if result * divider < ab:
+            result += 1
+        return result
+
+
+def shift_left32(a, offset):
     assert offset >= 0
-    i32_info = np.iinfo(np.int32)
+    assert np.int32(a) == a
     shifted = a * (1 << offset)
-    if shifted < i32_info.min:
-        return np.int32(i32_info.min)
-    elif shifted > i32_info.max:
-        return np.int32(i32_info.max)
+    if shifted < np.iinfo(np.int32).min:
+        return np.int32(np.iinfo(np.int32).min)
+    elif shifted > np.iinfo(np.int32).max:
+        return np.int32(np.iinfo(np.int32).max)
     else:
         return np.int32(shifted)
 
 
+def shift_left16(a, offset):
+    assert offset >= 0
+    assert np.int16(a) == a
+    shifted = a * (1 << offset)
+    if shifted < np.iinfo(np.int16).min:
+        return np.int16(np.iinfo(np.int16).min)
+    elif shifted > np.iinfo(np.int16).max:
+        return np.int16(np.iinfo(np.int16).max)
+    else:
+        return np.int16(shifted)
+
+
+def downscale_multiplier_int32_to_int16(a):
+    assert np.int32(a) == a
+    rounding_offset = 1 << 15
+    if a >= np.iinfo(np.int32).max - rounding_offset:
+        return np.iinfo(np.int16).max
+    else:
+        return np.int16((a + rounding_offset) >> 16)
+
+
 def rounding_divide_by_pot(x, exponent):
     assert np.int32(x) == x
     assert np.int32(exponent) == exponent
@@ -92,7 +155,7 @@
     elif x < -threshold:
         return np.iinfo(np.int32).min
     else:
-        return shift_left(x, exponent)
+        return shift_left32(x, exponent)
 
 
 def rescale(integer_bits_src, integer_bits_dst, x):
@@ -115,16 +178,16 @@
     constant_term = 1895147668
     constant_1_over_3 = 715827883
     x = a + (1 << offset)
-    x2 = saturating_rounding_mul(x, x)
-    x3 = saturating_rounding_mul(x2, x)
-    x4 = saturating_rounding_mul(x2, x2)
+    x2 = saturating_rounding_mul32(x, x)
+    x3 = saturating_rounding_mul32(x2, x)
+    x4 = saturating_rounding_mul32(x2, x2)
     x4_over_4 = rounding_divide_by_pot(x4, 2)
     x4_over_24_plus_x3_over_6_plus_x2_over_2 = rounding_divide_by_pot(
-        saturating_rounding_mul((x4_over_4 + x3), constant_1_over_3) + x2, 1
+        saturating_rounding_mul32((x4_over_4 + x3), constant_1_over_3) + x2, 1
     )
 
     return np.int32(
-        constant_term + saturating_rounding_mul(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2)
+        constant_term + saturating_rounding_mul32(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2)
     )
 
 
@@ -144,7 +207,7 @@
         integer_bits = 5
         shift = fractional_bits + exponent if integer_bits > exponent else 0
         if remainder & (1 << shift):
-            return saturating_rounding_mul(result, multiplier)
+            return saturating_rounding_mul32(result, multiplier)
         else:
             return result
 
@@ -168,5 +231,5 @@
     shift = 31 - shift
     left_shift = shift if shift > 0 else 0
     right_shift = -shift if shift < 0 else 0
-    mul = saturating_rounding_mul(x * (1 << left_shift), scale)
+    mul = saturating_rounding_mul32(x * (1 << left_shift), scale)
     return rounding_divide_by_pot(mul, right_shift)