blob: ac09bbb0bb7fe18208f904f4b52d79e9051cdc2f [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "control_flow.h"
17#include "subgraph_traverser.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070018using namespace TosaReference;
19using namespace Eigen;
20using namespace tosa;
21
Kevin Chengacb550f2021-06-29 15:32:19 -070022OpControlFlow::OpControlFlow(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, Op op_, uint64_t id_)
23 : GraphNode(sgt_, op_, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -070024{
25 tsh = tsh_;
26}
27
28OpControlFlow::~OpControlFlow()
29{}
30
31int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block,
32 std::vector<TosaReference::Tensor*>& block_inputs,
33 std::vector<TosaReference::Tensor*>& block_outputs)
34{
35 std::string block_name = block->GetName();
36
37 DEBUG_MED(OP, "Evaluating block %s", block_name.c_str());
38
Jerry Ge9e94af82022-10-27 09:57:00 -070039 SubgraphTraverser block_sgt(block, tsh, this->parent_sgt);
Eric Kunzee5e26762020-10-13 16:11:07 -070040
Kevin Cheng5d00c692021-10-15 20:06:00 +000041 ERROR_IF(block_sgt.initializeGraph(), "evalBlock(): Unable to initialize graph traverser for %s",
42 block_name.c_str());
43 ERROR_IF(block_sgt.linkTensorsAndNodes(), "evalBlock(): Failed to link tensors and nodes for %s",
44 block_name.c_str());
45 ERROR_IF(block_sgt.validateGraph(), "evalBlock(): Failed to validate subgraph for %s", block_name.c_str());
Jerry Gee5cabbf2023-07-17 21:33:17 +000046 ERROR_IF(block_sgt.allocateInputTensors(), "evalBlock(): Failed to allocate input tensors for %s",
47 block_name.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -070048
Kevin Cheng5d00c692021-10-15 20:06:00 +000049 int num_input_tensors = block_sgt.getNumInputTensors();
50 int num_output_tensors = block_sgt.getNumOutputTensors();
Eric Kunzee5e26762020-10-13 16:11:07 -070051
52 for (size_t i = 0; i < block_inputs.size(); i++)
53 {
54 DEBUG_HIGH(OP, "Input[%ld]: %s", i, block_inputs[i]->getName().c_str());
55 }
56 for (size_t i = 0; i < block_outputs.size(); i++)
57 {
58 DEBUG_HIGH(OP, "Output[%ld]: %s", i, block_outputs[i]->getName().c_str());
59 }
60
61 ASSERT_MSG((size_t)num_input_tensors == block_inputs.size(),
62 "op block %s inputs[%lu] does not match with graph traverser's inputs[%d]", block_name.c_str(),
63 block_inputs.size(), num_input_tensors);
64 ASSERT_MSG((size_t)num_output_tensors == block_outputs.size(),
65 "op block %s outputs[%lu] does not match with graph traverser's outputs[%d]", block_name.c_str(),
66 block_outputs.size(), num_output_tensors);
67
68 // set graph traverser's input = basic block's input
69 for (int i = 0; i < num_input_tensors; i++)
70 {
Kevin Cheng5d00c692021-10-15 20:06:00 +000071 TosaReference::Tensor* tensor = block_sgt.getInputTensor(i);
72 ERROR_IF(!tensor->is_allocated(), "block %s input tensor %s are not initialized before use", block_name.c_str(),
73 tensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -070074
75 if (tensor->copyValueFrom(block_inputs[i]))
76 {
77 WARNING("Fail to copy tensor value %s -> %s", block_inputs[i]->getName().c_str(),
78 tensor->getName().c_str());
79 return 1;
80 }
81
Kevin Cheng14d7f7a2021-05-12 10:44:49 -070082 tensor->setIsValid();
83
Eric Kunzee5e26762020-10-13 16:11:07 -070084 // Push ready consumers to the next node list
85 for (auto gn : tensor->getConsumers())
86 {
87 if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
88 {
Kevin Cheng5d00c692021-10-15 20:06:00 +000089 block_sgt.addToNextNodeList(gn);
Eric Kunzee5e26762020-10-13 16:11:07 -070090 }
91 }
92 }
93
Kevin Cheng5d00c692021-10-15 20:06:00 +000094 ERROR_IF(block_sgt.evaluateAll(), "Error evaluating network. Giving up.");
95
96 // pass block status back
97 switch (block_sgt.getGraphStatus())
98 {
Jerry Ge9c9c8da2023-07-19 23:08:16 +000099 case GraphStatus::TOSA_VALID: {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000100 DEBUG_MED(OP, "Successfully evaluating block %s", block_name.c_str());
101 break;
102 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000103 case GraphStatus::TOSA_UNPREDICTABLE: {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000104 DEBUG_MED(OP, "Finish evaluating block %s but result is UNPREDICTABLE", block_name.c_str());
105 DEBUG_MED(OP, "Setting parent graph status to UNPREDICTABLE");
106 parent_sgt->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
107 break;
108 }
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000109 case GraphStatus::TOSA_ERROR: {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000110 DEBUG_MED(OP, "Fail evaluating block %s. Result is ERROR", block_name.c_str());
111 if (parent_sgt->getGraphStatus() != GraphStatus::TOSA_UNPREDICTABLE)
112 {
113 DEBUG_MED(OP, "Setting parent graph status to ERROR");
114 parent_sgt->setGraphStatus(GraphStatus::TOSA_ERROR);
115 return 1;
116 }
117 }
118 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700119
120 // make sure output tensor is evaluated and show its value
121 bool all_output_valid = true;
122 for (int i = 0; i < num_output_tensors; i++)
123 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000124 const TosaReference::Tensor* ct = block_sgt.getOutputTensor(i);
Eric Kunzee5e26762020-10-13 16:11:07 -0700125 ASSERT_MEM(ct);
126 if (!ct->getIsValid())
127 {
128 ct->dumpTensorParams(g_func_debug.func_debug_file);
129 if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
130 {
131 ct->dumpTensor(g_func_debug.func_debug_file);
132 }
133 all_output_valid = false;
134 }
135 }
136 if (!all_output_valid)
137 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000138 block_sgt.dumpGraph(g_func_debug.func_debug_file);
Kevin Cheng903763c2021-09-28 16:14:52 -0700139 ERROR_IF(true, "SubgraphTraverser \"%s\" error: Output tensors are not all valid at the end of evaluation.",
140 block_name.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700141 }
142
143 // set basic block's output = subgraph_traverser's output
144 for (int i = 0; i < num_output_tensors; i++)
145 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000146 TosaReference::Tensor* tensor = block_sgt.getOutputTensor(i);
147 ERROR_IF(!tensor->is_allocated(), "block %s input tensor %s are not initialized before use", block_name.c_str(),
148 tensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700149
150 if (block_outputs[i]->copyValueFrom(tensor))
151 {
152 WARNING("Fail to copy tensor value %s -> %s", tensor->getName().c_str(), outputs[i]->getName().c_str());
153 return 1;
154 }
155 }
156 return 0;
157}
158
Kevin Chengacb550f2021-06-29 15:32:19 -0700159OpCondIf::OpCondIf(SubgraphTraverser* sgt_, TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_)
160 : OpControlFlow(sgt_, tsh_, Op_COND_IF, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700161{
162 INIT_ATTRIBUTE(CondIf);
163}
164
165OpCondIf::~OpCondIf()
166{
167 if (attribute)
168 delete attribute;
169}
170
171int OpCondIf::checkTensorAttributes()
172{
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100173 ERROR_IF(!tsh, "OpCondIf: tosa serialization handler must not be null");
174
Kevin Cheng5d00c692021-10-15 20:06:00 +0000175 ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand");
Eric Kunzee5e26762020-10-13 16:11:07 -0700176
Tai Lya4d748b2023-03-28 22:06:56 +0000177 ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0,
178 "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNameTOSAREFTYPE(inputs[0]->getDtype()),
Kevin Cheng5d00c692021-10-15 20:06:00 +0000179 inputs[0]->getRank());
Eric Kunzee5e26762020-10-13 16:11:07 -0700180
181 cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
182 ASSERT_MEM(cond);
183
Tai Ly8ead6c42024-02-14 22:35:44 +0000184 auto then_region = tsh->GetRegionByName(attribute->then_graph());
185 auto else_region = tsh->GetRegionByName(attribute->else_graph());
Tai Ly4e9a9772023-03-16 22:24:05 +0000186 if (then_region && else_region)
187 {
Tai Ly8ead6c42024-02-14 22:35:44 +0000188 // new serialization: then_graph and else_graph point to regions
Tai Ly4e9a9772023-03-16 22:24:05 +0000189 then_block = then_region->GetBlocks().front();
190 else_block = else_region->GetBlocks().front();
191 }
192 else
193 {
Tai Ly8ead6c42024-02-14 22:35:44 +0000194 // old serialization: then_graph and else_graph point to blocks in curr_region
Tai Ly4e9a9772023-03-16 22:24:05 +0000195 auto region_name = getParentSGT()->getRegionName();
196 auto curr_region = tsh->GetRegionByName(region_name);
Tai Ly8ead6c42024-02-14 22:35:44 +0000197 then_block = curr_region->GetBlockByName(attribute->then_graph());
198 else_block = curr_region->GetBlockByName(attribute->else_graph());
Tai Ly4e9a9772023-03-16 22:24:05 +0000199 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700200
Tai Ly8ead6c42024-02-14 22:35:44 +0000201 ERROR_IF(!then_block, "OpCondIf: fail to resolve then_graph %s", attribute->then_graph().c_str());
Kevin Cheng5d00c692021-10-15 20:06:00 +0000202
Tai Ly8ead6c42024-02-14 22:35:44 +0000203 ERROR_IF(!else_block, "OpCondIf: fail to resolve else_graph %s", attribute->else_graph().c_str());
Kevin Cheng5d00c692021-10-15 20:06:00 +0000204
205 // Make sure operator input/output matches block input/output
206 // Skip the first rank 0 bool tensor on input list
207 int32_t num_input_tensor = getInputs().size() - 1;
208 int32_t num_output_tensor = getOutputs().size();
Jerry Ge9e94af82022-10-27 09:57:00 -0700209
Kevin Cheng5d00c692021-10-15 20:06:00 +0000210 ERROR_IF((int32_t)then_block->GetInputs().size() != num_input_tensor,
211 "OpCondIf: then_block has unexpected number of input");
212 ERROR_IF((int32_t)else_block->GetInputs().size() != num_input_tensor,
213 "OpCondIf: else_block has unexpected number of input");
214 ERROR_IF((int32_t)then_block->GetOutputs().size() != num_output_tensor,
215 "OpCondIf: then_block has unexpected number of output");
216 ERROR_IF((int32_t)else_block->GetOutputs().size() != num_output_tensor,
217 "OpCondIf: else_block has unexpected number of output");
218
219 for (int32_t i = 0; i < num_input_tensor; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700220 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000221 Tensor* operator_input = getInputs()[i + 1];
222 std::string then_block_input_name = then_block->GetInputs()[i];
223 std::string else_block_input_name = else_block->GetInputs()[i];
224 TosaSerializationTensor* then_block_input = then_block->GetTensorByName(then_block_input_name);
225 TosaSerializationTensor* else_block_input = else_block->GetTensorByName(else_block_input_name);
Tai Lya4d748b2023-03-28 22:06:56 +0000226 ERROR_IF(operator_input->getDtype() != ConvertDType(then_block_input->GetDtype()),
Kevin Cheng5d00c692021-10-15 20:06:00 +0000227 "OpCondIf: input tensor type mismatch with then_block input type");
Tai Lya4d748b2023-03-28 22:06:56 +0000228 ERROR_IF(operator_input->getDtype() != ConvertDType(else_block_input->GetDtype()),
Kevin Cheng5d00c692021-10-15 20:06:00 +0000229 "OpCondIf: input tensor type mismatch with else_block input type");
230 ERROR_IF(operator_input->getRank() != (int32_t)then_block_input->GetShape().size(),
231 "OpCondIf: input tensor rank mismatch with then_block input rank");
232 ERROR_IF(operator_input->getRank() != (int32_t)else_block_input->GetShape().size(),
233 "OpCondIf: input tensor rank mismatch with else_block input rank");
234 for (int32_t d = 0; d < operator_input->getRank(); d++)
235 {
236 ERROR_IF(operator_input->getShape()[d] != then_block_input->GetShape()[d],
237 "OpCondIf: input tensor dimension mismatch with then_block input dimension");
238 ERROR_IF(operator_input->getShape()[d] != else_block_input->GetShape()[d],
239 "OpCondIf: input tensor dimension mismatch with else_block input dimension");
240 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700241 }
242
Kevin Cheng5d00c692021-10-15 20:06:00 +0000243 for (int32_t i = 0; i < num_output_tensor; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700244 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000245 Tensor* operator_output = getOutputs()[i];
246 std::string then_block_output_name = then_block->GetOutputs()[i];
247 std::string else_block_output_name = else_block->GetOutputs()[i];
248 TosaSerializationTensor* then_block_output = then_block->GetTensorByName(then_block_output_name);
249 TosaSerializationTensor* else_block_output = else_block->GetTensorByName(else_block_output_name);
Tai Lya4d748b2023-03-28 22:06:56 +0000250 ERROR_IF(operator_output->getDtype() != ConvertDType(then_block_output->GetDtype()),
Kevin Cheng5d00c692021-10-15 20:06:00 +0000251 "OpCondIf: output tensor type mismatch with then_block output type");
Tai Lya4d748b2023-03-28 22:06:56 +0000252 ERROR_IF(operator_output->getDtype() != ConvertDType(else_block_output->GetDtype()),
Kevin Cheng5d00c692021-10-15 20:06:00 +0000253 "OpCondIf: output tensor type mismatch with else_block output type");
254 ERROR_IF(operator_output->getRank() != (int32_t)then_block_output->GetShape().size(),
255 "OpCondIf: output tensor rank mismatch with then_block output rank");
256 ERROR_IF(operator_output->getRank() != (int32_t)else_block_output->GetShape().size(),
257 "OpCondIf: output tensor rank mismatch with else_block output rank");
258 for (int32_t d = 0; d < operator_output->getRank(); d++)
259 {
260 ERROR_IF(operator_output->getShape()[d] != then_block_output->GetShape()[d],
261 "OpCondIf: output tensor dimension mismatch with then_block output dimension");
262 ERROR_IF(operator_output->getShape()[d] != else_block_output->GetShape()[d],
263 "OpCondIf: output tensor dimension mismatch with else_block output dimension");
264 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700265 }
266
267 return 0;
268}
269
270int OpCondIf::eval()
271{
272 bool cond_val = cond->getTensor()(0);
273 std::vector<TosaReference::Tensor*> block_inputs(getInputs().begin() + 1, getInputs().end());
274
275 if (cond_val)
276 {
277 if (evalBlock(then_block, block_inputs, getOutputs()))
278 {
Tai Ly8ead6c42024-02-14 22:35:44 +0000279 WARNING("OpCondIf: Fail to evaluate then branch block %s", attribute->then_graph().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700280 return 1;
281 }
282 }
283 else
284 {
285 if (evalBlock(else_block, block_inputs, getOutputs()))
286 {
Tai Ly8ead6c42024-02-14 22:35:44 +0000287 WARNING("OpCondIf: Fail to evaluate else branch block %s", attribute->else_graph().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700288 return 1;
289 }
290 }
291
292 return GraphNode::eval();
293}
294
Kevin Chengacb550f2021-06-29 15:32:19 -0700295OpWhileLoop::OpWhileLoop(SubgraphTraverser* sgt_,
296 TosaSerializationHandler* tsh_,
297 TosaAttributeBase* attribute_,
298 uint64_t id_)
299 : OpControlFlow(sgt_, tsh_, Op_WHILE_LOOP, id_)
Eric Kunzee5e26762020-10-13 16:11:07 -0700300{
301 INIT_ATTRIBUTE(WhileLoop);
302}
303
304OpWhileLoop::~OpWhileLoop()
305{
306 if (attribute)
307 delete attribute;
308}
309
310int OpWhileLoop::checkTensorAttributes()
311{
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000312 if (!tsh)
313 {
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100314 WARNING("OpWhileLoop: tosa serialization handler must not be null");
315 return 1;
316 }
317
Eric Kunzee5e26762020-10-13 16:11:07 -0700318 if (getInputs().size() <= 0)
319 {
320 WARNING("OpWhileLoop: must have at least 1 operands");
321 return 1;
322 }
323
324 if (getInputs().size() != getOutputs().size())
325 {
326 WARNING("OpWhileLoop: inputs and outputs size must match");
327 return 1;
328 }
329
Tai Ly8ead6c42024-02-14 22:35:44 +0000330 auto cond_region = tsh->GetRegionByName(attribute->cond_graph());
331 auto body_region = tsh->GetRegionByName(attribute->body_graph());
Tai Ly4e9a9772023-03-16 22:24:05 +0000332 if (cond_region && body_region)
333 {
Tai Ly8ead6c42024-02-14 22:35:44 +0000334 // new serialization: then_graph and else_graph point to regions
Tai Ly4e9a9772023-03-16 22:24:05 +0000335 cond_block = cond_region->GetBlocks().front();
336 body_block = body_region->GetBlocks().front();
337 }
338 else
339 {
340 auto region_name = getParentSGT()->getRegionName();
341 auto curr_region = tsh->GetRegionByName(region_name);
Tai Ly8ead6c42024-02-14 22:35:44 +0000342 cond_block = curr_region->GetBlockByName(attribute->cond_graph());
343 body_block = curr_region->GetBlockByName(attribute->body_graph());
Tai Ly4e9a9772023-03-16 22:24:05 +0000344 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700345
Tai Ly8ead6c42024-02-14 22:35:44 +0000346 ERROR_IF(!cond_block, "OpWhileLoop: fail to resolve cond_graph %s", attribute->cond_graph().c_str());
347 ERROR_IF(!body_block, "OpWhileLoop: fail to resolve body_graph %s", attribute->body_graph().c_str());
Kevin Cheng5d00c692021-10-15 20:06:00 +0000348
349 // Make sure operator input/output matches block input/output
350 int32_t num_block_tensor = getInputs().size();
351 ERROR_IF((int32_t)getOutputs().size() != num_block_tensor,
352 "OpWhileLoop: operator input tensor doesn't match output");
353 ERROR_IF((int32_t)cond_block->GetInputs().size() != num_block_tensor,
354 "OpWhileLoop: cond_block has unexpected number of input");
355 ERROR_IF((int32_t)body_block->GetInputs().size() != num_block_tensor,
356 "OpWhileLoop: body_block has unexpected number of input");
357 ERROR_IF((int32_t)body_block->GetOutputs().size() != num_block_tensor,
358 "OpWhileLoop: body_block has unexpected number of output");
359 for (int32_t i = 0; i < num_block_tensor; i++)
Eric Kunzee5e26762020-10-13 16:11:07 -0700360 {
Kevin Cheng5d00c692021-10-15 20:06:00 +0000361 Tensor* operator_input = getInputs()[i];
362 Tensor* operator_output = getOutputs()[i];
363 ERROR_IF(operator_input->matchRankTypeShape(*operator_output),
364 "OpWhileLoop: operator input tensor mismatch operator output tensor");
365
366 std::string cond_block_input_name = cond_block->GetInputs()[i];
367 std::string body_block_input_name = body_block->GetInputs()[i];
368 std::string body_block_output_name = body_block->GetOutputs()[i];
369 TosaSerializationTensor* cond_block_input = cond_block->GetTensorByName(cond_block_input_name);
370 TosaSerializationTensor* body_block_input = body_block->GetTensorByName(body_block_input_name);
371 TosaSerializationTensor* body_block_output = body_block->GetTensorByName(body_block_output_name);
372
Tai Lya4d748b2023-03-28 22:06:56 +0000373 ERROR_IF(operator_input->getDtype() != ConvertDType(cond_block_input->GetDtype()),
Kevin Cheng5d00c692021-10-15 20:06:00 +0000374 "OpWhileLoop: input tensor type mismatch with cond_block input type");
Tai Lya4d748b2023-03-28 22:06:56 +0000375 ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_input->GetDtype()),
Kevin Cheng5d00c692021-10-15 20:06:00 +0000376 "OpWhileLoop: input tensor type mismatch with body_block input type");
Tai Lya4d748b2023-03-28 22:06:56 +0000377 ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_output->GetDtype()),
Kevin Cheng5d00c692021-10-15 20:06:00 +0000378 "OpWhileLoop: input tensor type mismatch with body_block output type");
379 ERROR_IF(operator_input->getRank() != (int32_t)cond_block_input->GetShape().size(),
380 "OpWhileLoop: input tensor rank mismatch with cond_block input rank");
381 ERROR_IF(operator_input->getRank() != (int32_t)body_block_input->GetShape().size(),
382 "OpWhileLoop: input tensor rank mismatch with body_block input rank");
383 ERROR_IF(operator_input->getRank() != (int32_t)body_block_output->GetShape().size(),
384 "OpWhileLoop: input tensor rank mismatch with body_block output rank");
385
386 for (int32_t d = 0; d < operator_input->getRank(); d++)
387 {
388 ERROR_IF(operator_input->getShape()[d] != cond_block_input->GetShape()[d],
389 "OpWhileLoop: input tensor dimension mismatch with cond_block input dimension");
390 ERROR_IF(operator_input->getShape()[d] != body_block_input->GetShape()[d],
391 "OpWhileLoop: input tensor dimension mismatch with body_block input dimension");
392 ERROR_IF(operator_input->getShape()[d] != body_block_output->GetShape()[d],
393 "OpWhileLoop: input tensor dimension mismatch with body_block output dimension");
394 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700395 }
396
Kevin Cheng5d00c692021-10-15 20:06:00 +0000397 ERROR_IF(cond_block->GetOutputs().size() != 1, "OpWhileLoop: cond_block can only have 1 output tensor");
398 std::string cond_block_output_name = cond_block->GetOutputs()[0];
399 TosaSerializationTensor* cond_block_output = cond_block->GetTensorByName(cond_block_output_name);
400 ERROR_IF(cond_block_output->GetDtype() != DType_BOOL, "OpWhileLoop: cond_block output can only be bool type");
401 ERROR_IF(cond_block_output->GetShape().size() != 0, "OpWhileLoop: cond_block output can only be rank 0");
Eric Kunzee5e26762020-10-13 16:11:07 -0700402
403 return 0;
404}
405
406int OpWhileLoop::eval()
407{
Tai Lya4d748b2023-03-28 22:06:56 +0000408 TosaReference::Tensor0<bool> cond_output_ctensor("cond_output", DType_BOOL, std::vector<int32_t>({}));
Eric Kunzee5e26762020-10-13 16:11:07 -0700409
410 cond_output_ctensor.allocate();
411 std::vector<TosaReference::Tensor*> cond_block_outputs;
412 cond_block_outputs.push_back(&cond_output_ctensor);
413
414 size_t num_input_output = getInputs().size();
415 size_t eval_count = 0;
416
417 while (eval_count++ < MAX_WHILE_LOOP_ITERATION)
418 {
419 if (evalBlock(cond_block, getInputs(), cond_block_outputs))
420 {
Tai Ly8ead6c42024-02-14 22:35:44 +0000421 WARNING("OpWhileLoop: Fail to evaluate cond block %s", attribute->cond_graph().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700422 return 1;
423 }
424 bool cond_val = cond_output_ctensor.getTensor()(0);
425 DEBUG_HIGH(OP, "Conditional block value: %d", cond_val);
426
427 if (cond_val)
428 {
429 if (evalBlock(body_block, getInputs(), getOutputs()))
430 {
Tai Ly8ead6c42024-02-14 22:35:44 +0000431 WARNING("OpWhileLoop: Fail to evaluate body block %s", attribute->body_graph().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700432 return 1;
433 }
434
435 // assigning output tensors value back to input tensors value for next iteration
436 for (size_t i = 0; i < num_input_output; i++)
437 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700438 getInputs()[i] = getOutputs()[i];
Eric Kunzee5e26762020-10-13 16:11:07 -0700439 }
440 }
441 else
442 {
443 // in last iteration or the case it never evaluates body block
444 // assign input tensors value to output tensors
445 for (size_t i = 0; i < num_input_output; i++)
446 {
447 if (getOutputs()[i]->copyValueFrom(getInputs()[i]))
448 {
449 WARNING("Fail to copy tensor value %s -> %s", getInputs()[i]->getName().c_str(),
450 getOutputs()[i]->getName().c_str());
451 return 1;
452 }
453 }
454 break;
455 }
456 }
457
458 return GraphNode::eval();
459}