blob: acd301a4705314ee3556d2eb54c9603eb338ac78 [file] [log] [blame]
Nina Drozd59e15b02019-04-25 15:45:20 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "QuantizationDataSet.hpp"
Nina Drozd59e15b02019-04-25 15:45:20 +01007
James Ward5ea9f312020-10-29 16:19:02 +00008#include <fmt/format.h>
9
10#include <armnn/utility/StringUtils.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
Francis Murtagh532a29d2020-06-29 11:50:01 +010012#include <Filesystem.hpp>
Nina Drozd59e15b02019-04-25 15:45:20 +010013
14namespace armnnQuantizer
15{
16
17QuantizationDataSet::QuantizationDataSet()
18{
19}
20
21QuantizationDataSet::QuantizationDataSet(const std::string csvFilePath):
22 m_QuantizationInputs(),
23 m_CsvFilePath(csvFilePath)
24{
25 ParseCsvFile();
26}
27
28void AddInputData(unsigned int passId,
29 armnn::LayerBindingId bindingId,
30 const std::string& inputFilePath,
31 std::map<unsigned int, QuantizationInput>& passIdToQuantizationInput)
32{
33 auto iterator = passIdToQuantizationInput.find(passId);
34 if (iterator == passIdToQuantizationInput.end())
35 {
36 QuantizationInput input(passId, bindingId, inputFilePath);
37 passIdToQuantizationInput.emplace(passId, input);
38 }
39 else
40 {
41 auto existingQuantizationInput = iterator->second;
42 existingQuantizationInput.AddEntry(bindingId, inputFilePath);
43 }
44}
45
46QuantizationDataSet::~QuantizationDataSet()
47{
48}
49
50void InputLayerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer,
51 armnn::LayerBindingId id,
52 const char* name)
53{
Jan Eilers8eb25602020-03-09 12:13:48 +000054 armnn::IgnoreUnused(name);
Narumol Prangnawaratcd05f3e2019-05-10 17:19:58 +010055 m_TensorInfos.emplace(id, layer->GetOutputSlot(0).GetTensorInfo());
Nina Drozd59e15b02019-04-25 15:45:20 +010056}
57
58armnn::TensorInfo InputLayerVisitor::GetTensorInfo(armnn::LayerBindingId layerBindingId)
59{
60 auto iterator = m_TensorInfos.find(layerBindingId);
61 if (iterator != m_TensorInfos.end())
62 {
63 return m_TensorInfos.at(layerBindingId);
64 }
65 else
66 {
67 throw armnn::Exception("Could not retrieve tensor info for binding ID " + std::to_string(layerBindingId));
68 }
69}
70
71
James Ward5ea9f312020-10-29 16:19:02 +000072unsigned int GetPassIdFromCsvRow(std::vector<std::string> tokens, unsigned int lineIndex)
Nina Drozd59e15b02019-04-25 15:45:20 +010073{
74 unsigned int passId;
75 try
76 {
James Ward5ea9f312020-10-29 16:19:02 +000077 passId = static_cast<unsigned int>(std::stoi(tokens[0]));
Nina Drozd59e15b02019-04-25 15:45:20 +010078 }
Matthew Bentham31b2e132019-05-22 17:20:55 +010079 catch (const std::invalid_argument&)
Nina Drozd59e15b02019-04-25 15:45:20 +010080 {
James Ward5ea9f312020-10-29 16:19:02 +000081 throw armnn::ParseException(fmt::format("Pass ID [{}] is not correct format on CSV row {}",
82 tokens[0], lineIndex));
Nina Drozd59e15b02019-04-25 15:45:20 +010083 }
84 return passId;
85}
86
James Ward5ea9f312020-10-29 16:19:02 +000087armnn::LayerBindingId GetBindingIdFromCsvRow(std::vector<std::string> tokens, unsigned int lineIndex)
Nina Drozd59e15b02019-04-25 15:45:20 +010088{
89 armnn::LayerBindingId bindingId;
90 try
91 {
James Ward5ea9f312020-10-29 16:19:02 +000092 bindingId = std::stoi(tokens[1]);
Nina Drozd59e15b02019-04-25 15:45:20 +010093 }
Matthew Bentham31b2e132019-05-22 17:20:55 +010094 catch (const std::invalid_argument&)
Nina Drozd59e15b02019-04-25 15:45:20 +010095 {
James Ward5ea9f312020-10-29 16:19:02 +000096 throw armnn::ParseException(fmt::format("Binding ID [{}] is not correct format on CSV row {}",
97 tokens[1], lineIndex));
Nina Drozd59e15b02019-04-25 15:45:20 +010098 }
99 return bindingId;
100}
101
James Ward5ea9f312020-10-29 16:19:02 +0000102std::string GetFileNameFromCsvRow(std::vector<std::string> tokens, unsigned int lineIndex)
Nina Drozd59e15b02019-04-25 15:45:20 +0100103{
James Ward5ea9f312020-10-29 16:19:02 +0000104 std::string fileName = armnn::stringUtils::StringTrim(tokens[2]);
Nina Drozd59e15b02019-04-25 15:45:20 +0100105
Francis Murtagh532a29d2020-06-29 11:50:01 +0100106 if (!fs::exists(fileName))
Nina Drozd59e15b02019-04-25 15:45:20 +0100107 {
James Ward5ea9f312020-10-29 16:19:02 +0000108 throw armnn::ParseException(fmt::format("File [{}] provided on CSV row {} does not exist.",
109 fileName, lineIndex));
Nina Drozd59e15b02019-04-25 15:45:20 +0100110 }
111
112 if (fileName.empty())
113 {
James Ward5ea9f312020-10-29 16:19:02 +0000114 throw armnn::ParseException(fmt::format("Filename cannot be empty on CSV row {} ", lineIndex));
Nina Drozd59e15b02019-04-25 15:45:20 +0100115 }
116 return fileName;
117}
118
119
120void QuantizationDataSet::ParseCsvFile()
121{
122 std::map<unsigned int, QuantizationInput> passIdToQuantizationInput;
Nina Drozd59e15b02019-04-25 15:45:20 +0100123
124 if (m_CsvFilePath == "")
125 {
126 throw armnn::Exception("CSV file not specified.");
127 }
128
James Ward5ea9f312020-10-29 16:19:02 +0000129 std::ifstream inf (m_CsvFilePath.c_str());
130 std::string line;
131 std::vector<std::string> tokens;
132 unsigned int lineIndex = 0;
133
134 if (!inf)
Nina Drozd59e15b02019-04-25 15:45:20 +0100135 {
James Ward5ea9f312020-10-29 16:19:02 +0000136 throw armnn::Exception(fmt::format("CSV file {} not found.", m_CsvFilePath));
Nina Drozd59e15b02019-04-25 15:45:20 +0100137 }
138
James Ward5ea9f312020-10-29 16:19:02 +0000139 while (getline(inf, line))
Nina Drozd59e15b02019-04-25 15:45:20 +0100140 {
James Ward5ea9f312020-10-29 16:19:02 +0000141 tokens = armnn::stringUtils::StringTokenizer(line, ",");
142
143 if (tokens.size() != 3)
Nina Drozd59e15b02019-04-25 15:45:20 +0100144 {
James Ward5ea9f312020-10-29 16:19:02 +0000145 throw armnn::Exception(fmt::format("CSV file [{}] does not have correct number of entries" \
146 "on line {}. Expected 3 entries but was {}.",
147 m_CsvFilePath, lineIndex, tokens.size()));
148
Nina Drozd59e15b02019-04-25 15:45:20 +0100149 }
150
James Ward5ea9f312020-10-29 16:19:02 +0000151 unsigned int passId = GetPassIdFromCsvRow(tokens, lineIndex);
152 armnn::LayerBindingId bindingId = GetBindingIdFromCsvRow(tokens, lineIndex);
153 std::string rawFileName = GetFileNameFromCsvRow(tokens, lineIndex);
Nina Drozd59e15b02019-04-25 15:45:20 +0100154
155 AddInputData(passId, bindingId, rawFileName, passIdToQuantizationInput);
James Ward5ea9f312020-10-29 16:19:02 +0000156
157 ++lineIndex;
Nina Drozd59e15b02019-04-25 15:45:20 +0100158 }
159
160 if (passIdToQuantizationInput.empty())
161 {
162 throw armnn::Exception("Could not parse CSV file.");
163 }
164
165 // Once all entries in CSV file are parsed successfully and QuantizationInput map is populated, populate
166 // QuantizationInputs iterator for easier access and clear the map
167 for (auto itr = passIdToQuantizationInput.begin(); itr != passIdToQuantizationInput.end(); ++itr)
168 {
169 m_QuantizationInputs.emplace_back(itr->second);
170 }
171}
172
173}