blob: 4399426de1faa774c7c906d3deaf8e604e46b797 [file] [log] [blame]
Anthony Barbiera3adb3a2017-09-13 16:03:39 +01001/*
2 Copyright 2017 Leon Merten Lohse
3
4 Permission is hereby granted, free of charge, to any person obtaining a copy
5 of this software and associated documentation files (the "Software"), to deal
6 in the Software without restriction, including without limitation the rights
7 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 copies of the Software, and to permit persons to whom the Software is
9 furnished to do so, subject to the following conditions:
10
11 The above copyright notice and this permission notice shall be included in
12 all copies or substantial portions of the Software.
13
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20 SOFTWARE.
21*/
22
Jakub Sujak3b504ef2022-12-07 23:55:22 +000023#ifndef NPY_HPP_
24#define NPY_HPP_
Anthony Barbier87f21cd2017-11-10 16:27:32 +000025
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010026#include <complex>
27#include <fstream>
28#include <string>
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010029#include <iostream>
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010030#include <sstream>
31#include <cstdint>
Anthony Barbier87f21cd2017-11-10 16:27:32 +000032#include <cstring>
Jakub Sujak3b504ef2022-12-07 23:55:22 +000033#include <array>
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010034#include <vector>
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010035#include <stdexcept>
36#include <algorithm>
Anthony Barbier87f21cd2017-11-10 16:27:32 +000037#include <unordered_map>
Pablo Marquez Telloa774d632022-05-31 19:39:18 +010038#include <type_traits>
Jakub Sujak3b504ef2022-12-07 23:55:22 +000039#include <typeinfo>
40#include <typeindex>
Pablo Marquez Telloa774d632022-05-31 19:39:18 +010041#include <iterator>
42#include <utility>
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010043
Jakub Sujak3b504ef2022-12-07 23:55:22 +000044
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010045namespace npy {
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010046
47/* Compile-time test for byte order.
48 If your compiler does not define these per default, you may want to define
Jakub Sujak3b504ef2022-12-07 23:55:22 +000049 one of these constants manually.
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010050 Defaults to little endian order. */
51#if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || \
52 defined(__BIG_ENDIAN__) || \
53 defined(__ARMEB__) || \
54 defined(__THUMBEB__) || \
55 defined(__AARCH64EB__) || \
56 defined(_MIBSEB) || defined(__MIBSEB) || defined(__MIBSEB__)
57const bool big_endian = true;
58#else
59const bool big_endian = false;
60#endif
61
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010062
63const char magic_string[] = "\x93NUMPY";
64const size_t magic_string_length = 6;
65
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +010066const char little_endian_char = '<';
67const char big_endian_char = '>';
68const char no_endian_char = '|';
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010069
Jakub Sujak3b504ef2022-12-07 23:55:22 +000070constexpr std::array<char, 3>
71endian_chars = {little_endian_char, big_endian_char, no_endian_char};
72constexpr std::array<char, 4>
73numtype_chars = {'f', 'i', 'u', 'c'};
74
75constexpr char host_endian_char = (big_endian ?
76 big_endian_char :
77 little_endian_char);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +010078
Anthony Barbier87f21cd2017-11-10 16:27:32 +000079/* npy array length */
80typedef unsigned long int ndarray_len_t;
81
Jakub Sujak3b504ef2022-12-07 23:55:22 +000082typedef std::pair<char, char> version_t;
83
84struct dtype_t {
85 const char byteorder;
86 const char kind;
87 const unsigned int itemsize;
88
89// TODO(llohse): implement as constexpr
90 inline std::string str() const {
91 const size_t max_buflen = 16;
92 char buf[max_buflen];
93 std::snprintf(buf, max_buflen, "%c%c%u", byteorder, kind, itemsize);
94 return std::string(buf);
95 }
96
97 inline std::tuple<const char, const char, const unsigned int> tie() const {
98 return std::tie(byteorder, kind, itemsize);
99 }
100};
101
102
103struct header_t {
104 const dtype_t dtype;
105 const bool fortran_order;
106 const std::vector <ndarray_len_t> shape;
107};
108
109inline void write_magic(std::ostream &ostream, version_t version) {
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100110 ostream.write(magic_string, magic_string_length);
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000111 ostream.put(version.first);
112 ostream.put(version.second);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100113}
114
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000115inline version_t read_magic(std::istream &istream) {
116 char buf[magic_string_length + 2];
117 istream.read(buf, magic_string_length + 2);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100118
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000119 if (!istream) {
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000120 throw std::runtime_error("io error: failed reading file");
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100121 }
122
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000123 if (0 != std::memcmp(buf, magic_string, magic_string_length))
124 throw std::runtime_error("this file does not have a valid npy format.");
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100125
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000126 version_t version;
127 version.first = buf[magic_string_length];
128 version.second = buf[magic_string_length + 1];
129
130 return version;
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100131}
132
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000133const std::unordered_map<std::type_index, dtype_t> dtype_map = {
134 {std::type_index(typeid(float)), {host_endian_char, 'f', sizeof(float)}},
135 {std::type_index(typeid(double)), {host_endian_char, 'f', sizeof(double)}},
136 {std::type_index(typeid(long double)), {host_endian_char, 'f', sizeof(long double)}},
137 {std::type_index(typeid(char)), {no_endian_char, 'i', sizeof(char)}},
138 {std::type_index(typeid(signed char)), {no_endian_char, 'i', sizeof(signed char)}},
139 {std::type_index(typeid(short)), {host_endian_char, 'i', sizeof(short)}},
140 {std::type_index(typeid(int)), {host_endian_char, 'i', sizeof(int)}},
141 {std::type_index(typeid(long)), {host_endian_char, 'i', sizeof(long)}},
142 {std::type_index(typeid(long long)), {host_endian_char, 'i', sizeof(long long)}},
143 {std::type_index(typeid(unsigned char)), {no_endian_char, 'u', sizeof(unsigned char)}},
144 {std::type_index(typeid(unsigned short)), {host_endian_char, 'u', sizeof(unsigned short)}},
145 {std::type_index(typeid(unsigned int)), {host_endian_char, 'u', sizeof(unsigned int)}},
146 {std::type_index(typeid(unsigned long)), {host_endian_char, 'u', sizeof(unsigned long)}},
147 {std::type_index(typeid(unsigned long long)), {host_endian_char, 'u', sizeof(unsigned long long)}},
148 {std::type_index(typeid(std::complex<float>)), {host_endian_char, 'c', sizeof(std::complex<float>)}},
149 {std::type_index(typeid(std::complex<double>)), {host_endian_char, 'c', sizeof(std::complex<double>)}},
150 {std::type_index(typeid(std::complex<long double>)), {host_endian_char, 'c', sizeof(std::complex<long double>)}}
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100151};
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100152
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100153
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000154// helpers
155inline bool is_digits(const std::string &str) {
156 return std::all_of(str.begin(), str.end(), ::isdigit);
157}
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100158
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000159template<typename T, size_t N>
160inline bool in_array(T val, const std::array <T, N> &arr) {
161 return std::find(std::begin(arr), std::end(arr), val) != std::end(arr);
162}
163
164inline dtype_t parse_descr(std::string typestring) {
165 if (typestring.length() < 3) {
166 throw std::runtime_error("invalid typestring (length)");
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100167 }
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000168
169 char byteorder_c = typestring.at(0);
170 char kind_c = typestring.at(1);
171 std::string itemsize_s = typestring.substr(2);
172
173 if (!in_array(byteorder_c, endian_chars)) {
174 throw std::runtime_error("invalid typestring (byteorder)");
175 }
176
177 if (!in_array(kind_c, numtype_chars)) {
178 throw std::runtime_error("invalid typestring (kind)");
179 }
180
181 if (!is_digits(itemsize_s)) {
182 throw std::runtime_error("invalid typestring (itemsize)");
183 }
184 unsigned int itemsize = std::stoul(itemsize_s);
185
186 return {byteorder_c, kind_c, itemsize};
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100187}
188
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000189namespace pyparse {
190
191/**
192 Removes leading and trailing whitespaces
193 */
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000194inline std::string trim(const std::string &str) {
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000195 const std::string whitespace = " \t";
196 auto begin = str.find_first_not_of(whitespace);
197
198 if (begin == std::string::npos)
199 return "";
200
201 auto end = str.find_last_not_of(whitespace);
202
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000203 return str.substr(begin, end - begin + 1);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100204}
205
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000206
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000207inline std::string get_value_from_map(const std::string &mapstr) {
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100208 size_t sep_pos = mapstr.find_first_of(":");
209 if (sep_pos == std::string::npos)
210 return "";
211
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000212 std::string tmp = mapstr.substr(sep_pos + 1);
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000213 return trim(tmp);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100214}
215
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000216/**
217 Parses the string representation of a Python dict
218
219 The keys need to be known and may not appear anywhere else in the data.
220 */
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000221inline std::unordered_map <std::string, std::string> parse_dict(std::string in, const std::vector <std::string> &keys) {
222 std::unordered_map <std::string, std::string> map;
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000223
224 if (keys.size() == 0)
225 return map;
226
227 in = trim(in);
228
229 // unwrap dictionary
230 if ((in.front() == '{') && (in.back() == '}'))
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000231 in = in.substr(1, in.length() - 2);
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000232 else
233 throw std::runtime_error("Not a Python dictionary.");
234
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000235 std::vector <std::pair<size_t, std::string>> positions;
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000236
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000237 for (auto const &value : keys) {
238 size_t pos = in.find("'" + value + "'");
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000239
240 if (pos == std::string::npos)
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000241 throw std::runtime_error("Missing '" + value + "' key.");
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000242
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000243 std::pair <size_t, std::string> position_pair{pos, value};
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000244 positions.push_back(position_pair);
245 }
246
247 // sort by position in dict
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000248 std::sort(positions.begin(), positions.end());
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000249
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000250 for (size_t i = 0; i < positions.size(); ++i) {
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000251 std::string raw_value;
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000252 size_t begin{positions[i].first};
253 size_t end{std::string::npos};
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000254
255 std::string key = positions[i].second;
256
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000257 if (i + 1 < positions.size())
258 end = positions[i + 1].first;
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000259
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000260 raw_value = in.substr(begin, end - begin);
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000261
262 raw_value = trim(raw_value);
263
264 if (raw_value.back() == ',')
265 raw_value.pop_back();
266
267 map[key] = get_value_from_map(raw_value);
268 }
269
270 return map;
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100271}
272
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000273/**
274 Parses the string representation of a Python boolean
275 */
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000276inline bool parse_bool(const std::string &in) {
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000277 if (in == "True")
278 return true;
279 if (in == "False")
280 return false;
281
282 throw std::runtime_error("Invalid python boolan.");
283}
284
285/**
286 Parses the string representation of a Python str
287 */
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000288inline std::string parse_str(const std::string &in) {
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000289 if ((in.front() == '\'') && (in.back() == '\''))
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000290 return in.substr(1, in.length() - 2);
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000291
292 throw std::runtime_error("Invalid python string.");
293}
294
295/**
296 Parses the string represenatation of a Python tuple into a vector of its items
297 */
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000298inline std::vector <std::string> parse_tuple(std::string in) {
299 std::vector <std::string> v;
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000300 const char seperator = ',';
301
302 in = trim(in);
303
304 if ((in.front() == '(') && (in.back() == ')'))
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000305 in = in.substr(1, in.length() - 2);
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000306 else
307 throw std::runtime_error("Invalid Python tuple.");
308
309 std::istringstream iss(in);
310
311 for (std::string token; std::getline(iss, token, seperator);) {
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000312 v.push_back(token);
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000313 }
314
315 return v;
316}
317
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000318template<typename T>
319inline std::string write_tuple(const std::vector <T> &v) {
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000320 if (v.size() == 0)
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000321 return "()";
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000322
323 std::ostringstream ss;
324
325 if (v.size() == 1) {
326 ss << "(" << v.front() << ",)";
327 } else {
328 const std::string delimiter = ", ";
329 // v.size() > 1
330 ss << "(";
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000331 std::copy(v.begin(), v.end() - 1, std::ostream_iterator<T>(ss, delimiter.c_str()));
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000332 ss << v.back();
333 ss << ")";
334 }
335
336 return ss.str();
337}
338
339inline std::string write_boolean(bool b) {
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000340 if (b)
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000341 return "True";
342 else
343 return "False";
344}
345
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000346} // namespace pyparse
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000347
348
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000349inline header_t parse_header(std::string header) {
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100350 /*
351 The first 6 bytes are a magic string: exactly "x93NUMPY".
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100352 The next 1 byte is an unsigned byte: the major version number of the file format, e.g. x01.
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100353 The next 1 byte is an unsigned byte: the minor version number of the file format, e.g. x00. Note: the version of the file format is not tied to the version of the numpy package.
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100354 The next 2 bytes form a little-endian unsigned short int: the length of the header data HEADER_LEN.
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100355 The next HEADER_LEN bytes form the header data describing the array's format. It is an ASCII string which contains a Python literal expression of a dictionary. It is terminated by a newline ('n') and padded with spaces ('x20') to make the total length of the magic string + 4 + HEADER_LEN be evenly divisible by 16 for alignment purposes.
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100356 The dictionary contains three keys:
357
358 "descr" : dtype.descr
359 An object that can be passed as an argument to the numpy.dtype() constructor to create the array's dtype.
360 "fortran_order" : bool
361 Whether the array data is Fortran-contiguous or not. Since Fortran-contiguous arrays are a common form of non-C-contiguity, we allow them to be written directly to disk for efficiency.
362 "shape" : tuple of int
363 The shape of the array.
364 For repeatability and readability, this dictionary is formatted using pprint.pformat() so the keys are in alphabetic order.
365 */
366
367 // remove trailing newline
368 if (header.back() != '\n')
369 throw std::runtime_error("invalid header");
370 header.pop_back();
371
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000372 // parse the dictionary
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000373 std::vector <std::string> keys{"descr", "fortran_order", "shape"};
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000374 auto dict_map = npy::pyparse::parse_dict(header, keys);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100375
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000376 if (dict_map.size() == 0)
377 throw std::runtime_error("invalid dictionary in header");
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100378
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000379 std::string descr_s = dict_map["descr"];
380 std::string fortran_s = dict_map["fortran_order"];
381 std::string shape_s = dict_map["shape"];
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100382
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000383 std::string descr = npy::pyparse::parse_str(descr_s);
384 dtype_t dtype = parse_descr(descr);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100385
386 // convert literal Python bool to C++ bool
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000387 bool fortran_order = npy::pyparse::parse_bool(fortran_s);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100388
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000389 // parse the shape tuple
390 auto shape_v = npy::pyparse::parse_tuple(shape_s);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100391
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000392 std::vector <ndarray_len_t> shape;
393 for (auto item : shape_v) {
394 ndarray_len_t dim = static_cast<ndarray_len_t>(std::stoul(item));
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000395 shape.push_back(dim);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100396 }
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000397
398 return {dtype, fortran_order, shape};
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100399}
400
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000401
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000402inline std::string
403write_header_dict(const std::string &descr, bool fortran_order, const std::vector <ndarray_len_t> &shape) {
404 std::string s_fortran_order = npy::pyparse::write_boolean(fortran_order);
405 std::string shape_s = npy::pyparse::write_tuple(shape);
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000406
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000407 return "{'descr': '" + descr + "', 'fortran_order': " + s_fortran_order + ", 'shape': " + shape_s + ", }";
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000408}
409
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000410inline void write_header(std::ostream &out, const header_t &header) {
411 std::string header_dict = write_header_dict(header.dtype.str(), header.fortran_order, header.shape);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100412
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000413 size_t length = magic_string_length + 2 + 2 + header_dict.length() + 1;
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100414
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000415 version_t version{1, 0};
416 if (length >= 255 * 255) {
417 length = magic_string_length + 2 + 4 + header_dict.length() + 1;
418 version = {2, 0};
419 }
420 size_t padding_len = 16 - length % 16;
421 std::string padding(padding_len, ' ');
422
423 // write magic
424 write_magic(out, version);
425
426 // write header length
427 if (version == version_t{1, 0}) {
428 uint8_t header_len_le16[2];
429 uint16_t header_len = static_cast<uint16_t>(header_dict.length() + padding.length() + 1);
430
431 header_len_le16[0] = (header_len >> 0) & 0xff;
432 header_len_le16[1] = (header_len >> 8) & 0xff;
433 out.write(reinterpret_cast<char *>(header_len_le16), 2);
434 } else {
435 uint8_t header_len_le32[4];
436 uint32_t header_len = static_cast<uint32_t>(header_dict.length() + padding.length() + 1);
437
438 header_len_le32[0] = (header_len >> 0) & 0xff;
439 header_len_le32[1] = (header_len >> 8) & 0xff;
440 header_len_le32[2] = (header_len >> 16) & 0xff;
441 header_len_le32[3] = (header_len >> 24) & 0xff;
442 out.write(reinterpret_cast<char *>(header_len_le32), 4);
443 }
444
445 out << header_dict << padding << '\n';
446}
447
448inline std::string read_header(std::istream &istream) {
449 // check magic bytes an version number
450 version_t version = read_magic(istream);
451
452 uint32_t header_length;
453 if (version == version_t{1, 0}) {
454 uint8_t header_len_le16[2];
455 istream.read(reinterpret_cast<char *>(header_len_le16), 2);
456 header_length = (header_len_le16[0] << 0) | (header_len_le16[1] << 8);
457
458 if ((magic_string_length + 2 + 2 + header_length) % 16 != 0) {
459 // TODO(llohse): display warning
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100460 }
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000461 } else if (version == version_t{2, 0}) {
462 uint8_t header_len_le32[4];
463 istream.read(reinterpret_cast<char *>(header_len_le32), 4);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100464
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000465 header_length = (header_len_le32[0] << 0) | (header_len_le32[1] << 8)
466 | (header_len_le32[2] << 16) | (header_len_le32[3] << 24);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100467
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000468 if ((magic_string_length + 2 + 4 + header_length) % 16 != 0) {
469 // TODO(llohse): display warning
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100470 }
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000471 } else {
472 throw std::runtime_error("unsupported file format version");
473 }
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100474
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000475 auto buf_v = std::vector<char>(header_length);
476 istream.read(buf_v.data(), header_length);
477 std::string header(buf_v.data(), header_length);
478
479 return header;
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100480}
481
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000482inline ndarray_len_t comp_size(const std::vector <ndarray_len_t> &shape) {
483 ndarray_len_t size = 1;
484 for (ndarray_len_t i : shape)
485 size *= i;
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100486
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000487 return size;
Anthony Barbier3c5b4ff2017-10-12 13:20:52 +0100488}
489
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100490template<typename Scalar>
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000491inline void
492SaveArrayAsNumpy(const std::string &filename, bool fortran_order, unsigned int n_dims, const unsigned long shape[],
493 const Scalar* data) {
494// static_assert(has_typestring<Scalar>::value, "scalar type not understood");
495 const dtype_t dtype = dtype_map.at(std::type_index(typeid(Scalar)));
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100496
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000497 std::ofstream stream(filename, std::ofstream::binary);
498 if (!stream) {
499 throw std::runtime_error("io error: failed to open a file.");
500 }
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100501
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000502 std::vector <ndarray_len_t> shape_v(shape, shape + n_dims);
503 header_t header{dtype, fortran_order, shape_v};
504 write_header(stream, header);
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000505
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000506 auto size = static_cast<size_t>(comp_size(shape_v));
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000507
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000508 stream.write(reinterpret_cast<const char *>(data), sizeof(Scalar) * size);
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100509}
510
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100511template<typename Scalar>
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000512inline void
513SaveArrayAsNumpy(const std::string &filename, bool fortran_order, unsigned int n_dims, const unsigned long shape[],
514 const std::vector <Scalar> &data) {
515 SaveArrayAsNumpy(filename, fortran_order, n_dims, shape, data.data());
Anthony Barbiera3adb3a2017-09-13 16:03:39 +0100516}
517
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000518template<typename Scalar>
519inline void
520LoadArrayFromNumpy(const std::string &filename, std::vector<unsigned long> &shape, std::vector <Scalar> &data) {
521 bool fortran_order;
522 LoadArrayFromNumpy<Scalar>(filename, shape, fortran_order, data);
523}
Anthony Barbier87f21cd2017-11-10 16:27:32 +0000524
Jakub Sujak3b504ef2022-12-07 23:55:22 +0000525template<typename Scalar>
526inline void LoadArrayFromNumpy(const std::string &filename, std::vector<unsigned long> &shape, bool &fortran_order,
527 std::vector <Scalar> &data) {
528 std::ifstream stream(filename, std::ifstream::binary);
529 if (!stream) {
530 throw std::runtime_error("io error: failed to open a file.");
531 }
532
533 std::string header_s = read_header(stream);
534
535 // parse header
536 header_t header = parse_header(header_s);
537
538 // check if the typestring matches the given one
539// static_assert(has_typestring<Scalar>::value, "scalar type not understood");
540 const dtype_t dtype = dtype_map.at(std::type_index(typeid(Scalar)));
541
542 if (header.dtype.tie() != dtype.tie()) {
543 throw std::runtime_error("formatting error: typestrings not matching");
544 }
545
546 shape = header.shape;
547 fortran_order = header.fortran_order;
548
549 // compute the data size based on the shape
550 auto size = static_cast<size_t>(comp_size(shape));
551 data.resize(size);
552
553 // read the data
554 stream.read(reinterpret_cast<char *>(data.data()), sizeof(Scalar) * size);
555}
556
557} // namespace npy
558
559#endif // NPY_HPP_