MLBEDSW-4223: Full support for PAD operator

- Added full support for PAD operator
- Hardware padding is still used whenever possible
- Bug fix Pad followed by max pool if IFM contains negative values

Change-Id: Ifc64d1943737d94466f5e2821009dab12a49a965
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index 285b3ac..d9e171d 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -23,7 +23,7 @@
 from ethosu.vela.graph_optimiser import calc_explicit_padding
 from ethosu.vela.graph_optimiser import convert_batched_fc_shape
 from ethosu.vela.graph_optimiser import optimise_graph_a
-from ethosu.vela.graph_optimiser import optimise_pad
+from ethosu.vela.graph_optimiser import replace_pad_by_hw_pad
 from ethosu.vela.graph_optimiser import rewrite_fully_connected_input
 from ethosu.vela.nn_graph import Graph
 from ethosu.vela.operation import Op
@@ -116,47 +116,92 @@
     assert (before, after) == expected_result
 
 
-def test_optimise_pad():
+def create_pad_and_conv2d(
+    in_shape,
+    out_shape,
+    padding,
+    in_dtype=DataType.int8,
+    out_dtype=DataType.int8,
+    pad_dtype=DataType.int32,
+    pad_setting=Padding.VALID,
+    kernel_size=3,
+):
+    """Creates Pad operator followed by a conv2d operator"""
+    qp = testutil.default_quant_params()
+    in0 = Tensor(in_shape, in_dtype, "in")
+    in0.quantization = qp
+    pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
+    out = Tensor(out_shape, out_dtype, "out")
+    out.quantization = qp.clone()
+    op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
+    op.run_on_npu = True
+    conv_out_tens = Tensor(in_shape, in_dtype, "output")
+    conv_out_tens.quantization = qp.clone()
+    weight_tens = Tensor([kernel_size, kernel_size, in_shape[-1], out_shape[-1]], in_dtype, "weights")
+    weight_tens.values = np.zeros(weight_tens.shape)
+    weight_tens.quant_values = np.zeros(weight_tens.shape, np.int8)
+    weight_tens.quantization = qp.clone()
+    bias_tens = Tensor(out_shape, pad_dtype, "biases")
+    attrs = {"padding": pad_setting, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
+    attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
+    conv2d_op = testutil.create_op(Op.Conv2DBias, [out, weight_tens, bias_tens], conv_out_tens, attrs)
+    conv2d_op.add_input_tensor(out)
+    conv2d_op.run_on_npu = True
+    return op, conv2d_op
+
+
+def test_pad_followed_by_conv_is_removed():
     """
     Tests that the PAD operator is bypassed when followed by a convolution operator,
     and that the padding of the convolution operation is correctly updated
     """
-    # Create Pad operation followed by Conv2D
-    quant = testutil.default_quant_params()
-    in_tens = Tensor([1, 76, 75, 64], DataType.uint8, "input")
-    in_tens.quantization = quant
-    pad_input = create_const_tensor("pad_input", [4, 2], DataType.int32, [[0, 0], [2, 1], [1, 1], [0, 0]])
-    temp_tens = Tensor([1, 79, 77, 64], DataType.uint8, "pad_out")
-    temp_tens.quantization = quant.clone()
-    out_tens = Tensor([1, 76, 75, 64], DataType.uint8, "output")
-    out_tens.quantization = quant.clone()
-    weight_tens = Tensor([5, 3, 64, 64], DataType.uint8, "weights")
-    weight_tens.values = np.zeros(weight_tens.shape)
-    weight_tens.quant_values = np.zeros(weight_tens.shape, np.uint8)
-    weight_tens.quantization = quant.clone()
-
-    bias_tens = Tensor([64], DataType.int32, "biases")
-    pad_op = testutil.create_op(Op.Pad, [in_tens, pad_input], temp_tens)
-    attrs = {"padding": Padding.VALID, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
-    attrs["strides"] = (1, attrs["stride_h"], attrs["stride_w"], 1)
-    pad_op.run_on_npu = True
-    conv2d_op = testutil.create_op(Op.Conv2D, [temp_tens, weight_tens, bias_tens], out_tens, attrs)
-    conv2d_op.run_on_npu = True
-    nng = Graph()
-    sg = testutil.create_subgraph([pad_op, conv2d_op])
-    nng.subgraphs.append(sg)
+    pad_op, conv2d_op = create_pad_and_conv2d(
+        in_shape=[1, 76, 75, 64], out_shape=[1, 76, 75, 64], padding=[[0, 0], [2, 1], [1, 1], [0, 0]], kernel_size=4
+    )
+    nng = testutil.create_graph([pad_op, conv2d_op])
     arch = testutil.create_arch()
 
-    optimise_pad(conv2d_op, nng, arch)
+    replace_pad_by_hw_pad(conv2d_op, nng, arch)
 
-    op = sg.output_tensors[0].ops[0]
-    assert op.type == Op.Conv2D
+    op = nng.subgraphs[0].output_tensors[0].ops[0]
+    assert op.type == Op.Conv2DBias
     assert op.attrs["padding"] == Padding.EXPLICIT
     assert op.attrs["explicit_padding"] == (2, 1, 1, 1)
     assert op.ifm.shape == [1, 76, 75, 64]
     assert pad_op not in op.ifm.ops
 
 
+leading_pad_test_data = [
+    (2, 2, 11, True),
+    (1, 2, 11, False),
+    (2, 1, 11, False),
+    (5, 2, 11, True),
+]
+
+
+@pytest.mark.parametrize("top, left, kernel_size, expect_pad_removed", leading_pad_test_data)
+def test_leading_pad_size(top, left, kernel_size, expect_pad_removed):
+    # Tests PAD operator with big kernel size; top and left pad must be multiple of stride
+    out_shape = [1, 11 + left, 11 + top, 1]
+    padding = [[0, 0], [top, 0], [left, 0], [0, 0]]
+    pad_op, conv2d_op = create_pad_and_conv2d(
+        in_shape=[1, 11, 11, 1], out_shape=out_shape, padding=padding, kernel_size=kernel_size
+    )
+    nng = testutil.create_graph([pad_op, conv2d_op])
+    arch = testutil.create_arch()
+    replace_pad_by_hw_pad(conv2d_op, nng, arch)
+    op = nng.subgraphs[0].output_tensors[0].ops[0]
+    if expect_pad_removed:
+        assert op.attrs["padding"] == Padding.EXPLICIT
+        assert "explicit_padding" in op.attrs
+        assert op.ifm.shape == op.ofm.shape
+        assert pad_op not in op.ifm.ops
+    else:
+        assert pad_op in op.ifm.ops
+        assert op.attrs["padding"] == Padding.VALID
+        assert "explicit_padding" not in op.attrs
+
+
 def test_optimise_pad_followed_by_avg_pool():
     """
     Tests that the PAD operator is bypassed when followed by a average pool operator,
@@ -166,7 +211,8 @@
     quant = testutil.default_quant_params()
     in_tens = Tensor([1, 76, 75, 64], DataType.uint8, "input")
     in_tens.quantization = quant
-    pad_input = create_const_tensor("pad_input", [4, 2], DataType.int32, [[0, 0], [2, 1], [1, 1], [0, 0]])
+    # Test with 3x2 input tensor
+    pad_input = create_const_tensor("pad_input", [3, 2], DataType.int32, [[2, 2], [1, 1], [0, 0]])
     temp_tens = Tensor([1, 79, 77, 64], DataType.uint8, "pad_out")
     temp_tens.quantization = quant.clone()
     out_tens = Tensor([1, 76, 75, 64], DataType.uint8, "output")
@@ -185,25 +231,99 @@
     pad_op.run_on_npu = True
     conv2d_op = testutil.create_op(Op.AvgPool, [temp_tens], out_tens, attrs)
     conv2d_op.run_on_npu = True
-    nng = Graph()
-    sg = testutil.create_subgraph([pad_op, conv2d_op])
-    nng.subgraphs.append(sg)
+    nng = testutil.create_graph([pad_op, conv2d_op])
     arch = testutil.create_arch()
 
-    optimise_pad(conv2d_op, nng, arch)
+    replace_pad_by_hw_pad(conv2d_op, nng, arch)
 
-    op = sg.output_tensors[0].ops[0]
+    op = nng.subgraphs[0].output_tensors[0].ops[0]
     assert op.type == Op.DepthwiseConv2DBias
     assert op.attrs["padding"] == Padding.EXPLICIT
-    assert op.attrs["explicit_padding"] == (2, 1, 1, 1)
+    assert op.attrs["explicit_padding"] == (2, 1, 2, 1)
     assert op.ifm.shape == [1, 76, 75, 64]
     assert pad_op not in op.ifm.ops
     # Check that bias and weight tensors have been added
     assert op.bias.shape == [64]
-    print("op.weights:", op.weights)
     assert op.weights.shape == [5, 3, 1, 64]
 
 
+pad_avg_pool_test_data = [
+    ((3, 3), (1, 1, 1, 1), True),
+    ((3, 3), (2, 1, 1, 1), False),
+    ((3, 3), (1, 2, 1, 1), False),
+    ((3, 3), (1, 1, 2, 1), False),
+    ((3, 3), (1, 1, 1, 2), False),
+    ((2, 4), (1, 2, 1, 2), True),
+    ((5, 3), (2, 1, 2, 1), True),
+    ((5, 3), (0, 1, 2, 1), True),
+    ((5, 3), (2, 0, 2, 1), True),
+    ((5, 3), (2, 1, 0, 1), True),
+    ((5, 3), (2, 1, 0, 1), True),
+    ((4, 4), (2, 2, 2, 2), True),
+    ((4, 4), (1, 2, 2, 2), False),
+    ((4, 4), (2, 1, 2, 2), False),
+    ((4, 4), (2, 2, 1, 2), False),
+    ((4, 4), (2, 2, 2, 1), False),
+]
+
+
+@pytest.mark.parametrize("k_size, padding, expect_pad_removed", pad_avg_pool_test_data)
+def test_pad_followed_by_avg_pool(k_size, padding, expect_pad_removed):
+    # Tests PAD followed by AvgPool
+    k_w, k_h = k_size
+    top, left, bottom, right = padding
+    pad_values = [[0, 0], [top, bottom], [left, right], [0, 0]]
+    dtype = DataType.int8
+    qp = testutil.default_quant_params()
+    in_shape = [1, 15, 17, 8]
+    out_shape = [1, in_shape[1] + top + bottom, in_shape[2] + left + right, in_shape[3]]
+    in0 = Tensor(in_shape, dtype, "in")
+    in0.quantization = qp
+    pad_tensor = create_const_tensor(
+        name="pad", shape=list(np.shape(pad_values)), values=pad_values, dtype=DataType.int32
+    )
+    out = Tensor(out_shape, dtype, "out")
+    out.quantization = qp.clone()
+    pad_op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
+    pool_out_tens = Tensor(in_shape, dtype, "output")
+    pool_out_tens.quantization = qp.clone()
+    attrs = {
+        "padding": Padding.VALID,
+        "ksize": [1, k_w, k_h, 1],
+        "stride_w": 1,
+        "stride_h": 1,
+        "dilation_w_factor": 1,
+        "dilation_h_factor": 1,
+    }
+    pool_op = testutil.create_op(Op.AvgPool, [out], pool_out_tens, attrs)
+    pool_op.add_input_tensor(out)
+    pad_op.run_on_npu = True
+    pool_op.run_on_npu = True
+    nng = testutil.create_graph([pad_op, pool_op])
+    arch = testutil.create_arch()
+    nng = optimise_graph_a(nng, arch)
+    sg = nng.subgraphs[0]
+    all_ops = sg.get_all_ops()
+    print("all_ops: ", all_ops)
+    # Pad should not be in the graph anymore, it should either have been removed or rewritten
+    assert not any(op.type == Op.Pad for op in all_ops)
+    op = nng.subgraphs[0].output_tensors[0].ops[0]
+    if expect_pad_removed:
+        # Expect rewrite to depthwise, PAD is removed
+        assert op.type == Op.DepthwiseConv2DBias
+        assert op.attrs["padding"] == Padding.EXPLICIT
+        assert any(pad > 0 for pad in op.attrs["explicit_padding"])
+        assert op.ifm.shape == op.ofm.shape
+        # Check that bias and weight tensors have been added
+        assert len(op.bias.shape) > 0
+        assert op.weights.shape is not None
+    else:
+        # Pad should have been rewritten to a number of average pool operations
+        assert all(op.type in (Op.AvgPool, Op.Const) for op in all_ops)
+        assert pool_op.type == Op.AvgPool
+        assert pool_op.attrs["padding"] == Padding.VALID
+
+
 def test_remove_reshape():
     """
     Tests that the expected reshape are removed in graph_optimisation