blob: 5f859327e59db5cb7942fc6e185772098370e3a0 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "graph_node.h"
17
18using namespace TosaReference;
19using namespace Eigen;
20using namespace tosa;
21
Kevin Chengacb550f2021-06-29 15:32:19 -070022GraphNode::GraphNode(SubgraphTraverser* parent_sgt_, const Op& nodeType_, const uint64_t id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070023{
Kevin Chengacb550f2021-06-29 15:32:19 -070024 parent_sgt = parent_sgt_;
25 nodeType = nodeType_;
26 nodeId = id_;
Eric Kunzee5e26762020-10-13 16:11:07 -070027 inputs.clear();
28 outputs.clear();
29 inputNames.clear();
30 outputNames.clear();
31 clearNodeMarked();
32 evalCount = 0;
33 clearOnNextNodeList();
34 setRequiredOperands(-1, -1);
35 setRequiredRank(-1);
36}
37
38GraphNode::~GraphNode()
39{}
40
41int GraphNode::addInputName(std::string& name)
42{
43 inputNames.push_back(name);
44 return 0;
45}
46
47int GraphNode::addOutputName(std::string& name)
48{
49 outputNames.push_back(name);
50 return 0;
51}
52
53int GraphNode::addInputTensor(Tensor* tens)
54{
55 ASSERT_MSG(tens, "GraphNode::addInputTensor: no tensor provided");
56 inputs.push_back(tens);
57 return 0;
58}
59
60int GraphNode::addOutputTensor(Tensor* tens)
61{
62 ASSERT_MSG(tens, "GraphNode::addOutputTensor: no tensor provided");
63 outputs.push_back(tens);
64 return 0;
65}
66
67int GraphNode::checkTensorAttributes()
68{
69 // Placeholder
70 return 0;
71}
72
73int GraphNode::eval()
74{
75 // Placeholder evaluation function
76 evalCount++;
77
78 // this should be set by derived op
79 for (auto ct : getOutputs())
80 {
81 ct->setIsValid();
82 }
83
84 return 0;
85}
86
87int GraphNode::hasAllInputsReady() const
88{
89 for (size_t i = 0; i < inputs.size(); i++)
90 {
91 if (!inputs[i]->getIsValid())
92 return false;
93 }
94
95 return true;
96}
97
98int GraphNode::hasAllOutputsReady() const
99{
100 for (size_t i = 0; i < outputs.size(); i++)
101 {
102 if (!outputs[i]->getIsValid())
103 return false;
104 }
105
106 return true;
107}
108
109int GraphNode::dumpNode(FILE* out)
110{
111 int i;
112 fprintf(out, "Node type: %s ID: %lu Eval Count: %d On next node list: %d Is marked: %d\n", EnumNamesOp()[nodeType],
113 nodeId, evalCount, onNextNodeList, isMarked);
114
115 i = 0;
116 for (Tensor* ins : inputs)
117 {
118 fprintf(out, " Input[%d] ", i++);
119 ins->dumpTensorParams(out);
120 }
121
122 i = 0;
123 for (Tensor* outs : outputs)
124 {
125 fprintf(out, " Output[%d] ", i++);
126 outs->dumpTensorParams(out);
127 }
128
129 return 0;
130}
131
132int GraphNode::dumpNode(std::ostream& out)
133{
134 int i;
135
136 out << "Node type: " << EnumNamesOp()[nodeType] << " ID: " << nodeId << " Eval count: " << evalCount
137 << " On next node list: " << onNextNodeList << " Is marked: " << isMarked << std::endl;
138
139 out << " Inputs:";
140 for (std::string& name : inputNames)
141 {
142 out << " " << name;
143 }
144 out << std::endl;
145
146 i = 0;
147 for (Tensor* ins : inputs)
148 {
149 out << " Input[" << i++ << "]: ";
150 ins->dumpTensorParams(out);
151 }
152
153 out << " Outputs:";
154 for (std::string& name : outputNames)
155 {
156 out << " " << name;
157 }
158 out << std::endl;
159
160 i = 0;
161 for (Tensor* outs : outputs)
162 {
163 out << " Output[" << i++ << "]: ";
164 outs->dumpTensorParams(out);
165 }
166 return 0;
167}
168
169int GraphNode::printNodeValidationError(const std::string& msg)
170{
171 std::cout << "Operator validation error: " << msg << std::endl;
172 ;
173 dumpNode(std::cout);
174
175 return 0;
176}
177
178int GraphNode::validateRequiredOperands()
179{
180 if (requiredInputCount >= 0 && inputs.size() != (size_t)requiredInputCount)
181 {
182 printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator must have " +
183 std::to_string(requiredInputCount) + " input(s)");
184 return 1;
185 }
186
187 if (requiredOutputCount >= 0 && outputs.size() != (size_t)requiredOutputCount)
188 {
189 printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator output must have exactly " +
190 std::to_string(requiredOutputCount) + " output(s)");
191 return 1;
192 }
193
194 return 0;
195}
196
197int GraphNode::validateRequiredRank(const Tensor* t)
198{
199 if (requiredRankMin >= 0 && requiredRankMax >= 0)
200 {
Kevin Cheng6097c3d2021-09-23 15:25:24 -0700201 std::string err_message = std::string(EnumNamesOp()[nodeType]) +
202 " operand has illegal rank=" + std::to_string(t->getRank()) + " not in range [" +
203 std::to_string(requiredRankMin) + "," + std::to_string(requiredRankMax) +
204 "]. tensorName: " + t->getName();
205 ERROR_IF(t->checkRequiredRank(requiredRankMin, requiredRankMax), "%s", err_message.c_str());
206
207 return 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700208 }
209
210 if (requiredRankMin >= 0)
211 {
Kevin Cheng6097c3d2021-09-23 15:25:24 -0700212 std::string err_message = std::string(EnumNamesOp()[nodeType]) +
213 " operand has illegal rank=" + std::to_string(t->getRank()) + " not equal to " +
214 std::to_string(requiredRankMin) + ". tensorName: " + t->getName();
215 ERROR_IF(t->checkRequiredRank(requiredRankMin), "%s", err_message.c_str());
216
217 return 0;
Eric Kunzee5e26762020-10-13 16:11:07 -0700218 }
219
220 return 0;
221}
TatWai Chongf7326092022-06-08 12:17:14 -0700222
223int GraphNode::idiv_check(int input1, int input2, int& result)
224{
225 ERROR_IF(input2 == 0, "idiv_check: input2 must not be zero");
226 ERROR_IF(input1 % input2 != 0, "idiv_check: input1 must be a multiple of input2");
227
228 result = input1 / input2;
229 return 0;
230}