Add support for FP8 to reference model

Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 9a88acb..7a4d0d6 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -325,12 +325,32 @@
 
     @staticmethod
     def eiCastErrorIf(testGen, input_dtype):
-        if input_dtype in [DType.BOOL, DType.FP32]:
+        # if input_dtype in [DType.BOOL, DType.FP32]:
+        #    outputDType = [DType.BOOL, DType.INT48, DType.FP32]
+        if input_dtype in [DType.BOOL]:
+            outputDType = [
+                DType.BOOL,
+                DType.INT48,
+                DType.FP32,
+                DType.FP16,
+                DType.BF16,
+                DType.FP8E4M3,
+                DType.FP8E5M2,
+            ]
+        elif input_dtype in [DType.FP32]:
             outputDType = [DType.BOOL, DType.INT48, DType.FP32]
         elif input_dtype in [DType.FP16, DType.BF16]:
             outputDType = [DType.BOOL, DType.INT48]
         elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
             outputDType = [DType.INT48]
+        elif input_dtype in [DType.FP8E4M3, DType.FP8E5M2]:
+            outputDType = [
+                DType.BOOL,
+                DType.INT8,
+                DType.INT16,
+                DType.INT32,
+                DType.INT48,
+            ]
         else:
             assert False, f"input_dtype ({input_dtype}) not supported"
         return outputDType
@@ -476,13 +496,23 @@
                     )
                     or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
                     or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
+                    or (input_dtype == DType.FP8E4M3 and output_dtype != DType.FP16)
+                    or (input_dtype == DType.FP8E5M2 and output_dtype != DType.FP16)
                 ):
                     error_result = True
 
             elif op["op"] == Op.ARGMAX:
                 if (
                     input_dtype
-                    in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
+                    in [
+                        DType.INT8,
+                        DType.INT16,
+                        DType.FP16,
+                        DType.BF16,
+                        DType.FP32,
+                        DType.FP8E4M3,
+                        DType.FP8E5M2,
+                    ]
                     and output_dtype != DType.INT32
                 ):
                     error_result = True
@@ -555,12 +585,26 @@
                     or (
                         input_dtype == DType.FP16
                         and output_dtype
-                        not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+                        not in [
+                            DType.INT8,
+                            DType.INT16,
+                            DType.INT32,
+                            DType.FP32,
+                            DType.FP8E4M3,
+                            DType.FP8E5M2,
+                        ]
                     )
                     or (
                         input_dtype == DType.BF16
                         and output_dtype
-                        not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+                        not in [
+                            DType.INT8,
+                            DType.INT16,
+                            DType.INT32,
+                            DType.FP32,
+                            DType.FP8E4M3,
+                            DType.FP8E5M2,
+                        ]
                     )
                     or (
                         input_dtype == DType.FP32
@@ -571,6 +615,17 @@
                             DType.INT32,
                             DType.FP16,
                             DType.BF16,
+                            DType.FP8E4M3,
+                            DType.FP8E5M2,
+                        ]
+                    )
+                    or (
+                        input_dtype in [DType.FP8E4M3, DType.FP8E5M2]
+                        and output_dtype
+                        not in [
+                            DType.FP16,
+                            DType.BF16,
+                            DType.FP32,
                         ]
                     )
                 ):
@@ -597,6 +652,10 @@
                     and output_dtype != DType.FP32
                     or input_dtype == DType.FP32
                     and output_dtype != DType.FP32
+                    or input_dtype == DType.FP8E4M3
+                    and output_dtype != DType.FP16
+                    or input_dtype == DType.FP8E5M2
+                    and output_dtype != DType.FP16
                 ):
                     error_result = True
                 # invalid input types are ignored, to avoid reporting multiple errors
@@ -2615,6 +2674,11 @@
                     DType.FP32,
                 ):
                     error_result = True
+                elif (
+                    input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
+                    and accum_dtype != DType.FP16
+                ):
+                    error_result = True
 
         info_dict = {
             "error_name": error_name,