Reference model changes for fp16 support

Change-Id: I72f21fcfa153046274969d327313e3349981dbe6
Signed-off-by: James Ward <james.ward@arm.com>
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index b76b656..9ff6ec5 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -81,6 +81,8 @@
             return np.int64(
                 self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape)
             )
+        elif dtype == DType.FP16:
+            return np.float16(self.rng.random(size=shape))
         elif dtype == DType.FLOAT:
             return np.float32(self.rng.random(size=shape))
         else:
@@ -128,6 +130,9 @@
     def getRandNumberDType(self, dtype):
         if dtype == DType.FLOAT:
             return self.rng.random()
+        elif dtype == DType.FP16:
+            rand_f32 = self.rng.random()
+            return np.float16(rand_f32)
         elif dtype == DType.BOOL:
             return self.rng.choice([False, True])
         # TOSA specific INT4 weight range from -7 to 7
@@ -178,13 +183,15 @@
                 return "i32"
             elif t == DType.INT48:
                 return "i48"
+            elif t == DType.FP16:
+                return "f16"
             elif t == DType.FLOAT:
                 return "float"
             else:
                 raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
 
     def typeWidth(self, t):
-        """Get the datatype width for integer types"""
+        """Get the datatype width for data types"""
         if t == DType.INT4:
             return 4
         elif t == DType.INT8:
@@ -199,6 +206,8 @@
             return 32
         elif t == DType.INT48:
             return 48
+        elif t == DType.FP16:
+            return 16
         elif t == DType.FLOAT:
             return 32
         elif t == DType.BOOL:
@@ -346,7 +355,7 @@
 
         # Special for multiply:
         # Force the result to INT32 for INT types
-        if a.dtype != DType.FLOAT:
+        if a.dtype not in (DType.FP16, DType.FLOAT):
             result_tens.setDtype(DType.INT32)
         if error_name == ErrorIf.WrongOutputType:
             all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
@@ -533,6 +542,7 @@
         self,
         op,
         input,
+        accum_dtype,
         stride,
         pad,
         kernel,
@@ -585,17 +595,43 @@
             qinfo = [0, 0]
 
         attr = ts.TosaSerializerAttribute()
-        attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1])
+        attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype)
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
         return result_tens
 
+    def build_maxpool2d(
+        self,
+        op,
+        input,
+        stride,
+        pad,
+        kernel,
+        validator_fcns=None,
+        error_name=None,
+        qinfo=None,
+    ):
+        # Same as build_pool2d but manually sets accum_dtype value
+        # (maxpool has no accum_dtype)
+        return self.build_pool2d(
+            op,
+            input,
+            DType.UNKNOWN,
+            stride,
+            pad,
+            kernel,
+            validator_fcns,
+            error_name,
+            qinfo,
+        )
+
     def build_conv2d(
         self,
         op,
         ifm,
         filter,
         bias,
+        accum_dtype,
         strides,
         padding,
         dilations,
@@ -605,7 +641,15 @@
     ):
         assert len(padding) == 4
         result_tens = OutputShaper.conv2dOp(
-            self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
+            self.ser,
+            self.rng,
+            ifm,
+            filter,
+            accum_dtype,
+            strides,
+            padding,
+            dilations,
+            error_name,
         )
 
         # Ensure new output type has correct qinfo
@@ -648,7 +692,7 @@
             return None
 
         attr = ts.TosaSerializerAttribute()
-        attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
+        attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
         return result_tens
@@ -659,6 +703,7 @@
         ifm,
         filter,
         bias,
+        accum_dtype,
         strides,
         padding,
         dilations,
@@ -668,7 +713,15 @@
     ):
         assert len(padding) == 6
         result_tens = OutputShaper.conv3dOp(
-            self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
+            self.ser,
+            self.rng,
+            ifm,
+            filter,
+            accum_dtype,
+            strides,
+            padding,
+            dilations,
+            error_name,
         )
 
         # Ensure new output type has correct qinfo
@@ -711,7 +764,7 @@
             return None
 
         attr = ts.TosaSerializerAttribute()
-        attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
+        attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
         return result_tens
@@ -722,6 +775,7 @@
         ifm,
         filter,
         bias,
+        accum_dtype,
         stride,
         out_pad,
         output_shape,
@@ -731,7 +785,7 @@
     ):
         assert len(out_pad) == 4
         result_tens = OutputShaper.transposeConv2DOp(
-            self.ser, self.rng, ifm, output_shape, error_name
+            self.ser, self.rng, ifm, output_shape, accum_dtype, error_name
         )
 
         # Ensure new output type has correct qinfo
@@ -773,7 +827,9 @@
             return None
 
         attr = ts.TosaSerializerAttribute()
-        attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
+        attr.TransposeConvAttribute(
+            out_pad, stride, output_shape, qinfo[0], qinfo[1], accum_dtype
+        )
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
         return result_tens
@@ -784,6 +840,7 @@
         ifm,
         filter,
         bias,
+        accum_dtype,
         strides,
         padding,
         dilations,
@@ -792,7 +849,15 @@
         qinfo=None,
     ):
         result_tens = OutputShaper.depthwiseConv2dOp(
-            self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
+            self.ser,
+            self.rng,
+            ifm,
+            filter,
+            accum_dtype,
+            strides,
+            padding,
+            dilations,
+            error_name,
         )
 
         # Ensure new output type has correct qinfo
@@ -835,16 +900,24 @@
             return None
 
         attr = ts.TosaSerializerAttribute()
-        attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
+        attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
         return result_tens
 
     def build_fully_connected(
-        self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None
+        self,
+        op,
+        ifm,
+        filter,
+        bias,
+        accum_dtype,
+        validator_fcns=None,
+        error_name=None,
+        qinfo=None,
     ):
         result_tens = OutputShaper.fullyConnectedOp(
-            self.ser, self.rng, ifm, filter, error_name
+            self.ser, self.rng, ifm, filter, accum_dtype, error_name
         )
 
         # Invalidate Input/Output list for error if checks.
@@ -871,17 +944,22 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
+            accum_dtype=accum_dtype,
         ):
             return None
 
         attr = ts.TosaSerializerAttribute()
-        attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
+        attr.FullyConnectedAttribute(qinfo[0], qinfo[1], accum_dtype)
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
         return result_tens
 
-    def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
-        result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
+    def build_matmul(
+        self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None
+    ):
+        result_tens = OutputShaper.matmulOp(
+            self.ser, self.rng, a, b, accum_dtype, error_name
+        )
 
         # Invalidate Input/Output list for error if checks.
         input_list = [a.name, b.name]
@@ -908,11 +986,12 @@
             input_list=input_list,
             output_list=output_list,
             num_operands=num_operands,
+            accum_dtype=accum_dtype,
         ):
             return None
 
         attr = ts.TosaSerializerAttribute()
-        attr.MatMulAttribute(qinfo[0], qinfo[1])
+        attr.MatMulAttribute(qinfo[0], qinfo[1], accum_dtype)
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
         return result_tens
@@ -995,7 +1074,7 @@
             return None
 
         attr = ts.TosaSerializerAttribute()
-        if a.dtype == DType.FLOAT:
+        if a.dtype in (DType.FP16, DType.FLOAT):
             attr.ClampAttribute(0, 0, min_val, max_val)
         else:
             attr.ClampAttribute(min_val, max_val, 0, 0)
@@ -1811,7 +1890,7 @@
             op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
         )
 
-        if a.dtype in (DType.FLOAT, DType.INT32):
+        if a.dtype in (DType.FLOAT, DType.FP16, DType.INT32):
             then_op, else_op = Op.ADD, Op.SUB
         elif a.dtype in (DType.INT8, DType.INT16):
             then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
@@ -2350,22 +2429,37 @@
     #    if not specified, defaults to (1, 4)
     #  'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
     #  'types': array of datatypes to be tested
-    TYPE_FP = [DType.FLOAT]
+    TYPE_FP = [DType.FLOAT, DType.FP16]
 
     TYPE_INT = [DType.INT8, DType.INT16, DType.INT32]  # Excludes INT4
-    TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT]  # Excludes INT4
+    TYPE_INT_FP = [
+        DType.INT8,
+        DType.INT16,
+        DType.INT32,
+        DType.FP16,
+        DType.FLOAT,
+    ]  # Excludes INT4
 
     TYPE_BOOL = [DType.BOOL]
-    TYPE_FI32 = [DType.FLOAT, DType.INT32]
-    TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL]
+    TYPE_FI32 = [DType.FLOAT, DType.FP16, DType.INT32]  # floating-types and INT32
+    TYPE_FIB = [
+        DType.FP16,
+        DType.FLOAT,
+        DType.INT8,
+        DType.INT16,
+        DType.INT32,
+        DType.BOOL,
+    ]
     TYPE_FI16 = [DType.FLOAT, DType.INT16]
 
-    TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
+    TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
 
     TYPE_CONV = [
         [DType.INT8, DType.INT4, DType.INT32],
         [DType.INT8, DType.INT8, DType.INT32],
         [DType.INT16, DType.INT8, DType.INT48],
+        [DType.FP16, DType.FP16, DType.FP16],
+        [DType.FP16, DType.FP16, DType.FLOAT],
         DType.FLOAT,
     ]
 
@@ -2524,7 +2618,7 @@
                 build_fully_connected,
                 TosaTensorGen.tgFullyConnected,
                 TosaTensorValuesGen.tvgDefault,
-                None,
+                TosaArgGen.agFullyConnected,
             ),
             "qgen": TosaQuantGen.qgConv,
             "types": TYPE_CONV,
@@ -2546,7 +2640,7 @@
                 build_matmul,
                 TosaTensorGen.tgMatmul,
                 TosaTensorValuesGen.tvgDefault,
-                None,
+                TosaArgGen.agMatMul,
             ),
             "qgen": TosaQuantGen.qgMatmul,
             "types": TYPE_NARROW_INT_FP,
@@ -2564,7 +2658,7 @@
             "operands": (1, 0),
             "rank": (4, 4),
             "build_fcn": (
-                build_pool2d,
+                build_maxpool2d,
                 TosaTensorGen.tgNHWC,
                 TosaTensorValuesGen.tvgDefault,
                 TosaArgGen.agPooling,
@@ -3384,7 +3478,7 @@
                 TosaTensorValuesGen.tvgReduceSum,
                 TosaArgGen.agAxis,
             ),
-            "types": TYPE_FI32,
+            "types": (DType.FP16, DType.FLOAT, DType.INT32),
             "error_if_validators": (
                 TosaErrorValidator.evAxisLargerRank,
                 TosaErrorValidator.evAxisSmallerZero,
@@ -3571,7 +3665,7 @@
                 TosaTensorValuesGen.tvgDefault,
                 None,
             ),
-            "types": TYPE_INT_FP,
+            "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FLOAT),
             "error_if_validators": (
                 TosaErrorValidator.evWrongInputType,
                 TosaErrorValidator.evWrongOutputType,
@@ -3612,7 +3706,7 @@
                 TosaTensorValuesGen.tvgDefault,
                 TosaArgGen.agResize,
             ),
-            "types": [DType.INT8, DType.INT16, DType.FLOAT],
+            "types": (DType.INT8, DType.INT16, DType.FP16, DType.FLOAT),
             "invalid_test_validators": (
                 TosaInvalidValidator.ivWrongDataTypeOrModeResize,
             ),
@@ -3646,7 +3740,14 @@
                 TosaTensorValuesGen.tvgDefault,
                 TosaArgGen.agCast,
             ),
-            "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
+            "types": (
+                DType.FP16,
+                DType.FLOAT,
+                DType.INT8,
+                DType.INT16,
+                DType.INT32,
+                DType.BOOL,
+            ),
             "error_if_validators": (
                 TosaErrorValidator.evWrongInputType,
                 TosaErrorValidator.evWrongOutputType,
@@ -3925,7 +4026,9 @@
         return ser.addOutput(shape, outputDType)
 
     @staticmethod
-    def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
+    def conv2dOp(
+        ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
+    ):
 
         # IFM:    NHWC
         # Filter: OHWI
@@ -3958,26 +4061,26 @@
 
         ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
 
-        if ifm.dtype == DType.INT8:
-            out_dtype = DType.INT32
-        elif ifm.dtype == DType.INT16:
-            out_dtype = DType.INT48
-        elif ifm.dtype == DType.FLOAT:
-            out_dtype = DType.FLOAT
-        elif error_name == ErrorIf.WrongInputType:
+        if error_name == ErrorIf.WrongInputType:
             # Pick some potentially correct output dtype if input type is incorrect
             out_dtype = DType.INT32
         else:
-            raise Exception(f"Unsupported input dtype: {ifm.dtype}")
+            out_dtype = accum_dtype
 
         if error_name == ErrorIf.WrongOutputType:
-            wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
+            if ifm.dtype == DType.FP16:
+                excludes = [DType.FP16, DType.FLOAT]
+            else:
+                excludes = [out_dtype]
+            wrong_dtypes = list(usableDTypes(excludes=excludes))
             out_dtype = rng.choice(wrong_dtypes)
 
         return ser.addOutput(ofm_shape, out_dtype)
 
     @staticmethod
-    def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
+    def conv3dOp(
+        ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
+    ):
 
         # IFM:    NDHWC
         # Filter: ODHWI
@@ -4020,27 +4123,25 @@
 
         ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
 
-        if ifm.dtype == DType.INT8:
-            out_dtype = DType.INT32
-        elif ifm.dtype == DType.INT16:
-            out_dtype = DType.INT48
-        elif ifm.dtype == DType.FLOAT:
-            out_dtype = DType.FLOAT
-        elif error_name == ErrorIf.WrongInputType:
+        if error_name == ErrorIf.WrongInputType:
             # Pick some potentially correct output dtype if input type is incorrect
             out_dtype = DType.INT32
         else:
-            raise Exception(f"Unsupported input dtype: {ifm.dtype}")
+            out_dtype = accum_dtype
 
         if error_name == ErrorIf.WrongOutputType:
-            wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
+            if ifm.dtype == DType.FP16:
+                excludes = [DType.FP16, DType.FLOAT]
+            else:
+                excludes = [out_dtype]
+            wrong_dtypes = list(usableDTypes(excludes=excludes))
             out_dtype = rng.choice(wrong_dtypes)
 
         return ser.addOutput(ofm_shape, out_dtype)
 
     @staticmethod
     def depthwiseConv2dOp(
-        ser, rng, ifm, filter, strides, padding, dilations, error_name=None
+        ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
     ):
         # IFM:    NHWC
         # Filter: HWCM
@@ -4073,20 +4174,18 @@
 
         ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
 
-        if ifm.dtype == DType.INT8:
-            out_dtype = DType.INT32
-        elif ifm.dtype == DType.INT16:
-            out_dtype = DType.INT48
-        elif ifm.dtype == DType.FLOAT:
-            out_dtype = DType.FLOAT
-        elif error_name == ErrorIf.WrongInputType:
+        if error_name == ErrorIf.WrongInputType:
             # Pick some potentially correct output dtype if input type is incorrect
             out_dtype = DType.INT32
         else:
-            raise Exception(f"Unsupported input dtype: {ifm.dtype}")
+            out_dtype = accum_dtype
 
         if error_name == ErrorIf.WrongOutputType:
-            wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
+            if ifm.dtype == DType.FP16:
+                excludes = [DType.FP16, DType.FLOAT]
+            else:
+                excludes = [out_dtype]
+            wrong_dtypes = list(usableDTypes(excludes=excludes))
             out_dtype = rng.choice(wrong_dtypes)
 
         return ser.addOutput(ofm_shape, out_dtype)
@@ -4119,6 +4218,7 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FLOAT,
+                DType.FP16,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -4128,55 +4228,20 @@
         return ser.addOutput(ofm_shape, outputDType)
 
     @staticmethod
-    def fullyConnectedOp(ser, rng, input, filter, error_name=None):
+    def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None):
         # input: N, IC
         # filter: OC, IC
         # output: N, OC
 
         output_shape = [input.shape[0], filter.shape[0]]
 
-        if error_name == ErrorIf.WrongOutputType:
-            if input.dtype == DType.INT8:
-                incorrect_types = (
-                    DType.INT4,
-                    DType.INT8,
-                    DType.INT16,
-                    DType.INT48,
-                    DType.FLOAT,
-                )
-            elif input.dtype == DType.INT16:
-                incorrect_types = (
-                    DType.INT4,
-                    DType.INT8,
-                    DType.INT16,
-                    DType.INT32,
-                    DType.FLOAT,
-                )
-            elif input.dtype == DType.FLOAT:
-                incorrect_types = (
-                    DType.INT4,
-                    DType.INT8,
-                    DType.INT16,
-                    DType.INT32,
-                    DType.INT48,
-                )
-            out_dtype = rng.choice(a=incorrect_types)
-        elif input.dtype == DType.INT8:
-            out_dtype = DType.INT32
-        elif input.dtype == DType.INT16:
-            out_dtype = DType.INT48
-        elif input.dtype == DType.FLOAT:
-            out_dtype = DType.FLOAT
-        elif error_name == ErrorIf.WrongInputType:
-            # Pick some potentially correct output dtype if input type is incorrect
-            out_dtype = DType.INT32
-        else:
-            raise Exception("Unsupported input dtype: {}".format(input.dtype))
+        # Validated in arg_gen (also invalidated for ErrorIf)
+        out_dtype = accum_dtype
 
         return ser.addOutput(output_shape, out_dtype)
 
     @staticmethod
-    def matmulOp(ser, rng, a, b, error_name=None):
+    def matmulOp(ser, rng, a, b, accum_dtype, error_name=None):
         # a: N, H, C
         # b: N, C, W
         # out: N, H, W
@@ -4200,7 +4265,7 @@
                     DType.INT32,
                     DType.FLOAT,
                 )
-            elif a.dtype == DType.FLOAT:
+            elif a.dtype == DType.FLOAT or a.dtype == DType.FP16:
                 incorrect_types = (
                     DType.INT4,
                     DType.INT8,
@@ -4209,17 +4274,11 @@
                     DType.INT48,
                 )
             out_dtype = rng.choice(a=incorrect_types)
-        elif a.dtype == DType.INT8:
-            out_dtype = DType.INT32
-        elif a.dtype == DType.INT16:
-            out_dtype = DType.INT48
-        elif a.dtype == DType.FLOAT:
-            out_dtype = DType.FLOAT
         elif error_name == ErrorIf.WrongInputType:
             # Pick some potentially correct output dtype if input type is incorrect
             out_dtype = DType.INT32
         else:
-            raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
+            out_dtype = accum_dtype  # Validated in arg_gen
 
         return ser.addOutput(output_shape, out_dtype)
 
@@ -4269,10 +4328,6 @@
             bad_dim = rng.choice(range(len(output_shape)))
             output_shape[bad_dim] -= rng.choice([1, 2])
 
-        # Fix negative output shape if error_if test causes it
-        if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1:
-            output_shape = [i if i >= 1 else 1 for i in output_shape]
-
         if error_name == ErrorIf.WrongOutputType:
             all_dtypes = [
                 DType.INT8,
@@ -4280,6 +4335,7 @@
                 DType.INT32,
                 DType.INT48,
                 DType.FLOAT,
+                DType.FP16,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -4546,7 +4602,7 @@
         return ser.addOutput(val.shape, out_dtype)
 
     @staticmethod
-    def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
+    def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None):
         if error_name == ErrorIf.ConvOutputShapeMismatch:
             choices = [1, 2, 3]
             change = rng.choice(choices)
@@ -4555,20 +4611,18 @@
             if change in [2, 3]:
                 output_shape[2] = output_shape[2] + rng.choice(choices)
 
-        if ifm.dtype == DType.INT8:
-            out_dtype = DType.INT32
-        elif ifm.dtype == DType.INT16:
-            out_dtype = DType.INT48
-        elif ifm.dtype == DType.FLOAT:
-            out_dtype = DType.FLOAT
-        elif error_name == ErrorIf.WrongInputType:
+        if error_name == ErrorIf.WrongInputType:
             # Pick some potentially correct output dtype if input type is incorrect
             out_dtype = DType.INT32
         else:
-            raise Exception(f"Unsupported input dtype: {ifm.dtype}")
+            out_dtype = accum_dtype
 
         if error_name == ErrorIf.WrongOutputType:
-            wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
+            if ifm.dtype == DType.FP16:
+                excludes = [DType.FP16, DType.FLOAT]
+            else:
+                excludes = [out_dtype]
+            wrong_dtypes = list(usableDTypes(excludes=excludes))
             out_dtype = rng.choice(wrong_dtypes)
 
         return ser.addOutput(output_shape, out_dtype)