blob: 440ca04373cbfa6f654a2601c01d06731ee98a2e [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef GRAPH_NODE_H
17#define GRAPH_NODE_H
18
19#include "attribute.h"
Kevin Chengacb550f2021-06-29 15:32:19 -070020#include "subgraph_traverser.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070021#include "tensor.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070022#include <iostream>
23
Tai Lya4d748b2023-03-28 22:06:56 +000024#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \
25 template class TosaReference::OP<RANK, TOSA_REF_TYPE_##DTYPE>;
Eric Kunzee5e26762020-10-13 16:11:07 -070026
27#define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
Tai Lya4d748b2023-03-28 22:06:56 +000028 template class TosaReference::OP<RANK, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>;
Eric Kunzee5e26762020-10-13 16:11:07 -070029
30#define DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
Tai Lya4d748b2023-03-28 22:06:56 +000031 template class TosaReference::OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE>;
Eric Kunzee5e26762020-10-13 16:11:07 -070032
33#define DEF_INSTANTIATE_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
Tai Lya4d748b2023-03-28 22:06:56 +000034 template class TosaReference::OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>;
Eric Kunzee5e26762020-10-13 16:11:07 -070035
Tai Lya4d748b2023-03-28 22:06:56 +000036#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE>;
Eric Kunzee5e26762020-10-13 16:11:07 -070037
Tai Lya4d748b2023-03-28 22:06:56 +000038#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) \
39 template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>;
Eric Kunzee5e26762020-10-13 16:11:07 -070040
Tai Lya4d748b2023-03-28 22:06:56 +000041#define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \
42 template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3>;
James Ward8b390432022-08-12 20:48:56 +010043
James Wardd34b3fc2023-01-18 14:51:25 +000044#define DEF_INSTANTIATE_THREE_TYPE_RESIZE(OP, DTYPE1, DTYPE2, OP_TYPE) \
Tai Lya4d748b2023-03-28 22:06:56 +000045 template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, OP_TYPE>;
TatWai Chongf7326092022-06-08 12:17:14 -070046
Eric Kunzee5e26762020-10-13 16:11:07 -070047#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
48 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \
49 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
50 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
51 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
52 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
53 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
54 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE)
55
56#define DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
57 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
58 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
59 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
60 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
61 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
62 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE)
63
64#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
65 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \
66 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \
67 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 2, DTYPE1, DTYPE2) \
68 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 3, DTYPE1, DTYPE2) \
69 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 4, DTYPE1, DTYPE2) \
70 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 5, DTYPE1, DTYPE2) \
71 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 6, DTYPE1, DTYPE2)
72
73#define DEF_INSTANTIATE_RESHAPE(OP, DTYPE) \
74 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 0, DTYPE) \
75 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 1, DTYPE) \
76 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 2, DTYPE) \
77 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 3, DTYPE) \
78 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 4, DTYPE) \
79 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 5, DTYPE) \
80 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 6, DTYPE) \
81 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 0, DTYPE) \
82 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \
83 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \
84 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \
85 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \
86 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \
87 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \
88 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 0, DTYPE) \
89 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \
90 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \
91 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \
92 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \
93 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \
94 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 6, DTYPE) \
95 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 0, DTYPE) \
96 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \
97 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \
98 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \
99 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \
100 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 5, DTYPE) \
101 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 6, DTYPE) \
102 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 0, DTYPE) \
103 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \
104 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \
105 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \
106 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 4, DTYPE) \
107 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 5, DTYPE) \
108 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 6, DTYPE) \
109 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 0, DTYPE) \
110 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \
111 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \
112 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 3, DTYPE) \
113 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 4, DTYPE) \
114 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 5, DTYPE) \
115 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 6, DTYPE) \
116 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 0, DTYPE) \
117 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) \
118 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 2, DTYPE) \
119 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 3, DTYPE) \
120 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 4, DTYPE) \
121 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \
122 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE)
123
Eric Kunzee5e26762020-10-13 16:11:07 -0700124#define INIT_ATTRIBUTE(ATTRIBUTE_NAME) \
125 if (auto p = dynamic_cast<Tosa##ATTRIBUTE_NAME##Attribute*>(attribute_)) \
126 { \
127 attribute = new Tosa##ATTRIBUTE_NAME##Attribute(p); \
128 ASSERT_MEM(attribute); \
129 } \
130 else \
131 { \
132 FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute"); \
133 }
134
Eric Kunzee5e26762020-10-13 16:11:07 -0700135namespace TosaReference
136{
137
Kevin Chengacb550f2021-06-29 15:32:19 -0700138class SubgraphTraverser;
139
Eric Kunzee5e26762020-10-13 16:11:07 -0700140// Nodes in the graph (e.g., tosa operators) are defined with this base
141// class.
142class GraphNode
143{
144public:
Kevin Chengacb550f2021-06-29 15:32:19 -0700145 GraphNode(SubgraphTraverser* parent_sgt_, const tosa::Op& nodeType_, const uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700146 virtual ~GraphNode();
147
148 int addInputName(std::string& name);
149 int addOutputName(std::string& name);
150
151 int addInputTensor(Tensor* tens);
152 int addOutputTensor(Tensor* tens);
Eric Kunzee5e26762020-10-13 16:11:07 -0700153 // Validate that the input tensors match properly
154 // in their types, attributes, rank, etc well enough to be
155 // processed.
156 //
157 // This function should be pure virtual (eventually) in order to force
158 // derivative operators to implement the check, but we'll initially
159 // provide a default function so that GraphNode can be instantiated
160 // directly for testing purposes.
161 virtual int checkTensorAttributes();
162
163 // Evalute the node/operator
164 virtual int eval();
165
166 int hasAllInputsReady() const;
167 int hasAllOutputsReady() const;
168
169 int dumpNode(FILE* out);
170 int dumpNode(std::ostream& out);
171
172 int setNodeMarked()
173 {
174 isMarked = true;
175 return 0;
176 }
177
178 int getNodeMarked() const
179 {
180 return isMarked;
181 }
182
183 int clearNodeMarked()
184 {
185 isMarked = false;
186 return 0;
187 }
188
189 int getEvalCount() const
190 {
191 return evalCount;
192 }
193
194 uint64_t getID() const
195 {
196 return nodeId;
197 }
198
199 std::vector<std::string>& getInputNames()
200 {
201 return inputNames;
202 }
203
204 std::vector<std::string>& getOutputNames()
205 {
206 return outputNames;
207 }
208
209 std::vector<Tensor*>& getOutputs()
210 {
211 return outputs;
212 }
213
214 std::vector<Tensor*>& getInputs()
215 {
216 return inputs;
217 }
218
219 int getOnNextNodeList() const
220 {
221 return onNextNodeList;
222 }
223
224 int setOnNextNodeList()
225 {
226 onNextNodeList = true;
227 return 0;
228 }
229
230 int clearOnNextNodeList()
231 {
232 onNextNodeList = false;
233 return 0;
234 }
235
236 tosa::Op getOp() const
237 {
238 return nodeType;
239 }
240
Jerry Ge9e94af82022-10-27 09:57:00 -0700241 SubgraphTraverser* getParentSGT()
242 {
243 return parent_sgt;
244 }
245
246 int setInMainBlock(bool isInMainBlock)
247 {
248 inMainBlock = isInMainBlock;
249 return 0;
250 }
251
252 bool getInMainBlock()
253 {
254 return inMainBlock;
255 }
256
TatWai Chongf7326092022-06-08 12:17:14 -0700257 // Helper functions.
258 int idiv_check(int input1, int input2, int& result);
TatWai Chongfb879822023-08-31 16:58:27 -0700259 int idiv_floor(int input1, int input2);
TatWai Chongf7326092022-06-08 12:17:14 -0700260
Eric Kunzee5e26762020-10-13 16:11:07 -0700261protected:
262 // Print out a node validation error
263 int printNodeValidationError(const std::string& msg);
264
265 int setRequiredOperands(const int in, const int out)
266 {
267 requiredInputCount = in;
268 requiredOutputCount = out;
269 return 0;
270 }
271
272 int setRequiredRank(const int min, const int max = -1)
273 {
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000274 requiredRankMin = min;
275 requiredRankMax = max;
Eric Kunzee5e26762020-10-13 16:11:07 -0700276
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000277 if (requiredRankMin >= 0 && requiredRankMax >= 0)
278 {
279 ASSERT_MSG(requiredRankMin <= requiredRankMax,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000280 "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin,
281 requiredRankMax);
Jerry Ge0bd4ec82023-05-01 18:36:43 +0000282 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700283
284 return 0;
285 }
286
287 int validateRequiredOperands();
288 int validateRequiredRank(const Tensor* t);
289
Kevin Chengacb550f2021-06-29 15:32:19 -0700290 // Parent SubgraphTraverser
291 SubgraphTraverser* parent_sgt;
292
Eric Kunzee5e26762020-10-13 16:11:07 -0700293 // Description of the node type (e.g., CONST, CONV2D, etc...)
294 tosa::Op nodeType;
295
296 // A list of input tensor names
297 std::vector<std::string> inputNames;
298
299 // A list of the output tensor names
300 std::vector<std::string> outputNames;
301
302 // A list of the input tensors (after names have been matched up)
303 std::vector<Tensor*> inputs;
304
305 // A list of the output tensors (after names have been matched up)
306 std::vector<Tensor*> outputs;
307
308 // Unique node ID for debugging
309 uint64_t nodeId;
310
311 // Flag used for graph analysis
312 int isMarked;
313
314 // Number of times eval() has been called for this node
315 int evalCount;
316
317 // Flag indicating that this node is ready and is on the
318 // next-node list.
319 int onNextNodeList;
320
321 // Required input/output tensor counts for node validation
322 // -1 means any number is allowed
323 int requiredInputCount;
324 int requiredOutputCount;
325
326 // Required rank ranges for input/output tensors
327 // -1 means n/a
328 int requiredRankMin;
329 int requiredRankMax;
Jerry Ge9e94af82022-10-27 09:57:00 -0700330
331 bool inMainBlock;
Eric Kunzee5e26762020-10-13 16:11:07 -0700332};
333
334}; // namespace TosaReference
335
336#endif