Add support for FP8 to reference model

Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 76e7388..31a0ff0 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -27,6 +27,8 @@
     DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
     DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
     DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
+    DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "json": "FP8E4M3"},
+    DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "json": "FP8E5M2"},
 }
 
 
@@ -186,6 +188,16 @@
                 DType.INT32,
                 DType.INT48,
             )
+        elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2:
+            incorrect_types = (
+                DType.INT4,
+                DType.INT8,
+                DType.INT16,
+                DType.INT32,
+                DType.INT48,
+                DType.FP32,
+                DType.BF16,
+            )
     else:
         # Assume all types but the input type are incorrect
         incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
@@ -209,6 +221,12 @@
     return f32_bits[16:] == "0" * 16
 
 
+def float32_is_valid_float8(f):
+    """Return True if float value is valid float8."""
+    f32_bits = get_float32_bitstring(f)
+    return f32_bits[8:] == "0" * 24
+
+
 def get_float32_bitstring(f):
     """Return a big-endian string of bits representing a 32 bit float."""
     f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
@@ -232,6 +250,30 @@
     return struct.unpack("@f", fp_bytes)[0]  # native byteorder
 
 
+def float32_to_fp8e4m3(f):
+    """Turns fp32 value into fp8e4m3"""
+    f32_bits = get_float32_bitstring(f)
+    fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24
+    fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+    return struct.unpack("@f", fp_bytes)[0]  # native byteorder
+
+
+def float32_to_fp8e5m2(f):
+    """Turns fp32 value into fp8e5m2"""
+    f32_bits = get_float32_bitstring(f)
+    fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24
+    fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+    return struct.unpack("@f", fp_bytes)[0]
+
+
 vect_f32_to_bf16 = np.vectorize(
     float32_to_bfloat16, otypes=(np.float32,)
 )  # NumPy vectorize: applies function to vector faster than looping
+
+vect_f32_to_fp8e4m3 = np.vectorize(
+    float32_to_fp8e4m3, otypes=(np.float32,)
+)  # NumPy vectorize: applies function to vector faster than looping
+
+vect_f32_to_fp8e5m2 = np.vectorize(
+    float32_to_fp8e5m2, otypes=(np.float32,)
+)  # Numpy vectorize: applies function to vector faster than looping