MLBEDSW-6687 Vela crashes in npu_serialisation.py and tflite_graph_optimiser.py

Fixed static optimisation of Quantize operator by running unsupported
formats on CPU. Also added support for int16 and corrected the
calculation.

Change-Id: I861c712aa6258dba53fcf4d5dae45d1d416e6141
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 10ddca6..f2a8c80 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -25,6 +25,7 @@
 from . import rewrite_graph
 from . import scaling
 from .api import NpuRoundingMode
+from .data_type import BaseType
 from .data_type import DataType
 from .debug_database import DebugDatabase
 from .errors import UnsupportedFeatureError
@@ -1408,40 +1409,44 @@
             input_values = np.array([input_values])
 
         # requantized int8 to int8
-        if ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8:
+        if (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8) or (
+            ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16
+        ):
 
             # scale needs to use double precision to match TFLite reference kernel
             effective_scale = np.float64(ifm.quantization.scale_f32) / np.float64(ofm.quantization.scale_f32)
             effective_multiplier, effective_shift = quantise_scale(effective_scale)
 
-            assert effective_shift >= 0
-            assert -31 <= effective_shift <= 30
-            round_val = 1 << (effective_shift - 1)
-
             requantized_vals = []
-            for val in input_values:
+            for val in input_values.flatten():
                 input_val = val - ifm.quantization.zero_point
 
-                output = input_val * effective_multiplier + round_val
-                ofm_val = (output >> effective_shift) + ofm.quantization.zero_point
+                ofm_val = fp_math.multiply_by_quantized_multiplier(input_val, effective_multiplier, effective_shift)
+                ofm_val += ofm.quantization.zero_point
 
-                clamped_ofm_values = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
-                requantized_vals.append(clamped_ofm_values)
+                clamped_ofm_value = max(min(ofm_val, ofm.quantization.quant_max), ofm.quantization.quant_min)
+                requantized_vals.append(clamped_ofm_value)
 
-            ofm.values = np.array(requantized_vals)
+            ofm.values = np.array(requantized_vals, ofm.dtype.as_numpy_type())
+            ofm.values.shape = input_values.shape
 
         # Case: Float input - quantize to int
-        elif np.issubdtype(input_values.dtype, np.float):
+        elif ifm.dtype.type == BaseType.Float:
 
             quantized_vals = []
             for val in input_values:
 
                 # Derive quantized value
                 quant_val = (val / ofm.quantization.scale_f32) + ofm.quantization.zero_point
-                quantized_vals.append(quant_val)
+                clamped_quantized_val = np.clip(quant_val, ofm.quantization.quant_min, ofm.quantization.quant_max)
+                quantized_vals.append(clamped_quantized_val)
 
             # Pass the statically calculated quant val to output tensor
-            ofm.values = np.array(quantized_vals)
+            ofm.values = np.array(quantized_vals, ofm.dtype.as_numpy_type())
+
+        # Unsupported data type
+        else:
+            return op
 
         # Make quantize op const and disconnect from parent node
 
@@ -1493,23 +1498,6 @@
     # Compile time optimisations
     optimisation_list = [optimise_quantize, convert_shape_op_to_constant_tensor]
 
-    for optimisation in optimisation_list:
-        for idx, sg in enumerate(nng.subgraphs):
-            nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-                nng,
-                sg,
-                arch,
-                [],
-                [optimisation],
-                rewrite_unsupported=False,
-            )
-
-    # Pre-processing step
-    pre_process_list = [
-        supported_operator_check,
-        set_ifm_ofm_op_shapes,
-    ]
-
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
             nng,
@@ -1520,6 +1508,12 @@
             rewrite_unsupported=False,
         )
 
+    # Pre-processing step
+    pre_process_list = [
+        supported_operator_check,
+        set_ifm_ofm_op_shapes,
+    ]
+
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
             nng,