blob: 819b437a558ef6b9d5f696362dfc750892ebcfa5 [file] [log] [blame]
Nina Drozd59e15b02019-04-25 15:45:20 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "QuantizationDataSet.hpp"
7#include "CsvReader.hpp"
8
Jan Eilers8eb25602020-03-09 12:13:48 +00009#include <armnn/utility/IgnoreUnused.hpp>
Francis Murtagh532a29d2020-06-29 11:50:01 +010010#include <Filesystem.hpp>
Nina Drozd59e15b02019-04-25 15:45:20 +010011
12namespace armnnQuantizer
13{
14
15QuantizationDataSet::QuantizationDataSet()
16{
17}
18
19QuantizationDataSet::QuantizationDataSet(const std::string csvFilePath):
20 m_QuantizationInputs(),
21 m_CsvFilePath(csvFilePath)
22{
23 ParseCsvFile();
24}
25
26void AddInputData(unsigned int passId,
27 armnn::LayerBindingId bindingId,
28 const std::string& inputFilePath,
29 std::map<unsigned int, QuantizationInput>& passIdToQuantizationInput)
30{
31 auto iterator = passIdToQuantizationInput.find(passId);
32 if (iterator == passIdToQuantizationInput.end())
33 {
34 QuantizationInput input(passId, bindingId, inputFilePath);
35 passIdToQuantizationInput.emplace(passId, input);
36 }
37 else
38 {
39 auto existingQuantizationInput = iterator->second;
40 existingQuantizationInput.AddEntry(bindingId, inputFilePath);
41 }
42}
43
44QuantizationDataSet::~QuantizationDataSet()
45{
46}
47
48void InputLayerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer,
49 armnn::LayerBindingId id,
50 const char* name)
51{
Jan Eilers8eb25602020-03-09 12:13:48 +000052 armnn::IgnoreUnused(name);
Narumol Prangnawaratcd05f3e2019-05-10 17:19:58 +010053 m_TensorInfos.emplace(id, layer->GetOutputSlot(0).GetTensorInfo());
Nina Drozd59e15b02019-04-25 15:45:20 +010054}
55
56armnn::TensorInfo InputLayerVisitor::GetTensorInfo(armnn::LayerBindingId layerBindingId)
57{
58 auto iterator = m_TensorInfos.find(layerBindingId);
59 if (iterator != m_TensorInfos.end())
60 {
61 return m_TensorInfos.at(layerBindingId);
62 }
63 else
64 {
65 throw armnn::Exception("Could not retrieve tensor info for binding ID " + std::to_string(layerBindingId));
66 }
67}
68
69
70unsigned int GetPassIdFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
71{
72 unsigned int passId;
73 try
74 {
75 passId = static_cast<unsigned int>(std::stoi(csvRows[rowIndex].values[0]));
76 }
Matthew Bentham31b2e132019-05-22 17:20:55 +010077 catch (const std::invalid_argument&)
Nina Drozd59e15b02019-04-25 15:45:20 +010078 {
79 throw armnn::ParseException("Pass ID [" + csvRows[rowIndex].values[0] + "]" +
80 " is not correct format on CSV row " + std::to_string(rowIndex));
81 }
82 return passId;
83}
84
85armnn::LayerBindingId GetBindingIdFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
86{
87 armnn::LayerBindingId bindingId;
88 try
89 {
90 bindingId = std::stoi(csvRows[rowIndex].values[1]);
91 }
Matthew Bentham31b2e132019-05-22 17:20:55 +010092 catch (const std::invalid_argument&)
Nina Drozd59e15b02019-04-25 15:45:20 +010093 {
94 throw armnn::ParseException("Binding ID [" + csvRows[rowIndex].values[0] + "]" +
95 " is not correct format on CSV row " + std::to_string(rowIndex));
96 }
97 return bindingId;
98}
99
100std::string GetFileNameFromCsvRow(std::vector<armnnUtils::CsvRow> csvRows, unsigned int rowIndex)
101{
102 std::string fileName = csvRows[rowIndex].values[2];
103
Francis Murtagh532a29d2020-06-29 11:50:01 +0100104 if (!fs::exists(fileName))
Nina Drozd59e15b02019-04-25 15:45:20 +0100105 {
106 throw armnn::ParseException("File [ " + fileName + "] provided on CSV row " + std::to_string(rowIndex) +
107 " does not exist.");
108 }
109
110 if (fileName.empty())
111 {
112 throw armnn::ParseException("Filename cannot be empty on CSV row " + std::to_string(rowIndex));
113 }
114 return fileName;
115}
116
117
118void QuantizationDataSet::ParseCsvFile()
119{
120 std::map<unsigned int, QuantizationInput> passIdToQuantizationInput;
121 armnnUtils::CsvReader reader;
122
123 if (m_CsvFilePath == "")
124 {
125 throw armnn::Exception("CSV file not specified.");
126 }
127
128 // Parse CSV file and extract data
129 std::vector<armnnUtils::CsvRow> csvRows = reader.ParseFile(m_CsvFilePath);
130 if (csvRows.empty())
131 {
132 throw armnn::Exception("CSV file [" + m_CsvFilePath + "] is empty.");
133 }
134
135 for (unsigned int i = 0; i < csvRows.size(); ++i)
136 {
137 if (csvRows[i].values.size() != 3)
138 {
139 throw armnn::Exception("CSV file [" + m_CsvFilePath + "] does not have correct number of entries " +
140 "on line " + std::to_string(i) + ". Expected 3 entries " +
141 "but was " + std::to_string(csvRows[i].values.size()));
142 }
143
144 unsigned int passId = GetPassIdFromCsvRow(csvRows, i);
145 armnn::LayerBindingId bindingId = GetBindingIdFromCsvRow(csvRows, i);
146 std::string rawFileName = GetFileNameFromCsvRow(csvRows, i);
147
148 AddInputData(passId, bindingId, rawFileName, passIdToQuantizationInput);
149 }
150
151 if (passIdToQuantizationInput.empty())
152 {
153 throw armnn::Exception("Could not parse CSV file.");
154 }
155
156 // Once all entries in CSV file are parsed successfully and QuantizationInput map is populated, populate
157 // QuantizationInputs iterator for easier access and clear the map
158 for (auto itr = passIdToQuantizationInput.begin(); itr != passIdToQuantizationInput.end(); ++itr)
159 {
160 m_QuantizationInputs.emplace_back(itr->second);
161 }
162}
163
164}