blob: 60cf77e3e4ea4b9d09bf19d07d81d03bfb093313 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
Jerry Ge13a32912023-07-03 16:36:41 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunze2364dcd2021-04-26 11:06:57 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef _TOSA_NUMPY_UTILS_H
17#define _TOSA_NUMPY_UTILS_H
18
19#include <cassert>
20#include <cctype>
21#include <cstdint>
22#include <cstdio>
23#include <cstdlib>
24#include <cstring>
25#include <vector>
26
James Ward485a11d2022-08-05 13:48:37 +010027#include "half.hpp"
28
Eric Kunze2364dcd2021-04-26 11:06:57 -070029class NumpyUtilities
30{
31public:
32 enum NPError
33 {
34 NO_ERROR = 0,
35 FILE_NOT_FOUND,
36 FILE_IO_ERROR,
37 FILE_TYPE_MISMATCH,
38 HEADER_PARSE_ERROR,
39 BUFFER_SIZE_MISMATCH,
Jerry Ge13a32912023-07-03 16:36:41 +000040 DATA_TYPE_NOT_SUPPORTED,
Eric Kunze2364dcd2021-04-26 11:06:57 -070041 };
42
TatWai Chong679bdad2023-07-31 15:15:12 -070043 template <typename T>
44 static const char* getDTypeString(bool& is_bool)
45 {
46 is_bool = false;
47 if (std::is_same<T, bool>::value)
48 {
49 is_bool = true;
50 return "'|b1'";
51 }
52 if (std::is_same<T, uint8_t>::value)
53 {
54 return "'|u1'";
55 }
56 if (std::is_same<T, int8_t>::value)
57 {
58 return "'|i1'";
59 }
60 if (std::is_same<T, uint16_t>::value)
61 {
62 return "'<u2'";
63 }
64 if (std::is_same<T, int16_t>::value)
65 {
66 return "'<i2'";
67 }
68 if (std::is_same<T, int32_t>::value)
69 {
70 return "'<i4'";
71 }
72 if (std::is_same<T, int64_t>::value)
73 {
74 return "'<i8'";
75 }
76 if (std::is_same<T, float>::value)
77 {
78 return "'<f4'";
79 }
80 if (std::is_same<T, double>::value)
81 {
82 return "'<f8'";
83 }
84 if (std::is_same<T, half_float::half>::value)
85 {
86 return "'<f2'";
87 }
88 assert(false && "unsupported Dtype");
89 };
Eric Kunze2364dcd2021-04-26 11:06:57 -070090
TatWai Chong679bdad2023-07-31 15:15:12 -070091 template <typename T>
92 static NPError writeToNpyFile(const char* filename, const uint32_t elems, const T* databuf)
93 {
94 std::vector<int32_t> shape = { static_cast<int32_t>(elems) };
95 return writeToNpyFile(filename, shape, databuf);
96 }
Tai Ly3ef34fb2023-04-04 20:34:05 +000097
TatWai Chong679bdad2023-07-31 15:15:12 -070098 template <typename T>
99 static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const T* databuf)
100 {
101 bool is_bool;
102 const char* dtype_str = getDTypeString<T>(is_bool);
103 return writeToNpyFileCommon(filename, dtype_str, sizeof(T), shape, databuf, is_bool);
104 }
James Ward485a11d2022-08-05 13:48:37 +0100105
TatWai Chong679bdad2023-07-31 15:15:12 -0700106 template <typename T>
107 static NPError readFromNpyFile(const char* filename, const uint32_t elems, T* databuf)
108 {
109 bool is_bool;
110 const char* dtype_str = getDTypeString<T>(is_bool);
111 return readFromNpyFileCommon(filename, dtype_str, sizeof(T), elems, databuf, is_bool);
112 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700113
TatWai Chong679bdad2023-07-31 15:15:12 -0700114 template <typename D, typename S>
115 static void copyBufferByElement(D* dest_buf, S* src_buf, int num)
116 {
Won Jeon780ffb52023-08-21 13:32:36 -0700117 static_assert(sizeof(D) >= sizeof(S), "The size of dest_buf must be equal to or larger than that of src_buf");
TatWai Chong679bdad2023-07-31 15:15:12 -0700118 for (int i = 0; i < num; ++i)
119 {
120 dest_buf[i] = src_buf[i];
121 }
122 }
Tai Ly3ef34fb2023-04-04 20:34:05 +0000123
Eric Kunze2364dcd2021-04-26 11:06:57 -0700124private:
125 static NPError writeToNpyFileCommon(const char* filename,
126 const char* dtype_str,
127 const size_t elementsize,
128 const std::vector<int32_t>& shape,
129 const void* databuf,
130 bool bool_translate);
131 static NPError readFromNpyFileCommon(const char* filename,
132 const char* dtype_str,
133 const size_t elementsize,
134 const uint32_t elems,
135 void* databuf,
136 bool bool_translate);
137 static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str);
Jerry Ge13a32912023-07-03 16:36:41 +0000138 static NPError getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700139 static NPError writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str);
140};
141
TatWai Chong39b5edc2023-08-18 17:58:17 +0000142template <>
143NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf);
144
Eric Kunze2364dcd2021-04-26 11:06:57 -0700145#endif // _TOSA_NUMPY_UTILS_H