MLBEDSW-3499: Support for PAD operator

Replaces the PAD operator by hardware padding when possible.

Change-Id: I9dce0885e51a4a73715824d7368637222e39b2b3
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 3759d3b..00edf83 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -156,7 +156,7 @@
     return total_padding
 
 
-def calc_padding_and_skirt(padding_type, kernel_size, stride, input_dims):
+def calc_padding_and_skirt(padding_type, kernel_size, stride, input_dims, explicit_padding):
     ypad = needed_total_padding(int(input_dims[1]), int(stride[1]), int(kernel_size[0]))
     xpad = needed_total_padding(int(input_dims[2]), int(stride[2]), int(kernel_size[1]))
     if padding_type == Padding.SAME:
@@ -169,6 +169,12 @@
         right_pad = 0
         top_pad = 0
         bottom_pad = 0
+    elif padding_type == Padding.EXPLICIT:
+        # Padding is specified in a PAD operator which has been bypassed.
+        # The top and left padding are taken from the PAD; bottom and right are calculated.
+        top_pad, left_pad, _, _ = explicit_padding
+        bottom_pad = ypad - top_pad
+        right_pad = xpad - left_pad
     else:
         raise UnsupportedFeatureError(f"Unknown padding")
     padding = (top_pad, left_pad, bottom_pad, right_pad)
@@ -537,7 +543,11 @@
                 dilation_h, dilation_w = op.get_dilation_h_w()
                 dilated_kernel_size = [dilation_h * (kernel_size[0] - 1) + 1, dilation_w * (kernel_size[1] - 1) + 1]
                 padding, skirt = calc_padding_and_skirt(
-                    op.attrs["padding"], dilated_kernel_size, op.attrs["strides"], input_shape
+                    op.attrs["padding"],
+                    dilated_kernel_size,
+                    op.attrs["strides"],
+                    input_shape,
+                    op.attrs.get("explicit_padding"),
                 )
 
             op.attrs["explicit_padding"] = padding
@@ -1122,6 +1132,30 @@
     return op
 
 
+def optimise_pad(op, arch, nng):
+    """
+    Converts tens1 -> PAD -> tens2 -> CONV to tens1 -> CONV
+    if both operations can be run on the NPU.
+    """
+    if (
+        (op.type.is_conv2d_op() or op.type.is_depthwise_conv2d_op())
+        and op.run_on_npu
+        and op.attrs["padding"] == Padding.VALID
+    ):
+        pad_op = op.ifm.ops[0]
+        if pad_op.type != Op.Pad or not pad_op.run_on_npu:
+            return op
+        # Bypass the PAD operator
+        op.set_input_tensor(pad_op.ifm, 0)
+        # Adjust the padding attributes of the convolution operator
+        op.attrs["padding"] = Padding.EXPLICIT
+        padding = pad_op.inputs[1].values  # 4x2 tensor, first dimension is N, H, W, C
+        top, left, bottom, right = (padding[1][0], padding[2][0], padding[1][1], padding[2][1])
+        op.attrs["explicit_padding"] = (top, left, bottom, right)
+        op.set_ifm_ofm_shapes()
+    return op
+
+
 def add_attrs_to_resizebilinear(op, arch, nng):
     if op.type == Op.ResizeBilinear and op.run_on_npu:
         input_tensor = op.inputs[0]
@@ -1213,7 +1247,11 @@
     for idx, sg in enumerate(nng.subgraphs):
         # remove passthrough tensors and attempt further optimizations
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields],
+            nng,
+            sg,
+            arch,
+            [remove_passthrough_tensor],
+            [fuse_activation_function_with_prev, optimise_pad, add_padding_fields],
         )
 
     # Post-optimisation operator debug tracing
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index c80e18b..af36587 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -201,7 +201,7 @@
     OneHot = OperatorInfo()
     Pack = OperatorInfo(indices=IFM_INDICES)
     PackReshaped = OperatorInfo(indices=IFM_INDICES)
-    Pad = OperatorInfo()
+    Pad = OperatorInfo(indices=IFM_INDICES)
     PadV2 = OperatorInfo()
     Placeholder = OperatorInfo()  # Only used in CPU subgraphs
     Pow = OperatorInfo()
@@ -335,6 +335,7 @@
 class Padding(Enum):
     SAME = 0
     VALID = 1
+    EXPLICIT = 2  # Padding is specified in a PAD operation (only used for NPU operations)
 
 
 class ActivationFunction:
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index 7fdc4bd..b3938bc 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -18,8 +18,12 @@
 # Unit tests for graph_optimiser
 import numpy as np
 
+from ethosu.vela.data_type import DataType
 from ethosu.vela.graph_optimiser import convert_batched_fc_shape
+from ethosu.vela.graph_optimiser import optimise_pad
+from ethosu.vela.nn_graph import Graph
 from ethosu.vela.operation import Op
+from ethosu.vela.operation import Padding
 from ethosu.vela.tensor import create_const_tensor
 from ethosu.vela.tensor import Shape4D
 from ethosu.vela.tensor import Tensor
@@ -73,3 +77,44 @@
     assert conv_op.type == Op.FullyConnected
     assert len(conv_op.ifm.shape) == 2
     assert conv_op.ifm.shape == conv_op.ofm.shape
+
+
+def test_optimise_pad():
+    """
+    Tests that the PAD operator is bypassed when followed by a convolution operator,
+    and that the padding of the convolution operation is correctly updated
+    """
+    # Create Pad operation followed by Conv2D
+    quant = testutil.default_quant_params()
+    in_tens = Tensor([1, 76, 75, 64], DataType.uint8, "input")
+    in_tens.quantization = quant
+    pad_input = create_const_tensor("pad_input", [4, 2], DataType.int32, [[0, 0], [2, 1], [1, 1], [0, 0]])
+    temp_tens = Tensor([1, 79, 77, 64], DataType.uint8, "pad_out")
+    temp_tens.quantization = quant.clone()
+    out_tens = Tensor([1, 76, 75, 64], DataType.uint8, "output")
+    out_tens.quantization = quant.clone()
+    weight_tens = Tensor([5, 3, 64, 64], DataType.uint8, "weights")
+    weight_tens.values = np.zeros(weight_tens.shape)
+    weight_tens.quant_values = np.zeros(weight_tens.shape, np.uint8)
+    weight_tens.quantization = quant.clone()
+
+    bias_tens = Tensor([64], DataType.int32, "biases")
+    pad_op = testutil.create_op(Op.Pad, [in_tens, pad_input], temp_tens)
+    attrs = {"padding": Padding.VALID, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
+    attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
+    pad_op.run_on_npu = True
+    conv2d_op = testutil.create_op(Op.Conv2D, [temp_tens, weight_tens, bias_tens], out_tens, attrs)
+    conv2d_op.run_on_npu = True
+    nng = Graph()
+    sg = testutil.create_subgraph([pad_op, conv2d_op])
+    nng.subgraphs.append(sg)
+    arch = testutil.create_arch()
+
+    optimise_pad(conv2d_op, nng, arch)
+
+    op = sg.output_tensors[0].ops[0]
+    assert op.type == Op.Conv2D
+    assert op.attrs["padding"] == Padding.EXPLICIT
+    assert op.attrs["explicit_padding"] == (2, 1, 1, 1)
+    assert op.ifm.shape == [1, 76, 75, 64]
+    assert pad_op not in op.ifm.ops
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index c345950..96aeb7e 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -115,8 +115,9 @@
 
 def create_op(op_type, inputs, output, attrs=None):
     op = Operation(op_type, output.name + "_op")
-    op.inputs = inputs
-    op.outputs = [output]
+    for input in inputs:
+        op.add_input_tensor(input)
+    op.set_output_tensor(output)
     if attrs is not None:
         op.attrs = attrs
     op.set_ifm_ofm_shapes()