blob: 99fc021a514c4e9d7c7ae934b999c143243ba4c4 [file] [log] [blame]
Nina Drozd59e15b02019-04-25 15:45:20 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "QuantizationDataSet.hpp"
Nina Drozd59e15b02019-04-25 15:45:20 +01007
James Ward5ea9f312020-10-29 16:19:02 +00008#include <fmt/format.h>
9
10#include <armnn/utility/StringUtils.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
Francis Murtagh532a29d2020-06-29 11:50:01 +010012#include <Filesystem.hpp>
Nina Drozd59e15b02019-04-25 15:45:20 +010013
14namespace armnnQuantizer
15{
16
17QuantizationDataSet::QuantizationDataSet()
18{
19}
20
21QuantizationDataSet::QuantizationDataSet(const std::string csvFilePath):
22 m_QuantizationInputs(),
23 m_CsvFilePath(csvFilePath)
24{
25 ParseCsvFile();
26}
27
28void AddInputData(unsigned int passId,
29 armnn::LayerBindingId bindingId,
30 const std::string& inputFilePath,
31 std::map<unsigned int, QuantizationInput>& passIdToQuantizationInput)
32{
33 auto iterator = passIdToQuantizationInput.find(passId);
34 if (iterator == passIdToQuantizationInput.end())
35 {
36 QuantizationInput input(passId, bindingId, inputFilePath);
37 passIdToQuantizationInput.emplace(passId, input);
38 }
39 else
40 {
41 auto existingQuantizationInput = iterator->second;
42 existingQuantizationInput.AddEntry(bindingId, inputFilePath);
43 }
44}
45
46QuantizationDataSet::~QuantizationDataSet()
47{
48}
49
Finn Williamsb454c5c2021-02-09 15:56:23 +000050
51/// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine.
52
53void InputLayerStrategy::ExecuteStrategy(const armnn::IConnectableLayer* layer,
54 const armnn::BaseDescriptor& descriptor,
55 const std::vector<armnn::ConstTensor>& constants,
56 const char* name,
57 const armnn::LayerBindingId id)
58{
59 armnn::IgnoreUnused(name, descriptor, constants);
60
61 m_TensorInfos.emplace(id, layer->GetOutputSlot(0).GetTensorInfo());
62}
63
64
65
66
67armnn::TensorInfo InputLayerStrategy::GetTensorInfo(armnn::LayerBindingId layerBindingId)
68{
69 auto iterator = m_TensorInfos.find(layerBindingId);
70 if (iterator != m_TensorInfos.end())
71 {
72 return m_TensorInfos.at(layerBindingId);
73 }
74 else
75 {
76 throw armnn::Exception("Could not retrieve tensor info for binding ID " + std::to_string(layerBindingId));
77 }
78}
79
Nina Drozd59e15b02019-04-25 15:45:20 +010080void InputLayerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer,
81 armnn::LayerBindingId id,
82 const char* name)
83{
Jan Eilers8eb25602020-03-09 12:13:48 +000084 armnn::IgnoreUnused(name);
Narumol Prangnawaratcd05f3e2019-05-10 17:19:58 +010085 m_TensorInfos.emplace(id, layer->GetOutputSlot(0).GetTensorInfo());
Nina Drozd59e15b02019-04-25 15:45:20 +010086}
87
88armnn::TensorInfo InputLayerVisitor::GetTensorInfo(armnn::LayerBindingId layerBindingId)
89{
90 auto iterator = m_TensorInfos.find(layerBindingId);
91 if (iterator != m_TensorInfos.end())
92 {
93 return m_TensorInfos.at(layerBindingId);
94 }
95 else
96 {
97 throw armnn::Exception("Could not retrieve tensor info for binding ID " + std::to_string(layerBindingId));
98 }
99}
100
101
James Ward5ea9f312020-10-29 16:19:02 +0000102unsigned int GetPassIdFromCsvRow(std::vector<std::string> tokens, unsigned int lineIndex)
Nina Drozd59e15b02019-04-25 15:45:20 +0100103{
104 unsigned int passId;
105 try
106 {
James Ward5ea9f312020-10-29 16:19:02 +0000107 passId = static_cast<unsigned int>(std::stoi(tokens[0]));
Nina Drozd59e15b02019-04-25 15:45:20 +0100108 }
Matthew Bentham31b2e132019-05-22 17:20:55 +0100109 catch (const std::invalid_argument&)
Nina Drozd59e15b02019-04-25 15:45:20 +0100110 {
James Ward5ea9f312020-10-29 16:19:02 +0000111 throw armnn::ParseException(fmt::format("Pass ID [{}] is not correct format on CSV row {}",
112 tokens[0], lineIndex));
Nina Drozd59e15b02019-04-25 15:45:20 +0100113 }
114 return passId;
115}
116
James Ward5ea9f312020-10-29 16:19:02 +0000117armnn::LayerBindingId GetBindingIdFromCsvRow(std::vector<std::string> tokens, unsigned int lineIndex)
Nina Drozd59e15b02019-04-25 15:45:20 +0100118{
119 armnn::LayerBindingId bindingId;
120 try
121 {
James Ward5ea9f312020-10-29 16:19:02 +0000122 bindingId = std::stoi(tokens[1]);
Nina Drozd59e15b02019-04-25 15:45:20 +0100123 }
Matthew Bentham31b2e132019-05-22 17:20:55 +0100124 catch (const std::invalid_argument&)
Nina Drozd59e15b02019-04-25 15:45:20 +0100125 {
James Ward5ea9f312020-10-29 16:19:02 +0000126 throw armnn::ParseException(fmt::format("Binding ID [{}] is not correct format on CSV row {}",
127 tokens[1], lineIndex));
Nina Drozd59e15b02019-04-25 15:45:20 +0100128 }
129 return bindingId;
130}
131
James Ward5ea9f312020-10-29 16:19:02 +0000132std::string GetFileNameFromCsvRow(std::vector<std::string> tokens, unsigned int lineIndex)
Nina Drozd59e15b02019-04-25 15:45:20 +0100133{
James Ward5ea9f312020-10-29 16:19:02 +0000134 std::string fileName = armnn::stringUtils::StringTrim(tokens[2]);
Nina Drozd59e15b02019-04-25 15:45:20 +0100135
Francis Murtagh532a29d2020-06-29 11:50:01 +0100136 if (!fs::exists(fileName))
Nina Drozd59e15b02019-04-25 15:45:20 +0100137 {
James Ward5ea9f312020-10-29 16:19:02 +0000138 throw armnn::ParseException(fmt::format("File [{}] provided on CSV row {} does not exist.",
139 fileName, lineIndex));
Nina Drozd59e15b02019-04-25 15:45:20 +0100140 }
141
142 if (fileName.empty())
143 {
James Ward5ea9f312020-10-29 16:19:02 +0000144 throw armnn::ParseException(fmt::format("Filename cannot be empty on CSV row {} ", lineIndex));
Nina Drozd59e15b02019-04-25 15:45:20 +0100145 }
146 return fileName;
147}
148
149
150void QuantizationDataSet::ParseCsvFile()
151{
152 std::map<unsigned int, QuantizationInput> passIdToQuantizationInput;
Nina Drozd59e15b02019-04-25 15:45:20 +0100153
154 if (m_CsvFilePath == "")
155 {
156 throw armnn::Exception("CSV file not specified.");
157 }
158
James Ward5ea9f312020-10-29 16:19:02 +0000159 std::ifstream inf (m_CsvFilePath.c_str());
160 std::string line;
161 std::vector<std::string> tokens;
162 unsigned int lineIndex = 0;
163
164 if (!inf)
Nina Drozd59e15b02019-04-25 15:45:20 +0100165 {
James Ward5ea9f312020-10-29 16:19:02 +0000166 throw armnn::Exception(fmt::format("CSV file {} not found.", m_CsvFilePath));
Nina Drozd59e15b02019-04-25 15:45:20 +0100167 }
168
James Ward5ea9f312020-10-29 16:19:02 +0000169 while (getline(inf, line))
Nina Drozd59e15b02019-04-25 15:45:20 +0100170 {
James Ward5ea9f312020-10-29 16:19:02 +0000171 tokens = armnn::stringUtils::StringTokenizer(line, ",");
172
173 if (tokens.size() != 3)
Nina Drozd59e15b02019-04-25 15:45:20 +0100174 {
James Ward5ea9f312020-10-29 16:19:02 +0000175 throw armnn::Exception(fmt::format("CSV file [{}] does not have correct number of entries" \
176 "on line {}. Expected 3 entries but was {}.",
177 m_CsvFilePath, lineIndex, tokens.size()));
178
Nina Drozd59e15b02019-04-25 15:45:20 +0100179 }
180
James Ward5ea9f312020-10-29 16:19:02 +0000181 unsigned int passId = GetPassIdFromCsvRow(tokens, lineIndex);
182 armnn::LayerBindingId bindingId = GetBindingIdFromCsvRow(tokens, lineIndex);
183 std::string rawFileName = GetFileNameFromCsvRow(tokens, lineIndex);
Nina Drozd59e15b02019-04-25 15:45:20 +0100184
185 AddInputData(passId, bindingId, rawFileName, passIdToQuantizationInput);
James Ward5ea9f312020-10-29 16:19:02 +0000186
187 ++lineIndex;
Nina Drozd59e15b02019-04-25 15:45:20 +0100188 }
189
190 if (passIdToQuantizationInput.empty())
191 {
192 throw armnn::Exception("Could not parse CSV file.");
193 }
194
195 // Once all entries in CSV file are parsed successfully and QuantizationInput map is populated, populate
196 // QuantizationInputs iterator for easier access and clear the map
197 for (auto itr = passIdToQuantizationInput.begin(); itr != passIdToQuantizationInput.end(); ++itr)
198 {
199 m_QuantizationInputs.emplace_back(itr->second);
200 }
201}
202
203}