Reference model changes for fp16 support
Change-Id: I72f21fcfa153046274969d327313e3349981dbe6
Signed-off-by: James Ward <james.ward@arm.com>
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 6a689d0..7fa31e7 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -84,3 +84,42 @@
for n in shape:
value *= n
return value
+
+
+def get_accum_dtype_from_tgTypes(dtypes):
+ # Get accumulate data-type from the test generator's defined types
+ if isinstance(dtypes, list) or isinstance(dtypes, tuple):
+ return dtypes[-1]
+ else:
+ return dtypes
+
+
+def get_wrong_output_type(op_name, rng, input_dtype):
+ if op_name == "fully_connected" or op_name == "matmul":
+ if input_dtype == DType.INT8:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT48,
+ DType.FLOAT,
+ DType.FP16,
+ )
+ elif input_dtype == DType.INT16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FLOAT,
+ DType.FP16,
+ )
+ elif input_dtype == DType.FLOAT or input_dtype == DType.FP16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ )
+ return rng.choice(a=incorrect_types)