IVGCVSW-2834 Add dynamic quantization via datasets
* Add QuantizationDataSet class for quantization data parsed from CSV file
* Add QuantizationInput for retrieving quantization data for each layer ID
* Add unit tests for command line processor and QuantizationDataSet
Change-Id: Iaf0a747b5f25a59a766ac04f7158e8cb7909d179
Signed-off-by: Nina Drozd <nina.drozd@arm.com>
diff --git a/src/armnnQuantizer/ArmNNQuantizerMain.cpp b/src/armnnQuantizer/ArmNNQuantizerMain.cpp
index 9ac8966..103597a 100644
--- a/src/armnnQuantizer/ArmNNQuantizerMain.cpp
+++ b/src/armnnQuantizer/ArmNNQuantizerMain.cpp
@@ -7,6 +7,8 @@
#include <armnnDeserializer/IDeserializer.hpp>
#include <armnn/INetworkQuantizer.hpp>
#include <armnnSerializer/ISerializer.hpp>
+#include "QuantizationDataSet.hpp"
+#include "QuantizationInput.hpp"
#include <algorithm>
#include <fstream>
@@ -41,31 +43,32 @@
armnn::INetworkPtr network = parser->CreateNetworkFromBinary(binaryContent);
armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get(), quantizerOptions);
- std::string csvFileName = cmdline.GetCsvFileName();
- if (csvFileName != "")
+ if (cmdline.HasQuantizationData())
{
- // Call the Quantizer::Refine() function which will update the min/max ranges for the quantize constants
- std::ifstream csvFileStream(csvFileName);
- std::string line;
- std::string csvDirectory = cmdline.GetCsvFileDirectory();
- while(getline(csvFileStream, line))
+ armnnQuantizer::QuantizationDataSet dataSet = cmdline.GetQuantizationDataSet();
+ if (!dataSet.IsEmpty())
{
- std::istringstream s(line);
- std::vector<std::string> row;
- std::string entry;
- while(getline(s, entry, ','))
+ // Get the Input Tensor Infos
+ armnnQuantizer::InputLayerVisitor inputLayerVisitor;
+ network->Accept(inputLayerVisitor);
+
+ for(armnnQuantizer::QuantizationInput quantizationInput : dataSet)
{
- entry.erase(std::remove(entry.begin(), entry.end(), ' '), entry.end());
- entry.erase(std::remove(entry.begin(), entry.end(), '"'), entry.end());
- row.push_back(entry);
+ armnn::InputTensors inputTensors;
+ std::vector<std::vector<float>> inputData(quantizationInput.GetNumberOfInputs());
+ std::vector<armnn::LayerBindingId> layerBindingIds = quantizationInput.GetLayerBindingIds();
+ unsigned int count = 0;
+ for (armnn::LayerBindingId layerBindingId : quantizationInput.GetLayerBindingIds())
+ {
+ armnn::TensorInfo tensorInfo = inputLayerVisitor.GetTensorInfo(layerBindingId);
+ inputData[count] = quantizationInput.GetDataForEntry(layerBindingId);
+ armnn::ConstTensor inputTensor(tensorInfo, inputData[count].data());
+ inputTensors.push_back(std::make_pair(layerBindingId, inputTensor));
+ count++;
+ }
+ quantizer->Refine(inputTensors);
}
- std::string rawFileName = cmdline.GetCsvFileDirectory() + "/" + row[2];
- // passId: row[0]
- // bindingId: row[1]
- // rawFileName: file contains the RAW input tensor data
- // LATER: Quantizer::Refine() function will be called with those arguments when it is implemented
}
- csvFileStream.close();
}
armnn::INetworkPtr quantizedNetwork = quantizer->ExportNetwork();