MLBEDSW-3302: Reject per-channel scaling for unsupported ops

Vela only supports per-channel scaling for
convolution ops. This commit adds a check that
puts ops with per-channel scaling on the CPU.
A caveat worth mentioning is that neither
TensorFlow Lite or TensorFlow Lite Micro support
per-channel scaling for the CPU placed op,
however the problem is moved away from Vela.
This commit also changes a small utility function
in supported_operators.py used for docstring
formatting.

Signed-off-by: Dwight Lidman <dwight.lidman@arm.com>
Change-Id: I9ed090592f1d05dd4566d3e54dba1ef405299383
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 91fcb5a..6dcb27d 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -43,8 +43,8 @@
     output = map(optype_to_builtintype, op_list)
     # Remove UNKNOWNs
     output = (x for x in output if x is not BUILTIN_OPERATOR_UNKNOWN)
-    # Order alphabetically
-    return sorted(output)
+    # Order alphabetically and join into a string representation
+    return ", ".join(str(op) for op in sorted(output))
 
 
 class SupportedOperators:
@@ -94,6 +94,7 @@
     concat_ops = set((Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack,))
     memory_only_ops = set((Op.Squeeze, Op.Reshape, Op.QuantizedReshape,)) | concat_ops | split_ops
     shapeless_input_ops = binary_elem_wise_main_ops | set((Op.Split, Op.SplitV,))
+    per_axis_quant_ops = convolution_like_ops  # per-axis/channel quantization only currently supported for conv ops
     supported_fused_activations = relu_ops | set((Op.Tanh, Op.Sigmoid, Op.LUT,))
     supported_operators = npu_pre_ops | mac_main_ops | elem_wise_main_ops | npu_post_ops | memory_only_ops
     # Supported data types
@@ -113,6 +114,7 @@
     docstring_shapeless_input_ops = _optype_formatter(shapeless_input_ops)
     docstring_supported_int32_tensor_ops = _optype_formatter(supported_int32_tensor_ops)
     docstring_supported_fused_activations = _optype_formatter(supported_fused_activations)
+    docstring_per_axis_quant_ops = _optype_formatter(per_axis_quant_ops)
 
     def __init__(self):
         # Setup the generic constraints. Note: the order matters
@@ -127,6 +129,7 @@
         self.generic_constraints.append(SupportedOperators.constraint_tens_dimension)
         self.generic_constraints.append(SupportedOperators.constraint_tens_quant_none_check)
         self.generic_constraints.append(SupportedOperators.constraint_tens_quant_scale)
+        self.generic_constraints.append(SupportedOperators.constraint_tens_quant_per_axis)
         self.generic_constraints.append(SupportedOperators.constraint_faf)
 
         # Setup specific constraints. Note: the order matters
@@ -391,6 +394,20 @@
         return valid, ", ".join(extra)
 
     @classmethod
+    @docstring_format_args([docstring_per_axis_quant_ops])
+    def constraint_tens_quant_per_axis(cls, op):
+        "Per-axis quantization is only supported for the following op types: {}"
+        valid = True
+        extra = []
+        if op.type not in cls.per_axis_quant_ops:
+            tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+            for tens in tensors:
+                if tens.quantization.is_per_axis():
+                    valid = False
+                    extra.append(tens.name)
+        return valid, "The following tensor(s) have per-axis quantization parameters: " + ", ".join(extra)
+
+    @classmethod
     @docstring_format_args([docstring_supported_fused_activations])
     def constraint_faf(cls, op):
         "The fused activation function (if present) must be one of type: {}"
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 3601c92..b07b4dc 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -246,6 +246,13 @@
 
         return None not in (self.scale_f32, self.zero_point)
 
+    def is_per_axis(self):
+        """Returns True if either the scale, zero point, minimum or maximum values are arrays"""
+        for attr in ("scale_f32", "zero_point", "min", "max"):
+            if isinstance(getattr(self, attr), np.ndarray):
+                return True
+        return False
+
 
 def create_const_tensor(name, shape, dtype, values, value_dtype=None, purpose=TensorPurpose.Unknown, quantization=None):
     # Tensor
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 62de0d1..86d2475 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -100,6 +100,28 @@
     assert not support.is_operator_supported(op)
 
 
+def test_constraint_tens_quant_per_axis_not_supp():
+    # Quantization scale cannot be array-valued for elemwise ops
+    qp = QuantizationParameters()
+    qp.zero_point = np.zeros((1, 3))
+    qp.scale_f32 = np.ones((1, 3))
+    op = testutil.create_elemwise_op(Op.Mul, "op", [1, 8, 8, 8], [], [1, 8, 8, 8], ifm_quant=qp)
+    assert not support.is_operator_supported(op)
+
+
+def test_constraint_tens_quant_per_axis_is_supp():
+    op = testutil.create_op_with_quant_tensors(
+        Op.Conv2DBias, [1, 1, 1, 3], [1, 1, 1, 3], weights_shape=[1, 1, 1, 3], bias_shape=[1, 1, 1, 3]
+    )
+    op.attrs = {"stride_w": 1, "stride_h": 1}
+    assert support.is_operator_supported(op)
+    qp = QuantizationParameters()
+    qp.zero_point = np.zeros((1, 3))
+    qp.scale_f32 = np.ones((1, 3))
+    op.bias.quantization = qp
+    assert support.is_operator_supported(op)
+
+
 def test_constraint_faf():
     # Fused activation functions, if set, must be a valid op type
     op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 8, 8, 8], [1, 8, 8, 8])
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index b06008a..8258827 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -80,7 +80,9 @@
     return op
 
 
-def create_op_with_quant_tensors(op_type, ifm_shape, ofm_shape, weights_shape=None, datatype=DataType.uint8):
+def create_op_with_quant_tensors(
+    op_type, ifm_shape, ofm_shape, weights_shape=None, bias_shape=None, datatype=DataType.uint8
+):
     ifm = Tensor(ifm_shape, datatype, "in")
     ifm.quantization = default_quant_params()
     ofm = Tensor(ofm_shape, datatype, "out")
@@ -102,6 +104,12 @@
             "weights", weights_shape, datatype, np.zeros(weights_shape), np_type, quantization=qp
         )
         op.add_input_tensor(weights)
+    # Optional bias tensor
+    if bias_shape is not None:
+        qp = default_quant_params()
+        qp.zero_point = np.zeros(bias_shape)
+        bias = create_const_tensor("bias", bias_shape, DataType.int32, np.zeros(bias_shape), np.int32, quantization=qp)
+        op.add_input_tensor(bias)
     return op