blob: e4171d7e305bee4ae7cb1f6d149d63b30106edfc [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
Eric Kunzecc426df2024-01-03 00:27:59 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunze2364dcd2021-04-26 11:06:57 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "numpy_utils.h"
James Ward485a11d2022-08-05 13:48:37 +010017#include "half.hpp"
Jerry Ge13a32912023-07-03 16:36:41 +000018#include <algorithm>
Eric Kunzecc426df2024-01-03 00:27:59 +000019#include <memory>
TatWai Chong679bdad2023-07-31 15:15:12 -070020
Eric Kunze2364dcd2021-04-26 11:06:57 -070021// Magic NUMPY header
22static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
23static const int NUMPY_HEADER_SZ = 128;
Jeremy Johnson82dbb322021-07-08 11:53:04 +010024// Maximum shape dimensions supported
25static const int NUMPY_MAX_DIMS_SUPPORTED = 10;
Eric Kunze2364dcd2021-04-26 11:06:57 -070026
TatWai Chong679bdad2023-07-31 15:15:12 -070027// This is an entry function for reading 8-/16-/32-bit npy file.
28template <>
Eric Kunze2364dcd2021-04-26 11:06:57 -070029NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
30{
Jerry Ge13a32912023-07-03 16:36:41 +000031 FILE* infile = nullptr;
32 NPError rc = HEADER_PARSE_ERROR;
33 assert(filename);
34 assert(databuf);
35
36 infile = fopen(filename, "rb");
37 if (!infile)
38 {
39 return FILE_NOT_FOUND;
40 }
41
TatWai Chong679bdad2023-07-31 15:15:12 -070042 bool is_signed = false;
43 int length_per_byte = 0;
Jerry Ge13a32912023-07-03 16:36:41 +000044 char byte_order;
TatWai Chong679bdad2023-07-31 15:15:12 -070045 rc = getHeader(infile, is_signed, length_per_byte, byte_order);
Jerry Ge13a32912023-07-03 16:36:41 +000046 if (rc != NO_ERROR)
47 return rc;
48
TatWai Chong679bdad2023-07-31 15:15:12 -070049 switch (length_per_byte)
Jerry Ge13a32912023-07-03 16:36:41 +000050 {
TatWai Chong679bdad2023-07-31 15:15:12 -070051 case 1:
Jerry Ge13a32912023-07-03 16:36:41 +000052 if (is_signed)
53 {
TatWai Chong679bdad2023-07-31 15:15:12 -070054 int8_t* tmp_buf = new int8_t[elems];
55 rc = readFromNpyFile<int8_t>(filename, elems, tmp_buf);
56 copyBufferByElement(databuf, tmp_buf, elems);
Eric Kunzecc426df2024-01-03 00:27:59 +000057 delete[] tmp_buf;
Jerry Ge13a32912023-07-03 16:36:41 +000058 }
59 else
60 {
TatWai Chong679bdad2023-07-31 15:15:12 -070061 uint8_t* tmp_buf = new uint8_t[elems];
62 rc = readFromNpyFile<uint8_t>(filename, elems, tmp_buf);
63 copyBufferByElement(databuf, tmp_buf, elems);
Eric Kunzecc426df2024-01-03 00:27:59 +000064 delete[] tmp_buf;
Jerry Ge13a32912023-07-03 16:36:41 +000065 }
66 break;
TatWai Chong679bdad2023-07-31 15:15:12 -070067 case 2:
Jerry Ge13a32912023-07-03 16:36:41 +000068 if (is_signed)
69 {
TatWai Chong679bdad2023-07-31 15:15:12 -070070 int16_t* tmp_buf = new int16_t[elems];
71 rc = readFromNpyFile<int16_t>(filename, elems, tmp_buf);
72 copyBufferByElement(databuf, tmp_buf, elems);
Eric Kunzecc426df2024-01-03 00:27:59 +000073 delete[] tmp_buf;
Jerry Ge13a32912023-07-03 16:36:41 +000074 }
75 else
76 {
TatWai Chong679bdad2023-07-31 15:15:12 -070077 uint16_t* tmp_buf = new uint16_t[elems];
78 rc = readFromNpyFile<uint16_t>(filename, elems, tmp_buf);
79 copyBufferByElement(databuf, tmp_buf, elems);
Eric Kunzecc426df2024-01-03 00:27:59 +000080 delete[] tmp_buf;
Jerry Ge13a32912023-07-03 16:36:41 +000081 }
82 break;
TatWai Chong679bdad2023-07-31 15:15:12 -070083 case 4:
Jerry Ge13a32912023-07-03 16:36:41 +000084 if (is_signed)
85 {
TatWai Chong679bdad2023-07-31 15:15:12 -070086 bool is_bool;
87 const char* dtype_str = getDTypeString<int32_t>(is_bool);
88 rc = readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, is_bool);
Jerry Ge13a32912023-07-03 16:36:41 +000089 }
90 else
91 {
92 // uint32, not supported
TatWai Chong679bdad2023-07-31 15:15:12 -070093 rc = DATA_TYPE_NOT_SUPPORTED;
Jerry Ge13a32912023-07-03 16:36:41 +000094 }
95 break;
96 default:
97 return DATA_TYPE_NOT_SUPPORTED;
98 break;
99 }
100
101 return rc;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700102}
103
Eric Kunze2364dcd2021-04-26 11:06:57 -0700104NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename,
105 const char* dtype_str,
106 const size_t elementsize,
107 const uint32_t elems,
108 void* databuf,
109 bool bool_translate)
110{
111 FILE* infile = nullptr;
112 NPError rc = NO_ERROR;
113
114 assert(filename);
115 assert(databuf);
116
117 infile = fopen(filename, "rb");
118 if (!infile)
119 {
120 return FILE_NOT_FOUND;
121 }
122
123 rc = checkNpyHeader(infile, elems, dtype_str);
124 if (rc == NO_ERROR)
125 {
126 if (bool_translate)
127 {
128 // Read in the data from numpy byte array to native bool
129 // array format
130 bool* buf = reinterpret_cast<bool*>(databuf);
131 for (uint32_t i = 0; i < elems; i++)
132 {
133 int val = fgetc(infile);
134
135 if (val == EOF)
136 {
137 rc = FILE_IO_ERROR;
138 }
139
140 buf[i] = val;
141 }
142 }
143 else
144 {
145 // Now we are at the beginning of the data
146 // Parse based on the datatype and number of dimensions
147 if (fread(databuf, elementsize, elems, infile) != elems)
148 {
149 rc = FILE_IO_ERROR;
150 }
151 }
152 }
153
154 if (infile)
155 fclose(infile);
156
157 return rc;
158}
159
Jerry Ge13a32912023-07-03 16:36:41 +0000160NumpyUtilities::NPError NumpyUtilities::getHeader(FILE* infile, bool& is_signed, int& bit_length, char& byte_order)
161{
162 char buf[NUMPY_HEADER_SZ + 1];
163 NPError rc = NO_ERROR;
164 assert(infile);
165
166 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
167 {
168 return HEADER_PARSE_ERROR;
169 }
Jerry Ge13a32912023-07-03 16:36:41 +0000170
Jerry Ge0bc46e42023-07-07 21:52:26 +0000171 // Validate the numpy magic number
172 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
173 {
174 return HEADER_PARSE_ERROR;
175 }
Jerry Ge13a32912023-07-03 16:36:41 +0000176
Jerry Ge0bc46e42023-07-07 21:52:26 +0000177 std::string dic_string(buf, NUMPY_HEADER_SZ);
Jerry Ge13a32912023-07-03 16:36:41 +0000178
Jerry Ge0bc46e42023-07-07 21:52:26 +0000179 std::string desc_str("descr':");
180 size_t offset = dic_string.find(desc_str);
181 if (offset == std::string::npos)
182 return HEADER_PARSE_ERROR;
Jerry Ge13a32912023-07-03 16:36:41 +0000183
Jerry Ge0bc46e42023-07-07 21:52:26 +0000184 offset += desc_str.size() + 1;
185 // Skip whitespace and the opening '
186 while (offset < dic_string.size() && (std::isspace(dic_string[offset]) || dic_string[offset] == '\''))
187 offset++;
188 // Check for overflow
189 if (offset + 2 > dic_string.size())
190 return HEADER_PARSE_ERROR;
Jerry Ge13a32912023-07-03 16:36:41 +0000191
Jerry Ge0bc46e42023-07-07 21:52:26 +0000192 byte_order = dic_string[offset];
193 is_signed = dic_string[offset + 1] == 'u' ? false : true;
194 bit_length = (int)dic_string[offset + 2] - '0';
Jerry Ge13a32912023-07-03 16:36:41 +0000195
196 rewind(infile);
197 return rc;
198}
199
Eric Kunze2364dcd2021-04-26 11:06:57 -0700200NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
201{
202 char buf[NUMPY_HEADER_SZ + 1];
203 char* ptr = nullptr;
204 NPError rc = NO_ERROR;
205 bool foundFormat = false;
206 bool foundOrder = false;
207 bool foundShape = false;
208 bool fortranOrder = false;
209 std::vector<int> shape;
210 uint32_t totalElems = 1;
211 char* outer_end = NULL;
212
213 assert(infile);
214 assert(elems > 0);
215
216 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
217 {
218 return HEADER_PARSE_ERROR;
219 }
220
221 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
222 {
223 return HEADER_PARSE_ERROR;
224 }
225
226 ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end);
227
228 // Read in the data type, order, and shape
229 while (ptr && (!foundFormat || !foundOrder || !foundShape))
230 {
231
232 // End of string?
233 if (!ptr)
234 break;
235
236 // Skip whitespace
237 while (isspace(*ptr))
238 ptr++;
239
240 // Parse the dictionary field name
241 if (!strcmp(ptr, "'descr'"))
242 {
243 ptr = strtok_r(NULL, ",", &outer_end);
244 if (!ptr)
245 break;
246
247 while (isspace(*ptr))
248 ptr++;
249
250 if (strcmp(ptr, dtype_str))
251 {
252 return FILE_TYPE_MISMATCH;
253 }
254
255 foundFormat = true;
256 }
257 else if (!strcmp(ptr, "'fortran_order'"))
258 {
259 ptr = strtok_r(NULL, ",", &outer_end);
260 if (!ptr)
261 break;
262
263 while (isspace(*ptr))
264 ptr++;
265
266 if (!strcmp(ptr, "False"))
267 {
268 fortranOrder = false;
269 }
270 else
271 {
272 return FILE_TYPE_MISMATCH;
273 }
274
275 foundOrder = true;
276 }
277 else if (!strcmp(ptr, "'shape'"))
278 {
279
280 ptr = strtok_r(NULL, "(", &outer_end);
281 if (!ptr)
282 break;
283 ptr = strtok_r(NULL, ")", &outer_end);
284 if (!ptr)
285 break;
286
287 while (isspace(*ptr))
288 ptr++;
289
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100290 // The shape contains N comma-separated integers. Read up to MAX_DIMS.
Eric Kunze2364dcd2021-04-26 11:06:57 -0700291 char* end = NULL;
292
293 ptr = strtok_r(ptr, ",", &end);
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100294 for (int i = 0; i < NUMPY_MAX_DIMS_SUPPORTED; i++)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700295 {
296 // Out of dimensions
297 if (!ptr)
298 break;
299
300 int dim = atoi(ptr);
301
302 // Dimension is 0
303 if (dim == 0)
304 break;
305
306 shape.push_back(dim);
307 totalElems *= dim;
308 ptr = strtok_r(NULL, ",", &end);
309 }
310
311 foundShape = true;
312 }
313 else
314 {
315 return HEADER_PARSE_ERROR;
316 }
317
318 if (!ptr)
319 break;
320
321 ptr = strtok_r(NULL, ":", &outer_end);
322 }
323
324 if (!foundShape || !foundFormat || !foundOrder)
325 {
326 return HEADER_PARSE_ERROR;
327 }
328
329 // Validate header
330 if (fortranOrder)
331 {
332 return FILE_TYPE_MISMATCH;
333 }
334
335 if (totalElems != elems)
336 {
337 return BUFFER_SIZE_MISMATCH;
338 }
339
340 // Go back to the begininng and read until the end of the header dictionary
341 rewind(infile);
342 int val;
343
344 do
345 {
346 val = fgetc(infile);
347 } while (val != EOF && val != '\n');
348
349 return rc;
350}
351
Eric Kunze2364dcd2021-04-26 11:06:57 -0700352NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename,
353 const char* dtype_str,
354 const size_t elementsize,
355 const std::vector<int32_t>& shape,
356 const void* databuf,
357 bool bool_translate)
358{
359 FILE* outfile = nullptr;
360 NPError rc = NO_ERROR;
361 uint32_t totalElems = 1;
362
363 assert(filename);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700364 assert(databuf);
365
366 outfile = fopen(filename, "wb");
367
368 if (!outfile)
369 {
370 return FILE_NOT_FOUND;
371 }
372
373 for (uint32_t i = 0; i < shape.size(); i++)
374 {
375 totalElems *= shape[i];
376 }
377
378 rc = writeNpyHeader(outfile, shape, dtype_str);
379
380 if (rc == NO_ERROR)
381 {
382 if (bool_translate)
383 {
384 // Numpy save format stores booleans as a byte array
385 // with one byte per boolean. This somewhat inefficiently
386 // remaps from system bool[] to this format.
387 const bool* buf = reinterpret_cast<const bool*>(databuf);
388 for (uint32_t i = 0; i < totalElems; i++)
389 {
390 int val = buf[i] ? 1 : 0;
391 if (fputc(val, outfile) == EOF)
392 {
393 rc = FILE_IO_ERROR;
394 }
395 }
396 }
397 else
398 {
399 if (fwrite(databuf, elementsize, totalElems, outfile) != totalElems)
400 {
401 rc = FILE_IO_ERROR;
402 }
403 }
404 }
405
406 if (outfile)
407 fclose(outfile);
408
409 return rc;
410}
411
412NumpyUtilities::NPError
413 NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str)
414{
415 NPError rc = NO_ERROR;
416 uint32_t i;
417 char header[NUMPY_HEADER_SZ + 1];
418 int headerPos = 0;
419
420 assert(outfile);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700421
422 // Space-fill the header and end with a newline to start per numpy spec
423 memset(header, 0x20, NUMPY_HEADER_SZ);
424 header[NUMPY_HEADER_SZ - 1] = '\n';
425 header[NUMPY_HEADER_SZ] = 0;
426
427 // Write out the hard-coded header. We only support a 128-byte 1.0 header
428 // for now, which should be sufficient for simple tensor types of any
429 // reasonable rank.
430 memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
431 headerPos += sizeof(NUMPY_HEADER_STR) - 1;
432
433 // Output the format dictionary
434 // Hard-coded for I32 for now
Jeremy Johnson8f9e2842024-04-04 11:14:06 +0100435 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos,
436 "'descr': %s, 'fortran_order': False, 'shape': (", dtype_str);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700437
Jeremy Johnson8f9e2842024-04-04 11:14:06 +0100438 // Add shape contents (if any - as this will be empty for rank 0)
439 for (i = 0; i < shape.size(); i++)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700440 {
441 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
442 }
443
444 // Close off the dictionary
445 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }");
446
447 // snprintf leaves a NULL at the end. Replace with a space
448 header[headerPos] = 0x20;
449
450 if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1)
451 {
452 rc = FILE_IO_ERROR;
453 }
454
455 return rc;
456}