blob: 0df3bf5ef5601388cbb0115938e011fc5f1e295c [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Colm Donelan0dfb2652023-06-22 10:19:17 +01002// Copyright © 2022, 2023 Arm Ltd and Contributors. All rights reserved.
Francis Murtaghbee4bc92019-06-18 12:30:37 +01003// SPDX-License-Identifier: MIT
4//
Francis Murtaghbee4bc92019-06-18 12:30:37 +01005
Jan Eilers45274902020-10-15 18:34:43 +01006#pragma once
7
Colm Donelana98e79a2022-12-06 21:32:29 +00008#include <armnn/BackendRegistry.hpp> // for BackendRegistryInstance
9#include <armnn/Logging.hpp> // for ScopedRecord, ARMNN_LOG
10#include <armnn/utility/NumericCast.hpp> // for numeric_cast
11#include <armnn/utility/StringUtils.hpp> // for StringTokenizer
12#include <armnn/BackendId.hpp> // for BackendId, BackendIdSet
13#include <armnn/Optional.hpp> // for Optional, EmptyOptional
14#include <armnn/Tensor.hpp> // for Tensor, TensorInfo
15#include <armnn/TypesUtils.hpp> // for Dequantize
16#include <chrono> // for duration
17#include <functional> // for function
Finn Williams56870182020-11-20 13:57:53 +000018#include <fstream>
Teresa Charlin83b42912022-07-07 14:24:59 +010019#include <iomanip>
Colm Donelana98e79a2022-12-06 21:32:29 +000020#include <iostream> // for ofstream, basic_istream
21#include <ratio> // for milli
22#include <string> // for string, getline, basic_string
23#include <type_traits> // for enable_if_t, is_floating_point
24#include <unordered_set> // for operator!=, operator==, _No...
25#include <vector> // for vector
26#include <math.h> // for pow, sqrt
27#include <stdint.h> // for int32_t
28#include <stdio.h> // for printf, size_t
29#include <stdlib.h> // for abs
30#include <algorithm> // for find, for_each
Francis Murtaghbee4bc92019-06-18 12:30:37 +010031
Teresa Charlin83b42912022-07-07 14:24:59 +010032/**
33 * Given a measured duration and a threshold time tell the user whether we succeeded or not.
34 *
35 * @param duration the measured inference duration.
36 * @param thresholdTime the threshold time in milliseconds.
37 * @return false if the measured time exceeded the threshold.
38 */
39bool CheckInferenceTimeThreshold(const std::chrono::duration<double, std::milli>& duration,
40 const double& thresholdTime);
41
42inline bool CheckRequestedBackendsAreValid(const std::vector<armnn::BackendId>& backendIds,
43 armnn::Optional<std::string&> invalidBackendIds = armnn::EmptyOptional())
44{
45 if (backendIds.empty())
46 {
47 return false;
48 }
49
50 armnn::BackendIdSet validBackendIds = armnn::BackendRegistryInstance().GetBackendIds();
51
52 bool allValid = true;
53 for (const auto& backendId : backendIds)
54 {
55 if (std::find(validBackendIds.begin(), validBackendIds.end(), backendId) == validBackendIds.end())
56 {
57 allValid = false;
58 if (invalidBackendIds)
59 {
60 if (!invalidBackendIds.value().empty())
61 {
62 invalidBackendIds.value() += ", ";
63 }
64 invalidBackendIds.value() += backendId;
65 }
66 }
67 }
68 return allValid;
69}
Francis Murtaghbee4bc92019-06-18 12:30:37 +010070
Jan Eilers45274902020-10-15 18:34:43 +010071std::vector<unsigned int> ParseArray(std::istream& stream);
Francis Murtaghbee4bc92019-06-18 12:30:37 +010072
Jan Eilers45274902020-10-15 18:34:43 +010073/// Splits a given string at every accurance of delimiter into a vector of string
74std::vector<std::string> ParseStringList(const std::string& inputString, const char* delimiter);
Francis Murtaghbee4bc92019-06-18 12:30:37 +010075
Colm Doneland0472622023-03-06 12:34:54 +000076double ComputeByteLevelRMSE(const void* expected, const void* actual, const size_t size);
77
Teresa Charlin83b42912022-07-07 14:24:59 +010078/// Dequantize an array of a given type
79/// @param array Type erased array to dequantize
80/// @param numElements Elements in the array
81/// @param array Type erased array to dequantize
82template <typename T>
83std::vector<float> DequantizeArray(const void* array, unsigned int numElements, float scale, int32_t offset)
Francis Murtaghbee4bc92019-06-18 12:30:37 +010084{
Teresa Charlin83b42912022-07-07 14:24:59 +010085 const T* quantizedArray = reinterpret_cast<const T*>(array);
86 std::vector<float> dequantizedVector;
87 dequantizedVector.reserve(numElements);
88 for (unsigned int i = 0; i < numElements; ++i)
89 {
90 float f = armnn::Dequantize(*(quantizedArray + i), scale, offset);
91 dequantizedVector.push_back(f);
92 }
93 return dequantizedVector;
94}
Francis Murtaghbee4bc92019-06-18 12:30:37 +010095
Teresa Charlin83b42912022-07-07 14:24:59 +010096void LogAndThrow(std::string eMsg);
Aron Virginas-Tarc82c8732019-10-24 17:07:43 +010097
Jan Eilers45274902020-10-15 18:34:43 +010098/**
99 * Verifies if the given string is a valid path. Reports invalid paths to std::err.
100 * @param file string - A string containing the path to check
101 * @param expectFile bool - If true, checks for a regular file.
102 * @return bool - True if given string is a valid path., false otherwise.
103 * */
104bool ValidatePath(const std::string& file, const bool expectFile);
Aron Virginas-Tarc82c8732019-10-24 17:07:43 +0100105
Jan Eilers45274902020-10-15 18:34:43 +0100106/**
107 * Verifies if a given vector of strings are valid paths. Reports invalid paths to std::err.
108 * @param fileVec vector of string - A vector of string containing the paths to check
109 * @param expectFile bool - If true, checks for a regular file.
110 * @return bool - True if all given strings are valid paths., false otherwise.
111 * */
Finn Williams56870182020-11-20 13:57:53 +0000112bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile);
113
Teresa Charlin83b42912022-07-07 14:24:59 +0100114/// Returns a function of read the given type as a string
115template <typename Integer, typename std::enable_if_t<std::is_integral<Integer>::value>* = nullptr>
116std::function<Integer(const std::string&)> GetParseElementFunc()
117{
118 return [](const std::string& s) { return armnn::numeric_cast<Integer>(std::stoi(s)); };
119}
120
121template <typename Float, std::enable_if_t<std::is_floating_point<Float>::value>* = nullptr>
122std::function<Float(const std::string&)> GetParseElementFunc()
123{
124 return [](const std::string& s) { return std::stof(s); };
125}
126
127template <typename T>
128void PopulateTensorWithData(T* tensor,
129 const unsigned int numElements,
130 const armnn::Optional<std::string>& dataFile,
131 const std::string& inputName)
132{
133 const bool readFromFile = dataFile.has_value() && !dataFile.value().empty();
134
135 std::ifstream inputTensorFile;
136 if (!readFromFile)
137 {
138 std::fill(tensor, tensor + numElements, 0);
139 return;
140 }
141 else
142 {
143 inputTensorFile = std::ifstream(dataFile.value());
144 }
145
146 auto parseElementFunc = GetParseElementFunc<T>();
147 std::string line;
148 unsigned int index = 0;
149 while (std::getline(inputTensorFile, line))
150 {
151 std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, "\t ,:");
152 for (const std::string& token : tokens)
153 {
154 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
155 {
156 try
157 {
158 if (index == numElements)
159 {
160 ARMNN_LOG(error) << "Number of elements: " << (index +1) << " in file \"" << dataFile.value()
161 << "\" does not match number of elements: " << numElements
162 << " for input \"" << inputName << "\".";
163 }
164 *(tensor + index) = parseElementFunc(token);
165 index++;
166 }
167 catch (const std::exception&)
168 {
169 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
170 }
171 }
172 }
173 }
174
175 if (index != numElements)
176 {
177 ARMNN_LOG(error) << "Number of elements: " << (index +1) << " in file \"" << inputName
178 << "\" does not match number of elements: " << numElements
179 << " for input \"" << inputName << "\".";
180 }
181}
182
183template<typename T>
184void WriteToFile(const std::string& outputTensorFileName,
185 const std::string& outputName,
186 const T* const array,
Colm Donelan0dfb2652023-06-22 10:19:17 +0100187 const unsigned int numElements,
188 armnn::DataType dataType)
Teresa Charlin83b42912022-07-07 14:24:59 +0100189{
190 std::ofstream outputTensorFile;
191 outputTensorFile.open(outputTensorFileName, std::ofstream::out | std::ofstream::trunc);
192 if (outputTensorFile.is_open())
193 {
Colm Donelan0dfb2652023-06-22 10:19:17 +0100194 outputTensorFile << outputName << ", "<< GetDataTypeName(dataType) << " : ";
Adam Jalkemo7bbf5652022-10-18 16:56:09 +0200195 for (std::size_t i = 0; i < numElements; ++i)
196 {
197 outputTensorFile << +array[i] << " ";
198 }
Teresa Charlin83b42912022-07-07 14:24:59 +0100199 }
200 else
201 {
202 ARMNN_LOG(info) << "Output Tensor File: " << outputTensorFileName << " could not be opened!";
203 }
204 outputTensorFile.close();
205}
206
207struct OutputWriteInfo
208{
209 const armnn::Optional<std::string>& m_OutputTensorFile;
210 const std::string& m_OutputName;
211 const armnn::Tensor& m_Tensor;
212 const bool m_PrintTensor;
Colm Donelan0dfb2652023-06-22 10:19:17 +0100213 const armnn::DataType m_DataType;
Teresa Charlin83b42912022-07-07 14:24:59 +0100214};
215
216template <typename T>
217void PrintTensor(OutputWriteInfo& info, const char* formatString)
218{
219 const T* array = reinterpret_cast<const T*>(info.m_Tensor.GetMemoryArea());
220
221 if (info.m_OutputTensorFile.has_value())
222 {
223 WriteToFile(info.m_OutputTensorFile.value(),
224 info.m_OutputName,
225 array,
Colm Donelan0dfb2652023-06-22 10:19:17 +0100226 info.m_Tensor.GetNumElements(),
227 info.m_DataType);
Teresa Charlin83b42912022-07-07 14:24:59 +0100228 }
229
230 if (info.m_PrintTensor)
231 {
232 for (unsigned int i = 0; i < info.m_Tensor.GetNumElements(); i++)
233 {
234 printf(formatString, array[i]);
235 }
236 }
237}
238
239template <typename T>
240void PrintQuantizedTensor(OutputWriteInfo& info)
241{
242 std::vector<float> dequantizedValues;
243 auto tensor = info.m_Tensor;
244 dequantizedValues = DequantizeArray<T>(tensor.GetMemoryArea(),
245 tensor.GetNumElements(),
246 tensor.GetInfo().GetQuantizationScale(),
247 tensor.GetInfo().GetQuantizationOffset());
248
249 if (info.m_OutputTensorFile.has_value())
250 {
251 WriteToFile(info.m_OutputTensorFile.value(),
252 info.m_OutputName,
253 dequantizedValues.data(),
Colm Donelan0dfb2652023-06-22 10:19:17 +0100254 tensor.GetNumElements(),
255 info.m_DataType);
Teresa Charlin83b42912022-07-07 14:24:59 +0100256 }
257
258 if (info.m_PrintTensor)
259 {
260 std::for_each(dequantizedValues.begin(), dequantizedValues.end(), [&](float value)
261 {
262 printf("%f ", value);
263 });
264 }
265}
266
Finn Williams56870182020-11-20 13:57:53 +0000267template<typename T, typename TParseElementFunc>
268std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char* chars = "\t ,:")
269{
270 std::vector<T> result;
271 // Processes line-by-line.
272 std::string line;
273 while (std::getline(stream, line))
274 {
275 std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
276 for (const std::string& token : tokens)
277 {
278 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
279 {
280 try
281 {
282 result.push_back(parseElementFunc(token));
283 }
284 catch (const std::exception&)
285 {
286 ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
287 }
288 }
289 }
290 }
291
292 return result;
293}