TOSA: Decomposition of CONCAT

-Added support for unlimited number of dimensions
-Added support for Tensors with dimension size
 exceeding maximum limit of NPU.

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I3cc7327ac759e69042a600e686160aeb18a5ec59
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 1e059cc..5cd9d21 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -265,32 +265,21 @@
             DebugDatabase.add_optimised(op, add_op)
 
 
-def rewrite_concat_ops(op, arch):
+def rewrite_concat(op):
     if not op.run_on_npu or not op.type == Op.Concat:
         return
 
-    axis_4D = 0
-    ofm = op.ofm
-    ofm.ops = []
     offset = 0
-
     inputs = op.inputs
-    axis = op.attrs["axis"]
+    axis_4D = op.attrs["axis4D"]
 
     for idx, inp in enumerate(inputs):
-        op.ifm_shapes[idx] = Shape4D(inp.shape)
-        if axis >= 0:
-            axis_4D = axis + (4 - len(inp.shape))
-        else:
-            axis_4D = axis
         write_offset = [0, 0, 0, 0]
         write_offset[axis_4D] = offset
         concat_end = offset + op.ifm_shapes[idx][axis_4D]
         create_add_for_concat(op, op.name + str(idx) + "_add", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset))
         offset = concat_end
-    assert ofm.shape[axis] == offset
-
-    return op
+    assert op.ofm_shapes[0][axis_4D] == offset
 
 
 def remove_reshapes(op, arch):
@@ -503,12 +492,15 @@
     return convert_to_lut(op, table.values, "table")
 
 
-def decompose_tensors_hwc(op):
+def decompose_elem_tensors_hwc(op):
+    """
+    Decomposes elementwise op if any of the ifm(s)/ofm are to large in any dimension to be handled by the NPU
+    """
     max_t_size = 65535
-    ofm_shape = op.ofm_shapes[0]
-    ifm_shape = op.ifm_shapes[0]
+    ofm_shape = op.write_shape if op.write_shape is not None else op.ofm_shapes[0]
+    ifm_shape = op.read_shapes[0] if op.read_shapes[0] is not None else op.ifm_shapes[0]
     ifm2_shape = op.ifm_shapes[1] if op.ifm_shapes[1] else None
-
+    ifm2_shape = op.read_shapes[1] if op.read_shapes[1] is not None else ifm2_shape
     limit_shape = Shape4D(1, max_t_size, max_t_size, max_t_size)
 
     if any(dim_size > max_t_size for dim_size in ofm_shape.as_list()):
@@ -536,7 +528,6 @@
                         ifm2_part_shape = ifm2_shape.clip(ifm2_offset, limit_shape)
                         ifm2_cut = (ifm2_offset, ifm2_part_shape)
                     else:
-                        ifm2_offset = None
                         ifm2_cut = (None, None)
 
                     create_elem_part_op(op, ifm_cut, ifm2_cut, ofm_cut)
@@ -582,17 +573,23 @@
     """
     Returns the shapes of ifm(s)/ofms after removing all the dimensions that are 1 for all ifm(s)/ofm
     """
-    rank = len(op.ofm.shape)
     binary = op.ifm2 is not None
+    ofm_shape = op.ofm_shapes[0].as_list() if len(op.ofm_shapes) > 0 else op.ofm.shape
+    ifm_shape = op.ifm_shapes[0].as_list() if len(op.ifm_shapes) > 0 else op.ifm.shape
+    if binary:
+        ifm2_shape = op.ifm_shapes[1].as_list() if len(op.ofm_shapes) else op.ifm2.shape
+
+    rank = len(ofm_shape)
     new_ofm_shape = []
     new_ifm_shape = []
     new_ifm2_shape = []
     for idx in range(rank):
-        if op.ofm.shape[idx] != 1:
-            new_ofm_shape.append(op.ofm.shape[idx])
-            new_ifm_shape.append(op.ifm.shape[idx])
+        if ofm_shape[idx] != 1:
+            new_ofm_shape.append(ofm_shape[idx])
+            new_ifm_shape.append(ifm_shape[idx])
             if binary:
-                new_ifm2_shape.append(op.ifm2.shape[idx])
+                new_ifm2_shape.append(ifm2_shape[idx])
+
     if new_ofm_shape == []:
         new_ofm_shape = [1]
         new_ifm_shape = [1]
@@ -614,7 +611,6 @@
     ifm2 = op.ifm2
     ofm = op.ofm
     binary = op.ifm2 is not None
-    assert len(ofm.shape) <= 6
 
     # Remove dimensions that are all 1
     new_ofm_shape, new_ifm_shape, new_ifm2_shape = get_elem_shapes_removed_singles(op)
@@ -679,19 +675,74 @@
 
 def decomp_elementwise(tens, arch, nng):
     """
-    Decompose elementwise ops with Rank > 3 (H,W,D).
+    Decompose elementwise ops with Rank > 3 (H,W,C).
     Decompose size of tensors exceeding NPU max size
     """
-    if len(tens.ops) == 1 and tens.ops[0].type.is_elementwise_op():
+    tens_ops = tens.ops.copy()
+    for op in tens_ops:
+        if op.type.is_elementwise_op():
+            decomp_list = decomp_dims_elementwise(op)
+            for part_op in decomp_list:
+                decompose_elem_tensors_hwc(part_op)
+    return tens
+
+
+def reshape_concat_shape(shape, rank, axis):
+    new_h = 1
+    for i in range(axis):
+        new_h *= shape[i]
+    new_c = 1
+    for i in range(axis + 1, rank):
+        new_c *= shape[i]
+    if axis == (rank - 1):
+        new_shape = [new_h, shape[axis], 1]
+    else:
+        new_shape = [new_h, shape[axis], new_c]
+    return new_shape
+
+
+def reshape_concat(op):
+    """
+    Reshapes concat ops with Rank > 3 (H,W,C).
+    """
+    ofm = op.ofm
+    rank = len(ofm.shape)
+    axis = op.attrs["axis"]
+    if axis < 0:
+        axis += rank
+
+    if rank > 3:
+        # Reshape so that axis in to be concatenated is the W dimension
+        # Reshape inputs
+        for inp in op.inputs:
+            new_shape = reshape_concat_shape(inp.shape, rank, axis)
+            op.ifm_shapes.append(Shape4D(new_shape))
+        # Reshape output
+        new_shape = reshape_concat_shape(ofm.shape, rank, axis)
+        op.ofm_shapes.append(Shape4D(new_shape))
+        op.attrs["axis4D"] = 2
+    else:
+        for inp in op.inputs:
+            op.ifm_shapes.append(Shape4D(inp.shape))
+        op.ofm_shapes.append(Shape4D(ofm.shape))
+        op.attrs["axis4D"] = axis + (4 - rank)
+
+
+def decomp_rewrite_concat(tens, arch, nng):
+    """
+    Decompose concat ops with Rank > 3 (H,W,C).
+    Rewrite of concat to elementwise operations
+    """
+    if len(tens.ops) == 1 and tens.ops[0].type == Op.Concat:
         op = tens.ops[0]
-        rank = len(op.ofm.shape)
-        assert rank <= 6
 
-        decomp_list = []
-        decomp_list = decomp_dims_elementwise(op)
+        reshape_concat(op)
+        rewrite_concat(op)
 
-        for part_op in decomp_list:
-            decompose_tensors_hwc(part_op)
+        op.ofm.ops.remove(op)
+        for inp in op.inputs:
+            inp.consumer_list.remove(op)
+
     return tens
 
 
@@ -714,21 +765,27 @@
 
 def tosa_optimise_graph(nng, arch):
 
-    # Decomposing to 4 dimensions
+    # TODO the supported operator checking need to be split in semantic and HW checks
+    for idx, sg in enumerate(nng.subgraphs):
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [], [supported_operator_check], rewrite_unsupported=False,
+        )
+
+    # Decomposing and rewrite of concat
+    for idx, sg in enumerate(nng.subgraphs):
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [decomp_rewrite_concat], [], rewrite_unsupported=False
+        )
+
+    # Decomposing of elementwise
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
             nng, sg, arch, [decomp_elementwise], [], rewrite_unsupported=False
         )
 
-    # Pre-processing step
-    pre_process_list = [
-        supported_operator_check,
-        set_ifm_ofm_op_shapes,
-    ]
-
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [], pre_process_list, rewrite_unsupported=False,
+            nng, sg, arch, [], [set_ifm_ofm_op_shapes], rewrite_unsupported=False,
         )
 
     # Removal of Transpose
@@ -743,11 +800,6 @@
             nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False,
         )
 
-    # Rewrite concat ops
-    for idx, sg in enumerate(nng.subgraphs):
-        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
-        sg.refresh_after_modification()
-
     # Removal of reshapes
     for sg in nng.subgraphs:
         rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 1012a61..d71e575 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -46,6 +46,10 @@
     activation_ops = relu_ops | set((Op.Table,))
     pad_ops = set((Op.Pad,))
 
+    rank_unlimited_ops = set((Op.Concat,))
+    rank6_limited_ops = elem_wise_ops
+    batch_enabled_ops = elem_wise_ops | set((Op.Concat,))
+    large_tens_dims_enabled_ops = elem_wise_ops | set((Op.Concat,))
     npu_post_ops = activation_ops
 
     supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | elem_wise_ops | pad_ops
@@ -60,8 +64,10 @@
         self.generic_constraints = []
         self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dtype)
         self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dimension)  # TODO as not supported yet
-        self.generic_constraints.append(TosaSupportedOperators.constraint_rank)  # TODO as not supported yet
-        self.generic_constraints.append(TosaSupportedOperators.constraint_batch)  # TODO as not supported yet
+        self.generic_constraints.append(TosaSupportedOperators.constraint_rank)  # TODO as not supported for all ops yet
+        self.generic_constraints.append(
+            TosaSupportedOperators.constraint_batch
+        )  # TODO as not supported for all ops yet
 
         # Setup specific constraints. Note: the order matters
         self.specific_constraints = defaultdict(list)
@@ -118,11 +124,11 @@
     @classmethod
     @docstring_format_args(tens_dim_range)
     def constraint_tens_dimension(self, op):
-        "Tensor dimensions must be in the range [{}, {}], if not elementwise"
+        "Tensor dimensions must be in the range [{}, {}]"
         tens_min, tens_max = self.tens_dim_range
         valid = True
         extra = []
-        if op.type not in self.binary_elem_wise_add_mul_sub:
+        if op.type not in self.large_tens_dims_enabled_ops:
             tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
             if not tensors:
                 tensors = [tens for tens in op.inputs if tens]
@@ -135,16 +141,20 @@
     # TODO This is for a HW limitation, that is to be resolved in SW later on
     @classmethod
     def constraint_rank(self, op):
-        "Tensor rank must be <= 4, if not elementwise"
+        "Tensor rank must be <= 6 or <= 4 depending on operator"
         valid = True
         extra = []
-        if op.type not in self.binary_elem_wise_add_mul_sub:
+        if op.type not in self.rank_unlimited_ops:
+            if op.type in self.rank6_limited_ops:
+                rank_limit = 6
+            else:
+                rank_limit = 4
             tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
             if not tensors:
                 tensors = [tens for tens in op.inputs if tens]
             for tens in tensors:
                 rank = len(tens.shape)
-                if not rank <= 4:
+                if not rank <= rank_limit:
                     valid = False
                     extra.append(f"Tensor '{tens.name}' has rank: {rank}")
         return valid, ", ".join(extra)
@@ -152,10 +162,10 @@
     # TODO This is for a HW limitation, that is to be resolved in SW later on
     @classmethod
     def constraint_batch(self, op):
-        "If Tensor rank is 4 batch of ifms/ofm must be 1, if not elementwise"
+        "If Tensor rank is 4 batch of ifms/ofm must be 1"
         valid = True
         extra = []
-        if op.type not in self.binary_elem_wise_add_mul_sub:
+        if op.type not in self.batch_enabled_ops:
             tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens]
             if not tensors:
                 tensors = [tens for tens in op.inputs if tens]