blob: 4d693967cfef9583084e5dcbcf5767593738ae2b [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
2// Copyright (c) 2020-2021, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "tosa_serialization_handler.h"
17
18#include <iostream>
19using namespace tosa;
20
21TosaSerializationTensor::TosaSerializationTensor(const flatbuffers::String* name,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070022 const flatbuffers::Vector<int32_t>* shape,
Eric Kunze2364dcd2021-04-26 11:06:57 -070023 DType dtype,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070024 const flatbuffers::Vector<uint8_t>* data)
Eric Kunze2364dcd2021-04-26 11:06:57 -070025{
26 _dtype = dtype;
27
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070028 std::copy(shape->begin(), shape->end(), std::back_inserter(_shape));
Eric Kunze2364dcd2021-04-26 11:06:57 -070029
30 assert(name);
31 _name = name->str();
32
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070033 if (data)
Eric Kunze2364dcd2021-04-26 11:06:57 -070034 {
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070035 std::copy(data->begin(), data->end(), std::back_inserter(_data));
Eric Kunze2364dcd2021-04-26 11:06:57 -070036 }
37}
38
39TosaSerializationTensor::TosaSerializationTensor(std::string& name,
40 const std::vector<int32_t>& shape,
41 DType dtype,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070042 const std::vector<uint8_t>& data)
Eric Kunze2364dcd2021-04-26 11:06:57 -070043{
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070044 _dtype = dtype;
45 _shape = shape;
46 _name = name;
47 _data = data;
Eric Kunze2364dcd2021-04-26 11:06:57 -070048}
49
50TosaSerializationTensor::TosaSerializationTensor()
51{
52 _dtype = DType_UNKNOWN;
53
54 _name = "UNKNOWN";
55}
56
57TosaSerializationTensor::~TosaSerializationTensor()
58{}
59
60TosaSerializationOperator::TosaSerializationOperator(Op op,
61 Attribute attribute_type,
62 const TosaAttributeBase* attribute,
63 QuantInfo qinfo_type,
64 const TosaQuantInfoBase* qinfo,
65 std::vector<std::string> input_tensor_names,
66 std::vector<std::string> output_tensor_names)
67{
68 _op = op;
69 _attribute_type = attribute_type;
70
71 switch (attribute_type)
72 {
73 case Attribute_NONE:
74 _attribute = new TosaNoneAttribute();
75 break;
76#define DEF_ATTRIBUTE(NAME, ...) \
77 case Attribute_##NAME##Attribute: \
78 _attribute = new Tosa##NAME##Attribute(attribute); \
79 break;
80#include "attribute.def"
81#undef DEF_ATTRIBUTE
82 default:
83 printf("TosaSerializationOperator::TosaSerializationOperator(): Attribute %s not implemented yet\n",
84 EnumNamesAttribute()[attribute_type]);
85 assert(0);
86 }
87
88 _qinfo_type = qinfo_type;
89 switch (qinfo_type)
90 {
91 case QuantInfo_NONE:
92 _qinfo = new TosaNoneQuantInfo();
93 break;
94#define DEF_QUANTIZATION_INFO(NAME, ...) \
95 case QuantInfo_##NAME##QuantInfo: \
96 _qinfo = new Tosa##NAME##QuantInfo(qinfo); \
97 break;
98#include "quant_info.def"
99#undef DEF_QUANTIZATION_INFO
100 default:
101 printf("TosaSerializationOperator::TosaSerializationOperator(): QuantInfo %s not implemented yet\n",
102 EnumNamesQuantInfo()[qinfo_type]);
103 assert(0);
104 }
105
106 assert(_attribute && _qinfo);
107
108 _input_tensor_names = input_tensor_names;
109 _output_tensor_names = output_tensor_names;
110}
111
112TosaSerializationOperator::~TosaSerializationOperator()
113{
114 delete _attribute;
115 delete _qinfo;
116 // TosaSerializationTensor should be free'd in TosaSerializationSerializationHandler destructor
117}
118
119TosaSerializationBasicBlock::TosaSerializationBasicBlock(std::string name,
120 std::vector<TosaSerializationOperator*> operators,
121 std::vector<TosaSerializationTensor*> tensors,
122 std::vector<std::string> inputs,
123 std::vector<std::string> outputs)
124{
125
126 _name = name;
127 _operators = operators;
128 _tensors = tensors;
129 _inputs = inputs;
130 _outputs = outputs;
131}
132
133TosaSerializationBasicBlock::~TosaSerializationBasicBlock()
134{
135 // deallocate all operators
136 for (auto op : GetOperators())
137 {
138 delete op; // ~TosaSerializationOperator()
139 }
140
141 // deallocate all tensors
142 for (auto ts : GetTensors())
143 {
144 delete ts; // ~TosaSerializationTensor()
145 }
146}
147
148TosaSerializationHandler::TosaSerializationHandler()
149{
150 _schemaLoaded = false;
151
152 SetTosaVersion();
153}
154
155TosaSerializationHandler::~TosaSerializationHandler()
156{
157 Clear(); // deallocate all basic blocks
158}
159
160tosa_err_t TosaSerializationHandler::SetTosaVersion()
161{
162 // version is specified within .fbs
163 // and it's encoded as defaulted value of CreateTosaVersion()
164 // need to write out one object to read that value out
165 // TODO: very costly now. is there any better way to encode constant in .fbs?
166 auto fboffset_version = CreateVersion(_builder);
167 auto fboffset_tosa_graph = CreateTosaGraphDirect(_builder, fboffset_version, nullptr);
168 _builder.Finish(fboffset_tosa_graph);
169 std::string jsongen;
170 uint8_t* buf = _builder.GetBufferPointer();
171 auto fb_tosa_graph = GetTosaGraph(buf);
172 auto fb_tosa_version = fb_tosa_graph->version();
173
174 _version.set_version(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(),
175 fb_tosa_version->_experimental());
176
177 return TOSA_OK;
178}
179
180tosa_err_t TosaSerializationHandler::LoadFileSchema(const char* schema_filename)
181{
182 std::string schema;
183 bool ok;
184
185 ok = flatbuffers::LoadFile(schema_filename, false, &schema);
186 if (!ok)
187 {
188 printf("Error loading schema file: %s\n", schema_filename);
189 return TOSA_FILE_ERROR;
190 }
191
192 ok = _parser.Parse(schema.c_str());
193 if (!ok)
194 {
195 printf("Error parsing ISA schema file: %s\n", schema_filename);
196 return TOSA_FILE_ERROR;
197 }
198 _schemaLoaded = true;
199
200 return TOSA_OK;
201}
202
203tosa_err_t TosaSerializationHandler::LoadFileJson(const char* filename)
204{
205 std::string jsonfile;
206 bool ok;
207 tosa_err_t err;
208
209 if (!_schemaLoaded)
210 {
211 return TOSA_SCHEMA_MISSING;
212 }
213
214 ok = flatbuffers::LoadFile(filename, false, &jsonfile);
215 if (!ok)
216 {
217 printf("Error loading json file: %s\n", filename);
218 return TOSA_FILE_ERROR;
219 }
220
221 ok = _parser.Parse(jsonfile.c_str());
222 if (!ok)
223 {
224 printf("Error parsing json file: %s\n", filename);
225 return TOSA_FILE_ERROR;
226 }
227
228 uint8_t* buf = _parser.builder_.GetBufferPointer();
229
230 err = InitWithBuf(buf);
231 if (err != TOSA_OK)
232 {
233 return err;
234 }
235
236 return TOSA_OK;
237}
238
239tosa_err_t TosaSerializationHandler::SaveFileJson(const char* filename)
240{
241 std::string jsongen;
242 tosa_err_t err;
243
244 if (!_schemaLoaded)
245 {
246 return TOSA_SCHEMA_MISSING;
247 }
248
249 err = FreezeBuilder();
250 if (err != TOSA_OK)
251 {
252 return err;
253 }
254
255 uint8_t* buf = _builder.GetBufferPointer();
256
257 if (!GenerateText(_parser, buf, &jsongen))
258 {
259 printf("Couldn't serialize parsed data to JSON!\n");
260 return TOSA_FILE_ERROR;
261 }
262
263 FILE* file = fopen(filename, "wb");
264
265 if (!file)
266 {
267 printf("Couldn't open output file: %s\n", filename);
268 return TOSA_FILE_ERROR;
269 }
270
271 if (fwrite(jsongen.c_str(), sizeof(char), jsongen.size(), file) != jsongen.size())
272 {
273 printf("Error writing to json output file: %s\n", filename);
274 fclose(file);
275 return TOSA_FILE_ERROR;
276 }
277
278 if (file)
279 fclose(file);
280
281 return TOSA_OK;
282}
283
284tosa_err_t TosaSerializationHandler::LoadFileTosaFlatbuffer(const char* filename)
285{
286 std::string read_buffer;
287 tosa_err_t err;
288 uint8_t* buf;
289 bool ok;
290
291 ok = flatbuffers::LoadFile(filename, false, &read_buffer);
292 if (!ok)
293 {
294 printf("Error loading flatbuffer file: %s\n", filename);
295 return TOSA_FILE_ERROR;
296 }
297
298 buf = (uint8_t*)read_buffer.data();
299
300 err = InitWithBuf(buf);
301 if (err != TOSA_OK)
302 {
303 return err;
304 }
305
306 return TOSA_OK;
307}
308
309tosa_err_t TosaSerializationHandler::SaveFileTosaFlatbuffer(const char* filename)
310{
311 tosa_err_t err;
312
313 err = FreezeBuilder();
314 if (err != TOSA_OK)
315 {
316 return err;
317 }
318
319 uint8_t* buf = _builder.GetBufferPointer();
320
321 bool ok = flatbuffers::SaveFile(filename, (const char*)buf, _builder.GetSize(), false);
322 if (!ok)
323 {
324 printf("Error saving floatbuffer file: %s\n", filename);
325 return TOSA_FILE_ERROR;
326 }
327
328 return TOSA_OK;
329}
330
331tosa_err_t TosaSerializationHandler::Clear()
332{
333 // deallocate all basic blocks
334 for (auto bb : GetBlocks())
335 {
336 delete bb;
337 }
338 _blocks.clear();
339
340 return TOSA_OK;
341}
342
343tosa_err_t TosaSerializationHandler::CheckTosaVersion(const TosaVersion& read_version)
344{
345 if (_version != read_version)
346 {
347 printf("WARNING: read tosa version: %s != schema tosa version %s\n", read_version.to_string().c_str(),
348 _version.to_string().c_str());
349 return TOSA_VERSION_MISMATCH;
350 }
351
352 return TOSA_OK;
353}
354
355tosa_err_t TosaSerializationHandler::InitWithBuf(const uint8_t* buf)
356{
357 auto fb_tosa_graph = GetTosaGraph(buf);
358 auto fb_tosa_version = fb_tosa_graph->version();
359 auto fb_tosa_blocks = fb_tosa_graph->blocks();
360
361 std::vector<std::string> operator_inputs_container;
362 std::vector<std::string> operator_outputs_container;
363
364 std::vector<TosaSerializationOperator*> block_operators_container;
365 std::vector<TosaSerializationTensor*> block_tensors_container;
366 std::vector<std::string> block_inputs_container;
367 std::vector<std::string> block_outputs_container;
368
369 TosaAttributeBase* typed_attribute = NULL;
370 TosaQuantInfoBase* typed_qinfo = NULL;
371 TosaSerializationOperator* new_operator = NULL;
372 TosaSerializationBasicBlock* new_block = NULL;
373 TosaSerializationTensor* new_tensor = NULL;
374
375 // erase container
376 Clear();
377
378 TosaVersion read_version(fb_tosa_version->_major(), fb_tosa_version->_minor(), fb_tosa_version->_patch(),
379 fb_tosa_version->_experimental());
380 tosa_err_t err = CheckTosaVersion(read_version);
381
382 if (err != TOSA_OK)
383 return err;
384
385 for (size_t i = 0; i < fb_tosa_blocks->size(); i++)
386 {
387 auto curr_block = fb_tosa_blocks->Get(i);
388
389 auto block_name = curr_block->name()->str();
390
391 auto fb_tosa_operators = curr_block->operators();
392 block_operators_container.clear();
393 for (size_t j = 0; j < fb_tosa_operators->size(); j++)
394 {
395 auto curr_operator = fb_tosa_operators->Get(j);
396
397 auto operator_op = curr_operator->op();
398 auto attribute_type = curr_operator->attribute_type();
399 auto attribute = curr_operator->attribute();
400 auto operator_qinfo_type = curr_operator->quant_info_type();
401 auto operator_qinfo = curr_operator->quant_info();
402
403 // input tensors
404 auto operator_inputs = curr_operator->inputs();
405 operator_inputs_container.clear();
406 if (operator_inputs)
407 {
408 for (size_t k = 0; k < operator_inputs->size(); k++)
409 {
410 auto curr_input = operator_inputs->Get(k);
411 operator_inputs_container.push_back(curr_input->str());
412 }
413 }
414
415 // output tensors
416 auto operator_outputs = curr_operator->outputs();
417 operator_outputs_container.clear();
418 if (operator_outputs)
419 {
420 for (size_t k = 0; k < operator_outputs->size(); k++)
421 {
422 auto curr_output = operator_outputs->Get(k);
423 operator_outputs_container.push_back(curr_output->str());
424 }
425 }
426
427 switch (attribute_type)
428 {
429 case Attribute_NONE:
430 typed_attribute = new TosaNoneAttribute();
431 break;
432#define DEF_ATTRIBUTE(NAME, ...) \
433 case Attribute_##NAME##Attribute: \
434 typed_attribute = new Tosa##NAME##Attribute(attribute); \
435 break;
436#include "attribute.def"
437#undef DEF_ATTRIBUTE
438 default:
439 printf("TosaSerializationHandler::InitWithBuf(): Attribute %s not implemented yet\n",
440 EnumNamesAttribute()[attribute_type]);
441 return TOSA_INTERNAL_ERROR;
442 }
443
444 switch (operator_qinfo_type)
445 {
446 case QuantInfo_NONE:
447 typed_qinfo = new TosaNoneQuantInfo();
448 break;
449#define DEF_QUANTIZATION_INFO(NAME, ...) \
450 case QuantInfo_##NAME##QuantInfo: \
451 typed_qinfo = new Tosa##NAME##QuantInfo(operator_qinfo); \
452 break;
453
454#include "quant_info.def"
455#undef DEF_QUANTIZATION_INFO
456 default:
457 printf("TosaSerializationHandler::InitWithBuf(): QuantInfo %s not implemented yet\n",
458 EnumNamesQuantInfo()[operator_qinfo_type]);
459 return TOSA_INTERNAL_ERROR;
460 }
461
462 new_operator =
463 new TosaSerializationOperator(operator_op, attribute_type, typed_attribute, operator_qinfo_type,
464 typed_qinfo, operator_inputs_container, operator_outputs_container);
465 if (new_operator)
466 {
467 block_operators_container.push_back(new_operator);
468 }
469 else
470 {
471 return TOSA_MEMORY_ERROR;
472 }
473
474 if (typed_attribute)
475 delete typed_attribute;
476 if (typed_qinfo)
477 delete typed_qinfo;
478 }
479
480 auto fb_tosa_tensors = curr_block->tensors();
481 block_tensors_container.clear();
482 for (size_t j = 0; j < fb_tosa_tensors->size(); j++)
483 {
484 auto curr_tensor = fb_tosa_tensors->Get(j);
485
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700486 auto tensor_name = curr_tensor->name();
487 auto tensor_shape = curr_tensor->shape();
488 auto tensor_type = curr_tensor->type();
489 auto tensor_data = curr_tensor->data();
Eric Kunze2364dcd2021-04-26 11:06:57 -0700490
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700491 new_tensor = new TosaSerializationTensor(tensor_name, tensor_shape, tensor_type, tensor_data);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700492 if (new_tensor)
493 {
494 block_tensors_container.push_back(new_tensor);
495 }
496 else
497 {
498 return TOSA_MEMORY_ERROR;
499 }
500 }
501
502 auto block_inputs = curr_block->inputs();
503 auto block_outputs = curr_block->outputs();
504
505 block_inputs_container.clear();
506 block_outputs_container.clear();
507
508 for (size_t j = 0; j < block_inputs->size(); j++)
509 {
510 auto curr_block_input = block_inputs->Get(j);
511 block_inputs_container.push_back(curr_block_input->str());
512 }
513 for (size_t j = 0; j < block_outputs->size(); j++)
514 {
515 auto curr_block_output = block_outputs->Get(j);
516 block_outputs_container.push_back(curr_block_output->str());
517 }
518
519 new_block = new TosaSerializationBasicBlock(block_name, block_operators_container, block_tensors_container,
520 block_inputs_container, block_outputs_container);
521 if (new_block)
522 {
523 this->GetBlocks().push_back(new_block);
524 }
525 else
526 {
527 return TOSA_MEMORY_ERROR;
528 }
529 }
530
531 return TOSA_OK;
532}
533
534tosa_err_t TosaSerializationHandler::FreezeBuilder()
535{
536 std::vector<flatbuffers::Offset<TosaBasicBlock>> fboffset_blocks;
537
538 std::vector<flatbuffers::Offset<TosaOperator>> fboffset_block_operators;
539 std::vector<flatbuffers::Offset<TosaTensor>> fboffset_block_tensors;
540 std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_inputs;
541 std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_block_outputs;
542
543 std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_inputs;
544 std::vector<flatbuffers::Offset<flatbuffers::String>> fboffset_operator_outputs;
545
546 // translate TosaFlatbufferOperator to flatbuffers::Offset<TosaOperator>
547 for (auto block : GetBlocks())
548 {
549 fboffset_block_operators.clear();
550 fboffset_block_tensors.clear();
551 fboffset_block_inputs.clear();
552 fboffset_block_outputs.clear();
553
554 auto block_name = _builder.CreateString(block->GetName().c_str());
555
556 for (auto tensor_str : block->GetInputs())
557 {
558 auto tensor_name = _builder.CreateString(tensor_str.c_str());
559 fboffset_block_inputs.push_back(tensor_name);
560 }
561
562 for (auto tensor_str : block->GetOutputs())
563 {
564 auto tensor_name = _builder.CreateString(tensor_str.c_str());
565 fboffset_block_outputs.push_back(tensor_name);
566 }
567
568 auto fb_block_inputs = _builder.CreateVector(fboffset_block_inputs);
569 auto fb_block_outputs = _builder.CreateVector(fboffset_block_outputs);
570
571 for (auto op : block->GetOperators())
572 {
573 fboffset_operator_inputs.clear();
574 fboffset_operator_outputs.clear();
575
576 auto operator_op = op->GetOp();
577 auto attribute_type = op->GetAttributeType();
578
579 for (auto tensor_str : op->GetInputTensorNames())
580 {
581 auto tensor_name = _builder.CreateString(tensor_str.c_str());
582 fboffset_operator_inputs.push_back(tensor_name);
583 }
584
585 for (auto tensor_str : op->GetOutputTensorNames())
586 {
587 auto tensor_name = _builder.CreateString(tensor_str.c_str());
588 fboffset_operator_outputs.push_back(tensor_name);
589 }
590
591 auto fb_operator_inputs = _builder.CreateVector(fboffset_operator_inputs);
592 auto fb_operator_outputs = _builder.CreateVector(fboffset_operator_outputs);
593
594 flatbuffers::Offset<void> fb_attribute;
595 switch (attribute_type)
596 {
597 case Attribute_NONE:
598 fb_attribute = 0;
599 break;
600
601#define DEF_ARGS_S_STR(NAME, V) , _builder.CreateString(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V().c_str())
602#define DEF_ARGS_S_DEFAULT(NAME, V) , reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V()
603
604#define DEF_ARGS_S_int32_t(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
605#define DEF_ARGS_S_float(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
606#define DEF_ARGS_S_bool(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
607#define DEF_ARGS_S_ResizeMode(NAME, V) DEF_ARGS_S_DEFAULT(NAME, V)
608#define DEF_ARGS_S_string(NAME, V) DEF_ARGS_S_STR(NAME, V)
609
610#define DEF_ARGS_S(NAME, T, V) DEF_ARGS_S_##T(NAME, V)
611#define DEF_ARGS_V(NAME, T, V) , _builder.CreateVector<T>(reinterpret_cast<Tosa##NAME*>(op->GetAttribute())->V())
612
613#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0)
614#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1)
615#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \
616 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2)
617#define DEF_ARGS_4(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \
618 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3)
619#define DEF_ARGS_5(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \
620 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
621 DEF_ARGS_##F4(NAME, T4, V4)
622#define DEF_ARGS_6(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \
623 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
624 DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5)
625#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \
626 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
627 DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6)
628#define DEF_ATTRIBUTE(NAME, NUM_ARGS, ...) \
629 case Attribute_##NAME##Attribute: \
630 fb_attribute = Create##NAME##Attribute(_builder DEF_ARGS_##NUM_ARGS(NAME##Attribute, __VA_ARGS__)).Union(); \
631 break;
632
633#include "attribute.def"
634#undef DEF_ATTRIBUTE
635#undef DEF_ARGS_1
636#undef DEF_ARGS_2
637#undef DEF_ARGS_3
638#undef DEF_ARGS_4
639#undef DEF_ARGS_5
640#undef DEF_ARGS_6
641#undef DEF_ARGS_7
642#undef DEF_ARGS_S
643#undef DEF_ARGS_V
644#undef DEF_ARGS_S_int32_t
645#undef DEF_ARGS_S_float
646#undef DEF_ARGS_S_bool
647#undef DEF_ARGS_S_ResizeMode
648#undef DEF_ARGS_S_string
649#undef DEF_ARGS_S_STR
650#undef DEF_ARGS_S_DEFAULT
651 default:
652 printf("TosaSerializationHandler::FreezeBuilder(): Attribute %s not implemented yet\n",
653 EnumNamesAttribute()[attribute_type]);
654 return TOSA_INTERNAL_ERROR;
655 }
656
657 auto qinfo_type = op->GetQInfoType();
658 flatbuffers::Offset<void> fb_operator_qinfo;
659 switch (qinfo_type)
660 {
661 case QuantInfo_NONE:
662 fb_operator_qinfo = 0;
663 break;
664#define DEF_ARGS_S(NAME, T, V) , reinterpret_cast<Tosa##NAME*>(op->GetQInfo())->V()
665#define DEF_ARGS_V(NAME, T, V) , _builder.CreateVector<T>(reinterpret_cast<Tosa##NAME*>(op->GetQInfo())->V())
666
667#define DEF_ARGS_1(NAME, T0, F0, V0) DEF_ARGS_##F0(NAME, T0, V0)
668#define DEF_ARGS_2(NAME, T0, F0, V0, T1, F1, V1) DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1)
669#define DEF_ARGS_3(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2) \
670 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2)
671#define DEF_ARGS_4(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3) \
672 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3)
673#define DEF_ARGS_5(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4) \
674 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
675 DEF_ARGS_##F4(NAME, T4, V4)
676#define DEF_ARGS_6(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5) \
677 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
678 DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5)
679#define DEF_ARGS_7(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6) \
680 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
681 DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6)
682#define DEF_ARGS_8(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
683 V7) \
684 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
685 DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \
686 DEF_ARGS_##F7(NAME, T7, V7)
687#define DEF_ARGS_9(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
688 V7, T8, F8, V8) \
689 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
690 DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \
691 DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8)
692#define DEF_ARGS_10(NAME, T0, F0, V0, T1, F1, V1, T2, F2, V2, T3, F3, V3, T4, F4, V4, T5, F5, V5, T6, F6, V6, T7, F7, \
693 V7, T8, F8, V8, T9, F9, V9) \
694 DEF_ARGS_##F0(NAME, T0, V0) DEF_ARGS_##F1(NAME, T1, V1) DEF_ARGS_##F2(NAME, T2, V2) DEF_ARGS_##F3(NAME, T3, V3) \
695 DEF_ARGS_##F4(NAME, T4, V4) DEF_ARGS_##F5(NAME, T5, V5) DEF_ARGS_##F6(NAME, T6, V6) \
696 DEF_ARGS_##F7(NAME, T7, V7) DEF_ARGS_##F8(NAME, T8, V8) DEF_ARGS_##F9(NAME, T9, V9)
697#define DEF_QUANTIZATION_INFO(NAME, NUM_ARGS, ...) \
698 case QuantInfo_##NAME##QuantInfo: \
699 fb_operator_qinfo = \
700 Create##NAME##QuantInfo(_builder DEF_ARGS_##NUM_ARGS(NAME##QuantInfo, __VA_ARGS__)).Union(); \
701 break;
702
703#include "quant_info.def"
704#undef DEF_QUANTIZATION_INFO
705#undef DEF_ARGS_1
706#undef DEF_ARGS_2
707#undef DEF_ARGS_3
708#undef DEF_ARGS_4
709#undef DEF_ARGS_5
710#undef DEF_ARGS_6
711#undef DEF_ARGS_7
712#undef DEF_ARGS_8
713#undef DEF_ARGS_9
714#undef DEF_ARGS_10
715#undef DEF_ARGS_S
716#undef DEF_ARGS_V
717 default:
718 printf("TosaSerializationHandler::FreezeBuilder(): Attribute %s not implemented yet\n",
719 EnumNamesAttribute()[attribute_type]);
720 return TOSA_INTERNAL_ERROR;
721 }
722
723 auto fboffset_operator =
724 CreateTosaOperator(_builder, operator_op, attribute_type, fb_attribute, fb_operator_inputs,
725 fb_operator_outputs, qinfo_type, fb_operator_qinfo);
726 fboffset_block_operators.push_back(fboffset_operator);
727 }
728
729 auto fb_block_operators = _builder.CreateVector(fboffset_block_operators);
730
731 for (auto tensor : block->GetTensors())
732 {
733
734 auto tensor_name = _builder.CreateString(tensor->GetName().c_str());
735 auto tensor_shape = _builder.CreateVector(tensor->GetShape());
736 auto tensor_dtype = tensor->GetDtype();
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700737 auto tensor_data = _builder.CreateVector(tensor->GetData());
Eric Kunze2364dcd2021-04-26 11:06:57 -0700738
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700739 auto fboffset_tensor = CreateTosaTensor(_builder, tensor_name, tensor_shape, tensor_dtype, tensor_data);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700740 fboffset_block_tensors.push_back(fboffset_tensor);
741 }
742
743 auto fb_block_tensors = _builder.CreateVector(fboffset_block_tensors);
744
745 auto fboffset_block = CreateTosaBasicBlock(_builder, block_name, fb_block_operators, fb_block_tensors,
746 fb_block_inputs, fb_block_outputs);
747 fboffset_blocks.push_back(fboffset_block);
748 }
749
750 auto fb_blocks = _builder.CreateVector(fboffset_blocks);
751
752 auto fb_version = CreateVersion(_builder, GetTosaVersion()._major, GetTosaVersion()._minor, GetTosaVersion()._patch,
753 GetTosaVersion()._experimental);
754
755 auto fb_graph = CreateTosaGraph(_builder, fb_version, fb_blocks);
756 _builder.Finish(fb_graph);
757
758 return TOSA_OK;
759}
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700760
761void zero_pad(std::vector<uint8_t>& buf)
762{
763 while ((buf.size() % TENSOR_BUFFER_FORCE_ALIGNMENT) != 0)
764 {
765 buf.push_back(0);
766 }
767}
768
769tosa_err_t TosaSerializationHandler::ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out)
770{
771 out.clear();
772 for (auto val : in)
773 {
774 uint32_t* val_u32 = reinterpret_cast<uint32_t*>(&val);
775 out.push_back(*val_u32 & 0xFF);
776 out.push_back((*val_u32 >> 8) & 0xFF);
777 out.push_back((*val_u32 >> 16) & 0xFF);
778 out.push_back((*val_u32 >> 24) & 0xFF);
779 }
780 zero_pad(out);
781 return TOSA_OK;
782}
783
784tosa_err_t TosaSerializationHandler::ConvertI48toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out)
785{
786 out.clear();
787 for (auto val : in)
788 {
789 uint64_t* val_u64 = reinterpret_cast<uint64_t*>(&val);
790 out.push_back(*val_u64 & 0xFF);
791 out.push_back((*val_u64 >> 8) & 0xFF);
792 out.push_back((*val_u64 >> 16) & 0xFF);
793 out.push_back((*val_u64 >> 24) & 0xFF);
794 out.push_back((*val_u64 >> 32) & 0xFF);
795 out.push_back((*val_u64 >> 40) & 0xFF);
796 }
797 zero_pad(out);
798 return TOSA_OK;
799}
800
801tosa_err_t TosaSerializationHandler::ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out)
802{
803 out.clear();
804 for (auto val : in)
805 {
806 uint32_t* val_u32 = reinterpret_cast<uint32_t*>(&val);
807 out.push_back(*val_u32 & 0xFF);
808 out.push_back((*val_u32 >> 8) & 0xFF);
809 out.push_back((*val_u32 >> 16) & 0xFF);
810 out.push_back((*val_u32 >> 24) & 0xFF);
811 }
812 zero_pad(out);
813 return TOSA_OK;
814}
815
816tosa_err_t TosaSerializationHandler::ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out)
817{
818 out.clear();
819 for (auto val : in)
820 {
821 uint16_t* val_u16 = reinterpret_cast<uint16_t*>(&val);
822 out.push_back(*val_u16 & 0xFF);
823 out.push_back((*val_u16 >> 8) & 0xFF);
824 }
825 zero_pad(out);
826 return TOSA_OK;
827}
828
829tosa_err_t TosaSerializationHandler::ConvertI8toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out)
830{
831 out.clear();
832 for (auto val : in)
833 {
834 uint8_t* val_u8 = reinterpret_cast<uint8_t*>(&val);
835 out.push_back(*val_u8);
836 }
837 zero_pad(out);
838 return TOSA_OK;
839}
840
Kevin Cheng3ce56342021-07-28 13:42:29 -0700841// Two int4 values are packed into one byte out.
842// For given input value val_0 = in[2*i], and val_1 = in[2*i+1],
843// they'll be packed as out[3:0] = val_0, and out[7:4] = val_1
844tosa_err_t TosaSerializationHandler::ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out)
845{
846 out.clear();
847 uint32_t in_size = in.size();
848 uint32_t out_size = (in_size % 2 == 0) ? (in_size / 2) : ((in_size + 1) / 2);
849 for (int i = 0; i < out_size; i++)
850 {
851 int8_t val_0 = in[2 * i];
852 int8_t val_1 = 0;
853 if (2 * i + 1 < in_size)
854 {
855 val_1 = in[2 * i + 1];
856 }
857 // In TOSA spec, int4 ranges [-7, 7]
858 if (val_0 < -7 || val_0 > 7 || val_1 < -7 || val_1 > 7)
859 {
860 printf("TosaSerializationHandler::ConvertI4toU8(): element in input array (%d or %d) exceeds int4 range.\n",
861 val_0, val_1);
862 return TOSA_USER_ERROR;
863 }
864 int8_t val_packed = (val_0 & 0xF) | ((val_1 & 0xF) << 4);
865 uint8_t val_u8 = static_cast<uint8_t>(val_packed);
866 out.push_back(val_u8);
867 }
868 zero_pad(out);
869 return TOSA_OK;
870}
871
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700872tosa_err_t TosaSerializationHandler::ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out)
873{
874 out.clear();
875 for (auto val : in)
876 {
877 uint8_t* val_u8 = reinterpret_cast<uint8_t*>(&val);
878 out.push_back(*val_u8);
879 }
880 zero_pad(out);
881 return TOSA_OK;
882}
883
884tosa_err_t
885 TosaSerializationHandler::ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out)
886{
887 out.clear();
888 if (in.size() < out_size * sizeof(float))
889 {
890 printf("TosaSerializationHandler::ConvertU8toF32(): uint8 buffer size %ld must >= target size %ld\n", in.size(),
891 out_size * sizeof(float));
892 return TOSA_USER_ERROR;
893 }
894 for (int i = 0; i < out_size; i++)
895 {
896 uint32_t byte0 = in[i * sizeof(float)];
897 uint32_t byte1 = in[i * sizeof(float) + 1];
898 uint32_t byte2 = in[i * sizeof(float) + 2];
899 uint32_t byte3 = in[i * sizeof(float) + 3];
900 uint32_t val_u32 = byte0 + (byte1 << 8) + (byte2 << 16) + (byte3 << 24);
901 float* val_fp32 = reinterpret_cast<float*>(&val_u32);
902 out.push_back(*val_fp32);
903 }
904 return TOSA_OK;
905}
906
907tosa_err_t TosaSerializationHandler::ConvertU8toI48(const std::vector<uint8_t>& in,
908 uint32_t out_size,
909 std::vector<int64_t>& out)
910{
911 out.clear();
912 if (in.size() < out_size * 6 /* sizeof(int48) */)
913 {
914 printf("TosaSerializationHandler::ConvertU8toI48(): uint8 buffer size %ld must >= target size %d\n", in.size(),
915 out_size * 6);
916 return TOSA_USER_ERROR;
917 }
918 for (int i = 0; i < out_size; i++)
919 {
920 uint64_t byte0 = in[i * 6];
921 uint64_t byte1 = in[i * 6 + 1];
922 uint64_t byte2 = in[i * 6 + 2];
923 uint64_t byte3 = in[i * 6 + 3];
924 uint64_t byte4 = in[i * 6 + 4];
925 uint64_t byte5 = in[i * 6 + 5];
926 bool sign = ((byte5 >> 7) & 1) == 1 ? true : false;
927 uint64_t val_u64 = byte0 + (byte1 << 8) + (byte2 << 16) + (byte3 << 24) + (byte4 << 32) + (byte5 << 40);
928 if (sign)
929 {
930 uint64_t sext_mask = (0xFFFFUL << 48);
931 val_u64 |= sext_mask;
932 }
933 int64_t* val_i64 = reinterpret_cast<int64_t*>(&val_u64);
934 out.push_back(*val_i64);
935 }
936 return TOSA_OK;
937}
938
939tosa_err_t TosaSerializationHandler::ConvertU8toI32(const std::vector<uint8_t>& in,
940 uint32_t out_size,
941 std::vector<int32_t>& out)
942{
943 out.clear();
944 if (in.size() < out_size * sizeof(int32_t))
945 {
946 printf("TosaSerializationHandler::ConvertU8toI32(): uint8 buffer size %ld must >= target size %ld\n", in.size(),
947 out_size * sizeof(int32_t));
948 return TOSA_USER_ERROR;
949 }
950 for (int i = 0; i < out_size; i++)
951 {
952 uint32_t byte0 = in[i * sizeof(int32_t)];
953 uint32_t byte1 = in[i * sizeof(int32_t) + 1];
954 uint32_t byte2 = in[i * sizeof(int32_t) + 2];
955 uint32_t byte3 = in[i * sizeof(int32_t) + 3];
956 uint32_t val_u32 = byte0 + (byte1 << 8) + (byte2 << 16) + (byte3 << 24);
957 int32_t* val_i32 = reinterpret_cast<int32_t*>(&val_u32);
958 out.push_back(*val_i32);
959 }
960 return TOSA_OK;
961}
962
963tosa_err_t TosaSerializationHandler::ConvertU8toI16(const std::vector<uint8_t>& in,
964 uint32_t out_size,
965 std::vector<int16_t>& out)
966{
967 out.clear();
968 if (in.size() < out_size * sizeof(int16_t))
969 {
970 printf("TosaSerializationHandler::ConvertU8toI16(): uint8 buffer size %ld must >= target size %ld\n", in.size(),
971 out_size * sizeof(int16_t));
972 return TOSA_USER_ERROR;
973 }
974 for (int i = 0; i < out_size; i++)
975 {
976 uint16_t byte0 = in[i * sizeof(int16_t)];
977 uint16_t byte1 = in[i * sizeof(int16_t) + 1];
978 uint16_t val_u16 = byte0 + (byte1 << 8);
979 int16_t* val_i16 = reinterpret_cast<int16_t*>(&val_u16);
980 out.push_back(*val_i16);
981 }
982 return TOSA_OK;
983}
984
985tosa_err_t
986 TosaSerializationHandler::ConvertU8toI8(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out)
987{
988 out.clear();
989 if (in.size() < out_size * sizeof(int8_t))
990 {
991 printf("TosaSerializationHandler::ConvertU8toI8(): uint8 buffer size %ld must >= target size %ld\n", in.size(),
Kevin Cheng3ce56342021-07-28 13:42:29 -0700992 out_size * sizeof(int8_t));
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700993 return TOSA_USER_ERROR;
994 }
995 for (int i = 0; i < out_size; i++)
996 {
997 uint8_t val_u8 = in[i];
998 int8_t* val_i8 = reinterpret_cast<int8_t*>(&val_u8);
999 out.push_back(*val_i8);
1000 }
1001 return TOSA_OK;
1002}
1003
1004tosa_err_t
Kevin Cheng3ce56342021-07-28 13:42:29 -07001005 TosaSerializationHandler::ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out)
1006{
1007 out.clear();
1008 if (out_size > in.size() * 2)
1009 {
1010 printf("TosaSerializationHandler::ConvertU8toI4(): output size %u must <= uint8 buffer size %ld x 2.\n",
1011 out_size, in.size());
1012 return TOSA_USER_ERROR;
1013 }
1014 for (int i = 0; i < in.size(); i++)
1015 {
1016 uint8_t val_u8 = in[i];
1017 uint8_t val_0_u4 = val_u8 & 0xF;
1018 uint8_t val_1_u4 = val_u8 >> 4;
1019 uint8_t val_0_u8_sext = (val_0_u4 & 0x08) ? (val_0_u4 | 0xF0) : val_0_u4;
1020 uint8_t val_1_u8_sext = (val_1_u4 & 0x08) ? (val_1_u4 | 0xF0) : val_1_u4;
1021 int8_t val_0 = static_cast<int8_t>(val_0_u8_sext);
1022 int8_t val_1 = static_cast<int8_t>(val_1_u8_sext);
1023 // In TOSA spec, int4 ranges [-7, 7]
1024 if (val_0 < -7 || val_0 > 7 || val_1 < -7 || val_1 > 7)
1025 {
1026 printf(
1027 "TosaSerializationHandler::ConvertU8toI4(): element in output array (%d or %d) exceeds int4 range.\n",
1028 val_0, val_1);
1029 return TOSA_USER_ERROR;
1030 }
1031 out.push_back(val_0);
1032 if (2 * i + 1 < out_size)
1033 out.push_back(val_1);
1034 }
1035 return TOSA_OK;
1036}
1037
1038tosa_err_t
Kevin Cheng3bb1bc12021-06-17 15:57:08 -07001039 TosaSerializationHandler::ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out)
1040{
1041 out.clear();
1042 if (in.size() < out_size * sizeof(bool))
1043 {
1044 printf("TosaSerializationHandler::ConvertU8toBool(): uint8 buffer size %ld must >= target size %ld\n",
1045 in.size(), out_size * sizeof(bool));
1046 return TOSA_USER_ERROR;
1047 }
1048 for (int i = 0; i < out_size; i++)
1049 {
1050 uint8_t val_u8 = in[i];
1051 bool* val_bool = reinterpret_cast<bool*>(&val_u8);
1052 out.push_back(*val_bool);
1053 }
1054 return TOSA_OK;
1055}