Add support for uint16_t to RESCALE

Update ref-model RESCALE op to support UINT16 conversions
Add testing for RESCALE UINT16 and ERROR_IFs

Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ic6e6e53de1f0b054bedb9e6ba3856e7475498aba
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index e7e758f..1900d8a 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -68,6 +68,8 @@
     InputListBodyGraphInputMismatch = "InputListBodyGraphInputMismatch"
     InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
     CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
+    U16InputZeroPointNotValid = "U16InputZeroPointNotValid"
+    U16OutputZeroPointNotValid = "U16OutputZeroPointNotValid"
 
 
 class TosaErrorIfArgGen:
@@ -227,14 +229,26 @@
         if input_dtype == DType.INT8:
             if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
                 return True
-        if input_dtype in [DType.INT16, DType.INT32]:
+        elif input_dtype == DType.INT16:
+            if output_dtype not in [
+                DType.UINT8,
+                DType.INT8,
+                DType.UINT16,
+                DType.INT16,
+                DType.INT32,
+            ]:
+                return True
+        elif input_dtype == DType.INT32:
             if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
                 return True
         elif input_dtype == DType.INT48:
             if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
                 return True
         elif input_dtype == DType.UINT8:
-            if output_dtype != DType.INT8:
+            if output_dtype not in [DType.INT8, DType.INT16]:
+                return True
+        elif input_dtype == DType.UINT16:
+            if output_dtype != DType.INT16:
                 return True
         return False
 
@@ -418,23 +432,9 @@
                     error_result = True
 
             elif op["op"] == Op.RESCALE:
-                if input_dtype == DType.INT8:
-                    if output_dtype not in [
-                        DType.UINT8,
-                        DType.INT8,
-                        DType.INT16,
-                        DType.INT32,
-                    ]:
-                        error_result = True
-                if input_dtype in [DType.INT16, DType.INT32]:
-                    if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
-                        error_result = True
-                elif input_dtype == DType.INT48:
-                    if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
-                        error_result = True
-                elif input_dtype == DType.UINT8:
-                    if output_dtype != DType.INT8:
-                        error_result = True
+                error_result = TosaErrorIfArgGen.eiRescaleWrongOutputType(
+                    input_dtype, output_dtype
+                )
 
             elif op["op"] in [Op.FULLY_CONNECTED, Op.MATMUL]:
                 if (
@@ -998,12 +998,25 @@
         return info_dict
 
     @staticmethod
+    def _getZeroPoint(qinfo, index):
+        """Return zero point value from quantization info.
+
+        Generally input_zp is index 0, output_zp is index 1
+        """
+        if isinstance(qinfo, tuple):
+            zero_point = qinfo[index]
+        else:
+            # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
+            zero_point = qinfo.ints[index][1]
+        return zero_point
+
+    @staticmethod
     def evInputZeroPointNotZero(check=False, **kwargs):
         op = kwargs["op"]
         error_result = False
 
         # Quantizable types
-        qTypes = (DType.INT8, DType.UINT8)
+        qTypes = (DType.INT8, DType.UINT8, DType.UINT16)
 
         # This does not apply to quantizable types
         inputDtypes = [
@@ -1015,19 +1028,12 @@
 
         if check:
             input_dtype = kwargs["input_dtype"]
-            if isinstance(kwargs["qinfo"], tuple):
-                qinfo = kwargs["qinfo"]
-                input_zero_point = qinfo[0]
-            else:
-                # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
-                qinfo = kwargs["qinfo"].ints
-                input_zero_point = qinfo[0][1]
-
+            input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
             if op["op"] == Op.MATMUL:
-                qinfo = kwargs["qinfo"].ints
+                input2_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
                 for dtype, zp in (
-                    (kwargs["input_dtype"], qinfo[0][1]),
-                    (kwargs["input2_dtype"], qinfo[1][1]),
+                    (kwargs["input_dtype"], input_zero_point),
+                    (kwargs["input2_dtype"], input2_zero_point),
                 ):
                     if dtype not in qTypes and zp != 0:
                         error_result = True
@@ -1059,9 +1065,7 @@
 
         if check:
             weight_dtype = kwargs["weight_dtype"]
-            # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
-            qinfo = kwargs["qinfo"].ints
-            weight_zero_point = qinfo[1][1]
+            weight_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
             if weight_dtype != DType.INT8 and weight_zero_point != 0:
                 error_result = True
 
@@ -1076,11 +1080,9 @@
     @staticmethod
     def evOutputZeroPointNotZero(check=False, **kwargs):
         op = kwargs["op"]
-        inputDtypes = op["types"].copy()
-        if DType.INT8 in inputDtypes:
-            inputDtypes.remove(DType.INT8)
-        if DType.UINT8 in inputDtypes:
-            inputDtypes.remove(DType.UINT8)
+        inputDtypes = [
+            t for t in op["types"] if t not in [DType.INT8, DType.UINT8, DType.UINT16]
+        ]
 
         error_name = ErrorIf.OutputZeroPointNotZero
         param_reqs = {"rank": None, "dtype": inputDtypes, "shape": None}
@@ -1090,18 +1092,13 @@
         if check:
             input_dtype = kwargs["input_dtype"]
             output_dtype = kwargs["output_dtype"]
-            if isinstance(kwargs["qinfo"], tuple):
-                qinfo = kwargs["qinfo"]
-                output_zero_point = qinfo[1]
-            else:
-                # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
-                qinfo = kwargs["qinfo"].ints
-                output_zero_point = qinfo[1][1]
+            output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
             if op["op"] == Op.AVG_POOL2D:
                 if input_dtype != DType.INT8 and output_zero_point != 0:
                     error_result = True
             elif (
-                output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0
+                output_dtype not in [DType.INT8, DType.UINT8, DType.UINT16]
+                and output_zero_point != 0
             ):
                 error_result = True
 
@@ -1114,6 +1111,53 @@
         return info_dict
 
     @staticmethod
+    def evU16InputZeroPointNotValid(check=False, **kwargs):
+        error_name = ErrorIf.U16InputZeroPointNotValid
+        param_reqs = {"rank": None, "dtype": [DType.UINT16], "shape": None}
+        error_result = False
+        error_reason = "Input DType is UINT16 and zero point not 0 or 32678"
+
+        if check:
+            input_dtype = kwargs["input_dtype"]
+            input_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 0)
+            error_result = input_dtype == DType.UINT16 and input_zero_point not in [
+                0,
+                32768,
+            ]
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs,
+        }
+        return info_dict
+
+    @staticmethod
+    def evU16OutputZeroPointNotValid(check=False, **kwargs):
+        error_name = ErrorIf.U16OutputZeroPointNotValid
+        param_reqs = {"rank": None, "dtype": None, "shape": None}
+        error_result = False
+        error_reason = "Output DType is UINT16 and zero point not 0 or 32678"
+
+        if check:
+            output_dtype = kwargs["output_dtype"]
+            output_zero_point = TosaErrorValidator._getZeroPoint(kwargs["qinfo"], 1)
+
+            error_result = output_dtype == DType.UINT16 and output_zero_point not in [
+                0,
+                32768,
+            ]
+
+        info_dict = {
+            "error_name": error_name,
+            "error_result": error_result,
+            "error_reason": error_reason,
+            "param_reqs": param_reqs,
+        }
+        return info_dict
+
+    @staticmethod
     def evAxisSmallerZero(check=False, **kwargs):
         error_name = ErrorIf.AxisSmallerZero
         param_reqs = {"rank": None, "dtype": None, "shape": None}