blob: 2136c446fbd04d740c6765a606da89eea9a51ad6 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Teresa Charlin83b42912022-07-07 14:24:59 +01002// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
Francis Murtaghbee4bc92019-06-18 12:30:37 +01003// SPDX-License-Identifier: MIT
4//
Francis Murtaghbee4bc92019-06-18 12:30:37 +01005
Jan Eilers45274902020-10-15 18:34:43 +01006#pragma once
7
Colm Donelana98e79a2022-12-06 21:32:29 +00008#include <armnn/BackendRegistry.hpp> // for BackendRegistryInstance
9#include <armnn/Logging.hpp> // for ScopedRecord, ARMNN_LOG
10#include <armnn/utility/NumericCast.hpp> // for numeric_cast
11#include <armnn/utility/StringUtils.hpp> // for StringTokenizer
12#include <armnn/BackendId.hpp> // for BackendId, BackendIdSet
13#include <armnn/Optional.hpp> // for Optional, EmptyOptional
14#include <armnn/Tensor.hpp> // for Tensor, TensorInfo
15#include <armnn/TypesUtils.hpp> // for Dequantize
16#include <chrono> // for duration
17#include <functional> // for function
Finn Williams56870182020-11-20 13:57:53 +000018#include <fstream>
Teresa Charlin83b42912022-07-07 14:24:59 +010019#include <iomanip>
Colm Donelana98e79a2022-12-06 21:32:29 +000020#include <iostream> // for ofstream, basic_istream
21#include <ratio> // for milli
22#include <string> // for string, getline, basic_string
23#include <type_traits> // for enable_if_t, is_floating_point
24#include <unordered_set> // for operator!=, operator==, _No...
25#include <vector> // for vector
26#include <math.h> // for pow, sqrt
27#include <stdint.h> // for int32_t
28#include <stdio.h> // for printf, size_t
29#include <stdlib.h> // for abs
30#include <algorithm> // for find, for_each
Francis Murtaghbee4bc92019-06-18 12:30:37 +010031
Teresa Charlin83b42912022-07-07 14:24:59 +010032/**
33 * Given a measured duration and a threshold time tell the user whether we succeeded or not.
34 *
35 * @param duration the measured inference duration.
36 * @param thresholdTime the threshold time in milliseconds.
37 * @return false if the measured time exceeded the threshold.
38 */
39bool CheckInferenceTimeThreshold(const std::chrono::duration<double, std::milli>& duration,
40 const double& thresholdTime);
41
42inline bool CheckRequestedBackendsAreValid(const std::vector<armnn::BackendId>& backendIds,
43 armnn::Optional<std::string&> invalidBackendIds = armnn::EmptyOptional())
44{
45 if (backendIds.empty())
46 {
47 return false;
48 }
49
50 armnn::BackendIdSet validBackendIds = armnn::BackendRegistryInstance().GetBackendIds();
51
52 bool allValid = true;
53 for (const auto& backendId : backendIds)
54 {
55 if (std::find(validBackendIds.begin(), validBackendIds.end(), backendId) == validBackendIds.end())
56 {
57 allValid = false;
58 if (invalidBackendIds)
59 {
60 if (!invalidBackendIds.value().empty())
61 {
62 invalidBackendIds.value() += ", ";
63 }
64 invalidBackendIds.value() += backendId;
65 }
66 }
67 }
68 return allValid;
69}
Francis Murtaghbee4bc92019-06-18 12:30:37 +010070
Jan Eilers45274902020-10-15 18:34:43 +010071std::vector<unsigned int> ParseArray(std::istream& stream);
Francis Murtaghbee4bc92019-06-18 12:30:37 +010072
Jan Eilers45274902020-10-15 18:34:43 +010073/// Splits a given string at every accurance of delimiter into a vector of string
74std::vector<std::string> ParseStringList(const std::string& inputString, const char* delimiter);
Francis Murtaghbee4bc92019-06-18 12:30:37 +010075
Colm Doneland0472622023-03-06 12:34:54 +000076double ComputeByteLevelRMSE(const void* expected, const void* actual, const size_t size);
77
Teresa Charlin83b42912022-07-07 14:24:59 +010078/// Dequantize an array of a given type
79/// @param array Type erased array to dequantize
80/// @param numElements Elements in the array
81/// @param array Type erased array to dequantize
82template <typename T>
83std::vector<float> DequantizeArray(const void* array, unsigned int numElements, float scale, int32_t offset)
Francis Murtaghbee4bc92019-06-18 12:30:37 +010084{
Teresa Charlin83b42912022-07-07 14:24:59 +010085 const T* quantizedArray = reinterpret_cast<const T*>(array);
86 std::vector<float> dequantizedVector;
87 dequantizedVector.reserve(numElements);
88 for (unsigned int i = 0; i < numElements; ++i)
89 {
90 float f = armnn::Dequantize(*(quantizedArray + i), scale, offset);
91 dequantizedVector.push_back(f);
92 }
93 return dequantizedVector;
94}
Francis Murtaghbee4bc92019-06-18 12:30:37 +010095
Teresa Charlin83b42912022-07-07 14:24:59 +010096void LogAndThrow(std::string eMsg);
Aron Virginas-Tarc82c8732019-10-24 17:07:43 +010097
Jan Eilers45274902020-10-15 18:34:43 +010098/**
99 * Verifies if the given string is a valid path. Reports invalid paths to std::err.
100 * @param file string - A string containing the path to check
101 * @param expectFile bool - If true, checks for a regular file.
102 * @return bool - True if given string is a valid path., false otherwise.
103 * */
104bool ValidatePath(const std::string& file, const bool expectFile);
Aron Virginas-Tarc82c8732019-10-24 17:07:43 +0100105
Jan Eilers45274902020-10-15 18:34:43 +0100106/**
107 * Verifies if a given vector of strings are valid paths. Reports invalid paths to std::err.
108 * @param fileVec vector of string - A vector of string containing the paths to check
109 * @param expectFile bool - If true, checks for a regular file.
110 * @return bool - True if all given strings are valid paths., false otherwise.
111 * */
Finn Williams56870182020-11-20 13:57:53 +0000112bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile);
113
Teresa Charlin83b42912022-07-07 14:24:59 +0100114/// Returns a function of read the given type as a string
115template <typename Integer, typename std::enable_if_t<std::is_integral<Integer>::value>* = nullptr>
116std::function<Integer(const std::string&)> GetParseElementFunc()
117{
118 return [](const std::string& s) { return armnn::numeric_cast<Integer>(std::stoi(s)); };
119}
120
121template <typename Float, std::enable_if_t<std::is_floating_point<Float>::value>* = nullptr>
122std::function<Float(const std::string&)> GetParseElementFunc()
123{
124 return [](const std::string& s) { return std::stof(s); };
125}
126
127template <typename T>
128void PopulateTensorWithData(T* tensor,
129 const unsigned int numElements,
130 const armnn::Optional<std::string>& dataFile,
131 const std::string& inputName)
132{
133 const bool readFromFile = dataFile.has_value() && !dataFile.value().empty();
134
135 std::ifstream inputTensorFile;
136 if (!readFromFile)
137 {
138 std::fill(tensor, tensor + numElements, 0);
139 return;
140 }
141 else
142 {
143 inputTensorFile = std::ifstream(dataFile.value());
144 }
145
146 auto parseElementFunc = GetParseElementFunc<T>();
147 std::string line;
148 unsigned int index = 0;
149 while (std::getline(inputTensorFile, line))
150 {
151 std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, "\t ,:");
152 for (const std::string& token : tokens)
153 {
154 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
155 {
156 try
157 {
158 if (index == numElements)
159 {
160 ARMNN_LOG(error) << "Number of elements: " << (index +1) << " in file \"" << dataFile.value()
161 << "\" does not match number of elements: " << numElements
162 << " for input \"" << inputName << "\".";
163 }
164 *(tensor + index) = parseElementFunc(token);
165 index++;
166 }
167 catch (const std::exception&)
168 {
169 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
170 }
171 }
172 }
173 }
174
175 if (index != numElements)
176 {
177 ARMNN_LOG(error) << "Number of elements: " << (index +1) << " in file \"" << inputName
178 << "\" does not match number of elements: " << numElements
179 << " for input \"" << inputName << "\".";
180 }
181}
182
183template<typename T>
184void WriteToFile(const std::string& outputTensorFileName,
185 const std::string& outputName,
186 const T* const array,
187 const unsigned int numElements)
188{
189 std::ofstream outputTensorFile;
190 outputTensorFile.open(outputTensorFileName, std::ofstream::out | std::ofstream::trunc);
191 if (outputTensorFile.is_open())
192 {
193 outputTensorFile << outputName << ": ";
Adam Jalkemo7bbf5652022-10-18 16:56:09 +0200194 for (std::size_t i = 0; i < numElements; ++i)
195 {
196 outputTensorFile << +array[i] << " ";
197 }
Teresa Charlin83b42912022-07-07 14:24:59 +0100198 }
199 else
200 {
201 ARMNN_LOG(info) << "Output Tensor File: " << outputTensorFileName << " could not be opened!";
202 }
203 outputTensorFile.close();
204}
205
206struct OutputWriteInfo
207{
208 const armnn::Optional<std::string>& m_OutputTensorFile;
209 const std::string& m_OutputName;
210 const armnn::Tensor& m_Tensor;
211 const bool m_PrintTensor;
212};
213
214template <typename T>
215void PrintTensor(OutputWriteInfo& info, const char* formatString)
216{
217 const T* array = reinterpret_cast<const T*>(info.m_Tensor.GetMemoryArea());
218
219 if (info.m_OutputTensorFile.has_value())
220 {
221 WriteToFile(info.m_OutputTensorFile.value(),
222 info.m_OutputName,
223 array,
224 info.m_Tensor.GetNumElements());
225 }
226
227 if (info.m_PrintTensor)
228 {
229 for (unsigned int i = 0; i < info.m_Tensor.GetNumElements(); i++)
230 {
231 printf(formatString, array[i]);
232 }
233 }
234}
235
236template <typename T>
237void PrintQuantizedTensor(OutputWriteInfo& info)
238{
239 std::vector<float> dequantizedValues;
240 auto tensor = info.m_Tensor;
241 dequantizedValues = DequantizeArray<T>(tensor.GetMemoryArea(),
242 tensor.GetNumElements(),
243 tensor.GetInfo().GetQuantizationScale(),
244 tensor.GetInfo().GetQuantizationOffset());
245
246 if (info.m_OutputTensorFile.has_value())
247 {
248 WriteToFile(info.m_OutputTensorFile.value(),
249 info.m_OutputName,
250 dequantizedValues.data(),
251 tensor.GetNumElements());
252 }
253
254 if (info.m_PrintTensor)
255 {
256 std::for_each(dequantizedValues.begin(), dequantizedValues.end(), [&](float value)
257 {
258 printf("%f ", value);
259 });
260 }
261}
262
Finn Williams56870182020-11-20 13:57:53 +0000263template<typename T, typename TParseElementFunc>
264std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char* chars = "\t ,:")
265{
266 std::vector<T> result;
267 // Processes line-by-line.
268 std::string line;
269 while (std::getline(stream, line))
270 {
271 std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
272 for (const std::string& token : tokens)
273 {
274 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
275 {
276 try
277 {
278 result.push_back(parseElementFunc(token));
279 }
280 catch (const std::exception&)
281 {
282 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
283 }
284 }
285 }
286 }
287
288 return result;
289}