blob: 5675be94cb5bad4ca9891d0322e4886a8101afe5 [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "subgraph_traverser.h"
James Ward24dbc422022-10-19 12:20:31 +010017#include "arith_util.h"
Jerry Ge9c9c8da2023-07-19 23:08:16 +000018#include "tosa_model_types.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070019
Kevin Cheng903763c2021-09-28 16:14:52 -070020#ifndef SUBGRAPH_ERROR_IF
21#define SUBGRAPH_ERROR_IF(COND, fmt, ...) \
22 if ((COND)) \
23 { \
24 if (this->getGraphStatus() != GraphStatus::TOSA_UNPREDICTABLE) \
25 { \
26 this->setGraphStatus(GraphStatus::TOSA_ERROR); \
27 } \
28 fprintf(g_func_debug.func_debug_file, COL_FATAL("SUBGRAPH_ERROR_IF() fails AT %s:%d %s(): (%s)\n"), __FILE__, \
29 __LINE__, __func__, #COND); \
30 fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
31 func_print_backtrace(g_func_debug.func_debug_file); \
32 return 1; \
33 }
34#endif
35
Eric Kunzee5e26762020-10-13 16:11:07 -070036using namespace TosaReference;
37using namespace Eigen;
38using namespace tosa;
39
Jerry Ge9c9c8da2023-07-19 23:08:16 +000040SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block,
41 TosaSerializationHandler* _tsh,
42 SubgraphTraverser* _parent_sgt)
Eric Kunzee5e26762020-10-13 16:11:07 -070043{
Jerry Ge9e94af82022-10-27 09:57:00 -070044
Kevin Chengacb550f2021-06-29 15:32:19 -070045 graph_status = GraphStatus::TOSA_VALID;
Jerry Ge9c9c8da2023-07-19 23:08:16 +000046 block = _block;
Eric Kunzee5e26762020-10-13 16:11:07 -070047
Jerry Ge9c9c8da2023-07-19 23:08:16 +000048 tsh = _tsh;
Jerry Ge9e94af82022-10-27 09:57:00 -070049 parent_sgt = _parent_sgt;
Eric Kunzee5e26762020-10-13 16:11:07 -070050 tensors.clear();
51 nodes.clear();
52 nextNodeList.clear();
53}
54
55SubgraphTraverser::~SubgraphTraverser()
56{
57 nextNodeList.clear();
58
59 for (GraphNode* n : nodes)
60 {
61 delete n;
62 }
63 nodes.clear();
64
65 for (TosaReference::Tensor* t : tensors)
66 {
67 if (t->is_allocated())
68 {
69 t->deallocate();
70 }
71 delete t;
72 }
73 tensors.clear();
74}
75
76int SubgraphTraverser::getNumInputTensors() const
77{
78 return inputTensors.size();
79}
80
81TosaReference::Tensor* SubgraphTraverser::getInputTensor(const unsigned int idx) const
82{
83 return inputTensors[idx];
84}
85
86TosaReference::Tensor* SubgraphTraverser::getInputTensorByName(const std::string name) const
87{
88 for (auto t : inputTensors)
89 {
90 if (t->getName() == name)
91 {
92 return t;
93 }
94 }
95
96 return nullptr;
97}
98
99int SubgraphTraverser::getNumOutputTensors() const
100{
101 return outputTensors.size();
102}
103
104TosaReference::Tensor* SubgraphTraverser::getOutputTensor(const unsigned int idx) const
105{
106 return outputTensors[idx];
107}
108
109TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::string name) const
110{
111 for (auto t : outputTensors)
112 {
113 if (t->getName() == name)
114 {
115 return t;
116 }
117 }
118
119 return nullptr;
120}
121
122int SubgraphTraverser::initializeGraph()
123{
Eric Kunzee5e26762020-10-13 16:11:07 -0700124 int idx = 0;
Kevin Chengc72b59c2021-09-29 16:57:55 -0700125
Jerry Ge9e94af82022-10-27 09:57:00 -0700126 std::vector<TosaSerializationTensor*> ser_tensor_vec;
127 // Get all the serialized tensors from TosaSerializationHandler.
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100128 if (tsh)
Jerry Ge9e94af82022-10-27 09:57:00 -0700129 {
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100130 for (auto region : tsh->GetRegions())
Jerry Ge9e94af82022-10-27 09:57:00 -0700131 {
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100132 for (auto block : region->GetBlocks())
Tai Ly4e9a9772023-03-16 22:24:05 +0000133 {
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100134 for (auto ser_tensor : block->GetTensors())
135 {
136 ser_tensor_vec.push_back(ser_tensor);
137 }
Tai Ly4e9a9772023-03-16 22:24:05 +0000138 }
Jerry Ge9e94af82022-10-27 09:57:00 -0700139 }
140 }
Jiacheng Liangeb52cc12023-05-17 16:49:44 +0100141 else
142 {
143 for (auto ser_tensor : block->GetTensors())
144 {
145 ser_tensor_vec.push_back(ser_tensor);
146 }
147 }
Jerry Ge9e94af82022-10-27 09:57:00 -0700148
149 std::vector<GraphNode*> non_const_node_vec;
Eric Kunzee5e26762020-10-13 16:11:07 -0700150 for (auto op : block->GetOperators())
151 {
152 // translated TosaSerializationOperator to GraphNode
Tai Lya4d748b2023-03-28 22:06:56 +0000153 TOSA_REF_TYPE input_dtype = TOSA_REF_TYPE_UNKNOWN;
154 TOSA_REF_TYPE output_dtype = TOSA_REF_TYPE_UNKNOWN;
155 TOSA_REF_TYPE weight_dtype = TOSA_REF_TYPE_UNKNOWN;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000156 uint32_t input_rank = 0;
157 uint32_t output_rank = 0;
158 uint32_t weight_rank = 0;
159 int32_t input_index = -1;
160 int32_t weight_index = -1;
Kevin Cheng550ccc52021-03-03 11:21:43 -0800161
162 switch (op->GetOp())
Eric Kunzee5e26762020-10-13 16:11:07 -0700163 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800164 case Op_CONV2D:
Kevin Cheng1533b852021-09-01 12:51:58 -0700165 case Op_CONV3D:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800166 case Op_DEPTHWISE_CONV2D:
167 case Op_TRANSPOSE_CONV2D:
168 case Op_FULLY_CONNECTED:
169 input_index = 0;
170 weight_index = 1;
171 break;
172 case Op_SELECT:
173 input_index = 1;
174 break;
175 default:
176 if (!op->GetInputTensorNames().empty())
177 input_index = 0;
178 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700179 }
180
Kevin Cheng550ccc52021-03-03 11:21:43 -0800181 if (input_index != -1)
Kevin Chengdf862692021-02-22 15:22:22 -0800182 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700183 SUBGRAPH_ERROR_IF(
184 (size_t)input_index >= op->GetInputTensorNames().size(),
185 "SubgraphTraverser::initializeGraph(): Op=%s, input_index %d must be within [0, num_input - 1]",
186 EnumNamesOp()[op->GetOp()], input_index);
Kevin Chengdf862692021-02-22 15:22:22 -0800187
Kevin Cheng550ccc52021-03-03 11:21:43 -0800188 std::string input_name = op->GetInputTensorNames()[input_index];
Jerry Ge9e94af82022-10-27 09:57:00 -0700189 TosaSerializationTensor* input_tensor = nullptr;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000190 for (auto ser_tensor : ser_tensor_vec)
191 {
192 if (ser_tensor->GetName() == input_name)
193 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700194 input_tensor = ser_tensor;
195 }
196 }
197
Kevin Cheng903763c2021-09-28 16:14:52 -0700198 SUBGRAPH_ERROR_IF(
199 !input_tensor,
200 "SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler",
201 input_name.c_str());
Tai Lya4d748b2023-03-28 22:06:56 +0000202 input_dtype = ConvertDType(input_tensor->GetDtype());
Kevin Cheng550ccc52021-03-03 11:21:43 -0800203 input_rank = input_tensor->GetShape().size();
Kevin Chengdf862692021-02-22 15:22:22 -0800204 }
205
Kevin Cheng550ccc52021-03-03 11:21:43 -0800206 if (weight_index != -1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700207 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700208 SUBGRAPH_ERROR_IF(
209 (size_t)weight_index >= op->GetInputTensorNames().size(),
210 "SubgraphTraverser::initializeGraph(): Op=%s, weight_index %d must be within [0, num_input - 1]",
211 EnumNamesOp()[op->GetOp()], weight_index);
Kevin Cheng550ccc52021-03-03 11:21:43 -0800212 std::string weight_name = op->GetInputTensorNames()[weight_index];
Jerry Ge9e94af82022-10-27 09:57:00 -0700213 TosaSerializationTensor* weight_tensor = nullptr;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000214 for (auto ser_tensor : ser_tensor_vec)
215 {
216 if (ser_tensor->GetName() == weight_name)
217 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700218 weight_tensor = ser_tensor;
219 }
220 }
221
Kevin Cheng903763c2021-09-28 16:14:52 -0700222 SUBGRAPH_ERROR_IF(
223 !weight_tensor,
224 "SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler",
225 weight_name.c_str());
Tai Lya4d748b2023-03-28 22:06:56 +0000226 weight_dtype = ConvertDType(weight_tensor->GetDtype());
Kevin Cheng550ccc52021-03-03 11:21:43 -0800227 weight_rank = weight_tensor->GetShape().size();
Eric Kunzee5e26762020-10-13 16:11:07 -0700228 }
229
Kevin Cheng478101b2021-10-04 10:43:14 -0700230 SUBGRAPH_ERROR_IF(op->GetOutputTensorNames().size() == 0,
231 "SubgraphTraverser::initializeGraph(): Op=%s must have at least one output tensor.",
232 EnumNamesOp()[op->GetOp()]);
Kevin Cheng550ccc52021-03-03 11:21:43 -0800233 std::string output_name = op->GetOutputTensorNames()[0];
234 TosaSerializationTensor* output_tensor = block->GetTensorByName(output_name);
Kevin Cheng903763c2021-09-28 16:14:52 -0700235 SUBGRAPH_ERROR_IF(
236 !output_tensor,
237 "SubgraphTraverser::initializeGraph(): fail to get output tensor %s from TosaSerializationHandler",
238 output_name.c_str());
Tai Lya4d748b2023-03-28 22:06:56 +0000239 output_dtype = ConvertDType(output_tensor->GetDtype());
Kevin Cheng550ccc52021-03-03 11:21:43 -0800240 output_rank = output_tensor->GetShape().size();
241
Eric Kunzee5e26762020-10-13 16:11:07 -0700242 DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
243 EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
244
Jerry Ge9e94af82022-10-27 09:57:00 -0700245 GraphNode* node = nullptr;
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000246 if (this->parent_sgt)
247 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700248 node = OpFactory::newOp(this->parent_sgt, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000249 input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
Jerry Ge9e94af82022-10-27 09:57:00 -0700250 node->setInMainBlock(false);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000251 }
252 else
253 {
254 node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype, input_rank,
255 output_dtype, output_rank, weight_dtype, weight_rank);
256 if (node)
257 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700258 node->setInMainBlock(true);
259 }
260 }
261
Kevin Cheng550ccc52021-03-03 11:21:43 -0800262 if (!node)
Eric Kunzee5e26762020-10-13 16:11:07 -0700263 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800264 if (weight_index == -1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700265 {
266 fprintf(g_func_debug.func_debug_file,
Kevin Cheng903763c2021-09-28 16:14:52 -0700267 "SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d) "
268 "-> (%s rank %d)",
Tai Lya4d748b2023-03-28 22:06:56 +0000269 EnumNamesOp()[op->GetOp()], EnumNameTOSAREFTYPE(input_dtype), input_rank,
270 EnumNameTOSAREFTYPE(output_dtype), output_rank);
Eric Kunzee5e26762020-10-13 16:11:07 -0700271 }
272 else
273 {
274 fprintf(g_func_debug.func_debug_file,
Kevin Cheng903763c2021-09-28 16:14:52 -0700275 "SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d), "
276 "weight=(%s rank %d) -> (%s rank %d)",
Tai Lya4d748b2023-03-28 22:06:56 +0000277 EnumNamesOp()[op->GetOp()], EnumNameTOSAREFTYPE(input_dtype), input_rank,
278 EnumNameTOSAREFTYPE(weight_dtype), weight_rank, EnumNameTOSAREFTYPE(output_dtype), output_rank);
Eric Kunzee5e26762020-10-13 16:11:07 -0700279 }
280
Kevin Cheng550ccc52021-03-03 11:21:43 -0800281 for (auto& ts : op->GetInputTensorNames())
Eric Kunzee5e26762020-10-13 16:11:07 -0700282 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700283 fprintf(g_func_debug.func_debug_file, "SubgraphTraverser::initializeGraph(): Input: %s\n", ts.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700284 }
285
Kevin Cheng550ccc52021-03-03 11:21:43 -0800286 for (auto& ts : op->GetOutputTensorNames())
Eric Kunzee5e26762020-10-13 16:11:07 -0700287 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700288 fprintf(g_func_debug.func_debug_file, "SubgraphTraverser::initializeGraph(): Output: %s\n", ts.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700289 }
Kevin Cheng903763c2021-09-28 16:14:52 -0700290 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported operation type or rank.");
Eric Kunzee5e26762020-10-13 16:11:07 -0700291 }
292
Kevin Chengc72b59c2021-09-29 16:57:55 -0700293 // Elementwise operator might set TOSA_ERROR when registering lambda function when creating the op.
294 // Check graph status after the op being constructed.
295 SUBGRAPH_ERROR_IF(getGraphStatus() == GraphStatus::TOSA_ERROR,
296 "SubgraphTraverser::initializeGraph(): Op %8s triggered ERROR_IF() when constructing the op.",
297 EnumNamesOp()[op->GetOp()]);
298
Kevin Cheng550ccc52021-03-03 11:21:43 -0800299 for (auto& name : op->GetInputTensorNames())
Eric Kunzee5e26762020-10-13 16:11:07 -0700300 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800301 node->addInputName(name);
Kevin Chengc72b59c2021-09-29 16:57:55 -0700302 used_tensor_name_set.insert(name);
Eric Kunzee5e26762020-10-13 16:11:07 -0700303 }
304
305 for (auto name : op->GetOutputTensorNames())
306 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800307 node->addOutputName(name);
Kevin Chengc72b59c2021-09-29 16:57:55 -0700308 used_tensor_name_set.insert(name);
Eric Kunzee5e26762020-10-13 16:11:07 -0700309 }
310
Kevin Cheng550ccc52021-03-03 11:21:43 -0800311 addNode(node);
Eric Kunzee5e26762020-10-13 16:11:07 -0700312
313 // if node doesn't have any inputs (i.e. CONST)
314 // it should be ready for evaluation
Kevin Cheng550ccc52021-03-03 11:21:43 -0800315 if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList())
Eric Kunzee5e26762020-10-13 16:11:07 -0700316 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800317 addToNextNodeList(node);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000318 }
319 else if (!node->getInMainBlock())
320 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700321 non_const_node_vec.push_back(node);
Eric Kunzee5e26762020-10-13 16:11:07 -0700322 }
323
324 idx++;
325 }
326
327 for (auto ts : block->GetTensors())
328 {
Kevin Chengcc61be32021-10-14 17:09:57 -0700329 DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
330 TosaReference::Tensor* tensor =
331 TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
332
333 SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
Tai Lya4d748b2023-03-28 22:06:56 +0000334 ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size());
Kevin Chengcc61be32021-10-14 17:09:57 -0700335
Kevin Chengcc61be32021-10-14 17:09:57 -0700336 addTensor(tensor);
337 }
338
339 DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str());
340 for (auto& input_name : block->GetInputs())
341 {
342 TosaReference::Tensor* tensor = findTensorByName(input_name);
343 DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str());
344 if (tensor)
345 {
346 tensor->setIsSubgraphInput();
347 inputTensors.push_back(tensor);
348 }
349 else
350 {
351 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find input tensor by name %s",
352 input_name.c_str());
353 }
354 }
355
356 DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str());
357 for (auto& output_name : block->GetOutputs())
358 {
359 TosaReference::Tensor* tensor = findTensorByName(output_name);
Jerry Ge9e94af82022-10-27 09:57:00 -0700360 DEBUG_INFO(GT, "output tensor name=%s", output_name.c_str());
Kevin Chengcc61be32021-10-14 17:09:57 -0700361 if (tensor)
362 {
363 tensor->setIsSubgraphOutput();
364 outputTensors.push_back(tensor);
365 }
366 else
367 {
368 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find output tensor by name %s",
369 output_name.c_str());
370 }
371 }
372
373 if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
374 {
375 dumpNextNodeList(g_func_debug.func_debug_file);
376 }
377
Jerry Ge9e94af82022-10-27 09:57:00 -0700378 // If the node is not in mainblock and not const
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000379 for (auto node : non_const_node_vec)
380 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700381 bool all_inputs_from_parent = true;
382 for (std::string& name : node->getInputNames())
383 {
384 TosaReference::Tensor* t = findTensorByName(name);
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000385 if (!t->getIsParentGraphOutput())
386 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700387 all_inputs_from_parent = false;
388 }
389 }
390 // In the children block, when a node has all its inputs from parent
391 // block, we have to manually add this node to the evaluation list
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000392 if (all_inputs_from_parent && !node->getOnNextNodeList())
393 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700394 addToNextNodeList(node);
395 }
396 }
Kevin Chengcc61be32021-10-14 17:09:57 -0700397 return 0;
398}
399
Jerry Gee5cabbf2023-07-17 21:33:17 +0000400int SubgraphTraverser::allocateInputTensors()
Kevin Chengcc61be32021-10-14 17:09:57 -0700401{
Jerry Gee5cabbf2023-07-17 21:33:17 +0000402 auto input_tensor_names_vec = block->GetInputs();
403
404 for (auto input_tensor_name : input_tensor_names_vec)
Kevin Chengcc61be32021-10-14 17:09:57 -0700405 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000406 this->allocateTensor(input_tensor_name);
407 }
408
409 return 0;
410}
411
412int SubgraphTraverser::allocateTensor(std::string name)
413{
414 auto ts = block->GetTensorByName(name);
415
416 // Bail out if tensor is used and any of its dimension is invalid.
417 auto got = used_tensor_name_set.find(ts->GetName());
418 if (got != used_tensor_name_set.end())
419 {
420 uint32_t elements = 1;
421 for (auto& dim : ts->GetShape())
Kevin Chengacb550f2021-06-29 15:32:19 -0700422 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000423 if (dim <= 0)
Kevin Chengacb550f2021-06-29 15:32:19 -0700424 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000425 DEBUG_INFO(GT, "Failed to allocate tensor %s with invalid dimension of %d", ts->GetName().c_str(), dim);
426 this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
427 return 1;
428 }
429 if (dim > static_cast<int32_t>(TOSA_MAX_TENSOR_SIZE / elements))
430 {
431 // Size greather than maximum defined in spec
432 DEBUG_INFO(GT, "Tensor %s size is greater than allowed maximum", ts->GetName().c_str());
433 this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
434 return 1;
Kevin Chengacb550f2021-06-29 15:32:19 -0700435 }
436 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000437 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700438
Jerry Gee5cabbf2023-07-17 21:33:17 +0000439 TosaReference::Tensor* tensor = findTensorByName(ts->GetName());
440 SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateTensor(): can't find tensor %s.", ts->GetName().c_str());
Kevin Chengacb550f2021-06-29 15:32:19 -0700441
Jerry Gee5cabbf2023-07-17 21:33:17 +0000442 DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
443 if (tensor->allocate())
444 {
445 FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
446 }
447
448 if (!ts->GetData().empty())
449 {
450 DEBUG_INFO(GT, "Setting data for tensor %s", tensor->getName().c_str());
451 auto serialization_dtype = ts->GetDtype();
452 switch (serialization_dtype)
Kevin Chengcc61be32021-10-14 17:09:57 -0700453 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000454 case DType_INT4: {
455 std::vector<int8_t> i4_data;
456 TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data);
457 std::vector<int32_t> i32_data(i4_data.begin(), i4_data.end());
458 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
Eric Kunzee5e26762020-10-13 16:11:07 -0700459 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000460 break;
461 case DType_INT8: {
462 std::vector<int8_t> i8_data;
463 TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data);
464 std::vector<int32_t> i32_data(i8_data.begin(), i8_data.end());
465 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
466 }
467 break;
468 case DType_INT16: {
469 std::vector<int16_t> i16_data;
470 TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data);
471 std::vector<int32_t> i32_data(i16_data.begin(), i16_data.end());
472 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
473 }
474 break;
475 case DType_INT32: {
476 std::vector<int32_t> i32_data;
477 TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data);
478 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
479 }
480 break;
481 case DType_INT48: {
482 std::vector<int64_t> i64_data;
483 TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
484 tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
485 }
486 break;
487 case DType_FP16: {
488 // Interpret f16 data as float
489 std::vector<float> f16_data;
490 TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data);
491 if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
492 {
493 std::vector<double> f64_data(f16_data.begin(), f16_data.end());
494 tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
495 }
496 else
497 {
498 tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
499 }
500 }
501 break;
502 case DType_BF16: {
503 std::vector<float> fp32_data;
504 TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
505 // Ensure valid bfloat16 stored in each float
506 for (auto f : fp32_data)
507 ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
508 if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
509 {
510 std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
511 tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
512 }
513 else
514 {
515 tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
516 }
517 }
518 break;
519 case DType_FP32: {
520 std::vector<float> fp32_data;
521 TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
522 if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
523 {
524 std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
525 tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
526 }
527 else
528 {
529 tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
530 }
531 }
532 break;
533 case DType_BOOL: {
534 std::vector<bool> bool_data;
535 TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data);
536
537 // std::vector<bool>::data() will return bit mask instead of array of bool array.
538 // Need to translate manually.
539 bool* bool_array = (bool*)calloc(bool_data.size(), sizeof(bool));
540 for (size_t i = 0; i < bool_data.size(); i++)
541 {
542 bool_array[i] = bool_data[i];
543 }
544 tensor->setTensorValueBool(bool_data.size(), bool_array);
545 }
546 break;
547 default:
548 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.",
549 EnumNameDType(ts->GetDtype()));
Eric Kunzee5e26762020-10-13 16:11:07 -0700550 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700551 }
552
553 return 0;
554}
555
556int SubgraphTraverser::isFullyEvaluated() const
557{
558 return nextNodeList.empty();
559}
560
561GraphNode* SubgraphTraverser::getNextNode()
562{
563 GraphNode* nextNode = nextNodeList.front();
564 ASSERT_MSG(nextNode, "SubgraphTraverser::getNextNode(): called with empty next node list");
565 ASSERT_MSG(nextNode->getOnNextNodeList(),
566 "SubgraphTraverser::getNextNode(): internal state error: node is not listed as being on next node list");
567
568 nextNodeList.pop_front();
569
570 nextNode->clearOnNextNodeList();
571 return nextNode;
572}
573
574int SubgraphTraverser::addToNextNodeList(GraphNode* nextNode)
575{
576 ASSERT_MSG(nextNode, "SubgraphTraverser::addToNextNodeList(): called with no node");
577 ASSERT_MSG(!nextNode->getOnNextNodeList(),
578 "SubgraphTraverser::addToNextNodeList(): internal state error: node is already on next node list");
579
580 nextNode->setOnNextNodeList();
581 nextNodeList.push_back(nextNode);
582
583 return 0;
584}
585
586int SubgraphTraverser::evaluateNextNode()
587{
588 if (isFullyEvaluated())
589 return 0;
590
591 GraphNode* currNode = getNextNode();
592
593 DEBUG_INFO(GT, "Evaluating node_%03lu, %8s, output tensor=%s", currNode->getID(), EnumNamesOp()[currNode->getOp()],
594 currNode->getOutputNames()[0].c_str());
595
596 // Sanity check for never-ending loops
597 if (currNode->getEvalCount() >= MAX_EVAL_COUNT && (currNode->getEvalCount() % MAX_EVAL_COUNT) == 0)
598 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700599 WARNING("SubgraphTraverser::evaluateNextNode(): Node %lu has been evaluated %d times. Loop suspected.",
600 currNode->getID(), currNode->getEvalCount());
Eric Kunzee5e26762020-10-13 16:11:07 -0700601 }
602
Kevin Cheng550ccc52021-03-03 11:21:43 -0800603 for (auto tensor : currNode->getOutputs())
Eric Kunzee5e26762020-10-13 16:11:07 -0700604 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800605 if (!tensor->is_allocated())
Jerry Gee5cabbf2023-07-17 21:33:17 +0000606 {
607 if (this->allocateTensor(tensor->getName()))
Eric Kunzee5e26762020-10-13 16:11:07 -0700608 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700609 FATAL_ERROR("SubgraphTraverser::evaluateNextNode(): Failed to allocate Eigen tensor %s",
610 tensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700611 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000612 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700613 }
614
615 if (currNode->eval())
616 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700617 WARNING("SubgraphTraverser::evaluateNextNode(): Failed to evaluate node: %lu", currNode->getID());
Kevin Chengacb550f2021-06-29 15:32:19 -0700618 return 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700619 }
620
621 // free input tensor if all of its consumers have all of their outputs ready and it's not block's output
Jerry Gee5cabbf2023-07-17 21:33:17 +0000622 for (auto tensor : currNode->getInputs())
623 {
624 bool in_use = false;
625
626 auto tensor_check = findTensorByName(tensor->getName());
627 if (tensor_check->getIsParentGraphOutput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700628 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000629 // if it's parent's block output tensor, we can't free it
630 continue;
631 }
Jerry Ge9e94af82022-10-27 09:57:00 -0700632
Jerry Gee5cabbf2023-07-17 21:33:17 +0000633 for (auto node : tensor->getConsumers())
634 {
635 // If the node is inside a loop, the input tensor is still needed
636 if (!node->hasAllOutputsReady())
Jerry Ge9c9c8da2023-07-19 23:08:16 +0000637 {
Jerry Gee5cabbf2023-07-17 21:33:17 +0000638 in_use = true;
Jerry Ge9e94af82022-10-27 09:57:00 -0700639 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700640 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000641
642 for (auto name : block->GetOutputs())
643 {
644 if (name == tensor->getName())
645 {
646 in_use = true;
647 }
648 }
649
650 if (!in_use)
651 {
652 tensor->deallocate();
653 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700654 }
Jerry Gee5cabbf2023-07-17 21:33:17 +0000655
Eric Kunzee5e26762020-10-13 16:11:07 -0700656 // Search the output tensors of this node to see if
657 // there are now new ready nodes available from completing this node
658 for (TosaReference::Tensor* tensor : currNode->getOutputs())
659 {
660 for (GraphNode* node : tensor->getConsumers())
661 {
662 if (!node->getOnNextNodeList() && node->hasAllInputsReady())
663 {
664 addToNextNodeList(node);
665 }
666 }
667 }
668
669 if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
670 {
671 dumpNextNodeList(g_func_debug.func_debug_file);
672 }
673
674 if (g_func_config.dump_intermediates)
675 {
676 currNode->dumpNode(g_func_debug.func_debug_file);
677 for (auto outs : currNode->getOutputs())
678 {
679 outs->dumpTensorParams(g_func_debug.func_debug_file);
680 outs->dumpTensor(g_func_debug.func_debug_file);
681 fprintf(g_func_debug.func_debug_file, "\n");
682 }
683 }
684
685 return 0;
686}
687
688int SubgraphTraverser::dumpNextNodeList(FILE* out) const
689{
690
691 // Dump next node list
692 fprintf(out, "Next node list\n");
693
694 if (nextNodeList.empty())
695 {
696 fprintf(out, "<empty>\n");
697 }
698
699 for (auto gn : nextNodeList)
700 {
701 gn->dumpNode(out);
702 }
703
704 fprintf(out, "Done.\n");
705 return 0;
706}
707
708int SubgraphTraverser::clearAllNodeMarkings()
709{
710 for (GraphNode* currNode : nodes)
711 {
712 currNode->clearNodeMarked();
713 }
714
715 return false;
716}
717
Kevin Cheng550ccc52021-03-03 11:21:43 -0800718int SubgraphTraverser::addTensor(TosaReference::Tensor* tensor)
Eric Kunzee5e26762020-10-13 16:11:07 -0700719{
720 // Enforce no duplicate tensors/tensor names
721 // O(N), but the number of tensors is small
722 for (TosaReference::Tensor* currTensor : tensors)
723 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800724 if (tensor == currTensor || currTensor->getName() == tensor->getName())
Eric Kunzee5e26762020-10-13 16:11:07 -0700725 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700726 FATAL_ERROR("SubgraphTraverser::addTensor(): Duplicate tensor or tensor name being added to graph: %s\n",
727 tensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700728 return 1;
729 }
730 }
731
Kevin Cheng550ccc52021-03-03 11:21:43 -0800732 tensors.push_back(tensor);
Eric Kunzee5e26762020-10-13 16:11:07 -0700733
Kevin Cheng550ccc52021-03-03 11:21:43 -0800734 if (tensor->getIsSubgraphInput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700735 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800736 inputTensors.push_back(tensor);
Eric Kunzee5e26762020-10-13 16:11:07 -0700737 }
738
Kevin Cheng550ccc52021-03-03 11:21:43 -0800739 if (tensor->getIsSubgraphOutput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700740 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800741 outputTensors.push_back(tensor);
Eric Kunzee5e26762020-10-13 16:11:07 -0700742 }
743
744 return 0;
745}
746int SubgraphTraverser::addNode(GraphNode* newNode)
747{
748 // Enforce no duplicate nodes
749 for (GraphNode* currNode : nodes)
750 {
751 if (currNode == newNode)
752 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700753 FATAL_ERROR("SubgraphTraverser::addTensor(): duplicate node being added to graph");
Eric Kunzee5e26762020-10-13 16:11:07 -0700754 return 1;
755 }
756 }
757
758 nodes.push_back(newNode);
759
760 return 0;
761}
762
763TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& name) const
764{
Jerry Ge9e94af82022-10-27 09:57:00 -0700765 TosaReference::Tensor* res_tensor = nullptr;
766
Eric Kunzee5e26762020-10-13 16:11:07 -0700767 for (TosaReference::Tensor* currTensor : tensors)
768 {
769 if (currTensor->getName() == name)
770 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700771 res_tensor = currTensor;
772 return res_tensor;
Eric Kunzee5e26762020-10-13 16:11:07 -0700773 }
774 }
775
Jerry Ge9e94af82022-10-27 09:57:00 -0700776 if (parent_sgt)
777 {
778 for (TosaReference::Tensor* currTensor : parent_sgt->tensors)
779 {
780 if (currTensor->getName() == name)
781 {
782 res_tensor = currTensor;
783 res_tensor->setIsParentGraphOutput();
784 }
785 }
786 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700787
Jerry Ge9e94af82022-10-27 09:57:00 -0700788 if (!res_tensor)
789 {
790 WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str());
791 return nullptr;
792 }
793 return res_tensor;
Eric Kunzee5e26762020-10-13 16:11:07 -0700794}
795
796int SubgraphTraverser::linkTensorsAndNodes()
797{
798 // Nodes have a list of input/output tensor names
799 // For each node, read this list, link up the tensors with their inputs/outputs
800 for (GraphNode* currNode : nodes)
801 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700802 // Link inputs/consuming nodes
803 for (std::string& name : currNode->getInputNames())
804 {
805 TosaReference::Tensor* t = findTensorByName(name);
Kevin Cheng903763c2021-09-28 16:14:52 -0700806 SUBGRAPH_ERROR_IF(!t, "SubgraphTraverser::linkTensorsAndNodes(): Cannot find tensor %s in node %lu\n",
807 name.c_str(), currNode->getID());
808 SUBGRAPH_ERROR_IF(currNode->addInputTensor(t),
809 "SubgraphTraverser::linkTensorsAndNodes(): cannot link tensor %s to node %lu\n",
810 name.c_str(), currNode->getID());
811 SUBGRAPH_ERROR_IF(t->addConsumer(currNode),
812 "SubgraphTraverser::linkTensorsAndNodes(): cannot link consumer node %lu to tensor %s\n",
813 currNode->getID(), name.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700814 }
815
816 // Link outputs/producing nodes
817 for (std::string& name : currNode->getOutputNames())
818 {
819 TosaReference::Tensor* t = findTensorByName(name);
Kevin Cheng903763c2021-09-28 16:14:52 -0700820 SUBGRAPH_ERROR_IF(!t, "SubgraphTraverser::linkTensorsAndNodes(): Cannot find tensor %s in node %lu\n",
821 name.c_str(), currNode->getID());
822 SUBGRAPH_ERROR_IF(currNode->addOutputTensor(t),
823 "SubgraphTraverser::linkTensorsAndNodes(): cannot link tensor %s to node %lu\n",
824 name.c_str(), currNode->getID());
Eric Kunzee5e26762020-10-13 16:11:07 -0700825
Kevin Cheng903763c2021-09-28 16:14:52 -0700826 SUBGRAPH_ERROR_IF(
827 t->setProducer(currNode),
828 "SubgraphTraverser::linkTensorsAndNodes(): cannot link producer node %lu to tensor tensor %s\n",
829 currNode->getID(), name.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700830 }
831 }
832
833 return 0;
834}
835
836int SubgraphTraverser::validateGraph()
837{
838 // Need to make sure that:
839 // - each tensor is actually used
840 // - input and output tesnsors truly are just input and just output
841 // Graph building already determined that each node has found its input/output tensors
842
843 for (TosaReference::Tensor* currTensor : tensors)
844 {
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700845 // It's okay for block input tensor not being consumed by operators.
846 // This is common in control flow op execution.
847 if (!currTensor->getIsSubgraphInput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700848 {
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700849 if (!currTensor->getProducer() && currTensor->getConsumers().empty())
Eric Kunzee5e26762020-10-13 16:11:07 -0700850 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700851 WARNING("SubgraphTraverser::validateGraph(): TosaReference::Tensor %s has no producers or consumers\n",
Eric Kunzee5e26762020-10-13 16:11:07 -0700852 currTensor->getName().c_str());
853 return 1;
854 }
855 }
856
Eric Kunzee5e26762020-10-13 16:11:07 -0700857 if (g_func_config.tosa_profile == 0)
858 {
Tai Lya4d748b2023-03-28 22:06:56 +0000859 TOSA_REF_TYPE dtype = currTensor->getDtype();
Eric Kunzee5e26762020-10-13 16:11:07 -0700860
861 // Float-point disallowed
Tai Lya4d748b2023-03-28 22:06:56 +0000862 if (dtype == TOSA_REF_TYPE_FP32 || dtype == TOSA_REF_TYPE_FP16)
Eric Kunzee5e26762020-10-13 16:11:07 -0700863 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700864 WARNING("SubgraphTraverser::validateGraph(): TOSA Base Inference profile selected: All floating point "
865 "disabled, but %s tensor %s found\n",
Tai Lya4d748b2023-03-28 22:06:56 +0000866 EnumNameTOSAREFTYPE(dtype), currTensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700867 return 1;
868 }
869 }
870 else if (g_func_config.tosa_profile == 1 || g_func_config.tosa_profile == 2)
871 {
872 // Do nothing. All FP types allowed
873 // Currently no implementation difference between Main Inference and Main Training modes
874 }
875 else
876 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700877 FATAL_ERROR("SubgraphTraverser::validateGraph(): TOSA profile not recognized: %d",
878 g_func_config.tosa_profile);
Eric Kunzee5e26762020-10-13 16:11:07 -0700879 }
880 }
881
882 for (GraphNode* currNode : nodes)
883 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700884 SUBGRAPH_ERROR_IF(currNode->checkTensorAttributes(),
885 "SubgraphTraverser::validateGraph(): TosaReference::Tensor attribute check failed");
Eric Kunzee5e26762020-10-13 16:11:07 -0700886 }
887
888 if (outputTensors.size() <= 0)
889 {
890 DEBUG_MED(GT, "Graph output tensor empty");
891 return 0;
892 }
893
894 return 0;
895}
896
897int SubgraphTraverser::dumpGraph(FILE* out) const
898{
899 int i = 0;
900
901 fprintf(out, "Full graph dump:\n");
902 for (GraphNode* currNode : nodes)
903 {
904 fprintf(out, "Node [%d]: ", i++);
905 currNode->dumpNode(out);
906 }
907
908 return 0;
909}
910
911int SubgraphTraverser::evaluateAll()
912{
913 // evaluation loop
914 while (!isFullyEvaluated())
915 {
916 if (evaluateNextNode())
917 {
918 return 1;
919 }
920 }
921
922 return 0;
923}