blob: 52b1806b4208baec54f809dce1a48a1521fb6aa2 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Won Jeon2c34b462024-02-06 18:37:00 +00002// Copyright (c) 2020-2024, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "subgraph_traverser.h"
James Ward24dbc422022-10-19 12:20:31 +010017#include "arith_util.h"
Jerry Ge9c9c8da2023-07-19 23:08:16 +000018#include "tosa_model_types.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070019
Kevin Cheng903763c2021-09-28 16:14:52 -070020#ifndef SUBGRAPH_ERROR_IF
21#define SUBGRAPH_ERROR_IF(COND, fmt, ...) \
22 if ((COND)) \
23 { \
24 if (this->getGraphStatus() != GraphStatus::TOSA_UNPREDICTABLE) \
25 { \
26 this->setGraphStatus(GraphStatus::TOSA_ERROR); \
27 } \
28 fprintf(g_func_debug.func_debug_file, COL_FATAL("SUBGRAPH_ERROR_IF() fails AT %s:%d %s(): (%s)\n"), __FILE__, \
29 __LINE__, __func__, #COND); \
30 fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
31 func_print_backtrace(g_func_debug.func_debug_file); \
32 return 1; \
33 }
34#endif
35
Eric Kunzee5e26762020-10-13 16:11:07 -070036using namespace TosaReference;
37using namespace Eigen;
38using namespace tosa;
39
Jerry Ge9c9c8da2023-07-19 23:08:16 +000040SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block,
41 TosaSerializationHandler* _tsh,
42 SubgraphTraverser* _parent_sgt)
Eric Kunzee5e26762020-10-13 16:11:07 -070043{
Jerry Ge9e94af82022-10-27 09:57:00 -070044
Kevin Chengacb550f2021-06-29 15:32:19 -070045 graph_status = GraphStatus::TOSA_VALID;
Jerry Ge9c9c8da2023-07-19 23:08:16 +000046 block = _block;
Eric Kunzee5e26762020-10-13 16:11:07 -070047
Jerry Ge9c9c8da2023-07-19 23:08:16 +000048 tsh = _tsh;
Jerry Ge9e94af82022-10-27 09:57:00 -070049 parent_sgt = _parent_sgt;
Eric Kunzee5e26762020-10-13 16:11:07 -070050 tensors.clear();
51 nodes.clear();
52 nextNodeList.clear();
53}
54
55SubgraphTraverser::~SubgraphTraverser()
56{
57 nextNodeList.clear();
58
59 for (GraphNode* n : nodes)
60 {
61 delete n;
62 }
63 nodes.clear();
64
65 for (TosaReference::Tensor* t : tensors)
66 {
Tai Lycf84bc92023-09-07 20:49:09 +000067 if (t->getIsVariable() && parent_sgt)
68 {
69 // variable tensors are owned by top level sgt
70 continue;
71 }
Eric Kunzee5e26762020-10-13 16:11:07 -070072 if (t->is_allocated())
73 {
74 t->deallocate();
75 }
76 delete t;
77 }
78 tensors.clear();
79}
80
81int SubgraphTraverser::getNumInputTensors() const
82{
83 return inputTensors.size();
84}
85
86TosaReference::Tensor* SubgraphTraverser::getInputTensor(const unsigned int idx) const
87{
88 return inputTensors[idx];
89}
90
91TosaReference::Tensor* SubgraphTraverser::getInputTensorByName(const std::string name) const
92{
93 for (auto t : inputTensors)
94 {
95 if (t->getName() == name)
96 {
97 return t;
98 }
99 }
100
101 return nullptr;
102}
103
104int SubgraphTraverser::getNumOutputTensors() const
105{
106 return outputTensors.size();
107}
108
109TosaReference::Tensor* SubgraphTraverser::getOutputTensor(const unsigned int idx) const
110{
111 return outputTensors[idx];
112}
113
114TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::string name) const
115{
116 for (auto t : outputTensors)
117 {
118 if (t->getName() == name)
119 {
120 return t;
121 }
122 }
123
124 return nullptr;
125}
126
Tai Lycf84bc92023-09-07 20:49:09 +0000127int SubgraphTraverser::getNumVariableTensors() const
128{
129 return variableTensors.size();
130}
131
132TosaReference::Tensor* SubgraphTraverser::getVariableTensor(const unsigned int idx) const
133{
134 return variableTensors[idx];
135}
136
137// find variable tensor by name in top level sgt's @a variableTensors
138TosaReference::Tensor* SubgraphTraverser::getVariableTensorByName(const std::string name) const
139{
140 // variable tensors are owned by top level sgt
141 if (parent_sgt)
142 {
143 return parent_sgt->getVariableTensorByName(name);
144 }
145
146 for (auto t : variableTensors)
147 {
148 if (t->getName() == name)
149 {
150 return t;
151 }
152 }
153
154 return nullptr;
155}
156
157// add variable tensor to top level sgt's @a variableTensors
158int SubgraphTraverser::registerVariableTensor(Tensor* tensor)
159{
160 SUBGRAPH_ERROR_IF(!tensor->getIsVariable(),
161 "SubgraphTraverser::registerVariableTensor(): tensor %s is not a variable",
162 tensor->getName().c_str());
163 // variable tensors are owned by top level sgt
164 if (parent_sgt)
165 {
166 return parent_sgt->registerVariableTensor(tensor);
167 }
168 variableTensors.push_back(tensor);
169 return 0;
170}
171
Eric Kunzee5e26762020-10-13 16:11:07 -0700172int SubgraphTraverser::initializeGraph()
173{
Eric Kunzee5e26762020-10-13 16:11:07 -0700174 int idx = 0;
Kevin Chengc72b59c2021-09-29 16:57:55 -0700175
Jerry Ge9e94af82022-10-27 09:57:00 -0700176 std::vector<TosaSerializationTensor*> ser_tensor_vec;
177 // Get all the serialized tensors from TosaSerializationHandler.
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100178 if (tsh)
Jerry Ge9e94af82022-10-27 09:57:00 -0700179 {
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100180 for (auto region : tsh->GetRegions())
Jerry Ge9e94af82022-10-27 09:57:00 -0700181 {
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100182 for (auto block : region->GetBlocks())
Tai Ly4e9a9772023-03-16 22:24:05 +0000183 {
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100184 for (auto ser_tensor : block->GetTensors())
185 {
186 ser_tensor_vec.push_back(ser_tensor);
187 }
Tai Ly4e9a9772023-03-16 22:24:05 +0000188 }
Jerry Ge9e94af82022-10-27 09:57:00 -0700189 }
190 }
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100191 else
192 {
193 for (auto ser_tensor : block->GetTensors())
194 {
195 ser_tensor_vec.push_back(ser_tensor);
196 }
197 }
Jerry Ge9e94af82022-10-27 09:57:00 -0700198
199 std::vector<GraphNode*> non_const_node_vec;
Eric Kunzee5e26762020-10-13 16:11:07 -0700200 for (auto op : block->GetOperators())
201 {
202 // translated TosaSerializationOperator to GraphNode
Tai Lya4d748b2023-03-28 22:06:56 +0000203 TOSA_REF_TYPE input_dtype = TOSA_REF_TYPE_UNKNOWN;
204 TOSA_REF_TYPE output_dtype = TOSA_REF_TYPE_UNKNOWN;
205 TOSA_REF_TYPE weight_dtype = TOSA_REF_TYPE_UNKNOWN;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000206 uint32_t input_rank = 0;
207 uint32_t output_rank = 0;
208 uint32_t weight_rank = 0;
209 int32_t input_index = -1;
210 int32_t weight_index = -1;
Kevin Cheng550ccc52021-03-03 11:21:43 -0800211
212 switch (op->GetOp())
Eric Kunzee5e26762020-10-13 16:11:07 -0700213 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800214 case Op_CONV2D:
Kevin Cheng1533b852021-09-01 12:51:58 -0700215 case Op_CONV3D:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800216 case Op_DEPTHWISE_CONV2D:
217 case Op_TRANSPOSE_CONV2D:
218 case Op_FULLY_CONNECTED:
219 input_index = 0;
220 weight_index = 1;
221 break;
222 case Op_SELECT:
223 input_index = 1;
224 break;
225 default:
226 if (!op->GetInputTensorNames().empty())
227 input_index = 0;
228 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700229 }
230
Kevin Cheng550ccc52021-03-03 11:21:43 -0800231 if (input_index != -1)
Kevin Chengdf862692021-02-22 15:22:22 -0800232 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700233 SUBGRAPH_ERROR_IF(
234 (size_t)input_index >= op->GetInputTensorNames().size(),
235 "SubgraphTraverser::initializeGraph(): Op=%s, input_index %d must be within [0, num_input - 1]",
236 EnumNamesOp()[op->GetOp()], input_index);
Kevin Chengdf862692021-02-22 15:22:22 -0800237
Kevin Cheng550ccc52021-03-03 11:21:43 -0800238 std::string input_name = op->GetInputTensorNames()[input_index];
Jerry Ge9e94af82022-10-27 09:57:00 -0700239 TosaSerializationTensor* input_tensor = nullptr;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000240 for (auto ser_tensor : ser_tensor_vec)
241 {
242 if (ser_tensor->GetName() == input_name)
243 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700244 input_tensor = ser_tensor;
245 }
246 }
247
Kevin Cheng903763c2021-09-28 16:14:52 -0700248 SUBGRAPH_ERROR_IF(
249 !input_tensor,
250 "SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler",
251 input_name.c_str());
Tai Lya4d748b2023-03-28 22:06:56 +0000252 input_dtype = ConvertDType(input_tensor->GetDtype());
Kevin Cheng550ccc52021-03-03 11:21:43 -0800253 input_rank = input_tensor->GetShape().size();
Kevin Chengdf862692021-02-22 15:22:22 -0800254 }
255
Kevin Cheng550ccc52021-03-03 11:21:43 -0800256 if (weight_index != -1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700257 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700258 SUBGRAPH_ERROR_IF(
259 (size_t)weight_index >= op->GetInputTensorNames().size(),
260 "SubgraphTraverser::initializeGraph(): Op=%s, weight_index %d must be within [0, num_input - 1]",
261 EnumNamesOp()[op->GetOp()], weight_index);
Kevin Cheng550ccc52021-03-03 11:21:43 -0800262 std::string weight_name = op->GetInputTensorNames()[weight_index];
Jerry Ge9e94af82022-10-27 09:57:00 -0700263 TosaSerializationTensor* weight_tensor = nullptr;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000264 for (auto ser_tensor : ser_tensor_vec)
265 {
266 if (ser_tensor->GetName() == weight_name)
267 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700268 weight_tensor = ser_tensor;
269 }
270 }
271
Kevin Cheng903763c2021-09-28 16:14:52 -0700272 SUBGRAPH_ERROR_IF(
273 !weight_tensor,
274 "SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler",
275 weight_name.c_str());
Tai Lya4d748b2023-03-28 22:06:56 +0000276 weight_dtype = ConvertDType(weight_tensor->GetDtype());
Kevin Cheng550ccc52021-03-03 11:21:43 -0800277 weight_rank = weight_tensor->GetShape().size();
Eric Kunzee5e26762020-10-13 16:11:07 -0700278 }
279
Kevin Cheng478101b2021-10-04 10:43:14 -0700280 SUBGRAPH_ERROR_IF(op->GetOutputTensorNames().size() == 0,
281 "SubgraphTraverser::initializeGraph(): Op=%s must have at least one output tensor.",
282 EnumNamesOp()[op->GetOp()]);
Kevin Cheng550ccc52021-03-03 11:21:43 -0800283 std::string output_name = op->GetOutputTensorNames()[0];
284 TosaSerializationTensor* output_tensor = block->GetTensorByName(output_name);
Kevin Cheng903763c2021-09-28 16:14:52 -0700285 SUBGRAPH_ERROR_IF(
286 !output_tensor,
287 "SubgraphTraverser::initializeGraph(): fail to get output tensor %s from TosaSerializationHandler",
288 output_name.c_str());
Tai Lya4d748b2023-03-28 22:06:56 +0000289 output_dtype = ConvertDType(output_tensor->GetDtype());
Kevin Cheng550ccc52021-03-03 11:21:43 -0800290 output_rank = output_tensor->GetShape().size();
291
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
293 EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
294
Jerry Ge9e94af82022-10-27 09:57:00 -0700295 GraphNode* node = nullptr;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000296 if (this->parent_sgt)
297 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700298 node = OpFactory::newOp(this->parent_sgt, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000299 input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
Jerry Ge9e94af82022-10-27 09:57:00 -0700300 node->setInMainBlock(false);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000301 }
302 else
303 {
304 node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, input_rank,
305 output_dtype, output_rank, weight_dtype, weight_rank);
306 if (node)
307 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700308 node->setInMainBlock(true);
309 }
310 }
311
Kevin Cheng550ccc52021-03-03 11:21:43 -0800312 if (!node)
Eric Kunzee5e26762020-10-13 16:11:07 -0700313 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800314 if (weight_index == -1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700315 {
316 fprintf(g_func_debug.func_debug_file,
Kevin Cheng903763c2021-09-28 16:14:52 -0700317 "SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d) "
318 "-> (%s rank %d)",
Tai Lya4d748b2023-03-28 22:06:56 +0000319 EnumNamesOp()[op->GetOp()], EnumNameTOSAREFTYPE(input_dtype), input_rank,
320 EnumNameTOSAREFTYPE(output_dtype), output_rank);
Eric Kunzee5e26762020-10-13 16:11:07 -0700321 }
322 else
323 {
324 fprintf(g_func_debug.func_debug_file,
Kevin Cheng903763c2021-09-28 16:14:52 -0700325 "SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d), "
326 "weight=(%s rank %d) -> (%s rank %d)",
Tai Lya4d748b2023-03-28 22:06:56 +0000327 EnumNamesOp()[op->GetOp()], EnumNameTOSAREFTYPE(input_dtype), input_rank,
328 EnumNameTOSAREFTYPE(weight_dtype), weight_rank, EnumNameTOSAREFTYPE(output_dtype), output_rank);
Eric Kunzee5e26762020-10-13 16:11:07 -0700329 }
330
Kevin Cheng550ccc52021-03-03 11:21:43 -0800331 for (auto& ts : op->GetInputTensorNames())
Eric Kunzee5e26762020-10-13 16:11:07 -0700332 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700333 fprintf(g_func_debug.func_debug_file, "SubgraphTraverser::initializeGraph(): Input: %s\n", ts.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700334 }
335
Kevin Cheng550ccc52021-03-03 11:21:43 -0800336 for (auto& ts : op->GetOutputTensorNames())
Eric Kunzee5e26762020-10-13 16:11:07 -0700337 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700338 fprintf(g_func_debug.func_debug_file, "SubgraphTraverser::initializeGraph(): Output: %s\n", ts.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700339 }
Kevin Cheng903763c2021-09-28 16:14:52 -0700340 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported operation type or rank.");
Eric Kunzee5e26762020-10-13 16:11:07 -0700341 }
342
Kevin Chengc72b59c2021-09-29 16:57:55 -0700343 // Elementwise operator might set TOSA_ERROR when registering lambda function when creating the op.
344 // Check graph status after the op being constructed.
345 SUBGRAPH_ERROR_IF(getGraphStatus() == GraphStatus::TOSA_ERROR,
346 "SubgraphTraverser::initializeGraph(): Op %8s triggered ERROR_IF() when constructing the op.",
347 EnumNamesOp()[op->GetOp()]);
348
Kevin Cheng550ccc52021-03-03 11:21:43 -0800349 for (auto& name : op->GetInputTensorNames())
Eric Kunzee5e26762020-10-13 16:11:07 -0700350 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800351 node->addInputName(name);
Kevin Chengc72b59c2021-09-29 16:57:55 -0700352 used_tensor_name_set.insert(name);
Eric Kunzee5e26762020-10-13 16:11:07 -0700353 }
354
355 for (auto name : op->GetOutputTensorNames())
356 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800357 node->addOutputName(name);
Kevin Chengc72b59c2021-09-29 16:57:55 -0700358 used_tensor_name_set.insert(name);
Eric Kunzee5e26762020-10-13 16:11:07 -0700359 }
360
Kevin Cheng550ccc52021-03-03 11:21:43 -0800361 addNode(node);
Eric Kunzee5e26762020-10-13 16:11:07 -0700362
363 // if node doesn't have any inputs (i.e. CONST)
364 // it should be ready for evaluation
Kevin Cheng550ccc52021-03-03 11:21:43 -0800365 if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList())
Eric Kunzee5e26762020-10-13 16:11:07 -0700366 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800367 addToNextNodeList(node);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000368 }
369 else if (!node->getInMainBlock())
370 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700371 non_const_node_vec.push_back(node);
Eric Kunzee5e26762020-10-13 16:11:07 -0700372 }
373
Tai Lycf84bc92023-09-07 20:49:09 +0000374 // Bug fix: add the ready node in main block for evaluation
375 if (node->hasAllInputsReady() && !node->getOnNextNodeList() && !node->getEvaluated())
376 {
377 addToNextNodeList(node);
378 }
379
Eric Kunzee5e26762020-10-13 16:11:07 -0700380 idx++;
381 }
382
383 for (auto ts : block->GetTensors())
384 {
Tai Lycf84bc92023-09-07 20:49:09 +0000385 addTensor(ts);
Kevin Chengcc61be32021-10-14 17:09:57 -0700386 }
387
388 DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str());
389 for (auto& input_name : block->GetInputs())
390 {
391 TosaReference::Tensor* tensor = findTensorByName(input_name);
392 DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str());
393 if (tensor)
394 {
395 tensor->setIsSubgraphInput();
396 inputTensors.push_back(tensor);
397 }
398 else
399 {
400 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find input tensor by name %s",
401 input_name.c_str());
402 }
403 }
404
405 DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str());
406 for (auto& output_name : block->GetOutputs())
407 {
408 TosaReference::Tensor* tensor = findTensorByName(output_name);
Jerry Ge9e94af82022-10-27 09:57:00 -0700409 DEBUG_INFO(GT, "output tensor name=%s", output_name.c_str());
Kevin Chengcc61be32021-10-14 17:09:57 -0700410 if (tensor)
411 {
412 tensor->setIsSubgraphOutput();
413 outputTensors.push_back(tensor);
414 }
415 else
416 {
417 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find output tensor by name %s",
418 output_name.c_str());
419 }
420 }
421
422 if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
423 {
424 dumpNextNodeList(g_func_debug.func_debug_file);
425 }
426
Jerry Ge9e94af82022-10-27 09:57:00 -0700427 // If the node is not in mainblock and not const
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000428 for (auto node : non_const_node_vec)
429 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700430 bool all_inputs_from_parent = true;
431 for (std::string& name : node->getInputNames())
432 {
433 TosaReference::Tensor* t = findTensorByName(name);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000434 if (!t->getIsParentGraphOutput())
435 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700436 all_inputs_from_parent = false;
437 }
438 }
439 // In the children block, when a node has all its inputs from parent
440 // block, we have to manually add this node to the evaluation list
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000441 if (all_inputs_from_parent && !node->getOnNextNodeList())
442 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700443 addToNextNodeList(node);
444 }
445 }
Kevin Chengcc61be32021-10-14 17:09:57 -0700446 return 0;
447}
448
Jerry Gee5cabbf2023-07-17 21:33:17 +0000449int SubgraphTraverser::allocateInputTensors()
Kevin Chengcc61be32021-10-14 17:09:57 -0700450{
Jerry Gee5cabbf2023-07-17 21:33:17 +0000451 auto input_tensor_names_vec = block->GetInputs();
452
453 for (auto input_tensor_name : input_tensor_names_vec)
Kevin Chengcc61be32021-10-14 17:09:57 -0700454 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000455 this->allocateTensor(input_tensor_name);
456 }
457
Tai Lycf84bc92023-09-07 20:49:09 +0000458 // allocate variable tensors if not already allocated
459 for (auto ts : block->GetTensors())
460 {
461 if (ts->GetVariable())
462 {
463 TosaReference::Tensor* tensor = findTensorByName(ts->GetName());
464 SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateInputTensors(): can't find tensor %s.",
465 ts->GetName().c_str());
466 if (!tensor->is_allocated())
467 {
468 DEBUG_INFO(GT, "Is a VariableTensor %s", ts->GetName().c_str());
469 this->allocateTensor(ts->GetName());
470 }
471 }
472 }
473
Jerry Gee5cabbf2023-07-17 21:33:17 +0000474 return 0;
475}
476
477int SubgraphTraverser::allocateTensor(std::string name)
478{
479 auto ts = block->GetTensorByName(name);
480
481 // Bail out if tensor is used and any of its dimension is invalid.
482 auto got = used_tensor_name_set.find(ts->GetName());
483 if (got != used_tensor_name_set.end())
484 {
485 uint32_t elements = 1;
486 for (auto& dim : ts->GetShape())
Kevin Chengacb550f2021-06-29 15:32:19 -0700487 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000488 if (dim <= 0)
Kevin Chengacb550f2021-06-29 15:32:19 -0700489 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000490 DEBUG_INFO(GT, "Failed to allocate tensor %s with invalid dimension of %d", ts->GetName().c_str(), dim);
491 this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
492 return 1;
493 }
494 if (dim > static_cast<int32_t>(TOSA_MAX_TENSOR_SIZE / elements))
495 {
496 // Size greather than maximum defined in spec
497 DEBUG_INFO(GT, "Tensor %s size is greater than allowed maximum", ts->GetName().c_str());
498 this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
499 return 1;
Kevin Chengacb550f2021-06-29 15:32:19 -0700500 }
501 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000502 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700503
Jerry Gee5cabbf2023-07-17 21:33:17 +0000504 TosaReference::Tensor* tensor = findTensorByName(ts->GetName());
505 SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateTensor(): can't find tensor %s.", ts->GetName().c_str());
Kevin Chengacb550f2021-06-29 15:32:19 -0700506
Jerry Gee5cabbf2023-07-17 21:33:17 +0000507 DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
508 if (tensor->allocate())
509 {
510 FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
511 }
512
Tai Ly8690a082023-12-18 20:40:24 +0000513 // set valid for constant tensors:
514 if ((ts->GetShape().empty() && ts->GetDtype() == DType_SHAPE))
515 {
516 // corner case: const_shape {} has no data
517 tensor->setIsValid();
518 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000519 if (!ts->GetData().empty())
520 {
Tai Lycf84bc92023-09-07 20:49:09 +0000521 if (ts->GetVariable() && g_func_config.initialize_variable_tensor_from_numpy)
522 return 0;
Jerry Gee5cabbf2023-07-17 21:33:17 +0000523 DEBUG_INFO(GT, "Setting data for tensor %s", tensor->getName().c_str());
524 auto serialization_dtype = ts->GetDtype();
525 switch (serialization_dtype)
Kevin Chengcc61be32021-10-14 17:09:57 -0700526 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000527 case DType_INT4: {
528 std::vector<int8_t> i4_data;
529 TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data);
530 std::vector<int32_t> i32_data(i4_data.begin(), i4_data.end());
531 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
Eric Kunzee5e26762020-10-13 16:11:07 -0700532 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000533 break;
534 case DType_INT8: {
535 std::vector<int8_t> i8_data;
536 TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data);
537 std::vector<int32_t> i32_data(i8_data.begin(), i8_data.end());
538 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
539 }
540 break;
541 case DType_INT16: {
542 std::vector<int16_t> i16_data;
543 TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data);
544 std::vector<int32_t> i32_data(i16_data.begin(), i16_data.end());
545 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
546 }
547 break;
548 case DType_INT32: {
549 std::vector<int32_t> i32_data;
550 TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data);
551 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
552 }
553 break;
Tai Ly8690a082023-12-18 20:40:24 +0000554 case DType_INT48: {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000555 std::vector<int64_t> i64_data;
556 TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
557 tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
558 }
559 break;
Tai Ly8690a082023-12-18 20:40:24 +0000560 case DType_SHAPE: {
561 std::vector<int64_t> i64_data;
562 TosaSerializationHandler::ConvertU8toI64(ts->GetData(), tensor->getElementCount(), i64_data);
563 tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
564 }
565 break;
Jerry Gee5cabbf2023-07-17 21:33:17 +0000566 case DType_FP16: {
Jerry Ge68780532024-02-26 14:57:45 -0800567 std::vector<half_float::half> f16_data;
Jerry Gee5cabbf2023-07-17 21:33:17 +0000568 TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data);
569 if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
570 {
571 std::vector<double> f64_data(f16_data.begin(), f16_data.end());
572 tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
573 }
574 else
575 {
Jerry Ge68780532024-02-26 14:57:45 -0800576 std::vector<float> f32_data(f16_data.begin(), f16_data.end());
577 tensor->setTensorValueFloat(f32_data.size(), f32_data.data());
Jerry Gee5cabbf2023-07-17 21:33:17 +0000578 }
579 }
580 break;
581 case DType_BF16: {
582 std::vector<float> fp32_data;
583 TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
584 // Ensure valid bfloat16 stored in each float
585 for (auto f : fp32_data)
586 ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
587 if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
588 {
589 std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
590 tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
591 }
592 else
593 {
594 tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
595 }
596 }
597 break;
Won Jeon2c34b462024-02-06 18:37:00 +0000598 case DType_FP8E4M3:
599 case DType_FP8E5M2: {
600 std::vector<float> fp32_data;
601 TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
602 // Ensure valid fp8 stored in each float
603 if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
604 {
605 std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
606 tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
607 }
608 else
609 {
610 tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
611 }
612 }
613 break;
Jerry Gee5cabbf2023-07-17 21:33:17 +0000614 case DType_FP32: {
615 std::vector<float> fp32_data;
616 TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
617 if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
618 {
619 std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
620 tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
621 }
622 else
623 {
624 tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
625 }
626 }
627 break;
628 case DType_BOOL: {
629 std::vector<bool> bool_data;
630 TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data);
631
632 // std::vector<bool>::data() will return bit mask instead of array of bool array.
633 // Need to translate manually.
634 bool* bool_array = (bool*)calloc(bool_data.size(), sizeof(bool));
635 for (size_t i = 0; i < bool_data.size(); i++)
636 {
637 bool_array[i] = bool_data[i];
638 }
639 tensor->setTensorValueBool(bool_data.size(), bool_array);
640 }
641 break;
642 default:
643 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.",
644 EnumNameDType(ts->GetDtype()));
Eric Kunzee5e26762020-10-13 16:11:07 -0700645 }
Tai Lycf84bc92023-09-07 20:49:09 +0000646 tensor->setIsValid();
Tai Ly8690a082023-12-18 20:40:24 +0000647 }
648
649 if (tensor->getIsValid())
650 {
Tai Lycf84bc92023-09-07 20:49:09 +0000651 // Push ready consumers to the next node list
652 for (auto gn : tensor->getConsumers())
653 {
654 if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated())
655 {
656 addToNextNodeList(gn);
657 }
658 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700659 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700660 return 0;
661}
662
663int SubgraphTraverser::isFullyEvaluated() const
664{
665 return nextNodeList.empty();
666}
667
668GraphNode* SubgraphTraverser::getNextNode()
669{
670 GraphNode* nextNode = nextNodeList.front();
671 ASSERT_MSG(nextNode, "SubgraphTraverser::getNextNode(): called with empty next node list");
672 ASSERT_MSG(nextNode->getOnNextNodeList(),
673 "SubgraphTraverser::getNextNode(): internal state error: node is not listed as being on next node list");
674
675 nextNodeList.pop_front();
676
677 nextNode->clearOnNextNodeList();
678 return nextNode;
679}
680
681int SubgraphTraverser::addToNextNodeList(GraphNode* nextNode)
682{
683 ASSERT_MSG(nextNode, "SubgraphTraverser::addToNextNodeList(): called with no node");
684 ASSERT_MSG(!nextNode->getOnNextNodeList(),
685 "SubgraphTraverser::addToNextNodeList(): internal state error: node is already on next node list");
686
687 nextNode->setOnNextNodeList();
688 nextNodeList.push_back(nextNode);
689
690 return 0;
691}
692
693int SubgraphTraverser::evaluateNextNode()
694{
695 if (isFullyEvaluated())
696 return 0;
697
698 GraphNode* currNode = getNextNode();
699
700 DEBUG_INFO(GT, "Evaluating node_%03lu, %8s, output tensor=%s", currNode->getID(), EnumNamesOp()[currNode->getOp()],
701 currNode->getOutputNames()[0].c_str());
702
703 // Sanity check for never-ending loops
704 if (currNode->getEvalCount() >= MAX_EVAL_COUNT && (currNode->getEvalCount() % MAX_EVAL_COUNT) == 0)
705 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700706 WARNING("SubgraphTraverser::evaluateNextNode(): Node %lu has been evaluated %d times. Loop suspected.",
707 currNode->getID(), currNode->getEvalCount());
Eric Kunzee5e26762020-10-13 16:11:07 -0700708 }
709
Kevin Cheng550ccc52021-03-03 11:21:43 -0800710 for (auto tensor : currNode->getOutputs())
Eric Kunzee5e26762020-10-13 16:11:07 -0700711 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800712 if (!tensor->is_allocated())
Jerry Gee5cabbf2023-07-17 21:33:17 +0000713 {
714 if (this->allocateTensor(tensor->getName()))
Eric Kunzee5e26762020-10-13 16:11:07 -0700715 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700716 FATAL_ERROR("SubgraphTraverser::evaluateNextNode(): Failed to allocate Eigen tensor %s",
717 tensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700718 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000719 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700720 }
721
722 if (currNode->eval())
723 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700724 WARNING("SubgraphTraverser::evaluateNextNode(): Failed to evaluate node: %lu", currNode->getID());
Kevin Chengacb550f2021-06-29 15:32:19 -0700725 return 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700726 }
727
Tai Lycf84bc92023-09-07 20:49:09 +0000728 currNode->setEvaluated();
729
Eric Kunzee5e26762020-10-13 16:11:07 -0700730 // free input tensor if all of its consumers have all of their outputs ready and it's not block's output
Jerry Gee5cabbf2023-07-17 21:33:17 +0000731 for (auto tensor : currNode->getInputs())
732 {
733 bool in_use = false;
734
735 auto tensor_check = findTensorByName(tensor->getName());
736 if (tensor_check->getIsParentGraphOutput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700737 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000738 // if it's parent's block output tensor, we can't free it
739 continue;
740 }
Jerry Ge9e94af82022-10-27 09:57:00 -0700741
Tai Lycf84bc92023-09-07 20:49:09 +0000742 if (tensor->getIsVariable())
743 {
744 // if tensor is a Variable, we cannot free it
745 continue;
746 }
747
Jerry Gee5cabbf2023-07-17 21:33:17 +0000748 for (auto node : tensor->getConsumers())
749 {
750 // If the node is inside a loop, the input tensor is still needed
751 if (!node->hasAllOutputsReady())
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000752 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000753 in_use = true;
Jerry Ge9e94af82022-10-27 09:57:00 -0700754 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700755 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000756
757 for (auto name : block->GetOutputs())
758 {
759 if (name == tensor->getName())
760 {
761 in_use = true;
762 }
763 }
764
765 if (!in_use)
766 {
767 tensor->deallocate();
768 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700769 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000770
Eric Kunzee5e26762020-10-13 16:11:07 -0700771 // Search the output tensors of this node to see if
772 // there are now new ready nodes available from completing this node
773 for (TosaReference::Tensor* tensor : currNode->getOutputs())
774 {
775 for (GraphNode* node : tensor->getConsumers())
776 {
Tai Lycf84bc92023-09-07 20:49:09 +0000777 if (!node->getOnNextNodeList() && node->hasAllInputsReady() && !node->getEvaluated())
Eric Kunzee5e26762020-10-13 16:11:07 -0700778 {
779 addToNextNodeList(node);
780 }
781 }
782 }
783
784 if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
785 {
786 dumpNextNodeList(g_func_debug.func_debug_file);
787 }
788
789 if (g_func_config.dump_intermediates)
790 {
791 currNode->dumpNode(g_func_debug.func_debug_file);
792 for (auto outs : currNode->getOutputs())
793 {
794 outs->dumpTensorParams(g_func_debug.func_debug_file);
795 outs->dumpTensor(g_func_debug.func_debug_file);
796 fprintf(g_func_debug.func_debug_file, "\n");
797 }
798 }
799
800 return 0;
801}
802
803int SubgraphTraverser::dumpNextNodeList(FILE* out) const
804{
805
806 // Dump next node list
807 fprintf(out, "Next node list\n");
808
809 if (nextNodeList.empty())
810 {
811 fprintf(out, "<empty>\n");
812 }
813
814 for (auto gn : nextNodeList)
815 {
816 gn->dumpNode(out);
817 }
818
819 fprintf(out, "Done.\n");
820 return 0;
821}
822
823int SubgraphTraverser::clearAllNodeMarkings()
824{
825 for (GraphNode* currNode : nodes)
826 {
827 currNode->clearNodeMarked();
828 }
829
830 return false;
831}
832
Tai Lycf84bc92023-09-07 20:49:09 +0000833int SubgraphTraverser::addTensor(const TosaSerializationTensor* ts)
Eric Kunzee5e26762020-10-13 16:11:07 -0700834{
Tai Lycf84bc92023-09-07 20:49:09 +0000835 TosaReference::Tensor* tensor = nullptr;
836
837 // variable tensors are shared: make new tensor only if not found
838 if (ts->GetVariable())
839 {
840 tensor = getVariableTensorByName(ts->GetName());
841 }
842
843 if (!tensor)
844 {
845 DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
846 tensor = TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
847
848 SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
849 ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size());
850
851 if (ts->GetVariable())
852 {
853 tensor->setIsVariable();
854 registerVariableTensor(tensor);
855 }
856 }
857
Eric Kunzee5e26762020-10-13 16:11:07 -0700858 // Enforce no duplicate tensors/tensor names
859 // O(N), but the number of tensors is small
860 for (TosaReference::Tensor* currTensor : tensors)
861 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800862 if (tensor == currTensor || currTensor->getName() == tensor->getName())
Eric Kunzee5e26762020-10-13 16:11:07 -0700863 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700864 FATAL_ERROR("SubgraphTraverser::addTensor(): Duplicate tensor or tensor name being added to graph: %s\n",
865 tensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700866 return 1;
867 }
868 }
869
Kevin Cheng550ccc52021-03-03 11:21:43 -0800870 tensors.push_back(tensor);
Eric Kunzee5e26762020-10-13 16:11:07 -0700871
Kevin Cheng550ccc52021-03-03 11:21:43 -0800872 if (tensor->getIsSubgraphInput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700873 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800874 inputTensors.push_back(tensor);
Eric Kunzee5e26762020-10-13 16:11:07 -0700875 }
876
Kevin Cheng550ccc52021-03-03 11:21:43 -0800877 if (tensor->getIsSubgraphOutput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700878 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800879 outputTensors.push_back(tensor);
Eric Kunzee5e26762020-10-13 16:11:07 -0700880 }
881
882 return 0;
883}
884int SubgraphTraverser::addNode(GraphNode* newNode)
885{
886 // Enforce no duplicate nodes
887 for (GraphNode* currNode : nodes)
888 {
889 if (currNode == newNode)
890 {
Tai Lycf84bc92023-09-07 20:49:09 +0000891 FATAL_ERROR("SubgraphTraverser::addNode(): duplicate node being added to graph");
Eric Kunzee5e26762020-10-13 16:11:07 -0700892 return 1;
893 }
894 }
895
896 nodes.push_back(newNode);
897
898 return 0;
899}
900
901TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& name) const
902{
Jerry Ge9e94af82022-10-27 09:57:00 -0700903 TosaReference::Tensor* res_tensor = nullptr;
904
Eric Kunzee5e26762020-10-13 16:11:07 -0700905 for (TosaReference::Tensor* currTensor : tensors)
906 {
907 if (currTensor->getName() == name)
908 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700909 res_tensor = currTensor;
910 return res_tensor;
Eric Kunzee5e26762020-10-13 16:11:07 -0700911 }
912 }
913
Jerry Ge9e94af82022-10-27 09:57:00 -0700914 if (parent_sgt)
915 {
916 for (TosaReference::Tensor* currTensor : parent_sgt->tensors)
917 {
918 if (currTensor->getName() == name)
919 {
920 res_tensor = currTensor;
921 res_tensor->setIsParentGraphOutput();
922 }
923 }
924 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700925
Jerry Ge9e94af82022-10-27 09:57:00 -0700926 if (!res_tensor)
927 {
928 WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str());
929 return nullptr;
930 }
931 return res_tensor;
Eric Kunzee5e26762020-10-13 16:11:07 -0700932}
933
934int SubgraphTraverser::linkTensorsAndNodes()
935{
936 // Nodes have a list of input/output tensor names
937 // For each node, read this list, link up the tensors with their inputs/outputs
938 for (GraphNode* currNode : nodes)
939 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700940 // Link inputs/consuming nodes
941 for (std::string& name : currNode->getInputNames())
942 {
943 TosaReference::Tensor* t = findTensorByName(name);
Kevin Cheng903763c2021-09-28 16:14:52 -0700944 SUBGRAPH_ERROR_IF(!t, "SubgraphTraverser::linkTensorsAndNodes(): Cannot find tensor %s in node %lu\n",
945 name.c_str(), currNode->getID());
946 SUBGRAPH_ERROR_IF(currNode->addInputTensor(t),
947 "SubgraphTraverser::linkTensorsAndNodes(): cannot link tensor %s to node %lu\n",
948 name.c_str(), currNode->getID());
949 SUBGRAPH_ERROR_IF(t->addConsumer(currNode),
950 "SubgraphTraverser::linkTensorsAndNodes(): cannot link consumer node %lu to tensor %s\n",
951 currNode->getID(), name.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700952 }
953
954 // Link outputs/producing nodes
955 for (std::string& name : currNode->getOutputNames())
956 {
957 TosaReference::Tensor* t = findTensorByName(name);
Kevin Cheng903763c2021-09-28 16:14:52 -0700958 SUBGRAPH_ERROR_IF(!t, "SubgraphTraverser::linkTensorsAndNodes(): Cannot find tensor %s in node %lu\n",
959 name.c_str(), currNode->getID());
960 SUBGRAPH_ERROR_IF(currNode->addOutputTensor(t),
961 "SubgraphTraverser::linkTensorsAndNodes(): cannot link tensor %s to node %lu\n",
962 name.c_str(), currNode->getID());
Eric Kunzee5e26762020-10-13 16:11:07 -0700963
Kevin Cheng903763c2021-09-28 16:14:52 -0700964 SUBGRAPH_ERROR_IF(
965 t->setProducer(currNode),
966 "SubgraphTraverser::linkTensorsAndNodes(): cannot link producer node %lu to tensor tensor %s\n",
967 currNode->getID(), name.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700968 }
969 }
970
971 return 0;
972}
973
974int SubgraphTraverser::validateGraph()
975{
976 // Need to make sure that:
977 // - each tensor is actually used
978 // - input and output tesnsors truly are just input and just output
979 // Graph building already determined that each node has found its input/output tensors
980
981 for (TosaReference::Tensor* currTensor : tensors)
982 {
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700983 // It's okay for block input tensor not being consumed by operators.
984 // This is common in control flow op execution.
985 if (!currTensor->getIsSubgraphInput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700986 {
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700987 if (!currTensor->getProducer() && currTensor->getConsumers().empty())
Eric Kunzee5e26762020-10-13 16:11:07 -0700988 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700989 WARNING("SubgraphTraverser::validateGraph(): TosaReference::Tensor %s has no producers or consumers\n",
Eric Kunzee5e26762020-10-13 16:11:07 -0700990 currTensor->getName().c_str());
991 return 1;
992 }
993 }
994
Eric Kunzee5e26762020-10-13 16:11:07 -0700995 if (g_func_config.tosa_profile == 0)
996 {
Tai Lya4d748b2023-03-28 22:06:56 +0000997 TOSA_REF_TYPE dtype = currTensor->getDtype();
Eric Kunzee5e26762020-10-13 16:11:07 -0700998
999 // Float-point disallowed
Tai Lya4d748b2023-03-28 22:06:56 +00001000 if (dtype == TOSA_REF_TYPE_FP32 || dtype == TOSA_REF_TYPE_FP16)
Eric Kunzee5e26762020-10-13 16:11:07 -07001001 {
Kevin Cheng903763c2021-09-28 16:14:52 -07001002 WARNING("SubgraphTraverser::validateGraph(): TOSA Base Inference profile selected: All floating point "
1003 "disabled, but %s tensor %s found\n",
Tai Lya4d748b2023-03-28 22:06:56 +00001004 EnumNameTOSAREFTYPE(dtype), currTensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -07001005 return 1;
1006 }
1007 }
1008 else if (g_func_config.tosa_profile == 1 || g_func_config.tosa_profile == 2)
1009 {
1010 // Do nothing. All FP types allowed
1011 // Currently no implementation difference between Main Inference and Main Training modes
1012 }
1013 else
1014 {
Kevin Cheng903763c2021-09-28 16:14:52 -07001015 FATAL_ERROR("SubgraphTraverser::validateGraph(): TOSA profile not recognized: %d",
1016 g_func_config.tosa_profile);
Eric Kunzee5e26762020-10-13 16:11:07 -07001017 }
1018 }
1019
1020 for (GraphNode* currNode : nodes)
1021 {
Kevin Cheng903763c2021-09-28 16:14:52 -07001022 SUBGRAPH_ERROR_IF(currNode->checkTensorAttributes(),
1023 "SubgraphTraverser::validateGraph(): TosaReference::Tensor attribute check failed");
Eric Kunzee5e26762020-10-13 16:11:07 -07001024 }
1025
1026 if (outputTensors.size() <= 0)
1027 {
1028 DEBUG_MED(GT, "Graph output tensor empty");
1029 return 0;
1030 }
1031
1032 return 0;
1033}
1034
1035int SubgraphTraverser::dumpGraph(FILE* out) const
1036{
1037 int i = 0;
1038
1039 fprintf(out, "Full graph dump:\n");
1040 for (GraphNode* currNode : nodes)
1041 {
1042 fprintf(out, "Node [%d]: ", i++);
1043 currNode->dumpNode(out);
1044 }
1045
1046 return 0;
1047}
1048
1049int SubgraphTraverser::evaluateAll()
1050{
1051 // evaluation loop
1052 while (!isFullyEvaluated())
1053 {
1054 if (evaluateNextNode())
1055 {
1056 return 1;
1057 }
1058 }
1059
1060 return 0;
1061}