blob: c0731bc81762ca8a2e4158a0d013dbd164282987 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
Eric Kunzecc426df2024-01-03 00:27:59 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunze2364dcd2021-04-26 11:06:57 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef _TOSA_SERIALIZATION_HANDLER_H
17#define _TOSA_SERIALIZATION_HANDLER_H
18#include "attribute.h"
19#include "flatbuffers/idl.h"
20#include "flatbuffers/util.h"
21#include "numpy_utils.h"
Eric Kunze2364dcd2021-04-26 11:06:57 -070022#include "tosa_generated.h"
23#include <cstdint>
24#include <memory>
25#include <string>
26#include <vector>
27
Kevin Chenge6563f52021-10-20 12:12:02 -070028// Keep version number in sync with the version default value with schema/tosa.fbs
Kevin Chengb97cb1d2021-10-14 11:53:39 -070029#define TOSA_VERSION_MAJOR 0
Eric Kunzec3badae2024-03-19 13:55:49 -070030#define TOSA_VERSION_MINOR 90
Kevin Chengb97cb1d2021-10-14 11:53:39 -070031#define TOSA_VERSION_PATCH 0
Eric Kunze881b56f2024-03-19 20:58:08 +000032#define TOSA_VERSION_DRAFT false
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070033#define TENSOR_BUFFER_FORCE_ALIGNMENT 8
34
Eric Kunze2364dcd2021-04-26 11:06:57 -070035namespace tosa
36{
37
38enum tosa_err_t
39{
40 TOSA_OK,
41 TOSA_USER_ERROR,
42 TOSA_FILE_ERROR,
43 TOSA_MEMORY_ERROR,
44 TOSA_SCHEMA_MISSING,
45 TOSA_INTERNAL_ERROR,
46 TOSA_VERSION_MISMATCH,
47 NUM_TOSA_ERROR
48};
49
Kevin Chenge6563f52021-10-20 12:12:02 -070050struct TosaVersion
51{
52 int32_t _major;
53 int32_t _minor;
54 int32_t _patch;
55 bool _draft;
56
57 enum class compat_t
58 {
59 COMPLETELY_COMPATIBLE,
Jerry Gec4733b02023-08-02 21:48:39 +000060 BACKWARD_COMPATIBLE,
Kevin Chenge6563f52021-10-20 12:12:02 -070061 NOT_COMPATIBLE
62 };
63
64 TosaVersion() = default;
65 TosaVersion(int32_t major, int32_t minor, int32_t patch, bool draft)
66 {
67 set_version(major, minor, patch, draft);
68 }
69
70 void set_version(int32_t major, int32_t minor, int32_t patch, bool draft)
71 {
72 _major = major;
73 _minor = minor;
74 _patch = patch;
75 _draft = draft;
76 }
77
78 std::string to_string() const
79 {
80 std::string str;
81 str += std::to_string(_major) + ".";
82 str += std::to_string(_minor) + ".";
83 str += std::to_string(_patch);
84 if (_draft)
85 str += "d";
86 return str;
87 }
88
Jerry Gec4733b02023-08-02 21:48:39 +000089 static bool less_than(const TosaVersion& version1, const TosaVersion& version2)
Kevin Chenge6563f52021-10-20 12:12:02 -070090 {
Jerry Gec4733b02023-08-02 21:48:39 +000091 if (version1._major < version2._major)
Kevin Chenge6563f52021-10-20 12:12:02 -070092 {
Jerry Gec4733b02023-08-02 21:48:39 +000093 return true;
94 }
95 else if (version1._major == version2._major)
96 {
97 if (version1._minor < version2._minor)
Kevin Chenge6563f52021-10-20 12:12:02 -070098 {
Jerry Gec4733b02023-08-02 21:48:39 +000099 return true;
Kevin Chenge6563f52021-10-20 12:12:02 -0700100 }
Jerry Gec4733b02023-08-02 21:48:39 +0000101 else if (version1._minor == version2._minor)
Kevin Chenge6563f52021-10-20 12:12:02 -0700102 {
Jerry Gec4733b02023-08-02 21:48:39 +0000103 if (version1._patch < version2._patch)
104 {
105 return true;
106 }
107 else if (version1._patch == version2._patch)
108 {
109 if (version1._draft == true && version2._draft == false)
110 {
111 return true;
112 }
113 }
114 }
115 }
116 return false;
117 }
118
119 static TosaVersion::compat_t is_compatible(const TosaVersion& tosa_fb_version,
120 const TosaVersion& serializer_version)
121 {
122 bool major_match = (serializer_version._major == tosa_fb_version._major);
123 bool minor_match = (serializer_version._minor == tosa_fb_version._minor);
124 bool patch_match = (serializer_version._patch == tosa_fb_version._patch);
125 bool draft_match = (serializer_version._draft == tosa_fb_version._draft);
126
127 if (major_match && minor_match && patch_match && draft_match)
128 return TosaVersion::compat_t::COMPLETELY_COMPATIBLE;
129
Won Jeon14e33562024-02-08 00:25:44 +0000130 // We currently support backward compatibility starting from 0.90.0
131 if ((tosa_fb_version._major == 0 && tosa_fb_version._minor >= 90) || (tosa_fb_version._major > 0))
Jerry Gec4733b02023-08-02 21:48:39 +0000132 {
133 if (less_than(tosa_fb_version, serializer_version))
134 {
135 return TosaVersion::compat_t::BACKWARD_COMPATIBLE;
Kevin Chenge6563f52021-10-20 12:12:02 -0700136 }
137 }
138 return TosaVersion::compat_t::NOT_COMPATIBLE;
139 }
140};
141
Eric Kunze2364dcd2021-04-26 11:06:57 -0700142class TosaSerializationHandler;
143
144class TosaSerializationTensor
145{
146public:
147 // constructor and destructor
148 TosaSerializationTensor(const flatbuffers::String* name,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700149 const flatbuffers::Vector<int32_t>* shape,
Eric Kunze2364dcd2021-04-26 11:06:57 -0700150 DType dtype,
Jerry Ge442261b2022-09-09 13:38:56 -0700151 const flatbuffers::Vector<uint8_t>* data,
Tai Lyd0520b92023-09-19 21:30:18 +0000152 const bool variable = false,
153 const bool is_unranked = false,
154 const flatbuffers::String* variable_name = NULL);
Kevin Cheng545a5082021-11-11 01:36:33 +0000155 TosaSerializationTensor(const std::string& name,
Eric Kunze2364dcd2021-04-26 11:06:57 -0700156 const std::vector<int32_t>& shape,
157 DType dtype,
Jerry Ge442261b2022-09-09 13:38:56 -0700158 const std::vector<uint8_t>& data,
Tai Lyd0520b92023-09-19 21:30:18 +0000159 const bool variable = false,
160 const bool is_unranked = false,
161 const std::string& variable_name = "");
Eric Kunze2364dcd2021-04-26 11:06:57 -0700162 TosaSerializationTensor();
163 ~TosaSerializationTensor();
164
165 // accessor
166 std::string GetName() const
167 {
168 return _name;
169 }
170 const std::vector<int32_t>& GetShape() const
171 {
172 return _shape;
173 }
Jerry Ge442261b2022-09-09 13:38:56 -0700174 DType GetDtype() const
Eric Kunze2364dcd2021-04-26 11:06:57 -0700175 {
176 return _dtype;
177 }
Jerry Ge442261b2022-09-09 13:38:56 -0700178 bool GetVariable() const
179 {
180 return _variable;
181 }
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700182 const std::vector<uint8_t>& GetData() const
Eric Kunze2364dcd2021-04-26 11:06:57 -0700183 {
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700184 return _data;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700185 }
Eric Kunzecc426df2024-01-03 00:27:59 +0000186 bool GetIsUnranked() const
Tai Lyc6939a42023-08-21 17:00:29 +0000187 {
188 return _is_unranked;
189 }
Tai Lyd0520b92023-09-19 21:30:18 +0000190 const std::string GetVariableName() const
191 {
192 return _variable_name;
193 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700194
195 // modifier
196 void SetDtype(DType dtype)
197 {
198 _dtype = dtype;
199 }
200 void SetName(std::string name)
201 {
202 _name = name;
203 }
Kevin Cheng545a5082021-11-11 01:36:33 +0000204 void SetData(const std::vector<uint8_t>& data)
205 {
206 _data = data;
207 }
208 void SetData(std::vector<uint8_t>&& data)
209 {
210 _data = std::move(data);
211 }
Tai Lyc6939a42023-08-21 17:00:29 +0000212 void SetIsUnranked(const bool value)
213 {
214 _is_unranked = value;
215 }
Jerry Geab8d2342023-04-26 22:31:11 +0000216 void SetDimSize(size_t dim, uint32_t new_size)
217 {
Eric Kunzecc426df2024-01-03 00:27:59 +0000218 if (dim >= _shape.size())
Jerry Geab8d2342023-04-26 22:31:11 +0000219 {
220 printf("dim is out of bound\n");
221 assert(0);
222 }
223 _shape[dim] = new_size;
224 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700225
226private:
227 DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
228 std::vector<int32_t> _shape; /* shape of the tensor */
229 std::string _name; /* name of the tensor, used for solving dependency */
Jerry Ge442261b2022-09-09 13:38:56 -0700230 bool _variable; /* is this a variable tensor */
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700231 std::vector<uint8_t> _data; /* data array */
Tai Lyc6939a42023-08-21 17:00:29 +0000232 bool _is_unranked; /* whether this is an unranked tensor */
Tai Lyd0520b92023-09-19 21:30:18 +0000233 std::string _variable_name; /* name for variable tensors */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700234};
235
236class TosaSerializationOperator
237{
238public:
239 // use default copy, void constructor
240 // constructor and destructor
241 TosaSerializationOperator(Op op,
242 Attribute attribute_type,
243 const TosaAttributeBase* attribute,
Kevin Cheng545a5082021-11-11 01:36:33 +0000244 const std::vector<std::string>& input_tensor_names,
245 const std::vector<std::string>& output_tensor_names);
246 TosaSerializationOperator(Op op,
247 Attribute attribute_type,
248 const TosaAttributeBase* attribute,
Kevin Cheng545a5082021-11-11 01:36:33 +0000249 std::vector<std::string>&& input_tensor_names,
250 std::vector<std::string>&& output_tensor_names);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700251 ~TosaSerializationOperator();
252
253 // accessor
254 Op GetOp() const
255 {
256 return _op;
257 }
258 Attribute GetAttributeType() const
259 {
260 return _attribute_type;
261 }
262 TosaAttributeBase* GetAttribute() const
263 {
264 return _attribute;
265 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700266 std::vector<std::string>& GetInputTensorNames()
267 {
268 return _input_tensor_names;
269 }
270 std::vector<std::string>& GetOutputTensorNames()
271 {
272 return _output_tensor_names;
273 }
274
275private:
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000276 void InitializeAttribute(Attribute attribute_type, const TosaAttributeBase* attribute);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700277 Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */
278 Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */
279 TosaAttributeBase* _attribute; /* real attribute class goes here */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700280 std::vector<std::string> _input_tensor_names; /* array of input tensor names */
281 std::vector<std::string> _output_tensor_names; /* array of output tensor names */
282};
283
284class TosaSerializationBasicBlock
285{
286public:
287 // constructor and destructor
Kevin Cheng545a5082021-11-11 01:36:33 +0000288 TosaSerializationBasicBlock(const std::string& name,
Jerry Ge13c78a62022-10-04 20:32:39 -0700289 const std::string& region_name,
Kevin Cheng545a5082021-11-11 01:36:33 +0000290 const std::vector<TosaSerializationOperator*>& operators,
291 const std::vector<TosaSerializationTensor*>& tensors,
292 const std::vector<std::string>& inputs,
293 const std::vector<std::string>& outputs);
294 TosaSerializationBasicBlock(std::string&& name,
Jerry Ge13c78a62022-10-04 20:32:39 -0700295 std::string&& region_name,
Kevin Cheng545a5082021-11-11 01:36:33 +0000296 std::vector<TosaSerializationOperator*>&& operators,
297 std::vector<TosaSerializationTensor*>&& tensors,
298 std::vector<std::string>&& inputs,
299 std::vector<std::string>&& outputs);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700300 ~TosaSerializationBasicBlock();
301
302 // accessor
303 std::string GetName() const
304 {
305 return _name;
306 }
Jerry Ge13c78a62022-10-04 20:32:39 -0700307 std::string GetRegionName() const
308 {
309 return _region_name;
310 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700311 std::vector<TosaSerializationOperator*>& GetOperators()
312 {
313 return _operators;
314 }
Jerry Ge13c78a62022-10-04 20:32:39 -0700315
Eric Kunze2364dcd2021-04-26 11:06:57 -0700316 std::vector<TosaSerializationTensor*>& GetTensors()
317 {
318 return _tensors;
319 }
320
321 TosaSerializationTensor* GetTensorByName(std::string name)
322 {
323 TosaSerializationTensor* result = nullptr;
324 for (auto tensor : GetTensors())
325 {
326 if (tensor->GetName() == name)
327 {
328 result = tensor;
329 break;
330 }
331 }
332 return result;
333 }
334
335 std::vector<std::string>& GetInputs()
336 {
337 return _inputs;
338 }
Jerry Ge13c78a62022-10-04 20:32:39 -0700339
Eric Kunze2364dcd2021-04-26 11:06:57 -0700340 std::vector<std::string>& GetOutputs()
341 {
342 return _outputs;
343 }
344
345private:
Jerry Ge13c78a62022-10-04 20:32:39 -0700346 std::string _name; /* name of basic block */
347 std::string _region_name;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700348 std::vector<TosaSerializationOperator*> _operators; /* TosaSerializationOperator list */
349 std::vector<TosaSerializationTensor*> _tensors; /* TosaSerializationTensor list */
350 std::vector<std::string> _inputs; /* array of string to specify block inputs */
351 std::vector<std::string> _outputs; /* array of string to specify block outputs */
352};
353
Jerry Ge13c78a62022-10-04 20:32:39 -0700354class TosaSerializationRegion
355{
356public:
357 // constructor and desctructor
358 TosaSerializationRegion(const std::string& name, const std::vector<TosaSerializationBasicBlock*>& blocks);
359 TosaSerializationRegion(const std::string&& name, const std::vector<TosaSerializationBasicBlock*>&& blocks);
360 ~TosaSerializationRegion();
361
362 // accessors
363 std::string GetName() const
364 {
365 return this->_name;
366 }
367
368 std::vector<TosaSerializationBasicBlock*>& GetBlocks()
369 {
370 return this->_blocks;
371 }
372
373 TosaSerializationBasicBlock* GetBlockByName(std::string name)
374 {
375 TosaSerializationBasicBlock* result = nullptr;
376 for (auto block : GetBlocks())
377 {
378 if (block->GetName() == name)
379 {
380 result = block;
381 break;
382 }
383 }
384 return result;
385 }
386
387private:
388 std::string _name; /* name of basic block */
389 std::vector<TosaSerializationBasicBlock*> _blocks; /* TosaSerializationBasicBlock list */
390};
391
Eric Kunze2364dcd2021-04-26 11:06:57 -0700392/*
393 * this is a helper class for writing/reading Tosa ISA
394 * supported format: .tosa (flatbuffer), .json
395 * and provide high-level std::vector-like interface
396 * to access internal data structure
397 */
398class TosaSerializationHandler
399{
400public:
401 // constructor and destructor
402 TosaSerializationHandler();
403 ~TosaSerializationHandler();
404
405 // file io
406 tosa_err_t LoadFileJson(const char* filename);
407 tosa_err_t LoadFileTosaFlatbuffer(const char* filename);
Aaron DeBattista8b3903a2021-11-18 16:38:11 +0000408 tosa_err_t LoadFileTosaFlatbuffer(const void* input, int in_size);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700409 tosa_err_t SaveFileJson(const char* filename);
410 tosa_err_t SaveFileTosaFlatbuffer(const char* filename);
411 tosa_err_t LoadFileSchema(const char* schema_filename);
412
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700413 // data format conversion. little-endian.
James Ward485a11d2022-08-05 13:48:37 +0100414 static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700415 static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
Tai Ly5d580fa2023-12-15 20:34:51 +0000416 static tosa_err_t ConvertI64toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700417 static tosa_err_t ConvertI48toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
418 static tosa_err_t ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out);
419 static tosa_err_t ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out);
420 static tosa_err_t ConvertI8toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
Kevin Cheng3ce56342021-07-28 13:42:29 -0700421 static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700422 static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
423
James Ward485a11d2022-08-05 13:48:37 +0100424 static tosa_err_t ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700425 static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
Tai Ly5d580fa2023-12-15 20:34:51 +0000426 static tosa_err_t ConvertU8toI64(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700427 static tosa_err_t ConvertU8toI48(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out);
428 static tosa_err_t ConvertU8toI32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int32_t>& out);
429 static tosa_err_t ConvertU8toI16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int16_t>& out);
430 static tosa_err_t ConvertU8toI8(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
Kevin Cheng3ce56342021-07-28 13:42:29 -0700431 static tosa_err_t ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700432 static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out);
433
Jerry Ge442261b2022-09-09 13:38:56 -0700434 static void ForceAlignTensorData(std::vector<uint8_t>& buf);
435
Eric Kunze2364dcd2021-04-26 11:06:57 -0700436 // version
Kevin Chenge6563f52021-10-20 12:12:02 -0700437 const TosaVersion& GetVersion()
Eric Kunze2364dcd2021-04-26 11:06:57 -0700438 {
439 return _version;
440 }
441
442 // accessor
Jerry Ge13c78a62022-10-04 20:32:39 -0700443 std::vector<TosaSerializationRegion*>& GetRegions()
Eric Kunze2364dcd2021-04-26 11:06:57 -0700444 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700445 return _regions;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700446 }
447
Jerry Ge13c78a62022-10-04 20:32:39 -0700448 TosaSerializationRegion* GetMainRegion()
Eric Kunze2364dcd2021-04-26 11:06:57 -0700449 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700450 return _regions[0];
451 }
452
453 TosaSerializationRegion* GetRegionByName(std::string name)
454 {
455 TosaSerializationRegion* result = nullptr;
456 for (auto region : GetRegions())
Eric Kunze2364dcd2021-04-26 11:06:57 -0700457 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700458 if (region->GetName() == name)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700459 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700460 result = region;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700461 break;
462 }
463 }
464 return result;
465 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700466
467 bool GetSchemaLoaded() const
468 {
469 return _schemaLoaded;
470 }
471
472protected:
473 tosa_err_t Clear();
Kevin Chengb97cb1d2021-10-14 11:53:39 -0700474 tosa_err_t Deserialize(const uint8_t* buf);
475 tosa_err_t Serialize();
Eric Kunze2364dcd2021-04-26 11:06:57 -0700476
477private:
Jerry Ge13c78a62022-10-04 20:32:39 -0700478 TosaVersion _version; /* version struct */
479 flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
480 flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
481 std::vector<TosaSerializationRegion*> _regions; /* array structure to store all TosaSerializationRegion */
482 bool _schemaLoaded; /* is the schema properly loaded? */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700483};
484
485} // namespace tosa
486
487#endif // _TOSA_SERIALIZATION_HANDLER_H