blob: d22588385483a3abd271e5e49c5cb436ad530494 [file] [log] [blame]
Nina Drozd59e15b02019-04-25 15:45:20 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "QuantizationDataSet.hpp"
7#include "CsvReader.hpp"
8
9#define BOOST_FILESYSTEM_NO_DEPRECATED
10
11#include <boost/filesystem/operations.hpp>
12#include <boost/filesystem/path.hpp>
13
14namespace armnnQuantizer
15{
16
17QuantizationDataSet::QuantizationDataSet()
18{
19}
20
21QuantizationDataSet::QuantizationDataSet(const std::string csvFilePath):
22 m_QuantizationInputs(),
23 m_CsvFilePath(csvFilePath)
24{
25 ParseCsvFile();
26}
27
28void AddInputData(unsigned int passId,
29 armnn::LayerBindingId bindingId,
30 const std::string& inputFilePath,
31 std::map<unsigned int, QuantizationInput>& passIdToQuantizationInput)
32{
33 auto iterator = passIdToQuantizationInput.find(passId);
34 if (iterator == passIdToQuantizationInput.end())
35 {
36 QuantizationInput input(passId, bindingId, inputFilePath);
37 passIdToQuantizationInput.emplace(passId, input);
38 }
39 else
40 {
41 auto existingQuantizationInput = iterator->second;
42 existingQuantizationInput.AddEntry(bindingId, inputFilePath);
43 }
44}
45
46QuantizationDataSet::~QuantizationDataSet()
47{
48}
49
50void InputLayerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer,
51 armnn::LayerBindingId id,
52 const char* name)
53{
54 m_TensorInfos.emplace(id, layer->GetInputSlot(0).GetConnection()->GetTensorInfo());
55}
56
57armnn::TensorInfo InputLayerVisitor::GetTensorInfo(armnn::LayerBindingId layerBindingId)
58{
59 auto iterator = m_TensorInfos.find(layerBindingId);
60 if (iterator != m_TensorInfos.end())
61 {
62 return m_TensorInfos.at(layerBindingId);
63 }
64 else
65 {
66 throw armnn::Exception("Could not retrieve tensor info for binding ID " + std::to_string(layerBindingId));
67 }
68}
69
70
71unsigned int GetPassIdFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
72{
73 unsigned int passId;
74 try
75 {
76 passId = static_cast<unsigned int>(std::stoi(csvRows[rowIndex].values[0]));
77 }
78 catch (std::invalid_argument)
79 {
80 throw armnn::ParseException("Pass ID [" + csvRows[rowIndex].values[0] + "]" +
81 " is not correct format on CSV row " + std::to_string(rowIndex));
82 }
83 return passId;
84}
85
86armnn::LayerBindingId GetBindingIdFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
87{
88 armnn::LayerBindingId bindingId;
89 try
90 {
91 bindingId = std::stoi(csvRows[rowIndex].values[1]);
92 }
93 catch (std::invalid_argument)
94 {
95 throw armnn::ParseException("Binding ID [" + csvRows[rowIndex].values[0] + "]" +
96 " is not correct format on CSV row " + std::to_string(rowIndex));
97 }
98 return bindingId;
99}
100
101std::string GetFileNameFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
102{
103 std::string fileName = csvRows[rowIndex].values[2];
104
105 if (!boost::filesystem::exists(fileName))
106 {
107 throw armnn::ParseException("File [ " + fileName + "] provided on CSV row " + std::to_string(rowIndex) +
108 " does not exist.");
109 }
110
111 if (fileName.empty())
112 {
113 throw armnn::ParseException("Filename cannot be empty on CSV row " + std::to_string(rowIndex));
114 }
115 return fileName;
116}
117
118
119void QuantizationDataSet::ParseCsvFile()
120{
121 std::map<unsigned int, QuantizationInput> passIdToQuantizationInput;
122 armnnUtils::CsvReader reader;
123
124 if (m_CsvFilePath == "")
125 {
126 throw armnn::Exception("CSV file not specified.");
127 }
128
129 // Parse CSV file and extract data
130 std::vector<armnnUtils::CsvRow> csvRows = reader.ParseFile(m_CsvFilePath);
131 if (csvRows.empty())
132 {
133 throw armnn::Exception("CSV file [" + m_CsvFilePath + "] is empty.");
134 }
135
136 for (unsigned int i = 0; i < csvRows.size(); ++i)
137 {
138 if (csvRows[i].values.size() != 3)
139 {
140 throw armnn::Exception("CSV file [" + m_CsvFilePath + "] does not have correct number of entries " +
141 "on line " + std::to_string(i) + ". Expected 3 entries " +
142 "but was " + std::to_string(csvRows[i].values.size()));
143 }
144
145 unsigned int passId = GetPassIdFromCsvRow(csvRows, i);
146 armnn::LayerBindingId bindingId = GetBindingIdFromCsvRow(csvRows, i);
147 std::string rawFileName = GetFileNameFromCsvRow(csvRows, i);
148
149 AddInputData(passId, bindingId, rawFileName, passIdToQuantizationInput);
150 }
151
152 if (passIdToQuantizationInput.empty())
153 {
154 throw armnn::Exception("Could not parse CSV file.");
155 }
156
157 // Once all entries in CSV file are parsed successfully and QuantizationInput map is populated, populate
158 // QuantizationInputs iterator for easier access and clear the map
159 for (auto itr = passIdToQuantizationInput.begin(); itr != passIdToQuantizationInput.end(); ++itr)
160 {
161 m_QuantizationInputs.emplace_back(itr->second);
162 }
163}
164
165}