MLBEDSW-3774 Removed ConcatSliceWrite

-Removed ConcatSliceWrite from the optimised graph.
 Always executed as avgpool, which is equivalent with
 before the patch.

-Added copy op to enable more removal of reshapes.
 Sg input/outputs need to remain. When Reshape input and
 outut, are sg input/outputs a copy op is needed to
 be inserted, in order to remove the reshape.

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Id7be9966673ae34499e8518a5544104493fe326b
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 5c1b90b..eb93106 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -96,19 +96,19 @@
                 axis_4D = axis + (4 - len(inp.shape))
             else:
                 axis_4D = axis
-        new_op = Operation(Op.ConcatSliceWrite, op.name + str(idx))
-        new_op.inputs = [inp]
-        new_op.outputs = [ofm]
-        new_op.attrs["concat_axis"] = axis_4D
-        new_op.attrs["concat_start"] = offset
+        avgpool_op = create_avgpool_nop(op.name + str(idx) + "_avgpool")
+        avgpool_op.inputs = [inp]
+        avgpool_op.outputs = [ofm]
+        avgpool_op.attrs["concat_axis"] = axis_4D
+        avgpool_op.attrs["concat_start"] = offset
         offset += op.ifm_shapes[idx][axis_4D]
 
-        new_op.attrs["concat_end"] = offset
-        new_op.run_on_npu = True
-        ofm.ops.append(new_op)
-        DebugDatabase.add_optimised(op, new_op)
-        new_op.ifm_shapes.append(op.ifm_shapes[idx])
-        new_op.ofm_shapes.append(op.ofm_shapes[0])
+        avgpool_op.attrs["concat_end"] = offset
+        avgpool_op.run_on_npu = True
+        ofm.ops.append(avgpool_op)
+        DebugDatabase.add_optimised(op, avgpool_op)
+        avgpool_op.ifm_shapes.append(op.ifm_shapes[idx])
+        avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
     assert ofm.shape[axis] == offset
 
     # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
@@ -177,6 +177,48 @@
     return tens
 
 
+def insert_copy_op_after_tens(tens):
+    tens_cons_list_copy = tens.consumer_list.copy()
+
+    # Create a avg_pool nop op with ifm as input
+    copy_tens = tens.clone()
+    copy_op = create_avgpool_nop(tens.name + "_avgpool")
+    copy_op.add_input_tensor(tens)
+    copy_op.set_output_tensor(copy_tens)
+    copy_op.set_ifm_ofm_shapes()
+    copy_op.run_on_npu = True
+
+    # Set copy_ifm consumers
+    for tens_cons in tens_cons_list_copy:
+        if tens_cons is not None:
+            for ifm_idx, cons_inp in enumerate(tens_cons.inputs):
+                if cons_inp == tens:
+                    tens_cons.set_input_tensor(copy_tens, ifm_idx)
+
+    DebugDatabase.add_optimised(tens.ops[0], copy_op)
+
+
+def fix_sg_input_output(op, arch, nng):
+    if not op.run_on_npu or op.type != Op.Reshape:
+        return op
+
+    # For the memory operators we want to remove, tensors are removed.
+    # But in order to to do this, they cannot be outputs of the sg,
+    # this need to be fixed prior to the removal.
+    # Solution is to add a avgpool NOP, to maintain the original tensor.
+
+    # Check if operator ifm/ofm are sg ifm/ofm
+    ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
+    ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
+    ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
+
+    if op.type == Op.Reshape and (ifm_is_sg_ofm or ifm_is_sg_ifm) and ofm_is_sg_ofm:
+        # Both ifm and ofm are sg outputs, only ifm need a copy, in order to remove the Reshape
+        insert_copy_op_after_tens(op.ifm)
+
+    return op
+
+
 def needed_total_padding(input_size, stride, filter_size):
     out_size = (input_size + stride - 1) // stride
     needed_input = (out_size - 1) * stride + filter_size
@@ -1020,50 +1062,12 @@
             # or the reshape need to be replace with a NOP.
             return
 
-        # Check if ifm is a sg input
-        if ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const):
-            # put the reshape on CPU
-            op.run_on_npu = False
-            return
-
         # Check if Reshape ifm/ofm are network ifm/ofm
+        ifm_is_sg_ifm = ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
         ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in ifm.consumer_list)
         ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in ofm.consumer_list)
-
-        if ifm_is_sg_ofm and ofm_is_sg_ofm:
-            # Both ifm and ofm are sg outputs,add reshape to the ifm and put it on CPU
-            ifm_cons_list_copy = ifm.consumer_list.copy()
-            ifm_ops_copy = ifm.ops.copy()
-            for ifm_cons in ifm_cons_list_copy:
-                if ifm_cons is None:
-                    # Create a reshape op with ifm as output
-                    name = ifm.name + "_cpu_reshape"
-                    reshape_ifm = ifm.clone()
-                    reshape_op = Operation(Op.Reshape, name)
-                    reshape_op.attrs["new_shape"] = ifm.shape
-                    reshape_op.add_input_tensor(reshape_ifm)
-                    reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, ifm.shape))
-                    reshape_op.set_output_tensor(ifm)
-                    reshape_op.set_ifm_ofm_shapes()
-                    reshape_op.run_on_npu = False
-                    reshape_op.ofm.ops = [reshape_op]
-                    reshape_op.ofm.consumer_list = [None]
-
-                    # Set reshape_ifm producers
-                    for prev_op in ifm_ops_copy:
-                        prev_op.outputs = [reshape_ifm]
-                        reshape_ifm.ops.append(prev_op)
-
-                    # Set reshape_ifm consumers
-                    for ifm_cons in ifm_cons_list_copy:
-                        if ifm_cons is not None:
-                            for ifm_idx, cons_ifm in enumerate(ifm_cons.inputs):
-                                if cons_ifm == ifm:
-                                    ifm_cons.set_input_tensor(reshape_ifm, ifm_idx)
-
-                    ifm = reshape_ifm
-                    break
-            ifm_is_sg_ofm = False
+        # This case should be handled prior to this function
+        assert not ((ifm_is_sg_ifm or ifm_is_sg_ofm) and ofm_is_sg_ofm)
 
         if ofm_is_sg_ofm:
             # Bypassed by replacing ifm with ofm
@@ -1244,6 +1248,12 @@
             nng, sg, arch, [rewrite_split_ops], [], rewrite_unsupported=False,
         )
 
+    # Handle sg input output
+    for idx, sg in enumerate(nng.subgraphs):
+        nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
+            nng, sg, arch, [], [fix_sg_input_output], rewrite_unsupported=False,
+        )
+
     # Removal of reshapes
     for sg in nng.subgraphs:
         rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [remove_reshapes])
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 66613ba..97b42ae 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -109,7 +109,7 @@
     concat_offset = 0
 
     for op in ps.ops:
-        if op.type == Op.ConcatSliceWrite:
+        if op.attrs.get("concat_axis", None) is not None:
             concat_axis = op.attrs["concat_axis"]
             concat_start = op.attrs["concat_start"]
             concat_end = op.attrs["concat_end"]
@@ -117,7 +117,7 @@
             ofm_start[concat_axis] = concat_start
             ofm_end[concat_axis] = concat_end
             concat_offset = concat_start
-            ps.primary_op.memory_function = op.type
+            ps.primary_op.memory_function = Op.ConcatSliceWrite
         elif op.type.is_relu_op() or op.type in (Op.Tanh, Op.Sigmoid):
             ps.primary_op.activation = create_activation_function(op.type)
 
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index abd235f..b52b159 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -81,7 +81,7 @@
 
 npu_post_fuse_limited_ops = set(
     # Set of post operators that should not be fused with main/elementwise ops
-    (Op.ConcatSliceWrite, Op.Sigmoid, Op.Tanh, Op.Quantize)
+    (Op.Sigmoid, Op.Tanh, Op.Quantize)
 )
 
 elem_wise_ops = elem_wise_main_ops | activation_ops | set((Op.Sigmoid, Op.Tanh))
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index 4281d31..40b8cd5 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -178,17 +178,6 @@
         reshape1_op.attrs["new_shape"] = reshape1_ofm_shape
         reshape1_op.run_on_npu = True
 
-        # create reshape2 op
-        reshape2_ofm_shape = [1, 8, 8, 16]
-        reshape2_ofm = create_const_tensor(
-            "reshape2_out", reshape2_ofm_shape, DataType.uint8, np.zeros(reshape2_ofm_shape)
-        )
-        reshape2_ofm.quantization = quant
-        shape_tens = create_const_tensor("reshape2_shape", [1], DataType.int32, reshape2_ofm_shape)
-        reshape2_op = testutil.create_op(Op.Reshape, [reshape1_ofm, shape_tens], reshape2_ofm, set_ifm_ofm_shapes=False)
-        reshape2_op.attrs["new_shape"] = reshape2_ofm_shape
-        reshape2_op.run_on_npu = True
-
         # create conv op
         conv_ofm = Tensor([1, 8, 8, 16], DataType.uint8, "output")
         conv_ofm.quantization = quant.clone()
@@ -206,40 +195,37 @@
         )
         conv2d_op.run_on_npu = True
 
-        # create reshape3 op
+        # create reshape2 op
         ofm_shape = [8, 8, 16]
-        reshape3_ofm = create_const_tensor("reshape3_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
-        reshape3_ofm.quantization = quant
-        shape_tens = create_const_tensor("reshape3_shape", [1], DataType.int32, ofm_shape)
-        reshape3_op = testutil.create_op(Op.Reshape, [conv_ofm, shape_tens], reshape3_ofm, set_ifm_ofm_shapes=False)
-        reshape3_op.attrs["new_shape"] = ofm_shape
-        reshape3_op.run_on_npu = True
+        reshape2_ofm = create_const_tensor("reshape2_out", ofm_shape, DataType.uint8, np.zeros(ofm_shape))
+        reshape2_ofm.quantization = quant
+        shape_tens = create_const_tensor("reshape2_shape", [1], DataType.int32, ofm_shape)
+        reshape2_op = testutil.create_op(Op.Reshape, [conv_ofm, shape_tens], reshape2_ofm, set_ifm_ofm_shapes=False)
+        reshape2_op.attrs["new_shape"] = ofm_shape
+        reshape2_op.run_on_npu = True
         nng = Graph()
-        sg = testutil.create_subgraph([reshape1_op, reshape2_op, conv2d_op, reshape3_op])
+        sg = testutil.create_subgraph([reshape1_op, conv2d_op, reshape2_op])
         nng.subgraphs.append(sg)
 
-        return nng, reshape1_op, reshape2_op, conv2d_op, reshape3_op
+        return nng, reshape1_op, conv2d_op, reshape2_op
 
     # Test1 no Reshape op is expected to remain in the NPU subgrapgh
     # but first one will be put on CPU
-    # Network is Reshape-Reshape-Conv-Reshape
-    # Result is cpu_Reshape-Conv
-    nng, reshape1_op, reshape2_op, conv2d_op, reshape3_op = setup_network()
+    # Network is Reshape-Conv-Reshape
+    # Result is Conv
+    nng, reshape1_op, conv2d_op, reshape2_op = setup_network()
     arch = testutil.create_arch()
     assert verify_graph_health(nng)
     nng = optimise_graph_a(nng, arch)
     assert verify_graph_health(nng)
-    assert conv2d_op.ifm == reshape1_op.ofm
-    assert conv2d_op.ofm == reshape3_op.ofm
 
-    # Test2 reshape2 with different quantisation, this Reshape op is expected to remain
-    # Network is Reshape-Reshape-Conv-Reshape
-    # expected is cpu_Reshape-Reshape-Conv
-    nng, reshape1_op, reshape2_op, conv2d_op, reshape3_op = setup_network()
+    # Test2 reshape1 with different quantisation, this Reshape op is expected to remain
+    # Network is Reshape-Conv-Reshape
+    # expected is Reshape-Conv
+    nng, reshape1_op, conv2d_op, reshape2_op = setup_network()
     quant_zp32 = testutil.default_quant_params()
     quant_zp32.zero_point = 32
-    reshape2_op.ofm.quantization = quant_zp32
+    reshape1_op.ofm.quantization = quant_zp32
     assert verify_graph_health(nng)
     nng = optimise_graph_a(nng, arch)
     assert verify_graph_health(nng)
-    assert conv2d_op.ofm == reshape3_op.ofm