blob: e0d6ff9ba9b4871e1508301349997f5350712ce7 [file] [log] [blame]
SiCong Lib63b1192022-01-28 18:24:39 +00001/*
2 * Copyright (c) 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
SiCong Li4e9f5682022-05-10 10:15:59 +010024#ifdef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION
SiCong Lib63b1192022-01-28 18:24:39 +000025#ifndef ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_DEPENDENCYGRAPH_H
26#define ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_DEPENDENCYGRAPH_H
27
28#include "arm_compute/core/Error.h"
29
30#include <algorithm>
31#include <map>
32#include <vector>
33
34namespace arm_compute
35{
36namespace experimental
37{
38namespace dynamic_fusion
39{
40template <typename T>
41bool is_in(const T &v, const std::vector<T> &vec)
42{
43 return std::find(std::begin(vec), std::end(vec), v) != std::end(vec);
44}
45
46/** The dependency graph of a workload, where the nodes are of 2 types: Tensor or Operator
47 * Represented as a doubly-linked adjacency list with the differentiation between source and destination
48 *
49 * A "Merge Tensor" is an external tensor associated with the tensor within the graph, and serve as a merge point
50 */
51class DependencyGraph
52{
53public:
54 /** A serial Id allocator
55 *
56 */
57 class SerialIdAllocator
58 {
59 public:
60 using Id = int;
61 Id alloc()
62 {
63 return _counter++;
64 }
65 constexpr static Id empty()
66 {
67 return -1;
68 }
69
70 private:
71 Id _counter{ 0 };
72 };
73 using Id = SerialIdAllocator::Id;
74 /** Adjacency list
75 *
76 */
77 using AdjList = std::map<Id, std::vector<Id>>;
78
79 /** A pack of operator including its input and output tensors, used by traversing through the graph in topological order
80 *
81 */
82 struct OpPack
83 {
84 Id op{};
85 std::vector<Id> inputs{};
86 std::vector<Id> outputs{};
87 friend bool operator==(const OpPack &opp0, const OpPack &opp1)
88 {
89 return std::make_tuple(
90 opp0.op, opp0.inputs, opp0.outputs)
91 == std::make_tuple(
92 opp1.op, opp1.inputs, opp1.outputs);
93 }
94 };
95
96public:
97 constexpr static Id empty_id()
98 {
99 return SerialIdAllocator::empty();
100 }
101
102 DependencyGraph() = default;
103 // Used in cases where two DependencyGraphs may want to share the same configuration of tensors
104 explicit DependencyGraph(const std::vector<Id> &imported_tensors);
105 // Testing only
106 DependencyGraph(const AdjList &adj_src_tensors, const AdjList &adj_dst_tensors, const AdjList &adj_src_ops, const AdjList &adj_dst_ops, std::map<Id, Id> merge_points = {});
107
108 /** Add a new tensor
109 *
110 * @param merge_tensor The external merge point associated with the tensor. Leave empty if not needed.
111 * @return Id The newly allocated tensor, or a previously added tensor associated with @p merge_tensor
112 */
113 Id add_tensor(Id merge_tensor = empty_id());
114
115 void remove_tensor(Id tensor);
116
117 /** Add a new operator
118 *
119 * @param inputs Input tensors to the operator
120 * @param outputs Output tensors to the operator
121 * @return std::pair<Status, DependencyGraph::Id> where id is the newly allocated operator
122 */
123 std::pair<Status, DependencyGraph::Id> add_operator(const std::vector<Id> &inputs, const std::vector<Id> &outputs);
124
125 void remove_operator(Id op);
126 /** Sort the graph in a topological order
127 *
128 * @return std::pair<Status, std::vector<OpPack>>
129 */
130 std::pair<Status, std::vector<OpPack>> topological_sort() const;
131
132 std::vector<Id> src_ops(Id op) const;
133 std::vector<Id> dst_ops(Id op) const;
134
135 std::vector<Id> src_ops_from_tensor(Id tensor) const;
136 std::vector<Id> dst_ops_from_tensor(Id tensor) const;
137 /** Get the merge points object
138 *
139 * @return std::map<Id, Id>
140 */
141 std::map<Id, Id> get_merge_points() const;
142 /** Get all root ops. Root ops can also be referred to as "src ops" of the whole graph
143 *
144 * @return std::vector<Id>
145 */
146 std::vector<Id> get_root_ops() const;
147 /** Get all dst ops of the whole graph
148 *
149 * @return std::vector<Id>
150 */
151 std::vector<Id> get_dst_ops() const;
152
153 /** Get source tensors to an operator
154 *
155 * @param op
156 * @return std::vector<Id>
157 */
158 std::vector<Id> src_tensors(Id op) const;
159 /** Get destination tensors to an operator
160 *
161 * @param op
162 * @return std::vector<Id>
163 */
164 std::vector<Id> dst_tensors(Id op) const;
165 /** Get source tensors of the whole graph
166 *
167 * @return std::vector<Id>
168 */
169 std::vector<Id> src_tensors() const;
170 /** Get destination tensors of the whole graph
171 *
172 * @return std::vector<Id>
173 */
174 std::vector<Id> dst_tensors() const;
175 /** Get all operators
176 *
177 * @return std::vector<Id>
178 */
179 std::vector<Id> all_ops() const;
180 /** Get all tensors
181 *
182 * @return std::vector<Id>
183 */
184 std::vector<Id> all_tensors() const;
185 /** Number of operators
186 *
187 * @return unsigned int
188 */
189 unsigned int number_of_ops() const;
190 /** Number of tensors
191 *
192 * @return unsigned int
193 */
194 unsigned int number_of_tensors() const;
195
196 /** Update @p merge_point to point to @p t_id
197 *
198 * @param t_id
199 * @param merge_point
200 */
201 Status update_merge_point(Id t_id, Id merge_point);
202
203 /** Strict equality comparison (all internal ids and order of insertion matter).
204 * In the future this may be replaced with a topological comparison, allowing equivalent graphs with different internal ids to be equal
205 *
206 *
207 * @param g0
208 * @param g1
209 * @return true
210 * @return false
211 */
212 friend bool operator==(const DependencyGraph &g0, const DependencyGraph &g1)
213 {
214 // Do not compare id allocators
215 return std::make_tuple(
216 g0._adj_src_tensors, g0._adj_dst_tensors, g0._adj_src_ops, g0._adj_dst_ops, g0._merge_to_internal)
217 == std::make_tuple(
218 g1._adj_src_tensors, g1._adj_dst_tensors, g1._adj_src_ops, g1._adj_dst_ops, g1._merge_to_internal);
219 }
220 void link_input(Id op, Id in_tensor);
221 void link_output(Id op, Id out_tensor);
222 /** Check if there's a path from @p src_tensor to @p dst_op
223 *
224 * @param src_tensor
225 * @param dst_op
226 * @return true
227 * @return false
228 */
229 bool path_exists_from_tensor_to_op(Id src_tensor, Id dst_op) const;
230 /** Check if there's a path from @p src_op to @p dst_op
231 *
232 * @param src_op
233 * @param dst_op
234 * @return true
235 * @return false
236 */
237 bool path_exists_from_op_to_op(Id src_op, Id dst_op) const;
238 /** Check if tensor is the src tensor of the entire graph
239 *
240 * @param tensor
241 * @return true
242 * @return false
243 */
244 bool is_src_tensor(Id tensor) const;
245 /** Check if tensor is the dst tensor of the entire graph
246 *
247 * @param tensor
248 * @return true
249 * @return false
250 */
251 bool is_dst_tensor(Id tensor) const;
252
253private:
254 Id insert_new_tensor();
255 Id insert_new_op();
256 bool tensor_exists(Id tensor) const;
257 bool operator_exists(Id op) const;
258 bool is_src_tensor_of(Id op, Id tensor) const;
259 bool is_dst_tensor_of(Id op, Id tensor) const;
260 bool are_connected(Id op, Id tensor) const;
261
262private:
263 AdjList _adj_src_tensors{};
264 AdjList _adj_dst_tensors{};
265 AdjList _adj_src_ops{};
266 AdjList _adj_dst_ops{};
267 std::map<Id, Id> _merge_to_internal{}; // From merge tensor to internal tensor
268 SerialIdAllocator _operator_id{};
269 SerialIdAllocator _tensor_id{};
270};
271
272} // namespace dynamic_fusion
273} // namespace experimental
274} // namespace arm_compute
275
SiCong Li4e9f5682022-05-10 10:15:59 +0100276#endif //ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_DEPENDENCYGRAPH_H
277#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */