IVGCVSW-2607 Refactor range tracking into own class

Change-Id: I1b409e5dac7922859e04a554893b982afc5ad1e7
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
diff --git a/src/armnn/NetworkQuantizer.cpp b/src/armnn/NetworkQuantizer.cpp
index bc25d5e..bf5c9ef 100644
--- a/src/armnn/NetworkQuantizer.cpp
+++ b/src/armnn/NetworkQuantizer.cpp
@@ -45,7 +45,7 @@
     auto inputLayers = graph.GetInputLayers();
 
     // Walk the input layers of the graph and override the quantization parameters of the one with the given id
-    OverrideInputRangeVisitor overrideInputRangeVisitor(m_GuidToRangesMap, layerId, MinMaxRange{min, max});
+    OverrideInputRangeVisitor overrideInputRangeVisitor(m_Ranges, layerId, RangeTracker::MinMaxRange{min, max});
     VisitLayers(inputLayers, overrideInputRangeVisitor);
 }
 
@@ -54,11 +54,11 @@
     const Graph& graph = boost::polymorphic_downcast<const Network*>(m_InputNetwork)->GetGraph().TopologicalSort();
 
     // Step 1) Walk the graph and register min/max values for intermediate tensors
-    StaticRangeVisitor rangeVisitor(m_GuidToRangesMap);
+    StaticRangeVisitor rangeVisitor(m_Ranges);
     VisitLayers(graph, rangeVisitor);
 
     // Step 2) Convert input InputNetwork to Quantized InputNetwork
-    QuantizerVisitor quantizerVisitor(&rangeVisitor);
+    QuantizerVisitor quantizerVisitor(m_Ranges);
     VisitLayers(graph, quantizerVisitor);
 
     return quantizerVisitor.RetrieveFinalNetwork();
diff --git a/src/armnn/NetworkQuantizer.hpp b/src/armnn/NetworkQuantizer.hpp
index 2f7d365..5b87851 100644
--- a/src/armnn/NetworkQuantizer.hpp
+++ b/src/armnn/NetworkQuantizer.hpp
@@ -9,7 +9,7 @@
 #include <armnn/INetworkQuantizer.hpp>
 #include <armnn/Types.hpp>
 
-#include <unordered_map>
+#include "RangeTracker.hpp"
 
 namespace armnn
 {
@@ -23,13 +23,11 @@
     INetworkPtr ExportNetwork() override;
 
 private:
-    using MinMaxRange  = std::pair<float, float>;
-    using MinMaxRanges = std::vector<MinMaxRange>;
-
+    /// Original input network to quantize
     INetwork* m_InputNetwork;
 
     /// Mapping from Guid to an array of ranges for outputs
-    std::unordered_map<LayerGuid, MinMaxRanges> m_GuidToRangesMap;
+    RangeTracker m_Ranges;
 };
 
 } //namespace armnn
diff --git a/src/armnn/OverrideInputRangeVisitor.cpp b/src/armnn/OverrideInputRangeVisitor.cpp
index dba233f..058e630 100644
--- a/src/armnn/OverrideInputRangeVisitor.cpp
+++ b/src/armnn/OverrideInputRangeVisitor.cpp
@@ -12,36 +12,20 @@
 namespace armnn
 {
 
-OverrideInputRangeVisitor::OverrideInputRangeVisitor(std::unordered_map<LayerGuid, MinMaxRanges>& guidToRangesMap,
+OverrideInputRangeVisitor::OverrideInputRangeVisitor(RangeTracker& ranges,
                                                      LayerBindingId layerId,
                                                      const MinMaxRange& minMaxRange)
-    : m_GuidToRangesMap(guidToRangesMap)
+    : m_Ranges(ranges)
     , m_LayerId(layerId)
     , m_MinMaxRange(minMaxRange)
 {}
 
 void OverrideInputRangeVisitor::VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
 {
-    if (m_LayerId != id)
+    if (m_LayerId == id)
     {
-        // Not the layer we are looking for
-        return;
+        m_Ranges.SetRange(layer, 0, m_MinMaxRange.first, m_MinMaxRange.second);
     }
-
-    SetRange(layer);
-}
-
-void OverrideInputRangeVisitor::SetRange(const IConnectableLayer* layer)
-{
-    BOOST_ASSERT(layer);
-
-    auto& ranges = m_GuidToRangesMap[layer->GetGuid()];
-
-    if (ranges.size() < layer->GetNumOutputSlots())
-    {
-        ranges.resize(layer->GetNumOutputSlots());
-    }
-    ranges[0] = m_MinMaxRange;
 }
 
 } // namespace armnn
diff --git a/src/armnn/OverrideInputRangeVisitor.hpp b/src/armnn/OverrideInputRangeVisitor.hpp
index 72396b4..f09eeb9 100644
--- a/src/armnn/OverrideInputRangeVisitor.hpp
+++ b/src/armnn/OverrideInputRangeVisitor.hpp
@@ -7,6 +7,7 @@
 
 #include "NetworkQuantizer.hpp"
 #include "armnn/LayerVisitorBase.hpp"
+#include "RangeTracker.hpp"
 
 #include <unordered_map>
 
@@ -21,7 +22,7 @@
     using MinMaxRanges = std::vector<MinMaxRange>;
 
 public:
-    OverrideInputRangeVisitor(std::unordered_map<LayerGuid, MinMaxRanges>& guidToRangesMap,
+    OverrideInputRangeVisitor(RangeTracker& ranges,
                               LayerBindingId layerId,
                               const MinMaxRange& minMaxRange);
     ~OverrideInputRangeVisitor() = default;
@@ -29,11 +30,8 @@
     void VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name = nullptr) override;
 
 private:
-    /// Sets the range for the given input layer
-    void SetRange(const IConnectableLayer* layer);
-
     /// Mapping from a layer Guid to an array of ranges for outputs
-    std::unordered_map<LayerGuid, MinMaxRanges>& m_GuidToRangesMap;
+    RangeTracker& m_Ranges;
 
     /// The id of the input layer of which to override the input range
     LayerBindingId m_LayerId;
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp
index ae0d438..af01092 100644
--- a/src/armnn/QuantizerVisitor.cpp
+++ b/src/armnn/QuantizerVisitor.cpp
@@ -11,11 +11,10 @@
 namespace armnn
 {
 
-QuantizerVisitor::QuantizerVisitor(const StaticRangeVisitor* staticRangeVisitor)
-    : m_StaticRangeVisitor(staticRangeVisitor)
+QuantizerVisitor::QuantizerVisitor(const RangeTracker& rangeTracker)
+    : m_Ranges(rangeTracker)
     , m_QuantizedNetwork(INetwork::Create())
 {
-    BOOST_ASSERT(m_StaticRangeVisitor);
 }
 
 void QuantizerVisitor::SetQuantizedInputConnections(const IConnectableLayer* srcLayer,
@@ -45,7 +44,7 @@
         newOutputSlot.Connect(newInputSlot);
 
         // Fetch the min/max ranges that were computed earlier
-        auto range = m_StaticRangeVisitor->GetRange(layerToFind.GetGuid(), i);
+        auto range = m_Ranges.GetRange(layerToFind.GetGuid(), i);
         auto qParams = ComputeQAsymmParams(8, range.first, range.second);
 
         // Set the quantization params
diff --git a/src/armnn/QuantizerVisitor.hpp b/src/armnn/QuantizerVisitor.hpp
index c55ef6d..121ee17 100644
--- a/src/armnn/QuantizerVisitor.hpp
+++ b/src/armnn/QuantizerVisitor.hpp
@@ -24,7 +24,7 @@
 class QuantizerVisitor : public LayerVisitorBase<VisitorNoThrowPolicy>
 {
 public:
-    QuantizerVisitor(const StaticRangeVisitor* staticRangeVisitor);
+    QuantizerVisitor(const RangeTracker& rangeTracker);
     ~QuantizerVisitor() = default;
 
     /// Functions to quantize the individual layers, overridden from ILayerVisitor
@@ -79,7 +79,7 @@
     void RecordLayer(const IConnectableLayer* srcLayer, IConnectableLayer* qLayer);
 
     /// Reference to the static range visitor used to retrieve the quantization ranges
-    const StaticRangeVisitor* m_StaticRangeVisitor;
+    const RangeTracker& m_Ranges;
 
     /// Quantized version of the model we are building up
     INetworkPtr m_QuantizedNetwork;
diff --git a/src/armnn/RangeTracker.cpp b/src/armnn/RangeTracker.cpp
new file mode 100644
index 0000000..2025103
--- /dev/null
+++ b/src/armnn/RangeTracker.cpp
@@ -0,0 +1,32 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RangeTracker.hpp"
+
+namespace armnn
+{
+
+void RangeTracker::SetRange(const armnn::IConnectableLayer *layer, unsigned int outputIdx, float min, float max)
+{
+    auto& ranges = m_GuidToRangesMap[layer->GetGuid()];
+
+    if (ranges.size() < layer->GetNumOutputSlots())
+    {
+        ranges.resize(layer->GetNumOutputSlots());
+    }
+    ranges[outputIdx] = std::make_pair(min, max);
+}
+
+RangeTracker::MinMaxRange RangeTracker::GetRange(armnn::LayerGuid guid, unsigned int idx) const
+{
+    auto search = m_GuidToRangesMap.find(guid);
+    if (search == m_GuidToRangesMap.end())
+    {
+        return DefaultRange();
+    }
+    return search->second.at(idx);
+}
+
+} //namespace armnn
\ No newline at end of file
diff --git a/src/armnn/RangeTracker.hpp b/src/armnn/RangeTracker.hpp
new file mode 100644
index 0000000..2e8b33a
--- /dev/null
+++ b/src/armnn/RangeTracker.hpp
@@ -0,0 +1,44 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/INetwork.hpp>
+#include <armnn/Types.hpp>
+
+#include <utility>
+#include <unordered_map>
+
+namespace armnn
+{
+
+class RangeTracker
+{
+public:
+    using MinMaxRange  = std::pair<float, float>;
+
+    /// Retrieve the Range for a particular output slot on a particular layer
+    MinMaxRange GetRange(LayerGuid guid, unsigned int idx) const;
+
+    /// Set the range for an output slot on a layer
+    void SetRange(const IConnectableLayer* layer, unsigned int outputIdx, float min, float max);
+
+    /// Query function to check that the RangeTracker is empty.
+    bool IsEmpty() const { return m_GuidToRangesMap.empty(); }
+
+    /// Query that there is an entry for a layer
+    bool HasRanges(LayerGuid guid) const { return m_GuidToRangesMap.find(guid) != m_GuidToRangesMap.end(); }
+
+private:
+    using MinMaxRanges = std::vector<MinMaxRange>;
+
+    /// Retrieve the default range
+    MinMaxRange DefaultRange() const { return std::make_pair(-15.0f, 15.0f); }
+
+    /// Mapping from a layer Guid to an array of ranges for outputs
+    std::unordered_map<LayerGuid, MinMaxRanges> m_GuidToRangesMap;
+};
+
+} //namespace armnn
\ No newline at end of file
diff --git a/src/armnn/StaticRangeVisitor.cpp b/src/armnn/StaticRangeVisitor.cpp
index 6eab200..2365e1b 100644
--- a/src/armnn/StaticRangeVisitor.cpp
+++ b/src/armnn/StaticRangeVisitor.cpp
@@ -12,29 +12,13 @@
 namespace armnn
 {
 
-StaticRangeVisitor::StaticRangeVisitor(std::unordered_map<LayerGuid, MinMaxRanges>& guidToRangesMap)
-    : m_GuidToRangesMap(guidToRangesMap)
+StaticRangeVisitor::StaticRangeVisitor(RangeTracker& rangeTracker)
+    : m_RangeTracker(rangeTracker)
 {}
 
-StaticRangeVisitor::MinMaxRange StaticRangeVisitor::GetRange(LayerGuid guid, unsigned int idx) const
-{
-    auto search = m_GuidToRangesMap.find(guid);
-    if (search == m_GuidToRangesMap.end())
-    {
-        return DefaultRange();
-    }
-    return search->second.at(idx);
-}
-
 void StaticRangeVisitor::SetRange(const IConnectableLayer* layer, unsigned int outputIdx, float min, float max)
 {
-    auto& ranges = m_GuidToRangesMap[layer->GetGuid()];
-
-    if (ranges.size() < layer->GetNumOutputSlots())
-    {
-        ranges.resize(layer->GetNumOutputSlots());
-    }
-    ranges[outputIdx] = std::make_pair(min, max);
+    m_RangeTracker.SetRange(layer, outputIdx, min, max);
 }
 
 void StaticRangeVisitor::VisitAdditionLayer(const IConnectableLayer* layer, const char* name)
diff --git a/src/armnn/StaticRangeVisitor.hpp b/src/armnn/StaticRangeVisitor.hpp
index 2b01437..e1f68f3 100644
--- a/src/armnn/StaticRangeVisitor.hpp
+++ b/src/armnn/StaticRangeVisitor.hpp
@@ -6,11 +6,11 @@
 #pragma once
 
 #include "armnn/LayerVisitorBase.hpp"
+#include "RangeTracker.hpp"
 
 #include <armnn/INetwork.hpp>
 #include <armnn/INetworkQuantizer.hpp>
 
-#include <unordered_map>
 
 namespace armnn
 {
@@ -18,12 +18,8 @@
 /// Visitor class to establish min/max ranges based on the type of the layer
 class StaticRangeVisitor : public LayerVisitorBase<VisitorNoThrowPolicy>
 {
-private:
-    using MinMaxRange  = std::pair<float, float>;
-    using MinMaxRanges = std::vector<MinMaxRange>;
-
 public:
-    StaticRangeVisitor(std::unordered_map<LayerGuid, MinMaxRanges>& guidToRangesMap);
+    StaticRangeVisitor(RangeTracker& rangeTracker);
     ~StaticRangeVisitor() = default;
 
     /// Functions to set the Range on a per-layer-type basis
@@ -61,18 +57,13 @@
     void VisitSoftmaxLayer(const IConnectableLayer* layer,
                            const SoftmaxDescriptor& softmaxDescriptor,
                            const char* name = nullptr) override;
-    /// Retrieve the default range
-    MinMaxRange DefaultRange() const { return std::make_pair(-15.0f, 15.0f); }
-
-    /// Retrieve the Range for a particular output slot on a particular layer
-    MinMaxRange GetRange(LayerGuid guid, unsigned int idx) const;
 
 private:
     /// Set the range for an output slot on a layer
     void SetRange(const IConnectableLayer* layer, unsigned int outputIdx, float min, float max);
 
     /// Mapping from a layer Guid to an array of ranges for outputs
-    std::unordered_map<LayerGuid, MinMaxRanges>& m_GuidToRangesMap;
+    RangeTracker& m_RangeTracker;
 };
 
 } //namespace armnn
diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp
index ac9ea1d..a130c1f 100644
--- a/src/armnn/test/QuantizerTest.cpp
+++ b/src/armnn/test/QuantizerTest.cpp
@@ -13,6 +13,7 @@
 #include "../Graph.hpp"
 #include "../NetworkQuantizerUtils.hpp"
 #include "../OverrideInputRangeVisitor.hpp"
+#include "../RangeTracker.hpp"
 
 #include <boost/test/unit_test.hpp>
 
@@ -377,36 +378,36 @@
 
 BOOST_AUTO_TEST_CASE(OverrideInputRangeEmptyNetwork)
 {
-    MinMaxRangeMap guidToRangesMap; // Empty map of ranges
-    MinMaxRange minMaxRange(-12.3f, 45.6f); // Range to use for the override
+    RangeTracker ranges;
+    RangeTracker::MinMaxRange minMaxRange(-12.3f, 45.6f); // Range to use for the override
 
     Network network; // Empty network
     auto inputLayers = network.GetGraph().GetInputLayers(); // Empty list of input layers
 
-    OverrideInputRangeVisitor overrideInputRangeVisitor(guidToRangesMap, 0, minMaxRange);
+    OverrideInputRangeVisitor overrideInputRangeVisitor(ranges, 0, minMaxRange);
     VisitLayers(inputLayers, overrideInputRangeVisitor);
 
-    BOOST_CHECK(guidToRangesMap.empty()); // Check that the map of ranges remained untouched
+    BOOST_CHECK(ranges.IsEmpty()); // Check that the map of ranges remained untouched
 }
 
 BOOST_AUTO_TEST_CASE(OverrideInputRangeNoInputLayers)
 {
-    MinMaxRangeMap guidToRangesMap; // Empty map of ranges
+    RangeTracker ranges;
     MinMaxRange minMaxRange(-12.3f, 45.6f); // Range to use for the override
 
     Network network;
     network.AddAdditionLayer(); // Network with no input layers
     auto inputLayers = network.GetGraph().GetInputLayers(); // Empty list of input layers
 
-    OverrideInputRangeVisitor overrideInputRangeVisitor(guidToRangesMap, 0, minMaxRange);
+    OverrideInputRangeVisitor overrideInputRangeVisitor(ranges, 0, minMaxRange);
     VisitLayers(inputLayers, overrideInputRangeVisitor);
 
-    BOOST_CHECK(guidToRangesMap.empty()); // Check that the map of ranges remained untouched
+    BOOST_CHECK(ranges.IsEmpty()); // Check that the map of ranges remained untouched
 }
 
 BOOST_AUTO_TEST_CASE(OverrideInputRangeInputLayers)
 {
-    MinMaxRangeMap guidToRangesMap; // Empty map of ranges
+    RangeTracker ranges;
     MinMaxRange minMaxRange(-12.3f, 45.6f); // Range to use for the override
 
     Network network;
@@ -432,31 +433,27 @@
     auto inputLayers = network.GetGraph().GetInputLayers(); // List of input layers
 
     // Trying to override the input range for the input layer with binding id 3 (does not exist in the network)
-    OverrideInputRangeVisitor overrideInputRangeVisitorLayer3(guidToRangesMap, 3, minMaxRange);
+    OverrideInputRangeVisitor overrideInputRangeVisitorLayer3(ranges, 3, minMaxRange);
     VisitLayers(inputLayers, overrideInputRangeVisitorLayer3);
 
     // Check that the map of ranges remained untouched
-    BOOST_CHECK(guidToRangesMap.empty());
+    BOOST_CHECK(ranges.IsEmpty());
 
     // Override the input range for the input layer with binding id 1
-    OverrideInputRangeVisitor overrideInputRangeVisitorLayer1(guidToRangesMap, 1, minMaxRange);
+    OverrideInputRangeVisitor overrideInputRangeVisitorLayer1(ranges, 1, minMaxRange);
     VisitLayers(inputLayers, overrideInputRangeVisitorLayer1);
 
     // Check that the map of ranges has been populated
-    BOOST_CHECK(!guidToRangesMap.empty());
+    BOOST_CHECK(!ranges.IsEmpty());
 
     // Check that an entry for the input layer with binding id 0 does not exist
-    BOOST_CHECK(guidToRangesMap.find(input0->GetGuid()) == guidToRangesMap.end());
+    BOOST_CHECK(!ranges.HasRanges(input0->GetGuid()));
 
     // Check that an entry for the input layer with binding id 1 exists
-    BOOST_CHECK(guidToRangesMap.find(input1->GetGuid()) != guidToRangesMap.end());
-
-    // Check that at least a value has been added for the input layer with binding id 1
-    BOOST_CHECK(!guidToRangesMap[input1->GetGuid()].empty());
+    BOOST_CHECK(ranges.HasRanges(input1->GetGuid()));
 
     // Check the the overridden values are what we intended to set
-    BOOST_CHECK(guidToRangesMap[input1->GetGuid()].at(0).first  == minMaxRange.first);
-    BOOST_CHECK(guidToRangesMap[input1->GetGuid()].at(0).second == minMaxRange.second);
+    BOOST_CHECK(ranges.GetRange(input1->GetGuid(), 0) == minMaxRange);
 }
 
 INetworkPtr CreateNetworkWithFullyConnectedLayer(const bool biasEnabled)