TOSA: Add support for PAD

Added support for TOSA PAD operator
in line with legacy support
Limitations:
-Rank <= 4
-N = 1 if Rank = 4 for ifms/ofm
-only padding in W and H dimensions
-bool_t not supported

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I511608202b4c9bf6d86285b559c517fb41741fdf
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 2d1245b..49fc997 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -35,6 +35,7 @@
 from .operation_util import create_avgpool_nop
 from .shape4d import Shape4D
 from .tensor import create_const_tensor
+from .tensor import create_equivalence_id
 
 
 def replace_rescale_with_avg_pool(rescale_op):
@@ -417,6 +418,96 @@
     return op
 
 
+# TODO modified copy of TFLite, solution for TOSA PAD will change so reuse has not been considered
+def convert_pad(op, arch, nng):
+    """
+    Rewrites PAD operator to an add that copies the IFM to the OFM
+    + up to 4 add operators that fill the OFM with zeros at the borders.
+    """
+
+    if op.type != Op.Pad:
+        return op
+
+    # TODO assuming rank <= 4 and N = 1 for rank ==4
+    # This is checked in tosa_supported_operators
+    ifm = op.ifm
+    assert ifm is not None
+    ifm_shape = Shape4D(ifm.shape)
+    ofm = op.ofm
+    assert ofm is not None
+    ofm.ops = []
+    ofm_shape = op.ofm_shapes[0]
+
+    rank = len(ifm.shape)
+    padding = op.inputs[1].values
+    pad_depth = padding[-1]
+    if not (pad_depth == 0).all():
+        print("Warning: For PAD, padding in depth not supported yet")
+        assert False
+
+    top, bottom = 0, 0
+    left, right = 0, 0
+    if rank > 1:
+        left, right = padding[-2][0], padding[-2][1]
+    if rank > 2:
+        top, bottom = padding[-3][0], padding[-3][1]
+    if rank == 4 and not (padding[-4] == 0).all():
+        print("Warning: For PAD, padding not supported in first dimension when rank == 4 yet")
+        assert False
+
+    # Add op that copies IFM to the right place inside the OFM
+    shp0 = Shape4D(0, 0, 0, 0)
+    shp_top = shp0.with_height(top)
+    add_op = create_add_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
+    add_op.activation = op.activation
+
+    quant = ofm.quantization
+    pad_value = ifm.quantization.zero_point
+    # Add operations that fill the borders of the OFM
+    if top > 0:
+        shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_top",
+            shape.as_list(),
+            ofm.dtype,
+            shape.elements() * [pad_value],
+            np.uint8,
+            quantization=quant,  # TODO
+        )
+        # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_add_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
+    if bottom > 0:
+        shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_bottom",
+            shape.as_list(),
+            ofm.dtype,
+            shape.elements() * [pad_value],
+            np.uint8,
+            quantization=quant,
+        )
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_add_for_concat(op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom))
+    if left > 0:
+        shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+        )
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_add_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
+    if right > 0:
+        shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+        )
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_add_for_concat(op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right))
+
+    op.type = Op.ConcatTFLite
+    return add_op
+
+
 def fixup_quantization(op, arch, nng):
     if op.ifm and op.ifm.quantization.zero_point is None:
         op.ifm.quantization.zero_point = 0
@@ -484,7 +575,7 @@
     # Post-processing step 1
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [], [rewrite_activation, add_padding_fields],
+            nng, sg, arch, [], [rewrite_activation, convert_pad, add_padding_fields],
         )
 
     # Removal of Slice, need to be done after optimisation has been performed,