blob: e4f22159316c4805e8ede69a681ccb3438ec3eff [file] [log] [blame]
Anthony Barbiera3adb3a2017-09-13 16:03:39 +01001/*
2 Copyright 2017 Leon Merten Lohse
3
4 Permission is hereby granted, free of charge, to any person obtaining a copy
5 of this software and associated documentation files (the "Software"), to deal
6 in the Software without restriction, including without limitation the rights
7 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 copies of the Software, and to permit persons to whom the Software is
9 furnished to do so, subject to the following conditions:
10
11 The above copyright notice and this permission notice shall be included in
12 all copies or substantial portions of the Software.
13
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20 SOFTWARE.
21*/
22
Anthony Barbier87f21cd2017-11-10 16:27:32 +000023#ifndef NPY_H
24#define NPY_H
25
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010026#include <complex>
27#include <fstream>
28#include <string>
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010029#include <iostream>
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010030#include <sstream>
31#include <cstdint>
Anthony Barbier87f21cd2017-11-10 16:27:32 +000032#include <cstring>
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010033#include <vector>
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010034#include <stdexcept>
35#include <algorithm>
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010036#include <regex>
Anthony Barbier87f21cd2017-11-10 16:27:32 +000037#include <unordered_map>
Pablo Marquez Telloa774d632022-05-31 19:39:18 +010038#include <type_traits>
39#include <iterator>
40#include <utility>
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010041
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010042namespace npy {
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010043
44/* Compile-time test for byte order.
45 If your compiler does not define these per default, you may want to define
46 one of these constants manually.
47 Defaults to little endian order. */
48#if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || \
49 defined(__BIG_ENDIAN__) || \
50 defined(__ARMEB__) || \
51 defined(__THUMBEB__) || \
52 defined(__AARCH64EB__) || \
53 defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__)
54const bool big_endian = true;
55#else
56const bool big_endian = false;
57#endif
58
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010059
60const char magic_string[] = "\x93NUMPY";
61const size_t magic_string_length = 6;
62
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010063const char little_endian_char = '<';
64const char big_endian_char = '>';
65const char no_endian_char = '|';
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010066
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010067constexpr char host_endian_char = ( big_endian ?
68 big_endian_char :
69 little_endian_char );
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010070
Anthony Barbier87f21cd2017-11-10 16:27:32 +000071/* npy array length */
72typedef unsigned long int ndarray_len_t;
73
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010074inline void write_magic(std::ostream& ostream, unsigned char v_major=1, unsigned char v_minor=0) {
75 ostream.write(magic_string, magic_string_length);
76 ostream.put(v_major);
77 ostream.put(v_minor);
78}
79
Anthony Barbier87f21cd2017-11-10 16:27:32 +000080inline void read_magic(std::istream& istream, unsigned char& v_major, unsigned char& v_minor) {
81 char buf[magic_string_length+2];
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010082 istream.read(buf, magic_string_length+2);
83
84 if(!istream) {
Anthony Barbier87f21cd2017-11-10 16:27:32 +000085 throw std::runtime_error("io error: failed reading file");
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010086 }
87
Anthony Barbier87f21cd2017-11-10 16:27:32 +000088 if (0 != std::memcmp(buf, magic_string, magic_string_length))
89 throw std::runtime_error("this file does not have a valid npy format.");
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010090
Anthony Barbier87f21cd2017-11-10 16:27:32 +000091 v_major = buf[magic_string_length];
92 v_minor = buf[magic_string_length+1];
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010093}
94
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010095// typestring magic
96struct Typestring {
97 private:
98 char c_endian;
99 char c_type;
100 int len;
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100101
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100102 public:
103 inline std::string str() {
104 const size_t max_buflen = 16;
105 char buf[max_buflen];
106 std::sprintf(buf, "%c%c%u", c_endian, c_type, len);
107 return std::string(buf);
108 }
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100109
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000110 Typestring(const std::vector<float>& v)
111 :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(float)} {}
112 Typestring(const std::vector<double>& v)
113 :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(double)} {}
114 Typestring(const std::vector<long double>& v)
115 :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(long double)} {}
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100116
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000117 Typestring(const std::vector<char>& v)
118 :c_endian {no_endian_char}, c_type {'i'}, len {sizeof(char)} {}
119 Typestring(const std::vector<short>& v)
120 :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(short)} {}
121 Typestring(const std::vector<int>& v)
122 :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(int)} {}
123 Typestring(const std::vector<long>& v)
124 :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(long)} {}
125 Typestring(const std::vector<long long>& v) :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(long long)} {}
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100126
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000127 Typestring(const std::vector<unsigned char>& v)
128 :c_endian {no_endian_char}, c_type {'u'}, len {sizeof(unsigned char)} {}
129 Typestring(const std::vector<unsigned short>& v)
130 :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned short)} {}
131 Typestring(const std::vector<unsigned int>& v)
132 :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned int)} {}
133 Typestring(const std::vector<unsigned long>& v)
134 :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned long)} {}
135 Typestring(const std::vector<unsigned long long>& v)
136 :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned long long)} {}
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100137
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000138 Typestring(const std::vector<std::complex<float>>& v)
139 :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<float>)} {}
140 Typestring(const std::vector<std::complex<double>>& v)
141 :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<double>)} {}
142 Typestring(const std::vector<std::complex<long double>>& v)
143 :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<long double>)} {}
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100144};
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100145
146inline void parse_typestring( std::string typestring){
147 std::regex re ("'([<>|])([ifuc])(\\d+)'");
148 std::smatch sm;
149
150 std::regex_match(typestring, sm, re );
151
152 if ( sm.size() != 4 ) {
153 throw std::runtime_error("invalid typestring");
154 }
155}
156
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000157namespace pyparse {
158
159/**
160 Removes leading and trailing whitespaces
161 */
162inline std::string trim(const std::string& str) {
163 const std::string whitespace = " \t";
164 auto begin = str.find_first_not_of(whitespace);
165
166 if (begin == std::string::npos)
167 return "";
168
169 auto end = str.find_last_not_of(whitespace);
170
171 return str.substr(begin, end-begin+1);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100172}
173
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000174
175inline std::string get_value_from_map(const std::string& mapstr) {
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100176 size_t sep_pos = mapstr.find_first_of(":");
177 if (sep_pos == std::string::npos)
178 return "";
179
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000180 std::string tmp = mapstr.substr(sep_pos+1);
181 return trim(tmp);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100182}
183
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000184/**
185 Parses the string representation of a Python dict
186
187 The keys need to be known and may not appear anywhere else in the data.
188 */
189inline std::unordered_map<std::string, std::string> parse_dict(std::string in, std::vector<std::string>& keys) {
190
191 std::unordered_map<std::string, std::string> map;
192
193 if (keys.size() == 0)
194 return map;
195
196 in = trim(in);
197
198 // unwrap dictionary
199 if ((in.front() == '{') && (in.back() == '}'))
200 in = in.substr(1, in.length()-2);
201 else
202 throw std::runtime_error("Not a Python dictionary.");
203
204 std::vector<std::pair<size_t, std::string>> positions;
205
206 for (auto const& value : keys) {
207 size_t pos = in.find( "'" + value + "'" );
208
209 if (pos == std::string::npos)
210 throw std::runtime_error("Missing '"+value+"' key.");
211
212 std::pair<size_t, std::string> position_pair { pos, value };
213 positions.push_back(position_pair);
214 }
215
216 // sort by position in dict
217 std::sort(positions.begin(), positions.end() );
218
219 for(size_t i = 0; i < positions.size(); ++i) {
220 std::string raw_value;
221 size_t begin { positions[i].first };
222 size_t end { std::string::npos };
223
224 std::string key = positions[i].second;
225
226 if ( i+1 < positions.size() )
227 end = positions[i+1].first;
228
229 raw_value = in.substr(begin, end-begin);
230
231 raw_value = trim(raw_value);
232
233 if (raw_value.back() == ',')
234 raw_value.pop_back();
235
236 map[key] = get_value_from_map(raw_value);
237 }
238
239 return map;
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100240}
241
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000242/**
243 Parses the string representation of a Python boolean
244 */
245inline bool parse_bool(const std::string& in) {
246 if (in == "True")
247 return true;
248 if (in == "False")
249 return false;
250
251 throw std::runtime_error("Invalid python boolan.");
252}
253
254/**
255 Parses the string representation of a Python str
256 */
257inline std::string parse_str(const std::string& in) {
258 if ((in.front() == '\'') && (in.back() == '\''))
259 return in.substr(1, in.length()-2);
260
261 throw std::runtime_error("Invalid python string.");
262}
263
264/**
265 Parses the string represenatation of a Python tuple into a vector of its items
266 */
267inline std::vector<std::string> parse_tuple(std::string in) {
268 std::vector<std::string> v;
269 const char seperator = ',';
270
271 in = trim(in);
272
273 if ((in.front() == '(') && (in.back() == ')'))
274 in = in.substr(1, in.length()-2);
275 else
276 throw std::runtime_error("Invalid Python tuple.");
277
278 std::istringstream iss(in);
279
280 for (std::string token; std::getline(iss, token, seperator);) {
281 v.push_back(token);
282 }
283
284 return v;
285}
286
287template <typename T>
288inline std::string write_tuple(const std::vector<T>& v) {
289 if (v.size() == 0)
290 return "";
291
292 std::ostringstream ss;
293
294 if (v.size() == 1) {
295 ss << "(" << v.front() << ",)";
296 } else {
297 const std::string delimiter = ", ";
298 // v.size() > 1
299 ss << "(";
300 std::copy(v.begin(), v.end()-1, std::ostream_iterator<T>(ss, delimiter.c_str()));
301 ss << v.back();
302 ss << ")";
303 }
304
305 return ss.str();
306}
307
308inline std::string write_boolean(bool b) {
309 if(b)
310 return "True";
311 else
312 return "False";
313}
314
315} // namespace pyparse
316
317
318inline void parse_header(std::string header, std::string& descr, bool& fortran_order, std::vector<ndarray_len_t>& shape) {
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100319 /*
320 The first 6 bytes are a magic string: exactly "x93NUMPY".
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100321 The next 1 byte is an unsigned byte: the major version number of the file format, e.g. x01.
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100322 The next 1 byte is an unsigned byte: the minor version number of the file format, e.g. x00. Note: the version of the file format is not tied to the version of the numpy package.
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100323 The next 2 bytes form a little-endian unsigned short int: the length of the header data HEADER_LEN.
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100324 The next HEADER_LEN bytes form the header data describing the array's format. It is an ASCII string which contains a Python literal expression of a dictionary. It is terminated by a newline ('n') and padded with spaces ('x20') to make the total length of the magic string + 4 + HEADER_LEN be evenly divisible by 16 for alignment purposes.
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100325 The dictionary contains three keys:
326
327 "descr" : dtype.descr
328 An object that can be passed as an argument to the numpy.dtype() constructor to create the array's dtype.
329 "fortran_order" : bool
330 Whether the array data is Fortran-contiguous or not. Since Fortran-contiguous arrays are a common form of non-C-contiguity, we allow them to be written directly to disk for efficiency.
331 "shape" : tuple of int
332 The shape of the array.
333 For repeatability and readability, this dictionary is formatted using pprint.pformat() so the keys are in alphabetic order.
334 */
335
336 // remove trailing newline
337 if (header.back() != '\n')
338 throw std::runtime_error("invalid header");
339 header.pop_back();
340
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000341 // parse the dictionary
342 std::vector<std::string> keys { "descr", "fortran_order", "shape" };
343 auto dict_map = npy::pyparse::parse_dict(header, keys);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100344
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000345 if (dict_map.size() == 0)
346 throw std::runtime_error("invalid dictionary in header");
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100347
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000348 std::string descr_s = dict_map["descr"];
349 std::string fortran_s = dict_map["fortran_order"];
350 std::string shape_s = dict_map["shape"];
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100351
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000352 // TODO: extract info from typestring
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100353 parse_typestring(descr_s);
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000354 // remove
355 descr = npy::pyparse::parse_str(descr_s);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100356
357 // convert literal Python bool to C++ bool
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000358 fortran_order = npy::pyparse::parse_bool(fortran_s);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100359
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000360 // parse the shape tuple
361 auto shape_v = npy::pyparse::parse_tuple(shape_s);
362 if (shape_v.size() == 0)
363 throw std::runtime_error("invalid shape tuple in header");
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100364
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000365 for ( auto item : shape_v ) {
366 std::stringstream stream(item);
367 unsigned long value;
368 stream >> value;
369 ndarray_len_t dim = static_cast<ndarray_len_t>(value);
370 shape.push_back(dim);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100371 }
372}
373
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000374
375inline std::string write_header_dict(const std::string& descr, bool fortran_order, const std::vector<ndarray_len_t>& shape) {
376 std::string s_fortran_order = npy::pyparse::write_boolean(fortran_order);
377 std::string shape_s = npy::pyparse::write_tuple(shape);
378
379 return "{'descr': '" + descr + "', 'fortran_order': " + s_fortran_order + ", 'shape': " + shape_s + ", }";
380}
381
382inline void write_header(std::ostream& out, const std::string& descr, bool fortran_order, const std::vector<ndarray_len_t>& shape_v)
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100383{
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000384 std::string header_dict = write_header_dict(descr, fortran_order, shape_v);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100385
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000386 size_t length = magic_string_length + 2 + 2 + header_dict.length() + 1;
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100387
388 unsigned char version[2] = {1, 0};
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000389 if (length >= 255*255) {
390 length = magic_string_length + 2 + 4 + header_dict.length() + 1;
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100391 version[0] = 2;
392 version[1] = 0;
393 }
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000394 size_t padding_len = 16 - length % 16;
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100395 std::string padding (padding_len, ' ');
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100396
397 // write magic
398 write_magic(out, version[0], version[1]);
399
400 // write header length
401 if (version[0] == 1 && version[1] == 0) {
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100402 char header_len_le16[2];
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000403 uint16_t header_len = header_dict.length() + padding.length() + 1;
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100404
405 header_len_le16[0] = (header_len >> 0) & 0xff;
406 header_len_le16[1] = (header_len >> 8) & 0xff;
407 out.write(reinterpret_cast<char *>(header_len_le16), 2);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100408 }else{
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100409 char header_len_le32[4];
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000410 uint32_t header_len = header_dict.length() + padding.length() + 1;
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100411
412 header_len_le32[0] = (header_len >> 0) & 0xff;
413 header_len_le32[1] = (header_len >> 8) & 0xff;
414 header_len_le32[2] = (header_len >> 16) & 0xff;
415 header_len_le32[3] = (header_len >> 24) & 0xff;
416 out.write(reinterpret_cast<char *>(header_len_le32), 4);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100417 }
418
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000419 out << header_dict << padding << '\n';
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100420}
421
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000422inline std::string read_header(std::istream& istream) {
423 // check magic bytes an version number
424 unsigned char v_major, v_minor;
425 read_magic(istream, v_major, v_minor);
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100426
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000427 uint32_t header_length;
428 if(v_major == 1 && v_minor == 0){
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100429
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000430 char header_len_le16[2];
431 istream.read(header_len_le16, 2);
432 header_length = (header_len_le16[0] << 0) | (header_len_le16[1] << 8);
433
434 if((magic_string_length + 2 + 2 + header_length) % 16 != 0) {
435 // TODO: display warning
436 }
437 }else if(v_major == 2 && v_minor == 0) {
438 char header_len_le32[4];
439 istream.read(header_len_le32, 4);
440
441 header_length = (header_len_le32[0] << 0) | (header_len_le32[1] << 8)
442 | (header_len_le32[2] << 16) | (header_len_le32[3] << 24);
443
444 if((magic_string_length + 2 + 4 + header_length) % 16 != 0) {
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100445 // TODO: display warning
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000446 }
447 }else{
448 throw std::runtime_error("unsupported file format version");
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100449 }
450
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000451 auto buf_v = std::vector<char>();
452 buf_v.reserve(header_length);
453 istream.read(buf_v.data(), header_length);
454 std::string header(buf_v.data(), header_length);
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100455
456 return header;
457}
458
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000459inline ndarray_len_t comp_size(const std::vector<ndarray_len_t>& shape) {
460 ndarray_len_t size = 1;
461 for (ndarray_len_t i : shape )
462 size *= i;
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100463
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000464 return size;
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100465}
466
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100467template<typename Scalar>
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000468inline void SaveArrayAsNumpy( const std::string& filename, bool fortran_order, unsigned int n_dims, const unsigned long shape[], const std::vector<Scalar>& data)
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100469{
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000470 Typestring typestring_o(data);
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100471 std::string typestring = typestring_o.str();
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100472
473 std::ofstream stream( filename, std::ofstream::binary);
474 if(!stream) {
475 throw std::runtime_error("io error: failed to open a file.");
476 }
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100477
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000478 std::vector<ndarray_len_t> shape_v(shape, shape+n_dims);
479 write_header(stream, typestring, fortran_order, shape_v);
480
481 auto size = static_cast<size_t>(comp_size(shape_v));
482
483 stream.write(reinterpret_cast<const char*>(data.data()), sizeof(Scalar) * size);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100484}
485
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100486
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100487template<typename Scalar>
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000488inline void LoadArrayFromNumpy(const std::string& filename, std::vector<unsigned long>& shape, std::vector<Scalar>& data)
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100489{
490 std::ifstream stream(filename, std::ifstream::binary);
491 if(!stream) {
492 throw std::runtime_error("io error: failed to open a file.");
493 }
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100494
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000495 std::string header = read_header(stream);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100496
497 // parse header
498 bool fortran_order;
499 std::string typestr;
500
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000501 parse_header(header, typestr, fortran_order, shape);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100502
503 // check if the typestring matches the given one
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100504 Typestring typestring_o {data};
505 std::string expect_typestr = typestring_o.str();
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100506 if (typestr != expect_typestr) {
507 throw std::runtime_error("formatting error: typestrings not matching");
508 }
509
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000510
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100511 // compute the data size based on the shape
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000512 auto size = static_cast<size_t>(comp_size(shape));
513 data.resize(size);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100514
515 // read the data
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000516 stream.read(reinterpret_cast<char*>(data.data()), sizeof(Scalar)*size);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100517}
518
519} // namespace npy
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000520
521#endif // NPY_H