MLBEDSW-6830: MLCE: Fix assert on concat op

- The compiler will assert when compiling a faulty concat op.
In the reported use case, there were 3 inputs with shape 1x1x2
but the output shape was 1x1x2 (expected to be 1x1x6)

- The solution is to add constraints to the concat operator.

Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Change-Id: I94a505c51a9fd54d1aa92531a0415031db52378a
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 90d93d0..5d25e37 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -297,6 +297,13 @@
         # Reshape specific checks:
         self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
 
+        # Concat specific checks:
+        for op_type in (Op.Concat, Op.ConcatTFLite):
+            self.specific_constraints[op_type].append(
+                TFLiteSupportedOperators.constraint_concat_valid_dimensions_non_axis
+            )
+            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_concat_valid_dimensions_axis)
+
     def is_operator_supported(self, op):
         ext_type = optype_to_builtintype(op.type)
         if op.type not in TFLiteSupportedOperators.supported_operators:
@@ -850,3 +857,42 @@
         extra = ", ".join(extra)
 
         return valid, f"Op has non-const input(s): {extra}"
+
+    @staticmethod
+    def constraint_concat_valid_dimensions_non_axis(op):
+        """All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute"""
+        valid = True
+        extra = []
+        ofm_shape = op.ofm.shape
+        ofm_dim = len(ofm_shape)
+        axis = op.attrs["axis"]
+        axis += ofm_dim if axis < 0 else 0
+
+        tensors = [tens for tens in op.inputs if tens]
+        for tens in tensors:
+            if any(tens.shape[dim] != ofm_shape[dim] for dim in range(ofm_dim) if dim != axis):
+                valid = False
+                extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
+
+        extra = ", ".join(extra)
+        return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"
+
+    @staticmethod
+    def constraint_concat_valid_dimensions_axis(op):
+        """The size of the OFM axis must match the sum of all IFM axis defined by the axis attribute"""
+        valid = True
+        extra = []
+        ofm_shape = op.ofm.shape
+        ofm_dim = len(ofm_shape)
+        axis = op.attrs["axis"]
+        axis += ofm_dim if axis < 0 else 0
+
+        sum_ifm_axis = 0
+        tensors = [tens for tens in op.inputs if tens]
+        for tens in tensors:
+            sum_ifm_axis += tens.shape[axis]
+            extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
+
+        valid = sum_ifm_axis == ofm_shape[axis]
+        extra = ", ".join(extra)
+        return valid, f"Op has axis={axis}, ofm_shape={ofm_shape} and the list of mismatching inputs are: {extra}"