TOSA: Added support for Const output

Added support for a Const operator generating
network output.

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Ia81990a94cc497a58535914124a29e7dbb511247
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index e27dbed..9e72a6c 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -151,6 +151,33 @@
     return op
 
 
+def insert_add_copy_for_const(op, ifm_ofm_shape):
+    assert op.type == Op.Const
+    ofm = op.ofm
+    copy_tens = ofm.clone()
+    op.set_output_tensor(copy_tens)
+
+    name = ofm.name + "_add"
+    ifm2 = create_const_tensor(
+        name + "_zero_scalar",
+        [1],
+        copy_tens.dtype,
+        [0],
+        copy_tens.dtype.as_numpy_type(),
+        quantization=copy_tens.quantization,
+    )
+    copy_op = create_add_nop(name)
+    copy_op.add_input_tensor(copy_tens)
+    copy_op.add_input_tensor(ifm2)
+    copy_op.set_output_tensor(ofm)
+    copy_op.ifm_shapes.append(ifm_ofm_shape)
+    copy_op.ifm_shapes.append(Shape4D(ifm2.shape))
+    copy_op.ofm_shapes.append(ifm_ofm_shape)
+    copy_op.run_on_npu = True
+
+    DebugDatabase.add_optimised(op, copy_op)
+
+
 # TODO can we change to add for both TFLite and TOSA?
 def insert_add_copy_op_after_tens(tens, ifm_ofm_shape):
     tens_cons_list_copy = tens.consumer_list.copy()
@@ -184,51 +211,55 @@
     DebugDatabase.add_optimised(tens.ops[0], copy_op)
 
 
+def get_shape_for_copy_op(shape):
+    # remove dimensions that are set to 1
+    new_shape = []
+    for dim in shape:
+        if dim != 1:
+            new_shape.append(dim)
+    if not new_shape:
+        new_shape = [1]
+
+    rank = len(new_shape)
+    if rank > 3:
+        # Reshape so that batch becomes 1, by moving elements to H dimension
+        n = rank - 2
+        h = 1
+        for i in range(n):
+            h *= shape[i]
+        new_shape = Shape4D(new_shape[n:]).with_height(h)
+    else:
+        new_shape = Shape4D(new_shape)
+    return new_shape
+
+
 def fix_sg_input_output_tosa(op, arch, nng):
-    if not op.run_on_npu or op.type not in (Op.Reshape, Op.Identity):
-        return op
 
-    # For the Reshape operators we want to remove, tensors are removed.
-    # But in order to to do this, they cannot be outputs of the sg,
-    # this need to be fixed prior to the removal.
-    # Solution is to add a copy op, to maintain the original tensor.
-    # This is also valid when reshape ifm/ofm is produced respectively
-    # consumed by CPU
+    if op.type == Op.Const and any(ofm_cons is None for ofm_cons in op.ofm.consumer_list):
+        # Const operator with sg output, insert copy op before the ofm
+        new_shape = get_shape_for_copy_op(op.ofm.shape.copy())
+        insert_add_copy_for_const(op, new_shape)
+    elif op.run_on_npu and op.type in (Op.Reshape, Op.Identity):
+        # For the Reshape operators we want to remove, tensors are removed.
+        # But in order to to do this, they cannot be outputs of the sg,
+        # this need to be fixed prior to the removal.
+        # Solution is to add a copy op, to maintain the original tensor.
+        # This is also valid when reshape ifm/ofm is produced respectively
+        # consumed by CPU
 
-    # Check if operator ifm/ofm are sg ifm/ofm
-    ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
-    ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
-    ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
-    # Check if ifm/ofm is produced repectivly consumed by CPU
-    ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
-    ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
+        # Check if operator ifm/ofm are sg ifm/ofm
+        ifm_is_sg_ifm = op.ifm.ops[0].type in (Op.Placeholder, Op.SubgraphInput, Op.Const)
+        ifm_is_sg_ofm = any(ifm_cons is None for ifm_cons in op.ifm.consumer_list)
+        ofm_is_sg_ofm = any(ofm_cons is None for ofm_cons in op.ofm.consumer_list)
+        # Check if ifm/ofm is produced repectivly consumed by CPU
+        ifm_is_cpu_produced = any(ifm_prod is not None and not ifm_prod.run_on_npu for ifm_prod in op.ifm.ops)
+        ofm_is_cpu_consumed = any(ofm_cons is not None and not ofm_cons.run_on_npu for ofm_cons in op.ofm.consumer_list)
 
-    if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
-        # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Reshape
-
-        # Decide on ifm/ofm shapes for the copy op based on ifm
-        shape = op.ifm.shape.copy()
-        # remove dimensions that are set to 1
-        new_shape = []
-        for dim in shape:
-            if dim != 1:
-                new_shape.append(dim)
-        if not new_shape:
-            new_shape = [1]
-
-        rank = len(new_shape)
-        if rank > 3:
-            # Reshape so that batch becomes 1, by moving elements to H dimension
-            n = rank - 2
-            h = 1
-            for i in range(n):
-                h *= shape[i]
-            new_shape = Shape4D(new_shape[n:]).with_height(h)
-        else:
-            new_shape = Shape4D(new_shape)
-
-        insert_add_copy_op_after_tens(op.ifm, new_shape)
-
+        if (ifm_is_sg_ofm or ifm_is_sg_ifm or ifm_is_cpu_produced) and (ofm_is_sg_ofm or ofm_is_cpu_consumed):
+            # Both ifm and ofm need to persist, but only ifm need a copy, in order to remove the Operator
+            # Decide on ifm/ofm shapes for the copy op based on ifm
+            new_shape = get_shape_for_copy_op(op.ifm.shape.copy())
+            insert_add_copy_op_after_tens(op.ifm, new_shape)
     return op
 
 
@@ -862,7 +893,7 @@
     # Handle sg input output
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=False,
+            nng, sg, arch, [], [fix_sg_input_output_tosa], rewrite_unsupported=True,
         )
 
     # Removal of reshapes