blob: 227a105bbacbb939974804578649f80d0612de4f [file] [log] [blame]
Jim Flynn3091b062019-02-15 14:45:04 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "CommandLineProcessor.hpp"
7#include <armnnDeserializer/IDeserializer.hpp>
Jim Flynnf92dfce2019-05-02 11:33:25 +01008#include <armnnQuantizer/INetworkQuantizer.hpp>
Jim Flynn3091b062019-02-15 14:45:04 +00009#include <armnnSerializer/ISerializer.hpp>
Nina Drozd59e15b02019-04-25 15:45:20 +010010#include "QuantizationDataSet.hpp"
11#include "QuantizationInput.hpp"
Jim Flynn3091b062019-02-15 14:45:04 +000012
Sadik Armagan2b03d642019-04-12 15:17:02 +010013#include <algorithm>
Jim Flynn3091b062019-02-15 14:45:04 +000014#include <fstream>
Sadik Armagan2b03d642019-04-12 15:17:02 +010015#include <iostream>
Jim Flynn3091b062019-02-15 14:45:04 +000016
17int main(int argc, char* argv[])
18{
19 armnnQuantizer::CommandLineProcessor cmdline;
20 if (!cmdline.ProcessCommandLine(argc, argv))
21 {
22 return -1;
23 }
24 armnnDeserializer::IDeserializerPtr parser = armnnDeserializer::IDeserializer::Create();
25 std::ifstream inputFileStream(cmdline.GetInputFileName(), std::ios::binary);
26 std::vector<std::uint8_t> binaryContent;
27 while (inputFileStream)
28 {
29 char c;
30 inputFileStream.get(c);
31 if (inputFileStream)
32 {
33 binaryContent.push_back(static_cast<std::uint8_t>(c));
34 }
35 }
36 inputFileStream.close();
Sadik Armagandc2f7f42019-04-26 17:11:47 +010037
38 armnn::QuantizerOptions quantizerOptions;
39 quantizerOptions.m_ActivationFormat = cmdline.GetQuantizationScheme() == "QSymm16"
40 ? armnn::DataType::QuantisedSymm16
41 : armnn::DataType::QuantisedAsymm8;
42
Éanna Ó Catháin5696bff2019-05-10 13:29:13 +010043 quantizerOptions.m_PreserveType = cmdline.HasPreservedDataType();
44
Jim Flynn3091b062019-02-15 14:45:04 +000045 armnn::INetworkPtr network = parser->CreateNetworkFromBinary(binaryContent);
Sadik Armagandc2f7f42019-04-26 17:11:47 +010046 armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get(), quantizerOptions);
Jim Flynn3091b062019-02-15 14:45:04 +000047
Nina Drozd59e15b02019-04-25 15:45:20 +010048 if (cmdline.HasQuantizationData())
Sadik Armagan2b03d642019-04-12 15:17:02 +010049 {
Nina Drozd59e15b02019-04-25 15:45:20 +010050 armnnQuantizer::QuantizationDataSet dataSet = cmdline.GetQuantizationDataSet();
51 if (!dataSet.IsEmpty())
Sadik Armagan2b03d642019-04-12 15:17:02 +010052 {
Nina Drozd59e15b02019-04-25 15:45:20 +010053 // Get the Input Tensor Infos
54 armnnQuantizer::InputLayerVisitor inputLayerVisitor;
55 network->Accept(inputLayerVisitor);
56
Jim Flynnf92dfce2019-05-02 11:33:25 +010057 for (armnnQuantizer::QuantizationInput quantizationInput : dataSet)
Sadik Armagan2b03d642019-04-12 15:17:02 +010058 {
Nina Drozd59e15b02019-04-25 15:45:20 +010059 armnn::InputTensors inputTensors;
60 std::vector<std::vector<float>> inputData(quantizationInput.GetNumberOfInputs());
61 std::vector<armnn::LayerBindingId> layerBindingIds = quantizationInput.GetLayerBindingIds();
62 unsigned int count = 0;
63 for (armnn::LayerBindingId layerBindingId : quantizationInput.GetLayerBindingIds())
64 {
65 armnn::TensorInfo tensorInfo = inputLayerVisitor.GetTensorInfo(layerBindingId);
66 inputData[count] = quantizationInput.GetDataForEntry(layerBindingId);
67 armnn::ConstTensor inputTensor(tensorInfo, inputData[count].data());
68 inputTensors.push_back(std::make_pair(layerBindingId, inputTensor));
69 count++;
70 }
71 quantizer->Refine(inputTensors);
Sadik Armagan2b03d642019-04-12 15:17:02 +010072 }
Sadik Armagan2b03d642019-04-12 15:17:02 +010073 }
Sadik Armagan2b03d642019-04-12 15:17:02 +010074 }
75
76 armnn::INetworkPtr quantizedNetwork = quantizer->ExportNetwork();
Jim Flynn3091b062019-02-15 14:45:04 +000077 armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
78 serializer->Serialize(*quantizedNetwork);
79
80 std::string output(cmdline.GetOutputDirectoryName());
81 output.append(cmdline.GetOutputFileName());
82 std::ofstream outputFileStream;
83 outputFileStream.open(output);
84 serializer->SaveSerializedToStream(outputFileStream);
85 outputFileStream.flush();
86 outputFileStream.close();
87
88 return 0;
89}