Add negative testing to cond_if, while_loop

Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
Signed-off-by: Les Bell <les.bell@arm.com>
Change-Id: Ie6c8c8653874f9eed6007a54a3ad526601a4a669
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index cd59898..4e944ea 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -29,6 +29,7 @@
 import traceback
 import math
 import itertools
+from copy import deepcopy
 
 from enum import IntEnum, Enum, unique
 from tosa_ref_run import TosaReturnCode
@@ -48,6 +49,13 @@
 Op = tosa.Op.Op()
 ResizeMode = tosa.ResizeMode.ResizeMode()
 
+
+def product(shape):
+    value = 1
+    for n in shape:
+        value *= n
+    return value
+
 class TosaQuantGen:
     """QuantizedInfo random generator helper functions.  Specify with 'qgen': in the operator defintion"""
 
@@ -185,8 +193,9 @@
         pl, const = opName["operands"]
         shape = testGen.makeShape(rank)
 
-        # Constrict dimension size for large ranks when creating WrongRank tests
-        shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
+        # Constrict the overall size of the shape when creating ERROR_IF tests
+        if error_name:
+            shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
 
         shape_list = []
         for i in range(pl + const):
@@ -213,8 +222,9 @@
         if testGen.args.max_batch_size:
             shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
 
-        # Constrict dimension size for large ranks when creating WrongRank tests
-        shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
+        # Constrict the overall size of the shape when creating ERROR_IF tests
+        if error_name:
+            shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
 
         shape_list = []
         for i in range(pl + const):
@@ -404,8 +414,9 @@
 
         input_shape = testGen.makeShape(rank)
 
-        # Constrict dimension size for large ranks when creating WrongRank tests
-        shape = TosaErrorIfArgGen.eiRestrictDimension(input_shape, error_name)
+        # Constrict the overall size of the shape when creating ERROR_IF tests
+        if error_name:
+            shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
 
         filter_oc = testGen.rng.integers(
             low=testGen.args.tensor_shape_range[0],
@@ -428,8 +439,9 @@
 
         a_shape = testGen.makeShape(rank)
 
-        # Constrict dimension size for large ranks when creating WrongRank tests
-        shape = TosaErrorIfArgGen.eiRestrictDimension(a_shape, error_name)
+        # Constrict the overall size of the shape when creating ERROR_IF tests
+        if error_name:
+            shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
 
         # Get a random number for b_oc even if target shape is defined
         b_oc = np.int32(
@@ -1405,17 +1417,13 @@
                 output_list = []
         return input_list, output_list
 
-
     @staticmethod
-    def eiRestrictDimension(shape, error_name):
-        # Restrict dimension size if rank is large for WrongRank Error_If
-        # This will keep the test sizes reasonably small
-        if error_name == ErrorIf.WrongRank:
-            if len(shape) > 4:
-                shape[4] = 1
-
-        return shape
-
+    def eiRestrictDimensions(shape, max_dim=32, max_items=100000):
+        """Restrict the dimensions and overall size of a shape to max_dim and max_items."""
+        new_shape = [min(d, max_dim) for d in shape] if max(shape) > max_dim else shape
+        while product(new_shape) > max_items:
+            new_shape = [max(d - 1, 1) for d in new_shape]
+        return new_shape
 
     def eiSliceErrorIf(testGen, error_name, input_shape, start, size):
         if error_name == ErrorIf.StartSmallerZero:
@@ -2705,6 +2713,243 @@
         }
         return info_dict
 
+    @staticmethod
+    def evInputListThenGraphMismatch(check=False, **kwargs):
+        error_name = ErrorIf.CondIfInputListThenGraphMismatch
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Input list shape does not match then-graph shape"
+
+        if check:
+            a = kwargs['a']
+            b = kwargs['b']
+            basicBlocks = kwargs['basicBlocks']
+            then_block = basicBlocks[1]
+            then_inputs = then_block.inputs
+            then_tens = then_block.tensors
+            if (a.shape != then_tens[then_inputs[0]].shape) or (b.shape != then_tens[then_inputs[1]].shape):
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+
+    @staticmethod
+    def evInputListElseGraphMismatch(check=False, **kwargs):
+        error_name = ErrorIf.CondIfInputListElseGraphMismatch
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Input list shape does not match else-graph shape"
+
+        if check:
+            a = kwargs['a']
+            b = kwargs['b']
+            basicBlocks = kwargs['basicBlocks']
+            else_block = basicBlocks[2]
+            else_inputs = else_block.inputs
+            else_tens = else_block.tensors
+            if (a.shape != else_tens[else_inputs[0]].shape) or (b.shape != else_tens[else_inputs[1]].shape):
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+
+    @staticmethod
+    def evOutputListThenGraphMismatch(check=False, **kwargs):
+        error_name = ErrorIf.CondIfOutputListThenGraphMismatch
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Output list shape does not match then-graph shape"
+
+        if check:
+            basicBlocks = kwargs['basicBlocks']
+            cond_block = basicBlocks[0]
+            cond_outputs = cond_block.outputs
+            cond_tens = cond_block.tensors
+            then_block = basicBlocks[1]
+            then_outputs = then_block.outputs
+            then_tens = then_block.tensors
+            if then_tens[then_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+
+    @staticmethod
+    def evOutputListElseGraphMismatch(check=False, **kwargs):
+        error_name = ErrorIf.CondIfOutputListElseGraphMismatch
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Output list shape does not match else-graph shape"
+
+        if check:
+            basicBlocks = kwargs['basicBlocks']
+            cond_block = basicBlocks[0]
+            cond_outputs = cond_block.outputs
+            cond_tens = cond_block.tensors
+            else_block = basicBlocks[2]
+            else_outputs = else_block.outputs
+            else_tens = else_block.tensors
+            if else_tens[else_outputs[0]].shape != cond_tens[cond_outputs[0]].shape:
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+
+    @staticmethod
+    def evInputListOutputListMismatch(check=False, **kwargs):
+        error_name = ErrorIf.InputListOutputListMismatch
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Input list does not match output list"
+
+        if check:
+            basicBlocks = kwargs['basicBlocks']
+            while_block = basicBlocks[0]
+            while_inputs = while_block.inputs
+            while_outputs = while_block.outputs
+            while_tens = while_block.tensors
+            if while_tens[while_inputs[1]].shape != while_tens[while_outputs[0]].shape:
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+
+    @staticmethod
+    def evInputListCondGraphMismatch(check=False, **kwargs):
+        error_name = ErrorIf.InputListCondGraphMismatch
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Input list does not match cond graph"
+
+        if check:
+            basicBlocks = kwargs['basicBlocks']
+            while_block = basicBlocks[0]
+            while_inputs = while_block.inputs
+            while_tens = while_block.tensors
+            cond_block = basicBlocks[1]
+            cond_inputs = cond_block.inputs
+            cond_tens = cond_block.tensors
+            if ((while_tens[while_inputs[0]].shape != cond_tens[cond_inputs[0]].shape) or
+                (while_tens[while_inputs[1]].shape != cond_tens[cond_inputs[2]].shape)):
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+
+    @staticmethod
+    def evInputListBodyGraphInputMismatch(check=False, **kwargs):
+        error_name = ErrorIf.InputListBodyGraphInputMismatch
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Input list does not match body graph input"
+
+        if check:
+            basicBlocks = kwargs['basicBlocks']
+            while_block = basicBlocks[0]
+            while_inputs = while_block.inputs
+            while_tens = while_block.tensors
+            body_block = basicBlocks[2]
+            body_outputs = body_block.inputs
+            body_tens = body_block.tensors
+            if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or
+                (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)):
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+
+    @staticmethod
+    def evInputListBodyGraphOutputMismatch(check=False, **kwargs):
+        error_name = ErrorIf.InputListBodyGraphOutputMismatch
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Input list does not match body graph output"
+
+        if check:
+            basicBlocks = kwargs['basicBlocks']
+            while_block = basicBlocks[0]
+            while_inputs = while_block.inputs
+            while_tens = while_block.tensors
+            body_block = basicBlocks[2]
+            body_outputs = body_block.outputs
+            body_tens = body_block.tensors
+            if ((while_tens[while_inputs[0]].shape != body_tens[body_outputs[0]].shape) or
+                (while_tens[while_inputs[1]].shape != body_tens[body_outputs[2]].shape)):
+                error_result = True
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+
+    @staticmethod
+    def evCondGraphOutputNotMatchingBool(check=False, **kwargs):
+        error_name = ErrorIf.CondGraphOutputNotMatchingBool
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Cond graph output is not a match list of booleans"
+
+        if check:
+            basicBlocks = kwargs['basicBlocks']
+            cond_block = basicBlocks[1]
+            cond_outputs = cond_block.outputs
+            cond_tens = cond_block.tensors
+            if cond_tens[cond_outputs[0]].dtype != DType.BOOL:
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
 
 class TosaInvalidValidator:
 
@@ -4131,7 +4376,7 @@
         self.ser.addOperator(op['op'], input_list, output_list, attr)
         return result_tens
 
-    def build_cond_if_const(self, op, then_tens, else_tens, cond):
+    def build_cond_if_const(self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None):
         # For cond_if with constants, we're supplied with then/else tensors that we ignore
         # (except for the generated shap) and the condition.  Build Then/Else blocks
         # and fill them with const nodes for the body.
@@ -4141,6 +4386,14 @@
 
         # Make then/else tensors
         out_shape = then_tens.shape
+
+        # Create an incorrect output shape for error_if tests
+        if error_name in [ErrorIf.CondIfOutputListThenGraphMismatch, ErrorIf.CondIfOutputListElseGraphMismatch]:
+            incorrect_shape = deepcopy(then_tens.shape)
+            for i in range(len(incorrect_shape)):
+                incorrect_shape[i] = incorrect_shape[i] + self.rng.choice([-3, -2, 2, 3])
+            incorrect_arr = np.int32(self.rng.integers(0, 256, size=incorrect_shape))
+
         then_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
         else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
 
@@ -4158,16 +4411,30 @@
 
         self.ser.startBasicBlock(then_block)
         # Build the actual then/else tensors inside their blocks
-        then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
+        if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
+            then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
+        else:
+            then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
         self.ser.addOutputTensor(then_tens)
 
         self.ser.startBasicBlock(else_block)
-        else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
+        if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
+            else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
+        else:
+            else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
         self.ser.addOutputTensor(else_tens)
 
+        TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            basicBlocks=self.ser.basicBlocks
+        )
+
         return result_tens
 
-    def build_cond_if_binary(self, op, a, b, cond):
+    def build_cond_if_binary(self, op, a, b, cond, validator_fcns=None, error_name=None):
         # For cond_if with a binary op in the then/else blocks, take a and b and
         # alternately add or subtract them based on the condition
 
@@ -4182,6 +4449,15 @@
         attr = ts.TosaSerializerAttribute()
         attr.CondIfAttribute(then_block, else_block)
 
+        if error_name in [ErrorIf.CondIfInputListThenGraphMismatch, ErrorIf.CondIfInputListElseGraphMismatch,
+                          ErrorIf.CondIfOutputListElseGraphMismatch, ErrorIf.CondIfOutputListThenGraphMismatch]:
+            incorrect_shape = a.shape.copy()
+            for i in range(len(incorrect_shape)):
+                incorrect_shape[i] += self.rng.choice([-3, -2, 2, 3])
+            incorrect_block_input = deepcopy(a)
+            incorrect_block_input.shape = incorrect_shape
+
+
         # Finally, build the op and the two blocks
         self.ser.addOperator(
             op['op'], [cond_tens.name, a.name, b.name], [result_tens.name], attr
@@ -4196,14 +4472,35 @@
 
         for block, op in ((then_block, then_op), (else_block, else_op)):
             self.ser.startBasicBlock(block)
-            self.ser.addInputTensor(a)
-            self.ser.addInputTensor(b)
-            tens = self.ser.addOutput(a.shape, a.dtype)
+            if ((error_name == ErrorIf.CondIfInputListThenGraphMismatch and block == then_block) or
+                (error_name == ErrorIf.CondIfInputListElseGraphMismatch and block == else_block)):
+                self.ser.addInputTensor(incorrect_block_input)
+                self.ser.addInputTensor(b)
+                tens = self.ser.addOutput(a.shape, a.dtype)
+            elif ((error_name == ErrorIf.CondIfOutputListThenGraphMismatch and block == then_block) or
+                (error_name == ErrorIf.CondIfOutputListElseGraphMismatch and block == else_block)):
+                self.ser.addInputTensor(a)
+                self.ser.addInputTensor(b)
+                tens = self.ser.addOutput(incorrect_block_input.shape, a.dtype)
+            else:
+                self.ser.addInputTensor(a)
+                self.ser.addInputTensor(b)
+                tens = self.ser.addOutput(a.shape, a.dtype)
             self.ser.addOperator(op, [a.name, b.name], [tens.name])
 
+        TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            a=a,
+            b=b,
+            basicBlocks=self.ser.basicBlocks
+        )
+
         return result_tens
 
-    def build_while_loop(self, op, a, iter_val):
+    def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
         iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
 
         cond_block = "COND_BLOCK"
@@ -4220,7 +4517,13 @@
         # Intermediate/output tensors for everything going through the loop
         iter_out = self.ser.addIntermediate(iter.shape, iter.dtype)
         a_out = self.ser.addIntermediate(a.shape, a.dtype)
-        acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
+        if error_name == ErrorIf.InputListOutputListMismatch:
+            incorrect_acc = deepcopy(acc)
+            for i in range(len(incorrect_acc.shape)):
+                incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
+            acc_out = self.ser.addIntermediate(incorrect_acc.shape, acc.dtype)
+        else:
+            acc_out = self.ser.addIntermediate(acc.shape, acc.dtype)
 
         # While_loop operator
         self.ser.addOperator(
@@ -4231,30 +4534,71 @@
         )
         self.ser.addOutputTensor(acc_out)
 
+        if error_name in [ErrorIf.InputListCondGraphMismatch, ErrorIf.InputListBodyGraphInputMismatch, ErrorIf.InputListBodyGraphOutputMismatch]:
+            incorrect_iter = deepcopy(iter)
+            for i in range(len(incorrect_iter.shape)):
+                incorrect_iter.shape[i] += self.rng.choice([-3, -2, 2, 3])
+            if len(incorrect_iter.shape) == 0:
+                incorrect_iter.shape.append(self.rng.choice([-3, -2, 2, 3]))
+
+            incorrect_acc = deepcopy(acc)
+            for i in range(len(incorrect_acc.shape)):
+                incorrect_acc.shape[i] += self.rng.choice([-3, -2, 2, 3])
+
         # COND block (input: iter, output: cond_tens )
         self.ser.startBasicBlock(cond_block)
-        self.ser.addInputTensor(iter)
-        self.ser.addInputTensor(a)
-        self.ser.addInputTensor(acc)
+        if error_name == ErrorIf.InputListCondGraphMismatch:
+            self.ser.addInputTensor(incorrect_iter)
+            self.ser.addInputTensor(a)
+            self.ser.addInputTensor(incorrect_acc)
+        else:
+            self.ser.addInputTensor(iter)
+            self.ser.addInputTensor(a)
+            self.ser.addInputTensor(acc)
         zero_tens = self.ser.addConst([], DType.INT32, [np.int32(0)])
-        cond_tens = self.ser.addOutput([], DType.BOOL)
+
+        if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
+            cond_tens = self.ser.addOutput([], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT]))
+        else:
+            cond_tens = self.ser.addOutput([], DType.BOOL)
+
         self.ser.addOperator(Op.GREATER, [iter.name, zero_tens.name], [cond_tens.name])
 
         # BODY block (input: a, acc, iter, output: a, acc, iter)
         # Note that local intermediate tensors need to be declared here for the outputs
         self.ser.startBasicBlock(body_block)
-        self.ser.addInputTensor(iter)
-        self.ser.addInputTensor(a)
-        self.ser.addInputTensor(acc)
+        if error_name == ErrorIf.InputListBodyGraphInputMismatch:
+            self.ser.addInputTensor(incorrect_iter)
+            self.ser.addInputTensor(a)
+            self.ser.addInputTensor(incorrect_acc)
+        else:
+            self.ser.addInputTensor(iter)
+            self.ser.addInputTensor(a)
+            self.ser.addInputTensor(acc)
+
         one_tens = self.ser.addConst([], DType.INT32, [np.int32(1)])
-        iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
-        acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
+
+        if error_name == ErrorIf.InputListBodyGraphOutputMismatch:
+            iter_body_out = self.ser.addIntermediate(incorrect_iter.shape, incorrect_iter.dtype)
+            acc_body_out = self.ser.addIntermediate(incorrect_acc.shape, incorrect_acc.dtype)
+        else:
+            iter_body_out = self.ser.addIntermediate(iter.shape, iter.dtype)
+            acc_body_out = self.ser.addIntermediate(acc.shape, acc.dtype)
+
         self.ser.addOperator(Op.ADD, [a.name, acc.name], [acc_body_out.name])
         self.ser.addOperator(Op.SUB, [iter.name, one_tens.name], [iter_body_out.name])
         self.ser.addOutputTensor(iter_body_out)
         self.ser.addOutputTensor(a)
         self.ser.addOutputTensor(acc_body_out)
 
+        TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            basicBlocks=self.ser.basicBlocks
+        )
+
         return acc_out
 
     def create_filter_lists(self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None):
@@ -5445,6 +5789,7 @@
                 TosaArgGen.agCondIf,
             ),
             "types": [DType.BOOL],
+            "error_if_validators": (TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch)
         },
         "cond_if_binary": {
             "op": Op.COND_IF,
@@ -5455,6 +5800,8 @@
                 TosaArgGen.agCondIf,
             ),
             "types": TYPE_INT_FP,
+            "error_if_validators": (TosaErrorValidator.evInputListThenGraphMismatch, TosaErrorValidator.evInputListElseGraphMismatch,
+            TosaErrorValidator.evOutputListThenGraphMismatch, TosaErrorValidator.evOutputListElseGraphMismatch)
         },
         # while_loop
         "while_loop": {
@@ -5466,6 +5813,9 @@
                 TosaArgGen.agWhileLoop,
             ),
             "types": [DType.INT32],
+            "error_if_validators": (TosaErrorValidator.evInputListOutputListMismatch, TosaErrorValidator.evInputListCondGraphMismatch,
+            TosaErrorValidator.evInputListBodyGraphInputMismatch, TosaErrorValidator.evInputListBodyGraphOutputMismatch,
+            TosaErrorValidator.evCondGraphOutputNotMatchingBool)
         },
     }