MLBEDSW-6314 Static optimisation for quantise OP

*Quantise op becomes constant if input is known at compile time
*Quantised values calculated if input of op is const and float
*Const inputs to quant op that are int are requantized

Change-Id: Ic94a72a392af709fe6a640d7dacbb5dc2334f16f
Signed-off-by: Ayaan Masood <Ayaan.Masood@arm.com>
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index cf3985e..10ddca6 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -49,6 +49,7 @@
 from .operation import Padding
 from .operation_util import create_avgpool_nop
 from .operation_util import get_pad_values_from_input
+from .scaling import quantise_scale
 from .shape4d import Shape4D
 from .softmax import SoftMax
 from .tensor import check_quantized_tens_scaling_equal
@@ -1391,6 +1392,71 @@
     return op
 
 
+def optimise_quantize(op: Operation, arch, nng):
+
+    if op.type == Op.Quantize and op.run_on_npu:
+
+        ifm, ofm = op.get_ifm_ofm()
+        input_values = ifm.values
+
+        # Guard clause - input not const or no values to quantize
+        if ifm.ops[0].type != Op.Const or input_values is None:
+            return op
+
+        # Singular val in numpy array, convert to indexable array
+        if input_values.ndim == 0:
+            input_values = np.array([input_values])
+
+        # requantized int8 to int8
+        if ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8:
+
+            # scale needs to use double precision to match TFLite reference kernel
+            effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
+            effective_multiplier, effective_shift = quantise_scale(effective_scale)
+
+            assert effective_shift >= 0
+            assert -31 <= effective_shift <= 30
+            round_val = 1 << (effective_shift - 1)
+
+            requantized_vals = []
+            for val in input_values:
+                input_val = val - ifm.quantization.zero_point
+
+                output = input_val * effective_multiplier + round_val
+                ofm_val = (output >> effective_shift) + ofm.quantization.zero_point
+
+                clamped_ofm_values = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
+                requantized_vals.append(clamped_ofm_values)
+
+            ofm.values = np.array(requantized_vals)
+
+        # Case: Float input - quantize to int
+        elif np.issubdtype(input_values.dtype, np.float):
+
+            quantized_vals = []
+            for val in input_values:
+
+                # Derive quantized value
+                quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
+                quantized_vals.append(quant_val)
+
+            # Pass the statically calculated quant val to output tensor
+            ofm.values = np.array(quantized_vals)
+
+        # Make quantize op const and disconnect from parent node
+
+        # Remove reference of the current quant op from the parent tensor's consumer list
+        ifm.consumer_list = [consumer for consumer in ifm.consumer_list if consumer.op_index != op.op_index]
+
+        # Clear any references to parent node
+        op.inputs = []
+
+        # Convert this quantize op to const
+        op.type = Op.Const
+
+    return op
+
+
 def convert_shape_op_to_constant_tensor(op: Operation, arch, nng):
     """Static optimisation for SHAPE operator output value known at compile time"""
 
@@ -1424,9 +1490,19 @@
 
 
 def tflite_optimise_graph(nng, arch):
-
     # Compile time optimisations
-    optimisation_list = [convert_shape_op_to_constant_tensor]
+    optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
+
+    for optimisation in optimisation_list:
+        for idx, sg in enumerate(nng.subgraphs):
+            nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+                nng,
+                sg,
+                arch,
+                [],
+                [optimisation],
+                rewrite_unsupported=False,
+            )
 
     # Pre-processing step
     pre_process_list = [