blob: 123908ac2cd7216380ca2a3f35e85896352fa4f2 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
2// Copyright (c) 2020-2021, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "numpy_utils.h"
James Ward485a11d2022-08-05 13:48:37 +010017#include "half.hpp"
Eric Kunze2364dcd2021-04-26 11:06:57 -070018
19// Magic NUMPY header
20static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
21static const int NUMPY_HEADER_SZ = 128;
Jeremy Johnson82dbb322021-07-08 11:53:04 +010022// Maximum shape dimensions supported
23static const int NUMPY_MAX_DIMS_SUPPORTED = 10;
Eric Kunze2364dcd2021-04-26 11:06:57 -070024
25NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
26{
27 const char dtype_str[] = "'|b1'";
28 return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true);
29}
30
31NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
32{
33 const char dtype_str[] = "'<i4'";
34 return readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, false);
35}
36
37NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
38{
39 const char dtype_str[] = "'<i8'";
40 return readFromNpyFileCommon(filename, dtype_str, sizeof(int64_t), elems, databuf, false);
41}
42
43NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf)
44{
45 const char dtype_str[] = "'<f4'";
46 return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false);
47}
48
Tai Ly3ef34fb2023-04-04 20:34:05 +000049NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, double* databuf)
50{
51 const char dtype_str[] = "'<f8'";
52 return readFromNpyFileCommon(filename, dtype_str, sizeof(double), elems, databuf, false);
53}
54
James Ward485a11d2022-08-05 13:48:37 +010055NumpyUtilities::NPError
56 NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf)
57{
58 const char dtype_str[] = "'<f2'";
59 return readFromNpyFileCommon(filename, dtype_str, sizeof(half_float::half), elems, databuf, false);
60}
61
Eric Kunze2364dcd2021-04-26 11:06:57 -070062NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename,
63 const char* dtype_str,
64 const size_t elementsize,
65 const uint32_t elems,
66 void* databuf,
67 bool bool_translate)
68{
69 FILE* infile = nullptr;
70 NPError rc = NO_ERROR;
71
72 assert(filename);
73 assert(databuf);
74
75 infile = fopen(filename, "rb");
76 if (!infile)
77 {
78 return FILE_NOT_FOUND;
79 }
80
81 rc = checkNpyHeader(infile, elems, dtype_str);
82 if (rc == NO_ERROR)
83 {
84 if (bool_translate)
85 {
86 // Read in the data from numpy byte array to native bool
87 // array format
88 bool* buf = reinterpret_cast<bool*>(databuf);
89 for (uint32_t i = 0; i < elems; i++)
90 {
91 int val = fgetc(infile);
92
93 if (val == EOF)
94 {
95 rc = FILE_IO_ERROR;
96 }
97
98 buf[i] = val;
99 }
100 }
101 else
102 {
103 // Now we are at the beginning of the data
104 // Parse based on the datatype and number of dimensions
105 if (fread(databuf, elementsize, elems, infile) != elems)
106 {
107 rc = FILE_IO_ERROR;
108 }
109 }
110 }
111
112 if (infile)
113 fclose(infile);
114
115 return rc;
116}
117
118NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
119{
120 char buf[NUMPY_HEADER_SZ + 1];
121 char* ptr = nullptr;
122 NPError rc = NO_ERROR;
123 bool foundFormat = false;
124 bool foundOrder = false;
125 bool foundShape = false;
126 bool fortranOrder = false;
127 std::vector<int> shape;
128 uint32_t totalElems = 1;
129 char* outer_end = NULL;
130
131 assert(infile);
132 assert(elems > 0);
133
134 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
135 {
136 return HEADER_PARSE_ERROR;
137 }
138
139 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
140 {
141 return HEADER_PARSE_ERROR;
142 }
143
144 ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end);
145
146 // Read in the data type, order, and shape
147 while (ptr && (!foundFormat || !foundOrder || !foundShape))
148 {
149
150 // End of string?
151 if (!ptr)
152 break;
153
154 // Skip whitespace
155 while (isspace(*ptr))
156 ptr++;
157
158 // Parse the dictionary field name
159 if (!strcmp(ptr, "'descr'"))
160 {
161 ptr = strtok_r(NULL, ",", &outer_end);
162 if (!ptr)
163 break;
164
165 while (isspace(*ptr))
166 ptr++;
167
168 if (strcmp(ptr, dtype_str))
169 {
170 return FILE_TYPE_MISMATCH;
171 }
172
173 foundFormat = true;
174 }
175 else if (!strcmp(ptr, "'fortran_order'"))
176 {
177 ptr = strtok_r(NULL, ",", &outer_end);
178 if (!ptr)
179 break;
180
181 while (isspace(*ptr))
182 ptr++;
183
184 if (!strcmp(ptr, "False"))
185 {
186 fortranOrder = false;
187 }
188 else
189 {
190 return FILE_TYPE_MISMATCH;
191 }
192
193 foundOrder = true;
194 }
195 else if (!strcmp(ptr, "'shape'"))
196 {
197
198 ptr = strtok_r(NULL, "(", &outer_end);
199 if (!ptr)
200 break;
201 ptr = strtok_r(NULL, ")", &outer_end);
202 if (!ptr)
203 break;
204
205 while (isspace(*ptr))
206 ptr++;
207
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100208 // The shape contains N comma-separated integers. Read up to MAX_DIMS.
Eric Kunze2364dcd2021-04-26 11:06:57 -0700209 char* end = NULL;
210
211 ptr = strtok_r(ptr, ",", &end);
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100212 for (int i = 0; i < NUMPY_MAX_DIMS_SUPPORTED; i++)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700213 {
214 // Out of dimensions
215 if (!ptr)
216 break;
217
218 int dim = atoi(ptr);
219
220 // Dimension is 0
221 if (dim == 0)
222 break;
223
224 shape.push_back(dim);
225 totalElems *= dim;
226 ptr = strtok_r(NULL, ",", &end);
227 }
228
229 foundShape = true;
230 }
231 else
232 {
233 return HEADER_PARSE_ERROR;
234 }
235
236 if (!ptr)
237 break;
238
239 ptr = strtok_r(NULL, ":", &outer_end);
240 }
241
242 if (!foundShape || !foundFormat || !foundOrder)
243 {
244 return HEADER_PARSE_ERROR;
245 }
246
247 // Validate header
248 if (fortranOrder)
249 {
250 return FILE_TYPE_MISMATCH;
251 }
252
253 if (totalElems != elems)
254 {
255 return BUFFER_SIZE_MISMATCH;
256 }
257
258 // Go back to the begininng and read until the end of the header dictionary
259 rewind(infile);
260 int val;
261
262 do
263 {
264 val = fgetc(infile);
265 } while (val != EOF && val != '\n');
266
267 return rc;
268}
269
270NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf)
271{
272 std::vector<int32_t> shape = { (int32_t)elems };
273 return writeToNpyFile(filename, shape, databuf);
274}
275
276NumpyUtilities::NPError
277 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf)
278{
279 const char dtype_str[] = "'|b1'";
280 return writeToNpyFileCommon(filename, dtype_str, 1, shape, databuf, true); // bools written as size 1
281}
282
283NumpyUtilities::NPError
284 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf)
285{
286 std::vector<int32_t> shape = { (int32_t)elems };
287 return writeToNpyFile(filename, shape, databuf);
288}
289
290NumpyUtilities::NPError
291 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf)
292{
293 const char dtype_str[] = "'<i4'";
294 return writeToNpyFileCommon(filename, dtype_str, sizeof(int32_t), shape, databuf, false);
295}
296
297NumpyUtilities::NPError
298 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf)
299{
300 std::vector<int32_t> shape = { (int32_t)elems };
301 return writeToNpyFile(filename, shape, databuf);
302}
303
304NumpyUtilities::NPError
305 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf)
306{
307 const char dtype_str[] = "'<i8'";
308 return writeToNpyFileCommon(filename, dtype_str, sizeof(int64_t), shape, databuf, false);
309}
310
311NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf)
312{
313 std::vector<int32_t> shape = { (int32_t)elems };
314 return writeToNpyFile(filename, shape, databuf);
315}
316
317NumpyUtilities::NPError
318 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf)
319{
320 const char dtype_str[] = "'<f4'";
321 return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false);
322}
323
Tai Ly3ef34fb2023-04-04 20:34:05 +0000324NumpyUtilities::NPError
325 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const double* databuf)
326{
327 std::vector<int32_t> shape = { (int32_t)elems };
328 return writeToNpyFile(filename, shape, databuf);
329}
330
331NumpyUtilities::NPError
332 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const double* databuf)
333{
334 const char dtype_str[] = "'<f8'";
335 return writeToNpyFileCommon(filename, dtype_str, sizeof(double), shape, databuf, false);
336}
337
James Ward485a11d2022-08-05 13:48:37 +0100338NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename,
339 const std::vector<int32_t>& shape,
340 const half_float::half* databuf)
341{
342 const char dtype_str[] = "'<f2'";
343 return writeToNpyFileCommon(filename, dtype_str, sizeof(half_float::half), shape, databuf, false);
344}
345
Eric Kunze2364dcd2021-04-26 11:06:57 -0700346NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename,
347 const char* dtype_str,
348 const size_t elementsize,
349 const std::vector<int32_t>& shape,
350 const void* databuf,
351 bool bool_translate)
352{
353 FILE* outfile = nullptr;
354 NPError rc = NO_ERROR;
355 uint32_t totalElems = 1;
356
357 assert(filename);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700358 assert(databuf);
359
360 outfile = fopen(filename, "wb");
361
362 if (!outfile)
363 {
364 return FILE_NOT_FOUND;
365 }
366
367 for (uint32_t i = 0; i < shape.size(); i++)
368 {
369 totalElems *= shape[i];
370 }
371
372 rc = writeNpyHeader(outfile, shape, dtype_str);
373
374 if (rc == NO_ERROR)
375 {
376 if (bool_translate)
377 {
378 // Numpy save format stores booleans as a byte array
379 // with one byte per boolean. This somewhat inefficiently
380 // remaps from system bool[] to this format.
381 const bool* buf = reinterpret_cast<const bool*>(databuf);
382 for (uint32_t i = 0; i < totalElems; i++)
383 {
384 int val = buf[i] ? 1 : 0;
385 if (fputc(val, outfile) == EOF)
386 {
387 rc = FILE_IO_ERROR;
388 }
389 }
390 }
391 else
392 {
393 if (fwrite(databuf, elementsize, totalElems, outfile) != totalElems)
394 {
395 rc = FILE_IO_ERROR;
396 }
397 }
398 }
399
400 if (outfile)
401 fclose(outfile);
402
403 return rc;
404}
405
406NumpyUtilities::NPError
407 NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str)
408{
409 NPError rc = NO_ERROR;
410 uint32_t i;
411 char header[NUMPY_HEADER_SZ + 1];
412 int headerPos = 0;
413
414 assert(outfile);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700415
416 // Space-fill the header and end with a newline to start per numpy spec
417 memset(header, 0x20, NUMPY_HEADER_SZ);
418 header[NUMPY_HEADER_SZ - 1] = '\n';
419 header[NUMPY_HEADER_SZ] = 0;
420
421 // Write out the hard-coded header. We only support a 128-byte 1.0 header
422 // for now, which should be sufficient for simple tensor types of any
423 // reasonable rank.
424 memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
425 headerPos += sizeof(NUMPY_HEADER_STR) - 1;
426
427 // Output the format dictionary
428 // Hard-coded for I32 for now
429 headerPos +=
430 snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,",
431 dtype_str, shape.empty() ? 1 : shape[0]);
432
433 // Remainder of shape array
434 for (i = 1; i < shape.size(); i++)
435 {
436 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
437 }
438
439 // Close off the dictionary
440 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }");
441
442 // snprintf leaves a NULL at the end. Replace with a space
443 header[headerPos] = 0x20;
444
445 if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1)
446 {
447 rc = FILE_IO_ERROR;
448 }
449
450 return rc;
451}