Add support for FP8 to reference model

Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 4ead982..bc931dc 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -76,7 +76,7 @@
             return tuple(sorted(vals))
 
         self.random_float_range = {}
-        for dtype in (DType.FP32, DType.FP16, DType.BF16):
+        for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
             self.random_float_range[dtype] = convertFPRange(
                 args.tensor_fp_value_range,
                 TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
@@ -152,7 +152,7 @@
         # Returns dtype value range boundaries (low, high)
         # The high boundary is excluded in the range
         # unless high_inclusive is True
-        if dtype in (DType.FP32, DType.FP16, DType.BF16):
+        if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
             return self.random_float_range[dtype]
         elif dtype == DType.BOOL:
             rng = (0, 2)
@@ -197,7 +197,13 @@
             return np.uint8(self.rng.integers(low=low, high=high, size=shape))
         elif dtype in (DType.INT48, DType.SHAPE):
             return np.int64(self.rng.integers(low=low, high=high, size=shape))
-        elif dtype in (DType.FP16, DType.BF16, DType.FP32):
+        elif dtype in (
+            DType.FP16,
+            DType.BF16,
+            DType.FP32,
+            DType.FP8E4M3,
+            DType.FP8E5M2,
+        ):
             f_tensor = self.rng.uniform(low=low, high=high, size=shape)
 
             if dtype == DType.FP16:
@@ -207,6 +213,10 @@
                 if dtype == DType.BF16:
                     # Floor the last 16 bits of each f32 value
                     return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
+                elif dtype == DType.FP8E4M3:
+                    return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
+                elif dtype == DType.FP8E5M2:
+                    return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
                 else:
                     return f32_tensor
         else:
@@ -266,6 +276,12 @@
         elif dtype == DType.BF16:
             rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
             return gtu.vect_f32_to_bf16(rand_f32)
+        elif dtype == DType.FP8E4M3:
+            rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
+            return gtu.vect_f32_to_fp8e4m3(rand_f32)
+        elif dtype == DType.FP8E5M2:
+            rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
+            return gtu.vect_f32_to_fp8e5m2(rand_f32)
         elif dtype == DType.BOOL:
             return self.rng.choice([False, True])
         elif dtype == DType.INT48 or dtype == DType.SHAPE:
@@ -1408,8 +1424,11 @@
                 max_val = max_val.astype(np.float32)
 
             attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
-        else:
+        elif a.dtype in (DType.INT8, DType.INT16):
             attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
+        else:
+            # to avoid internal error for incorrect input types
+            attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
 
         self.ser.addOperator(op["op"], input_list, output_list, attr)
 
@@ -3190,7 +3209,13 @@
     ]
     TYPE_FI16 = [DType.FP32, DType.INT16]
 
-    TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
+    TYPE_NARROW_INT_FP = [
+        DType.INT8,
+        DType.INT16,
+        DType.FP16,
+        DType.BF16,
+        DType.FP32,
+    ]
 
     # List of [Input Type 1, Input Type 2, Accumulator Type]
     TYPE_CONV = [
@@ -3201,6 +3226,8 @@
         [DType.FP16, DType.FP16, DType.FP32],
         [DType.BF16, DType.BF16, DType.FP32],
         [DType.FP32, DType.FP32, DType.FP32],
+        [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
+        [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
     ]
 
     DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
@@ -3217,7 +3244,7 @@
                 TosaTensorValuesGen.tvgLazyGenDefault,
                 TosaArgGen.agAxis,
             ),
-            "types": TYPE_NARROW_INT_FP,
+            "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
             "error_if_validators": (
                 TosaErrorValidator.evAxisSmallerZero,
                 TosaErrorValidator.evAxisLargerRank,
@@ -3244,7 +3271,7 @@
                 TosaArgGen.agPooling,
             ),
             "qgen": TosaQuantGen.qgUnary,
-            "types": TYPE_NARROW_INT_FP,
+            "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
             "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
             "error_if_validators": (
                 TosaErrorValidator.evKernelSmallerOne,
@@ -3402,7 +3429,7 @@
                 TosaArgGen.agMatMul,
             ),
             "qgen": TosaQuantGen.qgMatmul,
-            "types": TYPE_NARROW_INT_FP,
+            "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
             "error_if_validators": (
                 TosaErrorValidator.evInputZeroPointNotZero,
                 TosaErrorValidator.evWrongRank,
@@ -3425,7 +3452,7 @@
                 TosaTensorValuesGen.tvgLazyGenDefault,
                 TosaArgGen.agPooling,
             ),
-            "types": TYPE_NARROW_INT_FP,
+            "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
             "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
             "error_if_validators": (
                 TosaErrorValidator.evKernelSmallerOne,
@@ -4389,7 +4416,7 @@
                 TosaTensorValuesGen.tvgConcat,
                 TosaArgGen.agAxis,
             ),
-            "types": TYPE_FIB,
+            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
             "error_if_validators": (
                 TosaErrorValidator.evAxisLargerRank,
                 TosaErrorValidator.evAxisSmallerZero,
@@ -4413,7 +4440,7 @@
                 TosaTensorValuesGen.tvgPad,
                 TosaArgGen.agPad,
             ),
-            "types": TYPE_FIB,
+            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
             "error_if_validators": (
                 TosaErrorValidator.evWrongInputType,
                 TosaErrorValidator.evPadSmallerZero,
@@ -4437,7 +4464,7 @@
                 TosaTensorValuesGen.tvgLazyGenDefault,
                 TosaArgGen.agAxis,
             ),
-            "types": TYPE_FIB,
+            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
             "error_if_validators": (
                 TosaErrorValidator.evAxisLargerRank,
                 TosaErrorValidator.evAxisSmallerZero,
@@ -4456,7 +4483,7 @@
                 TosaTensorValuesGen.tvgReshape,
                 TosaArgGen.agReshape,
             ),
-            "types": TYPE_FIB,
+            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
             "error_if_validators": (
                 TosaErrorValidator.evTensorSizeInputOutputMismatch,
                 TosaErrorValidator.evWrongInputType,
@@ -4477,7 +4504,7 @@
                 TosaTensorValuesGen.tvgLazyGenDefault,
                 TosaArgGen.agAxis,
             ),
-            "types": TYPE_FIB,
+            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
             "error_if_validators": (
                 TosaErrorValidator.evAxisSmallerZero,
                 TosaErrorValidator.evAxisLargerRank,
@@ -4500,7 +4527,7 @@
                 TosaTensorValuesGen.tvgSlice,
                 TosaArgGen.agSlice,
             ),
-            "types": TYPE_FIB,
+            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
             "error_if_validators": (
                 # TODO Turn off these error categories for now as the reference
                 # model cannot allocate memory space for empty tensor. We probably
@@ -4532,7 +4559,7 @@
                 TosaTensorValuesGen.tvgTile,
                 TosaArgGen.agTile,
             ),
-            "types": TYPE_FIB,
+            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
             "error_if_validators": (
                 TosaErrorValidator.evWrongInputType,
                 TosaErrorValidator.evWrongOutputType,
@@ -4555,7 +4582,7 @@
                 TosaTensorValuesGen.tvgLazyGenDefault,
                 TosaArgGen.agTranspose,
             ),
-            "types": TYPE_FIB,
+            "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
             "error_if_validators": (
                 TosaErrorValidator.evIndexOutsideBounds,
                 TosaErrorValidator.evIndexUsedTwice,
@@ -4581,7 +4608,7 @@
                 TosaTensorValuesGen.tvgLazyGenDefault,
                 TosaArgGen.agNone,
             ),
-            "types": TYPE_FIB + [DType.INT48],
+            "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
             "data_gen": {
                 "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
             },
@@ -4618,6 +4645,8 @@
                 DType.FP16,
                 DType.BF16,
                 DType.FP32,
+                DType.FP8E4M3,
+                DType.FP8E5M2,
             ),
             "error_if_validators": (
                 TosaErrorValidator.evWrongInputType,
@@ -4640,7 +4669,7 @@
                 TosaTensorValuesGen.tvgScatter,
                 TosaArgGen.agNone,
             ),
-            "types": TYPE_INT_FP,
+            "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
             "error_if_validators": (
                 TosaErrorValidator.evWrongInputType,
                 TosaErrorValidator.evWrongOutputType,
@@ -4709,6 +4738,8 @@
                 DType.INT16,
                 DType.INT32,
                 DType.BOOL,
+                DType.FP8E4M3,
+                DType.FP8E5M2,
             ),
             "error_if_validators": (
                 TosaErrorValidator.evWrongInputType,
@@ -5141,6 +5172,8 @@
                 DType.FP32,
                 DType.FP16,
                 DType.BF16,
+                DType.FP8E4M3,
+                DType.FP8E5M2,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
             outputDType = rng.choice(wrong_dtypes)
@@ -5194,6 +5227,8 @@
         if error_name == ErrorIf.WrongOutputType:
             if ifm.dtype == DType.FP16:
                 excludes = [DType.FP16, DType.FP32]
+            if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
+                excludes = [DType.FP16]
             else:
                 excludes = [out_dtype]
             wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
@@ -5344,6 +5379,8 @@
                 DType.FP32,
                 DType.FP16,
                 DType.BF16,
+                DType.FP8E4M3,
+                DType.FP8E5M2,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
             outputDType = rng.choice(wrong_dtypes)
@@ -5383,6 +5420,8 @@
                     DType.FP32,
                     DType.FP16,
                     DType.BF16,
+                    DType.FP8E4M3,
+                    DType.FP8E5M2,
                 )
             elif a.dtype == DType.INT16:
                 incorrect_types = (
@@ -5393,6 +5432,20 @@
                     DType.FP32,
                     DType.FP16,
                     DType.BF16,
+                    DType.FP8E4M3,
+                    DType.FP8E5M2,
+                )
+            elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
+                incorrect_types = (
+                    DType.INT4,
+                    DType.INT8,
+                    DType.INT16,
+                    DType.INT32,
+                    DType.INT48,
+                    DType.FP32,
+                    DType.BF16,
+                    DType.FP8E4M3,
+                    DType.FP8E5M2,
                 )
             elif (
                 a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
@@ -5403,6 +5456,8 @@
                     DType.INT16,
                     DType.INT32,
                     DType.INT48,
+                    DType.FP8E4M3,
+                    DType.FP8E5M2,
                 )
             out_dtype = rng.choice(a=incorrect_types)
         elif error_name == ErrorIf.WrongInputType:
@@ -5669,6 +5724,8 @@
                 DType.FP32,
                 DType.FP16,
                 DType.BF16,
+                DType.FP8E4M3,
+                DType.FP8E5M2,
             ]
             wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
             outputDType = rng.choice(wrong_dtypes)