Simplify overloaded writeToNpyFiles and readFromNpyFiles

templatize these functions instead to reduce redundant code.

Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Change-Id: Ie8b6f7d2b489c3508fea72481ce38f0db6d0c490
diff --git a/include/numpy_utils.h b/include/numpy_utils.h
index e9c4bb4..83fbd5c 100644
--- a/include/numpy_utils.h
+++ b/include/numpy_utils.h
@@ -40,56 +40,86 @@
         DATA_TYPE_NOT_SUPPORTED,
     };
 
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* databuf);
+    template <typename T>
+    static const char* getDTypeString(bool& is_bool)
+    {
+        is_bool = false;
+        if (std::is_same<T, bool>::value)
+        {
+            is_bool = true;
+            return "'|b1'";
+        }
+        if (std::is_same<T, uint8_t>::value)
+        {
+            return "'|u1'";
+        }
+        if (std::is_same<T, int8_t>::value)
+        {
+            return "'|i1'";
+        }
+        if (std::is_same<T, uint16_t>::value)
+        {
+            return "'<u2'";
+        }
+        if (std::is_same<T, int16_t>::value)
+        {
+            return "'<i2'";
+        }
+        if (std::is_same<T, int32_t>::value)
+        {
+            return "'<i4'";
+        }
+        if (std::is_same<T, int64_t>::value)
+        {
+            return "'<i8'";
+        }
+        if (std::is_same<T, float>::value)
+        {
+            return "'<f4'";
+        }
+        if (std::is_same<T, double>::value)
+        {
+            return "'<f8'";
+        }
+        if (std::is_same<T, half_float::half>::value)
+        {
+            return "'<f2'";
+        }
+        assert(false && "unsupported Dtype");
+    };
 
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, double* databuf);
+    template <typename T>
+    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const T* databuf)
+    {
+        std::vector<int32_t> shape = { static_cast<int32_t>(elems) };
+        return writeToNpyFile(filename, shape, databuf);
+    }
 
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf);
+    template <typename T>
+    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const T* databuf)
+    {
+        bool is_bool;
+        const char* dtype_str = getDTypeString<T>(is_bool);
+        return writeToNpyFileCommon(filename, dtype_str, sizeof(T), shape, databuf, is_bool);
+    }
 
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf);
+    template <typename T>
+    static NPError readFromNpyFile(const char* filename, const uint32_t elems, T* databuf)
+    {
+        bool is_bool;
+        const char* dtype_str = getDTypeString<T>(is_bool);
+        return readFromNpyFileCommon(filename, dtype_str, sizeof(T), elems, databuf, is_bool);
+    }
 
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf);
-
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf);
-
-    static NPError
-        writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const half_float::half* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint8_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const uint8_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int8_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int8_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint16_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const uint16_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int16_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int16_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf);
+    template <typename D, typename S>
+    static void copyBufferByElement(D* dest_buf, S* src_buf, int num)
+    {
+        static_assert(sizeof(D) >= sizeof(S));
+        for (int i = 0; i < num; ++i)
+        {
+            dest_buf[i] = src_buf[i];
+        }
+    }
 
 private:
     static NPError writeToNpyFileCommon(const char* filename,
diff --git a/src/numpy_utils.cpp b/src/numpy_utils.cpp
index 0002fd9..64460bd 100644
--- a/src/numpy_utils.cpp
+++ b/src/numpy_utils.cpp
@@ -16,6 +16,7 @@
 #include "numpy_utils.h"
 #include "half.hpp"
 #include <algorithm>
+
 // Magic NUMPY header
 static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
 static const int NUMPY_HEADER_SZ     = 128;
@@ -24,20 +25,10 @@
 // Offset for NUMPY header desc dictionary string
 static const int NUMPY_HEADER_DESC_OFFSET = 8;
 
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
-{
-    const char dtype_str[] = "'|b1'";
-    return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true);
-}
-
+// This is an entry function for reading 8-/16-/32-bit npy file.
+template <>
 NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
 {
-    const char dtype_str_uint8[]  = "'|u1'";
-    const char dtype_str_int8[]   = "'|i1'";
-    const char dtype_str_uint16[] = "'<u2'";
-    const char dtype_str_int16[]  = "'<i2'";
-    const char dtype_str_int32[]  = "'<i4'";
-
     FILE* infile = nullptr;
     NPError rc   = HEADER_PARSE_ERROR;
     assert(filename);
@@ -49,91 +40,58 @@
         return FILE_NOT_FOUND;
     }
 
-    bool is_signed = false;
-    int bit_length;
+    bool is_signed      = false;
+    int length_per_byte = 0;
     char byte_order;
-    rc = getHeader(infile, is_signed, bit_length, byte_order);
+    rc = getHeader(infile, is_signed, length_per_byte, byte_order);
     if (rc != NO_ERROR)
         return rc;
 
-    switch (bit_length)
+    switch (length_per_byte)
     {
-        case 1:    // 8-bit
+        case 1:
             if (is_signed)
             {
-                // int8
-                int8_t* i8databuf = nullptr;
-                i8databuf         = (int8_t*)calloc(sizeof(i8databuf), elems);
-
-                rc = readFromNpyFileCommon(filename, dtype_str_int8, sizeof(int8_t), elems, i8databuf, false);
-
-                for (unsigned i = 0; i < elems; ++i)
-                {
-                    databuf[i] = (int32_t)i8databuf[i];
-                }
-                free(i8databuf);
-
-                return rc;
+                int8_t* tmp_buf = new int8_t[elems];
+                rc              = readFromNpyFile<int8_t>(filename, elems, tmp_buf);
+                copyBufferByElement(databuf, tmp_buf, elems);
+                free(tmp_buf);
             }
             else
             {
-                // uint8
-                uint8_t* ui8databuf = nullptr;
-                ui8databuf          = (uint8_t*)calloc(sizeof(ui8databuf), elems);
-
-                rc = readFromNpyFileCommon(filename, dtype_str_uint8, sizeof(uint8_t), elems, ui8databuf, false);
-
-                for (unsigned i = 0; i < elems; ++i)
-                {
-                    databuf[i] = (int32_t)ui8databuf[i];
-                }
-                free(ui8databuf);
+                uint8_t* tmp_buf = new uint8_t[elems];
+                rc               = readFromNpyFile<uint8_t>(filename, elems, tmp_buf);
+                copyBufferByElement(databuf, tmp_buf, elems);
+                free(tmp_buf);
             }
             break;
-        case 2:    // 16-bit
+        case 2:
             if (is_signed)
             {
-                // int16
-                int16_t* i16databuf = nullptr;
-                i16databuf          = (int16_t*)calloc(sizeof(i16databuf), elems);
-
-                rc = readFromNpyFileCommon(filename, dtype_str_int16, sizeof(int16_t), elems, i16databuf, false);
-
-                for (unsigned i = 0; i < elems; ++i)
-                {
-                    databuf[i] = (int32_t)i16databuf[i];
-                }
-                free(i16databuf);
-
-                return rc;
+                int16_t* tmp_buf = new int16_t[elems];
+                rc               = readFromNpyFile<int16_t>(filename, elems, tmp_buf);
+                copyBufferByElement(databuf, tmp_buf, elems);
+                free(tmp_buf);
             }
             else
             {
-                // uint16
-                uint16_t* ui16databuf = nullptr;
-                ui16databuf           = (uint16_t*)calloc(sizeof(ui16databuf), elems);
-
-                rc = readFromNpyFileCommon(filename, dtype_str_uint16, sizeof(uint16_t), elems, ui16databuf, false);
-
-                for (unsigned i = 0; i < elems; ++i)
-                {
-                    databuf[i] = (int32_t)ui16databuf[i];
-                }
-                free(ui16databuf);
-
-                return rc;
+                uint16_t* tmp_buf = new uint16_t[elems];
+                rc                = readFromNpyFile<uint16_t>(filename, elems, tmp_buf);
+                copyBufferByElement(databuf, tmp_buf, elems);
+                free(tmp_buf);
             }
             break;
-        case 4:    // 32-bit
+        case 4:
             if (is_signed)
             {
-                // int32
-                return readFromNpyFileCommon(filename, dtype_str_int32, sizeof(int32_t), elems, databuf, false);
+                bool is_bool;
+                const char* dtype_str = getDTypeString<int32_t>(is_bool);
+                rc = readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, is_bool);
             }
             else
             {
                 // uint32, not supported
-                return DATA_TYPE_NOT_SUPPORTED;
+                rc = DATA_TYPE_NOT_SUPPORTED;
             }
             break;
         default:
@@ -144,31 +102,6 @@
     return rc;
 }
 
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
-{
-    const char dtype_str[] = "'<i8'";
-    return readFromNpyFileCommon(filename, dtype_str, sizeof(int64_t), elems, databuf, false);
-}
-
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf)
-{
-    const char dtype_str[] = "'<f4'";
-    return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false);
-}
-
-NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, double* databuf)
-{
-    const char dtype_str[] = "'<f8'";
-    return readFromNpyFileCommon(filename, dtype_str, sizeof(double), elems, databuf, false);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf)
-{
-    const char dtype_str[] = "'<f2'";
-    return readFromNpyFileCommon(filename, dtype_str, sizeof(half_float::half), elems, databuf, false);
-}
-
 NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename,
                                                               const char* dtype_str,
                                                               const size_t elementsize,
@@ -418,138 +351,6 @@
     return rc;
 }
 
-NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf)
-{
-    std::vector<int32_t> shape = { (int32_t)elems };
-    return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf)
-{
-    const char dtype_str[] = "'|b1'";
-    return writeToNpyFileCommon(filename, dtype_str, 1, shape, databuf, true);    // bools written as size 1
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const uint8_t* databuf)
-{
-    std::vector<int32_t> shape = { (int32_t)elems };
-    return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint8_t* databuf)
-{
-    const char dtype_str[] = "'|u1'";
-    return writeToNpyFileCommon(filename, dtype_str, sizeof(uint8_t), shape, databuf, false);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int8_t* databuf)
-{
-    std::vector<int32_t> shape = { (int32_t)elems };
-    return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int8_t* databuf)
-{
-    const char dtype_str[] = "'|i1'";
-    return writeToNpyFileCommon(filename, dtype_str, sizeof(int8_t), shape, databuf, false);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const uint16_t* databuf)
-{
-    std::vector<int32_t> shape = { (int32_t)elems };
-    return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint16_t* databuf)
-{
-    const char dtype_str[] = "'<u2'";
-    return writeToNpyFileCommon(filename, dtype_str, sizeof(uint16_t), shape, databuf, false);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int16_t* databuf)
-{
-    std::vector<int32_t> shape = { (int32_t)elems };
-    return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int16_t* databuf)
-{
-    const char dtype_str[] = "'<i2'";
-    return writeToNpyFileCommon(filename, dtype_str, sizeof(int16_t), shape, databuf, false);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf)
-{
-    std::vector<int32_t> shape = { (int32_t)elems };
-    return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf)
-{
-    const char dtype_str[] = "'<i4'";
-    return writeToNpyFileCommon(filename, dtype_str, sizeof(int32_t), shape, databuf, false);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf)
-{
-    std::vector<int32_t> shape = { (int32_t)elems };
-    return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf)
-{
-    const char dtype_str[] = "'<i8'";
-    return writeToNpyFileCommon(filename, dtype_str, sizeof(int64_t), shape, databuf, false);
-}
-
-NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf)
-{
-    std::vector<int32_t> shape = { (int32_t)elems };
-    return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf)
-{
-    const char dtype_str[] = "'<f4'";
-    return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf)
-{
-    std::vector<int32_t> shape = { (int32_t)elems };
-    return writeToNpyFile(filename, shape, databuf);
-}
-
-NumpyUtilities::NPError
-    NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf)
-{
-    const char dtype_str[] = "'<f8'";
-    return writeToNpyFileCommon(filename, dtype_str, sizeof(double), shape, databuf, false);
-}
-
-NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename,
-                                                       const std::vector<int32_t>& shape,
-                                                       const half_float::half* databuf)
-{
-    const char dtype_str[] = "'<f2'";
-    return writeToNpyFileCommon(filename, dtype_str, sizeof(half_float::half), shape, databuf, false);
-}
-
 NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename,
                                                              const char* dtype_str,
                                                              const size_t elementsize,