Add conversions of U8 to/from BF16 and FP8

Adds type to PadAttribute and ClampAttribute so their pad_const
and max_val/min_val can be deserialized according to type

Adds conversion functions of U8 arrays to/from BF16/FP8 values
Also, refactor and expose TosaSerializer.convertDataToUint8Vec
for converting dtype/data to uint8 list for serialization
And modify convertDataToUint8Vec to serialize bf16 values into
2 bytes each, and serialize fp8 values into single bytes each.

Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I05659e8187c76d359f1cc9f71c8c23cafd0e877f
diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py
index e6ab3d0..298907e 100644
--- a/python/serializer/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -17,6 +17,7 @@
 import json
 import flatbuffers
 import numpy as np
+import struct
 from enum import IntEnum, unique
 from tosa import (
     TosaGraph,
@@ -204,7 +205,7 @@
         self.bools.append((a.AddLocalBound, local_bound))
         self.ints.append((a.AddAccType, acc_type))
 
-    def PadAttribute(self, serializer_builder, pad_const_val_as_bytes):
+    def PadAttribute(self, serializer_builder, pad_const_val_as_bytes, dtype):
         from tosa import PadAttribute as a, Attribute
 
         self.utype = Attribute.Attribute().PadAttribute
@@ -216,6 +217,7 @@
         )
 
         self.floats.append((a.AddPadConst, serialized_pad_const_val))
+        self.ints.append((a.AddType, dtype))
 
     def AxisAttribute(self, axis):
         from tosa import AxisAttribute as a, Attribute
@@ -236,7 +238,9 @@
         self.int16vecs.append((a.AddBorder, border))
         self.ints.append((a.AddMode, mode))
 
-    def ClampAttribute(self, serializer_builder, min_val_as_bytes, max_val_as_bytes):
+    def ClampAttribute(
+        self, serializer_builder, min_val_as_bytes, max_val_as_bytes, dtype
+    ):
         from tosa import ClampAttribute as a, Attribute
 
         self.utype = Attribute.Attribute().ClampAttribute
@@ -252,6 +256,7 @@
 
         self.floats.append((a.AddMinVal, serialized_min_val))
         self.floats.append((a.AddMaxVal, serialized_max_val))
+        self.ints.append((a.AddType, dtype))
 
     def RescaleAttribute(
         self,
@@ -439,87 +444,7 @@
         fb_name = builder.CreateString(self.name)
         fb_shapes = TosaSerializer.serializeInt32Vec(builder, self.shape)
         if self.data:
-            u8_data = list()
-            # little endianess
-            if self.dtype == DType.BOOL:
-                for val in self.data:
-                    val_u8 = np.uint8(val)
-                    u8_data.append(val_u8)
-            elif self.dtype == DType.INT4:
-                in_size = len(self.data)
-                out_size = (in_size + 1) // 2
-                for i in range(out_size):
-                    val_0 = self.data[2 * i]
-                    if (2 * i + 1) < in_size:
-                        val_1 = self.data[2 * i + 1]
-                    else:
-                        val_1 = 0
-                    val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
-                    val_u8 = np.uint8(val_i8)
-                    u8_data.append(val_u8)
-            elif self.dtype == DType.INT8:
-                for val in self.data:
-                    val_u8 = np.array(val).astype(dtype=np.uint8)
-                    u8_data.append(val_u8)
-            elif self.dtype == DType.INT16:
-                for val in self.data:
-                    val_u16 = np.array(val).astype(dtype=np.uint16)
-                    b0 = val_u16 & ByteMask
-                    b1 = (val_u16 >> np.uint16(8)) & ByteMask
-                    u8_data.extend([b0, b1])
-            elif self.dtype == DType.INT32:
-                for val in self.data:
-                    val_u32 = np.array(val).astype(dtype=np.uint32)
-                    b0 = val_u32 & ByteMask
-                    b1 = (val_u32 >> np.uint32(8)) & ByteMask
-                    b2 = (val_u32 >> np.uint32(16)) & ByteMask
-                    b3 = (val_u32 >> np.uint32(24)) & ByteMask
-                    u8_data.extend([b0, b1, b2, b3])
-            elif self.dtype == DType.INT48:
-                for val in self.data:
-                    val_u64 = np.uint64(val)
-                    b0 = val_u64 & ByteMask
-                    b1 = (val_u64 >> np.uint64(8)) & ByteMask
-                    b2 = (val_u64 >> np.uint64(16)) & ByteMask
-                    b3 = (val_u64 >> np.uint64(24)) & ByteMask
-                    b4 = (val_u64 >> np.uint64(32)) & ByteMask
-                    b5 = (val_u64 >> np.uint64(40)) & ByteMask
-                    u8_data.extend([b0, b1, b2, b3, b4, b5])
-            elif self.dtype == DType.SHAPE:
-                for val in self.data:
-                    val_u64 = np.uint64(val)
-                    b0 = val_u64 & ByteMask
-                    b1 = (val_u64 >> np.uint64(8)) & ByteMask
-                    b2 = (val_u64 >> np.uint64(16)) & ByteMask
-                    b3 = (val_u64 >> np.uint64(24)) & ByteMask
-                    b4 = (val_u64 >> np.uint64(32)) & ByteMask
-                    b5 = (val_u64 >> np.uint64(40)) & ByteMask
-                    b6 = (val_u64 >> np.uint64(48)) & ByteMask
-                    b7 = (val_u64 >> np.uint64(56)) & ByteMask
-                    u8_data.extend([b0, b1, b2, b3, b4, b5, b6, b7])
-            elif self.dtype == DType.FP16:
-                np_arr = np.array(self.data, dtype=np.float16)
-                u8_data.extend(np_arr.view(np.uint8))
-            elif (
-                self.dtype == DType.FP32
-                or self.dtype == DType.BF16
-                or self.dtype == DType.FP8E4M3
-                or self.dtype == DType.FP8E5M2
-            ):
-                # for val in self.data:
-                #     b = struct.pack("!f", val)
-                #     u8_data.extend([b[3], b[2], b[1], b[0]])
-                np_arr = np.array(self.data, dtype=np.float32)
-                u8_data.extend(np_arr.view(np.uint8))
-            elif self.dtype == TosaDType.DType:
-                # Serialize DType enum data as uint8 bytes
-                for val in self.data:
-                    np_arr = np.array(self.data, dtype=np.uint32)
-                    u8_data.extend(np_arr.view(np.uint8))
-            else:
-                raise Exception(
-                    "unsupported data type {}".format(DTypeNames[self.dtype])
-                )
+            u8_data = TosaSerializer.convertDataToUint8Vec(self.dtype, self.data)
             fb_data = TosaSerializer.serializeUint8Vec(builder, u8_data)
 
         TosaTensor.Start(builder)
@@ -958,3 +883,105 @@
             return val
         else:
             return [val]
+
+    @staticmethod
+    def convertDataToUint8Vec(dtype, data):
+        u8_data = list()
+        # little endianess
+        if dtype == DType.BOOL:
+            for val in data:
+                val_u8 = np.uint8(val)
+                u8_data.append(val_u8)
+        elif dtype == DType.INT4:
+            in_size = len(data)
+            out_size = (in_size + 1) // 2
+            for i in range(out_size):
+                val_0 = data[2 * i]
+                if (2 * i + 1) < in_size:
+                    val_1 = data[2 * i + 1]
+                else:
+                    val_1 = 0
+                val_i8 = (val_0 & 0xF) | ((val_1 & 0xF) << 4)
+                val_u8 = np.uint8(val_i8)
+                u8_data.append(val_u8)
+        elif dtype == DType.INT8:
+            for val in data:
+                val_u8 = np.array(val).astype(dtype=np.uint8)
+                u8_data.append(val_u8)
+        elif dtype == DType.INT16:
+            for val in data:
+                val_u16 = np.array(val).astype(dtype=np.uint16)
+                b0 = val_u16 & ByteMask
+                b1 = (val_u16 >> np.uint16(8)) & ByteMask
+                u8_data.extend([b0, b1])
+        elif dtype == DType.INT32:
+            for val in data:
+                val_u32 = np.array(val).astype(dtype=np.uint32)
+                b0 = val_u32 & ByteMask
+                b1 = (val_u32 >> np.uint32(8)) & ByteMask
+                b2 = (val_u32 >> np.uint32(16)) & ByteMask
+                b3 = (val_u32 >> np.uint32(24)) & ByteMask
+                u8_data.extend([b0, b1, b2, b3])
+        elif dtype == DType.INT48:
+            for val in data:
+                val_u64 = np.uint64(val)
+                b0 = val_u64 & ByteMask
+                b1 = (val_u64 >> np.uint64(8)) & ByteMask
+                b2 = (val_u64 >> np.uint64(16)) & ByteMask
+                b3 = (val_u64 >> np.uint64(24)) & ByteMask
+                b4 = (val_u64 >> np.uint64(32)) & ByteMask
+                b5 = (val_u64 >> np.uint64(40)) & ByteMask
+                u8_data.extend([b0, b1, b2, b3, b4, b5])
+        elif dtype == DType.SHAPE:
+            for val in data:
+                val_u64 = np.uint64(val)
+                b0 = val_u64 & ByteMask
+                b1 = (val_u64 >> np.uint64(8)) & ByteMask
+                b2 = (val_u64 >> np.uint64(16)) & ByteMask
+                b3 = (val_u64 >> np.uint64(24)) & ByteMask
+                b4 = (val_u64 >> np.uint64(32)) & ByteMask
+                b5 = (val_u64 >> np.uint64(40)) & ByteMask
+                b6 = (val_u64 >> np.uint64(48)) & ByteMask
+                b7 = (val_u64 >> np.uint64(56)) & ByteMask
+                u8_data.extend([b0, b1, b2, b3, b4, b5, b6, b7])
+        elif dtype == DType.FP16:
+            np_arr = np.array(data, dtype=np.float16)
+            u8_data.extend(np_arr.view(np.uint8))
+        elif dtype == DType.FP32:
+            # for val in data:
+            #     b = struct.pack("!f", val)
+            #     u8_data.extend([b[3], b[2], b[1], b[0]])
+            np_arr = np.array(data, dtype=np.float32)
+            u8_data.extend(np_arr.view(np.uint8))
+        elif dtype == DType.BF16:
+            for val in data:
+                # convert val to little endian byte arrays b
+                b = struct.pack("<f", val)
+                # val => [ b[3], b[2], b[1], b[0] ]
+                # keep only most significant 2 bytes for bf16
+                # in little endian ordering
+                u8_data.extend([b[2], b[3]])
+        elif dtype == DType.FP8E4M3:
+            for val in data:
+                # convert val to fp8_bits then to single byte
+                f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0]
+                f32_bits = f"{f32_as_int:032b}"
+                fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12]
+                fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little")
+                u8_data.extend(fp8_bytes)
+        elif dtype == DType.FP8E5M2:
+            for val in data:
+                # convert val to fp8_bits then to single byte
+                f32_as_int = struct.unpack(">L", struct.pack(">f", val))[0]
+                f32_bits = f"{f32_as_int:032b}"
+                fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11]
+                fp8_bytes = int(fp8_bits, 2).to_bytes(1, byteorder="little")
+                u8_data.extend(fp8_bytes)
+        elif dtype == TosaDType.DType:
+            # Serialize DType enum data as uint8 bytes
+            for val in data:
+                np_arr = np.array(data, dtype=np.uint32)
+                u8_data.extend(np_arr.view(np.uint8))
+        else:
+            raise Exception("unsupported data type {}".format(DTypeNames[dtype]))
+        return u8_data