blob: 942652dd7e65eca59e5e7513264c5e85a701ffa9 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "control_flow.h"
17#include "subgraph_traverser.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070018using namespace TosaReference;
19using namespace Eigen;
20using namespace tosa;
21
Kevin Chengacb550f2021-06-29 15:32:19 -070022OpControlFlow::OpControlFlow(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, Op op_, uint64_t id_)
23 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070024{
25 tsh = tsh_;
26}
27
28OpControlFlow::~OpControlFlow()
29{}
30
31int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block,
32 std::vector<TosaReference::Tensor*>& block_inputs,
33 std::vector<TosaReference::Tensor*>& block_outputs)
34{
35 std::string block_name = block->GetName();
36
37 DEBUG_MED(OP, "Evaluating block %s", block_name.c_str());
38
Jerry Ge9e94af82022-10-27 09:57:00 -070039 SubgraphTraverser block_sgt(block, tsh, this->parent_sgt);
Eric Kunzee5e26762020-10-13 16:11:07 -070040
Kevin Cheng5d00c692021-10-15 20:06:00 +000041 ERROR_IF(block_sgt.initializeGraph(), "evalBlock(): Unable to initialize graph traverser for %s",
42 block_name.c_str());
43 ERROR_IF(block_sgt.linkTensorsAndNodes(), "evalBlock(): Failed to link tensors and nodes for %s",
44 block_name.c_str());
45 ERROR_IF(block_sgt.validateGraph(), "evalBlock(): Failed to validate subgraph for %s", block_name.c_str());
46 ERROR_IF(block_sgt.allocateTensor(), "evalBlock(): Failed to allocate tensor for %s", block_name.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -070047
Kevin Cheng5d00c692021-10-15 20:06:00 +000048 int num_input_tensors = block_sgt.getNumInputTensors();
49 int num_output_tensors = block_sgt.getNumOutputTensors();
Eric Kunzee5e26762020-10-13 16:11:07 -070050
51 for (size_t i = 0; i < block_inputs.size(); i++)
52 {
53 DEBUG_HIGH(OP, "Input[%ld]: %s", i, block_inputs[i]->getName().c_str());
54 }
55 for (size_t i = 0; i < block_outputs.size(); i++)
56 {
57 DEBUG_HIGH(OP, "Output[%ld]: %s", i, block_outputs[i]->getName().c_str());
58 }
59
60 ASSERT_MSG((size_t)num_input_tensors == block_inputs.size(),
61 "op block %s inputs[%lu] does not match with graph traverser's inputs[%d]", block_name.c_str(),
62 block_inputs.size(), num_input_tensors);
63 ASSERT_MSG((size_t)num_output_tensors == block_outputs.size(),
64 "op block %s outputs[%lu] does not match with graph traverser's outputs[%d]", block_name.c_str(),
65 block_outputs.size(), num_output_tensors);
66
67 // set graph traverser's input = basic block's input
68 for (int i = 0; i < num_input_tensors; i++)
69 {
Kevin Cheng5d00c692021-10-15 20:06:00 +000070 TosaReference::Tensor* tensor = block_sgt.getInputTensor(i);
71 ERROR_IF(!tensor->is_allocated(), "block %s input tensor %s are not initialized before use", block_name.c_str(),
72 tensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -070073
74 if (tensor->copyValueFrom(block_inputs[i]))
75 {
76 WARNING("Fail to copy tensor value %s -> %s", block_inputs[i]->getName().c_str(),
77 tensor->getName().c_str());
78 return 1;
79 }
80
Kevin Cheng14d7f7a2021-05-12 10:44:49 -070081 tensor->setIsValid();
82
Eric Kunzee5e26762020-10-13 16:11:07 -070083 // Push ready consumers to the next node list
84 for (auto gn : tensor->getConsumers())
85 {
86 if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
87 {
Kevin Cheng5d00c692021-10-15 20:06:00 +000088 block_sgt.addToNextNodeList(gn);
Eric Kunzee5e26762020-10-13 16:11:07 -070089 }
90 }
91 }
92
Kevin Cheng5d00c692021-10-15 20:06:00 +000093 ERROR_IF(block_sgt.evaluateAll(), "Error evaluating network. Giving up.");
94
95 // pass block status back
96 switch (block_sgt.getGraphStatus())
97 {
98 case GraphStatus::TOSA_VALID:
99 {
100 DEBUG_MED(OP, "Successfully evaluating block %s", block_name.c_str());
101 break;
102 }
103 case GraphStatus::TOSA_UNPREDICTABLE:
104 {
105 DEBUG_MED(OP, "Finish evaluating block %s but result is UNPREDICTABLE", block_name.c_str());
106 DEBUG_MED(OP, "Setting parent graph status to UNPREDICTABLE");
107 parent_sgt->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
108 break;
109 }
110 case GraphStatus::TOSA_ERROR:
111 {
112 DEBUG_MED(OP, "Fail evaluating block %s. Result is ERROR", block_name.c_str());
113 if (parent_sgt->getGraphStatus() != GraphStatus::TOSA_UNPREDICTABLE)
114 {
115 DEBUG_MED(OP, "Setting parent graph status to ERROR");
116 parent_sgt->setGraphStatus(GraphStatus::TOSA_ERROR);
117 return 1;
118 }
119 }
120 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700121
122 // make sure output tensor is evaluated and show its value
123 bool all_output_valid = true;
124 for (int i = 0; i < num_output_tensors; i++)
125 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000126 const TosaReference::Tensor* ct = block_sgt.getOutputTensor(i);
Eric Kunzee5e26762020-10-13 16:11:07 -0700127 ASSERT_MEM(ct);
128 if (!ct->getIsValid())
129 {
130 ct->dumpTensorParams(g_func_debug.func_debug_file);
131 if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
132 {
133 ct->dumpTensor(g_func_debug.func_debug_file);
134 }
135 all_output_valid = false;
136 }
137 }
138 if (!all_output_valid)
139 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000140 block_sgt.dumpGraph(g_func_debug.func_debug_file);
Kevin Cheng903763c2021-09-28 16:14:52 -0700141 ERROR_IF(true, "SubgraphTraverser \"%s\" error: Output tensors are not all valid at the end of evaluation.",
142 block_name.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700143 }
144
145 // set basic block's output = subgraph_traverser's output
146 for (int i = 0; i < num_output_tensors; i++)
147 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000148 TosaReference::Tensor* tensor = block_sgt.getOutputTensor(i);
149 ERROR_IF(!tensor->is_allocated(), "block %s input tensor %s are not initialized before use", block_name.c_str(),
150 tensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700151
152 if (block_outputs[i]->copyValueFrom(tensor))
153 {
154 WARNING("Fail to copy tensor value %s -> %s", tensor->getName().c_str(), outputs[i]->getName().c_str());
155 return 1;
156 }
157 }
158 return 0;
159}
160
Kevin Chengacb550f2021-06-29 15:32:19 -0700161OpCondIf::OpCondIf(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_)
162 : OpControlFlow(sgt_, tsh_, Op_COND_IF, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700163{
164 INIT_ATTRIBUTE(CondIf);
165}
166
167OpCondIf::~OpCondIf()
168{
169 if (attribute)
170 delete attribute;
171}
172
173int OpCondIf::checkTensorAttributes()
174{
Kevin Cheng5d00c692021-10-15 20:06:00 +0000175 ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand");
Eric Kunzee5e26762020-10-13 16:11:07 -0700176
Kevin Cheng5d00c692021-10-15 20:06:00 +0000177 ERROR_IF(inputs[0]->getDtype() != DType_BOOL || inputs[0]->getRank() != 0,
178 "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNamesDType()[inputs[0]->getDtype()],
179 inputs[0]->getRank());
Eric Kunzee5e26762020-10-13 16:11:07 -0700180
181 cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
182 ASSERT_MEM(cond);
183
Jerry Ge9e94af82022-10-27 09:57:00 -0700184 auto region_name = getParentSGT()->getRegionName();
185 auto curr_region = tsh->GetRegionByName(region_name);
186 then_block = curr_region->GetBlockByName(attribute->then_branch());
187 else_block = curr_region->GetBlockByName(attribute->else_branch());
Eric Kunzee5e26762020-10-13 16:11:07 -0700188
Kevin Cheng5d00c692021-10-15 20:06:00 +0000189 ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str());
190
191 ERROR_IF(!else_block, "OpCondIf: fail to resolve else_branch %s", attribute->else_branch().c_str());
192
193 // Make sure operator input/output matches block input/output
194 // Skip the first rank 0 bool tensor on input list
195 int32_t num_input_tensor = getInputs().size() - 1;
196 int32_t num_output_tensor = getOutputs().size();
Jerry Ge9e94af82022-10-27 09:57:00 -0700197
Kevin Cheng5d00c692021-10-15 20:06:00 +0000198 ERROR_IF((int32_t)then_block->GetInputs().size() != num_input_tensor,
199 "OpCondIf: then_block has unexpected number of input");
200 ERROR_IF((int32_t)else_block->GetInputs().size() != num_input_tensor,
201 "OpCondIf: else_block has unexpected number of input");
202 ERROR_IF((int32_t)then_block->GetOutputs().size() != num_output_tensor,
203 "OpCondIf: then_block has unexpected number of output");
204 ERROR_IF((int32_t)else_block->GetOutputs().size() != num_output_tensor,
205 "OpCondIf: else_block has unexpected number of output");
206
207 for (int32_t i = 0; i < num_input_tensor; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700208 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000209 Tensor* operator_input = getInputs()[i + 1];
210 std::string then_block_input_name = then_block->GetInputs()[i];
211 std::string else_block_input_name = else_block->GetInputs()[i];
212 TosaSerializationTensor* then_block_input = then_block->GetTensorByName(then_block_input_name);
213 TosaSerializationTensor* else_block_input = else_block->GetTensorByName(else_block_input_name);
214 ERROR_IF(operator_input->getDtype() != then_block_input->GetDtype(),
215 "OpCondIf: input tensor type mismatch with then_block input type");
216 ERROR_IF(operator_input->getDtype() != else_block_input->GetDtype(),
217 "OpCondIf: input tensor type mismatch with else_block input type");
218 ERROR_IF(operator_input->getRank() != (int32_t)then_block_input->GetShape().size(),
219 "OpCondIf: input tensor rank mismatch with then_block input rank");
220 ERROR_IF(operator_input->getRank() != (int32_t)else_block_input->GetShape().size(),
221 "OpCondIf: input tensor rank mismatch with else_block input rank");
222 for (int32_t d = 0; d < operator_input->getRank(); d++)
223 {
224 ERROR_IF(operator_input->getShape()[d] != then_block_input->GetShape()[d],
225 "OpCondIf: input tensor dimension mismatch with then_block input dimension");
226 ERROR_IF(operator_input->getShape()[d] != else_block_input->GetShape()[d],
227 "OpCondIf: input tensor dimension mismatch with else_block input dimension");
228 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700229 }
230
Kevin Cheng5d00c692021-10-15 20:06:00 +0000231 for (int32_t i = 0; i < num_output_tensor; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700232 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000233 Tensor* operator_output = getOutputs()[i];
234 std::string then_block_output_name = then_block->GetOutputs()[i];
235 std::string else_block_output_name = else_block->GetOutputs()[i];
236 TosaSerializationTensor* then_block_output = then_block->GetTensorByName(then_block_output_name);
237 TosaSerializationTensor* else_block_output = else_block->GetTensorByName(else_block_output_name);
238 ERROR_IF(operator_output->getDtype() != then_block_output->GetDtype(),
239 "OpCondIf: output tensor type mismatch with then_block output type");
240 ERROR_IF(operator_output->getDtype() != else_block_output->GetDtype(),
241 "OpCondIf: output tensor type mismatch with else_block output type");
242 ERROR_IF(operator_output->getRank() != (int32_t)then_block_output->GetShape().size(),
243 "OpCondIf: output tensor rank mismatch with then_block output rank");
244 ERROR_IF(operator_output->getRank() != (int32_t)else_block_output->GetShape().size(),
245 "OpCondIf: output tensor rank mismatch with else_block output rank");
246 for (int32_t d = 0; d < operator_output->getRank(); d++)
247 {
248 ERROR_IF(operator_output->getShape()[d] != then_block_output->GetShape()[d],
249 "OpCondIf: output tensor dimension mismatch with then_block output dimension");
250 ERROR_IF(operator_output->getShape()[d] != else_block_output->GetShape()[d],
251 "OpCondIf: output tensor dimension mismatch with else_block output dimension");
252 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700253 }
254
255 return 0;
256}
257
258int OpCondIf::eval()
259{
260 bool cond_val = cond->getTensor()(0);
261 std::vector<TosaReference::Tensor*> block_inputs(getInputs().begin() + 1, getInputs().end());
262
263 if (cond_val)
264 {
265 if (evalBlock(then_block, block_inputs, getOutputs()))
266 {
267 WARNING("OpCondIf: Fail to evaluate then branch block %s", attribute->then_branch().c_str());
268 return 1;
269 }
270 }
271 else
272 {
273 if (evalBlock(else_block, block_inputs, getOutputs()))
274 {
275 WARNING("OpCondIf: Fail to evaluate else branch block %s", attribute->else_branch().c_str());
276 return 1;
277 }
278 }
279
280 return GraphNode::eval();
281}
282
Kevin Chengacb550f2021-06-29 15:32:19 -0700283OpWhileLoop::OpWhileLoop(SubgraphTraverser* sgt_,
284 TosaSerializationHandler* tsh_,
285 TosaAttributeBase* attribute_,
286 uint64_t id_)
287 : OpControlFlow(sgt_, tsh_, Op_WHILE_LOOP, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700288{
289 INIT_ATTRIBUTE(WhileLoop);
290}
291
292OpWhileLoop::~OpWhileLoop()
293{
294 if (attribute)
295 delete attribute;
296}
297
298int OpWhileLoop::checkTensorAttributes()
299{
300 if (getInputs().size() <= 0)
301 {
302 WARNING("OpWhileLoop: must have at least 1 operands");
303 return 1;
304 }
305
306 if (getInputs().size() != getOutputs().size())
307 {
308 WARNING("OpWhileLoop: inputs and outputs size must match");
309 return 1;
310 }
311
Jerry Ge9e94af82022-10-27 09:57:00 -0700312 auto region_name = getParentSGT()->getRegionName();
313 auto curr_region = tsh->GetRegionByName(region_name);
314 cond_block = curr_region->GetBlockByName(attribute->cond_branch());
315 body_block = curr_region->GetBlockByName(attribute->body_branch());
Eric Kunzee5e26762020-10-13 16:11:07 -0700316
Kevin Cheng5d00c692021-10-15 20:06:00 +0000317 ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str());
318 ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str());
319
320 // Make sure operator input/output matches block input/output
321 int32_t num_block_tensor = getInputs().size();
322 ERROR_IF((int32_t)getOutputs().size() != num_block_tensor,
323 "OpWhileLoop: operator input tensor doesn't match output");
324 ERROR_IF((int32_t)cond_block->GetInputs().size() != num_block_tensor,
325 "OpWhileLoop: cond_block has unexpected number of input");
326 ERROR_IF((int32_t)body_block->GetInputs().size() != num_block_tensor,
327 "OpWhileLoop: body_block has unexpected number of input");
328 ERROR_IF((int32_t)body_block->GetOutputs().size() != num_block_tensor,
329 "OpWhileLoop: body_block has unexpected number of output");
330 for (int32_t i = 0; i < num_block_tensor; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700331 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000332 Tensor* operator_input = getInputs()[i];
333 Tensor* operator_output = getOutputs()[i];
334 ERROR_IF(operator_input->matchRankTypeShape(*operator_output),
335 "OpWhileLoop: operator input tensor mismatch operator output tensor");
336
337 std::string cond_block_input_name = cond_block->GetInputs()[i];
338 std::string body_block_input_name = body_block->GetInputs()[i];
339 std::string body_block_output_name = body_block->GetOutputs()[i];
340 TosaSerializationTensor* cond_block_input = cond_block->GetTensorByName(cond_block_input_name);
341 TosaSerializationTensor* body_block_input = body_block->GetTensorByName(body_block_input_name);
342 TosaSerializationTensor* body_block_output = body_block->GetTensorByName(body_block_output_name);
343
344 ERROR_IF(operator_input->getDtype() != cond_block_input->GetDtype(),
345 "OpWhileLoop: input tensor type mismatch with cond_block input type");
346 ERROR_IF(operator_input->getDtype() != body_block_input->GetDtype(),
347 "OpWhileLoop: input tensor type mismatch with body_block input type");
348 ERROR_IF(operator_input->getDtype() != body_block_output->GetDtype(),
349 "OpWhileLoop: input tensor type mismatch with body_block output type");
350 ERROR_IF(operator_input->getRank() != (int32_t)cond_block_input->GetShape().size(),
351 "OpWhileLoop: input tensor rank mismatch with cond_block input rank");
352 ERROR_IF(operator_input->getRank() != (int32_t)body_block_input->GetShape().size(),
353 "OpWhileLoop: input tensor rank mismatch with body_block input rank");
354 ERROR_IF(operator_input->getRank() != (int32_t)body_block_output->GetShape().size(),
355 "OpWhileLoop: input tensor rank mismatch with body_block output rank");
356
357 for (int32_t d = 0; d < operator_input->getRank(); d++)
358 {
359 ERROR_IF(operator_input->getShape()[d] != cond_block_input->GetShape()[d],
360 "OpWhileLoop: input tensor dimension mismatch with cond_block input dimension");
361 ERROR_IF(operator_input->getShape()[d] != body_block_input->GetShape()[d],
362 "OpWhileLoop: input tensor dimension mismatch with body_block input dimension");
363 ERROR_IF(operator_input->getShape()[d] != body_block_output->GetShape()[d],
364 "OpWhileLoop: input tensor dimension mismatch with body_block output dimension");
365 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700366 }
367
Kevin Cheng5d00c692021-10-15 20:06:00 +0000368 ERROR_IF(cond_block->GetOutputs().size() != 1, "OpWhileLoop: cond_block can only have 1 output tensor");
369 std::string cond_block_output_name = cond_block->GetOutputs()[0];
370 TosaSerializationTensor* cond_block_output = cond_block->GetTensorByName(cond_block_output_name);
371 ERROR_IF(cond_block_output->GetDtype() != DType_BOOL, "OpWhileLoop: cond_block output can only be bool type");
372 ERROR_IF(cond_block_output->GetShape().size() != 0, "OpWhileLoop: cond_block output can only be rank 0");
Eric Kunzee5e26762020-10-13 16:11:07 -0700373
374 return 0;
375}
376
377int OpWhileLoop::eval()
378{
379
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700380 TosaReference::Tensor0<bool> cond_output_ctensor(std::string("cond_output"), DType_BOOL, std::vector<int32_t>({}));
Eric Kunzee5e26762020-10-13 16:11:07 -0700381
382 cond_output_ctensor.allocate();
383 std::vector<TosaReference::Tensor*> cond_block_outputs;
384 cond_block_outputs.push_back(&cond_output_ctensor);
385
386 size_t num_input_output = getInputs().size();
387 size_t eval_count = 0;
388
389 while (eval_count++ < MAX_WHILE_LOOP_ITERATION)
390 {
391 if (evalBlock(cond_block, getInputs(), cond_block_outputs))
392 {
393 WARNING("OpWhileLoop: Fail to evaluate cond block %s", attribute->cond_branch().c_str());
394 return 1;
395 }
396 bool cond_val = cond_output_ctensor.getTensor()(0);
397 DEBUG_HIGH(OP, "Conditional block value: %d", cond_val);
398
399 if (cond_val)
400 {
401 if (evalBlock(body_block, getInputs(), getOutputs()))
402 {
403 WARNING("OpWhileLoop: Fail to evaluate body block %s", attribute->body_branch().c_str());
404 return 1;
405 }
406
407 // assigning output tensors value back to input tensors value for next iteration
408 for (size_t i = 0; i < num_input_output; i++)
409 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700410 getInputs()[i] = getOutputs()[i];
Eric Kunzee5e26762020-10-13 16:11:07 -0700411 }
412 }
413 else
414 {
415 // in last iteration or the case it never evaluates body block
416 // assign input tensors value to output tensors
417 for (size_t i = 0; i < num_input_output; i++)
418 {
419 if (getOutputs()[i]->copyValueFrom(getInputs()[i]))
420 {
421 WARNING("Fail to copy tensor value %s -> %s", getInputs()[i]->getName().c_str(),
422 getOutputs()[i]->getName().c_str());
423 return 1;
424 }
425 }
426 break;
427 }
428 }
429
430 return GraphNode::eval();
431}