David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 6 | #include "SubgraphViewSelector.hpp" |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 7 | #include "Graph.hpp" |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 8 | |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame] | 9 | #include <armnn/utility/Assert.hpp> |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 10 | #include <armnn/utility/IgnoreUnused.hpp> |
Jan Eilers | bb446e5 | 2020-04-02 13:56:54 +0100 | [diff] [blame] | 11 | #include <armnn/utility/PolymorphicDowncast.hpp> |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 12 | |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 13 | #include <algorithm> |
Matteo Martincigh | f02e6cd | 2019-05-17 12:15:30 +0100 | [diff] [blame] | 14 | #include <map> |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 15 | #include <queue> |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 16 | #include <unordered_set> |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 17 | |
| 18 | namespace armnn |
| 19 | { |
| 20 | |
| 21 | namespace |
| 22 | { |
| 23 | |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 24 | /// Intermediate data-structure to store the subgraph that a layer has been assigned to. |
| 25 | /// This is a "disjoint set" data structure that allows efficient merging of subgraphs, |
| 26 | /// which is a key part of the algorithm. Subgraphs are arranged in singly-linked trees |
| 27 | /// (with each node storing a pointer to its parent). Subgraphs in the same tree are considered |
| 28 | /// to have been merged. Merging subgraphs is performed by attaching one tree to another, |
| 29 | /// which is a simple pointer update. |
| 30 | /// |
| 31 | /// NOTE: Due to the way this is stored, it is almost never correct to directly compare pointers |
| 32 | /// to two PartialSubgraphs to check if two layers belong in the same subgraph. Instead you |
| 33 | /// should use IsMergedWith(). |
| 34 | /// |
| 35 | /// This structure also stores information about the dependencies of each subgraph, which is needed |
| 36 | /// to determine whether certain subgraphs can be merged. Checking whether a subgraph |
| 37 | /// depends on another subgraph is a frequent operation in the algorithm (see AssignSplitId) and so this is optimized |
| 38 | /// in preference to the merging of subgraphs. This leads to an approach where each subgraph stores |
| 39 | /// a set of all the subgraphs it depends on (for a fast lookup). In order to efficiently update this |
| 40 | /// set as subgraphs are merged means we also store a set of subgraphs which *depend on us* (i.e. the |
| 41 | /// complement of our dependencies). |
| 42 | class PartialSubgraph |
| 43 | { |
| 44 | public: |
| 45 | /// If this subgraph has been merged with another then there is an agreed "representative" for the combined |
| 46 | /// subgraph, which uniquely identifies the subgraph. |
| 47 | PartialSubgraph* GetRepresentative() |
| 48 | { |
| 49 | // Recurse up the tree to find the root node. |
| 50 | if (m_Parent == nullptr) |
| 51 | { |
| 52 | return this; |
| 53 | } |
| 54 | else |
| 55 | { |
| 56 | PartialSubgraph* result = m_Parent->GetRepresentative(); |
| 57 | // Update our parent pointer to point directly to the root in order to speed up future calls to this method. |
| 58 | // This essentially "flattens" the tree. |
| 59 | m_Parent = result; |
| 60 | return result; |
| 61 | } |
| 62 | } |
| 63 | |
| 64 | /// Merges this subgraph with another. |
| 65 | void MergeWith(PartialSubgraph* other) |
| 66 | { |
| 67 | if (m_Parent == nullptr) |
| 68 | { |
| 69 | other = other->GetRepresentative(); |
| 70 | if (this == other) |
| 71 | { |
| 72 | // Already merged - no-op |
| 73 | return; |
| 74 | } |
| 75 | m_Parent = other; |
| 76 | |
| 77 | // Update others' dependency sets to point to the new representative rather than us. |
| 78 | // Keeping these up-to-date means we can rely on these sets containing representatives when |
| 79 | // we perform a lookup in HasAntecedent() and so don't need to resolve the representative for each element |
| 80 | // of the set. See description at the top of this class for more rationale. |
| 81 | for (PartialSubgraph* a : m_Antecedents) |
| 82 | { |
| 83 | size_t numErased = a->m_Dependants.erase(this); |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame] | 84 | ARMNN_ASSERT(numErased == 1); |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 85 | IgnoreUnused(numErased); |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 86 | a->m_Dependants.insert(m_Parent); |
| 87 | } |
| 88 | for (PartialSubgraph* a : m_Dependants) |
| 89 | { |
| 90 | size_t numErased = a->m_Antecedents.erase(this); |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame] | 91 | ARMNN_ASSERT(numErased == 1); |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 92 | IgnoreUnused(numErased); |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 93 | a->m_Antecedents.insert(m_Parent); |
| 94 | } |
| 95 | |
| 96 | // Merge our dependency sets into our new representative. |
| 97 | // We no longer need to maintain our own sets, as requests will always be forwarded to the representative. |
| 98 | m_Parent->m_Antecedents.insert(m_Antecedents.begin(), m_Antecedents.end()); |
| 99 | m_Antecedents.clear(); |
| 100 | m_Parent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end()); |
| 101 | m_Dependants.clear(); |
| 102 | } |
| 103 | else |
| 104 | { |
| 105 | // Defer request to the representative |
| 106 | GetRepresentative()->MergeWith(other); |
| 107 | } |
| 108 | } |
| 109 | |
| 110 | /// Checks if this subgraph has been merged with the given subgraph. |
| 111 | bool IsMergedWith(PartialSubgraph* other) |
| 112 | { |
| 113 | return GetRepresentative() == other->GetRepresentative(); |
| 114 | } |
| 115 | |
| 116 | /// Marks the given subgraph as a direct antecedent (dependency) of this one. |
| 117 | void AddDirectAntecedent(PartialSubgraph* antecedent) |
| 118 | { |
| 119 | if (m_Parent == nullptr) |
| 120 | { |
| 121 | antecedent = antecedent->GetRepresentative(); |
| 122 | |
| 123 | m_Antecedents.insert(antecedent); |
| 124 | // Also record all of its antecedents, so that we end up with direct and indirect antecedents. |
| 125 | // This makes the lookup in HasAntecedent() faster. |
| 126 | m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end()); |
| 127 | // All of our dependents also need to include the new antecedents |
| 128 | for (PartialSubgraph* d : m_Dependants) |
| 129 | { |
| 130 | d->m_Antecedents.insert(antecedent); |
| 131 | d->m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end()); |
| 132 | } |
| 133 | |
| 134 | // Store reverse dependencies as well, required so that we can efficiently navigate the graph |
| 135 | // when making updates. |
| 136 | antecedent->m_Dependants.insert(this); |
| 137 | antecedent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end()); |
| 138 | for (PartialSubgraph* a : antecedent->m_Antecedents) |
| 139 | { |
| 140 | a->m_Dependants.insert(this); |
| 141 | a->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end()); |
| 142 | } |
| 143 | } |
| 144 | else |
| 145 | { |
| 146 | // Defer request to the representative |
| 147 | GetRepresentative()->AddDirectAntecedent(antecedent); |
| 148 | } |
| 149 | } |
| 150 | |
| 151 | /// Checks if this subgraph is dependent on the given subgraph, either directly or indirectly. |
| 152 | bool HasAntecedent(PartialSubgraph* antecedent) |
| 153 | { |
| 154 | if (m_Parent == nullptr) |
| 155 | { |
| 156 | antecedent = antecedent->GetRepresentative(); |
| 157 | // Thanks to keeping this set updated in MergeWith and AddDirectAntecedent, we can do an efficient lookup. |
| 158 | return m_Antecedents.count(antecedent) > 0; |
| 159 | } |
| 160 | else |
| 161 | { |
| 162 | // Defer request to the representative |
| 163 | return GetRepresentative()->HasAntecedent(antecedent); |
| 164 | } |
| 165 | } |
| 166 | |
| 167 | private: |
| 168 | /// Pointer to the parent node in the tree. If this is null then we are the representative for our merged subgraph. |
| 169 | PartialSubgraph* m_Parent; |
| 170 | /// The representatives of all the subgraphs which we depend on, either directly or indirectly. |
| 171 | std::unordered_set<PartialSubgraph*> m_Antecedents; |
| 172 | /// The representatives of all the subgraphs which depend on us, either directly or indirectly. |
| 173 | std::unordered_set<PartialSubgraph*> m_Dependants; |
| 174 | }; |
| 175 | |
| 176 | /// Intermediate data structure to store information associated with a particular layer. |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 177 | struct LayerSelectionInfo |
| 178 | { |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 179 | using LayerInfoContainer = std::map<IConnectableLayer*, LayerSelectionInfo>; |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 180 | using LayerInfoQueue = std::queue<LayerSelectionInfo*>; |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 181 | |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 182 | LayerSelectionInfo(Layer* layer, const SubgraphViewSelector::LayerSelectorFunction& selector) |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 183 | : m_Layer{layer} |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 184 | , m_Subgraph{nullptr} |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 185 | , m_IsSelected{selector(*layer)} |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 186 | , m_IsProcessed(false) |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 187 | { |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 188 | } |
| 189 | |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 190 | bool IsInputLayer() const |
| 191 | { |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 192 | return m_Layer->GetType() == armnn::LayerType::Input || m_Layer->GetType() == armnn::LayerType::Constant; |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 193 | } |
| 194 | |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 195 | void CollectNonSelectedInputs(LayerSelectionInfo::LayerInfoContainer& layerInfos, |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 196 | SubgraphView::IInputSlots& inputSlots) |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 197 | { |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 198 | for (auto&& slot = PolymorphicDowncast<Layer*>(m_Layer)->BeginInputSlots(); |
| 199 | slot != PolymorphicDowncast<Layer*>(m_Layer)->EndInputSlots(); |
| 200 | ++slot) |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 201 | { |
| 202 | OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot(); |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame] | 203 | ARMNN_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The input slots must be connected here."); |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 204 | if (parentLayerOutputSlot) |
| 205 | { |
| 206 | Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer(); |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 207 | auto parentInfo = layerInfos.find(&parentLayer); |
Matteo Martincigh | f02e6cd | 2019-05-17 12:15:30 +0100 | [diff] [blame] | 208 | if (parentInfo == layerInfos.end() || |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 209 | !m_Subgraph->IsMergedWith(parentInfo->second.m_Subgraph.get())) |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 210 | { |
Matteo Martincigh | 05349c5 | 2019-05-21 13:29:00 +0100 | [diff] [blame] | 211 | // Avoid collecting duplicate input slots |
| 212 | InputSlot* inputSlot = &(*slot); |
| 213 | if (std::find(inputSlots.begin(), inputSlots.end(), inputSlot) == inputSlots.end()) |
| 214 | { |
| 215 | inputSlots.push_back(inputSlot); |
| 216 | } |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 217 | } |
| 218 | } |
| 219 | } |
| 220 | } |
| 221 | |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 222 | void CollectNonSelectedOutputSlots(LayerSelectionInfo::LayerInfoContainer& layerInfos, |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 223 | SubgraphView::IOutputSlots& outputSlots) |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 224 | { |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 225 | for (auto&& slot = PolymorphicDowncast<Layer*>(m_Layer)->BeginOutputSlots(); |
| 226 | slot != PolymorphicDowncast<Layer*>(m_Layer)->EndOutputSlots(); |
| 227 | ++slot) |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 228 | { |
| 229 | for (InputSlot* childLayerInputSlot : slot->GetConnections()) |
| 230 | { |
| 231 | Layer& childLayer = childLayerInputSlot->GetOwningLayer(); |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 232 | auto childInfo = layerInfos.find(&childLayer); |
Matteo Martincigh | f02e6cd | 2019-05-17 12:15:30 +0100 | [diff] [blame] | 233 | if (childInfo == layerInfos.end() || |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 234 | !m_Subgraph->IsMergedWith(childInfo->second.m_Subgraph.get())) |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 235 | { |
Matteo Martincigh | 05349c5 | 2019-05-21 13:29:00 +0100 | [diff] [blame] | 236 | // Avoid collecting duplicate output slots |
| 237 | OutputSlot* outputSlot = &(*slot); |
| 238 | if (std::find(outputSlots.begin(), outputSlots.end(), outputSlot) == outputSlots.end()) |
| 239 | { |
| 240 | outputSlots.push_back(outputSlot); |
| 241 | } |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 242 | } |
| 243 | } |
| 244 | } |
| 245 | } |
| 246 | |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 247 | IConnectableLayer* m_Layer; |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 248 | /// Which subgraph this layer has been assigned to. Only valid once m_IsProcessed is true. |
| 249 | /// Two layers with different m_Subgraph pointers may in fact have been merged into the same subgraph - |
| 250 | /// see the description of the PartialSubgraph class. |
| 251 | std::shared_ptr<PartialSubgraph> m_Subgraph; |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 252 | bool m_IsSelected; |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 253 | bool m_IsProcessed; |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 254 | }; |
| 255 | |
| 256 | } // namespace <anonymous> |
| 257 | |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 258 | SubgraphViewSelector::Subgraphs |
| 259 | SubgraphViewSelector::SelectSubgraphs(Graph& graph, const LayerSelectorFunction& selector) |
Matteo Martincigh | adddddb | 2019-01-24 14:06:23 +0000 | [diff] [blame] | 260 | { |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 261 | SubgraphView subgraph(graph); |
| 262 | return SubgraphViewSelector::SelectSubgraphs(subgraph, selector); |
Matteo Martincigh | adddddb | 2019-01-24 14:06:23 +0000 | [diff] [blame] | 263 | } |
| 264 | |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 265 | |
| 266 | template<typename Delegate> |
| 267 | void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer& layerInfos, |
| 268 | LayerSelectionInfo& layerInfo, |
| 269 | Delegate function) |
| 270 | { |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 271 | Layer& layer = *PolymorphicDowncast<Layer*>(layerInfo.m_Layer); |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 272 | |
| 273 | for (auto inputSlot : layer.GetInputSlots()) |
| 274 | { |
Jan Eilers | bb446e5 | 2020-04-02 13:56:54 +0100 | [diff] [blame] | 275 | auto connectedInput = PolymorphicDowncast<OutputSlot*>(inputSlot.GetConnection()); |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame] | 276 | ARMNN_ASSERT_MSG(connectedInput, "Dangling input slot detected."); |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 277 | Layer& inputLayer = connectedInput->GetOwningLayer(); |
| 278 | |
| 279 | auto parentInfo = layerInfos.find(&inputLayer); |
Matteo Martincigh | f02e6cd | 2019-05-17 12:15:30 +0100 | [diff] [blame] | 280 | if (parentInfo != layerInfos.end()) |
| 281 | { |
| 282 | function(parentInfo->second); |
| 283 | } |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 284 | } |
| 285 | } |
| 286 | |
| 287 | template<typename Delegate> |
| 288 | void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer& layerInfos, |
| 289 | LayerSelectionInfo& layerInfo, |
| 290 | Delegate function) |
| 291 | { |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 292 | Layer& layer = *PolymorphicDowncast<Layer*>(layerInfo.m_Layer); |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 293 | |
| 294 | for (auto& outputSlot : layer.GetOutputSlots()) |
| 295 | { |
| 296 | for (auto& output : outputSlot.GetConnections()) |
| 297 | { |
| 298 | Layer& childLayer = output->GetOwningLayer(); |
| 299 | |
| 300 | auto childInfo = layerInfos.find(&childLayer); |
Matteo Martincigh | f02e6cd | 2019-05-17 12:15:30 +0100 | [diff] [blame] | 301 | if (childInfo != layerInfos.end()) |
| 302 | { |
| 303 | function(childInfo->second); |
| 304 | } |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 305 | } |
| 306 | } |
| 307 | } |
| 308 | |
| 309 | void AssignSplitId(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo) |
| 310 | { |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 311 | // Check each input to see if we can attach ourselves to any of the subgraphs that have already been assigned. |
| 312 | ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& parentInfo) |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 313 | { |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 314 | // We can only attach ourselves to the subgraph from this input if there isn't a cut here. |
| 315 | if (layerInfo.m_IsSelected == parentInfo.m_IsSelected) |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 316 | { |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 317 | // We also need to check that merging into this subgraph won't cause a dependency cycle between subgraphs. |
| 318 | // This will be the case if the subgraph that we will become part of is already a dependency |
| 319 | // of one of the subgraphs that are input to this layer, e.g: |
| 320 | // |
| 321 | // 0 | The numbers (0, 1) are the subgraph IDs of each layer and we are looking at layer X. |
| 322 | // / \ | |
| 323 | // 1 0 | We can't merge X into subgraph 0, because the left-hand input already depends on subgraph 0. |
| 324 | // \ / | We can however merge X into subgraph 1. |
| 325 | // X | |
| 326 | // |
| 327 | bool dependenciesOk = true; |
| 328 | ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& otherParentInfo) |
| 329 | { |
| 330 | // We call HasAntecedent() ~ n^2 times, where n is the number of inputs to this layer. |
| 331 | // Hence it is important that this is efficient - see PartialSubgraph class description. |
| 332 | if (otherParentInfo.m_Subgraph->HasAntecedent(parentInfo.m_Subgraph.get())) |
| 333 | { |
| 334 | dependenciesOk = false; |
| 335 | } |
| 336 | }); |
| 337 | |
| 338 | if (dependenciesOk) |
| 339 | { |
| 340 | // Merge into the subgraph of this input. If we have already been merged into another subgraph |
| 341 | // (from another input of this layer), then merge both of them together. |
| 342 | if (layerInfo.m_Subgraph == nullptr) |
| 343 | { |
| 344 | layerInfo.m_Subgraph = parentInfo.m_Subgraph; |
| 345 | } |
| 346 | else |
| 347 | { |
| 348 | // We call MergeWith() ~ n times, where n is the number of inputs to this layer. |
| 349 | // Therefore it does not need to be as performant as HasAntecedent(). |
| 350 | layerInfo.m_Subgraph->MergeWith(parentInfo.m_Subgraph.get()); |
| 351 | } |
| 352 | } |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 353 | } |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 354 | }); |
| 355 | |
| 356 | // If we weren't able to merge into an existing subgraph then we need to make a new one |
| 357 | if (layerInfo.m_Subgraph == nullptr) |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 358 | { |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 359 | layerInfo.m_Subgraph = std::make_shared<PartialSubgraph>(); |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 360 | } |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 361 | |
| 362 | // Record dependencies of the chosen subgraph based on the inputs of this layer. |
| 363 | ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& parentInfo) |
| 364 | { |
| 365 | // These functions are called ~n times, where n is the number of inputs to this layer. |
| 366 | // Therefore it does not need to be as performant as HasAntecedent(). |
| 367 | if (!layerInfo.m_Subgraph->IsMergedWith(parentInfo.m_Subgraph.get())) |
| 368 | { |
| 369 | layerInfo.m_Subgraph->AddDirectAntecedent(parentInfo.m_Subgraph.get()); |
| 370 | } |
| 371 | }); |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 372 | } |
| 373 | |
| 374 | bool IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo) |
| 375 | { |
| 376 | bool ready = true; |
| 377 | ForEachLayerInput(layerInfos, layerInfo, |
| 378 | [&ready](LayerSelectionInfo& parentInfo) |
| 379 | { |
| 380 | if (!parentInfo.m_IsProcessed) |
| 381 | { |
| 382 | ready = false; |
| 383 | } |
| 384 | }); |
| 385 | return ready; |
| 386 | } |
| 387 | |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 388 | SubgraphViewSelector::Subgraphs |
| 389 | SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelectorFunction& selector) |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 390 | { |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 391 | LayerSelectionInfo::LayerInfoContainer layerInfos; |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 392 | |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 393 | LayerSelectionInfo::LayerInfoQueue processQueue; |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 394 | const SubgraphView::IConnectableLayers& subgraphLayers = subgraph.GetIConnectableLayers(); |
| 395 | for (auto& layer : subgraphLayers) |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 396 | { |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 397 | |
| 398 | auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{PolymorphicDowncast<Layer*>(layer), selector}); |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 399 | LayerSelectionInfo& layerInfo = emplaced.first->second; |
| 400 | |
| 401 | // Start with Input type layers |
| 402 | if (layerInfo.IsInputLayer()) |
| 403 | { |
| 404 | processQueue.push(&layerInfo); |
| 405 | } |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 406 | } |
| 407 | |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 408 | const SubgraphView::IInputSlots& subgraphInputSlots = subgraph.GetIInputSlots(); |
Matteo Martincigh | f02e6cd | 2019-05-17 12:15:30 +0100 | [diff] [blame] | 409 | for (auto& inputSlot : subgraphInputSlots) |
| 410 | { |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 411 | Layer& layer = PolymorphicDowncast<InputSlot*>(inputSlot)->GetOwningLayer(); |
Matteo Martincigh | f02e6cd | 2019-05-17 12:15:30 +0100 | [diff] [blame] | 412 | auto emplaced = layerInfos.emplace(&layer, LayerSelectionInfo{&layer, selector}); |
| 413 | LayerSelectionInfo& layerInfo = emplaced.first->second; |
| 414 | |
| 415 | processQueue.push(&layerInfo); |
| 416 | } |
| 417 | |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 418 | while (!processQueue.empty()) |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 419 | { |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 420 | LayerSelectionInfo& layerInfo = *processQueue.front(); |
| 421 | processQueue.pop(); // remove front from queue |
| 422 | |
| 423 | // This layerInfo may have been added to the queue multiple times, so skip if we have already processed it |
| 424 | if (!layerInfo.m_IsProcessed) |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 425 | { |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 426 | // Only process this layerInfo if all inputs have been processed |
| 427 | if (!IsReadyForSplitAssignment(layerInfos, layerInfo)) |
| 428 | { |
| 429 | // Put back of the process queue if we can't process it just yet |
| 430 | processQueue.push(&layerInfo); |
| 431 | continue; // Skip to next iteration |
| 432 | } |
| 433 | |
| 434 | // Now we do the processing |
| 435 | AssignSplitId(layerInfos, layerInfo); |
| 436 | |
| 437 | // Queue any child nodes for processing |
| 438 | ForEachLayerOutput(layerInfos, layerInfo, [&processQueue](LayerSelectionInfo& childInfo) |
| 439 | { |
| 440 | processQueue.push(&childInfo); |
| 441 | }); |
| 442 | |
| 443 | // We don't need to process this node again |
| 444 | layerInfo.m_IsProcessed = true; |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 445 | } |
| 446 | } |
| 447 | |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 448 | // Collect all selected layers keyed by subgraph representative into a map |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 449 | using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>; |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 450 | std::map<PartialSubgraph*, SelectionInfoPtrs> splitMap; |
Derek Lamberti | 5cf4d1c | 2019-05-03 18:57:12 +0100 | [diff] [blame] | 451 | for (auto& info : layerInfos) |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 452 | { |
| 453 | if (info.second.m_IsSelected) |
| 454 | { |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 455 | auto it = splitMap.find(info.second.m_Subgraph->GetRepresentative()); |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 456 | if (it == splitMap.end()) |
| 457 | { |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 458 | splitMap.insert( |
| 459 | std::make_pair(info.second.m_Subgraph->GetRepresentative(), SelectionInfoPtrs{&info.second})); |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 460 | } |
| 461 | else |
| 462 | { |
| 463 | it->second.push_back(&info.second); |
| 464 | } |
| 465 | } |
| 466 | } |
| 467 | |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 468 | // Now each entry in splitMap represents a subgraph |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 469 | Subgraphs result; |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 470 | for (auto& splitGraph : splitMap) |
| 471 | { |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 472 | SubgraphView::IInputSlots inputs; |
| 473 | SubgraphView::IOutputSlots outputs; |
| 474 | SubgraphView::IConnectableLayers layers; |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 475 | for (auto&& infoPtr : splitGraph.second) |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 476 | { |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 477 | infoPtr->CollectNonSelectedInputs(layerInfos, inputs); |
| 478 | infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs); |
| 479 | layers.push_back(infoPtr->m_Layer); |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 480 | } |
Rob Hughes | 1addbf3 | 2021-02-19 09:24:44 +0000 | [diff] [blame] | 481 | |
| 482 | // Sort lists into deterministic order, not relying on pointer values which may be different on each execution. |
| 483 | // This makes debugging the optimised graph much easier as subsequent stages can also be deterministic. |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 484 | std::sort(inputs.begin(), inputs.end(), [](const IInputSlot* a, const IInputSlot* b) |
Rob Hughes | 1addbf3 | 2021-02-19 09:24:44 +0000 | [diff] [blame] | 485 | { |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 486 | auto* castA = PolymorphicDowncast<const InputSlot*>(a); |
| 487 | auto* castB = PolymorphicDowncast<const InputSlot*>(b); |
| 488 | const LayerGuid guidA = castA->GetOwningLayer().GetGuid(); |
| 489 | const LayerGuid guidB = castB->GetOwningLayer().GetGuid(); |
Rob Hughes | 1addbf3 | 2021-02-19 09:24:44 +0000 | [diff] [blame] | 490 | if (guidA < guidB) |
| 491 | { |
| 492 | return true; |
| 493 | } |
| 494 | else if (guidA == guidB) |
| 495 | { |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 496 | return (castA->GetSlotIndex() < castB->GetSlotIndex()); |
Rob Hughes | 1addbf3 | 2021-02-19 09:24:44 +0000 | [diff] [blame] | 497 | } |
| 498 | return false; |
| 499 | }); |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 500 | std::sort(outputs.begin(), outputs.end(), [](const IOutputSlot* a, const IOutputSlot* b) |
Rob Hughes | 1addbf3 | 2021-02-19 09:24:44 +0000 | [diff] [blame] | 501 | { |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 502 | auto* castA = PolymorphicDowncast<const OutputSlot*>(a); |
| 503 | auto* castB = PolymorphicDowncast<const OutputSlot*>(b); |
| 504 | const LayerGuid guidA = castA->GetOwningLayer().GetGuid(); |
| 505 | const LayerGuid guidB = castB->GetOwningLayer().GetGuid(); |
Rob Hughes | 1addbf3 | 2021-02-19 09:24:44 +0000 | [diff] [blame] | 506 | if (guidA < guidB) |
| 507 | { |
| 508 | return true; |
| 509 | } |
| 510 | else if (guidA == guidB) |
| 511 | { |
| 512 | return (a->CalculateIndexOnOwner() < b->CalculateIndexOnOwner()); |
| 513 | } |
| 514 | return false; |
| 515 | }); |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 516 | layers.sort([](const IConnectableLayer* a, const IConnectableLayer* b) { return a->GetGuid() < b->GetGuid(); }); |
Rob Hughes | 1addbf3 | 2021-02-19 09:24:44 +0000 | [diff] [blame] | 517 | |
Rob Hughes | 30db8ad | 2019-11-08 15:50:10 +0000 | [diff] [blame] | 518 | // Create a new sub-graph with the new lists of input/output slots and layer |
Francis Murtagh | 56ccf68 | 2021-12-13 18:48:12 +0000 | [diff] [blame] | 519 | result.emplace_back(std::make_unique<SubgraphView>(std::move(layers), |
| 520 | std::move(inputs), |
| 521 | std::move(outputs))); |
David Beck | f98d21a | 2018-10-26 16:03:03 +0100 | [diff] [blame] | 522 | } |
| 523 | |
| 524 | return result; |
| 525 | } |
| 526 | |
| 527 | } // namespace armnn |