TOSA: Added decomposition of RESHAPE

-Added support for unlimited number of dimensions
-Added support for tensors exceeding maxlimit of NPU
-Fixed regression for PAD

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Ib2ce50a30cc5cf396032d85d57dab9968e3fc06a
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 5cd9d21..d32955d 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -149,7 +149,7 @@
 
 
 # TODO can we change to add for both TFLite and TOSA?
-def insert_add_copy_op_after_tens(tens):
+def insert_add_copy_op_after_tens(tens, ifm_ofm_shape):
     tens_cons_list_copy = tens.consumer_list.copy()
     copy_tens = tens.clone()
 
@@ -166,7 +166,9 @@
     copy_op.add_input_tensor(tens)
     copy_op.add_input_tensor(ifm2)
     copy_op.set_output_tensor(copy_tens)
-    copy_op.set_ifm_ofm_shapes()
+    copy_op.ifm_shapes.append(ifm_ofm_shape)
+    copy_op.ifm_shapes.append(Shape4D(ifm2.shape))
+    copy_op.ofm_shapes.append(ifm_ofm_shape)
     copy_op.run_on_npu = True
 
     # Set copy_ifm consumers
@@ -200,7 +202,29 @@
 
     if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
         # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
-        insert_add_copy_op_after_tens(op.ifm)
+
+        # Decide on ifm/ofm shapes for the copy op based on ifm
+        shape = op.ifm.shape.copy()
+        # remove dimensions that are set to 1
+        new_shape = []
+        for dim in shape:
+            if dim != 1:
+                new_shape.append(dim)
+        if not new_shape:
+            new_shape = [1]
+
+        rank = len(new_shape)
+        if rank > 3:
+            # Reshape so that batch becomes 1, by moving elements to H dimension
+            n = rank - 2
+            h = 1
+            for i in range(n):
+                h *= shape[i]
+            new_shape = Shape4D(new_shape[n:]).with_height(h)
+        else:
+            new_shape = Shape4D(new_shape)
+
+        insert_add_copy_op_after_tens(op.ifm, new_shape)
 
     return op
 
@@ -435,16 +459,12 @@
 
     quant = ofm.quantization
     pad_value = ifm.quantization.zero_point
+    ifm.quantization.zero_point = 0
     # Add operations that fill the borders of the OFM
     if top > 0:
         shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
         zero_tens = create_const_tensor(
-            op.name + "_top",
-            shape.as_list(),
-            ofm.dtype,
-            shape.elements() * [pad_value],
-            np.uint8,
-            quantization=quant,  # TODO
+            op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant,
         )
         # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
         zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
@@ -569,6 +589,16 @@
     return Shape4D(stride_n, stride_y, stride_x, 1)
 
 
+def pad_to_rank(shape, rank):
+    """
+    Pads a shape to the given rank
+    """
+    while len(shape) < rank:
+        shape = [1] + shape
+
+    return shape
+
+
 def get_elem_shapes_removed_singles(op):
     """
     Returns the shapes of ifm(s)/ofms after removing all the dimensions that are 1 for all ifm(s)/ofm
@@ -579,7 +609,12 @@
     if binary:
         ifm2_shape = op.ifm_shapes[1].as_list() if len(op.ofm_shapes) else op.ifm2.shape
 
-    rank = len(ofm_shape)
+    rank = max(len(ofm_shape), len(ifm_shape), len(ifm2_shape) if binary else 0)
+    ofm_shape = pad_to_rank(ofm_shape, rank)
+    ifm_shape = pad_to_rank(ifm_shape, rank)
+    if binary:
+        ifm2_shape = pad_to_rank(ifm2_shape, rank)
+
     new_ofm_shape = []
     new_ifm_shape = []
     new_ifm2_shape = []
@@ -777,6 +812,17 @@
             nng, sg, arch, [decomp_rewrite_concat], [], rewrite_unsupported=False
         )
 
+    # Handle sg input output
+    for idx, sg in enumerate(nng.subgraphs):
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False,
+        )
+
+    # Removal of reshapes
+    for sg in nng.subgraphs:
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
+        sg.refresh_after_modification()
+
     # Decomposing of elementwise
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
@@ -794,17 +840,6 @@
             nng, sg, arch, [], [remove_const_transpose], rewrite_unsupported=False,
         )
 
-    # Handle sg input output
-    for idx, sg in enumerate(nng.subgraphs):
-        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False,
-        )
-
-    # Removal of reshapes
-    for sg in nng.subgraphs:
-        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
-        sg.refresh_after_modification()
-
     # TODO, when and where to best handle calc_scaling_avgpool
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index d71e575..2692c05 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -46,10 +46,10 @@
     activation_ops = relu_ops | set((Op.Table,))
     pad_ops = set((Op.Pad,))
 
-    rank_unlimited_ops = set((Op.Concat,))
+    rank_unlimited_ops = set((Op.Concat, Op.Reshape))
     rank6_limited_ops = elem_wise_ops
-    batch_enabled_ops = elem_wise_ops | set((Op.Concat,))
-    large_tens_dims_enabled_ops = elem_wise_ops | set((Op.Concat,))
+    batch_enabled_ops = rank6_limited_ops | rank_unlimited_ops
+    large_tens_dims_enabled_ops = batch_enabled_ops | set((Op.SplitSliceRead,))
     npu_post_ops = activation_ops
 
     supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | elem_wise_ops | pad_ops
@@ -63,11 +63,9 @@
         # Setup the generic constraints. Note: the order matters
         self.generic_constraints = []
         self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dtype)
-        self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dimension)  # TODO as not supported yet
-        self.generic_constraints.append(TosaSupportedOperators.constraint_rank)  # TODO as not supported for all ops yet
-        self.generic_constraints.append(
-            TosaSupportedOperators.constraint_batch
-        )  # TODO as not supported for all ops yet
+        self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dimension)  # TODO not supported yet
+        self.generic_constraints.append(TosaSupportedOperators.constraint_rank)  # TODO not supported for all ops yet
+        self.generic_constraints.append(TosaSupportedOperators.constraint_batch)  # TODO not supported for all ops yet
 
         # Setup specific constraints. Note: the order matters
         self.specific_constraints = defaultdict(list)
@@ -156,7 +154,10 @@
                 rank = len(tens.shape)
                 if not rank <= rank_limit:
                     valid = False
-                    extra.append(f"Tensor '{tens.name}' has rank: {rank}")
+                    extra.append(
+                        f"Tensor '{tens.name}' has rank: {rank}, rank limit is currently {rank_limit}"
+                        f" for op of type {op.type}"
+                    )
         return valid, ", ".join(extra)
 
     # TODO This is for a HW limitation, that is to be resolved in SW later on