Reference model changes for fp16 support

Change-Id: I72f21fcfa153046274969d327313e3349981dbe6
Signed-off-by: James Ward <james.ward@arm.com>
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index a65e220..69968d3 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -2,10 +2,13 @@
 # SPDX-License-Identifier: Apache-2.0
 import itertools
 import math
+import warnings
 
 import numpy as np
 from generator.tosa_error_if import ErrorIf
 from generator.tosa_error_if import TosaErrorIfArgGen
+from generator.tosa_utils import get_accum_dtype_from_tgTypes
+from generator.tosa_utils import get_wrong_output_type
 from generator.tosa_utils import MAX_RESIZE_DIMENSION
 from serializer.tosa_serializer import DTypeNames
 from tosa.DType import DType
@@ -773,7 +776,7 @@
             ), "Op.MUL must have 2 placeholders, 0 consts"
 
             tens = []
-            if dtypeList[0] == DType.FLOAT:
+            if dtypeList[0] in (DType.FP16, DType.FLOAT):
                 tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
             else:
                 placeholders = []
@@ -982,7 +985,7 @@
         return axes
 
     @staticmethod
-    def agConv(testGen, opName, shapeList, dtype, error_name=None):
+    def agConv(testGen, opName, shapeList, dtypes, error_name=None):
         arg_list = []
 
         ifm_shape = shapeList[0]
@@ -990,6 +993,8 @@
         # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
         k = [int(x) for x in opName.split("_")[-1].split("x")]
 
+        accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
+
         # Check the rank
         rank = 5 if opName.startswith("conv3d") else 4
         if error_name != ErrorIf.WrongRank:
@@ -1089,12 +1094,13 @@
                         ):
                             arg_list.append(
                                 (
-                                    "st{}_pad{}_dilat{}".format(
+                                    "acc{}_st{}_pad{}_dilat{}".format(
+                                        testGen.typeStr(accum_dtype),
                                         "".join([str(x) for x in s]),
                                         "".join([str(x) for x in p]),
                                         "".join([str(x) for x in d]),
                                     ),
-                                    [s, p, d],
+                                    [accum_dtype, s, p, d],
                                 )
                             )
                     n += 1
@@ -1102,12 +1108,55 @@
         return arg_list
 
     @staticmethod
-    def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
+    def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
+
+        if isinstance(dtypes, list) or isinstance(dtypes, tuple):
+            input_dtype = dtypes[0]
+        else:
+            input_dtype = dtypes
+
+        if error_name == ErrorIf.WrongOutputType:
+            accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
+        elif error_name == ErrorIf.WrongInputType:
+            # Pick some potentially correct output dtype if input type is incorrect
+            accum_dtype = DType.INT32
+        else:
+            accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
+
+        return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
+
+    @staticmethod
+    def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
+        # Get valid accumulate type(s)
+        if dtype == DType.INT8:
+            accum_dtypes = [DType.INT32]
+        elif dtype == DType.INT16:
+            accum_dtypes = [DType.INT48]
+        elif dtype == DType.FP16:
+            accum_dtypes = [DType.FP16, DType.FLOAT]
+        elif dtype == DType.FLOAT:
+            accum_dtypes = [DType.FLOAT]
+        elif error_name is None:
+            assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
+
+        if error_name == ErrorIf.WrongOutputType:
+            # Get incorrect output dtype for ErrorIf case
+            accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)]
+        elif error_name == ErrorIf.WrongInputType:
+            # Pick some potentially correct output dtype if input type is incorrect
+            accum_dtypes = [DType.INT32]
+
+        return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes]
+
+    @staticmethod
+    def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
         arg_list = []
 
         ifm_shape = shapeList[0]
         filter_shape = shapeList[1]
 
+        accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
+
         # Must be rank 4
         if error_name != ErrorIf.WrongRank:
             assert len(ifm_shape) == 4
@@ -1169,12 +1218,13 @@
                     os = [ifm_shape[0], oh, ow, filter_shape[0]]
                     arg_list.append(
                         (
-                            "st{}_pad{}_os{}".format(
+                            "acc{}_st{}_pad{}_os{}".format(
+                                testGen.typeStr(accum_dtype),
                                 "".join([str(x) for x in s]),
                                 "".join([str(x) for x in p]),
                                 "x".join([str(x) for x in os]),
                             ),
-                            [s, p, os],
+                            [accum_dtype, s, p, os],
                         )
                     )
                 n += 1
@@ -1199,18 +1249,38 @@
         if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
             pad_const_int = testGen.getRandNumberDType(dtype)
             pad_const_fp = 0
-        elif dtype == DType.FLOAT:
+        elif dtype in (DType.FP16, DType.FLOAT):
             pad_const_int = 0
             pad_const_fp = testGen.getRandNumberDType(dtype)
         else:
             return []
 
         for paddings in shape_pad_values:
-            name = "pad"
-            for r in range(rank):
-                before, after = paddings[r]
-                name = f"{name}{before}{after}"
-            arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
+            paddings = list(paddings)
+            args_valid = True
+
+            if error_name == ErrorIf.PadSmallerZero:
+                # Prevent negative output shapes while ensuring still testing for negative padding
+                for i in range(rank):
+                    dim_after_padding = (
+                        paddings[i][0] + paddings[i][1] + shapeList[0][i]
+                    )
+                    if dim_after_padding < 1:
+                        paddings[i] = (0, 0)
+                if all([p > -1 for p in paddings[i]]):
+                    args_valid = False
+
+            if args_valid:
+                name = "pad"
+                for r in range(rank):
+                    before, after = paddings[r]
+                    name = f"{name}{before}{after}"
+                arg_list.append(
+                    (name, [np.array(paddings), pad_const_int, pad_const_fp])
+                )
+
+        if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
+            warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
 
         return arg_list
 
@@ -1232,6 +1302,21 @@
         k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
         kernels = {x for x in itertools.product(*([k_vals] * 2))}
 
+        if opName == "max_pool2d":
+            accum_dtypes = [None]  # max_pool has no accumulate dtype
+        elif dtype == DType.INT8 or dtype == DType.INT16:
+            accum_dtypes = [DType.INT32]
+        elif dtype == DType.FP16:
+            accum_dtypes = [DType.FP16, DType.FLOAT]
+        elif dtype == DType.FLOAT:
+            accum_dtypes = [DType.FLOAT]
+        elif error_name is None:
+            assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
+        else:
+            # Set to something for the ErrorIf case which has
+            # incorrect input data-type
+            accum_dtypes = [DType.INT32]
+
         if testGen.args.oversize:
             # add some oversize argument values
             bigStride = 7
@@ -1252,63 +1337,70 @@
         sparsity_factor = 2 if error_name else 500
         sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
 
+        arg_str = (
+            "acc{}_st{}_kern{}_pad{}"
+            if accum_dtypes[0] is not None
+            else "st{}_kern{}_pad{}"
+        )
+
+        def get_arg_list_element(accum, stride, pad, kern):
+            # Return tuple containing the formatted argument string and
+            # the corresponding argument values
+            arg_str_elems = [
+                "".join([str(x) for x in stride]),
+                "".join([str(x) for x in kern]),
+                "".join([str(x) for x in pad]),
+            ]
+            # Note: different order to string
+            arg_val_elems = [stride, pad, kern]
+
+            if accum is not None:
+                arg_str_elems.insert(0, testGen.typeStr(accum))
+                arg_val_elems.insert(0, accum)
+            return (arg_str.format(*arg_str_elems), arg_val_elems)
+
         n = 0
-        for s in sorted(list(strides)):
-            for p in sorted(list(paddings)):
-                for k in sorted(list(kernels)):
-                    if error_name in [
-                        ErrorIf.StrideSmallerOne,
-                        ErrorIf.KernelSmallerOne,
-                        ErrorIf.PadSmallerZero,
-                        ErrorIf.PadLargerEqualKernel,
-                    ]:
-                        sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
-                            testGen, error_name, s, p, k
-                        )
-                        if None not in [sNew, pNew, kNew] and n % sparsity == 0:
-                            arg_list.append(
-                                (
-                                    "st{}_kern{}_pad{}".format(
-                                        "".join([str(x) for x in sNew]),
-                                        "".join([str(x) for x in kNew]),
-                                        "".join([str(x) for x in pNew]),
-                                    ),
-                                    [sNew, pNew, kNew],
-                                )
+        for a in accum_dtypes:
+            for s in sorted(list(strides)):
+                for p in sorted(list(paddings)):
+                    for k in sorted(list(kernels)):
+                        if error_name in [
+                            ErrorIf.StrideSmallerOne,
+                            ErrorIf.KernelSmallerOne,
+                            ErrorIf.PadSmallerZero,
+                            ErrorIf.PadLargerEqualKernel,
+                        ]:
+                            sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
+                                testGen, error_name, s, p, k
                             )
-                    elif (
-                        n % sparsity == 0
-                        # padding must not exceed the kernel size
-                        and p[0] < k[0]
-                        and p[1] < k[0]
-                        and p[2] < k[1]
-                        and p[3] < k[1]
-                        # the padded shape must exceed the kernel size
-                        and (shape[1] + p[0] + p[1]) > k[0]
-                        and (shape[2] + p[2] + p[3]) > k[1]
-                    ):
-                        remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
-                        remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
-                        if (
-                            # the parameters must produce integer exact output
-                            error_name != ErrorIf.PoolingOutputShapeNonInteger
-                            and remainder_h == 0
-                            and remainder_w == 0
-                        ) or (
-                            error_name == ErrorIf.PoolingOutputShapeNonInteger
-                            and (remainder_h != 0 or remainder_w != 0)
+                            if None not in [sNew, pNew, kNew] and n % sparsity == 0:
+                                arg_vals = [a, sNew, pNew, kNew]
+                                arg_list.append(get_arg_list_element(*arg_vals))
+                        elif (
+                            n % sparsity == 0
+                            # padding must not exceed the kernel size
+                            and p[0] < k[0]
+                            and p[1] < k[0]
+                            and p[2] < k[1]
+                            and p[3] < k[1]
+                            # the padded shape must exceed the kernel size
+                            and (shape[1] + p[0] + p[1]) > k[0]
+                            and (shape[2] + p[2] + p[3]) > k[1]
                         ):
-                            arg_list.append(
-                                (
-                                    "st{}_kern{}_pad{}".format(
-                                        "".join([str(x) for x in s]),
-                                        "".join([str(x) for x in k]),
-                                        "".join([str(x) for x in p]),
-                                    ),
-                                    [s, p, k],
-                                )
-                            )
-                    n += 1
+                            remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
+                            remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
+                            if (
+                                # the parameters must produce integer exact output
+                                error_name != ErrorIf.PoolingOutputShapeNonInteger
+                                and remainder_h == 0
+                                and remainder_w == 0
+                            ) or (
+                                error_name == ErrorIf.PoolingOutputShapeNonInteger
+                                and (remainder_h != 0 or remainder_w != 0)
+                            ):
+                                arg_vals = [a, s, p, k]
+                                arg_list.append(get_arg_list_element(*arg_vals))
+                        n += 1
 
         return arg_list
 
@@ -1327,6 +1419,8 @@
             dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
         elif inDtype == DType.BOOL:
             dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+        elif inDtype == DType.FP16:
+            dtypeList = [DType.INT8, DType.INT16, DType.INT32]
         elif inDtype == DType.FLOAT:
             dtypeList = [DType.INT8, DType.INT16, DType.INT32]
         elif error_name == ErrorIf.WrongInputType:
@@ -1734,6 +1828,8 @@
                 outputDTypeList = [DType.INT32]
             elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
                 outputDTypeList = [DType.INT48]
+            elif dtype == DType.FP16:
+                outputDTypeList = [DType.FP16]
             elif dtype == DType.FLOAT:
                 outputDTypeList = [DType.FLOAT]
             elif error_name == ErrorIf.WrongInputType: