[MLBEDSW-2845] Improve unit test coverage of fp_math

Improved unit test coverage of fp_math.py

Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: I883fd984a1bfa67102826a400380e41a363fc59d
diff --git a/ethosu/vela/test/test_fp_math.py b/ethosu/vela/test/test_fp_math.py
index 8c1ed67..905826f 100644
--- a/ethosu/vela/test/test_fp_math.py
+++ b/ethosu/vela/test/test_fp_math.py
@@ -64,53 +64,107 @@
 
 def test_saturating_rounding_mul():
     i32info = np.iinfo(np.int32)
-    shift = 22
-    multiplier = 1760306048
+    # Saturation
     assert fp_math.saturating_rounding_mul(i32info.min, i32info.min) == i32info.max
-    assert fp_math.saturating_rounding_mul(-255 * 1 << shift, multiplier) == -876714926
-    assert fp_math.saturating_rounding_mul(-128 * 1 << shift, multiplier) == -440076512
-    assert fp_math.saturating_rounding_mul(0, multiplier) == 0
-    assert fp_math.saturating_rounding_mul(128 * 1 << shift, multiplier) == 440076512
-    assert fp_math.saturating_rounding_mul(255 * 1 << shift, multiplier) == 876714926
+    assert fp_math.saturating_rounding_mul(i32info.min, i32info.max) == -i32info.max
+    assert fp_math.saturating_rounding_mul(i32info.max, i32info.min) == -i32info.max
+
+    # Multiply by zero
+    assert fp_math.saturating_rounding_mul(0, fp_math.from_float(1.0)) == 0
+    assert fp_math.saturating_rounding_mul(0, fp_math.from_float(-1.0)) == 0
+    assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), 0) == 0
+    assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), 0) == 0
+
+    # Multiply positive/negative
+    assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), fp_math.from_float(1.0)) == fp_math.from_float(
+        1.0, 5 + 5
+    )
+    assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), fp_math.from_float(1.0)) == fp_math.from_float(
+        -1.0, 5 + 5
+    )
+    assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), fp_math.from_float(-1.0)) == fp_math.from_float(
+        -1.0, 5 + 5
+    )
+    assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), fp_math.from_float(-1.0)) == fp_math.from_float(
+        1.0, 5 + 5
+    )
+
+    # Rounding
+    assert fp_math.saturating_rounding_mul(fp_math.from_float(16.0), 1) == 1
+    assert fp_math.saturating_rounding_mul(fp_math.from_float(-16.0), 1) == 0
+    assert fp_math.saturating_rounding_mul(fp_math.from_float(16.0) - 1, 1) == 0
+    assert fp_math.saturating_rounding_mul(fp_math.from_float(-16.0) - 1, 1) == -1
 
 
 def test_shift_left():
     i32info = np.iinfo(np.int32)
-    assert fp_math.shift_left(np.int32(1), i32info.bits) == i32info.max
-    assert fp_math.shift_left(np.int32(-1), i32info.bits) == i32info.min
-    assert fp_math.shift_left(np.int32(1), i32info.bits - 2) == (i32info.max + 1) / 2
-    assert fp_math.shift_left(np.int32(-1), i32info.bits - 2) == i32info.min // 2
+    assert fp_math.shift_left(1, i32info.bits) == i32info.max
+    assert fp_math.shift_left(-1, i32info.bits) == i32info.min
+    assert fp_math.shift_left(1, i32info.bits - 2) == (i32info.max + 1) / 2
+    assert fp_math.shift_left(-1, i32info.bits - 2) == i32info.min // 2
+
+    assert fp_math.shift_left(fp_math.from_float(1.0), 5) == i32info.max
+    assert fp_math.shift_left(fp_math.from_float(-1.0), 5) == i32info.min
+    assert fp_math.shift_left(fp_math.from_float(1.0), 4) == 16 * fp_math.from_float(1.0)
+    assert fp_math.shift_left(fp_math.from_float(-1.0), 4) == 16 * fp_math.from_float(-1.0)
+
+    with pytest.raises(AssertionError):
+        fp_math.shift_left(1, -1)
 
 
 def test_rounding_divide_by_pot():
-    assert fp_math.rounding_divide_by_pot(1024, 4) == 64
-    assert fp_math.rounding_divide_by_pot(1031, 4) == 64
-    assert fp_math.rounding_divide_by_pot(1032, 4) == 65
-    assert fp_math.rounding_divide_by_pot(1047, 4) == 65
-    assert fp_math.rounding_divide_by_pot(1048, 4) == 66
-    assert fp_math.rounding_divide_by_pot(1056, 4) == 66
-    assert fp_math.rounding_divide_by_pot(-1024, 4) == -64
-    assert fp_math.rounding_divide_by_pot(-1031, 4) == -64
-    assert fp_math.rounding_divide_by_pot(-1032, 4) == -65
-    assert fp_math.rounding_divide_by_pot(-1047, 4) == -65
-    assert fp_math.rounding_divide_by_pot(-1048, 4) == -66
-    assert fp_math.rounding_divide_by_pot(-1056, 4) == -66
+    # No remainder division
+    assert fp_math.rounding_divide_by_pot(fp_math.from_float(1.0), 26) == 1
+    assert fp_math.rounding_divide_by_pot(fp_math.from_float(-1.0), 26) == -1
+
+    # Remainder rounding the result away from zero
+    assert fp_math.rounding_divide_by_pot(fp_math.from_float(-1.0), 27) == -1
+    assert fp_math.rounding_divide_by_pot(fp_math.from_float(1.0), 27) == 1
+
+    # Remainder smaller than threshold to round the result away from zero
+    # Positive and negative edge cases
+    assert fp_math.rounding_divide_by_pot(fp_math.from_float(1.0) - 1, 27) == 0
+    assert fp_math.rounding_divide_by_pot(fp_math.from_float(-1.0) + 1, 27) == 0
+    # Far from the edge
+    assert fp_math.rounding_divide_by_pot(fp_math.from_float(1.0), 28) == 0
+    assert fp_math.rounding_divide_by_pot(fp_math.from_float(-1.0), 28) == 0
+
+    # Regular division - no remainder
+    assert fp_math.rounding_divide_by_pot(fp_math.from_float(1.0), 4) == fp_math.from_float(1.0 / 16)
+    assert fp_math.rounding_divide_by_pot(fp_math.from_float(-1.0), 4) == fp_math.from_float(-1.0 / 16)
+
+    # Rounding/no rounding edge cases
+    assert fp_math.rounding_divide_by_pot(fp_math.from_float(1.0) + (1 << 3) - 1, 4) == fp_math.from_float(1.0 / 16)
+    assert fp_math.rounding_divide_by_pot(fp_math.from_float(1.0) + (1 << 3), 4) == fp_math.from_float(1.0 / 16) + 1
+    assert fp_math.rounding_divide_by_pot(fp_math.from_float(-1.0) - (1 << 3) + 1, 4) == fp_math.from_float(-1.0 / 16)
+    assert fp_math.rounding_divide_by_pot(fp_math.from_float(-1.0) - (1 << 3), 4) == fp_math.from_float(-1.0 / 16) - 1
 
 
 def test_saturating_rounding_multiply_by_pot():
     i32info = np.iinfo(np.int32)
-    assert fp_math.saturating_rounding_multiply_by_pot(4, np.int32(1025)) == 16400
-    assert fp_math.saturating_rounding_multiply_by_pot(5, np.int32(67108865)) == i32info.max
-    assert fp_math.saturating_rounding_multiply_by_pot(5, np.int32(-67108865)) == i32info.min
+    assert fp_math.saturating_rounding_multiply_by_pot(fp_math.from_float(1.0), 5) == i32info.max
+    assert fp_math.saturating_rounding_multiply_by_pot(fp_math.from_float(-1.0), 5) == i32info.min
+    assert fp_math.saturating_rounding_multiply_by_pot(fp_math.from_float(1.0) - 1, 5) == i32info.max - 32 + 1
+    assert fp_math.saturating_rounding_multiply_by_pot(fp_math.from_float(-1.0) + 1, 5) == -i32info.max + 32 - 1
+    assert fp_math.saturating_rounding_multiply_by_pot(fp_math.from_float(1.0), 4) == fp_math.from_float(1.0 * 16)
+    assert fp_math.saturating_rounding_multiply_by_pot(fp_math.from_float(-1.0), 4) == fp_math.from_float(-1.0 * 16)
 
 
 def test_rescale():
-    assert fp_math.rescale(5, 0, np.int32(1025)) == 32800
-    assert fp_math.rescale(3, 0, np.int32(1025)) == 8200
-    assert fp_math.rescale(5, 1, np.int32(1025)) == 16400
-    assert fp_math.rescale(3, 1, np.int32(1025)) == 4100
-    with pytest.raises(AssertionError):
-        fp_math.rescale(1, 3, np.int32(1024))
+    assert fp_math.rescale(5, 0, fp_math.from_float(1.0)) == fp_math.from_float(1.0, 0)
+    assert fp_math.rescale(5, 10, fp_math.from_float(1.0)) == fp_math.from_float(1.0, 10)
+    assert fp_math.rescale(5, 0, fp_math.from_float(-1.0)) == fp_math.from_float(-1.0, 0)
+    assert fp_math.rescale(5, 10, fp_math.from_float(-1.0)) == fp_math.from_float(-1.0, 10)
+
+    assert fp_math.rescale(5, 4, fp_math.from_float(32.0)) == fp_math.from_float(32.0, 4)
+    assert fp_math.rescale(5, 6, fp_math.from_float(32.0)) == fp_math.from_float(32.0, 6)
+    assert fp_math.rescale(5, 4, fp_math.from_float(-32.0)) == fp_math.from_float(-32.0, 4)
+    assert fp_math.rescale(5, 6, fp_math.from_float(-32.0)) == fp_math.from_float(-32.0, 6)
+
+    assert fp_math.rescale(5, 4, fp_math.from_float(31.9)) == fp_math.from_float(31.9, 4)
+    assert fp_math.rescale(5, 6, fp_math.from_float(31.9)) == fp_math.from_float(31.9, 6)
+    assert fp_math.rescale(5, 4, fp_math.from_float(-31.9)) == fp_math.from_float(-31.9, 4)
+    assert fp_math.rescale(5, 6, fp_math.from_float(-31.9)) == fp_math.from_float(-31.9, 6)
 
 
 def test_exp():