MLBEDSW-3654 Add/use op ifm/ofm shapes

Add ifm/ofm shapes to op
Changed to rely on these shapes

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I571535a1dcadc2bdb04a3c727a8e1c49703b174d
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 905263d..18a419c 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -56,6 +56,7 @@
         # Ensure correct ifm and ifm2 order
         if match_tensor(ps.inputs[0], ps.primary_op.inputs[1]) and match_tensor(ps.inputs[1], ps.primary_op.inputs[0]):
             ps.ifm_tensor, ps.ifm2_tensor = ps.ifm2_tensor, ps.ifm_tensor
+            ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
 
         for op in ps.ops:
             if op.type == Op.SplitSliceRead:
@@ -77,13 +78,20 @@
                 ifm_idx += 1
 
     ifm_tensor = ps.ifm_tensor
+    ifm_shape = None
+    if ifm_tensor.shape != []:
+        ifm_shape = ps.ifm_shapes[0]
     ifm2_tensor = ps.ifm2_tensor
+    ifm2_shape = None
+    if ifm2_tensor is not None and ifm2_tensor.shape != []:
+        ifm2_shape = ps.ifm_shapes[1]
     ofm_tensor = ps.ofm_tensor
+    ofm_shape = ps.ofm_shapes[0]
     weight_tensor = ps.weight_tensor
     scale_tensor = ps.scale_tensor
 
-    ofm_start = [0] * len(ofm_tensor.shape)
-    ofm_end = list(ofm_tensor.shape)
+    ofm_start = [0] * len(ofm_shape)
+    ofm_end = list(ofm_shape)
 
     strides = None
     skirt = None
@@ -92,9 +100,9 @@
         strides = ps.primary_op.attrs.get("strides", None)
         skirt = ps.primary_op.attrs.get("skirt", None)
         if ps.primary_op.type == Op.Conv2DBackpropInputSwitchedBias:
-            upscaling = ofm_tensor.shape[-3] // ifm_tensor.shape[-3]
+            upscaling = ofm_shape[-3] // ifm_shape[-3]
         elif ps.primary_op.type == Op.ResizeBilinear:
-            upscaling = round_up_divide(ofm_tensor.shape[-3], ifm_tensor.shape[-3])
+            upscaling = round_up_divide(ofm_shape[-3], ifm_shape[-3])
 
     concat_axis = 0
     concat_offset = 0
@@ -125,7 +133,7 @@
             ifm_box = None
             ifm2_box = None
 
-            if ifm_tensor.shape != []:
+            if ifm_shape is not None:
                 ifm_box, _, _ = ofm_box.transform_with_strides_and_skirt(
                     strides,
                     skirt,
@@ -138,16 +146,9 @@
                 )
             else:
                 ifm_box = Box([], [])
-            if ifm2_tensor is not None and ifm2_tensor.shape != []:
+            if ifm2_shape is not None:
                 ifm2_box, _, _ = ofm_box.transform_with_strides_and_skirt(
-                    strides,
-                    skirt,
-                    ifm2_tensor.shape,
-                    npu_block_type,
-                    concat_axis,
-                    concat_offset,
-                    split_offsets[1],
-                    upscaling,
+                    strides, skirt, ifm2_shape, npu_block_type, concat_axis, concat_offset, split_offsets[1], upscaling,
                 )
             else:
                 ifm2_box = Box([], [])
@@ -212,19 +213,17 @@
 
     elif strat == SchedulingStrategy.IfmStream:
         y_step = block_config[0]
-        y_start = 0
-        y_dim = 1
-        if len(ofm_tensor.shape) >= 3:
-            y_start = ofm_start[-3]
-            y_dim = ofm_end[-3]
+        y_start = ofm_start[-3]
+        y_dim = ofm_end[-3]
+
         if idx > 0:
             ifm_y_present = 0
             prev_pass = passes[idx - 1]
             prev_pass_gen = generate_high_level_command_stream_for_pass(strat, passes, block_configs, idx - 1)
         else:
             ifm_y_present = 1
-            if len(ifm_tensor.shape) >= 3:
-                ifm_y_present = ifm_tensor.shape[-3]
+            if len(ifm_shape) >= 3:
+                ifm_y_present = ifm_shape[-3]
             prev_pass_gen = []
             prev_pass = None
 
@@ -243,9 +242,8 @@
 
         for start in range(y_start, y_dim, y_step):
             end = min(start + y_step, y_dim)
-            if len(ofm_tensor.shape) >= 3:
-                ofm_start[-3] = start
-                ofm_end[-3] = end
+            ofm_start[-3] = start
+            ofm_end[-3] = end
             ofm_box = Box(ofm_start, ofm_end)
 
             k_height = 1
@@ -259,7 +257,7 @@
             ifm_box, pad_top, pad_bottom = ofm_box.transform_with_strides_and_skirt(
                 strides,
                 skirt,
-                ifm_tensor.shape,
+                ifm_shape,
                 npu_block_type,
                 concat_axis,
                 concat_offset,
@@ -381,11 +379,15 @@
     for cmd in generate_high_level_command_stream_for_pass_list(strat, passes, block_configs):
         if cmd.is_npu_pass_command():
             if cmd.is_first:
-                ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(cmd.ifm_box.start_coord, is_top_box=False)
+                ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(
+                    cmd.ifm_box.start_coord, shape=cmd.ps.ifm_shapes[0], is_top_box=False
+                )
                 if ifm_read is None:
                     return 0
             if cmd.is_last:
-                write_offset = cmd.ofm_tensor.address_offset_for_coordinate(cmd.ofm_box.end_coord, is_top_box=True)
+                write_offset = cmd.ofm_tensor.address_offset_for_coordinate(
+                    cmd.ofm_box.end_coord, shape=cmd.ps.ofm_shapes[0], is_top_box=True
+                )
                 if write_offset is None:
                     return 0
                 highest_ofm_write = max(write_offset, highest_ofm_write)
@@ -396,7 +398,9 @@
                 min_overlap = min(min_overlap, can_overwrite)
 
             if cmd.is_first:
-                ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(cmd.ifm_box.end_coord, is_top_box=True)
+                ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(
+                    cmd.ifm_box.end_coord, shape=cmd.ps.ifm_shapes[0], is_top_box=True
+                )
 
     min_overlap = max(min_overlap, 0)
     return min_overlap