MLBEDSW-7846: Number of CPU Ops reported is wrong

- Added support for multiple npu subgraphs to have the same cpu output tensor

Change-Id: I2e787306dd64af9b03cdf2bacb4c9ff7119f6c49
Signed-off-by: William Isaksson <william.isaksson@arm.com>
diff --git a/ethosu/vela/extract_npu_subgraphs.py b/ethosu/vela/extract_npu_subgraphs.py
index 5e9a5b5..dcc8687 100644
--- a/ethosu/vela/extract_npu_subgraphs.py
+++ b/ethosu/vela/extract_npu_subgraphs.py
@@ -110,17 +110,21 @@
 
 
 def rewrite_tensor_npu_producer_cpu_consumers(
-    orig_tens, call_ps, startup_init_ps, npu_subgraph, cpu_subgraph, subgraph_for_pass
+    orig_tens, call_ps, npu_subgraph, cpu_subgraph, subgraph_for_pass, multiple_npu_sg_have_same_cpu_out_tens
 ):
+    if multiple_npu_sg_have_same_cpu_out_tens:
+        new_tens = orig_tens
+        orig_tens = orig_tens.src_tensor
+    else:
+        new_tens = orig_tens.clone("")
+        orig_tens.name = orig_tens.name + "_cpu"
+        new_tens.ops = []
 
-    new_tens = orig_tens.clone("")
-    orig_tens.name = orig_tens.name + "_cpu"
     npu_subgraph.output_tensors.append(orig_tens)
 
     call_ps.outputs.append(new_tens)
     call_ps.primary_op.outputs.append(new_tens)
-    new_tens.ops = [call_ps.primary_op]
-
+    new_tens.ops.append(call_ps.primary_op)
     # Elementwise op can not overwrite ifm if input is used by many consumers
     if orig_tens in npu_subgraph.input_tensors and len(orig_tens.consumers()) > 1:
         new_tens.ifm_write_protected = True
@@ -235,16 +239,28 @@
 
                 dest_sgs = [subgraph_for_pass[op.scheduled_pass] for op in tens.consumers() if op is not None]
                 need_rewrite = False
+                multiple_npu_sg_have_same_cpu_out_tens = False
+                output_tensor = tens
                 for sg in dest_sgs:
                     if sg != curr_sg:
                         need_rewrite = True
                         break
-                if tens in orig_sg.output_tensors:
-                    need_rewrite = True
+                for orig_out_tens in orig_sg.output_tensors:
+                    if tens == orig_out_tens:
+                        need_rewrite = True
+                    elif tens.equivalence_id == orig_out_tens.equivalence_id:
+                        need_rewrite = True
+                        multiple_npu_sg_have_same_cpu_out_tens = True
+                        output_tensor = orig_out_tens
 
                 if need_rewrite:
                     rewrite_tensor_npu_producer_cpu_consumers(
-                        tens, call_pass[curr_sg], startup_init_passes[curr_sg], curr_sg, orig_sg, subgraph_for_pass
+                        output_tensor,
+                        call_pass[curr_sg],
+                        curr_sg,
+                        orig_sg,
+                        subgraph_for_pass,
+                        multiple_npu_sg_have_same_cpu_out_tens,
                     )
 
         for tens in curr_sg.output_tensors: