SiCong Li | b63b119 | 2022-01-28 18:24:39 +0000 | [diff] [blame^] | 1 | /* |
| 2 | * Copyright (c) 2022 Arm Limited. |
| 3 | * |
| 4 | * SPDX-License-Identifier: MIT |
| 5 | * |
| 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | * of this software and associated documentation files (the "Software"), to |
| 8 | * deal in the Software without restriction, including without limitation the |
| 9 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 10 | * sell copies of the Software, and to permit persons to whom the Software is |
| 11 | * furnished to do so, subject to the following conditions: |
| 12 | * |
| 13 | * The above copyright notice and this permission notice shall be included in all |
| 14 | * copies or substantial portions of the Software. |
| 15 | * |
| 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | * SOFTWARE. |
| 23 | */ |
| 24 | #ifndef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION |
| 25 | #error "This experimental feature must be enabled with -DENABLE_EXPERIMENTAL_DYNAMIC_FUSION" |
| 26 | #endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */ |
| 27 | #include "arm_compute/core/experimental/DependencyGraph.h" |
| 28 | |
| 29 | #include <algorithm> |
| 30 | #include <deque> |
| 31 | #include <set> |
| 32 | |
| 33 | namespace arm_compute |
| 34 | { |
| 35 | namespace experimental |
| 36 | { |
| 37 | namespace dynamic_fusion |
| 38 | { |
| 39 | DependencyGraph::DependencyGraph(const AdjList &adj_src_tensors, const AdjList &adj_dst_tensors, const AdjList &adj_src_ops, const AdjList &adj_dst_ops, std::map<Id, Id> merge_points) |
| 40 | : _adj_src_tensors{ adj_src_tensors }, _adj_dst_tensors{ adj_dst_tensors }, _adj_src_ops{ adj_src_ops }, _adj_dst_ops{ adj_dst_ops }, _merge_to_internal{ merge_points }, _operator_id{}, _tensor_id{} |
| 41 | { |
| 42 | } |
| 43 | DependencyGraph::DependencyGraph(const std::vector<Id> &imported_tensors) |
| 44 | : _adj_src_tensors{}, _adj_dst_tensors{}, _adj_src_ops{}, _adj_dst_ops{}, _merge_to_internal{}, _operator_id{}, _tensor_id{} |
| 45 | { |
| 46 | for(auto t : imported_tensors) |
| 47 | { |
| 48 | _adj_src_ops[t] = {}; |
| 49 | _adj_dst_ops[t] = {}; |
| 50 | } |
| 51 | } |
| 52 | |
| 53 | Status DependencyGraph::update_merge_point(Id t_id, Id merge_point) |
| 54 | { |
| 55 | if(_merge_to_internal.find(merge_point) == _merge_to_internal.end()) |
| 56 | { |
| 57 | return Status{ ErrorCode::RUNTIME_ERROR, "Merge point does not exist" }; |
| 58 | } |
| 59 | _merge_to_internal[merge_point] = t_id; |
| 60 | return Status{}; |
| 61 | } |
| 62 | |
| 63 | DependencyGraph::Id DependencyGraph::add_tensor(Id merge_tensor) |
| 64 | { |
| 65 | Id new_tensor{ empty_id() }; |
| 66 | if(merge_tensor != empty_id()) |
| 67 | { |
| 68 | if(_merge_to_internal.find(merge_tensor) != _merge_to_internal.end()) |
| 69 | { |
| 70 | new_tensor = _merge_to_internal[merge_tensor]; |
| 71 | } |
| 72 | else |
| 73 | { |
| 74 | new_tensor = insert_new_tensor(); |
| 75 | _merge_to_internal[merge_tensor] = new_tensor; |
| 76 | } |
| 77 | } |
| 78 | else |
| 79 | { |
| 80 | new_tensor = insert_new_tensor(); |
| 81 | } |
| 82 | return new_tensor; |
| 83 | } |
| 84 | |
| 85 | void DependencyGraph::remove_tensor(Id tensor) |
| 86 | { |
| 87 | for(auto src_op : _adj_src_ops.at(tensor)) |
| 88 | { |
| 89 | auto &dst_tensors = _adj_dst_tensors.at(src_op); |
| 90 | dst_tensors.erase( |
| 91 | std::remove(std::begin(dst_tensors), std::end(dst_tensors), tensor), |
| 92 | std::end(dst_tensors)); |
| 93 | } |
| 94 | for(auto dst_op : _adj_dst_ops.at(tensor)) |
| 95 | { |
| 96 | auto &src_tensors = _adj_src_tensors.at(dst_op); |
| 97 | src_tensors.erase( |
| 98 | std::remove(std::begin(src_tensors), std::end(src_tensors), tensor), |
| 99 | std::end(src_tensors)); |
| 100 | } |
| 101 | _adj_src_ops.erase(tensor); |
| 102 | _adj_dst_ops.erase(tensor); |
| 103 | } |
| 104 | |
| 105 | std::pair<Status, DependencyGraph::Id> DependencyGraph::add_operator(const std::vector<Id> &inputs, const std::vector<Id> &outputs) |
| 106 | { |
| 107 | Id new_op = insert_new_op(); |
| 108 | for(Id tensor : inputs) |
| 109 | { |
| 110 | link_input(new_op, tensor); |
| 111 | } |
| 112 | for(Id tensor : outputs) |
| 113 | { |
| 114 | link_output(new_op, tensor); |
| 115 | } |
| 116 | |
| 117 | // Use topological sort in order to detect possible loops / cycles. |
| 118 | // NOTE: This is unscalable. We'll need to have a better way of detecting loops or relax this invariant during operation, and add a validate method instead |
| 119 | return std::pair<Status, DependencyGraph::Id>(topological_sort().first, new_op); |
| 120 | } |
| 121 | |
| 122 | void DependencyGraph::remove_operator(Id op) |
| 123 | { |
| 124 | for(auto src_tensor : _adj_src_tensors.at(op)) |
| 125 | { |
| 126 | auto &dst_ops = _adj_dst_ops.at(src_tensor); |
| 127 | dst_ops.erase( |
| 128 | std::remove(std::begin(dst_ops), std::end(dst_ops), op), |
| 129 | std::end(dst_ops)); |
| 130 | } |
| 131 | for(auto dst_tensor : _adj_dst_tensors.at(op)) |
| 132 | { |
| 133 | auto &src_ops = _adj_src_ops.at(dst_tensor); |
| 134 | src_ops.erase( |
| 135 | std::remove(std::begin(src_ops), std::end(src_ops), op), |
| 136 | std::end(src_ops)); |
| 137 | } |
| 138 | _adj_src_tensors.erase(op); |
| 139 | _adj_dst_tensors.erase(op); |
| 140 | } |
| 141 | |
| 142 | std::map<DependencyGraph::Id, DependencyGraph::Id> DependencyGraph::get_merge_points() const |
| 143 | { |
| 144 | return _merge_to_internal; |
| 145 | } |
| 146 | |
| 147 | std::vector<DependencyGraph::Id> DependencyGraph::get_root_ops() const |
| 148 | { |
| 149 | std::vector<Id> ops{}; |
| 150 | const auto op_list = all_ops(); |
| 151 | |
| 152 | for(auto op : op_list) |
| 153 | { |
| 154 | if(src_ops(op).empty()) |
| 155 | { |
| 156 | ops.emplace_back(op); |
| 157 | } |
| 158 | } |
| 159 | return ops; |
| 160 | } |
| 161 | |
| 162 | std::vector<DependencyGraph::Id> DependencyGraph::get_dst_ops() const |
| 163 | { |
| 164 | std::vector<Id> ops{}; |
| 165 | const auto op_list = all_ops(); |
| 166 | |
| 167 | for(auto op : op_list) |
| 168 | { |
| 169 | if(dst_ops(op).empty()) |
| 170 | { |
| 171 | ops.emplace_back(op); |
| 172 | } |
| 173 | } |
| 174 | return ops; |
| 175 | } |
| 176 | |
| 177 | std::vector<DependencyGraph::Id> DependencyGraph::src_tensors(Id op) const |
| 178 | { |
| 179 | ARM_COMPUTE_ERROR_ON(!operator_exists(op)); |
| 180 | return _adj_src_tensors.at(op); |
| 181 | } |
| 182 | |
| 183 | std::vector<DependencyGraph::Id> DependencyGraph::dst_tensors(Id op) const |
| 184 | { |
| 185 | ARM_COMPUTE_ERROR_ON(!operator_exists(op)); |
| 186 | return _adj_dst_tensors.at(op); |
| 187 | } |
| 188 | |
| 189 | std::vector<DependencyGraph::Id> DependencyGraph::src_tensors() const |
| 190 | { |
| 191 | std::vector<Id> tensors; |
| 192 | for(auto tensor_src_ops : _adj_src_ops) |
| 193 | { |
| 194 | if(tensor_src_ops.second.empty()) |
| 195 | tensors.push_back(tensor_src_ops.first); |
| 196 | } |
| 197 | return tensors; |
| 198 | } |
| 199 | |
| 200 | std::vector<DependencyGraph::Id> DependencyGraph::dst_tensors() const |
| 201 | { |
| 202 | std::vector<Id> tensors; |
| 203 | for(auto tensor_dst_ops : _adj_dst_ops) |
| 204 | { |
| 205 | if(tensor_dst_ops.second.empty()) |
| 206 | tensors.push_back(tensor_dst_ops.first); |
| 207 | } |
| 208 | return tensors; |
| 209 | } |
| 210 | |
| 211 | std::vector<DependencyGraph::Id> DependencyGraph::src_ops_from_tensor(Id tensor) const |
| 212 | { |
| 213 | return _adj_src_ops.at(tensor); |
| 214 | } |
| 215 | std::vector<DependencyGraph::Id> DependencyGraph::dst_ops_from_tensor(Id tensor) const |
| 216 | { |
| 217 | return _adj_dst_ops.at(tensor); |
| 218 | } |
| 219 | |
| 220 | std::vector<DependencyGraph::Id> DependencyGraph::all_ops() const |
| 221 | { |
| 222 | std::vector<Id> ops{}; |
| 223 | std::transform(std::begin(_adj_src_tensors), std::end(_adj_src_tensors), std::back_inserter(ops), [](const auto & it) |
| 224 | { |
| 225 | return it.first; |
| 226 | }); |
| 227 | return ops; |
| 228 | } |
| 229 | |
| 230 | bool DependencyGraph::path_exists_from_tensor_to_op(Id src_tensor, Id dst_op) const |
| 231 | { |
| 232 | for(auto child_op : dst_ops_from_tensor(src_tensor)) |
| 233 | { |
| 234 | if(path_exists_from_op_to_op(child_op, dst_op)) |
| 235 | { |
| 236 | return true; |
| 237 | } |
| 238 | } |
| 239 | return false; |
| 240 | } |
| 241 | |
| 242 | bool DependencyGraph::path_exists_from_op_to_op(Id src_op, Id dst_op) const |
| 243 | { |
| 244 | if(src_op == dst_op) |
| 245 | { |
| 246 | return true; |
| 247 | } |
| 248 | if(is_in(src_op, get_dst_ops())) |
| 249 | { |
| 250 | return false; |
| 251 | } |
| 252 | for(auto child_tensor : dst_tensors(src_op)) |
| 253 | { |
| 254 | if(path_exists_from_tensor_to_op(child_tensor, dst_op)) |
| 255 | { |
| 256 | return true; |
| 257 | } |
| 258 | } |
| 259 | return false; |
| 260 | } |
| 261 | |
| 262 | std::vector<DependencyGraph::Id> DependencyGraph::all_tensors() const |
| 263 | { |
| 264 | std::vector<Id> tensors{}; |
| 265 | std::transform(std::begin(_adj_src_ops), std::end(_adj_src_ops), std::back_inserter(tensors), [](const auto & it) |
| 266 | { |
| 267 | return it.first; |
| 268 | }); |
| 269 | return tensors; |
| 270 | } |
| 271 | |
| 272 | unsigned int DependencyGraph::number_of_ops() const |
| 273 | { |
| 274 | return _adj_src_tensors.size(); |
| 275 | } |
| 276 | |
| 277 | unsigned int DependencyGraph::number_of_tensors() const |
| 278 | { |
| 279 | return _adj_src_ops.size(); |
| 280 | } |
| 281 | |
| 282 | DependencyGraph::Id DependencyGraph::insert_new_tensor() |
| 283 | { |
| 284 | Id new_tensor = _tensor_id.alloc(); |
| 285 | _adj_src_ops[new_tensor] = {}; |
| 286 | _adj_dst_ops[new_tensor] = {}; |
| 287 | return new_tensor; |
| 288 | } |
| 289 | DependencyGraph::Id DependencyGraph::insert_new_op() |
| 290 | { |
| 291 | Id new_op = _operator_id.alloc(); |
| 292 | _adj_src_tensors[new_op] = {}; |
| 293 | _adj_dst_tensors[new_op] = {}; |
| 294 | return new_op; |
| 295 | } |
| 296 | void DependencyGraph::link_input(Id op, Id in_tensor) |
| 297 | { |
| 298 | ARM_COMPUTE_ERROR_ON(!operator_exists(op)); |
| 299 | ARM_COMPUTE_ERROR_ON(!tensor_exists(in_tensor)); |
| 300 | ARM_COMPUTE_ERROR_ON(are_connected(op, in_tensor)); |
| 301 | _adj_src_tensors[op].push_back(in_tensor); |
| 302 | _adj_dst_ops[in_tensor].push_back(op); |
| 303 | } |
| 304 | void DependencyGraph::link_output(Id op, Id out_tensor) |
| 305 | { |
| 306 | ARM_COMPUTE_ERROR_ON(!operator_exists(op)); |
| 307 | ARM_COMPUTE_ERROR_ON(!tensor_exists(out_tensor)); |
| 308 | ARM_COMPUTE_ERROR_ON(are_connected(op, out_tensor)); |
| 309 | _adj_dst_tensors[op].push_back(out_tensor); |
| 310 | _adj_src_ops[out_tensor].push_back(op); |
| 311 | } |
| 312 | bool DependencyGraph::tensor_exists(Id tensor) const |
| 313 | { |
| 314 | return _adj_src_ops.find(tensor) != _adj_src_ops.end() && _adj_dst_ops.find(tensor) != _adj_dst_ops.end(); |
| 315 | } |
| 316 | bool DependencyGraph::operator_exists(Id op) const |
| 317 | { |
| 318 | return _adj_src_tensors.find(op) != _adj_src_tensors.end() && _adj_dst_tensors.find(op) != _adj_dst_tensors.end(); |
| 319 | } |
| 320 | |
| 321 | bool DependencyGraph::is_src_tensor(Id tensor) const |
| 322 | { |
| 323 | if(!tensor_exists(tensor)) |
| 324 | { |
| 325 | return false; |
| 326 | } |
| 327 | return _adj_src_ops.at(tensor).empty(); |
| 328 | } |
| 329 | |
| 330 | bool DependencyGraph::is_dst_tensor(Id tensor) const |
| 331 | { |
| 332 | if(!tensor_exists(tensor)) |
| 333 | { |
| 334 | return false; |
| 335 | } |
| 336 | return _adj_dst_ops.at(tensor).empty(); |
| 337 | } |
| 338 | bool DependencyGraph::is_src_tensor_of(Id op, Id tensor) const |
| 339 | { |
| 340 | if(!operator_exists(op) || !tensor_exists(tensor)) |
| 341 | { |
| 342 | return false; |
| 343 | } |
| 344 | const auto op_inputs = src_tensors(op); |
| 345 | return std::find(op_inputs.begin(), op_inputs.end(), tensor) != op_inputs.end(); |
| 346 | } |
| 347 | bool DependencyGraph::is_dst_tensor_of(Id op, Id tensor) const |
| 348 | { |
| 349 | if(!operator_exists(op) || !tensor_exists(tensor)) |
| 350 | { |
| 351 | return false; |
| 352 | } |
| 353 | const auto op_outputs = dst_tensors(op); |
| 354 | return std::find(op_outputs.begin(), op_outputs.end(), tensor) != op_outputs.end(); |
| 355 | } |
| 356 | bool DependencyGraph::are_connected(Id op, Id tensor) const |
| 357 | { |
| 358 | return is_src_tensor_of(op, tensor) || is_dst_tensor_of(op, tensor); |
| 359 | } |
| 360 | std::vector<DependencyGraph::Id> DependencyGraph::src_ops(Id op) const |
| 361 | { |
| 362 | ARM_COMPUTE_ERROR_ON(!operator_exists(op)); |
| 363 | std::vector<Id> ops{}; |
| 364 | for(Id src_tensor : src_tensors(op)) |
| 365 | { |
| 366 | ops.insert(ops.end(), std::begin(_adj_src_ops.at(src_tensor)), std::end(_adj_src_ops.at(src_tensor))); |
| 367 | } |
| 368 | return ops; |
| 369 | } |
| 370 | |
| 371 | std::vector<DependencyGraph::Id> DependencyGraph::dst_ops(Id op) const |
| 372 | { |
| 373 | ARM_COMPUTE_ERROR_ON(!operator_exists(op)); |
| 374 | std::vector<Id> ops{}; |
| 375 | for(Id dst_tensor : _adj_dst_tensors.at(op)) |
| 376 | { |
| 377 | ops.insert(ops.end(), std::begin(_adj_dst_ops.at(dst_tensor)), std::end(_adj_dst_ops.at(dst_tensor))); |
| 378 | } |
| 379 | return ops; |
| 380 | } |
| 381 | |
| 382 | std::pair<Status, std::vector<DependencyGraph::OpPack>> DependencyGraph::topological_sort() const |
| 383 | { |
| 384 | // Incident degree (number of source operators to an op) |
| 385 | std::map<Id, unsigned int> in_degree{}; |
| 386 | std::set<Id> visited_ops{}; |
| 387 | std::deque<Id> zero_in_degree_ops{}; |
| 388 | std::vector<OpPack> sorted_op_packs{}; |
| 389 | for(auto op : all_ops()) |
| 390 | { |
| 391 | const auto degree = src_ops(op).size(); |
| 392 | in_degree[op] = degree; |
| 393 | if(degree == 0) |
| 394 | { |
| 395 | zero_in_degree_ops.push_back(op); |
| 396 | visited_ops.insert(op); |
| 397 | } |
| 398 | } |
| 399 | |
| 400 | while(!zero_in_degree_ops.empty()) |
| 401 | { |
| 402 | const Id op = zero_in_degree_ops.front(); |
| 403 | zero_in_degree_ops.pop_front(); |
| 404 | sorted_op_packs.push_back(OpPack{ op, src_tensors(op), dst_tensors(op) }); |
| 405 | |
| 406 | for(const auto next_op : dst_ops(op)) |
| 407 | { |
| 408 | if(in_degree[next_op] > 0) |
| 409 | { |
| 410 | in_degree[next_op]--; |
| 411 | } |
| 412 | if(in_degree[next_op] == 0 && visited_ops.find(next_op) == visited_ops.end()) |
| 413 | { |
| 414 | zero_in_degree_ops.push_back(next_op); |
| 415 | visited_ops.insert(op); |
| 416 | } |
| 417 | } |
| 418 | } |
| 419 | |
| 420 | // If there are remaining ops with in_degree > 0, then it's indication that there are cycles in the graph |
| 421 | Status st{}; |
| 422 | if(sorted_op_packs.size() != number_of_ops()) |
| 423 | { |
| 424 | st = Status{ ErrorCode::RUNTIME_ERROR, "Cycles or loops are not allowed in a DependencyGraph" }; |
| 425 | } |
| 426 | return std::make_pair(st, sorted_op_packs); |
| 427 | } |
| 428 | |
| 429 | } // namespace dynamic_fusion |
| 430 | } // namespace experimental |
| 431 | } // namespace arm_compute |