Add extra control flow ERROR_IF tests
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I7276dc686d8d18ba44663b73e35ceca2a1cbaadf
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index a850699..c9d35c7 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -73,6 +73,9 @@
CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
+ CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
+ CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
+ CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
class TosaErrorIfArgGen:
@@ -2191,6 +2194,47 @@
return info_dict
@staticmethod
+ def evCondIfCondNotMatchingBool(check=False, **kwargs):
+ error_name = ErrorIf.CondIfCondNotMatchingBool
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Conditional tensor does not match bool type"
+
+ if check:
+ cond = kwargs["cond"]
+ if cond.dtype != DType.BOOL:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
+ error_name = ErrorIf.CondIfCondShapeNotSizeOne
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Conditional tensor is not equal to a size of one"
+
+ if check:
+ cond = kwargs["cond"]
+ # Size of 1 is equivalent to rank 0
+ if len(cond.shape) != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
def evInputListOutputListMismatch(check=False, **kwargs):
error_name = ErrorIf.InputListOutputListMismatch
param_reqs = {"rank": None, "dtype": None, "shape": None}
@@ -2324,6 +2368,30 @@
}
return info_dict
+ @staticmethod
+ def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
+ error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Cond graph output is not a shape of size one"
+
+ if check:
+ basicBlocks = kwargs["basicBlocks"]
+ cond_block = basicBlocks[1]
+ cond_outputs = cond_block.outputs
+ cond_tens = cond_block.tensors
+ # Size of 1 is equivalent to rank 0
+ if len(cond_tens[cond_outputs[0]].shape) != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
class TosaInvalidValidator:
@staticmethod