MLBEDSW-2654: Convert Resizebilinear to a number of 2x2 pools
Signed-off-by: Charles Xu <charles.xu@arm.com>
Change-Id: Ida307afc33cd7963bdeb505df400732a3efcc846
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 65588bf..ab7f2db 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -15,6 +15,8 @@
# limitations under the License.
# Description:
# The SupportedOperators class which is a collection of all supported operators and parameter checks.
+import numpy as np
+
from .data_type import BaseType
from .data_type import DataType
@@ -287,13 +289,15 @@
return True
if op.inputs[0].shape == op.outputs[0].shape:
return True
- upscaled_shape = [op.inputs[0].shape[1] * 2, op.inputs[0].shape[2] * 2]
- out_shape = op.outputs[0].shape[1:3]
- if not op.attrs["align_corners"] and out_shape != upscaled_shape:
- return False
- elif op.attrs["align_corners"] and out_shape != [upscaled_shape[0] - 1, upscaled_shape[1] - 1]:
- return False
- return True
+ upscaled_shape = np.array(op.inputs[0].shape[1:3])
+ out_shape = np.array(op.outputs[0].shape[1:3])
+ while (upscaled_shape < out_shape).all():
+ upscaled_shape *= 2
+ if op.attrs["align_corners"]:
+ upscaled_shape -= 1
+ if np.array_equal(out_shape, upscaled_shape):
+ return True
+ return False
def check_vector_product_restrictions(self, op):
# check data type