MLBEDSW-4157: Add RESIZE_NEAREST_NEIGHBOR support

 - Changed ResizeBilinear to support ResizeNearestNeighbor as well for
1x1 IFM, IFM equal OFM, and non-align corners
 - Added support for ResizeNearestNeighbor with align corners by
converting to a DepthwiseConv
 - Updated supported operator unit tests
 - Added is_resize() helper function and some associated refactoring

Signed-off-by: Tim Hall <tim.hall@arm.com>
Change-Id: Id5bdf2a25e8aa6a4f28b7236250abf768141ce37
diff --git a/ethosu/vela/api.py b/ethosu/vela/api.py
index 399fd46..26ca291 100644
--- a/ethosu/vela/api.py
+++ b/ethosu/vela/api.py
@@ -374,7 +374,7 @@
     def __init__(self, pooling_op_type: NpuPoolingOp):
         super().__init__(NpuOperationType.Pooling)
         self.sub_op_type: NpuPoolingOp = pooling_op_type
-        # Set to a float value for ResizeBilinear operations (affects scaling), else to None
+        # Set to a float value for ResizeBilinear/NearestNeighbor operations (affects scaling), else to None
         self.rescale: Optional[float] = None
 
 
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index a52bdc3..7e13b62 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -85,7 +85,7 @@
     upscaling = 1
     if sched_op.op_type == Op.Conv2DBackpropInputSwitchedBias:
         upscaling = ofm_shape.height // ifm.shape.height
-    elif sched_op.op_type == Op.ResizeBilinear:
+    elif sched_op.op_type.is_resize_op():
         upscaling = round_up_divide(ofm_shape.height, ifm.shape.height)
 
     # Get kernel height and height dilation
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index e6bfc1c..2ce150f 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -129,7 +129,7 @@
 def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
     """Specifies type of rounding to be used"""
     rounding_mode = NpuRoundingMode.TFL
-    if op.type == Op.ResizeBilinear:
+    if op.type.is_resize_op():
         rounding_mode = NpuRoundingMode.NATURAL
     elif (
         op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
@@ -201,17 +201,6 @@
     return mem_limits
 
 
-def get_upscale(op: Operation) -> NpuResamplingMode:
-    upscale = NpuResamplingMode.NONE
-    if op.type == Op.ResizeBilinear:
-        # perform nearest neighbor upscale
-        upscale = NpuResamplingMode.NEAREST
-    elif op.type == Op.Conv2DBackpropInputSwitchedBias:
-        # perform insert zero upscale
-        upscale = NpuResamplingMode.TRANSPOSE
-    return upscale
-
-
 def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
     if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
         block = ifm_box.get_block()
@@ -224,7 +213,7 @@
     """Checks if quantization should use 0 as zero point"""
     if tens.dtype == DataType.int32 and is_ifm_tensor:
         return True
-    if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
+    if ps.primary_op.type not in (Op.AvgPool, Op.CLZ, Op.SHL) and not ps.primary_op.type.is_resize_op():
         return False
     if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
         return False
@@ -435,10 +424,9 @@
     """Converts the command to NpuPoolingOperation"""
     ps = cmd.ps
     op = ps.primary_op
-    pool_op = NpuPoolingOp.AVERAGE
     if op.type.is_maxpool_op():
         pool_op = NpuPoolingOp.MAX
-    elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
+    elif op.type.is_avgpool_op() or op.type.is_resize_op():
         pool_op = NpuPoolingOp.AVERAGE
     elif op.type == Op.ReduceSum:
         pool_op = NpuPoolingOp.REDUCE_SUM
@@ -485,18 +473,18 @@
     set_common_op_fields(npu_op, cmd, arch)
     # Check if output scale needs to be overridden
     output_scale = None
-    if op.type == Op.Add and "resizebilinear" in op.attrs:
+    if op.type == Op.Add and op.original_type.is_resize_op():
         # Force output scale same as the input scale for
-        # resizebilinear 1x1 that is converted to add
+        # resizebilinear/nearestneighbor 1x1 that is converted to add
         output_scale = npu_op.ifm2.quantization.scale_f32
-    if op.type == Op.Abs:
+    elif op.type == Op.Abs:
         output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
-    if op.type == Op.LeakyRelu:
+    elif op.type == Op.LeakyRelu:
         output_scale = op.attrs["alpha"]
-    if op.type in (Op.RescaleAdd, Op.RescaleMul):
+    elif op.type in (Op.RescaleAdd, Op.RescaleMul):
         assert op.rescale is not None, f"{op.type} must have rescale"
         npu_op.rescale = op.rescale
-    if op.type in (Op.Add, Op.Mul, Op.Sub):
+    elif op.type in (Op.Add, Op.Mul, Op.Sub):
         if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
             output_scale = 1 / 0x3000
     if output_scale is not None:
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index f3eace7..1a34d0e 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -248,8 +248,9 @@
     RescaleAdd = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
     RescaleMul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
     Reshape = OperatorInfo(indices=NNG_IFM_INDICES)
+    # resize ops map to pooling operations unless explicitly converted to other operations in the graph optimiser
     ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
-    ResizeNearestNeighbor = OperatorInfo()
+    ResizeNearestNeighbor = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
     ReverseSequence = OperatorInfo()
     ReverseV2 = OperatorInfo()
     Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
@@ -364,6 +365,9 @@
     def is_concat_op(self):
         return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack)
 
+    def is_resize_op(self):
+        return self in (Op.ResizeBilinear, Op.ResizeNearestNeighbor)
+
     def needs_bias(self):
         return bool(self.info.indices.biases)
 
@@ -467,6 +471,7 @@
 
     __slots__ = (
         "type",
+        "original_type",
         "name",
         "op_index",
         "attrs",
@@ -497,6 +502,7 @@
 
     def __init__(self, op_type: Op, name: str):
         self.type = op_type
+        self.original_type = op_type
         self.name = name
         self.attrs: Dict[str, Any] = {}
         self.inputs: List[Optional[Tensor]] = []
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 050b096..988e52e 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -61,10 +61,9 @@
         Op.AvgPool,
         Op.MaxPool,
         Op.ReduceSum,
-        # deconvolution
-        Op.ResizeBilinear,
     )
-)
+    # resize ops use pooling operations unless explicitly converted to other operations prior to pass packing
+) | Op.op_set(Op.is_resize_op)
 
 binary_elem_wise_main_ops = Op.op_set(Op.is_binary_elementwise_op)
 
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 12a36ca..a8d1ddf 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -706,7 +706,7 @@
             scale = explicit_scaling.multiplier[0]
             shift = explicit_scaling.shift[0]
         else:
-            # for ResizeBilinear operations with rescale
+            # for ResizeBilinear/NearestNeighbor operations with rescale
             rescale = pool_op.rescale
             rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
             scale, shift = scaling.quantise_pooling_scale(kernel.height * kernel.width, rescale_bits)
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index ab12e41..89c2799 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -306,84 +306,88 @@
     assert not support.is_operator_supported(op)
 
 
-def test_constraint_bilinear_resize():
-    # IFM W and H == 1
-    op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 1, 1, 8], [1, 8, 8, 8])
-    op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
-    assert support.is_operator_supported(op)
+def test_constraint_resize():
+    for resize_op in Op.op_set(Op.is_resize_op):
+        # IFM W and H == 1
+        op = testutil.create_op_with_quant_tensors(resize_op, [1, 1, 1, 8], [1, 8, 8, 8])
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+        assert support.is_operator_supported(op)
 
-    # IFM == OFM
-    op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 8, 8, 8], [1, 8, 8, 8])
-    op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
-    assert support.is_operator_supported(op)
+        # IFM == OFM
+        op = testutil.create_op_with_quant_tensors(resize_op, [1, 8, 8, 8], [1, 8, 8, 8])
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+        assert support.is_operator_supported(op)
 
-    # IFM x2 == OFM ; align_corners = False
-    op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 8, 8, 8])
-    op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
-    assert support.is_operator_supported(op)
+        # IFM x2 == OFM ; align_corners = False
+        op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+        assert support.is_operator_supported(op)
 
-    # IFM x4 == OFM ; align_corners = False
-    op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 16, 16, 8])
-    op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [16, 16], np.int32))
-    assert support.is_operator_supported(op)
+        # IFM x4 == OFM ; align_corners = False
+        op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 16, 16, 8])
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [16, 16], np.int32))
+        assert support.is_operator_supported(op)
 
-    # IFM x8 == OFM ; align_corners = False
-    op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 32, 32, 8])
-    op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [32, 32], np.int32))
-    assert support.is_operator_supported(op)
+        # IFM x8 == OFM ; align_corners = False
+        op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 32, 32, 8])
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [32, 32], np.int32))
+        assert support.is_operator_supported(op)
 
-    # IFM -1 x2 == OFM -1 ; align_corners = True
-    op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 7, 7, 8])
-    op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7], np.int32))
-    op.attrs["align_corners"] = True
-    assert support.is_operator_supported(op)
+        # IFM -1 x2 == OFM -1 ; align_corners = True
+        op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 7, 7, 8])
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7], np.int32))
+        op.attrs["align_corners"] = True
+        assert support.is_operator_supported(op)
 
-    # IFM -1 x4 == OFM -1 ; align_corners = True
-    op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 13, 13, 8])
-    op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [13, 13], np.int32))
-    op.attrs["align_corners"] = True
-    assert support.is_operator_supported(op)
+        # IFM -1 x4 == OFM -1 ; align_corners = True
+        op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 13, 13, 8])
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [13, 13], np.int32))
+        op.attrs["align_corners"] = True
+        assert support.is_operator_supported(op)
 
-    # IFM -1 x8 == OFM -1 ; align_corners = True
-    op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 25, 25, 8])
-    op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [25, 25], np.int32))
-    op.attrs["align_corners"] = True
-    assert support.is_operator_supported(op)
+        # IFM -1 x8 == OFM -1 ; align_corners = True
+        op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 25, 25, 8])
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [25, 25], np.int32))
+        op.attrs["align_corners"] = True
+        assert support.is_operator_supported(op)
 
-    # Invalid case - upscale size
-    op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 17, 17, 8])
-    op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [17, 17], np.int32))
-    assert not support.is_operator_supported(op)
+        # Invalid case - upscale size
+        op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 17, 17, 8])
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [17, 17], np.int32))
+        assert not support.is_operator_supported(op)
 
-    # Invalid case - upscale size with align corners
-    op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 15, 15, 8])
-    op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [15, 15], np.int32))
-    op.attrs["align_corners"] = True
-    assert not support.is_operator_supported(op)
+        # Invalid case - upscale size with align corners
+        op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 15, 15, 8])
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [15, 15], np.int32))
+        op.attrs["align_corners"] = True
+        assert not support.is_operator_supported(op)
 
 
-def test_constraint_bilinear_resize_size():
-    # Invalid case - size != ofm size
-    op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 8, 8, 8])
-    op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7], np.int32))
-    assert not support.is_operator_supported(op)
+def test_constraint_resize_size():
+    for resize_op in Op.op_set(Op.is_resize_op):
+        # Invalid case - size != ofm size
+        op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7], np.int32))
+        assert not support.is_operator_supported(op)
 
 
-def test_constraint_bilinear_resize_attrs():
-    # Invalid case - both align corners and half-pixel centers
-    op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 8, 8, 8])
-    op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
-    op.attrs["align_corners"] = True
-    op.attrs["half_pixel_centers"] = True
-    assert not support.is_operator_supported(op)
+def test_constraint_resize_attrs():
+    for resize_op in Op.op_set(Op.is_resize_op):
+        # Invalid case - both align corners and half-pixel centers
+        op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+        op.attrs["align_corners"] = True
+        op.attrs["half_pixel_centers"] = True
+        assert not support.is_operator_supported(op)
 
 
-def test_constraint_bilinear_resize_hpc():
-    # Invalid case - half-pixel centers (not supported)
-    op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 8, 8, 8])
-    op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
-    op.attrs["half_pixel_centers"] = True
-    assert not support.is_operator_supported(op)
+def test_constraint_resize_half_pixel_centers():
+    for resize_op in Op.op_set(Op.is_resize_op):
+        # Invalid case - half-pixel centers (not supported)
+        op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
+        op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+        op.attrs["half_pixel_centers"] = True
+        assert not support.is_operator_supported(op)
 
 
 def test_constraint_concat_pass():
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index d2899c4..ed8fa1e 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -279,10 +279,9 @@
 
 
 # Convert the op to an elementwise add
-def convert_resizebilinear_1x1_to_add(op):
-    op.type = Op.Add
+def convert_resize_1x1_to_add(op):
+    op.type = Op.Add  # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
     op.name = op.name + "_add"
-    op.attrs["resizebilinear"] = True
     # Create an input tensor filled with zeros
     shape = op.ofm_shapes[0].as_list()
     tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
@@ -301,12 +300,103 @@
     return op
 
 
-# Convert ResizeBilinear to a number of 2x2 nearest neighbor upscaling and one avgpool op with kernel size dependent
-# on the upscaling factor. Avgpool kernel limit of 8x8 when padding is applied limits upscaling to 8x8.
-def convert_resizebilinear_to_upscale_and_average_pool(op):
+# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
+# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
+# to select the appropriate nearest neighbor value
+def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
+    ifm = op.ifm
+    ofm = op.ofm
+    output_depth = ofm.shape[-1]
+    dw_op_attrs = {
+        "padding": Padding.VALID,
+        "stride_h": 1,
+        "stride_w": 1,
+        "strides": (1, 1, 1, 1),
+        "depth_multiplier": 1,
+        "channel_multiplier": 1,
+        "dilation_h_factor": 1,
+        "dilation_w_factor": 1,
+        "dilation": (1, 1, 1, 1),
+    }
+
+    # change resizebilinear to depthwise
+    op.type = Op.DepthwiseConv2DBias
+    op.attrs.update(dw_op_attrs)
+    op.set_input_tensor(ifm, 0)  # ifm tensor index
+    op.activation = None
+
+    # add input resample to resize by x2
+    op.ifm_resampling_mode = resampling_mode.NEAREST
+
+    # don't care about the rounding mode as it is nearest neighbor
+
+    # setup weight tensor
+    weight_quant = QuantizationParameters()
+    weight_quant.scale_f32 = 1.0  # no scaling as only a single non-zero coeff to select the desired value
+    weight_quant.zero_point = 0
+    weight_quant.quant_dim = 0
+    ofm_dtype = ofm.dtype
+    if ofm_dtype == DataType.uint8:
+        weight_value_dtype = np.uint8
+        weight_quant.quant_min = 0
+        weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
+    else:
+        if ofm_dtype == DataType.int8:
+            weight_value_dtype = np.int8
+        else:
+            assert ofm_dtype == DataType.int16
+            weight_value_dtype = np.int16
+
+        weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
+        weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
+
+    weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth]  # HWIO
+
+    # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
+    # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
+    # below-and-right (i.e. next) to it (D).
+    # 0---1---2
+    # | A | B |
+    # 1---*---+
+    # | C | D |
+    # 2---+---+
+    weight_values = [0] * (upscale_factor * upscale_factor)
+    centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
+    weight_values[centre_coeff] = 1
+
+    # add weight tensor, this will discard the size tensor of the resize op
+    op.set_input_tensor(
+        create_const_tensor(
+            "weights",
+            weight_shape,
+            ofm.dtype,
+            np.array(weight_values).reshape(weight_shape),
+            value_dtype=weight_value_dtype,
+            quantization=weight_quant,
+        ),
+        1,  # inputs tensor weight index
+    )
+
+    # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
+    # need to append the bias tensor as resize ops only have 2 inputs
+    assert len(op.inputs) == 2
+    op.inputs.append(None)
+    fixup_bias_tensors(op, None, None)
+
+    # finally update the shape incase we've change the tensor shapes or connections
+    op.set_ifm_ofm_shapes()
+
+    return op
+
+
+# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
+# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
+# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
+def convert_resize_to_upscale_and_average_pool(op):
     pre_op = op
     outputs = op.outputs
     dtype = op.ifm.dtype
+
     op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
     op.attrs["padding"] = Padding.SAME  # doesn't really matter as the kernel is 1x1
     op.ifm_resampling_mode = resampling_mode.NEAREST
@@ -321,14 +411,14 @@
     # between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
     n = int(np.log2(upscale_factor))
 
-    # Perform 2x2 upscaling n-1 times
+    # Perform x2 upscaling n-1 times
     scaled_op = pre_op
     for count in range(n - 1):
         if count > 0:
             scaled_op = op.clone(f"_{count}")
             scaled_op.inputs[0] = pre_op.outputs[0]
 
-        # Nearest neighbor 2x2 upscaling
+        # Nearest neighbor x2 upscaling
         upscaled_shape = upscaled_shape * 2
         shape = op.ofm_shapes[0].as_list()
         shape[1:3] = upscaled_shape
@@ -339,17 +429,30 @@
 
         scaled_op.set_ifm_ofm_shapes()
 
-    # Last 2x2 upscaling also applies avgpool with kernel size dependent on the upscaling factor and adds
-    # padding to the right and bottom.
+    # Last x2 upscaling
     if n > 1:
         scaled_op = op.clone(f"_{n-1}")
         scaled_op.inputs[0] = pre_op.outputs[0]
-    if op.attrs["align_corners"]:
-        scaled_op.attrs["padding"] = Padding.VALID
-    else:
-        scaled_op.attrs["padding"] = Padding.EXPLICIT
-        scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
-    scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
+
+    if scaled_op.original_type == Op.ResizeBilinear:
+        if scaled_op.attrs["align_corners"]:
+            # no padding
+            scaled_op.attrs["padding"] = Padding.VALID
+        else:
+            # padding to the right and bottom (limits average pool to 8x8 kernel)
+            scaled_op.attrs["padding"] = Padding.EXPLICIT
+            scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
+
+        # kernal size dependent on the upscaling factor
+        scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
+    else:  # Op.ResizeNearestNeighbor
+        if scaled_op.attrs["align_corners"]:
+            # use depthwise conv to select the correct value
+            scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
+        else:
+            # keep 1x1 kernel and average pool
+            pass
+
     scaled_op.outputs = outputs
     scaled_op.outputs[0].ops = [scaled_op]
     scaled_op.set_ifm_ofm_shapes()
@@ -357,16 +460,16 @@
     return op
 
 
-def fixup_resizebilinear(op, arch, nng):
-    if op.type == Op.ResizeBilinear and op.run_on_npu:
+def fixup_resize(op, arch, nng):
+    if op.type.is_resize_op() and op.run_on_npu:
         if op.ifm_shapes[0] == op.ofm_shapes[0]:
-            # Bypass nop resizebilinear
+            # Bypass the resize op which is essentially a NOP
             op.inputs = op.inputs[:1]
             op.type = Op.Identity
         elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
-            convert_resizebilinear_1x1_to_add(op)
+            convert_resize_1x1_to_add(op)
         else:
-            convert_resizebilinear_to_upscale_and_average_pool(op)
+            convert_resize_to_upscale_and_average_pool(op)
 
     return op
 
@@ -1130,31 +1233,6 @@
     return avgpool_op
 
 
-def add_attrs_to_resizebilinear(op, arch, nng):
-    if op.type == Op.ResizeBilinear and op.run_on_npu:
-        input_shape = op.ifm_shapes[0]
-        upscaled_height = input_shape.height * 2
-        upscaled_width = input_shape.width * 2
-        out_shape = op.ofm_shapes[0]
-        if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
-            # this means the output is supposed to be a x2 upscale,
-            # so we need to do SAME padding
-            op.attrs["padding"] = Padding.SAME
-        elif (
-            op.attrs["align_corners"]
-            and out_shape.height == (upscaled_height - 1)
-            and out_shape.width == (upscaled_width - 1)
-        ):
-            # here we can just run the avg pool without padding and
-            # produce a (M * 2 - 1, N * 2 - 1) sized output
-            op.attrs["padding"] = Padding.VALID
-        else:
-            return op
-        op.ifm_resampling_mode = resampling_mode.NEAREST
-        op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
-    return op
-
-
 def fixup_bias_tensors(op, arch, nng):
     if op.type.needs_bias() and op.bias is None:
         # Op has no bias, add bias tensor filled with zeros
@@ -1577,7 +1655,7 @@
         fixup_conv2d_backprop,
         fixup_relus_with_differing_ifm_ofm_scaling,
         reorder_depthwise_weights,
-        fixup_resizebilinear,
+        fixup_resize,
         fixup_bias_tensors,
         fixup_asymmetric_weights,
         convert_mul_max_to_abs_or_lrelu,
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index bf155b9..39b08b9 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -799,7 +799,7 @@
     BuiltinOperator.RESIZE_NEAREST_NEIGHBOR: (
         Op.ResizeNearestNeighbor,
         OptionsSerializer("ResizeNearestNeighborOptions", ("align_corners", "half_pixel_centers")),
-        TFLITE_NO_INDICES,
+        TFLITE_IFM_INDICES,
     ),
     BuiltinOperator.LEAKY_RELU: (Op.LeakyRelu, OptionsSerializer("LeakyReluOptions", ("alpha",)), TFLITE_IFM_INDICES),
     BuiltinOperator.SQUARED_DIFFERENCE: (
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 01d2e61..90d93d0 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -58,7 +58,7 @@
     max_pooling_ops = Op.op_set(Op.is_maxpool_op)
     avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
     pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
-    resizing_ops = set((Op.ResizeBilinear,))
+    resizing_ops = Op.op_set(Op.is_resize_op)
     fc_vector_products = set(
         (
             Op.QuantizedMatMul,
@@ -242,10 +242,10 @@
 
         # Resizing specific checks:
         for op_type in TFLiteSupportedOperators.resizing_ops:
-            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bilinear_resize)
-            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bilinear_resize_size)
-            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bilinear_resize_attrs)
-            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bilinear_resize_hpc)
+            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize)
+            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize_size)
+            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize_attrs)
+            self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize_half_pixel_centers)
 
         # Vector Product specific checks:
         for op_type in TFLiteSupportedOperators.fc_vector_products:
@@ -589,7 +589,7 @@
         return True, "Op has padding=SAME"
 
     @staticmethod
-    def constraint_bilinear_resize(op):
+    def constraint_resize(op):
         """The width and height of the IFM and OFM must match one of the following criteria:
         IFM W and H must both be 1
         IFM must match OFM
@@ -625,7 +625,7 @@
         return valid, f"Op has ifm_shape={ifm_shape}, ofm_shape={ofm_shape} and align_corners={align_corners}"
 
     @staticmethod
-    def constraint_bilinear_resize_size(op):
+    def constraint_resize_size(op):
         "The size tensor must match the output tensor shape"
         valid = False
         ofm_shape = op.ofm.shape
@@ -640,7 +640,7 @@
         return valid, f"Op has size={size_h}x{size_w} and ofm_shape={ofm_shape}."
 
     @staticmethod
-    def constraint_bilinear_resize_attrs(op):
+    def constraint_resize_attrs(op):
         "Both align_corners and half_pixel_centers can't be True"
         valid = True
         align_corners = op.attrs.get("align_corners", False)
@@ -651,7 +651,7 @@
         return valid, "Op has both align_corners and half_pixel_centers set to True."
 
     @staticmethod
-    def constraint_bilinear_resize_hpc(op):
+    def constraint_resize_half_pixel_centers(op):
         "half_pixel_centers are not supported"
         valid = True
         if op.attrs.get("half_pixel_centers", False):