MLBEDSW-4223: Full support for PAD operator

- Added full support for PAD operator
- Hardware padding is still used whenever possible
- Bug fix Pad followed by max pool if IFM contains negative values

Change-Id: Ifc64d1943737d94466f5e2821009dab12a49a965
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index e48ebf5..3c90e20 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -1,7 +1,7 @@
 # Supported Ops
 
 This file was automatically generated by Vela using the `--supported-ops-report` parameter.  
-Vela version: `2.1.0`
+Vela version: `2.1.2.dev0+g41c006a.d20210309`
 
 This file complies with
 [**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -239,15 +239,11 @@
 
 This is a list of constraints that the PAD operator must satisfy in order to be scheduled on the NPU.
 
-- IFM and OFM data types must match
-- Both Input quantization parameters must match OFM quantization parameters
 - Number of input tensors must be exactly 2
-- The padding tensor must have the shape [4,2]
+- The padding tensor must have the shape [3,2] or [4,2]
 - The pad tensor can only pad width and height
 - Pad tensor must be of type: int32, int64
 - The padding tensor must be constant
-- Must be followed by one of the following operator types: AVERAGE_POOL_2D, CONV_2D, DEPTHWISE_CONV_2D, MAX_POOL_2D
-- Padding must be at most kernel size divided by 2
 
 ## RESIZE_BILINEAR Constraints
 
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 1e890bb..3084117 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -41,10 +41,12 @@
 from .operation import Operation
 from .operation import Padding
 from .operation_util import create_avgpool_nop
+from .operation_util import get_pad_values_from_input
 from .shape4d import Shape4D
 from .softmax import SoftMax
 from .tensor import check_quantized_tens_scaling_equal
 from .tensor import create_const_tensor
+from .tensor import create_equivalence_id
 from .tensor import QuantizationParameters
 from .tensor import Tensor
 from .tensor import TensorPurpose
@@ -55,6 +57,23 @@
 memory_only_ops = (Op.Reshape,)
 
 
+def create_avg_pool_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
+    """Creates an average pool for the given concat op/input feature map"""
+    ofm = concat_op.ofm
+    avgpool_op = create_avgpool_nop(name)
+    avgpool_op.inputs = [ifm]
+    avgpool_op.outputs = [ofm]
+
+    avgpool_op.write_offset = write_offset
+    avgpool_op.write_shape = ifm_shape
+    ofm.ops.append(avgpool_op)
+    DebugDatabase.add_optimised(concat_op, avgpool_op)
+    avgpool_op.ifm_shapes.append(ifm_shape)
+    avgpool_op.ofm_shapes.append(concat_op.ofm_shapes[0])
+    avgpool_op.memory_function = Op.ConcatSliceWrite
+    return avgpool_op
+
+
 def remove_passthrough_tensor(tens, arch, nng):
     if len(tens.ops) == 1 and tens.ops[0].type in passthrough_nodes:
         assert len(tens.ops[0].inputs) == 1
@@ -64,7 +83,7 @@
 
 def rewrite_concat_ops(op, arch):
     if not op.run_on_npu or not op.type.is_concat_op():
-        return op
+        return
 
     axis_4D = 0
     ofm = op.ofm
@@ -90,7 +109,6 @@
         op.type = Op.PackReshaped
 
     inputs, axis = op.get_concat_inputs_axis()
-
     for idx, inp in enumerate(inputs):
         if op.type != Op.PackReshaped:
             op.ifm_shapes[idx] = Shape4D(inp.shape)
@@ -98,20 +116,13 @@
                 axis_4D = axis + (4 - len(inp.shape))
             else:
                 axis_4D = axis
-        avgpool_op = create_avgpool_nop(op.name + str(idx) + "_avgpool")
-        avgpool_op.inputs = [inp]
-        avgpool_op.outputs = [ofm]
-        avgpool_op.attrs["concat_axis"] = axis_4D
-        avgpool_op.attrs["concat_start"] = offset
-        offset += op.ifm_shapes[idx][axis_4D]
-
-        avgpool_op.attrs["concat_end"] = offset
-        avgpool_op.run_on_npu = True
-        ofm.ops.append(avgpool_op)
-        DebugDatabase.add_optimised(op, avgpool_op)
-        avgpool_op.ifm_shapes.append(op.ifm_shapes[idx])
-        avgpool_op.ofm_shapes.append(op.ofm_shapes[0])
-        avgpool_op.memory_function = Op.ConcatSliceWrite
+        write_offset = [0, 0, 0, 0]
+        write_offset[axis_4D] = offset
+        concat_end = offset + op.ifm_shapes[idx][axis_4D]
+        create_avg_pool_for_concat(
+            op, op.name + str(idx) + "_avgpool", inp, op.ifm_shapes[idx], Shape4D.from_list(write_offset)
+        )
+        offset = concat_end
     assert ofm.shape[axis] == offset
 
     # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a
@@ -119,11 +130,7 @@
     # aligned. For other values of axis the address offsets will be 16 byte aligned, as they are all based on c = 0
     # and those addresses are always 16 byte aligned due to the NHCWB16 format.
     if axis == -1 or axis == (len(ofm.shape) - 1):
-        for op in ofm.ops:
-            if op.attrs["concat_start"] % 16 != 0:
-                ofm.avoid_NHCWB16 = True
-                break
-    return op
+        ofm.avoid_NHCWB16 = any(op2.write_offset.depth % 16 != 0 for op2 in ofm.ops if op2.write_offset is not None)
 
 
 def rewrite_split_ops(tens, arch, nng):
@@ -1177,20 +1184,53 @@
     return op
 
 
-def optimise_pad(op: Operation, arch, nng):
+def _leading_pad_ok(leading_pad, stride, kernel_size):
+    # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
+    # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
+    max_size = kernel_size // 2
+    return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
+
+
+def replace_pad_by_hw_pad(op: Operation, arch, nng):
     """
+    Tries to completely remove a PAD operator by using hardware padding.
+    E.g. a PAD operation that pads 1, followed by a CONV with VALID padding and kernel size 3
+    is rewritten such that the PAD is removed, and the CONV uses SAME padding.
     Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
     if both operations can be run on the NPU.
+    This is the most efficient way to implement PAD, but cannot be done for all pad sizes.
     """
     if (
-        (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_pool_op())
+        (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op() or op.type.is_avgpool_op())
         and op.run_on_npu
         and op.attrs["padding"] == Padding.VALID
     ):
         pad_op = op.ifm.ops[0]
         if pad_op.type != Op.Pad or not pad_op.run_on_npu:
             return op
+        if pad_op.ifm.dtype != pad_op.ofm.dtype or not check_quantized_tens_scaling_equal(pad_op.ofm, pad_op.ifm):
+            return op
+        top, left, bottom, right = get_pad_values_from_input(pad_op.inputs[1].values)
+        k = op.kernel
+        k_w, k_h = k.dilated_wh()
+
+        # Check if the PAD operator can be replaced by hardware padding
+        if left > k_w // 2 or right > k_w // 2 or top > k_h // 2 or bottom > k_h // 2:
+            # Too much padding, it would require hardware padding to actually insert zeros
+            return op
+        if not _leading_pad_ok(top, k.stride.y, k_h) or not _leading_pad_ok(left, k.stride.x, k_w):
+            return op
+
         if op.type.is_avgpool_op():
+            # For average pool, hardware padding can only be used if padding is 0 or kernel size / 2
+            for pad, k_size in (
+                (left, k_w),
+                (right, k_w),
+                (top, k_h),
+                (bottom, k_h),
+            ):
+                if pad not in (0, k_size // 2):
+                    return op
             # Average pool is converted to depthwise, because NPU average pool + same padding
             # has a special implementation that is different from PAD followed by average pool with
             # valid padding.
@@ -1230,13 +1270,80 @@
         op.set_input_tensor(pad_op.ifm, 0)
         # Adjust the padding attributes of the convolution operator
         op.attrs["padding"] = Padding.EXPLICIT
-        padding = pad_op.inputs[1].values  # 4x2 tensor, first dimension is N, H, W, C
-        top, left, bottom, right = (padding[1][0], padding[2][0], padding[1][1], padding[2][1])
         op.attrs["explicit_padding"] = (top, left, bottom, right)
         op.set_ifm_ofm_shapes()
     return op
 
 
+def convert_pad(op: Operation, arch, nng):
+    """
+    Rewrites PAD operator to an average pool that copies the IFM to the OFM
+    + up to 4 average pool operators that fill the OFM with zeros at the borders.
+    This is done as fall-back for the PAD operators that remain after replace_pad_by_hw_pad
+    """
+    if op.type != Op.Pad or not op.run_on_npu:
+        return op
+    top, left, bottom, right = get_pad_values_from_input(op.inputs[1].values)
+
+    ifm = op.ifm
+    assert ifm is not None
+    ifm_shape = Shape4D(ifm.shape)
+    ofm = op.ofm
+    assert ofm is not None
+    ofm.ops = []
+    ofm_shape = op.ofm_shapes[0]
+
+    # Average pool op that copies IFM to the right place inside the OFM
+    shp0 = Shape4D(0, 0, 0, 0)
+    shp_top = shp0.with_height(top)
+    avgpool_op = create_avg_pool_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
+    avgpool_op.activation = op.activation
+    quant = ofm.quantization
+    pad_value = quant.zero_point
+    # Add operations that fill the borders of the OFM
+    if top > 0:
+        shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+        )
+        # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_avg_pool_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
+    if bottom > 0:
+        shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_bottom",
+            shape.as_list(),
+            ofm.dtype,
+            shape.elements() * [pad_value],
+            np.uint8,
+            quantization=quant,
+        )
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_avg_pool_for_concat(
+            op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom)
+        )
+    if left > 0:
+        shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+        )
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
+    if right > 0:
+        shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+        )
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_avg_pool_for_concat(
+            op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right)
+        )
+    ofm.avoid_NHCWB16 = True
+    op.type = Op.ConcatTFLite
+    return avgpool_op
+
+
 def add_attrs_to_resizebilinear(op, arch, nng):
     if op.type == Op.ResizeBilinear and op.run_on_npu:
         input_tensor = op.inputs[0]
@@ -1497,6 +1604,7 @@
         convert_mul_max_to_abs_or_lrelu,
         convert_lrelu,
         convert_tanh_sigmoid_to_lut,
+        replace_pad_by_hw_pad,
     ]
 
     for idx, sg in enumerate(nng.subgraphs):
@@ -1512,7 +1620,7 @@
             sg,
             arch,
             [remove_passthrough_tensor],
-            [fuse_activation_function_with_prev, optimise_pad, add_padding_fields],
+            [fuse_activation_function_with_prev, convert_pad, add_padding_fields],
         )
 
     # Removal of SplitSliceRead, need to be done after optimisation has been performed,
diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py
index 0ce8fac..075574e 100644
--- a/ethosu/vela/high_level_command_stream.py
+++ b/ethosu/vela/high_level_command_stream.py
@@ -39,8 +39,7 @@
         skirt: List[int],
         ifm_shape: Shape4D,
         npu_block_type: NpuBlockType,
-        concat_axis: int = 0,
-        concat_offset: int = 0,
+        concat_offsets: List[int],
         split_offset: Shape4D = None,
         k_height: int = 1,
         upscaling_factor: int = 1,
@@ -48,8 +47,8 @@
         new_start_coord = list(self.start_coord)
         new_end_coord = list(self.end_coord)
 
-        new_start_coord[concat_axis] -= concat_offset
-        new_end_coord[concat_axis] -= concat_offset
+        new_start_coord = np.subtract(new_start_coord, concat_offsets)
+        new_end_coord = np.subtract(new_end_coord, concat_offsets)
 
         if split_offset is not None:
             for idx in range(len(split_offset)):
@@ -170,8 +169,6 @@
         weight_tensor=None,
         weight_box=None,
         scale_tensor=None,
-        concat_axis=0,
-        concat_offset=0,
         ifm2_tensor=None,
         ifm2_box=None,
         pad_top=0,
@@ -192,8 +189,6 @@
         self.weight_tensor = weight_tensor
         self.scale_tensor = scale_tensor
         self.weight_box = weight_box
-        self.concat_axis = concat_axis
-        self.concat_offset = concat_offset
         self.pad_top = pad_top
         self.pad_bottom = pad_bottom
         for i in range(len(self.ofm_box.end_coord)):
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index 1ce7e7e..23d3a4f 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -79,19 +79,14 @@
         elif ps.primary_op.type == Op.ResizeBilinear:
             upscaling = round_up_divide(ofm_shape.height, ifm_shape.height)
 
-    concat_axis = 0
-    concat_offset = 0
+    concat_offset = [0, 0, 0, 0]
 
     for op in ps.ops:
-        if op.attrs.get("concat_axis", None) is not None:
-            concat_axis = op.attrs["concat_axis"]
-            concat_start = op.attrs["concat_start"]
-            concat_end = op.attrs["concat_end"]
-
-            ofm_start[concat_axis] = concat_start
-            ofm_end[concat_axis] = concat_end
-            concat_offset = concat_start
-        elif op.type.is_relu_op() or op.type in (Op.Tanh, Op.Sigmoid):
+        if op.write_offset is not None:
+            concat_offset = op.write_offset.as_list()
+            ofm_start = concat_offset
+            ofm_end = (op.write_offset + op.write_shape).as_list()
+        if op.type.is_relu_op() or op.type in (Op.Tanh, Op.Sigmoid):
             ps.primary_op.activation = create_activation_function(op.type)
 
     if strat == SchedulingStrategy.WeightStream:
@@ -109,13 +104,13 @@
 
             if ifm_shape is not None:
                 ifm_box, _, _ = ofm_box.transform_with_strides_and_skirt(
-                    strides, skirt, ifm_shape, npu_block_type, concat_axis, concat_offset, split_offsets[0], upscaling,
+                    strides, skirt, ifm_shape, npu_block_type, concat_offset, split_offsets[0], upscaling,
                 )
             else:
                 ifm_box = Box([], [])
             if ifm2_shape is not None:
                 ifm2_box, _, _ = ofm_box.transform_with_strides_and_skirt(
-                    strides, skirt, ifm2_shape, npu_block_type, concat_axis, concat_offset, split_offsets[1], upscaling,
+                    strides, skirt, ifm2_shape, npu_block_type, concat_offset, split_offsets[1], upscaling,
                 )
             else:
                 ifm2_box = Box([], [])
@@ -132,7 +127,6 @@
                             skirt,
                             Shape4D(intermediate.shape),
                             npu_block_type,
-                            concat_axis,
                             concat_offset,
                             split_offsets[0],
                             upscaling,
@@ -143,11 +137,9 @@
 
             weight_box = None
             if weight_tensor is not None:
-                weight_oc_start = start
-                weight_oc_end = end
-                if concat_axis - len(weight_tensor.shape) == -1:
-                    weight_oc_start -= concat_offset
-                    weight_oc_end -= concat_offset
+                weight_offset = concat_offset[len(weight_tensor.shape) - 1]
+                weight_oc_start = start - weight_offset
+                weight_oc_end = end - weight_offset
 
                 weight_box = Box.make_weight_box(
                     weight_tensor.shape,
@@ -172,8 +164,6 @@
                 weight_tensor,
                 weight_box,
                 scale_tensor,
-                concat_axis,
-                concat_offset,
                 ifm2_tensor=ifm2_tensor,
                 ifm2_box=ifm2_box,
             )
@@ -222,15 +212,7 @@
                     k_height = weight_tensor.shape[0]
 
             ifm_box, pad_top, pad_bottom = ofm_box.transform_with_strides_and_skirt(
-                strides,
-                skirt,
-                ifm_shape,
-                npu_block_type,
-                concat_axis,
-                concat_offset,
-                split_offsets[0],
-                k_height,
-                upscaling,
+                strides, skirt, ifm_shape, npu_block_type, concat_offset, split_offsets[0], k_height, upscaling,
             )
 
             ifm_y_needed = 1
@@ -257,7 +239,6 @@
                             skirt,
                             Shape4D(intermediate.shape),
                             npu_block_type,
-                            concat_axis,
                             concat_offset,
                             split_offsets[0],
                             upscaling,
@@ -294,8 +275,6 @@
                 weight_tensor,
                 weight_box,
                 scale_tensor,
-                concat_axis,
-                concat_offset,
                 None,
                 None,
                 pad_top,
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index d2b08b5..a5a58e8 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -424,6 +424,8 @@
         "read_offsets",
         "rounding_mode",
         "low_precision_scaling",
+        "write_offset",
+        "write_shape",
     )
 
     def __init__(self, op_type: Op, name: str):
@@ -438,7 +440,7 @@
         # Fused activation function. If not none: operator code.
         self.activation: Optional[ActivationFunction] = None
         # Fused memory function, if not None: operator code
-        self.memory_function = None
+        self.memory_function: Optional[Op] = None
         # If not none: contains QuantizationParameters to be used as output quantization
         # (which overrides the ofm tensor's quantization), used in LUT
         self.forced_input_quantization = None
@@ -457,6 +459,12 @@
         # The Mean operator (implemented as a depthwise convolution) requires scaling
         # to be calculated differently in one case. In that case, this is set to True.
         self.low_precision_scaling = False
+        # Write offset, for operations that only produce a part of the OFM
+        self.write_offset: Optional[Shape4D] = None
+        # The amount of OFM that is produced by the operation (only if write_offset is not None).
+        # E.g. an operation that only fills the bottom row of an OFM of size 1x10x8x1 would have
+        # write_offset 0,9,0,0, write_shape 1,1,8,1
+        self.write_shape: Optional[Shape4D] = None
 
     def clone(self, suffix="_clone"):
         res = Operation(self.type, self.name + suffix)
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index 82a7fb8..417f27e 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -40,6 +40,7 @@
     op.attrs["ksize"] = [1, 1, 1, 1]
     op.attrs["skirt"] = [0, 0, 0, 0]
     op.attrs["explicit_padding"] = [0, 0, 0, 0]
+    op.run_on_npu = True
     return op
 
 
@@ -259,3 +260,8 @@
     op.set_output_tensor(ofm)
     op.ofm_shapes.append(ofm_shape)
     return op
+
+
+def get_pad_values_from_input(padding) -> Tuple:
+    """Returns top, left, bottom, right padding from input values in a Pad input tensor"""
+    return (padding[-3][0], padding[-2][0], padding[-3][1], padding[-2][1])
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index a82f812..2319706 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -259,15 +259,11 @@
         self.specific_constraints[Op.FullyConnected].append(SupportedOperators.constraint_keep_dim_ifm_ofm)
 
         # Pad specific checks:
-        self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_matching_in_out_types)
-        self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_matching_quantization_parameters)
         self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_input_count)
         self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_shape)
         self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_padding_dimensions)
         self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_type)
         self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_constant)
-        self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_ofm)
-        self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_size)
 
         # HardSwish specific checks:
         self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_input_8bit)
@@ -830,8 +826,8 @@
 
     @staticmethod
     def constraint_pad_shape(op):
-        "The padding tensor must have the shape [4,2]"
-        valid = op.inputs[1].shape == [4, 2]
+        "The padding tensor must have the shape [3,2] or [4,2]"
+        valid = op.inputs[1].shape in ([3, 2], [4, 2])
         return valid, f"The pad tensor has the shape: {op.inputs[1].shape}"
 
     @classmethod
@@ -846,7 +842,10 @@
     def constraint_padding_dimensions(op):
         "The pad tensor can only pad width and height"
         pad_tensor = op.inputs[1].values
-        valid = sum(pad_tensor[0, :]) + sum(pad_tensor[-1, :]) == 0
+
+        valid = sum(pad_tensor[-1, :]) == 0
+        if valid and len(pad_tensor) > 3:
+            valid = sum(pad_tensor[0, :]) == 0
         return valid, f"First dimension padding: {pad_tensor[0,:]}, last dimension padding: {pad_tensor[-1,:]}"
 
     @staticmethod
@@ -856,65 +855,6 @@
         valid = pad_tensor is not None
         return valid, f"Op has non-constant padding tensor: {op.inputs[1].values}"
 
-    @classmethod
-    @docstring_format_args([_optype_formatter(supported_pad_consumers)])
-    def constraint_pad_ofm(cls, op):
-        "Must be followed by one of the following operator types: {}"
-        consumers = op.ofm.consumers()
-        unsupported_consumers = [
-            cons.type
-            for cons in consumers
-            if cons is not None
-            if cons.type not in cls.supported_pad_consumers or cons.attrs["padding"] != Padding.VALID
-        ] + [None for cons in consumers if cons is None]
-        none_string = ", ".join(["NoneType" for cons in consumers if cons is None])
-        valid = len(unsupported_consumers) == 0
-        return valid, f"PAD operator is followed by: {_optype_formatter(unsupported_consumers)+none_string}"
-
-    @staticmethod
-    def __leading_pad_ok(leading_pad, stride, kernel_size):
-        # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
-        # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
-        max_size = kernel_size // 2
-        return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
-
-    @staticmethod
-    def constraint_pad_size(op):
-        "Padding must be at most kernel size divided by 2"
-        if SupportedOperators.constraint_pad_ofm(op)[0]:
-            padding = op.inputs[1].values  # 4x2 tensor, first dimension is N, H, W, C
-            top, left, bottom, right = (padding[1][0], padding[2][0], padding[1][1], padding[2][1])
-            for cons in op.ofm.consumers():
-                if cons is not None:
-                    # Note: pre-order graph traversal removes inputs of operators that are in traversal,
-                    # which makes it impossible to calculate kernel size, hence use cached _kernel for those operators
-                    k = cons.kernel if cons.inputs else cons._kernel
-                    k_w, k_h = k.dilated_wh()
-                    if cons.type.is_avgpool_op():
-                        # For average pool, padding works different on the NPU; more restrictions apply
-                        for name, pad, k_size in (
-                            ("Left", left, k_w),
-                            ("Right", right, k_w),
-                            ("Top", top, k_h),
-                            ("Bottom", bottom, k_h),
-                        ):
-                            if pad not in (0, k_size // 2):
-                                return False, f"{name} padding is {pad}, only 0 or {k_size // 2} are supported"
-                    else:
-                        if left > k_w // 2:
-                            return False, f"Left padding is {left}, kernel width is {k_w}"
-                        if right > k_w // 2:
-                            return False, f"Right padding is {right}, kernel width is {k_w}"
-                        if top > k_h // 2:
-                            return False, f"Top padding is {top}, kernel height is {k_h}"
-                        if bottom > k_h // 2:
-                            return False, f"Bottom padding is {bottom}, kernel height is {k_h}"
-                        if not SupportedOperators.__leading_pad_ok(top, k.stride.y, k_h):
-                            return False, f"Top padding is {top}, must be {k_h // 2} or multiple of {k.stride.y}"
-                        if not SupportedOperators.__leading_pad_ok(left, k.stride.x, k_w):
-                            return False, f"Left padding is {left}, must be {k_w // 2} or multiple of {k.stride.x}"
-        return True, "Pad size is ok"
-
     @staticmethod
     def constraint_stridedslice_inputs_const(op):
         "Begin, End and Stride Input tensors must be constant"
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index 285b3ac..d9e171d 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -23,7 +23,7 @@
 from ethosu.vela.graph_optimiser import calc_explicit_padding
 from ethosu.vela.graph_optimiser import convert_batched_fc_shape
 from ethosu.vela.graph_optimiser import optimise_graph_a
-from ethosu.vela.graph_optimiser import optimise_pad
+from ethosu.vela.graph_optimiser import replace_pad_by_hw_pad
 from ethosu.vela.graph_optimiser import rewrite_fully_connected_input
 from ethosu.vela.nn_graph import Graph
 from ethosu.vela.operation import Op
@@ -116,47 +116,92 @@
     assert (before, after) == expected_result
 
 
-def test_optimise_pad():
+def create_pad_and_conv2d(
+    in_shape,
+    out_shape,
+    padding,
+    in_dtype=DataType.int8,
+    out_dtype=DataType.int8,
+    pad_dtype=DataType.int32,
+    pad_setting=Padding.VALID,
+    kernel_size=3,
+):
+    """Creates Pad operator followed by a conv2d operator"""
+    qp = testutil.default_quant_params()
+    in0 = Tensor(in_shape, in_dtype, "in")
+    in0.quantization = qp
+    pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
+    out = Tensor(out_shape, out_dtype, "out")
+    out.quantization = qp.clone()
+    op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
+    op.run_on_npu = True
+    conv_out_tens = Tensor(in_shape, in_dtype, "output")
+    conv_out_tens.quantization = qp.clone()
+    weight_tens = Tensor([kernel_size, kernel_size, in_shape[-1], out_shape[-1]], in_dtype, "weights")
+    weight_tens.values = np.zeros(weight_tens.shape)
+    weight_tens.quant_values = np.zeros(weight_tens.shape, np.int8)
+    weight_tens.quantization = qp.clone()
+    bias_tens = Tensor(out_shape, pad_dtype, "biases")
+    attrs = {"padding": pad_setting, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
+    attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
+    conv2d_op = testutil.create_op(Op.Conv2DBias, [out, weight_tens, bias_tens], conv_out_tens, attrs)
+    conv2d_op.add_input_tensor(out)
+    conv2d_op.run_on_npu = True
+    return op, conv2d_op
+
+
+def test_pad_followed_by_conv_is_removed():
     """
     Tests that the PAD operator is bypassed when followed by a convolution operator,
     and that the padding of the convolution operation is correctly updated
     """
-    # Create Pad operation followed by Conv2D
-    quant = testutil.default_quant_params()
-    in_tens = Tensor([1, 76, 75, 64], DataType.uint8, "input")
-    in_tens.quantization = quant
-    pad_input = create_const_tensor("pad_input", [4, 2], DataType.int32, [[0, 0], [2, 1], [1, 1], [0, 0]])
-    temp_tens = Tensor([1, 79, 77, 64], DataType.uint8, "pad_out")
-    temp_tens.quantization = quant.clone()
-    out_tens = Tensor([1, 76, 75, 64], DataType.uint8, "output")
-    out_tens.quantization = quant.clone()
-    weight_tens = Tensor([5, 3, 64, 64], DataType.uint8, "weights")
-    weight_tens.values = np.zeros(weight_tens.shape)
-    weight_tens.quant_values = np.zeros(weight_tens.shape, np.uint8)
-    weight_tens.quantization = quant.clone()
-
-    bias_tens = Tensor([64], DataType.int32, "biases")
-    pad_op = testutil.create_op(Op.Pad, [in_tens, pad_input], temp_tens)
-    attrs = {"padding": Padding.VALID, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
-    attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
-    pad_op.run_on_npu = True
-    conv2d_op = testutil.create_op(Op.Conv2D, [temp_tens, weight_tens, bias_tens], out_tens, attrs)
-    conv2d_op.run_on_npu = True
-    nng = Graph()
-    sg = testutil.create_subgraph([pad_op, conv2d_op])
-    nng.subgraphs.append(sg)
+    pad_op, conv2d_op = create_pad_and_conv2d(
+        in_shape=[1, 76, 75, 64], out_shape=[1, 76, 75, 64], padding=[[0, 0], [2, 1], [1, 1], [0, 0]], kernel_size=4
+    )
+    nng = testutil.create_graph([pad_op, conv2d_op])
     arch = testutil.create_arch()
 
-    optimise_pad(conv2d_op, nng, arch)
+    replace_pad_by_hw_pad(conv2d_op, nng, arch)
 
-    op = sg.output_tensors[0].ops[0]
-    assert op.type == Op.Conv2D
+    op = nng.subgraphs[0].output_tensors[0].ops[0]
+    assert op.type == Op.Conv2DBias
     assert op.attrs["padding"] == Padding.EXPLICIT
     assert op.attrs["explicit_padding"] == (2, 1, 1, 1)
     assert op.ifm.shape == [1, 76, 75, 64]
     assert pad_op not in op.ifm.ops
 
 
+leading_pad_test_data = [
+    (2, 2, 11, True),
+    (1, 2, 11, False),
+    (2, 1, 11, False),
+    (5, 2, 11, True),
+]
+
+
+@pytest.mark.parametrize("top, left, kernel_size, expect_pad_removed", leading_pad_test_data)
+def test_leading_pad_size(top, left, kernel_size, expect_pad_removed):
+    # Tests PAD operator with big kernel size; top and left pad must be multiple of stride
+    out_shape = [1, 11 + left, 11 + top, 1]
+    padding = [[0, 0], [top, 0], [left, 0], [0, 0]]
+    pad_op, conv2d_op = create_pad_and_conv2d(
+        in_shape=[1, 11, 11, 1], out_shape=out_shape, padding=padding, kernel_size=kernel_size
+    )
+    nng = testutil.create_graph([pad_op, conv2d_op])
+    arch = testutil.create_arch()
+    replace_pad_by_hw_pad(conv2d_op, nng, arch)
+    op = nng.subgraphs[0].output_tensors[0].ops[0]
+    if expect_pad_removed:
+        assert op.attrs["padding"] == Padding.EXPLICIT
+        assert "explicit_padding" in op.attrs
+        assert op.ifm.shape == op.ofm.shape
+        assert pad_op not in op.ifm.ops
+    else:
+        assert pad_op in op.ifm.ops
+        assert op.attrs["padding"] == Padding.VALID
+        assert "explicit_padding" not in op.attrs
+
+
 def test_optimise_pad_followed_by_avg_pool():
     """
     Tests that the PAD operator is bypassed when followed by a average pool operator,
@@ -166,7 +211,8 @@
     quant = testutil.default_quant_params()
     in_tens = Tensor([1, 76, 75, 64], DataType.uint8, "input")
     in_tens.quantization = quant
-    pad_input = create_const_tensor("pad_input", [4, 2], DataType.int32, [[0, 0], [2, 1], [1, 1], [0, 0]])
+    # Test with 3x2 input tensor
+    pad_input = create_const_tensor("pad_input", [3, 2], DataType.int32, [[2, 2], [1, 1], [0, 0]])
     temp_tens = Tensor([1, 79, 77, 64], DataType.uint8, "pad_out")
     temp_tens.quantization = quant.clone()
     out_tens = Tensor([1, 76, 75, 64], DataType.uint8, "output")
@@ -185,25 +231,99 @@
     pad_op.run_on_npu = True
     conv2d_op = testutil.create_op(Op.AvgPool, [temp_tens], out_tens, attrs)
     conv2d_op.run_on_npu = True
-    nng = Graph()
-    sg = testutil.create_subgraph([pad_op, conv2d_op])
-    nng.subgraphs.append(sg)
+    nng = testutil.create_graph([pad_op, conv2d_op])
     arch = testutil.create_arch()
 
-    optimise_pad(conv2d_op, nng, arch)
+    replace_pad_by_hw_pad(conv2d_op, nng, arch)
 
-    op = sg.output_tensors[0].ops[0]
+    op = nng.subgraphs[0].output_tensors[0].ops[0]
     assert op.type == Op.DepthwiseConv2DBias
     assert op.attrs["padding"] == Padding.EXPLICIT
-    assert op.attrs["explicit_padding"] == (2, 1, 1, 1)
+    assert op.attrs["explicit_padding"] == (2, 1, 2, 1)
     assert op.ifm.shape == [1, 76, 75, 64]
     assert pad_op not in op.ifm.ops
     # Check that bias and weight tensors have been added
     assert op.bias.shape == [64]
-    print("op.weights:", op.weights)
     assert op.weights.shape == [5, 3, 1, 64]
 
 
+pad_avg_pool_test_data = [
+    ((3, 3), (1, 1, 1, 1), True),
+    ((3, 3), (2, 1, 1, 1), False),
+    ((3, 3), (1, 2, 1, 1), False),
+    ((3, 3), (1, 1, 2, 1), False),
+    ((3, 3), (1, 1, 1, 2), False),
+    ((2, 4), (1, 2, 1, 2), True),
+    ((5, 3), (2, 1, 2, 1), True),
+    ((5, 3), (0, 1, 2, 1), True),
+    ((5, 3), (2, 0, 2, 1), True),
+    ((5, 3), (2, 1, 0, 1), True),
+    ((5, 3), (2, 1, 0, 1), True),
+    ((4, 4), (2, 2, 2, 2), True),
+    ((4, 4), (1, 2, 2, 2), False),
+    ((4, 4), (2, 1, 2, 2), False),
+    ((4, 4), (2, 2, 1, 2), False),
+    ((4, 4), (2, 2, 2, 1), False),
+]
+
+
+@pytest.mark.parametrize("k_size, padding, expect_pad_removed", pad_avg_pool_test_data)
+def test_pad_followed_by_avg_pool(k_size, padding, expect_pad_removed):
+    # Tests PAD followed by AvgPool
+    k_w, k_h = k_size
+    top, left, bottom, right = padding
+    pad_values = [[0, 0], [top, bottom], [left, right], [0, 0]]
+    dtype = DataType.int8
+    qp = testutil.default_quant_params()
+    in_shape = [1, 15, 17, 8]
+    out_shape = [1, in_shape[1] + top + bottom, in_shape[2] + left + right, in_shape[3]]
+    in0 = Tensor(in_shape, dtype, "in")
+    in0.quantization = qp
+    pad_tensor = create_const_tensor(
+        name="pad", shape=list(np.shape(pad_values)), values=pad_values, dtype=DataType.int32
+    )
+    out = Tensor(out_shape, dtype, "out")
+    out.quantization = qp.clone()
+    pad_op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
+    pool_out_tens = Tensor(in_shape, dtype, "output")
+    pool_out_tens.quantization = qp.clone()
+    attrs = {
+        "padding": Padding.VALID,
+        "ksize": [1, k_w, k_h, 1],
+        "stride_w": 1,
+        "stride_h": 1,
+        "dilation_w_factor": 1,
+        "dilation_h_factor": 1,
+    }
+    pool_op = testutil.create_op(Op.AvgPool, [out], pool_out_tens, attrs)
+    pool_op.add_input_tensor(out)
+    pad_op.run_on_npu = True
+    pool_op.run_on_npu = True
+    nng = testutil.create_graph([pad_op, pool_op])
+    arch = testutil.create_arch()
+    nng = optimise_graph_a(nng, arch)
+    sg = nng.subgraphs[0]
+    all_ops = sg.get_all_ops()
+    print("all_ops: ", all_ops)
+    # Pad should not be in the graph anymore, it should either have been removed or rewritten
+    assert not any(op.type == Op.Pad for op in all_ops)
+    op = nng.subgraphs[0].output_tensors[0].ops[0]
+    if expect_pad_removed:
+        # Expect rewrite to depthwise, PAD is removed
+        assert op.type == Op.DepthwiseConv2DBias
+        assert op.attrs["padding"] == Padding.EXPLICIT
+        assert any(pad > 0 for pad in op.attrs["explicit_padding"])
+        assert op.ifm.shape == op.ofm.shape
+        # Check that bias and weight tensors have been added
+        assert len(op.bias.shape) > 0
+        assert op.weights.shape is not None
+    else:
+        # Pad should have been rewritten to a number of average pool operations
+        assert all(op.type in (Op.AvgPool, Op.Const) for op in all_ops)
+        assert pool_op.type == Op.AvgPool
+        assert pool_op.attrs["padding"] == Padding.VALID
+
+
 def test_remove_reshape():
     """
     Tests that the expected reshape are removed in graph_optimisation
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index cd331fd..34ddb90 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -17,7 +17,6 @@
 # Description:
 # Unit tests for support_operators
 import numpy as np
-import pytest
 
 from ethosu.vela.data_type import DataType
 from ethosu.vela.operation import ActivationFunction
@@ -529,14 +528,7 @@
 
 
 def create_pad_op(
-    in_shape,
-    out_shape,
-    padding,
-    in_dtype=DataType.int8,
-    out_dtype=DataType.int8,
-    pad_dtype=DataType.int32,
-    pad_setting=Padding.VALID,
-    kernel_size=3,
+    in_shape, out_shape, padding, in_dtype=DataType.int8, out_dtype=DataType.int8, pad_dtype=DataType.int32,
 ):
     qp = testutil.default_quant_params()
     in0 = Tensor(in_shape, in_dtype, "in")
@@ -545,17 +537,6 @@
     out = Tensor(out_shape, out_dtype, "out")
     out.quantization = qp.clone()
     op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
-    conv_out_tens = Tensor(in_shape, in_dtype, "output")
-    conv_out_tens.quantization = qp.clone()
-    weight_tens = Tensor([kernel_size, kernel_size, in_shape[-1], out_shape[-1]], in_dtype, "weights")
-    weight_tens.values = np.zeros(weight_tens.shape)
-    weight_tens.quant_values = np.zeros(weight_tens.shape, np.int8)
-    weight_tens.quantization = qp.clone()
-    bias_tens = Tensor(out_shape, pad_dtype, "biases")
-    attrs = {"padding": pad_setting, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
-    attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
-    conv2d_op = testutil.create_op(Op.Conv2DBias, [out, weight_tens, bias_tens], conv_out_tens, attrs)
-    conv2d_op.add_input_tensor(out)
     return op
 
 
@@ -571,10 +552,16 @@
     # Incorrect padding dimensions, can only pad width and height
     op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [1, 1], [0, 0]],)
     assert not support.is_operator_supported(op)
+    op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [0, 0]],)
+    assert support.is_operator_supported(op)
+    op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [0, 1]],)
+    assert not support.is_operator_supported(op)
 
 
 def test_constraint_pad_shape():
-    # PAD operator must be of shape (4,2)
+    # PAD operator must be of shape (3,2) or (4,2)
+    op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[1, 1], [1, 1], [0, 0]])
+    assert support.is_operator_supported(op)
     op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0], [0, 0]],)
     assert not support.is_operator_supported(op)
 
@@ -595,108 +582,6 @@
     assert not support.is_operator_supported(op)
 
 
-def test_constraint_pad_consumer():
-    # PAD operator must be followed by a valid consumer with Padding.VALID attribute
-    op = create_pad_op(in_shape=[1, 1, 1, 1], out_shape=[1, 3, 3, 1], padding=[[0, 0], [1, 1], [1, 1], [0, 0]],)
-    assert support.is_operator_supported(op)
-    op = create_pad_op(
-        in_shape=[1, 1, 1, 1],
-        out_shape=[1, 3, 3, 1],
-        padding=[[0, 0], [1, 1], [1, 1], [0, 0]],
-        pad_setting=Padding.SAME,
-    )
-    assert not support.is_operator_supported(op)
-    op_consumer = testutil.create_op_with_quant_tensors(Op.ConcatTFLite, [1, 1, 1, 4], [1, 1, 1, 8])
-    op.ofm.consumer_list = [op_consumer]
-    assert not support.is_operator_supported(op)
-    op_consumer = testutil.create_elemwise_op(Op.Add, "op", [1, 3, 3, 1], [1, 3, 3, 1], [1, 3, 3, 1])
-    op.ofm.consumer_list = [op_consumer]
-    assert not support.is_operator_supported(op)
-
-
-pad_invalid_size_test_data = [
-    (2, 1, 1, 1),
-    (1, 2, 1, 1),
-    (1, 1, 2, 1),
-    (1, 1, 1, 2),
-]
-
-
-@pytest.mark.parametrize("top, left, bottom, right", pad_invalid_size_test_data)
-def test_constraint_pad_size(top, left, bottom, right):
-    # Tests PAD operator with a padding that is too high to be handled by the NPU
-    out_shape = [1, 11 + left + right, 11 + top + bottom, 1]
-    padding = [[0, 0], [top, bottom], [left, right], [0, 0]]
-    op = create_pad_op(in_shape=[1, 11, 11, 1], out_shape=out_shape, padding=padding,)
-    assert not support.is_operator_supported(op)
-
-
-leading_pad_test_data = [
-    (2, 2, 11, True),
-    (1, 2, 11, False),
-    (2, 1, 11, False),
-    (5, 2, 11, True),
-]
-
-
-@pytest.mark.parametrize("top, left, kernel_size, expected", leading_pad_test_data)
-def test_constraint_leading_pad_size(top, left, kernel_size, expected):
-    # Tests PAD operator with big kernel size; top and left pad must be multiple of stride
-    out_shape = [1, 11 + left, 11 + top, 1]
-    padding = [[0, 0], [top, 0], [left, 0], [0, 0]]
-    op = create_pad_op(in_shape=[1, 11, 11, 1], out_shape=out_shape, padding=padding, kernel_size=kernel_size)
-    assert support.is_operator_supported(op) == expected
-
-
-pad_avg_pool_test_data = [
-    ((3, 3), (1, 1, 1, 1), True),
-    ((2, 4), (1, 2, 1, 2), True),
-    ((5, 3), (2, 1, 2, 1), True),
-    ((5, 3), (0, 1, 2, 1), True),
-    ((5, 3), (2, 0, 2, 1), True),
-    ((5, 3), (2, 1, 0, 1), True),
-    ((5, 3), (2, 1, 0, 1), True),
-    ((4, 4), (2, 2, 2, 2), True),
-    ((4, 4), (1, 2, 2, 2), False),
-    ((4, 4), (2, 1, 2, 2), False),
-    ((4, 4), (2, 2, 1, 2), False),
-    ((4, 4), (2, 2, 2, 1), False),
-]
-
-
-@pytest.mark.parametrize("k_size, padding, expected", pad_avg_pool_test_data)
-def test_pad_followed_by_avg_pool(k_size, padding, expected):
-    # Tests PAD followed by AvgPool
-    k_w, k_h = k_size
-    top, left, bottom, right = padding
-    pad_values = [[0, 0], [top, bottom], [left, right], [0, 0]]
-    dtype = DataType.int8
-    qp = testutil.default_quant_params()
-    in_shape = [1, 15, 17, 8]
-    out_shape = [1, in_shape[1] + top + bottom, in_shape[2] + left + right, in_shape[3]]
-    in0 = Tensor(in_shape, dtype, "in")
-    in0.quantization = qp
-    pad_tensor = create_const_tensor(
-        name="pad", shape=list(np.shape(pad_values)), values=pad_values, dtype=DataType.int32
-    )
-    out = Tensor(out_shape, dtype, "out")
-    out.quantization = qp.clone()
-    op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
-    pool_out_tens = Tensor(in_shape, dtype, "output")
-    pool_out_tens.quantization = qp.clone()
-    attrs = {
-        "padding": Padding.VALID,
-        "ksize": [1, k_w, k_h, 1],
-        "stride_w": 1,
-        "stride_h": 1,
-        "dilation_w_factor": 1,
-        "dilation_h_factor": 1,
-    }
-    pool_op = testutil.create_op(Op.AvgPool, [out], pool_out_tens, attrs)
-    pool_op.add_input_tensor(out)
-    assert support.is_operator_supported(op) == expected
-
-
 def create_strided_slice():
     # Creates a valid strided slice operator with some valid inputs/outputs
     op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index aef5f61..25dc801 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -19,6 +19,8 @@
 
 from ethosu.vela import architecture_features
 from ethosu.vela.data_type import DataType
+from ethosu.vela.nn_graph import Graph
+from ethosu.vela.nn_graph import PassPlacement
 from ethosu.vela.nn_graph import Subgraph
 from ethosu.vela.operation import Op
 from ethosu.vela.operation import Operation
@@ -128,6 +130,7 @@
 def create_subgraph(op_list):
     # Creates subgraph using the given list of operations
     sg = Subgraph()
+    sg.placement = PassPlacement.Npu
     all_inputs = set(tens for op in op_list for tens in op.inputs)
     # Reversing, so that the resulting subgraph has same order as op_list
     for op in op_list[::-1]:
@@ -135,3 +138,11 @@
             if tens not in all_inputs and tens not in sg.output_tensors:
                 sg.output_tensors.append(tens)
     return sg
+
+
+def create_graph(op_list):
+    # Creates subgraph using the given list of operations
+    nng = Graph()
+    sg = create_subgraph(op_list)
+    nng.subgraphs.append(sg)
+    return nng