blob: 96d6515ba0fa954e39c7f462d72778d2400eaef3 [file] [log] [blame]
Jim Flynn3091b062019-02-15 14:45:04 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "CommandLineProcessor.hpp"
7#include <armnnDeserializer/IDeserializer.hpp>
Jim Flynnf92dfce2019-05-02 11:33:25 +01008#include <armnnQuantizer/INetworkQuantizer.hpp>
Jim Flynn3091b062019-02-15 14:45:04 +00009#include <armnnSerializer/ISerializer.hpp>
Nina Drozd59e15b02019-04-25 15:45:20 +010010#include "QuantizationDataSet.hpp"
11#include "QuantizationInput.hpp"
Jim Flynn3091b062019-02-15 14:45:04 +000012
Sadik Armagan2b03d642019-04-12 15:17:02 +010013#include <algorithm>
Jim Flynn3091b062019-02-15 14:45:04 +000014#include <fstream>
Sadik Armagan2b03d642019-04-12 15:17:02 +010015#include <iostream>
Jim Flynn3091b062019-02-15 14:45:04 +000016
17int main(int argc, char* argv[])
18{
19 armnnQuantizer::CommandLineProcessor cmdline;
20 if (!cmdline.ProcessCommandLine(argc, argv))
21 {
22 return -1;
23 }
24 armnnDeserializer::IDeserializerPtr parser = armnnDeserializer::IDeserializer::Create();
25 std::ifstream inputFileStream(cmdline.GetInputFileName(), std::ios::binary);
26 std::vector<std::uint8_t> binaryContent;
27 while (inputFileStream)
28 {
29 char c;
30 inputFileStream.get(c);
31 if (inputFileStream)
32 {
33 binaryContent.push_back(static_cast<std::uint8_t>(c));
34 }
35 }
36 inputFileStream.close();
Sadik Armagandc2f7f42019-04-26 17:11:47 +010037
38 armnn::QuantizerOptions quantizerOptions;
39 quantizerOptions.m_ActivationFormat = cmdline.GetQuantizationScheme() == "QSymm16"
40 ? armnn::DataType::QuantisedSymm16
41 : armnn::DataType::QuantisedAsymm8;
42
Jim Flynn3091b062019-02-15 14:45:04 +000043 armnn::INetworkPtr network = parser->CreateNetworkFromBinary(binaryContent);
Sadik Armagandc2f7f42019-04-26 17:11:47 +010044 armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get(), quantizerOptions);
Jim Flynn3091b062019-02-15 14:45:04 +000045
Nina Drozd59e15b02019-04-25 15:45:20 +010046 if (cmdline.HasQuantizationData())
Sadik Armagan2b03d642019-04-12 15:17:02 +010047 {
Nina Drozd59e15b02019-04-25 15:45:20 +010048 armnnQuantizer::QuantizationDataSet dataSet = cmdline.GetQuantizationDataSet();
49 if (!dataSet.IsEmpty())
Sadik Armagan2b03d642019-04-12 15:17:02 +010050 {
Nina Drozd59e15b02019-04-25 15:45:20 +010051 // Get the Input Tensor Infos
52 armnnQuantizer::InputLayerVisitor inputLayerVisitor;
53 network->Accept(inputLayerVisitor);
54
Jim Flynnf92dfce2019-05-02 11:33:25 +010055 for (armnnQuantizer::QuantizationInput quantizationInput : dataSet)
Sadik Armagan2b03d642019-04-12 15:17:02 +010056 {
Nina Drozd59e15b02019-04-25 15:45:20 +010057 armnn::InputTensors inputTensors;
58 std::vector<std::vector<float>> inputData(quantizationInput.GetNumberOfInputs());
59 std::vector<armnn::LayerBindingId> layerBindingIds = quantizationInput.GetLayerBindingIds();
60 unsigned int count = 0;
61 for (armnn::LayerBindingId layerBindingId : quantizationInput.GetLayerBindingIds())
62 {
63 armnn::TensorInfo tensorInfo = inputLayerVisitor.GetTensorInfo(layerBindingId);
64 inputData[count] = quantizationInput.GetDataForEntry(layerBindingId);
65 armnn::ConstTensor inputTensor(tensorInfo, inputData[count].data());
66 inputTensors.push_back(std::make_pair(layerBindingId, inputTensor));
67 count++;
68 }
69 quantizer->Refine(inputTensors);
Sadik Armagan2b03d642019-04-12 15:17:02 +010070 }
Sadik Armagan2b03d642019-04-12 15:17:02 +010071 }
Sadik Armagan2b03d642019-04-12 15:17:02 +010072 }
73
74 armnn::INetworkPtr quantizedNetwork = quantizer->ExportNetwork();
Jim Flynn3091b062019-02-15 14:45:04 +000075 armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
76 serializer->Serialize(*quantizedNetwork);
77
78 std::string output(cmdline.GetOutputDirectoryName());
79 output.append(cmdline.GetOutputFileName());
80 std::ofstream outputFileStream;
81 outputFileStream.open(output);
82 serializer->SaveSerializedToStream(outputFileStream);
83 outputFileStream.flush();
84 outputFileStream.close();
85
86 return 0;
87}