MLBEDSW-3903: Bug fix PAD operator

- Added checks for unsupported pad sizes in PAD operator
- Bug fix right pad/bottom pad calculation when replacing PAD operator
  by hardware padding

Change-Id: Ib84be711277d987052f14352ab386e0e0b774987
Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 7755cc3..2d47c26 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -18,6 +18,7 @@
 # split into two parts optimise_graph_a and optimise_graph_b.
 import math
 import uuid
+from typing import Tuple
 
 import numpy as np
 
@@ -183,9 +184,26 @@
     return total_padding
 
 
-def calc_padding_and_skirt(padding_type, kernel_size, stride, input_shape, explicit_padding):
-    ypad = needed_total_padding(int(input_shape.height), int(stride[1]), int(kernel_size[0]))
-    xpad = needed_total_padding(int(input_shape.width), int(stride[2]), int(kernel_size[1]))
+def calc_explicit_padding(input_size, stride, filter_size, pad_before, pad_after) -> Tuple[int, int]:
+    """
+    Based on explicit padding provided in a PAD operation, returns the corresponding hardware padding
+    that provides equivalent results.
+    """
+    total_padding = needed_total_padding(input_size, stride, filter_size)
+    # The top/left padding can be taken as is from the PAD
+    output_pad_before = pad_before
+    # The bottom/right padding might need downward adjustment depending on stride/input size
+    output_pad_after = pad_after
+    while output_pad_after > 0 and output_pad_after % stride != (total_padding - pad_before) % stride:
+        output_pad_after -= 1
+    return output_pad_before, output_pad_after
+
+
+def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
+    k_w, k_h = kernel.dilated_wh()
+    s_x, s_y = kernel.stride
+    ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
+    xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
     if padding_type == Padding.SAME:
         left_pad = (xpad + 0) // 2
         right_pad = (xpad + 1) // 2
@@ -198,10 +216,9 @@
         bottom_pad = 0
     elif padding_type == Padding.EXPLICIT:
         # Padding is specified in a PAD operator which has been bypassed.
-        # The top and left padding are taken from the PAD; bottom and right are calculated.
-        top_pad, left_pad, _, _ = explicit_padding
-        bottom_pad = ypad - top_pad
-        right_pad = xpad - left_pad
+        top, left, bottom, right = explicit_padding
+        top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
+        left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
     else:
         raise UnsupportedFeatureError(f"Unknown padding")
     padding = (top_pad, left_pad, bottom_pad, right_pad)
@@ -495,14 +512,8 @@
                     op.attrs["padding"], kernel_size, op.attrs["strides"], input_shape, upscaling_factor
                 )
             else:
-                dilation_h, dilation_w = op.get_dilation_h_w()
-                dilated_kernel_size = [dilation_h * (kernel_size[0] - 1) + 1, dilation_w * (kernel_size[1] - 1) + 1]
                 padding, skirt = calc_padding_and_skirt(
-                    op.attrs["padding"],
-                    dilated_kernel_size,
-                    op.attrs["strides"],
-                    input_shape,
-                    op.attrs.get("explicit_padding"),
+                    op.attrs["padding"], op.kernel, input_shape, op.attrs.get("explicit_padding"),
                 )
 
             op.attrs["explicit_padding"] = padding
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 73953ce..963d9e6 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -22,6 +22,7 @@
 from typing import Dict
 from typing import List
 from typing import Optional
+from typing import Tuple
 from typing import TYPE_CHECKING
 
 from .errors import VelaError
@@ -68,6 +69,10 @@
     def area_height(self) -> int:
         return (self.height - 1) * self.dilation.y + 1
 
+    def dilated_wh(self) -> Tuple[int, int]:
+        """Returns the dilated kernel width/height"""
+        return self.dilation.x * (self.width - 1) + 1, self.dilation.y * (self.height - 1) + 1
+
     def __str__(self):
         return f"w={self.width}, h={self.height}, stride={tuple(self.stride)}, dilation={tuple(self.dilation)}"
 
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 99a4ba1..505d4d1 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -260,6 +260,7 @@
         self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_type)
         self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_constant)
         self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_ofm)
+        self.specific_constraints[Op.Pad].append(SupportedOperators.constraint_pad_size)
 
         # HardSwish specific checks:
         self.specific_constraints[Op.HardSwish].append(SupportedOperators.constraint_input_8bit)
@@ -844,6 +845,39 @@
         return valid, f"PAD operator is followed by: {_optype_formatter(unsupported_consumers)+none_string}"
 
     @staticmethod
+    def __leading_pad_ok(leading_pad, stride, kernel_size):
+        # If kernel size // 2 > stride, then (left, top) padding must be a multiple of stride,
+        # otherwise replacing PAD by hardware padding would iterate the wrong IFM rows/columns
+        max_size = kernel_size // 2
+        return leading_pad == max_size or max_size <= stride or leading_pad % stride == 0
+
+    @staticmethod
+    def constraint_pad_size(op):
+        "Padding must be at most kernel size divided by 2"
+        if SupportedOperators.constraint_pad_ofm(op)[0]:
+            padding = op.inputs[1].values  # 4x2 tensor, first dimension is N, H, W, C
+            top, left, bottom, right = (padding[1][0], padding[2][0], padding[1][1], padding[2][1])
+            for cons in op.ofm.consumers():
+                if cons is not None:
+                    # Note: pre-order graph traversal removes inputs of operators that are in traversal,
+                    # which makes it impossible to calculate kernel size, hence use cached _kernel for those operators
+                    k = cons.kernel if cons.inputs else cons._kernel
+                    k_w, k_h = k.dilated_wh()
+                    if left > k_w // 2:
+                        return False, f"Left padding is {left}, kernel width is {k_w}"
+                    if right > k_w // 2:
+                        return False, f"Right padding is {right}, kernel width is {k_w}"
+                    if top > k_h // 2:
+                        return False, f"Top padding is {top}, kernel height is {k_h}"
+                    if bottom > k_h // 2:
+                        return False, f"Bottom padding is {bottom}, kernel height is {k_h}"
+                    if not SupportedOperators.__leading_pad_ok(top, k.stride.y, k_h):
+                        return False, f"Top padding is {top}, must be {k_h // 2} or multiple of {k.stride.y}"
+                    if not SupportedOperators.__leading_pad_ok(left, k.stride.x, k_w):
+                        return False, f"Left padding is {left}, must be {k_w // 2} or multiple of {k.stride.x}"
+        return True, "Pad size is ok"
+
+    @staticmethod
     def constraint_stridedslice_inputs_const(op):
         "Begin, End and Stride Input tensors must be constant"
         valid = True
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index 55980e3..4281d31 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -17,8 +17,10 @@
 # Description:
 # Unit tests for graph_optimiser
 import numpy as np
+import pytest
 
 from ethosu.vela.data_type import DataType
+from ethosu.vela.graph_optimiser import calc_explicit_padding
 from ethosu.vela.graph_optimiser import convert_batched_fc_shape
 from ethosu.vela.graph_optimiser import optimise_graph_a
 from ethosu.vela.graph_optimiser import optimise_pad
@@ -82,6 +84,38 @@
     assert conv_op.ifm.shape == conv_op.ofm.shape
 
 
+explicit_padding_test_data = [
+    # Kernel size 2
+    [(17, 1, 2, 1, 1), (1, 1)],
+    [(18, 1, 2, 0, 1), (0, 1)],
+    [(18, 1, 2, 1, 0), (1, 0)],
+    # Kernel size 3
+    [(18, 2, 3, 1, 1), (1, 0)],
+    [(25, 2, 3, 1, 1), (1, 1)],
+    # Kernel size 4
+    [(18, 1, 4, 1, 2), (1, 2)],
+    [(18, 1, 4, 2, 1), (2, 1)],
+    [(19, 1, 4, 2, 2), (2, 2)],
+    # Kernel size 5
+    [(19, 1, 5, 1, 2), (1, 2)],
+    [(19, 1, 5, 0, 2), (0, 2)],
+    [(19, 1, 5, 1, 0), (1, 0)],
+    # Kernel size 21
+    [(41, 2, 21, 8, 10), (8, 10)],
+    [(41, 3, 21, 10, 10), (10, 9)],
+    [(42, 3, 21, 10, 10), (10, 8)],
+    [(42, 3, 21, 9, 10), (9, 9)],
+    [(41, 3, 21, 10, 6), (10, 6)],
+]
+
+
+@pytest.mark.parametrize("test_input, expected_result", explicit_padding_test_data)
+def test_calc_explicit_padding(test_input, expected_result):
+    input_size, stride, filter_size, explicit_pad_before, explicit_pad_after = test_input
+    before, after = calc_explicit_padding(input_size, stride, filter_size, explicit_pad_before, explicit_pad_after)
+    assert (before, after) == expected_result
+
+
 def test_optimise_pad():
     """
     Tests that the PAD operator is bypassed when followed by a convolution operator,
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index 5c01027..5f64dd9 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -17,6 +17,7 @@
 # Description:
 # Unit tests for support_operators
 import numpy as np
+import pytest
 
 from ethosu.vela.data_type import DataType
 from ethosu.vela.operation import ActivationFunction
@@ -525,6 +526,7 @@
     out_dtype=DataType.int8,
     pad_dtype=DataType.int32,
     pad_setting=Padding.VALID,
+    kernel_size=3,
 ):
     qp = testutil.default_quant_params()
     in0 = Tensor(in_shape, in_dtype, "in")
@@ -535,7 +537,7 @@
     op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
     conv_out_tens = Tensor(in_shape, in_dtype, "output")
     conv_out_tens.quantization = qp.clone()
-    weight_tens = Tensor(in_shape, in_dtype, "weights")
+    weight_tens = Tensor([kernel_size, kernel_size, in_shape[-1], out_shape[-1]], in_dtype, "weights")
     weight_tens.values = np.zeros(weight_tens.shape)
     weight_tens.quant_values = np.zeros(weight_tens.shape, np.int8)
     weight_tens.quantization = qp.clone()
@@ -609,6 +611,40 @@
     assert not support.is_operator_supported(op)
 
 
+pad_invalid_size_test_data = [
+    (2, 1, 1, 1),
+    (1, 2, 1, 1),
+    (1, 1, 2, 1),
+    (1, 1, 1, 2),
+]
+
+
+@pytest.mark.parametrize("top, left, bottom, right", pad_invalid_size_test_data)
+def test_constraint_pad_size(top, left, bottom, right):
+    # Tests PAD operator with a padding that is too high to be handled by the NPU
+    out_shape = [1, 11 + left + right, 11 + top + bottom, 1]
+    padding = [[0, 0], [top, bottom], [left, right], [0, 0]]
+    op = create_pad_op(in_shape=[1, 11, 11, 1], out_shape=out_shape, padding=padding,)
+    assert not support.is_operator_supported(op)
+
+
+leading_pad_test_data = [
+    (2, 2, 11, True),
+    (1, 2, 11, False),
+    (2, 1, 11, False),
+    (5, 2, 11, True),
+]
+
+
+@pytest.mark.parametrize("top, left, kernel_size, expected", leading_pad_test_data)
+def test_constraint_leading_pad_size(top, left, kernel_size, expected):
+    # Tests PAD operator with big kernel size; top and left pad must be multiple of stride
+    out_shape = [1, 11 + left, 11 + top, 1]
+    padding = [[0, 0], [top, 0], [left, 0], [0, 0]]
+    op = create_pad_op(in_shape=[1, 11, 11, 1], out_shape=out_shape, padding=padding, kernel_size=kernel_size)
+    assert support.is_operator_supported(op) == expected
+
+
 def create_strided_slice():
     # Creates a valid strided slice operator with some valid inputs/outputs
     op = create_strided_slice_op([1, 10, 10, 10], [1, 5, 5, 10], [127, 2, 2, 0], [0, 7, -3, 0])