blob: 7cf5f94ecf6b0a406054a3b2c4be44c71241ea07 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
Eric Kunzecc426df2024-01-03 00:27:59 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunze2364dcd2021-04-26 11:06:57 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "numpy_utils.h"
James Ward485a11d2022-08-05 13:48:37 +010017#include "half.hpp"
Jerry Ge13a32912023-07-03 16:36:41 +000018#include <algorithm>
Eric Kunzecc426df2024-01-03 00:27:59 +000019#include <memory>
TatWai Chong679bdad2023-07-31 15:15:12 -070020
Eric Kunze2364dcd2021-04-26 11:06:57 -070021// Magic NUMPY header
22static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
23static const int NUMPY_HEADER_SZ = 128;
Jeremy Johnson82dbb322021-07-08 11:53:04 +010024// Maximum shape dimensions supported
25static const int NUMPY_MAX_DIMS_SUPPORTED = 10;
Eric Kunze2364dcd2021-04-26 11:06:57 -070026
TatWai Chong679bdad2023-07-31 15:15:12 -070027// This is an entry function for reading 8-/16-/32-bit npy file.
28template <>
Eric Kunze2364dcd2021-04-26 11:06:57 -070029NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
30{
Jerry Ge13a32912023-07-03 16:36:41 +000031 FILE* infile = nullptr;
32 NPError rc = HEADER_PARSE_ERROR;
33 assert(filename);
34 assert(databuf);
35
36 infile = fopen(filename, "rb");
37 if (!infile)
38 {
39 return FILE_NOT_FOUND;
40 }
41
TatWai Chong679bdad2023-07-31 15:15:12 -070042 bool is_signed = false;
43 int length_per_byte = 0;
Jerry Ge13a32912023-07-03 16:36:41 +000044 char byte_order;
TatWai Chong679bdad2023-07-31 15:15:12 -070045 rc = getHeader(infile, is_signed, length_per_byte, byte_order);
Jerry Ge13a32912023-07-03 16:36:41 +000046 if (rc != NO_ERROR)
47 return rc;
48
TatWai Chong679bdad2023-07-31 15:15:12 -070049 switch (length_per_byte)
Jerry Ge13a32912023-07-03 16:36:41 +000050 {
TatWai Chong679bdad2023-07-31 15:15:12 -070051 case 1:
Jerry Ge13a32912023-07-03 16:36:41 +000052 if (is_signed)
53 {
TatWai Chong679bdad2023-07-31 15:15:12 -070054 int8_t* tmp_buf = new int8_t[elems];
55 rc = readFromNpyFile<int8_t>(filename, elems, tmp_buf);
56 copyBufferByElement(databuf, tmp_buf, elems);
Eric Kunzecc426df2024-01-03 00:27:59 +000057 delete[] tmp_buf;
Jerry Ge13a32912023-07-03 16:36:41 +000058 }
59 else
60 {
TatWai Chong679bdad2023-07-31 15:15:12 -070061 uint8_t* tmp_buf = new uint8_t[elems];
62 rc = readFromNpyFile<uint8_t>(filename, elems, tmp_buf);
63 copyBufferByElement(databuf, tmp_buf, elems);
Eric Kunzecc426df2024-01-03 00:27:59 +000064 delete[] tmp_buf;
Jerry Ge13a32912023-07-03 16:36:41 +000065 }
66 break;
TatWai Chong679bdad2023-07-31 15:15:12 -070067 case 2:
Jerry Ge13a32912023-07-03 16:36:41 +000068 if (is_signed)
69 {
TatWai Chong679bdad2023-07-31 15:15:12 -070070 int16_t* tmp_buf = new int16_t[elems];
71 rc = readFromNpyFile<int16_t>(filename, elems, tmp_buf);
72 copyBufferByElement(databuf, tmp_buf, elems);
Eric Kunzecc426df2024-01-03 00:27:59 +000073 delete[] tmp_buf;
Jerry Ge13a32912023-07-03 16:36:41 +000074 }
75 else
76 {
TatWai Chong679bdad2023-07-31 15:15:12 -070077 uint16_t* tmp_buf = new uint16_t[elems];
78 rc = readFromNpyFile<uint16_t>(filename, elems, tmp_buf);
79 copyBufferByElement(databuf, tmp_buf, elems);
Eric Kunzecc426df2024-01-03 00:27:59 +000080 delete[] tmp_buf;
Jerry Ge13a32912023-07-03 16:36:41 +000081 }
82 break;
TatWai Chong679bdad2023-07-31 15:15:12 -070083 case 4:
Jerry Ge13a32912023-07-03 16:36:41 +000084 if (is_signed)
85 {
TatWai Chong679bdad2023-07-31 15:15:12 -070086 bool is_bool;
87 const char* dtype_str = getDTypeString<int32_t>(is_bool);
88 rc = readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, is_bool);
Jerry Ge13a32912023-07-03 16:36:41 +000089 }
90 else
91 {
92 // uint32, not supported
TatWai Chong679bdad2023-07-31 15:15:12 -070093 rc = DATA_TYPE_NOT_SUPPORTED;
Jerry Ge13a32912023-07-03 16:36:41 +000094 }
95 break;
96 default:
97 return DATA_TYPE_NOT_SUPPORTED;
98 break;
99 }
100
101 return rc;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700102}
103
Eric Kunze2364dcd2021-04-26 11:06:57 -0700104NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename,
105 const char* dtype_str,
106 const size_t elementsize,
107 const uint32_t elems,
108 void* databuf,
109 bool bool_translate)
110{
111 FILE* infile = nullptr;
112 NPError rc = NO_ERROR;
113
114 assert(filename);
115 assert(databuf);
116
117 infile = fopen(filename, "rb");
118 if (!infile)
119 {
120 return FILE_NOT_FOUND;
121 }
122
123 rc = checkNpyHeader(infile, elems, dtype_str);
124 if (rc == NO_ERROR)
125 {
126 if (bool_translate)
127 {
128 // Read in the data from numpy byte array to native bool
129 // array format
130 bool* buf = reinterpret_cast<bool*>(databuf);
131 for (uint32_t i = 0; i < elems; i++)
132 {
133 int val = fgetc(infile);
134
135 if (val == EOF)
136 {
137 rc = FILE_IO_ERROR;
138 }
139
140 buf[i] = val;
141 }
142 }
143 else
144 {
145 // Now we are at the beginning of the data
146 // Parse based on the datatype and number of dimensions
147 if (fread(databuf, elementsize, elems, infile) != elems)
148 {
149 rc = FILE_IO_ERROR;
150 }
151 }
152 }
153
154 if (infile)
155 fclose(infile);
156
157 return rc;
158}
159
Jerry Ge13a32912023-07-03 16:36:41 +0000160NumpyUtilities::NPError NumpyUtilities::getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order)
161{
162 char buf[NUMPY_HEADER_SZ + 1];
163 NPError rc = NO_ERROR;
164 assert(infile);
165
166 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
167 {
168 return HEADER_PARSE_ERROR;
169 }
Jerry Ge13a32912023-07-03 16:36:41 +0000170
Jerry Ge0bc46e42023-07-07 21:52:26 +0000171 // Validate the numpy magic number
172 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
173 {
174 return HEADER_PARSE_ERROR;
175 }
Jerry Ge13a32912023-07-03 16:36:41 +0000176
Jerry Ge0bc46e42023-07-07 21:52:26 +0000177 std::string dic_string(buf, NUMPY_HEADER_SZ);
Jerry Ge13a32912023-07-03 16:36:41 +0000178
Jerry Ge0bc46e42023-07-07 21:52:26 +0000179 std::string desc_str("descr':");
180 size_t offset = dic_string.find(desc_str);
181 if (offset == std::string::npos)
182 return HEADER_PARSE_ERROR;
Jerry Ge13a32912023-07-03 16:36:41 +0000183
Jerry Ge0bc46e42023-07-07 21:52:26 +0000184 offset += desc_str.size() + 1;
185 // Skip whitespace and the opening '
186 while (offset < dic_string.size() && (std::isspace(dic_string[offset]) || dic_string[offset] == '\''))
187 offset++;
188 // Check for overflow
189 if (offset + 2 > dic_string.size())
190 return HEADER_PARSE_ERROR;
Jerry Ge13a32912023-07-03 16:36:41 +0000191
Jerry Ge0bc46e42023-07-07 21:52:26 +0000192 byte_order = dic_string[offset];
193 is_signed = dic_string[offset + 1] == 'u' ? false : true;
194 bit_length = (int)dic_string[offset + 2] - '0';
Jerry Ge13a32912023-07-03 16:36:41 +0000195
196 rewind(infile);
197 return rc;
198}
199
Eric Kunze2364dcd2021-04-26 11:06:57 -0700200NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
201{
202 char buf[NUMPY_HEADER_SZ + 1];
203 char* ptr = nullptr;
204 NPError rc = NO_ERROR;
205 bool foundFormat = false;
206 bool foundOrder = false;
207 bool foundShape = false;
208 bool fortranOrder = false;
209 std::vector<int> shape;
210 uint32_t totalElems = 1;
211 char* outer_end = NULL;
212
213 assert(infile);
214 assert(elems > 0);
215
216 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
217 {
218 return HEADER_PARSE_ERROR;
219 }
220
221 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
222 {
223 return HEADER_PARSE_ERROR;
224 }
225
226 ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end);
227
228 // Read in the data type, order, and shape
229 while (ptr && (!foundFormat || !foundOrder || !foundShape))
230 {
231
232 // End of string?
233 if (!ptr)
234 break;
235
236 // Skip whitespace
237 while (isspace(*ptr))
238 ptr++;
239
240 // Parse the dictionary field name
241 if (!strcmp(ptr, "'descr'"))
242 {
243 ptr = strtok_r(NULL, ",", &outer_end);
244 if (!ptr)
245 break;
246
247 while (isspace(*ptr))
248 ptr++;
249
Won Jeona8141522024-04-29 23:57:27 +0000250 // ml_dtypes writes '<f1' for 'numpy.dtype' in the header for float8_e5m2, but
251 // default NumPy does not understand this notation, which causes trouble
252 // when other code tries to open this file.
253 // To avoid this, '|u1' notation is used when the file is written, and the uint8
254 // data is viewed as float8_e5m2 later when the file is read.
255 if (!strcmp(dtype_str, "'<f1'"))
256 dtype_str = "'|u1'";
257
Eric Kunze2364dcd2021-04-26 11:06:57 -0700258 if (strcmp(ptr, dtype_str))
259 {
260 return FILE_TYPE_MISMATCH;
261 }
262
263 foundFormat = true;
264 }
265 else if (!strcmp(ptr, "'fortran_order'"))
266 {
267 ptr = strtok_r(NULL, ",", &outer_end);
268 if (!ptr)
269 break;
270
271 while (isspace(*ptr))
272 ptr++;
273
274 if (!strcmp(ptr, "False"))
275 {
276 fortranOrder = false;
277 }
278 else
279 {
280 return FILE_TYPE_MISMATCH;
281 }
282
283 foundOrder = true;
284 }
285 else if (!strcmp(ptr, "'shape'"))
286 {
287
288 ptr = strtok_r(NULL, "(", &outer_end);
289 if (!ptr)
290 break;
291 ptr = strtok_r(NULL, ")", &outer_end);
292 if (!ptr)
293 break;
294
295 while (isspace(*ptr))
296 ptr++;
297
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100298 // The shape contains N comma-separated integers. Read up to MAX_DIMS.
Eric Kunze2364dcd2021-04-26 11:06:57 -0700299 char* end = NULL;
300
301 ptr = strtok_r(ptr, ",", &end);
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100302 for (int i = 0; i < NUMPY_MAX_DIMS_SUPPORTED; i++)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700303 {
304 // Out of dimensions
305 if (!ptr)
306 break;
307
308 int dim = atoi(ptr);
309
310 // Dimension is 0
311 if (dim == 0)
312 break;
313
314 shape.push_back(dim);
315 totalElems *= dim;
316 ptr = strtok_r(NULL, ",", &end);
317 }
318
319 foundShape = true;
320 }
321 else
322 {
323 return HEADER_PARSE_ERROR;
324 }
325
326 if (!ptr)
327 break;
328
329 ptr = strtok_r(NULL, ":", &outer_end);
330 }
331
332 if (!foundShape || !foundFormat || !foundOrder)
333 {
334 return HEADER_PARSE_ERROR;
335 }
336
337 // Validate header
338 if (fortranOrder)
339 {
340 return FILE_TYPE_MISMATCH;
341 }
342
343 if (totalElems != elems)
344 {
345 return BUFFER_SIZE_MISMATCH;
346 }
347
348 // Go back to the begininng and read until the end of the header dictionary
349 rewind(infile);
350 int val;
351
352 do
353 {
354 val = fgetc(infile);
355 } while (val != EOF && val != '\n');
356
357 return rc;
358}
359
Eric Kunze2364dcd2021-04-26 11:06:57 -0700360NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename,
361 const char* dtype_str,
362 const size_t elementsize,
363 const std::vector<int32_t>& shape,
364 const void* databuf,
365 bool bool_translate)
366{
367 FILE* outfile = nullptr;
368 NPError rc = NO_ERROR;
369 uint32_t totalElems = 1;
370
371 assert(filename);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700372 assert(databuf);
373
374 outfile = fopen(filename, "wb");
375
376 if (!outfile)
377 {
378 return FILE_NOT_FOUND;
379 }
380
381 for (uint32_t i = 0; i < shape.size(); i++)
382 {
383 totalElems *= shape[i];
384 }
385
386 rc = writeNpyHeader(outfile, shape, dtype_str);
387
388 if (rc == NO_ERROR)
389 {
390 if (bool_translate)
391 {
392 // Numpy save format stores booleans as a byte array
393 // with one byte per boolean. This somewhat inefficiently
394 // remaps from system bool[] to this format.
395 const bool* buf = reinterpret_cast<const bool*>(databuf);
396 for (uint32_t i = 0; i < totalElems; i++)
397 {
398 int val = buf[i] ? 1 : 0;
399 if (fputc(val, outfile) == EOF)
400 {
401 rc = FILE_IO_ERROR;
402 }
403 }
404 }
405 else
406 {
407 if (fwrite(databuf, elementsize, totalElems, outfile) != totalElems)
408 {
409 rc = FILE_IO_ERROR;
410 }
411 }
412 }
413
414 if (outfile)
415 fclose(outfile);
416
417 return rc;
418}
419
420NumpyUtilities::NPError
421 NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str)
422{
423 NPError rc = NO_ERROR;
424 uint32_t i;
425 char header[NUMPY_HEADER_SZ + 1];
426 int headerPos = 0;
427
428 assert(outfile);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700429
430 // Space-fill the header and end with a newline to start per numpy spec
431 memset(header, 0x20, NUMPY_HEADER_SZ);
432 header[NUMPY_HEADER_SZ - 1] = '\n';
433 header[NUMPY_HEADER_SZ] = 0;
434
435 // Write out the hard-coded header. We only support a 128-byte 1.0 header
436 // for now, which should be sufficient for simple tensor types of any
437 // reasonable rank.
438 memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
439 headerPos += sizeof(NUMPY_HEADER_STR) - 1;
440
Won Jeona8141522024-04-29 23:57:27 +0000441 // NumPy does not understand float8_e5m2, so change it to uint8 type, so that
442 // Python can read .npy files.
443 if (!strcmp(dtype_str, "'<f1'"))
444 {
445 dtype_str = "'|u1'";
446 }
447
Eric Kunze2364dcd2021-04-26 11:06:57 -0700448 // Output the format dictionary
449 // Hard-coded for I32 for now
Jeremy Johnson8f9e2842024-04-04 11:14:06 +0100450 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos,
451 "'descr': %s, 'fortran_order': False, 'shape': (", dtype_str);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700452
Jeremy Johnson8f9e2842024-04-04 11:14:06 +0100453 // Add shape contents (if any - as this will be empty for rank 0)
454 for (i = 0; i < shape.size(); i++)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700455 {
Won Jeona8141522024-04-29 23:57:27 +0000456 // Output NumPy file from tosa_refmodel_sut_run generates the shape information
457 // without a trailing comma when the rank is greater than 1.
458 if (i == 0)
459 {
460 if (shape.size() == 1)
461 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d,", shape[i]);
462 else
463 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "%d", shape[i]);
464 }
465 else
466 {
467 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, ", %d", shape[i]);
468 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700469 }
470
471 // Close off the dictionary
472 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }");
473
474 // snprintf leaves a NULL at the end. Replace with a space
475 header[headerPos] = 0x20;
476
477 if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1)
478 {
479 rc = FILE_IO_ERROR;
480 }
481
482 return rc;
483}