Add support for FP8 to reference model
Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 9a88acb..7a4d0d6 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -325,12 +325,32 @@
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FP32]:
+ # if input_dtype in [DType.BOOL, DType.FP32]:
+ # outputDType = [DType.BOOL, DType.INT48, DType.FP32]
+ if input_dtype in [DType.BOOL]:
+ outputDType = [
+ DType.BOOL,
+ DType.INT48,
+ DType.FP32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
+ elif input_dtype in [DType.FP32]:
outputDType = [DType.BOOL, DType.INT48, DType.FP32]
elif input_dtype in [DType.FP16, DType.BF16]:
outputDType = [DType.BOOL, DType.INT48]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
+ elif input_dtype in [DType.FP8E4M3, DType.FP8E5M2]:
+ outputDType = [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ ]
else:
assert False, f"input_dtype ({input_dtype}) not supported"
return outputDType
@@ -476,13 +496,23 @@
)
or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
+ or (input_dtype == DType.FP8E4M3 and output_dtype != DType.FP16)
+ or (input_dtype == DType.FP8E5M2 and output_dtype != DType.FP16)
):
error_result = True
elif op["op"] == Op.ARGMAX:
if (
input_dtype
- in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
+ in [
+ DType.INT8,
+ DType.INT16,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
and output_dtype != DType.INT32
):
error_result = True
@@ -555,12 +585,26 @@
or (
input_dtype == DType.FP16
and output_dtype
- not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ not in [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
)
or (
input_dtype == DType.BF16
and output_dtype
- not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ not in [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
)
or (
input_dtype == DType.FP32
@@ -571,6 +615,17 @@
DType.INT32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
+ )
+ or (
+ input_dtype in [DType.FP8E4M3, DType.FP8E5M2]
+ and output_dtype
+ not in [
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
]
)
):
@@ -597,6 +652,10 @@
and output_dtype != DType.FP32
or input_dtype == DType.FP32
and output_dtype != DType.FP32
+ or input_dtype == DType.FP8E4M3
+ and output_dtype != DType.FP16
+ or input_dtype == DType.FP8E5M2
+ and output_dtype != DType.FP16
):
error_result = True
# invalid input types are ignored, to avoid reporting multiple errors
@@ -2615,6 +2674,11 @@
DType.FP32,
):
error_result = True
+ elif (
+ input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
+ and accum_dtype != DType.FP16
+ ):
+ error_result = True
info_dict = {
"error_name": error_name,