blob: a9803365095681fde5d820a8740720ae9435f2f4 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
Jerry Ge442261b2022-09-09 13:38:56 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunze2364dcd2021-04-26 11:06:57 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef _TOSA_SERIALIZATION_HANDLER_H
17#define _TOSA_SERIALIZATION_HANDLER_H
18#include "attribute.h"
19#include "flatbuffers/idl.h"
20#include "flatbuffers/util.h"
21#include "numpy_utils.h"
Eric Kunze2364dcd2021-04-26 11:06:57 -070022#include "tosa_generated.h"
23#include <cstdint>
24#include <memory>
25#include <string>
26#include <vector>
27
Kevin Chenge6563f52021-10-20 12:12:02 -070028// Keep version number in sync with the version default value with schema/tosa.fbs
Kevin Chengb97cb1d2021-10-14 11:53:39 -070029#define TOSA_VERSION_MAJOR 0
Eric Kunze8a270432023-06-01 20:08:17 +000030#define TOSA_VERSION_MINOR 80
Kevin Chengb97cb1d2021-10-14 11:53:39 -070031#define TOSA_VERSION_PATCH 0
Eric Kunze8a270432023-06-01 20:08:17 +000032#define TOSA_VERSION_DRAFT true
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070033#define TENSOR_BUFFER_FORCE_ALIGNMENT 8
34
Eric Kunze2364dcd2021-04-26 11:06:57 -070035namespace tosa
36{
37
38enum tosa_err_t
39{
40 TOSA_OK,
41 TOSA_USER_ERROR,
42 TOSA_FILE_ERROR,
43 TOSA_MEMORY_ERROR,
44 TOSA_SCHEMA_MISSING,
45 TOSA_INTERNAL_ERROR,
46 TOSA_VERSION_MISMATCH,
47 NUM_TOSA_ERROR
48};
49
Kevin Chenge6563f52021-10-20 12:12:02 -070050struct TosaVersion
51{
52 int32_t _major;
53 int32_t _minor;
54 int32_t _patch;
55 bool _draft;
56
57 enum class compat_t
58 {
59 COMPLETELY_COMPATIBLE,
Jerry Gec4733b02023-08-02 21:48:39 +000060 BACKWARD_COMPATIBLE,
Kevin Chenge6563f52021-10-20 12:12:02 -070061 NOT_COMPATIBLE
62 };
63
64 TosaVersion() = default;
65 TosaVersion(int32_t major, int32_t minor, int32_t patch, bool draft)
66 {
67 set_version(major, minor, patch, draft);
68 }
69
70 void set_version(int32_t major, int32_t minor, int32_t patch, bool draft)
71 {
72 _major = major;
73 _minor = minor;
74 _patch = patch;
75 _draft = draft;
76 }
77
78 std::string to_string() const
79 {
80 std::string str;
81 str += std::to_string(_major) + ".";
82 str += std::to_string(_minor) + ".";
83 str += std::to_string(_patch);
84 if (_draft)
85 str += "d";
86 return str;
87 }
88
Jerry Gec4733b02023-08-02 21:48:39 +000089 static bool less_than(const TosaVersion& version1, const TosaVersion& version2)
Kevin Chenge6563f52021-10-20 12:12:02 -070090 {
Jerry Gec4733b02023-08-02 21:48:39 +000091 if (version1._major < version2._major)
Kevin Chenge6563f52021-10-20 12:12:02 -070092 {
Jerry Gec4733b02023-08-02 21:48:39 +000093 return true;
94 }
95 else if (version1._major == version2._major)
96 {
97 if (version1._minor < version2._minor)
Kevin Chenge6563f52021-10-20 12:12:02 -070098 {
Jerry Gec4733b02023-08-02 21:48:39 +000099 return true;
Kevin Chenge6563f52021-10-20 12:12:02 -0700100 }
Jerry Gec4733b02023-08-02 21:48:39 +0000101 else if (version1._minor == version2._minor)
Kevin Chenge6563f52021-10-20 12:12:02 -0700102 {
Jerry Gec4733b02023-08-02 21:48:39 +0000103 if (version1._patch < version2._patch)
104 {
105 return true;
106 }
107 else if (version1._patch == version2._patch)
108 {
109 if (version1._draft == true && version2._draft == false)
110 {
111 return true;
112 }
113 }
114 }
115 }
116 return false;
117 }
118
119 static TosaVersion::compat_t is_compatible(const TosaVersion& tosa_fb_version,
120 const TosaVersion& serializer_version)
121 {
122 bool major_match = (serializer_version._major == tosa_fb_version._major);
123 bool minor_match = (serializer_version._minor == tosa_fb_version._minor);
124 bool patch_match = (serializer_version._patch == tosa_fb_version._patch);
125 bool draft_match = (serializer_version._draft == tosa_fb_version._draft);
126
127 if (major_match && minor_match && patch_match && draft_match)
128 return TosaVersion::compat_t::COMPLETELY_COMPATIBLE;
129
130 // We currently support backward compatibility starting from 0.70.0
131 // TODO: need to double-check this logic right before TOSA 1.0.0 release
132 if ((tosa_fb_version._major == 0 && tosa_fb_version._minor >= 70) || (tosa_fb_version._major > 0))
133 {
134 if (less_than(tosa_fb_version, serializer_version))
135 {
136 return TosaVersion::compat_t::BACKWARD_COMPATIBLE;
Kevin Chenge6563f52021-10-20 12:12:02 -0700137 }
138 }
139 return TosaVersion::compat_t::NOT_COMPATIBLE;
140 }
141};
142
Eric Kunze2364dcd2021-04-26 11:06:57 -0700143class TosaSerializationHandler;
144
145class TosaSerializationTensor
146{
147public:
148 // constructor and destructor
149 TosaSerializationTensor(const flatbuffers::String* name,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700150 const flatbuffers::Vector<int32_t>* shape,
Eric Kunze2364dcd2021-04-26 11:06:57 -0700151 DType dtype,
Jerry Ge442261b2022-09-09 13:38:56 -0700152 const flatbuffers::Vector<uint8_t>* data,
Tai Lyc6939a42023-08-21 17:00:29 +0000153 const bool variable = false,
154 const bool is_unranked = false);
Kevin Cheng545a5082021-11-11 01:36:33 +0000155 TosaSerializationTensor(const std::string& name,
Eric Kunze2364dcd2021-04-26 11:06:57 -0700156 const std::vector<int32_t>& shape,
157 DType dtype,
Jerry Ge442261b2022-09-09 13:38:56 -0700158 const std::vector<uint8_t>& data,
Tai Lyc6939a42023-08-21 17:00:29 +0000159 const bool variable = false,
160 const bool is_unranked = false);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700161 TosaSerializationTensor();
162 ~TosaSerializationTensor();
163
164 // accessor
165 std::string GetName() const
166 {
167 return _name;
168 }
169 const std::vector<int32_t>& GetShape() const
170 {
171 return _shape;
172 }
Jerry Ge442261b2022-09-09 13:38:56 -0700173 DType GetDtype() const
Eric Kunze2364dcd2021-04-26 11:06:57 -0700174 {
175 return _dtype;
176 }
Jerry Ge442261b2022-09-09 13:38:56 -0700177 bool GetVariable() const
178 {
179 return _variable;
180 }
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700181 const std::vector<uint8_t>& GetData() const
Eric Kunze2364dcd2021-04-26 11:06:57 -0700182 {
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700183 return _data;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700184 }
Tai Lyc6939a42023-08-21 17:00:29 +0000185 const bool GetIsUnranked() const
186 {
187 return _is_unranked;
188 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700189
190 // modifier
191 void SetDtype(DType dtype)
192 {
193 _dtype = dtype;
194 }
195 void SetName(std::string name)
196 {
197 _name = name;
198 }
Kevin Cheng545a5082021-11-11 01:36:33 +0000199 void SetData(const std::vector<uint8_t>& data)
200 {
201 _data = data;
202 }
203 void SetData(std::vector<uint8_t>&& data)
204 {
205 _data = std::move(data);
206 }
Tai Lyc6939a42023-08-21 17:00:29 +0000207 void SetIsUnranked(const bool value)
208 {
209 _is_unranked = value;
210 }
Jerry Geab8d2342023-04-26 22:31:11 +0000211 void SetDimSize(size_t dim, uint32_t new_size)
212 {
213 if (dim < 0 || dim >= _shape.size())
214 {
215 printf("dim is out of bound\n");
216 assert(0);
217 }
218 _shape[dim] = new_size;
219 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700220
221private:
222 DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
223 std::vector<int32_t> _shape; /* shape of the tensor */
224 std::string _name; /* name of the tensor, used for solving dependency */
Jerry Ge442261b2022-09-09 13:38:56 -0700225 bool _variable; /* is this a variable tensor */
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700226 std::vector<uint8_t> _data; /* data array */
Tai Lyc6939a42023-08-21 17:00:29 +0000227 bool _is_unranked; /* whether this is an unranked tensor */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700228};
229
230class TosaSerializationOperator
231{
232public:
233 // use default copy, void constructor
234 // constructor and destructor
235 TosaSerializationOperator(Op op,
236 Attribute attribute_type,
237 const TosaAttributeBase* attribute,
Kevin Cheng545a5082021-11-11 01:36:33 +0000238 const std::vector<std::string>& input_tensor_names,
239 const std::vector<std::string>& output_tensor_names);
240 TosaSerializationOperator(Op op,
241 Attribute attribute_type,
242 const TosaAttributeBase* attribute,
Kevin Cheng545a5082021-11-11 01:36:33 +0000243 std::vector<std::string>&& input_tensor_names,
244 std::vector<std::string>&& output_tensor_names);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700245 ~TosaSerializationOperator();
246
247 // accessor
248 Op GetOp() const
249 {
250 return _op;
251 }
252 Attribute GetAttributeType() const
253 {
254 return _attribute_type;
255 }
256 TosaAttributeBase* GetAttribute() const
257 {
258 return _attribute;
259 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700260 std::vector<std::string>& GetInputTensorNames()
261 {
262 return _input_tensor_names;
263 }
264 std::vector<std::string>& GetOutputTensorNames()
265 {
266 return _output_tensor_names;
267 }
268
269private:
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000270 void InitializeAttribute(Attribute attribute_type, const TosaAttributeBase* attribute);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700271 Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */
272 Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */
273 TosaAttributeBase* _attribute; /* real attribute class goes here */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700274 std::vector<std::string> _input_tensor_names; /* array of input tensor names */
275 std::vector<std::string> _output_tensor_names; /* array of output tensor names */
276};
277
278class TosaSerializationBasicBlock
279{
280public:
281 // constructor and destructor
Kevin Cheng545a5082021-11-11 01:36:33 +0000282 TosaSerializationBasicBlock(const std::string& name,
Jerry Ge13c78a62022-10-04 20:32:39 -0700283 const std::string& region_name,
Kevin Cheng545a5082021-11-11 01:36:33 +0000284 const std::vector<TosaSerializationOperator*>& operators,
285 const std::vector<TosaSerializationTensor*>& tensors,
286 const std::vector<std::string>& inputs,
287 const std::vector<std::string>& outputs);
288 TosaSerializationBasicBlock(std::string&& name,
Jerry Ge13c78a62022-10-04 20:32:39 -0700289 std::string&& region_name,
Kevin Cheng545a5082021-11-11 01:36:33 +0000290 std::vector<TosaSerializationOperator*>&& operators,
291 std::vector<TosaSerializationTensor*>&& tensors,
292 std::vector<std::string>&& inputs,
293 std::vector<std::string>&& outputs);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700294 ~TosaSerializationBasicBlock();
295
296 // accessor
297 std::string GetName() const
298 {
299 return _name;
300 }
Jerry Ge13c78a62022-10-04 20:32:39 -0700301 std::string GetRegionName() const
302 {
303 return _region_name;
304 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700305 std::vector<TosaSerializationOperator*>& GetOperators()
306 {
307 return _operators;
308 }
Jerry Ge13c78a62022-10-04 20:32:39 -0700309
Eric Kunze2364dcd2021-04-26 11:06:57 -0700310 std::vector<TosaSerializationTensor*>& GetTensors()
311 {
312 return _tensors;
313 }
314
315 TosaSerializationTensor* GetTensorByName(std::string name)
316 {
317 TosaSerializationTensor* result = nullptr;
318 for (auto tensor : GetTensors())
319 {
320 if (tensor->GetName() == name)
321 {
322 result = tensor;
323 break;
324 }
325 }
326 return result;
327 }
328
329 std::vector<std::string>& GetInputs()
330 {
331 return _inputs;
332 }
Jerry Ge13c78a62022-10-04 20:32:39 -0700333
Eric Kunze2364dcd2021-04-26 11:06:57 -0700334 std::vector<std::string>& GetOutputs()
335 {
336 return _outputs;
337 }
338
339private:
Jerry Ge13c78a62022-10-04 20:32:39 -0700340 std::string _name; /* name of basic block */
341 std::string _region_name;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700342 std::vector<TosaSerializationOperator*> _operators; /* TosaSerializationOperator list */
343 std::vector<TosaSerializationTensor*> _tensors; /* TosaSerializationTensor list */
344 std::vector<std::string> _inputs; /* array of string to specify block inputs */
345 std::vector<std::string> _outputs; /* array of string to specify block outputs */
346};
347
Jerry Ge13c78a62022-10-04 20:32:39 -0700348class TosaSerializationRegion
349{
350public:
351 // constructor and desctructor
352 TosaSerializationRegion(const std::string& name, const std::vector<TosaSerializationBasicBlock*>& blocks);
353 TosaSerializationRegion(const std::string&& name, const std::vector<TosaSerializationBasicBlock*>&& blocks);
354 ~TosaSerializationRegion();
355
356 // accessors
357 std::string GetName() const
358 {
359 return this->_name;
360 }
361
362 std::vector<TosaSerializationBasicBlock*>& GetBlocks()
363 {
364 return this->_blocks;
365 }
366
367 TosaSerializationBasicBlock* GetBlockByName(std::string name)
368 {
369 TosaSerializationBasicBlock* result = nullptr;
370 for (auto block : GetBlocks())
371 {
372 if (block->GetName() == name)
373 {
374 result = block;
375 break;
376 }
377 }
378 return result;
379 }
380
381private:
382 std::string _name; /* name of basic block */
383 std::vector<TosaSerializationBasicBlock*> _blocks; /* TosaSerializationBasicBlock list */
384};
385
Eric Kunze2364dcd2021-04-26 11:06:57 -0700386/*
387 * this is a helper class for writing/reading Tosa ISA
388 * supported format: .tosa (flatbuffer), .json
389 * and provide high-level std::vector-like interface
390 * to access internal data structure
391 */
392class TosaSerializationHandler
393{
394public:
395 // constructor and destructor
396 TosaSerializationHandler();
397 ~TosaSerializationHandler();
398
399 // file io
400 tosa_err_t LoadFileJson(const char* filename);
401 tosa_err_t LoadFileTosaFlatbuffer(const char* filename);
Aaron DeBattista8b3903a2021-11-18 16:38:11 +0000402 tosa_err_t LoadFileTosaFlatbuffer(const void* input, int in_size);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700403 tosa_err_t SaveFileJson(const char* filename);
404 tosa_err_t SaveFileTosaFlatbuffer(const char* filename);
405 tosa_err_t LoadFileSchema(const char* schema_filename);
406
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700407 // data format conversion. little-endian.
James Ward485a11d2022-08-05 13:48:37 +0100408 static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700409 static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
410 static tosa_err_t ConvertI48toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
411 static tosa_err_t ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out);
412 static tosa_err_t ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out);
413 static tosa_err_t ConvertI8toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
Kevin Cheng3ce56342021-07-28 13:42:29 -0700414 static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700415 static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
416
James Ward485a11d2022-08-05 13:48:37 +0100417 static tosa_err_t ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700418 static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
419 static tosa_err_t ConvertU8toI48(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out);
420 static tosa_err_t ConvertU8toI32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int32_t>& out);
421 static tosa_err_t ConvertU8toI16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int16_t>& out);
422 static tosa_err_t ConvertU8toI8(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
Kevin Cheng3ce56342021-07-28 13:42:29 -0700423 static tosa_err_t ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700424 static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out);
425
Jerry Ge442261b2022-09-09 13:38:56 -0700426 static void ForceAlignTensorData(std::vector<uint8_t>& buf);
427
Eric Kunze2364dcd2021-04-26 11:06:57 -0700428 // version
Kevin Chenge6563f52021-10-20 12:12:02 -0700429 const TosaVersion& GetVersion()
Eric Kunze2364dcd2021-04-26 11:06:57 -0700430 {
431 return _version;
432 }
433
434 // accessor
Jerry Ge13c78a62022-10-04 20:32:39 -0700435 std::vector<TosaSerializationRegion*>& GetRegions()
Eric Kunze2364dcd2021-04-26 11:06:57 -0700436 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700437 return _regions;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700438 }
439
Jerry Ge13c78a62022-10-04 20:32:39 -0700440 TosaSerializationRegion* GetMainRegion()
Eric Kunze2364dcd2021-04-26 11:06:57 -0700441 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700442 return _regions[0];
443 }
444
445 TosaSerializationRegion* GetRegionByName(std::string name)
446 {
447 TosaSerializationRegion* result = nullptr;
448 for (auto region : GetRegions())
Eric Kunze2364dcd2021-04-26 11:06:57 -0700449 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700450 if (region->GetName() == name)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700451 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700452 result = region;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700453 break;
454 }
455 }
456 return result;
457 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700458
459 bool GetSchemaLoaded() const
460 {
461 return _schemaLoaded;
462 }
463
464protected:
465 tosa_err_t Clear();
Kevin Chengb97cb1d2021-10-14 11:53:39 -0700466 tosa_err_t Deserialize(const uint8_t* buf);
467 tosa_err_t Serialize();
Eric Kunze2364dcd2021-04-26 11:06:57 -0700468
469private:
Jerry Ge13c78a62022-10-04 20:32:39 -0700470 TosaVersion _version; /* version struct */
471 flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
472 flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
473 std::vector<TosaSerializationRegion*> _regions; /* array structure to store all TosaSerializationRegion */
474 bool _schemaLoaded; /* is the schema properly loaded? */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700475};
476
477} // namespace tosa
478
479#endif // _TOSA_SERIALIZATION_HANDLER_H