blob: 24244ca272831a874f94238757e7a41809f19032 [file] [log] [blame]
Anthony Barbiera3adb3a2017-09-13 16:03:39 +01001/*
2 Copyright 2017 Leon Merten Lohse
3
4 Permission is hereby granted, free of charge, to any person obtaining a copy
5 of this software and associated documentation files (the "Software"), to deal
6 in the Software without restriction, including without limitation the rights
7 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 copies of the Software, and to permit persons to whom the Software is
9 furnished to do so, subject to the following conditions:
10
11 The above copyright notice and this permission notice shall be included in
12 all copies or substantial portions of the Software.
13
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20 SOFTWARE.
21*/
22
Anthony Barbier87f21cd2017-11-10 16:27:32 +000023#ifndef NPY_H
24#define NPY_H
25
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010026#include <complex>
27#include <fstream>
28#include <string>
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010029#include <iostream>
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010030#include <sstream>
31#include <cstdint>
Anthony Barbier87f21cd2017-11-10 16:27:32 +000032#include <cstring>
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010033#include <vector>
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010034#include <stdexcept>
35#include <algorithm>
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010036#include <regex>
Anthony Barbier87f21cd2017-11-10 16:27:32 +000037#include <unordered_map>
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010038
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010039
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010040namespace npy {
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010041
42/* Compile-time test for byte order.
43 If your compiler does not define these per default, you may want to define
44 one of these constants manually.
45 Defaults to little endian order. */
46#if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || \
47 defined(__BIG_ENDIAN__) || \
48 defined(__ARMEB__) || \
49 defined(__THUMBEB__) || \
50 defined(__AARCH64EB__) || \
51 defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__)
52const bool big_endian = true;
53#else
54const bool big_endian = false;
55#endif
56
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010057
58const char magic_string[] = "\x93NUMPY";
59const size_t magic_string_length = 6;
60
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010061const char little_endian_char = '<';
62const char big_endian_char = '>';
63const char no_endian_char = '|';
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010064
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010065constexpr char host_endian_char = ( big_endian ?
66 big_endian_char :
67 little_endian_char );
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010068
Anthony Barbier87f21cd2017-11-10 16:27:32 +000069/* npy array length */
70typedef unsigned long int ndarray_len_t;
71
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010072inline void write_magic(std::ostream& ostream, unsigned char v_major=1, unsigned char v_minor=0) {
73 ostream.write(magic_string, magic_string_length);
74 ostream.put(v_major);
75 ostream.put(v_minor);
76}
77
Anthony Barbier87f21cd2017-11-10 16:27:32 +000078inline void read_magic(std::istream& istream, unsigned char& v_major, unsigned char& v_minor) {
79 char buf[magic_string_length+2];
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010080 istream.read(buf, magic_string_length+2);
81
82 if(!istream) {
Anthony Barbier87f21cd2017-11-10 16:27:32 +000083 throw std::runtime_error("io error: failed reading file");
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010084 }
85
Anthony Barbier87f21cd2017-11-10 16:27:32 +000086 if (0 != std::memcmp(buf, magic_string, magic_string_length))
87 throw std::runtime_error("this file does not have a valid npy format.");
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010088
Anthony Barbier87f21cd2017-11-10 16:27:32 +000089 v_major = buf[magic_string_length];
90 v_minor = buf[magic_string_length+1];
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010091}
92
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010093// typestring magic
94struct Typestring {
95 private:
96 char c_endian;
97 char c_type;
98 int len;
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010099
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100100 public:
101 inline std::string str() {
102 const size_t max_buflen = 16;
103 char buf[max_buflen];
104 std::sprintf(buf, "%c%c%u", c_endian, c_type, len);
105 return std::string(buf);
106 }
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100107
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000108 Typestring(const std::vector<float>& v)
109 :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(float)} {}
110 Typestring(const std::vector<double>& v)
111 :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(double)} {}
112 Typestring(const std::vector<long double>& v)
113 :c_endian {host_endian_char}, c_type {'f'}, len {sizeof(long double)} {}
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100114
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000115 Typestring(const std::vector<char>& v)
116 :c_endian {no_endian_char}, c_type {'i'}, len {sizeof(char)} {}
117 Typestring(const std::vector<short>& v)
118 :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(short)} {}
119 Typestring(const std::vector<int>& v)
120 :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(int)} {}
121 Typestring(const std::vector<long>& v)
122 :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(long)} {}
123 Typestring(const std::vector<long long>& v) :c_endian {host_endian_char}, c_type {'i'}, len {sizeof(long long)} {}
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100124
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000125 Typestring(const std::vector<unsigned char>& v)
126 :c_endian {no_endian_char}, c_type {'u'}, len {sizeof(unsigned char)} {}
127 Typestring(const std::vector<unsigned short>& v)
128 :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned short)} {}
129 Typestring(const std::vector<unsigned int>& v)
130 :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned int)} {}
131 Typestring(const std::vector<unsigned long>& v)
132 :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned long)} {}
133 Typestring(const std::vector<unsigned long long>& v)
134 :c_endian {host_endian_char}, c_type {'u'}, len {sizeof(unsigned long long)} {}
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100135
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000136 Typestring(const std::vector<std::complex<float>>& v)
137 :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<float>)} {}
138 Typestring(const std::vector<std::complex<double>>& v)
139 :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<double>)} {}
140 Typestring(const std::vector<std::complex<long double>>& v)
141 :c_endian {host_endian_char}, c_type {'c'}, len {sizeof(std::complex<long double>)} {}
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100142};
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100143
144inline void parse_typestring( std::string typestring){
145 std::regex re ("'([<>|])([ifuc])(\\d+)'");
146 std::smatch sm;
147
148 std::regex_match(typestring, sm, re );
149
150 if ( sm.size() != 4 ) {
151 throw std::runtime_error("invalid typestring");
152 }
153}
154
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000155namespace pyparse {
156
157/**
158 Removes leading and trailing whitespaces
159 */
160inline std::string trim(const std::string& str) {
161 const std::string whitespace = " \t";
162 auto begin = str.find_first_not_of(whitespace);
163
164 if (begin == std::string::npos)
165 return "";
166
167 auto end = str.find_last_not_of(whitespace);
168
169 return str.substr(begin, end-begin+1);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100170}
171
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000172
173inline std::string get_value_from_map(const std::string& mapstr) {
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100174 size_t sep_pos = mapstr.find_first_of(":");
175 if (sep_pos == std::string::npos)
176 return "";
177
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000178 std::string tmp = mapstr.substr(sep_pos+1);
179 return trim(tmp);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100180}
181
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000182/**
183 Parses the string representation of a Python dict
184
185 The keys need to be known and may not appear anywhere else in the data.
186 */
187inline std::unordered_map<std::string, std::string> parse_dict(std::string in, std::vector<std::string>& keys) {
188
189 std::unordered_map<std::string, std::string> map;
190
191 if (keys.size() == 0)
192 return map;
193
194 in = trim(in);
195
196 // unwrap dictionary
197 if ((in.front() == '{') && (in.back() == '}'))
198 in = in.substr(1, in.length()-2);
199 else
200 throw std::runtime_error("Not a Python dictionary.");
201
202 std::vector<std::pair<size_t, std::string>> positions;
203
204 for (auto const& value : keys) {
205 size_t pos = in.find( "'" + value + "'" );
206
207 if (pos == std::string::npos)
208 throw std::runtime_error("Missing '"+value+"' key.");
209
210 std::pair<size_t, std::string> position_pair { pos, value };
211 positions.push_back(position_pair);
212 }
213
214 // sort by position in dict
215 std::sort(positions.begin(), positions.end() );
216
217 for(size_t i = 0; i < positions.size(); ++i) {
218 std::string raw_value;
219 size_t begin { positions[i].first };
220 size_t end { std::string::npos };
221
222 std::string key = positions[i].second;
223
224 if ( i+1 < positions.size() )
225 end = positions[i+1].first;
226
227 raw_value = in.substr(begin, end-begin);
228
229 raw_value = trim(raw_value);
230
231 if (raw_value.back() == ',')
232 raw_value.pop_back();
233
234 map[key] = get_value_from_map(raw_value);
235 }
236
237 return map;
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100238}
239
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000240/**
241 Parses the string representation of a Python boolean
242 */
243inline bool parse_bool(const std::string& in) {
244 if (in == "True")
245 return true;
246 if (in == "False")
247 return false;
248
249 throw std::runtime_error("Invalid python boolan.");
250}
251
252/**
253 Parses the string representation of a Python str
254 */
255inline std::string parse_str(const std::string& in) {
256 if ((in.front() == '\'') && (in.back() == '\''))
257 return in.substr(1, in.length()-2);
258
259 throw std::runtime_error("Invalid python string.");
260}
261
262/**
263 Parses the string represenatation of a Python tuple into a vector of its items
264 */
265inline std::vector<std::string> parse_tuple(std::string in) {
266 std::vector<std::string> v;
267 const char seperator = ',';
268
269 in = trim(in);
270
271 if ((in.front() == '(') && (in.back() == ')'))
272 in = in.substr(1, in.length()-2);
273 else
274 throw std::runtime_error("Invalid Python tuple.");
275
276 std::istringstream iss(in);
277
278 for (std::string token; std::getline(iss, token, seperator);) {
279 v.push_back(token);
280 }
281
282 return v;
283}
284
285template <typename T>
286inline std::string write_tuple(const std::vector<T>& v) {
287 if (v.size() == 0)
288 return "";
289
290 std::ostringstream ss;
291
292 if (v.size() == 1) {
293 ss << "(" << v.front() << ",)";
294 } else {
295 const std::string delimiter = ", ";
296 // v.size() > 1
297 ss << "(";
298 std::copy(v.begin(), v.end()-1, std::ostream_iterator<T>(ss, delimiter.c_str()));
299 ss << v.back();
300 ss << ")";
301 }
302
303 return ss.str();
304}
305
306inline std::string write_boolean(bool b) {
307 if(b)
308 return "True";
309 else
310 return "False";
311}
312
313} // namespace pyparse
314
315
316inline void parse_header(std::string header, std::string& descr, bool& fortran_order, std::vector<ndarray_len_t>& shape) {
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100317 /*
318 The first 6 bytes are a magic string: exactly "x93NUMPY".
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100319 The next 1 byte is an unsigned byte: the major version number of the file format, e.g. x01.
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100320 The next 1 byte is an unsigned byte: the minor version number of the file format, e.g. x00. Note: the version of the file format is not tied to the version of the numpy package.
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100321 The next 2 bytes form a little-endian unsigned short int: the length of the header data HEADER_LEN.
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100322 The next HEADER_LEN bytes form the header data describing the array's format. It is an ASCII string which contains a Python literal expression of a dictionary. It is terminated by a newline ('n') and padded with spaces ('x20') to make the total length of the magic string + 4 + HEADER_LEN be evenly divisible by 16 for alignment purposes.
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100323 The dictionary contains three keys:
324
325 "descr" : dtype.descr
326 An object that can be passed as an argument to the numpy.dtype() constructor to create the array's dtype.
327 "fortran_order" : bool
328 Whether the array data is Fortran-contiguous or not. Since Fortran-contiguous arrays are a common form of non-C-contiguity, we allow them to be written directly to disk for efficiency.
329 "shape" : tuple of int
330 The shape of the array.
331 For repeatability and readability, this dictionary is formatted using pprint.pformat() so the keys are in alphabetic order.
332 */
333
334 // remove trailing newline
335 if (header.back() != '\n')
336 throw std::runtime_error("invalid header");
337 header.pop_back();
338
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000339 // parse the dictionary
340 std::vector<std::string> keys { "descr", "fortran_order", "shape" };
341 auto dict_map = npy::pyparse::parse_dict(header, keys);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100342
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000343 if (dict_map.size() == 0)
344 throw std::runtime_error("invalid dictionary in header");
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100345
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000346 std::string descr_s = dict_map["descr"];
347 std::string fortran_s = dict_map["fortran_order"];
348 std::string shape_s = dict_map["shape"];
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100349
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000350 // TODO: extract info from typestring
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100351 parse_typestring(descr_s);
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000352 // remove
353 descr = npy::pyparse::parse_str(descr_s);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100354
355 // convert literal Python bool to C++ bool
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000356 fortran_order = npy::pyparse::parse_bool(fortran_s);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100357
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000358 // parse the shape tuple
359 auto shape_v = npy::pyparse::parse_tuple(shape_s);
360 if (shape_v.size() == 0)
361 throw std::runtime_error("invalid shape tuple in header");
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100362
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000363 for ( auto item : shape_v ) {
364 std::stringstream stream(item);
365 unsigned long value;
366 stream >> value;
367 ndarray_len_t dim = static_cast<ndarray_len_t>(value);
368 shape.push_back(dim);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100369 }
370}
371
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000372
373inline std::string write_header_dict(const std::string& descr, bool fortran_order, const std::vector<ndarray_len_t>& shape) {
374 std::string s_fortran_order = npy::pyparse::write_boolean(fortran_order);
375 std::string shape_s = npy::pyparse::write_tuple(shape);
376
377 return "{'descr': '" + descr + "', 'fortran_order': " + s_fortran_order + ", 'shape': " + shape_s + ", }";
378}
379
380inline void write_header(std::ostream& out, const std::string& descr, bool fortran_order, const std::vector<ndarray_len_t>& shape_v)
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100381{
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000382 std::string header_dict = write_header_dict(descr, fortran_order, shape_v);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100383
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000384 size_t length = magic_string_length + 2 + 2 + header_dict.length() + 1;
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100385
386 unsigned char version[2] = {1, 0};
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000387 if (length >= 255*255) {
388 length = magic_string_length + 2 + 4 + header_dict.length() + 1;
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100389 version[0] = 2;
390 version[1] = 0;
391 }
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000392 size_t padding_len = 16 - length % 16;
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100393 std::string padding (padding_len, ' ');
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100394
395 // write magic
396 write_magic(out, version[0], version[1]);
397
398 // write header length
399 if (version[0] == 1 && version[1] == 0) {
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100400 char header_len_le16[2];
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000401 uint16_t header_len = header_dict.length() + padding.length() + 1;
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100402
403 header_len_le16[0] = (header_len >> 0) & 0xff;
404 header_len_le16[1] = (header_len >> 8) & 0xff;
405 out.write(reinterpret_cast<char *>(header_len_le16), 2);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100406 }else{
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100407 char header_len_le32[4];
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000408 uint32_t header_len = header_dict.length() + padding.length() + 1;
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100409
410 header_len_le32[0] = (header_len >> 0) & 0xff;
411 header_len_le32[1] = (header_len >> 8) & 0xff;
412 header_len_le32[2] = (header_len >> 16) & 0xff;
413 header_len_le32[3] = (header_len >> 24) & 0xff;
414 out.write(reinterpret_cast<char *>(header_len_le32), 4);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100415 }
416
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000417 out << header_dict << padding << '\n';
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100418}
419
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000420inline std::string read_header(std::istream& istream) {
421 // check magic bytes an version number
422 unsigned char v_major, v_minor;
423 read_magic(istream, v_major, v_minor);
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100424
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000425 uint32_t header_length;
426 if(v_major == 1 && v_minor == 0){
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100427
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000428 char header_len_le16[2];
429 istream.read(header_len_le16, 2);
430 header_length = (header_len_le16[0] << 0) | (header_len_le16[1] << 8);
431
432 if((magic_string_length + 2 + 2 + header_length) % 16 != 0) {
433 // TODO: display warning
434 }
435 }else if(v_major == 2 && v_minor == 0) {
436 char header_len_le32[4];
437 istream.read(header_len_le32, 4);
438
439 header_length = (header_len_le32[0] << 0) | (header_len_le32[1] << 8)
440 | (header_len_le32[2] << 16) | (header_len_le32[3] << 24);
441
442 if((magic_string_length + 2 + 4 + header_length) % 16 != 0) {
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100443 // TODO: display warning
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000444 }
445 }else{
446 throw std::runtime_error("unsupported file format version");
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100447 }
448
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000449 auto buf_v = std::vector<char>();
450 buf_v.reserve(header_length);
451 istream.read(buf_v.data(), header_length);
452 std::string header(buf_v.data(), header_length);
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100453
454 return header;
455}
456
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000457inline ndarray_len_t comp_size(const std::vector<ndarray_len_t>& shape) {
458 ndarray_len_t size = 1;
459 for (ndarray_len_t i : shape )
460 size *= i;
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100461
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000462 return size;
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100463}
464
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100465template<typename Scalar>
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000466inline void SaveArrayAsNumpy( const std::string& filename, bool fortran_order, unsigned int n_dims, const unsigned long shape[], const std::vector<Scalar>& data)
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100467{
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000468 Typestring typestring_o(data);
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100469 std::string typestring = typestring_o.str();
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100470
471 std::ofstream stream( filename, std::ofstream::binary);
472 if(!stream) {
473 throw std::runtime_error("io error: failed to open a file.");
474 }
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100475
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000476 std::vector<ndarray_len_t> shape_v(shape, shape+n_dims);
477 write_header(stream, typestring, fortran_order, shape_v);
478
479 auto size = static_cast<size_t>(comp_size(shape_v));
480
481 stream.write(reinterpret_cast<const char*>(data.data()), sizeof(Scalar) * size);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100482}
483
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100484
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100485template<typename Scalar>
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000486inline void LoadArrayFromNumpy(const std::string& filename, std::vector<unsigned long>& shape, std::vector<Scalar>& data)
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100487{
488 std::ifstream stream(filename, std::ifstream::binary);
489 if(!stream) {
490 throw std::runtime_error("io error: failed to open a file.");
491 }
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100492
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000493 std::string header = read_header(stream);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100494
495 // parse header
496 bool fortran_order;
497 std::string typestr;
498
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000499 parse_header(header, typestr, fortran_order, shape);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100500
501 // check if the typestring matches the given one
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100502 Typestring typestring_o {data};
503 std::string expect_typestr = typestring_o.str();
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100504 if (typestr != expect_typestr) {
505 throw std::runtime_error("formatting error: typestrings not matching");
506 }
507
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000508
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100509 // compute the data size based on the shape
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000510 auto size = static_cast<size_t>(comp_size(shape));
511 data.resize(size);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100512
513 // read the data
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000514 stream.read(reinterpret_cast<char*>(data.data()), sizeof(Scalar)*size);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100515}
516
517} // namespace npy
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000518
519#endif // NPY_H