IVGCVSW-4893 Refactor ILayerVisitor using unified interface strategy.
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Change-Id: Id7bc8255a8e3f9e5aac65d510bec8a559bf37246
diff --git a/src/armnnQuantizer/QuantizationDataSet.cpp b/src/armnnQuantizer/QuantizationDataSet.cpp
index acd301a..99fc021 100644
--- a/src/armnnQuantizer/QuantizationDataSet.cpp
+++ b/src/armnnQuantizer/QuantizationDataSet.cpp
@@ -47,6 +47,36 @@
{
}
+
+/// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine.
+
+void InputLayerStrategy::ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id)
+{
+ armnn::IgnoreUnused(name, descriptor, constants);
+
+ m_TensorInfos.emplace(id, layer->GetOutputSlot(0).GetTensorInfo());
+}
+
+
+
+
+armnn::TensorInfo InputLayerStrategy::GetTensorInfo(armnn::LayerBindingId layerBindingId)
+{
+ auto iterator = m_TensorInfos.find(layerBindingId);
+ if (iterator != m_TensorInfos.end())
+ {
+ return m_TensorInfos.at(layerBindingId);
+ }
+ else
+ {
+ throw armnn::Exception("Could not retrieve tensor info for binding ID " + std::to_string(layerBindingId));
+ }
+}
+
void InputLayerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer,
armnn::LayerBindingId id,
const char* name)