blob: a9a336bc5938fa777f23f881e96763f7d21992bf [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#ifndef GRAPH_NODE_H
17#define GRAPH_NODE_H
18
19#include "attribute.h"
Kevin Chengacb550f2021-06-29 15:32:19 -070020#include "subgraph_traverser.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070021#include "tensor.h"
22#include "tosa_generated.h"
23#include <iostream>
24
25#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) template class TosaReference::OP<RANK, DType_##DTYPE>;
26
27#define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
28 template class TosaReference::OP<RANK, DType_##DTYPE1, DType_##DTYPE2>;
29
30#define DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
31 template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE>;
32
33#define DEF_INSTANTIATE_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
34 template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>;
35
Eric Kunzee5e26762020-10-13 16:11:07 -070036#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP<DType_##DTYPE>;
37
38#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2>;
39
James Wardd34b3fc2023-01-18 14:51:25 +000040#define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \
41 template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2, DType_##DTYPE3>;
James Ward8b390432022-08-12 20:48:56 +010042
James Wardd34b3fc2023-01-18 14:51:25 +000043#define DEF_INSTANTIATE_THREE_TYPE_RESIZE(OP, DTYPE1, DTYPE2, OP_TYPE) \
TatWai Chongf7326092022-06-08 12:17:14 -070044 template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2, OP_TYPE>;
45
Eric Kunzee5e26762020-10-13 16:11:07 -070046#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
47 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \
48 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
49 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
50 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
51 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
52 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
53 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE)
54
55#define DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
56 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
57 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
58 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
59 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
60 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
61 DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE)
62
63#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
64 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \
65 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \
66 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 2, DTYPE1, DTYPE2) \
67 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 3, DTYPE1, DTYPE2) \
68 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 4, DTYPE1, DTYPE2) \
69 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 5, DTYPE1, DTYPE2) \
70 DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 6, DTYPE1, DTYPE2)
71
72#define DEF_INSTANTIATE_RESHAPE(OP, DTYPE) \
73 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 0, DTYPE) \
74 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 1, DTYPE) \
75 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 2, DTYPE) \
76 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 3, DTYPE) \
77 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 4, DTYPE) \
78 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 5, DTYPE) \
79 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 6, DTYPE) \
80 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 0, DTYPE) \
81 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \
82 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \
83 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \
84 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \
85 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \
86 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \
87 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 0, DTYPE) \
88 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \
89 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \
90 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \
91 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \
92 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \
93 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 6, DTYPE) \
94 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 0, DTYPE) \
95 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \
96 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \
97 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \
98 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \
99 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 5, DTYPE) \
100 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 6, DTYPE) \
101 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 0, DTYPE) \
102 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \
103 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \
104 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \
105 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 4, DTYPE) \
106 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 5, DTYPE) \
107 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 6, DTYPE) \
108 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 0, DTYPE) \
109 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \
110 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \
111 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 3, DTYPE) \
112 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 4, DTYPE) \
113 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 5, DTYPE) \
114 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 6, DTYPE) \
115 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 0, DTYPE) \
116 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) \
117 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 2, DTYPE) \
118 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 3, DTYPE) \
119 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 4, DTYPE) \
120 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \
121 DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE)
122
Eric Kunzee5e26762020-10-13 16:11:07 -0700123#define INIT_ATTRIBUTE(ATTRIBUTE_NAME) \
124 if (auto p = dynamic_cast<Tosa##ATTRIBUTE_NAME##Attribute*>(attribute_)) \
125 { \
126 attribute = new Tosa##ATTRIBUTE_NAME##Attribute(p); \
127 ASSERT_MEM(attribute); \
128 } \
129 else \
130 { \
131 FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute"); \
132 }
133
Eric Kunzee5e26762020-10-13 16:11:07 -0700134namespace TosaReference
135{
136
Kevin Chengacb550f2021-06-29 15:32:19 -0700137class SubgraphTraverser;
138
Eric Kunzee5e26762020-10-13 16:11:07 -0700139// Nodes in the graph (e.g., tosa operators) are defined with this base
140// class.
141class GraphNode
142{
143public:
Kevin Chengacb550f2021-06-29 15:32:19 -0700144 GraphNode(SubgraphTraverser* parent_sgt_, const tosa::Op& nodeType_, const uint64_t id_);
Eric Kunzee5e26762020-10-13 16:11:07 -0700145 virtual ~GraphNode();
146
147 int addInputName(std::string& name);
148 int addOutputName(std::string& name);
149
150 int addInputTensor(Tensor* tens);
151 int addOutputTensor(Tensor* tens);
Eric Kunzee5e26762020-10-13 16:11:07 -0700152 // Validate that the input tensors match properly
153 // in their types, attributes, rank, etc well enough to be
154 // processed.
155 //
156 // This function should be pure virtual (eventually) in order to force
157 // derivative operators to implement the check, but we'll initially
158 // provide a default function so that GraphNode can be instantiated
159 // directly for testing purposes.
160 virtual int checkTensorAttributes();
161
162 // Evalute the node/operator
163 virtual int eval();
164
165 int hasAllInputsReady() const;
166 int hasAllOutputsReady() const;
167
168 int dumpNode(FILE* out);
169 int dumpNode(std::ostream& out);
170
171 int setNodeMarked()
172 {
173 isMarked = true;
174 return 0;
175 }
176
177 int getNodeMarked() const
178 {
179 return isMarked;
180 }
181
182 int clearNodeMarked()
183 {
184 isMarked = false;
185 return 0;
186 }
187
188 int getEvalCount() const
189 {
190 return evalCount;
191 }
192
193 uint64_t getID() const
194 {
195 return nodeId;
196 }
197
198 std::vector<std::string>& getInputNames()
199 {
200 return inputNames;
201 }
202
203 std::vector<std::string>& getOutputNames()
204 {
205 return outputNames;
206 }
207
208 std::vector<Tensor*>& getOutputs()
209 {
210 return outputs;
211 }
212
213 std::vector<Tensor*>& getInputs()
214 {
215 return inputs;
216 }
217
218 int getOnNextNodeList() const
219 {
220 return onNextNodeList;
221 }
222
223 int setOnNextNodeList()
224 {
225 onNextNodeList = true;
226 return 0;
227 }
228
229 int clearOnNextNodeList()
230 {
231 onNextNodeList = false;
232 return 0;
233 }
234
235 tosa::Op getOp() const
236 {
237 return nodeType;
238 }
239
Jerry Ge9e94af82022-10-27 09:57:00 -0700240 SubgraphTraverser* getParentSGT()
241 {
242 return parent_sgt;
243 }
244
245 int setInMainBlock(bool isInMainBlock)
246 {
247 inMainBlock = isInMainBlock;
248 return 0;
249 }
250
251 bool getInMainBlock()
252 {
253 return inMainBlock;
254 }
255
TatWai Chongf7326092022-06-08 12:17:14 -0700256 // Helper functions.
257 int idiv_check(int input1, int input2, int& result);
258
Eric Kunzee5e26762020-10-13 16:11:07 -0700259protected:
260 // Print out a node validation error
261 int printNodeValidationError(const std::string& msg);
262
263 int setRequiredOperands(const int in, const int out)
264 {
265 requiredInputCount = in;
266 requiredOutputCount = out;
267 return 0;
268 }
269
270 int setRequiredRank(const int min, const int max = -1)
271 {
272 if (max == -1)
273 {
274 requiredRankMin = requiredRankMax = min;
275 }
276 else
277 {
278 requiredRankMin = min;
279 requiredRankMax = max;
280 }
281
282 ASSERT_MSG(requiredRankMin <= requiredRankMax,
283 "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin,
284 requiredRankMax);
285
286 return 0;
287 }
288
289 int validateRequiredOperands();
290 int validateRequiredRank(const Tensor* t);
291
Kevin Chengacb550f2021-06-29 15:32:19 -0700292 // Parent SubgraphTraverser
293 SubgraphTraverser* parent_sgt;
294
Eric Kunzee5e26762020-10-13 16:11:07 -0700295 // Description of the node type (e.g., CONST, CONV2D, etc...)
296 tosa::Op nodeType;
297
298 // A list of input tensor names
299 std::vector<std::string> inputNames;
300
301 // A list of the output tensor names
302 std::vector<std::string> outputNames;
303
304 // A list of the input tensors (after names have been matched up)
305 std::vector<Tensor*> inputs;
306
307 // A list of the output tensors (after names have been matched up)
308 std::vector<Tensor*> outputs;
309
310 // Unique node ID for debugging
311 uint64_t nodeId;
312
313 // Flag used for graph analysis
314 int isMarked;
315
316 // Number of times eval() has been called for this node
317 int evalCount;
318
319 // Flag indicating that this node is ready and is on the
320 // next-node list.
321 int onNextNodeList;
322
323 // Required input/output tensor counts for node validation
324 // -1 means any number is allowed
325 int requiredInputCount;
326 int requiredOutputCount;
327
328 // Required rank ranges for input/output tensors
329 // -1 means n/a
330 int requiredRankMin;
331 int requiredRankMax;
Jerry Ge9e94af82022-10-27 09:57:00 -0700332
333 bool inMainBlock;
Eric Kunzee5e26762020-10-13 16:11:07 -0700334};
335
336}; // namespace TosaReference
337
338#endif