blob: daceecdf0b5105a24a5d81fd9dc800006e33c951 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
2// Copyright (c) 2020-2021, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef _TOSA_SERIALIZATION_HANDLER_H
17#define _TOSA_SERIALIZATION_HANDLER_H
18#include "attribute.h"
19#include "flatbuffers/idl.h"
20#include "flatbuffers/util.h"
21#include "numpy_utils.h"
Eric Kunze2364dcd2021-04-26 11:06:57 -070022#include "tosa_generated.h"
23#include <cstdint>
24#include <memory>
25#include <string>
26#include <vector>
27
Kevin Chenge6563f52021-10-20 12:12:02 -070028// Keep version number in sync with the version default value with schema/tosa.fbs
Kevin Chengb97cb1d2021-10-14 11:53:39 -070029#define TOSA_VERSION_MAJOR 0
Eric Kunzedce6ceb2023-03-16 18:44:26 +000030#define TOSA_VERSION_MINOR 70
Kevin Chengb97cb1d2021-10-14 11:53:39 -070031#define TOSA_VERSION_PATCH 0
Eric Kunze63d45ab2023-05-25 16:18:02 -070032#define TOSA_VERSION_DRAFT false
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070033#define TENSOR_BUFFER_FORCE_ALIGNMENT 8
34
Eric Kunze2364dcd2021-04-26 11:06:57 -070035namespace tosa
36{
37
38enum tosa_err_t
39{
40 TOSA_OK,
41 TOSA_USER_ERROR,
42 TOSA_FILE_ERROR,
43 TOSA_MEMORY_ERROR,
44 TOSA_SCHEMA_MISSING,
45 TOSA_INTERNAL_ERROR,
46 TOSA_VERSION_MISMATCH,
47 NUM_TOSA_ERROR
48};
49
Kevin Chenge6563f52021-10-20 12:12:02 -070050struct TosaVersion
51{
52 int32_t _major;
53 int32_t _minor;
54 int32_t _patch;
55 bool _draft;
56
57 enum class compat_t
58 {
59 COMPLETELY_COMPATIBLE,
60 PARTIALLY_COMPATIBLE,
61 NOT_COMPATIBLE
62 };
63
64 TosaVersion() = default;
65 TosaVersion(int32_t major, int32_t minor, int32_t patch, bool draft)
66 {
67 set_version(major, minor, patch, draft);
68 }
69
70 void set_version(int32_t major, int32_t minor, int32_t patch, bool draft)
71 {
72 _major = major;
73 _minor = minor;
74 _patch = patch;
75 _draft = draft;
76 }
77
78 std::string to_string() const
79 {
80 std::string str;
81 str += std::to_string(_major) + ".";
82 str += std::to_string(_minor) + ".";
83 str += std::to_string(_patch);
84 if (_draft)
85 str += "d";
86 return str;
87 }
88
89 compat_t is_compatible(const TosaVersion& rhs) const
90 {
91 if (rhs._major == _major && rhs._minor == _minor)
92 {
93 if (rhs._patch == _patch && rhs._draft == _draft)
94 {
95 return TosaVersion::compat_t::COMPLETELY_COMPATIBLE;
96 }
97 else
98 {
99 return TosaVersion::compat_t::PARTIALLY_COMPATIBLE;
100 }
101 }
102 return TosaVersion::compat_t::NOT_COMPATIBLE;
103 }
104};
105
Eric Kunze2364dcd2021-04-26 11:06:57 -0700106class TosaSerializationHandler;
107
108class TosaSerializationTensor
109{
110public:
111 // constructor and destructor
112 TosaSerializationTensor(const flatbuffers::String* name,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700113 const flatbuffers::Vector<int32_t>* shape,
Eric Kunze2364dcd2021-04-26 11:06:57 -0700114 DType dtype,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700115 const flatbuffers::Vector<uint8_t>* data);
Kevin Cheng545a5082021-11-11 01:36:33 +0000116 TosaSerializationTensor(const std::string& name,
Eric Kunze2364dcd2021-04-26 11:06:57 -0700117 const std::vector<int32_t>& shape,
118 DType dtype,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700119 const std::vector<uint8_t>& data);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700120 TosaSerializationTensor();
121 ~TosaSerializationTensor();
122
123 // accessor
124 std::string GetName() const
125 {
126 return _name;
127 }
128 const std::vector<int32_t>& GetShape() const
129 {
130 return _shape;
131 }
132 DType GetDtype()
133 {
134 return _dtype;
135 }
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700136 const std::vector<uint8_t>& GetData() const
Eric Kunze2364dcd2021-04-26 11:06:57 -0700137 {
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700138 return _data;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700139 }
140
141 // modifier
142 void SetDtype(DType dtype)
143 {
144 _dtype = dtype;
145 }
146 void SetName(std::string name)
147 {
148 _name = name;
149 }
Kevin Cheng545a5082021-11-11 01:36:33 +0000150 void SetData(const std::vector<uint8_t>& data)
151 {
152 _data = data;
153 }
154 void SetData(std::vector<uint8_t>&& data)
155 {
156 _data = std::move(data);
157 }
Jerry Geab8d2342023-04-26 22:31:11 +0000158 void SetDimSize(size_t dim, uint32_t new_size)
159 {
160 if (dim < 0 || dim >= _shape.size())
161 {
162 printf("dim is out of bound\n");
163 assert(0);
164 }
165 _shape[dim] = new_size;
166 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700167
168private:
169 DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
170 std::vector<int32_t> _shape; /* shape of the tensor */
171 std::string _name; /* name of the tensor, used for solving dependency */
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700172 std::vector<uint8_t> _data; /* data array */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700173};
174
175class TosaSerializationOperator
176{
177public:
178 // use default copy, void constructor
179 // constructor and destructor
180 TosaSerializationOperator(Op op,
181 Attribute attribute_type,
182 const TosaAttributeBase* attribute,
Kevin Cheng545a5082021-11-11 01:36:33 +0000183 const std::vector<std::string>& input_tensor_names,
184 const std::vector<std::string>& output_tensor_names);
185 TosaSerializationOperator(Op op,
186 Attribute attribute_type,
187 const TosaAttributeBase* attribute,
Kevin Cheng545a5082021-11-11 01:36:33 +0000188 std::vector<std::string>&& input_tensor_names,
189 std::vector<std::string>&& output_tensor_names);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700190 ~TosaSerializationOperator();
191
192 // accessor
193 Op GetOp() const
194 {
195 return _op;
196 }
197 Attribute GetAttributeType() const
198 {
199 return _attribute_type;
200 }
201 TosaAttributeBase* GetAttribute() const
202 {
203 return _attribute;
204 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700205 std::vector<std::string>& GetInputTensorNames()
206 {
207 return _input_tensor_names;
208 }
209 std::vector<std::string>& GetOutputTensorNames()
210 {
211 return _output_tensor_names;
212 }
213
214private:
Eric Kunzebdcc3fe2022-06-07 05:17:37 +0000215 void InitializeAttribute(Attribute attribute_type, const TosaAttributeBase* attribute);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700216 Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */
217 Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */
218 TosaAttributeBase* _attribute; /* real attribute class goes here */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700219 std::vector<std::string> _input_tensor_names; /* array of input tensor names */
220 std::vector<std::string> _output_tensor_names; /* array of output tensor names */
221};
222
223class TosaSerializationBasicBlock
224{
225public:
226 // constructor and destructor
Kevin Cheng545a5082021-11-11 01:36:33 +0000227 TosaSerializationBasicBlock(const std::string& name,
Jerry Ge13c78a62022-10-04 20:32:39 -0700228 const std::string& region_name,
Kevin Cheng545a5082021-11-11 01:36:33 +0000229 const std::vector<TosaSerializationOperator*>& operators,
230 const std::vector<TosaSerializationTensor*>& tensors,
231 const std::vector<std::string>& inputs,
232 const std::vector<std::string>& outputs);
233 TosaSerializationBasicBlock(std::string&& name,
Jerry Ge13c78a62022-10-04 20:32:39 -0700234 std::string&& region_name,
Kevin Cheng545a5082021-11-11 01:36:33 +0000235 std::vector<TosaSerializationOperator*>&& operators,
236 std::vector<TosaSerializationTensor*>&& tensors,
237 std::vector<std::string>&& inputs,
238 std::vector<std::string>&& outputs);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700239 ~TosaSerializationBasicBlock();
240
241 // accessor
242 std::string GetName() const
243 {
244 return _name;
245 }
Jerry Ge13c78a62022-10-04 20:32:39 -0700246 std::string GetRegionName() const
247 {
248 return _region_name;
249 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700250 std::vector<TosaSerializationOperator*>& GetOperators()
251 {
252 return _operators;
253 }
Jerry Ge13c78a62022-10-04 20:32:39 -0700254
Eric Kunze2364dcd2021-04-26 11:06:57 -0700255 std::vector<TosaSerializationTensor*>& GetTensors()
256 {
257 return _tensors;
258 }
259
260 TosaSerializationTensor* GetTensorByName(std::string name)
261 {
262 TosaSerializationTensor* result = nullptr;
263 for (auto tensor : GetTensors())
264 {
265 if (tensor->GetName() == name)
266 {
267 result = tensor;
268 break;
269 }
270 }
271 return result;
272 }
273
274 std::vector<std::string>& GetInputs()
275 {
276 return _inputs;
277 }
Jerry Ge13c78a62022-10-04 20:32:39 -0700278
Eric Kunze2364dcd2021-04-26 11:06:57 -0700279 std::vector<std::string>& GetOutputs()
280 {
281 return _outputs;
282 }
283
284private:
Jerry Ge13c78a62022-10-04 20:32:39 -0700285 std::string _name; /* name of basic block */
286 std::string _region_name;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700287 std::vector<TosaSerializationOperator*> _operators; /* TosaSerializationOperator list */
288 std::vector<TosaSerializationTensor*> _tensors; /* TosaSerializationTensor list */
289 std::vector<std::string> _inputs; /* array of string to specify block inputs */
290 std::vector<std::string> _outputs; /* array of string to specify block outputs */
291};
292
Jerry Ge13c78a62022-10-04 20:32:39 -0700293class TosaSerializationRegion
294{
295public:
296 // constructor and desctructor
297 TosaSerializationRegion(const std::string& name, const std::vector<TosaSerializationBasicBlock*>& blocks);
298 TosaSerializationRegion(const std::string&& name, const std::vector<TosaSerializationBasicBlock*>&& blocks);
299 ~TosaSerializationRegion();
300
301 // accessors
302 std::string GetName() const
303 {
304 return this->_name;
305 }
306
307 std::vector<TosaSerializationBasicBlock*>& GetBlocks()
308 {
309 return this->_blocks;
310 }
311
312 TosaSerializationBasicBlock* GetBlockByName(std::string name)
313 {
314 TosaSerializationBasicBlock* result = nullptr;
315 for (auto block : GetBlocks())
316 {
317 if (block->GetName() == name)
318 {
319 result = block;
320 break;
321 }
322 }
323 return result;
324 }
325
326private:
327 std::string _name; /* name of basic block */
328 std::vector<TosaSerializationBasicBlock*> _blocks; /* TosaSerializationBasicBlock list */
329};
330
Eric Kunze2364dcd2021-04-26 11:06:57 -0700331/*
332 * this is a helper class for writing/reading Tosa ISA
333 * supported format: .tosa (flatbuffer), .json
334 * and provide high-level std::vector-like interface
335 * to access internal data structure
336 */
337class TosaSerializationHandler
338{
339public:
340 // constructor and destructor
341 TosaSerializationHandler();
342 ~TosaSerializationHandler();
343
344 // file io
345 tosa_err_t LoadFileJson(const char* filename);
346 tosa_err_t LoadFileTosaFlatbuffer(const char* filename);
Aaron DeBattista8b3903a2021-11-18 16:38:11 +0000347 tosa_err_t LoadFileTosaFlatbuffer(const void* input, int in_size);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700348 tosa_err_t SaveFileJson(const char* filename);
349 tosa_err_t SaveFileTosaFlatbuffer(const char* filename);
350 tosa_err_t LoadFileSchema(const char* schema_filename);
351
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700352 // data format conversion. little-endian.
James Ward485a11d2022-08-05 13:48:37 +0100353 static tosa_err_t ConvertF16toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700354 static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
355 static tosa_err_t ConvertI48toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
356 static tosa_err_t ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out);
357 static tosa_err_t ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out);
358 static tosa_err_t ConvertI8toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
Kevin Cheng3ce56342021-07-28 13:42:29 -0700359 static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700360 static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
361
James Ward485a11d2022-08-05 13:48:37 +0100362 static tosa_err_t ConvertU8toF16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700363 static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
364 static tosa_err_t ConvertU8toI48(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out);
365 static tosa_err_t ConvertU8toI32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int32_t>& out);
366 static tosa_err_t ConvertU8toI16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int16_t>& out);
367 static tosa_err_t ConvertU8toI8(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
Kevin Cheng3ce56342021-07-28 13:42:29 -0700368 static tosa_err_t ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700369 static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out);
370
Eric Kunze2364dcd2021-04-26 11:06:57 -0700371 // version
Kevin Chenge6563f52021-10-20 12:12:02 -0700372 const TosaVersion& GetVersion()
Eric Kunze2364dcd2021-04-26 11:06:57 -0700373 {
374 return _version;
375 }
376
377 // accessor
Jerry Ge13c78a62022-10-04 20:32:39 -0700378 std::vector<TosaSerializationRegion*>& GetRegions()
Eric Kunze2364dcd2021-04-26 11:06:57 -0700379 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700380 return _regions;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700381 }
382
Jerry Ge13c78a62022-10-04 20:32:39 -0700383 TosaSerializationRegion* GetMainRegion()
Eric Kunze2364dcd2021-04-26 11:06:57 -0700384 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700385 return _regions[0];
386 }
387
388 TosaSerializationRegion* GetRegionByName(std::string name)
389 {
390 TosaSerializationRegion* result = nullptr;
391 for (auto region : GetRegions())
Eric Kunze2364dcd2021-04-26 11:06:57 -0700392 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700393 if (region->GetName() == name)
Eric Kunze2364dcd2021-04-26 11:06:57 -0700394 {
Jerry Ge13c78a62022-10-04 20:32:39 -0700395 result = region;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700396 break;
397 }
398 }
399 return result;
400 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700401
402 bool GetSchemaLoaded() const
403 {
404 return _schemaLoaded;
405 }
406
407protected:
408 tosa_err_t Clear();
Kevin Chengb97cb1d2021-10-14 11:53:39 -0700409 tosa_err_t Deserialize(const uint8_t* buf);
410 tosa_err_t Serialize();
Eric Kunze2364dcd2021-04-26 11:06:57 -0700411
412private:
Jerry Ge13c78a62022-10-04 20:32:39 -0700413 TosaVersion _version; /* version struct */
414 flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
415 flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
416 std::vector<TosaSerializationRegion*> _regions; /* array structure to store all TosaSerializationRegion */
417 bool _schemaLoaded; /* is the schema properly loaded? */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700418};
419
420} // namespace tosa
421
422#endif // _TOSA_SERIALIZATION_HANDLER_H