blob: c891e76d8b1d346d6366acdbcb0f4c158c101569 [file] [log] [blame]
SiCong Lif44bbc52022-08-29 18:25:51 +01001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef SRC_DYNAMIC_FUSION_SKETCH_UTILS_DEPENDENCYGRAPH
25#define SRC_DYNAMIC_FUSION_SKETCH_UTILS_DEPENDENCYGRAPH
26
27#include "arm_compute/core/Error.h"
SiCong Lif44bbc52022-08-29 18:25:51 +010028#include <cstdint>
SiCong Lif44bbc52022-08-29 18:25:51 +010029#include <map>
30#include <set>
31#include <tuple>
32#include <vector>
33
34namespace arm_compute
35{
36namespace experimental
37{
38namespace dynamic_fusion
39{
40namespace
41{
42template <typename T>
43bool is_in(const T &v, const std::vector<T> &vec)
44{
45 return std::find(std::begin(vec), std::end(vec), v) != std::end(vec);
46}
47} // namespace
48
49/** A multi-input (tensors), multi-output (tensors) acyclic directed graph
50 * Represented as a doubly-linked adjacency list with the differentiation between source and destination
51 */
52class DependencyGraph
53{
54public:
55 using Id = int32_t;
56 using TensorId = Id;
57 using OperatorId = Id;
58 /** Adjacency list
59 *
60 */
61 using AdjList = std::map<Id, std::vector<Id>>;
62
63 /** A pack of operator including its input and output tensors, used by traversing through the graph in topological order
64 *
65 */
66 struct OpPack
67 {
68 OperatorId op{};
69 std::vector<TensorId> inputs{};
70 std::vector<TensorId> outputs{};
71 friend bool operator==(const OpPack &opp0, const OpPack &opp1)
72 {
73 return std::make_tuple(
74 opp0.op, opp0.inputs, opp0.outputs)
75 == std::make_tuple(
76 opp1.op, opp1.inputs, opp1.outputs);
77 }
78 };
79
80public:
81 DependencyGraph() = default;
82 friend std::ostream &operator<<(std::ostream &os, const DependencyGraph &);
83
84 /** Try adding an operator (without actually adding it), while keeping the graph as a "linear sequence" / list
Viet-Hoa Do04f46202022-12-14 14:49:56 +000085 *
86 * Rule: If the new operator is not the first operator, at least one input tensor must be
87 * the output tensor of the last non-output operator. All other input tensors must be
88 * the global input of the graph (i.e. not the output of any operator).
89 *
90 * Rule: The output tensor of the new operator must not be the input tensor of any previously
91 * added operator.
SiCong Lif44bbc52022-08-29 18:25:51 +010092 *
93 * PRECONDITION: The current graph is already linear
94 *
95 * @return true If the operator can be added while keeping the graph as a linear sequence
96 * @return false Otherwise
97 */
Viet-Hoa Do04f46202022-12-14 14:49:56 +000098 bool try_add_operator_as_linear(OperatorId op, const std::vector<TensorId> &inputs, const std::vector<TensorId> &outputs, bool is_output = false) const
SiCong Lif44bbc52022-08-29 18:25:51 +010099 {
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000100 ARM_COMPUTE_UNUSED(op, is_output);
SiCong Lif44bbc52022-08-29 18:25:51 +0100101 if(all_ops().empty())
102 {
103 return true;
104 }
SiCong Lif44bbc52022-08-29 18:25:51 +0100105
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000106 // If the new operator is not the first operator, at least one input tensor must be
107 // the output tensor of the last non-output operator. All other input tensors must be
108 // the global input of the graph (i.e. not the output of any operator).
109 if(_last_op_available)
SiCong Lif44bbc52022-08-29 18:25:51 +0100110 {
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000111 auto use_input_from_last_op = false;
112
113 for(auto src_tensor : inputs)
114 {
115 const auto src_ops = _adj_src_ops.find(src_tensor);
116
117 if(src_ops != _adj_src_ops.end())
118 {
119 ARM_COMPUTE_ERROR_ON(src_ops->second.size() > 1);
120
121 if(!src_ops->second.empty())
122 {
123 const auto src_op = src_ops->second[0];
124
125 if(src_op == _last_op)
126 {
127 if(use_input_from_last_op)
128 {
129 // To be safe, we also forbid using the output tensor
130 // of the last operator twice.
131 return false;
132 }
133
134 use_input_from_last_op = true;
135 }
136 else
137 {
138 // The input tensor of this operator must not be the output tensor
139 // of any other operator except the last non-output operator.
140 return false;
141 }
142 }
143 }
144 }
145
146 if(!use_input_from_last_op)
147 {
148 // At least one input tensor must be the output tensor of the last non-output operator.
149 return false;
150 }
SiCong Lif44bbc52022-08-29 18:25:51 +0100151 }
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000152
153 // The output tensor of the new operator must not be the input tensor of any previously
154 // added operator.
155 for(auto dst_tensor : outputs)
156 {
157 if(_adj_dst_ops.find(dst_tensor) != _adj_dst_ops.end())
158 {
159 return false;
160 }
161 }
162
SiCong Lif44bbc52022-08-29 18:25:51 +0100163 return true;
164 }
165 /** Add an operator, while keeping the graph as a "linear sequence"
166 *
167 * PRECONDITION: The current graph is already linear
168 * INVARIANT: The list can only grow from head to tail
169 * INVARIANT: POSTCONDITION: The graph is linear
170 */
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000171 void add_operator_as_linear(OperatorId op, const std::vector<TensorId> &inputs, const std::vector<TensorId> &outputs, bool is_output = false)
SiCong Lif44bbc52022-08-29 18:25:51 +0100172 {
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000173 const auto success = add_operator(op, inputs, outputs, is_output);
SiCong Lifd766112022-11-09 16:01:44 +0000174 ARM_COMPUTE_UNUSED(success);
SiCong Lif44bbc52022-08-29 18:25:51 +0100175 ARM_COMPUTE_ERROR_ON(!success);
176 }
177 /** Add a new operator
178 * Return invalid if it violates the DAG invariant
179 * Invalid operation will not change the graph
180 *
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000181 * @param[in] op Operator to add
182 * @param[in] inputs Input tensors to the operator
183 * @param[in] outputs Output tensors to the operator
184 * @param[in] is_output Whether this is an output operator
SiCong Lif44bbc52022-08-29 18:25:51 +0100185 */
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000186 bool add_operator(OperatorId op, const std::vector<TensorId> &inputs, const std::vector<TensorId> &outputs, bool is_output = false)
SiCong Lif44bbc52022-08-29 18:25:51 +0100187 {
188 if(operator_exists(op))
189 {
190 return false;
191 }
192 _adj_src_tensors[op] = {};
193 _adj_dst_tensors[op] = {};
194 for(auto in_tensor : inputs)
195 {
196 // Linking input tensor to operator node will never create a cycle / loop because we guarantee
197 // each op is newly created, so every <input, op> pair / edge is new
198 link_input(op, in_tensor);
199 }
200 for(auto out_tensor : outputs)
201 {
202 // If there exists a back path from op's output tensor to op already, then linking the two will create a loop / cycle
203 if(path_exists_from_tensor_to_op(out_tensor, op))
204 {
205 remove_operator(op);
206 return false;
207 }
208 else
209 {
210 link_output(op, out_tensor);
211 }
212 }
213
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000214 if(!is_output)
215 {
216 _last_op_available = true;
217 _last_op = op;
218 }
219
SiCong Lif44bbc52022-08-29 18:25:51 +0100220 return true;
221 }
222
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000223 /** Build a sequence of operators from the acyclic graph of operators.
SiCong Lif44bbc52022-08-29 18:25:51 +0100224 *
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000225 * The graph will be visited in depth-first strategy. The operator can only be added to
226 * the sequence when all operators that supply the input tensors have been added. Otherwise,
227 * the operator will be ignored and later visited again. In other words, the dependency between
228 * operators will be preserved in the sequence.
SiCong Lif44bbc52022-08-29 18:25:51 +0100229 */
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000230 std::vector<OpPack> build_operators_sequence() const
SiCong Lif44bbc52022-08-29 18:25:51 +0100231 {
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000232 std::vector<OpPack> ops_seq;
233 std::set<Id> done_ops;
234 std::set<Id> done_tensors;
235
236 const auto input_tensors = global_src_tensors();
237
238 for(auto tensor : input_tensors)
SiCong Lif44bbc52022-08-29 18:25:51 +0100239 {
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000240 done_tensors.insert(tensor);
241
242 for(auto op : _adj_dst_ops.at(tensor))
SiCong Lif44bbc52022-08-29 18:25:51 +0100243 {
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000244 build_operators_sequence_from_op(op, ops_seq, done_ops, done_tensors);
SiCong Lif44bbc52022-08-29 18:25:51 +0100245 }
246 }
247
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000248 return ops_seq;
SiCong Lif44bbc52022-08-29 18:25:51 +0100249 }
250
251 /** Strict equality comparison (all internal ids and order of insertion matter).
252 * In the future this may be replaced with a topological comparison, allowing equivalent graphs with different internal ids to be equal
253 *
254 *
255 * @param[in] g0
256 * @param[in] g1
257 * @return true If the same
258 * @return false Otherwise
259 */
260 friend bool operator==(const DependencyGraph &g0, const DependencyGraph &g1)
261 {
262 // Do not compare id allocators
263 return std::make_tuple(
264 g0._adj_src_tensors, g0._adj_dst_tensors, g0._adj_src_ops, g0._adj_dst_ops)
265 == std::make_tuple(
266 g1._adj_src_tensors, g1._adj_dst_tensors, g1._adj_src_ops, g1._adj_dst_ops);
267 }
268 std::vector<OperatorId> src_ops_from_tensor(TensorId tensor) const
269 {
270 return _adj_src_ops.at(tensor);
271 }
272 std::vector<OperatorId> dst_ops_from_tensor(TensorId tensor) const
273 {
274 return _adj_dst_ops.at(tensor);
275 }
276 /** Get all tensors
277 *
278 * @return std::vector<TensorId>
279 */
280 std::vector<TensorId> all_tensors() const
281 {
282 std::vector<TensorId> tensors{};
283 std::transform(std::begin(_adj_src_ops), std::end(_adj_src_ops), std::back_inserter(tensors), [](const auto & it)
284 {
285 return it.first;
286 });
287 return tensors;
288 }
289 /** Get source tensors of the whole graph
290 *
291 * @return std::vector<TensorId>
292 */
293 std::vector<TensorId> global_src_tensors() const
294 {
295 std::vector<TensorId> tensors;
296 for(auto tensor_src_ops : _adj_src_ops)
297 {
298 if(tensor_src_ops.second.empty())
299 {
300 tensors.push_back(tensor_src_ops.first);
301 }
302 }
303 return tensors;
304 }
305 /** Get destination tensors of the whole graph
306 *
307 * @return std::vector<TensorId>
308 */
309 std::vector<TensorId> global_dst_tensors() const
310 {
311 std::vector<TensorId> tensors;
312 for(auto tensor_dst_ops : _adj_dst_ops)
313 {
314 if(tensor_dst_ops.second.empty())
315 {
316 tensors.push_back(tensor_dst_ops.first);
317 }
318 }
319 return tensors;
320 }
Viet-Hoa Dob84e2532022-12-13 13:09:10 +0000321 /** Get intermediate tensors of the whole graph.
322 *
323 * @return std::vector<TensorId>
324 */
325 std::vector<TensorId> intermediate_tensors() const
326 {
327 std::vector<TensorId> tensors;
328
329 // If a tensor is used to connect the input of an operator and the output of another operator,
330 // it is not allocated in the memory. The tensor exists as a temporary variable only.
331 for(auto src_tensor : _adj_src_ops)
332 {
333 if(!src_tensor.second.empty())
334 {
335 const auto dst_tensor = _adj_dst_ops.find(src_tensor.first);
336 if(dst_tensor != _adj_dst_ops.end())
337 {
338 if(!dst_tensor->second.empty())
339 {
340 tensors.push_back(src_tensor.first);
341 }
342 }
343 }
344 }
345
346 return tensors;
347 }
SiCong Lif44bbc52022-08-29 18:25:51 +0100348 /** Get all root ops. Root ops can also be referred to as "src ops" of the whole graph
349 *
350 * @return std::vector<OperatorId>
351 */
352 std::vector<OperatorId> get_root_ops() const
353 {
354 std::vector<OperatorId> ops{};
355 const auto op_list = all_ops();
356
357 for(auto op : op_list)
358 {
359 if(src_ops(op).empty())
360 {
361 ops.emplace_back(op);
362 }
363 }
364 return ops;
365 }
366
367private:
368 void link_input(OperatorId op, TensorId in_tensor)
369 {
370 ARM_COMPUTE_ERROR_ON(!operator_exists(op));
371 if(!tensor_exists(in_tensor))
372 {
373 insert_new_tensor(in_tensor);
374 }
375 ARM_COMPUTE_ERROR_ON(are_connected(op, in_tensor)); // Prevent repetitive linking
376 _adj_src_tensors[op].push_back(in_tensor);
377 _adj_dst_ops[in_tensor].push_back(op);
378 }
379 void link_output(OperatorId op, TensorId out_tensor)
380 {
381 ARM_COMPUTE_ERROR_ON(!operator_exists(op));
382 if(!tensor_exists(out_tensor))
383 {
384 insert_new_tensor(out_tensor);
385 }
386 ARM_COMPUTE_ERROR_ON(are_connected(op, out_tensor)); // Prevent repetitive linking
387 _adj_dst_tensors[op].push_back(out_tensor);
388 _adj_src_ops[out_tensor].push_back(op);
389 }
390
391 std::vector<OperatorId> src_ops(OperatorId op) const
392 {
393 ARM_COMPUTE_ERROR_ON(!operator_exists(op));
394 std::vector<OperatorId> ops{};
395 for(TensorId src_tensor : src_tensors(op))
396 {
397 ops.insert(ops.end(), std::begin(_adj_src_ops.at(src_tensor)), std::end(_adj_src_ops.at(src_tensor)));
398 }
399 return ops;
400 }
401 std::vector<OperatorId> dst_ops(OperatorId op) const
402 {
403 ARM_COMPUTE_ERROR_ON(!operator_exists(op));
404 std::vector<OperatorId> ops{};
405 for(TensorId dst_tensor : _adj_dst_tensors.at(op))
406 {
407 ops.insert(ops.end(), std::begin(_adj_dst_ops.at(dst_tensor)), std::end(_adj_dst_ops.at(dst_tensor)));
408 }
409 return ops;
410 }
411
412 /** Get source tensors to an operator
413 *
414 * @param[in] op
415 * @return std::vector<TensorId>
416 */
417 std::vector<TensorId> src_tensors(OperatorId op) const
418 {
419 ARM_COMPUTE_ERROR_ON(!operator_exists(op));
420 return _adj_src_tensors.at(op);
421 }
422 /** Get destination tensors to an operator
423 *
424 * @param[in] op
425 * @return std::vector<TensorId>
426 */
427 std::vector<TensorId> dst_tensors(OperatorId op) const
428 {
429 ARM_COMPUTE_ERROR_ON(!operator_exists(op));
430 return _adj_dst_tensors.at(op);
431 }
432 /** Get all operators
433 *
434 * @return std::vector<OperatorId>
435 */
436 std::vector<OperatorId> all_ops() const
437 {
438 std::vector<OperatorId> ops{};
439 std::transform(std::begin(_adj_src_tensors), std::end(_adj_src_tensors), std::back_inserter(ops), [](const auto & it)
440 {
441 return it.first;
442 });
443 return ops;
444 }
445 /** Remove an operator from graph.
446 *
447 * @param[in] op
448 */
449 void remove_operator(OperatorId op)
450 {
451 for(auto src_tensor : _adj_src_tensors.at(op))
452 {
453 auto &dst_ops = _adj_dst_ops.at(src_tensor);
454 dst_ops.erase(
455 std::remove(std::begin(dst_ops), std::end(dst_ops), op),
456 std::end(dst_ops));
457 }
458 for(auto dst_tensor : _adj_dst_tensors.at(op))
459 {
460 auto &src_ops = _adj_src_ops.at(dst_tensor);
461 src_ops.erase(
462 std::remove(std::begin(src_ops), std::end(src_ops), op),
463 std::end(src_ops));
464 }
465 // Remove any isolated tensors
466 // An isolated tensor is one where both its _adj_src_ops and _adj_dst_ops are empty
467 for(auto t : all_tensors())
468 {
469 if(_adj_src_ops.at(t).empty() && _adj_dst_ops.at(t).empty())
470 {
471 _adj_src_ops.erase(t);
472 _adj_dst_ops.erase(t);
473 }
474 }
475 _adj_src_tensors.erase(op);
476 _adj_dst_tensors.erase(op);
477 }
478 void insert_new_tensor(TensorId tensor)
479 {
480 _adj_src_ops[tensor] = {};
481 _adj_dst_ops[tensor] = {};
482 }
483 bool tensor_exists(TensorId tensor) const
484 {
485 return _adj_src_ops.find(tensor) != _adj_src_ops.end() && _adj_dst_ops.find(tensor) != _adj_dst_ops.end();
486 }
487 bool operator_exists(OperatorId op) const
488 {
489 return _adj_src_tensors.find(op) != _adj_src_tensors.end() && _adj_dst_tensors.find(op) != _adj_dst_tensors.end();
490 }
491 bool is_src_tensor_of(OperatorId op, TensorId tensor) const
492 {
493 if(!operator_exists(op) || !tensor_exists(tensor))
494 {
495 return false;
496 }
497 const auto op_inputs = src_tensors(op);
498 return std::find(op_inputs.begin(), op_inputs.end(), tensor) != op_inputs.end();
499 }
500 bool is_dst_tensor_of(OperatorId op, TensorId tensor) const
501 {
502 if(!operator_exists(op) || !tensor_exists(tensor))
503 {
504 return false;
505 }
506 const auto op_outputs = dst_tensors(op);
507 return std::find(op_outputs.begin(), op_outputs.end(), tensor) != op_outputs.end();
508 }
509 bool are_connected(OperatorId op, TensorId tensor) const
510 {
511 return is_src_tensor_of(op, tensor) || is_dst_tensor_of(op, tensor);
512 }
513 /** If op is the destination / leaf operator of the whole graph
514 *
515 * @param[in] op
516 * @return true
517 * @return false
518 */
519 bool is_dst_op(OperatorId op) const
520 {
521 return dst_ops(op).empty();
522 }
523 std::vector<OperatorId> get_dst_ops() const
524 {
525 std::vector<OperatorId> ops{};
526 const auto op_list = all_ops();
527
528 for(auto op : op_list)
529 {
530 if(is_dst_op(op))
531 {
532 ops.emplace_back(op);
533 }
534 }
535 return ops;
536 }
537 bool path_exists_from_tensor_to_op(TensorId src_tensor, OperatorId dst_op) const
538 {
539 if(!tensor_exists(src_tensor) || !operator_exists(dst_op))
540 {
541 return false;
542 }
543 for(auto child_op : dst_ops_from_tensor(src_tensor))
544 {
545 if(path_exists_from_op_to_op(child_op, dst_op))
546 {
547 return true;
548 }
549 }
550 return false;
551 }
552
553 bool path_exists_from_op_to_op(OperatorId src_op, OperatorId dst_op) const
554 {
555 if(!operator_exists(src_op) || !operator_exists(dst_op))
556 {
557 return false;
558 }
559 if(src_op == dst_op)
560 {
561 return true;
562 }
563 if(is_in(src_op, get_dst_ops()))
564 {
565 return false;
566 }
567 for(auto child_tensor : dst_tensors(src_op))
568 {
569 if(path_exists_from_tensor_to_op(child_tensor, dst_op))
570 {
571 return true;
572 }
573 }
574 return false;
575 }
576
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000577 void build_operators_sequence_from_op(
578 Id op,
579 std::vector<OpPack> &ops_seq,
580 std::set<Id> &done_ops,
581 std::set<Id> &done_tensors) const
582 {
583 while(true)
584 {
585 // If the operator has been added to the sequence, ignore it.
586 if(done_ops.find(op) != done_ops.end())
587 {
588 return;
589 }
590
591 // If not all the input tensors of the operator are available, this operator cannot be
592 // added to the sequence for now. It will be visited again after the source operator
593 // is added to the sequence.
594 const auto src_tensors = _adj_src_tensors.at(op);
595
596 for(auto src : src_tensors)
597 {
598 if(done_tensors.find(src) == done_tensors.end())
599 {
600 return;
601 }
602 }
603
604 // This operator is ready to be added to the sequence.
605 const auto dst_tensors = _adj_dst_tensors.at(op);
606
607 done_ops.insert(op);
608
609 OpPack pack{ op, src_tensors, dst_tensors };
610 ops_seq.push_back(pack);
611
612 done_tensors.insert(dst_tensors.begin(), dst_tensors.end());
613
614 // Visit all the sink operators.
615 // Call this function recursively unless there is only one sink.
616 if(dst_tensors.size() == 1 && _adj_dst_ops.at(dst_tensors[0]).size() == 1)
617 {
618 op = _adj_dst_ops.at(dst_tensors[0])[0];
619 }
620 else
621 {
622 for(auto dst_tensor : dst_tensors)
623 {
624 const auto dst_ops = _adj_dst_ops.at(dst_tensor);
625
626 for(auto dst_op : dst_ops)
627 {
628 build_operators_sequence_from_op(dst_op, ops_seq, done_ops, done_tensors);
629 }
630 }
631
632 return;
633 }
634 }
635 }
636
SiCong Lif44bbc52022-08-29 18:25:51 +0100637private:
638 AdjList _adj_src_tensors{};
639 AdjList _adj_dst_tensors{};
640 AdjList _adj_src_ops{};
641 AdjList _adj_dst_ops{};
Viet-Hoa Do04f46202022-12-14 14:49:56 +0000642
643 bool _last_op_available{ false };
644 OperatorId _last_op{ 0 };
SiCong Lif44bbc52022-08-29 18:25:51 +0100645};
646
647} // namespace dynamic_fusion
648} // namespace experimental
649} // namespace arm_compute
650#endif /* SRC_DYNAMIC_FUSION_SKETCH_UTILS_DEPENDENCYGRAPH */