MLBEDSW-6686: Resize bilinear HPC with tile padding

- Added support for Resize Bilinear with half pixel centers for int8 and
uint8.

- Utilizes the new "TILE" padding mode.

- Utilizes ofm stride multipliers and modified tile base offsets to
write OFMs interleaved.

Signed-off-by: Rickard Bolin <rickard.bolin@arm.com>
Change-Id: I37fa77c022a368f05fda0ead75d8696c9205f833
diff --git a/SUPPORTED_OPS.md b/SUPPORTED_OPS.md
index 6a92e82..36b403a 100644
--- a/SUPPORTED_OPS.md
+++ b/SUPPORTED_OPS.md
@@ -1,7 +1,7 @@
 # Supported Ops
 
 This file was automatically generated by Vela using the `--supported-ops-report` parameter.  
-Vela version: `3.5.0`
+Vela version: `3.5.1.dev14+gc22ad76.d20220921`
 
 This file complies with
 [**Gitiles Markdown syntax**](https://github.com/google/gitiles/blob/master/Documentation/markdown.md)
@@ -36,6 +36,7 @@
 | MUL | [Generic](#tflite-generic-constraints), [Specific](#tflite-mul-constraints) |
 | PACK | [Generic](#tflite-generic-constraints) |
 | PAD | [Generic](#tflite-generic-constraints), [Specific](#tflite-pad-constraints) |
+| PRELU | [Generic](#tflite-generic-constraints) |
 | QUANTIZE | [Generic](#tflite-generic-constraints) |
 | RELU | [Generic](#tflite-generic-constraints) |
 | RELU6 | [Generic](#tflite-generic-constraints) |
@@ -116,7 +117,6 @@
 - Axis attribute must be in the range [0, <ofm_dimensions>)
 - All Input dimensionalities must match OFM dimensionality
 - All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute
-- All Input dimensions must match OFM dimension in all axes except the one defined by the axis attribute
 - The size of the OFM axis must match the sum of all IFM axis defined by the axis attribute
 
 ### TFLite CONV_2D Constraints
@@ -184,7 +184,6 @@
 
 - At least one Input's shape must match the OFM's shape
 - IFM and OFM data types must match
-- Alpha only allowed to be negative if IFM is int8 or uint8
 - Batch size must be 1 for Input tensors with more than 2 dimensions
 
 ### TFLite MAXIMUM Constraints
@@ -268,6 +267,7 @@
 
 - Input and output quantisation must match.
 - Shape must be constant
+- Reshape on NPU not supported before MEAN operator
 
 ### TFLite RESIZE_BILINEAR Constraints
 
@@ -276,11 +276,12 @@
 - The width and height of the IFM and OFM must match one of the following criteria:  
         IFM W and H must both be 1  
         IFM must match OFM  
-        OFM W and H must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True  
-        OFM W and H must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False
+        W and H scaling must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True  
+        W and H scaling must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False
 - The size tensor must match the output tensor shape
 - Both align_corners and half_pixel_centers can't be True
-- half_pixel_centers are not supported
+- Half_pixel_centers are only supported for resize bilinear with IFM dtype int8 or uint8
+- Half_pixel_centers for resize bilinear requires that OFM W and H is 2x IFM W and H
 
 ### TFLite RESIZE_NEAREST_NEIGHBOR Constraints
 
@@ -289,11 +290,11 @@
 - The width and height of the IFM and OFM must match one of the following criteria:  
         IFM W and H must both be 1  
         IFM must match OFM  
-        OFM W and H must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True  
-        OFM W and H must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False
+        W and H scaling must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True  
+        W and H scaling must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False
 - The size tensor must match the output tensor shape
 - Both align_corners and half_pixel_centers can't be True
-- half_pixel_centers are not supported
+- Half_pixel_centers are only supported for resize bilinear with IFM dtype int8 or uint8
 
 ### TFLite SOFTMAX Constraints
 
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 5e7e112..b33851a 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2021-2022 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -111,6 +111,12 @@
     if _avoid_nhcwb16_for_shapes(tens):
         return
 
+    # Resize bilinear half pixel center implementation requires OFM with linear format to
+    # allow stride modification in H/W dimensions.
+    for op in tens.ops:
+        if op.original_type == Op.ResizeBilinear and op.type == Op.DepthwiseConv2DBias:
+            return
+
     for op in tens.consumer_list:
         if op.type == Op.ReduceSum and (
             tens.dtype == DataType.int32 or arch.accelerator_config == Accelerator.Ethos_U65_512
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 6246b37..7923e37 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -189,6 +189,7 @@
             dtype=cmd.ifm_tensor.dtype,
         )
         top, left, bottom, right = 0, 0, 0, 0
+
     return NpuPadding(top=top, left=left, bottom=bottom, right=right)
 
 
@@ -297,6 +298,10 @@
     """Checks if quantization should use 0 as zero point"""
     if tens.dtype == DataType.int32 and is_ifm_tensor:
         return True
+    # Force zero point to 0 for ResizeBilinear when converting to a DepthwiseConv since the reference kernel
+    # will ignore the zero point.
+    if ps.primary_op.original_type == Op.ResizeBilinear and ps.primary_op.type == Op.DepthwiseConv2DBias:
+        return True
     if ps.primary_op.type not in (Op.AvgPool, Op.CLZ, Op.SHL) and not ps.primary_op.type.is_resize_op():
         return False
     if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
@@ -352,6 +357,7 @@
     box: Box,
     arch: ArchitectureFeatures,
     op_shape4D: Shape4D,
+    tile_base_offsets: List[int],
     stride_multiplier: Optional[List[int]] = None,
 ) -> NpuFeatureMap:
     """Creates feature map with common fields populated"""
@@ -380,6 +386,8 @@
         box.start_coord, box.end_coord, strides, op_shape4D
     )
 
+    for idx, offset in enumerate(tile_base_offsets):
+        addresses[idx] += offset
     fm.tiles = NpuTileBox(
         height_0=height_0, height_1=height_1, width_0=width_0, addresses=[int(addr) for addr in addresses]
     )
@@ -475,12 +483,14 @@
     ifm_width = cmd.ps.ifm_shapes[0].width
     ifm_depth = get_ifm_depth(op.type.npu_block_type, cmd.ifm_box, cmd.ofm_box)
 
-    npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0])
+    npu_op.ifm = create_feature_map(cmd.ifm_tensor, cmd.ifm_box, arch, ps.ifm_shapes[0], op.tile_base_offsets_ifm[0])
     npu_op.ifm.shape = NpuShape3D(height=ifm_height, width=ifm_width, depth=ifm_depth)
     npu_op.ifm.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm_tensor)
 
     out_block = cmd.ofm_box.get_block()
-    npu_op.ofm = create_feature_map(cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0], op.ofm_stride_multiplier)
+    npu_op.ofm = create_feature_map(
+        cmd.ofm_tensor, cmd.ofm_box, arch, ps.ofm_shapes[0], op.tile_base_offsets_ofm, op.ofm_stride_multiplier
+    )
     npu_op.ofm.shape = NpuShape3D(height=out_block.height, width=out_block.width, depth=out_block.depth)
     npu_op.ofm.quantization = get_ofm_quantization(ps, cmd.ofm_tensor)
 
@@ -559,7 +569,13 @@
             cmd.ifm_box, cmd.ifm2_box = cmd.ifm2_box, cmd.ifm_box
             ps.ifm_shapes[0], ps.ifm_shapes[1] = ps.ifm_shapes[1], ps.ifm_shapes[0]
             npu_op.reversed_operands = True
-        npu_op.ifm2 = create_feature_map(cmd.ifm2_tensor, cmd.ifm2_box, arch, ps.ifm_shapes[1])
+        npu_op.ifm2 = create_feature_map(
+            cmd.ifm2_tensor,
+            cmd.ifm2_box,
+            arch,
+            ps.ifm_shapes[1],
+            op.tile_base_offsets_ifm[1],
+        )
         npu_op.ifm2.quantization = get_ifm_or_ifm2_quantization(ps, cmd.ifm2_tensor)
         if cmd.ifm2_tensor.shape == []:
             # scalar
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index e162204..af2205c 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -474,7 +474,7 @@
 
     __slots__ = (
         "type",
-        "original_type",
+        "_original_type",
         "name",
         "op_index",
         "attrs",
@@ -501,12 +501,14 @@
         "write_offset",
         "write_shape",
         "ifm_resampling_mode",
+        "tile_base_offsets_ifm",
+        "tile_base_offsets_ofm",
         "ofm_stride_multiplier",
     )
 
     def __init__(self, op_type: Op, name: str):
         self.type = op_type
-        self.original_type = op_type
+        self._original_type = op_type  # the original type of the operation. once set this shouldn't be changed
         self.name = name
         self.attrs: Dict[str, Any] = {}
         self.inputs: List[Optional[Tensor]] = []
@@ -546,6 +548,10 @@
         # write_offset 0,9,0,0, write_shape 1,1,8,1
         self.write_shape: Optional[Shape4D] = None
         self.ifm_resampling_mode: resampling_mode = resampling_mode.NONE
+        # ifm (nhwc), ifm2 (nhwc)
+        self.tile_base_offsets_ifm: List[List[int]] = [[0, 0, 0, 0], [0, 0, 0, 0]]
+        # ofm (nhwc)
+        self.tile_base_offsets_ofm: List[int] = [0, 0, 0, 0]
         # For interleaved/sparse outputs - stride is multiplied with the stride factor of the corresponding axis
         # Order is [C, H, W] - default is no multiplication
         self.ofm_stride_multiplier: List[int] = [1, 1, 1]
@@ -553,6 +559,9 @@
     def clone(self, suffix="_clone"):
         res = Operation(self.type, self.name + suffix)
 
+        # maintain the original type, in cases where the type was changed to something different
+        res._original_type = self._original_type
+
         res.attrs = dict(self.attrs)
         res.inputs = list(self.inputs)
         res.outputs = list(self.outputs)
@@ -567,11 +576,15 @@
         res.op_index = None  # not relevant as not part of input network
         res.read_offsets = list(self.read_offsets)
         res.read_shapes = list(self.read_shapes)
+        res.write_offset = Shape4D(*self.write_offset) if self.write_offset else None
+        res.write_shape = Shape4D(*self.write_shape) if self.write_shape else None
         res.rounding_mode = self.rounding_mode
         res.explicit_scaling = self.explicit_scaling
         res.low_precision_scaling = self.low_precision_scaling
         res.rescale = self.rescale
         res.ifm_resampling_mode = self.ifm_resampling_mode
+        res.tile_base_offsets_ifm = [_ifm.copy() for _ifm in self.tile_base_offsets_ifm]
+        res.tile_base_offsets_ofm = self.tile_base_offsets_ofm.copy()
         res.ofm_stride_multiplier = self.ofm_stride_multiplier.copy()
 
         return res
@@ -581,6 +594,10 @@
 
     __repr__ = __str__
 
+    @property
+    def original_type(self):
+        return self._original_type
+
     def get_kernel_size(self):
         weights = self.weights
         if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 9997031..9fbd454 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -213,6 +213,7 @@
         "max",
         "num_bits",
         "narrow_range",
+        "next_after",
         "scale_f32",
         "zero_point",
         "quant_min",
@@ -233,6 +234,10 @@
         self.num_bits = num_bits
         self.narrow_range = narrow_range
 
+        # Use the 'next after' float value of scale_f32 when converting to scale and shift. It can be combined with
+        # natural rounding to perform rounding away from zero. This only affects the ofm scale and bias tensor, it has
+        # no affect on global scaling i.e. the ofm_scale register
+        self.next_after = False
         self.scale_f32: Union[float, np.ndarray, None] = None
         self.zero_point: Union[int, np.ndarray, None] = None
         self.quant_min: Optional[float] = None
@@ -240,12 +245,9 @@
         self.quant_dim: Optional[int] = None
 
     def __str__(self):
-        return "<nng.QuantizationParameters min=%s max=%s, num_bits=%s, scale=%s, zero_point=%s>" % (
-            self.min,
-            self.max,
-            self.num_bits,
-            self.scale_f32,
-            self.zero_point,
+        return (
+            f"<nng.QuantizationParameters min={self.min}, max={self.max}, num_bits={self.num_bits}, "
+            f"scale={self.scale_f32}, zero_point={self.zero_point}, next={self.next_after}>"
         )
 
     __repr__ = __str__
@@ -258,6 +260,7 @@
         res.num_bits = self.num_bits
         res.narrow_range = self.narrow_range
 
+        res.next_after = self.next_after
         res.scale_f32 = self.scale_f32
         res.zero_point = self.zero_point
         res.quant_min = self.quant_min
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index 89c2799..3872bdc 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -383,11 +383,14 @@
 
 def test_constraint_resize_half_pixel_centers():
     for resize_op in Op.op_set(Op.is_resize_op):
-        # Invalid case - half-pixel centers (not supported)
+        # Half-pixel centers is only supported for resize bilinear
         op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
         op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
         op.attrs["half_pixel_centers"] = True
-        assert not support.is_operator_supported(op)
+        if resize_op == Op.ResizeBilinear:
+            assert support.is_operator_supported(op)
+        else:
+            assert not support.is_operator_supported(op)
 
 
 def test_constraint_concat_pass():
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 6b454e3..27513d3 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -464,6 +464,143 @@
     return op
 
 
+def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True):
+    def _compute_interpolation_values(index, input_size, output_size):
+        scale = input_size / output_size
+        scaled_value = (index + 0.5 * half_pixel_centers) * scale - 0.5 * half_pixel_centers
+        lower_bound = max(np.floor(scaled_value), 0)
+
+        return scaled_value, lower_bound
+
+    def _compute_kernels(input_height, input_width, output_height, output_width):
+        kernels = []
+        for y in (1, 2):
+            for x in (1, 2):
+                sv_h, lb_h = _compute_interpolation_values(y, input_height, output_height)
+                sv_w, lb_w = _compute_interpolation_values(x, input_width, output_width)
+
+                # Interpolation values calculated for (x, y) = ([1, 2], [1, 2]) will always generalize to the whole
+                # input for upscale = 2 and input sizes >= 2x2 and be in the correct order for going left-to-right,
+                # top-to-bottom - same as the depthwise convolution strides across each tile
+                kernel = np.zeros((2, 2))
+                kernel[1, 1] = (1 - (sv_h - lb_h)) * (1 - (sv_w - lb_w))
+                kernel[0, 1] = (sv_h - lb_h) * (1 - (sv_w - lb_w))
+                kernel[1, 0] = (1 - (sv_h - lb_h)) * (sv_w - lb_w)
+                kernel[0, 0] = (sv_h - lb_h) * (sv_w - lb_w)
+                kernel *= 16
+                kernels.append(kernel)
+
+        return kernels
+
+    def _build_convolutions(op, kernels):
+        dw_op_attrs = {
+            "padding": Padding.TILE,
+            "stride_h": 1,
+            "stride_w": 1,
+            "strides": (1, 1, 1, 1),
+            "depth_multiplier": 1,
+            "channel_multiplier": 1,
+            "dilation_h_factor": 1,
+            "dilation_w_factor": 1,
+            "dilation": (1, 1, 1, 1),
+        }
+        ifm = op.ifm
+        ofm = op.ofm
+        ofm.ops = []
+        elem_size = 2 if ofm.dtype == DataType.int16 else 1
+
+        n, h, w, c = ifm.shape
+        _, _, ow, _ = ofm.shape
+
+        intermediate_tens = Tensor(ifm.shape, ifm.dtype, "intermediate_tens")
+        intermediate_tens.quantization = op.outputs[0].quantization.clone()
+        avgpool_op = op
+        avgpool_op.name = "rb_init_avgpool"
+        avgpool_op.type = Op.AvgPool
+        avgpool_op.attrs["padding"] = Padding.VALID
+        avgpool_op.attrs["stride_w"] = 1
+        avgpool_op.attrs["stride_h"] = 1
+        avgpool_op.attrs["filter_width"] = 1
+        avgpool_op.attrs["filter_height"] = 1
+        avgpool_op.attrs["strides"] = [1, 1, 1, 1]
+        avgpool_op.attrs["ksize"] = [1, 1, 1, 1]
+
+        avgpool_op.add_input_tensor(ifm)
+        avgpool_op.set_output_tensor(intermediate_tens)
+        avgpool_op.set_ifm_ofm_shapes()
+
+        dw_conv = Operation(Op.DepthwiseConv2DBias, "depthwise_conv")
+        dw_conv._original_type = Op.ResizeBilinear
+        dw_conv.write_shape = Shape4D(n, h, w, c)
+        dw_conv.write_offset = Shape4D(0, 0, 0, 0)
+
+        # Set the output rounding mode. Resize bilinear requires rounding away from zero. Therefore, we need to
+        # adjust the accumulated value by a "small" amount before applying natural rounding. The "small" amount
+        # should be big enough to cause a x.5 to be rounded correctly but small enough not to cause smaller
+        # values to be incorrectly rounded
+        ofm.quantization.next_after = True
+        dw_conv.rounding_mode = NpuRoundingMode.NATURAL
+
+        # Double height and width stride to write the output of each of the four depthwise convolutions below
+        # interleaved with each other when combined with OFM tile base offsets.
+        dw_conv.ofm_stride_multiplier = [1, 2, 2]  # C/H/W
+
+        # Choose tile padding direction - pad by 1 with edge values in two direction.
+        # For example, TL (top left) will pad top and left in H/W-plane in all channels.
+        directions = [[1, 1, 0, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 0, 1, 1]]  # TL, TR, BL, BR
+        for i in (0, 1):
+            for j in (0, 1):
+                index = i * 2 + j
+                dw_conv.name = f"depthwise_conv_{index}"
+                dw_op_attrs["explicit_padding"] = directions[index]
+                dw_conv.attrs.update(dw_op_attrs)
+
+                # This will offset the start of the write by modifying the Tile 0 base address
+                dw_conv.tile_base_offsets_ofm[0] = (i * ow + j) * c * elem_size
+
+                ofm.ops.append(dw_conv)
+                dw_conv.outputs = [ofm]
+
+                kernel = kernels[index]
+                shape = [2, 2, 1, c]
+                kernel = np.dstack([kernel] * c)
+
+                quant = QuantizationParameters()
+                quant.zero_point = 0
+                quant.scale_f32 = 1.0 / 16
+
+                dw_conv.inputs = []
+                dw_conv.add_input_tensor(intermediate_tens)
+                dw_conv.add_input_tensor(
+                    create_const_tensor(
+                        "weights",
+                        shape,
+                        intermediate_tens.dtype,
+                        np.array(kernel).reshape(shape),
+                        value_dtype=np.int8,
+                        quantization=quant,
+                    ),
+                )
+
+                # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
+                # need to append the bias tensor as resize ops only have 2 inputs
+                assert len(dw_conv.inputs) == 2
+                dw_conv.inputs.append(None)
+                fixup_bias_tensors(dw_conv, None, None)
+
+                dw_conv.set_ifm_ofm_shapes()
+                dw_conv = dw_conv.clone(f"_{index}")
+        return op
+
+    _, input_height, input_width, _ = op.ifm.shape
+    _, output_height, output_width, _ = op.ofm.shape
+
+    kernels = _compute_kernels(input_height, input_width, output_height, output_width)
+    op = _build_convolutions(op, kernels)
+
+    return op
+
+
 def fixup_resize(op, arch, nng):
     if op.type.is_resize_op() and op.run_on_npu:
         if op.ifm_shapes[0] == op.ofm_shapes[0]:
@@ -472,6 +609,8 @@
             op.type = Op.Identity
         elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
             convert_resize_1x1_to_add(op)
+        elif op.type == Op.ResizeBilinear and op.attrs.get("half_pixel_centers", False):
+            convert_resizebilinear_to_depthwise_convolutions(op)
         else:
             convert_resize_to_upscale_and_average_pool(op)
 
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index be86e9a..9aa174d 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -255,6 +255,11 @@
             self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize_attrs)
             self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize_half_pixel_centers)
 
+        # Resize Bilinear specific checks:
+        self.specific_constraints[Op.ResizeBilinear].append(
+            TFLiteSupportedOperators.constraint_resizebi_half_pixel_centers_dims
+        )
+
         # Vector Product specific checks:
         for op_type in TFLiteSupportedOperators.fc_vector_products:
             self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_type)
@@ -602,8 +607,8 @@
         """The width and height of the IFM and OFM must match one of the following criteria:
         IFM W and H must both be 1
         IFM must match OFM
-        OFM W and H must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True
-        OFM W and H must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False"""
+        W and H scaling must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True
+        W and H scaling must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False"""
         # Easier to start with False condition as very few cases result in a supported resize
         valid = False
         ifm_shape = op.ifm.shape
@@ -661,11 +666,30 @@
 
     @staticmethod
     def constraint_resize_half_pixel_centers(op):
-        "half_pixel_centers are not supported"
-        valid = True
-        if op.attrs.get("half_pixel_centers", False):
+        """Half_pixel_centers are only supported for resize bilinear with IFM dtype int8 or uint8"""
+        valid = op.ifm.dtype in (DataType.int8, DataType.uint8)
+        half_pixel_centers = op.attrs.get("half_pixel_centers", False)
+        if half_pixel_centers and op.type != Op.ResizeBilinear:
             valid = False
-        return valid, f"Op has half_pixel_centers set to {not valid}."
+        return valid, f"Op type={op.type}, ifm dtype={op.ifm.dtype} and half_pixel_centers={half_pixel_centers}"
+
+    @staticmethod
+    def constraint_resizebi_half_pixel_centers_dims(op):
+        """Half_pixel_centers for resize bilinear requires that OFM W and H is 2x IFM W and H"""
+        half_pixel_centers = op.attrs.get("half_pixel_centers", False)
+        if not half_pixel_centers:
+            valid = True
+        elif len(op.ifm.shape) >= 3:
+            ifm_h, ifm_w = op.ifm.shape[-3:-1]
+            ofm_h, ofm_w = op.ofm.shape[-3:-1]
+            valid = ofm_h / ifm_h == 2 and ofm_w / ifm_w == 2
+        else:
+            # Unexpected IFM shape
+            valid = False
+        return (
+            valid,
+            f"Op has ifm_shape={op.ifm.shape}, ofm_shape={op.ofm.shape} and half_pixel_centers={half_pixel_centers}",
+        )
 
     @staticmethod
     def constraint_pad_shape(op):
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index db225fb..6f9467e 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
 #
 # SPDX-License-Identifier: Apache-2.0
 #
@@ -281,6 +281,12 @@
         else:
             quantised_scales = [quantise_scale(scale) for scale in scales]
 
+    # Check the output quantisation to see if the scale value needs increasing to the next one
+    if first_consumer_op.get_output_quantization().next_after:
+        for i, quant_scale in enumerate(quantised_scales):
+            q_scale, q_shift = quant_scale
+            quantised_scales[i] = (q_scale + 1, q_shift)
+
     # If only 1 quantised scale is used, repeat that value for the length of the biases
     if len(quantised_scales) == 1:
         quantised_scales = [quantised_scales[0]] * len(biases)