blob: 398590d016b473645301b315a0ef42dc1a59b294 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
2// Copyright (c) 2020-2021, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef _TOSA_SERIALIZATION_HANDLER_H
17#define _TOSA_SERIALIZATION_HANDLER_H
18#include "attribute.h"
19#include "flatbuffers/idl.h"
20#include "flatbuffers/util.h"
21#include "numpy_utils.h"
22#include "quant_info.h"
23#include "tosa_generated.h"
24#include <cstdint>
25#include <memory>
26#include <string>
27#include <vector>
28
29namespace tosa
30{
31
32enum tosa_err_t
33{
34 TOSA_OK,
35 TOSA_USER_ERROR,
36 TOSA_FILE_ERROR,
37 TOSA_MEMORY_ERROR,
38 TOSA_SCHEMA_MISSING,
39 TOSA_INTERNAL_ERROR,
40 TOSA_VERSION_MISMATCH,
41 NUM_TOSA_ERROR
42};
43
44struct TosaVersion
45{
46 int32_t _major;
47 int32_t _minor;
48 int32_t _patch;
49 bool _experimental;
50 bool _valid;
51
52 TosaVersion()
53 {
54 _valid = false;
55 }
56
57 TosaVersion(int32_t major, int32_t minor, int32_t patch, bool experimental)
58 {
59 set_version(major, minor, patch, experimental);
60 }
61
62 void set_version(int32_t major, int32_t minor, int32_t patch, bool experimental)
63 {
64 _major = major;
65 _minor = minor;
66 _patch = patch;
67 _experimental = experimental;
68 _valid = true;
69 }
70
71 std::string to_string() const
72 {
73 std::string str;
74 assert(_valid);
75 str += std::to_string(_major) + ".";
76 str += std::to_string(_minor) + ".";
77 str += std::to_string(_patch);
78 if (_experimental)
79 str += "(experimental)";
80 return str;
81 };
82
83 bool operator==(const TosaVersion& rhs)
84 {
85 assert(_valid);
86 if (!_valid)
87 return false;
88 if (rhs._major == _major && rhs._minor == _minor && rhs._patch == _patch && rhs._experimental == _experimental)
89 {
90 return true;
91 }
92 return false;
93 }
94
95 bool operator!=(const TosaVersion& rhs)
96 {
97 assert(_valid);
98 if (!_valid)
99 return true;
100 return !((*this) == rhs);
101 }
102};
103
104class TosaSerializationHandler;
105
106class TosaSerializationTensor
107{
108public:
109 // constructor and destructor
110 TosaSerializationTensor(const flatbuffers::String* name,
111 const flatbuffers::Vector<int32_t>& shape,
112 DType dtype,
113 const flatbuffers::String* npy_filename);
114 TosaSerializationTensor(std::string& name,
115 const std::vector<int32_t>& shape,
116 DType dtype,
117 const std::string& npy_filename);
118 TosaSerializationTensor();
119 ~TosaSerializationTensor();
120
121 // accessor
122 std::string GetName() const
123 {
124 return _name;
125 }
126 const std::vector<int32_t>& GetShape() const
127 {
128 return _shape;
129 }
130 DType GetDtype()
131 {
132 return _dtype;
133 }
134 const std::string& GetNpyFilePtr() const
135 {
136 return _npy_filename;
137 }
138
139 // modifier
140 void SetDtype(DType dtype)
141 {
142 _dtype = dtype;
143 }
144 void SetName(std::string name)
145 {
146 _name = name;
147 }
148
149private:
150 DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
151 std::vector<int32_t> _shape; /* shape of the tensor */
152 std::string _name; /* name of the tensor, used for solving dependency */
153 std::string _npy_filename; /* numpy array filename if not null. so null is the distinguisher */
154};
155
156class TosaSerializationOperator
157{
158public:
159 // use default copy, void constructor
160 // constructor and destructor
161 TosaSerializationOperator(Op op,
162 Attribute attribute_type,
163 const TosaAttributeBase* attribute,
164 QuantInfo qinfo_type,
165 const TosaQuantInfoBase* qinfo,
166 std::vector<std::string> input_tensor_names,
167 std::vector<std::string> output_tensor_names);
168 ~TosaSerializationOperator();
169
170 // accessor
171 Op GetOp() const
172 {
173 return _op;
174 }
175 Attribute GetAttributeType() const
176 {
177 return _attribute_type;
178 }
179 TosaAttributeBase* GetAttribute() const
180 {
181 return _attribute;
182 }
183 QuantInfo GetQInfoType() const
184 {
185 return _qinfo_type;
186 }
187 TosaQuantInfoBase* GetQInfo() const
188 {
189 return _qinfo;
190 }
191 std::vector<std::string>& GetInputTensorNames()
192 {
193 return _input_tensor_names;
194 }
195 std::vector<std::string>& GetOutputTensorNames()
196 {
197 return _output_tensor_names;
198 }
199
200private:
201 Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */
202 Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */
203 TosaAttributeBase* _attribute; /* real attribute class goes here */
204 QuantInfo _qinfo_type; /* QuantInfo enum */
205 TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */
206 std::vector<std::string> _input_tensor_names; /* array of input tensor names */
207 std::vector<std::string> _output_tensor_names; /* array of output tensor names */
208};
209
210class TosaSerializationBasicBlock
211{
212public:
213 // constructor and destructor
214 TosaSerializationBasicBlock(std::string name,
215 std::vector<TosaSerializationOperator*> operators,
216 std::vector<TosaSerializationTensor*> tensors,
217 std::vector<std::string> inputs,
218 std::vector<std::string> outputs);
219 ~TosaSerializationBasicBlock();
220
221 // accessor
222 std::string GetName() const
223 {
224 return _name;
225 }
226 std::vector<TosaSerializationOperator*>& GetOperators()
227 {
228 return _operators;
229 }
230 std::vector<TosaSerializationTensor*>& GetTensors()
231 {
232 return _tensors;
233 }
234
235 TosaSerializationTensor* GetTensorByName(std::string name)
236 {
237 TosaSerializationTensor* result = nullptr;
238 for (auto tensor : GetTensors())
239 {
240 if (tensor->GetName() == name)
241 {
242 result = tensor;
243 break;
244 }
245 }
246 return result;
247 }
248
249 std::vector<std::string>& GetInputs()
250 {
251 return _inputs;
252 }
253 std::vector<std::string>& GetOutputs()
254 {
255 return _outputs;
256 }
257
258private:
259 std::string _name; /* name of basic block */
260 std::vector<TosaSerializationOperator*> _operators; /* TosaSerializationOperator list */
261 std::vector<TosaSerializationTensor*> _tensors; /* TosaSerializationTensor list */
262 std::vector<std::string> _inputs; /* array of string to specify block inputs */
263 std::vector<std::string> _outputs; /* array of string to specify block outputs */
264};
265
266/*
267 * this is a helper class for writing/reading Tosa ISA
268 * supported format: .tosa (flatbuffer), .json
269 * and provide high-level std::vector-like interface
270 * to access internal data structure
271 */
272class TosaSerializationHandler
273{
274public:
275 // constructor and destructor
276 TosaSerializationHandler();
277 ~TosaSerializationHandler();
278
279 // file io
280 tosa_err_t LoadFileJson(const char* filename);
281 tosa_err_t LoadFileTosaFlatbuffer(const char* filename);
282 tosa_err_t SaveFileJson(const char* filename);
283 tosa_err_t SaveFileTosaFlatbuffer(const char* filename);
284 tosa_err_t LoadFileSchema(const char* schema_filename);
285
286 // version
287 const TosaVersion& GetTosaVersion() const
288 {
289 return _version;
290 }
291
292 // accessor
293 std::vector<TosaSerializationBasicBlock*>& GetBlocks()
294 {
295 return _blocks;
296 }
297
298 TosaSerializationBasicBlock* GetBlockByName(std::string name)
299 {
300 TosaSerializationBasicBlock* result = nullptr;
301 for (auto block : GetBlocks())
302 {
303 if (block->GetName() == name)
304 {
305 result = block;
306 break;
307 }
308 }
309 return result;
310 }
311 TosaSerializationBasicBlock* GetMainBlock()
312 {
313 TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main"));
314 assert(main_block);
315 return main_block;
316 }
317
318 std::vector<std::string>& GetInputs()
319 {
320 return GetMainBlock()->GetInputs();
321 }
322 std::vector<std::string>& GetOutputs()
323 {
324 return GetMainBlock()->GetOutputs();
325 }
326
327 bool GetSchemaLoaded() const
328 {
329 return _schemaLoaded;
330 }
331
332protected:
333 tosa_err_t Clear();
334 tosa_err_t InitWithBuf(const uint8_t* buf);
335 tosa_err_t FreezeBuilder();
336 tosa_err_t SetTosaVersion();
337 tosa_err_t CheckTosaVersion(const TosaVersion& read_version);
338
339private:
340 TosaVersion _version; /* tosa version */
341 flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
342 flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
343 std::vector<TosaSerializationBasicBlock*> _blocks; /* array structure to store all TosaSerializationBasicBlock */
344 bool _schemaLoaded; /* is the schema properly loaded? */
345};
346
347} // namespace tosa
348
349#endif // _TOSA_SERIALIZATION_HANDLER_H