vela: Improve the scaling is equal check

 - Fixed and documented both tensor and quant params scaling checks
 - Added quant params validity check and tensor quantisation check
 - Added valid tensor checks to some graph optimisation functions

Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: I8d6e8f03a603d28886dde511672c8399c85b794c
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index f6209ed..4696446 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -35,6 +35,7 @@
 from .operation import Op
 from .operation import Operation
 from .softmax import SoftMax
+from .tensor import check_quantized_tens_scaling_equal
 from .tensor import create_const_tensor
 from .tensor import create_reshape_tensor
 from .tensor import QuantizationParameters
@@ -341,7 +342,7 @@
                 # There is a preceding Reshape
                 # Compare input of prev_op and input of op, to see if prev_op can be removed
                 ifm_prev_op = prev_op.inputs[0]
-                if ifm_prev_op.shape == ifm.shape and ifm_prev_op.quantization.is_scaling_equal(ifm.quantization):
+                if ifm_prev_op.shape == ifm.shape and check_quantized_tens_scaling_equal(ifm_prev_op, ifm.quantization):
                     # prev_op can be removed
                     op.set_input_tensor(ifm_prev_op, 0)
                 else:
@@ -369,7 +370,7 @@
                 # There is a subsequent Reshape
                 # Compare desired shape and output of consumer op, to see if consumer op can be removed
                 ofm_cons_op = ofm.consumer_list[0].outputs[0]
-                if desired_shape == ofm_cons_op.shape and ofm.quantization.is_scaling_equal(ofm_cons_op.quantization):
+                if desired_shape == ofm_cons_op.shape and check_quantized_tens_scaling_equal(ofm, ofm_cons_op):
                     op.outputs[0] = ofm_cons_op
                     op.outputs[0].ops = [op]
                 else:
@@ -613,7 +614,7 @@
         ofm = op.outputs[0]
         # Relu with differing IFM and OFM scaling cannot be fused with another primary op
         # and requires its own to be inserted
-        if not ifm.is_scaling_equal(ofm):
+        if not check_quantized_tens_scaling_equal(ifm, ofm):
             # Override this op with its own primary op (avgpool)
             relu_fused_op = create_avgpool_nop(op.name + "_avgpool")
             # And fuse the original activation function to it
@@ -727,9 +728,12 @@
         if mul.activation:
             return op
         ifm, ofm = op.get_ifm_ofm()
+        if ifm is None or ofm is None:
+            return op
+
         if ifm.dtype not in (DataType.uint8, DataType.int8) or ifm.dtype != ofm.dtype:
             return op
-        if not ifm.is_scaling_equal(ofm) or not ifm.is_scaling_equal(mul_ofm):
+        if not check_quantized_tens_scaling_equal(ifm, ofm) or not check_quantized_tens_scaling_equal(ifm, mul_ofm):
             # rewrite to LeakyRelu currently only makes sense if the quantization is identical
             return op
 
@@ -780,6 +784,8 @@
     # Converts LeakyRelu to Max(alpha * IFM, identity * IFM)
     # (the opposite of convert_mul_max_to_abs_or_lrelu)
     ifm, ofm = op.get_ifm_ofm()
+    if ifm is None or ofm is None:
+        return op
 
     # Add multiplication with alpha
     mul_alpha = Operation(Op.Mul, op.name + "_mul_alpha")
@@ -796,7 +802,7 @@
     fm_alpha = ofm.clone(op.name + "_alpha")
     mul_alpha.set_output_tensor(fm_alpha)
 
-    if ifm.is_scaling_equal(ofm):
+    if check_quantized_tens_scaling_equal(ifm, ofm):
         # No identity multiplication is needed
         fm_id = ifm
     else:
@@ -829,6 +835,8 @@
 def convert_to_lut(op, lut_values, lut_name):
     # Rewrite the operation by Add with scalar 0 + LUT activation
     ifm = op.inputs[0]
+    if ifm is None:
+        return op
     assert ifm.dtype.size_in_bytes() == 1
     op.type = Op.Add
     op.name = op.name + "_lut_" + lut_name
@@ -908,10 +916,12 @@
     if op.type != Op.LeakyRelu:
         return op
     ifm, ofm = op.get_ifm_ofm()
+    if ifm is None or ofm is None:
+        return op
     if ifm.dtype in (DataType.uint8, DataType.int8) and ifm.dtype == ofm.dtype:
         # use LUT for int8/uint8
         return convert_lrelu_to_lut(op, arch)
-    if ifm.is_scaling_equal(ofm) and ifm.dtype == ofm.dtype and ifm.dtype == DataType.int16:
+    if check_quantized_tens_scaling_equal(ifm, ofm) and ifm.dtype == ofm.dtype == DataType.int16:
         # use LeakyRelu unmodified for int16 with equal input/output scaling
         return op
     return convert_lrelu_to_mul_max(op, arch)
@@ -953,9 +963,9 @@
         cons_op_ofm = cons_op.outputs[0]
         if len(prev_op_ifm.shape) == len(cons_op_ofm.shape):
             # Check if quantization is the same in the input and output for the reshape ops
-            if prev_op_ifm.quantization.is_scaling_equal(
-                prev_op_ofm.quantization
-            ) and cons_op_ifm.quantization.is_scaling_equal(cons_op_ofm.quantization):
+            if check_quantized_tens_scaling_equal(prev_op_ifm, prev_op_ofm) and check_quantized_tens_scaling_equal(
+                cons_op_ifm, cons_op_ofm
+            ):
                 op.set_input_tensor(prev_op_ifm, 0)
                 op.set_output_tensor(cons_op_ofm)
     return op
@@ -966,6 +976,8 @@
     if not op.attrs.get("is_nop", False) or op.activation is None:
         return op
     ifm, ofm = op.get_ifm_ofm()
+    if ifm is None or ofm is None:
+        return op
     # finds the input(s) to the operation
     prev_op = ifm.ops[0]
     # Note: the below checks on prev_op require that a first optimize pass on the full graph has been performed
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index f4dd579..dfb7bc7 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -24,6 +24,8 @@
 from .numeric_util import is_integer
 from .operation import get_slice_offsets
 from .operation import Op
+from .tensor import check_quantized_tens_scaling_equal
+from .tensor import check_tens_quantized
 
 
 # Custom decorator function to allow formatting docstrings containing "{}"
@@ -730,17 +732,22 @@
 
     @classmethod
     def check_quantization_restrictions_binary_elem_wise(cls, op):
-        # makes sure IFM1, IFM2 and OFM quantization are equal for binary ops
+        # checks that IFM1, IFM2 and OFM quantization are equal for binary ops
+
         assert len(op.inputs) >= 2 and len(op.outputs) == 1
 
         if (
-            op.inputs[0].quantization is None
-            or not op.inputs[0].is_scaling_equal(op.inputs[1])
-            or not op.inputs[0].is_scaling_equal(op.outputs[0])
+            not check_tens_quantized(op.inputs[0])
+            or not check_tens_quantized(op.inputs[1])
+            or not check_tens_quantized(op.outputs[0])
         ):
-            print(
-                "Warning: Input/output tensors with different quantization is unsupported for the", op.type, "operator"
-            )
+            warn_cpu(op, "has non-quantised input and/or output tensors")
+            return False
+
+        if not check_quantized_tens_scaling_equal(op.inputs[0], op.inputs[1]) or not check_quantized_tens_scaling_equal(
+            op.inputs[0], op.outputs[0]
+        ):
+            warn_cpu(op, "has input/output tensors with different quantisation which is illegal")
             return False
 
         return True
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 98dfa3d..84af8ed 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -23,6 +23,7 @@
 import numpy as np
 
 from . import numeric_util
+from .data_type import BaseType
 from .data_type import DataType
 from .ethos_u55_regs.ethos_u55_regs import resampling_mode
 from .operation import Op
@@ -229,11 +230,22 @@
         return res
 
     def is_scaling_equal(self, other):
-        if other is None or not isinstance(other, QuantizationParameters):
+        # quantisation parameter scaling is not equal if 'other' is None because
+        # it implies that the tensor it belongs to is not quantised. otherwise,
+        # it depends upon whether the scale and zero point are equal
+
+        if other is None:
             return False
 
+        assert isinstance(other, QuantizationParameters)
+
         return self.scale_f32 == other.scale_f32 and self.zero_point == other.zero_point
 
+    def is_valid(self):
+        # quantisation parameters are consider valid if they have a scale and zero point
+
+        return None not in (self.scale_f32, self.zero_point)
+
 
 def create_const_tensor(name, shape, dtype, values, value_dtype=None, purpose=TensorPurpose.Unknown, quantization=None):
     # Tensor
@@ -765,9 +777,6 @@
             return True
         return False
 
-    def is_scaling_equal(self, tens):
-        return self.quantization.is_scaling_equal(tens.quantization)
-
     def equivalent(self, tens):
         return self.equivalence_id == tens.equivalence_id
 
@@ -785,7 +794,33 @@
         else:
             return self.shape.copy()
 
+    def is_quantized(self):
+        # a tensor is quantized if it has an integral type and it contains valid quantization params
+
+        if (self.dtype.type & BaseType.Int) == 0 or self.quantization is None:
+            return False
+
+        assert isinstance(self.quantisation, QuantizationParameters)
+        assert self.quantization.is_valid()
+
+        return True
+
     def __str__(self):
         return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)
 
     __repr__ = __str__
+
+
+def check_tens_quantized(tens):
+    # checks that a tensor is quantized
+
+    return isinstance(tens, Tensor) and tens.is_quantized()
+
+
+def check_quantized_tens_scaling_equal(tens_a, tens_b):
+    # checks that the scaling of two quantized tensors are equal
+
+    assert check_tens_quantized(tens_a)
+    assert check_tens_quantized(tens_b)
+
+    return tens_a.quantization.is_scaling_equal(tens_b.quantization)