IVGCVSW-3061 Modify NetworkQuantizer to support option to preserve input/output types

* Also add unit tests for new preserve type option

Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com>
Change-Id: I860759072f2e3546698118d1bcd5e79eb4e805ec
diff --git a/include/armnnQuantizer/INetworkQuantizer.hpp b/include/armnnQuantizer/INetworkQuantizer.hpp
index 89548d1..826b077 100644
--- a/include/armnnQuantizer/INetworkQuantizer.hpp
+++ b/include/armnnQuantizer/INetworkQuantizer.hpp
@@ -14,10 +14,16 @@
 
 struct QuantizerOptions
 {
-    QuantizerOptions() : m_ActivationFormat(DataType::QuantisedAsymm8) {}
-    QuantizerOptions(DataType activationFormat) : m_ActivationFormat(activationFormat) {}
+    QuantizerOptions() : QuantizerOptions(DataType::QuantisedAsymm8, false) {}
+
+    QuantizerOptions(DataType activationFormat) : QuantizerOptions(activationFormat, false) {}
+
+    QuantizerOptions(DataType activationFormat, bool preserveType)
+    : m_ActivationFormat(activationFormat)
+    , m_PreserveType(preserveType) {}
 
     DataType m_ActivationFormat;
+    bool m_PreserveType;
 };
 
 using INetworkQuantizerPtr = std::unique_ptr<class INetworkQuantizer, void(*)(INetworkQuantizer* quantizer)>;
diff --git a/src/armnn/NetworkQuantizer.cpp b/src/armnn/NetworkQuantizer.cpp
index 12e459d..f308d54 100644
--- a/src/armnn/NetworkQuantizer.cpp
+++ b/src/armnn/NetworkQuantizer.cpp
@@ -171,7 +171,7 @@
             throw InvalidArgumentException("Unsupported quantization target");
     }
 
-    QuantizerVisitor quantizerVisitor(m_Ranges, quantizationScheme.get());
+    QuantizerVisitor quantizerVisitor(m_Ranges, quantizationScheme.get(), m_Options.m_PreserveType);
     VisitLayers(graph, quantizerVisitor);
 
     // clear the ranges
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp
index 919eda1..61e0e60 100644
--- a/src/armnn/QuantizerVisitor.cpp
+++ b/src/armnn/QuantizerVisitor.cpp
@@ -11,10 +11,13 @@
 namespace armnn
 {
 
-QuantizerVisitor::QuantizerVisitor(const RangeTracker& rangeTracker, const IQuantizationScheme* quantizationScheme)
+QuantizerVisitor::QuantizerVisitor(const RangeTracker& rangeTracker,
+                                   const IQuantizationScheme* quantizationScheme,
+                                   bool preserveType)
     : m_Ranges(rangeTracker)
     , m_QuantizedNetwork(INetwork::Create())
     , m_QuantizationScheme(quantizationScheme)
+    , m_PreserveType(preserveType)
 {
 }
 
@@ -106,15 +109,41 @@
 
 void QuantizerVisitor::VisitInputLayer(const IConnectableLayer *layer, LayerBindingId id, const char *name)
 {
-    IConnectableLayer* newLayer = m_QuantizedNetwork->AddInputLayer(id, name);
-    RecordLayer(layer, newLayer);
+    const DataType dataType = layer->GetOutputSlot(0).GetTensorInfo().GetDataType();
+    IConnectableLayer* inputLayer = m_QuantizedNetwork->AddInputLayer(id, name);
+
+    if (m_PreserveType && (dataType == DataType::Float32 || dataType == DataType::Float16))
+    {
+        IConnectableLayer* quantizeLayer = m_QuantizedNetwork->AddQuantizeLayer();
+        inputLayer->GetOutputSlot(0).Connect(quantizeLayer->GetInputSlot(0));
+        inputLayer->GetOutputSlot(0).SetTensorInfo(layer->GetOutputSlot(0).GetTensorInfo());
+        RecordLayer(layer, quantizeLayer);
+    }
+    else
+    {
+        RecordLayer(layer, inputLayer);
+    }
 }
 
 void QuantizerVisitor::VisitOutputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name)
 {
-    IConnectableLayer* newLayer = m_QuantizedNetwork->AddOutputLayer(id, name);
-    RecordLayer(layer, newLayer);
-    SetQuantizedInputConnections(layer, newLayer);
+    const TensorInfo& info = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+    const DataType& dataType = info.GetDataType();
+    IConnectableLayer* outputLayer = m_QuantizedNetwork->AddOutputLayer(id, name);
+
+    if (m_PreserveType  && (dataType == DataType::Float32 || dataType == DataType::Float16))
+    {
+        IConnectableLayer* dequantizeLayer = m_QuantizedNetwork->AddDequantizeLayer();
+        RecordLayer(layer, dequantizeLayer);
+        SetQuantizedInputConnections(layer, dequantizeLayer);
+        dequantizeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+        dequantizeLayer->GetOutputSlot(0).SetTensorInfo(info);
+    }
+    else
+    {
+        RecordLayer(layer, outputLayer);
+        SetQuantizedInputConnections(layer, outputLayer);
+    }
 }
 
 void QuantizerVisitor::VisitBatchNormalizationLayer(const IConnectableLayer* layer,
diff --git a/src/armnn/QuantizerVisitor.hpp b/src/armnn/QuantizerVisitor.hpp
index eb9ebac..300ac16 100644
--- a/src/armnn/QuantizerVisitor.hpp
+++ b/src/armnn/QuantizerVisitor.hpp
@@ -25,7 +25,10 @@
 class QuantizerVisitor : public LayerVisitorBase<VisitorNoThrowPolicy>
 {
 public:
-    QuantizerVisitor(const RangeTracker& rangeTracker, const IQuantizationScheme* quantizationScheme);
+    QuantizerVisitor(const RangeTracker& rangeTracker,
+                     const IQuantizationScheme* quantizationScheme,
+                     bool preserveType = false);
+
     ~QuantizerVisitor() = default;
 
     /// Functions to quantize the individual layers, overridden from ILayerVisitor
@@ -132,6 +135,8 @@
     std::unordered_map<LayerGuid, IConnectableLayer*> m_QuantizedGuidToLayerMap;
 
     const IQuantizationScheme* m_QuantizationScheme;
+
+    const bool m_PreserveType;
 };
 
 } //namespace armnn
diff --git a/src/armnn/test/QuantizerTest.cpp b/src/armnn/test/QuantizerTest.cpp
index 259e90f..2103de0 100644
--- a/src/armnn/test/QuantizerTest.cpp
+++ b/src/armnn/test/QuantizerTest.cpp
@@ -38,15 +38,15 @@
 public:
     TestQuantization(const TensorShape& inputShape, const TensorShape& outputShape)
     : LayerVisitorBase<VisitorThrowingPolicy>()
-    , m_QuantizerOptions(QuantizerOptions())
     , m_InputShape(inputShape)
-    , m_OutputShape(outputShape) {}
+    , m_OutputShape(outputShape)
+    , m_QuantizerOptions(QuantizerOptions()) {}
 
     TestQuantization(const QuantizerOptions& options, const TensorShape& inputShape, const TensorShape& outputShape)
     : LayerVisitorBase<VisitorThrowingPolicy>()
-    , m_QuantizerOptions(options)
     , m_InputShape(inputShape)
-    , m_OutputShape(outputShape) {}
+    , m_OutputShape(outputShape)
+    , m_QuantizerOptions(options) {}
 
     void VisitInputLayer(const IConnectableLayer* layer,
                          LayerBindingId id,
@@ -91,6 +91,9 @@
         TestQuantizationParamsImpl(info, DataType::QuantisedAsymm8, params.first, params.second);
     }
 
+    TensorShape m_InputShape;
+    TensorShape m_OutputShape;
+
 private:
     void TestQuantizationParamsImpl(const TensorInfo& info, DataType dataType, float scale, int32_t offset)
     {
@@ -100,8 +103,6 @@
     }
 
     QuantizerOptions m_QuantizerOptions;
-    TensorShape m_InputShape;
-    TensorShape m_OutputShape;
 };
 
 void VisitLayersTopologically(const INetwork* inputNetwork, ILayerVisitor& visitor)
@@ -1574,5 +1575,104 @@
     BOOST_CHECK_EQUAL(SetupQuantize(-1 * std::numeric_limits<float>::infinity())[0], 0);
 }
 
+class TestPreserveType : public TestAdditionQuantization
+{
+public:
+    TestPreserveType(const QuantizerOptions& options,
+                     const DataType& dataType,
+                     const TensorShape& inputShape,
+                     const TensorShape& outputShape)
+    : TestAdditionQuantization(options, inputShape, outputShape)
+    , m_DataType(dataType)
+    , m_VisitedQuantizeLayer(false)
+    , m_VisitedDequantizeLayer(false) {}
+
+    void VisitInputLayer(const IConnectableLayer* layer,
+                         LayerBindingId id,
+                         const char* name = nullptr) override
+    {
+        const TensorInfo& info = layer->GetOutputSlot(0).GetTensorInfo();
+        BOOST_TEST(GetDataTypeName(info.GetDataType()) == GetDataTypeName(m_DataType));
+        BOOST_TEST(m_InputShape == info.GetShape());
+    }
+
+    void VisitOutputLayer(const IConnectableLayer* layer,
+                          LayerBindingId id,
+                          const char* name = nullptr) override
+    {
+        const TensorInfo& info = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+        BOOST_TEST(GetDataTypeName(info.GetDataType()) == GetDataTypeName(m_DataType));
+        BOOST_TEST(m_OutputShape == info.GetShape());
+    }
+
+    void VisitQuantizeLayer(const IConnectableLayer* layer,
+                            const char* name = nullptr) override
+    {
+        m_VisitedQuantizeLayer = true;
+    }
+
+    void VisitDequantizeLayer(const IConnectableLayer* layer,
+                              const char* name = nullptr) override
+    {
+        m_VisitedDequantizeLayer = true;
+    }
+
+    void CheckQuantizeDequantizeLayerVisited(bool expected)
+    {
+        if (expected)
+        {
+            BOOST_CHECK(m_VisitedQuantizeLayer);
+            BOOST_CHECK(m_VisitedDequantizeLayer);
+        }
+        else
+        {
+            BOOST_CHECK(!m_VisitedQuantizeLayer);
+            BOOST_CHECK(!m_VisitedDequantizeLayer);
+        }
+    }
+private:
+    const DataType m_DataType;
+    bool m_VisitedQuantizeLayer;
+    bool m_VisitedDequantizeLayer;
+};
+
+void PreserveTypeTestImpl(const DataType& dataType)
+{
+    INetworkPtr network = INetwork::Create();
+
+    // Add the layers
+    IConnectableLayer* input0 = network->AddInputLayer(0);
+    IConnectableLayer* input1 = network->AddInputLayer(1);
+    IConnectableLayer* addition = network->AddAdditionLayer();
+    IConnectableLayer* output = network->AddOutputLayer(2);
+
+    input0->GetOutputSlot(0).Connect(addition->GetInputSlot(0));
+    input1->GetOutputSlot(0).Connect(addition->GetInputSlot(1));
+    addition->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+    const TensorShape shape{1U, 2U, 3U};
+    const TensorInfo info(shape, dataType);
+    input0->GetOutputSlot(0).SetTensorInfo(info);
+    input1->GetOutputSlot(0).SetTensorInfo(info);
+    addition->GetOutputSlot(0).SetTensorInfo(info);
+
+    const QuantizerOptions options(DataType::QuantisedAsymm8, true);
+    INetworkPtr quantizedNetworkQAsymm8 = INetworkQuantizer::Create(network.get(), options)->ExportNetwork();
+    TestPreserveType validatorQAsymm8(options, dataType, shape, shape);
+    VisitLayersTopologically(quantizedNetworkQAsymm8.get(), validatorQAsymm8);
+    validatorQAsymm8.CheckQuantizeDequantizeLayerVisited(
+        dataType == DataType::Float32 || dataType == DataType::Float16);
+}
+
+BOOST_AUTO_TEST_CASE(PreserveTypeFloat32)
+{
+    PreserveTypeTestImpl(DataType::Float32);
+}
+
+BOOST_AUTO_TEST_CASE(PreserveTypeQAsymm8)
+{
+    PreserveTypeTestImpl(DataType::QuantisedAsymm8);
+}
+
 BOOST_AUTO_TEST_SUITE_END()
 } // namespace armnn