blob: 49652efe25b8945af59b50874a9cf842ab1ca835 [file] [log] [blame]
Jim Flynn3091b062019-02-15 14:45:04 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "CommandLineProcessor.hpp"
7#include <armnnDeserializer/IDeserializer.hpp>
Jim Flynnf92dfce2019-05-02 11:33:25 +01008#include <armnnQuantizer/INetworkQuantizer.hpp>
Jim Flynn3091b062019-02-15 14:45:04 +00009#include <armnnSerializer/ISerializer.hpp>
Nina Drozd59e15b02019-04-25 15:45:20 +010010#include "QuantizationDataSet.hpp"
11#include "QuantizationInput.hpp"
Jim Flynn3091b062019-02-15 14:45:04 +000012
Sadik Armagan2b03d642019-04-12 15:17:02 +010013#include <algorithm>
Jim Flynn3091b062019-02-15 14:45:04 +000014#include <fstream>
Sadik Armagan2b03d642019-04-12 15:17:02 +010015#include <iostream>
Jim Flynn3091b062019-02-15 14:45:04 +000016
17int main(int argc, char* argv[])
18{
19 armnnQuantizer::CommandLineProcessor cmdline;
20 if (!cmdline.ProcessCommandLine(argc, argv))
21 {
22 return -1;
23 }
24 armnnDeserializer::IDeserializerPtr parser = armnnDeserializer::IDeserializer::Create();
25 std::ifstream inputFileStream(cmdline.GetInputFileName(), std::ios::binary);
26 std::vector<std::uint8_t> binaryContent;
27 while (inputFileStream)
28 {
29 char c;
30 inputFileStream.get(c);
31 if (inputFileStream)
32 {
33 binaryContent.push_back(static_cast<std::uint8_t>(c));
34 }
35 }
36 inputFileStream.close();
Sadik Armagandc2f7f42019-04-26 17:11:47 +010037
38 armnn::QuantizerOptions quantizerOptions;
Francis Murtaghddb1d062020-03-10 13:51:45 +000039
40 if (cmdline.GetQuantizationScheme() == "QAsymmS8")
41 {
42 quantizerOptions.m_ActivationFormat = armnn::DataType::QAsymmS8;
43 }
44 else if (cmdline.GetQuantizationScheme() == "QSymmS16")
45 {
46 quantizerOptions.m_ActivationFormat = armnn::DataType::QSymmS16;
47 }
48 else
49 {
50 quantizerOptions.m_ActivationFormat = armnn::DataType::QAsymmU8;
51 }
Sadik Armagandc2f7f42019-04-26 17:11:47 +010052
Éanna Ó Catháin5696bff2019-05-10 13:29:13 +010053 quantizerOptions.m_PreserveType = cmdline.HasPreservedDataType();
54
Jim Flynn3091b062019-02-15 14:45:04 +000055 armnn::INetworkPtr network = parser->CreateNetworkFromBinary(binaryContent);
Sadik Armagandc2f7f42019-04-26 17:11:47 +010056 armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get(), quantizerOptions);
Jim Flynn3091b062019-02-15 14:45:04 +000057
Nina Drozd59e15b02019-04-25 15:45:20 +010058 if (cmdline.HasQuantizationData())
Sadik Armagan2b03d642019-04-12 15:17:02 +010059 {
Nina Drozd59e15b02019-04-25 15:45:20 +010060 armnnQuantizer::QuantizationDataSet dataSet = cmdline.GetQuantizationDataSet();
61 if (!dataSet.IsEmpty())
Sadik Armagan2b03d642019-04-12 15:17:02 +010062 {
Nina Drozd59e15b02019-04-25 15:45:20 +010063 // Get the Input Tensor Infos
Finn Williamsb454c5c2021-02-09 15:56:23 +000064 armnnQuantizer::InputLayerStrategy inputLayerStrategy;
65 network->ExecuteStrategy(inputLayerStrategy);
Nina Drozd59e15b02019-04-25 15:45:20 +010066
Jim Flynnf92dfce2019-05-02 11:33:25 +010067 for (armnnQuantizer::QuantizationInput quantizationInput : dataSet)
Sadik Armagan2b03d642019-04-12 15:17:02 +010068 {
Nina Drozd59e15b02019-04-25 15:45:20 +010069 armnn::InputTensors inputTensors;
70 std::vector<std::vector<float>> inputData(quantizationInput.GetNumberOfInputs());
71 std::vector<armnn::LayerBindingId> layerBindingIds = quantizationInput.GetLayerBindingIds();
72 unsigned int count = 0;
73 for (armnn::LayerBindingId layerBindingId : quantizationInput.GetLayerBindingIds())
74 {
Finn Williamsb454c5c2021-02-09 15:56:23 +000075 armnn::TensorInfo tensorInfo = inputLayerStrategy.GetTensorInfo(layerBindingId);
Nina Drozd59e15b02019-04-25 15:45:20 +010076 inputData[count] = quantizationInput.GetDataForEntry(layerBindingId);
77 armnn::ConstTensor inputTensor(tensorInfo, inputData[count].data());
78 inputTensors.push_back(std::make_pair(layerBindingId, inputTensor));
79 count++;
80 }
81 quantizer->Refine(inputTensors);
Sadik Armagan2b03d642019-04-12 15:17:02 +010082 }
Sadik Armagan2b03d642019-04-12 15:17:02 +010083 }
Sadik Armagan2b03d642019-04-12 15:17:02 +010084 }
85
86 armnn::INetworkPtr quantizedNetwork = quantizer->ExportNetwork();
Jim Flynn3091b062019-02-15 14:45:04 +000087 armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
88 serializer->Serialize(*quantizedNetwork);
89
90 std::string output(cmdline.GetOutputDirectoryName());
91 output.append(cmdline.GetOutputFileName());
92 std::ofstream outputFileStream;
93 outputFileStream.open(output);
94 serializer->SaveSerializedToStream(outputFileStream);
95 outputFileStream.flush();
96 outputFileStream.close();
97
98 return 0;
99}