Add extra control flow ERROR_IF tests

Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I7276dc686d8d18ba44663b73e35ceca2a1cbaadf
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index a850699..c9d35c7 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -73,6 +73,9 @@
     CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
     U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
     U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
+    CondIfCondNotMatchingBool = "CondIfCondNotMatchingBool"
+    CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
+    CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
 
 
 class TosaErrorIfArgGen:
@@ -2191,6 +2194,47 @@
         return info_dict
 
     @staticmethod
+    def evCondIfCondNotMatchingBool(check=False, **kwargs):
+        error_name = ErrorIf.CondIfCondNotMatchingBool
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Conditional tensor does not match bool type"
+
+        if check:
+            cond = kwargs["cond"]
+            if cond.dtype != DType.BOOL:
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs,
+        }
+        return info_dict
+
+    @staticmethod
+    def evCondIfCondShapeNotSizeOne(check=False, **kwargs):
+        error_name = ErrorIf.CondIfCondShapeNotSizeOne
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Conditional tensor is not equal to a size of one"
+
+        if check:
+            cond = kwargs["cond"]
+            # Size of 1 is equivalent to rank 0
+            if len(cond.shape) != 0:
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs,
+        }
+        return info_dict
+
+    @staticmethod
     def evInputListOutputListMismatch(check=False, **kwargs):
         error_name = ErrorIf.InputListOutputListMismatch
         param_reqs = {"rank": None, "dtype": None, "shape": None}
@@ -2324,6 +2368,30 @@
         }
         return info_dict
 
+    @staticmethod
+    def evCondGraphOutputShapeNotSizeOne(check=False, **kwargs):
+        error_name = ErrorIf.CondGraphOutputShapeNotSizeOne
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Cond graph output is not a shape of size one"
+
+        if check:
+            basicBlocks = kwargs["basicBlocks"]
+            cond_block = basicBlocks[1]
+            cond_outputs = cond_block.outputs
+            cond_tens = cond_block.tensors
+            # Size of 1 is equivalent to rank 0
+            if len(cond_tens[cond_outputs[0]].shape) != 0:
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs,
+        }
+        return info_dict
+
 
 class TosaInvalidValidator:
     @staticmethod
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index f3ca512..515e8bb 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -14,6 +14,7 @@
 from generator.tosa_error_if import TosaErrorValidator
 from generator.tosa_error_if import TosaInvalidValidator
 from generator.tosa_utils import DTYPE_ATTRIBUTES
+from generator.tosa_utils import get_wrong_output_type
 from generator.tosa_utils import MAX_RESIZE_DIMENSION
 from generator.tosa_utils import usableDTypes
 from generator.tosa_utils import vect_f32_to_bf16
@@ -1785,15 +1786,32 @@
         self.ser.addOperator(op["op"], input_list, output_list, attr)
         return result_tens
 
+    def _get_condition_tensor(self, op, cond, error_name):
+        if error_name == ErrorIf.CondIfCondNotMatchingBool:
+            cond_type = get_wrong_output_type(op, self.rng, DType.BOOL)
+        else:
+            cond_type = DType.BOOL
+        if error_name == ErrorIf.CondIfCondShapeNotSizeOne:
+            choice = self.rng.choice([1, 2])
+            if choice == 1:
+                cond_shape = [2]
+            else:
+                cond_shape = [1, 2]
+        else:
+            # Must be of size 1 (rank 0)
+            cond_shape = []
+        cond_tens = self.ser.addConst(cond_shape, cond_type, [cond])
+        return cond_tens
+
     def build_cond_if_const(
         self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
     ):
         # For cond_if with constants, we're supplied with then/else tensors that we ignore
-        # (except for the generated shap) and the condition.  Build Then/Else blocks
+        # (except for the generated shape) and the condition.  Build Then/Else blocks
         # and fill them with const nodes for the body.
 
         # Condition tensor
-        cond_tens = self.ser.addConst([], DType.BOOL, [cond])
+        cond_tens = self._get_condition_tensor(op, cond, error_name)
 
         # Make then/else tensors
         out_shape = then_tens.shape
@@ -1848,6 +1866,7 @@
             error_name,
             op=op,
             basicBlocks=self.ser.basicBlocks,
+            cond=cond_tens,
         ):
             return None
 
@@ -1860,7 +1879,7 @@
         # alternately add or subtract them based on the condition
 
         # Condition tensor
-        cond_tens = self.ser.addConst([], DType.BOOL, [cond])
+        cond_tens = self._get_condition_tensor(op, cond, error_name)
 
         result_tens = self.ser.addOutput(a.shape, a.dtype)
 
@@ -1930,6 +1949,7 @@
             a=a,
             b=b,
             basicBlocks=self.ser.basicBlocks,
+            cond=cond_tens,
         ):
             return None
 
@@ -1997,11 +2017,18 @@
         zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
 
         if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
-            cond_tens = self.ser.addOutput(
-                [], self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
-            )
+            cond_type = self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
         else:
-            cond_tens = self.ser.addOutput([], DType.BOOL)
+            cond_type = DType.BOOL
+        if error_name == ErrorIf.CondGraphOutputShapeNotSizeOne:
+            choice = self.rng.choice([1, 2])
+            if choice == 1:
+                cond_shape = [3]
+            else:
+                cond_shape = [1, 2]
+        else:
+            cond_shape = []
+        cond_tens = self.ser.addOutput(cond_shape, cond_type)
 
         self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
 
@@ -3818,6 +3845,8 @@
             "error_if_validators": (
                 TosaErrorValidator.evOutputListThenGraphMismatch,
                 TosaErrorValidator.evOutputListElseGraphMismatch,
+                TosaErrorValidator.evCondIfCondNotMatchingBool,
+                TosaErrorValidator.evCondIfCondShapeNotSizeOne,
             ),
         },
         "cond_if_binary": {
@@ -3835,6 +3864,8 @@
                 TosaErrorValidator.evInputListElseGraphMismatch,
                 TosaErrorValidator.evOutputListThenGraphMismatch,
                 TosaErrorValidator.evOutputListElseGraphMismatch,
+                TosaErrorValidator.evCondIfCondNotMatchingBool,
+                TosaErrorValidator.evCondIfCondShapeNotSizeOne,
             ),
         },
         # while_loop
@@ -3854,6 +3885,7 @@
                 TosaErrorValidator.evInputListBodyGraphInputMismatch,
                 TosaErrorValidator.evInputListBodyGraphOutputMismatch,
                 TosaErrorValidator.evCondGraphOutputNotMatchingBool,
+                TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
             ),
         },
     }
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index d79ab3c..29ae898 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -142,6 +142,9 @@
                 DType.INT32,
                 DType.INT48,
             )
+    else:
+        # Assume all types but the input type are incorrect
+        incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
     return rng.choice(a=incorrect_types)