blob: 7105caf1bb66c77976639f904928e5fde3cfa9c5 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
2// Copyright (c) 2020, ARM Limited.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "control_flow.h"
17#include "subgraph_traverser.h"
18
19using namespace TosaReference;
20using namespace Eigen;
21using namespace tosa;
22
Kevin Chengacb550f2021-06-29 15:32:19 -070023OpControlFlow::OpControlFlow(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, Op op_, uint64_t id_)
24 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070025{
26 tsh = tsh_;
27}
28
29OpControlFlow::~OpControlFlow()
30{}
31
32int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block,
33 std::vector<TosaReference::Tensor*>& block_inputs,
34 std::vector<TosaReference::Tensor*>& block_outputs)
35{
36 std::string block_name = block->GetName();
37
38 DEBUG_MED(OP, "Evaluating block %s", block_name.c_str());
39
Kevin Cheng5d00c692021-10-15 20:06:00 +000040 SubgraphTraverser block_sgt(block, tsh);
Eric Kunzee5e26762020-10-13 16:11:07 -070041
Kevin Cheng5d00c692021-10-15 20:06:00 +000042 ERROR_IF(block_sgt.initializeGraph(), "evalBlock(): Unable to initialize graph traverser for %s",
43 block_name.c_str());
44 ERROR_IF(block_sgt.linkTensorsAndNodes(), "evalBlock(): Failed to link tensors and nodes for %s",
45 block_name.c_str());
46 ERROR_IF(block_sgt.validateGraph(), "evalBlock(): Failed to validate subgraph for %s", block_name.c_str());
47 ERROR_IF(block_sgt.allocateTensor(), "evalBlock(): Failed to allocate tensor for %s", block_name.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -070048
Kevin Cheng5d00c692021-10-15 20:06:00 +000049 int num_input_tensors = block_sgt.getNumInputTensors();
50 int num_output_tensors = block_sgt.getNumOutputTensors();
Eric Kunzee5e26762020-10-13 16:11:07 -070051
52 for (size_t i = 0; i < block_inputs.size(); i++)
53 {
54 DEBUG_HIGH(OP, "Input[%ld]: %s", i, block_inputs[i]->getName().c_str());
55 }
56 for (size_t i = 0; i < block_outputs.size(); i++)
57 {
58 DEBUG_HIGH(OP, "Output[%ld]: %s", i, block_outputs[i]->getName().c_str());
59 }
60
61 ASSERT_MSG((size_t)num_input_tensors == block_inputs.size(),
62 "op block %s inputs[%lu] does not match with graph traverser's inputs[%d]", block_name.c_str(),
63 block_inputs.size(), num_input_tensors);
64 ASSERT_MSG((size_t)num_output_tensors == block_outputs.size(),
65 "op block %s outputs[%lu] does not match with graph traverser's outputs[%d]", block_name.c_str(),
66 block_outputs.size(), num_output_tensors);
67
68 // set graph traverser's input = basic block's input
69 for (int i = 0; i < num_input_tensors; i++)
70 {
Kevin Cheng5d00c692021-10-15 20:06:00 +000071 TosaReference::Tensor* tensor = block_sgt.getInputTensor(i);
72 ERROR_IF(!tensor->is_allocated(), "block %s input tensor %s are not initialized before use", block_name.c_str(),
73 tensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -070074
75 if (tensor->copyValueFrom(block_inputs[i]))
76 {
77 WARNING("Fail to copy tensor value %s -> %s", block_inputs[i]->getName().c_str(),
78 tensor->getName().c_str());
79 return 1;
80 }
81
Kevin Cheng14d7f7a2021-05-12 10:44:49 -070082 tensor->setIsValid();
83
Eric Kunzee5e26762020-10-13 16:11:07 -070084 // Push ready consumers to the next node list
85 for (auto gn : tensor->getConsumers())
86 {
87 if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
88 {
Kevin Cheng5d00c692021-10-15 20:06:00 +000089 block_sgt.addToNextNodeList(gn);
Eric Kunzee5e26762020-10-13 16:11:07 -070090 }
91 }
92 }
93
Kevin Cheng5d00c692021-10-15 20:06:00 +000094 ERROR_IF(block_sgt.evaluateAll(), "Error evaluating network. Giving up.");
95
96 // pass block status back
97 switch (block_sgt.getGraphStatus())
98 {
99 case GraphStatus::TOSA_VALID:
100 {
101 DEBUG_MED(OP, "Successfully evaluating block %s", block_name.c_str());
102 break;
103 }
104 case GraphStatus::TOSA_UNPREDICTABLE:
105 {
106 DEBUG_MED(OP, "Finish evaluating block %s but result is UNPREDICTABLE", block_name.c_str());
107 DEBUG_MED(OP, "Setting parent graph status to UNPREDICTABLE");
108 parent_sgt->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
109 break;
110 }
111 case GraphStatus::TOSA_ERROR:
112 {
113 DEBUG_MED(OP, "Fail evaluating block %s. Result is ERROR", block_name.c_str());
114 if (parent_sgt->getGraphStatus() != GraphStatus::TOSA_UNPREDICTABLE)
115 {
116 DEBUG_MED(OP, "Setting parent graph status to ERROR");
117 parent_sgt->setGraphStatus(GraphStatus::TOSA_ERROR);
118 return 1;
119 }
120 }
121 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700122
123 // make sure output tensor is evaluated and show its value
124 bool all_output_valid = true;
125 for (int i = 0; i < num_output_tensors; i++)
126 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000127 const TosaReference::Tensor* ct = block_sgt.getOutputTensor(i);
Eric Kunzee5e26762020-10-13 16:11:07 -0700128 ASSERT_MEM(ct);
129 if (!ct->getIsValid())
130 {
131 ct->dumpTensorParams(g_func_debug.func_debug_file);
132 if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
133 {
134 ct->dumpTensor(g_func_debug.func_debug_file);
135 }
136 all_output_valid = false;
137 }
138 }
139 if (!all_output_valid)
140 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000141 block_sgt.dumpGraph(g_func_debug.func_debug_file);
Kevin Cheng903763c2021-09-28 16:14:52 -0700142 ERROR_IF(true, "SubgraphTraverser \"%s\" error: Output tensors are not all valid at the end of evaluation.",
143 block_name.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700144 }
145
146 // set basic block's output = subgraph_traverser's output
147 for (int i = 0; i < num_output_tensors; i++)
148 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000149 TosaReference::Tensor* tensor = block_sgt.getOutputTensor(i);
150 ERROR_IF(!tensor->is_allocated(), "block %s input tensor %s are not initialized before use", block_name.c_str(),
151 tensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700152
153 if (block_outputs[i]->copyValueFrom(tensor))
154 {
155 WARNING("Fail to copy tensor value %s -> %s", tensor->getName().c_str(), outputs[i]->getName().c_str());
156 return 1;
157 }
158 }
159 return 0;
160}
161
Kevin Chengacb550f2021-06-29 15:32:19 -0700162OpCondIf::OpCondIf(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_)
163 : OpControlFlow(sgt_, tsh_, Op_COND_IF, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700164{
165 INIT_ATTRIBUTE(CondIf);
166}
167
168OpCondIf::~OpCondIf()
169{
170 if (attribute)
171 delete attribute;
172}
173
174int OpCondIf::checkTensorAttributes()
175{
Kevin Cheng5d00c692021-10-15 20:06:00 +0000176 ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand");
Eric Kunzee5e26762020-10-13 16:11:07 -0700177
Kevin Cheng5d00c692021-10-15 20:06:00 +0000178 ERROR_IF(inputs[0]->getDtype() != DType_BOOL || inputs[0]->getRank() != 0,
179 "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNamesDType()[inputs[0]->getDtype()],
180 inputs[0]->getRank());
Eric Kunzee5e26762020-10-13 16:11:07 -0700181
182 cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
183 ASSERT_MEM(cond);
184
185 then_block = tsh->GetBlockByName(attribute->then_branch());
186 else_block = tsh->GetBlockByName(attribute->else_branch());
187
Kevin Cheng5d00c692021-10-15 20:06:00 +0000188 ERROR_IF(!then_block, "OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str());
189
190 ERROR_IF(!else_block, "OpCondIf: fail to resolve else_branch %s", attribute->else_branch().c_str());
191
192 // Make sure operator input/output matches block input/output
193 // Skip the first rank 0 bool tensor on input list
194 int32_t num_input_tensor = getInputs().size() - 1;
195 int32_t num_output_tensor = getOutputs().size();
196 ERROR_IF((int32_t)then_block->GetInputs().size() != num_input_tensor,
197 "OpCondIf: then_block has unexpected number of input");
198 ERROR_IF((int32_t)else_block->GetInputs().size() != num_input_tensor,
199 "OpCondIf: else_block has unexpected number of input");
200 ERROR_IF((int32_t)then_block->GetOutputs().size() != num_output_tensor,
201 "OpCondIf: then_block has unexpected number of output");
202 ERROR_IF((int32_t)else_block->GetOutputs().size() != num_output_tensor,
203 "OpCondIf: else_block has unexpected number of output");
204
205 for (int32_t i = 0; i < num_input_tensor; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700206 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000207 Tensor* operator_input = getInputs()[i + 1];
208 std::string then_block_input_name = then_block->GetInputs()[i];
209 std::string else_block_input_name = else_block->GetInputs()[i];
210 TosaSerializationTensor* then_block_input = then_block->GetTensorByName(then_block_input_name);
211 TosaSerializationTensor* else_block_input = else_block->GetTensorByName(else_block_input_name);
212 ERROR_IF(operator_input->getDtype() != then_block_input->GetDtype(),
213 "OpCondIf: input tensor type mismatch with then_block input type");
214 ERROR_IF(operator_input->getDtype() != else_block_input->GetDtype(),
215 "OpCondIf: input tensor type mismatch with else_block input type");
216 ERROR_IF(operator_input->getRank() != (int32_t)then_block_input->GetShape().size(),
217 "OpCondIf: input tensor rank mismatch with then_block input rank");
218 ERROR_IF(operator_input->getRank() != (int32_t)else_block_input->GetShape().size(),
219 "OpCondIf: input tensor rank mismatch with else_block input rank");
220 for (int32_t d = 0; d < operator_input->getRank(); d++)
221 {
222 ERROR_IF(operator_input->getShape()[d] != then_block_input->GetShape()[d],
223 "OpCondIf: input tensor dimension mismatch with then_block input dimension");
224 ERROR_IF(operator_input->getShape()[d] != else_block_input->GetShape()[d],
225 "OpCondIf: input tensor dimension mismatch with else_block input dimension");
226 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700227 }
228
Kevin Cheng5d00c692021-10-15 20:06:00 +0000229 for (int32_t i = 0; i < num_output_tensor; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700230 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000231 Tensor* operator_output = getOutputs()[i];
232 std::string then_block_output_name = then_block->GetOutputs()[i];
233 std::string else_block_output_name = else_block->GetOutputs()[i];
234 TosaSerializationTensor* then_block_output = then_block->GetTensorByName(then_block_output_name);
235 TosaSerializationTensor* else_block_output = else_block->GetTensorByName(else_block_output_name);
236 ERROR_IF(operator_output->getDtype() != then_block_output->GetDtype(),
237 "OpCondIf: output tensor type mismatch with then_block output type");
238 ERROR_IF(operator_output->getDtype() != else_block_output->GetDtype(),
239 "OpCondIf: output tensor type mismatch with else_block output type");
240 ERROR_IF(operator_output->getRank() != (int32_t)then_block_output->GetShape().size(),
241 "OpCondIf: output tensor rank mismatch with then_block output rank");
242 ERROR_IF(operator_output->getRank() != (int32_t)else_block_output->GetShape().size(),
243 "OpCondIf: output tensor rank mismatch with else_block output rank");
244 for (int32_t d = 0; d < operator_output->getRank(); d++)
245 {
246 ERROR_IF(operator_output->getShape()[d] != then_block_output->GetShape()[d],
247 "OpCondIf: output tensor dimension mismatch with then_block output dimension");
248 ERROR_IF(operator_output->getShape()[d] != else_block_output->GetShape()[d],
249 "OpCondIf: output tensor dimension mismatch with else_block output dimension");
250 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700251 }
252
253 return 0;
254}
255
256int OpCondIf::eval()
257{
258 bool cond_val = cond->getTensor()(0);
259 std::vector<TosaReference::Tensor*> block_inputs(getInputs().begin() + 1, getInputs().end());
260
261 if (cond_val)
262 {
263 if (evalBlock(then_block, block_inputs, getOutputs()))
264 {
265 WARNING("OpCondIf: Fail to evaluate then branch block %s", attribute->then_branch().c_str());
266 return 1;
267 }
268 }
269 else
270 {
271 if (evalBlock(else_block, block_inputs, getOutputs()))
272 {
273 WARNING("OpCondIf: Fail to evaluate else branch block %s", attribute->else_branch().c_str());
274 return 1;
275 }
276 }
277
278 return GraphNode::eval();
279}
280
Kevin Chengacb550f2021-06-29 15:32:19 -0700281OpWhileLoop::OpWhileLoop(SubgraphTraverser* sgt_,
282 TosaSerializationHandler* tsh_,
283 TosaAttributeBase* attribute_,
284 uint64_t id_)
285 : OpControlFlow(sgt_, tsh_, Op_WHILE_LOOP, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700286{
287 INIT_ATTRIBUTE(WhileLoop);
288}
289
290OpWhileLoop::~OpWhileLoop()
291{
292 if (attribute)
293 delete attribute;
294}
295
296int OpWhileLoop::checkTensorAttributes()
297{
298 if (getInputs().size() <= 0)
299 {
300 WARNING("OpWhileLoop: must have at least 1 operands");
301 return 1;
302 }
303
304 if (getInputs().size() != getOutputs().size())
305 {
306 WARNING("OpWhileLoop: inputs and outputs size must match");
307 return 1;
308 }
309
310 cond_block = tsh->GetBlockByName(attribute->cond_branch());
311 body_block = tsh->GetBlockByName(attribute->body_branch());
312
Kevin Cheng5d00c692021-10-15 20:06:00 +0000313 ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str());
314 ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str());
315
316 // Make sure operator input/output matches block input/output
317 int32_t num_block_tensor = getInputs().size();
318 ERROR_IF((int32_t)getOutputs().size() != num_block_tensor,
319 "OpWhileLoop: operator input tensor doesn't match output");
320 ERROR_IF((int32_t)cond_block->GetInputs().size() != num_block_tensor,
321 "OpWhileLoop: cond_block has unexpected number of input");
322 ERROR_IF((int32_t)body_block->GetInputs().size() != num_block_tensor,
323 "OpWhileLoop: body_block has unexpected number of input");
324 ERROR_IF((int32_t)body_block->GetOutputs().size() != num_block_tensor,
325 "OpWhileLoop: body_block has unexpected number of output");
326 for (int32_t i = 0; i < num_block_tensor; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700327 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000328 Tensor* operator_input = getInputs()[i];
329 Tensor* operator_output = getOutputs()[i];
330 ERROR_IF(operator_input->matchRankTypeShape(*operator_output),
331 "OpWhileLoop: operator input tensor mismatch operator output tensor");
332
333 std::string cond_block_input_name = cond_block->GetInputs()[i];
334 std::string body_block_input_name = body_block->GetInputs()[i];
335 std::string body_block_output_name = body_block->GetOutputs()[i];
336 TosaSerializationTensor* cond_block_input = cond_block->GetTensorByName(cond_block_input_name);
337 TosaSerializationTensor* body_block_input = body_block->GetTensorByName(body_block_input_name);
338 TosaSerializationTensor* body_block_output = body_block->GetTensorByName(body_block_output_name);
339
340 ERROR_IF(operator_input->getDtype() != cond_block_input->GetDtype(),
341 "OpWhileLoop: input tensor type mismatch with cond_block input type");
342 ERROR_IF(operator_input->getDtype() != body_block_input->GetDtype(),
343 "OpWhileLoop: input tensor type mismatch with body_block input type");
344 ERROR_IF(operator_input->getDtype() != body_block_output->GetDtype(),
345 "OpWhileLoop: input tensor type mismatch with body_block output type");
346 ERROR_IF(operator_input->getRank() != (int32_t)cond_block_input->GetShape().size(),
347 "OpWhileLoop: input tensor rank mismatch with cond_block input rank");
348 ERROR_IF(operator_input->getRank() != (int32_t)body_block_input->GetShape().size(),
349 "OpWhileLoop: input tensor rank mismatch with body_block input rank");
350 ERROR_IF(operator_input->getRank() != (int32_t)body_block_output->GetShape().size(),
351 "OpWhileLoop: input tensor rank mismatch with body_block output rank");
352
353 for (int32_t d = 0; d < operator_input->getRank(); d++)
354 {
355 ERROR_IF(operator_input->getShape()[d] != cond_block_input->GetShape()[d],
356 "OpWhileLoop: input tensor dimension mismatch with cond_block input dimension");
357 ERROR_IF(operator_input->getShape()[d] != body_block_input->GetShape()[d],
358 "OpWhileLoop: input tensor dimension mismatch with body_block input dimension");
359 ERROR_IF(operator_input->getShape()[d] != body_block_output->GetShape()[d],
360 "OpWhileLoop: input tensor dimension mismatch with body_block output dimension");
361 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700362 }
363
Kevin Cheng5d00c692021-10-15 20:06:00 +0000364 ERROR_IF(cond_block->GetOutputs().size() != 1, "OpWhileLoop: cond_block can only have 1 output tensor");
365 std::string cond_block_output_name = cond_block->GetOutputs()[0];
366 TosaSerializationTensor* cond_block_output = cond_block->GetTensorByName(cond_block_output_name);
367 ERROR_IF(cond_block_output->GetDtype() != DType_BOOL, "OpWhileLoop: cond_block output can only be bool type");
368 ERROR_IF(cond_block_output->GetShape().size() != 0, "OpWhileLoop: cond_block output can only be rank 0");
Eric Kunzee5e26762020-10-13 16:11:07 -0700369
370 return 0;
371}
372
373int OpWhileLoop::eval()
374{
375
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700376 TosaReference::Tensor0<bool> cond_output_ctensor(std::string("cond_output"), DType_BOOL, std::vector<int32_t>({}));
Eric Kunzee5e26762020-10-13 16:11:07 -0700377
378 cond_output_ctensor.allocate();
379 std::vector<TosaReference::Tensor*> cond_block_outputs;
380 cond_block_outputs.push_back(&cond_output_ctensor);
381
382 size_t num_input_output = getInputs().size();
383 size_t eval_count = 0;
384
385 while (eval_count++ < MAX_WHILE_LOOP_ITERATION)
386 {
387 if (evalBlock(cond_block, getInputs(), cond_block_outputs))
388 {
389 WARNING("OpWhileLoop: Fail to evaluate cond block %s", attribute->cond_branch().c_str());
390 return 1;
391 }
392 bool cond_val = cond_output_ctensor.getTensor()(0);
393 DEBUG_HIGH(OP, "Conditional block value: %d", cond_val);
394
395 if (cond_val)
396 {
397 if (evalBlock(body_block, getInputs(), getOutputs()))
398 {
399 WARNING("OpWhileLoop: Fail to evaluate body block %s", attribute->body_branch().c_str());
400 return 1;
401 }
402
403 // assigning output tensors value back to input tensors value for next iteration
404 for (size_t i = 0; i < num_input_output; i++)
405 {
406 if (getInputs()[i]->copyValueFrom(getOutputs()[i]))
407 {
408 WARNING("Fail to copy tensor value %s -> %s", getOutputs()[i]->getName().c_str(),
409 getInputs()[i]->getName().c_str());
410 return 1;
411 }
412 }
413 }
414 else
415 {
416 // in last iteration or the case it never evaluates body block
417 // assign input tensors value to output tensors
418 for (size_t i = 0; i < num_input_output; i++)
419 {
420 if (getOutputs()[i]->copyValueFrom(getInputs()[i]))
421 {
422 WARNING("Fail to copy tensor value %s -> %s", getInputs()[i]->getName().c_str(),
423 getOutputs()[i]->getName().c_str());
424 return 1;
425 }
426 }
427 break;
428 }
429 }
430
431 return GraphNode::eval();
432}