Add support for FP8 to reference model
Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 76e7388..31a0ff0 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -27,6 +27,8 @@
DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
+ DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "json": "FP8E4M3"},
+ DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "json": "FP8E5M2"},
}
@@ -186,6 +188,16 @@
DType.INT32,
DType.INT48,
)
+ elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FP32,
+ DType.BF16,
+ )
else:
# Assume all types but the input type are incorrect
incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
@@ -209,6 +221,12 @@
return f32_bits[16:] == "0" * 16
+def float32_is_valid_float8(f):
+ """Return True if float value is valid float8."""
+ f32_bits = get_float32_bitstring(f)
+ return f32_bits[8:] == "0" * 24
+
+
def get_float32_bitstring(f):
"""Return a big-endian string of bits representing a 32 bit float."""
f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
@@ -232,6 +250,30 @@
return struct.unpack("@f", fp_bytes)[0] # native byteorder
+def float32_to_fp8e4m3(f):
+ """Turns fp32 value into fp8e4m3"""
+ f32_bits = get_float32_bitstring(f)
+ fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24
+ fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+ return struct.unpack("@f", fp_bytes)[0] # native byteorder
+
+
+def float32_to_fp8e5m2(f):
+ """Turns fp32 value into fp8e5m2"""
+ f32_bits = get_float32_bitstring(f)
+ fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24
+ fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+ return struct.unpack("@f", fp_bytes)[0]
+
+
vect_f32_to_bf16 = np.vectorize(
float32_to_bfloat16, otypes=(np.float32,)
) # NumPy vectorize: applies function to vector faster than looping
+
+vect_f32_to_fp8e4m3 = np.vectorize(
+ float32_to_fp8e4m3, otypes=(np.float32,)
+) # NumPy vectorize: applies function to vector faster than looping
+
+vect_f32_to_fp8e5m2 = np.vectorize(
+ float32_to_fp8e5m2, otypes=(np.float32,)
+) # Numpy vectorize: applies function to vector faster than looping