blob: 4d894edbd4ebed78e4bc126ada4c057e52de3b21 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
2// Copyright (c) 2020-2021, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef _TOSA_SERIALIZATION_HANDLER_H
17#define _TOSA_SERIALIZATION_HANDLER_H
18#include "attribute.h"
19#include "flatbuffers/idl.h"
20#include "flatbuffers/util.h"
21#include "numpy_utils.h"
22#include "quant_info.h"
23#include "tosa_generated.h"
24#include <cstdint>
25#include <memory>
26#include <string>
27#include <vector>
28
Kevin Chenge6563f52021-10-20 12:12:02 -070029// Keep version number in sync with the version default value with schema/tosa.fbs
Kevin Chengb97cb1d2021-10-14 11:53:39 -070030#define TOSA_VERSION_MAJOR 0
Eric Kunzea687b612021-11-03 17:02:57 -070031#define TOSA_VERSION_MINOR 24
Kevin Chengb97cb1d2021-10-14 11:53:39 -070032#define TOSA_VERSION_PATCH 0
Eric Kunzea687b612021-11-03 17:02:57 -070033#define TOSA_VERSION_DRAFT true
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070034#define TENSOR_BUFFER_FORCE_ALIGNMENT 8
35
Eric Kunze2364dcd2021-04-26 11:06:57 -070036namespace tosa
37{
38
39enum tosa_err_t
40{
41 TOSA_OK,
42 TOSA_USER_ERROR,
43 TOSA_FILE_ERROR,
44 TOSA_MEMORY_ERROR,
45 TOSA_SCHEMA_MISSING,
46 TOSA_INTERNAL_ERROR,
47 TOSA_VERSION_MISMATCH,
48 NUM_TOSA_ERROR
49};
50
Kevin Chenge6563f52021-10-20 12:12:02 -070051struct TosaVersion
52{
53 int32_t _major;
54 int32_t _minor;
55 int32_t _patch;
56 bool _draft;
57
58 enum class compat_t
59 {
60 COMPLETELY_COMPATIBLE,
61 PARTIALLY_COMPATIBLE,
62 NOT_COMPATIBLE
63 };
64
65 TosaVersion() = default;
66 TosaVersion(int32_t major, int32_t minor, int32_t patch, bool draft)
67 {
68 set_version(major, minor, patch, draft);
69 }
70
71 void set_version(int32_t major, int32_t minor, int32_t patch, bool draft)
72 {
73 _major = major;
74 _minor = minor;
75 _patch = patch;
76 _draft = draft;
77 }
78
79 std::string to_string() const
80 {
81 std::string str;
82 str += std::to_string(_major) + ".";
83 str += std::to_string(_minor) + ".";
84 str += std::to_string(_patch);
85 if (_draft)
86 str += "d";
87 return str;
88 }
89
90 compat_t is_compatible(const TosaVersion& rhs) const
91 {
92 if (rhs._major == _major && rhs._minor == _minor)
93 {
94 if (rhs._patch == _patch && rhs._draft == _draft)
95 {
96 return TosaVersion::compat_t::COMPLETELY_COMPATIBLE;
97 }
98 else
99 {
100 return TosaVersion::compat_t::PARTIALLY_COMPATIBLE;
101 }
102 }
103 return TosaVersion::compat_t::NOT_COMPATIBLE;
104 }
105};
106
Eric Kunze2364dcd2021-04-26 11:06:57 -0700107class TosaSerializationHandler;
108
109class TosaSerializationTensor
110{
111public:
112 // constructor and destructor
113 TosaSerializationTensor(const flatbuffers::String* name,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700114 const flatbuffers::Vector<int32_t>* shape,
Eric Kunze2364dcd2021-04-26 11:06:57 -0700115 DType dtype,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700116 const flatbuffers::Vector<uint8_t>* data);
Kevin Cheng545a5082021-11-11 01:36:33 +0000117 TosaSerializationTensor(const std::string& name,
Eric Kunze2364dcd2021-04-26 11:06:57 -0700118 const std::vector<int32_t>& shape,
119 DType dtype,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700120 const std::vector<uint8_t>& data);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700121 TosaSerializationTensor();
122 ~TosaSerializationTensor();
123
124 // accessor
125 std::string GetName() const
126 {
127 return _name;
128 }
129 const std::vector<int32_t>& GetShape() const
130 {
131 return _shape;
132 }
133 DType GetDtype()
134 {
135 return _dtype;
136 }
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700137 const std::vector<uint8_t>& GetData() const
Eric Kunze2364dcd2021-04-26 11:06:57 -0700138 {
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700139 return _data;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700140 }
141
142 // modifier
143 void SetDtype(DType dtype)
144 {
145 _dtype = dtype;
146 }
147 void SetName(std::string name)
148 {
149 _name = name;
150 }
Kevin Cheng545a5082021-11-11 01:36:33 +0000151 void SetData(const std::vector<uint8_t>& data)
152 {
153 _data = data;
154 }
155 void SetData(std::vector<uint8_t>&& data)
156 {
157 _data = std::move(data);
158 }
Eric Kunze2364dcd2021-04-26 11:06:57 -0700159
160private:
161 DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
162 std::vector<int32_t> _shape; /* shape of the tensor */
163 std::string _name; /* name of the tensor, used for solving dependency */
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700164 std::vector<uint8_t> _data; /* data array */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700165};
166
167class TosaSerializationOperator
168{
169public:
170 // use default copy, void constructor
171 // constructor and destructor
172 TosaSerializationOperator(Op op,
173 Attribute attribute_type,
174 const TosaAttributeBase* attribute,
175 QuantInfo qinfo_type,
176 const TosaQuantInfoBase* qinfo,
Kevin Cheng545a5082021-11-11 01:36:33 +0000177 const std::vector<std::string>& input_tensor_names,
178 const std::vector<std::string>& output_tensor_names);
179 TosaSerializationOperator(Op op,
180 Attribute attribute_type,
181 const TosaAttributeBase* attribute,
182 QuantInfo qinfo_type,
183 const TosaQuantInfoBase* qinfo,
184 std::vector<std::string>&& input_tensor_names,
185 std::vector<std::string>&& output_tensor_names);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700186 ~TosaSerializationOperator();
187
188 // accessor
189 Op GetOp() const
190 {
191 return _op;
192 }
193 Attribute GetAttributeType() const
194 {
195 return _attribute_type;
196 }
197 TosaAttributeBase* GetAttribute() const
198 {
199 return _attribute;
200 }
201 QuantInfo GetQInfoType() const
202 {
203 return _qinfo_type;
204 }
205 TosaQuantInfoBase* GetQInfo() const
206 {
207 return _qinfo;
208 }
209 std::vector<std::string>& GetInputTensorNames()
210 {
211 return _input_tensor_names;
212 }
213 std::vector<std::string>& GetOutputTensorNames()
214 {
215 return _output_tensor_names;
216 }
217
218private:
Kevin Cheng545a5082021-11-11 01:36:33 +0000219 void InitializeAttributeQinfo(Attribute attribute_type,
220 const TosaAttributeBase* attribute,
221 QuantInfo qinfo_type,
222 const TosaQuantInfoBase* qinfo);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700223 Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */
224 Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */
225 TosaAttributeBase* _attribute; /* real attribute class goes here */
226 QuantInfo _qinfo_type; /* QuantInfo enum */
227 TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */
228 std::vector<std::string> _input_tensor_names; /* array of input tensor names */
229 std::vector<std::string> _output_tensor_names; /* array of output tensor names */
230};
231
232class TosaSerializationBasicBlock
233{
234public:
235 // constructor and destructor
Kevin Cheng545a5082021-11-11 01:36:33 +0000236 TosaSerializationBasicBlock(const std::string& name,
237 const std::vector<TosaSerializationOperator*>& operators,
238 const std::vector<TosaSerializationTensor*>& tensors,
239 const std::vector<std::string>& inputs,
240 const std::vector<std::string>& outputs);
241 TosaSerializationBasicBlock(std::string&& name,
242 std::vector<TosaSerializationOperator*>&& operators,
243 std::vector<TosaSerializationTensor*>&& tensors,
244 std::vector<std::string>&& inputs,
245 std::vector<std::string>&& outputs);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700246 ~TosaSerializationBasicBlock();
247
248 // accessor
249 std::string GetName() const
250 {
251 return _name;
252 }
253 std::vector<TosaSerializationOperator*>& GetOperators()
254 {
255 return _operators;
256 }
257 std::vector<TosaSerializationTensor*>& GetTensors()
258 {
259 return _tensors;
260 }
261
262 TosaSerializationTensor* GetTensorByName(std::string name)
263 {
264 TosaSerializationTensor* result = nullptr;
265 for (auto tensor : GetTensors())
266 {
267 if (tensor->GetName() == name)
268 {
269 result = tensor;
270 break;
271 }
272 }
273 return result;
274 }
275
276 std::vector<std::string>& GetInputs()
277 {
278 return _inputs;
279 }
280 std::vector<std::string>& GetOutputs()
281 {
282 return _outputs;
283 }
284
285private:
286 std::string _name; /* name of basic block */
287 std::vector<TosaSerializationOperator*> _operators; /* TosaSerializationOperator list */
288 std::vector<TosaSerializationTensor*> _tensors; /* TosaSerializationTensor list */
289 std::vector<std::string> _inputs; /* array of string to specify block inputs */
290 std::vector<std::string> _outputs; /* array of string to specify block outputs */
291};
292
293/*
294 * this is a helper class for writing/reading Tosa ISA
295 * supported format: .tosa (flatbuffer), .json
296 * and provide high-level std::vector-like interface
297 * to access internal data structure
298 */
299class TosaSerializationHandler
300{
301public:
302 // constructor and destructor
303 TosaSerializationHandler();
304 ~TosaSerializationHandler();
305
306 // file io
307 tosa_err_t LoadFileJson(const char* filename);
308 tosa_err_t LoadFileTosaFlatbuffer(const char* filename);
Aaron DeBattista8b3903a2021-11-18 16:38:11 +0000309 tosa_err_t LoadFileTosaFlatbuffer(const void* input, int in_size);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700310 tosa_err_t SaveFileJson(const char* filename);
311 tosa_err_t SaveFileTosaFlatbuffer(const char* filename);
312 tosa_err_t LoadFileSchema(const char* schema_filename);
313
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700314 // data format conversion. little-endian.
315 static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
316 static tosa_err_t ConvertI48toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
317 static tosa_err_t ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out);
318 static tosa_err_t ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out);
319 static tosa_err_t ConvertI8toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
Kevin Cheng3ce56342021-07-28 13:42:29 -0700320 static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700321 static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
322
323 static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
324 static tosa_err_t ConvertU8toI48(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out);
325 static tosa_err_t ConvertU8toI32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int32_t>& out);
326 static tosa_err_t ConvertU8toI16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int16_t>& out);
327 static tosa_err_t ConvertU8toI8(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
Kevin Cheng3ce56342021-07-28 13:42:29 -0700328 static tosa_err_t ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700329 static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out);
330
Eric Kunze2364dcd2021-04-26 11:06:57 -0700331 // version
Kevin Chenge6563f52021-10-20 12:12:02 -0700332 const TosaVersion& GetVersion()
Eric Kunze2364dcd2021-04-26 11:06:57 -0700333 {
334 return _version;
335 }
336
337 // accessor
338 std::vector<TosaSerializationBasicBlock*>& GetBlocks()
339 {
340 return _blocks;
341 }
342
343 TosaSerializationBasicBlock* GetBlockByName(std::string name)
344 {
345 TosaSerializationBasicBlock* result = nullptr;
346 for (auto block : GetBlocks())
347 {
348 if (block->GetName() == name)
349 {
350 result = block;
351 break;
352 }
353 }
354 return result;
355 }
356 TosaSerializationBasicBlock* GetMainBlock()
357 {
358 TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main"));
359 assert(main_block);
360 return main_block;
361 }
362
363 std::vector<std::string>& GetInputs()
364 {
365 return GetMainBlock()->GetInputs();
366 }
367 std::vector<std::string>& GetOutputs()
368 {
369 return GetMainBlock()->GetOutputs();
370 }
371
372 bool GetSchemaLoaded() const
373 {
374 return _schemaLoaded;
375 }
376
377protected:
378 tosa_err_t Clear();
Kevin Chengb97cb1d2021-10-14 11:53:39 -0700379 tosa_err_t Deserialize(const uint8_t* buf);
380 tosa_err_t Serialize();
Kevin Chenga81a7a12021-11-10 14:07:34 -0800381 TosaVersion ParseTosaSchemaVersion(std::string schema);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700382
383private:
Kevin Chenge6563f52021-10-20 12:12:02 -0700384 TosaVersion _version; /* version struct */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700385 flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
386 flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
387 std::vector<TosaSerializationBasicBlock*> _blocks; /* array structure to store all TosaSerializationBasicBlock */
388 bool _schemaLoaded; /* is the schema properly loaded? */
389};
390
391} // namespace tosa
392
393#endif // _TOSA_SERIALIZATION_HANDLER_H