blob: 1781e40543ac221919327d73bdbd8781bf12c20c [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "graph_node.h"
17
18using namespace TosaReference;
19using namespace Eigen;
20using namespace tosa;
21
Kevin Chengacb550f2021-06-29 15:32:19 -070022GraphNode::GraphNode(SubgraphTraverser* parent_sgt_, const Op& nodeType_, const uint64_t id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070023{
Kevin Chengacb550f2021-06-29 15:32:19 -070024 parent_sgt = parent_sgt_;
25 nodeType = nodeType_;
26 nodeId = id_;
Eric Kunzee5e26762020-10-13 16:11:07 -070027 inputs.clear();
28 outputs.clear();
29 inputNames.clear();
30 outputNames.clear();
31 clearNodeMarked();
32 evalCount = 0;
33 clearOnNextNodeList();
34 setRequiredOperands(-1, -1);
35 setRequiredRank(-1);
Jerry Ge9e94af82022-10-27 09:57:00 -070036 inMainBlock = false;
Eric Kunzee5e26762020-10-13 16:11:07 -070037}
38
39GraphNode::~GraphNode()
40{}
41
42int GraphNode::addInputName(std::string& name)
43{
44 inputNames.push_back(name);
45 return 0;
46}
47
48int GraphNode::addOutputName(std::string& name)
49{
50 outputNames.push_back(name);
51 return 0;
52}
53
54int GraphNode::addInputTensor(Tensor* tens)
55{
56 ASSERT_MSG(tens, "GraphNode::addInputTensor: no tensor provided");
57 inputs.push_back(tens);
58 return 0;
59}
60
61int GraphNode::addOutputTensor(Tensor* tens)
62{
63 ASSERT_MSG(tens, "GraphNode::addOutputTensor: no tensor provided");
64 outputs.push_back(tens);
65 return 0;
66}
67
68int GraphNode::checkTensorAttributes()
69{
70 // Placeholder
71 return 0;
72}
73
74int GraphNode::eval()
75{
76 // Placeholder evaluation function
77 evalCount++;
78
79 // this should be set by derived op
80 for (auto ct : getOutputs())
81 {
82 ct->setIsValid();
83 }
84
85 return 0;
86}
87
88int GraphNode::hasAllInputsReady() const
89{
90 for (size_t i = 0; i < inputs.size(); i++)
91 {
92 if (!inputs[i]->getIsValid())
93 return false;
94 }
95
96 return true;
97}
98
99int GraphNode::hasAllOutputsReady() const
100{
101 for (size_t i = 0; i < outputs.size(); i++)
102 {
103 if (!outputs[i]->getIsValid())
104 return false;
105 }
106
107 return true;
108}
109
110int GraphNode::dumpNode(FILE* out)
111{
112 int i;
113 fprintf(out, "Node type: %s ID: %lu Eval Count: %d On next node list: %d Is marked: %d\n", EnumNamesOp()[nodeType],
114 nodeId, evalCount, onNextNodeList, isMarked);
115
116 i = 0;
117 for (Tensor* ins : inputs)
118 {
119 fprintf(out, " Input[%d] ", i++);
120 ins->dumpTensorParams(out);
121 }
122
123 i = 0;
124 for (Tensor* outs : outputs)
125 {
126 fprintf(out, " Output[%d] ", i++);
127 outs->dumpTensorParams(out);
128 }
129
130 return 0;
131}
132
133int GraphNode::dumpNode(std::ostream& out)
134{
135 int i;
136
137 out << "Node type: " << EnumNamesOp()[nodeType] << " ID: " << nodeId << " Eval count: " << evalCount
138 << " On next node list: " << onNextNodeList << " Is marked: " << isMarked << std::endl;
139
140 out << " Inputs:";
141 for (std::string& name : inputNames)
142 {
143 out << " " << name;
144 }
145 out << std::endl;
146
147 i = 0;
148 for (Tensor* ins : inputs)
149 {
150 out << " Input[" << i++ << "]: ";
151 ins->dumpTensorParams(out);
152 }
153
154 out << " Outputs:";
155 for (std::string& name : outputNames)
156 {
157 out << " " << name;
158 }
159 out << std::endl;
160
161 i = 0;
162 for (Tensor* outs : outputs)
163 {
164 out << " Output[" << i++ << "]: ";
165 outs->dumpTensorParams(out);
166 }
167 return 0;
168}
169
170int GraphNode::printNodeValidationError(const std::string& msg)
171{
172 std::cout << "Operator validation error: " << msg << std::endl;
173 ;
174 dumpNode(std::cout);
175
176 return 0;
177}
178
179int GraphNode::validateRequiredOperands()
180{
181 if (requiredInputCount >= 0 && inputs.size() != (size_t)requiredInputCount)
182 {
183 printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator must have " +
184 std::to_string(requiredInputCount) + " input(s)");
185 return 1;
186 }
187
188 if (requiredOutputCount >= 0 && outputs.size() != (size_t)requiredOutputCount)
189 {
190 printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator output must have exactly " +
191 std::to_string(requiredOutputCount) + " output(s)");
192 return 1;
193 }
194
195 return 0;
196}
197
198int GraphNode::validateRequiredRank(const Tensor* t)
199{
200 if (requiredRankMin >= 0 && requiredRankMax >= 0)
201 {
Kevin Cheng6097c3d2021-09-23 15:25:24 -0700202 std::string err_message = std::string(EnumNamesOp()[nodeType]) +
203 " operand has illegal rank=" + std::to_string(t->getRank()) + " not in range [" +
204 std::to_string(requiredRankMin) + "," + std::to_string(requiredRankMax) +
205 "]. tensorName: " + t->getName();
206 ERROR_IF(t->checkRequiredRank(requiredRankMin, requiredRankMax), "%s", err_message.c_str());
207
208 return 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700209 }
210
211 if (requiredRankMin >= 0)
212 {
Kevin Cheng6097c3d2021-09-23 15:25:24 -0700213 std::string err_message = std::string(EnumNamesOp()[nodeType]) +
214 " operand has illegal rank=" + std::to_string(t->getRank()) + " not equal to " +
215 std::to_string(requiredRankMin) + ". tensorName: " + t->getName();
216 ERROR_IF(t->checkRequiredRank(requiredRankMin), "%s", err_message.c_str());
217
218 return 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700219 }
220
221 return 0;
222}
TatWai Chongf7326092022-06-08 12:17:14 -0700223
224int GraphNode::idiv_check(int input1, int input2, int& result)
225{
226 ERROR_IF(input2 == 0, "idiv_check: input2 must not be zero");
227 ERROR_IF(input1 % input2 != 0, "idiv_check: input1 must be a multiple of input2");
228
229 result = input1 / input2;
230 return 0;
231}