MLBEDSW-2903 Split mapping to tensor

Split mapping to tensor

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Ic143f3b4d37f6904edd8f119eff1d108f70b5026
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index d5a6341..50b913d 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -50,18 +50,29 @@
     npu_block_type = ps.npu_block_type
     split_offsets = [None, None]  # offset for [ifm, ifm2]
 
-    ifm_idx = 0
-    for op in ps.ops:
-        if op.type == "SplitSliceRead":
-            split_offsets[ifm_idx] = op.attrs["split_start"]
-            ps.primary_op.attrs["fused_memory_function"] = op.type
-            ifm_idx += 1
-
     if len(ps.inputs) == 2 and npu_block_type == NpuBlockType.ElementWise:
-        # Ensure correct imf and ifm2 order
+        # Ensure correct ifm and ifm2 order
         if match_tensor(ps.inputs[0], ps.primary_op.inputs[1]) and match_tensor(ps.inputs[1], ps.primary_op.inputs[0]):
             ps.ifm_tensor, ps.ifm2_tensor = ps.ifm2_tensor, ps.ifm_tensor
-            split_offsets[0], split_offsets[1] = split_offsets[1], split_offsets[0]
+
+        for op in ps.ops:
+            if op.type == "SplitSliceRead":
+                ps.primary_op.attrs["fused_memory_function"] = op.type
+                assert len(op.inputs) == 1
+                if match_tensor(ps.ifm_tensor, op.inputs[0]):
+                    split_offsets[0] = op.attrs["split_start"]
+                elif match_tensor(ps.ifm2_tensor, op.inputs[0]):
+                    split_offsets[1] = op.attrs["split_start"]
+                else:
+                    assert False
+    else:
+        ifm_idx = 0
+        for op in ps.ops:
+            if op.type == "SplitSliceRead":
+                assert ifm_idx < 2
+                split_offsets[ifm_idx] = op.attrs["split_start"]
+                ps.primary_op.attrs["fused_memory_function"] = op.type
+                ifm_idx += 1
 
     ifm_tensor = ps.ifm_tensor
     ifm2_tensor = ps.ifm2_tensor