Add negative testing support to RESCALE

 * Negative tests for rescale op added

Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
Change-Id: I70aead1c6a67f159c7b7c9a05f7d5f0b92521584
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index a03c66f..6780aa7 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -169,6 +169,9 @@
         pl, const = opName["operands"]
         shape = testGen.makeShape(rank)
 
+        # Constrict dimension size for large ranks when creating WrongRank tests
+        shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
+
         shape_list = []
         for i in range(pl + const):
             shape_list.append(shape.copy())
@@ -754,21 +757,31 @@
 
         # Enumerate the output types here
         for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
-            if inDtype == DType.UINT8 and dtype != DType.INT8:
+            if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero:
+                continue
+            if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType:
                 # The only output dtype for UINT8 is INT8, skip all other combinations
                 continue
-            if inDtype != DType.INT8 and dtype == DType.UINT8:
+            if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType:
                 # The only input dtype for UINT8 is INT8, skip all other combinations
                 continue
+            if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype):
+                continue
 
             for scale32 in [False, True]:
+                if error_name == ErrorIf.ScaleTrue and scale32 == False:
+                    continue
+                elif error_name == ErrorIf.ScaleNotTrue and scale32 == True:
+                    continue
                 for double_round in [False, True]:
+                    if error_name == ErrorIf.ScaleNotTrue and double_round == False:
+                        continue
                     for per_channel in [False, True]:
 
-                        if inDtype == DType.INT48 and scale32:
+                        if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue:
                             # Illegal condition.  Must be scale32=False
                             continue
-                        if double_round and not scale32:
+                        if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue:
                             # Illegal condition.  ERROR_IF(!scale32 && double_round)
                             continue
 
@@ -1229,6 +1242,22 @@
         else:
             return None, None, None
 
+    @staticmethod
+    def eiRescaleWrongOutputType(input_dtype, output_dtype):
+        if input_dtype == DType.INT8:
+            if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
+                return True
+        if input_dtype in [DType.INT16, DType.INT32]:
+            if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
+                return True
+        elif input_dtype == DType.INT48:
+            if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
+                return True
+        elif input_dtype == DType.UINT8:
+            if output_dtype != DType.INT8:
+                return True
+        return False
+
 
     @staticmethod
     def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
@@ -1247,6 +1276,16 @@
                 output_list = []
         return input_list, output_list
 
+    @staticmethod
+    def eiRestrictDimension(shape, error_name):
+        # Restrict dimension size if rank is large for WrongRank Error_If
+        # This will keep the test sizes reasonably small
+        if error_name == ErrorIf.WrongRank:
+            if len(shape) > 4:
+                shape[4] = 1
+
+        return shape
+
 class TosaErrorValidator:
 
     @staticmethod
@@ -1321,6 +1360,19 @@
                     (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
                 ):
                     error_result = True
+            elif op['op'] == Op.RESCALE:
+                if input_dtype == DType.INT8:
+                    if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
+                        error_result = True
+                if input_dtype in [DType.INT16, DType.INT32]:
+                    if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
+                        error_result = True
+                elif input_dtype == DType.INT48:
+                    if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
+                        error_result = True
+                elif input_dtype == DType.UINT8:
+                    if output_dtype != DType.INT8:
+                        error_result = True
             else:
                 if output_dtype != input_dtype:
                     error_result = True
@@ -1343,11 +1395,11 @@
         rmin, rmax = op['rank']
         rank_range = range(rmin, rmax + 1)
         incorrect_ranks = list(set(all_ranks) - set(rank_range))
+        # Remove small incorrect ranks to avoid index errors
+        incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
         # Set minimum incorrect rank to 3 to avoid index error
         if op['op'] in [Op.RESIZE]:
             incorrect_ranks = [3, 5]
-        elif op['op'] in [Op.AVG_POOL2D, Op.MAX_POOL2D]:
-            incorrect_ranks = [5]
 
         error_name = ErrorIf.WrongRank
         param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
@@ -1358,6 +1410,9 @@
             input_shape = kwargs['input_shape']
             if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
                 error_result = True
+            else:
+                if len(input_shape) not in rank_range:
+                    error_result = True
 
         info_dict = {
             "error_name": error_name,
@@ -1739,9 +1794,14 @@
 
         if check:
             input_dtype = kwargs['input_dtype']
-            # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
-            qinfo = kwargs['qinfo'].ints
-            input_zero_point = qinfo[0][1]
+            if isinstance(kwargs['qinfo'], tuple):
+                qinfo = kwargs['qinfo']
+                input_zero_point = qinfo[0]
+            else:
+                # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
+                qinfo = kwargs['qinfo'].ints
+                input_zero_point = qinfo[0][1]
+
             if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
                 error_result = True
 
@@ -1774,10 +1834,18 @@
 
         if check:
             input_dtype = kwargs['input_dtype']
-            # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
-            qinfo = kwargs['qinfo'].ints
-            output_zero_point = qinfo[1][1]
-            if input_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
+            output_dtype = kwargs['output_dtype']
+            if isinstance(kwargs['qinfo'], tuple):
+                qinfo = kwargs['qinfo']
+                output_zero_point = qinfo[1]
+            else:
+                # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
+                qinfo = kwargs['qinfo'].ints
+                output_zero_point = qinfo[1][1]
+            if op['op'] == Op.AVG_POOL2D:
+                if input_dtype != DType.INT8 and output_zero_point != 0:
+                    error_result = True
+            elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
                 error_result = True
 
         info_dict = {
@@ -1980,6 +2048,48 @@
         }
         return info_dict
 
+    @staticmethod
+    def evScaleTrue(check=False, **kwargs):
+        error_name = ErrorIf.ScaleTrue
+        param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
+        error_result = False
+        error_reason = "Scale set to true but input type is INT48"
+
+        if check:
+            input_dtype = kwargs['input_dtype']
+            scale32 = kwargs['scale32']
+            if scale32 and input_dtype == DType.INT48:
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
+    @staticmethod
+    def evScaleNotTrue(check=False, **kwargs):
+        error_name = ErrorIf.ScaleNotTrue
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Scale set to false but double round set to true"
+
+        if check:
+            scale32 = kwargs['scale32']
+            double_round = kwargs['double_round']
+            if not scale32 and double_round:
+                error_result = True
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs
+        }
+        return info_dict
+
 
 
 class TosaInvalidValidator:
@@ -2276,6 +2386,10 @@
             return 32
         elif t == DType.INT48:
             return 48
+        elif t == DType.FLOAT:
+            return 32
+        elif t == DType.BOOL:
+            return 1
         else:
             raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
 
@@ -2809,7 +2923,7 @@
         self.ser.addOperator(op['op'], [val.name], [result_tens.name])
         return result_tens
 
-    def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
+    def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
         result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
 
         if per_channel:
@@ -2826,6 +2940,11 @@
         elif val.dtype == DType.UINT8:
             input_zp = self.randInt(0, 256)
             in_type_width = in_type_width + 1
+        elif error_name == ErrorIf.InputZeroPointNotZero:
+            input_zp = self.randInt(-128, 128)
+            if input_zp == 0:
+                input_zp = input_zp + self.rng.integers(1, 10)
+            in_type_width = in_type_width + 1
         else:
             input_zp = 0
 
@@ -2835,6 +2954,11 @@
         elif out_dtype == DType.UINT8:
             output_zp = self.randInt(0, 256)
             out_type_width = out_type_width + 1
+        elif error_name == ErrorIf.OutputZeroPointNotZero:
+            output_zp = self.randInt(-128, 128)
+            if output_zp == 0:
+                output_zp = output_zp + self.rng.integers(1, 10)
+            out_type_width = out_type_width + 1
         else:
             output_zp = 0
 
@@ -2864,6 +2988,31 @@
 
         # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
 
+        # Invalidate Input/Output list for error if checks.
+        input_list = [val.name]
+        output_list = [result_tens.name]
+        pCount, cCount = op["operands"]
+        num_operands = pCount + cCount
+        input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+        qinfo = (input_zp, output_zp)
+        TosaErrorValidator.evValidateErrorIfs(
+            self.ser,
+            validator_fcns,
+            error_name,
+            op=op,
+            input_dtype=val.dtype,
+            output_dtype=out_dtype,
+            input_shape=val.shape,
+            qinfo=qinfo,
+            scale32 = scale32,
+            double_round = double_round,
+            input_list=input_list,
+            output_list=output_list,
+            result_tensor=result_tens,
+            num_operands=num_operands,
+        )
+
         attr = ts.TosaSerializerAttribute()
         attr.RescaleAttribute(
             input_zp,
@@ -2875,7 +3024,7 @@
             per_channel,
         )
 
-        self.ser.addOperator(op['op'], [val.name], [result_tens.name], attr)
+        self.ser.addOperator(op['op'], input_list, output_list, attr)
         return result_tens
 
     def build_cond_if_const(self, op, then_tens, else_tens, cond):
@@ -4092,8 +4241,12 @@
         "rescale": {
             "op": Op.RESCALE,
             "operands": (1, 0),
+            "rank": (1,4),
             "build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
             "types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
+            "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue,
+            TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
+            TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
         },
         # Custom
         # Not implemented.