telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
David Beck | ecb56cd | 2018-09-05 12:52:57 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 4 | // |
Matteo Martincigh | adddddb | 2019-01-24 14:06:23 +0000 | [diff] [blame] | 5 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 6 | #include "Graph.hpp" |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 7 | #include "SubgraphView.hpp" |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 8 | #include "LayersFwd.hpp" |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 9 | |
Matteo Martincigh | e5b8eb9 | 2019-11-28 15:45:42 +0000 | [diff] [blame] | 10 | #include <armnn/backends/IBackendInternal.hpp> |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 11 | |
| 12 | #include <armnn/BackendId.hpp> |
Matthew Bentham | f48afc6 | 2020-01-15 17:55:08 +0000 | [diff] [blame] | 13 | #include <armnn/Logging.hpp> |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 14 | #include <armnn/TypesUtils.hpp> |
Matthew Bentham | f48afc6 | 2020-01-15 17:55:08 +0000 | [diff] [blame] | 15 | #include <armnn/Utils.hpp> |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 16 | #include <armnn/utility/Assert.hpp> |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 17 | |
| 18 | #include <boost/polymorphic_cast.hpp> |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 19 | #include <boost/format.hpp> |
| 20 | |
| 21 | #include <unordered_map> |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 22 | #include <DotSerializer.hpp> |
| 23 | #include <sstream> |
| 24 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 25 | namespace armnn |
| 26 | { |
| 27 | |
| 28 | Graph::Graph(const Graph& other) |
| 29 | : m_LayersInOrder(other.m_LayersInOrder) |
| 30 | { |
| 31 | std::unordered_map<const Layer*, Layer*> otherToClonedMap; |
| 32 | |
| 33 | for (auto&& otherLayer : other.m_Layers) |
| 34 | { |
| 35 | Layer* const layer = otherLayer->Clone(*this); |
| 36 | otherToClonedMap.emplace(otherLayer, layer); |
| 37 | } |
| 38 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 39 | // Copies slot connections. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 40 | for (auto&& otherLayer : other.m_Layers) |
| 41 | { |
| 42 | Layer* const thisLayer = otherToClonedMap[otherLayer]; |
| 43 | |
| 44 | auto outputSlot = thisLayer->BeginOutputSlots(); |
| 45 | for (auto&& otherOutputSlot : otherLayer->GetOutputSlots()) |
| 46 | { |
| 47 | for (auto&& otherInputSlot : otherOutputSlot.GetConnections()) |
| 48 | { |
| 49 | const Layer& otherTgtLayer = otherInputSlot->GetOwningLayer(); |
| 50 | Layer* const thisTgtLayer = otherToClonedMap[&otherTgtLayer]; |
| 51 | |
| 52 | InputSlot& inputSlot = thisTgtLayer->GetInputSlot(otherInputSlot->GetSlotIndex()); |
| 53 | outputSlot->Connect(inputSlot); |
| 54 | } |
| 55 | outputSlot->SetTensorInfo(otherOutputSlot.GetTensorInfo()); |
| 56 | ++outputSlot; |
| 57 | } |
| 58 | } |
| 59 | } |
| 60 | |
| 61 | Status Graph::Print() const |
| 62 | { |
| 63 | if (m_Layers.empty()) |
| 64 | { |
Derek Lamberti | 0844697 | 2019-11-26 16:38:31 +0000 | [diff] [blame] | 65 | ARMNN_LOG(info) << "\n Graph is empty.\n"; |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 66 | return Status::Success; |
| 67 | } |
Derek Lamberti | 0844697 | 2019-11-26 16:38:31 +0000 | [diff] [blame] | 68 | ARMNN_LOG(info) << "\n"; |
| 69 | ARMNN_LOG(info) << "Walking Pattern: \n"; |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 70 | |
| 71 | for (auto&& it : TopologicalSort()) |
| 72 | { |
Derek Lamberti | 0844697 | 2019-11-26 16:38:31 +0000 | [diff] [blame] | 73 | ARMNN_LOG(info) << it->GetName() << ":" << GetLayerTypeAsCString(it->GetType()) |
David Beck | 33f0ae0 | 2018-10-18 15:13:56 +0100 | [diff] [blame] | 74 | << ":" << it->GetBackendId().Get(); |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 75 | } |
Derek Lamberti | 0844697 | 2019-11-26 16:38:31 +0000 | [diff] [blame] | 76 | ARMNN_LOG(info) << "\n\n"; |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 77 | |
| 78 | return Status::Success; |
| 79 | } |
| 80 | |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 81 | Status Graph::SerializeToDot(std::ostream& stream) |
| 82 | { |
| 83 | { |
| 84 | DotGraph graph(stream, "Optimized"); |
| 85 | |
| 86 | { |
| 87 | // Default node attributes: |
| 88 | DotDefaults nodes(stream, "node"); |
| 89 | nodes.GetAttributeSet() |
| 90 | .AddAttribute("shape", "record"); |
| 91 | } |
| 92 | |
| 93 | { |
| 94 | // Default edge attributes: |
| 95 | DotDefaults edges(stream, "edge"); |
| 96 | edges.GetAttributeSet() |
| 97 | .AddAttribute("fontsize", 8) |
| 98 | .AddAttribute("fontcolor", "blue") |
| 99 | .AddAttribute("fontname", "arial-bold"); |
| 100 | } |
| 101 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 102 | // First declares the nodes. |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 103 | for (auto&& layer : m_Layers) |
| 104 | { |
| 105 | DotNode node(stream, layer->GetGuid(), GetLayerTypeAsCString(layer->GetType())); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 106 | // Extracts the layer parameters. |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 107 | ParameterStringifyFunction extractParams = [&node](const std::string & name, const std::string & value){ |
| 108 | node.GetContents().AddContent(name + " : " + value); |
| 109 | }; |
| 110 | layer->SerializeLayerParameters(extractParams); |
| 111 | } |
| 112 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 113 | // Second declares the edges. |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 114 | for (auto&& layer : m_Layers) |
| 115 | { |
| 116 | LayerGuid toId = layer->GetGuid(); |
| 117 | |
| 118 | for (unsigned int i=0;i<layer->GetNumInputSlots(); i++) |
| 119 | { |
| 120 | OutputSlot* outputSlot = static_cast<OutputSlot*>(layer->GetInputSlot(i).GetConnection()); |
| 121 | LayerGuid fromId = outputSlot->GetOwningLayer().GetGuid(); |
| 122 | DotEdge edge(stream, fromId, toId); |
| 123 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 124 | // Now print the tensor shape on the edge. |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 125 | { |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 126 | // Constructs the label attribute with HTML markup. |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 127 | std::stringstream ss; |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 128 | ss << "< " << outputSlot->GetTensorInfo().GetShape() << " >"; |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 129 | edge.GetAttributeSet().AddAttribute("label", ss); |
| 130 | } |
| 131 | } |
| 132 | } |
| 133 | } |
| 134 | |
| 135 | if (stream.bad()) |
| 136 | { |
| 137 | return Status::Failure; |
| 138 | } |
| 139 | return Status::Success; |
| 140 | } |
| 141 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 142 | Status Graph::AllocateDynamicBuffers() |
| 143 | { |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 144 | // Layers must be sorted in topological order |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 145 | ARMNN_ASSERT(m_LayersInOrder); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 146 | |
| 147 | std::unordered_set<const ITensorHandle*> preallocatedTensors; |
| 148 | std::unordered_map<const ITensorHandle*, unsigned int> handleReferenceCounts; |
| 149 | |
| 150 | // Finds the first TensorHandle ancestor of a SubTensorHandle. If the ITensorHandle provided |
| 151 | // is a TensorHandle, the function just returns it |
| 152 | auto TraceSubTensorHandleAncestry = [](ITensorHandle* const subTensorHandle) |
| 153 | { |
| 154 | ITensorHandle* ancestor = subTensorHandle; |
| 155 | while (ancestor && ancestor->GetParent()) |
| 156 | { |
| 157 | ancestor = ancestor->GetParent(); |
| 158 | } |
| 159 | return ancestor; |
| 160 | }; |
| 161 | |
| 162 | // Checks whether a TensorHandle has been pre-allocated |
| 163 | auto IsPreallocated = [&](ITensorHandle* const tensorHandle) |
| 164 | { |
| 165 | return tensorHandle && preallocatedTensors.find(tensorHandle) != preallocatedTensors.end(); |
| 166 | }; |
| 167 | |
| 168 | // Constant tensor handles need to last from the beginning of execution till the end, |
| 169 | // therefore we pre-allocate them upfront |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 170 | for (auto&& layer : m_Layers) |
| 171 | { |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 172 | if (layer->GetType() == LayerType::Constant) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 173 | { |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 174 | for (auto&& slot = layer->BeginOutputSlots(); slot != layer->EndOutputSlots(); ++slot) |
| 175 | { |
| 176 | ITensorHandle *tensorHandle = TraceSubTensorHandleAncestry(slot->GetOutputHandler().GetData()); |
| 177 | |
| 178 | if (tensorHandle && !IsPreallocated(tensorHandle)) |
| 179 | { |
| 180 | tensorHandle->Allocate(); |
| 181 | preallocatedTensors.insert(tensorHandle); |
| 182 | } |
| 183 | } |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 184 | } |
| 185 | } |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 186 | |
| 187 | // Iterate over the network in topological order |
| 188 | for (auto&& layer : m_Layers) |
| 189 | { |
| 190 | // Count the amount of times each output slot references a certain buffer (ITensorHandle). |
| 191 | // The first time we encounter a new tensor handle, we start managing its lifetime. |
| 192 | for (auto&& slot = layer->BeginOutputSlots(); slot != layer->EndOutputSlots(); ++slot) |
| 193 | { |
| 194 | ITensorHandle *tensorHandle = TraceSubTensorHandleAncestry(slot->GetOutputHandler().GetData()); |
| 195 | |
| 196 | if (tensorHandle && !IsPreallocated(tensorHandle)) |
| 197 | { |
| 198 | unsigned int numConnections = slot->GetNumConnections(); |
| 199 | if (handleReferenceCounts.find(tensorHandle) == handleReferenceCounts.end()) |
| 200 | { |
| 201 | handleReferenceCounts[tensorHandle] = numConnections; |
| 202 | tensorHandle->Manage(); |
Pablo Tello | 9af1dcd | 2019-12-03 15:46:50 +0000 | [diff] [blame] | 203 | if (handleReferenceCounts[tensorHandle] == 0u) |
| 204 | { |
| 205 | // if nobody consumes this tensor we call Allocate() |
| 206 | tensorHandle->Allocate(); |
| 207 | } |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 208 | } |
| 209 | else |
| 210 | { |
| 211 | handleReferenceCounts[tensorHandle] += numConnections; |
| 212 | } |
| 213 | } |
| 214 | } |
| 215 | |
| 216 | // Loop through the input slots in the same layer and decrement the reference counter associated |
| 217 | // to each tensor handle we encounter. Once it reaches zero, we end the lifetime of the tensor handle |
| 218 | for (auto&& slot = layer->BeginInputSlots(); slot != layer->EndInputSlots(); ++slot) |
| 219 | { |
| 220 | ITensorHandle *tensorHandle = TraceSubTensorHandleAncestry( |
| 221 | slot->GetConnectedOutputSlot()->GetOutputHandler().GetData()); |
| 222 | |
| 223 | if (tensorHandle && !IsPreallocated(tensorHandle)) |
| 224 | { |
| 225 | --handleReferenceCounts[tensorHandle]; |
| 226 | |
| 227 | if (handleReferenceCounts[tensorHandle] == 0u) |
| 228 | { |
| 229 | // Stop managing lifetime of tensor handle |
| 230 | tensorHandle->Allocate(); |
| 231 | handleReferenceCounts.erase(tensorHandle); |
| 232 | } |
| 233 | } |
| 234 | } |
| 235 | } |
| 236 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 237 | return Status::Success; |
| 238 | } |
| 239 | |
| 240 | const Graph& Graph::TopologicalSort() const |
| 241 | { |
| 242 | if (!m_LayersInOrder) |
| 243 | { |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 244 | // Resets layer order. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 245 | for (auto&& it : m_Layers) |
| 246 | { |
| 247 | it->ResetPriority(); |
| 248 | } |
| 249 | |
Matteo Martincigh | adddddb | 2019-01-24 14:06:23 +0000 | [diff] [blame] | 250 | auto compareLayerPriority = [](const LayerList::value_type& layerA, const LayerList::value_type& layerB) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 251 | { |
| 252 | return layerA->GetPriority() < layerB->GetPriority(); |
| 253 | }; |
| 254 | |
| 255 | m_Layers.sort(compareLayerPriority); |
| 256 | |
| 257 | m_LayersInOrder = true; |
| 258 | } |
| 259 | |
| 260 | return *this; |
| 261 | } |
| 262 | |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 263 | void Graph::AddCompatibilityLayers(std::map<BackendId, std::unique_ptr<IBackendInternal>>& backends, |
| 264 | TensorHandleFactoryRegistry& registry) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 265 | { |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 266 | // Returns true if the given layer could potentially need an intermediate copy/import layer (depending on its |
| 267 | // connections to other layers). |
| 268 | auto MayNeedCompatibilityLayer = [](const Layer& layer) |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 269 | { |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 270 | // All layers should have been associated with a valid compute device at this point. |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 271 | ARMNN_ASSERT(layer.GetBackendId() != Compute::Undefined); |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 272 | // Does not need another compatibility layer if a copy or import layer is already present. |
| 273 | return layer.GetType() != LayerType::MemCopy && |
| 274 | layer.GetType() != LayerType::MemImport; |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 275 | }; |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 276 | |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 277 | auto IsCompatibilityStrategy = [](EdgeStrategy strategy) |
| 278 | { |
| 279 | return strategy == EdgeStrategy::CopyToTarget || |
| 280 | strategy == EdgeStrategy::ExportToTarget; |
| 281 | }; |
| 282 | |
| 283 | ForEachLayer([this, &backends, ®istry, MayNeedCompatibilityLayer, IsCompatibilityStrategy](Layer* srcLayer) |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 284 | { |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 285 | ARMNN_ASSERT(srcLayer); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 286 | |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 287 | if (!MayNeedCompatibilityLayer(*srcLayer)) |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 288 | { |
| 289 | // The current layer does not need copy layers, move to the next one |
| 290 | return; |
| 291 | } |
| 292 | |
| 293 | const std::vector<OutputSlot>& srcOutputSlots = srcLayer->GetOutputSlots(); |
| 294 | for (unsigned int srcOutputIndex = 0; srcOutputIndex < srcOutputSlots.size(); srcOutputIndex++) |
| 295 | { |
| 296 | OutputSlot& srcOutputSlot = srcLayer->GetOutputSlot(srcOutputIndex); |
| 297 | const std::vector<InputSlot*> srcConnections = srcOutputSlot.GetConnections(); |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 298 | const std::vector<EdgeStrategy> srcEdgeStrategies = srcOutputSlot.GetEdgeStrategies(); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 299 | for (unsigned int srcConnectionIndex = 0; srcConnectionIndex < srcConnections.size(); srcConnectionIndex++) |
| 300 | { |
| 301 | InputSlot* dstInputSlot = srcConnections[srcConnectionIndex]; |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 302 | ARMNN_ASSERT(dstInputSlot); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 303 | |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 304 | EdgeStrategy strategy = srcEdgeStrategies[srcConnectionIndex]; |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 305 | ARMNN_ASSERT_MSG(strategy != EdgeStrategy::Undefined, |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 306 | "Undefined memory strategy found while adding copy layers for compatibility"); |
| 307 | |
| 308 | const Layer& dstLayer = dstInputSlot->GetOwningLayer(); |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 309 | if (MayNeedCompatibilityLayer(dstLayer) && |
| 310 | IsCompatibilityStrategy(strategy)) |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 311 | { |
| 312 | // A copy layer is needed in between the source and destination layers. |
| 313 | // Record the operation rather than attempting to modify the graph as we go. |
| 314 | // (invalidating iterators) |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 315 | const std::string compLayerName = boost::str(boost::format("[ %1% (%2%) -> %3% (%4%) ]") |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 316 | % srcLayer->GetName() |
| 317 | % srcOutputIndex |
| 318 | % dstLayer.GetName() |
| 319 | % dstInputSlot->GetSlotIndex()); |
| 320 | |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 321 | Layer* compLayer = nullptr; |
| 322 | if (strategy == EdgeStrategy::CopyToTarget) |
| 323 | { |
| 324 | compLayer = InsertNewLayer<MemCopyLayer>(*dstInputSlot, compLayerName.c_str()); |
| 325 | } |
| 326 | else |
| 327 | { |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 328 | ARMNN_ASSERT_MSG(strategy == EdgeStrategy::ExportToTarget, "Invalid edge strategy found."); |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 329 | compLayer = InsertNewLayer<MemImportLayer>(*dstInputSlot, compLayerName.c_str()); |
| 330 | } |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 331 | |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 332 | compLayer->SetBackendId(dstLayer.GetBackendId()); |
| 333 | |
| 334 | OutputSlot& compOutputSlot = compLayer->GetOutputSlot(0); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 335 | auto backendIt = backends.find(dstLayer.GetBackendId()); |
| 336 | if (backendIt != backends.end() && |
| 337 | backendIt->second && |
| 338 | backendIt->second->SupportsTensorAllocatorAPI()) |
| 339 | { |
| 340 | auto backend = backendIt->second.get(); |
| 341 | auto tensorHandleFactoryIds = backend->GetHandleFactoryPreferences(); |
| 342 | bool found = false; |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 343 | |
| 344 | for (auto preference : tensorHandleFactoryIds) |
| 345 | { |
| 346 | auto factory = registry.GetFactory(preference); |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 347 | if (factory) |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 348 | { |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 349 | auto srcPref = srcOutputSlot.GetTensorHandleFactoryId(); |
| 350 | auto srcFactory = registry.GetFactory(srcPref); |
Ferran Balaguer | bfeb271 | 2019-08-07 15:14:56 +0100 | [diff] [blame] | 351 | |
Ferran Balaguer | 9752010 | 2019-08-14 12:11:27 +0100 | [diff] [blame] | 352 | if (srcFactory) |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 353 | { |
Ferran Balaguer | 9752010 | 2019-08-14 12:11:27 +0100 | [diff] [blame] | 354 | bool canExportImport = |
Ferran Balaguer | bfeb271 | 2019-08-07 15:14:56 +0100 | [diff] [blame] | 355 | (factory->GetImportFlags() & srcFactory->GetExportFlags()) != 0; |
| 356 | |
Ferran Balaguer | 9752010 | 2019-08-14 12:11:27 +0100 | [diff] [blame] | 357 | if (factory->SupportsMapUnmap() || canExportImport) |
| 358 | { |
| 359 | compOutputSlot.SetTensorHandleFactory(preference); |
| 360 | found = true; |
| 361 | break; |
| 362 | } |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 363 | } |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 364 | } |
| 365 | } |
| 366 | |
Ferran Balaguer | 9752010 | 2019-08-14 12:11:27 +0100 | [diff] [blame] | 367 | if (!found) |
| 368 | { |
| 369 | compOutputSlot.SetTensorHandleFactory(ITensorHandleFactory::LegacyFactoryId); |
| 370 | } |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 371 | } |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 372 | else |
| 373 | { |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 374 | compOutputSlot.SetTensorHandleFactory(ITensorHandleFactory::LegacyFactoryId); |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 375 | } |
| 376 | |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 377 | // The output strategy of a compatibility layer is always DirectCompatibility. |
| 378 | compOutputSlot.SetEdgeStrategy(0, EdgeStrategy::DirectCompatibility); |
Matthew Bentham | 0cf01dc | 2019-07-30 08:24:12 +0000 | [diff] [blame] | 379 | |
| 380 | // Recalculate the connection index on the previous layer as we have just inserted into it. |
| 381 | const std::vector<InputSlot*>& newSourceConnections = srcOutputSlot.GetConnections(); |
| 382 | long newSrcConnectionIndex = std::distance(newSourceConnections.begin(), |
| 383 | std::find(newSourceConnections.begin(), |
| 384 | newSourceConnections.end(), |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 385 | &compLayer->GetInputSlot(0))); |
Matthew Bentham | 0cf01dc | 2019-07-30 08:24:12 +0000 | [diff] [blame] | 386 | |
Derek Lamberti | f674aa0 | 2019-08-01 15:56:25 +0100 | [diff] [blame] | 387 | // The input strategy of a compatibility layer is always DirectCompatibilty. |
| 388 | srcOutputSlot.SetEdgeStrategy(boost::numeric_cast<unsigned int>(newSrcConnectionIndex), |
| 389 | EdgeStrategy::DirectCompatibility); |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 390 | } |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 391 | } |
| 392 | } |
Derek Lamberti | 84da38b | 2019-06-13 11:40:08 +0100 | [diff] [blame] | 393 | }); |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 394 | } |
| 395 | |
Derek Lamberti | c2fe5fb | 2019-05-08 10:23:08 +0100 | [diff] [blame] | 396 | void Graph::SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer* substituteLayer) |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 397 | { |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 398 | ARMNN_ASSERT(substituteLayer != nullptr); |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 399 | |
Derek Lamberti | c2fe5fb | 2019-05-08 10:23:08 +0100 | [diff] [blame] | 400 | ReplaceSubgraphConnections(subgraph, substituteLayer); |
| 401 | EraseSubgraphLayers(subgraph); |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 402 | } |
| 403 | |
Derek Lamberti | c2fe5fb | 2019-05-08 10:23:08 +0100 | [diff] [blame] | 404 | void Graph::SubstituteSubgraph(SubgraphView& subgraph, const SubgraphView& substituteSubgraph) |
Matteo Martincigh | adddddb | 2019-01-24 14:06:23 +0000 | [diff] [blame] | 405 | { |
David Monahan | 5200afa | 2019-05-10 11:52:14 +0100 | [diff] [blame] | 406 | // Look through each layer in the new subgraph and add any that are not already a member of this graph |
| 407 | substituteSubgraph.ForEachLayer([this](Layer* layer) |
| 408 | { |
| 409 | if (std::find(std::begin(m_Layers), std::end(m_Layers), layer) == std::end(m_Layers)) |
| 410 | { |
| 411 | layer->Reparent(*this, m_Layers.end()); |
| 412 | m_LayersInOrder = false; |
| 413 | } |
| 414 | }); |
| 415 | |
Derek Lamberti | c2fe5fb | 2019-05-08 10:23:08 +0100 | [diff] [blame] | 416 | ReplaceSubgraphConnections(subgraph, substituteSubgraph); |
| 417 | EraseSubgraphLayers(subgraph); |
Matteo Martincigh | e3a4245 | 2019-05-23 13:56:01 +0100 | [diff] [blame] | 418 | TopologicalSort(); |
Matteo Martincigh | adddddb | 2019-01-24 14:06:23 +0000 | [diff] [blame] | 419 | } |
| 420 | |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 421 | void Graph::ReplaceSubgraphConnections(const SubgraphView& subgraph, IConnectableLayer* substituteLayer) |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 422 | { |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 423 | ARMNN_ASSERT(substituteLayer != nullptr); |
Matteo Martincigh | adddddb | 2019-01-24 14:06:23 +0000 | [diff] [blame] | 424 | |
Matteo Martincigh | 0c051f9 | 2019-01-31 12:09:49 +0000 | [diff] [blame] | 425 | // Create a new sub-graph with only the given layer, using |
| 426 | // the given sub-graph as a reference of which parent graph to use |
Matteo Martincigh | 602af09 | 2019-05-01 10:31:27 +0100 | [diff] [blame] | 427 | SubgraphView substituteSubgraph(substituteLayer); |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 428 | ReplaceSubgraphConnections(subgraph, substituteSubgraph); |
Matteo Martincigh | adddddb | 2019-01-24 14:06:23 +0000 | [diff] [blame] | 429 | } |
| 430 | |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 431 | void Graph::ReplaceSubgraphConnections(const SubgraphView& subgraph, const SubgraphView& substituteSubgraph) |
Matteo Martincigh | adddddb | 2019-01-24 14:06:23 +0000 | [diff] [blame] | 432 | { |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 433 | ARMNN_ASSERT_MSG(!substituteSubgraph.GetLayers().empty(), "New sub-graph used for substitution must not be empty"); |
Matteo Martincigh | adddddb | 2019-01-24 14:06:23 +0000 | [diff] [blame] | 434 | |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 435 | const SubgraphView::Layers& substituteSubgraphLayers = substituteSubgraph.GetLayers(); |
| 436 | std::for_each(substituteSubgraphLayers.begin(), substituteSubgraphLayers.end(), [&](Layer* layer) |
Matteo Martincigh | adddddb | 2019-01-24 14:06:23 +0000 | [diff] [blame] | 437 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 438 | IgnoreUnused(layer); |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 439 | ARMNN_ASSERT_MSG(std::find(m_Layers.begin(), m_Layers.end(), layer) != m_Layers.end(), |
Matteo Martincigh | adddddb | 2019-01-24 14:06:23 +0000 | [diff] [blame] | 440 | "Substitute layer is not a member of graph"); |
| 441 | }); |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 442 | |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 443 | const SubgraphView::InputSlots& subgraphInputSlots = subgraph.GetInputSlots(); |
| 444 | const SubgraphView::OutputSlots& subgraphOutputSlots = subgraph.GetOutputSlots(); |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 445 | |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 446 | unsigned int subgraphNumInputSlots = boost::numeric_cast<unsigned int>(subgraphInputSlots.size()); |
| 447 | unsigned int subgraphNumOutputSlots = boost::numeric_cast<unsigned int>(subgraphOutputSlots.size()); |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 448 | |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 449 | const SubgraphView::InputSlots& substituteSubgraphInputSlots = substituteSubgraph.GetInputSlots(); |
| 450 | const SubgraphView::OutputSlots& substituteSubgraphOutputSlots = substituteSubgraph.GetOutputSlots(); |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 451 | |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 452 | ARMNN_ASSERT(subgraphNumInputSlots == substituteSubgraphInputSlots.size()); |
| 453 | ARMNN_ASSERT(subgraphNumOutputSlots == substituteSubgraphOutputSlots.size()); |
Matteo Martincigh | adddddb | 2019-01-24 14:06:23 +0000 | [diff] [blame] | 454 | |
| 455 | // Disconnect the sub-graph and replace it with the substitute sub-graph |
| 456 | |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 457 | // Step 1: process input slots |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 458 | for (unsigned int inputSlotIdx = 0; inputSlotIdx < subgraphNumInputSlots; ++inputSlotIdx) |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 459 | { |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 460 | InputSlot* subgraphInputSlot = subgraphInputSlots.at(inputSlotIdx); |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 461 | ARMNN_ASSERT(subgraphInputSlot); |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 462 | |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 463 | IOutputSlot* connectedOutputSlot = subgraphInputSlot->GetConnection(); |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 464 | ARMNN_ASSERT(connectedOutputSlot); |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 465 | connectedOutputSlot->Disconnect(*subgraphInputSlot); |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 466 | |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 467 | IInputSlot* substituteInputSlot = substituteSubgraphInputSlots.at(inputSlotIdx); |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 468 | ARMNN_ASSERT(substituteInputSlot); |
Matteo Martincigh | adddddb | 2019-01-24 14:06:23 +0000 | [diff] [blame] | 469 | connectedOutputSlot->Connect(*substituteInputSlot); |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 470 | } |
| 471 | |
| 472 | // Step 2: process output slots |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 473 | for(unsigned int outputSlotIdx = 0; outputSlotIdx < subgraphNumOutputSlots; ++outputSlotIdx) |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 474 | { |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 475 | OutputSlot* subgraphOutputSlot = subgraphOutputSlots.at(outputSlotIdx); |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 476 | ARMNN_ASSERT(subgraphOutputSlot); |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 477 | |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 478 | OutputSlot* substituteOutputSlot = substituteSubgraphOutputSlots.at(outputSlotIdx); |
Narumol Prangnawarat | ac2770a | 2020-04-01 16:51:23 +0100 | [diff] [blame^] | 479 | ARMNN_ASSERT(substituteOutputSlot); |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 480 | subgraphOutputSlot->MoveAllConnections(*substituteOutputSlot); |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 481 | } |
| 482 | } |
| 483 | |
Derek Lamberti | c2fe5fb | 2019-05-08 10:23:08 +0100 | [diff] [blame] | 484 | void Graph::EraseSubgraphLayers(SubgraphView &subgraph) |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 485 | { |
Derek Lamberti | ff05cc5 | 2019-04-26 13:05:17 +0100 | [diff] [blame] | 486 | for (auto layer : subgraph.GetLayers()) |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 487 | { |
| 488 | EraseLayer(layer); |
| 489 | } |
Derek Lamberti | c2fe5fb | 2019-05-08 10:23:08 +0100 | [diff] [blame] | 490 | subgraph.Clear(); |
Matteo Martincigh | 4912402 | 2019-01-11 13:25:59 +0000 | [diff] [blame] | 491 | } |
| 492 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 493 | void Graph::InferTensorInfos() |
| 494 | { |
| 495 | for (auto&& layer : TopologicalSort()) |
| 496 | { |
| 497 | for (auto&& input : layer->GetInputSlots()) |
| 498 | { |
Matthew Bentham | 584a2b8 | 2019-11-01 13:29:48 +0000 | [diff] [blame] | 499 | const IOutputSlot* source = input.GetConnectedOutputSlot(); |
| 500 | if (source == NULL) |
| 501 | { |
| 502 | std::ostringstream message; |
| 503 | message << "Input not connected on " |
| 504 | << GetLayerTypeAsCString(layer->GetType()) |
| 505 | << " layer \"" |
| 506 | << layer->GetName() |
| 507 | << "\""; |
| 508 | throw LayerValidationException(message.str()); |
| 509 | } |
| 510 | |
| 511 | if (!source->IsTensorInfoSet()) |
| 512 | { |
| 513 | throw LayerValidationException("All inputs must have the TensorInfo set at this point."); |
| 514 | } |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 515 | } |
| 516 | layer->ValidateTensorShapesFromInputs(); |
| 517 | } |
| 518 | } |
| 519 | |
| 520 | } // namespace armnn |