blob: 3e7c87d65311d91e3526fcdb69eb5585a43ba75f [file] [log] [blame]
Jan Eilers45274902020-10-15 18:34:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NetworkExecutionUtils.hpp"
7
8#include <Filesystem.hpp>
9#include <InferenceTest.hpp>
10#include <ResolveType.hpp>
11
12#if defined(ARMNN_SERIALIZER)
13#include "armnnDeserializer/IDeserializer.hpp"
14#endif
15#if defined(ARMNN_CAFFE_PARSER)
16#include "armnnCaffeParser/ICaffeParser.hpp"
17#endif
18#if defined(ARMNN_TF_PARSER)
19#include "armnnTfParser/ITfParser.hpp"
20#endif
21#if defined(ARMNN_TF_LITE_PARSER)
22#include "armnnTfLiteParser/ITfLiteParser.hpp"
23#endif
24#if defined(ARMNN_ONNX_PARSER)
25#include "armnnOnnxParser/IOnnxParser.hpp"
26#endif
27
28
29template<typename T, typename TParseElementFunc>
30std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char* chars = "\t ,:")
31{
32 std::vector<T> result;
33 // Processes line-by-line.
34 std::string line;
35 while (std::getline(stream, line))
36 {
37 std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
38 for (const std::string& token : tokens)
39 {
40 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
41 {
42 try
43 {
44 result.push_back(parseElementFunc(token));
45 }
46 catch (const std::exception&)
47 {
48 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
49 }
50 }
51 }
52 }
53
54 return result;
55}
56
57
58template<armnn::DataType NonQuantizedType>
59auto ParseDataArray(std::istream& stream);
60
61template<armnn::DataType QuantizedType>
62auto ParseDataArray(std::istream& stream,
63 const float& quantizationScale,
64 const int32_t& quantizationOffset);
65
66template<>
67auto ParseDataArray<armnn::DataType::Float32>(std::istream& stream)
68{
69 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
70}
71
72template<>
73auto ParseDataArray<armnn::DataType::Signed32>(std::istream& stream)
74{
75 return ParseArrayImpl<int>(stream, [](const std::string& s) { return std::stoi(s); });
76}
77
78template<>
79auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream)
80{
81 return ParseArrayImpl<uint8_t>(stream,
82 [](const std::string& s) { return armnn::numeric_cast<uint8_t>(std::stoi(s)); });
83}
84
85template<>
86auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream,
87 const float& quantizationScale,
88 const int32_t& quantizationOffset)
89{
90 return ParseArrayImpl<uint8_t>(stream,
91 [&quantizationScale, &quantizationOffset](const std::string& s)
92 {
93 return armnn::numeric_cast<uint8_t>(
94 armnn::Quantize<uint8_t>(std::stof(s),
95 quantizationScale,
96 quantizationOffset));
97 });
98}
99
100template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
101std::vector<T> GenerateDummyTensorData(unsigned int numElements)
102{
103 return std::vector<T>(numElements, static_cast<T>(0));
104}
105
106
107std::vector<unsigned int> ParseArray(std::istream& stream)
108{
109 return ParseArrayImpl<unsigned int>(
110 stream,
111 [](const std::string& s) { return armnn::numeric_cast<unsigned int>(std::stoi(s)); });
112}
113
114std::vector<std::string> ParseStringList(const std::string& inputString, const char* delimiter)
115{
116 std::stringstream stream(inputString);
117 return ParseArrayImpl<std::string>(stream, [](const std::string& s) {
118 return armnn::stringUtils::StringTrimCopy(s); }, delimiter);
119}
120
121
122TensorPrinter::TensorPrinter(const std::string& binding,
123 const armnn::TensorInfo& info,
124 const std::string& outputTensorFile,
125 bool dequantizeOutput)
126 : m_OutputBinding(binding)
127 , m_Scale(info.GetQuantizationScale())
128 , m_Offset(info.GetQuantizationOffset())
129 , m_OutputTensorFile(outputTensorFile)
130 , m_DequantizeOutput(dequantizeOutput) {}
131
132void TensorPrinter::operator()(const std::vector<float>& values)
133{
134 ForEachValue(values, [](float value)
135 {
136 printf("%f ", value);
137 });
138 WriteToFile(values);
139}
140
141void TensorPrinter::operator()(const std::vector<uint8_t>& values)
142{
143 if(m_DequantizeOutput)
144 {
145 auto& scale = m_Scale;
146 auto& offset = m_Offset;
147 std::vector<float> dequantizedValues;
148 ForEachValue(values, [&scale, &offset, &dequantizedValues](uint8_t value)
149 {
150 auto dequantizedValue = armnn::Dequantize(value, scale, offset);
151 printf("%f ", dequantizedValue);
152 dequantizedValues.push_back(dequantizedValue);
153 });
154 WriteToFile(dequantizedValues);
155 }
156 else
157 {
158 const std::vector<int> intValues(values.begin(), values.end());
159 operator()(intValues);
160 }
161}
162
163void TensorPrinter::operator()(const std::vector<int>& values)
164{
165 ForEachValue(values, [](int value)
166 {
167 printf("%d ", value);
168 });
169 WriteToFile(values);
170}
171
172template<typename Container, typename Delegate>
173void TensorPrinter::ForEachValue(const Container& c, Delegate delegate)
174{
175 std::cout << m_OutputBinding << ": ";
176 for (const auto& value : c)
177 {
178 delegate(value);
179 }
180 printf("\n");
181}
182
183template<typename T>
184void TensorPrinter::WriteToFile(const std::vector<T>& values)
185{
186 if (!m_OutputTensorFile.empty())
187 {
188 std::ofstream outputTensorFile;
189 outputTensorFile.open(m_OutputTensorFile, std::ofstream::out | std::ofstream::trunc);
190 if (outputTensorFile.is_open())
191 {
192 outputTensorFile << m_OutputBinding << ": ";
193 std::copy(values.begin(), values.end(), std::ostream_iterator<T>(outputTensorFile, " "));
194 }
195 else
196 {
197 ARMNN_LOG(info) << "Output Tensor File: " << m_OutputTensorFile << " could not be opened!";
198 }
199 outputTensorFile.close();
200 }
201}
202
203using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
204using QuantizationParams = std::pair<float, int32_t>;
205
206void PopulateTensorWithData(TContainer& tensorData,
207 unsigned int numElements,
208 const std::string& dataTypeStr,
209 const armnn::Optional<QuantizationParams>& qParams,
210 const armnn::Optional<std::string>& dataFile)
211{
212 const bool readFromFile = dataFile.has_value() && !dataFile.value().empty();
213 const bool quantizeData = qParams.has_value();
214
215 std::ifstream inputTensorFile;
216 if (readFromFile)
217 {
218 inputTensorFile = std::ifstream(dataFile.value());
219 }
220
221 if (dataTypeStr.compare("float") == 0)
222 {
223 if (quantizeData)
224 {
225 const float qScale = qParams.value().first;
226 const int qOffset = qParams.value().second;
227
228 tensorData = readFromFile ?
229 ParseDataArray<armnn::DataType::QAsymmU8>(inputTensorFile, qScale, qOffset) :
230 GenerateDummyTensorData<armnn::DataType::QAsymmU8>(numElements);
231 }
232 else
233 {
234 tensorData = readFromFile ?
235 ParseDataArray<armnn::DataType::Float32>(inputTensorFile) :
236 GenerateDummyTensorData<armnn::DataType::Float32>(numElements);
237 }
238 }
239 else if (dataTypeStr.compare("int") == 0)
240 {
241 tensorData = readFromFile ?
242 ParseDataArray<armnn::DataType::Signed32>(inputTensorFile) :
243 GenerateDummyTensorData<armnn::DataType::Signed32>(numElements);
244 }
245 else if (dataTypeStr.compare("qasymm8") == 0)
246 {
247 tensorData = readFromFile ?
248 ParseDataArray<armnn::DataType::QAsymmU8>(inputTensorFile) :
249 GenerateDummyTensorData<armnn::DataType::QAsymmU8>(numElements);
250 }
251 else
252 {
253 std::string errorMessage = "Unsupported tensor data type " + dataTypeStr;
254 ARMNN_LOG(fatal) << errorMessage;
255
256 inputTensorFile.close();
257 throw armnn::Exception(errorMessage);
258 }
259
260 inputTensorFile.close();
261}
262
263bool ValidatePath(const std::string& file, const bool expectFile)
264{
265 if (!fs::exists(file))
266 {
267 std::cerr << "Given file path '" << file << "' does not exist" << std::endl;
268 return false;
269 }
270 if (!fs::is_regular_file(file) && expectFile)
271 {
272 std::cerr << "Given file path '" << file << "' is not a regular file" << std::endl;
273 return false;
274 }
275 return true;
276}
277
278bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile)
279{
280 bool allPathsValid = true;
281 for (auto const& file : fileVec)
282 {
283 if(!ValidatePath(file, expectFile))
284 {
285 allPathsValid = false;
286 }
287 }
288 return allPathsValid;
289}
290
291
292