MLBEDSW-3694 Replace padding with enum
Use an Enum instead of a bytestring to specify VALID or SAME padding
Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com>
Change-Id: I4e87f8c32b3bfac176d822a68de061e85a558fce
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index a1fcf6a..f2c2eb9 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -24,6 +24,7 @@
from .numeric_util import is_integer
from .operation import get_slice_offsets
from .operation import Op
+from .operation import Padding
from .tensor import check_quantized_tens_scaling_equal
from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
from .tflite_mapping import optype_to_builtintype
@@ -569,7 +570,7 @@
@staticmethod
def constraint_tconv_same(op):
"SAME padding: OFM dimensions must equal IFM dimensions multiplied by stride"
- if op.attrs["padding"] == b"SAME":
+ if op.attrs["padding"] == Padding.SAME:
w = op.kernel.stride.x
h = op.kernel.stride.y
ifm_shape = op.ifm.shape
@@ -582,7 +583,7 @@
def constraint_tconv_valid(op):
"""VALID padding: OFM dimensions must equal IFM dimensions multiplied by stride,
minus difference between kernel size and stride"""
- if op.attrs["padding"] == b"VALID":
+ if op.attrs["padding"] == Padding.VALID:
s_w = op.kernel.stride.x
s_h = op.kernel.stride.y
k_w = op.kernel.width
@@ -626,7 +627,7 @@
@docstring_format_args(filter_range)
def constraint_filter_range(cls, op):
"Kernel filter values for both width and height must be in the range [{}, {}]"
- if op.attrs["padding"] == b"SAME":
+ if op.attrs["padding"] == Padding.SAME:
w = op.kernel.width
h = op.kernel.height
filter_min, filter_max = cls.filter_range
@@ -656,7 +657,7 @@
@docstring_format_args(filter_height_range)
def constraint_filter_height_range_valid_pad(op):
"VALID padding: Kernel filter height must be in the range [{}, {}]"
- if op.attrs["padding"] == b"VALID":
+ if op.attrs["padding"] == Padding.VALID:
return SupportedOperators.constraint_filter_height_range(op)
return True, "Op has padding=SAME"
@@ -664,7 +665,7 @@
@docstring_format_args(filter_product_range)
def constraint_filter_product_range_valid_pad(op):
"VALID padding: Product of kernel filter width and height must be in the range [{}, {}]"
- if op.attrs["padding"] == b"VALID":
+ if op.attrs["padding"] == Padding.VALID:
return SupportedOperators.constraint_filter_product_range(op)
return True, "Op has padding=SAME"