IVGCVSW-3031 Reparent layer to new graph
Change-Id: Ic4423b8d21d794f44ddae291853e0e3b89d11bc0
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp
index 88d2002..062d727 100644
--- a/src/armnn/Graph.hpp
+++ b/src/armnn/Graph.hpp
@@ -235,18 +235,40 @@
protected:
template <typename... Args>
LayerInGraphBase(Graph& graph, Iterator insertBefore, Args&&... args)
- : LayerT(std::forward<Args>(args)...), m_Graph(graph)
+ : LayerT(std::forward<Args>(args)...), m_Graph(&graph)
{
- m_Graph.m_PosInGraphMap.emplace(this, m_Graph.m_Layers.emplace(insertBefore, this));
+ Insert(*m_Graph, insertBefore);
}
~LayerInGraphBase()
{
- const size_t numErased = m_Graph.m_PosInGraphMap.erase(this);
+ Remove(*m_Graph);
+ }
+
+ void Reparent(Graph& destGraph, Iterator insertBefore) override
+ {
+ Insert(destGraph, insertBefore);
+ Remove(*m_Graph);
+ m_Graph->m_Layers.erase(m_Graph->GetPosInGraph(*this));
+
+ m_Graph = &destGraph;
+ }
+
+private:
+ void Insert(Graph& graph, Iterator insertBefore)
+ {
+ graph.m_PosInGraphMap.emplace(this, graph.m_Layers.emplace(insertBefore, this));
+ }
+
+ void Remove(Graph& graph)
+ {
+ const size_t numErased = graph.m_PosInGraphMap.erase(this);
boost::ignore_unused(numErased);
BOOST_ASSERT(numErased == 1);
}
- Graph& m_Graph;
+protected:
+ Graph* m_Graph;
+
};
/// Input/Output layers specialize this template.
@@ -284,7 +306,7 @@
std::next(graph.begin(), IteratorDifference(graph.GetNumInputs())),
std::forward<Args>(args)...)
{
- const bool isNewId = m_Graph.m_InputIds.emplace(GetBindingId()).second;
+ const bool isNewId = m_Graph->m_InputIds.emplace(GetBindingId()).second;
if (!isNewId)
{
throw InvalidArgumentException("A layer already exists with the specified id");
@@ -298,7 +320,7 @@
}
~LayerInGraph() override
{
- const size_t numErased = m_Graph.m_InputIds.erase(GetBindingId());
+ const size_t numErased = m_Graph->m_InputIds.erase(GetBindingId());
boost::ignore_unused(numErased);
BOOST_ASSERT(numErased == 1);
}
@@ -316,7 +338,7 @@
graph.end(),
std::forward<Args>(args)...)
{
- const bool isNewId = m_Graph.m_OutputIds.emplace(GetBindingId()).second;
+ const bool isNewId = m_Graph->m_OutputIds.emplace(GetBindingId()).second;
if (!isNewId)
{
throw InvalidArgumentException("A layer already exists with the specified id");
@@ -324,7 +346,7 @@
}
~LayerInGraph() override
{
- const size_t numErased = m_Graph.m_OutputIds.erase(GetBindingId());
+ const size_t numErased = m_Graph->m_OutputIds.erase(GetBindingId());
boost::ignore_unused(numErased);
BOOST_ASSERT(numErased == 1);
}
diff --git a/src/armnn/Layer.hpp b/src/armnn/Layer.hpp
index 507b37b..cbb1771 100644
--- a/src/armnn/Layer.hpp
+++ b/src/armnn/Layer.hpp
@@ -298,6 +298,7 @@
const std::list<std::string>& GetRelatedLayerNames() { return m_RelatedLayerNames; }
+ virtual void Reparent(Graph& dest, std::list<Layer*>::const_iterator iterator) = 0;
protected:
// Graph needs access to the virtual destructor.
friend class Graph;
diff --git a/src/backends/backendsCommon/IBackendInternal.hpp b/src/backends/backendsCommon/IBackendInternal.hpp
index f49a210..5316f68 100644
--- a/src/backends/backendsCommon/IBackendInternal.hpp
+++ b/src/backends/backendsCommon/IBackendInternal.hpp
@@ -65,7 +65,7 @@
// Default implementation of OptimizeSubgraphView for backward compatibility with old API.
// Override this method with a custom optimization implementation.
- virtual OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph)
+ virtual OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) const
{
bool attempted=false;
SubgraphViewUniquePtr optSubgraph = OptimizeSubgraphView(subgraph, attempted);
diff --git a/src/backends/backendsCommon/OptimizationViews.hpp b/src/backends/backendsCommon/OptimizationViews.hpp
index cf7051d..e1b59ed 100644
--- a/src/backends/backendsCommon/OptimizationViews.hpp
+++ b/src/backends/backendsCommon/OptimizationViews.hpp
@@ -45,9 +45,13 @@
Subgraphs GetUntouchedSubgraphs() const { return m_UntouchedSubgraphs; }
bool Validate(const SubgraphView& originalSubgraph) const;
+ Graph& GetGraph() { return m_Graph; };
+
private:
Substitutions m_SuccesfulOptimizations; ///< Proposed substitutions from successful optimizations
Subgraphs m_FailedOptimizations; ///< Subgraphs from the original subgraph which cannot be supported
Subgraphs m_UntouchedSubgraphs; ///< Subgraphs from the original subgraph which remain unmodified
+
+ Graph m_Graph;
};
} //namespace armnn
\ No newline at end of file