blob: 2e8292bbfb7f1f4d318c6a01118177f7b7df5418 [file] [log] [blame]
SiCong Lib63b1192022-01-28 18:24:39 +00001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION
25#error "This experimental feature must be enabled with -DENABLE_EXPERIMENTAL_DYNAMIC_FUSION"
26#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */
27#include "arm_compute/core/experimental/DependencyGraph.h"
28
29#include <algorithm>
30#include <deque>
31#include <set>
32
33namespace arm_compute
34{
35namespace experimental
36{
37namespace dynamic_fusion
38{
39DependencyGraph::DependencyGraph(const AdjList &adj_src_tensors, const AdjList &adj_dst_tensors, const AdjList &adj_src_ops, const AdjList &adj_dst_ops, std::map<Id, Id> merge_points)
40 : _adj_src_tensors{ adj_src_tensors }, _adj_dst_tensors{ adj_dst_tensors }, _adj_src_ops{ adj_src_ops }, _adj_dst_ops{ adj_dst_ops }, _merge_to_internal{ merge_points }, _operator_id{}, _tensor_id{}
41{
42}
43DependencyGraph::DependencyGraph(const std::vector<Id> &imported_tensors)
44 : _adj_src_tensors{}, _adj_dst_tensors{}, _adj_src_ops{}, _adj_dst_ops{}, _merge_to_internal{}, _operator_id{}, _tensor_id{}
45{
46 for(auto t : imported_tensors)
47 {
48 _adj_src_ops[t] = {};
49 _adj_dst_ops[t] = {};
50 }
51}
52
53Status DependencyGraph::update_merge_point(Id t_id, Id merge_point)
54{
55 if(_merge_to_internal.find(merge_point) == _merge_to_internal.end())
56 {
57 return Status{ ErrorCode::RUNTIME_ERROR, "Merge point does not exist" };
58 }
59 _merge_to_internal[merge_point] = t_id;
60 return Status{};
61}
62
63DependencyGraph::Id DependencyGraph::add_tensor(Id merge_tensor)
64{
65 Id new_tensor{ empty_id() };
66 if(merge_tensor != empty_id())
67 {
68 if(_merge_to_internal.find(merge_tensor) != _merge_to_internal.end())
69 {
70 new_tensor = _merge_to_internal[merge_tensor];
71 }
72 else
73 {
74 new_tensor = insert_new_tensor();
75 _merge_to_internal[merge_tensor] = new_tensor;
76 }
77 }
78 else
79 {
80 new_tensor = insert_new_tensor();
81 }
82 return new_tensor;
83}
84
85void DependencyGraph::remove_tensor(Id tensor)
86{
87 for(auto src_op : _adj_src_ops.at(tensor))
88 {
89 auto &dst_tensors = _adj_dst_tensors.at(src_op);
90 dst_tensors.erase(
91 std::remove(std::begin(dst_tensors), std::end(dst_tensors), tensor),
92 std::end(dst_tensors));
93 }
94 for(auto dst_op : _adj_dst_ops.at(tensor))
95 {
96 auto &src_tensors = _adj_src_tensors.at(dst_op);
97 src_tensors.erase(
98 std::remove(std::begin(src_tensors), std::end(src_tensors), tensor),
99 std::end(src_tensors));
100 }
101 _adj_src_ops.erase(tensor);
102 _adj_dst_ops.erase(tensor);
103}
104
105std::pair<Status, DependencyGraph::Id> DependencyGraph::add_operator(const std::vector<Id> &inputs, const std::vector<Id> &outputs)
106{
107 Id new_op = insert_new_op();
108 for(Id tensor : inputs)
109 {
110 link_input(new_op, tensor);
111 }
112 for(Id tensor : outputs)
113 {
114 link_output(new_op, tensor);
115 }
116
117 // Use topological sort in order to detect possible loops / cycles.
118 // NOTE: This is unscalable. We'll need to have a better way of detecting loops or relax this invariant during operation, and add a validate method instead
119 return std::pair<Status, DependencyGraph::Id>(topological_sort().first, new_op);
120}
121
122void DependencyGraph::remove_operator(Id op)
123{
124 for(auto src_tensor : _adj_src_tensors.at(op))
125 {
126 auto &dst_ops = _adj_dst_ops.at(src_tensor);
127 dst_ops.erase(
128 std::remove(std::begin(dst_ops), std::end(dst_ops), op),
129 std::end(dst_ops));
130 }
131 for(auto dst_tensor : _adj_dst_tensors.at(op))
132 {
133 auto &src_ops = _adj_src_ops.at(dst_tensor);
134 src_ops.erase(
135 std::remove(std::begin(src_ops), std::end(src_ops), op),
136 std::end(src_ops));
137 }
138 _adj_src_tensors.erase(op);
139 _adj_dst_tensors.erase(op);
140}
141
142std::map<DependencyGraph::Id, DependencyGraph::Id> DependencyGraph::get_merge_points() const
143{
144 return _merge_to_internal;
145}
146
147std::vector<DependencyGraph::Id> DependencyGraph::get_root_ops() const
148{
149 std::vector<Id> ops{};
150 const auto op_list = all_ops();
151
152 for(auto op : op_list)
153 {
154 if(src_ops(op).empty())
155 {
156 ops.emplace_back(op);
157 }
158 }
159 return ops;
160}
161
162std::vector<DependencyGraph::Id> DependencyGraph::get_dst_ops() const
163{
164 std::vector<Id> ops{};
165 const auto op_list = all_ops();
166
167 for(auto op : op_list)
168 {
169 if(dst_ops(op).empty())
170 {
171 ops.emplace_back(op);
172 }
173 }
174 return ops;
175}
176
177std::vector<DependencyGraph::Id> DependencyGraph::src_tensors(Id op) const
178{
179 ARM_COMPUTE_ERROR_ON(!operator_exists(op));
180 return _adj_src_tensors.at(op);
181}
182
183std::vector<DependencyGraph::Id> DependencyGraph::dst_tensors(Id op) const
184{
185 ARM_COMPUTE_ERROR_ON(!operator_exists(op));
186 return _adj_dst_tensors.at(op);
187}
188
189std::vector<DependencyGraph::Id> DependencyGraph::src_tensors() const
190{
191 std::vector<Id> tensors;
192 for(auto tensor_src_ops : _adj_src_ops)
193 {
194 if(tensor_src_ops.second.empty())
195 tensors.push_back(tensor_src_ops.first);
196 }
197 return tensors;
198}
199
200std::vector<DependencyGraph::Id> DependencyGraph::dst_tensors() const
201{
202 std::vector<Id> tensors;
203 for(auto tensor_dst_ops : _adj_dst_ops)
204 {
205 if(tensor_dst_ops.second.empty())
206 tensors.push_back(tensor_dst_ops.first);
207 }
208 return tensors;
209}
210
211std::vector<DependencyGraph::Id> DependencyGraph::src_ops_from_tensor(Id tensor) const
212{
213 return _adj_src_ops.at(tensor);
214}
215std::vector<DependencyGraph::Id> DependencyGraph::dst_ops_from_tensor(Id tensor) const
216{
217 return _adj_dst_ops.at(tensor);
218}
219
220std::vector<DependencyGraph::Id> DependencyGraph::all_ops() const
221{
222 std::vector<Id> ops{};
223 std::transform(std::begin(_adj_src_tensors), std::end(_adj_src_tensors), std::back_inserter(ops), [](const auto & it)
224 {
225 return it.first;
226 });
227 return ops;
228}
229
230bool DependencyGraph::path_exists_from_tensor_to_op(Id src_tensor, Id dst_op) const
231{
232 for(auto child_op : dst_ops_from_tensor(src_tensor))
233 {
234 if(path_exists_from_op_to_op(child_op, dst_op))
235 {
236 return true;
237 }
238 }
239 return false;
240}
241
242bool DependencyGraph::path_exists_from_op_to_op(Id src_op, Id dst_op) const
243{
244 if(src_op == dst_op)
245 {
246 return true;
247 }
248 if(is_in(src_op, get_dst_ops()))
249 {
250 return false;
251 }
252 for(auto child_tensor : dst_tensors(src_op))
253 {
254 if(path_exists_from_tensor_to_op(child_tensor, dst_op))
255 {
256 return true;
257 }
258 }
259 return false;
260}
261
262std::vector<DependencyGraph::Id> DependencyGraph::all_tensors() const
263{
264 std::vector<Id> tensors{};
265 std::transform(std::begin(_adj_src_ops), std::end(_adj_src_ops), std::back_inserter(tensors), [](const auto & it)
266 {
267 return it.first;
268 });
269 return tensors;
270}
271
272unsigned int DependencyGraph::number_of_ops() const
273{
274 return _adj_src_tensors.size();
275}
276
277unsigned int DependencyGraph::number_of_tensors() const
278{
279 return _adj_src_ops.size();
280}
281
282DependencyGraph::Id DependencyGraph::insert_new_tensor()
283{
284 Id new_tensor = _tensor_id.alloc();
285 _adj_src_ops[new_tensor] = {};
286 _adj_dst_ops[new_tensor] = {};
287 return new_tensor;
288}
289DependencyGraph::Id DependencyGraph::insert_new_op()
290{
291 Id new_op = _operator_id.alloc();
292 _adj_src_tensors[new_op] = {};
293 _adj_dst_tensors[new_op] = {};
294 return new_op;
295}
296void DependencyGraph::link_input(Id op, Id in_tensor)
297{
298 ARM_COMPUTE_ERROR_ON(!operator_exists(op));
299 ARM_COMPUTE_ERROR_ON(!tensor_exists(in_tensor));
300 ARM_COMPUTE_ERROR_ON(are_connected(op, in_tensor));
301 _adj_src_tensors[op].push_back(in_tensor);
302 _adj_dst_ops[in_tensor].push_back(op);
303}
304void DependencyGraph::link_output(Id op, Id out_tensor)
305{
306 ARM_COMPUTE_ERROR_ON(!operator_exists(op));
307 ARM_COMPUTE_ERROR_ON(!tensor_exists(out_tensor));
308 ARM_COMPUTE_ERROR_ON(are_connected(op, out_tensor));
309 _adj_dst_tensors[op].push_back(out_tensor);
310 _adj_src_ops[out_tensor].push_back(op);
311}
312bool DependencyGraph::tensor_exists(Id tensor) const
313{
314 return _adj_src_ops.find(tensor) != _adj_src_ops.end() && _adj_dst_ops.find(tensor) != _adj_dst_ops.end();
315}
316bool DependencyGraph::operator_exists(Id op) const
317{
318 return _adj_src_tensors.find(op) != _adj_src_tensors.end() && _adj_dst_tensors.find(op) != _adj_dst_tensors.end();
319}
320
321bool DependencyGraph::is_src_tensor(Id tensor) const
322{
323 if(!tensor_exists(tensor))
324 {
325 return false;
326 }
327 return _adj_src_ops.at(tensor).empty();
328}
329
330bool DependencyGraph::is_dst_tensor(Id tensor) const
331{
332 if(!tensor_exists(tensor))
333 {
334 return false;
335 }
336 return _adj_dst_ops.at(tensor).empty();
337}
338bool DependencyGraph::is_src_tensor_of(Id op, Id tensor) const
339{
340 if(!operator_exists(op) || !tensor_exists(tensor))
341 {
342 return false;
343 }
344 const auto op_inputs = src_tensors(op);
345 return std::find(op_inputs.begin(), op_inputs.end(), tensor) != op_inputs.end();
346}
347bool DependencyGraph::is_dst_tensor_of(Id op, Id tensor) const
348{
349 if(!operator_exists(op) || !tensor_exists(tensor))
350 {
351 return false;
352 }
353 const auto op_outputs = dst_tensors(op);
354 return std::find(op_outputs.begin(), op_outputs.end(), tensor) != op_outputs.end();
355}
356bool DependencyGraph::are_connected(Id op, Id tensor) const
357{
358 return is_src_tensor_of(op, tensor) || is_dst_tensor_of(op, tensor);
359}
360std::vector<DependencyGraph::Id> DependencyGraph::src_ops(Id op) const
361{
362 ARM_COMPUTE_ERROR_ON(!operator_exists(op));
363 std::vector<Id> ops{};
364 for(Id src_tensor : src_tensors(op))
365 {
366 ops.insert(ops.end(), std::begin(_adj_src_ops.at(src_tensor)), std::end(_adj_src_ops.at(src_tensor)));
367 }
368 return ops;
369}
370
371std::vector<DependencyGraph::Id> DependencyGraph::dst_ops(Id op) const
372{
373 ARM_COMPUTE_ERROR_ON(!operator_exists(op));
374 std::vector<Id> ops{};
375 for(Id dst_tensor : _adj_dst_tensors.at(op))
376 {
377 ops.insert(ops.end(), std::begin(_adj_dst_ops.at(dst_tensor)), std::end(_adj_dst_ops.at(dst_tensor)));
378 }
379 return ops;
380}
381
382std::pair<Status, std::vector<DependencyGraph::OpPack>> DependencyGraph::topological_sort() const
383{
384 // Incident degree (number of source operators to an op)
385 std::map<Id, unsigned int> in_degree{};
386 std::set<Id> visited_ops{};
387 std::deque<Id> zero_in_degree_ops{};
388 std::vector<OpPack> sorted_op_packs{};
389 for(auto op : all_ops())
390 {
391 const auto degree = src_ops(op).size();
392 in_degree[op] = degree;
393 if(degree == 0)
394 {
395 zero_in_degree_ops.push_back(op);
396 visited_ops.insert(op);
397 }
398 }
399
400 while(!zero_in_degree_ops.empty())
401 {
402 const Id op = zero_in_degree_ops.front();
403 zero_in_degree_ops.pop_front();
404 sorted_op_packs.push_back(OpPack{ op, src_tensors(op), dst_tensors(op) });
405
406 for(const auto next_op : dst_ops(op))
407 {
408 if(in_degree[next_op] > 0)
409 {
410 in_degree[next_op]--;
411 }
412 if(in_degree[next_op] == 0 && visited_ops.find(next_op) == visited_ops.end())
413 {
414 zero_in_degree_ops.push_back(next_op);
415 visited_ops.insert(op);
416 }
417 }
418 }
419
420 // If there are remaining ops with in_degree > 0, then it's indication that there are cycles in the graph
421 Status st{};
422 if(sorted_op_packs.size() != number_of_ops())
423 {
424 st = Status{ ErrorCode::RUNTIME_ERROR, "Cycles or loops are not allowed in a DependencyGraph" };
425 }
426 return std::make_pair(st, sorted_op_packs);
427}
428
429} // namespace dynamic_fusion
430} // namespace experimental
431} // namespace arm_compute