blob: db9481b996c8e07ab80aae70dbde56178bcd95e3 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
2// Copyright (c) 2020-2021, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef _TOSA_SERIALIZATION_HANDLER_H
17#define _TOSA_SERIALIZATION_HANDLER_H
18#include "attribute.h"
19#include "flatbuffers/idl.h"
20#include "flatbuffers/util.h"
21#include "numpy_utils.h"
22#include "quant_info.h"
23#include "tosa_generated.h"
24#include <cstdint>
25#include <memory>
26#include <string>
27#include <vector>
28
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070029#define TENSOR_BUFFER_FORCE_ALIGNMENT 8
30
Eric Kunze2364dcd2021-04-26 11:06:57 -070031namespace tosa
32{
33
34enum tosa_err_t
35{
36 TOSA_OK,
37 TOSA_USER_ERROR,
38 TOSA_FILE_ERROR,
39 TOSA_MEMORY_ERROR,
40 TOSA_SCHEMA_MISSING,
41 TOSA_INTERNAL_ERROR,
42 TOSA_VERSION_MISMATCH,
43 NUM_TOSA_ERROR
44};
45
46struct TosaVersion
47{
48 int32_t _major;
49 int32_t _minor;
50 int32_t _patch;
51 bool _experimental;
52 bool _valid;
53
54 TosaVersion()
55 {
56 _valid = false;
57 }
58
59 TosaVersion(int32_t major, int32_t minor, int32_t patch, bool experimental)
60 {
61 set_version(major, minor, patch, experimental);
62 }
63
64 void set_version(int32_t major, int32_t minor, int32_t patch, bool experimental)
65 {
66 _major = major;
67 _minor = minor;
68 _patch = patch;
69 _experimental = experimental;
70 _valid = true;
71 }
72
73 std::string to_string() const
74 {
75 std::string str;
76 assert(_valid);
77 str += std::to_string(_major) + ".";
78 str += std::to_string(_minor) + ".";
79 str += std::to_string(_patch);
80 if (_experimental)
81 str += "(experimental)";
82 return str;
83 };
84
85 bool operator==(const TosaVersion& rhs)
86 {
87 assert(_valid);
88 if (!_valid)
89 return false;
90 if (rhs._major == _major && rhs._minor == _minor && rhs._patch == _patch && rhs._experimental == _experimental)
91 {
92 return true;
93 }
94 return false;
95 }
96
97 bool operator!=(const TosaVersion& rhs)
98 {
99 assert(_valid);
100 if (!_valid)
101 return true;
102 return !((*this) == rhs);
103 }
104};
105
106class TosaSerializationHandler;
107
108class TosaSerializationTensor
109{
110public:
111 // constructor and destructor
112 TosaSerializationTensor(const flatbuffers::String* name,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700113 const flatbuffers::Vector<int32_t>* shape,
Eric Kunze2364dcd2021-04-26 11:06:57 -0700114 DType dtype,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700115 const flatbuffers::Vector<uint8_t>* data);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700116 TosaSerializationTensor(std::string& name,
117 const std::vector<int32_t>& shape,
118 DType dtype,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700119 const std::vector<uint8_t>& data);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700120 TosaSerializationTensor();
121 ~TosaSerializationTensor();
122
123 // accessor
124 std::string GetName() const
125 {
126 return _name;
127 }
128 const std::vector<int32_t>& GetShape() const
129 {
130 return _shape;
131 }
132 DType GetDtype()
133 {
134 return _dtype;
135 }
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700136 const std::vector<uint8_t>& GetData() const
Eric Kunze2364dcd2021-04-26 11:06:57 -0700137 {
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700138 return _data;
Eric Kunze2364dcd2021-04-26 11:06:57 -0700139 }
140
141 // modifier
142 void SetDtype(DType dtype)
143 {
144 _dtype = dtype;
145 }
146 void SetName(std::string name)
147 {
148 _name = name;
149 }
150
151private:
152 DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
153 std::vector<int32_t> _shape; /* shape of the tensor */
154 std::string _name; /* name of the tensor, used for solving dependency */
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700155 std::vector<uint8_t> _data; /* data array */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700156};
157
158class TosaSerializationOperator
159{
160public:
161 // use default copy, void constructor
162 // constructor and destructor
163 TosaSerializationOperator(Op op,
164 Attribute attribute_type,
165 const TosaAttributeBase* attribute,
166 QuantInfo qinfo_type,
167 const TosaQuantInfoBase* qinfo,
168 std::vector<std::string> input_tensor_names,
169 std::vector<std::string> output_tensor_names);
170 ~TosaSerializationOperator();
171
172 // accessor
173 Op GetOp() const
174 {
175 return _op;
176 }
177 Attribute GetAttributeType() const
178 {
179 return _attribute_type;
180 }
181 TosaAttributeBase* GetAttribute() const
182 {
183 return _attribute;
184 }
185 QuantInfo GetQInfoType() const
186 {
187 return _qinfo_type;
188 }
189 TosaQuantInfoBase* GetQInfo() const
190 {
191 return _qinfo;
192 }
193 std::vector<std::string>& GetInputTensorNames()
194 {
195 return _input_tensor_names;
196 }
197 std::vector<std::string>& GetOutputTensorNames()
198 {
199 return _output_tensor_names;
200 }
201
202private:
203 Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */
204 Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */
205 TosaAttributeBase* _attribute; /* real attribute class goes here */
206 QuantInfo _qinfo_type; /* QuantInfo enum */
207 TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */
208 std::vector<std::string> _input_tensor_names; /* array of input tensor names */
209 std::vector<std::string> _output_tensor_names; /* array of output tensor names */
210};
211
212class TosaSerializationBasicBlock
213{
214public:
215 // constructor and destructor
216 TosaSerializationBasicBlock(std::string name,
217 std::vector<TosaSerializationOperator*> operators,
218 std::vector<TosaSerializationTensor*> tensors,
219 std::vector<std::string> inputs,
220 std::vector<std::string> outputs);
221 ~TosaSerializationBasicBlock();
222
223 // accessor
224 std::string GetName() const
225 {
226 return _name;
227 }
228 std::vector<TosaSerializationOperator*>& GetOperators()
229 {
230 return _operators;
231 }
232 std::vector<TosaSerializationTensor*>& GetTensors()
233 {
234 return _tensors;
235 }
236
237 TosaSerializationTensor* GetTensorByName(std::string name)
238 {
239 TosaSerializationTensor* result = nullptr;
240 for (auto tensor : GetTensors())
241 {
242 if (tensor->GetName() == name)
243 {
244 result = tensor;
245 break;
246 }
247 }
248 return result;
249 }
250
251 std::vector<std::string>& GetInputs()
252 {
253 return _inputs;
254 }
255 std::vector<std::string>& GetOutputs()
256 {
257 return _outputs;
258 }
259
260private:
261 std::string _name; /* name of basic block */
262 std::vector<TosaSerializationOperator*> _operators; /* TosaSerializationOperator list */
263 std::vector<TosaSerializationTensor*> _tensors; /* TosaSerializationTensor list */
264 std::vector<std::string> _inputs; /* array of string to specify block inputs */
265 std::vector<std::string> _outputs; /* array of string to specify block outputs */
266};
267
268/*
269 * this is a helper class for writing/reading Tosa ISA
270 * supported format: .tosa (flatbuffer), .json
271 * and provide high-level std::vector-like interface
272 * to access internal data structure
273 */
274class TosaSerializationHandler
275{
276public:
277 // constructor and destructor
278 TosaSerializationHandler();
279 ~TosaSerializationHandler();
280
281 // file io
282 tosa_err_t LoadFileJson(const char* filename);
283 tosa_err_t LoadFileTosaFlatbuffer(const char* filename);
284 tosa_err_t SaveFileJson(const char* filename);
285 tosa_err_t SaveFileTosaFlatbuffer(const char* filename);
286 tosa_err_t LoadFileSchema(const char* schema_filename);
287
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700288 // data format conversion. little-endian.
289 static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
290 static tosa_err_t ConvertI48toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
291 static tosa_err_t ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out);
292 static tosa_err_t ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out);
293 static tosa_err_t ConvertI8toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
294 static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
295
296 static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
297 static tosa_err_t ConvertU8toI48(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out);
298 static tosa_err_t ConvertU8toI32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int32_t>& out);
299 static tosa_err_t ConvertU8toI16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int16_t>& out);
300 static tosa_err_t ConvertU8toI8(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
301 static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out);
302
Eric Kunze2364dcd2021-04-26 11:06:57 -0700303 // version
304 const TosaVersion& GetTosaVersion() const
305 {
306 return _version;
307 }
308
309 // accessor
310 std::vector<TosaSerializationBasicBlock*>& GetBlocks()
311 {
312 return _blocks;
313 }
314
315 TosaSerializationBasicBlock* GetBlockByName(std::string name)
316 {
317 TosaSerializationBasicBlock* result = nullptr;
318 for (auto block : GetBlocks())
319 {
320 if (block->GetName() == name)
321 {
322 result = block;
323 break;
324 }
325 }
326 return result;
327 }
328 TosaSerializationBasicBlock* GetMainBlock()
329 {
330 TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main"));
331 assert(main_block);
332 return main_block;
333 }
334
335 std::vector<std::string>& GetInputs()
336 {
337 return GetMainBlock()->GetInputs();
338 }
339 std::vector<std::string>& GetOutputs()
340 {
341 return GetMainBlock()->GetOutputs();
342 }
343
344 bool GetSchemaLoaded() const
345 {
346 return _schemaLoaded;
347 }
348
349protected:
350 tosa_err_t Clear();
351 tosa_err_t InitWithBuf(const uint8_t* buf);
352 tosa_err_t FreezeBuilder();
353 tosa_err_t SetTosaVersion();
354 tosa_err_t CheckTosaVersion(const TosaVersion& read_version);
355
356private:
357 TosaVersion _version; /* tosa version */
358 flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
359 flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
360 std::vector<TosaSerializationBasicBlock*> _blocks; /* array structure to store all TosaSerializationBasicBlock */
361 bool _schemaLoaded; /* is the schema properly loaded? */
362};
363
364} // namespace tosa
365
366#endif // _TOSA_SERIALIZATION_HANDLER_H