MLBEDSW-2688: LUT calculation with different in/out scale

Enables LUT for LeakyRelu with int8/uint8 even if input scale
is different from the output scale.

Fusing LUT with a previous operator for this situation
requires further work.

Change-Id: I9eddfe36f457e763d44eb3e05fbe240eac7cfec9
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/fp_math.py b/ethosu/vela/fp_math.py
index 2055879..eaeb84a 100644
--- a/ethosu/vela/fp_math.py
+++ b/ethosu/vela/fp_math.py
@@ -136,3 +136,13 @@
         return np.iinfo(np.int32).max
     else:
         return result
+
+
+def multiply_by_quantized_multiplier(x, scale, shift):
+    # Multiplies x (int32) by (scale, shift) which have obtained by a call to scaling.quantize_scale,
+    # returns rounded result
+    shift = 31 - shift
+    left_shift = shift if shift > 0 else 0
+    right_shift = -shift if shift < 0 else 0
+    mul = saturating_rounding_mul(x * (1 << left_shift), scale)
+    return rounding_divide_by_pot(mul, right_shift)
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index aaccce2..7ab009f 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -20,8 +20,10 @@
 
 import numpy as np
 
+from . import fp_math
 from . import lut
 from . import rewrite_graph
+from . import scaling
 from .data_type import DataType
 from .errors import UnsupportedFeatureError
 from .ethos_u55_regs.ethos_u55_regs import resampling_mode
@@ -637,7 +639,8 @@
             return op
 
         # make sure the Mul doesn't have any other consumers
-        if len(mul.outputs[0].consumers()) != 1:
+        mul_ofm = mul.outputs[0]
+        if len(mul_ofm.consumers()) != 1:
             return op
         # make sure the Mul doesn't have a faf
         if mul.attrs["fused_activation_function"]:
@@ -645,7 +648,7 @@
         ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
         if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
             return op
-        if not ifm.is_scaling_equal(ofm):
+        if not ifm.is_scaling_equal(ofm) or not ifm.is_scaling_equal(mul_ofm):
             # rewrite to LeakyRelu currently only makes sense if the quantization is identical
             return op
 
@@ -671,6 +674,15 @@
         if val >= 0:
             new_op = "LeakyRelu"
             op.attrs["alpha"] = val
+            # to produce bit exact results, the alpha is not enough;
+            # save additional scaling info in attr "alpha_scale", to be used as input
+            # to the LUT construction
+            alpha_scalar = const_tens.quant_values - const_tens.quantization.zero_point
+            mul_ifm_scale = np.double(ifm.quantization.scale_f32)
+            mul_ifm2_scale = np.double(const_tens.quantization.scale_f32)
+            mul_ofm_scale = np.double(mul_ofm.quantization.scale_f32)
+            alpha_scale, alpha_shift = scaling.elementwise_mul_scale(mul_ifm_scale, mul_ifm2_scale, mul_ofm_scale)
+            op.attrs["alpha_scaling"] = (alpha_scalar, alpha_scale, alpha_shift)
         elif val == -1:
             new_op = "Abs"
         else:
@@ -744,15 +756,39 @@
     op.attrs["is_nop"] = True
     # Create an input tensor containing scalar zero
     quantization = QuantizationParameters(0.0, 255.0)
-    quantization.scale_f32 = 1.0
+    quantization.scale_f32 = ifm.quantization.scale_f32
     quantization.zero_point = 0
     tens = create_const_tensor(op.inputs[0].name + "_add", [], ifm.dtype, [0], np.uint8, quantization=quantization)
     op.add_input_tensor(tens)
-    alpha = op.attrs["alpha"]
-    zp = ofm.quantization.zero_point
     # Generate the LUT
+    alpha = op.attrs["alpha"]
+    ifm_scale = np.double(ifm.quantization.scale_f32)
+    ofm_scale = np.double(ofm.quantization.scale_f32)
+    zp_in = ifm.quantization.zero_point
+    zp_out = ofm.quantization.zero_point
+    identity_scale, identity_shift = scaling.elementwise_mul_scale(ifm_scale, 1, ofm_scale)
+    alpha_scalar = 1
+    alpha_scale, alpha_shift = scaling.elementwise_mul_scale(ifm_scale, alpha, ofm_scale)
+    if "alpha_scaling" in op.attrs:
+        # The LeakyRelu was the result from convert_mul_max_to_abs_or_lrelu
+        alpha_scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
+    values = []
     ix = range(256) if ifm.dtype == DataType.uint8 else range(-128, 128)
-    values = [int(x) if x >= zp else int(round(zp - alpha * (zp - x))) for x in ix]
+    quantized_min = min(ix)
+    quantized_max = max(ix)
+    for x in ix:
+        if x < zp_in:
+            lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(
+                alpha_scalar * (x - zp_in), alpha_scale, alpha_shift
+            )
+        else:
+            lut_result = zp_out + fp_math.multiply_by_quantized_multiplier(x - zp_in, identity_scale, identity_shift)
+        lut_result = min(quantized_max, max(quantized_min, lut_result))
+        values.append(lut_result)
+    # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
+    # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
+    # should be the same as the IFM
+    op.attrs["forced_output_quantization"] = ifm.quantization
     lut_tensor = lut.create_lut_tensor(op.name + "_lut", values, DataType.int8)
     op.set_activation_lut(lut_tensor)
     return op
@@ -763,13 +799,12 @@
     if op.type != "LeakyRelu":
         return op
     ifm, _, _, ofm = op.get_ifm_weights_biases_ofm()
-    if ifm.is_scaling_equal(ofm) and ifm.dtype == ofm.dtype:
-        if ifm.dtype in (DataType.uint8, DataType.int8):
-            # use LUT
-            return convert_lrelu_to_lut(op, arch)
-        elif ifm.dtype == DataType.int16:
-            # use LeakyRelu unmodified
-            return op
+    if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
+        # use LUT for int8/uint8
+        return convert_lrelu_to_lut(op, arch)
+    if ifm.is_scaling_equal(ofm) and ifm.dtype == ofm.dtype and ifm.dtype == DataType.int16:
+        # use LeakyRelu unmodified for int16 with equal input/output scaling
+        return op
     return convert_lrelu_to_mul_max(op, arch)
 
 
@@ -802,7 +837,7 @@
     if not fuse:
         return op
     # Move the fused activation function + corresponding info to prev_op
-    for attr in ("fused_activation_function", "alpha"):
+    for attr in ("fused_activation_function", "alpha", "forced_output_quantization"):
         if attr in op.attrs:
             prev_op.attrs[attr] = op.attrs[attr]
     if op.activation_lut is not None:
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 8d9f918..609fcc6 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -442,6 +442,13 @@
             fmf = primary_op.attrs.get("fused_memory_function", None)
             faf = primary_op.attrs.get("fused_activation_function", None)
             fused_quantize = any(op.type == "Quantize" for op in ps.ops)
+            # Force output scale, used in operations with fused LUT
+            # Note: with current LUT support, forced_ofm_quantization is always equal to cmd.ofm_tensor.quantization
+            # except when primary_op is AddAct + 0 (no-op) + LUT
+            forced_ofm_quantization = primary_op.attrs.get("forced_output_quantization", None)
+            ofm_quant = cmd.ofm_tensor.quantization
+            if forced_ofm_quantization is not None:
+                ofm_quant = forced_ofm_quantization
 
             # Specifies which operand to apply scaling to in bitexact elementwise ADD/SUB
             op_to_scale = 0
@@ -476,7 +483,7 @@
                 if primary_op.type in set(("AddAct", "MulAct", "SubAct",)):
                     input_scale = cmd.ifm_tensor.quantization.scale_f32
                     input2_scale = cmd.ifm2_tensor.quantization.scale_f32
-                    output_scale = cmd.ofm_tensor.quantization.scale_f32
+                    output_scale = ofm_quant.scale_f32
                     use_global_scale = True
 
                     if output_scale is not None and faf in ("Sigmoid", "Tanh"):
@@ -491,7 +498,7 @@
                         emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
                     else:  # AddAct/SubAct
                         # Force output scale same as the input scale for
-                        # resizebiliner 1x1 that is converted to add
+                        # resizebilinear 1x1 that is converted to add
                         if "resizebilinear" in primary_op.attrs:
                             output_scale = input2_scale
 
@@ -529,7 +536,7 @@
                         emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
 
                 elif primary_op.type in set(("LeakyRelu", "Abs",)):
-                    output_scale = cmd.ofm_tensor.quantization.scale_f32
+                    output_scale = ofm_quant.scale_f32
                     use_global_scale = True
 
                     if primary_op.type == "LeakyRelu":
@@ -664,7 +671,7 @@
                         elif fused_quantize:
                             # Quantize op requires different scaling
                             ifm_scale_f64 = np.double(cmd.ifm_tensor.quantization.scale_f32)
-                            ofm_scale_f64 = np.double(cmd.ofm_tensor.quantization.scale_f32)
+                            ofm_scale_f64 = np.double(ofm_quant.scale_f32)
                             scale, shift = scaling.quantise_scale(ifm_scale_f64 / ofm_scale_f64)
                         elif primary_op.type == "ResizeBilinear" and "rescale" in primary_op.attrs:
                             rescale = primary_op.attrs["rescale"]
@@ -676,11 +683,8 @@
                             # k_height == k_width == 1 is allways true in this case
                             # Normally the scale is maximised, to get maximum precision, which means that
                             # if rescale != 1, scale need to consider the number of bits needed for rescaling
-                            if None not in (
-                                cmd.ofm_tensor.quantization.scale_f32,
-                                cmd.ifm_tensor.quantization.scale_f32,
-                            ):
-                                rescale = cmd.ifm_tensor.quantization.scale_f32 / cmd.ofm_tensor.quantization.scale_f32
+                            if None not in (ofm_quant.scale_f32, cmd.ifm_tensor.quantization.scale_f32,):
+                                rescale = cmd.ifm_tensor.quantization.scale_f32 / ofm_quant.scale_f32
                                 rescale_bits = 0
                                 if k_height == k_width == 1:
                                     if fmf == "ConcatSliceWrite":
@@ -797,9 +801,8 @@
                     scale_region = base_ptr_idx_map[cmd.scale_tensor.mem_type]
                     emit.cmd0_with_param(cmd0.NPU_SET_SCALE_REGION, scale_region)
 
-            ofm_quant = cmd.ofm_tensor.quantization
-            ofm_quant_qmin = cmd.ofm_tensor.quantization.quant_min
-            ofm_quant_qmax = cmd.ofm_tensor.quantization.quant_max
+            ofm_quant_qmin = ofm_quant.quant_min
+            ofm_quant_qmax = ofm_quant.quant_max
             ifm_min = cmd.ifm_tensor.quantization.min
             ifm_max = cmd.ifm_tensor.quantization.max
 
@@ -912,13 +915,15 @@
                     emit.cmd0_with_param(zero_point_op, 0)
                 else:
                     assert tens.quantization.zero_point is not None, "need an actual zero point set"
-                    if (
+                    if cmd0.NPU_SET_OFM_ZERO_POINT == zero_point_op and forced_ofm_quantization is not None:
+                        zero_point = forced_ofm_quantization.zero_point
+                    elif (
                         "resizebilinear" in primary_op.attrs
                         and primary_op.type == "AddAct"
                         and cmd0.NPU_SET_OFM_ZERO_POINT == zero_point_op
                     ):
                         # Force output zero point same as the input zero point
-                        # for resizebiliner 1x1 that is converted to add
+                        # for resizebilinear 1x1 that is converted to add
                         zero_point = cmd.ifm2_tensor.quantization.zero_point
                     else:
                         zero_point = tens.quantization.zero_point
diff --git a/ethosu/vela/test/test_fp_math.py b/ethosu/vela/test/test_fp_math.py
index 2dde1e4..8c1ed67 100644
--- a/ethosu/vela/test/test_fp_math.py
+++ b/ethosu/vela/test/test_fp_math.py
@@ -19,6 +19,7 @@
 import pytest
 
 from ethosu.vela import fp_math
+from ethosu.vela import scaling
 from ethosu.vela.softmax import SoftMax
 
 # Turn off black formatting for EXP_LUT to keep it compact
@@ -116,3 +117,39 @@
     sm = SoftMax(None)
     for (expected, actual) in zip(EXP_LUT, sm.generate_exp_table(1.0, np.float32(0.05123165))):
         assert actual == expected
+
+
+multiply_test_data = [
+    (0, 0, 0),
+    (0, 0.7, 0),
+    (0, 55.8, 0),
+    (6, 0.3, 2),
+    (200, 0, 0),
+    (1, 1, 1),
+    (1, 0.1, 0),
+    (1, 3.49, 3),
+    (1, 3.51, 4),
+    (27, 1, 27),
+    (13, 0.9, 12),
+    (3, 21.2, 64),
+    (1000, 2000, 2000000),
+    (32767, 32767, 32767 * 32767),  # extreme values
+]
+
+
+@pytest.mark.parametrize("x, factor, expected", multiply_test_data)
+def test_multiply_by_quantized_multiplier(x, factor, expected):
+    scale, shift = scaling.quantise_scale(factor)
+    assert fp_math.multiply_by_quantized_multiplier(x, scale, shift) == expected
+    assert fp_math.multiply_by_quantized_multiplier(-x, scale, shift) == -expected
+    assert fp_math.multiply_by_quantized_multiplier(x, -scale, shift) == -expected
+    assert fp_math.multiply_by_quantized_multiplier(-x, -scale, shift) == expected
+
+
+def test_multiply_by_quantized_multiplier_int16_limits():
+    # Tests min/max limits of foreseen practical usage of multiply_by_quantized_multiplier
+    # for the purpose of calculating LUTs
+    for x in [-32768, 32767]:
+        for y in [-32768, 32767]:
+            scale, shift = scaling.quantise_scale(y)
+            assert fp_math.multiply_by_quantized_multiplier(x, scale, shift) == x * y