MLBEDSW-2654: Convert Resizebilinear to a number of 2x2 pools

Signed-off-by: Charles Xu <charles.xu@arm.com>
Change-Id: Ida307afc33cd7963bdeb505df400732a3efcc846
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 8d920d8..9f92e75 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -210,14 +210,70 @@
     return op
 
 
+# Convert ResizeBilinear to a number of 2x2 pool ops
+def convert_resizebilinear_to_2x2_pool(op):
+    count = 0
+    pre_op = op
+    outputs = op.outputs
+
+    op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
+    if op.attrs["align_corners"]:
+        shape_modifier = 1
+        op.attrs["padding"] = b"VALID"
+    else:
+        shape_modifier = 0
+        op.attrs["padding"] = b"SAME"
+    op.inputs[0].resampling_mode = resampling_mode.NEAREST
+
+    upscaled_shape = np.array(op.inputs[0].shape[1:3])
+    out_shape = np.array(op.outputs[0].shape[1:3])
+    if (upscaled_shape == upscaled_shape * 2 - shape_modifier).all():
+        return op
+
+    while (upscaled_shape < out_shape).all():
+        if count == 0:
+            scaled_op = pre_op
+        else:
+            scaled_op = op.clone("_{}".format(count))
+            scaled_op.inputs[0] = pre_op.outputs[0]
+
+        upscaled_shape = upscaled_shape * 2 - shape_modifier
+
+        if (upscaled_shape == out_shape).all():
+            scaled_op.outputs = outputs
+            scaled_op.outputs[0].ops = [scaled_op]
+        else:
+            shape = outputs[0].shape.copy()
+            shape[1:3] = upscaled_shape[0:2]
+            out_tens = Tensor(shape, DataType.int16, "{}_{}".format(op.outputs[0].name, count))
+            out_tens.quantization = op.outputs[0].quantization.clone()
+            out_tens.quantization.quant_min = np.iinfo(np.int16).min
+            out_tens.quantization.quant_max = np.iinfo(np.int16).max
+            scaled_op.set_output_tensor(out_tens)
+            pre_op = scaled_op
+            count += 1
+
+        # Setup the scale value
+        if scaled_op.inputs[0].dtype.bits == 8 and scaled_op.outputs[0].dtype.bits == 16:
+            scaled_op.attrs["rescale"] = 128
+        elif scaled_op.inputs[0].dtype.bits == 16 and scaled_op.outputs[0].dtype.bits == 8:
+            scaled_op.attrs["rescale"] = 1 / 128
+        elif "rescale" in scaled_op.attrs:
+            del scaled_op.attrs["rescale"]
+
+    return op
+
+
 def fixup_resizebilinear(op, arch):
-    if op.type == "ResizeBilinear":
-        if op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
-            convert_resizebilinear_1x1_to_add(op)
-        elif op.inputs[0].shape == op.outputs[0].shape:
+    if op.type == "ResizeBilinear" and op.run_on_npu:
+        if op.inputs[0].shape == op.outputs[0].shape:
             # Bypass nop resizebilinear
             op.inputs = op.inputs[:1]
             op.type = "Identity"
+        elif op.inputs[0].shape[1] == 1 and op.inputs[0].shape[2] == 1:
+            convert_resizebilinear_1x1_to_add(op)
+        else:
+            convert_resizebilinear_to_2x2_pool(op)
 
     return op
 
@@ -822,8 +878,6 @@
         fixup_pack_input,
         fixup_conv2d_backprop,
         fixup_act_reorder,
-        add_attrs_to_resizebilinear,
-        add_padding_fields,
         mark_npu_block_type,
         fixup_elementwise_with_scalars,
         reorder_depthwise_weights,
@@ -842,7 +896,7 @@
     for idx, sg in enumerate(nng.subgraphs):
         # remove passthrough tensors and attempt further optimizations
         nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
-            sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev]
+            sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields]
         )
 
     if verbose_graph: