diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index bb4f142..9cbda45 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -15,11 +15,14 @@
 # limitations under the License.
 # Description:
 # Contains classes that hold commands for the high-level command stream (one command per DMA or NPU stripe).
+from typing import List
+
 import numpy as np
 
 from .architecture_features import Block
 from .numeric_util import round_up_divide
 from .operation import NpuBlockType
+from .shape4d import Shape4D
 
 
 class Box:
@@ -32,15 +35,15 @@
 
     def transform_with_strides_and_skirt(
         self,
-        strides,
-        skirt,
-        ifm_shape,
-        npu_block_type,
-        concat_axis=0,
-        concat_offset=0,
-        split_offset=None,
-        k_height=1,
-        upscaling_factor=1,
+        strides: List[int],
+        skirt: List[int],
+        ifm_shape: Shape4D,
+        npu_block_type: NpuBlockType,
+        concat_axis: int = 0,
+        concat_offset: int = 0,
+        split_offset: int = None,
+        k_height: int = 1,
+        upscaling_factor: int = 1,
     ):
         new_start_coord = list(self.start_coord)
         new_end_coord = list(self.end_coord)
@@ -58,15 +61,15 @@
         ):
             # these types of operations do a "dot product" or sum over the entire IFM
             new_start_coord[-1] = 0
-            new_end_coord[-1] = ifm_shape[-1]
+            new_end_coord[-1] = ifm_shape.depth
 
-        if npu_block_type == NpuBlockType.ElementWise and min(len(new_end_coord), len(ifm_shape)) >= 1:
-            new_end_coord[-1] = min(new_end_coord[-1], ifm_shape[-1])
-        if min(len(new_end_coord), len(ifm_shape)) >= 2:
-            new_end_coord[-2] = min(new_end_coord[-2], ifm_shape[-2] * upscaling_factor)
-        if min(len(new_end_coord), len(ifm_shape)) >= 3:
+        if npu_block_type == NpuBlockType.ElementWise and len(new_end_coord) >= 1:
+            new_end_coord[-1] = min(new_end_coord[-1], ifm_shape.depth)
+        if len(new_end_coord) >= 2:
+            new_end_coord[-2] = min(new_end_coord[-2], ifm_shape.width * upscaling_factor)
+        if len(new_end_coord) >= 3:
             original_end_coord = list(new_end_coord)
-            new_end_coord[-3] = min(new_end_coord[-3], ifm_shape[-3] * upscaling_factor)
+            new_end_coord[-3] = min(new_end_coord[-3], ifm_shape.height * upscaling_factor)
 
         pad_top = 0
         pad_bottom = 0
@@ -74,7 +77,7 @@
             if len(new_start_coord) >= 2:
                 stride = strides[2]
                 new_start_coord[-2] = max(new_start_coord[-2] * stride - skirt[1], 0)
-                new_end_coord[-2] = min(new_end_coord[-2] * stride + skirt[3], ifm_shape[-2])
+                new_end_coord[-2] = min(new_end_coord[-2] * stride + skirt[3], ifm_shape.width)
 
             if len(new_start_coord) >= 3:
                 stride = strides[1]
@@ -86,23 +89,20 @@
                 pad_top = max(0, 0 - new_start_coord[-3]) + skirt_top_remainder
                 new_start_coord[-3] = max(new_start_coord[-3], 0)
 
-                while len(ifm_shape) < 3:
-                    ifm_shape = [1] + ifm_shape
-
-                if (new_end_coord[-3] * stride + skirt[2]) > (ifm_shape[-3] * upscaling_factor):
+                if (new_end_coord[-3] * stride + skirt[2]) > (ifm_shape.height * upscaling_factor):
                     # pad_bottom is calculated based the diff between the end position of the weight kernel,
                     # after last stride and the ifm height.
-                    if upscaling_factor != 1 and original_end_coord[-3] > ifm_shape[-3] * upscaling_factor:
+                    if upscaling_factor != 1 and original_end_coord[-3] > ifm_shape.height * upscaling_factor:
                         # Special case for Transpose Convolution with VALID padding.
-                        pad_bottom = original_end_coord[-3] - (ifm_shape[-3] * upscaling_factor)
+                        pad_bottom = original_end_coord[-3] - (ifm_shape.height * upscaling_factor)
                     else:
                         k_start = new_start_coord[-3] - pad_top
-                        pad_bottom = max(0, k_start + total_stride + k_height - (ifm_shape[-3] * upscaling_factor))
+                        pad_bottom = max(0, k_start + total_stride + k_height - (ifm_shape.height * upscaling_factor))
 
                 # Adjust for upscaling
                 new_start_coord[-3] = max(new_start_coord[-3] // upscaling_factor, 0)
                 new_end_coord[-3] = new_end_coord[-3] * stride + skirt[2] + (skirt[2] % upscaling_factor)
-                new_end_coord[-3] = max(min(new_end_coord[-3] // upscaling_factor, ifm_shape[-3]), 1)
+                new_end_coord[-3] = max(min(new_end_coord[-3] // upscaling_factor, ifm_shape.height), 1)
 
         return Box(new_start_coord, new_end_coord), pad_top, pad_bottom
 
@@ -197,7 +197,7 @@
         self.pad_top = pad_top
         self.pad_bottom = pad_bottom
         for i in range(len(self.ofm_box.end_coord)):
-            assert self.ofm_box.end_coord[i] <= ps.ofm_shapes[0][i]
+            assert self.ofm_box.end_coord[i] <= ps.ofm_shapes[0].get_dim(i)
 
     def is_npu_pass_command(self):
         return True
@@ -251,76 +251,6 @@
         assert res >= 0
         return res
 
-    def get_single_block_command(self, block_idx):
-        block_cfg = (self.block_config[0], self.block_config[1], self.block_config[3])
-        dims = self.get_block_dimensions()
-        strides = dims[1] * dims[2], dims[2], 1
-        coord = []
-        idx_left = block_idx
-        for s in strides:
-            c = idx_left // s
-            idx_left -= c * s
-            coord.append(c)
-
-        assert idx_left == 0
-
-        # put in dummy height/widths in case we're dealing with FC layers
-        ofm_start = list(self.ofm_box.start_coord)
-        ofm_end = list(self.ofm_box.end_coord)
-
-        # cut out a nice block shape
-        for idx in (-1, -2, -3):
-            if len(ofm_start) >= -idx:
-                ofm_start[idx] += block_cfg[idx] * coord[idx]
-                ofm_end[idx] = min(ofm_end[idx], ofm_start[idx] + block_cfg[idx])
-
-        ps = self.ps
-        strides = None
-        skirt = None
-        if ps.primary_op is not None:
-            strides = ps.primary_op.attrs.get("strides", None)
-            skirt = ps.primary_op.attrs.get("skirt", None)
-        npu_block_type = ps.npu_block_type
-
-        ofm_box = Box(ofm_start, ofm_end)
-        ifm_box, _, _ = ofm_box.transform_with_strides_and_skirt(
-            strides, skirt, self.ifm_tensor.shape, npu_block_type, self.concat_axis, self.concat_offset
-        )
-
-        weight_box = None
-        if self.weight_tensor is not None:
-            weight_oc_start = ofm_start[-1]
-            weight_oc_end = ofm_end[-1]
-            if self.concat_axis - len(self.weight_tensor.shape) == -1:
-                weight_oc_start -= self.concat_offset
-                weight_oc_end -= self.concat_offset
-
-            weight_box = Box.make_weight_box(
-                self.weight_tensor.shape,
-                npu_block_type,
-                weight_oc_start,
-                weight_oc_end,
-                self.weight_tensor.weight_transpose_depthwise,
-            )
-
-        return NpuStripe(
-            self.ps,
-            self.block_config,
-            self.is_first,
-            self.is_last,
-            self.is_first_h_stripe,
-            self.is_last_h_stripe,
-            self.ifm_tensor,
-            ifm_box,
-            self.ofm_tensor,
-            ofm_box,
-            self.weight_tensor,
-            weight_box,
-            self.scale_tensor,
-            self.concat_axis,
-            self.concat_offset,
-        )
-
 
 class DMA(Command):
     def __init__(self, ps, in_tensor, out_tensor, box):
