blob: 14a8acca7a6f3dda06140454e7562fc3edbe6bad [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef GRAPH_NODE_H
17#define GRAPH_NODE_H
18
19#include "attribute.h"
20#include "quant_info.h"
Kevin Chengacb550f2021-06-29 15:32:19 -070021#include "subgraph_traverser.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070022#include "tensor.h"
23#include "tosa_generated.h"
24#include <iostream>
25
26#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) template class TosaReference::OP<RANK, DType_##DTYPE>;
27
28#define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
29 template class TosaReference::OP<RANK, DType_##DTYPE1, DType_##DTYPE2>;
30
31#define DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
32 template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE>;
33
34#define DEF_INSTANTIATE_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
35 template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>;
36
Eric Kunzee5e26762020-10-13 16:11:07 -070037#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP<DType_##DTYPE>;
38
39#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2>;
40
41#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
42 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \
43 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
44 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
45 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
46 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
47 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
48 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE)
49
50#define DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
51 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
52 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
53 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
54 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
55 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
56 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE)
57
58#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
59 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \
60 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \
61 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 2, DTYPE1, DTYPE2) \
62 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 3, DTYPE1, DTYPE2) \
63 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 4, DTYPE1, DTYPE2) \
64 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 5, DTYPE1, DTYPE2) \
65 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 6, DTYPE1, DTYPE2)
66
67#define DEF_INSTANTIATE_RESHAPE(OP, DTYPE) \
68 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 0, DTYPE) \
69 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 1, DTYPE) \
70 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 2, DTYPE) \
71 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 3, DTYPE) \
72 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 4, DTYPE) \
73 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 5, DTYPE) \
74 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 6, DTYPE) \
75 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 0, DTYPE) \
76 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \
77 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \
78 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \
79 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \
80 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \
81 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \
82 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 0, DTYPE) \
83 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \
84 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \
85 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \
86 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \
87 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \
88 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 6, DTYPE) \
89 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 0, DTYPE) \
90 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \
91 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \
92 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \
93 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \
94 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 5, DTYPE) \
95 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 6, DTYPE) \
96 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 0, DTYPE) \
97 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \
98 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \
99 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \
100 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 4, DTYPE) \
101 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 5, DTYPE) \
102 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 6, DTYPE) \
103 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 0, DTYPE) \
104 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \
105 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \
106 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 3, DTYPE) \
107 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 4, DTYPE) \
108 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 5, DTYPE) \
109 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 6, DTYPE) \
110 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 0, DTYPE) \
111 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) \
112 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 2, DTYPE) \
113 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 3, DTYPE) \
114 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 4, DTYPE) \
115 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \
116 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE)
117
Eric Kunzee5e26762020-10-13 16:11:07 -0700118#define INIT_ATTRIBUTE(ATTRIBUTE_NAME) \
119 if (auto p = dynamic_cast<Tosa##ATTRIBUTE_NAME##Attribute*>(attribute_)) \
120 { \
121 attribute = new Tosa##ATTRIBUTE_NAME##Attribute(p); \
122 ASSERT_MEM(attribute); \
123 } \
124 else \
125 { \
126 FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute"); \
127 }
128
129#define INIT_QINFO(QINFO_NAME) \
130 if (auto p = dynamic_cast<Tosa##QINFO_NAME##QuantInfo*>(qinfo_)) \
131 { \
132 qinfo = new Tosa##QINFO_NAME##QuantInfo(p); \
133 ASSERT_MEM(qinfo); \
134 } \
135 else \
136 { \
137 qinfo = nullptr; \
138 }
139
140namespace TosaReference
141{
142
Kevin Chengacb550f2021-06-29 15:32:19 -0700143class SubgraphTraverser;
144
Eric Kunzee5e26762020-10-13 16:11:07 -0700145// Nodes in the graph (e.g., tosa operators) are defined with this base
146// class.
147class GraphNode
148{
149public:
Kevin Chengacb550f2021-06-29 15:32:19 -0700150 GraphNode(SubgraphTraverser* parent_sgt_, const tosa::Op& nodeType_, const uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700151 virtual ~GraphNode();
152
153 int addInputName(std::string& name);
154 int addOutputName(std::string& name);
155
156 int addInputTensor(Tensor* tens);
157 int addOutputTensor(Tensor* tens);
158
159 // Validate that the input tensors match properly
160 // in their types, attributes, rank, etc well enough to be
161 // processed.
162 //
163 // This function should be pure virtual (eventually) in order to force
164 // derivative operators to implement the check, but we'll initially
165 // provide a default function so that GraphNode can be instantiated
166 // directly for testing purposes.
167 virtual int checkTensorAttributes();
168
169 // Evalute the node/operator
170 virtual int eval();
171
172 int hasAllInputsReady() const;
173 int hasAllOutputsReady() const;
174
175 int dumpNode(FILE* out);
176 int dumpNode(std::ostream& out);
177
178 int setNodeMarked()
179 {
180 isMarked = true;
181 return 0;
182 }
183
184 int getNodeMarked() const
185 {
186 return isMarked;
187 }
188
189 int clearNodeMarked()
190 {
191 isMarked = false;
192 return 0;
193 }
194
195 int getEvalCount() const
196 {
197 return evalCount;
198 }
199
200 uint64_t getID() const
201 {
202 return nodeId;
203 }
204
205 std::vector<std::string>& getInputNames()
206 {
207 return inputNames;
208 }
209
210 std::vector<std::string>& getOutputNames()
211 {
212 return outputNames;
213 }
214
215 std::vector<Tensor*>& getOutputs()
216 {
217 return outputs;
218 }
219
220 std::vector<Tensor*>& getInputs()
221 {
222 return inputs;
223 }
224
225 int getOnNextNodeList() const
226 {
227 return onNextNodeList;
228 }
229
230 int setOnNextNodeList()
231 {
232 onNextNodeList = true;
233 return 0;
234 }
235
236 int clearOnNextNodeList()
237 {
238 onNextNodeList = false;
239 return 0;
240 }
241
242 tosa::Op getOp() const
243 {
244 return nodeType;
245 }
246
247protected:
248 // Print out a node validation error
249 int printNodeValidationError(const std::string& msg);
250
251 int setRequiredOperands(const int in, const int out)
252 {
253 requiredInputCount = in;
254 requiredOutputCount = out;
255 return 0;
256 }
257
258 int setRequiredRank(const int min, const int max = -1)
259 {
260 if (max == -1)
261 {
262 requiredRankMin = requiredRankMax = min;
263 }
264 else
265 {
266 requiredRankMin = min;
267 requiredRankMax = max;
268 }
269
270 ASSERT_MSG(requiredRankMin <= requiredRankMax,
271 "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin,
272 requiredRankMax);
273
274 return 0;
275 }
276
277 int validateRequiredOperands();
278 int validateRequiredRank(const Tensor* t);
279
Kevin Chengacb550f2021-06-29 15:32:19 -0700280 // Parent SubgraphTraverser
281 SubgraphTraverser* parent_sgt;
282
Eric Kunzee5e26762020-10-13 16:11:07 -0700283 // Description of the node type (e.g., CONST, CONV2D, etc...)
284 tosa::Op nodeType;
285
286 // A list of input tensor names
287 std::vector<std::string> inputNames;
288
289 // A list of the output tensor names
290 std::vector<std::string> outputNames;
291
292 // A list of the input tensors (after names have been matched up)
293 std::vector<Tensor*> inputs;
294
295 // A list of the output tensors (after names have been matched up)
296 std::vector<Tensor*> outputs;
297
298 // Unique node ID for debugging
299 uint64_t nodeId;
300
301 // Flag used for graph analysis
302 int isMarked;
303
304 // Number of times eval() has been called for this node
305 int evalCount;
306
307 // Flag indicating that this node is ready and is on the
308 // next-node list.
309 int onNextNodeList;
310
311 // Required input/output tensor counts for node validation
312 // -1 means any number is allowed
313 int requiredInputCount;
314 int requiredOutputCount;
315
316 // Required rank ranges for input/output tensors
317 // -1 means n/a
318 int requiredRankMin;
319 int requiredRankMax;
320};
321
322}; // namespace TosaReference
323
324#endif