blob: e438235fe7ab14ed88901dc8422da17f40c6aae3 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
2// Copyright (c) 2020-2021, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "numpy_utils.h"
17
18// Magic NUMPY header
19static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
20static const int NUMPY_HEADER_SZ = 128;
21
22NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
23{
24 const char dtype_str[] = "'|b1'";
25 return readFromNpyFileCommon(filename, dtype_str, 1, elems, databuf, true);
26}
27
28NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
29{
30 const char dtype_str[] = "'<i4'";
31 return readFromNpyFileCommon(filename, dtype_str, sizeof(int32_t), elems, databuf, false);
32}
33
34NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
35{
36 const char dtype_str[] = "'<i8'";
37 return readFromNpyFileCommon(filename, dtype_str, sizeof(int64_t), elems, databuf, false);
38}
39
40NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf)
41{
42 const char dtype_str[] = "'<f4'";
43 return readFromNpyFileCommon(filename, dtype_str, sizeof(float), elems, databuf, false);
44}
45
46NumpyUtilities::NPError NumpyUtilities::readFromNpyFileCommon(const char* filename,
47 const char* dtype_str,
48 const size_t elementsize,
49 const uint32_t elems,
50 void* databuf,
51 bool bool_translate)
52{
53 FILE* infile = nullptr;
54 NPError rc = NO_ERROR;
55
56 assert(filename);
57 assert(databuf);
58
59 infile = fopen(filename, "rb");
60 if (!infile)
61 {
62 return FILE_NOT_FOUND;
63 }
64
65 rc = checkNpyHeader(infile, elems, dtype_str);
66 if (rc == NO_ERROR)
67 {
68 if (bool_translate)
69 {
70 // Read in the data from numpy byte array to native bool
71 // array format
72 bool* buf = reinterpret_cast<bool*>(databuf);
73 for (uint32_t i = 0; i < elems; i++)
74 {
75 int val = fgetc(infile);
76
77 if (val == EOF)
78 {
79 rc = FILE_IO_ERROR;
80 }
81
82 buf[i] = val;
83 }
84 }
85 else
86 {
87 // Now we are at the beginning of the data
88 // Parse based on the datatype and number of dimensions
89 if (fread(databuf, elementsize, elems, infile) != elems)
90 {
91 rc = FILE_IO_ERROR;
92 }
93 }
94 }
95
96 if (infile)
97 fclose(infile);
98
99 return rc;
100}
101
102NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
103{
104 char buf[NUMPY_HEADER_SZ + 1];
105 char* ptr = nullptr;
106 NPError rc = NO_ERROR;
107 bool foundFormat = false;
108 bool foundOrder = false;
109 bool foundShape = false;
110 bool fortranOrder = false;
111 std::vector<int> shape;
112 uint32_t totalElems = 1;
113 char* outer_end = NULL;
114
115 assert(infile);
116 assert(elems > 0);
117
118 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
119 {
120 return HEADER_PARSE_ERROR;
121 }
122
123 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
124 {
125 return HEADER_PARSE_ERROR;
126 }
127
128 ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end);
129
130 // Read in the data type, order, and shape
131 while (ptr && (!foundFormat || !foundOrder || !foundShape))
132 {
133
134 // End of string?
135 if (!ptr)
136 break;
137
138 // Skip whitespace
139 while (isspace(*ptr))
140 ptr++;
141
142 // Parse the dictionary field name
143 if (!strcmp(ptr, "'descr'"))
144 {
145 ptr = strtok_r(NULL, ",", &outer_end);
146 if (!ptr)
147 break;
148
149 while (isspace(*ptr))
150 ptr++;
151
152 if (strcmp(ptr, dtype_str))
153 {
154 return FILE_TYPE_MISMATCH;
155 }
156
157 foundFormat = true;
158 }
159 else if (!strcmp(ptr, "'fortran_order'"))
160 {
161 ptr = strtok_r(NULL, ",", &outer_end);
162 if (!ptr)
163 break;
164
165 while (isspace(*ptr))
166 ptr++;
167
168 if (!strcmp(ptr, "False"))
169 {
170 fortranOrder = false;
171 }
172 else
173 {
174 return FILE_TYPE_MISMATCH;
175 }
176
177 foundOrder = true;
178 }
179 else if (!strcmp(ptr, "'shape'"))
180 {
181
182 ptr = strtok_r(NULL, "(", &outer_end);
183 if (!ptr)
184 break;
185 ptr = strtok_r(NULL, ")", &outer_end);
186 if (!ptr)
187 break;
188
189 while (isspace(*ptr))
190 ptr++;
191
192 // The shape contains N comma-separated integers. Read up to 4.
193 char* end = NULL;
194
195 ptr = strtok_r(ptr, ",", &end);
196 for (int i = 0; i < 4; i++)
197 {
198 // Out of dimensions
199 if (!ptr)
200 break;
201
202 int dim = atoi(ptr);
203
204 // Dimension is 0
205 if (dim == 0)
206 break;
207
208 shape.push_back(dim);
209 totalElems *= dim;
210 ptr = strtok_r(NULL, ",", &end);
211 }
212
213 foundShape = true;
214 }
215 else
216 {
217 return HEADER_PARSE_ERROR;
218 }
219
220 if (!ptr)
221 break;
222
223 ptr = strtok_r(NULL, ":", &outer_end);
224 }
225
226 if (!foundShape || !foundFormat || !foundOrder)
227 {
228 return HEADER_PARSE_ERROR;
229 }
230
231 // Validate header
232 if (fortranOrder)
233 {
234 return FILE_TYPE_MISMATCH;
235 }
236
237 if (totalElems != elems)
238 {
239 return BUFFER_SIZE_MISMATCH;
240 }
241
242 // Go back to the begininng and read until the end of the header dictionary
243 rewind(infile);
244 int val;
245
246 do
247 {
248 val = fgetc(infile);
249 } while (val != EOF && val != '\n');
250
251 return rc;
252}
253
254NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf)
255{
256 std::vector<int32_t> shape = { (int32_t)elems };
257 return writeToNpyFile(filename, shape, databuf);
258}
259
260NumpyUtilities::NPError
261 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf)
262{
263 const char dtype_str[] = "'|b1'";
264 return writeToNpyFileCommon(filename, dtype_str, 1, shape, databuf, true); // bools written as size 1
265}
266
267NumpyUtilities::NPError
268 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf)
269{
270 std::vector<int32_t> shape = { (int32_t)elems };
271 return writeToNpyFile(filename, shape, databuf);
272}
273
274NumpyUtilities::NPError
275 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf)
276{
277 const char dtype_str[] = "'<i4'";
278 return writeToNpyFileCommon(filename, dtype_str, sizeof(int32_t), shape, databuf, false);
279}
280
281NumpyUtilities::NPError
282 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf)
283{
284 std::vector<int32_t> shape = { (int32_t)elems };
285 return writeToNpyFile(filename, shape, databuf);
286}
287
288NumpyUtilities::NPError
289 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf)
290{
291 const char dtype_str[] = "'<i8'";
292 return writeToNpyFileCommon(filename, dtype_str, sizeof(int64_t), shape, databuf, false);
293}
294
295NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf)
296{
297 std::vector<int32_t> shape = { (int32_t)elems };
298 return writeToNpyFile(filename, shape, databuf);
299}
300
301NumpyUtilities::NPError
302 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf)
303{
304 const char dtype_str[] = "'<f4'";
305 return writeToNpyFileCommon(filename, dtype_str, sizeof(float), shape, databuf, false);
306}
307
308NumpyUtilities::NPError NumpyUtilities::writeToNpyFileCommon(const char* filename,
309 const char* dtype_str,
310 const size_t elementsize,
311 const std::vector<int32_t>& shape,
312 const void* databuf,
313 bool bool_translate)
314{
315 FILE* outfile = nullptr;
316 NPError rc = NO_ERROR;
317 uint32_t totalElems = 1;
318
319 assert(filename);
320 assert(shape.size() >= 0);
321 assert(databuf);
322
323 outfile = fopen(filename, "wb");
324
325 if (!outfile)
326 {
327 return FILE_NOT_FOUND;
328 }
329
330 for (uint32_t i = 0; i < shape.size(); i++)
331 {
332 totalElems *= shape[i];
333 }
334
335 rc = writeNpyHeader(outfile, shape, dtype_str);
336
337 if (rc == NO_ERROR)
338 {
339 if (bool_translate)
340 {
341 // Numpy save format stores booleans as a byte array
342 // with one byte per boolean. This somewhat inefficiently
343 // remaps from system bool[] to this format.
344 const bool* buf = reinterpret_cast<const bool*>(databuf);
345 for (uint32_t i = 0; i < totalElems; i++)
346 {
347 int val = buf[i] ? 1 : 0;
348 if (fputc(val, outfile) == EOF)
349 {
350 rc = FILE_IO_ERROR;
351 }
352 }
353 }
354 else
355 {
356 if (fwrite(databuf, elementsize, totalElems, outfile) != totalElems)
357 {
358 rc = FILE_IO_ERROR;
359 }
360 }
361 }
362
363 if (outfile)
364 fclose(outfile);
365
366 return rc;
367}
368
369NumpyUtilities::NPError
370 NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str)
371{
372 NPError rc = NO_ERROR;
373 uint32_t i;
374 char header[NUMPY_HEADER_SZ + 1];
375 int headerPos = 0;
376
377 assert(outfile);
378 assert(shape.size() >= 0);
379
380 // Space-fill the header and end with a newline to start per numpy spec
381 memset(header, 0x20, NUMPY_HEADER_SZ);
382 header[NUMPY_HEADER_SZ - 1] = '\n';
383 header[NUMPY_HEADER_SZ] = 0;
384
385 // Write out the hard-coded header. We only support a 128-byte 1.0 header
386 // for now, which should be sufficient for simple tensor types of any
387 // reasonable rank.
388 memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
389 headerPos += sizeof(NUMPY_HEADER_STR) - 1;
390
391 // Output the format dictionary
392 // Hard-coded for I32 for now
393 headerPos +=
394 snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,",
395 dtype_str, shape.empty() ? 1 : shape[0]);
396
397 // Remainder of shape array
398 for (i = 1; i < shape.size(); i++)
399 {
400 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
401 }
402
403 // Close off the dictionary
404 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }");
405
406 // snprintf leaves a NULL at the end. Replace with a space
407 header[headerPos] = 0x20;
408
409 if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1)
410 {
411 rc = FILE_IO_ERROR;
412 }
413
414 return rc;
415}