blob: 23b892ffb457411786acb27738936ce6846e7a80 [file] [log] [blame]
Jan Eilers45274902020-10-15 18:34:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NetworkExecutionUtils.hpp"
7
Rob Hughes9542f902021-07-14 09:48:54 +01008#include <armnnUtils/Filesystem.hpp>
Jan Eilers45274902020-10-15 18:34:43 +01009#include <InferenceTest.hpp>
10#include <ResolveType.hpp>
11
12#if defined(ARMNN_SERIALIZER)
13#include "armnnDeserializer/IDeserializer.hpp"
14#endif
Jan Eilers45274902020-10-15 18:34:43 +010015#if defined(ARMNN_TF_LITE_PARSER)
16#include "armnnTfLiteParser/ITfLiteParser.hpp"
17#endif
18#if defined(ARMNN_ONNX_PARSER)
19#include "armnnOnnxParser/IOnnxParser.hpp"
20#endif
21
Jan Eilers45274902020-10-15 18:34:43 +010022template<armnn::DataType NonQuantizedType>
23auto ParseDataArray(std::istream& stream);
24
25template<armnn::DataType QuantizedType>
26auto ParseDataArray(std::istream& stream,
27 const float& quantizationScale,
28 const int32_t& quantizationOffset);
29
30template<>
31auto ParseDataArray<armnn::DataType::Float32>(std::istream& stream)
32{
33 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
34}
35
36template<>
37auto ParseDataArray<armnn::DataType::Signed32>(std::istream& stream)
38{
39 return ParseArrayImpl<int>(stream, [](const std::string& s) { return std::stoi(s); });
40}
41
42template<>
43auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream)
44{
45 return ParseArrayImpl<uint8_t>(stream,
46 [](const std::string& s) { return armnn::numeric_cast<uint8_t>(std::stoi(s)); });
47}
48
Finn Williamsf806c4d2021-02-22 15:13:12 +000049
50template<>
51auto ParseDataArray<armnn::DataType::QSymmS8>(std::istream& stream)
52{
53 return ParseArrayImpl<int8_t>(stream,
54 [](const std::string& s) { return armnn::numeric_cast<int8_t>(std::stoi(s)); });
55}
56
57
58
Jan Eilers45274902020-10-15 18:34:43 +010059template<>
60auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream,
61 const float& quantizationScale,
62 const int32_t& quantizationOffset)
63{
64 return ParseArrayImpl<uint8_t>(stream,
65 [&quantizationScale, &quantizationOffset](const std::string& s)
66 {
67 return armnn::numeric_cast<uint8_t>(
68 armnn::Quantize<uint8_t>(std::stof(s),
69 quantizationScale,
70 quantizationOffset));
71 });
72}
73
74template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
75std::vector<T> GenerateDummyTensorData(unsigned int numElements)
76{
77 return std::vector<T>(numElements, static_cast<T>(0));
78}
79
80
81std::vector<unsigned int> ParseArray(std::istream& stream)
82{
83 return ParseArrayImpl<unsigned int>(
84 stream,
85 [](const std::string& s) { return armnn::numeric_cast<unsigned int>(std::stoi(s)); });
86}
87
88std::vector<std::string> ParseStringList(const std::string& inputString, const char* delimiter)
89{
90 std::stringstream stream(inputString);
91 return ParseArrayImpl<std::string>(stream, [](const std::string& s) {
92 return armnn::stringUtils::StringTrimCopy(s); }, delimiter);
93}
94
95
96TensorPrinter::TensorPrinter(const std::string& binding,
97 const armnn::TensorInfo& info,
98 const std::string& outputTensorFile,
99 bool dequantizeOutput)
100 : m_OutputBinding(binding)
101 , m_Scale(info.GetQuantizationScale())
102 , m_Offset(info.GetQuantizationOffset())
103 , m_OutputTensorFile(outputTensorFile)
104 , m_DequantizeOutput(dequantizeOutput) {}
105
106void TensorPrinter::operator()(const std::vector<float>& values)
107{
108 ForEachValue(values, [](float value)
109 {
110 printf("%f ", value);
111 });
112 WriteToFile(values);
113}
114
115void TensorPrinter::operator()(const std::vector<uint8_t>& values)
116{
117 if(m_DequantizeOutput)
118 {
119 auto& scale = m_Scale;
120 auto& offset = m_Offset;
121 std::vector<float> dequantizedValues;
122 ForEachValue(values, [&scale, &offset, &dequantizedValues](uint8_t value)
123 {
124 auto dequantizedValue = armnn::Dequantize(value, scale, offset);
125 printf("%f ", dequantizedValue);
126 dequantizedValues.push_back(dequantizedValue);
127 });
128 WriteToFile(dequantizedValues);
129 }
130 else
131 {
132 const std::vector<int> intValues(values.begin(), values.end());
133 operator()(intValues);
134 }
135}
136
Finn Williamsf806c4d2021-02-22 15:13:12 +0000137void TensorPrinter::operator()(const std::vector<int8_t>& values)
138{
139 ForEachValue(values, [](int8_t value)
140 {
141 printf("%d ", value);
142 });
143 WriteToFile(values);
144}
145
Jan Eilers45274902020-10-15 18:34:43 +0100146void TensorPrinter::operator()(const std::vector<int>& values)
147{
148 ForEachValue(values, [](int value)
149 {
150 printf("%d ", value);
151 });
152 WriteToFile(values);
153}
154
155template<typename Container, typename Delegate>
156void TensorPrinter::ForEachValue(const Container& c, Delegate delegate)
157{
158 std::cout << m_OutputBinding << ": ";
159 for (const auto& value : c)
160 {
161 delegate(value);
162 }
163 printf("\n");
164}
165
166template<typename T>
167void TensorPrinter::WriteToFile(const std::vector<T>& values)
168{
169 if (!m_OutputTensorFile.empty())
170 {
171 std::ofstream outputTensorFile;
172 outputTensorFile.open(m_OutputTensorFile, std::ofstream::out | std::ofstream::trunc);
173 if (outputTensorFile.is_open())
174 {
175 outputTensorFile << m_OutputBinding << ": ";
176 std::copy(values.begin(), values.end(), std::ostream_iterator<T>(outputTensorFile, " "));
177 }
178 else
179 {
180 ARMNN_LOG(info) << "Output Tensor File: " << m_OutputTensorFile << " could not be opened!";
181 }
182 outputTensorFile.close();
183 }
184}
185
Finn Williamsf806c4d2021-02-22 15:13:12 +0000186using TContainer =
187 mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>, std::vector<int8_t>>;
Jan Eilers45274902020-10-15 18:34:43 +0100188using QuantizationParams = std::pair<float, int32_t>;
189
190void PopulateTensorWithData(TContainer& tensorData,
191 unsigned int numElements,
192 const std::string& dataTypeStr,
193 const armnn::Optional<QuantizationParams>& qParams,
194 const armnn::Optional<std::string>& dataFile)
195{
196 const bool readFromFile = dataFile.has_value() && !dataFile.value().empty();
197 const bool quantizeData = qParams.has_value();
198
199 std::ifstream inputTensorFile;
200 if (readFromFile)
201 {
202 inputTensorFile = std::ifstream(dataFile.value());
203 }
204
205 if (dataTypeStr.compare("float") == 0)
206 {
207 if (quantizeData)
208 {
209 const float qScale = qParams.value().first;
210 const int qOffset = qParams.value().second;
211
212 tensorData = readFromFile ?
213 ParseDataArray<armnn::DataType::QAsymmU8>(inputTensorFile, qScale, qOffset) :
214 GenerateDummyTensorData<armnn::DataType::QAsymmU8>(numElements);
215 }
216 else
217 {
218 tensorData = readFromFile ?
219 ParseDataArray<armnn::DataType::Float32>(inputTensorFile) :
220 GenerateDummyTensorData<armnn::DataType::Float32>(numElements);
221 }
222 }
223 else if (dataTypeStr.compare("int") == 0)
224 {
225 tensorData = readFromFile ?
226 ParseDataArray<armnn::DataType::Signed32>(inputTensorFile) :
227 GenerateDummyTensorData<armnn::DataType::Signed32>(numElements);
228 }
Finn Williamsf806c4d2021-02-22 15:13:12 +0000229 else if (dataTypeStr.compare("qsymms8") == 0)
230 {
231 tensorData = readFromFile ?
232 ParseDataArray<armnn::DataType::QSymmS8>(inputTensorFile) :
233 GenerateDummyTensorData<armnn::DataType::QSymmS8>(numElements);
234 }
Jan Eilers45274902020-10-15 18:34:43 +0100235 else if (dataTypeStr.compare("qasymm8") == 0)
236 {
237 tensorData = readFromFile ?
238 ParseDataArray<armnn::DataType::QAsymmU8>(inputTensorFile) :
239 GenerateDummyTensorData<armnn::DataType::QAsymmU8>(numElements);
240 }
241 else
242 {
243 std::string errorMessage = "Unsupported tensor data type " + dataTypeStr;
244 ARMNN_LOG(fatal) << errorMessage;
245
246 inputTensorFile.close();
247 throw armnn::Exception(errorMessage);
248 }
249
250 inputTensorFile.close();
251}
252
253bool ValidatePath(const std::string& file, const bool expectFile)
254{
255 if (!fs::exists(file))
256 {
257 std::cerr << "Given file path '" << file << "' does not exist" << std::endl;
258 return false;
259 }
260 if (!fs::is_regular_file(file) && expectFile)
261 {
262 std::cerr << "Given file path '" << file << "' is not a regular file" << std::endl;
263 return false;
264 }
265 return true;
266}
267
268bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile)
269{
270 bool allPathsValid = true;
271 for (auto const& file : fileVec)
272 {
273 if(!ValidatePath(file, expectFile))
274 {
275 allPathsValid = false;
276 }
277 }
278 return allPathsValid;
279}
280
281
282