blob: 7042d74d00ebb11dc395d53b473233ff95c79c24 [file] [log] [blame]
Nina Drozd59e15b02019-04-25 15:45:20 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "QuantizationDataSet.hpp"
7#include "CsvReader.hpp"
8
9#define BOOST_FILESYSTEM_NO_DEPRECATED
10
Jan Eilers8eb25602020-03-09 12:13:48 +000011#include <armnn/utility/IgnoreUnused.hpp>
12
Nina Drozd59e15b02019-04-25 15:45:20 +010013#include <boost/filesystem/operations.hpp>
14#include <boost/filesystem/path.hpp>
15
16namespace armnnQuantizer
17{
18
19QuantizationDataSet::QuantizationDataSet()
20{
21}
22
23QuantizationDataSet::QuantizationDataSet(const std::string csvFilePath):
24 m_QuantizationInputs(),
25 m_CsvFilePath(csvFilePath)
26{
27 ParseCsvFile();
28}
29
30void AddInputData(unsigned int passId,
31 armnn::LayerBindingId bindingId,
32 const std::string& inputFilePath,
33 std::map<unsigned int, QuantizationInput>& passIdToQuantizationInput)
34{
35 auto iterator = passIdToQuantizationInput.find(passId);
36 if (iterator == passIdToQuantizationInput.end())
37 {
38 QuantizationInput input(passId, bindingId, inputFilePath);
39 passIdToQuantizationInput.emplace(passId, input);
40 }
41 else
42 {
43 auto existingQuantizationInput = iterator->second;
44 existingQuantizationInput.AddEntry(bindingId, inputFilePath);
45 }
46}
47
48QuantizationDataSet::~QuantizationDataSet()
49{
50}
51
52void InputLayerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer,
53 armnn::LayerBindingId id,
54 const char* name)
55{
Jan Eilers8eb25602020-03-09 12:13:48 +000056 armnn::IgnoreUnused(name);
Narumol Prangnawaratcd05f3e2019-05-10 17:19:58 +010057 m_TensorInfos.emplace(id, layer->GetOutputSlot(0).GetTensorInfo());
Nina Drozd59e15b02019-04-25 15:45:20 +010058}
59
60armnn::TensorInfo InputLayerVisitor::GetTensorInfo(armnn::LayerBindingId layerBindingId)
61{
62 auto iterator = m_TensorInfos.find(layerBindingId);
63 if (iterator != m_TensorInfos.end())
64 {
65 return m_TensorInfos.at(layerBindingId);
66 }
67 else
68 {
69 throw armnn::Exception("Could not retrieve tensor info for binding ID " + std::to_string(layerBindingId));
70 }
71}
72
73
74unsigned int GetPassIdFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
75{
76 unsigned int passId;
77 try
78 {
79 passId = static_cast<unsigned int>(std::stoi(csvRows[rowIndex].values[0]));
80 }
Matthew Bentham31b2e132019-05-22 17:20:55 +010081 catch (const std::invalid_argument&)
Nina Drozd59e15b02019-04-25 15:45:20 +010082 {
83 throw armnn::ParseException("Pass ID [" + csvRows[rowIndex].values[0] + "]" +
84 " is not correct format on CSV row " + std::to_string(rowIndex));
85 }
86 return passId;
87}
88
89armnn::LayerBindingId GetBindingIdFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
90{
91 armnn::LayerBindingId bindingId;
92 try
93 {
94 bindingId = std::stoi(csvRows[rowIndex].values[1]);
95 }
Matthew Bentham31b2e132019-05-22 17:20:55 +010096 catch (const std::invalid_argument&)
Nina Drozd59e15b02019-04-25 15:45:20 +010097 {
98 throw armnn::ParseException("Binding ID [" + csvRows[rowIndex].values[0] + "]" +
99 " is not correct format on CSV row " + std::to_string(rowIndex));
100 }
101 return bindingId;
102}
103
104std::string GetFileNameFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
105{
106 std::string fileName = csvRows[rowIndex].values[2];
107
108 if (!boost::filesystem::exists(fileName))
109 {
110 throw armnn::ParseException("File [ " + fileName + "] provided on CSV row " + std::to_string(rowIndex) +
111 " does not exist.");
112 }
113
114 if (fileName.empty())
115 {
116 throw armnn::ParseException("Filename cannot be empty on CSV row " + std::to_string(rowIndex));
117 }
118 return fileName;
119}
120
121
122void QuantizationDataSet::ParseCsvFile()
123{
124 std::map<unsigned int, QuantizationInput> passIdToQuantizationInput;
125 armnnUtils::CsvReader reader;
126
127 if (m_CsvFilePath == "")
128 {
129 throw armnn::Exception("CSV file not specified.");
130 }
131
132 // Parse CSV file and extract data
133 std::vector<armnnUtils::CsvRow> csvRows = reader.ParseFile(m_CsvFilePath);
134 if (csvRows.empty())
135 {
136 throw armnn::Exception("CSV file [" + m_CsvFilePath + "] is empty.");
137 }
138
139 for (unsigned int i = 0; i < csvRows.size(); ++i)
140 {
141 if (csvRows[i].values.size() != 3)
142 {
143 throw armnn::Exception("CSV file [" + m_CsvFilePath + "] does not have correct number of entries " +
144 "on line " + std::to_string(i) + ". Expected 3 entries " +
145 "but was " + std::to_string(csvRows[i].values.size()));
146 }
147
148 unsigned int passId = GetPassIdFromCsvRow(csvRows, i);
149 armnn::LayerBindingId bindingId = GetBindingIdFromCsvRow(csvRows, i);
150 std::string rawFileName = GetFileNameFromCsvRow(csvRows, i);
151
152 AddInputData(passId, bindingId, rawFileName, passIdToQuantizationInput);
153 }
154
155 if (passIdToQuantizationInput.empty())
156 {
157 throw armnn::Exception("Could not parse CSV file.");
158 }
159
160 // Once all entries in CSV file are parsed successfully and QuantizationInput map is populated, populate
161 // QuantizationInputs iterator for easier access and clear the map
162 for (auto itr = passIdToQuantizationInput.begin(); itr != passIdToQuantizationInput.end(); ++itr)
163 {
164 m_QuantizationInputs.emplace_back(itr->second);
165 }
166}
167
168}