blob: 794bf0e344136f079a2a810d44abd10b4c8a6b5c [file] [log] [blame]
SiCong Lib63b1192022-01-28 18:24:39 +00001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#ifndef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION
25#error "This experimental feature must be enabled with -DENABLE_EXPERIMENTAL_DYNAMIC_FUSION"
26#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */
27#ifndef ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_DEPENDENCYGRAPH_H
28#define ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_DEPENDENCYGRAPH_H
29
30#include "arm_compute/core/Error.h"
31
32#include <algorithm>
33#include <map>
34#include <vector>
35
36namespace arm_compute
37{
38namespace experimental
39{
40namespace dynamic_fusion
41{
42template <typename T>
43bool is_in(const T &v, const std::vector<T> &vec)
44{
45 return std::find(std::begin(vec), std::end(vec), v) != std::end(vec);
46}
47
48/** The dependency graph of a workload, where the nodes are of 2 types: Tensor or Operator
49 * Represented as a doubly-linked adjacency list with the differentiation between source and destination
50 *
51 * A "Merge Tensor" is an external tensor associated with the tensor within the graph, and serve as a merge point
52 */
53class DependencyGraph
54{
55public:
56 /** A serial Id allocator
57 *
58 */
59 class SerialIdAllocator
60 {
61 public:
62 using Id = int;
63 Id alloc()
64 {
65 return _counter++;
66 }
67 constexpr static Id empty()
68 {
69 return -1;
70 }
71
72 private:
73 Id _counter{ 0 };
74 };
75 using Id = SerialIdAllocator::Id;
76 /** Adjacency list
77 *
78 */
79 using AdjList = std::map<Id, std::vector<Id>>;
80
81 /** A pack of operator including its input and output tensors, used by traversing through the graph in topological order
82 *
83 */
84 struct OpPack
85 {
86 Id op{};
87 std::vector<Id> inputs{};
88 std::vector<Id> outputs{};
89 friend bool operator==(const OpPack &opp0, const OpPack &opp1)
90 {
91 return std::make_tuple(
92 opp0.op, opp0.inputs, opp0.outputs)
93 == std::make_tuple(
94 opp1.op, opp1.inputs, opp1.outputs);
95 }
96 };
97
98public:
99 constexpr static Id empty_id()
100 {
101 return SerialIdAllocator::empty();
102 }
103
104 DependencyGraph() = default;
105 // Used in cases where two DependencyGraphs may want to share the same configuration of tensors
106 explicit DependencyGraph(const std::vector<Id> &imported_tensors);
107 // Testing only
108 DependencyGraph(const AdjList &adj_src_tensors, const AdjList &adj_dst_tensors, const AdjList &adj_src_ops, const AdjList &adj_dst_ops, std::map<Id, Id> merge_points = {});
109
110 /** Add a new tensor
111 *
112 * @param merge_tensor The external merge point associated with the tensor. Leave empty if not needed.
113 * @return Id The newly allocated tensor, or a previously added tensor associated with @p merge_tensor
114 */
115 Id add_tensor(Id merge_tensor = empty_id());
116
117 void remove_tensor(Id tensor);
118
119 /** Add a new operator
120 *
121 * @param inputs Input tensors to the operator
122 * @param outputs Output tensors to the operator
123 * @return std::pair<Status, DependencyGraph::Id> where id is the newly allocated operator
124 */
125 std::pair<Status, DependencyGraph::Id> add_operator(const std::vector<Id> &inputs, const std::vector<Id> &outputs);
126
127 void remove_operator(Id op);
128 /** Sort the graph in a topological order
129 *
130 * @return std::pair<Status, std::vector<OpPack>>
131 */
132 std::pair<Status, std::vector<OpPack>> topological_sort() const;
133
134 std::vector<Id> src_ops(Id op) const;
135 std::vector<Id> dst_ops(Id op) const;
136
137 std::vector<Id> src_ops_from_tensor(Id tensor) const;
138 std::vector<Id> dst_ops_from_tensor(Id tensor) const;
139 /** Get the merge points object
140 *
141 * @return std::map<Id, Id>
142 */
143 std::map<Id, Id> get_merge_points() const;
144 /** Get all root ops. Root ops can also be referred to as "src ops" of the whole graph
145 *
146 * @return std::vector<Id>
147 */
148 std::vector<Id> get_root_ops() const;
149 /** Get all dst ops of the whole graph
150 *
151 * @return std::vector<Id>
152 */
153 std::vector<Id> get_dst_ops() const;
154
155 /** Get source tensors to an operator
156 *
157 * @param op
158 * @return std::vector<Id>
159 */
160 std::vector<Id> src_tensors(Id op) const;
161 /** Get destination tensors to an operator
162 *
163 * @param op
164 * @return std::vector<Id>
165 */
166 std::vector<Id> dst_tensors(Id op) const;
167 /** Get source tensors of the whole graph
168 *
169 * @return std::vector<Id>
170 */
171 std::vector<Id> src_tensors() const;
172 /** Get destination tensors of the whole graph
173 *
174 * @return std::vector<Id>
175 */
176 std::vector<Id> dst_tensors() const;
177 /** Get all operators
178 *
179 * @return std::vector<Id>
180 */
181 std::vector<Id> all_ops() const;
182 /** Get all tensors
183 *
184 * @return std::vector<Id>
185 */
186 std::vector<Id> all_tensors() const;
187 /** Number of operators
188 *
189 * @return unsigned int
190 */
191 unsigned int number_of_ops() const;
192 /** Number of tensors
193 *
194 * @return unsigned int
195 */
196 unsigned int number_of_tensors() const;
197
198 /** Update @p merge_point to point to @p t_id
199 *
200 * @param t_id
201 * @param merge_point
202 */
203 Status update_merge_point(Id t_id, Id merge_point);
204
205 /** Strict equality comparison (all internal ids and order of insertion matter).
206 * In the future this may be replaced with a topological comparison, allowing equivalent graphs with different internal ids to be equal
207 *
208 *
209 * @param g0
210 * @param g1
211 * @return true
212 * @return false
213 */
214 friend bool operator==(const DependencyGraph &g0, const DependencyGraph &g1)
215 {
216 // Do not compare id allocators
217 return std::make_tuple(
218 g0._adj_src_tensors, g0._adj_dst_tensors, g0._adj_src_ops, g0._adj_dst_ops, g0._merge_to_internal)
219 == std::make_tuple(
220 g1._adj_src_tensors, g1._adj_dst_tensors, g1._adj_src_ops, g1._adj_dst_ops, g1._merge_to_internal);
221 }
222 void link_input(Id op, Id in_tensor);
223 void link_output(Id op, Id out_tensor);
224 /** Check if there's a path from @p src_tensor to @p dst_op
225 *
226 * @param src_tensor
227 * @param dst_op
228 * @return true
229 * @return false
230 */
231 bool path_exists_from_tensor_to_op(Id src_tensor, Id dst_op) const;
232 /** Check if there's a path from @p src_op to @p dst_op
233 *
234 * @param src_op
235 * @param dst_op
236 * @return true
237 * @return false
238 */
239 bool path_exists_from_op_to_op(Id src_op, Id dst_op) const;
240 /** Check if tensor is the src tensor of the entire graph
241 *
242 * @param tensor
243 * @return true
244 * @return false
245 */
246 bool is_src_tensor(Id tensor) const;
247 /** Check if tensor is the dst tensor of the entire graph
248 *
249 * @param tensor
250 * @return true
251 * @return false
252 */
253 bool is_dst_tensor(Id tensor) const;
254
255private:
256 Id insert_new_tensor();
257 Id insert_new_op();
258 bool tensor_exists(Id tensor) const;
259 bool operator_exists(Id op) const;
260 bool is_src_tensor_of(Id op, Id tensor) const;
261 bool is_dst_tensor_of(Id op, Id tensor) const;
262 bool are_connected(Id op, Id tensor) const;
263
264private:
265 AdjList _adj_src_tensors{};
266 AdjList _adj_dst_tensors{};
267 AdjList _adj_src_ops{};
268 AdjList _adj_dst_ops{};
269 std::map<Id, Id> _merge_to_internal{}; // From merge tensor to internal tensor
270 SerialIdAllocator _operator_id{};
271 SerialIdAllocator _tensor_id{};
272};
273
274} // namespace dynamic_fusion
275} // namespace experimental
276} // namespace arm_compute
277
278#endif //ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_DEPENDENCYGRAPH_H