Revert "Revert "MLBEDSW-3645 4D class for op ifm/ofm shapes""

This reverts commit df0a5905177f3a1b836076bc3f9f39b2e86f1794.

Reason for revert: <INSERT REASONING HERE>

Change-Id: I891c66fb29db9d25e942947e8d1c29a10610de51
diff --git a/ethosu/vela/debug_database.py b/ethosu/vela/debug_database.py
index 203503f..77e13eb 100644
--- a/ethosu/vela/debug_database.py
+++ b/ethosu/vela/debug_database.py
@@ -23,7 +23,7 @@
 
 from . import numeric_util
 from .operation import Operation
-
+from .shape4d import Shape4D
 
 UntypedDict = Dict[Any, Any]
 UntypedList = List[Any]
@@ -79,9 +79,18 @@
                 src_uid = cls._sourceUID[parent]
             uid = len(cls._optimisedUID)
             cls._optimisedUID[op] = (uid, src_uid)
-            ofm_shape = op.ofm_shapes[0] if op.ofm_shapes else numeric_util.full_shape(3, op.outputs[0].shape, 1)
+            ofm_shape = op.ofm_shapes[0] if op.ofm_shapes else Shape4D(op.outputs[0].shape)
             cls._optimisedTable.append(
-                [uid, src_uid, op.type, op.kernel.width, op.kernel.height, ofm_shape[-2], ofm_shape[-3], ofm_shape[-1]]
+                [
+                    uid,
+                    src_uid,
+                    op.type,
+                    op.kernel.width,
+                    op.kernel.height,
+                    ofm_shape.width,
+                    ofm_shape.height,
+                    ofm_shape.depth,
+                ]
             )
 
     @classmethod
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index fdb0fae..1128a31 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -37,6 +37,7 @@
 from .operation import Operation
 from .operation import Padding
 from .operation_util import create_avgpool_nop
+from .shape4d import Shape4D
 from .softmax import SoftMax
 from .tensor import check_quantized_tens_scaling_equal
 from .tensor import create_const_tensor
@@ -82,6 +83,7 @@
             new_op.run_on_npu = True
             tens.ops.append(new_op)
             DebugDatabase.add_optimised(concat_op, new_op)
+            new_op.set_ifm_ofm_shapes()
         assert tens.shape[axis] == offset
 
         # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
@@ -121,7 +123,8 @@
                 if out == tens:
                     break
                 axis_4D = axis + (4 - len(out.shape))
-                offset_start[axis_4D] += split_op.ofm_shapes[idx][axis_4D]
+
+                offset_start[axis_4D] += split_op.ofm_shapes[idx].get_dim(axis_4D)
 
                 # If start offset is not a multiple of 16 in the C-dimension, NHCWB16 need to be avoided in the input
                 if (offset_start[-1] % 16) != 0:
@@ -132,6 +135,7 @@
         new_op.attrs["split_start"] = offset_start
         new_op.run_on_npu = True
         new_op.set_output_tensor(tens)
+        new_op.set_ifm_ofm_shapes()
         DebugDatabase.add_optimised(split_op, new_op)
 
     return tens
@@ -189,6 +193,7 @@
     if op.type == Op.Conv2DBackpropInput:
         # flip the inputs
         op.inputs[0], op.inputs[2] = op.inputs[2], op.inputs[0]
+        op.set_ifm_ofm_shapes()
         op.type = Op.Conv2DBackpropInputSwitchedBias
 
         # Update strides
@@ -216,8 +221,7 @@
     # Set the add inputs
     op.inputs[1] = op.inputs[0]
     op.inputs[0] = tens
-    op.ifm_shapes = []
-    op.ofm_shapes = []
+    op.set_ifm_ofm_shapes()
 
     return op
 
@@ -323,14 +327,14 @@
         ofm = op.outputs[0]
         # Check if the FC is 2D and first dimension indicates batching
         # TOD0 op.ifm_shape[0] > 1 is enough when refactory is complete
-        if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0][0] > 1:
+        if len(ifm.shape) == len(ofm.shape) == 2 and ifm.shape[0] > 1 and op.ifm_shapes[0].batch > 1:
             n = ifm.shape[0]
             batching_split = {4: (2, 2), 8: (2, 4), 16: (4, 4)}
             h, w = batching_split.get(n, (1, n))
 
             prev_op = ifm.ops[0]
             desired_shape = [1, h, w, ifm.shape[-1]]
-            op.ifm_shapes[0] = desired_shape
+            op.ifm_shapes[0] = Shape4D(desired_shape)
 
             if len(ifm.consumer_list) == 1 and prev_op is not None and prev_op.type == Op.Reshape:
                 # There is a preceding Reshape
@@ -356,7 +360,7 @@
             weight_tensor.set_all_shapes(list(weight_tensor.quant_values.shape))
 
             desired_shape = [1, h, w, ofm.shape[-1]]
-            op.ofm_shapes[0] = desired_shape
+            op.ofm_shapes[0] = Shape4D(desired_shape)
 
             if (
                 len(ofm.consumer_list) == 1
@@ -395,6 +399,7 @@
             reshape_op.attrs["new_shape"] = desired_shape
             reshape_op.inputs = [inp, new_shape_tens]
             reshape_op.set_output_tensor(reshape_out)
+            reshape_op.set_ifm_ofm_shapes()
             DebugDatabase.add_optimised(op, reshape_op)
 
             op.inputs[idx] = reshape_out
@@ -413,6 +418,7 @@
         act_op.set_output_tensor(out_tens)
         act_op.add_input_tensor(intermediate_tens)
         op.set_output_tensor(intermediate_tens)
+        act_op.set_ifm_ofm_shapes()
 
     return op
 
@@ -457,7 +463,7 @@
         new_shape_tens = create_const_tensor(op.name + "_reshape_shape", [1], DataType.int32, tens.shape)
 
         for idx, out_tens in enumerate(op.outputs):
-            op.ofm_shapes[idx] = new_shape_tens
+            op.ofm_shapes[idx] = Shape4D(new_shape_tens.shape)
             reshape_in = out_tens.clone("_reshaped")
             reshape_in.set_all_shapes(reshape_input_shape)
             reshape_in.ops = [op]
@@ -466,6 +472,7 @@
             reshape_op.attrs["new_shape"] = reshape_input_shape
             reshape_op.inputs = [reshape_in, new_shape_tens]
             reshape_op.set_output_tensor(out_tens)
+            reshape_op.set_ifm_ofm_shapes()
 
             op.outputs[idx] = reshape_in
 
@@ -493,6 +500,7 @@
             reshape_op.attrs["new_shape"] = reshape_input_shape
             reshape_op.inputs = [reshape_in, new_shape_tens]
             reshape_op.set_output_tensor(out_tens)
+            reshape_op.set_ifm_ofm_shapes()
             DebugDatabase.add_optimised(op, reshape_op)
 
             op.outputs[idx] = reshape_in
@@ -588,7 +596,8 @@
     # caching/double buffering for the weights.
     # (Weights dont need to be reloaded for convs when IFM H and W are 1)
     if op.type == Op.Conv2DBias:
-        _, h, w, _ = op.ifm_shapes[0]
+        h = op.ifm_shapes[0].height
+        w = op.ifm_shapes[0].width
         kh, kw, _, _ = op.inputs[1].shape
         if h == 1 and w == 1 and kh == 1 and kw == 1:
             # Overwrite this op as a Fully Connected Op
@@ -616,9 +625,11 @@
             reshape_op.attrs["new_shape"] = orig_ofm_tensor.shape
             reshape_op.inputs = [fc_ofm_tensor, new_shape_tens]
             reshape_op.set_output_tensor(orig_ofm_tensor)
+            reshape_op.set_ifm_ofm_shapes()
 
             # Replace this ops OFM to point to the 2D tensor
             op.outputs[0] = fc_ofm_tensor
+            op.set_ifm_ofm_shapes()
             # Record optimisation in debug database
             DebugDatabase.add_optimised(op, reshape_op)
             DebugDatabase.add_optimised(op, op)
@@ -649,6 +660,7 @@
 
             relu_fused_op.add_input_tensor(ifm)
             relu_fused_op.set_output_tensor(ofm)
+            relu_fused_op.set_ifm_ofm_shapes()
             op = relu_fused_op
     return op
 
@@ -668,8 +680,8 @@
             act_op_out = act_op.inputs[0].clone("_acted")
             act_op_out.quantization = op.outputs[0].quantization.clone()
             act_op.set_output_tensor(act_op_out)
-            act_op.ifm_shapes[0] = full_shape(4, prep_op.inputs[0].shape, 1)
-            act_op.ofm_shapes[0] = full_shape(4, act_op_out.shape, 1)
+            act_op.ifm_shapes[0] = Shape4D(prep_op.inputs[0].shape)
+            act_op.ofm_shapes[0] = Shape4D(act_op_out.shape)
 
             # Update the consumer list
             act_op_out.consumer_list = op.outputs[0].consumer_list.copy()
@@ -839,6 +851,7 @@
     mul_alpha.add_input_tensor(alpha_tens)
     fm_alpha = ofm.clone(op.name + "_alpha")
     mul_alpha.set_output_tensor(fm_alpha)
+    mul_alpha.set_ifm_ofm_shapes()
     DebugDatabase.add_optimised(op, mul_alpha)
 
     if check_quantized_tens_scaling_equal(ifm, ofm):
@@ -860,6 +873,7 @@
         mul_identity.add_input_tensor(identity_tens)
         fm_id = ofm.clone(op.name + "_id")
         mul_identity.set_output_tensor(fm_id)
+        mul_identity.set_ifm_ofm_shapes()
         DebugDatabase.add_optimised(op, mul_identity)
 
     # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs
@@ -890,7 +904,7 @@
     quantization.zero_point = 0
     tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
     op.add_input_tensor(tens)
-    op.ifm_shapes.append(full_shape(4, tens.shape, 1))
+    op.ifm_shapes.append(Shape4D(tens.shape))
 
     # The LUT must be applied without any preceding rescaling (the LUT itself performs the rescale),
     # so even if the OFM has a different scale than the IFM, the generated OFM scale instructions
@@ -1158,11 +1172,7 @@
     for idx, sg in enumerate(nng.subgraphs):
         # combined rewrite graph pass
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng,
-            sg,
-            arch,
-            [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split],
-            [set_ifm_ofm_op_shapes],
+            nng, sg, arch, [fixup_unpack_output, fixup_stridedslice_output, rewrite_concat, rewrite_split], [],
         )
 
     if verbose_graph:
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index bb4f142..9cbda45 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -15,11 +15,14 @@
 # limitations under the License.
 # Description:
 # Contains classes that hold commands for the high-level command stream (one command per DMA or NPU stripe).
+from typing import List
+
 import numpy as np
 
 from .architecture_features import Block
 from .numeric_util import round_up_divide
 from .operation import NpuBlockType
+from .shape4d import Shape4D
 
 
 class Box:
@@ -32,15 +35,15 @@
 
     def transform_with_strides_and_skirt(
         self,
-        strides,
-        skirt,
-        ifm_shape,
-        npu_block_type,
-        concat_axis=0,
-        concat_offset=0,
-        split_offset=None,
-        k_height=1,
-        upscaling_factor=1,
+        strides: List[int],
+        skirt: List[int],
+        ifm_shape: Shape4D,
+        npu_block_type: NpuBlockType,
+        concat_axis: int = 0,
+        concat_offset: int = 0,
+        split_offset: int = None,
+        k_height: int = 1,
+        upscaling_factor: int = 1,
     ):
         new_start_coord = list(self.start_coord)
         new_end_coord = list(self.end_coord)
@@ -58,15 +61,15 @@
         ):
             # these types of operations do a "dot product" or sum over the entire IFM
             new_start_coord[-1] = 0
-            new_end_coord[-1] = ifm_shape[-1]
+            new_end_coord[-1] = ifm_shape.depth
 
-        if npu_block_type == NpuBlockType.ElementWise and min(len(new_end_coord), len(ifm_shape)) >= 1:
-            new_end_coord[-1] = min(new_end_coord[-1], ifm_shape[-1])
-        if min(len(new_end_coord), len(ifm_shape)) >= 2:
-            new_end_coord[-2] = min(new_end_coord[-2], ifm_shape[-2] * upscaling_factor)
-        if min(len(new_end_coord), len(ifm_shape)) >= 3:
+        if npu_block_type == NpuBlockType.ElementWise and len(new_end_coord) >= 1:
+            new_end_coord[-1] = min(new_end_coord[-1], ifm_shape.depth)
+        if len(new_end_coord) >= 2:
+            new_end_coord[-2] = min(new_end_coord[-2], ifm_shape.width * upscaling_factor)
+        if len(new_end_coord) >= 3:
             original_end_coord = list(new_end_coord)
-            new_end_coord[-3] = min(new_end_coord[-3], ifm_shape[-3] * upscaling_factor)
+            new_end_coord[-3] = min(new_end_coord[-3], ifm_shape.height * upscaling_factor)
 
         pad_top = 0
         pad_bottom = 0
@@ -74,7 +77,7 @@
             if len(new_start_coord) >= 2:
                 stride = strides[2]
                 new_start_coord[-2] = max(new_start_coord[-2] * stride - skirt[1], 0)
-                new_end_coord[-2] = min(new_end_coord[-2] * stride + skirt[3], ifm_shape[-2])
+                new_end_coord[-2] = min(new_end_coord[-2] * stride + skirt[3], ifm_shape.width)
 
             if len(new_start_coord) >= 3:
                 stride = strides[1]
@@ -86,23 +89,20 @@
                 pad_top = max(0, 0 - new_start_coord[-3]) + skirt_top_remainder
                 new_start_coord[-3] = max(new_start_coord[-3], 0)
 
-                while len(ifm_shape) < 3:
-                    ifm_shape = [1] + ifm_shape
-
-                if (new_end_coord[-3] * stride + skirt[2]) > (ifm_shape[-3] * upscaling_factor):
+                if (new_end_coord[-3] * stride + skirt[2]) > (ifm_shape.height * upscaling_factor):
                     # pad_bottom is calculated based the diff between the end position of the weight kernel,
                     # after last stride and the ifm height.
-                    if upscaling_factor != 1 and original_end_coord[-3] > ifm_shape[-3] * upscaling_factor:
+                    if upscaling_factor != 1 and original_end_coord[-3] > ifm_shape.height * upscaling_factor:
                         # Special case for Transpose Convolution with VALID padding.
-                        pad_bottom = original_end_coord[-3] - (ifm_shape[-3] * upscaling_factor)
+                        pad_bottom = original_end_coord[-3] - (ifm_shape.height * upscaling_factor)
                     else:
                         k_start = new_start_coord[-3] - pad_top
-                        pad_bottom = max(0, k_start + total_stride + k_height - (ifm_shape[-3] * upscaling_factor))
+                        pad_bottom = max(0, k_start + total_stride + k_height - (ifm_shape.height * upscaling_factor))
 
                 # Adjust for upscaling
                 new_start_coord[-3] = max(new_start_coord[-3] // upscaling_factor, 0)
                 new_end_coord[-3] = new_end_coord[-3] * stride + skirt[2] + (skirt[2] % upscaling_factor)
-                new_end_coord[-3] = max(min(new_end_coord[-3] // upscaling_factor, ifm_shape[-3]), 1)
+                new_end_coord[-3] = max(min(new_end_coord[-3] // upscaling_factor, ifm_shape.height), 1)
 
         return Box(new_start_coord, new_end_coord), pad_top, pad_bottom
 
@@ -197,7 +197,7 @@
         self.pad_top = pad_top
         self.pad_bottom = pad_bottom
         for i in range(len(self.ofm_box.end_coord)):
-            assert self.ofm_box.end_coord[i] <= ps.ofm_shapes[0][i]
+            assert self.ofm_box.end_coord[i] <= ps.ofm_shapes[0].get_dim(i)
 
     def is_npu_pass_command(self):
         return True
@@ -251,76 +251,6 @@
         assert res >= 0
         return res
 
-    def get_single_block_command(self, block_idx):
-        block_cfg = (self.block_config[0], self.block_config[1], self.block_config[3])
-        dims = self.get_block_dimensions()
-        strides = dims[1] * dims[2], dims[2], 1
-        coord = []
-        idx_left = block_idx
-        for s in strides:
-            c = idx_left // s
-            idx_left -= c * s
-            coord.append(c)
-
-        assert idx_left == 0
-
-        # put in dummy height/widths in case we're dealing with FC layers
-        ofm_start = list(self.ofm_box.start_coord)
-        ofm_end = list(self.ofm_box.end_coord)
-
-        # cut out a nice block shape
-        for idx in (-1, -2, -3):
-            if len(ofm_start) >= -idx:
-                ofm_start[idx] += block_cfg[idx] * coord[idx]
-                ofm_end[idx] = min(ofm_end[idx], ofm_start[idx] + block_cfg[idx])
-
-        ps = self.ps
-        strides = None
-        skirt = None
-        if ps.primary_op is not None:
-            strides = ps.primary_op.attrs.get("strides", None)
-            skirt = ps.primary_op.attrs.get("skirt", None)
-        npu_block_type = ps.npu_block_type
-
-        ofm_box = Box(ofm_start, ofm_end)
-        ifm_box, _, _ = ofm_box.transform_with_strides_and_skirt(
-            strides, skirt, self.ifm_tensor.shape, npu_block_type, self.concat_axis, self.concat_offset
-        )
-
-        weight_box = None
-        if self.weight_tensor is not None:
-            weight_oc_start = ofm_start[-1]
-            weight_oc_end = ofm_end[-1]
-            if self.concat_axis - len(self.weight_tensor.shape) == -1:
-                weight_oc_start -= self.concat_offset
-                weight_oc_end -= self.concat_offset
-
-            weight_box = Box.make_weight_box(
-                self.weight_tensor.shape,
-                npu_block_type,
-                weight_oc_start,
-                weight_oc_end,
-                self.weight_tensor.weight_transpose_depthwise,
-            )
-
-        return NpuStripe(
-            self.ps,
-            self.block_config,
-            self.is_first,
-            self.is_last,
-            self.is_first_h_stripe,
-            self.is_last_h_stripe,
-            self.ifm_tensor,
-            ifm_box,
-            self.ofm_tensor,
-            ofm_box,
-            self.weight_tensor,
-            weight_box,
-            self.scale_tensor,
-            self.concat_axis,
-            self.concat_offset,
-        )
-
 
 class DMA(Command):
     def __init__(self, ps, in_tensor, out_tensor, box):
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 18a419c..60e62aa 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -27,6 +27,7 @@
 from .operation import create_activation_function
 from .operation import NpuBlockType
 from .operation import Op
+from .shape4d import Shape4D
 from .tensor import TensorPurpose
 
 
@@ -90,8 +91,8 @@
     weight_tensor = ps.weight_tensor
     scale_tensor = ps.scale_tensor
 
-    ofm_start = [0] * len(ofm_shape)
-    ofm_end = list(ofm_shape)
+    ofm_start = [0, 0, 0, 0]
+    ofm_end = ofm_shape.as_list()
 
     strides = None
     skirt = None
@@ -100,9 +101,9 @@
         strides = ps.primary_op.attrs.get("strides", None)
         skirt = ps.primary_op.attrs.get("skirt", None)
         if ps.primary_op.type == Op.Conv2DBackpropInputSwitchedBias:
-            upscaling = ofm_shape[-3] // ifm_shape[-3]
+            upscaling = ofm_shape.height // ifm_shape.height
         elif ps.primary_op.type == Op.ResizeBilinear:
-            upscaling = round_up_divide(ofm_shape[-3], ifm_shape[-3])
+            upscaling = round_up_divide(ofm_shape.height, ifm_shape.height)
 
     concat_axis = 0
     concat_offset = 0
@@ -135,14 +136,7 @@
 
             if ifm_shape is not None:
                 ifm_box, _, _ = ofm_box.transform_with_strides_and_skirt(
-                    strides,
-                    skirt,
-                    ifm_tensor.shape,
-                    npu_block_type,
-                    concat_axis,
-                    concat_offset,
-                    split_offsets[0],
-                    upscaling,
+                    strides, skirt, ifm_shape, npu_block_type, concat_axis, concat_offset, split_offsets[0], upscaling,
                 )
             else:
                 ifm_box = Box([], [])
@@ -163,7 +157,7 @@
                         intermediate_box, _, _ = ofm_box.transform_with_strides_and_skirt(
                             strides,
                             skirt,
-                            intermediate.shape,
+                            Shape4D(intermediate.shape),
                             npu_block_type,
                             concat_axis,
                             concat_offset,
@@ -212,6 +206,7 @@
             )
 
     elif strat == SchedulingStrategy.IfmStream:
+        assert ifm_shape is not None
         y_step = block_config[0]
         y_start = ofm_start[-3]
         y_dim = ofm_end[-3]
@@ -222,8 +217,7 @@
             prev_pass_gen = generate_high_level_command_stream_for_pass(strat, passes, block_configs, idx - 1)
         else:
             ifm_y_present = 1
-            if len(ifm_shape) >= 3:
-                ifm_y_present = ifm_shape[-3]
+            ifm_y_present = ifm_shape.height
             prev_pass_gen = []
             prev_pass = None
 
@@ -276,7 +270,7 @@
                         intermediate_box, _, _ = ofm_box.transform_with_strides_and_skirt(
                             strides,
                             skirt,
-                            intermediate.shape,
+                            Shape4D(intermediate.shape),
                             npu_block_type,
                             concat_axis,
                             concat_offset,
@@ -380,13 +374,13 @@
         if cmd.is_npu_pass_command():
             if cmd.is_first:
                 ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(
-                    cmd.ifm_box.start_coord, shape=cmd.ps.ifm_shapes[0], is_top_box=False
+                    cmd.ifm_box.start_coord, cmd.ps.ifm_shapes[0].as_list(), is_top_box=False
                 )
                 if ifm_read is None:
                     return 0
             if cmd.is_last:
                 write_offset = cmd.ofm_tensor.address_offset_for_coordinate(
-                    cmd.ofm_box.end_coord, shape=cmd.ps.ofm_shapes[0], is_top_box=True
+                    cmd.ofm_box.end_coord, cmd.ps.ofm_shapes[0].as_list(), is_top_box=True
                 )
                 if write_offset is None:
                     return 0
@@ -399,7 +393,7 @@
 
             if cmd.is_first:
                 ifm_read = cmd.ifm_tensor.address_offset_for_coordinate(
-                    cmd.ifm_box.end_coord, shape=cmd.ps.ifm_shapes[0], is_top_box=True
+                    cmd.ifm_box.end_coord, cmd.ps.ifm_shapes[0].as_list(), is_top_box=True
                 )
 
     min_overlap = max(min_overlap, 0)
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 9380374..0711702 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -58,6 +58,7 @@
 from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
 from .register_command_stream_util import to_npu_kernel
 from .register_command_stream_util import UNARY_ELEMWISE_OPS
+from .shape4d import Shape4D
 from .tensor import MemType
 from .tensor import Tensor
 from .tensor import TensorBlockTraversal
@@ -231,7 +232,7 @@
     return NpuQuantization(scale_f32=ofm_quant.scale_f32, zero_point=zero_point)
 
 
-def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, fm_shape: List[int]) -> NpuFeatureMap:
+def create_feature_map(tens: Tensor, box: Box, arch: ArchitectureFeatures, fm_shape: Shape4D) -> NpuFeatureMap:
     """Creates feature map with common fields populated"""
     fm = NpuFeatureMap()
     fm.region = get_region(tens, arch)
diff --git a/ethosu/vela/nn_graph.py b/ethosu/vela/nn_graph.py
index 6792517..d2c848a 100644
--- a/ethosu/vela/nn_graph.py
+++ b/ethosu/vela/nn_graph.py
@@ -21,8 +21,10 @@
 # Subgraph - Holds a neural network subgraph, pointing at Tensors, Operations, Passes, and CascadedPasses.
 # Graph - A full neural network graph with one or more Subgraphs.
 import enum
+from typing import List
 
 from .operation import Op
+from .shape4d import Shape4D
 
 
 class PassPlacement(enum.Enum):
@@ -58,8 +60,8 @@
         self.name = name
         self.cascade = None
         self.placement = placement
-        self.ifm_shapes = []
-        self.ofm_shapes = []
+        self.ifm_shapes: List[Shape4D] = []
+        self.ofm_shapes: List[Shape4D] = []
 
         # TODO: rename is_element_wise because it is not the same as an ElementWise operator. It is used by the tensor
         # allocation and requires that the OFM and IFM has the exact same address. Essentially complete overlap.
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index c2ec442..4ca4683 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -48,7 +48,7 @@
 
     if ps2.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct):
         op = ps2.primary_op
-        ifm_block_depth = arch.calc_ifm_block_depth(op.ifm_shapes[0][-1], op.ifm.dtype.size_in_bits())
+        ifm_block_depth = arch.calc_ifm_block_depth(op.ifm_shapes[0].depth, op.ifm.dtype.size_in_bits())
     else:
         ifm_block_depth = block_config_ps2[-1]
 
@@ -231,9 +231,9 @@
         arch.config.ofm_ublock.height == 2
         and npu_block_type
         in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise, NpuBlockType.VectorProduct)
-        and ofm_tens_shape[1] == 1
+        and ofm_tens_shape.height == 1
         # Optimisation only applies for even width tensors
-        and ofm_tens_shape[2] % 2 == 0
+        and ofm_tens_shape.width % 2 == 0
         and kernel_dims[0] == 1
     ):
         ofm_ublock.width = 4
@@ -319,14 +319,14 @@
         cycles_dpu_blk += delay_cycles
 
     if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
-        cycles_dpu_blk *= numeric_util.round_up_divide(ifm_tens_shape[3], ifm_block.depth)
+        cycles_dpu_blk *= numeric_util.round_up_divide(ifm_tens_shape.depth, ifm_block.depth)
 
     cycles_dpu_blk /= arch.ncores
 
     num_ofm_blk = (
-        numeric_util.round_up_divide(ofm_tens_shape[1], ofm_block.height)
-        * numeric_util.round_up_divide(ofm_tens_shape[2], ofm_block.width)
-        * numeric_util.round_up_divide(ofm_tens_shape[3], ofm_block.depth)
+        numeric_util.round_up_divide(ofm_tens_shape.height, ofm_block.height)
+        * numeric_util.round_up_divide(ofm_tens_shape.width, ofm_block.width)
+        * numeric_util.round_up_divide(ofm_tens_shape.depth, ofm_block.depth)
     )
 
     cycles_output_blk = estimate_output_cycles(
@@ -336,7 +336,7 @@
     if scale_tensor:
         cycles_bias_blk = (
             10
-            * min(ofm_block.depth, ofm_tens_shape[3])
+            * min(ofm_block.depth, ofm_tens_shape.depth)
             * arch.memory_latency[scale_tensor.mem_area][BandwidthDirection.Read]
             / 256
         )
@@ -420,8 +420,8 @@
         npu_block_type = primary_op.type.npu_block_type
 
         ifm_tensor, _, weight_tensor, ofm_tensor = ps.get_primary_op_ifm_ifm2_weights_ofm()
-        ifm_tensor_shape = list(ps.primary_op.ifm_shapes[0])
-        ofm_tensor_shape = list(ps.primary_op.ofm_shapes[0])
+        ifm_tensor_shape = ps.primary_op.ifm_shapes[0].clone()
+        ofm_tensor_shape = ps.primary_op.ofm_shapes[0].clone()
 
         if npu_block_type == NpuBlockType.ReduceSum:
             block_traversal = TensorBlockTraversal.DepthFirst
@@ -434,7 +434,7 @@
         else:
             block_traversal = TensorBlockTraversal.Default
         ifm_block_depth = get_ifm_block_depth(
-            npu_block_type, ifm_tensor_shape[3], ifm_tensor.dtype.size_in_bits(), block_traversal, ofm_block.depth
+            npu_block_type, ifm_tensor_shape.depth, ifm_tensor.dtype.size_in_bits(), block_traversal, ofm_block.depth
         )
         ifm_block = arch.get_ifm_block_size(
             ifm_block_depth, ofm_block, primary_op.kernel, ifm_resampling_mode=ifm_tensor.resampling_mode
@@ -448,11 +448,12 @@
             NpuBlockType.ReduceSum,
         ):
             # extent the ifm to full dimension
-            batch_size = ifm_tensor_shape[0]
+
+            batch_size = ifm_tensor_shape.batch
 
             # add in padding
-            ifm_tensor_shape[1] += explicit_padding[0] + explicit_padding[2]  # height += top and bottom
-            ifm_tensor_shape[2] += explicit_padding[1] + explicit_padding[3]  # width  += left and right
+            ifm_tensor_shape.height += explicit_padding[0] + explicit_padding[2]  # height += top and bottom
+            ifm_tensor_shape.width += explicit_padding[1] + explicit_padding[3]  # width  += left and right
 
             if npu_block_type != NpuBlockType.Pooling:
                 if npu_block_type == NpuBlockType.ReduceSum:
@@ -468,9 +469,9 @@
                     weight_tensor_bandwidth_compression_scale = weight_tensor.bandwidth_compression_scale
 
                 nn_ops = (
-                    int(ofm_tensor_shape[0])
-                    * int(ofm_tensor_shape[1])
-                    * int(ofm_tensor_shape[2])
+                    int(ofm_tensor_shape.batch)
+                    * int(ofm_tensor_shape.height)
+                    * int(ofm_tensor_shape.width)
                     * int(weight_tensor_shape[0])
                     * int(weight_tensor_shape[1])
                     * int(weight_tensor_shape[2])
@@ -481,7 +482,7 @@
                     primary_op.attrs["ksize"][1],
                     primary_op.attrs["ksize"][2],
                     1,
-                    ifm_tensor_shape[3],
+                    ifm_tensor_shape.depth,
                 ]
                 weight_tensor_bandwidth_shape = weight_tensor_shape
                 weight_tensor_element_size = 0
@@ -504,8 +505,8 @@
             replacement_read_bws[ifm_tensor] = ifm_tensor.bandwidth() * ifm_read_multiple
 
             weight_read_multiple = numeric_util.round_up_divide(
-                ofm_tensor_shape[1], ofm_block.height
-            ) * numeric_util.round_up_divide(ofm_tensor_shape[2], ofm_block.width)
+                ofm_tensor_shape.height, ofm_block.height
+            ) * numeric_util.round_up_divide(ofm_tensor_shape.width, ofm_block.width)
             replacement_read_bws[weight_tensor] = (
                 batch_size
                 * shape_num_elements(weight_tensor_bandwidth_shape)
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index be26a26..c80e18b 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -26,6 +26,7 @@
 
 from .errors import VelaError
 from .numeric_util import full_shape
+from .shape4d import Shape4D
 
 
 if TYPE_CHECKING:
@@ -372,7 +373,7 @@
     return act
 
 
-def get_slice_offsets(input_shape, offset_tens, offset_mask, is_begin=True):
+def get_slice_offsets(input_shape: List[int], offset_tens: int, offset_mask: int, is_begin: bool = True):
     # For strided slice operator: get start or end offsets
     offsets = len(input_shape) * [0] if is_begin else input_shape[:]
     for idx in range(len(input_shape)):
@@ -427,8 +428,8 @@
         self.op_index = None  # input network operator index
         self.activation_lut = None
         self._kernel = None
-        self.ifm_shapes = []
-        self.ofm_shapes = []
+        self.ifm_shapes: List[Shape4D] = []
+        self.ofm_shapes: List[Shape4D] = []
 
     def clone(self, suffix="_clone"):
         res = Operation(self.type, self.name + suffix)
@@ -707,6 +708,9 @@
         raise VelaError("\n".join(lines))
 
     def set_ifm_ofm_shapes(self):
+        self.ifm_shapes = []
+        self.ofm_shapes = []
+
         ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor = self.get_ifm_ifm2_weights_ofm()
 
         # set all shapes to op, as 4D
@@ -716,24 +720,24 @@
             batch_size = elms // n_in_elems
             assert batch_size * n_in_elems == elms
 
-            self.ifm_shapes.append([batch_size, 1, 1, n_in_elems])
-            self.ofm_shapes.append(ofm_tensor.get_full_shape())
+            self.ifm_shapes.append(Shape4D([batch_size, 1, 1, n_in_elems]))
+            self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
         elif self.type == Op.Softmax:
-            self.ifm_shapes.append(ifm_tensor.get_full_shape())
-            self.ofm_shapes.append(ofm_tensor.get_full_shape())
+            self.ifm_shapes.append(Shape4D(ifm_tensor.get_full_shape()))
+            self.ofm_shapes.append(Shape4D(ofm_tensor.get_full_shape()))
         elif self.type.is_split_op or self.type.is_concat_op():
             for inp in self.inputs:
                 if inp is not None:
-                    self.ifm_shapes.append(full_shape(4, inp.shape, 1))
+                    self.ifm_shapes.append(Shape4D(full_shape(4, inp.shape, 1)))
                 else:
                     self.ifm_shapes.append(None)
             for out in self.outputs:
                 if out is not None:
-                    self.ofm_shapes.append(full_shape(4, out.shape, 1))
+                    self.ofm_shapes.append(Shape4D(full_shape(4, out.shape, 1)))
                 else:
                     self.ofm_shapes.append(None)
         else:
-            self.ifm_shapes.append(full_shape(4, ifm_tensor.shape, 1))
+            self.ifm_shapes.append(Shape4D(full_shape(4, ifm_tensor.shape, 1)))
             if ifm2_tensor is not None:
-                self.ifm_shapes.append(full_shape(4, ifm2_tensor.shape, 1))
-            self.ofm_shapes.append(full_shape(4, ofm_tensor.shape, 1))
+                self.ifm_shapes.append(Shape4D(full_shape(4, ifm2_tensor.shape, 1)))
+            self.ofm_shapes.append(Shape4D(full_shape(4, ofm_tensor.shape, 1)))
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 095a78d..8f6660c 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -231,9 +231,9 @@
                 ofm_tensor = op.ofm
                 if ofm_tensor is None:
                     ofm_tensor = op.outputs[0]
-                build_pass((op,), ofm_tensor)
+                build_pass((op,), ofm_tensor, op.ofm_shapes[0].clone())
 
-    def build_pass(start_ops_to_process, ofm_tensor=None):
+    def build_pass(start_ops_to_process, ofm_tensor=None, ofm_shapes=None):
         reverse_ops_list = []
         curr_flags = PassFlags.Empty
         npu_block_type = NpuBlockType.Default
@@ -416,8 +416,7 @@
                 ps.ifm_shapes.append(ps.primary_op.ifm_shapes[0])
 
         ps.ofm_tensor = ofm_tensor
-        if ps.primary_op is not None:
-            ps.ofm_shapes.append(ps.primary_op.ofm_shapes[0])
+        ps.ofm_shapes.append(ofm_shapes)
 
         assert ps.placement != PassPlacement.Npu or ps.ofm_tensor is not None
         ps.weight_tensor = ps.get_primary_op_ifm_weights()[1]
@@ -453,11 +452,11 @@
             avgpool_out = inp.clone("_avgpooled")
             avgpool_out.consumer_list.append(op)
             avgpool_op.set_output_tensor(avgpool_out)
-            avgpool_op.ifm_shapes = op.ifm_shapes
-            avgpool_op.ofm_shapes = op.ofm_shapes
+            avgpool_op.set_ifm_ofm_shapes()
 
             op.inputs[0] = avgpool_out
             op_list.insert(0, avgpool_op)
+            op.set_ifm_ofm_shapes()
 
             DebugDatabase.add_optimised(op, avgpool_op)
             return avgpool_op
diff --git a/ethosu/vela/shape4d.py b/ethosu/vela/shape4d.py
new file mode 100644
index 0000000..a1b4fea
--- /dev/null
+++ b/ethosu/vela/shape4d.py
@@ -0,0 +1,77 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Defines the class Shape4D.
+from .numeric_util import full_shape
+
+
+class Shape4D:
+    """
+    4D Shape (in NHWC format)
+    """
+
+    def __init__(self, shape, base=1):
+        assert shape is not None
+        assert len(shape) <= 4
+        self._shape4D = tuple(full_shape(4, shape, base))
+
+    def __str__(self):
+        return f"<Shape4D {self.as_list()}>"
+
+    def __eq__(self, other):
+        return self._shape4D == other._shape4D
+
+    def clone(self):
+        return Shape4D(self.as_list())
+
+    @property
+    def batch(self):
+        return self._shape4D[0]
+
+    @property
+    def height(self):
+        return self._shape4D[1]
+
+    @property
+    def width(self):
+        return self._shape4D[2]
+
+    @property
+    def depth(self):
+        return self._shape4D[3]
+
+    @batch.setter
+    def batch(self, new_batch):
+        self._shape4D = (new_batch, self._shape4D[1], self._shape4D[2], self._shape4D[3])
+
+    @height.setter
+    def height(self, new_height):
+        self._shape4D = (self._shape4D[0], new_height, self._shape4D[2], self._shape4D[3])
+
+    @width.setter
+    def width(self, new_width):
+        self._shape4D = (self._shape4D[0], self._shape4D[1], new_width, self._shape4D[3])
+
+    @depth.setter
+    def depth(self, new_depth):
+        self._shape4D = (self._shape4D[0], self._shape4D[1], self._shape4D[2], new_depth)
+
+    def get_dim(self, dim):
+        assert -4 <= dim < 4
+        return self._shape4D[dim]
+
+    def as_list(self):
+        return list(self._shape4D)
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index 1f027d6..d8faf36 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -32,6 +32,7 @@
 from .operation import NpuBlockType
 from .range_set import MemoryRangeSet
 from .register_command_stream_util import to_kernel
+from .shape4d import Shape4D
 from .tensor import MemArea
 
 
@@ -195,14 +196,14 @@
         ifm_bits = ifm_tensor.dtype.size_in_bits()
         ifm_shape = ps.primary_op.ifm_shapes[0]
 
-        if ifm_shape != []:
-            ifm_depth = ifm_shape[-1]
+        if ifm_tensor.shape != []:
+            ifm_depth = ifm_shape.depth
 
         if is_elementwise:
             ifm_count = 2
             if ifm_tensor.shape == []:  # Scalar in ifm1
                 assert ifm2_tensor
-                ifm_depth = ps.primary_op.ifm_shapes[1][-1]
+                ifm_depth = ps.primary_op.ifm_shapes[1].depth
                 ifm_count = 1
             elif not ifm2_tensor or ifm2_tensor.shape == []:  # Scalar in ifm2
                 ifm_count = 1
@@ -251,7 +252,7 @@
         ifm_bits=ifm_bits,
         ifm_depth=ifm_depth,
         ifm_count=ifm_count,
-        ofm_shape=ofm_shape,
+        ofm_shape=Shape4D(ofm_shape),
     )
 
 
@@ -265,14 +266,9 @@
 
     # Constrain the search space if the OFM is smaller than the max block size
     # - Add other block search constraints here if required
-    if len(alloc.ofm_shape) <= 2:
-        max_block_height = max_block_width = alloc.ofm_shape[0]
-    else:
-        max_block_width = alloc.ofm_shape[-2]
-        max_block_height = alloc.ofm_shape[-3]
-
-    # Common block depth
-    max_block_depth = alloc.ofm_shape[-1]
+    max_block_width = alloc.ofm_shape.width
+    max_block_height = alloc.ofm_shape.height
+    max_block_depth = alloc.ofm_shape.depth
 
     # Constrain to valid ranges before search
     max_block_width = min(arch.ofm_block_max.width, max_block_width)
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 9849653..3b4bace 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -213,7 +213,7 @@
         ofm = self.op.outputs[0]
 
         # Reshape ifm/ofm (if needed)
-        full_shape = self.op.ifm_shapes[0]
+        full_shape = self.op.ifm_shapes[0].as_list()
         if full_shape[0] > 1:
             full_shape[1] *= full_shape[0]
             full_shape[0] = 1
@@ -414,6 +414,7 @@
         shr30_op.add_input_tensor(scaled_exp)
         shr30_op.add_input_tensor(right_shift)
         shr30_op.set_output_tensor(ofm)
+        shr30_op.set_ifm_ofm_shapes()
         DebugDatabase.add_optimised(self.op, shr30_op)
 
         return shr30_op
@@ -535,6 +536,7 @@
         shr13_op.add_input_tensor(mul_ofm)
         shr13_op.add_input_tensor(reciprocal_right_shift)
         shr13_op.set_output_tensor(ofm)
+        shr13_op.set_ifm_ofm_shapes()
         DebugDatabase.add_optimised(self.op, shr13_op)
 
         return shr13_op
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index df8f886..093e877 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -40,6 +40,7 @@
 from .numeric_util import full_shape
 from .operation import Op
 from .operation import Operation
+from .shape4d import Shape4D
 
 Shape = List
 
@@ -304,6 +305,7 @@
     # Operator
     const_op = Operation(Op.Const, name)
     const_op.set_output_tensor(const_tensor)
+    const_op.set_ifm_ofm_shapes()
     return const_tensor
 
 
@@ -323,8 +325,7 @@
     reshape_op.add_input_tensor(reshape_ifm)
     reshape_op.add_input_tensor(create_const_tensor(name + "_shape", [1], DataType.int32, shape))
     reshape_op.set_output_tensor(reshape_ofm)
-    reshape_op.ifm_shapes.append(full_shape(4, reshape_ifm.shape, 1))
-    reshape_op.ofm_shapes.append(full_shape(4, reshape_ofm.shape, 1))
+    reshape_op.set_ifm_ofm_shapes()
     return reshape_ofm if ifm_reshape else reshape_ifm
 
 
@@ -608,7 +609,7 @@
     def consumers(self) -> List[Operation]:
         return self.consumer_list
 
-    def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape) -> Tuple:
+    def addresses_for_rolling_buffer(self, start_coord: Shape, end_coord: Shape, fm_shape: Shape4D) -> Tuple:
         # returns ( box_height0, box_height1, box_width, [address_tl, address_tr, address_bl, address_br] )
 
         if self.storage_shape == []:
@@ -616,7 +617,7 @@
                 1,
                 1,
                 1,
-                [self.address_for_coordinate(start_coord, shape=fm_shape), None, None, None],
+                [self.address_for_coordinate(start_coord, shape=fm_shape.as_list()), None, None, None],
             )
 
         storage_shape_4D = full_shape(4, self.storage_shape, 1)
@@ -630,20 +631,20 @@
         box_width = crossing_x - start_coord[2]
 
         addresses: List = [None] * 4
-        addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape)
+        addresses[0] = self.address_for_coordinate(start_coord, shape=fm_shape.as_list())
 
         if end_coord[2] > crossing_x:
             addresses[1] = self.address_for_coordinate(
-                [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape
+                [start_coord[0], start_coord[1], crossing_x, start_coord[3]], shape=fm_shape.as_list()
             )
             raise UnsupportedFeatureError("Striping in vertical direction is not supported")
         if end_coord[1] > crossing_y:
             addresses[2] = self.address_for_coordinate(
-                [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape
+                [start_coord[0], crossing_y, start_coord[2], start_coord[3]], shape=fm_shape.as_list()
             )
         if end_coord[1] > crossing_y and end_coord[2] > crossing_x:
             addresses[3] = self.address_for_coordinate(
-                [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape
+                [start_coord[0], crossing_y, crossing_x, start_coord[3]], shape=fm_shape.as_list()
             )
 
         return box_height0, box_height0, box_width, addresses
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index 4537741..7fdc4bd 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -21,6 +21,7 @@
 from ethosu.vela.graph_optimiser import convert_batched_fc_shape
 from ethosu.vela.operation import Op
 from ethosu.vela.tensor import create_const_tensor
+from ethosu.vela.tensor import Shape4D
 from ethosu.vela.tensor import Tensor
 from ethosu.vela.test import testutil
 
@@ -35,8 +36,8 @@
 
     ifm.consumer_list.append(op)
 
-    op.ifm_shapes.append([4, 1, 1, 8])
-    op.ofm_shapes.append([4, 1, 1, 8])
+    op.ifm_shapes.append(Shape4D([4, 1, 1, 8]))
+    op.ofm_shapes.append(Shape4D([4, 1, 1, 8]))
 
     prev_op = op.clone()
     prev_op.ifm_shapes = op.ifm_shapes
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 583821a..973b820 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -62,7 +62,7 @@
 
 def test_constraint_tens_shape_size():
     # Tensors cannot be > 4D
-    op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8])
+    op = testutil.create_op_with_quant_tensors(Op.Relu, [1, 1, 8, 8, 8], [1, 1, 8, 8, 8], set_ifm_ofm_shapes=False)
     assert not support.is_operator_supported(op)
 
 
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index 63f841b..c345950 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -75,7 +75,7 @@
 
 
 def create_op_with_quant_tensors(
-    op_type, ifm_shape, ofm_shape, weights_shape=None, bias_shape=None, datatype=DataType.uint8
+    op_type, ifm_shape, ofm_shape, weights_shape=None, bias_shape=None, datatype=DataType.uint8, set_ifm_ofm_shapes=True
 ):
     ifm = Tensor(ifm_shape, datatype, "in")
     ifm.quantization = default_quant_params()
@@ -107,7 +107,9 @@
         bias = create_const_tensor("bias", bias_shape, DataType.int32, np.zeros(bias_shape), np.int32, quantization=qp)
         op.add_input_tensor(bias)
 
-    op.set_ifm_ofm_shapes()
+    if set_ifm_ofm_shapes:
+        op.set_ifm_ofm_shapes()
+
     return op