MLBEDSW-7654: Extend support for Mean where HxW > 4096
* Convert Means with large IFMs to several DeptwiseConv2DBias and Add
operations.
* Update tflite supported operator check with new height and width
constraints.
* Update unit-tests to verify supported operator changes.
* Fix output-diff for 2D IFMs (MLBEDSW-7772)
Signed-off-by: Alexander Hansson <Alexander.Hansson@arm.com>
Change-Id: Ifae6fb1cdac475ae7dac5116c5f13631ff82108a
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index f965d2b..92a7f3c 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -191,7 +191,10 @@
filter_range = (1, 8)
filter_height_range = (1, 256)
filter_product_range = (1, 256 * 256)
- mean_kernel_product = 64 * 64
+ mean_width_size = 64 * 64
+ mean_kernel_product_int8 = 2 ** (24)
+ mean_kernel_product_uint8 = 2 ** (23)
+ mean_kernel_product_int16 = 2 ** (16)
def __init__(self):
# Setup the generic constraints. Note: the order matters
@@ -311,7 +314,7 @@
# Mean specific checks:
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product)
- self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_single_axis)
+ self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_width)
# Reshape specific checks:
self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
@@ -838,14 +841,35 @@
return valid, f"Op has ifm_shape={ifm_shape} and ifm2_shape={ifm2_shape}"
@classmethod
- @docstring_format_args([mean_kernel_product])
+ @docstring_format_args([mean_kernel_product_int8, mean_kernel_product_uint8, mean_kernel_product_int16])
def constraint_mean_height_width_product(cls, op):
- """Product of height and width must be no greater than {}"""
+ """Product of height and width must be no greater than:
+ - {} for signed 8-bit inputs
+ - {} for unsigned 8-bit inputs
+ - {} for signed 16-bit inputs"""
shape = op.inputs[0].shape
hi = 0 if len(shape) < 4 else 1
h, w = shape[hi : hi + 2]
- max_prod = cls.mean_kernel_product
- return h * w <= max_prod, f"Product of height and width is {h * w}"
+ if op.ifm.dtype == DataType.int16:
+ max_prod = cls.mean_kernel_product_int16
+ datatype = "int16"
+ elif op.ifm.dtype == DataType.uint8:
+ max_prod = cls.mean_kernel_product_uint8
+ datatype = "uint8"
+ else:
+ max_prod = cls.mean_kernel_product_int8
+ datatype = "int8"
+ return h * w <= max_prod, f"Datatype is {datatype}, product of height and width is {h * w}"
+
+ @classmethod
+ @docstring_format_args([mean_width_size])
+ def constraint_mean_width(cls, op):
+ """Width must be no greater than {}"""
+ shape = op.inputs[0].shape
+ hi = 0 if len(shape) < 4 else 1
+ h, w = shape[hi : hi + 2]
+ max_width = cls.mean_width_size
+ return w <= max_width, f"Width is {w}"
@classmethod
@docstring_format_args([dilated_height_range[1]])