Reference model changes for fp16 support

Change-Id: I72f21fcfa153046274969d327313e3349981dbe6
Signed-off-by: James Ward <james.ward@arm.com>
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index f9a00f9..a766803 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -120,6 +120,7 @@
                     DType.INT32,
                     DType.INT48,
                     DType.FLOAT,
+                    DType.FP16,
                 )
             elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
                 incorrect_types = (
@@ -128,6 +129,7 @@
                     DType.INT32,
                     DType.INT48,
                     DType.FLOAT,
+                    DType.FP16,
                 )
             elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
                 incorrect_types = (
@@ -136,6 +138,7 @@
                     DType.INT16,
                     DType.INT48,
                     DType.FLOAT,
+                    DType.FP16,
                 )
             elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
                 incorrect_types = (
@@ -144,6 +147,16 @@
                     DType.INT16,
                     DType.INT32,
                     DType.FLOAT,
+                    DType.FP16,
+                )
+            elif dtype == DType.FP16:
+                incorrect_types = (
+                    DType.INT4,
+                    DType.INT8,
+                    DType.INT16,
+                    DType.INT32,
+                    DType.INT48,
+                    DType.FLOAT,
                 )
             elif dtype == DType.FLOAT:
                 incorrect_types = (
@@ -152,6 +165,7 @@
                     DType.INT16,
                     DType.INT32,
                     DType.INT48,
+                    DType.FP16,
                 )
             outputDType = testGen.rng.choice(a=incorrect_types)
 
@@ -285,8 +299,8 @@
 
     @staticmethod
     def eiCastErrorIf(testGen, input_dtype):
-        if input_dtype in [DType.BOOL, DType.FLOAT]:
-            outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
+        if input_dtype in [DType.BOOL, DType.FP16, DType.FLOAT]:
+            outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FLOAT]
         elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
             outputDType = [DType.INT48]
         else:
@@ -400,6 +414,7 @@
                         and input_dtype == DType.INT16
                         and output_dtype != DType.INT48
                     )
+                    or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
                     or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
                 ):
                     error_result = True
@@ -413,19 +428,28 @@
                 if (
                     (input_dtype == DType.INT8 and output_dtype != DType.INT32)
                     or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
+                    or (
+                        input_dtype == DType.FP16
+                        and output_dtype not in (DType.FP16, DType.FLOAT)
+                    )
                     or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
                 ):
                     error_result = True
 
             elif op["op"] == Op.ARGMAX:
                 if (
-                    input_dtype in [DType.INT8, DType.INT16, DType.FLOAT]
+                    input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
                     and output_dtype != DType.INT32
                 ):
                     error_result = True
 
             elif op["op"] == Op.MUL:
-                if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
+                if (
+                    input_dtype not in (DType.FP16, DType.FLOAT)
+                    and output_dtype != DType.INT32
+                ):
+                    error_result = True
+                elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
                     error_result = True
                 elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
                     error_result = True
@@ -449,17 +473,39 @@
                     or (
                         input_dtype == DType.INT8
                         and output_dtype
-                        not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
+                        not in [
+                            DType.BOOL,
+                            DType.INT16,
+                            DType.INT32,
+                            DType.FLOAT,
+                            DType.FP16,
+                        ]
                     )
                     or (
                         input_dtype == DType.INT16
                         and output_dtype
-                        not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
+                        not in [
+                            DType.BOOL,
+                            DType.INT8,
+                            DType.INT32,
+                            DType.FLOAT,
+                            DType.FP16,
+                        ]
                     )
                     or (
                         input_dtype == DType.INT32
                         and output_dtype
-                        not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+                        not in [
+                            DType.BOOL,
+                            DType.INT8,
+                            DType.INT16,
+                            DType.FLOAT,
+                            DType.FP16,
+                        ]
+                    )
+                    or (
+                        input_dtype == DType.FP16
+                        and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
                     )
                     or (
                         input_dtype == DType.FLOAT
@@ -479,6 +525,8 @@
                     and output_dtype != DType.INT32
                     or input_dtype == DType.INT16
                     and output_dtype != DType.INT48
+                    or input_dtype == DType.FP16
+                    and output_dtype not in (DType.FP16, DType.FLOAT)
                     or input_dtype == DType.FLOAT
                     and output_dtype != DType.FLOAT
                 ):
@@ -2257,12 +2305,13 @@
             return (
                 not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
                 and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
+                and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
                 and not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
             )
         elif mode == ResizeMode.NEAREST:
             # Invalid output data type / Invalid input datatype
             return (input_dtype != output_dtype) or (
-                input_dtype not in [DType.INT8, DType.INT16, DType.FLOAT]
+                input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
             )
         else:
             # Invalid resize mode
@@ -2276,8 +2325,11 @@
         input_shape = inputShapes[0]
 
         args = kwargs["args"]
-        strides = args[0]
-        padding = args[1]
+
+        # MaxPool2D has no accum_dtype arg
+        stride_idx, pad_idx = (0, 1) if opName == "max_pool2d" else (1, 2)
+        strides = args[stride_idx]
+        padding = args[pad_idx]
 
         if opName.endswith("pool2d"):
             # avg_pool2d, max_pool2d
@@ -2365,7 +2417,7 @@
     @staticmethod
     def ivNonPositiveOutputShape(**kwargs):
         args = kwargs["args"]
-        output_shape = args[2]
+        output_shape = args[3]
         if output_shape[1] <= 0 or output_shape[2] <= 0:
             # Negative output shape
             return True