Support reading anydtype into a 32-bit buffer

Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: Ic6b43539fcb2d75c5614d3addccd24a06e9f2a31
diff --git a/include/numpy_utils.h b/include/numpy_utils.h
index 29d7e11..e9c4bb4 100644
--- a/include/numpy_utils.h
+++ b/include/numpy_utils.h
@@ -1,5 +1,5 @@
 
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
 //
 //    Licensed under the Apache License, Version 2.0 (the "License");
 //    you may not use this file except in compliance with the License.
@@ -37,6 +37,7 @@
         FILE_TYPE_MISMATCH,
         HEADER_PARSE_ERROR,
         BUFFER_SIZE_MISMATCH,
+        DATA_TYPE_NOT_SUPPORTED,
     };
 
     static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* databuf);
@@ -45,14 +46,6 @@
 
     static NPError readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf);
 
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, uint8_t* databuf);
-
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, int8_t* databuf);
-
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, uint16_t* databuf);
-
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, int16_t* databuf);
-
     static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf);
 
     static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf);
@@ -112,6 +105,7 @@
                                          void* databuf,
                                          bool bool_translate);
     static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str);
+    static NPError getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order);
     static NPError writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str);
 };
 
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp
index 65d76e3..d31ec1c 100644
--- a/src/numpy_utils.cpp
+++ b/src/numpy_utils.cpp
@@ -1,5 +1,5 @@
 
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
 //
 //    Licensed under the Apache License, Version 2.0 (the "License");
 //    you may not use this file except in compliance with the License.
@@ -15,12 +15,14 @@
 
 #include "numpy_utils.h"
 #include "half.hpp"
-
+#include <algorithm>
 // Magic NUMPY header
 static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
 static const int NUMPY_HEADER_SZ     = 128;
 // Maximum shape dimensions supported
 static const int NUMPY_MAX_DIMS_SUPPORTED = 10;
+// Offset for NUMPY header desc dictionary string
+static const int NUMPY_HEADER_DESC_OFFSET = 8;
 
 NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
 {
@@ -28,34 +30,118 @@
     return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true);
 }
 
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, uint8_t* databuf)
-{
-    const char dtype_str[] = "'|u1'";
-    return readFromNpyFileCommon(filename, dtype_str, sizeof(uint8_t), elems, databuf, false);
-}
-
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int8_t* databuf)
-{
-    const char dtype_str[] = "'|i1'";
-    return readFromNpyFileCommon(filename, dtype_str, sizeof(int8_t), elems, databuf, false);
-}
-
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, uint16_t* databuf)
-{
-    const char dtype_str[] = "'<u2'";
-    return readFromNpyFileCommon(filename, dtype_str, sizeof(uint16_t), elems, databuf, false);
-}
-
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int16_t* databuf)
-{
-    const char dtype_str[] = "'<i2'";
-    return readFromNpyFileCommon(filename, dtype_str, sizeof(int16_t), elems, databuf, false);
-}
-
 NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
 {
-    const char dtype_str[] = "'<i4'";
-    return readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, false);
+    const char dtype_str_uint8[]  = "'|u1'";
+    const char dtype_str_int8[]   = "'|i1'";
+    const char dtype_str_uint16[] = "'<u2'";
+    const char dtype_str_int16[]  = "'<i2'";
+    const char dtype_str_int32[]  = "'<i4'";
+
+    FILE* infile = nullptr;
+    NPError rc   = HEADER_PARSE_ERROR;
+    assert(filename);
+    assert(databuf);
+
+    infile = fopen(filename, "rb");
+    if (!infile)
+    {
+        return FILE_NOT_FOUND;
+    }
+
+    bool is_signed = false;
+    int bit_length;
+    char byte_order;
+    rc = getHeader(infile, is_signed, bit_length, byte_order);
+    if (rc != NO_ERROR)
+        return rc;
+
+    switch (bit_length)
+    {
+        case 1:    // 8-bit
+            if (is_signed)
+            {
+                // int8
+                int8_t* i8databuf = nullptr;
+                i8databuf         = (int8_t*)calloc(sizeof(i8databuf), elems);
+
+                rc = readFromNpyFileCommon(filename, dtype_str_int8, sizeof(int8_t), elems, i8databuf, false);
+
+                for (unsigned i = 0; i < elems; ++i)
+                {
+                    databuf[i] = (int32_t)i8databuf[i];
+                }
+                free(i8databuf);
+
+                return rc;
+            }
+            else
+            {
+                // uint8
+                uint8_t* ui8databuf = nullptr;
+                ui8databuf          = (uint8_t*)calloc(sizeof(ui8databuf), elems);
+
+                rc = readFromNpyFileCommon(filename, dtype_str_uint8, sizeof(uint8_t), elems, ui8databuf, false);
+
+                for (unsigned i = 0; i < elems; ++i)
+                {
+                    databuf[i] = (int32_t)ui8databuf[i];
+                }
+                free(ui8databuf);
+            }
+            break;
+        case 2:    // 16-bit
+            if (is_signed)
+            {
+                // int16
+                int16_t* i16databuf = nullptr;
+                i16databuf          = (int16_t*)calloc(sizeof(i16databuf), elems);
+
+                rc = readFromNpyFileCommon(filename, dtype_str_int16, sizeof(int16_t), elems, i16databuf, false);
+
+                for (unsigned i = 0; i < elems; ++i)
+                {
+                    databuf[i] = (int32_t)i16databuf[i];
+                }
+                free(i16databuf);
+
+                return rc;
+            }
+            else
+            {
+                // uint16
+                uint16_t* ui16databuf = nullptr;
+                ui16databuf           = (uint16_t*)calloc(sizeof(ui16databuf), elems);
+
+                rc = readFromNpyFileCommon(filename, dtype_str_uint16, sizeof(uint16_t), elems, ui16databuf, false);
+
+                for (unsigned i = 0; i < elems; ++i)
+                {
+                    databuf[i] = (int32_t)ui16databuf[i];
+                }
+                free(ui16databuf);
+
+                return rc;
+            }
+            break;
+        case 4:    // 32-bit
+            if (is_signed)
+            {
+                // int32
+                return readFromNpyFileCommon(filename, dtype_str_int32, sizeof(int32_t), elems, databuf, false);
+            }
+            else
+            {
+                // uint32, not supported
+                return DATA_TYPE_NOT_SUPPORTED;
+            }
+            break;
+        default:
+            return DATA_TYPE_NOT_SUPPORTED;
+            break;
+    }
+
+    return rc;
 }
 
 NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
@@ -139,6 +225,50 @@
     return rc;
 }
 
+NumpyUtilities::NPError NumpyUtilities::getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order)
+{
+    char buf[NUMPY_HEADER_SZ + 1];
+    NPError rc = NO_ERROR;
+    assert(infile);
+
+    if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
+    {
+        return HEADER_PARSE_ERROR;
+    }
+    char* ptr;
+    ptr = buf + sizeof(NUMPY_HEADER_STR) - 1;
+
+    std::string dic_string(ptr);
+    auto descr_loc = dic_string.find("descr");
+
+    // Reference: https://en.cppreference.com/w/cpp/algorithm/remove
+    // remove all the white spaces for the following offset NUMPY_HEADER_DESC_OFFSET to work
+    dic_string.erase(
+        std::remove_if(dic_string.begin(), dic_string.end(), [](unsigned char x) { return std::isspace(x); }),
+        dic_string.end());
+    // The dic_string is constant: descr': ', add a offset of NUMPY_HEADER_DESC_OFFSET
+    // to the actual dtype string station
+    dic_string = dic_string.substr(descr_loc + NUMPY_HEADER_DESC_OFFSET, 3);
+
+    // Fill byte_order;
+    char byte_order_c[1];
+    strcpy(byte_order_c, dic_string.substr(0, 1).c_str());
+    byte_order = byte_order_c[0];
+
+    // Fill is_signed
+    char is_signed_c[1];
+    strcpy(is_signed_c, dic_string.substr(1, 1).c_str());
+    is_signed = is_signed_c[0] == 'u' ? false : true;
+
+    // Fill bit_length
+    char bit_length_c[1];
+    strcpy(bit_length_c, dic_string.substr(2, 1).c_str());
+    bit_length = (int)(bit_length_c[0] - '0');
+
+    rewind(infile);
+    return rc;
+}
+
 NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
 {
     char buf[NUMPY_HEADER_SZ + 1];