blob: 0002fd9c65c545440ead822713dcb5f15049b85c [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
Jerry Ge13a32912023-07-03 16:36:41 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunze2364dcd2021-04-26 11:06:57 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "numpy_utils.h"
James Ward485a11d2022-08-05 13:48:37 +010017#include "half.hpp"
Jerry Ge13a32912023-07-03 16:36:41 +000018#include <algorithm>
Eric Kunze2364dcd2021-04-26 11:06:57 -070019// Magic NUMPY header
20static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
21static const int NUMPY_HEADER_SZ = 128;
Jeremy Johnson82dbb322021-07-08 11:53:04 +010022// Maximum shape dimensions supported
23static const int NUMPY_MAX_DIMS_SUPPORTED = 10;
Jerry Ge13a32912023-07-03 16:36:41 +000024// Offset for NUMPY header desc dictionary string
25static const int NUMPY_HEADER_DESC_OFFSET = 8;
Eric Kunze2364dcd2021-04-26 11:06:57 -070026
27NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
28{
29 const char dtype_str[] = "'|b1'";
30 return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true);
31}
32
33NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
34{
Jerry Ge13a32912023-07-03 16:36:41 +000035 const char dtype_str_uint8[] = "'|u1'";
36 const char dtype_str_int8[] = "'|i1'";
37 const char dtype_str_uint16[] = "'<u2'";
38 const char dtype_str_int16[] = "'<i2'";
39 const char dtype_str_int32[] = "'<i4'";
40
41 FILE* infile = nullptr;
42 NPError rc = HEADER_PARSE_ERROR;
43 assert(filename);
44 assert(databuf);
45
46 infile = fopen(filename, "rb");
47 if (!infile)
48 {
49 return FILE_NOT_FOUND;
50 }
51
52 bool is_signed = false;
53 int bit_length;
54 char byte_order;
55 rc = getHeader(infile, is_signed, bit_length, byte_order);
56 if (rc != NO_ERROR)
57 return rc;
58
59 switch (bit_length)
60 {
61 case 1: // 8-bit
62 if (is_signed)
63 {
64 // int8
65 int8_t* i8databuf = nullptr;
66 i8databuf = (int8_t*)calloc(sizeof(i8databuf), elems);
67
68 rc = readFromNpyFileCommon(filename, dtype_str_int8, sizeof(int8_t), elems, i8databuf, false);
69
70 for (unsigned i = 0; i < elems; ++i)
71 {
72 databuf[i] = (int32_t)i8databuf[i];
73 }
74 free(i8databuf);
75
76 return rc;
77 }
78 else
79 {
80 // uint8
81 uint8_t* ui8databuf = nullptr;
82 ui8databuf = (uint8_t*)calloc(sizeof(ui8databuf), elems);
83
84 rc = readFromNpyFileCommon(filename, dtype_str_uint8, sizeof(uint8_t), elems, ui8databuf, false);
85
86 for (unsigned i = 0; i < elems; ++i)
87 {
88 databuf[i] = (int32_t)ui8databuf[i];
89 }
90 free(ui8databuf);
91 }
92 break;
93 case 2: // 16-bit
94 if (is_signed)
95 {
96 // int16
97 int16_t* i16databuf = nullptr;
98 i16databuf = (int16_t*)calloc(sizeof(i16databuf), elems);
99
100 rc = readFromNpyFileCommon(filename, dtype_str_int16, sizeof(int16_t), elems, i16databuf, false);
101
102 for (unsigned i = 0; i < elems; ++i)
103 {
104 databuf[i] = (int32_t)i16databuf[i];
105 }
106 free(i16databuf);
107
108 return rc;
109 }
110 else
111 {
112 // uint16
113 uint16_t* ui16databuf = nullptr;
114 ui16databuf = (uint16_t*)calloc(sizeof(ui16databuf), elems);
115
116 rc = readFromNpyFileCommon(filename, dtype_str_uint16, sizeof(uint16_t), elems, ui16databuf, false);
117
118 for (unsigned i = 0; i < elems; ++i)
119 {
120 databuf[i] = (int32_t)ui16databuf[i];
121 }
122 free(ui16databuf);
123
124 return rc;
125 }
126 break;
127 case 4: // 32-bit
128 if (is_signed)
129 {
130 // int32
131 return readFromNpyFileCommon(filename, dtype_str_int32, sizeof(int32_t), elems, databuf, false);
132 }
133 else
134 {
135 // uint32, not supported
136 return DATA_TYPE_NOT_SUPPORTED;
137 }
138 break;
139 default:
140 return DATA_TYPE_NOT_SUPPORTED;
141 break;
142 }
143
144 return rc;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700145}
146
147NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
148{
149 const char dtype_str[] = "'<i8'";
150 return readFromNpyFileCommon(filename, dtype_str, sizeof(int64_t), elems, databuf, false);
151}
152
153NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf)
154{
155 const char dtype_str[] = "'<f4'";
156 return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false);
157}
158
Tai Ly3ef34fb2023-04-04 20:34:05 +0000159NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, double* databuf)
160{
161 const char dtype_str[] = "'<f8'";
162 return readFromNpyFileCommon(filename, dtype_str, sizeof(double), elems, databuf, false);
163}
164
James Ward485a11d2022-08-05 13:48:37 +0100165NumpyUtilities::NPError
166 NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf)
167{
168 const char dtype_str[] = "'<f2'";
169 return readFromNpyFileCommon(filename, dtype_str, sizeof(half_float::half), elems, databuf, false);
170}
171
Eric Kunze2364dcd2021-04-26 11:06:57 -0700172NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename,
173 const char* dtype_str,
174 const size_t elementsize,
175 const uint32_t elems,
176 void* databuf,
177 bool bool_translate)
178{
179 FILE* infile = nullptr;
180 NPError rc = NO_ERROR;
181
182 assert(filename);
183 assert(databuf);
184
185 infile = fopen(filename, "rb");
186 if (!infile)
187 {
188 return FILE_NOT_FOUND;
189 }
190
191 rc = checkNpyHeader(infile, elems, dtype_str);
192 if (rc == NO_ERROR)
193 {
194 if (bool_translate)
195 {
196 // Read in the data from numpy byte array to native bool
197 // array format
198 bool* buf = reinterpret_cast<bool*>(databuf);
199 for (uint32_t i = 0; i < elems; i++)
200 {
201 int val = fgetc(infile);
202
203 if (val == EOF)
204 {
205 rc = FILE_IO_ERROR;
206 }
207
208 buf[i] = val;
209 }
210 }
211 else
212 {
213 // Now we are at the beginning of the data
214 // Parse based on the datatype and number of dimensions
215 if (fread(databuf, elementsize, elems, infile) != elems)
216 {
217 rc = FILE_IO_ERROR;
218 }
219 }
220 }
221
222 if (infile)
223 fclose(infile);
224
225 return rc;
226}
227
Jerry Ge13a32912023-07-03 16:36:41 +0000228NumpyUtilities::NPError NumpyUtilities::getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order)
229{
230 char buf[NUMPY_HEADER_SZ + 1];
231 NPError rc = NO_ERROR;
232 assert(infile);
233
234 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
235 {
236 return HEADER_PARSE_ERROR;
237 }
238 char* ptr;
Jerry Ge13a32912023-07-03 16:36:41 +0000239
Jerry Ge0bc46e42023-07-07 21:52:26 +0000240 // Validate the numpy magic number
241 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
242 {
243 return HEADER_PARSE_ERROR;
244 }
Jerry Ge13a32912023-07-03 16:36:41 +0000245
Jerry Ge0bc46e42023-07-07 21:52:26 +0000246 std::string dic_string(buf, NUMPY_HEADER_SZ);
Jerry Ge13a32912023-07-03 16:36:41 +0000247
Jerry Ge0bc46e42023-07-07 21:52:26 +0000248 std::string desc_str("descr':");
249 size_t offset = dic_string.find(desc_str);
250 if (offset == std::string::npos)
251 return HEADER_PARSE_ERROR;
Jerry Ge13a32912023-07-03 16:36:41 +0000252
Jerry Ge0bc46e42023-07-07 21:52:26 +0000253 offset += desc_str.size() + 1;
254 // Skip whitespace and the opening '
255 while (offset < dic_string.size() && (std::isspace(dic_string[offset]) || dic_string[offset] == '\''))
256 offset++;
257 // Check for overflow
258 if (offset + 2 > dic_string.size())
259 return HEADER_PARSE_ERROR;
Jerry Ge13a32912023-07-03 16:36:41 +0000260
Jerry Ge0bc46e42023-07-07 21:52:26 +0000261 byte_order = dic_string[offset];
262 is_signed = dic_string[offset + 1] == 'u' ? false : true;
263 bit_length = (int)dic_string[offset + 2] - '0';
Jerry Ge13a32912023-07-03 16:36:41 +0000264
265 rewind(infile);
266 return rc;
267}
268
Eric Kunze2364dcd2021-04-26 11:06:57 -0700269NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
270{
271 char buf[NUMPY_HEADER_SZ + 1];
272 char* ptr = nullptr;
273 NPError rc = NO_ERROR;
274 bool foundFormat = false;
275 bool foundOrder = false;
276 bool foundShape = false;
277 bool fortranOrder = false;
278 std::vector<int> shape;
279 uint32_t totalElems = 1;
280 char* outer_end = NULL;
281
282 assert(infile);
283 assert(elems > 0);
284
285 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
286 {
287 return HEADER_PARSE_ERROR;
288 }
289
290 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
291 {
292 return HEADER_PARSE_ERROR;
293 }
294
295 ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end);
296
297 // Read in the data type, order, and shape
298 while (ptr && (!foundFormat || !foundOrder || !foundShape))
299 {
300
301 // End of string?
302 if (!ptr)
303 break;
304
305 // Skip whitespace
306 while (isspace(*ptr))
307 ptr++;
308
309 // Parse the dictionary field name
310 if (!strcmp(ptr, "'descr'"))
311 {
312 ptr = strtok_r(NULL, ",", &outer_end);
313 if (!ptr)
314 break;
315
316 while (isspace(*ptr))
317 ptr++;
318
319 if (strcmp(ptr, dtype_str))
320 {
321 return FILE_TYPE_MISMATCH;
322 }
323
324 foundFormat = true;
325 }
326 else if (!strcmp(ptr, "'fortran_order'"))
327 {
328 ptr = strtok_r(NULL, ",", &outer_end);
329 if (!ptr)
330 break;
331
332 while (isspace(*ptr))
333 ptr++;
334
335 if (!strcmp(ptr, "False"))
336 {
337 fortranOrder = false;
338 }
339 else
340 {
341 return FILE_TYPE_MISMATCH;
342 }
343
344 foundOrder = true;
345 }
346 else if (!strcmp(ptr, "'shape'"))
347 {
348
349 ptr = strtok_r(NULL, "(", &outer_end);
350 if (!ptr)
351 break;
352 ptr = strtok_r(NULL, ")", &outer_end);
353 if (!ptr)
354 break;
355
356 while (isspace(*ptr))
357 ptr++;
358
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100359 // The shape contains N comma-separated integers. Read up to MAX_DIMS.
Eric Kunze2364dcd2021-04-26 11:06:57 -0700360 char* end = NULL;
361
362 ptr = strtok_r(ptr, ",", &end);
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100363 for (int i = 0; i < NUMPY_MAX_DIMS_SUPPORTED; i++)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700364 {
365 // Out of dimensions
366 if (!ptr)
367 break;
368
369 int dim = atoi(ptr);
370
371 // Dimension is 0
372 if (dim == 0)
373 break;
374
375 shape.push_back(dim);
376 totalElems *= dim;
377 ptr = strtok_r(NULL, ",", &end);
378 }
379
380 foundShape = true;
381 }
382 else
383 {
384 return HEADER_PARSE_ERROR;
385 }
386
387 if (!ptr)
388 break;
389
390 ptr = strtok_r(NULL, ":", &outer_end);
391 }
392
393 if (!foundShape || !foundFormat || !foundOrder)
394 {
395 return HEADER_PARSE_ERROR;
396 }
397
398 // Validate header
399 if (fortranOrder)
400 {
401 return FILE_TYPE_MISMATCH;
402 }
403
404 if (totalElems != elems)
405 {
406 return BUFFER_SIZE_MISMATCH;
407 }
408
409 // Go back to the begininng and read until the end of the header dictionary
410 rewind(infile);
411 int val;
412
413 do
414 {
415 val = fgetc(infile);
416 } while (val != EOF && val != '\n');
417
418 return rc;
419}
420
421NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf)
422{
423 std::vector<int32_t> shape = { (int32_t)elems };
424 return writeToNpyFile(filename, shape, databuf);
425}
426
427NumpyUtilities::NPError
428 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf)
429{
430 const char dtype_str[] = "'|b1'";
431 return writeToNpyFileCommon(filename, dtype_str, 1, shape, databuf, true); // bools written as size 1
432}
433
434NumpyUtilities::NPError
Jerry Gec93cb4a2023-05-18 21:16:22 +0000435 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const uint8_t* databuf)
436{
437 std::vector<int32_t> shape = { (int32_t)elems };
438 return writeToNpyFile(filename, shape, databuf);
439}
440
441NumpyUtilities::NPError
442 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint8_t* databuf)
443{
444 const char dtype_str[] = "'|u1'";
445 return writeToNpyFileCommon(filename, dtype_str, sizeof(uint8_t), shape, databuf, false);
446}
447
448NumpyUtilities::NPError
449 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int8_t* databuf)
450{
451 std::vector<int32_t> shape = { (int32_t)elems };
452 return writeToNpyFile(filename, shape, databuf);
453}
454
455NumpyUtilities::NPError
456 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int8_t* databuf)
457{
458 const char dtype_str[] = "'|i1'";
459 return writeToNpyFileCommon(filename, dtype_str, sizeof(int8_t), shape, databuf, false);
460}
461
462NumpyUtilities::NPError
463 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const uint16_t* databuf)
464{
465 std::vector<int32_t> shape = { (int32_t)elems };
466 return writeToNpyFile(filename, shape, databuf);
467}
468
469NumpyUtilities::NPError
470 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint16_t* databuf)
471{
472 const char dtype_str[] = "'<u2'";
473 return writeToNpyFileCommon(filename, dtype_str, sizeof(uint16_t), shape, databuf, false);
474}
475
476NumpyUtilities::NPError
477 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int16_t* databuf)
478{
479 std::vector<int32_t> shape = { (int32_t)elems };
480 return writeToNpyFile(filename, shape, databuf);
481}
482
483NumpyUtilities::NPError
484 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int16_t* databuf)
485{
486 const char dtype_str[] = "'<i2'";
487 return writeToNpyFileCommon(filename, dtype_str, sizeof(int16_t), shape, databuf, false);
488}
489
490NumpyUtilities::NPError
Eric Kunze2364dcd2021-04-26 11:06:57 -0700491 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf)
492{
493 std::vector<int32_t> shape = { (int32_t)elems };
494 return writeToNpyFile(filename, shape, databuf);
495}
496
497NumpyUtilities::NPError
498 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf)
499{
500 const char dtype_str[] = "'<i4'";
501 return writeToNpyFileCommon(filename, dtype_str, sizeof(int32_t), shape, databuf, false);
502}
503
504NumpyUtilities::NPError
505 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf)
506{
507 std::vector<int32_t> shape = { (int32_t)elems };
508 return writeToNpyFile(filename, shape, databuf);
509}
510
511NumpyUtilities::NPError
512 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf)
513{
514 const char dtype_str[] = "'<i8'";
515 return writeToNpyFileCommon(filename, dtype_str, sizeof(int64_t), shape, databuf, false);
516}
517
518NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf)
519{
520 std::vector<int32_t> shape = { (int32_t)elems };
521 return writeToNpyFile(filename, shape, databuf);
522}
523
524NumpyUtilities::NPError
525 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf)
526{
527 const char dtype_str[] = "'<f4'";
528 return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false);
529}
530
Tai Ly3ef34fb2023-04-04 20:34:05 +0000531NumpyUtilities::NPError
532 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf)
533{
534 std::vector<int32_t> shape = { (int32_t)elems };
535 return writeToNpyFile(filename, shape, databuf);
536}
537
538NumpyUtilities::NPError
539 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf)
540{
541 const char dtype_str[] = "'<f8'";
542 return writeToNpyFileCommon(filename, dtype_str, sizeof(double), shape, databuf, false);
543}
544
James Ward485a11d2022-08-05 13:48:37 +0100545NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename,
546 const std::vector<int32_t>& shape,
547 const half_float::half* databuf)
548{
549 const char dtype_str[] = "'<f2'";
550 return writeToNpyFileCommon(filename, dtype_str, sizeof(half_float::half), shape, databuf, false);
551}
552
Eric Kunze2364dcd2021-04-26 11:06:57 -0700553NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename,
554 const char* dtype_str,
555 const size_t elementsize,
556 const std::vector<int32_t>& shape,
557 const void* databuf,
558 bool bool_translate)
559{
560 FILE* outfile = nullptr;
561 NPError rc = NO_ERROR;
562 uint32_t totalElems = 1;
563
564 assert(filename);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700565 assert(databuf);
566
567 outfile = fopen(filename, "wb");
568
569 if (!outfile)
570 {
571 return FILE_NOT_FOUND;
572 }
573
574 for (uint32_t i = 0; i < shape.size(); i++)
575 {
576 totalElems *= shape[i];
577 }
578
579 rc = writeNpyHeader(outfile, shape, dtype_str);
580
581 if (rc == NO_ERROR)
582 {
583 if (bool_translate)
584 {
585 // Numpy save format stores booleans as a byte array
586 // with one byte per boolean. This somewhat inefficiently
587 // remaps from system bool[] to this format.
588 const bool* buf = reinterpret_cast<const bool*>(databuf);
589 for (uint32_t i = 0; i < totalElems; i++)
590 {
591 int val = buf[i] ? 1 : 0;
592 if (fputc(val, outfile) == EOF)
593 {
594 rc = FILE_IO_ERROR;
595 }
596 }
597 }
598 else
599 {
600 if (fwrite(databuf, elementsize, totalElems, outfile) != totalElems)
601 {
602 rc = FILE_IO_ERROR;
603 }
604 }
605 }
606
607 if (outfile)
608 fclose(outfile);
609
610 return rc;
611}
612
613NumpyUtilities::NPError
614 NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str)
615{
616 NPError rc = NO_ERROR;
617 uint32_t i;
618 char header[NUMPY_HEADER_SZ + 1];
619 int headerPos = 0;
620
621 assert(outfile);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700622
623 // Space-fill the header and end with a newline to start per numpy spec
624 memset(header, 0x20, NUMPY_HEADER_SZ);
625 header[NUMPY_HEADER_SZ - 1] = '\n';
626 header[NUMPY_HEADER_SZ] = 0;
627
628 // Write out the hard-coded header. We only support a 128-byte 1.0 header
629 // for now, which should be sufficient for simple tensor types of any
630 // reasonable rank.
631 memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
632 headerPos += sizeof(NUMPY_HEADER_STR) - 1;
633
634 // Output the format dictionary
635 // Hard-coded for I32 for now
636 headerPos +=
637 snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,",
638 dtype_str, shape.empty() ? 1 : shape[0]);
639
640 // Remainder of shape array
641 for (i = 1; i < shape.size(); i++)
642 {
643 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
644 }
645
646 // Close off the dictionary
647 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }");
648
649 // snprintf leaves a NULL at the end. Replace with a space
650 header[headerPos] = 0x20;
651
652 if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1)
653 {
654 rc = FILE_IO_ERROR;
655 }
656
657 return rc;
658}