blob: 7fd82828f090bda457d4e53d71d02a6aa65b17b8 [file] [log] [blame]
Eric Kunze2364dcd2021-04-26 11:06:57 -07001
2// Copyright (c) 2020-2021, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef _TOSA_SERIALIZATION_HANDLER_H
17#define _TOSA_SERIALIZATION_HANDLER_H
18#include "attribute.h"
19#include "flatbuffers/idl.h"
20#include "flatbuffers/util.h"
21#include "numpy_utils.h"
22#include "quant_info.h"
23#include "tosa_generated.h"
24#include <cstdint>
25#include <memory>
26#include <string>
27#include <vector>
28
Kevin Chengb97cb1d2021-10-14 11:53:39 -070029#define TOSA_VERSION_MAJOR 0
30#define TOSA_VERSION_MINOR 23
31#define TOSA_VERSION_PATCH 0
32#define TOSA_VERSION_DRAFT true
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070033#define TENSOR_BUFFER_FORCE_ALIGNMENT 8
34
Eric Kunze2364dcd2021-04-26 11:06:57 -070035namespace tosa
36{
37
38enum tosa_err_t
39{
40 TOSA_OK,
41 TOSA_USER_ERROR,
42 TOSA_FILE_ERROR,
43 TOSA_MEMORY_ERROR,
44 TOSA_SCHEMA_MISSING,
45 TOSA_INTERNAL_ERROR,
46 TOSA_VERSION_MISMATCH,
47 NUM_TOSA_ERROR
48};
49
Eric Kunze2364dcd2021-04-26 11:06:57 -070050class TosaSerializationHandler;
51
52class TosaSerializationTensor
53{
54public:
55 // constructor and destructor
56 TosaSerializationTensor(const flatbuffers::String* name,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070057 const flatbuffers::Vector<int32_t>* shape,
Eric Kunze2364dcd2021-04-26 11:06:57 -070058 DType dtype,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070059 const flatbuffers::Vector<uint8_t>* data);
Eric Kunze2364dcd2021-04-26 11:06:57 -070060 TosaSerializationTensor(std::string& name,
61 const std::vector<int32_t>& shape,
62 DType dtype,
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070063 const std::vector<uint8_t>& data);
Eric Kunze2364dcd2021-04-26 11:06:57 -070064 TosaSerializationTensor();
65 ~TosaSerializationTensor();
66
67 // accessor
68 std::string GetName() const
69 {
70 return _name;
71 }
72 const std::vector<int32_t>& GetShape() const
73 {
74 return _shape;
75 }
76 DType GetDtype()
77 {
78 return _dtype;
79 }
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070080 const std::vector<uint8_t>& GetData() const
Eric Kunze2364dcd2021-04-26 11:06:57 -070081 {
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070082 return _data;
Eric Kunze2364dcd2021-04-26 11:06:57 -070083 }
84
85 // modifier
86 void SetDtype(DType dtype)
87 {
88 _dtype = dtype;
89 }
90 void SetName(std::string name)
91 {
92 _name = name;
93 }
94
95private:
96 DType _dtype; /* data type enumeration, see tosa_isa_generated.h */
97 std::vector<int32_t> _shape; /* shape of the tensor */
98 std::string _name; /* name of the tensor, used for solving dependency */
Kevin Cheng3bb1bc12021-06-17 15:57:08 -070099 std::vector<uint8_t> _data; /* data array */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700100};
101
102class TosaSerializationOperator
103{
104public:
105 // use default copy, void constructor
106 // constructor and destructor
107 TosaSerializationOperator(Op op,
108 Attribute attribute_type,
109 const TosaAttributeBase* attribute,
110 QuantInfo qinfo_type,
111 const TosaQuantInfoBase* qinfo,
112 std::vector<std::string> input_tensor_names,
113 std::vector<std::string> output_tensor_names);
114 ~TosaSerializationOperator();
115
116 // accessor
117 Op GetOp() const
118 {
119 return _op;
120 }
121 Attribute GetAttributeType() const
122 {
123 return _attribute_type;
124 }
125 TosaAttributeBase* GetAttribute() const
126 {
127 return _attribute;
128 }
129 QuantInfo GetQInfoType() const
130 {
131 return _qinfo_type;
132 }
133 TosaQuantInfoBase* GetQInfo() const
134 {
135 return _qinfo;
136 }
137 std::vector<std::string>& GetInputTensorNames()
138 {
139 return _input_tensor_names;
140 }
141 std::vector<std::string>& GetOutputTensorNames()
142 {
143 return _output_tensor_names;
144 }
145
146private:
147 Op _op; /* operator enum, see tosa_isa_generated.h for enumeration table */
148 Attribute _attribute_type; /* operator attribute enum, used for dynamic casting TosaAttributeBase class */
149 TosaAttributeBase* _attribute; /* real attribute class goes here */
150 QuantInfo _qinfo_type; /* QuantInfo enum */
151 TosaQuantInfoBase* _qinfo; /* base class pointer of QuantInfo */
152 std::vector<std::string> _input_tensor_names; /* array of input tensor names */
153 std::vector<std::string> _output_tensor_names; /* array of output tensor names */
154};
155
156class TosaSerializationBasicBlock
157{
158public:
159 // constructor and destructor
160 TosaSerializationBasicBlock(std::string name,
161 std::vector<TosaSerializationOperator*> operators,
162 std::vector<TosaSerializationTensor*> tensors,
163 std::vector<std::string> inputs,
164 std::vector<std::string> outputs);
165 ~TosaSerializationBasicBlock();
166
167 // accessor
168 std::string GetName() const
169 {
170 return _name;
171 }
172 std::vector<TosaSerializationOperator*>& GetOperators()
173 {
174 return _operators;
175 }
176 std::vector<TosaSerializationTensor*>& GetTensors()
177 {
178 return _tensors;
179 }
180
181 TosaSerializationTensor* GetTensorByName(std::string name)
182 {
183 TosaSerializationTensor* result = nullptr;
184 for (auto tensor : GetTensors())
185 {
186 if (tensor->GetName() == name)
187 {
188 result = tensor;
189 break;
190 }
191 }
192 return result;
193 }
194
195 std::vector<std::string>& GetInputs()
196 {
197 return _inputs;
198 }
199 std::vector<std::string>& GetOutputs()
200 {
201 return _outputs;
202 }
203
204private:
205 std::string _name; /* name of basic block */
206 std::vector<TosaSerializationOperator*> _operators; /* TosaSerializationOperator list */
207 std::vector<TosaSerializationTensor*> _tensors; /* TosaSerializationTensor list */
208 std::vector<std::string> _inputs; /* array of string to specify block inputs */
209 std::vector<std::string> _outputs; /* array of string to specify block outputs */
210};
211
212/*
213 * this is a helper class for writing/reading Tosa ISA
214 * supported format: .tosa (flatbuffer), .json
215 * and provide high-level std::vector-like interface
216 * to access internal data structure
217 */
218class TosaSerializationHandler
219{
220public:
221 // constructor and destructor
222 TosaSerializationHandler();
223 ~TosaSerializationHandler();
224
225 // file io
226 tosa_err_t LoadFileJson(const char* filename);
227 tosa_err_t LoadFileTosaFlatbuffer(const char* filename);
228 tosa_err_t SaveFileJson(const char* filename);
229 tosa_err_t SaveFileTosaFlatbuffer(const char* filename);
230 tosa_err_t LoadFileSchema(const char* schema_filename);
231
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700232 // data format conversion. little-endian.
233 static tosa_err_t ConvertF32toU8(const std::vector<float>& in, std::vector<uint8_t>& out);
234 static tosa_err_t ConvertI48toU8(const std::vector<int64_t>& in, std::vector<uint8_t>& out);
235 static tosa_err_t ConvertI32toU8(const std::vector<int32_t>& in, std::vector<uint8_t>& out);
236 static tosa_err_t ConvertI16toU8(const std::vector<int16_t>& in, std::vector<uint8_t>& out);
237 static tosa_err_t ConvertI8toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
Kevin Cheng3ce56342021-07-28 13:42:29 -0700238 static tosa_err_t ConvertI4toU8(const std::vector<int8_t>& in, std::vector<uint8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700239 static tosa_err_t ConvertBooltoU8(const std::vector<bool>& in, std::vector<uint8_t>& out);
240
241 static tosa_err_t ConvertU8toF32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<float>& out);
242 static tosa_err_t ConvertU8toI48(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int64_t>& out);
243 static tosa_err_t ConvertU8toI32(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int32_t>& out);
244 static tosa_err_t ConvertU8toI16(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int16_t>& out);
245 static tosa_err_t ConvertU8toI8(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
Kevin Cheng3ce56342021-07-28 13:42:29 -0700246 static tosa_err_t ConvertU8toI4(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<int8_t>& out);
Kevin Cheng3bb1bc12021-06-17 15:57:08 -0700247 static tosa_err_t ConvertU8toBool(const std::vector<uint8_t>& in, uint32_t out_size, std::vector<bool>& out);
248
Eric Kunze2364dcd2021-04-26 11:06:57 -0700249 // version
Kevin Chengb97cb1d2021-10-14 11:53:39 -0700250 const std::string& GetVersionStr()
Eric Kunze2364dcd2021-04-26 11:06:57 -0700251 {
252 return _version;
253 }
254
255 // accessor
256 std::vector<TosaSerializationBasicBlock*>& GetBlocks()
257 {
258 return _blocks;
259 }
260
261 TosaSerializationBasicBlock* GetBlockByName(std::string name)
262 {
263 TosaSerializationBasicBlock* result = nullptr;
264 for (auto block : GetBlocks())
265 {
266 if (block->GetName() == name)
267 {
268 result = block;
269 break;
270 }
271 }
272 return result;
273 }
274 TosaSerializationBasicBlock* GetMainBlock()
275 {
276 TosaSerializationBasicBlock* main_block = GetBlockByName(std::string("main"));
277 assert(main_block);
278 return main_block;
279 }
280
281 std::vector<std::string>& GetInputs()
282 {
283 return GetMainBlock()->GetInputs();
284 }
285 std::vector<std::string>& GetOutputs()
286 {
287 return GetMainBlock()->GetOutputs();
288 }
289
290 bool GetSchemaLoaded() const
291 {
292 return _schemaLoaded;
293 }
294
295protected:
296 tosa_err_t Clear();
Kevin Chengb97cb1d2021-10-14 11:53:39 -0700297 tosa_err_t Deserialize(const uint8_t* buf);
298 tosa_err_t Serialize();
299 std::string VersionToStr(int32_t major, int32_t minor, int32_t patch, bool draft);
Eric Kunze2364dcd2021-04-26 11:06:57 -0700300
301private:
Kevin Chengb97cb1d2021-10-14 11:53:39 -0700302 std::string _version; /* version string */
Eric Kunze2364dcd2021-04-26 11:06:57 -0700303 flatbuffers::FlatBufferBuilder _builder; /* flatbuffer builder */
304 flatbuffers::Parser _parser; /* flatbuffer parser, used for json parsing */
305 std::vector<TosaSerializationBasicBlock*> _blocks; /* array structure to store all TosaSerializationBasicBlock */
306 bool _schemaLoaded; /* is the schema properly loaded? */
307};
308
309} // namespace tosa
310
311#endif // _TOSA_SERIALIZATION_HANDLER_H