blob: d31ec1c4d7c9014298d84af9a4116761dbc7d064 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
Jerry Ge13a32912023-07-03 16:36:41 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunze2364dcd2021-04-26 11:06:57 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "numpy_utils.h"
James Ward485a11d2022-08-05 13:48:37 +010017#include "half.hpp"
Jerry Ge13a32912023-07-03 16:36:41 +000018#include <algorithm>
Eric Kunze2364dcd2021-04-26 11:06:57 -070019// Magic NUMPY header
20static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
21static const int NUMPY_HEADER_SZ = 128;
Jeremy Johnson82dbb322021-07-08 11:53:04 +010022// Maximum shape dimensions supported
23static const int NUMPY_MAX_DIMS_SUPPORTED = 10;
Jerry Ge13a32912023-07-03 16:36:41 +000024// Offset for NUMPY header desc dictionary string
25static const int NUMPY_HEADER_DESC_OFFSET = 8;
Eric Kunze2364dcd2021-04-26 11:06:57 -070026
27NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
28{
29 const char dtype_str[] = "'|b1'";
30 return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true);
31}
32
33NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
34{
Jerry Ge13a32912023-07-03 16:36:41 +000035 const char dtype_str_uint8[] = "'|u1'";
36 const char dtype_str_int8[] = "'|i1'";
37 const char dtype_str_uint16[] = "'<u2'";
38 const char dtype_str_int16[] = "'<i2'";
39 const char dtype_str_int32[] = "'<i4'";
40
41 FILE* infile = nullptr;
42 NPError rc = HEADER_PARSE_ERROR;
43 assert(filename);
44 assert(databuf);
45
46 infile = fopen(filename, "rb");
47 if (!infile)
48 {
49 return FILE_NOT_FOUND;
50 }
51
52 bool is_signed = false;
53 int bit_length;
54 char byte_order;
55 rc = getHeader(infile, is_signed, bit_length, byte_order);
56 if (rc != NO_ERROR)
57 return rc;
58
59 switch (bit_length)
60 {
61 case 1: // 8-bit
62 if (is_signed)
63 {
64 // int8
65 int8_t* i8databuf = nullptr;
66 i8databuf = (int8_t*)calloc(sizeof(i8databuf), elems);
67
68 rc = readFromNpyFileCommon(filename, dtype_str_int8, sizeof(int8_t), elems, i8databuf, false);
69
70 for (unsigned i = 0; i < elems; ++i)
71 {
72 databuf[i] = (int32_t)i8databuf[i];
73 }
74 free(i8databuf);
75
76 return rc;
77 }
78 else
79 {
80 // uint8
81 uint8_t* ui8databuf = nullptr;
82 ui8databuf = (uint8_t*)calloc(sizeof(ui8databuf), elems);
83
84 rc = readFromNpyFileCommon(filename, dtype_str_uint8, sizeof(uint8_t), elems, ui8databuf, false);
85
86 for (unsigned i = 0; i < elems; ++i)
87 {
88 databuf[i] = (int32_t)ui8databuf[i];
89 }
90 free(ui8databuf);
91 }
92 break;
93 case 2: // 16-bit
94 if (is_signed)
95 {
96 // int16
97 int16_t* i16databuf = nullptr;
98 i16databuf = (int16_t*)calloc(sizeof(i16databuf), elems);
99
100 rc = readFromNpyFileCommon(filename, dtype_str_int16, sizeof(int16_t), elems, i16databuf, false);
101
102 for (unsigned i = 0; i < elems; ++i)
103 {
104 databuf[i] = (int32_t)i16databuf[i];
105 }
106 free(i16databuf);
107
108 return rc;
109 }
110 else
111 {
112 // uint16
113 uint16_t* ui16databuf = nullptr;
114 ui16databuf = (uint16_t*)calloc(sizeof(ui16databuf), elems);
115
116 rc = readFromNpyFileCommon(filename, dtype_str_uint16, sizeof(uint16_t), elems, ui16databuf, false);
117
118 for (unsigned i = 0; i < elems; ++i)
119 {
120 databuf[i] = (int32_t)ui16databuf[i];
121 }
122 free(ui16databuf);
123
124 return rc;
125 }
126 break;
127 case 4: // 32-bit
128 if (is_signed)
129 {
130 // int32
131 return readFromNpyFileCommon(filename, dtype_str_int32, sizeof(int32_t), elems, databuf, false);
132 }
133 else
134 {
135 // uint32, not supported
136 return DATA_TYPE_NOT_SUPPORTED;
137 }
138 break;
139 default:
140 return DATA_TYPE_NOT_SUPPORTED;
141 break;
142 }
143
144 return rc;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700145}
146
147NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
148{
149 const char dtype_str[] = "'<i8'";
150 return readFromNpyFileCommon(filename, dtype_str, sizeof(int64_t), elems, databuf, false);
151}
152
153NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf)
154{
155 const char dtype_str[] = "'<f4'";
156 return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false);
157}
158
Tai Ly3ef34fb2023-04-04 20:34:05 +0000159NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, double* databuf)
160{
161 const char dtype_str[] = "'<f8'";
162 return readFromNpyFileCommon(filename, dtype_str, sizeof(double), elems, databuf, false);
163}
164
James Ward485a11d2022-08-05 13:48:37 +0100165NumpyUtilities::NPError
166 NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf)
167{
168 const char dtype_str[] = "'<f2'";
169 return readFromNpyFileCommon(filename, dtype_str, sizeof(half_float::half), elems, databuf, false);
170}
171
Eric Kunze2364dcd2021-04-26 11:06:57 -0700172NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename,
173 const char* dtype_str,
174 const size_t elementsize,
175 const uint32_t elems,
176 void* databuf,
177 bool bool_translate)
178{
179 FILE* infile = nullptr;
180 NPError rc = NO_ERROR;
181
182 assert(filename);
183 assert(databuf);
184
185 infile = fopen(filename, "rb");
186 if (!infile)
187 {
188 return FILE_NOT_FOUND;
189 }
190
191 rc = checkNpyHeader(infile, elems, dtype_str);
192 if (rc == NO_ERROR)
193 {
194 if (bool_translate)
195 {
196 // Read in the data from numpy byte array to native bool
197 // array format
198 bool* buf = reinterpret_cast<bool*>(databuf);
199 for (uint32_t i = 0; i < elems; i++)
200 {
201 int val = fgetc(infile);
202
203 if (val == EOF)
204 {
205 rc = FILE_IO_ERROR;
206 }
207
208 buf[i] = val;
209 }
210 }
211 else
212 {
213 // Now we are at the beginning of the data
214 // Parse based on the datatype and number of dimensions
215 if (fread(databuf, elementsize, elems, infile) != elems)
216 {
217 rc = FILE_IO_ERROR;
218 }
219 }
220 }
221
222 if (infile)
223 fclose(infile);
224
225 return rc;
226}
227
Jerry Ge13a32912023-07-03 16:36:41 +0000228NumpyUtilities::NPError NumpyUtilities::getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order)
229{
230 char buf[NUMPY_HEADER_SZ + 1];
231 NPError rc = NO_ERROR;
232 assert(infile);
233
234 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
235 {
236 return HEADER_PARSE_ERROR;
237 }
238 char* ptr;
239 ptr = buf + sizeof(NUMPY_HEADER_STR) - 1;
240
241 std::string dic_string(ptr);
242 auto descr_loc = dic_string.find("descr");
243
244 // Reference: https://en.cppreference.com/w/cpp/algorithm/remove
245 // remove all the white spaces for the following offset NUMPY_HEADER_DESC_OFFSET to work
246 dic_string.erase(
247 std::remove_if(dic_string.begin(), dic_string.end(), [](unsigned char x) { return std::isspace(x); }),
248 dic_string.end());
249 // The dic_string is constant: descr': ', add a offset of NUMPY_HEADER_DESC_OFFSET
250 // to the actual dtype string station
251 dic_string = dic_string.substr(descr_loc + NUMPY_HEADER_DESC_OFFSET, 3);
252
253 // Fill byte_order;
254 char byte_order_c[1];
255 strcpy(byte_order_c, dic_string.substr(0, 1).c_str());
256 byte_order = byte_order_c[0];
257
258 // Fill is_signed
259 char is_signed_c[1];
260 strcpy(is_signed_c, dic_string.substr(1, 1).c_str());
261 is_signed = is_signed_c[0] == 'u' ? false : true;
262
263 // Fill bit_length
264 char bit_length_c[1];
265 strcpy(bit_length_c, dic_string.substr(2, 1).c_str());
266 bit_length = (int)(bit_length_c[0] - '0');
267
268 rewind(infile);
269 return rc;
270}
271
Eric Kunze2364dcd2021-04-26 11:06:57 -0700272NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
273{
274 char buf[NUMPY_HEADER_SZ + 1];
275 char* ptr = nullptr;
276 NPError rc = NO_ERROR;
277 bool foundFormat = false;
278 bool foundOrder = false;
279 bool foundShape = false;
280 bool fortranOrder = false;
281 std::vector<int> shape;
282 uint32_t totalElems = 1;
283 char* outer_end = NULL;
284
285 assert(infile);
286 assert(elems > 0);
287
288 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
289 {
290 return HEADER_PARSE_ERROR;
291 }
292
293 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
294 {
295 return HEADER_PARSE_ERROR;
296 }
297
298 ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end);
299
300 // Read in the data type, order, and shape
301 while (ptr && (!foundFormat || !foundOrder || !foundShape))
302 {
303
304 // End of string?
305 if (!ptr)
306 break;
307
308 // Skip whitespace
309 while (isspace(*ptr))
310 ptr++;
311
312 // Parse the dictionary field name
313 if (!strcmp(ptr, "'descr'"))
314 {
315 ptr = strtok_r(NULL, ",", &outer_end);
316 if (!ptr)
317 break;
318
319 while (isspace(*ptr))
320 ptr++;
321
322 if (strcmp(ptr, dtype_str))
323 {
324 return FILE_TYPE_MISMATCH;
325 }
326
327 foundFormat = true;
328 }
329 else if (!strcmp(ptr, "'fortran_order'"))
330 {
331 ptr = strtok_r(NULL, ",", &outer_end);
332 if (!ptr)
333 break;
334
335 while (isspace(*ptr))
336 ptr++;
337
338 if (!strcmp(ptr, "False"))
339 {
340 fortranOrder = false;
341 }
342 else
343 {
344 return FILE_TYPE_MISMATCH;
345 }
346
347 foundOrder = true;
348 }
349 else if (!strcmp(ptr, "'shape'"))
350 {
351
352 ptr = strtok_r(NULL, "(", &outer_end);
353 if (!ptr)
354 break;
355 ptr = strtok_r(NULL, ")", &outer_end);
356 if (!ptr)
357 break;
358
359 while (isspace(*ptr))
360 ptr++;
361
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100362 // The shape contains N comma-separated integers. Read up to MAX_DIMS.
Eric Kunze2364dcd2021-04-26 11:06:57 -0700363 char* end = NULL;
364
365 ptr = strtok_r(ptr, ",", &end);
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100366 for (int i = 0; i < NUMPY_MAX_DIMS_SUPPORTED; i++)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700367 {
368 // Out of dimensions
369 if (!ptr)
370 break;
371
372 int dim = atoi(ptr);
373
374 // Dimension is 0
375 if (dim == 0)
376 break;
377
378 shape.push_back(dim);
379 totalElems *= dim;
380 ptr = strtok_r(NULL, ",", &end);
381 }
382
383 foundShape = true;
384 }
385 else
386 {
387 return HEADER_PARSE_ERROR;
388 }
389
390 if (!ptr)
391 break;
392
393 ptr = strtok_r(NULL, ":", &outer_end);
394 }
395
396 if (!foundShape || !foundFormat || !foundOrder)
397 {
398 return HEADER_PARSE_ERROR;
399 }
400
401 // Validate header
402 if (fortranOrder)
403 {
404 return FILE_TYPE_MISMATCH;
405 }
406
407 if (totalElems != elems)
408 {
409 return BUFFER_SIZE_MISMATCH;
410 }
411
412 // Go back to the begininng and read until the end of the header dictionary
413 rewind(infile);
414 int val;
415
416 do
417 {
418 val = fgetc(infile);
419 } while (val != EOF && val != '\n');
420
421 return rc;
422}
423
424NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf)
425{
426 std::vector<int32_t> shape = { (int32_t)elems };
427 return writeToNpyFile(filename, shape, databuf);
428}
429
430NumpyUtilities::NPError
431 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf)
432{
433 const char dtype_str[] = "'|b1'";
434 return writeToNpyFileCommon(filename, dtype_str, 1, shape, databuf, true); // bools written as size 1
435}
436
437NumpyUtilities::NPError
Jerry Gec93cb4a2023-05-18 21:16:22 +0000438 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const uint8_t* databuf)
439{
440 std::vector<int32_t> shape = { (int32_t)elems };
441 return writeToNpyFile(filename, shape, databuf);
442}
443
444NumpyUtilities::NPError
445 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint8_t* databuf)
446{
447 const char dtype_str[] = "'|u1'";
448 return writeToNpyFileCommon(filename, dtype_str, sizeof(uint8_t), shape, databuf, false);
449}
450
451NumpyUtilities::NPError
452 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int8_t* databuf)
453{
454 std::vector<int32_t> shape = { (int32_t)elems };
455 return writeToNpyFile(filename, shape, databuf);
456}
457
458NumpyUtilities::NPError
459 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int8_t* databuf)
460{
461 const char dtype_str[] = "'|i1'";
462 return writeToNpyFileCommon(filename, dtype_str, sizeof(int8_t), shape, databuf, false);
463}
464
465NumpyUtilities::NPError
466 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const uint16_t* databuf)
467{
468 std::vector<int32_t> shape = { (int32_t)elems };
469 return writeToNpyFile(filename, shape, databuf);
470}
471
472NumpyUtilities::NPError
473 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const uint16_t* databuf)
474{
475 const char dtype_str[] = "'<u2'";
476 return writeToNpyFileCommon(filename, dtype_str, sizeof(uint16_t), shape, databuf, false);
477}
478
479NumpyUtilities::NPError
480 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int16_t* databuf)
481{
482 std::vector<int32_t> shape = { (int32_t)elems };
483 return writeToNpyFile(filename, shape, databuf);
484}
485
486NumpyUtilities::NPError
487 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int16_t* databuf)
488{
489 const char dtype_str[] = "'<i2'";
490 return writeToNpyFileCommon(filename, dtype_str, sizeof(int16_t), shape, databuf, false);
491}
492
493NumpyUtilities::NPError
Eric Kunze2364dcd2021-04-26 11:06:57 -0700494 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf)
495{
496 std::vector<int32_t> shape = { (int32_t)elems };
497 return writeToNpyFile(filename, shape, databuf);
498}
499
500NumpyUtilities::NPError
501 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf)
502{
503 const char dtype_str[] = "'<i4'";
504 return writeToNpyFileCommon(filename, dtype_str, sizeof(int32_t), shape, databuf, false);
505}
506
507NumpyUtilities::NPError
508 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf)
509{
510 std::vector<int32_t> shape = { (int32_t)elems };
511 return writeToNpyFile(filename, shape, databuf);
512}
513
514NumpyUtilities::NPError
515 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf)
516{
517 const char dtype_str[] = "'<i8'";
518 return writeToNpyFileCommon(filename, dtype_str, sizeof(int64_t), shape, databuf, false);
519}
520
521NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf)
522{
523 std::vector<int32_t> shape = { (int32_t)elems };
524 return writeToNpyFile(filename, shape, databuf);
525}
526
527NumpyUtilities::NPError
528 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf)
529{
530 const char dtype_str[] = "'<f4'";
531 return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false);
532}
533
Tai Ly3ef34fb2023-04-04 20:34:05 +0000534NumpyUtilities::NPError
535 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf)
536{
537 std::vector<int32_t> shape = { (int32_t)elems };
538 return writeToNpyFile(filename, shape, databuf);
539}
540
541NumpyUtilities::NPError
542 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf)
543{
544 const char dtype_str[] = "'<f8'";
545 return writeToNpyFileCommon(filename, dtype_str, sizeof(double), shape, databuf, false);
546}
547
James Ward485a11d2022-08-05 13:48:37 +0100548NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename,
549 const std::vector<int32_t>& shape,
550 const half_float::half* databuf)
551{
552 const char dtype_str[] = "'<f2'";
553 return writeToNpyFileCommon(filename, dtype_str, sizeof(half_float::half), shape, databuf, false);
554}
555
Eric Kunze2364dcd2021-04-26 11:06:57 -0700556NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename,
557 const char* dtype_str,
558 const size_t elementsize,
559 const std::vector<int32_t>& shape,
560 const void* databuf,
561 bool bool_translate)
562{
563 FILE* outfile = nullptr;
564 NPError rc = NO_ERROR;
565 uint32_t totalElems = 1;
566
567 assert(filename);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700568 assert(databuf);
569
570 outfile = fopen(filename, "wb");
571
572 if (!outfile)
573 {
574 return FILE_NOT_FOUND;
575 }
576
577 for (uint32_t i = 0; i < shape.size(); i++)
578 {
579 totalElems *= shape[i];
580 }
581
582 rc = writeNpyHeader(outfile, shape, dtype_str);
583
584 if (rc == NO_ERROR)
585 {
586 if (bool_translate)
587 {
588 // Numpy save format stores booleans as a byte array
589 // with one byte per boolean. This somewhat inefficiently
590 // remaps from system bool[] to this format.
591 const bool* buf = reinterpret_cast<const bool*>(databuf);
592 for (uint32_t i = 0; i < totalElems; i++)
593 {
594 int val = buf[i] ? 1 : 0;
595 if (fputc(val, outfile) == EOF)
596 {
597 rc = FILE_IO_ERROR;
598 }
599 }
600 }
601 else
602 {
603 if (fwrite(databuf, elementsize, totalElems, outfile) != totalElems)
604 {
605 rc = FILE_IO_ERROR;
606 }
607 }
608 }
609
610 if (outfile)
611 fclose(outfile);
612
613 return rc;
614}
615
616NumpyUtilities::NPError
617 NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str)
618{
619 NPError rc = NO_ERROR;
620 uint32_t i;
621 char header[NUMPY_HEADER_SZ + 1];
622 int headerPos = 0;
623
624 assert(outfile);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700625
626 // Space-fill the header and end with a newline to start per numpy spec
627 memset(header, 0x20, NUMPY_HEADER_SZ);
628 header[NUMPY_HEADER_SZ - 1] = '\n';
629 header[NUMPY_HEADER_SZ] = 0;
630
631 // Write out the hard-coded header. We only support a 128-byte 1.0 header
632 // for now, which should be sufficient for simple tensor types of any
633 // reasonable rank.
634 memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
635 headerPos += sizeof(NUMPY_HEADER_STR) - 1;
636
637 // Output the format dictionary
638 // Hard-coded for I32 for now
639 headerPos +=
640 snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,",
641 dtype_str, shape.empty() ? 1 : shape[0]);
642
643 // Remainder of shape array
644 for (i = 1; i < shape.size(); i++)
645 {
646 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
647 }
648
649 // Close off the dictionary
650 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }");
651
652 // snprintf leaves a NULL at the end. Replace with a space
653 header[headerPos] = 0x20;
654
655 if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1)
656 {
657 rc = FILE_IO_ERROR;
658 }
659
660 return rc;
661}