TOSA: Added decomposition of PAD

Added support for:
-Rank > 4 and batch > 1
-Tensor dimensions exceeding NPU limit
-Padding in any dimension

(Implementation for functional compliance,
 not considering performance)

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: Ief58fb3233d885f10ba5e68c5374b190efbe9351
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index 0fbed46..29caf6d 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -50,6 +50,12 @@
     return op
 
 
+def create_pad_nop(name: str) -> Operation:
+    op = Operation(Op.Pad, name)
+    op.run_on_npu = True
+    return op
+
+
 def create_depthwise_maxpool(
     name: str,
     ifm: Tensor,
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 954ac68..e27dbed 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -33,9 +33,12 @@
 from .operation import Op
 from .operation_util import create_add_nop
 from .operation_util import create_avgpool_nop
+from .operation_util import create_pad_nop
 from .shape4d import Shape4D
 from .tensor import create_const_tensor
 from .tensor import create_equivalence_id
+from .tensor import shape_num_elements
+from .tensor import Tensor
 
 
 def replace_rescale_with_avg_pool(rescale_op):
@@ -414,87 +417,44 @@
     return op
 
 
-# TODO modified copy of TFLite, solution for TOSA PAD will change so reuse has not been considered
-def convert_pad(op, arch, nng):
+def convert_pad_in_width(op):
     """
     Rewrites PAD operator to an add that copies the IFM to the OFM
     + up to 4 add operators that fill the OFM with zeros at the borders.
     """
-
-    if op.type != Op.Pad:
-        return op
-
-    # TODO assuming rank <= 4 and N = 1 for rank ==4
-    # This is checked in tosa_supported_operators
+    assert op.type == Op.Pad
+    assert op.ifm_shapes[0] is not None and op.ofm_shapes[0] is not None
     ifm = op.ifm
-    assert ifm is not None
-    ifm_shape = Shape4D(ifm.shape)
     ofm = op.ofm
-    assert ofm is not None
+    ifm_shape = op.ifm_shapes[0]
     ofm.ops = []
     ofm_shape = op.ofm_shapes[0]
 
-    rank = len(ifm.shape)
     padding = op.inputs[1].values
-    pad_depth = padding[-1]
-    if not (pad_depth == 0).all():
-        print("Warning: For PAD, padding in depth not supported yet")
-        assert False
-
-    top, bottom = 0, 0
-    left, right = 0, 0
-    if rank > 1:
-        left, right = padding[-2][0], padding[-2][1]
-    if rank > 2:
-        top, bottom = padding[-3][0], padding[-3][1]
-    if rank == 4 and not (padding[-4] == 0).all():
-        print("Warning: For PAD, padding not supported in first dimension when rank == 4 yet")
-        assert False
+    left, right = padding[-2]
 
     # Add op that copies IFM to the right place inside the OFM
     shp0 = Shape4D(0, 0, 0, 0)
-    shp_top = shp0.with_height(top)
-    add_op = create_add_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
+    add_op = create_add_for_concat(op, op.name + "_main", ifm, ifm_shape, shp0.with_width(left))
     add_op.activation = op.activation
 
     quant = ofm.quantization
     pad_value = ifm.quantization.zero_point
     ifm.quantization.zero_point = 0
-    # Add operations that fill the borders of the OFM
-    if top > 0:
-        shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
-        zero_tens = create_const_tensor(
-            op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant,
-        )
-        # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
-        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
-        create_add_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
-    if bottom > 0:
-        shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
-        zero_tens = create_const_tensor(
-            op.name + "_bottom",
-            shape.as_list(),
-            ofm.dtype,
-            shape.elements() * [pad_value],
-            np.uint8,
-            quantization=quant,
-        )
-        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
-        create_add_for_concat(op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom))
     if left > 0:
         shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
         zero_tens = create_const_tensor(
             op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
         )
         zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
-        create_add_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
+        create_add_for_concat(op, op.name + "_left", zero_tens, shape, shp0)
     if right > 0:
         shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
         zero_tens = create_const_tensor(
             op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
         )
         zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
-        create_add_for_concat(op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right))
+        create_add_for_concat(op, op.name + "_right", zero_tens, shape, shp0.with_width(ofm_shape.width - right))
 
     op.type = Op.ConcatTFLite
     return add_op
@@ -581,6 +541,8 @@
         part_op.read_shapes[1] = ifm2_shape
         part_op.ifm2.consumer_list.append(part_op)
 
+    return part_op
+
 
 def get_nhwc_stride(shape):
     stride_x = shape.depth
@@ -700,6 +662,8 @@
         ifm.consumer_list.remove(op)
         if binary:
             ifm2.consumer_list.remove(op)
+
+        return op_list
     else:
         op.ofm_shapes.append(Shape4D(new_ofm_shape))
         op.ifm_shapes.append(Shape4D(new_ifm_shape))
@@ -781,6 +745,84 @@
     return tens
 
 
+def decomp_rewrite_pad(op, arch):
+    """
+    Decomposition of pad to elementwise operations:
+    For each dimension that needs padding:
+    -Create a new PAD operator for each dimension to be added
+     Ifm/ofm are reshape so this is the width dimension is to be padded
+     (rank for each is 3)
+    -Rewrite the the new PAD operator so there is:
+    -1 Add operator for copying the data
+    -1 Add operator for each left/right to be padded
+    """
+    # TODO several things would be possible to optimize
+    # For instance there are cases when it should be possible to pad 2
+    # dimensions at the same time.
+    if op.type == Op.Pad:
+        ofm_elements = shape_num_elements(op.ofm.shape)
+        padding = op.inputs[1].values
+
+        rank = len(op.ifm.shape)
+        next_ifm = op.ifm
+        next_ifm_shape = next_ifm.shape.copy()
+
+        first_pad_rewrite_op = None
+        ifm_quant = op.ifm.quantization.clone()
+
+        for dim in range(padding.shape[0]):
+            # Check if padding is to be applied in this dimension
+            dim_pad = padding[dim]
+            if not (dim_pad == 0).all():
+                # Reshape so that width dimension is to be padded
+                new_ifm_shape = reshape_concat_shape(next_ifm_shape, rank, dim)
+                new_pad_input = np.zeros((4, 2), dtype=np.int32)
+                new_pad_input[2] = dim_pad
+
+                pad_op = create_pad_nop(f"{op.name}_dim_{dim}")
+                pad_op.add_input_tensor(next_ifm)
+                new_pad_tens = op.inputs[1].clone("_dim_{dim}")
+
+                name = op.inputs[1].name + f"_dim_{dim}"
+                new_pad_tens = create_const_tensor(
+                    name, list(new_pad_input.shape), DataType.int32, new_pad_input, np.int32
+                )
+                pad_op.add_input_tensor(new_pad_tens)
+
+                new_ofm_shape = new_ifm_shape.copy()
+                new_ofm_shape[-2] = new_ofm_shape[-2] + dim_pad.sum()
+                next_ifm_shape[dim] = next_ifm_shape[dim] + dim_pad.sum()
+
+                if Shape4D(new_ofm_shape).elements() == ofm_elements:
+                    # Last one, use op.ofm
+                    ofm = op.ofm
+                else:
+                    # add a new ofm Tensor
+                    ofm = Tensor(new_ofm_shape, op.ofm.dtype, f"{pad_op.name}_tens")
+                    ofm.quantization = ifm_quant.clone()
+
+                pad_op.set_output_tensor(ofm)
+                pad_op.ifm_shapes.append(Shape4D(new_ifm_shape))
+                pad_op.ofm_shapes.append(Shape4D(new_ofm_shape))
+                DebugDatabase.add_optimised(op, pad_op)
+                next_ifm = ofm
+
+                # Rewrite the pad op
+                converted_pad_op = convert_pad_in_width(pad_op)
+                first_pad_rewrite_op = converted_pad_op
+            else:
+                # Change to Identity operation (will be removed)
+                op.type = Op.Identity
+
+        if first_pad_rewrite_op:
+            assert op.ofm.shape == next_ifm_shape
+            for inp in op.inputs:
+                inp.consumer_list.remove(op)
+            return first_pad_rewrite_op
+
+    return op
+
+
 def fixup_quantization(op, arch, nng):
     if op.ifm and op.ifm.quantization.zero_point is None:
         op.ifm.quantization.zero_point = 0
@@ -812,6 +854,11 @@
             nng, sg, arch, [decomp_rewrite_concat], [], rewrite_unsupported=False
         )
 
+    # Decomposing of pad
+    for idx, sg in enumerate(nng.subgraphs):
+        rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [decomp_rewrite_pad])
+        sg.refresh_after_modification()
+
     # Handle sg input output
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
@@ -857,7 +904,7 @@
     # Post-processing step 1
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [], [rewrite_activation, convert_pad, add_padding_fields],
+            nng, sg, arch, [], [rewrite_activation, add_padding_fields],
         )
 
     # Removal of Slice, need to be done after optimisation has been performed,
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 5a85b0e..e378511 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -46,7 +46,7 @@
     activation_ops = relu_ops | set((Op.Table,))
     pad_ops = set((Op.Pad,))
 
-    rank_unlimited_ops = set((Op.Concat, Op.Reshape, Op.Identity))
+    rank_unlimited_ops = set((Op.Concat, Op.Reshape, Op.Identity, Op.Pad))
     rank6_limited_ops = elem_wise_ops
     batch_enabled_ops = rank6_limited_ops | rank_unlimited_ops
     large_tens_dims_enabled_ops = batch_enabled_ops | set((Op.SplitSliceRead,))