Simplify overloaded writeToNpyFiles and readFromNpyFiles

templatize these functions instead to reduce redundant code.

Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Change-Id: Ie8b6f7d2b489c3508fea72481ce38f0db6d0c490
diff --git a/include/numpy_utils.h b/include/numpy_utils.h
index e9c4bb4..83fbd5c 100644
--- a/include/numpy_utils.h
+++ b/include/numpy_utils.h
@@ -40,56 +40,86 @@
         DATA_TYPE_NOT_SUPPORTED,
     };
 
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* databuf);
+    template <typename T>
+    static const char* getDTypeString(bool& is_bool)
+    {
+        is_bool = false;
+        if (std::is_same<T, bool>::value)
+        {
+            is_bool = true;
+            return "'|b1'";
+        }
+        if (std::is_same<T, uint8_t>::value)
+        {
+            return "'|u1'";
+        }
+        if (std::is_same<T, int8_t>::value)
+        {
+            return "'|i1'";
+        }
+        if (std::is_same<T, uint16_t>::value)
+        {
+            return "'<u2'";
+        }
+        if (std::is_same<T, int16_t>::value)
+        {
+            return "'<i2'";
+        }
+        if (std::is_same<T, int32_t>::value)
+        {
+            return "'<i4'";
+        }
+        if (std::is_same<T, int64_t>::value)
+        {
+            return "'<i8'";
+        }
+        if (std::is_same<T, float>::value)
+        {
+            return "'<f4'";
+        }
+        if (std::is_same<T, double>::value)
+        {
+            return "'<f8'";
+        }
+        if (std::is_same<T, half_float::half>::value)
+        {
+            return "'<f2'";
+        }
+        assert(false && "unsupported Dtype");
+    };
 
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, double* databuf);
+    template <typename T>
+    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const T* databuf)
+    {
+        std::vector<int32_t> shape = { static_cast<int32_t>(elems) };
+        return writeToNpyFile(filename, shape, databuf);
+    }
 
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf);
+    template <typename T>
+    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const T* databuf)
+    {
+        bool is_bool;
+        const char* dtype_str = getDTypeString<T>(is_bool);
+        return writeToNpyFileCommon(filename, dtype_str, sizeof(T), shape, databuf, is_bool);
+    }
 
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf);
+    template <typename T>
+    static NPError readFromNpyFile(const char* filename, const uint32_t elems, T* databuf)
+    {
+        bool is_bool;
+        const char* dtype_str = getDTypeString<T>(is_bool);
+        return readFromNpyFileCommon(filename, dtype_str, sizeof(T), elems, databuf, is_bool);
+    }
 
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf);
-
-    static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf);
-
-    static NPError
-        writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const half_float::half* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint8_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const uint8_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int8_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int8_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint16_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const uint16_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int16_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int16_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf);
-
-    static NPError writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf);
+    template <typename D, typename S>
+    static void copyBufferByElement(D* dest_buf, S* src_buf, int num)
+    {
+        static_assert(sizeof(D) >= sizeof(S));
+        for (int i = 0; i < num; ++i)
+        {
+            dest_buf[i] = src_buf[i];
+        }
+    }
 
 private:
     static NPError writeToNpyFileCommon(const char* filename,