blob: ad33d234bf76068ff2540a9533ec06713cec6626 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "tosa_serialization_handler.h"
17
18#include <iostream>
19using namespace tosa;
20
21TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name,
22 const flatbuffers::Vector<uint32_t>& usage,
23 const flatbuffers::Vector<int32_t>& shape,
24 DType dtype,
25 const flatbuffers::Vector<uint32_t>& format,
26 const flatbuffers::String* npy_filename)
27{
28 _dtype = dtype;
29
30 _usage = new std::vector<Usage>(usage.size());
31 for (uint32_t us : usage)
32 {
33 _usage->push_back((Usage)us);
34 }
35 assert(_usage);
36
37 _format = new std::vector<Format>(format.size());
38 for (uint32_t fm : format)
39 {
40 _format->push_back((Format)fm);
41 }
42 assert(_format);
43
44 _shape = new std::vector<int32_t>(shape.begin(), shape.end());
45
46 _shape = new std::vector<int32_t>(shape.begin(), shape.end());
47 assert(_shape);
48
49 assert(name);
50 _name = new std::string(name->str());
51 assert(_name);
52
53 if (npy_filename)
54 {
55 _npy_filename = new std::string(npy_filename->str());
56 assert(_npy_filename);
57 }
58 else
59 {
60 _npy_filename = nullptr;
61 }
62}
63
64TosaSerializationTensor::TosaSerializationTensor(std::string name,
65 const std::vector<Usage>& usage,
66 const std::vector<int32_t>& shape,
67 DType dtype,
68 const std::vector<Format>& format,
69 const std::string* npy_filename)
70{
71
72 _dtype = dtype;
73
74 _usage = new std::vector<Usage>(usage);
75 assert(_usage);
76
77 _format = new std::vector<Format>(format);
78 assert(_format);
79
80 _shape = new std::vector<int32_t>(shape);
81 assert(_shape);
82
83 _name = new std::string(name);
84 assert(_name);
85
86 if (npy_filename)
87 {
88 _npy_filename = new std::string(*npy_filename);
89 assert(_npy_filename);
90 }
91 else
92 {
93 _npy_filename = nullptr;
94 }
95}
96
97TosaSerializationTensor::TosaSerializationTensor()
98{
99 _dtype = DType_UNKNOWN;
100
101 _usage = new std::vector<Usage>();
102 _format = new std::vector<Format>();
103 _shape = new std::vector<int32_t>();
104 _name = new std::string("UNKNOWN");
105 assert(_usage && _format && _shape && _name);
106
107 _npy_filename = nullptr;
108}
109
110TosaSerializationTensor::TosaSerializationTensor(const TosaSerializationTensor& rhs)
111{
112 _dtype = rhs._dtype;
113
114 assert(rhs._usage);
115 _usage = new std::vector<Usage>(*rhs._usage);
116 assert(_usage);
117
118 assert(rhs._format);
119 _format = new std::vector<Format>(*rhs._format);
120 assert(_format);
121
122 assert(rhs._shape);
123 _shape = new std::vector<int32_t>(*rhs._shape);
124 assert(_shape);
125
126 assert(rhs._name);
127 _name = new std::string(*rhs._name);
128 assert(_name);
129
130 if (rhs._npy_filename)
131 {
132 _npy_filename = new std::string(*rhs._npy_filename);
133 assert(_npy_filename);
134 }
135 else
136 {
137 _npy_filename = nullptr;
138 }
139}
140
141TosaSerializationTensor& TosaSerializationTensor::operator=(const TosaSerializationTensor& rhs)
142{
143 _dtype = rhs._dtype;
144
145 delete _usage;
146 assert(rhs._usage);
147 _usage = new std::vector<Usage>(*rhs._usage);
148 assert(_usage);
149
150 delete _format;
151 assert(rhs._format);
152 _format = new std::vector<Format>(*rhs._format);
153 assert(_format);
154
155 delete _shape;
156 assert(rhs._shape);
157 _shape = new std::vector<int32_t>(*rhs._shape);
158 assert(_shape);
159
160 delete _name;
161 assert(rhs._name);
162 _name = new std::string(*rhs._name);
163 assert(_name);
164
165 if (_npy_filename)
166 delete _npy_filename;
167
168 if (rhs._npy_filename)
169 {
170 _npy_filename = new std::string(*rhs._npy_filename);
171 }
172 else
173 {
174 _npy_filename = nullptr;
175 }
176 return *this;
177}
178
179TosaSerializationTensor::TosaSerializationTensor(TosaSerializationTensor&& rhs)
180{
181 _dtype = rhs._dtype;
182 std::swap(_format, rhs._format);
183 std::swap(_usage, rhs._usage);
184 std::swap(_shape, rhs._shape);
185 std::swap(_name, rhs._name);
186 std::swap(_npy_filename, rhs._npy_filename);
187}
188
189TosaSerializationTensor& TosaSerializationTensor::operator=(TosaSerializationTensor&& rhs)
190{
191 _dtype = rhs._dtype;
192 std::swap(_format, rhs._format);
193 std::swap(_usage, rhs._usage);
194 std::swap(_shape, rhs._shape);
195 std::swap(_name, rhs._name);
196 std::swap(_npy_filename, rhs._npy_filename);
197 return *this;
198}
199
200TosaSerializationTensor::~TosaSerializationTensor()
201{
202 delete _usage;
203 delete _format;
204 delete _shape;
205 delete _name;
206 if (_npy_filename)
207 delete _npy_filename;
208}
209
210TosaSerializationOperator::TosaSerializationOperator(Op op,
211 Attribute attribute_type,
212 const TosaAttributeBase* attribute,
213 QuantInfo qinfo_type,
214 const TosaQuantInfoBase* qinfo,
215 std::vector<std::string> input_tensor_names,
216 std::vector<std::string> output_tensor_names)
217{
218 _op = op;
219 _attribute_type = attribute_type;
220
221 switch (attribute_type)
222 {
223 case Attribute_NONE:
224 _attribute = new TosaNoneAttribute();
225 break;
226#define DEF_ATTRIBUTE(NAME, ...) \
227 case Attribute_##NAME##Attribute: \
228 _attribute = new Tosa##NAME##Attribute(attribute); \
229 break;
230#include "attribute.def"
231#undef DEF_ATTRIBUTE
232 default:
233 printf("TosaSerializationOperator::TosaSerializationOperator(): Attribute %s not implemented yet\n",
234 EnumNamesAttribute()[attribute_type]);
235 assert(0);
236 }
237
238 _qinfo_type = qinfo_type;
239 switch (qinfo_type)
240 {
241 case QuantInfo_NONE:
242 _qinfo = new TosaNoneQuantInfo();
243 break;
244#define DEF_QUANTIZATION_INFO(NAME, ...) \
245 case QuantInfo_##NAME##QuantInfo: \
246 _qinfo = new Tosa##NAME##QuantInfo(qinfo); \
247 break;
248#include "quant_info.def"
249#undef DEF_QUANTIZATION_INFO
250 default:
251 printf("TosaSerializationOperator::TosaSerializationOperator(): QuantInfo %s not implemented yet\n",
252 EnumNamesQuantInfo()[qinfo_type]);
253 assert(0);
254 }
255
256 assert(_attribute && _qinfo);
257
258 _input_tensor_names = new std::vector<std::string>(input_tensor_names);
259 _output_tensor_names = new std::vector<std::string>(output_tensor_names);
260
261 assert(_input_tensor_names && _output_tensor_names);
262
263 _input_tensors = new std::vector<TosaSerializationTensor*>();
264 _output_tensors = new std::vector<TosaSerializationTensor*>();
265
266 assert(_input_tensors && _output_tensors);
267}
268
269TosaSerializationOperator::~TosaSerializationOperator()
270{
271 delete _attribute;
272 delete _qinfo;
273 delete _input_tensor_names;
274 delete _output_tensor_names;
275 // TosaSerializationTensor should be free'd in TosaSerializationSerializationHandler destructor
276 delete _input_tensors;
277 delete _output_tensors;
278}
279
280TosaSerializationBasicBlock::TosaSerializationBasicBlock(std::string name,
281 std::vector<TosaSerializationOperator*> operators,
282 std::vector<TosaSerializationTensor*> tensors,
283 std::vector<std::string> inputs,
284 std::vector<std::string> outputs)
285{
286
287 _name = new std::string(name);
288 assert(_name);
289
290 _operators = new std::vector<TosaSerializationOperator*>(operators);
291 assert(_operators);
292
293 _tensors = new std::vector<TosaSerializationTensor*>(tensors);
294 assert(_tensors);
295
296 _inputs = new std::vector<std::string>(inputs);
297 assert(_inputs);
298
299 _outputs = new std::vector<std::string>(outputs);
300 assert(_outputs);
301}
302
303TosaSerializationBasicBlock::~TosaSerializationBasicBlock()
304{
305 delete _name;
306
307 // deallocate all operators
308 for (auto op : GetOperators())
309 {
310 delete op; // ~TosaSerializationOperator()
311 }
312 delete _operators;
313
314 // deallocate all tensors
315 for (auto ts : GetTensors())
316 {
317 delete ts; // ~TosaSerializationTensor()
318 }
319 _tensors->clear();
320
321 delete _inputs;
322 delete _outputs;
323}
324
325TosaSerializationHandler::TosaSerializationHandler()
326{
327 _schemaLoaded = false;
328 _builder = new flatbuffers::FlatBufferBuilder();
329 _parser = new flatbuffers::Parser();
330 _blocks = new std::vector<TosaSerializationBasicBlock*>();
331
332 assert(_builder && _parser && _blocks);
333
334 SetTosaVersion();
335}
336
337TosaSerializationHandler::~TosaSerializationHandler()
338{
339 if (_version)
340 delete _version;
341 delete _builder;
342 delete _parser;
343
344 Clear(); // deallocate all basic blocks
345
346 delete _blocks;
347}
348
349tosa_err_t TosaSerializationHandler::SetTosaVersion()
350{
351 // version is specified within .fbs
352 // and it's encoded as defaulted value of CreateTosaVersion()
353 // need to write out one object to read that value out
354 // TODO: very costly now. is there any better way to encode constant in .fbs?
355 auto fboffset_version = CreateVersion(*_builder);
356 auto fboffset_tosa_graph = CreateTosaGraphDirect(*_builder, fboffset_version, nullptr);
357 _builder->Finish(fboffset_tosa_graph);
358 std::string jsongen;
359 uint8_t* buf = _builder->GetBufferPointer();
360 auto fb_tosa_graph = GetTosaGraph(buf);
361 auto fb_tosa_version = fb_tosa_graph->version();
362
363 _version = new TosaVersion(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(),
364 fb_tosa_version->_experimental());
365
366 assert(_version);
367 return TOSA_OK;
368}
369
370tosa_err_t TosaSerializationHandler::LoadFileSchema(const char* schema_filename)
371{
372 std::string schema;
373 bool ok;
374
375 ok = flatbuffers::LoadFile(schema_filename, false, &schema);
376 if (!ok)
377 {
378 printf("Error loading schema file: %s\n", schema_filename);
379 return TOSA_FILE_ERROR;
380 }
381
382 ok = _parser->Parse(schema.c_str());
383 if (!ok)
384 {
385 printf("Error parsing ISA schema file: %s\n", schema_filename);
386 return TOSA_FILE_ERROR;
387 }
388 _schemaLoaded = true;
389
390 return TOSA_OK;
391}
392
393tosa_err_t TosaSerializationHandler::LoadFileJson(const char* filename)
394{
395 std::string jsonfile;
396 bool ok;
397 tosa_err_t err;
398
399 if (!_schemaLoaded)
400 {
401 return TOSA_SCHEMA_MISSING;
402 }
403
404 ok = flatbuffers::LoadFile(filename, false, &jsonfile);
405 if (!ok)
406 {
407 printf("Error loading json file: %s\n", filename);
408 return TOSA_FILE_ERROR;
409 }
410
411 ok = _parser->Parse(jsonfile.c_str());
412 if (!ok)
413 {
414 printf("Error parsing json file: %s\n", filename);
415 return TOSA_FILE_ERROR;
416 }
417
418 uint8_t* buf = _parser->builder_.GetBufferPointer();
419
420 err = InitWithBuf(buf);
421 if (err != TOSA_OK)
422 {
423 return err;
424 }
425
426 return TOSA_OK;
427}
428
429tosa_err_t TosaSerializationHandler::SaveFileJson(const char* filename)
430{
431 std::string jsongen;
432 tosa_err_t err;
433
434 if (!_schemaLoaded)
435 {
436 return TOSA_SCHEMA_MISSING;
437 }
438
439 err = FreezeBuilder();
440 if (err != TOSA_OK)
441 {
442 return err;
443 }
444
445 uint8_t* buf = _builder->GetBufferPointer();
446
447 if (!GenerateText(*_parser, buf, &jsongen))
448 {
449 printf("Couldn't serialize parsed data to JSON!\n");
450 return TOSA_FILE_ERROR;
451 }
452
453 FILE* file = fopen(filename, "wb");
454
455 if (!file)
456 {
457 printf("Couldn't open output file: %s\n", filename);
458 return TOSA_FILE_ERROR;
459 }
460
461 if (fwrite(jsongen.c_str(), sizeof(char), jsongen.size(), file) != jsongen.size())
462 {
463 printf("Error writing to json output file: %s\n", filename);
464 fclose(file);
465 return TOSA_FILE_ERROR;
466 }
467
468 if (file)
469 fclose(file);
470
471 return TOSA_OK;
472}
473
474tosa_err_t TosaSerializationHandler::LoadFileTosaFlatbuffer(const char* filename)
475{
476 std::string read_buffer;
477 tosa_err_t err;
478 uint8_t* buf;
479 bool ok;
480
481 ok = flatbuffers::LoadFile(filename, false, &read_buffer);
482 if (!ok)
483 {
484 printf("Error loading flatbuffer file: %s\n", filename);
485 return TOSA_FILE_ERROR;
486 }
487
488 buf = (uint8_t*)read_buffer.data();
489
490 err = InitWithBuf(buf);
491 if (err != TOSA_OK)
492 {
493 return err;
494 }
495
496 return TOSA_OK;
497}
498
499tosa_err_t TosaSerializationHandler::SaveFileTosaFlatbuffer(const char* filename)
500{
501 tosa_err_t err;
502
503 err = FreezeBuilder();
504 if (err != TOSA_OK)
505 {
506 return err;
507 }
508
509 uint8_t* buf = _builder->GetBufferPointer();
510
511 bool ok = flatbuffers::SaveFile(filename, (const char*)buf, _builder->GetSize(), false);
512 if (!ok)
513 {
514 printf("Error saving floatbuffer file: %s\n", filename);
515 return TOSA_FILE_ERROR;
516 }
517
518 return TOSA_OK;
519}
520
521tosa_err_t TosaSerializationHandler::Clear()
522{
523 // deallocate all basic blocks
524 for (auto bb : GetBlocks())
525 {
526 delete bb;
527 }
528 _blocks->clear();
529
530 return TOSA_OK;
531}
532
533tosa_err_t TosaSerializationHandler::CheckTosaVersion(const TosaVersion& read_version)
534{
535 if ((*_version) != read_version)
536 {
537 printf("WARNING: read tosa version: %s != schema tosa version %s\n", read_version.to_string().c_str(),
538 this->_version->to_string().c_str());
539 return TOSA_VERSION_MISMATCH;
540 }
541
542 return TOSA_OK;
543}
544
545tosa_err_t TosaSerializationHandler::InitWithBuf(const uint8_t* buf)
546{
547 auto fb_tosa_graph = GetTosaGraph(buf);
548 auto fb_tosa_version = fb_tosa_graph->version();
549 auto fb_tosa_blocks = fb_tosa_graph->blocks();
550
551 std::vector<std::string> operator_inputs_container;
552 std::vector<std::string> operator_outputs_container;
553
554 std::vector<TosaSerializationOperator*> block_operators_container;
555 std::vector<TosaSerializationTensor*> block_tensors_container;
556 std::vector<std::string> block_inputs_container;
557 std::vector<std::string> block_outputs_container;
558
559 TosaAttributeBase* typed_attribute = NULL;
560 TosaQuantInfoBase* typed_qinfo = NULL;
561 TosaSerializationOperator* new_operator = NULL;
562 TosaSerializationBasicBlock* new_block = NULL;
563 TosaSerializationTensor* new_tensor = NULL;
564
565 // erase container
566 Clear();
567
568 TosaVersion read_version(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(),
569 fb_tosa_version->_experimental());
570 tosa_err_t err = CheckTosaVersion(read_version);
571
572 if (err != TOSA_OK)
573 return err;
574
575 for (size_t i = 0; i < fb_tosa_blocks->size(); i++)
576 {
577 auto curr_block = fb_tosa_blocks->Get(i);
578
579 auto block_name = curr_block->name()->str();
580
581 auto fb_tosa_operators = curr_block->operators();
582 block_operators_container.clear();
583 for (size_t j = 0; j < fb_tosa_operators->size(); j++)
584 {
585 auto curr_operator = fb_tosa_operators->Get(j);
586
587 auto operator_op = curr_operator->op();
588 auto attribute_type = curr_operator->attribute_type();
589 auto attribute = curr_operator->attribute();
590 auto operator_qinfo_type = curr_operator->quant_info_type();
591 auto operator_qinfo = curr_operator->quant_info();
592
593 // input tensors
594 auto operator_inputs = curr_operator->inputs();
595 operator_inputs_container.clear();
596 if (operator_inputs)
597 {
598 for (size_t k = 0; k < operator_inputs->size(); k++)
599 {
600 auto curr_input = operator_inputs->Get(k);
601 operator_inputs_container.push_back(curr_input->str());
602 }
603 }
604
605 // output tensors
606 auto operator_outputs = curr_operator->outputs();
607 operator_outputs_container.clear();
608 if (operator_outputs)
609 {
610 for (size_t k = 0; k < operator_outputs->size(); k++)
611 {
612 auto curr_output = operator_outputs->Get(k);
613 operator_outputs_container.push_back(curr_output->str());
614 }
615 }
616
617 switch (attribute_type)
618 {
619 case Attribute_NONE:
620 typed_attribute = new TosaNoneAttribute();
621 break;
622#define DEF_ATTRIBUTE(NAME, ...) \
623 case Attribute_##NAME##Attribute: \
624 typed_attribute = new Tosa##NAME##Attribute(attribute); \
625 break;
626#include "attribute.def"
627#undef DEF_ATTRIBUTE
628 default:
629 printf("TosaSerializationHandler::InitWithBuf(): Attribute %s not implemented yet\n",
630 EnumNamesAttribute()[attribute_type]);
631 return TOSA_INTERNAL_ERROR;
632 }
633
634 switch (operator_qinfo_type)
635 {
636 case QuantInfo_NONE:
637 typed_qinfo = new TosaNoneQuantInfo();
638 break;
639#define DEF_QUANTIZATION_INFO(NAME, ...) \
640 case QuantInfo_##NAME##QuantInfo: \
641 typed_qinfo = new Tosa##NAME##QuantInfo(operator_qinfo); \
642 break;
643
644#include "quant_info.def"
645#undef DEF_QUANTIZATION_INFO
646 default:
647 printf("TosaSerializationHandler::InitWithBuf(): QuantInfo %s not implemented yet\n",
648 EnumNamesQuantInfo()[operator_qinfo_type]);
649 return TOSA_INTERNAL_ERROR;
650 }
651
652 new_operator =
653 new TosaSerializationOperator(operator_op, attribute_type, typed_attribute, operator_qinfo_type,
654 typed_qinfo, operator_inputs_container, operator_outputs_container);
655 if (new_operator)
656 {
657 block_operators_container.push_back(new_operator);
658 }
659 else
660 {
661 return TOSA_MEMORY_ERROR;
662 }
663
664 if (typed_attribute)
665 delete typed_attribute;
666 if (typed_qinfo)
667 delete typed_qinfo;
668 }
669
670 auto fb_tosa_tensors = curr_block->tensors();
671 block_tensors_container.clear();
672 for (size_t j = 0; j < fb_tosa_tensors->size(); j++)
673 {
674 auto curr_tensor = fb_tosa_tensors->Get(j);
675
676 auto tensor_name = curr_tensor->name();
677 auto tensor_usage = curr_tensor->usage();
678 auto tensor_shape = curr_tensor->shape();
679 auto tensor_type = curr_tensor->type();
680 auto tensor_format = curr_tensor->format();
681 auto tensor_npy_filename = curr_tensor->npy_filename();
682
683 new_tensor = new TosaSerializationTensor(tensor_name, *tensor_usage, *tensor_shape, tensor_type,
684 *tensor_format, tensor_npy_filename);
685 if (new_tensor)
686 {
687 block_tensors_container.push_back(new_tensor);
688 }
689 else
690 {
691 return TOSA_MEMORY_ERROR;
692 }
693 }
694
695 auto block_inputs = curr_block->inputs();
696 auto block_outputs = curr_block->outputs();
697
698 block_inputs_container.clear();
699 block_outputs_container.clear();
700
701 for (size_t j = 0; j < block_inputs->size(); j++)
702 {
703 auto curr_block_input = block_inputs->Get(j);
704 block_inputs_container.push_back(curr_block_input->str());
705 }
706 for (size_t j = 0; j < block_outputs->size(); j++)
707 {
708 auto curr_block_output = block_outputs->Get(j);
709 block_outputs_container.push_back(curr_block_output->str());
710 }
711
712 new_block = new TosaSerializationBasicBlock(block_name, block_operators_container, block_tensors_container,
713 block_inputs_container, block_outputs_container);
714 if (new_block)
715 {
716 this->GetBlocks().push_back(new_block);
717 }
718 else
719 {
720 return TOSA_MEMORY_ERROR;
721 }
722 }
723
724 return TOSA_OK;
725}
726
727tosa_err_t TosaSerializationHandler::FreezeBuilder()
728{
729 std::vector<flatbuffers::Offset<TosaBasicBlock>> fboffset_blocks;
730
731 std::vector<flatbuffers::Offset<TosaOperator>> fboffset_block_operators;
732 std::vector<flatbuffers::Offset<TosaTensor>> fboffset_block_tensors;
733 std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_inputs;
734 std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_outputs;
735
736 std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_inputs;
737 std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_outputs;
738
739 // translate TosaFlatbufferOperator to flatbuffers::Offset<TosaOperator>
740 for (auto block : GetBlocks())
741 {
742 fboffset_block_operators.clear();
743 fboffset_block_tensors.clear();
744 fboffset_block_inputs.clear();
745 fboffset_block_outputs.clear();
746
747 auto block_name = _builder->CreateString(block->GetName().c_str());
748
749 for (auto tensor_str : block->GetInputs())
750 {
751 auto tensor_name = _builder->CreateString(tensor_str.c_str());
752 fboffset_block_inputs.push_back(tensor_name);
753 }
754
755 for (auto tensor_str : block->GetOutputs())
756 {
757 auto tensor_name = _builder->CreateString(tensor_str.c_str());
758 fboffset_block_outputs.push_back(tensor_name);
759 }
760
761 auto fb_block_inputs = _builder->CreateVector(fboffset_block_inputs);
762 auto fb_block_outputs = _builder->CreateVector(fboffset_block_outputs);
763
764 for (auto op : block->GetOperators())
765 {
766 fboffset_operator_inputs.clear();
767 fboffset_operator_outputs.clear();
768
769 auto operator_op = op->GetOp();
770 auto attribute_type = op->GetAttributeType();
771
772 for (auto tensor_str : op->GetInputTensorNames())
773 {
774 auto tensor_name = _builder->CreateString(tensor_str.c_str());
775 fboffset_operator_inputs.push_back(tensor_name);
776 }
777
778 for (auto tensor_str : op->GetOutputTensorNames())
779 {
780 auto tensor_name = _builder->CreateString(tensor_str.c_str());
781 fboffset_operator_outputs.push_back(tensor_name);
782 }
783
784 auto fb_operator_inputs = _builder->CreateVector(fboffset_operator_inputs);
785 auto fb_operator_outputs = _builder->CreateVector(fboffset_operator_outputs);
786
787 flatbuffers::Offset<void> fb_attribute;
788 switch (attribute_type)
789 {
790 case Attribute_NONE:
791 fb_attribute = 0;
792 break;
793
794#define DEF_ARGS_S_STR(NAME, V) , _builder->CreateString(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V().c_str())
795#define DEF_ARGS_S_DEFAULT(NAME, V) , reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V()
796
797#define DEF_ARGS_S_int32_t(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
798#define DEF_ARGS_S_float(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
799#define DEF_ARGS_S_bool(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
800#define DEF_ARGS_S_ResizeMode(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
801#define DEF_ARGS_S_string(NAME, V) DEF_ARGS_S_STR(NAME, V)
802
803#define DEF_ARGS_S(NAME, T, V) DEF_ARGS_S_##T(NAME, V)
804#define DEF_ARGS_V(NAME, T, V) , _builder->CreateVector<T>(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V())
805
806#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0)
807#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1)
808#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \
809 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2)
810#define DEF_ARGS_4(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \
811 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3)
812#define DEF_ARGS_5(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \
813 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
814 DEF_ARGS_##F4(NAME, T4, V4)
815#define DEF_ARGS_6(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \
816 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
817 DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5)
818#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \
819 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
820 DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6)
821#define DEF_ATTRIBUTE(NAME, NUM_ARGS, ...) \
822 case Attribute_##NAME##Attribute: \
823 fb_attribute = Create##NAME##Attribute(*_builder DEF_ARGS_##NUM_ARGS(NAME##Attribute, __VA_ARGS__)).Union(); \
824 break;
825
826#include "attribute.def"
827#undef DEF_ATTRIBUTE
828#undef DEF_ARGS_1
829#undef DEF_ARGS_2
830#undef DEF_ARGS_3
831#undef DEF_ARGS_4
832#undef DEF_ARGS_5
833#undef DEF_ARGS_6
834#undef DEF_ARGS_7
835#undef DEF_ARGS_S
836#undef DEF_ARGS_V
837#undef DEF_ARGS_S_int32_t
838#undef DEF_ARGS_S_float
839#undef DEF_ARGS_S_bool
840#undef DEF_ARGS_S_ResizeMode
841#undef DEF_ARGS_S_string
842#undef DEF_ARGS_S_STR
843#undef DEF_ARGS_S_DEFAULT
844 default:
845 printf("TosaSerializationHandler::FreezeBuilder(): Attribute %s not implemented yet\n",
846 EnumNamesAttribute()[attribute_type]);
847 return TOSA_INTERNAL_ERROR;
848 }
849
850 auto qinfo_type = op->GetQInfoType();
851 flatbuffers::Offset<void> fb_operator_qinfo;
852 switch (qinfo_type)
853 {
854 case QuantInfo_NONE:
855 fb_operator_qinfo = 0;
856 break;
857#define DEF_ARGS_S(NAME, T, V) , reinterpret_cast<Tosa##NAME*>(op->GetQInfo())->V()
858#define DEF_ARGS_V(NAME, T, V) , _builder->CreateVector<T>(reinterpret_cast<Tosa##NAME*>(op->GetQInfo())->V())
859
860#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0)
861#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1)
862#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \
863 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2)
864#define DEF_ARGS_4(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \
865 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3)
866#define DEF_ARGS_5(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \
867 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
868 DEF_ARGS_##F4(NAME, T4, V4)
869#define DEF_ARGS_6(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \
870 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
871 DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5)
872#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \
873 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
874 DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6)
875#define DEF_ARGS_8(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
876 V7) \
877 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
878 DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \
879 DEF_ARGS_##F7(NAME, T7, V7)
880#define DEF_ARGS_9(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
881 V7, T8, F8, V8) \
882 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
883 DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \
884 DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8)
885#define DEF_ARGS_10(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
886 V7, T8, F8, V8, T9, F9, V9) \
887 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
888 DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \
889 DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8) DEF_ARGS_##F9(NAME, T9, V9)
890#define DEF_QUANTIZATION_INFO(NAME, NUM_ARGS, ...) \
891 case QuantInfo_##NAME##QuantInfo: \
892 fb_operator_qinfo = \
893 Create##NAME##QuantInfo(*_builder DEF_ARGS_##NUM_ARGS(NAME##QuantInfo, __VA_ARGS__)).Union(); \
894 break;
895
896#include "quant_info.def"
897#undef DEF_QUANTIZATION_INFO
898#undef DEF_ARGS_1
899#undef DEF_ARGS_2
900#undef DEF_ARGS_3
901#undef DEF_ARGS_4
902#undef DEF_ARGS_5
903#undef DEF_ARGS_6
904#undef DEF_ARGS_7
905#undef DEF_ARGS_8
906#undef DEF_ARGS_9
907#undef DEF_ARGS_10
908#undef DEF_ARGS_S
909#undef DEF_ARGS_V
910 default:
911 printf("TosaSerializationHandler::FreezeBuilder(): Attribute %s not implemented yet\n",
912 EnumNamesAttribute()[attribute_type]);
913 return TOSA_INTERNAL_ERROR;
914 }
915
916 auto fboffset_operator =
917 CreateTosaOperator(*_builder, operator_op, attribute_type, fb_attribute, fb_operator_inputs,
918 fb_operator_outputs, qinfo_type, fb_operator_qinfo);
919 fboffset_block_operators.push_back(fboffset_operator);
920 }
921
922 auto fb_block_operators = _builder->CreateVector(fboffset_block_operators);
923
924 for (auto tensor : block->GetTensors())
925 {
926
927 auto tensor_name = _builder->CreateString(tensor->GetName().c_str());
928 auto tensor_usage =
929 _builder->CreateVector(std::vector<uint32_t>(tensor->GetUsage().begin(), tensor->GetUsage().end()));
930 auto tensor_shape = _builder->CreateVector(tensor->GetShape());
931 auto tensor_dtype = tensor->GetDtype();
932 auto tensor_format =
933 _builder->CreateVector(std::vector<uint32_t>(tensor->GetFormat().begin(), tensor->GetFormat().end()));
934 flatbuffers::Offset<flatbuffers::String> tensor_npy_filename = 0;
935 if (tensor->GetNpyFilePtr())
936 tensor_npy_filename = _builder->CreateString(tensor->GetNpyFilePtr()->c_str());
937
938 auto fboffset_tensor = CreateTosaTensor(*_builder, tensor_name, tensor_shape, tensor_dtype, tensor_usage,
939 tensor_format, tensor_npy_filename);
940 fboffset_block_tensors.push_back(fboffset_tensor);
941 }
942
943 auto fb_block_tensors = _builder->CreateVector(fboffset_block_tensors);
944
945 auto fboffset_block = CreateTosaBasicBlock(*_builder, block_name, fb_block_operators, fb_block_tensors,
946 fb_block_inputs, fb_block_outputs);
947 fboffset_blocks.push_back(fboffset_block);
948 }
949
950 auto fb_blocks = _builder->CreateVector(fboffset_blocks);
951
952 auto fb_version = CreateVersion(*_builder, GetTosaVersion()->_major, GetTosaVersion()->_minor,
953 GetTosaVersion()->_patch, GetTosaVersion()->_experimental);
954
955 auto fb_graph = CreateTosaGraph(*_builder, fb_version, fb_blocks);
956 _builder->Finish(fb_graph);
957
958 return TOSA_OK;
959}
960
961// Magic NUMPY header
962static const char NUMPY_HEADER_STR[] = "\x93NUMPY\x1\x0\x76\x0{";
963static const int NUMPY_HEADER_SZ = 128;
964
965NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, bool* databuf)
966{
967 const char dtype_str[] = "'|b1'";
968 FILE* infile = nullptr;
969 NPError rc = NO_ERROR;
970
971 assert(filename);
972 assert(databuf);
973
974 infile = fopen(filename, "rb");
975 if (!infile)
976 {
977 rc = FILE_NOT_FOUND;
978 goto done;
979 }
980
981 rc = checkNpyHeader(infile, elems, dtype_str);
982 if (rc != NO_ERROR)
983 {
984 goto done;
985 }
986
987 // Read in the data from numpy byte array to native bool
988 // array format
989 for (uint32_t i = 0; i < elems; i++)
990 {
991 int val = fgetc(infile);
992
993 if (val == EOF)
994 {
995 rc = FILE_IO_ERROR;
996 goto done;
997 }
998
999 databuf[i] = val;
1000 }
1001
1002done:
1003
1004 if (infile)
1005 fclose(infile);
1006
1007 return rc;
1008}
1009
1010NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int32_t* databuf)
1011{
1012 const char dtype_str[] = "'<i4'";
1013 FILE* infile = nullptr;
1014 NPError rc = NO_ERROR;
1015
1016 assert(filename);
1017 assert(databuf);
1018
1019 infile = fopen(filename, "rb");
1020 if (!infile)
1021 {
1022 rc = FILE_NOT_FOUND;
1023 goto done;
1024 }
1025
1026 rc = checkNpyHeader(infile, elems, dtype_str);
1027 if (rc != NO_ERROR)
1028 {
1029 goto done;
1030 }
1031
1032 // Now we are at the beginning of the data
1033 // Parse based on the datatype and number of dimensions
1034 if (fread(databuf, sizeof(int32_t), elems, infile) != elems)
1035 {
1036 rc = FILE_IO_ERROR;
1037 goto done;
1038 }
1039
1040done:
1041
1042 if (infile)
1043 fclose(infile);
1044
1045 return rc;
1046}
1047
1048NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, int64_t* databuf)
1049{
1050 const char dtype_str[] = "'<i8'";
1051 FILE* infile = nullptr;
1052 NPError rc = NO_ERROR;
1053
1054 assert(filename);
1055 assert(databuf);
1056
1057 infile = fopen(filename, "rb");
1058 if (!infile)
1059 {
1060 rc = FILE_NOT_FOUND;
1061 goto done;
1062 }
1063
1064 rc = checkNpyHeader(infile, elems, dtype_str);
1065 if (rc != NO_ERROR)
1066 {
1067 goto done;
1068 }
1069
1070 // Now we are at the beginning of the data
1071 // Parse based on the datatype and number of dimensions
1072 if (fread(databuf, sizeof(int64_t), elems, infile) != elems)
1073 {
1074 rc = FILE_IO_ERROR;
1075 goto done;
1076 }
1077
1078done:
1079
1080 if (infile)
1081 fclose(infile);
1082
1083 return rc;
1084}
1085
1086NumpyUtilities::NPError NumpyUtilities::readFromNpyFile(const char* filename, const uint32_t elems, float* databuf)
1087{
1088 const char dtype_str[] = "'<f4'";
1089 FILE* infile = nullptr;
1090 NPError rc = NO_ERROR;
1091
1092 assert(filename);
1093 assert(databuf);
1094
1095 infile = fopen(filename, "rb");
1096 if (!infile)
1097 {
1098 rc = FILE_NOT_FOUND;
1099 goto done;
1100 }
1101
1102 rc = checkNpyHeader(infile, elems, dtype_str);
1103 if (rc != NO_ERROR)
1104 {
1105 goto done;
1106 }
1107
1108 // Now we are at the beginning of the data
1109 // Parse based on the datatype and number of dimensions
1110 if (fread(databuf, sizeof(float), elems, infile) != elems)
1111 {
1112 rc = FILE_IO_ERROR;
1113 goto done;
1114 }
1115
1116done:
1117
1118 if (infile)
1119 fclose(infile);
1120
1121 return rc;
1122}
1123
1124NumpyUtilities::NPError NumpyUtilities::checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str)
1125{
1126 char buf[NUMPY_HEADER_SZ + 1];
1127 char* ptr = nullptr;
1128 NPError rc = NO_ERROR;
1129 bool foundFormat = false;
1130 bool foundOrder = false;
1131 bool foundShape = false;
1132 bool fortranOrder = false;
1133 std::vector<int> shape;
1134 uint32_t totalElems = 1;
1135 char* outer_end = NULL;
1136
1137 assert(infile);
1138 assert(elems > 0);
1139
1140 if (fread(buf, NUMPY_HEADER_SZ, 1, infile) != 1)
1141 {
1142 rc = HEADER_PARSE_ERROR;
1143 goto done;
1144 }
1145
1146 if (memcmp(buf, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1))
1147 {
1148 rc = HEADER_PARSE_ERROR;
1149 goto done;
1150 }
1151
1152 ptr = strtok_r(buf + sizeof(NUMPY_HEADER_STR) - 1, ":", &outer_end);
1153
1154 // Read in the data type, order, and shape
1155 while (ptr && (!foundFormat || !foundOrder || !foundShape))
1156 {
1157
1158 // End of string?
1159 if (!ptr)
1160 break;
1161
1162 // Skip whitespace
1163 while (isspace(*ptr))
1164 ptr++;
1165
1166 // Parse the dictionary field name
1167 if (!strcmp(ptr, "'descr'"))
1168 {
1169 ptr = strtok_r(NULL, ",", &outer_end);
1170 if (!ptr)
1171 break;
1172
1173 while (isspace(*ptr))
1174 ptr++;
1175
1176 if (strcmp(ptr, dtype_str))
1177 {
1178 rc = FILE_TYPE_MISMATCH;
1179 goto done;
1180 }
1181
1182 foundFormat = true;
1183 }
1184 else if (!strcmp(ptr, "'fortran_order'"))
1185 {
1186 ptr = strtok_r(NULL, ",", &outer_end);
1187 if (!ptr)
1188 break;
1189
1190 while (isspace(*ptr))
1191 ptr++;
1192
1193 if (!strcmp(ptr, "False"))
1194 {
1195 fortranOrder = false;
1196 }
1197 else
1198 {
1199 rc = FILE_TYPE_MISMATCH;
1200 goto done;
1201 }
1202
1203 foundOrder = true;
1204 }
1205 else if (!strcmp(ptr, "'shape'"))
1206 {
1207
1208 ptr = strtok_r(NULL, "(", &outer_end);
1209 if (!ptr)
1210 break;
1211 ptr = strtok_r(NULL, ")", &outer_end);
1212 if (!ptr)
1213 break;
1214
1215 while (isspace(*ptr))
1216 ptr++;
1217
1218 // The shape contains N comma-separated integers. Read up to 4.
1219 char* end = NULL;
1220
1221 ptr = strtok_r(ptr, ",", &end);
1222 for (int i = 0; i < 4; i++)
1223 {
1224 // Out of dimensions
1225 if (!ptr)
1226 break;
1227
Kevin Cheng40253152020-12-30 10:12:35 -08001228 int dim = atoi(ptr);
1229
1230 // Dimension is 0
1231 if (dim == 0)
1232 break;
1233
1234 shape.push_back(dim);
1235 totalElems *= dim;
Eric Kunzee5e26762020-10-13 16:11:07 -07001236 ptr = strtok_r(NULL, ",", &end);
1237 }
1238
1239 foundShape = true;
1240 }
1241 else
1242 {
1243 rc = HEADER_PARSE_ERROR;
1244 goto done;
1245 }
1246
1247 if (!ptr)
1248 break;
1249
1250 ptr = strtok_r(NULL, ":", &outer_end);
1251 }
1252
1253 if (!foundShape || !foundFormat || !foundOrder)
1254 {
1255 rc = HEADER_PARSE_ERROR;
1256 goto done;
1257 }
1258
1259 // Validate header
1260 if (fortranOrder != false)
1261 {
1262 rc = FILE_TYPE_MISMATCH;
1263 goto done;
1264 }
1265
1266 if (totalElems != elems)
1267 {
1268 rc = BUFFER_SIZE_MISMATCH;
1269 goto done;
1270 }
1271
1272 // Go back to the begininng and read until the end of the header dictionary
1273 rewind(infile);
1274 int val;
1275
1276 do
1277 {
1278 val = fgetc(infile);
1279 } while (val != EOF && val != '\n');
1280
1281done:
1282
1283 return rc;
1284}
1285
1286NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const bool* databuf)
1287{
1288 std::vector<int32_t> shape = { (int32_t)elems };
1289 return writeToNpyFile(filename, shape, databuf);
1290}
1291
1292NumpyUtilities::NPError
1293 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* databuf)
1294{
1295 const char dtype_str[] = "'|b1'";
1296 FILE* outfile = nullptr;
1297 NPError rc = NO_ERROR;
1298 uint32_t totalElems = 1;
1299
1300 assert(filename);
1301 assert(shape.size() >= 0);
1302 assert(databuf);
1303
1304 outfile = fopen(filename, "wb");
1305
1306 if (!outfile)
1307 {
1308 rc = FILE_NOT_FOUND;
1309 goto done;
1310 }
1311
1312 for (uint32_t i = 0; i < shape.size(); i++)
1313 {
1314 totalElems *= shape[i];
1315 }
1316
1317 rc = writeNpyHeader(outfile, shape, dtype_str);
1318
1319 // Numpy save format stores booleans as a byte array
1320 // with one byte per boolean. This somewhat inefficiently
1321 // remaps from system bool[] to this format.
1322 for (uint32_t i = 0; i < totalElems; i++)
1323 {
1324 int val = databuf[i] ? 1 : 0;
1325 if (fputc(val, outfile) == EOF)
1326 {
1327 rc = FILE_IO_ERROR;
1328 goto done;
1329 }
1330 }
1331
1332done:
1333
1334 if (outfile)
1335 fclose(outfile);
1336
1337 return rc;
1338}
1339
1340NumpyUtilities::NPError
1341 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* databuf)
1342{
1343 std::vector<int32_t> shape = { (int32_t)elems };
1344 return writeToNpyFile(filename, shape, databuf);
1345}
1346
1347NumpyUtilities::NPError
1348 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* databuf)
1349{
1350 const char dtype_str[] = "'<i4'";
1351 FILE* outfile = nullptr;
1352 NPError rc = NO_ERROR;
1353 uint32_t totalElems = 1;
1354
1355 assert(filename);
1356 assert(shape.size() >= 0);
1357 assert(databuf);
1358
1359 outfile = fopen(filename, "wb");
1360
1361 if (!outfile)
1362 {
1363 rc = FILE_NOT_FOUND;
1364 goto done;
1365 }
1366
1367 for (uint32_t i = 0; i < shape.size(); i++)
1368 {
1369 totalElems *= shape[i];
1370 }
1371
1372 rc = writeNpyHeader(outfile, shape, dtype_str);
1373
1374 if (fwrite(databuf, sizeof(int32_t), totalElems, outfile) != totalElems)
1375 {
1376 rc = FILE_IO_ERROR;
1377 goto done;
1378 }
1379
1380done:
1381
1382 if (outfile)
1383 fclose(outfile);
1384
1385 return rc;
1386}
1387
1388NumpyUtilities::NPError
1389 NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* databuf)
1390{
1391 std::vector<int32_t> shape = { (int32_t)elems };
1392 return writeToNpyFile(filename, shape, databuf);
1393}
1394
1395NumpyUtilities::NPError
1396 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* databuf)
1397{
1398 const char dtype_str[] = "'<i8'";
1399 FILE* outfile = nullptr;
1400 NPError rc = NO_ERROR;
1401 uint32_t totalElems = 1;
1402
1403 assert(filename);
1404 assert(shape.size() >= 0);
1405 assert(databuf);
1406
1407 outfile = fopen(filename, "wb");
1408
1409 if (!outfile)
1410 {
1411 rc = FILE_NOT_FOUND;
1412 goto done;
1413 }
1414
1415 for (uint32_t i = 0; i < shape.size(); i++)
1416 {
1417 totalElems *= shape[i];
1418 }
1419
1420 rc = writeNpyHeader(outfile, shape, dtype_str);
1421
1422 if (fwrite(databuf, sizeof(int64_t), totalElems, outfile) != totalElems)
1423 {
1424 rc = FILE_IO_ERROR;
1425 goto done;
1426 }
1427
1428done:
1429
1430 if (outfile)
1431 fclose(outfile);
1432
1433 return rc;
1434}
1435
1436NumpyUtilities::NPError NumpyUtilities::writeToNpyFile(const char* filename, const uint32_t elems, const float* databuf)
1437{
1438 std::vector<int32_t> shape = { (int32_t)elems };
1439 return writeToNpyFile(filename, shape, databuf);
1440}
1441
1442NumpyUtilities::NPError
1443 NumpyUtilities::writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* databuf)
1444{
1445 const char dtype_str[] = "'<f4'";
1446 FILE* outfile = nullptr;
1447 NPError rc = NO_ERROR;
1448 uint32_t totalElems = 1;
1449
1450 assert(filename);
1451 assert(shape.size() >= 0);
1452 assert(databuf);
1453
1454 outfile = fopen(filename, "wb");
1455
1456 if (!outfile)
1457 {
1458 rc = FILE_NOT_FOUND;
1459 goto done;
1460 }
1461
1462 for (uint32_t i = 0; i < shape.size(); i++)
1463 {
1464 totalElems *= shape[i];
1465 }
1466
1467 rc = writeNpyHeader(outfile, shape, dtype_str);
1468
1469 if (fwrite(databuf, sizeof(float), totalElems, outfile) != totalElems)
1470 {
1471 rc = FILE_IO_ERROR;
1472 goto done;
1473 }
1474
1475done:
1476
1477 if (outfile)
1478 fclose(outfile);
1479
1480 return rc;
1481}
1482
1483NumpyUtilities::NPError
1484 NumpyUtilities::writeNpyHeader(FILE* outfile, const std::vector<int32_t>& shape, const char* dtype_str)
1485{
1486 NPError rc = NO_ERROR;
1487 uint32_t i;
1488 char header[NUMPY_HEADER_SZ + 1];
1489 int headerPos = 0;
1490
1491 assert(outfile);
1492 assert(shape.size() >= 0);
1493
1494 // Space-fill the header and end with a newline to start per numpy spec
1495 memset(header, 0x20, NUMPY_HEADER_SZ);
1496 header[NUMPY_HEADER_SZ - 1] = '\n';
1497 header[NUMPY_HEADER_SZ] = 0;
1498
1499 // Write out the hard-coded header. We only support a 128-byte 1.0 header
1500 // for now, which should be sufficient for simple tensor types of any
1501 // reasonable rank.
1502 memcpy(header, NUMPY_HEADER_STR, sizeof(NUMPY_HEADER_STR) - 1);
1503 headerPos += sizeof(NUMPY_HEADER_STR) - 1;
1504
1505 // Output the format dictionary
1506 // Hard-coded for I32 for now
1507 headerPos +=
1508 snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "'descr': %s, 'fortran_order': False, 'shape': (%d,",
1509 dtype_str, shape.size() > 0 ? shape[0] : 1);
1510
1511 // Remainder of shape array
1512 for (i = 1; i < shape.size(); i++)
1513 {
1514 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, " %d,", shape[i]);
1515 }
1516
1517 // Close off the dictionary
1518 headerPos += snprintf(header + headerPos, NUMPY_HEADER_SZ - headerPos, "), }");
1519
1520 // snprintf leaves a NULL at the end. Replace with a space
1521 header[headerPos] = 0x20;
1522
1523 if (fwrite(header, NUMPY_HEADER_SZ, 1, outfile) != 1)
1524 {
1525 rc = FILE_IO_ERROR;
1526 goto done;
1527 }
1528
1529done:
1530
1531 return rc;
1532}