blob: 8867ada766dbd9b40a14da307c8af6205dcb3eac [file] [log] [blame]
Eric Kunzee5e26762020-10-13 16:11:07 -07001
Jerry Ge9e94af82022-10-27 09:57:00 -07002// Copyright (c) 2020-2023, ARM Limited.
Eric Kunzee5e26762020-10-13 16:11:07 -07003//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16#include "subgraph_traverser.h"
Eric Kunze41964912022-10-14 13:33:58 -070017#include "tosa_model_types.h"
James Ward24dbc422022-10-19 12:20:31 +010018#include "arith_util.h"
Eric Kunzee5e26762020-10-13 16:11:07 -070019
Kevin Cheng903763c2021-09-28 16:14:52 -070020#ifndef SUBGRAPH_ERROR_IF
21#define SUBGRAPH_ERROR_IF(COND, fmt, ...) \
22 if ((COND)) \
23 { \
24 if (this->getGraphStatus() != GraphStatus::TOSA_UNPREDICTABLE) \
25 { \
26 this->setGraphStatus(GraphStatus::TOSA_ERROR); \
27 } \
28 fprintf(g_func_debug.func_debug_file, COL_FATAL("SUBGRAPH_ERROR_IF() fails AT %s:%d %s(): (%s)\n"), __FILE__, \
29 __LINE__, __func__, #COND); \
30 fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
31 func_print_backtrace(g_func_debug.func_debug_file); \
32 return 1; \
33 }
34#endif
35
Eric Kunzee5e26762020-10-13 16:11:07 -070036using namespace TosaReference;
37using namespace Eigen;
38using namespace tosa;
39
Jerry Ge9e94af82022-10-27 09:57:00 -070040SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh, SubgraphTraverser* _parent_sgt)
Eric Kunzee5e26762020-10-13 16:11:07 -070041{
Jerry Ge9e94af82022-10-27 09:57:00 -070042
Kevin Chengacb550f2021-06-29 15:32:19 -070043 graph_status = GraphStatus::TOSA_VALID;
Eric Kunzee5e26762020-10-13 16:11:07 -070044 block = _block;
Eric Kunzee5e26762020-10-13 16:11:07 -070045
Jerry Ge9e94af82022-10-27 09:57:00 -070046 tsh = _tsh;
47 parent_sgt = _parent_sgt;
Eric Kunzee5e26762020-10-13 16:11:07 -070048 tensors.clear();
49 nodes.clear();
50 nextNodeList.clear();
51}
52
53SubgraphTraverser::~SubgraphTraverser()
54{
55 nextNodeList.clear();
56
57 for (GraphNode* n : nodes)
58 {
59 delete n;
60 }
61 nodes.clear();
62
63 for (TosaReference::Tensor* t : tensors)
64 {
65 if (t->is_allocated())
66 {
67 t->deallocate();
68 }
69 delete t;
70 }
71 tensors.clear();
72}
73
74int SubgraphTraverser::getNumInputTensors() const
75{
76 return inputTensors.size();
77}
78
79TosaReference::Tensor* SubgraphTraverser::getInputTensor(const unsigned int idx) const
80{
81 return inputTensors[idx];
82}
83
84TosaReference::Tensor* SubgraphTraverser::getInputTensorByName(const std::string name) const
85{
86 for (auto t : inputTensors)
87 {
88 if (t->getName() == name)
89 {
90 return t;
91 }
92 }
93
94 return nullptr;
95}
96
97int SubgraphTraverser::getNumOutputTensors() const
98{
99 return outputTensors.size();
100}
101
102TosaReference::Tensor* SubgraphTraverser::getOutputTensor(const unsigned int idx) const
103{
104 return outputTensors[idx];
105}
106
107TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::string name) const
108{
109 for (auto t : outputTensors)
110 {
111 if (t->getName() == name)
112 {
113 return t;
114 }
115 }
116
117 return nullptr;
118}
119
120int SubgraphTraverser::initializeGraph()
121{
Eric Kunzee5e26762020-10-13 16:11:07 -0700122 int idx = 0;
Kevin Chengc72b59c2021-09-29 16:57:55 -0700123
Jerry Ge9e94af82022-10-27 09:57:00 -0700124 std::vector<TosaSerializationTensor*> ser_tensor_vec;
125 // Get all the serialized tensors from TosaSerializationHandler.
126 for (auto block: tsh->GetMainRegion()->GetBlocks())
127 {
128 for (auto ser_tensor : block->GetTensors())
129 {
130 ser_tensor_vec.push_back(ser_tensor);
131 }
132 }
133
134 std::vector<GraphNode*> non_const_node_vec;
Eric Kunzee5e26762020-10-13 16:11:07 -0700135 for (auto op : block->GetOperators())
136 {
137 // translated TosaSerializationOperator to GraphNode
Kevin Cheng550ccc52021-03-03 11:21:43 -0800138 DType input_dtype = DType_UNKNOWN;
139 DType output_dtype = DType_UNKNOWN;
140 DType weight_dtype = DType_UNKNOWN;
141 uint32_t input_rank = 0;
142 uint32_t output_rank = 0;
143 uint32_t weight_rank = 0;
144 int32_t input_index = -1;
145 int32_t weight_index = -1;
146
147 switch (op->GetOp())
Eric Kunzee5e26762020-10-13 16:11:07 -0700148 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800149 case Op_CONV2D:
Kevin Cheng1533b852021-09-01 12:51:58 -0700150 case Op_CONV3D:
Kevin Cheng550ccc52021-03-03 11:21:43 -0800151 case Op_DEPTHWISE_CONV2D:
152 case Op_TRANSPOSE_CONV2D:
153 case Op_FULLY_CONNECTED:
154 input_index = 0;
155 weight_index = 1;
156 break;
157 case Op_SELECT:
158 input_index = 1;
159 break;
160 default:
161 if (!op->GetInputTensorNames().empty())
162 input_index = 0;
163 break;
Eric Kunzee5e26762020-10-13 16:11:07 -0700164 }
165
Kevin Cheng550ccc52021-03-03 11:21:43 -0800166 if (input_index != -1)
Kevin Chengdf862692021-02-22 15:22:22 -0800167 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700168 SUBGRAPH_ERROR_IF(
169 (size_t)input_index >= op->GetInputTensorNames().size(),
170 "SubgraphTraverser::initializeGraph(): Op=%s, input_index %d must be within [0, num_input - 1]",
171 EnumNamesOp()[op->GetOp()], input_index);
Kevin Chengdf862692021-02-22 15:22:22 -0800172
Kevin Cheng550ccc52021-03-03 11:21:43 -0800173 std::string input_name = op->GetInputTensorNames()[input_index];
Jerry Ge9e94af82022-10-27 09:57:00 -0700174 TosaSerializationTensor* input_tensor = nullptr;
175 for (auto ser_tensor : ser_tensor_vec) {
176 if (ser_tensor->GetName() == input_name) {
177 input_tensor = ser_tensor;
178 }
179 }
180
Kevin Cheng903763c2021-09-28 16:14:52 -0700181 SUBGRAPH_ERROR_IF(
182 !input_tensor,
183 "SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler",
184 input_name.c_str());
Kevin Cheng550ccc52021-03-03 11:21:43 -0800185 input_dtype = input_tensor->GetDtype();
186 input_rank = input_tensor->GetShape().size();
Kevin Chengdf862692021-02-22 15:22:22 -0800187 }
188
Kevin Cheng550ccc52021-03-03 11:21:43 -0800189 if (weight_index != -1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700190 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700191 SUBGRAPH_ERROR_IF(
192 (size_t)weight_index >= op->GetInputTensorNames().size(),
193 "SubgraphTraverser::initializeGraph(): Op=%s, weight_index %d must be within [0, num_input - 1]",
194 EnumNamesOp()[op->GetOp()], weight_index);
Kevin Cheng550ccc52021-03-03 11:21:43 -0800195 std::string weight_name = op->GetInputTensorNames()[weight_index];
Jerry Ge9e94af82022-10-27 09:57:00 -0700196 TosaSerializationTensor* weight_tensor = nullptr;
197 for (auto ser_tensor : ser_tensor_vec) {
198 if (ser_tensor->GetName() == weight_name) {
199 weight_tensor = ser_tensor;
200 }
201 }
202
Kevin Cheng903763c2021-09-28 16:14:52 -0700203 SUBGRAPH_ERROR_IF(
204 !weight_tensor,
205 "SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler",
206 weight_name.c_str());
Kevin Cheng550ccc52021-03-03 11:21:43 -0800207 weight_dtype = weight_tensor->GetDtype();
208 weight_rank = weight_tensor->GetShape().size();
Eric Kunzee5e26762020-10-13 16:11:07 -0700209 }
210
Kevin Cheng478101b2021-10-04 10:43:14 -0700211 SUBGRAPH_ERROR_IF(op->GetOutputTensorNames().size() == 0,
212 "SubgraphTraverser::initializeGraph(): Op=%s must have at least one output tensor.",
213 EnumNamesOp()[op->GetOp()]);
Kevin Cheng550ccc52021-03-03 11:21:43 -0800214 std::string output_name = op->GetOutputTensorNames()[0];
215 TosaSerializationTensor* output_tensor = block->GetTensorByName(output_name);
Kevin Cheng903763c2021-09-28 16:14:52 -0700216 SUBGRAPH_ERROR_IF(
217 !output_tensor,
218 "SubgraphTraverser::initializeGraph(): fail to get output tensor %s from TosaSerializationHandler",
219 output_name.c_str());
Kevin Cheng550ccc52021-03-03 11:21:43 -0800220 output_dtype = output_tensor->GetDtype();
221 output_rank = output_tensor->GetShape().size();
222
Eric Kunzee5e26762020-10-13 16:11:07 -0700223 DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
224 EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
225
Jerry Ge9e94af82022-10-27 09:57:00 -0700226 GraphNode* node = nullptr;
227 if (this->parent_sgt) {
228 node = OpFactory::newOp(this->parent_sgt, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
Kevin Cheng550ccc52021-03-03 11:21:43 -0800229 input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
Jerry Ge9e94af82022-10-27 09:57:00 -0700230 node->setInMainBlock(false);
231 } else {
232 node = OpFactory::newOp(this, tsh, op->GetOp(), op->GetAttribute(), idx, input_dtype,
233 input_rank, output_dtype, output_rank, weight_dtype, weight_rank);
234 if (node) {
235 node->setInMainBlock(true);
236 }
237 }
238
Kevin Cheng550ccc52021-03-03 11:21:43 -0800239 if (!node)
Eric Kunzee5e26762020-10-13 16:11:07 -0700240 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800241 if (weight_index == -1)
Eric Kunzee5e26762020-10-13 16:11:07 -0700242 {
243 fprintf(g_func_debug.func_debug_file,
Kevin Cheng903763c2021-09-28 16:14:52 -0700244 "SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d) "
245 "-> (%s rank %d)",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800246 EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank,
247 EnumNamesDType()[output_dtype], output_rank);
Eric Kunzee5e26762020-10-13 16:11:07 -0700248 }
249 else
250 {
251 fprintf(g_func_debug.func_debug_file,
Kevin Cheng903763c2021-09-28 16:14:52 -0700252 "SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d), "
253 "weight=(%s rank %d) -> (%s rank %d)",
Kevin Cheng550ccc52021-03-03 11:21:43 -0800254 EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank,
255 EnumNamesDType()[weight_dtype], weight_rank, EnumNamesDType()[output_dtype], output_rank);
Eric Kunzee5e26762020-10-13 16:11:07 -0700256 }
257
Kevin Cheng550ccc52021-03-03 11:21:43 -0800258 for (auto& ts : op->GetInputTensorNames())
Eric Kunzee5e26762020-10-13 16:11:07 -0700259 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700260 fprintf(g_func_debug.func_debug_file, "SubgraphTraverser::initializeGraph(): Input: %s\n", ts.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700261 }
262
Kevin Cheng550ccc52021-03-03 11:21:43 -0800263 for (auto& ts : op->GetOutputTensorNames())
Eric Kunzee5e26762020-10-13 16:11:07 -0700264 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700265 fprintf(g_func_debug.func_debug_file, "SubgraphTraverser::initializeGraph(): Output: %s\n", ts.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700266 }
Kevin Cheng903763c2021-09-28 16:14:52 -0700267 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported operation type or rank.");
Eric Kunzee5e26762020-10-13 16:11:07 -0700268 }
269
Kevin Chengc72b59c2021-09-29 16:57:55 -0700270 // Elementwise operator might set TOSA_ERROR when registering lambda function when creating the op.
271 // Check graph status after the op being constructed.
272 SUBGRAPH_ERROR_IF(getGraphStatus() == GraphStatus::TOSA_ERROR,
273 "SubgraphTraverser::initializeGraph(): Op %8s triggered ERROR_IF() when constructing the op.",
274 EnumNamesOp()[op->GetOp()]);
275
Kevin Cheng550ccc52021-03-03 11:21:43 -0800276 for (auto& name : op->GetInputTensorNames())
Eric Kunzee5e26762020-10-13 16:11:07 -0700277 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800278 node->addInputName(name);
Kevin Chengc72b59c2021-09-29 16:57:55 -0700279 used_tensor_name_set.insert(name);
Eric Kunzee5e26762020-10-13 16:11:07 -0700280 }
281
282 for (auto name : op->GetOutputTensorNames())
283 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800284 node->addOutputName(name);
Kevin Chengc72b59c2021-09-29 16:57:55 -0700285 used_tensor_name_set.insert(name);
Eric Kunzee5e26762020-10-13 16:11:07 -0700286 }
287
Kevin Cheng550ccc52021-03-03 11:21:43 -0800288 addNode(node);
Eric Kunzee5e26762020-10-13 16:11:07 -0700289
290 // if node doesn't have any inputs (i.e. CONST)
291 // it should be ready for evaluation
Kevin Cheng550ccc52021-03-03 11:21:43 -0800292 if (op->GetInputTensorNames().empty() && !node->getOnNextNodeList())
Eric Kunzee5e26762020-10-13 16:11:07 -0700293 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800294 addToNextNodeList(node);
Jerry Ge9e94af82022-10-27 09:57:00 -0700295 } else if (!node->getInMainBlock()) {
296 non_const_node_vec.push_back(node);
Eric Kunzee5e26762020-10-13 16:11:07 -0700297 }
298
299 idx++;
300 }
301
302 for (auto ts : block->GetTensors())
303 {
Kevin Chengcc61be32021-10-14 17:09:57 -0700304 DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
305 TosaReference::Tensor* tensor =
306 TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
307
308 SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
309 ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size());
310
Kevin Chengcc61be32021-10-14 17:09:57 -0700311 addTensor(tensor);
312 }
313
314 DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str());
315 for (auto& input_name : block->GetInputs())
316 {
317 TosaReference::Tensor* tensor = findTensorByName(input_name);
318 DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str());
319 if (tensor)
320 {
321 tensor->setIsSubgraphInput();
322 inputTensors.push_back(tensor);
323 }
324 else
325 {
326 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find input tensor by name %s",
327 input_name.c_str());
328 }
329 }
330
331 DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str());
332 for (auto& output_name : block->GetOutputs())
333 {
334 TosaReference::Tensor* tensor = findTensorByName(output_name);
Jerry Ge9e94af82022-10-27 09:57:00 -0700335 DEBUG_INFO(GT, "output tensor name=%s", output_name.c_str());
Kevin Chengcc61be32021-10-14 17:09:57 -0700336 if (tensor)
337 {
338 tensor->setIsSubgraphOutput();
339 outputTensors.push_back(tensor);
340 }
341 else
342 {
343 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Failed to find output tensor by name %s",
344 output_name.c_str());
345 }
346 }
347
348 if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
349 {
350 dumpNextNodeList(g_func_debug.func_debug_file);
351 }
352
Jerry Ge9e94af82022-10-27 09:57:00 -0700353 // If the node is not in mainblock and not const
354 for (auto node : non_const_node_vec) {
355 bool all_inputs_from_parent = true;
356 for (std::string& name : node->getInputNames())
357 {
358 TosaReference::Tensor* t = findTensorByName(name);
359 if (!t->getIsParentGraphOutput()) {
360 all_inputs_from_parent = false;
361 }
362 }
363 // In the children block, when a node has all its inputs from parent
364 // block, we have to manually add this node to the evaluation list
365 if (all_inputs_from_parent && !node->getOnNextNodeList()) {
366 addToNextNodeList(node);
367 }
368 }
Kevin Chengcc61be32021-10-14 17:09:57 -0700369 return 0;
370}
371
372int SubgraphTraverser::allocateTensor()
373{
374 for (auto ts : block->GetTensors())
375 {
Kevin Chengc72b59c2021-09-29 16:57:55 -0700376 // Bail out if tensor is used and any of its dimension is invalid.
377 auto got = used_tensor_name_set.find(ts->GetName());
378 if (got != used_tensor_name_set.end())
Kevin Chengacb550f2021-06-29 15:32:19 -0700379 {
Eric Kunze41964912022-10-14 13:33:58 -0700380 uint32_t elements = 1;
Kevin Chengc72b59c2021-09-29 16:57:55 -0700381 for (auto& dim : ts->GetShape())
Kevin Chengacb550f2021-06-29 15:32:19 -0700382 {
Kevin Chengc72b59c2021-09-29 16:57:55 -0700383 if (dim <= 0)
384 {
Jeremy Johnson93d43902022-09-27 12:26:14 +0100385 DEBUG_INFO(GT, "Failed to allocate tensor %s with invalid dimension of %d", ts->GetName().c_str(), dim);
Kevin Chengc72b59c2021-09-29 16:57:55 -0700386 this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
387 return 1;
388 }
Eric Kunze41964912022-10-14 13:33:58 -0700389 if (dim > static_cast<int32_t>(TOSA_MAX_TENSOR_SIZE / elements))
390 {
391 // Size greather than maximum defined in spec
392 DEBUG_INFO(GT, "Tensor %s size is greater than allowed maximum", ts->GetName().c_str());
393 this->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE);
394 return 1;
395 }
Kevin Chengacb550f2021-06-29 15:32:19 -0700396 }
397 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700398
Kevin Chengcc61be32021-10-14 17:09:57 -0700399 TosaReference::Tensor* tensor = findTensorByName(ts->GetName());
400 SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::allocateTensor(): can't find tensor %s.", ts->GetName().c_str());
Kevin Chengacb550f2021-06-29 15:32:19 -0700401
Kevin Chengcc61be32021-10-14 17:09:57 -0700402 DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
403 if (tensor->allocate())
404 {
405 FATAL_ERROR("Failed to allocate tensor %s", tensor->getName().c_str());
406 }
Kevin Cheng903763c2021-09-28 16:14:52 -0700407
Kevin Cheng82507d72021-06-17 16:01:59 -0700408 if (!ts->GetData().empty())
Eric Kunzee5e26762020-10-13 16:11:07 -0700409 {
Kevin Chengcc61be32021-10-14 17:09:57 -0700410 DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
Kevin Cheng82507d72021-06-17 16:01:59 -0700411 switch (ts->GetDtype())
Eric Kunzee5e26762020-10-13 16:11:07 -0700412 {
Kevin Chenga9017402021-07-28 17:19:23 -0700413 case DType_INT4:
414 {
415 std::vector<int8_t> i4_data;
416 TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data);
417 std::vector<int32_t> i32_data(i4_data.begin(), i4_data.end());
418 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
419 }
420 break;
Kevin Cheng82507d72021-06-17 16:01:59 -0700421 case DType_INT8:
422 {
423 std::vector<int8_t> i8_data;
424 TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data);
425 std::vector<int32_t> i32_data(i8_data.begin(), i8_data.end());
426 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
427 }
428 break;
429 case DType_INT16:
430 {
431 std::vector<int16_t> i16_data;
432 TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data);
433 std::vector<int32_t> i32_data(i16_data.begin(), i16_data.end());
434 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
435 }
436 break;
437 case DType_INT32:
438 {
439 std::vector<int32_t> i32_data;
440 TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data);
441 tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
442 }
443 break;
444 case DType_INT48:
445 {
446 std::vector<int64_t> i64_data;
447 TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
448 tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
449 }
450 break;
James Ward8b390432022-08-12 20:48:56 +0100451 case DType_FP16:
452 {
453 // Interpret f16 data as float
454 std::vector<float> f16_data;
455 TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data);
456 tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
457 }
458 break;
James Ward24dbc422022-10-19 12:20:31 +0100459 case DType_BF16:
460 {
461 std::vector<float> fp32_data;
462 TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
463 // Ensure valid bfloat16 stored in each float
464 for (auto f : fp32_data)
465 ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
466 tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
467 }
468 break;
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100469 case DType_FP32:
Kevin Cheng82507d72021-06-17 16:01:59 -0700470 {
471 std::vector<float> fp32_data;
472 TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
473 tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
474 }
475 break;
476 case DType_BOOL:
477 {
478 std::vector<bool> bool_data;
479 TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data);
480
481 // std::vector<bool>::data() will return bit mask instead of array of bool array.
482 // Need to translate manually.
483 bool* bool_array = (bool*)calloc(bool_data.size(), sizeof(bool));
484 for (size_t i = 0; i < bool_data.size(); i++)
485 {
486 bool_array[i] = bool_data[i];
487 }
488 tensor->setTensorValueBool(bool_data.size(), bool_array);
489 }
490 break;
491 default:
Kevin Cheng903763c2021-09-28 16:14:52 -0700492 SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.",
493 EnumNamesDType()[ts->GetDtype()]);
Eric Kunzee5e26762020-10-13 16:11:07 -0700494 }
495 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700496 }
497
498 return 0;
499}
500
501int SubgraphTraverser::isFullyEvaluated() const
502{
503 return nextNodeList.empty();
504}
505
506GraphNode* SubgraphTraverser::getNextNode()
507{
508 GraphNode* nextNode = nextNodeList.front();
509 ASSERT_MSG(nextNode, "SubgraphTraverser::getNextNode(): called with empty next node list");
510 ASSERT_MSG(nextNode->getOnNextNodeList(),
511 "SubgraphTraverser::getNextNode(): internal state error: node is not listed as being on next node list");
512
513 nextNodeList.pop_front();
514
515 nextNode->clearOnNextNodeList();
516 return nextNode;
517}
518
519int SubgraphTraverser::addToNextNodeList(GraphNode* nextNode)
520{
521 ASSERT_MSG(nextNode, "SubgraphTraverser::addToNextNodeList(): called with no node");
522 ASSERT_MSG(!nextNode->getOnNextNodeList(),
523 "SubgraphTraverser::addToNextNodeList(): internal state error: node is already on next node list");
524
525 nextNode->setOnNextNodeList();
526 nextNodeList.push_back(nextNode);
527
528 return 0;
529}
530
531int SubgraphTraverser::evaluateNextNode()
532{
533 if (isFullyEvaluated())
534 return 0;
535
536 GraphNode* currNode = getNextNode();
537
538 DEBUG_INFO(GT, "Evaluating node_%03lu, %8s, output tensor=%s", currNode->getID(), EnumNamesOp()[currNode->getOp()],
539 currNode->getOutputNames()[0].c_str());
540
541 // Sanity check for never-ending loops
542 if (currNode->getEvalCount() >= MAX_EVAL_COUNT && (currNode->getEvalCount() % MAX_EVAL_COUNT) == 0)
543 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700544 WARNING("SubgraphTraverser::evaluateNextNode(): Node %lu has been evaluated %d times. Loop suspected.",
545 currNode->getID(), currNode->getEvalCount());
Eric Kunzee5e26762020-10-13 16:11:07 -0700546 }
547
Kevin Cheng550ccc52021-03-03 11:21:43 -0800548 for (auto tensor : currNode->getOutputs())
Eric Kunzee5e26762020-10-13 16:11:07 -0700549 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800550 if (!tensor->is_allocated())
551 if (tensor->allocate())
Eric Kunzee5e26762020-10-13 16:11:07 -0700552 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700553 FATAL_ERROR("SubgraphTraverser::evaluateNextNode(): Failed to allocate Eigen tensor %s",
554 tensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700555 }
556 }
557
558 if (currNode->eval())
559 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700560 WARNING("SubgraphTraverser::evaluateNextNode(): Failed to evaluate node: %lu", currNode->getID());
Kevin Chengacb550f2021-06-29 15:32:19 -0700561 return 1;
Eric Kunzee5e26762020-10-13 16:11:07 -0700562 }
563
564 // free input tensor if all of its consumers have all of their outputs ready and it's not block's output
Jerry Ge9e94af82022-10-27 09:57:00 -0700565 if (!currNode->getInMainBlock()) { // we don't free it if the node is in main block and has nested blocks
566 for (auto tensor : currNode->getInputs())
Eric Kunzee5e26762020-10-13 16:11:07 -0700567 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700568 bool in_use = false;
569
570 auto tensor_check = findTensorByName(tensor->getName());
571 if (tensor_check->getIsParentGraphOutput()) {
572 // if it's parent's block output tensor, we can't free it
573 continue;
Eric Kunzee5e26762020-10-13 16:11:07 -0700574 }
Jerry Ge9e94af82022-10-27 09:57:00 -0700575
576 for (auto node : tensor->getConsumers())
Eric Kunzee5e26762020-10-13 16:11:07 -0700577 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700578 // If the node is inside a loop, the input tensor is still needed
579 if (!node->hasAllOutputsReady())
580 {
581 in_use = true;
582 }
583
Eric Kunzee5e26762020-10-13 16:11:07 -0700584 }
Jerry Ge9e94af82022-10-27 09:57:00 -0700585 for (auto name : block->GetOutputs())
586 {
587 if (name == tensor->getName())
588 {
589 in_use = true;
590 }
591 }
592
593 if (!in_use)
594 {
595 tensor->deallocate();
596 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700597 }
598 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700599 // Search the output tensors of this node to see if
600 // there are now new ready nodes available from completing this node
601 for (TosaReference::Tensor* tensor : currNode->getOutputs())
602 {
603 for (GraphNode* node : tensor->getConsumers())
604 {
605 if (!node->getOnNextNodeList() && node->hasAllInputsReady())
606 {
607 addToNextNodeList(node);
608 }
609 }
610 }
611
612 if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
613 {
614 dumpNextNodeList(g_func_debug.func_debug_file);
615 }
616
617 if (g_func_config.dump_intermediates)
618 {
619 currNode->dumpNode(g_func_debug.func_debug_file);
620 for (auto outs : currNode->getOutputs())
621 {
622 outs->dumpTensorParams(g_func_debug.func_debug_file);
623 outs->dumpTensor(g_func_debug.func_debug_file);
624 fprintf(g_func_debug.func_debug_file, "\n");
625 }
626 }
627
628 return 0;
629}
630
631int SubgraphTraverser::dumpNextNodeList(FILE* out) const
632{
633
634 // Dump next node list
635 fprintf(out, "Next node list\n");
636
637 if (nextNodeList.empty())
638 {
639 fprintf(out, "<empty>\n");
640 }
641
642 for (auto gn : nextNodeList)
643 {
644 gn->dumpNode(out);
645 }
646
647 fprintf(out, "Done.\n");
648 return 0;
649}
650
651int SubgraphTraverser::clearAllNodeMarkings()
652{
653 for (GraphNode* currNode : nodes)
654 {
655 currNode->clearNodeMarked();
656 }
657
658 return false;
659}
660
Kevin Cheng550ccc52021-03-03 11:21:43 -0800661int SubgraphTraverser::addTensor(TosaReference::Tensor* tensor)
Eric Kunzee5e26762020-10-13 16:11:07 -0700662{
663 // Enforce no duplicate tensors/tensor names
664 // O(N), but the number of tensors is small
665 for (TosaReference::Tensor* currTensor : tensors)
666 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800667 if (tensor == currTensor || currTensor->getName() == tensor->getName())
Eric Kunzee5e26762020-10-13 16:11:07 -0700668 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700669 FATAL_ERROR("SubgraphTraverser::addTensor(): Duplicate tensor or tensor name being added to graph: %s\n",
670 tensor->getName().c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700671 return 1;
672 }
673 }
674
Kevin Cheng550ccc52021-03-03 11:21:43 -0800675 tensors.push_back(tensor);
Eric Kunzee5e26762020-10-13 16:11:07 -0700676
Kevin Cheng550ccc52021-03-03 11:21:43 -0800677 if (tensor->getIsSubgraphInput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700678 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800679 inputTensors.push_back(tensor);
Eric Kunzee5e26762020-10-13 16:11:07 -0700680 }
681
Kevin Cheng550ccc52021-03-03 11:21:43 -0800682 if (tensor->getIsSubgraphOutput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700683 {
Kevin Cheng550ccc52021-03-03 11:21:43 -0800684 outputTensors.push_back(tensor);
Eric Kunzee5e26762020-10-13 16:11:07 -0700685 }
686
687 return 0;
688}
689int SubgraphTraverser::addNode(GraphNode* newNode)
690{
691 // Enforce no duplicate nodes
692 for (GraphNode* currNode : nodes)
693 {
694 if (currNode == newNode)
695 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700696 FATAL_ERROR("SubgraphTraverser::addTensor(): duplicate node being added to graph");
Eric Kunzee5e26762020-10-13 16:11:07 -0700697 return 1;
698 }
699 }
700
701 nodes.push_back(newNode);
702
703 return 0;
704}
705
706TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& name) const
707{
Jerry Ge9e94af82022-10-27 09:57:00 -0700708 TosaReference::Tensor* res_tensor = nullptr;
709
Eric Kunzee5e26762020-10-13 16:11:07 -0700710 for (TosaReference::Tensor* currTensor : tensors)
711 {
712 if (currTensor->getName() == name)
713 {
Jerry Ge9e94af82022-10-27 09:57:00 -0700714 res_tensor = currTensor;
715 return res_tensor;
Eric Kunzee5e26762020-10-13 16:11:07 -0700716 }
717 }
718
Jerry Ge9e94af82022-10-27 09:57:00 -0700719 if (parent_sgt)
720 {
721 for (TosaReference::Tensor* currTensor : parent_sgt->tensors)
722 {
723 if (currTensor->getName() == name)
724 {
725 res_tensor = currTensor;
726 res_tensor->setIsParentGraphOutput();
727 }
728 }
729 }
Eric Kunzee5e26762020-10-13 16:11:07 -0700730
Jerry Ge9e94af82022-10-27 09:57:00 -0700731 if (!res_tensor)
732 {
733 WARNING("SubgraphTraverser::findTensorByName(): Unable to find tensor with name: %s\n", name.c_str());
734 return nullptr;
735 }
736 return res_tensor;
Eric Kunzee5e26762020-10-13 16:11:07 -0700737}
738
739int SubgraphTraverser::linkTensorsAndNodes()
740{
741 // Nodes have a list of input/output tensor names
742 // For each node, read this list, link up the tensors with their inputs/outputs
743 for (GraphNode* currNode : nodes)
744 {
Eric Kunzee5e26762020-10-13 16:11:07 -0700745 // Link inputs/consuming nodes
746 for (std::string& name : currNode->getInputNames())
747 {
748 TosaReference::Tensor* t = findTensorByName(name);
Kevin Cheng903763c2021-09-28 16:14:52 -0700749 SUBGRAPH_ERROR_IF(!t, "SubgraphTraverser::linkTensorsAndNodes(): Cannot find tensor %s in node %lu\n",
750 name.c_str(), currNode->getID());
751 SUBGRAPH_ERROR_IF(currNode->addInputTensor(t),
752 "SubgraphTraverser::linkTensorsAndNodes(): cannot link tensor %s to node %lu\n",
753 name.c_str(), currNode->getID());
754 SUBGRAPH_ERROR_IF(t->addConsumer(currNode),
755 "SubgraphTraverser::linkTensorsAndNodes(): cannot link consumer node %lu to tensor %s\n",
756 currNode->getID(), name.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700757 }
758
759 // Link outputs/producing nodes
760 for (std::string& name : currNode->getOutputNames())
761 {
762 TosaReference::Tensor* t = findTensorByName(name);
Kevin Cheng903763c2021-09-28 16:14:52 -0700763 SUBGRAPH_ERROR_IF(!t, "SubgraphTraverser::linkTensorsAndNodes(): Cannot find tensor %s in node %lu\n",
764 name.c_str(), currNode->getID());
765 SUBGRAPH_ERROR_IF(currNode->addOutputTensor(t),
766 "SubgraphTraverser::linkTensorsAndNodes(): cannot link tensor %s to node %lu\n",
767 name.c_str(), currNode->getID());
Eric Kunzee5e26762020-10-13 16:11:07 -0700768
Kevin Cheng903763c2021-09-28 16:14:52 -0700769 SUBGRAPH_ERROR_IF(
770 t->setProducer(currNode),
771 "SubgraphTraverser::linkTensorsAndNodes(): cannot link producer node %lu to tensor tensor %s\n",
772 currNode->getID(), name.c_str());
Eric Kunzee5e26762020-10-13 16:11:07 -0700773 }
774 }
775
776 return 0;
777}
778
779int SubgraphTraverser::validateGraph()
780{
781 // Need to make sure that:
782 // - each tensor is actually used
783 // - input and output tesnsors truly are just input and just output
784 // Graph building already determined that each node has found its input/output tensors
785
786 for (TosaReference::Tensor* currTensor : tensors)
787 {
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700788 // It's okay for block input tensor not being consumed by operators.
789 // This is common in control flow op execution.
790 if (!currTensor->getIsSubgraphInput())
Eric Kunzee5e26762020-10-13 16:11:07 -0700791 {
Kevin Cheng14d7f7a2021-05-12 10:44:49 -0700792 if (!currTensor->getProducer() && currTensor->getConsumers().empty())
Eric Kunzee5e26762020-10-13 16:11:07 -0700793 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700794 WARNING("SubgraphTraverser::validateGraph(): TosaReference::Tensor %s has no producers or consumers\n",
Eric Kunzee5e26762020-10-13 16:11:07 -0700795 currTensor->getName().c_str());
796 return 1;
797 }
798 }
799
Eric Kunzee5e26762020-10-13 16:11:07 -0700800 if (g_func_config.tosa_profile == 0)
801 {
802 DType dtype = currTensor->getDtype();
803
804 // Float-point disallowed
Jeremy Johnsonbc2a3db2022-09-27 13:50:00 +0100805 if (dtype == DType_FP32 || dtype == DType_FP16)
Eric Kunzee5e26762020-10-13 16:11:07 -0700806 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700807 WARNING("SubgraphTraverser::validateGraph(): TOSA Base Inference profile selected: All floating point "
808 "disabled, but %s tensor %s found\n",
Eric Kunzee5e26762020-10-13 16:11:07 -0700809 EnumNamesDType()[dtype], currTensor->getName().c_str());
810 return 1;
811 }
812 }
813 else if (g_func_config.tosa_profile == 1 || g_func_config.tosa_profile == 2)
814 {
815 // Do nothing. All FP types allowed
816 // Currently no implementation difference between Main Inference and Main Training modes
817 }
818 else
819 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700820 FATAL_ERROR("SubgraphTraverser::validateGraph(): TOSA profile not recognized: %d",
821 g_func_config.tosa_profile);
Eric Kunzee5e26762020-10-13 16:11:07 -0700822 }
823 }
824
825 for (GraphNode* currNode : nodes)
826 {
Kevin Cheng903763c2021-09-28 16:14:52 -0700827 SUBGRAPH_ERROR_IF(currNode->checkTensorAttributes(),
828 "SubgraphTraverser::validateGraph(): TosaReference::Tensor attribute check failed");
Eric Kunzee5e26762020-10-13 16:11:07 -0700829 }
830
831 if (outputTensors.size() <= 0)
832 {
833 DEBUG_MED(GT, "Graph output tensor empty");
834 return 0;
835 }
836
837 return 0;
838}
839
840int SubgraphTraverser::dumpGraph(FILE* out) const
841{
842 int i = 0;
843
844 fprintf(out, "Full graph dump:\n");
845 for (GraphNode* currNode : nodes)
846 {
847 fprintf(out, "Node [%d]: ", i++);
848 currNode->dumpNode(out);
849 }
850
851 return 0;
852}
853
854int SubgraphTraverser::evaluateAll()
855{
856 // evaluation loop
857 while (!isFullyEvaluated())
858 {
859 if (evaluateNextNode())
860 {
861 return 1;
862 }
863 }
864
865 return 0;
866}