blob: b87e2b73b11e7157b61c7554b907bfd1bb533a82 [file] [log] [blame]
David Beckf98d21a2018-10-26 16:03:03 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "SubGraphSelector.hpp"
7#include "Graph.hpp"
8#include <boost/assert.hpp>
9#include <algorithm>
10#include <unordered_map>
11
12namespace armnn
13{
14
15namespace
16{
17
18struct LayerSelectionInfo
19{
20 using LayerInfoContainer = std::unordered_map<Layer*, LayerSelectionInfo>;
21 static constexpr uint32_t InitialSplitId() { return 1; }
22
23 LayerSelectionInfo(Layer* layer, const SubGraphSelector::LayerSelectorFunction& selector)
24 : m_Layer{layer}
25 , m_SplitId{0}
26 , m_IsSelected{selector(*layer)}
27 {
28 // fill topology information by storing direct children
29 for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot)
30 {
31 for (InputSlot* childLayerInputSlot : slot->GetConnections())
32 {
33 Layer& childLayer = childLayerInputSlot->GetOwningLayer();
34 m_DirectChildren.push_back(&childLayer);
35 }
36 }
37 }
38
39 void MarkChildrenSplits(LayerInfoContainer& network,
40 uint32_t splitId,
41 bool prevSelected)
42 {
43 if (m_SplitId < splitId)
44 {
45 m_SplitId = splitId;
46 }
47
48 // introduce a new split point at all non-selected points, but only if the
49 // previous point was selected. this prevents creating a new subgraph at
50 // every non-selected layer
51 if (!m_IsSelected && prevSelected)
52 {
53 ++m_SplitId;
54 }
55
56 for (auto& layer : m_DirectChildren)
57 {
58 auto it = network.find(layer);
59 BOOST_ASSERT_MSG(it != network.end(), "All layers must be part of the topology.");
60 if (it != network.end())
61 {
62 it->second.MarkChildrenSplits(network, m_SplitId, m_IsSelected);
63 }
64 }
65 }
66
67 bool IsInputLayer() const
68 {
69 return m_Layer->GetType() == armnn::LayerType::Input;
70 }
71
72 void CollectNonSelectedInputs(SubGraph::InputSlots& slots,
73 const SubGraphSelector::LayerSelectorFunction& selector)
74 {
75 for (auto&& slot = m_Layer->BeginInputSlots(); slot != m_Layer->EndInputSlots(); ++slot)
76 {
77 OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
78 BOOST_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The slots must be connected here.");
79 if (parentLayerOutputSlot)
80 {
81 Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
82 if (selector(parentLayer) == false)
83 {
84 slots.push_back(&(*slot));
85 }
86 }
87 }
88 }
89
90 void CollectNonSelectedOutputSlots(SubGraph::OutputSlots& slots,
91 const SubGraphSelector::LayerSelectorFunction& selector)
92 {
93 for (auto&& slot = m_Layer->BeginOutputSlots(); slot != m_Layer->EndOutputSlots(); ++slot)
94 {
95 for (InputSlot* childLayerInputSlot : slot->GetConnections())
96 {
97 Layer& childLayer = childLayerInputSlot->GetOwningLayer();
98 if (selector(childLayer) == false)
99 {
100 slots.push_back(&(*slot));
101 }
102 }
103 }
104 }
105
106 std::vector<Layer*> m_DirectChildren;
107 Layer* m_Layer;
108 uint32_t m_SplitId;
109 bool m_IsSelected;
110};
111
112} // namespace <anonymous>
113
114SubGraphSelector::SubGraphs
115SubGraphSelector::SelectSubGraphs(Graph& graph,
116 const LayerSelectorFunction& selector)
117{
118 LayerSelectionInfo::LayerInfoContainer layerInfo;
119
120 for (auto& layer : graph)
121 {
122 layerInfo.emplace(layer, LayerSelectionInfo{layer, selector});
123 }
124
125 uint32_t splitNo = LayerSelectionInfo::InitialSplitId();
126 for (auto& info : layerInfo)
127 {
128 if (info.second.IsInputLayer())
129 {
130 // for each input layer we mark the graph where subgraph
131 // splits need to happen because of the dependency between
132 // the selected and non-selected nodes
133 info.second.MarkChildrenSplits(layerInfo, splitNo, false);
134 }
135 }
136
137 // Collect all selected layers keyed by split id into a map
138 using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
139 std::unordered_map<uint32_t, SelectionInfoPtrs> splitMap;
140 for (auto& info : layerInfo)
141 {
142 if (info.second.m_IsSelected)
143 {
144 auto it = splitMap.find(info.second.m_SplitId);
145 if (it == splitMap.end())
146 {
147 splitMap.insert(std::make_pair(info.second.m_SplitId, SelectionInfoPtrs{&info.second}));
148 }
149 else
150 {
151 it->second.push_back(&info.second);
152 }
153 }
154 }
155
156 // Now each non-empty split id represents a subgraph
157 SubGraphs result;
158 for (auto& splitGraph : splitMap)
159 {
160 if (splitGraph.second.empty() == false)
161 {
162 SubGraph::OutputSlots outputs;
163 SubGraph::InputSlots inputs;
164 SubGraph::Layers layers;
165 for (auto&& infoPtr : splitGraph.second)
166 {
167 infoPtr->CollectNonSelectedOutputSlots(outputs, selector);
168 infoPtr->CollectNonSelectedInputs(inputs, selector);
169 layers.insert(infoPtr->m_Layer);
170 }
171 result.emplace_back(
172 std::make_unique<SubGraph>(
173 std::move(inputs),
174 std::move(outputs),
175 std::move(layers)));
176 }
177 }
178
179 return result;
180}
181
182} // namespace armnn