blob: 9694342fe827044d5065e7816ab5c5c6b259262a [file] [log] [blame]
Nina Drozd59e15b02019-04-25 15:45:20 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "QuantizationDataSet.hpp"
7#include "CsvReader.hpp"
8
9#define BOOST_FILESYSTEM_NO_DEPRECATED
10
Derek Lamberti859f9ce2019-12-10 22:05:21 +000011#include <boost/core/ignore_unused.hpp>
Nina Drozd59e15b02019-04-25 15:45:20 +010012#include <boost/filesystem/operations.hpp>
13#include <boost/filesystem/path.hpp>
14
15namespace armnnQuantizer
16{
17
18QuantizationDataSet::QuantizationDataSet()
19{
20}
21
22QuantizationDataSet::QuantizationDataSet(const std::string csvFilePath):
23 m_QuantizationInputs(),
24 m_CsvFilePath(csvFilePath)
25{
26 ParseCsvFile();
27}
28
29void AddInputData(unsigned int passId,
30 armnn::LayerBindingId bindingId,
31 const std::string& inputFilePath,
32 std::map<unsigned int, QuantizationInput>& passIdToQuantizationInput)
33{
34 auto iterator = passIdToQuantizationInput.find(passId);
35 if (iterator == passIdToQuantizationInput.end())
36 {
37 QuantizationInput input(passId, bindingId, inputFilePath);
38 passIdToQuantizationInput.emplace(passId, input);
39 }
40 else
41 {
42 auto existingQuantizationInput = iterator->second;
43 existingQuantizationInput.AddEntry(bindingId, inputFilePath);
44 }
45}
46
47QuantizationDataSet::~QuantizationDataSet()
48{
49}
50
51void InputLayerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer,
52 armnn::LayerBindingId id,
53 const char* name)
54{
Derek Lamberti859f9ce2019-12-10 22:05:21 +000055 boost::ignore_unused(name);
Narumol Prangnawaratcd05f3e2019-05-10 17:19:58 +010056 m_TensorInfos.emplace(id, layer->GetOutputSlot(0).GetTensorInfo());
Nina Drozd59e15b02019-04-25 15:45:20 +010057}
58
59armnn::TensorInfo InputLayerVisitor::GetTensorInfo(armnn::LayerBindingId layerBindingId)
60{
61 auto iterator = m_TensorInfos.find(layerBindingId);
62 if (iterator != m_TensorInfos.end())
63 {
64 return m_TensorInfos.at(layerBindingId);
65 }
66 else
67 {
68 throw armnn::Exception("Could not retrieve tensor info for binding ID " + std::to_string(layerBindingId));
69 }
70}
71
72
73unsigned int GetPassIdFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
74{
75 unsigned int passId;
76 try
77 {
78 passId = static_cast<unsigned int>(std::stoi(csvRows[rowIndex].values[0]));
79 }
Matthew Bentham31b2e132019-05-22 17:20:55 +010080 catch (const std::invalid_argument&)
Nina Drozd59e15b02019-04-25 15:45:20 +010081 {
82 throw armnn::ParseException("Pass ID [" + csvRows[rowIndex].values[0] + "]" +
83 " is not correct format on CSV row " + std::to_string(rowIndex));
84 }
85 return passId;
86}
87
88armnn::LayerBindingId GetBindingIdFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
89{
90 armnn::LayerBindingId bindingId;
91 try
92 {
93 bindingId = std::stoi(csvRows[rowIndex].values[1]);
94 }
Matthew Bentham31b2e132019-05-22 17:20:55 +010095 catch (const std::invalid_argument&)
Nina Drozd59e15b02019-04-25 15:45:20 +010096 {
97 throw armnn::ParseException("Binding ID [" + csvRows[rowIndex].values[0] + "]" +
98 " is not correct format on CSV row " + std::to_string(rowIndex));
99 }
100 return bindingId;
101}
102
103std::string GetFileNameFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
104{
105 std::string fileName = csvRows[rowIndex].values[2];
106
107 if (!boost::filesystem::exists(fileName))
108 {
109 throw armnn::ParseException("File [ " + fileName + "] provided on CSV row " + std::to_string(rowIndex) +
110 " does not exist.");
111 }
112
113 if (fileName.empty())
114 {
115 throw armnn::ParseException("Filename cannot be empty on CSV row " + std::to_string(rowIndex));
116 }
117 return fileName;
118}
119
120
121void QuantizationDataSet::ParseCsvFile()
122{
123 std::map<unsigned int, QuantizationInput> passIdToQuantizationInput;
124 armnnUtils::CsvReader reader;
125
126 if (m_CsvFilePath == "")
127 {
128 throw armnn::Exception("CSV file not specified.");
129 }
130
131 // Parse CSV file and extract data
132 std::vector<armnnUtils::CsvRow> csvRows = reader.ParseFile(m_CsvFilePath);
133 if (csvRows.empty())
134 {
135 throw armnn::Exception("CSV file [" + m_CsvFilePath + "] is empty.");
136 }
137
138 for (unsigned int i = 0; i < csvRows.size(); ++i)
139 {
140 if (csvRows[i].values.size() != 3)
141 {
142 throw armnn::Exception("CSV file [" + m_CsvFilePath + "] does not have correct number of entries " +
143 "on line " + std::to_string(i) + ". Expected 3 entries " +
144 "but was " + std::to_string(csvRows[i].values.size()));
145 }
146
147 unsigned int passId = GetPassIdFromCsvRow(csvRows, i);
148 armnn::LayerBindingId bindingId = GetBindingIdFromCsvRow(csvRows, i);
149 std::string rawFileName = GetFileNameFromCsvRow(csvRows, i);
150
151 AddInputData(passId, bindingId, rawFileName, passIdToQuantizationInput);
152 }
153
154 if (passIdToQuantizationInput.empty())
155 {
156 throw armnn::Exception("Could not parse CSV file.");
157 }
158
159 // Once all entries in CSV file are parsed successfully and QuantizationInput map is populated, populate
160 // QuantizationInputs iterator for easier access and clear the map
161 for (auto itr = passIdToQuantizationInput.begin(); itr != passIdToQuantizationInput.end(); ++itr)
162 {
163 m_QuantizationInputs.emplace_back(itr->second);
164 }
165}
166
167}