MLBEDSW-6454: Enable ReLu with negative alpha value
Removing constraint for negative alpha value in ReLu
for int8 and uint8.
Signed-off-by: Johan Alfven <johan.alfven@arm.com>
Change-Id: Id7a3a30bf5d1f0a591f990bd04cd0dbbad5819c6
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index 2d6ca15..e290dd2 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -413,12 +413,17 @@
def test_constraint_alpha_valid():
- # Alpha cannot be negative
- op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2])
+ # Alpha can only be negative for int8 and uint8
+ op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2], DataType.int16)
op.attrs["alpha"] = 0
assert semantic_checker.is_operator_semantic_valid(op)
op.attrs["alpha"] = -1
assert not semantic_checker.is_operator_semantic_valid(op)
+ op = testutil.create_elemwise_op(Op.LeakyRelu, "op", [2, 2], None, [2, 2], DataType.int8)
+ op.attrs["alpha"] = 0
+ assert semantic_checker.is_operator_semantic_valid(op)
+ op.attrs["alpha"] = -1
+ assert semantic_checker.is_operator_semantic_valid(op)
def test_constraint_hardswish_dtype():
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index c811a0d..e0541df 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -532,10 +532,11 @@
@staticmethod
def constraint_alpha_valid(op):
- "Alpha must not be negative"
+ "Alpha only allowed to be negative if IFM is int8 or uint8"
alpha = op.attrs["alpha"]
- valid = alpha >= 0
- return valid, f"Op has alpha={alpha}"
+ ifm_dtype = op.ifm.dtype
+ valid = ifm_dtype == DataType.int8 or ifm_dtype == DataType.uint8 or alpha >= 0
+ return valid, f"Op has alpha={alpha} and ifm_dtype={ifm_dtype} "
@staticmethod
def constraint_keep_dim_ifm_ofm(op):