MLBEDSW-4075 PACK axis 0 + tanh fails with output diff

The test failed since the tanh had batch size > 1.
Added checks for batch size for all supported operators.

Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Change-Id: I3570352740c40eb96bd9db965dfa3c91c81ff2ad
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index 3872bdc..35fc1a6 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -550,25 +550,25 @@
 
 def test_constraint_elemwise_batch_size():
     # BINARY CASE
-    # Batch can be >1 if dims is <=2D
-    op = testutil.create_elemwise_op(Op.Add, "op", [2, 2], [2, 2], [2, 2])
+    # Batch can be >1 if dims is <=3D
+    op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2], [2, 2, 2], [2, 2, 2])
     assert support.is_operator_supported(op)
-    # For dims >2D, batch must be 1
-    op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 2], [1, 2, 2], [1, 2, 2])
+    # For dims >3D, batch must be 1
+    op = testutil.create_elemwise_op(Op.Add, "op", [1, 2, 2, 2], [1, 2, 2, 2], [1, 2, 2, 2])
     assert support.is_operator_supported(op)
     # invalid case
-    op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2], [2, 2, 2], [2, 2, 2])
+    op = testutil.create_elemwise_op(Op.Add, "op", [2, 2, 2, 2], [2, 2, 2, 2], [2, 2, 2, 2])
     assert not support.is_operator_supported(op)
 
     # UNARY CASE
-    # Batch can be >1 if dims is <=2D
-    op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2], None, [2, 2], datatype=DataType.int32)
+    # Batch can be >1 if dims is <=3D
+    op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2], None, [2, 2, 2], datatype=DataType.int32)
     assert support.is_operator_supported(op)
-    # For dims >2D, batch must be 1
-    op = testutil.create_elemwise_op(Op.CLZ, "op", [1, 2, 2], None, [1, 2, 2], datatype=DataType.int32)
+    # For dims >3D, batch must be 1
+    op = testutil.create_elemwise_op(Op.CLZ, "op", [1, 2, 2, 2], None, [1, 2, 2, 2], datatype=DataType.int32)
     assert support.is_operator_supported(op)
     # invalid case
-    op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2], None, [2, 2, 2], datatype=DataType.int32)
+    op = testutil.create_elemwise_op(Op.CLZ, "op", [2, 2, 2, 2], None, [2, 2, 2, 2], datatype=DataType.int32)
     assert not support.is_operator_supported(op)
 
 
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index b6f9796..d42caf5 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -20,6 +20,7 @@
 import numpy as np
 
 from .data_type import DataType
+from .numeric_util import full_shape
 from .operation import Op
 from .operation import Padding
 from .supported_operators_util import docstring_format_args
@@ -206,9 +207,20 @@
         self.generic_constraints.append(TFLiteSupportedOperators.constraint_tens_int32_ops)
         self.generic_constraints.append(TFLiteSupportedOperators.constraint_tens_dimension)
         self.generic_constraints.append(TFLiteSupportedOperators.constraint_tens_quant_per_axis)
+        self.generic_constraints.append(TFLiteSupportedOperators.constraint_batch_size)
         self.generic_constraints.append(TFLiteSupportedOperators.constraint_faf)
         self.generic_constraints.append(TFLiteSupportedOperators.constraint_faf_type)
 
+        # Setup generic constraint exceptions
+        self.generic_constraints_exceptions = defaultdict(list)
+        self.generic_constraints_exceptions[Op.FullyConnected].append(TFLiteSupportedOperators.constraint_batch_size)
+        self.generic_constraints_exceptions[Op.Softmax].append(TFLiteSupportedOperators.constraint_batch_size)
+        self.generic_constraints_exceptions[Op.Reshape].append(TFLiteSupportedOperators.constraint_batch_size)
+        self.generic_constraints_exceptions[Op.Shape].append(TFLiteSupportedOperators.constraint_batch_size)
+        self.generic_constraints_exceptions[Op.Squeeze].append(TFLiteSupportedOperators.constraint_batch_size)
+        for op_type in TFLiteSupportedOperators.split_ops - set((Op.UnpackReshaped,)):
+            self.generic_constraints_exceptions[op_type].append(TFLiteSupportedOperators.constraint_batch_size)
+
         # Setup specific constraints. Note: the order matters
         self.specific_constraints = defaultdict(list)
 
@@ -223,7 +235,6 @@
             self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_limit)
             self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_type)
             self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_40bit)
-            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_batch_size)
         # Depthwise Conv specific checks:
         for op_type in TFLiteSupportedOperators.depthwise_convolution_ops:
             self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_depth_multiplier)
@@ -235,7 +246,6 @@
 
         # Pooling checks:
         for op_type in TFLiteSupportedOperators.pooling_ops:
-            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_batch_size)
             self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_stride_range)
         # AVG pooling specific checks:
         for op_type in TFLiteSupportedOperators.avg_pooling_ops:
@@ -268,9 +278,7 @@
             self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_type)
             self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bias_40bit)
 
-        # Element-wise checks:
-        for op_type in TFLiteSupportedOperators.elem_wise_main_ops:
-            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_elemwise_batch_size)
+        # Element-wise checks
         # Binary Min/Max specific checks:
         for op_type in TFLiteSupportedOperators.binary_elem_wise_min_max_ops:
             self.specific_constraints[op_type].append(
@@ -302,7 +310,6 @@
         self.specific_constraints[Op.Pad].append(TFLiteSupportedOperators.constraint_pad_type)
 
         # Mean specific checks:
-        self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_batch_size)
         self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_avgpool)
         self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product)
         self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_int8)
@@ -319,7 +326,10 @@
                 print(f"Info: {ext_type} '{op.name}' is a CPU only op")
             return False
 
-        for constraint in self.generic_constraints + self.specific_constraints[op.type]:
+        op_exceptions = self.generic_constraints_exceptions[op.type]
+        generic_constraints = [constraint for constraint in self.generic_constraints if constraint not in op_exceptions]
+
+        for constraint in generic_constraints + self.specific_constraints[op.type]:
             valid, extra = constraint(op)
             if not valid:
                 print(f"Warning: {ext_type} '{op.name}' is not supported on the NPU. Placing on CPU instead")
@@ -497,9 +507,16 @@
     @staticmethod
     def constraint_batch_size(op):
         "IFM Tensor batch size must be 1"
-        ifm = op.ifm
-        valid = ifm.shape[0] == 1
-        return valid, f"Tensor '{ifm.name}' has batch size: {ifm.shape[0]}"
+        valid = True
+        extra = []
+        for tens in (op.ifm, op.ifm2):
+            if tens is not None:
+                batch_size = full_shape(4, tens.shape, 1)[0]
+                if batch_size != 1:
+                    valid = False
+                    extra.append(f"Tensor '{tens.name}' has batch size: {batch_size}")
+        extra = "\n   ".join(extra)
+        return valid, extra
 
     @staticmethod
     def constraint_depth_multiplier(op):
@@ -753,20 +770,6 @@
         return valid, f"Op has tensors with different quantization parameters to the OFM '{op.ofm.name}': {extra}"
 
     @staticmethod
-    def constraint_elemwise_batch_size(op):
-        "Batch size must be 1 for Input tensors with more than 2 dimensions"
-        valid = True
-        extra = []
-        for tens in (op.ifm, op.ifm2):
-            # Unary ops have ifm2 as None
-            if tens is not None:
-                if (len(tens.shape) > 2) and (tens.shape[0] != 1):
-                    valid = False
-                    extra.append(tens.name)
-        extra = ", ".join(extra)
-        return valid, f"Op has invalid input tensors: {extra}"
-
-    @staticmethod
     def constraint_broadcast_shapes(op):
         "Broadcasting is only allowed for rank indices with dimension 1, from either IFM1 or IFM2"
         ifm_shape = op.ifm.shape
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 15e1569..192862e 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -94,6 +94,9 @@
         self.generic_constraints.append(TosaSupportedOperators.constraint_rank)  # TODO not supported for all ops yet
         self.generic_constraints.append(TosaSupportedOperators.constraint_batch)  # TODO not supported for all ops yet
 
+        # Setup generic constraint exceptions
+        self.generic_constraints_exceptions = defaultdict(list)
+
         # Setup specific constraints. Note: the order matters
         self.specific_constraints = defaultdict(list)
 
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index a42b218..7740711 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -45,6 +45,7 @@
 from .tflite.Model import Model
 from .tflite_mapping import builtin_operator_map
 from .tflite_mapping import builtin_operator_name_map
+from .tflite_mapping import optype_to_builtintype
 from .tflite_model_semantic import TFLiteSemantic
 from .tflite_supported_operators import TFLiteSupportedOperators
 from .tosa_model_semantic import TosaSemantic
@@ -178,6 +179,12 @@
     # To easily exclude NetworkType from generated documentation.
     exclude_generation_network_type_value = [NetworkType.TOSA.value]
 
+    def _exclude_list_names(constraint, exclude_list):
+        constraints_excluded_names = [
+            optype_to_builtintype(op) for op, exclude_constraint in exclude_list if constraint in exclude_constraint
+        ]
+        return f" - [{', '.join(sorted(constraints_excluded_names))}]" if constraints_excluded_names else ""
+
     lines = [
         "# Supported Ops",
         "",
@@ -256,20 +263,13 @@
         for constraint in semantic_checker.generic_constraints:
             # Markdown needs two spaces at the end of a line to render it as a separate line
             reason = constraint.__doc__.replace("\n", "  \n")
-
             exclude_list = TFLiteSemantic.get_generic_constraint_exclude_list().items()
-            constraints_excluded_names = [
-                op.name for op, exclude_constraint in exclude_list if constraint in exclude_constraint
-            ]
-            excluded_constraint_text = ""
-            if constraints_excluded_names:
-                excluded_constraint_text = f"- [{', '.join(constraints_excluded_names)}]"
-
-            lines.append(f"- {reason} {excluded_constraint_text}")
+            lines.append(f"- {reason}{_exclude_list_names(constraint, exclude_list)}")
         for constraint in supported.generic_constraints:
             # Markdown needs two spaces at the end of a line to render it as a separate line
             reason = constraint.__doc__.replace("\n", "  \n")
-            lines.append(f"- {reason}")
+            exclude_list = supported.generic_constraints_exceptions.items()
+            lines.append(f"- {reason}{_exclude_list_names(constraint, exclude_list)}")
         for op, name in op_constraint_links:
             lines += [
                 "",