Use native size of Bfloat16 and Float8 for serialization/deserialization

Signed-off-by: Won Jeon <won.jeon@arm.com>
Change-Id: I0d2075f90988d4fd1139a11b5c154bdd600bb2cd
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp
index e4171d7..7cf5f94 100644
--- a/src/numpy_utils.cpp
+++ b/src/numpy_utils.cpp
@@ -247,6 +247,14 @@
             while (isspace(*ptr))
                 ptr++;
 
+            // ml_dtypes writes '<f1' for 'numpy.dtype' in the header for float8_e5m2, but
+            // default NumPy does not understand this notation, which causes trouble
+            // when other code tries to open this file.
+            // To avoid this, '|u1' notation is used when the file is written, and the uint8
+            // data is viewed as float8_e5m2 later when the file is read.
+            if (!strcmp(dtype_str, "'<f1'"))
+                dtype_str = "'|u1'";
+
             if (strcmp(ptr, dtype_str))
             {
                 return FILE_TYPE_MISMATCH;
@@ -430,6 +438,13 @@
     memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
     headerPos += sizeof(NUMPY_HEADER_STR) - 1;
 
+    // NumPy does not understand float8_e5m2, so change it to uint8 type, so that
+    // Python can read .npy files.
+    if (!strcmp(dtype_str, "'<f1'"))
+    {
+        dtype_str = "'|u1'";
+    }
+
     // Output the format dictionary
     // Hard-coded for I32 for now
     headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos,
@@ -438,7 +453,19 @@
     // Add shape contents (if any - as this will be empty for rank 0)
     for (i = 0; i < shape.size(); i++)
     {
-        headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
+        // Output NumPy file from tosa_refmodel_sut_run generates the shape information
+        // without a trailing comma when the rank is greater than 1.
+        if (i == 0)
+        {
+            if (shape.size() == 1)
+                headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d,", shape[i]);
+            else
+                headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d", shape[i]);
+        }
+        else
+        {
+            headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, ", %d", shape[i]);
+        }
     }
 
     // Close off the dictionary