MLBEDSW-3224: Support HardSwish

Change-Id: If49abc31f093f1bd3393bee86f821fd35972086f
Signed-off-by: Diqing Zhong <diqing.zhong@arm.com>
diff --git a/ethosu/vela/fp_math.py b/ethosu/vela/fp_math.py
index 5228f03..21022c2 100644
--- a/ethosu/vela/fp_math.py
+++ b/ethosu/vela/fp_math.py
@@ -35,13 +35,14 @@
     return x / (1 << fractional_bits)
 
 
-def saturating_rounding_mul(a, b):
+def saturating_rounding_mul32(a, b):
     assert np.int32(a) == a
     assert np.int32(b) == b
     if a == b and a == np.iinfo(np.int32).min:
         return np.int32(np.iinfo(np.int32).max)
     divider = 1 << 31
     ab = np.int64(a) * np.int64(b)
+
     if ab >= 0:
         nudge = 1 << 30
         return (ab + nudge) // divider
@@ -56,19 +57,81 @@
         return result
 
 
-def shift_left(a, offset):
-    assert np.int32(a) == a
+def saturating_rounding_mul16(a, b):
+    assert np.int16(a) == a
+    assert np.int16(b) == b
+    if a == b and a == np.iinfo(np.int16).min:
+        return np.int16(np.iinfo(np.int16).max)
+    divider = 1 << 15
+    ab = np.int32(a) * np.int32(b)
+
+    if ab >= 0:
+        nudge = 1 << 14
+        return (ab + nudge) // divider
+    else:
+        nudge = 1 - (1 << 14)
+        ab_plus_nudge = ab + nudge
+        result = ab_plus_nudge // divider
+        # Python uses floor, the reference uses truncation
+        # so we need to compensate for that.
+        if result * divider < ab_plus_nudge:
+            result += 1
+        return result
+
+
+# Similar to saturating_rounding_mul16 except rounding to zero instead of to nearest
+# Only supports 16bit
+def saturating_mul16(a, b):
+    assert np.int16(a) == a
+    assert np.int16(b) == b
+    if a == b and a == np.iinfo(np.int16).min:
+        return np.int16(np.iinfo(np.int16).max)
+    ab = np.int32(a) * np.int32(b)
+    divider = 1 << 15
+    if ab >= 0:
+        return ab // divider
+    else:
+        result = ab // divider
+        # Python uses floor, the reference uses truncation
+        # so we need to compensate for that.
+        if result * divider < ab:
+            result += 1
+        return result
+
+
+def shift_left32(a, offset):
     assert offset >= 0
-    i32_info = np.iinfo(np.int32)
+    assert np.int32(a) == a
     shifted = a * (1 << offset)
-    if shifted < i32_info.min:
-        return np.int32(i32_info.min)
-    elif shifted > i32_info.max:
-        return np.int32(i32_info.max)
+    if shifted < np.iinfo(np.int32).min:
+        return np.int32(np.iinfo(np.int32).min)
+    elif shifted > np.iinfo(np.int32).max:
+        return np.int32(np.iinfo(np.int32).max)
     else:
         return np.int32(shifted)
 
 
+def shift_left16(a, offset):
+    assert offset >= 0
+    assert np.int16(a) == a
+    shifted = a * (1 << offset)
+    if shifted < np.iinfo(np.int16).min:
+        return np.int16(np.iinfo(np.int16).min)
+    elif shifted > np.iinfo(np.int16).max:
+        return np.int16(np.iinfo(np.int16).max)
+    else:
+        return np.int16(shifted)
+
+
+def downscale_multiplier_int32_to_int16(a):
+    assert np.int32(a) == a
+    rounding_offset = 1 << 15
+    if a >= np.iinfo(np.int32).max - rounding_offset:
+        return np.iinfo(np.int16).max
+    else:
+        return np.int16((a + rounding_offset) >> 16)
+
+
 def rounding_divide_by_pot(x, exponent):
     assert np.int32(x) == x
     assert np.int32(exponent) == exponent
@@ -92,7 +155,7 @@
     elif x < -threshold:
         return np.iinfo(np.int32).min
     else:
-        return shift_left(x, exponent)
+        return shift_left32(x, exponent)
 
 
 def rescale(integer_bits_src, integer_bits_dst, x):
@@ -115,16 +178,16 @@
     constant_term = 1895147668
     constant_1_over_3 = 715827883
     x = a + (1 << offset)
-    x2 = saturating_rounding_mul(x, x)
-    x3 = saturating_rounding_mul(x2, x)
-    x4 = saturating_rounding_mul(x2, x2)
+    x2 = saturating_rounding_mul32(x, x)
+    x3 = saturating_rounding_mul32(x2, x)
+    x4 = saturating_rounding_mul32(x2, x2)
     x4_over_4 = rounding_divide_by_pot(x4, 2)
     x4_over_24_plus_x3_over_6_plus_x2_over_2 = rounding_divide_by_pot(
-        saturating_rounding_mul((x4_over_4 + x3), constant_1_over_3) + x2, 1
+        saturating_rounding_mul32((x4_over_4 + x3), constant_1_over_3) + x2, 1
     )
 
     return np.int32(
-        constant_term + saturating_rounding_mul(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2)
+        constant_term + saturating_rounding_mul32(constant_term, x + x4_over_24_plus_x3_over_6_plus_x2_over_2)
     )
 
 
@@ -144,7 +207,7 @@
         integer_bits = 5
         shift = fractional_bits + exponent if integer_bits > exponent else 0
         if remainder & (1 << shift):
-            return saturating_rounding_mul(result, multiplier)
+            return saturating_rounding_mul32(result, multiplier)
         else:
             return result
 
@@ -168,5 +231,5 @@
     shift = 31 - shift
     left_shift = shift if shift > 0 else 0
     right_shift = -shift if shift < 0 else 0
-    mul = saturating_rounding_mul(x * (1 << left_shift), scale)
+    mul = saturating_rounding_mul32(x * (1 << left_shift), scale)
     return rounding_divide_by_pot(mul, right_shift)
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index ab4d916..7755cc3 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -823,6 +823,58 @@
     return op
 
 
+def convert_hardswish_to_lut(op, arch, nng):
+    if op.type == Op.HardSwish:
+        ifm, ofm = op.get_ifm_ofm()
+        # Generate the LUT
+        ifm_scale = np.double(ifm.quantization.scale_f32)
+        ofm_scale = np.double(ofm.quantization.scale_f32)
+        zp_in = ifm.quantization.zero_point
+        zp_out = ofm.quantization.zero_point
+        ifm_scale_hires = (1 / 128) * ifm_scale
+        relu_multiplier = np.double(3 / 32768)
+        out_scale, out_shift = scaling.quantise_scale(ifm_scale_hires / ofm_scale)
+        relu_scale, relu_shift = scaling.quantise_scale(ifm_scale_hires / relu_multiplier)
+        # Use 16bit scale
+        out_scale_16 = fp_math.downscale_multiplier_int32_to_int16(out_scale)
+        relu_scale_16 = fp_math.downscale_multiplier_int32_to_int16(relu_scale)
+
+        values = []
+        ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
+        quantized_min = min(ix)
+        quantized_max = max(ix)
+        for x in ix:
+            input_value = x - zp_in
+            input_value_hires = input_value * 128
+            # Compute the input value on essentially the output scale, not shifted yet
+            input_value_preshift = fp_math.saturating_rounding_mul16(input_value_hires, out_scale_16)
+            # Compute the "relu-ish multiplier". This matches the code in TensorFlow Lite Micro kernel
+            relu_value = np.int16(input_value_hires)
+            if relu_shift < 31:
+                relu_value = fp_math.shift_left16(relu_value, 30 - relu_shift)
+
+            relu_value = fp_math.saturating_rounding_mul16(relu_value, relu_scale_16)
+
+            if relu_shift < 31:
+                relu_value = fp_math.shift_left16(relu_value, 1)
+
+            if relu_shift > 31:
+                relu_value = fp_math.rounding_divide_by_pot(relu_value, relu_shift - 31)
+
+            # Rescaled the value into a 16bit fixedpoint relu_value in [-1, 1]
+            # Now convert that to a 16bit fixedpoint value in [0, 1]
+            relu_value = (relu_value + (1 << 15)) >> 1
+            lut_result = fp_math.saturating_mul16(relu_value, input_value_preshift)
+            shift = 31 - out_shift
+            shift = -shift if shift < 0 else 0
+            # Finally apply the output shift
+            lut_result = fp_math.rounding_divide_by_pot(lut_result, shift) + zp_out
+            lut_result = min(quantized_max, max(quantized_min, lut_result))
+            values.append(lut_result)
+        return convert_to_lut(op, values, "hardswish")
+    return op
+
+
 def convert_lrelu_to_mul_max(op, arch):
     # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
     # (the opposite of convert_mul_max_to_abs_or_lrelu)
@@ -1245,6 +1297,7 @@
         convert_conv_to_fc,
         convert_softmax,
         optimise_strided_conv,
+        convert_hardswish_to_lut,
         rewrite_fully_connected_input,
         convert_batched_fc_shape,
         fixup_conv2d_backprop,
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 8d54d65..73953ce 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -165,7 +165,7 @@
     GatherV2 = OperatorInfo()
     Greater = OperatorInfo()
     GreaterEqual = OperatorInfo()
-    HardSwish = OperatorInfo()
+    HardSwish = OperatorInfo(indices=IFM_INDICES)
     HashtableLookup = OperatorInfo()
     Identity = OperatorInfo()
     If = OperatorInfo()
@@ -305,7 +305,7 @@
         return self in (Op.Relu, Op.Relu6, Op.ReluN1To1, Op.Clip)
 
     def is_activation_op(self):
-        return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT)
+        return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT, Op.HardSwish)
 
     def is_split_op(self):
         return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped, Op.Unpack)
@@ -372,6 +372,8 @@
     elif op_type == Op.Sigmoid:
         act.min = 0.0
         act.max = 1.0
+    elif op_type == Op.HardSwish:
+        act.min = 0.0
     return act
 
 
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 656a7e6..c3b0611 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -203,7 +203,7 @@
         for x in range(256):
             input_diff = x - 255
             if input_diff >= diff_min:
-                rescale = fp_math.saturating_rounding_mul(input_diff * (1 << shift), scale)
+                rescale = fp_math.saturating_rounding_mul32(input_diff * (1 << shift), scale)
                 lut.append(fp_math.exp_on_negative_values(rescale))
             else:
                 lut.append(0)
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 1bebe9a..99a4ba1 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -87,7 +87,7 @@
         set((Op.ReduceSum, Op.CLZ,)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops
     )
     relu_ops = Op.op_set(Op.is_relu_op)
-    activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax,))
+    activation_ops = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.Softmax, Op.HardSwish))
     npu_post_ops = (
         # activation functions
         activation_ops
@@ -261,6 +261,10 @@
         self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_constant)
         self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_ofm)
 
+        # HardSwish specific checks:
+        self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_input_8bit)
+        self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_matching_in_out_types)
+
     def is_operator_supported(self, op):
         ext_type = optype_to_builtintype(op.type)
         if op.type not in SupportedOperators.supported_operators:
@@ -934,6 +938,13 @@
         return valid, f"Op has ofm_dtype={ofm_dtype}"
 
     @staticmethod
+    def constraint_input_8bit(op):
+        "IFM must be int8 or uint8"
+        ifm_dtype = op.ifm.dtype
+        valid = (ifm_dtype == DataType.int8) or (ifm_dtype == DataType.uint8)
+        return valid, f"Op has ifm_dtype={ifm_dtype}"
+
+    @staticmethod
     def constraint_matching_quantization_parameters(op):
         "Both Input quantization parameters must match OFM quantization parameters"
         valid = True
diff --git a/ethosu/vela/test/test_fp_math.py b/ethosu/vela/test/test_fp_math.py
index 905826f..355d3ae 100644
--- a/ethosu/vela/test/test_fp_math.py
+++ b/ethosu/vela/test/test_fp_math.py
@@ -64,52 +64,80 @@
 
 def test_saturating_rounding_mul():
     i32info = np.iinfo(np.int32)
+    i16info = np.iinfo(np.int16)
+
     # Saturation
-    assert fp_math.saturating_rounding_mul(i32info.min, i32info.min) == i32info.max
-    assert fp_math.saturating_rounding_mul(i32info.min, i32info.max) == -i32info.max
-    assert fp_math.saturating_rounding_mul(i32info.max, i32info.min) == -i32info.max
+    assert fp_math.saturating_rounding_mul32(i32info.min, i32info.min) == i32info.max
+    assert fp_math.saturating_rounding_mul32(i32info.min, i32info.max) == -i32info.max
+    assert fp_math.saturating_rounding_mul32(i32info.max, i32info.min) == -i32info.max
+
+    assert fp_math.saturating_rounding_mul16(i16info.min, i16info.min) == i16info.max
+    assert fp_math.saturating_rounding_mul16(i16info.min, i16info.max) == -i16info.max
+    assert fp_math.saturating_rounding_mul16(i16info.max, i16info.min) == -i16info.max
 
     # Multiply by zero
-    assert fp_math.saturating_rounding_mul(0, fp_math.from_float(1.0)) == 0
-    assert fp_math.saturating_rounding_mul(0, fp_math.from_float(-1.0)) == 0
-    assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), 0) == 0
-    assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), 0) == 0
+    assert fp_math.saturating_rounding_mul32(0, fp_math.from_float(1.0)) == 0
+    assert fp_math.saturating_rounding_mul32(0, fp_math.from_float(-1.0)) == 0
+    assert fp_math.saturating_rounding_mul32(fp_math.from_float(1.0), 0) == 0
+    assert fp_math.saturating_rounding_mul32(fp_math.from_float(-1.0), 0) == 0
+
+    assert fp_math.saturating_rounding_mul16(0, i16info.max) == 0
+    assert fp_math.saturating_rounding_mul16(0, i16info.min) == 0
+    assert fp_math.saturating_rounding_mul16(i16info.max, 0) == 0
+    assert fp_math.saturating_rounding_mul16(i16info.min, 0) == 0
 
     # Multiply positive/negative
-    assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), fp_math.from_float(1.0)) == fp_math.from_float(
+    assert fp_math.saturating_rounding_mul32(fp_math.from_float(1.0), fp_math.from_float(1.0)) == fp_math.from_float(
         1.0, 5 + 5
     )
-    assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), fp_math.from_float(1.0)) == fp_math.from_float(
+    assert fp_math.saturating_rounding_mul32(fp_math.from_float(-1.0), fp_math.from_float(1.0)) == fp_math.from_float(
         -1.0, 5 + 5
     )
-    assert fp_math.saturating_rounding_mul(fp_math.from_float(1.0), fp_math.from_float(-1.0)) == fp_math.from_float(
+    assert fp_math.saturating_rounding_mul32(fp_math.from_float(1.0), fp_math.from_float(-1.0)) == fp_math.from_float(
         -1.0, 5 + 5
     )
-    assert fp_math.saturating_rounding_mul(fp_math.from_float(-1.0), fp_math.from_float(-1.0)) == fp_math.from_float(
+    assert fp_math.saturating_rounding_mul32(fp_math.from_float(-1.0), fp_math.from_float(-1.0)) == fp_math.from_float(
         1.0, 5 + 5
     )
 
     # Rounding
-    assert fp_math.saturating_rounding_mul(fp_math.from_float(16.0), 1) == 1
-    assert fp_math.saturating_rounding_mul(fp_math.from_float(-16.0), 1) == 0
-    assert fp_math.saturating_rounding_mul(fp_math.from_float(16.0) - 1, 1) == 0
-    assert fp_math.saturating_rounding_mul(fp_math.from_float(-16.0) - 1, 1) == -1
+    assert fp_math.saturating_rounding_mul32(fp_math.from_float(16.0), 1) == 1
+    assert fp_math.saturating_rounding_mul32(fp_math.from_float(-16.0), 1) == 0
+    assert fp_math.saturating_rounding_mul32(fp_math.from_float(16.0) - 1, 1) == 0
+    assert fp_math.saturating_rounding_mul32(fp_math.from_float(-16.0) - 1, 1) == -1
+
+    assert fp_math.saturating_rounding_mul16(fp_math.from_float(16.0, 21), 1) == 1
+    assert fp_math.saturating_rounding_mul16(fp_math.from_float(-16.0, 21), 1) == 0
+    assert fp_math.saturating_rounding_mul16(fp_math.from_float(16.0, 21) - 1, 1) == 0
+    assert fp_math.saturating_rounding_mul16(fp_math.from_float(-16.0, 21) - 1, 1) == -1
 
 
 def test_shift_left():
     i32info = np.iinfo(np.int32)
-    assert fp_math.shift_left(1, i32info.bits) == i32info.max
-    assert fp_math.shift_left(-1, i32info.bits) == i32info.min
-    assert fp_math.shift_left(1, i32info.bits - 2) == (i32info.max + 1) / 2
-    assert fp_math.shift_left(-1, i32info.bits - 2) == i32info.min // 2
+    i16info = np.iinfo(np.int16)
+    assert fp_math.shift_left32(1, i32info.bits) == i32info.max
+    assert fp_math.shift_left32(-1, i32info.bits) == i32info.min
+    assert fp_math.shift_left32(1, i32info.bits - 2) == (i32info.max + 1) / 2
+    assert fp_math.shift_left32(-1, i32info.bits - 2) == i32info.min // 2
 
-    assert fp_math.shift_left(fp_math.from_float(1.0), 5) == i32info.max
-    assert fp_math.shift_left(fp_math.from_float(-1.0), 5) == i32info.min
-    assert fp_math.shift_left(fp_math.from_float(1.0), 4) == 16 * fp_math.from_float(1.0)
-    assert fp_math.shift_left(fp_math.from_float(-1.0), 4) == 16 * fp_math.from_float(-1.0)
+    assert fp_math.shift_left16(1, i16info.bits) == i16info.max
+    assert fp_math.shift_left16(-1, i16info.bits) == i16info.min
+    assert fp_math.shift_left16(1, i16info.bits - 2) == (i16info.max + 1) / 2
+    assert fp_math.shift_left16(-1, i16info.bits - 2) == i16info.min // 2
+
+    assert fp_math.shift_left32(fp_math.from_float(1.0), 5) == i32info.max
+    assert fp_math.shift_left32(fp_math.from_float(-1.0), 5) == i32info.min
+    assert fp_math.shift_left32(fp_math.from_float(1.0), 4) == 16 * fp_math.from_float(1.0)
+    assert fp_math.shift_left32(fp_math.from_float(-1.0), 4) == 16 * fp_math.from_float(-1.0)
+
+    assert fp_math.shift_left16(fp_math.from_float(1.0, 21), 5) == i16info.max
+    assert fp_math.shift_left16(fp_math.from_float(-1.0, 21), 5) == i16info.min
+    assert fp_math.shift_left16(fp_math.from_float(1.0, 21), 4) == 16 * fp_math.from_float(1.0, 21)
+    assert fp_math.shift_left16(fp_math.from_float(-1.0, 21), 4) == 16 * fp_math.from_float(-1.0, 21)
 
     with pytest.raises(AssertionError):
-        fp_math.shift_left(1, -1)
+        fp_math.shift_left32(1, -1)
+        fp_math.shift_left16(1, -1)
 
 
 def test_rounding_divide_by_pot():
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 36213b7..5c01027 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -834,3 +834,26 @@
     assert support.is_operator_supported(op)
     op.attrs["alpha"] = -1
     assert not support.is_operator_supported(op)
+
+
+def test_constraint_hardswish_dtype():
+    # HardSwish operator dtype should be int8 or uint8, and input dtype must match output
+    # UINT8
+    op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8])
+    assert support.is_operator_supported(op)
+    # INT8
+    op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int8)
+    assert support.is_operator_supported(op)
+
+    # Invalid
+    op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int16)
+    assert not support.is_operator_supported(op)
+    op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.uint16)
+    assert not support.is_operator_supported(op)
+    op = testutil.create_op_with_quant_tensors(Op.HardSwish, [1, 8, 8, 8], [1, 8, 8, 8], datatype=DataType.int32)
+    assert not support.is_operator_supported(op)
+
+    in_tens = Tensor([1, 8, 8, 8], DataType.int8, "in")
+    out_tens = Tensor([1, 8, 8, 8], DataType.uint8, "out")
+    op = testutil.create_op(Op.HardSwish, [in_tens], out_tens)
+    assert not support.is_operator_supported(op)