blob: 80c680f3ae1b609d89a0354d06ef16b926154a2e [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
2// Copyright (c) 2020-2021, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "numpy_utils.h"
17
18// Magic NUMPY header
19static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
20static const int NUMPY_HEADER_SZ = 128;
Jeremy Johnson82dbb322021-07-08 11:53:04 +010021// Maximum shape dimensions supported
22static const int NUMPY_MAX_DIMS_SUPPORTED = 10;
Eric Kunze2364dcd2021-04-26 11:06:57 -070023
24NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
25{
26 const char dtype_str[] = "'|b1'";
27 return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true);
28}
29
30NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
31{
32 const char dtype_str[] = "'<i4'";
33 return readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, false);
34}
35
36NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
37{
38 const char dtype_str[] = "'<i8'";
39 return readFromNpyFileCommon(filename, dtype_str, sizeof(int64_t), elems, databuf, false);
40}
41
42NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf)
43{
44 const char dtype_str[] = "'<f4'";
45 return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false);
46}
47
48NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename,
49 const char* dtype_str,
50 const size_t elementsize,
51 const uint32_t elems,
52 void* databuf,
53 bool bool_translate)
54{
55 FILE* infile = nullptr;
56 NPError rc = NO_ERROR;
57
58 assert(filename);
59 assert(databuf);
60
61 infile = fopen(filename, "rb");
62 if (!infile)
63 {
64 return FILE_NOT_FOUND;
65 }
66
67 rc = checkNpyHeader(infile, elems, dtype_str);
68 if (rc == NO_ERROR)
69 {
70 if (bool_translate)
71 {
72 // Read in the data from numpy byte array to native bool
73 // array format
74 bool* buf = reinterpret_cast<bool*>(databuf);
75 for (uint32_t i = 0; i < elems; i++)
76 {
77 int val = fgetc(infile);
78
79 if (val == EOF)
80 {
81 rc = FILE_IO_ERROR;
82 }
83
84 buf[i] = val;
85 }
86 }
87 else
88 {
89 // Now we are at the beginning of the data
90 // Parse based on the datatype and number of dimensions
91 if (fread(databuf, elementsize, elems, infile) != elems)
92 {
93 rc = FILE_IO_ERROR;
94 }
95 }
96 }
97
98 if (infile)
99 fclose(infile);
100
101 return rc;
102}
103
104NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
105{
106 char buf[NUMPY_HEADER_SZ + 1];
107 char* ptr = nullptr;
108 NPError rc = NO_ERROR;
109 bool foundFormat = false;
110 bool foundOrder = false;
111 bool foundShape = false;
112 bool fortranOrder = false;
113 std::vector<int> shape;
114 uint32_t totalElems = 1;
115 char* outer_end = NULL;
116
117 assert(infile);
118 assert(elems > 0);
119
120 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
121 {
122 return HEADER_PARSE_ERROR;
123 }
124
125 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
126 {
127 return HEADER_PARSE_ERROR;
128 }
129
130 ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end);
131
132 // Read in the data type, order, and shape
133 while (ptr && (!foundFormat || !foundOrder || !foundShape))
134 {
135
136 // End of string?
137 if (!ptr)
138 break;
139
140 // Skip whitespace
141 while (isspace(*ptr))
142 ptr++;
143
144 // Parse the dictionary field name
145 if (!strcmp(ptr, "'descr'"))
146 {
147 ptr = strtok_r(NULL, ",", &outer_end);
148 if (!ptr)
149 break;
150
151 while (isspace(*ptr))
152 ptr++;
153
154 if (strcmp(ptr, dtype_str))
155 {
156 return FILE_TYPE_MISMATCH;
157 }
158
159 foundFormat = true;
160 }
161 else if (!strcmp(ptr, "'fortran_order'"))
162 {
163 ptr = strtok_r(NULL, ",", &outer_end);
164 if (!ptr)
165 break;
166
167 while (isspace(*ptr))
168 ptr++;
169
170 if (!strcmp(ptr, "False"))
171 {
172 fortranOrder = false;
173 }
174 else
175 {
176 return FILE_TYPE_MISMATCH;
177 }
178
179 foundOrder = true;
180 }
181 else if (!strcmp(ptr, "'shape'"))
182 {
183
184 ptr = strtok_r(NULL, "(", &outer_end);
185 if (!ptr)
186 break;
187 ptr = strtok_r(NULL, ")", &outer_end);
188 if (!ptr)
189 break;
190
191 while (isspace(*ptr))
192 ptr++;
193
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100194 // The shape contains N comma-separated integers. Read up to MAX_DIMS.
Eric Kunze2364dcd2021-04-26 11:06:57 -0700195 char* end = NULL;
196
197 ptr = strtok_r(ptr, ",", &end);
Jeremy Johnson82dbb322021-07-08 11:53:04 +0100198 for (int i = 0; i < NUMPY_MAX_DIMS_SUPPORTED; i++)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700199 {
200 // Out of dimensions
201 if (!ptr)
202 break;
203
204 int dim = atoi(ptr);
205
206 // Dimension is 0
207 if (dim == 0)
208 break;
209
210 shape.push_back(dim);
211 totalElems *= dim;
212 ptr = strtok_r(NULL, ",", &end);
213 }
214
215 foundShape = true;
216 }
217 else
218 {
219 return HEADER_PARSE_ERROR;
220 }
221
222 if (!ptr)
223 break;
224
225 ptr = strtok_r(NULL, ":", &outer_end);
226 }
227
228 if (!foundShape || !foundFormat || !foundOrder)
229 {
230 return HEADER_PARSE_ERROR;
231 }
232
233 // Validate header
234 if (fortranOrder)
235 {
236 return FILE_TYPE_MISMATCH;
237 }
238
239 if (totalElems != elems)
240 {
241 return BUFFER_SIZE_MISMATCH;
242 }
243
244 // Go back to the begininng and read until the end of the header dictionary
245 rewind(infile);
246 int val;
247
248 do
249 {
250 val = fgetc(infile);
251 } while (val != EOF && val != '\n');
252
253 return rc;
254}
255
256NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf)
257{
258 std::vector<int32_t> shape = { (int32_t)elems };
259 return writeToNpyFile(filename, shape, databuf);
260}
261
262NumpyUtilities::NPError
263 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf)
264{
265 const char dtype_str[] = "'|b1'";
266 return writeToNpyFileCommon(filename, dtype_str, 1, shape, databuf, true); // bools written as size 1
267}
268
269NumpyUtilities::NPError
270 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf)
271{
272 std::vector<int32_t> shape = { (int32_t)elems };
273 return writeToNpyFile(filename, shape, databuf);
274}
275
276NumpyUtilities::NPError
277 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf)
278{
279 const char dtype_str[] = "'<i4'";
280 return writeToNpyFileCommon(filename, dtype_str, sizeof(int32_t), shape, databuf, false);
281}
282
283NumpyUtilities::NPError
284 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf)
285{
286 std::vector<int32_t> shape = { (int32_t)elems };
287 return writeToNpyFile(filename, shape, databuf);
288}
289
290NumpyUtilities::NPError
291 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf)
292{
293 const char dtype_str[] = "'<i8'";
294 return writeToNpyFileCommon(filename, dtype_str, sizeof(int64_t), shape, databuf, false);
295}
296
297NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf)
298{
299 std::vector<int32_t> shape = { (int32_t)elems };
300 return writeToNpyFile(filename, shape, databuf);
301}
302
303NumpyUtilities::NPError
304 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf)
305{
306 const char dtype_str[] = "'<f4'";
307 return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false);
308}
309
310NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename,
311 const char* dtype_str,
312 const size_t elementsize,
313 const std::vector<int32_t>& shape,
314 const void* databuf,
315 bool bool_translate)
316{
317 FILE* outfile = nullptr;
318 NPError rc = NO_ERROR;
319 uint32_t totalElems = 1;
320
321 assert(filename);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700322 assert(databuf);
323
324 outfile = fopen(filename, "wb");
325
326 if (!outfile)
327 {
328 return FILE_NOT_FOUND;
329 }
330
331 for (uint32_t i = 0; i < shape.size(); i++)
332 {
333 totalElems *= shape[i];
334 }
335
336 rc = writeNpyHeader(outfile, shape, dtype_str);
337
338 if (rc == NO_ERROR)
339 {
340 if (bool_translate)
341 {
342 // Numpy save format stores booleans as a byte array
343 // with one byte per boolean. This somewhat inefficiently
344 // remaps from system bool[] to this format.
345 const bool* buf = reinterpret_cast<const bool*>(databuf);
346 for (uint32_t i = 0; i < totalElems; i++)
347 {
348 int val = buf[i] ? 1 : 0;
349 if (fputc(val, outfile) == EOF)
350 {
351 rc = FILE_IO_ERROR;
352 }
353 }
354 }
355 else
356 {
357 if (fwrite(databuf, elementsize, totalElems, outfile) != totalElems)
358 {
359 rc = FILE_IO_ERROR;
360 }
361 }
362 }
363
364 if (outfile)
365 fclose(outfile);
366
367 return rc;
368}
369
370NumpyUtilities::NPError
371 NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str)
372{
373 NPError rc = NO_ERROR;
374 uint32_t i;
375 char header[NUMPY_HEADER_SZ + 1];
376 int headerPos = 0;
377
378 assert(outfile);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700379
380 // Space-fill the header and end with a newline to start per numpy spec
381 memset(header, 0x20, NUMPY_HEADER_SZ);
382 header[NUMPY_HEADER_SZ - 1] = '\n';
383 header[NUMPY_HEADER_SZ] = 0;
384
385 // Write out the hard-coded header. We only support a 128-byte 1.0 header
386 // for now, which should be sufficient for simple tensor types of any
387 // reasonable rank.
388 memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
389 headerPos += sizeof(NUMPY_HEADER_STR) - 1;
390
391 // Output the format dictionary
392 // Hard-coded for I32 for now
393 headerPos +=
394 snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,",
395 dtype_str, shape.empty() ? 1 : shape[0]);
396
397 // Remainder of shape array
398 for (i = 1; i < shape.size(); i++)
399 {
400 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
401 }
402
403 // Close off the dictionary
404 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }");
405
406 // snprintf leaves a NULL at the end. Replace with a space
407 header[headerPos] = 0x20;
408
409 if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1)
410 {
411 rc = FILE_IO_ERROR;
412 }
413
414 return rc;
415}