blob: 65d76e388044c624c58e6f86a595e6362eb52cb9 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
2// Copyright (c) 2020-2021, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "numpy_utils.h"
James Ward485a11d2022-08-05 13:48:37 +010017#include "half.hpp"
Eric Kunze2364dcd2021-04-26 11:06:57 -070018
19// Magic NUMPY header
20static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
21static const int NUMPY_HEADER_SZ = 128;
Jeremy Johnson82dbb322021-07-08 11:53:04 +010022// Maximum shape dimensions supported
23static const int NUMPY_MAX_DIMS_SUPPORTED = 10;
Eric Kunze2364dcd2021-04-26 11:06:57 -070024
25NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
26{
27 const char dtype_str[] = "'|b1'";
28 return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true);
29}
30
Jerry Gec93cb4a2023-05-18 21:16:22 +000031NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, uint8_t* databuf)
32{
33 const char dtype_str[] = "'|u1'";
34 return readFromNpyFileCommon(filename, dtype_str, sizeof(uint8_t), elems, databuf, false);
35}
36
37NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int8_t* databuf)
38{
39 const char dtype_str[] = "'|i1'";
40 return readFromNpyFileCommon(filename, dtype_str, sizeof(int8_t), elems, databuf, false);
41}
42
43NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, uint16_t* databuf)
44{
45 const char dtype_str[] = "'<u2'";
46 return readFromNpyFileCommon(filename, dtype_str, sizeof(uint16_t), elems, databuf, false);
47}
48
49NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int16_t* databuf)
50{
51 const char dtype_str[] = "'<i2'";
52 return readFromNpyFileCommon(filename, dtype_str, sizeof(int16_t), elems, databuf, false);
53}
54
Eric Kunze2364dcd2021-04-26 11:06:57 -070055NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
56{
57 const char dtype_str[] = "'<i4'";
58 return readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, false);
59}
60
61NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
62{
63 const char dtype_str[] = "'<i8'";
64 return readFromNpyFileCommon(filename, dtype_str, sizeof(int64_t), elems, databuf, false);
65}
66
67NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf)
68{
69 const char dtype_str[] = "'<f4'";
70 return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false);
71}
72
Tai Ly3ef34fb2023-04-04 20:34:05 +000073NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, double* databuf)
74{
75 const char dtype_str[] = "'<f8'";
76 return readFromNpyFileCommon(filename, dtype_str, sizeof(double), elems, databuf, false);
77}
78
James Ward485a11d2022-08-05 13:48:37 +010079NumpyUtilities::NPError
80 NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf)
81{
82 const char dtype_str[] = "'<f2'";
83 return readFromNpyFileCommon(filename, dtype_str, sizeof(half_float::half), elems, databuf, false);
84}
85
Eric Kunze2364dcd2021-04-26 11:06:57 -070086NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename,
87 const char* dtype_str,
88 const size_t elementsize,
89 const uint32_t elems,
90 void* databuf,
91 bool bool_translate)
92{
93 FILE* infile = nullptr;
94 NPError rc = NO_ERROR;
95
96 assert(filename);
97 assert(databuf);
98
99 infile = fopen(filename, "rb");
100 if (!infile)
101 {
102 return FILE_NOT_FOUND;
103 }
104
105 rc = checkNpyHeader(infile, elems, dtype_str);
106 if (rc == NO_ERROR)
107 {
108 if (bool_translate)
109 {
110 // Read in the data from numpy byte array to native bool
111 // array format
112 bool* buf = reinterpret_cast<bool*>(databuf);
113 for (uint32_t i = 0; i < elems; i++)
114 {
115 int val = fgetc(infile);
116
117 if (val == EOF)
118 {
119 rc = FILE_IO_ERROR;
120 }
121
122 buf[i] = val;
123 }
124 }
125 else
126 {
127 // Now we are at the beginning of the data
128 // Parse based on the datatype and number of dimensions
129 if (fread(databuf, elementsize, elems, infile) != elems)
130 {
131 rc = FILE_IO_ERROR;
132 }
133 }
134 }
135
136 if (infile)
137 fclose(infile);
138
139 return rc;
140}
141
142NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
143{
144 char buf[NUMPY_HEADER_SZ + 1];
145 char* ptr = nullptr;
146 NPError rc = NO_ERROR;
147 bool foundFormat = false;
148 bool foundOrder = false;
149 bool foundShape = false;
150 bool fortranOrder = false;
151 std::vector<int> shape;
152 uint32_t totalElems = 1;
153 char* outer_end = NULL;
154
155 assert(infile);
156 assert(elems > 0);
157
158 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
159 {
160 return HEADER_PARSE_ERROR;
161 }
162
163 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
164 {
165 return HEADER_PARSE_ERROR;
166 }
167
168 ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end);
169
170 // Read in the data type, order, and shape
171 while (ptr && (!foundFormat || !foundOrder || !foundShape))
172 {
173
174 // End of string?
175 if (!ptr)
176 break;
177
178 // Skip whitespace
179 while (isspace(*ptr))
180 ptr++;
181
182 // Parse the dictionary field name
183 if (!strcmp(ptr, "'descr'"))
184 {
185 ptr = strtok_r(NULL, ",", &outer_end);
186 if (!ptr)
187 break;
188
189 while (isspace(*ptr))
190 ptr++;
191
192 if (strcmp(ptr, dtype_str))
193 {
194 return FILE_TYPE_MISMATCH;
195 }
196
197 foundFormat = true;
198 }
199 else if (!strcmp(ptr, "'fortran_order'"))
200 {
201 ptr = strtok_r(NULL, ",", &outer_end);
202 if (!ptr)
203 break;
204
205 while (isspace(*ptr))
206 ptr++;
207
208 if (!strcmp(ptr, "False"))
209 {
210 fortranOrder = false;
211 }
212 else
213 {
214 return FILE_TYPE_MISMATCH;
215 }
216
217 foundOrder = true;
218 }
219 else if (!strcmp(ptr, "'shape'"))
220 {
221
222 ptr = strtok_r(NULL, "(", &outer_end);
223 if (!ptr)
224 break;
225 ptr = strtok_r(NULL, ")", &outer_end);
226 if (!ptr)
227 break;
228
229 while (isspace(*ptr))
230 ptr++;
231
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100232 // The shape contains N comma-separated integers. Read up to MAX_DIMS.
Eric Kunze2364dcd2021-04-26 11:06:57 -0700233 char* end = NULL;
234
235 ptr = strtok_r(ptr, ",", &end);
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100236 for (int i = 0; i < NUMPY_MAX_DIMS_SUPPORTED; i++)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700237 {
238 // Out of dimensions
239 if (!ptr)
240 break;
241
242 int dim = atoi(ptr);
243
244 // Dimension is 0
245 if (dim == 0)
246 break;
247
248 shape.push_back(dim);
249 totalElems *= dim;
250 ptr = strtok_r(NULL, ",", &end);
251 }
252
253 foundShape = true;
254 }
255 else
256 {
257 return HEADER_PARSE_ERROR;
258 }
259
260 if (!ptr)
261 break;
262
263 ptr = strtok_r(NULL, ":", &outer_end);
264 }
265
266 if (!foundShape || !foundFormat || !foundOrder)
267 {
268 return HEADER_PARSE_ERROR;
269 }
270
271 // Validate header
272 if (fortranOrder)
273 {
274 return FILE_TYPE_MISMATCH;
275 }
276
277 if (totalElems != elems)
278 {
279 return BUFFER_SIZE_MISMATCH;
280 }
281
282 // Go back to the begininng and read until the end of the header dictionary
283 rewind(infile);
284 int val;
285
286 do
287 {
288 val = fgetc(infile);
289 } while (val != EOF && val != '\n');
290
291 return rc;
292}
293
294NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf)
295{
296 std::vector<int32_t> shape = { (int32_t)elems };
297 return writeToNpyFile(filename, shape, databuf);
298}
299
300NumpyUtilities::NPError
301 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf)
302{
303 const char dtype_str[] = "'|b1'";
304 return writeToNpyFileCommon(filename, dtype_str, 1, shape, databuf, true); // bools written as size 1
305}
306
307NumpyUtilities::NPError
Jerry Gec93cb4a2023-05-18 21:16:22 +0000308 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const uint8_t* databuf)
309{
310 std::vector<int32_t> shape = { (int32_t)elems };
311 return writeToNpyFile(filename, shape, databuf);
312}
313
314NumpyUtilities::NPError
315 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint8_t* databuf)
316{
317 const char dtype_str[] = "'|u1'";
318 return writeToNpyFileCommon(filename, dtype_str, sizeof(uint8_t), shape, databuf, false);
319}
320
321NumpyUtilities::NPError
322 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int8_t* databuf)
323{
324 std::vector<int32_t> shape = { (int32_t)elems };
325 return writeToNpyFile(filename, shape, databuf);
326}
327
328NumpyUtilities::NPError
329 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int8_t* databuf)
330{
331 const char dtype_str[] = "'|i1'";
332 return writeToNpyFileCommon(filename, dtype_str, sizeof(int8_t), shape, databuf, false);
333}
334
335NumpyUtilities::NPError
336 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const uint16_t* databuf)
337{
338 std::vector<int32_t> shape = { (int32_t)elems };
339 return writeToNpyFile(filename, shape, databuf);
340}
341
342NumpyUtilities::NPError
343 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint16_t* databuf)
344{
345 const char dtype_str[] = "'<u2'";
346 return writeToNpyFileCommon(filename, dtype_str, sizeof(uint16_t), shape, databuf, false);
347}
348
349NumpyUtilities::NPError
350 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int16_t* databuf)
351{
352 std::vector<int32_t> shape = { (int32_t)elems };
353 return writeToNpyFile(filename, shape, databuf);
354}
355
356NumpyUtilities::NPError
357 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int16_t* databuf)
358{
359 const char dtype_str[] = "'<i2'";
360 return writeToNpyFileCommon(filename, dtype_str, sizeof(int16_t), shape, databuf, false);
361}
362
363NumpyUtilities::NPError
Eric Kunze2364dcd2021-04-26 11:06:57 -0700364 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf)
365{
366 std::vector<int32_t> shape = { (int32_t)elems };
367 return writeToNpyFile(filename, shape, databuf);
368}
369
370NumpyUtilities::NPError
371 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf)
372{
373 const char dtype_str[] = "'<i4'";
374 return writeToNpyFileCommon(filename, dtype_str, sizeof(int32_t), shape, databuf, false);
375}
376
377NumpyUtilities::NPError
378 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf)
379{
380 std::vector<int32_t> shape = { (int32_t)elems };
381 return writeToNpyFile(filename, shape, databuf);
382}
383
384NumpyUtilities::NPError
385 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf)
386{
387 const char dtype_str[] = "'<i8'";
388 return writeToNpyFileCommon(filename, dtype_str, sizeof(int64_t), shape, databuf, false);
389}
390
391NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf)
392{
393 std::vector<int32_t> shape = { (int32_t)elems };
394 return writeToNpyFile(filename, shape, databuf);
395}
396
397NumpyUtilities::NPError
398 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf)
399{
400 const char dtype_str[] = "'<f4'";
401 return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false);
402}
403
Tai Ly3ef34fb2023-04-04 20:34:05 +0000404NumpyUtilities::NPError
405 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf)
406{
407 std::vector<int32_t> shape = { (int32_t)elems };
408 return writeToNpyFile(filename, shape, databuf);
409}
410
411NumpyUtilities::NPError
412 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf)
413{
414 const char dtype_str[] = "'<f8'";
415 return writeToNpyFileCommon(filename, dtype_str, sizeof(double), shape, databuf, false);
416}
417
James Ward485a11d2022-08-05 13:48:37 +0100418NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename,
419 const std::vector<int32_t>& shape,
420 const half_float::half* databuf)
421{
422 const char dtype_str[] = "'<f2'";
423 return writeToNpyFileCommon(filename, dtype_str, sizeof(half_float::half), shape, databuf, false);
424}
425
Eric Kunze2364dcd2021-04-26 11:06:57 -0700426NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename,
427 const char* dtype_str,
428 const size_t elementsize,
429 const std::vector<int32_t>& shape,
430 const void* databuf,
431 bool bool_translate)
432{
433 FILE* outfile = nullptr;
434 NPError rc = NO_ERROR;
435 uint32_t totalElems = 1;
436
437 assert(filename);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700438 assert(databuf);
439
440 outfile = fopen(filename, "wb");
441
442 if (!outfile)
443 {
444 return FILE_NOT_FOUND;
445 }
446
447 for (uint32_t i = 0; i < shape.size(); i++)
448 {
449 totalElems *= shape[i];
450 }
451
452 rc = writeNpyHeader(outfile, shape, dtype_str);
453
454 if (rc == NO_ERROR)
455 {
456 if (bool_translate)
457 {
458 // Numpy save format stores booleans as a byte array
459 // with one byte per boolean. This somewhat inefficiently
460 // remaps from system bool[] to this format.
461 const bool* buf = reinterpret_cast<const bool*>(databuf);
462 for (uint32_t i = 0; i < totalElems; i++)
463 {
464 int val = buf[i] ? 1 : 0;
465 if (fputc(val, outfile) == EOF)
466 {
467 rc = FILE_IO_ERROR;
468 }
469 }
470 }
471 else
472 {
473 if (fwrite(databuf, elementsize, totalElems, outfile) != totalElems)
474 {
475 rc = FILE_IO_ERROR;
476 }
477 }
478 }
479
480 if (outfile)
481 fclose(outfile);
482
483 return rc;
484}
485
486NumpyUtilities::NPError
487 NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str)
488{
489 NPError rc = NO_ERROR;
490 uint32_t i;
491 char header[NUMPY_HEADER_SZ + 1];
492 int headerPos = 0;
493
494 assert(outfile);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700495
496 // Space-fill the header and end with a newline to start per numpy spec
497 memset(header, 0x20, NUMPY_HEADER_SZ);
498 header[NUMPY_HEADER_SZ - 1] = '\n';
499 header[NUMPY_HEADER_SZ] = 0;
500
501 // Write out the hard-coded header. We only support a 128-byte 1.0 header
502 // for now, which should be sufficient for simple tensor types of any
503 // reasonable rank.
504 memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
505 headerPos += sizeof(NUMPY_HEADER_STR) - 1;
506
507 // Output the format dictionary
508 // Hard-coded for I32 for now
509 headerPos +=
510 snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,",
511 dtype_str, shape.empty() ? 1 : shape[0]);
512
513 // Remainder of shape array
514 for (i = 1; i < shape.size(); i++)
515 {
516 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
517 }
518
519 // Close off the dictionary
520 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }");
521
522 // snprintf leaves a NULL at the end. Replace with a space
523 header[headerPos] = 0x20;
524
525 if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1)
526 {
527 rc = FILE_IO_ERROR;
528 }
529
530 return rc;
531}