blob: 14d7fe5551b40923ad2f32a066448a122277ac33 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin83b42912022-07-07 14:24:59 +01002// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
Francis Murtaghbee4bc92019-06-18 12:30:37 +01003// SPDX-License-Identifier: MIT
4//
Francis Murtaghbee4bc92019-06-18 12:30:37 +01005
Jan Eilers45274902020-10-15 18:34:43 +01006#pragma once
7
Finn Williams56870182020-11-20 13:57:53 +00008#include <armnn/Logging.hpp>
9#include <armnn/utility/StringUtils.hpp>
Teresa Charlin83b42912022-07-07 14:24:59 +010010#include <armnn/utility/NumericCast.hpp>
11#include <armnn/BackendRegistry.hpp>
Francis Murtaghbee4bc92019-06-18 12:30:37 +010012
13#include <iostream>
Finn Williams56870182020-11-20 13:57:53 +000014#include <fstream>
Teresa Charlin83b42912022-07-07 14:24:59 +010015#include <iomanip>
16#include <iterator>
Francis Murtaghbee4bc92019-06-18 12:30:37 +010017
Teresa Charlin83b42912022-07-07 14:24:59 +010018/**
19 * Given a measured duration and a threshold time tell the user whether we succeeded or not.
20 *
21 * @param duration the measured inference duration.
22 * @param thresholdTime the threshold time in milliseconds.
23 * @return false if the measured time exceeded the threshold.
24 */
25bool CheckInferenceTimeThreshold(const std::chrono::duration<double, std::milli>& duration,
26 const double& thresholdTime);
27
28inline bool CheckRequestedBackendsAreValid(const std::vector<armnn::BackendId>& backendIds,
29 armnn::Optional<std::string&> invalidBackendIds = armnn::EmptyOptional())
30{
31 if (backendIds.empty())
32 {
33 return false;
34 }
35
36 armnn::BackendIdSet validBackendIds = armnn::BackendRegistryInstance().GetBackendIds();
37
38 bool allValid = true;
39 for (const auto& backendId : backendIds)
40 {
41 if (std::find(validBackendIds.begin(), validBackendIds.end(), backendId) == validBackendIds.end())
42 {
43 allValid = false;
44 if (invalidBackendIds)
45 {
46 if (!invalidBackendIds.value().empty())
47 {
48 invalidBackendIds.value() += ", ";
49 }
50 invalidBackendIds.value() += backendId;
51 }
52 }
53 }
54 return allValid;
55}
Francis Murtaghbee4bc92019-06-18 12:30:37 +010056
Jan Eilers45274902020-10-15 18:34:43 +010057std::vector<unsigned int> ParseArray(std::istream& stream);
Francis Murtaghbee4bc92019-06-18 12:30:37 +010058
Jan Eilers45274902020-10-15 18:34:43 +010059/// Splits a given string at every accurance of delimiter into a vector of string
60std::vector<std::string> ParseStringList(const std::string& inputString, const char* delimiter);
Francis Murtaghbee4bc92019-06-18 12:30:37 +010061
Teresa Charlin83b42912022-07-07 14:24:59 +010062/// Dequantize an array of a given type
63/// @param array Type erased array to dequantize
64/// @param numElements Elements in the array
65/// @param array Type erased array to dequantize
66template <typename T>
67std::vector<float> DequantizeArray(const void* array, unsigned int numElements, float scale, int32_t offset)
Francis Murtaghbee4bc92019-06-18 12:30:37 +010068{
Teresa Charlin83b42912022-07-07 14:24:59 +010069 const T* quantizedArray = reinterpret_cast<const T*>(array);
70 std::vector<float> dequantizedVector;
71 dequantizedVector.reserve(numElements);
72 for (unsigned int i = 0; i < numElements; ++i)
73 {
74 float f = armnn::Dequantize(*(quantizedArray + i), scale, offset);
75 dequantizedVector.push_back(f);
76 }
77 return dequantizedVector;
78}
Francis Murtaghbee4bc92019-06-18 12:30:37 +010079
Teresa Charlin83b42912022-07-07 14:24:59 +010080void LogAndThrow(std::string eMsg);
Aron Virginas-Tarc82c8732019-10-24 17:07:43 +010081
Jan Eilers45274902020-10-15 18:34:43 +010082/**
83 * Verifies if the given string is a valid path. Reports invalid paths to std::err.
84 * @param file string - A string containing the path to check
85 * @param expectFile bool - If true, checks for a regular file.
86 * @return bool - True if given string is a valid path., false otherwise.
87 * */
88bool ValidatePath(const std::string& file, const bool expectFile);
Aron Virginas-Tarc82c8732019-10-24 17:07:43 +010089
Jan Eilers45274902020-10-15 18:34:43 +010090/**
91 * Verifies if a given vector of strings are valid paths. Reports invalid paths to std::err.
92 * @param fileVec vector of string - A vector of string containing the paths to check
93 * @param expectFile bool - If true, checks for a regular file.
94 * @return bool - True if all given strings are valid paths., false otherwise.
95 * */
Finn Williams56870182020-11-20 13:57:53 +000096bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile);
97
Teresa Charlin83b42912022-07-07 14:24:59 +010098/// Returns a function of read the given type as a string
99template <typename Integer, typename std::enable_if_t<std::is_integral<Integer>::value>* = nullptr>
100std::function<Integer(const std::string&)> GetParseElementFunc()
101{
102 return [](const std::string& s) { return armnn::numeric_cast<Integer>(std::stoi(s)); };
103}
104
105template <typename Float, std::enable_if_t<std::is_floating_point<Float>::value>* = nullptr>
106std::function<Float(const std::string&)> GetParseElementFunc()
107{
108 return [](const std::string& s) { return std::stof(s); };
109}
110
111template <typename T>
112void PopulateTensorWithData(T* tensor,
113 const unsigned int numElements,
114 const armnn::Optional<std::string>& dataFile,
115 const std::string& inputName)
116{
117 const bool readFromFile = dataFile.has_value() && !dataFile.value().empty();
118
119 std::ifstream inputTensorFile;
120 if (!readFromFile)
121 {
122 std::fill(tensor, tensor + numElements, 0);
123 return;
124 }
125 else
126 {
127 inputTensorFile = std::ifstream(dataFile.value());
128 }
129
130 auto parseElementFunc = GetParseElementFunc<T>();
131 std::string line;
132 unsigned int index = 0;
133 while (std::getline(inputTensorFile, line))
134 {
135 std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, "\t ,:");
136 for (const std::string& token : tokens)
137 {
138 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
139 {
140 try
141 {
142 if (index == numElements)
143 {
144 ARMNN_LOG(error) << "Number of elements: " << (index +1) << " in file \"" << dataFile.value()
145 << "\" does not match number of elements: " << numElements
146 << " for input \"" << inputName << "\".";
147 }
148 *(tensor + index) = parseElementFunc(token);
149 index++;
150 }
151 catch (const std::exception&)
152 {
153 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
154 }
155 }
156 }
157 }
158
159 if (index != numElements)
160 {
161 ARMNN_LOG(error) << "Number of elements: " << (index +1) << " in file \"" << inputName
162 << "\" does not match number of elements: " << numElements
163 << " for input \"" << inputName << "\".";
164 }
165}
166
167template<typename T>
168void WriteToFile(const std::string& outputTensorFileName,
169 const std::string& outputName,
170 const T* const array,
171 const unsigned int numElements)
172{
173 std::ofstream outputTensorFile;
174 outputTensorFile.open(outputTensorFileName, std::ofstream::out | std::ofstream::trunc);
175 if (outputTensorFile.is_open())
176 {
177 outputTensorFile << outputName << ": ";
178 std::copy(array, array + numElements, std::ostream_iterator<T>(outputTensorFile, " "));
179 }
180 else
181 {
182 ARMNN_LOG(info) << "Output Tensor File: " << outputTensorFileName << " could not be opened!";
183 }
184 outputTensorFile.close();
185}
186
187struct OutputWriteInfo
188{
189 const armnn::Optional<std::string>& m_OutputTensorFile;
190 const std::string& m_OutputName;
191 const armnn::Tensor& m_Tensor;
192 const bool m_PrintTensor;
193};
194
195template <typename T>
196void PrintTensor(OutputWriteInfo& info, const char* formatString)
197{
198 const T* array = reinterpret_cast<const T*>(info.m_Tensor.GetMemoryArea());
199
200 if (info.m_OutputTensorFile.has_value())
201 {
202 WriteToFile(info.m_OutputTensorFile.value(),
203 info.m_OutputName,
204 array,
205 info.m_Tensor.GetNumElements());
206 }
207
208 if (info.m_PrintTensor)
209 {
210 for (unsigned int i = 0; i < info.m_Tensor.GetNumElements(); i++)
211 {
212 printf(formatString, array[i]);
213 }
214 }
215}
216
217template <typename T>
218void PrintQuantizedTensor(OutputWriteInfo& info)
219{
220 std::vector<float> dequantizedValues;
221 auto tensor = info.m_Tensor;
222 dequantizedValues = DequantizeArray<T>(tensor.GetMemoryArea(),
223 tensor.GetNumElements(),
224 tensor.GetInfo().GetQuantizationScale(),
225 tensor.GetInfo().GetQuantizationOffset());
226
227 if (info.m_OutputTensorFile.has_value())
228 {
229 WriteToFile(info.m_OutputTensorFile.value(),
230 info.m_OutputName,
231 dequantizedValues.data(),
232 tensor.GetNumElements());
233 }
234
235 if (info.m_PrintTensor)
236 {
237 std::for_each(dequantizedValues.begin(), dequantizedValues.end(), [&](float value)
238 {
239 printf("%f ", value);
240 });
241 }
242}
243
Finn Williams56870182020-11-20 13:57:53 +0000244template<typename T, typename TParseElementFunc>
245std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char* chars = "\t ,:")
246{
247 std::vector<T> result;
248 // Processes line-by-line.
249 std::string line;
250 while (std::getline(stream, line))
251 {
252 std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
253 for (const std::string& token : tokens)
254 {
255 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
256 {
257 try
258 {
259 result.push_back(parseElementFunc(token));
260 }
261 catch (const std::exception&)
262 {
263 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
264 }
265 }
266 }
267 }
268
269 return result;
270}
271
Teresa Charlin83b42912022-07-07 14:24:59 +0100272/// Compute the root-mean-square error (RMSE)
273/// @param expected
274/// @param actual
275/// @param size size of the tensor
276/// @return float the RMSE
277template<typename T>
278float ComputeRMSE(const void* expected, const void* actual, const size_t size)
Finn Williams56870182020-11-20 13:57:53 +0000279{
Teresa Charlin83b42912022-07-07 14:24:59 +0100280 auto typedExpected = reinterpret_cast<const T*>(expected);
281 auto typedActual = reinterpret_cast<const T*>(actual);
Finn Williams56870182020-11-20 13:57:53 +0000282
Teresa Charlin83b42912022-07-07 14:24:59 +0100283 T errorSum = 0;
284
285 for (unsigned int i = 0; i < size; i++)
Finn Williams56870182020-11-20 13:57:53 +0000286 {
Teresa Charlin83b42912022-07-07 14:24:59 +0100287 if (std::abs(typedExpected[i] - typedActual[i]) != 0)
288 {
289 std::cout << "";
290 }
291 errorSum += std::pow(std::abs(typedExpected[i] - typedActual[i]), 2);
Finn Williams56870182020-11-20 13:57:53 +0000292 }
293
Teresa Charlin83b42912022-07-07 14:24:59 +0100294 float rmse = std::sqrt(armnn::numeric_cast<float>(errorSum) / armnn::numeric_cast<float>(size / sizeof(T)));
295 return rmse;
296}