MLBEDSW-6686: Resize bilinear HPC with tile padding

- Added support for Resize Bilinear with half pixel centers for int8 and
uint8.

- Utilizes the new "TILE" padding mode.

- Utilizes ofm stride multipliers and modified tile base offsets to
write OFMs interleaved.

Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
Change-Id: I37fa77c022a368f05fda0ead75d8696c9205f833
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 6b454e3..27513d3 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -464,6 +464,143 @@
     return op
 
 
+def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
+    def _compute_interpolation_values(index, input_size, output_size):
+        scale = input_size / output_size
+        scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
+        lower_bound = max(np.floor(scaled_value), 0)
+
+        return scaled_value, lower_bound
+
+    def _compute_kernels(input_height, input_width, output_height, output_width):
+        kernels = []
+        for y in (1, 2):
+            for x in (1, 2):
+                sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
+                sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
+
+                # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
+                # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
+                # top-to-bottom - same as the depthwise convolution strides across each tile
+                kernel = np.zeros((2, 2))
+                kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
+                kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
+                kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
+                kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
+                kernel *= 16
+                kernels.append(kernel)
+
+        return kernels
+
+    def _build_convolutions(op, kernels):
+        dw_op_attrs = {
+            "padding": Padding.TILE,
+            "stride_h": 1,
+            "stride_w": 1,
+            "strides": (1, 1, 1, 1),
+            "depth_multiplier": 1,
+            "channel_multiplier": 1,
+            "dilation_h_factor": 1,
+            "dilation_w_factor": 1,
+            "dilation": (1, 1, 1, 1),
+        }
+        ifm = op.ifm
+        ofm = op.ofm
+        ofm.ops = []
+        elem_size = 2 if ofm.dtype == DataType.int16 else 1
+
+        n, h, w, c = ifm.shape
+        _, _, ow, _ = ofm.shape
+
+        intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
+        intermediate_tens.quantization = op.outputs[0].quantization.clone()
+        avgpool_op = op
+        avgpool_op.name = "rb_init_avgpool"
+        avgpool_op.type = Op.AvgPool
+        avgpool_op.attrs["padding"] = Padding.VALID
+        avgpool_op.attrs["stride_w"] = 1
+        avgpool_op.attrs["stride_h"] = 1
+        avgpool_op.attrs["filter_width"] = 1
+        avgpool_op.attrs["filter_height"] = 1
+        avgpool_op.attrs["strides"] = [1, 1, 1, 1]
+        avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
+
+        avgpool_op.add_input_tensor(ifm)
+        avgpool_op.set_output_tensor(intermediate_tens)
+        avgpool_op.set_ifm_ofm_shapes()
+
+        dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
+        dw_conv._original_type = Op.ResizeBilinear
+        dw_conv.write_shape = Shape4D(n, h, w, c)
+        dw_conv.write_offset = Shape4D(0, 0, 0, 0)
+
+        # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
+        # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
+        # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
+        # values to be incorrectly rounded
+        ofm.quantization.next_after = True
+        dw_conv.rounding_mode = NpuRoundingMode.NATURAL
+
+        # Double height and width stride to write the output of each of the four depthwise convolutions below
+        # interleaved with each other when combined with OFM tile base offsets.
+        dw_conv.ofm_stride_multiplier = [1, 2, 2]  # C/H/W
+
+        # Choose tile padding direction - pad by 1 with edge values in two direction.
+        # For example, TL (top left) will pad top and left in H/W-plane in all channels.
+        directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]]  # TL, TR, BL, BR
+        for i in (0, 1):
+            for j in (0, 1):
+                index = i * 2 + j
+                dw_conv.name = f"depthwise_conv_{index}"
+                dw_op_attrs["explicit_padding"] = directions[index]
+                dw_conv.attrs.update(dw_op_attrs)
+
+                # This will offset the start of the write by modifying the Tile 0 base address
+                dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
+
+                ofm.ops.append(dw_conv)
+                dw_conv.outputs = [ofm]
+
+                kernel = kernels[index]
+                shape = [2, 2, 1, c]
+                kernel = np.dstack([kernel] * c)
+
+                quant = QuantizationParameters()
+                quant.zero_point = 0
+                quant.scale_f32 = 1.0 / 16
+
+                dw_conv.inputs = []
+                dw_conv.add_input_tensor(intermediate_tens)
+                dw_conv.add_input_tensor(
+                    create_const_tensor(
+                        "weights",
+                        shape,
+                        intermediate_tens.dtype,
+                        np.array(kernel).reshape(shape),
+                        value_dtype=np.int8,
+                        quantization=quant,
+                    ),
+                )
+
+                # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
+                # need to append the bias tensor as resize ops only have 2 inputs
+                assert len(dw_conv.inputs) == 2
+                dw_conv.inputs.append(None)
+                fixup_bias_tensors(dw_conv, None, None)
+
+                dw_conv.set_ifm_ofm_shapes()
+                dw_conv = dw_conv.clone(f"_{index}")
+        return op
+
+    _, input_height, input_width, _ = op.ifm.shape
+    _, output_height, output_width, _ = op.ofm.shape
+
+    kernels = _compute_kernels(input_height, input_width, output_height, output_width)
+    op = _build_convolutions(op, kernels)
+
+    return op
+
+
 def fixup_resize(op, arch, nng):
     if op.type.is_resize_op() and op.run_on_npu:
         if op.ifm_shapes[0] == op.ofm_shapes[0]:
@@ -472,6 +609,8 @@
             op.type = Op.Identity
         elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
             convert_resize_1x1_to_add(op)
+        elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
+            convert_resizebilinear_to_depthwise_convolutions(op)
         else:
             convert_resize_to_upscale_and_average_pool(op)