blob: 745213e7195a23b268c9202d37c6d305f2f3a2fe [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "subgraph_traverser.h"
James Ward24dbc422022-10-19 12:20:31 +010017#include "arith_util.h"
Jerry Ge9c9c8da2023-07-19 23:08:16 +000018#include "tosa_model_types.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070019
Kevin Cheng903763c2021-09-28 16:14:52 -070020#ifndef SUBGRAPH_ERROR_IF
21#define SUBGRAPH_ERROR_IF(COND, fmt, ...) \
22 if ((COND)) \
23 { \
24 if (this->getGraphStatus() != GraphStatus::TOSA_UNPREDICTABLE) \
25 { \
26 this->setGraphStatus(GraphStatus::TOSA_ERROR); \
27 } \
28 fprintf(g_func_debug.func_debug_file, COL_FATAL("SUBGRAPH_ERROR_IF() fails AT %s:%d %s(): (%s)\n"), __FILE__, \
29 __LINE__, __func__, #COND); \
30 fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
31 func_print_backtrace(g_func_debug.func_debug_file); \
32 return 1; \
33 }
34#endif
35
Eric Kunzee5e26762020-10-13 16:11:07 -070036using namespace TosaReference;
37using namespace Eigen;
38using namespace tosa;
39
Jerry Ge9c9c8da2023-07-19 23:08:16 +000040SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block,
41 TosaSerializationHandler* _tsh,
42 SubgraphTraverser* _parent_sgt)
Eric Kunzee5e26762020-10-13 16:11:07 -070043{
Jerry Ge9e94af82022-10-27 09:57:00 -070044
Kevin Chengacb550f2021-06-29 15:32:19 -070045 graph_status = GraphStatus::TOSA_VALID;
Jerry Ge9c9c8da2023-07-19 23:08:16 +000046 block = _block;
Eric Kunzee5e26762020-10-13 16:11:07 -070047
Jerry Ge9c9c8da2023-07-19 23:08:16 +000048 tsh = _tsh;
Jerry Ge9e94af82022-10-27 09:57:00 -070049 parent_sgt = _parent_sgt;
Eric Kunzee5e26762020-10-13 16:11:07 -070050 tensors.clear();
51 nodes.clear();
52 nextNodeList.clear();
53}
54
55SubgraphTraverser::~SubgraphTraverser()
56{
57 nextNodeList.clear();
58
59 for (GraphNode* n : nodes)
60 {
61 delete n;
62 }
63 nodes.clear();
64
65 for (TosaReference::Tensor* t : tensors)
66 {
Tai Ly47625642023-09-07 20:49:09 +000067 if (t->getIsVariable() && parent_sgt)
68 {
69 // variable tensors are owned by top level sgt
70 continue;
71 }
Eric Kunzee5e26762020-10-13 16:11:07 -070072 if (t->is_allocated())
73 {
74 t->deallocate();
75 }
76 delete t;
77 }
78 tensors.clear();
79}
80
81int SubgraphTraverser::getNumInputTensors() const
82{
83 return inputTensors.size();
84}
85
86TosaReference::Tensor* SubgraphTraverser::getInputTensor(const unsigned int idx) const
87{
88 return inputTensors[idx];
89}
90
91TosaReference::Tensor* SubgraphTraverser::getInputTensorByName(const std::string name) const
92{
93 for (auto t : inputTensors)
94 {
95 if (t->getName() == name)
96 {
97 return t;
98 }
99 }
100
101 return nullptr;
102}
103
104int SubgraphTraverser::getNumOutputTensors() const
105{
106 return outputTensors.size();
107}
108
109TosaReference::Tensor* SubgraphTraverser::getOutputTensor(const unsigned int idx) const
110{
111 return outputTensors[idx];
112}
113
114TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::string name) const
115{
116 for (auto t : outputTensors)
117 {
118 if (t->getName() == name)
119 {
120 return t;
121 }
122 }
123
124 return nullptr;
125}
126
Tai Ly47625642023-09-07 20:49:09 +0000127int SubgraphTraverser::getNumVariableTensors() const
128{
129 return variableTensors.size();
130}
131
132TosaReference::Tensor* SubgraphTraverser::getVariableTensor(const unsigned int idx) const
133{
134 return variableTensors[idx];
135}
136
137// find variable tensor by name in top level sgt's @a variableTensors
138TosaReference::Tensor* SubgraphTraverser::getVariableTensorByName(const std::string name) const
139{
140 // variable tensors are owned by top level sgt
141 if (parent_sgt)
142 {
143 return parent_sgt->getVariableTensorByName(name);
144 }
145
146 for (auto t : variableTensors)
147 {
148 if (t->getName() == name)
149 {
150 return t;
151 }
152 }
153
154 return nullptr;
155}
156
157// add variable tensor to top level sgt's @a variableTensors
158int SubgraphTraverser::registerVariableTensor(Tensor* tensor)
159{
160 SUBGRAPH_ERROR_IF(!tensor->getIsVariable(),
161 "SubgraphTraverser::registerVariableTensor(): tensor %s is not a variable",
162 tensor->getName().c_str());
163 // variable tensors are owned by top level sgt
164 if (parent_sgt)
165 {
166 return parent_sgt->registerVariableTensor(tensor);
167 }
168 variableTensors.push_back(tensor);
169 return 0;
170}
171
Eric Kunzee5e26762020-10-13 16:11:07 -0700172int SubgraphTraverser::initializeGraph()
173{
Eric Kunzee5e26762020-10-13 16:11:07 -0700174 int idx = 0;
Kevin Chengc72b59c2021-09-29 16:57:55 -0700175
Jerry Ge9e94af82022-10-27 09:57:00 -0700176 std::vector<TosaSerializationTensor*> ser_tensor_vec;
177 // Get all the serialized tensors from TosaSerializationHandler.
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100178 if (tsh)
Jerry Ge9e94af82022-10-27 09:57:00 -0700179 {
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100180 for (auto region : tsh->GetRegions())
Jerry Ge9e94af82022-10-27 09:57:00 -0700181 {
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100182 for (auto block : region->GetBlocks())
Tai Ly4e9a9772023-03-16 22:24:05 +0000183 {
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100184 for (auto ser_tensor : block->GetTensors())
185 {
186 ser_tensor_vec.push_back(ser_tensor);
187 }
Tai Ly4e9a9772023-03-16 22:24:05 +0000188 }
Jerry Ge9e94af82022-10-27 09:57:00 -0700189 }
190 }
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100191 else
192 {
193 for (auto ser_tensor : block->GetTensors())
194 {
195 ser_tensor_vec.push_back(ser_tensor);
196 }
197 }
Jerry Ge9e94af82022-10-27 09:57:00 -0700198
199 std::vector<GraphNode*> non_const_node_vec;
Eric Kunzee5e26762020-10-13 16:11:07 -0700200 for (auto op : block->GetOperators())
201 {
202 // translated TosaSerializationOperator to GraphNode
Tai Lya4d748b2023-03-28 22:06:56 +0000203 TOSA_REF_TYPE input_dtype = TOSA_REF_TYPE_UNKNOWN;
204 TOSA_REF_TYPE output_dtype = TOSA_REF_TYPE_UNKNOWN;
205 TOSA_REF_TYPE weight_dtype = TOSA_REF_TYPE_UNKNOWN;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000206 uint32_t input_rank = 0;
207 uint32_t output_rank = 0;
208 uint32_t weight_rank = 0;
209 int32_t input_index = -1;
210 int32_t weight_index = -1;
Kevin Cheng550ccc52021-03-03 11:21:43 -0800211
212 switch (op->GetOp())
Eric Kunzee5e26762020-10-13 16:11:07 -0700213 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800214 case Op_CONV2D:
Kevin Cheng1533b852021-09-01 12:51:58 -0700215 case Op_CONV3D:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800216 case Op_DEPTHWISE_CONV2D:
217 case Op_TRANSPOSE_CONV2D:
218 case Op_FULLY_CONNECTED:
219 input_index = 0;
220 weight_index = 1;
221 break;
222 case Op_SELECT:
223 input_index = 1;
224 break;
225 default:
226 if (!op->GetInputTensorNames().empty())
227 input_index = 0;
228 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700229 }
230
Kevin Cheng550ccc52021-03-03 11:21:43 -0800231 if (input_index != -1)
Kevin Chengdf862692021-02-22 15:22:22 -0800232 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700233 SUBGRAPH_ERROR_IF(
234 (size_t)input_index >= op->GetInputTensorNames().size(),
235 "SubgraphTraverser::initializeGraph(): Op=%s, input_index %d must be within [0, num_input - 1]",
236 EnumNamesOp()[op->GetOp()], input_index);
Kevin Chengdf862692021-02-22 15:22:22 -0800237
Kevin Cheng550ccc52021-03-03 11:21:43 -0800238 std::string input_name = op->GetInputTensorNames()[input_index];
Jerry Ge9e94af82022-10-27 09:57:00 -0700239 TosaSerializationTensor* input_tensor = nullptr;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000240 for (auto ser_tensor : ser_tensor_vec)
241 {
242 if (ser_tensor->GetName() == input_name)
243 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700244 input_tensor = ser_tensor;
245 }
246 }
247
Kevin Cheng903763c2021-09-28 16:14:52 -0700248 SUBGRAPH_ERROR_IF(
249 !input_tensor,
250 "SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler",
251 input_name.c_str());
Tai Lya4d748b2023-03-28 22:06:56 +0000252 input_dtype = ConvertDType(input_tensor->GetDtype());
Kevin Cheng550ccc52021-03-03 11:21:43 -0800253 input_rank = input_tensor->GetShape().size();
Kevin Chengdf862692021-02-22 15:22:22 -0800254 }
255
Kevin Cheng550ccc52021-03-03 11:21:43 -0800256 if (weight_index != -1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700257 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700258 SUBGRAPH_ERROR_IF(
259 (size_t)weight_index >= op->GetInputTensorNames().size(),
260 "SubgraphTraverser::initializeGraph(): Op=%s, weight_index %d must be within [0, num_input - 1]",
261 EnumNamesOp()[op->GetOp()], weight_index);
Kevin Cheng550ccc52021-03-03 11:21:43 -0800262 std::string weight_name = op->GetInputTensorNames()[weight_index];
Jerry Ge9e94af82022-10-27 09:57:00 -0700263 TosaSerializationTensor* weight_tensor = nullptr;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000264 for (auto ser_tensor : ser_tensor_vec)
265 {
266 if (ser_tensor->GetName() == weight_name)
267 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700268 weight_tensor = ser_tensor;
269 }
270 }
271
Kevin Cheng903763c2021-09-28 16:14:52 -0700272 SUBGRAPH_ERROR_IF(
273 !weight_tensor,
274 "SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler",
275 weight_name.c_str());
Tai Lya4d748b2023-03-28 22:06:56 +0000276 weight_dtype = ConvertDType(weight_tensor->GetDtype());
Kevin Cheng550ccc52021-03-03 11:21:43 -0800277 weight_rank = weight_tensor->GetShape().size();
Eric Kunzee5e26762020-10-13 16:11:07 -0700278 }
279
Kevin Cheng478101b2021-10-04 10:43:14 -0700280 SUBGRAPH_ERROR_IF(op->GetOutputTensorNames().size() == 0,
281 "SubgraphTraverser::initializeGraph(): Op=%s must have at least one output tensor.",
282 EnumNamesOp()[op->GetOp()]);
Kevin Cheng550ccc52021-03-03 11:21:43 -0800283 std::string output_name = op->GetOutputTensorNames()[0];
284 TosaSerializationTensor* output_tensor = block->GetTensorByName(output_name);
Kevin Cheng903763c2021-09-28 16:14:52 -0700285 SUBGRAPH_ERROR_IF(
286 !output_tensor,
287 "SubgraphTraverser::initializeGraph(): fail to get output tensor %s from TosaSerializationHandler",
288 output_name.c_str());
Tai Lya4d748b2023-03-28 22:06:56 +0000289 output_dtype = ConvertDType(output_tensor->GetDtype());
Kevin Cheng550ccc52021-03-03 11:21:43 -0800290 output_rank = output_tensor->GetShape().size();
291
Eric Kunzee5e26762020-10-13 16:11:07 -0700292 DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
293 EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
294
Jerry Ge9e94af82022-10-27 09:57:00 -0700295 GraphNode* node = nullptr;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000296 if (this->parent_sgt)
297 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700298 node = OpFactory::newOp(this->parent_sgt, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000299 input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
Jerry Ge9e94af82022-10-27 09:57:00 -0700300 node->setInMainBlock(false);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000301 }
302 else
303 {
304 node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, input_rank,
305 output_dtype, output_rank, weight_dtype, weight_rank);
306 if (node)
307 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700308 node->setInMainBlock(true);
309 }
310 }
311
Kevin Cheng550ccc52021-03-03 11:21:43 -0800312 if (!node)
Eric Kunzee5e26762020-10-13 16:11:07 -0700313 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800314 if (weight_index == -1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700315 {
316 fprintf(g_func_debug.func_debug_file,
Kevin Cheng903763c2021-09-28 16:14:52 -0700317 "SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d) "
318 "-> (%s rank %d)",
Tai Lya4d748b2023-03-28 22:06:56 +0000319 EnumNamesOp()[op->GetOp()], EnumNameTOSAREFTYPE(input_dtype), input_rank,
320 EnumNameTOSAREFTYPE(output_dtype), output_rank);
Eric Kunzee5e26762020-10-13 16:11:07 -0700321 }
322 else
323 {
324 fprintf(g_func_debug.func_debug_file,
Kevin Cheng903763c2021-09-28 16:14:52 -0700325 "SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d), "
326 "weight=(%s rank %d) -> (%s rank %d)",
Tai Lya4d748b2023-03-28 22:06:56 +0000327 EnumNamesOp()[op->GetOp()], EnumNameTOSAREFTYPE(input_dtype), input_rank,
328 EnumNameTOSAREFTYPE(weight_dtype), weight_rank, EnumNameTOSAREFTYPE(output_dtype), output_rank);
Eric Kunzee5e26762020-10-13 16:11:07 -0700329 }
330
Kevin Cheng550ccc52021-03-03 11:21:43 -0800331 for (auto& ts : op->GetInputTensorNames())
Eric Kunzee5e26762020-10-13 16:11:07 -0700332 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700333 fprintf(g_func_debug.func_debug_file, "SubgraphTraverser::initializeGraph(): Input: %s\n", ts.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700334 }
335
Kevin Cheng550ccc52021-03-03 11:21:43 -0800336 for (auto& ts : op->GetOutputTensorNames())
Eric Kunzee5e26762020-10-13 16:11:07 -0700337 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700338 fprintf(g_func_debug.func_debug_file, "SubgraphTraverser::initializeGraph(): Output: %s\n", ts.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700339 }
Kevin Cheng903763c2021-09-28 16:14:52 -0700340 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported operation type or rank.");
Eric Kunzee5e26762020-10-13 16:11:07 -0700341 }
342
Kevin Chengc72b59c2021-09-29 16:57:55 -0700343 // Elementwise operator might set TOSA_ERROR when registering lambda function when creating the op.
344 // Check graph status after the op being constructed.
345 SUBGRAPH_ERROR_IF(getGraphStatus() == GraphStatus::TOSA_ERROR,
346 "SubgraphTraverser::initializeGraph(): Op %8s triggered ERROR_IF() when constructing the op.",
347 EnumNamesOp()[op->GetOp()]);
348
Kevin Cheng550ccc52021-03-03 11:21:43 -0800349 for (auto& name : op->GetInputTensorNames())
Eric Kunzee5e26762020-10-13 16:11:07 -0700350 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800351 node->addInputName(name);
Kevin Chengc72b59c2021-09-29 16:57:55 -0700352 used_tensor_name_set.insert(name);
Eric Kunzee5e26762020-10-13 16:11:07 -0700353 }
354
355 for (auto name : op->GetOutputTensorNames())
356 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800357 node->addOutputName(name);
Kevin Chengc72b59c2021-09-29 16:57:55 -0700358 used_tensor_name_set.insert(name);
Eric Kunzee5e26762020-10-13 16:11:07 -0700359 }
360
Kevin Cheng550ccc52021-03-03 11:21:43 -0800361 addNode(node);
Eric Kunzee5e26762020-10-13 16:11:07 -0700362
363 // if node doesn't have any inputs (i.e. CONST)
364 // it should be ready for evaluation
Kevin Cheng550ccc52021-03-03 11:21:43 -0800365 if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList())
Eric Kunzee5e26762020-10-13 16:11:07 -0700366 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800367 addToNextNodeList(node);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000368 }
369 else if (!node->getInMainBlock())
370 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700371 non_const_node_vec.push_back(node);
Eric Kunzee5e26762020-10-13 16:11:07 -0700372 }
373
Tai Ly47625642023-09-07 20:49:09 +0000374 // Bug fix: add the ready node in main block for evaluation
375 if (node->hasAllInputsReady() && !node->getOnNextNodeList() && !node->getEvaluated())
376 {
377 addToNextNodeList(node);
378 }
379
Eric Kunzee5e26762020-10-13 16:11:07 -0700380 idx++;
381 }
382
383 for (auto ts : block->GetTensors())
384 {
Tai Ly47625642023-09-07 20:49:09 +0000385 addTensor(ts);
Kevin Chengcc61be32021-10-14 17:09:57 -0700386 }
387
388 DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str());
389 for (auto& input_name : block->GetInputs())
390 {
391 TosaReference::Tensor* tensor = findTensorByName(input_name);
392 DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str());
393 if (tensor)
394 {
395 tensor->setIsSubgraphInput();
396 inputTensors.push_back(tensor);
397 }
398 else
399 {
400 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find input tensor by name %s",
401 input_name.c_str());
402 }
403 }
404
405 DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str());
406 for (auto& output_name : block->GetOutputs())
407 {
408 TosaReference::Tensor* tensor = findTensorByName(output_name);
Jerry Ge9e94af82022-10-27 09:57:00 -0700409 DEBUG_INFO(GT, "output tensor name=%s", output_name.c_str());
Kevin Chengcc61be32021-10-14 17:09:57 -0700410 if (tensor)
411 {
412 tensor->setIsSubgraphOutput();
413 outputTensors.push_back(tensor);
414 }
415 else
416 {
417 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find output tensor by name %s",
418 output_name.c_str());
419 }
420 }
421
422 if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
423 {
424 dumpNextNodeList(g_func_debug.func_debug_file);
425 }
426
Jerry Ge9e94af82022-10-27 09:57:00 -0700427 // If the node is not in mainblock and not const
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000428 for (auto node : non_const_node_vec)
429 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700430 bool all_inputs_from_parent = true;
431 for (std::string& name : node->getInputNames())
432 {
433 TosaReference::Tensor* t = findTensorByName(name);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000434 if (!t->getIsParentGraphOutput())
435 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700436 all_inputs_from_parent = false;
437 }
438 }
439 // In the children block, when a node has all its inputs from parent
440 // block, we have to manually add this node to the evaluation list
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000441 if (all_inputs_from_parent && !node->getOnNextNodeList())
442 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700443 addToNextNodeList(node);
444 }
445 }
Kevin Chengcc61be32021-10-14 17:09:57 -0700446 return 0;
447}
448
Jerry Gee5cabbf2023-07-17 21:33:17 +0000449int SubgraphTraverser::allocateInputTensors()
Kevin Chengcc61be32021-10-14 17:09:57 -0700450{
Jerry Gee5cabbf2023-07-17 21:33:17 +0000451 auto input_tensor_names_vec = block->GetInputs();
452
453 for (auto input_tensor_name : input_tensor_names_vec)
Kevin Chengcc61be32021-10-14 17:09:57 -0700454 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000455 this->allocateTensor(input_tensor_name);
456 }
457
Tai Ly47625642023-09-07 20:49:09 +0000458 // allocate variable tensors if not already allocated
459 for (auto ts : block->GetTensors())
460 {
461 if (ts->GetVariable())
462 {
463 TosaReference::Tensor* tensor = findTensorByName(ts->GetName());
464 SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateInputTensors(): can't find tensor %s.",
465 ts->GetName().c_str());
466 if (!tensor->is_allocated())
467 {
468 DEBUG_INFO(GT, "Is a VariableTensor %s", ts->GetName().c_str());
469 this->allocateTensor(ts->GetName());
470 }
471 }
472 }
473
Jerry Gee5cabbf2023-07-17 21:33:17 +0000474 return 0;
475}
476
477int SubgraphTraverser::allocateTensor(std::string name)
478{
479 auto ts = block->GetTensorByName(name);
480
481 // Bail out if tensor is used and any of its dimension is invalid.
482 auto got = used_tensor_name_set.find(ts->GetName());
483 if (got != used_tensor_name_set.end())
484 {
485 uint32_t elements = 1;
486 for (auto& dim : ts->GetShape())
Kevin Chengacb550f2021-06-29 15:32:19 -0700487 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000488 if (dim <= 0)
Kevin Chengacb550f2021-06-29 15:32:19 -0700489 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000490 DEBUG_INFO(GT, "Failed to allocate tensor %s with invalid dimension of %d", ts->GetName().c_str(), dim);
491 this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
492 return 1;
493 }
494 if (dim > static_cast<int32_t>(TOSA_MAX_TENSOR_SIZE / elements))
495 {
496 // Size greather than maximum defined in spec
497 DEBUG_INFO(GT, "Tensor %s size is greater than allowed maximum", ts->GetName().c_str());
498 this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
499 return 1;
Kevin Chengacb550f2021-06-29 15:32:19 -0700500 }
501 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000502 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700503
Jerry Gee5cabbf2023-07-17 21:33:17 +0000504 TosaReference::Tensor* tensor = findTensorByName(ts->GetName());
505 SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateTensor(): can't find tensor %s.", ts->GetName().c_str());
Kevin Chengacb550f2021-06-29 15:32:19 -0700506
Jerry Gee5cabbf2023-07-17 21:33:17 +0000507 DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
508 if (tensor->allocate())
509 {
510 FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
511 }
512
513 if (!ts->GetData().empty())
514 {
Tai Ly47625642023-09-07 20:49:09 +0000515 if (ts->GetVariable() && g_func_config.initialize_variable_tensor_from_numpy)
516 return 0;
Jerry Gee5cabbf2023-07-17 21:33:17 +0000517 DEBUG_INFO(GT, "Setting data for tensor %s", tensor->getName().c_str());
518 auto serialization_dtype = ts->GetDtype();
519 switch (serialization_dtype)
Kevin Chengcc61be32021-10-14 17:09:57 -0700520 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000521 case DType_INT4: {
522 std::vector<int8_t> i4_data;
523 TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data);
524 std::vector<int32_t> i32_data(i4_data.begin(), i4_data.end());
525 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
Eric Kunzee5e26762020-10-13 16:11:07 -0700526 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000527 break;
528 case DType_INT8: {
529 std::vector<int8_t> i8_data;
530 TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data);
531 std::vector<int32_t> i32_data(i8_data.begin(), i8_data.end());
532 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
533 }
534 break;
535 case DType_INT16: {
536 std::vector<int16_t> i16_data;
537 TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data);
538 std::vector<int32_t> i32_data(i16_data.begin(), i16_data.end());
539 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
540 }
541 break;
542 case DType_INT32: {
543 std::vector<int32_t> i32_data;
544 TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data);
545 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
546 }
547 break;
Won Jeona21b2e82023-08-10 10:33:01 +0000548 case DType_INT48:
549 case DType_SHAPE: {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000550 std::vector<int64_t> i64_data;
551 TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
552 tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
553 }
554 break;
555 case DType_FP16: {
556 // Interpret f16 data as float
557 std::vector<float> f16_data;
558 TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data);
559 if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
560 {
561 std::vector<double> f64_data(f16_data.begin(), f16_data.end());
562 tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
563 }
564 else
565 {
566 tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
567 }
568 }
569 break;
570 case DType_BF16: {
571 std::vector<float> fp32_data;
572 TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
573 // Ensure valid bfloat16 stored in each float
574 for (auto f : fp32_data)
575 ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
576 if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
577 {
578 std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
579 tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
580 }
581 else
582 {
583 tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
584 }
585 }
586 break;
587 case DType_FP32: {
588 std::vector<float> fp32_data;
589 TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
590 if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
591 {
592 std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
593 tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
594 }
595 else
596 {
597 tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
598 }
599 }
600 break;
601 case DType_BOOL: {
602 std::vector<bool> bool_data;
603 TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data);
604
605 // std::vector<bool>::data() will return bit mask instead of array of bool array.
606 // Need to translate manually.
607 bool* bool_array = (bool*)calloc(bool_data.size(), sizeof(bool));
608 for (size_t i = 0; i < bool_data.size(); i++)
609 {
610 bool_array[i] = bool_data[i];
611 }
612 tensor->setTensorValueBool(bool_data.size(), bool_array);
613 }
614 break;
615 default:
616 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.",
617 EnumNameDType(ts->GetDtype()));
Eric Kunzee5e26762020-10-13 16:11:07 -0700618 }
Tai Ly47625642023-09-07 20:49:09 +0000619 tensor->setIsValid();
620 // Push ready consumers to the next node list
621 for (auto gn : tensor->getConsumers())
622 {
623 if (gn->hasAllInputsReady() && !gn->getOnNextNodeList() && !gn->getEvaluated())
624 {
625 addToNextNodeList(gn);
626 }
627 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700628 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700629 return 0;
630}
631
632int SubgraphTraverser::isFullyEvaluated() const
633{
634 return nextNodeList.empty();
635}
636
637GraphNode* SubgraphTraverser::getNextNode()
638{
639 GraphNode* nextNode = nextNodeList.front();
640 ASSERT_MSG(nextNode, "SubgraphTraverser::getNextNode(): called with empty next node list");
641 ASSERT_MSG(nextNode->getOnNextNodeList(),
642 "SubgraphTraverser::getNextNode(): internal state error: node is not listed as being on next node list");
643
644 nextNodeList.pop_front();
645
646 nextNode->clearOnNextNodeList();
647 return nextNode;
648}
649
650int SubgraphTraverser::addToNextNodeList(GraphNode* nextNode)
651{
652 ASSERT_MSG(nextNode, "SubgraphTraverser::addToNextNodeList(): called with no node");
653 ASSERT_MSG(!nextNode->getOnNextNodeList(),
654 "SubgraphTraverser::addToNextNodeList(): internal state error: node is already on next node list");
655
656 nextNode->setOnNextNodeList();
657 nextNodeList.push_back(nextNode);
658
659 return 0;
660}
661
662int SubgraphTraverser::evaluateNextNode()
663{
664 if (isFullyEvaluated())
665 return 0;
666
667 GraphNode* currNode = getNextNode();
668
669 DEBUG_INFO(GT, "Evaluating node_%03lu, %8s, output tensor=%s", currNode->getID(), EnumNamesOp()[currNode->getOp()],
670 currNode->getOutputNames()[0].c_str());
671
672 // Sanity check for never-ending loops
673 if (currNode->getEvalCount() >= MAX_EVAL_COUNT && (currNode->getEvalCount() % MAX_EVAL_COUNT) == 0)
674 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700675 WARNING("SubgraphTraverser::evaluateNextNode(): Node %lu has been evaluated %d times. Loop suspected.",
676 currNode->getID(), currNode->getEvalCount());
Eric Kunzee5e26762020-10-13 16:11:07 -0700677 }
678
Kevin Cheng550ccc52021-03-03 11:21:43 -0800679 for (auto tensor : currNode->getOutputs())
Eric Kunzee5e26762020-10-13 16:11:07 -0700680 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800681 if (!tensor->is_allocated())
Jerry Gee5cabbf2023-07-17 21:33:17 +0000682 {
683 if (this->allocateTensor(tensor->getName()))
Eric Kunzee5e26762020-10-13 16:11:07 -0700684 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700685 FATAL_ERROR("SubgraphTraverser::evaluateNextNode(): Failed to allocate Eigen tensor %s",
686 tensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700687 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000688 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700689 }
690
691 if (currNode->eval())
692 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700693 WARNING("SubgraphTraverser::evaluateNextNode(): Failed to evaluate node: %lu", currNode->getID());
Kevin Chengacb550f2021-06-29 15:32:19 -0700694 return 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700695 }
696
Tai Ly47625642023-09-07 20:49:09 +0000697 currNode->setEvaluated();
698
Eric Kunzee5e26762020-10-13 16:11:07 -0700699 // free input tensor if all of its consumers have all of their outputs ready and it's not block's output
Jerry Gee5cabbf2023-07-17 21:33:17 +0000700 for (auto tensor : currNode->getInputs())
701 {
702 bool in_use = false;
703
704 auto tensor_check = findTensorByName(tensor->getName());
705 if (tensor_check->getIsParentGraphOutput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700706 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000707 // if it's parent's block output tensor, we can't free it
708 continue;
709 }
Jerry Ge9e94af82022-10-27 09:57:00 -0700710
Tai Ly47625642023-09-07 20:49:09 +0000711 if (tensor->getIsVariable())
712 {
713 // if tensor is a Variable, we cannot free it
714 continue;
715 }
716
Jerry Gee5cabbf2023-07-17 21:33:17 +0000717 for (auto node : tensor->getConsumers())
718 {
719 // If the node is inside a loop, the input tensor is still needed
720 if (!node->hasAllOutputsReady())
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000721 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000722 in_use = true;
Jerry Ge9e94af82022-10-27 09:57:00 -0700723 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700724 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000725
726 for (auto name : block->GetOutputs())
727 {
728 if (name == tensor->getName())
729 {
730 in_use = true;
731 }
732 }
733
734 if (!in_use)
735 {
736 tensor->deallocate();
737 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700738 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000739
Eric Kunzee5e26762020-10-13 16:11:07 -0700740 // Search the output tensors of this node to see if
741 // there are now new ready nodes available from completing this node
742 for (TosaReference::Tensor* tensor : currNode->getOutputs())
743 {
744 for (GraphNode* node : tensor->getConsumers())
745 {
Tai Ly47625642023-09-07 20:49:09 +0000746 if (!node->getOnNextNodeList() && node->hasAllInputsReady() && !node->getEvaluated())
Eric Kunzee5e26762020-10-13 16:11:07 -0700747 {
748 addToNextNodeList(node);
749 }
750 }
751 }
752
753 if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
754 {
755 dumpNextNodeList(g_func_debug.func_debug_file);
756 }
757
758 if (g_func_config.dump_intermediates)
759 {
760 currNode->dumpNode(g_func_debug.func_debug_file);
761 for (auto outs : currNode->getOutputs())
762 {
763 outs->dumpTensorParams(g_func_debug.func_debug_file);
764 outs->dumpTensor(g_func_debug.func_debug_file);
765 fprintf(g_func_debug.func_debug_file, "\n");
766 }
767 }
768
769 return 0;
770}
771
772int SubgraphTraverser::dumpNextNodeList(FILE* out) const
773{
774
775 // Dump next node list
776 fprintf(out, "Next node list\n");
777
778 if (nextNodeList.empty())
779 {
780 fprintf(out, "<empty>\n");
781 }
782
783 for (auto gn : nextNodeList)
784 {
785 gn->dumpNode(out);
786 }
787
788 fprintf(out, "Done.\n");
789 return 0;
790}
791
792int SubgraphTraverser::clearAllNodeMarkings()
793{
794 for (GraphNode* currNode : nodes)
795 {
796 currNode->clearNodeMarked();
797 }
798
799 return false;
800}
801
Tai Ly47625642023-09-07 20:49:09 +0000802int SubgraphTraverser::addTensor(const TosaSerializationTensor* ts)
Eric Kunzee5e26762020-10-13 16:11:07 -0700803{
Tai Ly47625642023-09-07 20:49:09 +0000804 TosaReference::Tensor* tensor = nullptr;
805
806 // variable tensors are shared: make new tensor only if not found
807 if (ts->GetVariable())
808 {
809 tensor = getVariableTensorByName(ts->GetName());
810 }
811
812 if (!tensor)
813 {
814 DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
815 tensor = TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
816
817 SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
818 ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size());
819
820 if (ts->GetVariable())
821 {
822 tensor->setIsVariable();
823 registerVariableTensor(tensor);
824 }
825 }
826
Eric Kunzee5e26762020-10-13 16:11:07 -0700827 // Enforce no duplicate tensors/tensor names
828 // O(N), but the number of tensors is small
829 for (TosaReference::Tensor* currTensor : tensors)
830 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800831 if (tensor == currTensor || currTensor->getName() == tensor->getName())
Eric Kunzee5e26762020-10-13 16:11:07 -0700832 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700833 FATAL_ERROR("SubgraphTraverser::addTensor(): Duplicate tensor or tensor name being added to graph: %s\n",
834 tensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700835 return 1;
836 }
837 }
838
Kevin Cheng550ccc52021-03-03 11:21:43 -0800839 tensors.push_back(tensor);
Eric Kunzee5e26762020-10-13 16:11:07 -0700840
Kevin Cheng550ccc52021-03-03 11:21:43 -0800841 if (tensor->getIsSubgraphInput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700842 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800843 inputTensors.push_back(tensor);
Eric Kunzee5e26762020-10-13 16:11:07 -0700844 }
845
Kevin Cheng550ccc52021-03-03 11:21:43 -0800846 if (tensor->getIsSubgraphOutput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700847 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800848 outputTensors.push_back(tensor);
Eric Kunzee5e26762020-10-13 16:11:07 -0700849 }
850
851 return 0;
852}
853int SubgraphTraverser::addNode(GraphNode* newNode)
854{
855 // Enforce no duplicate nodes
856 for (GraphNode* currNode : nodes)
857 {
858 if (currNode == newNode)
859 {
Tai Ly47625642023-09-07 20:49:09 +0000860 FATAL_ERROR("SubgraphTraverser::addNode(): duplicate node being added to graph");
Eric Kunzee5e26762020-10-13 16:11:07 -0700861 return 1;
862 }
863 }
864
865 nodes.push_back(newNode);
866
867 return 0;
868}
869
870TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& name) const
871{
Jerry Ge9e94af82022-10-27 09:57:00 -0700872 TosaReference::Tensor* res_tensor = nullptr;
873
Eric Kunzee5e26762020-10-13 16:11:07 -0700874 for (TosaReference::Tensor* currTensor : tensors)
875 {
876 if (currTensor->getName() == name)
877 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700878 res_tensor = currTensor;
879 return res_tensor;
Eric Kunzee5e26762020-10-13 16:11:07 -0700880 }
881 }
882
Jerry Ge9e94af82022-10-27 09:57:00 -0700883 if (parent_sgt)
884 {
885 for (TosaReference::Tensor* currTensor : parent_sgt->tensors)
886 {
887 if (currTensor->getName() == name)
888 {
889 res_tensor = currTensor;
890 res_tensor->setIsParentGraphOutput();
891 }
892 }
893 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700894
Jerry Ge9e94af82022-10-27 09:57:00 -0700895 if (!res_tensor)
896 {
897 WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str());
898 return nullptr;
899 }
900 return res_tensor;
Eric Kunzee5e26762020-10-13 16:11:07 -0700901}
902
903int SubgraphTraverser::linkTensorsAndNodes()
904{
905 // Nodes have a list of input/output tensor names
906 // For each node, read this list, link up the tensors with their inputs/outputs
907 for (GraphNode* currNode : nodes)
908 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700909 // Link inputs/consuming nodes
910 for (std::string& name : currNode->getInputNames())
911 {
912 TosaReference::Tensor* t = findTensorByName(name);
Kevin Cheng903763c2021-09-28 16:14:52 -0700913 SUBGRAPH_ERROR_IF(!t, "SubgraphTraverser::linkTensorsAndNodes(): Cannot find tensor %s in node %lu\n",
914 name.c_str(), currNode->getID());
915 SUBGRAPH_ERROR_IF(currNode->addInputTensor(t),
916 "SubgraphTraverser::linkTensorsAndNodes(): cannot link tensor %s to node %lu\n",
917 name.c_str(), currNode->getID());
918 SUBGRAPH_ERROR_IF(t->addConsumer(currNode),
919 "SubgraphTraverser::linkTensorsAndNodes(): cannot link consumer node %lu to tensor %s\n",
920 currNode->getID(), name.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700921 }
922
923 // Link outputs/producing nodes
924 for (std::string& name : currNode->getOutputNames())
925 {
926 TosaReference::Tensor* t = findTensorByName(name);
Kevin Cheng903763c2021-09-28 16:14:52 -0700927 SUBGRAPH_ERROR_IF(!t, "SubgraphTraverser::linkTensorsAndNodes(): Cannot find tensor %s in node %lu\n",
928 name.c_str(), currNode->getID());
929 SUBGRAPH_ERROR_IF(currNode->addOutputTensor(t),
930 "SubgraphTraverser::linkTensorsAndNodes(): cannot link tensor %s to node %lu\n",
931 name.c_str(), currNode->getID());
Eric Kunzee5e26762020-10-13 16:11:07 -0700932
Kevin Cheng903763c2021-09-28 16:14:52 -0700933 SUBGRAPH_ERROR_IF(
934 t->setProducer(currNode),
935 "SubgraphTraverser::linkTensorsAndNodes(): cannot link producer node %lu to tensor tensor %s\n",
936 currNode->getID(), name.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700937 }
938 }
939
940 return 0;
941}
942
943int SubgraphTraverser::validateGraph()
944{
945 // Need to make sure that:
946 // - each tensor is actually used
947 // - input and output tesnsors truly are just input and just output
948 // Graph building already determined that each node has found its input/output tensors
949
950 for (TosaReference::Tensor* currTensor : tensors)
951 {
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700952 // It's okay for block input tensor not being consumed by operators.
953 // This is common in control flow op execution.
954 if (!currTensor->getIsSubgraphInput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700955 {
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700956 if (!currTensor->getProducer() && currTensor->getConsumers().empty())
Eric Kunzee5e26762020-10-13 16:11:07 -0700957 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700958 WARNING("SubgraphTraverser::validateGraph(): TosaReference::Tensor %s has no producers or consumers\n",
Eric Kunzee5e26762020-10-13 16:11:07 -0700959 currTensor->getName().c_str());
960 return 1;
961 }
962 }
963
Eric Kunzee5e26762020-10-13 16:11:07 -0700964 if (g_func_config.tosa_profile == 0)
965 {
Tai Lya4d748b2023-03-28 22:06:56 +0000966 TOSA_REF_TYPE dtype = currTensor->getDtype();
Eric Kunzee5e26762020-10-13 16:11:07 -0700967
968 // Float-point disallowed
Tai Lya4d748b2023-03-28 22:06:56 +0000969 if (dtype == TOSA_REF_TYPE_FP32 || dtype == TOSA_REF_TYPE_FP16)
Eric Kunzee5e26762020-10-13 16:11:07 -0700970 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700971 WARNING("SubgraphTraverser::validateGraph(): TOSA Base Inference profile selected: All floating point "
972 "disabled, but %s tensor %s found\n",
Tai Lya4d748b2023-03-28 22:06:56 +0000973 EnumNameTOSAREFTYPE(dtype), currTensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700974 return 1;
975 }
976 }
977 else if (g_func_config.tosa_profile == 1 || g_func_config.tosa_profile == 2)
978 {
979 // Do nothing. All FP types allowed
980 // Currently no implementation difference between Main Inference and Main Training modes
981 }
982 else
983 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700984 FATAL_ERROR("SubgraphTraverser::validateGraph(): TOSA profile not recognized: %d",
985 g_func_config.tosa_profile);
Eric Kunzee5e26762020-10-13 16:11:07 -0700986 }
987 }
988
989 for (GraphNode* currNode : nodes)
990 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700991 SUBGRAPH_ERROR_IF(currNode->checkTensorAttributes(),
992 "SubgraphTraverser::validateGraph(): TosaReference::Tensor attribute check failed");
Eric Kunzee5e26762020-10-13 16:11:07 -0700993 }
994
995 if (outputTensors.size() <= 0)
996 {
997 DEBUG_MED(GT, "Graph output tensor empty");
998 return 0;
999 }
1000
1001 return 0;
1002}
1003
1004int SubgraphTraverser::dumpGraph(FILE* out) const
1005{
1006 int i = 0;
1007
1008 fprintf(out, "Full graph dump:\n");
1009 for (GraphNode* currNode : nodes)
1010 {
1011 fprintf(out, "Node [%d]: ", i++);
1012 currNode->dumpNode(out);
1013 }
1014
1015 return 0;
1016}
1017
1018int SubgraphTraverser::evaluateAll()
1019{
1020 // evaluation loop
1021 while (!isFullyEvaluated())
1022 {
1023 if (evaluateNextNode())
1024 {
1025 return 1;
1026 }
1027 }
1028
1029 return 0;
1030}