blob: 64460bdec0c50b1d0907b6069a2c0835d9004059 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
Jerry Ge13a32912023-07-03 16:36:41 +00002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunze2364dcd2021-04-26 11:06:57 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "numpy_utils.h"
James Ward485a11d2022-08-05 13:48:37 +010017#include "half.hpp"
Jerry Ge13a32912023-07-03 16:36:41 +000018#include <algorithm>
TatWai Chong679bdad2023-07-31 15:15:12 -070019
Eric Kunze2364dcd2021-04-26 11:06:57 -070020// Magic NUMPY header
21static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
22static const int NUMPY_HEADER_SZ = 128;
Jeremy Johnson82dbb322021-07-08 11:53:04 +010023// Maximum shape dimensions supported
24static const int NUMPY_MAX_DIMS_SUPPORTED = 10;
Jerry Ge13a32912023-07-03 16:36:41 +000025// Offset for NUMPY header desc dictionary string
26static const int NUMPY_HEADER_DESC_OFFSET = 8;
Eric Kunze2364dcd2021-04-26 11:06:57 -070027
TatWai Chong679bdad2023-07-31 15:15:12 -070028// This is an entry function for reading 8-/16-/32-bit npy file.
29template <>
Eric Kunze2364dcd2021-04-26 11:06:57 -070030NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
31{
Jerry Ge13a32912023-07-03 16:36:41 +000032 FILE* infile = nullptr;
33 NPError rc = HEADER_PARSE_ERROR;
34 assert(filename);
35 assert(databuf);
36
37 infile = fopen(filename, "rb");
38 if (!infile)
39 {
40 return FILE_NOT_FOUND;
41 }
42
TatWai Chong679bdad2023-07-31 15:15:12 -070043 bool is_signed = false;
44 int length_per_byte = 0;
Jerry Ge13a32912023-07-03 16:36:41 +000045 char byte_order;
TatWai Chong679bdad2023-07-31 15:15:12 -070046 rc = getHeader(infile, is_signed, length_per_byte, byte_order);
Jerry Ge13a32912023-07-03 16:36:41 +000047 if (rc != NO_ERROR)
48 return rc;
49
TatWai Chong679bdad2023-07-31 15:15:12 -070050 switch (length_per_byte)
Jerry Ge13a32912023-07-03 16:36:41 +000051 {
TatWai Chong679bdad2023-07-31 15:15:12 -070052 case 1:
Jerry Ge13a32912023-07-03 16:36:41 +000053 if (is_signed)
54 {
TatWai Chong679bdad2023-07-31 15:15:12 -070055 int8_t* tmp_buf = new int8_t[elems];
56 rc = readFromNpyFile<int8_t>(filename, elems, tmp_buf);
57 copyBufferByElement(databuf, tmp_buf, elems);
58 free(tmp_buf);
Jerry Ge13a32912023-07-03 16:36:41 +000059 }
60 else
61 {
TatWai Chong679bdad2023-07-31 15:15:12 -070062 uint8_t* tmp_buf = new uint8_t[elems];
63 rc = readFromNpyFile<uint8_t>(filename, elems, tmp_buf);
64 copyBufferByElement(databuf, tmp_buf, elems);
65 free(tmp_buf);
Jerry Ge13a32912023-07-03 16:36:41 +000066 }
67 break;
TatWai Chong679bdad2023-07-31 15:15:12 -070068 case 2:
Jerry Ge13a32912023-07-03 16:36:41 +000069 if (is_signed)
70 {
TatWai Chong679bdad2023-07-31 15:15:12 -070071 int16_t* tmp_buf = new int16_t[elems];
72 rc = readFromNpyFile<int16_t>(filename, elems, tmp_buf);
73 copyBufferByElement(databuf, tmp_buf, elems);
74 free(tmp_buf);
Jerry Ge13a32912023-07-03 16:36:41 +000075 }
76 else
77 {
TatWai Chong679bdad2023-07-31 15:15:12 -070078 uint16_t* tmp_buf = new uint16_t[elems];
79 rc = readFromNpyFile<uint16_t>(filename, elems, tmp_buf);
80 copyBufferByElement(databuf, tmp_buf, elems);
81 free(tmp_buf);
Jerry Ge13a32912023-07-03 16:36:41 +000082 }
83 break;
TatWai Chong679bdad2023-07-31 15:15:12 -070084 case 4:
Jerry Ge13a32912023-07-03 16:36:41 +000085 if (is_signed)
86 {
TatWai Chong679bdad2023-07-31 15:15:12 -070087 bool is_bool;
88 const char* dtype_str = getDTypeString<int32_t>(is_bool);
89 rc = readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, is_bool);
Jerry Ge13a32912023-07-03 16:36:41 +000090 }
91 else
92 {
93 // uint32, not supported
TatWai Chong679bdad2023-07-31 15:15:12 -070094 rc = DATA_TYPE_NOT_SUPPORTED;
Jerry Ge13a32912023-07-03 16:36:41 +000095 }
96 break;
97 default:
98 return DATA_TYPE_NOT_SUPPORTED;
99 break;
100 }
101
102 return rc;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700103}
104
Eric Kunze2364dcd2021-04-26 11:06:57 -0700105NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename,
106 const char* dtype_str,
107 const size_t elementsize,
108 const uint32_t elems,
109 void* databuf,
110 bool bool_translate)
111{
112 FILE* infile = nullptr;
113 NPError rc = NO_ERROR;
114
115 assert(filename);
116 assert(databuf);
117
118 infile = fopen(filename, "rb");
119 if (!infile)
120 {
121 return FILE_NOT_FOUND;
122 }
123
124 rc = checkNpyHeader(infile, elems, dtype_str);
125 if (rc == NO_ERROR)
126 {
127 if (bool_translate)
128 {
129 // Read in the data from numpy byte array to native bool
130 // array format
131 bool* buf = reinterpret_cast<bool*>(databuf);
132 for (uint32_t i = 0; i < elems; i++)
133 {
134 int val = fgetc(infile);
135
136 if (val == EOF)
137 {
138 rc = FILE_IO_ERROR;
139 }
140
141 buf[i] = val;
142 }
143 }
144 else
145 {
146 // Now we are at the beginning of the data
147 // Parse based on the datatype and number of dimensions
148 if (fread(databuf, elementsize, elems, infile) != elems)
149 {
150 rc = FILE_IO_ERROR;
151 }
152 }
153 }
154
155 if (infile)
156 fclose(infile);
157
158 return rc;
159}
160
Jerry Ge13a32912023-07-03 16:36:41 +0000161NumpyUtilities::NPError NumpyUtilities::getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order)
162{
163 char buf[NUMPY_HEADER_SZ + 1];
164 NPError rc = NO_ERROR;
165 assert(infile);
166
167 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
168 {
169 return HEADER_PARSE_ERROR;
170 }
171 char* ptr;
Jerry Ge13a32912023-07-03 16:36:41 +0000172
Jerry Ge0bc46e42023-07-07 21:52:26 +0000173 // Validate the numpy magic number
174 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
175 {
176 return HEADER_PARSE_ERROR;
177 }
Jerry Ge13a32912023-07-03 16:36:41 +0000178
Jerry Ge0bc46e42023-07-07 21:52:26 +0000179 std::string dic_string(buf, NUMPY_HEADER_SZ);
Jerry Ge13a32912023-07-03 16:36:41 +0000180
Jerry Ge0bc46e42023-07-07 21:52:26 +0000181 std::string desc_str("descr':");
182 size_t offset = dic_string.find(desc_str);
183 if (offset == std::string::npos)
184 return HEADER_PARSE_ERROR;
Jerry Ge13a32912023-07-03 16:36:41 +0000185
Jerry Ge0bc46e42023-07-07 21:52:26 +0000186 offset += desc_str.size() + 1;
187 // Skip whitespace and the opening '
188 while (offset < dic_string.size() && (std::isspace(dic_string[offset]) || dic_string[offset] == '\''))
189 offset++;
190 // Check for overflow
191 if (offset + 2 > dic_string.size())
192 return HEADER_PARSE_ERROR;
Jerry Ge13a32912023-07-03 16:36:41 +0000193
Jerry Ge0bc46e42023-07-07 21:52:26 +0000194 byte_order = dic_string[offset];
195 is_signed = dic_string[offset + 1] == 'u' ? false : true;
196 bit_length = (int)dic_string[offset + 2] - '0';
Jerry Ge13a32912023-07-03 16:36:41 +0000197
198 rewind(infile);
199 return rc;
200}
201
Eric Kunze2364dcd2021-04-26 11:06:57 -0700202NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
203{
204 char buf[NUMPY_HEADER_SZ + 1];
205 char* ptr = nullptr;
206 NPError rc = NO_ERROR;
207 bool foundFormat = false;
208 bool foundOrder = false;
209 bool foundShape = false;
210 bool fortranOrder = false;
211 std::vector<int> shape;
212 uint32_t totalElems = 1;
213 char* outer_end = NULL;
214
215 assert(infile);
216 assert(elems > 0);
217
218 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
219 {
220 return HEADER_PARSE_ERROR;
221 }
222
223 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
224 {
225 return HEADER_PARSE_ERROR;
226 }
227
228 ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end);
229
230 // Read in the data type, order, and shape
231 while (ptr && (!foundFormat || !foundOrder || !foundShape))
232 {
233
234 // End of string?
235 if (!ptr)
236 break;
237
238 // Skip whitespace
239 while (isspace(*ptr))
240 ptr++;
241
242 // Parse the dictionary field name
243 if (!strcmp(ptr, "'descr'"))
244 {
245 ptr = strtok_r(NULL, ",", &outer_end);
246 if (!ptr)
247 break;
248
249 while (isspace(*ptr))
250 ptr++;
251
252 if (strcmp(ptr, dtype_str))
253 {
254 return FILE_TYPE_MISMATCH;
255 }
256
257 foundFormat = true;
258 }
259 else if (!strcmp(ptr, "'fortran_order'"))
260 {
261 ptr = strtok_r(NULL, ",", &outer_end);
262 if (!ptr)
263 break;
264
265 while (isspace(*ptr))
266 ptr++;
267
268 if (!strcmp(ptr, "False"))
269 {
270 fortranOrder = false;
271 }
272 else
273 {
274 return FILE_TYPE_MISMATCH;
275 }
276
277 foundOrder = true;
278 }
279 else if (!strcmp(ptr, "'shape'"))
280 {
281
282 ptr = strtok_r(NULL, "(", &outer_end);
283 if (!ptr)
284 break;
285 ptr = strtok_r(NULL, ")", &outer_end);
286 if (!ptr)
287 break;
288
289 while (isspace(*ptr))
290 ptr++;
291
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100292 // The shape contains N comma-separated integers. Read up to MAX_DIMS.
Eric Kunze2364dcd2021-04-26 11:06:57 -0700293 char* end = NULL;
294
295 ptr = strtok_r(ptr, ",", &end);
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100296 for (int i = 0; i < NUMPY_MAX_DIMS_SUPPORTED; i++)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700297 {
298 // Out of dimensions
299 if (!ptr)
300 break;
301
302 int dim = atoi(ptr);
303
304 // Dimension is 0
305 if (dim == 0)
306 break;
307
308 shape.push_back(dim);
309 totalElems *= dim;
310 ptr = strtok_r(NULL, ",", &end);
311 }
312
313 foundShape = true;
314 }
315 else
316 {
317 return HEADER_PARSE_ERROR;
318 }
319
320 if (!ptr)
321 break;
322
323 ptr = strtok_r(NULL, ":", &outer_end);
324 }
325
326 if (!foundShape || !foundFormat || !foundOrder)
327 {
328 return HEADER_PARSE_ERROR;
329 }
330
331 // Validate header
332 if (fortranOrder)
333 {
334 return FILE_TYPE_MISMATCH;
335 }
336
337 if (totalElems != elems)
338 {
339 return BUFFER_SIZE_MISMATCH;
340 }
341
342 // Go back to the begininng and read until the end of the header dictionary
343 rewind(infile);
344 int val;
345
346 do
347 {
348 val = fgetc(infile);
349 } while (val != EOF && val != '\n');
350
351 return rc;
352}
353
Eric Kunze2364dcd2021-04-26 11:06:57 -0700354NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename,
355 const char* dtype_str,
356 const size_t elementsize,
357 const std::vector<int32_t>& shape,
358 const void* databuf,
359 bool bool_translate)
360{
361 FILE* outfile = nullptr;
362 NPError rc = NO_ERROR;
363 uint32_t totalElems = 1;
364
365 assert(filename);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700366 assert(databuf);
367
368 outfile = fopen(filename, "wb");
369
370 if (!outfile)
371 {
372 return FILE_NOT_FOUND;
373 }
374
375 for (uint32_t i = 0; i < shape.size(); i++)
376 {
377 totalElems *= shape[i];
378 }
379
380 rc = writeNpyHeader(outfile, shape, dtype_str);
381
382 if (rc == NO_ERROR)
383 {
384 if (bool_translate)
385 {
386 // Numpy save format stores booleans as a byte array
387 // with one byte per boolean. This somewhat inefficiently
388 // remaps from system bool[] to this format.
389 const bool* buf = reinterpret_cast<const bool*>(databuf);
390 for (uint32_t i = 0; i < totalElems; i++)
391 {
392 int val = buf[i] ? 1 : 0;
393 if (fputc(val, outfile) == EOF)
394 {
395 rc = FILE_IO_ERROR;
396 }
397 }
398 }
399 else
400 {
401 if (fwrite(databuf, elementsize, totalElems, outfile) != totalElems)
402 {
403 rc = FILE_IO_ERROR;
404 }
405 }
406 }
407
408 if (outfile)
409 fclose(outfile);
410
411 return rc;
412}
413
414NumpyUtilities::NPError
415 NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str)
416{
417 NPError rc = NO_ERROR;
418 uint32_t i;
419 char header[NUMPY_HEADER_SZ + 1];
420 int headerPos = 0;
421
422 assert(outfile);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700423
424 // Space-fill the header and end with a newline to start per numpy spec
425 memset(header, 0x20, NUMPY_HEADER_SZ);
426 header[NUMPY_HEADER_SZ - 1] = '\n';
427 header[NUMPY_HEADER_SZ] = 0;
428
429 // Write out the hard-coded header. We only support a 128-byte 1.0 header
430 // for now, which should be sufficient for simple tensor types of any
431 // reasonable rank.
432 memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
433 headerPos += sizeof(NUMPY_HEADER_STR) - 1;
434
435 // Output the format dictionary
436 // Hard-coded for I32 for now
437 headerPos +=
438 snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,",
439 dtype_str, shape.empty() ? 1 : shape[0]);
440
441 // Remainder of shape array
442 for (i = 1; i < shape.size(); i++)
443 {
444 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
445 }
446
447 // Close off the dictionary
448 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }");
449
450 // snprintf leaves a NULL at the end. Replace with a space
451 header[headerPos] = 0x20;
452
453 if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1)
454 {
455 rc = FILE_IO_ERROR;
456 }
457
458 return rc;
459}