TOSA: Add support for PAD

Added support for TOSA PAD operator
in line with legacy support
Limitations:
-Rank <= 4
-N = 1 if Rank = 4 for ifms/ofm
-only padding in W and H dimensions
-bool_t not supported

Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com>
Change-Id: I511608202b4c9bf6d86285b559c517fb41741fdf
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 681f498..e9d364e 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -592,6 +592,9 @@
     def get_ifm_ifm2_weights_ofm(self):
         return self.ifm, self.ifm2, self.weights, self.ofm
 
+    def get_ifm_ifm2_ofm(self):
+        return self.ifm, self.ifm2, self.ofm
+
     def get_ifm_weights_biases_ofm(self):
         return self.ifm, self.weights, self.bias, self.ofm
 
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 2d1245b..49fc997 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -35,6 +35,7 @@
 from .operation_util import create_avgpool_nop
 from .shape4d import Shape4D
 from .tensor import create_const_tensor
+from .tensor import create_equivalence_id
 
 
 def replace_rescale_with_avg_pool(rescale_op):
@@ -417,6 +418,96 @@
     return op
 
 
+# TODO modified copy of TFLite, solution for TOSA PAD will change so reuse has not been considered
+def convert_pad(op, arch, nng):
+    """
+    Rewrites PAD operator to an add that copies the IFM to the OFM
+    + up to 4 add operators that fill the OFM with zeros at the borders.
+    """
+
+    if op.type != Op.Pad:
+        return op
+
+    # TODO assuming rank <= 4 and N = 1 for rank ==4
+    # This is checked in tosa_supported_operators
+    ifm = op.ifm
+    assert ifm is not None
+    ifm_shape = Shape4D(ifm.shape)
+    ofm = op.ofm
+    assert ofm is not None
+    ofm.ops = []
+    ofm_shape = op.ofm_shapes[0]
+
+    rank = len(ifm.shape)
+    padding = op.inputs[1].values
+    pad_depth = padding[-1]
+    if not (pad_depth == 0).all():
+        print("Warning: For PAD, padding in depth not supported yet")
+        assert False
+
+    top, bottom = 0, 0
+    left, right = 0, 0
+    if rank > 1:
+        left, right = padding[-2][0], padding[-2][1]
+    if rank > 2:
+        top, bottom = padding[-3][0], padding[-3][1]
+    if rank == 4 and not (padding[-4] == 0).all():
+        print("Warning: For PAD, padding not supported in first dimension when rank == 4 yet")
+        assert False
+
+    # Add op that copies IFM to the right place inside the OFM
+    shp0 = Shape4D(0, 0, 0, 0)
+    shp_top = shp0.with_height(top)
+    add_op = create_add_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
+    add_op.activation = op.activation
+
+    quant = ofm.quantization
+    pad_value = ifm.quantization.zero_point
+    # Add operations that fill the borders of the OFM
+    if top > 0:
+        shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_top",
+            shape.as_list(),
+            ofm.dtype,
+            shape.elements() * [pad_value],
+            np.uint8,
+            quantization=quant,  # TODO
+        )
+        # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_add_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
+    if bottom > 0:
+        shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_bottom",
+            shape.as_list(),
+            ofm.dtype,
+            shape.elements() * [pad_value],
+            np.uint8,
+            quantization=quant,
+        )
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_add_for_concat(op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom))
+    if left > 0:
+        shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+        )
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_add_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
+    if right > 0:
+        shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
+        zero_tens = create_const_tensor(
+            op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+        )
+        zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+        create_add_for_concat(op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right))
+
+    op.type = Op.ConcatTFLite
+    return add_op
+
+
 def fixup_quantization(op, arch, nng):
     if op.ifm and op.ifm.quantization.zero_point is None:
         op.ifm.quantization.zero_point = 0
@@ -484,7 +575,7 @@
     # Post-processing step 1
     for idx, sg in enumerate(nng.subgraphs):
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            nng, sg, arch, [], [rewrite_activation, add_padding_fields],
+            nng, sg, arch, [], [rewrite_activation, convert_pad, add_padding_fields],
         )
 
     # Removal of Slice, need to be done after optimisation has been performed,
diff --git a/ethosu/vela/tosa_mapping.py b/ethosu/vela/tosa_mapping.py
index 6efc479..ebbaa0a 100644
--- a/ethosu/vela/tosa_mapping.py
+++ b/ethosu/vela/tosa_mapping.py
@@ -174,7 +174,7 @@
 unary_quant_info = QuantSerializer("UnaryQuantInfo", ("input_zp", "output_zp"))
 conv_quant_info = QuantSerializer("ConvQuantInfo", ("input_zp", "weight_zp"))
 matmul_quant_info = QuantSerializer("MatMulQuantInfo", ("a_zp", "b_zp"))
-pad_quant_info = QuantSerializer("PadQuantInfo", ("input_zp"))
+pad_quant_info = QuantSerializer("PadQuantInfo", ("input_zp",))
 
 unsupported_tosa_operators = {
     TosaOp.UNKNOWN,
@@ -218,7 +218,6 @@
     TosaOp.REDUCE_MIN,
     TosaOp.REDUCE_PRODUCT,
     TosaOp.REDUCE_SUM,
-    TosaOp.PAD,
     TosaOp.REVERSE,
     TosaOp.TILE,
     TosaOp.GATHER,
@@ -298,12 +297,19 @@
     # TODO TosaOp.REDUCE_PRODUCT
     # TODO TosaOp.REDUCE_SUM
     TosaOp.CONCAT: (Op.Concat, axis_attrs, None, TOSA_CONCAT_INDICES),
-    # TODO TosaOp.PAD
+    # TODO Is the padding intended to be dynamic input, TOSA spec state it as attribute
+    # Handled as for TFLite for now
+    TosaOp.PAD: (Op.Pad, None, pad_quant_info, TOSA_IFM_INDICES),
     TosaOp.RESHAPE: (Op.Reshape, reshape_attrs, None, TOSA_IFM_INDICES),
     # TODO TosaOp.REVERSE
     TosaOp.SLICE: (Op.SplitSliceRead, slice_attrs, None, TOSA_IFM_INDICES),
     # TODO TosaOp.TILE
-    TosaOp.TRANSPOSE: (Op.Transpose, None, None, TOSA_IFM_IFM2_INDICES),
+    TosaOp.TRANSPOSE: (
+        Op.Transpose,
+        None,
+        None,
+        TOSA_IFM_IFM2_INDICES,
+    ),  # TODO Is the perms intended to be dynamic input, TOSA spec state it as attribute
     # TODO TosaOp.GATHER
     # TODO TosaOp.SCATTER
     # TODO TosaOp.RESIZE
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index 94ba350..aadcb0a 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -113,7 +113,7 @@
         # Moving permutation to an attribute, to match internal graph representation for now
         perms = None
         if op_code == TosaOp.TRANSPOSE:
-            perms = perms = inputs.pop(1)
+            perms = inputs.pop(1)
             indices = TOSA_IFM_INDICES
 
         name = "unknown_op_name"
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index d368616..a4f822e 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -39,16 +39,15 @@
 
     mac_main_ops = convolution_like_ops | pooling_ops | fc_vector_products
     memory_only_ops = set((Op.Reshape, Op.Transpose, Op.Concat, Op.SplitSliceRead,))
-
     binary_elem_wise_add_mul_sub = set((Op.Add, Op.Mul, Op.RescaleMul, Op.Sub,))
-
     type_conversion_ops = set((Op.Rescale,))
     relu_ops = set((Op.Clamp, Op.ReluN,))
     activation_ops = relu_ops
+    pad_ops = set((Op.Pad,))
 
     npu_post_ops = activation_ops
     supported_operators = (
-        mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | binary_elem_wise_add_mul_sub
+        mac_main_ops | type_conversion_ops | npu_post_ops | memory_only_ops | binary_elem_wise_add_mul_sub | pad_ops
     )
 
     # Supported data types
@@ -60,12 +59,15 @@
         # Setup the generic constraints. Note: the order matters
         self.generic_constraints = []
         self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dtype)
-        self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dimension)
+        self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dimension)  # TODO as not supported yet
+        self.generic_constraints.append(TosaSupportedOperators.constraint_rank)  # TODO as not supported yet
+        self.generic_constraints.append(TosaSupportedOperators.constraint_batch)  # TODO as not supported yet
 
         # Setup specific constraints. Note: the order matters
         self.specific_constraints = defaultdict(list)
 
         self.specific_constraints[Op.Transpose].append(TosaSupportedOperators.constraint_ifm_producer)
+        self.specific_constraints[Op.Pad].append(TosaSupportedOperators.constraint_padding_producer)
 
         # Depthwise Conv specific checks:
         for op_type in TosaSupportedOperators.depthwise_convolution_ops:
@@ -127,6 +129,38 @@
                 extra.append(f"Tensor '{tens.name}' has shape: {tens.shape}")
         return valid, ", ".join(extra)
 
+    # TODO This is for a HW limitation, that is to be resolved in SW later on
+    @staticmethod
+    def constraint_rank(op):
+        "Tensor rank must be <= 4"
+        valid = True
+        extra = []
+        tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens]
+        if not tensors:
+            tensors = [tens for tens in op.inputs if tens]
+        for tens in tensors:
+            rank = len(tens.shape)
+            if not rank <= 4:
+                valid = False
+                extra.append(f"Tensor '{tens.name}' has rank: {rank}")
+        return valid, ", ".join(extra)
+
+    # TODO This is for a HW limitation, that is to be resolved in SW later on
+    @staticmethod
+    def constraint_batch(op):
+        "If Tensor rank is 4 batch of ifms/ofm must be 1"
+        valid = True
+        extra = []
+        tensors = [tens for tens in op.get_ifm_ifm2_ofm() if tens]
+        if not tensors:
+            tensors = [tens for tens in op.inputs if tens]
+        for tens in tensors:
+            rank = len(tens.shape)
+            if rank == 4 and tens.shape[0] != 1:
+                valid = False
+                extra.append(f"Tensor '{tens.name}' has rank: 4 and N: {tens.shape[0]}")
+        return valid, ", ".join(extra)
+
     @staticmethod
     def constraint_ifm_producer(cls, op):
         "Input must be constant data"
@@ -143,6 +177,14 @@
 
         return valid, "Avgpool with pad_top {top} and pad_left {left}"
 
+    # TODO limit padding to be const data for now.
+    # For TFLite it is assumed to be constant.
+    @staticmethod
+    def constraint_padding_producer(op):
+        "Input must be constant data"
+        valid = op.inputs[1].ops and op.inputs[1].ops[0].type == Op.Const
+        return valid, "PAD Op with non-constant data padding"
+
     # TODO duplicates tflite_supported operators, but support for depth multiplier should be added at a later stage
     @staticmethod
     def constraint_depth_multiplier(op):