Do not generate tests that fail validation checks

Change-Id: I33237ebfd946b9ec91352c2b0dc6298cc113cd77
Signed-off-by: Les Bell <les.bell@arm.com>
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 22886d6..655cdfc 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -1565,7 +1565,17 @@
 
     @staticmethod
     def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
-        # Check ERROR_IF statements
+        """Check ERROR_IF statements are caught and set the expected result.
+
+        Args:
+            serializer: the serializer to set the expected result in
+            validator_fcns: a sequence of validator functions to verify the result
+            error_name: the name of the ERROR_IF condition to check for
+            kwargs: keyword arguments for the validator functions
+        Returns:
+            True if the result matches the expected result; otherwise False
+        """
+        overall_result = True
         for val_fcn in validator_fcns:
             val_result = val_fcn(True, **kwargs)
             validator_name = val_result['error_name']
@@ -1574,6 +1584,7 @@
 
             # expect an error IFF the error_name and validator_name match
             expected_result = error_result == (error_name == validator_name)
+            overall_result &= expected_result
 
             if expected_result and error_result:
                 serializer.setExpectedReturnCode(2, error_reason)
@@ -1591,6 +1602,8 @@
                             v = valueToName(DType, v)
                         print(f'  {k} = {v}')
 
+        return overall_result
+
     @staticmethod
     def evWrongInputType(check=False, **kwargs):
         error_result = False
@@ -3447,7 +3460,7 @@
         elif t == DType.BOOL:
             return 1
         else:
-            raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
+            raise Exception(f"Unknown dtype, cannot determine width: {t}")
 
     # Argument generators
     # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
@@ -3481,7 +3494,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -3493,7 +3506,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
         return result_tens
@@ -3509,7 +3523,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -3522,7 +3536,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         self.ser.addOperator(op['op'], input_list, output_list)
         return result_tens
@@ -3542,7 +3557,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -3555,7 +3570,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.ArithmeticRightShiftAttribute(round)
@@ -3582,7 +3598,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -3595,7 +3611,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.MulAttribute(shift)
@@ -3616,7 +3633,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -3628,7 +3645,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         self.ser.addOperator(op['op'], input_list, output_list, attr)
 
@@ -3644,7 +3662,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -3659,7 +3677,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         self.ser.addOperator(op['op'], input_list, output_list,)
         return result_tens
@@ -3674,7 +3693,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -3689,7 +3708,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         self.ser.addOperator(op['op'], input_list, output_list,)
         return result_tens
@@ -3704,7 +3724,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -3718,7 +3738,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.AxisAttribute(axis)
@@ -3744,7 +3765,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -3761,7 +3782,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.PoolAttribute(kernel, stride, pad)
@@ -3788,7 +3810,7 @@
         num_operands = sum(op["operands"])
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -3804,7 +3826,8 @@
             stride=strides,
             dilation=dilations,
             input_shape=ifm.shape,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.ConvAttribute(padding, strides, dilations)
@@ -3833,7 +3856,7 @@
         num_operands = sum(op["operands"])
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -3849,7 +3872,8 @@
             stride=strides,
             dilation=dilations,
             input_shape=ifm.shape,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.ConvAttribute(padding, strides, dilations)
@@ -3878,7 +3902,7 @@
         num_operands = sum(op["operands"])
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -3894,7 +3918,8 @@
             stride=stride,
             dilation=dilation,
             input_shape=ifm.shape,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
@@ -3924,7 +3949,7 @@
         num_operands = sum(op["operands"])
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -3940,7 +3965,8 @@
             stride=strides,
             dilation=dilations,
             input_shape=ifm.shape,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.ConvAttribute(padding, strides, dilations)
@@ -3960,7 +3986,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -3975,7 +4001,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         self.ser.addOperator(
             op['op'], input_list, output_list, None, qinfo
@@ -3992,7 +4019,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4008,7 +4035,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
         return result_tens
@@ -4023,7 +4051,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4037,7 +4065,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.AxisAttribute(axis)
@@ -4067,7 +4096,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4082,7 +4111,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         if a.dtype == DType.FLOAT:
@@ -4119,7 +4149,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4132,7 +4162,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         self.ser.addOperator(op['op'], input_list, output_list)
         return result_tens
@@ -4147,7 +4178,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4160,7 +4191,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         self.ser.addOperator(op['op'], input_list, output_list)
         return result_tens
@@ -4186,7 +4218,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4201,7 +4233,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.AxisAttribute(axis)
@@ -4223,7 +4256,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4238,7 +4271,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         self.ser.addOperator(
             op['op'], input_list, output_list, attr, qinfo
@@ -4255,7 +4289,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4268,7 +4302,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.ReshapeAttribute(newShape)
@@ -4286,7 +4321,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4300,7 +4335,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.AxisAttribute(axis)
@@ -4321,7 +4357,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4335,7 +4371,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
 
         self.ser.addOperator(op['op'], input_list, output_list, attr)
@@ -4351,7 +4388,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4366,7 +4403,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.SliceAttribute(start, size)
@@ -4384,7 +4422,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4397,7 +4435,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.TileAttribute(multiples)
@@ -4428,7 +4467,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4441,7 +4480,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         self.ser.addOperator(op['op'], input_list, output_list)
 
@@ -4468,7 +4508,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4481,7 +4521,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         self.ser.addOperator(op['op'], input_list, output_list)
 
@@ -4527,7 +4568,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4546,7 +4587,8 @@
             output_list=output_list,
             result_tensor=result_tens,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
 
@@ -4580,7 +4622,7 @@
         num_operands = pCount + cCount
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4593,7 +4635,8 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         self.ser.addOperator(op['op'], input_list, output_list)
         return result_tens
@@ -4671,7 +4714,7 @@
         input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
 
         qinfo = (input_zp, output_zp)
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4686,7 +4729,8 @@
             output_list=output_list,
             result_tensor=result_tens,
             num_operands=num_operands,
-        )
+        ):
+            return None
 
         attr = ts.TosaSerializerAttribute()
         attr.RescaleAttribute(
@@ -4750,13 +4794,14 @@
             else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
         self.ser.addOutputTensor(else_tens)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
             op=op,
             basicBlocks=self.ser.basicBlocks
-        )
+        ):
+            return None
 
         return result_tens
 
@@ -4814,7 +4859,7 @@
                 tens = self.ser.addOutput(a.shape, a.dtype)
             self.ser.addOperator(op, [a.name, b.name], [tens.name])
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
@@ -4822,7 +4867,8 @@
             a=a,
             b=b,
             basicBlocks=self.ser.basicBlocks
-        )
+        ):
+            return None
 
         return result_tens
 
@@ -4917,13 +4963,14 @@
         self.ser.addOutputTensor(a)
         self.ser.addOutputTensor(acc_body_out)
 
-        TosaErrorValidator.evValidateErrorIfs(
+        if not TosaErrorValidator.evValidateErrorIfs(
             self.ser,
             validator_fcns,
             error_name,
             op=op,
             basicBlocks=self.ser.basicBlocks
-        )
+        ):
+            return None
 
         return acc_out
 
@@ -5156,11 +5203,12 @@
             print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
             raise e
 
-        if resultName is None:
-            print("Invalid ERROR_IF tests created")
-
-        # Save the serialized test
-        self.serialize("test")
+        if resultName:
+            # The test is valid, serialize it
+            self.serialize("test")
+        else:
+            # The test is not valid
+            print(f"Invalid ERROR_IF test created: {opName} {testStr}")
 
 
     def generate_tensors(self, op, dtypeList, shapeList, testArgs, error_name=None):