IVGCVSW-4893 Refactor ILayerVisitor using unified interface strategy.

Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Change-Id: Id7bc8255a8e3f9e5aac65d510bec8a559bf37246
diff --git a/src/armnnQuantizer/ArmNNQuantizerMain.cpp b/src/armnnQuantizer/ArmNNQuantizerMain.cpp
index 219363e..49652ef 100644
--- a/src/armnnQuantizer/ArmNNQuantizerMain.cpp
+++ b/src/armnnQuantizer/ArmNNQuantizerMain.cpp
@@ -61,8 +61,8 @@
         if (!dataSet.IsEmpty())
         {
             // Get the Input Tensor Infos
-            armnnQuantizer::InputLayerVisitor inputLayerVisitor;
-            network->Accept(inputLayerVisitor);
+            armnnQuantizer::InputLayerStrategy inputLayerStrategy;
+            network->ExecuteStrategy(inputLayerStrategy);
 
             for (armnnQuantizer::QuantizationInput quantizationInput : dataSet)
             {
@@ -72,7 +72,7 @@
                 unsigned int count = 0;
                 for (armnn::LayerBindingId layerBindingId : quantizationInput.GetLayerBindingIds())
                 {
-                    armnn::TensorInfo tensorInfo = inputLayerVisitor.GetTensorInfo(layerBindingId);
+                    armnn::TensorInfo tensorInfo = inputLayerStrategy.GetTensorInfo(layerBindingId);
                     inputData[count] = quantizationInput.GetDataForEntry(layerBindingId);
                     armnn::ConstTensor inputTensor(tensorInfo, inputData[count].data());
                     inputTensors.push_back(std::make_pair(layerBindingId, inputTensor));
diff --git a/src/armnnQuantizer/QuantizationDataSet.cpp b/src/armnnQuantizer/QuantizationDataSet.cpp
index acd301a..99fc021 100644
--- a/src/armnnQuantizer/QuantizationDataSet.cpp
+++ b/src/armnnQuantizer/QuantizationDataSet.cpp
@@ -47,6 +47,36 @@
 {
 }
 
+
+/// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine.
+
+void InputLayerStrategy::ExecuteStrategy(const armnn::IConnectableLayer* layer,
+                                         const armnn::BaseDescriptor& descriptor,
+                                         const std::vector<armnn::ConstTensor>& constants,
+                                         const char* name,
+                                         const armnn::LayerBindingId id)
+{
+    armnn::IgnoreUnused(name, descriptor, constants);
+
+    m_TensorInfos.emplace(id, layer->GetOutputSlot(0).GetTensorInfo());
+}
+
+
+
+
+armnn::TensorInfo InputLayerStrategy::GetTensorInfo(armnn::LayerBindingId layerBindingId)
+{
+    auto iterator = m_TensorInfos.find(layerBindingId);
+    if (iterator != m_TensorInfos.end())
+    {
+        return m_TensorInfos.at(layerBindingId);
+    }
+    else
+    {
+        throw armnn::Exception("Could not retrieve tensor info for binding ID " + std::to_string(layerBindingId));
+    }
+}
+
 void InputLayerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer,
                                         armnn::LayerBindingId id,
                                         const char* name)
diff --git a/src/armnnQuantizer/QuantizationDataSet.hpp b/src/armnnQuantizer/QuantizationDataSet.hpp
index 3a97630..47b893a 100644
--- a/src/armnnQuantizer/QuantizationDataSet.hpp
+++ b/src/armnnQuantizer/QuantizationDataSet.hpp
@@ -43,6 +43,22 @@
 };
 
 /// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine.
+class InputLayerStrategy : public armnn::IStrategy
+{
+public:
+    virtual void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+                                 const armnn::BaseDescriptor& descriptor,
+                                 const std::vector<armnn::ConstTensor>& constants,
+                                 const char* name,
+                                 const armnn::LayerBindingId id = 0) override;
+
+    armnn::TensorInfo GetTensorInfo(armnn::LayerBindingId);
+private:
+    std::map<armnn::LayerBindingId, armnn::TensorInfo> m_TensorInfos;
+};
+
+
+/// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine.
 class InputLayerVisitor : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
 {
 public: