blob: c770d45f75a025d9fce2ed08dc5952e484cabdc0 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
2// Copyright (c) 2020-2021, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "numpy_utils.h"
James Ward485a11d2022-08-05 13:48:37 +010017#include "half.hpp"
Eric Kunze2364dcd2021-04-26 11:06:57 -070018
19// Magic NUMPY header
20static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
21static const int NUMPY_HEADER_SZ = 128;
Jeremy Johnson82dbb322021-07-08 11:53:04 +010022// Maximum shape dimensions supported
23static const int NUMPY_MAX_DIMS_SUPPORTED = 10;
Eric Kunze2364dcd2021-04-26 11:06:57 -070024
25NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
26{
27 const char dtype_str[] = "'|b1'";
28 return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true);
29}
30
31NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
32{
33 const char dtype_str[] = "'<i4'";
34 return readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, false);
35}
36
37NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
38{
39 const char dtype_str[] = "'<i8'";
40 return readFromNpyFileCommon(filename, dtype_str, sizeof(int64_t), elems, databuf, false);
41}
42
43NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf)
44{
45 const char dtype_str[] = "'<f4'";
46 return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false);
47}
48
James Ward485a11d2022-08-05 13:48:37 +010049NumpyUtilities::NPError
50 NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, half_float::half* databuf)
51{
52 const char dtype_str[] = "'<f2'";
53 return readFromNpyFileCommon(filename, dtype_str, sizeof(half_float::half), elems, databuf, false);
54}
55
Eric Kunze2364dcd2021-04-26 11:06:57 -070056NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename,
57 const char* dtype_str,
58 const size_t elementsize,
59 const uint32_t elems,
60 void* databuf,
61 bool bool_translate)
62{
63 FILE* infile = nullptr;
64 NPError rc = NO_ERROR;
65
66 assert(filename);
67 assert(databuf);
68
69 infile = fopen(filename, "rb");
70 if (!infile)
71 {
72 return FILE_NOT_FOUND;
73 }
74
75 rc = checkNpyHeader(infile, elems, dtype_str);
76 if (rc == NO_ERROR)
77 {
78 if (bool_translate)
79 {
80 // Read in the data from numpy byte array to native bool
81 // array format
82 bool* buf = reinterpret_cast<bool*>(databuf);
83 for (uint32_t i = 0; i < elems; i++)
84 {
85 int val = fgetc(infile);
86
87 if (val == EOF)
88 {
89 rc = FILE_IO_ERROR;
90 }
91
92 buf[i] = val;
93 }
94 }
95 else
96 {
97 // Now we are at the beginning of the data
98 // Parse based on the datatype and number of dimensions
99 if (fread(databuf, elementsize, elems, infile) != elems)
100 {
101 rc = FILE_IO_ERROR;
102 }
103 }
104 }
105
106 if (infile)
107 fclose(infile);
108
109 return rc;
110}
111
112NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
113{
114 char buf[NUMPY_HEADER_SZ + 1];
115 char* ptr = nullptr;
116 NPError rc = NO_ERROR;
117 bool foundFormat = false;
118 bool foundOrder = false;
119 bool foundShape = false;
120 bool fortranOrder = false;
121 std::vector<int> shape;
122 uint32_t totalElems = 1;
123 char* outer_end = NULL;
124
125 assert(infile);
126 assert(elems > 0);
127
128 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
129 {
130 return HEADER_PARSE_ERROR;
131 }
132
133 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
134 {
135 return HEADER_PARSE_ERROR;
136 }
137
138 ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end);
139
140 // Read in the data type, order, and shape
141 while (ptr && (!foundFormat || !foundOrder || !foundShape))
142 {
143
144 // End of string?
145 if (!ptr)
146 break;
147
148 // Skip whitespace
149 while (isspace(*ptr))
150 ptr++;
151
152 // Parse the dictionary field name
153 if (!strcmp(ptr, "'descr'"))
154 {
155 ptr = strtok_r(NULL, ",", &outer_end);
156 if (!ptr)
157 break;
158
159 while (isspace(*ptr))
160 ptr++;
161
162 if (strcmp(ptr, dtype_str))
163 {
164 return FILE_TYPE_MISMATCH;
165 }
166
167 foundFormat = true;
168 }
169 else if (!strcmp(ptr, "'fortran_order'"))
170 {
171 ptr = strtok_r(NULL, ",", &outer_end);
172 if (!ptr)
173 break;
174
175 while (isspace(*ptr))
176 ptr++;
177
178 if (!strcmp(ptr, "False"))
179 {
180 fortranOrder = false;
181 }
182 else
183 {
184 return FILE_TYPE_MISMATCH;
185 }
186
187 foundOrder = true;
188 }
189 else if (!strcmp(ptr, "'shape'"))
190 {
191
192 ptr = strtok_r(NULL, "(", &outer_end);
193 if (!ptr)
194 break;
195 ptr = strtok_r(NULL, ")", &outer_end);
196 if (!ptr)
197 break;
198
199 while (isspace(*ptr))
200 ptr++;
201
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100202 // The shape contains N comma-separated integers. Read up to MAX_DIMS.
Eric Kunze2364dcd2021-04-26 11:06:57 -0700203 char* end = NULL;
204
205 ptr = strtok_r(ptr, ",", &end);
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100206 for (int i = 0; i < NUMPY_MAX_DIMS_SUPPORTED; i++)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700207 {
208 // Out of dimensions
209 if (!ptr)
210 break;
211
212 int dim = atoi(ptr);
213
214 // Dimension is 0
215 if (dim == 0)
216 break;
217
218 shape.push_back(dim);
219 totalElems *= dim;
220 ptr = strtok_r(NULL, ",", &end);
221 }
222
223 foundShape = true;
224 }
225 else
226 {
227 return HEADER_PARSE_ERROR;
228 }
229
230 if (!ptr)
231 break;
232
233 ptr = strtok_r(NULL, ":", &outer_end);
234 }
235
236 if (!foundShape || !foundFormat || !foundOrder)
237 {
238 return HEADER_PARSE_ERROR;
239 }
240
241 // Validate header
242 if (fortranOrder)
243 {
244 return FILE_TYPE_MISMATCH;
245 }
246
247 if (totalElems != elems)
248 {
249 return BUFFER_SIZE_MISMATCH;
250 }
251
252 // Go back to the begininng and read until the end of the header dictionary
253 rewind(infile);
254 int val;
255
256 do
257 {
258 val = fgetc(infile);
259 } while (val != EOF && val != '\n');
260
261 return rc;
262}
263
264NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf)
265{
266 std::vector<int32_t> shape = { (int32_t)elems };
267 return writeToNpyFile(filename, shape, databuf);
268}
269
270NumpyUtilities::NPError
271 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf)
272{
273 const char dtype_str[] = "'|b1'";
274 return writeToNpyFileCommon(filename, dtype_str, 1, shape, databuf, true); // bools written as size 1
275}
276
277NumpyUtilities::NPError
278 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf)
279{
280 std::vector<int32_t> shape = { (int32_t)elems };
281 return writeToNpyFile(filename, shape, databuf);
282}
283
284NumpyUtilities::NPError
285 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf)
286{
287 const char dtype_str[] = "'<i4'";
288 return writeToNpyFileCommon(filename, dtype_str, sizeof(int32_t), shape, databuf, false);
289}
290
291NumpyUtilities::NPError
292 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf)
293{
294 std::vector<int32_t> shape = { (int32_t)elems };
295 return writeToNpyFile(filename, shape, databuf);
296}
297
298NumpyUtilities::NPError
299 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf)
300{
301 const char dtype_str[] = "'<i8'";
302 return writeToNpyFileCommon(filename, dtype_str, sizeof(int64_t), shape, databuf, false);
303}
304
305NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf)
306{
307 std::vector<int32_t> shape = { (int32_t)elems };
308 return writeToNpyFile(filename, shape, databuf);
309}
310
311NumpyUtilities::NPError
312 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf)
313{
314 const char dtype_str[] = "'<f4'";
315 return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false);
316}
317
James Ward485a11d2022-08-05 13:48:37 +0100318NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename,
319 const std::vector<int32_t>& shape,
320 const half_float::half* databuf)
321{
322 const char dtype_str[] = "'<f2'";
323 return writeToNpyFileCommon(filename, dtype_str, sizeof(half_float::half), shape, databuf, false);
324}
325
Eric Kunze2364dcd2021-04-26 11:06:57 -0700326NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename,
327 const char* dtype_str,
328 const size_t elementsize,
329 const std::vector<int32_t>& shape,
330 const void* databuf,
331 bool bool_translate)
332{
333 FILE* outfile = nullptr;
334 NPError rc = NO_ERROR;
335 uint32_t totalElems = 1;
336
337 assert(filename);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700338 assert(databuf);
339
340 outfile = fopen(filename, "wb");
341
342 if (!outfile)
343 {
344 return FILE_NOT_FOUND;
345 }
346
347 for (uint32_t i = 0; i < shape.size(); i++)
348 {
349 totalElems *= shape[i];
350 }
351
352 rc = writeNpyHeader(outfile, shape, dtype_str);
353
354 if (rc == NO_ERROR)
355 {
356 if (bool_translate)
357 {
358 // Numpy save format stores booleans as a byte array
359 // with one byte per boolean. This somewhat inefficiently
360 // remaps from system bool[] to this format.
361 const bool* buf = reinterpret_cast<const bool*>(databuf);
362 for (uint32_t i = 0; i < totalElems; i++)
363 {
364 int val = buf[i] ? 1 : 0;
365 if (fputc(val, outfile) == EOF)
366 {
367 rc = FILE_IO_ERROR;
368 }
369 }
370 }
371 else
372 {
373 if (fwrite(databuf, elementsize, totalElems, outfile) != totalElems)
374 {
375 rc = FILE_IO_ERROR;
376 }
377 }
378 }
379
380 if (outfile)
381 fclose(outfile);
382
383 return rc;
384}
385
386NumpyUtilities::NPError
387 NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str)
388{
389 NPError rc = NO_ERROR;
390 uint32_t i;
391 char header[NUMPY_HEADER_SZ + 1];
392 int headerPos = 0;
393
394 assert(outfile);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700395
396 // Space-fill the header and end with a newline to start per numpy spec
397 memset(header, 0x20, NUMPY_HEADER_SZ);
398 header[NUMPY_HEADER_SZ - 1] = '\n';
399 header[NUMPY_HEADER_SZ] = 0;
400
401 // Write out the hard-coded header. We only support a 128-byte 1.0 header
402 // for now, which should be sufficient for simple tensor types of any
403 // reasonable rank.
404 memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
405 headerPos += sizeof(NUMPY_HEADER_STR) - 1;
406
407 // Output the format dictionary
408 // Hard-coded for I32 for now
409 headerPos +=
410 snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,",
411 dtype_str, shape.empty() ? 1 : shape[0]);
412
413 // Remainder of shape array
414 for (i = 1; i < shape.size(); i++)
415 {
416 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
417 }
418
419 // Close off the dictionary
420 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }");
421
422 // snprintf leaves a NULL at the end. Replace with a space
423 header[headerPos] = 0x20;
424
425 if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1)
426 {
427 rc = FILE_IO_ERROR;
428 }
429
430 return rc;
431}