TOSA: Added RESHAPE, SLICE and CONCAT

Added support for Data layout ops
RESHAPE, SLICE and CONCAT.
-No support for bool_t
-Support limited to Rank <= 4 and N = 1

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I487ac494b6506a2a6ba947ee758aa193194dd796
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 169da40..f3cddad 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -22,14 +22,17 @@
 from .graph_optimiser_util import bypass_reshape_and_squeeze_ops
 from .graph_optimiser_util import calc_explicit_padding
 from .graph_optimiser_util import convert_depthwise_to_conv
-from .graph_optimiser_util import fix_sg_input_output
+from .graph_optimiser_util import move_splitsliceread_to_consumer
 from .graph_optimiser_util import needed_total_padding
 from .graph_optimiser_util import set_ifm_ofm_op_shapes
 from .graph_optimiser_util import set_tensor_equivalence
 from .operation import ExplicitScaling
 from .operation import NpuBlockType
 from .operation import Op
+from .operation_util import create_add_nop
 from .operation_util import create_avgpool_nop
+from .shape4d import Shape4D
+from .tensor import create_const_tensor
 
 
 def replace_rescale_with_avg_pool(rescale_op):
@@ -103,12 +106,157 @@
                 removed = True
 
         if not removed:
-            print("Cannot remove Transpose, and handling of Transpose is not supported")
+            print("Warning: Cannot remove Transpose, and handling of Transpose is not supported")
             assert False
 
     return op
 
 
+# TODO can we change to add for both TFLite and TOSA?
+def insert_add_copy_op_after_tens(tens):
+    tens_cons_list_copy = tens.consumer_list.copy()
+    copy_tens = tens.clone()
+
+    name = tens.name + "_add"
+    ifm2 = create_const_tensor(
+        name + "_zero_scalar",
+        [1],
+        copy_tens.dtype,
+        [0],
+        copy_tens.dtype.as_numpy_type(),
+        quantization=copy_tens.quantization,
+    )
+    copy_op = create_add_nop(name)
+    copy_op.add_input_tensor(tens)
+    copy_op.add_input_tensor(ifm2)
+    copy_op.set_output_tensor(copy_tens)
+    copy_op.set_ifm_ofm_shapes()
+    copy_op.run_on_npu = True
+
+    # Set copy_ifm consumers
+    for tens_cons in tens_cons_list_copy:
+        if tens_cons is not None:
+            for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
+                if cons_inp == tens:
+                    tens_cons.set_input_tensor(copy_tens, ifm_idx)
+
+    DebugDatabase.add_optimised(tens.ops[0], copy_op)
+
+
+def fix_sg_input_output_tosa(op, arch, nng):
+    if not op.run_on_npu or op.type != Op.Reshape:
+        return op
+
+    # For the Reshape operators we want to remove, tensors are removed.
+    # But in order to to do this, they cannot be outputs of the sg,
+    # this need to be fixed prior to the removal.
+    # Solution is to add a copy op, to maintain the original tensor.
+    # This is also valid when reshape ifm/ofm is produced respectively
+    # consumed by CPU
+
+    # Check if operator ifm/ofm are sg ifm/ofm
+    ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
+    ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
+    ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
+    # Check if ifm/ofm is produced repectivly consumed by CPU
+    ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
+    ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
+
+    if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
+        # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
+        insert_add_copy_op_after_tens(op.ifm)
+
+    return op
+
+
+def create_add_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
+    """Creates an add op for the given concat op/input feature map"""
+    ofm = concat_op.ofm
+    ifm2 = create_const_tensor(
+        name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
+    )
+    add_op = create_add_nop(name)
+
+    add_op.inputs = [ifm, ifm2]
+    add_op.outputs = [ofm]
+    add_op.write_offset = write_offset
+    add_op.write_shape = ifm_shape
+    ofm.ops.append(add_op)
+    DebugDatabase.add_optimised(concat_op, add_op)
+    add_op.ifm_shapes.append(ifm_shape)
+    add_op.ifm_shapes.append(Shape4D(ifm2.shape))
+    add_op.ofm_shapes.append(concat_op.ofm_shapes[0])
+    add_op.memory_function = Op.ConcatSliceWrite
+    return add_op
+
+
+# TODO Could be further optimized checking the type of the consumer,
+# rather than just mimic the TFLite behaviour depending on type.
+# TOSA bool_t not considered yet
+def remove_splitsliceread(op, arch):
+
+    if op.type == Op.SplitSliceRead:
+        # Check if it is possible to put the SplitSliceRead on the tensor consumer, or if an avgpool need to be inserted
+        if (
+            len(op.ofm.consumer_list) == 1
+            and op.ofm.consumer_list[0] is not None
+            and op.ofm.consumer_list[0].run_on_npu
+            and op.ofm.consumer_list[0].type != Op.Reshape
+            and op.ofm_shapes[0] == Shape4D.from_list(op.ofm.shape)
+            and op.ofm.dtype in (DataType.uint8, DataType.int8, DataType.int16)
+        ):
+            # SplitSliceRead can be performed by tensor consumer
+            cons_op = op.ofm.consumer_list[0]
+            move_splitsliceread_to_consumer(op, cons_op)
+        else:
+            name = op.name + "_add"
+            ofm = op.ofm
+            ifm2 = create_const_tensor(
+                name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
+            )
+            add_op = create_add_nop(name)
+            add_op.inputs = [op.ifm, ifm2]
+            add_op.outputs = [ofm]
+            op.ofm.ops.remove(op)
+            op.ofm.ops.append(add_op)
+            add_op.ifm_shapes.append(op.ifm_shapes[0])
+            add_op.ifm_shapes.append(Shape4D(ifm2.shape))
+            add_op.ofm_shapes.append(op.ofm_shapes[0])
+            add_op.read_offsets[0] = op.read_offsets[0]
+            add_op.read_shapes[0] = op.read_shapes[0]
+
+            op.ifm.consumer_list.remove(op)
+            DebugDatabase.add_optimised(op, add_op)
+
+
+def rewrite_concat_ops(op, arch):
+    if not op.run_on_npu or not op.type == Op.Concat:
+        return
+
+    axis_4D = 0
+    ofm = op.ofm
+    ofm.ops = []
+    offset = 0
+
+    inputs = op.inputs
+    axis = op.attrs["axis"]
+
+    for idx, inp in enumerate(inputs):
+        op.ifm_shapes[idx] = Shape4D(inp.shape)
+        if axis >= 0:
+            axis_4D = axis + (4 - len(inp.shape))
+        else:
+            axis_4D = axis
+        write_offset = [0, 0, 0, 0]
+        write_offset[axis_4D] = offset
+        concat_end = offset + op.ifm_shapes[idx][axis_4D]
+        create_add_for_concat(op, op.name + str(idx) + "_add", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset))
+        offset = concat_end
+    assert ofm.shape[axis] == offset
+
+    return op
+
+
 def remove_reshapes(op, arch):
     if op.run_on_npu and op.type == Op.Reshape:
         bypass_reshape_and_squeeze_ops(op)
@@ -271,9 +419,14 @@
     # Handle sg input output
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
+            nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False,
         )
 
+    # Rewrite concat ops
+    for idx, sg in enumerate(nng.subgraphs):
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [rewrite_concat_ops])
+        sg.refresh_after_modification()
+
     # Removal of reshapes
     for sg in nng.subgraphs:
         rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
@@ -293,6 +446,12 @@
             nng, sg, arch, [], [rewrite_activation, add_padding_fields],
         )
 
+    # Removal of Slice, need to be done after optimisation has been performed,
+    # since ifm/ofm_shapes are of importance to this function
+    for sg in nng.subgraphs:
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_splitsliceread])
+        sg.refresh_after_modification()
+
     # Post-processing step 2
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [fixup_quantization],)