blob: 97dfdb07c0d5d9bdd9d2bbd450f3dd70c3afb92a [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
Eric Kunzecc426df2024-01-03 00:27:59 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunze2364dcd2021-04-26 11:06:57 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef _TOSA_SERIALIZATION_HANDLER_H
17#define _TOSA_SERIALIZATION_HANDLER_H
18#include "attribute.h"
19#include "flatbuffers/idl.h"
20#include "flatbuffers/util.h"
21#include "numpy_utils.h"
Eric Kunze2364dcd2021-04-26 11:06:57 -070022#include "tosa_generated.h"
23#include <cstdint>
24#include <memory>
25#include <string>
26#include <vector>
27
Kevin Chenge6563f52021-10-20 12:12:02 -070028// Keep version number in sync with the version default value with schema/tosa.fbs
Kevin Chengb97cb1d2021-10-14 11:53:39 -070029#define TOSA_VERSION_MAJOR 0
Eric Kunze8137a432024-02-02 21:33:22 +000030#define TOSA_VERSION_MINOR 100
Kevin Chengb97cb1d2021-10-14 11:53:39 -070031#define TOSA_VERSION_PATCH 0
Eric Kunze8a270432023-06-01 20:08:17 +000032#define TOSA_VERSION_DRAFT true
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070033#define TENSOR_BUFFER_FORCE_ALIGNMENT 8
34
Eric Kunze2364dcd2021-04-26 11:06:57 -070035namespace tosa
36{
37
38enum tosa_err_t
39{
40 TOSA_OK,
41 TOSA_USER_ERROR,
42 TOSA_FILE_ERROR,
43 TOSA_MEMORY_ERROR,
44 TOSA_SCHEMA_MISSING,
45 TOSA_INTERNAL_ERROR,
46 TOSA_VERSION_MISMATCH,
47 NUM_TOSA_ERROR
48};
49
Kevin Chenge6563f52021-10-20 12:12:02 -070050struct TosaVersion
51{
52 int32_t _major;
53 int32_t _minor;
54 int32_t _patch;
55 bool _draft;
56
57 enum class compat_t
58 {
59 COMPLETELY_COMPATIBLE,
Jerry Gec4733b02023-08-02 21:48:39 +000060 BACKWARD_COMPATIBLE,
Kevin Chenge6563f52021-10-20 12:12:02 -070061 NOT_COMPATIBLE
62 };
63
64 TosaVersion() = default;
65 TosaVersion(int32_t major, int32_t minor, int32_t patch, bool draft)
66 {
67 set_version(major, minor, patch, draft);
68 }
69
70 void set_version(int32_t major, int32_t minor, int32_t patch, bool draft)
71 {
72 _major = major;
73 _minor = minor;
74 _patch = patch;
75 _draft = draft;
76 }
77
78 std::string to_string() const
79 {
80 std::string str;
81 str += std::to_string(_major) + ".";
82 str += std::to_string(_minor) + ".";
83 str += std::to_string(_patch);
84 if (_draft)
85 str += "d";
86 return str;
87 }
88
Jerry Gec4733b02023-08-02 21:48:39 +000089 static bool less_than(const TosaVersion& version1, const TosaVersion& version2)
Kevin Chenge6563f52021-10-20 12:12:02 -070090 {
Jerry Gec4733b02023-08-02 21:48:39 +000091 if (version1._major < version2._major)
Kevin Chenge6563f52021-10-20 12:12:02 -070092 {
Jerry Gec4733b02023-08-02 21:48:39 +000093 return true;
94 }
95 else if (version1._major == version2._major)
96 {
97 if (version1._minor < version2._minor)
Kevin Chenge6563f52021-10-20 12:12:02 -070098 {
Jerry Gec4733b02023-08-02 21:48:39 +000099 return true;
Kevin Chenge6563f52021-10-20 12:12:02 -0700100 }
Jerry Gec4733b02023-08-02 21:48:39 +0000101 else if (version1._minor == version2._minor)
Kevin Chenge6563f52021-10-20 12:12:02 -0700102 {
Jerry Gec4733b02023-08-02 21:48:39 +0000103 if (version1._patch < version2._patch)
104 {
105 return true;
106 }
107 else if (version1._patch == version2._patch)
108 {
109 if (version1._draft == true && version2._draft == false)
110 {
111 return true;
112 }
113 }
114 }
115 }
116 return false;
117 }
118
119 static TosaVersion::compat_t is_compatible(const TosaVersion& tosa_fb_version,
120 const TosaVersion& serializer_version)
121 {
122 bool major_match = (serializer_version._major == tosa_fb_version._major);
123 bool minor_match = (serializer_version._minor == tosa_fb_version._minor);
124 bool patch_match = (serializer_version._patch == tosa_fb_version._patch);
125 bool draft_match = (serializer_version._draft == tosa_fb_version._draft);
126
127 if (major_match && minor_match && patch_match && draft_match)
128 return TosaVersion::compat_t::COMPLETELY_COMPATIBLE;
129
130 // We currently support backward compatibility starting from 0.70.0
131 // TODO: need to double-check this logic right before TOSA 1.0.0 release
132 if ((tosa_fb_version._major == 0 && tosa_fb_version._minor >= 70) || (tosa_fb_version._major > 0))
133 {
134 if (less_than(tosa_fb_version, serializer_version))
135 {
136 return TosaVersion::compat_t::BACKWARD_COMPATIBLE;
Kevin Chenge6563f52021-10-20 12:12:02 -0700137 }
138 }
139 return TosaVersion::compat_t::NOT_COMPATIBLE;
140 }
141};
142
Eric Kunze2364dcd2021-04-26 11:06:57 -0700143class TosaSerializationHandler;
144
145class TosaSerializationTensor
146{
147public:
148 // constructor and destructor
149 TosaSerializationTensor(const flatbuffers::String* name,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700150 const flatbuffers::Vector<int32_t>* shape,
Eric Kunze2364dcd2021-04-26 11:06:57 -0700151 DType dtype,
Jerry Ge442261b2022-09-09 13:38:56 -0700152 const flatbuffers::Vector<uint8_t>* data,
Tai Lyd0520b92023-09-19 21:30:18 +0000153 const bool variable = false,
154 const bool is_unranked = false,
155 const flatbuffers::String* variable_name = NULL);
Kevin Cheng545a5082021-11-11 01:36:33 +0000156 TosaSerializationTensor(const std::string& name,
Eric Kunze2364dcd2021-04-26 11:06:57 -0700157 const std::vector<int32_t>& shape,
158 DType dtype,
Jerry Ge442261b2022-09-09 13:38:56 -0700159 const std::vector<uint8_t>& data,
Tai Lyd0520b92023-09-19 21:30:18 +0000160 const bool variable = false,
161 const bool is_unranked = false,
162 const std::string& variable_name = "");
Eric Kunze2364dcd2021-04-26 11:06:57 -0700163 TosaSerializationTensor();
164 ~TosaSerializationTensor();
165
166 // accessor
167 std::string GetName() const
168 {
169 return _name;
170 }
171 const std::vector<int32_t>& GetShape() const
172 {
173 return _shape;
174 }
Jerry Ge442261b2022-09-09 13:38:56 -0700175 DType GetDtype() const
Eric Kunze2364dcd2021-04-26 11:06:57 -0700176 {
177 return _dtype;
178 }
Jerry Ge442261b2022-09-09 13:38:56 -0700179 bool GetVariable() const
180 {
181 return _variable;
182 }
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700183 const std::vector<uint8_t>& GetData() const
Eric Kunze2364dcd2021-04-26 11:06:57 -0700184 {
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700185 return _data;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700186 }
Eric Kunzecc426df2024-01-03 00:27:59 +0000187 bool GetIsUnranked() const
Tai Lyc6939a42023-08-21 17:00:29 +0000188 {
189 return _is_unranked;
190 }
Tai Lyd0520b92023-09-19 21:30:18 +0000191 const std::string GetVariableName() const
192 {
193 return _variable_name;
194 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700195
196 // modifier
197 void SetDtype(DType dtype)
198 {
199 _dtype = dtype;
200 }
201 void SetName(std::string name)
202 {
203 _name = name;
204 }
Kevin Cheng545a5082021-11-11 01:36:33 +0000205 void SetData(const std::vector<uint8_t>& data)
206 {
207 _data = data;
208 }
209 void SetData(std::vector<uint8_t>&& data)
210 {
211 _data = std::move(data);
212 }
Tai Lyc6939a42023-08-21 17:00:29 +0000213 void SetIsUnranked(const bool value)
214 {
215 _is_unranked = value;
216 }
Jerry Geab8d2342023-04-26 22:31:11 +0000217 void SetDimSize(size_t dim, uint32_t new_size)
218 {
Eric Kunzecc426df2024-01-03 00:27:59 +0000219 if (dim >= _shape.size())
Jerry Geab8d2342023-04-26 22:31:11 +0000220 {
221 printf("dim is out of bound\n");
222 assert(0);
223 }
224 _shape[dim] = new_size;
225 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700226
227private:
228 DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
229 std::vector<int32_t> _shape; /* shape of the tensor */
230 std::string _name; /* name of the tensor, used for solving dependency */
Jerry Ge442261b2022-09-09 13:38:56 -0700231 bool _variable; /* is this a variable tensor */
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700232 std::vector<uint8_t> _data; /* data array */
Tai Lyc6939a42023-08-21 17:00:29 +0000233 bool _is_unranked; /* whether this is an unranked tensor */
Tai Lyd0520b92023-09-19 21:30:18 +0000234 std::string _variable_name; /* name for variable tensors */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700235};
236
237class TosaSerializationOperator
238{
239public:
240 // use default copy, void constructor
241 // constructor and destructor
242 TosaSerializationOperator(Op op,
243 Attribute attribute_type,
244 const TosaAttributeBase* attribute,
Kevin Cheng545a5082021-11-11 01:36:33 +0000245 const std::vector<std::string>& input_tensor_names,
246 const std::vector<std::string>& output_tensor_names);
247 TosaSerializationOperator(Op op,
248 Attribute attribute_type,
249 const TosaAttributeBase* attribute,
Kevin Cheng545a5082021-11-11 01:36:33 +0000250 std::vector<std::string>&& input_tensor_names,
251 std::vector<std::string>&& output_tensor_names);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700252 ~TosaSerializationOperator();
253
254 // accessor
255 Op GetOp() const
256 {
257 return _op;
258 }
259 Attribute GetAttributeType() const
260 {
261 return _attribute_type;
262 }
263 TosaAttributeBase* GetAttribute() const
264 {
265 return _attribute;
266 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700267 std::vector<std::string>& GetInputTensorNames()
268 {
269 return _input_tensor_names;
270 }
271 std::vector<std::string>& GetOutputTensorNames()
272 {
273 return _output_tensor_names;
274 }
275
276private:
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000277 void InitializeAttribute(Attribute attribute_type, const TosaAttributeBase* attribute);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700278 Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */
279 Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */
280 TosaAttributeBase* _attribute; /* real attribute class goes here */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700281 std::vector<std::string> _input_tensor_names; /* array of input tensor names */
282 std::vector<std::string> _output_tensor_names; /* array of output tensor names */
283};
284
285class TosaSerializationBasicBlock
286{
287public:
288 // constructor and destructor
Kevin Cheng545a5082021-11-11 01:36:33 +0000289 TosaSerializationBasicBlock(const std::string& name,
Jerry Ge13c78a62022-10-04 20:32:39 -0700290 const std::string& region_name,
Kevin Cheng545a5082021-11-11 01:36:33 +0000291 const std::vector<TosaSerializationOperator*>& operators,
292 const std::vector<TosaSerializationTensor*>& tensors,
293 const std::vector<std::string>& inputs,
294 const std::vector<std::string>& outputs);
295 TosaSerializationBasicBlock(std::string&& name,
Jerry Ge13c78a62022-10-04 20:32:39 -0700296 std::string&& region_name,
Kevin Cheng545a5082021-11-11 01:36:33 +0000297 std::vector<TosaSerializationOperator*>&& operators,
298 std::vector<TosaSerializationTensor*>&& tensors,
299 std::vector<std::string>&& inputs,
300 std::vector<std::string>&& outputs);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700301 ~TosaSerializationBasicBlock();
302
303 // accessor
304 std::string GetName() const
305 {
306 return _name;
307 }
Jerry Ge13c78a62022-10-04 20:32:39 -0700308 std::string GetRegionName() const
309 {
310 return _region_name;
311 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700312 std::vector<TosaSerializationOperator*>& GetOperators()
313 {
314 return _operators;
315 }
Jerry Ge13c78a62022-10-04 20:32:39 -0700316
Eric Kunze2364dcd2021-04-26 11:06:57 -0700317 std::vector<TosaSerializationTensor*>& GetTensors()
318 {
319 return _tensors;
320 }
321
322 TosaSerializationTensor* GetTensorByName(std::string name)
323 {
324 TosaSerializationTensor* result = nullptr;
325 for (auto tensor : GetTensors())
326 {
327 if (tensor->GetName() == name)
328 {
329 result = tensor;
330 break;
331 }
332 }
333 return result;
334 }
335
336 std::vector<std::string>& GetInputs()
337 {
338 return _inputs;
339 }
Jerry Ge13c78a62022-10-04 20:32:39 -0700340
Eric Kunze2364dcd2021-04-26 11:06:57 -0700341 std::vector<std::string>& GetOutputs()
342 {
343 return _outputs;
344 }
345
346private:
Jerry Ge13c78a62022-10-04 20:32:39 -0700347 std::string _name; /* name of basic block */
348 std::string _region_name;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700349 std::vector<TosaSerializationOperator*> _operators; /* TosaSerializationOperator list */
350 std::vector<TosaSerializationTensor*> _tensors; /* TosaSerializationTensor list */
351 std::vector<std::string> _inputs; /* array of string to specify block inputs */
352 std::vector<std::string> _outputs; /* array of string to specify block outputs */
353};
354
Jerry Ge13c78a62022-10-04 20:32:39 -0700355class TosaSerializationRegion
356{
357public:
358 // constructor and desctructor
359 TosaSerializationRegion(const std::string& name, const std::vector<TosaSerializationBasicBlock*>& blocks);
360 TosaSerializationRegion(const std::string&& name, const std::vector<TosaSerializationBasicBlock*>&& blocks);
361 ~TosaSerializationRegion();
362
363 // accessors
364 std::string GetName() const
365 {
366 return this->_name;
367 }
368
369 std::vector<TosaSerializationBasicBlock*>& GetBlocks()
370 {
371 return this->_blocks;
372 }
373
374 TosaSerializationBasicBlock* GetBlockByName(std::string name)
375 {
376 TosaSerializationBasicBlock* result = nullptr;
377 for (auto block : GetBlocks())
378 {
379 if (block->GetName() == name)
380 {
381 result = block;
382 break;
383 }
384 }
385 return result;
386 }
387
388private:
389 std::string _name; /* name of basic block */
390 std::vector<TosaSerializationBasicBlock*> _blocks; /* TosaSerializationBasicBlock list */
391};
392
Eric Kunze2364dcd2021-04-26 11:06:57 -0700393/*
394 * this is a helper class for writing/reading Tosa ISA
395 * supported format: .tosa (flatbuffer), .json
396 * and provide high-level std::vector-like interface
397 * to access internal data structure
398 */
399class TosaSerializationHandler
400{
401public:
402 // constructor and destructor
403 TosaSerializationHandler();
404 ~TosaSerializationHandler();
405
406 // file io
407 tosa_err_t LoadFileJson(const char* filename);
408 tosa_err_t LoadFileTosaFlatbuffer(const char* filename);
Aaron DeBattista8b3903a2021-11-18 16:38:11 +0000409 tosa_err_t LoadFileTosaFlatbuffer(const void* input, int in_size);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700410 tosa_err_t SaveFileJson(const char* filename);
411 tosa_err_t SaveFileTosaFlatbuffer(const char* filename);
412 tosa_err_t LoadFileSchema(const char* schema_filename);
413
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700414 // data format conversion. little-endian.
James Ward485a11d2022-08-05 13:48:37 +0100415 static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700416 static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
Tai Ly5d580fa2023-12-15 20:34:51 +0000417 static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700418 static tosa_err_t ConvertI48toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
419 static tosa_err_t ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out);
420 static tosa_err_t ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out);
421 static tosa_err_t ConvertI8toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
Kevin Cheng3ce56342021-07-28 13:42:29 -0700422 static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700423 static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
424
James Ward485a11d2022-08-05 13:48:37 +0100425 static tosa_err_t ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700426 static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
Tai Ly5d580fa2023-12-15 20:34:51 +0000427 static tosa_err_t ConvertU8toI64(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700428 static tosa_err_t ConvertU8toI48(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out);
429 static tosa_err_t ConvertU8toI32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int32_t>& out);
430 static tosa_err_t ConvertU8toI16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int16_t>& out);
431 static tosa_err_t ConvertU8toI8(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
Kevin Cheng3ce56342021-07-28 13:42:29 -0700432 static tosa_err_t ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700433 static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out);
434
Jerry Ge442261b2022-09-09 13:38:56 -0700435 static void ForceAlignTensorData(std::vector<uint8_t>& buf);
436
Eric Kunze2364dcd2021-04-26 11:06:57 -0700437 // version
Kevin Chenge6563f52021-10-20 12:12:02 -0700438 const TosaVersion& GetVersion()
Eric Kunze2364dcd2021-04-26 11:06:57 -0700439 {
440 return _version;
441 }
442
443 // accessor
Jerry Ge13c78a62022-10-04 20:32:39 -0700444 std::vector<TosaSerializationRegion*>& GetRegions()
Eric Kunze2364dcd2021-04-26 11:06:57 -0700445 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700446 return _regions;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700447 }
448
Jerry Ge13c78a62022-10-04 20:32:39 -0700449 TosaSerializationRegion* GetMainRegion()
Eric Kunze2364dcd2021-04-26 11:06:57 -0700450 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700451 return _regions[0];
452 }
453
454 TosaSerializationRegion* GetRegionByName(std::string name)
455 {
456 TosaSerializationRegion* result = nullptr;
457 for (auto region : GetRegions())
Eric Kunze2364dcd2021-04-26 11:06:57 -0700458 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700459 if (region->GetName() == name)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700460 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700461 result = region;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700462 break;
463 }
464 }
465 return result;
466 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700467
468 bool GetSchemaLoaded() const
469 {
470 return _schemaLoaded;
471 }
472
473protected:
474 tosa_err_t Clear();
Kevin Chengb97cb1d2021-10-14 11:53:39 -0700475 tosa_err_t Deserialize(const uint8_t* buf);
476 tosa_err_t Serialize();
Eric Kunze2364dcd2021-04-26 11:06:57 -0700477
478private:
Jerry Ge13c78a62022-10-04 20:32:39 -0700479 TosaVersion _version; /* version struct */
480 flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
481 flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
482 std::vector<TosaSerializationRegion*> _regions; /* array structure to store all TosaSerializationRegion */
483 bool _schemaLoaded; /* is the schema properly loaded? */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700484};
485
486} // namespace tosa
487
488#endif // _TOSA_SERIALIZATION_HANDLER_H