blob: b3fe8d6a0e8a3b564e4e7cf42073451a4e4ba745 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef GRAPH_NODE_H
17#define GRAPH_NODE_H
18
19#include "attribute.h"
Kevin Chengacb550f2021-06-29 15:32:19 -070020#include "subgraph_traverser.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070021#include "tensor.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070022#include <iostream>
23
Tai Lya4d748b2023-03-28 22:06:56 +000024#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \
25 template class TosaReference::OP<RANK, TOSA_REF_TYPE_##DTYPE>;
Eric Kunzee5e26762020-10-13 16:11:07 -070026
27#define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
Tai Lya4d748b2023-03-28 22:06:56 +000028 template class TosaReference::OP<RANK, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>;
Eric Kunzee5e26762020-10-13 16:11:07 -070029
30#define DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
Tai Lya4d748b2023-03-28 22:06:56 +000031 template class TosaReference::OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE>;
Eric Kunzee5e26762020-10-13 16:11:07 -070032
33#define DEF_INSTANTIATE_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
Tai Lya4d748b2023-03-28 22:06:56 +000034 template class TosaReference::OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>;
Eric Kunzee5e26762020-10-13 16:11:07 -070035
Tai Lya4d748b2023-03-28 22:06:56 +000036#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE>;
Eric Kunzee5e26762020-10-13 16:11:07 -070037
Tai Lya4d748b2023-03-28 22:06:56 +000038#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) \
39 template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>;
Eric Kunzee5e26762020-10-13 16:11:07 -070040
Tai Lya4d748b2023-03-28 22:06:56 +000041#define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \
42 template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3>;
James Ward8b390432022-08-12 20:48:56 +010043
James Wardd34b3fc2023-01-18 14:51:25 +000044#define DEF_INSTANTIATE_THREE_TYPE_RESIZE(OP, DTYPE1, DTYPE2, OP_TYPE) \
Tai Lya4d748b2023-03-28 22:06:56 +000045 template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, OP_TYPE>;
TatWai Chongf7326092022-06-08 12:17:14 -070046
Eric Kunzee5e26762020-10-13 16:11:07 -070047#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
48 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \
49 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
50 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
51 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
52 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
53 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
54 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE)
55
56#define DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
57 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
58 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
59 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
60 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
61 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
62 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE)
63
64#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
65 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \
66 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \
67 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 2, DTYPE1, DTYPE2) \
68 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 3, DTYPE1, DTYPE2) \
69 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 4, DTYPE1, DTYPE2) \
70 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 5, DTYPE1, DTYPE2) \
71 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 6, DTYPE1, DTYPE2)
72
73#define DEF_INSTANTIATE_RESHAPE(OP, DTYPE) \
74 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 0, DTYPE) \
75 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 1, DTYPE) \
76 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 2, DTYPE) \
77 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 3, DTYPE) \
78 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 4, DTYPE) \
79 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 5, DTYPE) \
80 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 6, DTYPE) \
81 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 0, DTYPE) \
82 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \
83 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \
84 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \
85 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \
86 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \
87 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \
88 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 0, DTYPE) \
89 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \
90 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \
91 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \
92 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \
93 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \
94 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 6, DTYPE) \
95 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 0, DTYPE) \
96 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \
97 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \
98 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \
99 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \
100 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 5, DTYPE) \
101 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 6, DTYPE) \
102 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 0, DTYPE) \
103 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \
104 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \
105 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \
106 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 4, DTYPE) \
107 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 5, DTYPE) \
108 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 6, DTYPE) \
109 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 0, DTYPE) \
110 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \
111 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \
112 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 3, DTYPE) \
113 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 4, DTYPE) \
114 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 5, DTYPE) \
115 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 6, DTYPE) \
116 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 0, DTYPE) \
117 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) \
118 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 2, DTYPE) \
119 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 3, DTYPE) \
120 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 4, DTYPE) \
121 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \
122 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE)
123
Eric Kunzee5e26762020-10-13 16:11:07 -0700124#define INIT_ATTRIBUTE(ATTRIBUTE_NAME) \
125 if (auto p = dynamic_cast<Tosa##ATTRIBUTE_NAME##Attribute*>(attribute_)) \
126 { \
127 attribute = new Tosa##ATTRIBUTE_NAME##Attribute(p); \
128 ASSERT_MEM(attribute); \
129 } \
130 else \
131 { \
132 FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute"); \
133 }
134
Eric Kunzee5e26762020-10-13 16:11:07 -0700135namespace TosaReference
136{
137
Kevin Chengacb550f2021-06-29 15:32:19 -0700138class SubgraphTraverser;
139
Eric Kunzee5e26762020-10-13 16:11:07 -0700140// Nodes in the graph (e.g., tosa operators) are defined with this base
141// class.
142class GraphNode
143{
144public:
Kevin Chengacb550f2021-06-29 15:32:19 -0700145 GraphNode(SubgraphTraverser* parent_sgt_, const tosa::Op& nodeType_, const uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700146 virtual ~GraphNode();
147
148 int addInputName(std::string& name);
149 int addOutputName(std::string& name);
150
151 int addInputTensor(Tensor* tens);
152 int addOutputTensor(Tensor* tens);
Eric Kunzee5e26762020-10-13 16:11:07 -0700153 // Validate that the input tensors match properly
154 // in their types, attributes, rank, etc well enough to be
155 // processed.
156 //
157 // This function should be pure virtual (eventually) in order to force
158 // derivative operators to implement the check, but we'll initially
159 // provide a default function so that GraphNode can be instantiated
160 // directly for testing purposes.
161 virtual int checkTensorAttributes();
162
163 // Evalute the node/operator
164 virtual int eval();
165
166 int hasAllInputsReady() const;
167 int hasAllOutputsReady() const;
168
169 int dumpNode(FILE* out);
170 int dumpNode(std::ostream& out);
171
172 int setNodeMarked()
173 {
174 isMarked = true;
175 return 0;
176 }
177
178 int getNodeMarked() const
179 {
180 return isMarked;
181 }
182
183 int clearNodeMarked()
184 {
185 isMarked = false;
186 return 0;
187 }
188
189 int getEvalCount() const
190 {
191 return evalCount;
192 }
193
194 uint64_t getID() const
195 {
196 return nodeId;
197 }
198
199 std::vector<std::string>& getInputNames()
200 {
201 return inputNames;
202 }
203
204 std::vector<std::string>& getOutputNames()
205 {
206 return outputNames;
207 }
208
209 std::vector<Tensor*>& getOutputs()
210 {
211 return outputs;
212 }
213
214 std::vector<Tensor*>& getInputs()
215 {
216 return inputs;
217 }
218
219 int getOnNextNodeList() const
220 {
221 return onNextNodeList;
222 }
223
224 int setOnNextNodeList()
225 {
226 onNextNodeList = true;
227 return 0;
228 }
229
230 int clearOnNextNodeList()
231 {
232 onNextNodeList = false;
233 return 0;
234 }
235
236 tosa::Op getOp() const
237 {
238 return nodeType;
239 }
240
Jerry Ge9e94af82022-10-27 09:57:00 -0700241 SubgraphTraverser* getParentSGT()
242 {
243 return parent_sgt;
244 }
245
246 int setInMainBlock(bool isInMainBlock)
247 {
248 inMainBlock = isInMainBlock;
249 return 0;
250 }
251
252 bool getInMainBlock()
253 {
254 return inMainBlock;
255 }
256
Tai Ly47625642023-09-07 20:49:09 +0000257 int getEvaluated() const
258 {
259 return evaluated;
260 }
261
262 int setEvaluated()
263 {
264 evaluated = true;
265 return 0;
266 }
267
268 int clearEvaluated()
269 {
270 evaluated = false;
271 return 0;
272 }
273
TatWai Chongf7326092022-06-08 12:17:14 -0700274 // Helper functions.
275 int idiv_check(int input1, int input2, int& result);
276
Eric Kunzee5e26762020-10-13 16:11:07 -0700277protected:
278 // Print out a node validation error
279 int printNodeValidationError(const std::string& msg);
280
281 int setRequiredOperands(const int in, const int out)
282 {
283 requiredInputCount = in;
284 requiredOutputCount = out;
285 return 0;
286 }
287
288 int setRequiredRank(const int min, const int max = -1)
289 {
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000290 requiredRankMin = min;
291 requiredRankMax = max;
Eric Kunzee5e26762020-10-13 16:11:07 -0700292
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000293 if (requiredRankMin >= 0 && requiredRankMax >= 0)
294 {
295 ASSERT_MSG(requiredRankMin <= requiredRankMax,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000296 "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin,
297 requiredRankMax);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000298 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700299
300 return 0;
301 }
302
303 int validateRequiredOperands();
304 int validateRequiredRank(const Tensor* t);
305
Kevin Chengacb550f2021-06-29 15:32:19 -0700306 // Parent SubgraphTraverser
307 SubgraphTraverser* parent_sgt;
308
Eric Kunzee5e26762020-10-13 16:11:07 -0700309 // Description of the node type (e.g., CONST, CONV2D, etc...)
310 tosa::Op nodeType;
311
312 // A list of input tensor names
313 std::vector<std::string> inputNames;
314
315 // A list of the output tensor names
316 std::vector<std::string> outputNames;
317
318 // A list of the input tensors (after names have been matched up)
319 std::vector<Tensor*> inputs;
320
321 // A list of the output tensors (after names have been matched up)
322 std::vector<Tensor*> outputs;
323
324 // Unique node ID for debugging
325 uint64_t nodeId;
326
327 // Flag used for graph analysis
328 int isMarked;
329
330 // Number of times eval() has been called for this node
331 int evalCount;
332
333 // Flag indicating that this node is ready and is on the
334 // next-node list.
335 int onNextNodeList;
336
Tai Ly47625642023-09-07 20:49:09 +0000337 // Flag indicating that this node has been evaluated before
338 int evaluated;
339
Eric Kunzee5e26762020-10-13 16:11:07 -0700340 // Required input/output tensor counts for node validation
341 // -1 means any number is allowed
342 int requiredInputCount;
343 int requiredOutputCount;
344
345 // Required rank ranges for input/output tensors
346 // -1 means n/a
347 int requiredRankMin;
348 int requiredRankMax;
Jerry Ge9e94af82022-10-27 09:57:00 -0700349
350 bool inMainBlock;
Eric Kunzee5e26762020-10-13 16:11:07 -0700351};
352
353}; // namespace TosaReference
354
355#endif