TOSA: Decomposition of CONCAT

-Added support for unlimited number of dimensions
-Added support for Tensors with dimension size
 exceeding maximum limit of NPU.

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I3cc7327ac759e69042a600e686160aeb18a5ec59
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 1012a61..d71e575 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -46,6 +46,10 @@
     activation_ops = relu_ops | set((Op.Table,))
     pad_ops = set((Op.Pad,))
 
+    rank_unlimited_ops = set((Op.Concat,))
+    rank6_limited_ops = elem_wise_ops
+    batch_enabled_ops = elem_wise_ops | set((Op.Concat,))
+    large_tens_dims_enabled_ops = elem_wise_ops | set((Op.Concat,))
     npu_post_ops = activation_ops
 
     supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | elem_wise_ops | pad_ops
@@ -60,8 +64,10 @@
         self.generic_constraints = []
         self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dtype)
         self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dimension)  # TODO as not supported yet
-        self.generic_constraints.append(TosaSupportedOperators.constraint_rank)  # TODO as not supported yet
-        self.generic_constraints.append(TosaSupportedOperators.constraint_batch)  # TODO as not supported yet
+        self.generic_constraints.append(TosaSupportedOperators.constraint_rank)  # TODO as not supported for all ops yet
+        self.generic_constraints.append(
+            TosaSupportedOperators.constraint_batch
+        )  # TODO as not supported for all ops yet
 
         # Setup specific constraints. Note: the order matters
         self.specific_constraints = defaultdict(list)
@@ -118,11 +124,11 @@
     @classmethod
     @docstring_format_args(tens_dim_range)
     def constraint_tens_dimension(self, op):
-        "Tensor dimensions must be in the range [{}, {}], if not elementwise"
+        "Tensor dimensions must be in the range [{}, {}]"
         tens_min, tens_max = self.tens_dim_range
         valid = True
         extra = []
-        if op.type not in self.binary_elem_wise_add_mul_sub:
+        if op.type not in self.large_tens_dims_enabled_ops:
             tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
             if not tensors:
                 tensors = [tens for tens in op.inputs if tens]
@@ -135,16 +141,20 @@
     # TODO This is for a HW limitation, that is to be resolved in SW later on
     @classmethod
     def constraint_rank(self, op):
-        "Tensor rank must be <= 4, if not elementwise"
+        "Tensor rank must be <= 6 or <= 4 depending on operator"
         valid = True
         extra = []
-        if op.type not in self.binary_elem_wise_add_mul_sub:
+        if op.type not in self.rank_unlimited_ops:
+            if op.type in self.rank6_limited_ops:
+                rank_limit = 6
+            else:
+                rank_limit = 4
             tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
             if not tensors:
                 tensors = [tens for tens in op.inputs if tens]
             for tens in tensors:
                 rank = len(tens.shape)
-                if not rank <= 4:
+                if not rank <= rank_limit:
                     valid = False
                     extra.append(f"Tensor '{tens.name}' has rank: {rank}")
         return valid, ", ".join(extra)
@@ -152,10 +162,10 @@
     # TODO This is for a HW limitation, that is to be resolved in SW later on
     @classmethod
     def constraint_batch(self, op):
-        "If Tensor rank is 4 batch of ifms/ofm must be 1, if not elementwise"
+        "If Tensor rank is 4 batch of ifms/ofm must be 1"
         valid = True
         extra = []
-        if op.type not in self.binary_elem_wise_add_mul_sub:
+        if op.type not in self.batch_enabled_ops:
             tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens]
             if not tensors:
                 tensors = [tens for tens in op.inputs if tens]