blob: 4860df2275012a2e53b1db9af82e10b2b6a6f873 [file] [log] [blame]
Anthony Barbiera3adb3a2017-09-13 16:03:39 +01001/*
2 Copyright 2017 Leon Merten Lohse
3
4 Permission is hereby granted, free of charge, to any person obtaining a copy
5 of this software and associated documentation files (the "Software"), to deal
6 in the Software without restriction, including without limitation the rights
7 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 copies of the Software, and to permit persons to whom the Software is
9 furnished to do so, subject to the following conditions:
10
11 The above copyright notice and this permission notice shall be included in
12 all copies or substantial portions of the Software.
13
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20 SOFTWARE.
21*/
22
23#include <complex>
24#include <fstream>
25#include <string>
26#include <sstream>
27#include <cstdint>
28#include <vector>
29#include <endian.h>
30#include <typeinfo>
31#include <typeindex>
32#include <stdexcept>
33#include <algorithm>
34#include <map>
35#include <regex>
36
37namespace npy {
38namespace {
39/** Convert integer and float values to string. Reference: support/ToolchainSupport.h
40 *
41 * @note This function implements the same behaviour as to_string. The
42 * latter is missing in some Android toolchains.
43 *
44 * @param[in] value Value to be converted to string.
45 *
46 * @return String representation of @p value.
47 */
48template <typename T, typename std::enable_if<std::is_arithmetic<typename std::decay<T>::type>::value, int>::type = 0>
49inline std::string to_string(T && value)
50{
51 std::stringstream stream;
52 stream << std::forward<T>(value);
53 return stream.str();
54}
55}
56
57const char magic_string[] = "\x93NUMPY";
58const size_t magic_string_length = 6;
59
60const unsigned char little_endian_char = '<';
61const unsigned char big_endian_char = '>';
62const unsigned char no_endian_char = '|';
63
64// check if host is little endian
65inline bool isle(void) {
66 unsigned int i = 1;
67 char *c = (char*)&i;
68 if (*c)
69 return true;
70 else
71 return false;
72}
73
74inline void write_magic(std::ostream& ostream, unsigned char v_major=1, unsigned char v_minor=0) {
75 ostream.write(magic_string, magic_string_length);
76 ostream.put(v_major);
77 ostream.put(v_minor);
78}
79
80inline void read_magic(std::istream& istream, unsigned char *v_major, unsigned char *v_minor) {
81 char *buf = new char[magic_string_length+2];
82 istream.read(buf, magic_string_length+2);
83
84 if(!istream) {
85 throw std::runtime_error("io error: failed reading file");
86 }
87
88 for (size_t i=0; i < magic_string_length; i++) {
89 if(buf[i] != magic_string[i]) {
90 throw std::runtime_error("this file do not have a valid npy format.");
91 }
92 }
93
94 *v_major = buf[magic_string_length];
95 *v_minor = buf[magic_string_length+1];
96 delete[] buf;
97}
98
99
100
101inline std::string get_typestring(const std::type_info& t) {
102 std::string endianness;
103 std::string no_endianness(no_endian_char, 1);
104 // little endian or big endian?
105 if (isle())
106 endianness = little_endian_char;
107 else
108 endianness = big_endian_char;
109
110 std::map<std::type_index, std::string> map;
111
112 map[std::type_index(typeid(float))] = endianness + "f" + to_string(sizeof(float));
113 map[std::type_index(typeid(double))] = endianness + "f" + to_string(sizeof(double));
114 map[std::type_index(typeid(long double))] = endianness + "f" + to_string(sizeof(long double));
115
116 map[std::type_index(typeid(char))] = no_endianness + "i" + to_string(sizeof(char));
117 map[std::type_index(typeid(short))] = endianness + "i" + to_string(sizeof(short));
118 map[std::type_index(typeid(int))] = endianness + "i" + to_string(sizeof(int));
119 map[std::type_index(typeid(long))] = endianness + "i" + to_string(sizeof(long));
120 map[std::type_index(typeid(long long))] = endianness + "i" + to_string(sizeof(long long));
121
122 map[std::type_index(typeid(unsigned char))] = no_endianness + "u" + to_string(sizeof(unsigned char));
123 map[std::type_index(typeid(unsigned short))] = endianness + "u" + to_string(sizeof(unsigned short));
124 map[std::type_index(typeid(unsigned int))] = endianness + "u" + to_string(sizeof(unsigned int));
125 map[std::type_index(typeid(unsigned long))] = endianness + "u" + to_string(sizeof(unsigned long));
126 map[std::type_index(typeid(unsigned long long))] = endianness + "u" + to_string(sizeof(unsigned long long));
127
128 map[std::type_index(typeid(std::complex<float>))] = endianness + "c" + to_string(sizeof(std::complex<float>));
129 map[std::type_index(typeid(std::complex<double>))] = endianness + "c" + to_string(sizeof(std::complex<double>));
130 map[std::type_index(typeid(std::complex<long double>))] = endianness + "c" + to_string(sizeof(std::complex<long double>));
131
132 if (map.count(std::type_index(t)) > 0)
133 return map[std::type_index(t)];
134 else
135 throw std::runtime_error("unsupported data type");
136}
137
138inline void parse_typestring( std::string typestring){
139 std::regex re ("'([<>|])([ifuc])(\\d+)'");
140 std::smatch sm;
141
142 std::regex_match(typestring, sm, re );
143
144 if ( sm.size() != 4 ) {
145 throw std::runtime_error("invalid typestring");
146 }
147}
148
149inline std::string unwrap_s(std::string s, char delim_front, char delim_back) {
150 if ((s.back() == delim_back) && (s.front() == delim_front))
151 return s.substr(1, s.length()-2);
152 else
153 throw std::runtime_error("unable to unwrap");
154}
155
156inline std::string get_value_from_map(std::string mapstr) {
157 size_t sep_pos = mapstr.find_first_of(":");
158 if (sep_pos == std::string::npos)
159 return "";
160
161 return mapstr.substr(sep_pos+1);
162}
163
164inline void pop_char(std::string& s, char c) {
165 if (s.back() == c)
166 s.pop_back();
167}
168
169inline void ParseHeader(std::string header, std::string& descr, bool *fortran_order, std::vector<unsigned long>& shape) {
170 /*
171 The first 6 bytes are a magic string: exactly "x93NUMPY".
172
173 The next 1 byte is an unsigned byte: the major version number of the file format, e.g. x01.
174
175 The next 1 byte is an unsigned byte: the minor version number of the file format, e.g. x00. Note: the version of the file format is not tied to the version of the numpy package.
176
177 The next 2 bytes form a little-endian unsigned short int: the length of the header data HEADER_LEN.
178
179 The next HEADER_LEN bytes form the header data describing the array's format. It is an ASCII string which contains a Python literal expression of a dictionary. It is terminated by a newline ('n') and padded with spaces ('x20') to make the total length of the magic string + 4 + HEADER_LEN be evenly divisible by 16 for alignment purposes.
180
181 The dictionary contains three keys:
182
183 "descr" : dtype.descr
184 An object that can be passed as an argument to the numpy.dtype() constructor to create the array's dtype.
185 "fortran_order" : bool
186 Whether the array data is Fortran-contiguous or not. Since Fortran-contiguous arrays are a common form of non-C-contiguity, we allow them to be written directly to disk for efficiency.
187 "shape" : tuple of int
188 The shape of the array.
189 For repeatability and readability, this dictionary is formatted using pprint.pformat() so the keys are in alphabetic order.
190 */
191
192 // remove trailing newline
193 if (header.back() != '\n')
194 throw std::runtime_error("invalid header");
195 header.pop_back();
196
197 // remove all whitespaces
198 header.erase(std::remove(header.begin(), header.end(), ' '), header.end());
199
200 // unwrap dictionary
201 header = unwrap_s(header, '{', '}');
202
203 // find the positions of the 3 dictionary keys
204 size_t keypos_descr = header.find("'descr'");
205 size_t keypos_fortran = header.find("'fortran_order'");
206 size_t keypos_shape = header.find("'shape'");
207
208 // make sure all the keys are present
209 if (keypos_descr == std::string::npos)
210 throw std::runtime_error("missing 'descr' key");
211 if (keypos_fortran == std::string::npos)
212 throw std::runtime_error("missing 'fortran_order' key");
213 if (keypos_shape == std::string::npos)
214 throw std::runtime_error("missing 'shape' key");
215
216 // make sure the keys are in order
217 if (keypos_descr >= keypos_fortran || keypos_fortran >= keypos_shape)
218 throw std::runtime_error("header keys in wrong order");
219
220 // get the 3 key-value pairs
221 std::string keyvalue_descr;
222 keyvalue_descr = header.substr(keypos_descr, keypos_fortran - keypos_descr);
223 pop_char(keyvalue_descr, ',');
224
225 std::string keyvalue_fortran;
226 keyvalue_fortran = header.substr(keypos_fortran, keypos_shape - keypos_fortran);
227 pop_char(keyvalue_fortran, ',');
228
229 std::string keyvalue_shape;
230 keyvalue_shape = header.substr(keypos_shape, std::string::npos);
231 pop_char(keyvalue_shape, ',');
232
233 // get the values (right side of `:')
234 std::string descr_s = get_value_from_map(keyvalue_descr);
235 std::string fortran_s = get_value_from_map(keyvalue_fortran);
236 std::string shape_s = get_value_from_map(keyvalue_shape);
237
238 parse_typestring(descr_s);
239 descr = unwrap_s(descr_s, '\'', '\'');
240
241 // convert literal Python bool to C++ bool
242 if (fortran_s == "True")
243 *fortran_order = true;
244 else if (fortran_s == "False")
245 *fortran_order = false;
246 else
247 throw std::runtime_error("invalid fortran_order value");
248
249 // parse the shape Python tuple ( x, y, z,)
250
251 // first clear the vector
252 shape.clear();
253 shape_s = unwrap_s(shape_s, '(', ')');
254
255 // a tokenizer would be nice...
256 size_t pos = 0;
257 size_t pos_next;
258 for(;;) {
259 pos_next = shape_s.find_first_of(',', pos);
260 std::string dim_s;
261 if (pos_next != std::string::npos)
262 dim_s = shape_s.substr(pos, pos_next - pos);
263 else
264 dim_s = shape_s.substr(pos);
265 pop_char(dim_s, ',');
266 if (dim_s.length() == 0) {
267 if (pos_next != std::string::npos)
268 throw std::runtime_error("invalid shape");
269 }else{
270 std::stringstream ss;
271 ss << dim_s;
272 unsigned long tmp;
273 ss >> tmp;
274 shape.push_back(tmp);
275 }
276 if (pos_next != std::string::npos)
277 pos = ++pos_next;
278 else
279 break;
280 }
281}
282
283inline void WriteHeader(std::ostream& out, const std::string& descr, bool fortran_order, unsigned int n_dims, const unsigned long shape[])
284{
285 std::ostringstream ss_header;
286 std::string s_fortran_order;
287 if (fortran_order)
288 s_fortran_order = "True";
289 else
290 s_fortran_order = "False";
291
292 std::ostringstream ss_shape;
293 ss_shape << "(";
294 for (unsigned int n=0; n < n_dims; n++){
295 ss_shape << shape[n] << ", ";
296 }
297 ss_shape << ")";
298
299 ss_header << "{'descr': '" << descr << "', 'fortran_order': " << s_fortran_order << ", 'shape': " << ss_shape.str() << " }";
300
301 size_t header_len_pre = ss_header.str().length() + 1;
302 size_t metadata_len = magic_string_length + 2 + 2 + header_len_pre;
303
304 unsigned char version[2] = {1, 0};
305 if (metadata_len >= 255*255) {
306 metadata_len = magic_string_length + 2 + 4 + header_len_pre;
307 version[0] = 2;
308 version[1] = 0;
309 }
310 size_t padding_len = 16 - metadata_len % 16;
311 std::string padding (padding_len, ' ');
312 ss_header << padding;
313 ss_header << '\n';
314
315 std::string header = ss_header.str();
316
317 // write magic
318 write_magic(out, version[0], version[1]);
319
320 // write header length
321 if (version[0] == 1 && version[1] == 0) {
322 uint16_t header_len_le16 = htole16(header.length());
323 out.write(reinterpret_cast<char *>(&header_len_le16), 2);
324 }else{
325 uint32_t header_len_le32 = htole32(header.length());
326 out.write(reinterpret_cast<char *>(&header_len_le32), 4);
327 }
328
329 out << header;
330}
331
332template<typename Scalar>
333void SaveArrayAsNumpy( const std::string& filename, bool fortran_order, unsigned int n_dims, const unsigned long shape[], const std::vector<Scalar>& data)
334{
335 std::string typestring = get_typestring(typeid(Scalar));
336
337 std::ofstream stream( filename, std::ofstream::binary);
338 if(!stream) {
339 throw std::runtime_error("io error: failed to open a file.");
340 }
341 WriteHeader(stream, typestring, fortran_order, n_dims, shape);
342
343 size_t size = 1;
344 for (unsigned int i=0; i<n_dims; ++i)
345 size *= shape[i];
346 stream.write(reinterpret_cast<const char*>(&data[0]), sizeof(Scalar) * size);
347}
348
349inline std::string read_header_1_0(std::istream& istream) {
350 // read header length and convert from little endian
351 uint16_t header_length_raw;
352 char *header_ptr = reinterpret_cast<char *>(&header_length_raw);
353 istream.read(header_ptr, 2);
354 uint16_t header_length = le16toh(header_length_raw);
355
356 if((magic_string_length + 2 + 2 + header_length) % 16 != 0) {
357 // display warning
358 }
359
360 char *buf = new char[header_length];
361 istream.read(buf, header_length);
362 std::string header (buf, header_length);
363 delete[] buf;
364
365 return header;
366}
367
368inline std::string read_header_2_0(std::istream& istream) {
369 // read header length and convert from little endian
370 uint32_t header_length_raw;
371 char *header_ptr = reinterpret_cast<char *>(&header_length_raw);
372 istream.read(header_ptr, 4);
373 uint32_t header_length = le32toh(header_length_raw);
374
375 if((magic_string_length + 2 + 4 + header_length) % 16 != 0) {
376 // display warning
377 }
378
379 char *buf = new char[header_length];
380 istream.read(buf, header_length);
381 std::string header (buf, header_length);
382 delete[] buf;
383
384 return header;
385}
386
387template<typename Scalar>
388void LoadArrayFromNumpy(const std::string& filename, std::vector<unsigned long>& shape, std::vector<Scalar>& data)
389{
390 std::ifstream stream(filename, std::ifstream::binary);
391 if(!stream) {
392 throw std::runtime_error("io error: failed to open a file.");
393 }
394 // check magic bytes an version number
395 unsigned char v_major, v_minor;
396 read_magic(stream, &v_major, &v_minor);
397
398 std::string header;
399
400 if(v_major == 1 && v_minor == 0){
401 header = read_header_1_0(stream);
402 }else if(v_major == 2 && v_minor == 0) {
403 header = read_header_2_0(stream);
404 }else{
405 throw std::runtime_error("unsupported file format version");
406 }
407
408 // parse header
409 bool fortran_order;
410 std::string typestr;
411
412 ParseHeader(header, typestr, &fortran_order, shape);
413
414 // check if the typestring matches the given one
415 std::string expect_typestr = get_typestring(typeid(Scalar));
416 if (typestr != expect_typestr) {
417 throw std::runtime_error("formatting error: typestrings not matching");
418 }
419
420 // compute the data size based on the shape
421 size_t total_size = 1;
422 for(size_t i=0; i<shape.size(); ++i) {
423 total_size *= shape[i];
424 }
425 data.resize(total_size);
426
427 // read the data
428 stream.read(reinterpret_cast<char*>(&data[0]), sizeof(Scalar)*total_size);
429}
430
431} // namespace npy