Add negative testing support to fully_connected, matmul, argmax

Change-Id: I75f2a4ab6790dcbdfaec064f42f601d8f44da70b
Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 6780aa7..1ec4a47 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -62,7 +62,7 @@
             return testGen.randInt(-128, 128)
         elif dtype == DType.UINT8:
             return testGen.randInt(0, 256)
-        elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
+        elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
             zero_point = testGen.randInt(-128, 128)
             if zero_point == 0:
                 zero_point = 1
@@ -95,17 +95,31 @@
         else:
             # an int, [input, weights, accumulator] dtypes are the same
             dtypeList = [dtype_or_dtypeList] * 3
-        input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
-        weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
+
+        if error_name == ErrorIf.InputZeroPointNotZero:
+            input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
+            weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
+        elif error_name == ErrorIf.WeightZeroPointNotZero:
+            input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
+            weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
+        else:
+            input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
+            weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
+
         qinfo.ConvQuantInfo(input_zp, weights_zp)
         return qinfo
 
     @staticmethod
     def qgMatmul(testGen, op, dtype, error_name=None):
         qinfo = ts.TosaSerializerQuantInfo()
-        qinfo.MatMulQuantInfo(
-            TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
+        if error_name == ErrorIf.InputZeroPointNotZero:
+            qinfo.MatMulQuantInfo(
+                TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype, error_name)
         )
+        else:
+            qinfo.MatMulQuantInfo(
+                TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
+            )
         return qinfo
 
     @staticmethod
@@ -196,9 +210,9 @@
         # Constrict the batch size?
         if testGen.args.max_batch_size:
             shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
-        # Constrict dimension size for large ranks
-        if rank > 4:
-            shape[4] = 1
+
+        # Constrict dimension size for large ranks when creating WrongRank tests
+        shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
 
         shape_list = []
         for i in range(pl + const):
@@ -383,9 +397,14 @@
     def tgFullyConnected(testGen, op, rank, error_name=None):
         pl, const = op["operands"]
 
-        assert rank == 2
+        if error_name != ErrorIf.WrongRank:
+            assert rank == 2
 
         input_shape = testGen.makeShape(rank)
+
+        # Constrict dimension size for large ranks when creating WrongRank tests
+        shape = TosaErrorIfArgGen.eiRestrictDimension(input_shape, error_name)
+
         filter_oc = testGen.rng.integers(
             low=testGen.args.tensor_shape_range[0],
             high=testGen.args.tensor_shape_range[1],
@@ -401,10 +420,15 @@
     def tgMatmul(testGen, op, rank, error_name=None):
         pl, const = op["operands"]
 
-        assert rank == 3
+        if error_name != ErrorIf.WrongRank:
+            assert rank == 3
         assert pl == 2 and const == 0
 
         a_shape = testGen.makeShape(rank)
+
+        # Constrict dimension size for large ranks when creating WrongRank tests
+        shape = TosaErrorIfArgGen.eiRestrictDimension(a_shape, error_name)
+
         # Get a random number for b_oc even if target shape is defined
         b_oc = np.int32(
             testGen.rng.integers(
@@ -1312,13 +1336,15 @@
 
     @staticmethod
     def evWrongInputType(check=False, **kwargs):
-        all_dtypes = (DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
+        all_dtypes = {DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
 
         # Find the unsupported input data types
         assert 'op' in kwargs
         op = kwargs['op']
         input_dtypes = op['types']
-        wrong_input_dtypes = list(set(all_dtypes) - set(input_dtypes))
+
+        allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
+        wrong_input_dtypes = list(all_dtypes - allowed_input_dtypes)
 
         error_name = ErrorIf.WrongInputType
         param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
@@ -1327,7 +1353,10 @@
 
         if check:
             input_dtype = kwargs['input_dtype']
-            if input_dtype not in input_dtypes:
+            if op['op'] == Op.FULLY_CONNECTED:
+                if input_dtype not in allowed_input_dtypes:
+                    error_result = True
+            elif input_dtype not in input_dtypes:
                 error_result = True
 
         info_dict = {
@@ -1373,6 +1402,16 @@
                 elif input_dtype == DType.UINT8:
                     if output_dtype != DType.INT8:
                         error_result = True
+            elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]:
+                if (
+                    (input_dtype == DType.INT8 and output_dtype != DType.INT32) or
+                    (input_dtype == DType.INT16 and output_dtype != DType.INT48) or
+                    (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
+                ):
+                    error_result = True
+            elif op['op'] == Op.ARGMAX:
+                if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32:
+                    error_result = True
             else:
                 if output_dtype != input_dtype:
                     error_result = True
@@ -1408,8 +1447,13 @@
 
         if check:
             input_shape = kwargs['input_shape']
+
             if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
                 error_result = True
+            elif op['op'] == Op.FULLY_CONNECTED and len(input_shape) != 2:
+                error_result = True
+            elif op['op'] == Op.MATMUL and len(input_shape) != 3:
+                error_result = True
             else:
                 if len(input_shape) not in rank_range:
                     error_result = True
@@ -1778,6 +1822,10 @@
     def evInputZeroPointNotZero(check=False, **kwargs):
         op = kwargs['op']
         inputDtypes = op['types'].copy()
+        # If inputDtypes is a list then only the first two elements are INT8 inputs
+        if isinstance(inputDtypes, list):
+            inputDtypes = inputDtypes[2:]
+
         if DType.INT8 in inputDtypes:
             inputDtypes.remove(DType.INT8)
         if DType.UINT8 in inputDtypes:
@@ -1802,7 +1850,50 @@
                 qinfo = kwargs['qinfo'].ints
                 input_zero_point = qinfo[0][1]
 
-            if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
+            if op['op'] == Op.MATMUL:
+                input1_dtype = kwargs['input_dtype']
+                input2_dtype = kwargs['input2_dtype']
+                qinfo = kwargs['qinfo'].ints
+                input1_zero_point = qinfo[0][1]
+                input2_zero_point = qinfo[1][1]
+                if (input1_dtype != DType.INT8 and input1_zero_point != 0) or (input2_dtype != DType.INT8 and input2_zero_point != 0):
+                    error_result = True
+            else:
+                if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
+                    error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+
+    @staticmethod
+    def evWeightZeroPointNotZero(check=False, **kwargs):
+        op = kwargs['op']
+
+        # exclude inputs with INT8 weights
+        inputDtypes = [t for t in op['types']
+                       if not isinstance(t, list) or t[1] != DType.INT8]
+
+        error_name = ErrorIf.WeightZeroPointNotZero
+        param_reqs = {
+            "rank": None,
+            "dtype": inputDtypes,
+            "shape": None
+            }
+        error_result = False
+        error_reason = "Weight DType not INT8 and zero point not 0"
+
+        if check:
+            weight_dtype = kwargs['weight_dtype']
+            # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
+            qinfo = kwargs['qinfo'].ints
+            weight_zero_point = qinfo[1][1]
+            if weight_dtype != DType.INT8 and weight_zero_point != 0:
                 error_result = True
 
         info_dict = {
@@ -2007,6 +2098,65 @@
         }
         return info_dict
 
+    @staticmethod
+    def evArgmaxOutputShapeMismatch(check=False, **kwargs):
+        error_name = ErrorIf.ArgmaxOutputShapeMismatch
+        param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Mismatch between output shape provided and expected output shape"
+
+        if check:
+            output_shape = kwargs['output_shape']
+            input_shape = kwargs['input_shape']
+            axis = kwargs['axis']
+
+            dimension_match = True
+            axis_shift = 0
+
+            # Check that rank is correct before trying to check dimensions
+            if (len(input_shape) - 1) == len(output_shape):
+                for i in range(len(input_shape)):
+                    if i == axis:
+                        axis_shift = 1
+                        continue
+                    if input_shape[i] != output_shape[i - axis_shift]:
+                        dimension_match = False
+
+                if not dimension_match:
+                    error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+    @staticmethod
+    def evArgmaxOutputRankMismatch(check=False, **kwargs):
+        error_name = ErrorIf.ArgmaxOutputRankMismatch
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Mismatch between output shape provided and expected output shape"
+
+        if check:
+            output_shape = kwargs['output_shape']
+            input_shape = kwargs['input_shape']
+            axis = kwargs['axis']
+            valid_params = axis >= 0 and axis < len(input_shape)
+
+            if valid_params and (len(input_shape) - 1) != len(output_shape):
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
 
     @staticmethod
     def evKernelSmallerOne(check=False, **kwargs):
@@ -2525,13 +2675,36 @@
         self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
         return result_tens
 
-    def build_argmax(self, op, a, axis):
-        result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
+    def build_argmax(self, op, a, axis, validator_fcns, error_name):
+        result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
+
+        # Invalidate Input/Output list for error if checks.
+        input_list = [a.name]
+        output_list = [result_tens.name]
+        pCount, cCount = op["operands"]
+        num_operands = pCount + cCount
+        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+        TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            axis=axis,
+            input_shape = a.shape,
+            input_dtype = a.dtype,
+            output_shape = result_tens.shape,
+            output_dtype = result_tens.dtype,
+            result_tensor = result_tens,
+            input_list=input_list,
+            output_list=output_list,
+            num_operands=num_operands,
+        )
 
         attr = ts.TosaSerializerAttribute()
         attr.AxisAttribute(axis)
 
-        self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
+        self.ser.addOperator(op['op'], input_list, output_list, attr)
         return result_tens
 
     def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
@@ -2634,17 +2807,67 @@
         )
         return result_tens
 
-    def build_fully_connected(self, op, ifm, filter, bias, qinfo):
-        result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
+    def build_fully_connected(self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None):
+        result_tens = OutputShaper.fullyConnectedOp(self.ser, self.rng, ifm, filter, error_name)
+
+        # Invalidate Input/Output list for error if checks.
+        input_list = [ifm.name, filter.name, bias.name]
+        output_list = [result_tens.name]
+        pCount, cCount = op["operands"]
+        num_operands = pCount + cCount
+        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+        TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            input_shape=ifm.shape,
+            input_dtype=ifm.dtype,
+            weight_dtype=filter.dtype,
+            output_shape=result_tens.shape,
+            output_dtype=result_tens.dtype,
+            qinfo = qinfo,
+            result_tensor = result_tens,
+            input_list=input_list,
+            output_list=output_list,
+            num_operands=num_operands,
+        )
 
         self.ser.addOperator(
-            op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
+            op['op'], input_list, output_list, None, qinfo
         )
         return result_tens
 
-    def build_matmul(self, op, a, b, qinfo):
-        result_tens = OutputShaper.matmulOp(self.ser, a, b)
-        self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], None, qinfo)
+    def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
+        result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
+
+        # Invalidate Input/Output list for error if checks.
+        input_list = [a.name, b.name]
+        output_list = [result_tens.name]
+        pCount, cCount = op["operands"]
+        num_operands = pCount + cCount
+        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+        TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            input_shape=a.shape,
+            input_dtype=a.dtype,
+            input2_shape=b.shape,
+            input2_dtype=b.dtype,
+            output_shape=result_tens.shape,
+            output_dtype=result_tens.dtype,
+            qinfo = qinfo,
+            result_tensor = result_tens,
+            input_list=input_list,
+            output_list=output_list,
+            num_operands=num_operands,
+        )
+
+        self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
         return result_tens
 
     def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
@@ -3246,7 +3469,6 @@
         for validator in error_if_validators:
             if validator is not None:
                 error_name = validator(check=False, op=op)['error_name']
-                #print("error_name: ", error_name)
             else:
                 error_name = None
 
@@ -3713,8 +3935,12 @@
         "argmax": {
             "op": Op.ARGMAX,
             "operands": (1, 0),
+            "rank": (1, 4),
             "build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
             "types": TYPE_NARROW_INT_FP,
+            "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch,
+            TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
+            TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
         },
         "avg_pool2d": {
             "op": Op.AVG_POOL2D,
@@ -3773,6 +3999,8 @@
             "build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
             "qgen": TosaQuantGen.qgConv,
             "types": TYPE_CONV,
+            "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank,
+            TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
         },
         "matmul": {
             "op": Op.MATMUL,
@@ -3781,6 +4009,8 @@
             "build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
             "qgen": TosaQuantGen.qgMatmul,
             "types": TYPE_NARROW_INT_FP,
+            "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
+            TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
         },
         "max_pool2d": {
             "op": Op.MAX_POOL2D,
@@ -4386,10 +4616,30 @@
         return ser.addOutput(shape, outputDType)
 
     @staticmethod
-    def argmaxOp(ser, a, axis):
+    def argmaxOp(ser, rng, a, axis, error_name=None):
         shape = a.shape.copy()
-        del shape[axis]
-        return ser.addOutput(shape, DType.INT32)
+
+        if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
+            del shape[axis]
+
+        if error_name == ErrorIf.ArgmaxOutputRankMismatch:
+            remove = rng.choice([True, False])
+            if remove and len(shape) > 1:
+                del shape[0]
+            else:
+                shape.append(1)
+        elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
+            for i in range(len(shape)):
+                shape[i] = shape[i] + rng.integers(1, 10)
+
+        if error_name == ErrorIf.WrongOutputType:
+            all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+            wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
+            outputDType = rng.choice(wrong_dtypes)
+        else:
+            outputDType = DType.INT32
+
+        return ser.addOutput(shape, outputDType)
 
     @staticmethod
     def conv2dOp(ser, ifm, filter, strides, padding, dilations):
@@ -4514,7 +4764,7 @@
     def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
         # input: NHWC
         if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
-            # If an incorrect stride is used set dimensions to 0, test is invalid anyway.
+            # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
             h = 1
             w = 1
         else:
@@ -4538,40 +4788,62 @@
         return ser.addOutput(ofm_shape, outputDType)
 
     @staticmethod
-    def fullyConnectedOp(ser, input, filter):
+    def fullyConnectedOp(ser, rng, input, filter, error_name=None):
         # input: N, IC
         # filter: OC, IC
         # output: N, OC
 
         output_shape = [input.shape[0], filter.shape[0]]
 
-        if input.dtype == DType.INT8:
+        if error_name == ErrorIf.WrongOutputType:
+            if input.dtype == DType.INT8:
+                incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
+            elif input.dtype == DType.INT16:
+                incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
+            elif input.dtype == DType.FLOAT:
+                incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
+            out_dtype = rng.choice(a=incorrect_types)
+        elif input.dtype == DType.INT8:
             out_dtype = DType.INT32
         elif input.dtype == DType.INT16:
             out_dtype = DType.INT48
         elif input.dtype == DType.FLOAT:
             out_dtype = DType.FLOAT
+        elif error_name == ErrorIf.WrongInputType:
+            # Pick some potentially correct output dtype if input type is incorrect
+            out_dtype = DType.INT32
         else:
             raise Exception("Unsupported input dtype: {}".format(input.dtype))
 
         return ser.addOutput(output_shape, out_dtype)
 
     @staticmethod
-    def matmulOp(ser, a, b):
+    def matmulOp(ser, rng, a, b, error_name=None):
         # a: N, H, C
         # b: N, C, W
         # out: N, H, W
 
         output_shape = [a.shape[0], a.shape[1], b.shape[2]]
 
-        if a.dtype == DType.INT8:
+        if error_name == ErrorIf.WrongOutputType:
+            if a.dtype == DType.INT8:
+                incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
+            elif a.dtype == DType.INT16:
+                incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
+            elif a.dtype == DType.FLOAT:
+                incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
+            out_dtype = rng.choice(a=incorrect_types)
+        elif a.dtype == DType.INT8:
             out_dtype = DType.INT32
         elif a.dtype == DType.INT16:
             out_dtype = DType.INT48
         elif a.dtype == DType.FLOAT:
             out_dtype = DType.FLOAT
+        elif error_name == ErrorIf.WrongInputType:
+            # Pick some potentially correct output dtype if input type is incorrect
+            out_dtype = DType.INT32
         else:
-            raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
+            raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
 
         return ser.addOutput(output_shape, out_dtype)