blob: 124b8e0c7d83ab7591e6ad6165a9450ab3f562f6 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef _TOSA_SERIALIZATION_HANDLER_H
17#define _TOSA_SERIALIZATION_HANDLER_H
18#include "attribute.h"
19#include "flatbuffers/idl.h"
20#include "flatbuffers/util.h"
21#include "quant_info.h"
22#include "tosa_generated.h"
23#include <cstdint>
24#include <memory>
25#include <string>
26#include <vector>
27
28namespace tosa
29{
30
31enum tosa_err_t
32{
33 TOSA_OK,
34 TOSA_USER_ERROR,
35 TOSA_FILE_ERROR,
36 TOSA_MEMORY_ERROR,
37 TOSA_SCHEMA_MISSING,
38 TOSA_INTERNAL_ERROR,
39 TOSA_VERSION_MISMATCH,
40 NUM_TOSA_ERROR
41};
42
43struct TosaVersion
44{
45 int32_t _major;
46 int32_t _minor;
47 int32_t _patch;
48 bool _experimental;
49
50 TosaVersion() = delete;
51 TosaVersion(int32_t major, int32_t minor, int32_t patch, bool experimental)
52 {
53 _major = major;
54 _minor = minor;
55 _patch = patch;
56 _experimental = experimental;
57 }
58
59 std::string to_string() const
60 {
61 std::string str;
62 str += std::to_string(_major) + ".";
63 str += std::to_string(_minor) + ".";
64 str += std::to_string(_patch);
65 if (_experimental)
66 str += "(experimental)";
67 return str;
68 };
69
70 bool operator==(const TosaVersion& rhs)
71 {
72 if (rhs._major == _major && rhs._minor == _minor && rhs._patch == _patch && rhs._experimental == _experimental)
73 {
74 return true;
75 }
76 return false;
77 }
78
79 bool operator!=(const TosaVersion& rhs)
80 {
81 return !((*this) == rhs);
82 }
83};
84
85class TosaSerializationHandler;
86
87class TosaSerializationTensor
88{
89public:
90 // constructor and destructor
91 TosaSerializationTensor(const flatbuffers::String* name,
92 const flatbuffers::Vector<uint32_t>& usage,
93 const flatbuffers::Vector<int32_t>& shape,
94 DType dtype,
95 const flatbuffers::Vector<uint32_t>& format,
96 const flatbuffers::String* npy_filename);
97 TosaSerializationTensor(std::string name,
98 const std::vector<Usage>& usage,
99 const std::vector<int32_t>& shape,
100 DType dtype,
101 const std::vector<Format>& format,
102 const std::string* npy_filename);
103 TosaSerializationTensor();
104 ~TosaSerializationTensor();
105
106 // copy constructor/assignment
107 TosaSerializationTensor(const TosaSerializationTensor& rhs);
108 TosaSerializationTensor& operator=(const TosaSerializationTensor& rhs);
109
110 // move constructor/assignment
111 TosaSerializationTensor(TosaSerializationTensor&& rhs);
112 TosaSerializationTensor& operator=(TosaSerializationTensor&& rhs);
113
114 // accessor
115 std::string GetName() const
116 {
117 return *_name;
118 }
119 const std::vector<int32_t>& GetShape() const
120 {
121 return *_shape;
122 }
123 DType GetDtype()
124 {
125 return _dtype;
126 }
127 bool HasFormat(Format format)
128 {
129 for (Format us : *_format)
130 {
131 if (us == format)
132 return true;
133 }
134 return false;
135 }
136 std::vector<Format>& GetFormat()
137 {
138 return *_format;
139 }
140 bool HasUsage(Usage usage)
141 {
142 for (Usage us : *_usage)
143 {
144 if (us == usage)
145 return true;
146 }
147 return false;
148 }
149 std::vector<Usage>& GetUsage()
150 {
151 return *_usage;
152 }
153 std::string* GetNpyFilePtr() const
154 {
155 return _npy_filename;
156 }
157
158 // modifier
159 void SetDtype(DType dtype)
160 {
161 _dtype = dtype;
162 }
163 void SetName(std::string name)
164 {
165 *_name = name;
166 }
167
168private:
169 DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
170 std::vector<Format>* _format; /* list of possible tensor format */
171 std::vector<Usage>* _usage; /* list of possible tensor usage */
172 std::vector<int32_t>* _shape; /* shape of the tensor */
173 std::string* _name; /* name of the tensor, used for solving dependency */
174 std::string* _npy_filename; /* numpy array filename if not null. so null is the distinguisher */
175};
176
177class TosaSerializationOperator
178{
179public:
180 // use default copy, void constructor
181 // constructor and destructor
182 TosaSerializationOperator(Op op_name,
183 Attribute attribute_type,
184 const TosaAttributeBase* attribute,
185 QuantInfo qinfo_type,
186 const TosaQuantInfoBase* qinfo,
187 std::vector<std::string> input_tensor_names,
188 std::vector<std::string> output_tensor_names);
189 ~TosaSerializationOperator();
190
191 // accessor
192 Op GetOp() const
193 {
194 return _op;
195 }
196 Attribute GetAttributeType() const
197 {
198 return _attribute_type;
199 }
200 TosaAttributeBase* GetAttribute() const
201 {
202 return _attribute;
203 }
204 QuantInfo GetQInfoType() const
205 {
206 return _qinfo_type;
207 }
208 TosaQuantInfoBase* GetQInfo() const
209 {
210 return _qinfo;
211 }
212 std::vector<std::string>& GetInputTensorNames() const
213 {
214 return *_input_tensor_names;
215 }
216 std::vector<std::string>& GetOutputTensorNames() const
217 {
218 return *_output_tensor_names;
219 }
220 std::vector<TosaSerializationTensor*>& GetInputTensors() const
221 {
222 return *_input_tensors;
223 }
224 std::vector<TosaSerializationTensor*>& GetOutputTensors() const
225 {
226 return *_output_tensors;
227 }
228
229private:
230 Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */
231 Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */
232 TosaAttributeBase* _attribute; /* real attribute class goes here */
233 QuantInfo _qinfo_type; /* QuantInfo enum */
234 TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */
235 std::vector<std::string>* _input_tensor_names; /* array of input tensor names */
236 std::vector<std::string>* _output_tensor_names; /* array of output tensor names */
237
238 std::vector<TosaSerializationTensor*>* _input_tensors; /* array of input TosaSerializationTensor */
239 std::vector<TosaSerializationTensor*>* _output_tensors; /* array of output TosaSerializationTensor */
240};
241
242class TosaSerializationBasicBlock
243{
244public:
245 // constructor and destructor
246 TosaSerializationBasicBlock(std::string name,
247 std::vector<TosaSerializationOperator*> operators,
248 std::vector<TosaSerializationTensor*> tensors,
249 std::vector<std::string> inputs,
250 std::vector<std::string> outputs);
251 ~TosaSerializationBasicBlock();
252
253 // accessor
254 std::string GetName() const
255 {
256 return *_name;
257 }
258 std::vector<TosaSerializationOperator*>& GetOperators()
259 {
260 return *_operators;
261 }
262 std::vector<TosaSerializationTensor*>& GetTensors()
263 {
264 return *_tensors;
265 }
266
267 TosaSerializationTensor* GetTensorByName(std::string name)
268 {
269 TosaSerializationTensor* result = nullptr;
270 for (auto tensor : GetTensors())
271 {
272 if (tensor->GetName() == name)
273 {
274 result = tensor;
275 break;
276 }
277 }
278 return result;
279 }
280
281 std::vector<std::string>& GetInputs()
282 {
283 return *_inputs;
284 }
285 std::vector<std::string>& GetOutputs()
286 {
287 return *_outputs;
288 }
289
290private:
291 std::string* _name; /* name of basic block */
292 std::vector<TosaSerializationOperator*>* _operators; /* TosaSerializationOperator list */
293 std::vector<TosaSerializationTensor*>* _tensors; /* TosaSerializationTensor list */
294 std::vector<std::string>* _inputs; /* array of string to specify block inputs */
295 std::vector<std::string>* _outputs; /* array of string to specify block outputs */
296};
297
298/*
299 * this is a helper class for writing/reading Tosa ISA
300 * supported format: .tosa (flatbuffer), .json
301 * and provide high-level std::vector-like interface
302 * to access internal data structure
303 */
304class TosaSerializationHandler
305{
306public:
307 // constructor and destructor
308 TosaSerializationHandler();
309 ~TosaSerializationHandler();
310
311 // file io
312 tosa_err_t LoadFileJson(const char* filename);
313 tosa_err_t LoadFileTosaFlatbuffer(const char* filename);
314 tosa_err_t SaveFileJson(const char* filename);
315 tosa_err_t SaveFileTosaFlatbuffer(const char* filename);
316 tosa_err_t LoadFileSchema(const char* filename);
317
318 // version
319 TosaVersion* GetTosaVersion() const
320 {
321 return _version;
322 }
323
324 // accessor
325 std::vector<TosaSerializationBasicBlock*>& GetBlocks()
326 {
327 return *_blocks;
328 }
329
330 TosaSerializationBasicBlock* GetBlockByName(std::string name)
331 {
332 TosaSerializationBasicBlock* result = nullptr;
333 for (auto block : GetBlocks())
334 {
335 if (block->GetName() == name)
336 {
337 result = block;
338 break;
339 }
340 }
341 return result;
342 }
343 TosaSerializationBasicBlock* GetMainBlock()
344 {
345 TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main"));
346 assert(main_block);
347 return main_block;
348 }
349
350 std::vector<std::string>& GetInputs()
351 {
352 return GetMainBlock()->GetInputs();
353 }
354 std::vector<std::string>& GetOutputs()
355 {
356 return GetMainBlock()->GetOutputs();
357 }
358
359 bool GetSchemaLoaded() const
360 {
361 return _schemaLoaded;
362 }
363
364protected:
365 tosa_err_t Clear();
366 tosa_err_t InitWithBuf(const uint8_t* buf);
367 tosa_err_t FreezeBuilder();
368 tosa_err_t SetTosaVersion();
369 tosa_err_t CheckTosaVersion(const TosaVersion& read_version);
370
371private:
372 TosaVersion* _version; /* tosa version */
373 flatbuffers::FlatBufferBuilder* _builder; /* flatbuffer builder */
374 flatbuffers::Parser* _parser; /* flatbuffer parser, used for json parsing */
375 std::vector<TosaSerializationBasicBlock*>* _blocks; /* array structure to store all TosaSerializationBasicBlock */
376 bool _schemaLoaded; /* is the schema properly loaded? */
377};
378
379class NumpyUtilities
380{
381public:
382 enum NPError
383 {
384 NO_ERROR = 0,
385 FILE_NOT_FOUND,
386 FILE_IO_ERROR,
387 FILE_TYPE_MISMATCH,
388 HEADER_PARSE_ERROR,
389 BUFFER_SIZE_MISMATCH,
390 };
391
392 static NPError readFromNpyFile(const char* filename, const uint32_t elems, float* buf);
393
394 static NPError readFromNpyFile(const char* filename, const uint32_t elems, int32_t* buf);
395
396 static NPError readFromNpyFile(const char* filename, const uint32_t elems, int64_t* buf);
397
398 static NPError readFromNpyFile(const char* filename, const uint32_t elems, bool* buf);
399
400 static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const bool* buf);
401
402 static NPError writeToNpyFile(const char* filename, const uint32_t elems, const bool* buf);
403
404 static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int32_t* buf);
405
406 static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int32_t* buf);
407
408 static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const int64_t* buf);
409
410 static NPError writeToNpyFile(const char* filename, const uint32_t elems, const int64_t* buf);
411
412 static NPError writeToNpyFile(const char* filename, const std::vector<int32_t>& shape, const float* buf);
413
414 static NPError writeToNpyFile(const char* filename, const uint32_t elems, const float* buf);
415
416private:
417 static NPError checkNpyHeader(FILE* infile, const uint32_t elems, const char* dtype_str);
418 static NPError writeNpyHeader(FILE* infile, const std::vector<int32_t>& shape, const char* dtype_str);
419};
420
421} // namespace tosa
422
423#endif // _TOSA_SERIALIZATION_HANDLER_H