blob: d902d23d86866c552b3a50dddab85e588b79dfd9 [file] [log] [blame]
Jan Eilers45274902020-10-15 18:34:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NetworkExecutionUtils.hpp"
7
8#include <Filesystem.hpp>
9#include <InferenceTest.hpp>
10#include <ResolveType.hpp>
11
12#if defined(ARMNN_SERIALIZER)
13#include "armnnDeserializer/IDeserializer.hpp"
14#endif
15#if defined(ARMNN_CAFFE_PARSER)
16#include "armnnCaffeParser/ICaffeParser.hpp"
17#endif
18#if defined(ARMNN_TF_PARSER)
19#include "armnnTfParser/ITfParser.hpp"
20#endif
21#if defined(ARMNN_TF_LITE_PARSER)
22#include "armnnTfLiteParser/ITfLiteParser.hpp"
23#endif
24#if defined(ARMNN_ONNX_PARSER)
25#include "armnnOnnxParser/IOnnxParser.hpp"
26#endif
27
Jan Eilers45274902020-10-15 18:34:43 +010028template<armnn::DataType NonQuantizedType>
29auto ParseDataArray(std::istream& stream);
30
31template<armnn::DataType QuantizedType>
32auto ParseDataArray(std::istream& stream,
33 const float& quantizationScale,
34 const int32_t& quantizationOffset);
35
36template<>
37auto ParseDataArray<armnn::DataType::Float32>(std::istream& stream)
38{
39 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
40}
41
42template<>
43auto ParseDataArray<armnn::DataType::Signed32>(std::istream& stream)
44{
45 return ParseArrayImpl<int>(stream, [](const std::string& s) { return std::stoi(s); });
46}
47
48template<>
49auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream)
50{
51 return ParseArrayImpl<uint8_t>(stream,
52 [](const std::string& s) { return armnn::numeric_cast<uint8_t>(std::stoi(s)); });
53}
54
Finn Williamsf806c4d2021-02-22 15:13:12 +000055
56template<>
57auto ParseDataArray<armnn::DataType::QSymmS8>(std::istream& stream)
58{
59 return ParseArrayImpl<int8_t>(stream,
60 [](const std::string& s) { return armnn::numeric_cast<int8_t>(std::stoi(s)); });
61}
62
63
64
Jan Eilers45274902020-10-15 18:34:43 +010065template<>
66auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream,
67 const float& quantizationScale,
68 const int32_t& quantizationOffset)
69{
70 return ParseArrayImpl<uint8_t>(stream,
71 [&quantizationScale, &quantizationOffset](const std::string& s)
72 {
73 return armnn::numeric_cast<uint8_t>(
74 armnn::Quantize<uint8_t>(std::stof(s),
75 quantizationScale,
76 quantizationOffset));
77 });
78}
79
80template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
81std::vector<T> GenerateDummyTensorData(unsigned int numElements)
82{
83 return std::vector<T>(numElements, static_cast<T>(0));
84}
85
86
87std::vector<unsigned int> ParseArray(std::istream& stream)
88{
89 return ParseArrayImpl<unsigned int>(
90 stream,
91 [](const std::string& s) { return armnn::numeric_cast<unsigned int>(std::stoi(s)); });
92}
93
94std::vector<std::string> ParseStringList(const std::string& inputString, const char* delimiter)
95{
96 std::stringstream stream(inputString);
97 return ParseArrayImpl<std::string>(stream, [](const std::string& s) {
98 return armnn::stringUtils::StringTrimCopy(s); }, delimiter);
99}
100
101
102TensorPrinter::TensorPrinter(const std::string& binding,
103 const armnn::TensorInfo& info,
104 const std::string& outputTensorFile,
105 bool dequantizeOutput)
106 : m_OutputBinding(binding)
107 , m_Scale(info.GetQuantizationScale())
108 , m_Offset(info.GetQuantizationOffset())
109 , m_OutputTensorFile(outputTensorFile)
110 , m_DequantizeOutput(dequantizeOutput) {}
111
112void TensorPrinter::operator()(const std::vector<float>& values)
113{
114 ForEachValue(values, [](float value)
115 {
116 printf("%f ", value);
117 });
118 WriteToFile(values);
119}
120
121void TensorPrinter::operator()(const std::vector<uint8_t>& values)
122{
123 if(m_DequantizeOutput)
124 {
125 auto& scale = m_Scale;
126 auto& offset = m_Offset;
127 std::vector<float> dequantizedValues;
128 ForEachValue(values, [&scale, &offset, &dequantizedValues](uint8_t value)
129 {
130 auto dequantizedValue = armnn::Dequantize(value, scale, offset);
131 printf("%f ", dequantizedValue);
132 dequantizedValues.push_back(dequantizedValue);
133 });
134 WriteToFile(dequantizedValues);
135 }
136 else
137 {
138 const std::vector<int> intValues(values.begin(), values.end());
139 operator()(intValues);
140 }
141}
142
Finn Williamsf806c4d2021-02-22 15:13:12 +0000143void TensorPrinter::operator()(const std::vector<int8_t>& values)
144{
145 ForEachValue(values, [](int8_t value)
146 {
147 printf("%d ", value);
148 });
149 WriteToFile(values);
150}
151
Jan Eilers45274902020-10-15 18:34:43 +0100152void TensorPrinter::operator()(const std::vector<int>& values)
153{
154 ForEachValue(values, [](int value)
155 {
156 printf("%d ", value);
157 });
158 WriteToFile(values);
159}
160
161template<typename Container, typename Delegate>
162void TensorPrinter::ForEachValue(const Container& c, Delegate delegate)
163{
164 std::cout << m_OutputBinding << ": ";
165 for (const auto& value : c)
166 {
167 delegate(value);
168 }
169 printf("\n");
170}
171
172template<typename T>
173void TensorPrinter::WriteToFile(const std::vector<T>& values)
174{
175 if (!m_OutputTensorFile.empty())
176 {
177 std::ofstream outputTensorFile;
178 outputTensorFile.open(m_OutputTensorFile, std::ofstream::out | std::ofstream::trunc);
179 if (outputTensorFile.is_open())
180 {
181 outputTensorFile << m_OutputBinding << ": ";
182 std::copy(values.begin(), values.end(), std::ostream_iterator<T>(outputTensorFile, " "));
183 }
184 else
185 {
186 ARMNN_LOG(info) << "Output Tensor File: " << m_OutputTensorFile << " could not be opened!";
187 }
188 outputTensorFile.close();
189 }
190}
191
Finn Williamsf806c4d2021-02-22 15:13:12 +0000192using TContainer =
193 mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>, std::vector<int8_t>>;
Jan Eilers45274902020-10-15 18:34:43 +0100194using QuantizationParams = std::pair<float, int32_t>;
195
196void PopulateTensorWithData(TContainer& tensorData,
197 unsigned int numElements,
198 const std::string& dataTypeStr,
199 const armnn::Optional<QuantizationParams>& qParams,
200 const armnn::Optional<std::string>& dataFile)
201{
202 const bool readFromFile = dataFile.has_value() && !dataFile.value().empty();
203 const bool quantizeData = qParams.has_value();
204
205 std::ifstream inputTensorFile;
206 if (readFromFile)
207 {
208 inputTensorFile = std::ifstream(dataFile.value());
209 }
210
211 if (dataTypeStr.compare("float") == 0)
212 {
213 if (quantizeData)
214 {
215 const float qScale = qParams.value().first;
216 const int qOffset = qParams.value().second;
217
218 tensorData = readFromFile ?
219 ParseDataArray<armnn::DataType::QAsymmU8>(inputTensorFile, qScale, qOffset) :
220 GenerateDummyTensorData<armnn::DataType::QAsymmU8>(numElements);
221 }
222 else
223 {
224 tensorData = readFromFile ?
225 ParseDataArray<armnn::DataType::Float32>(inputTensorFile) :
226 GenerateDummyTensorData<armnn::DataType::Float32>(numElements);
227 }
228 }
229 else if (dataTypeStr.compare("int") == 0)
230 {
231 tensorData = readFromFile ?
232 ParseDataArray<armnn::DataType::Signed32>(inputTensorFile) :
233 GenerateDummyTensorData<armnn::DataType::Signed32>(numElements);
234 }
Finn Williamsf806c4d2021-02-22 15:13:12 +0000235 else if (dataTypeStr.compare("qsymms8") == 0)
236 {
237 tensorData = readFromFile ?
238 ParseDataArray<armnn::DataType::QSymmS8>(inputTensorFile) :
239 GenerateDummyTensorData<armnn::DataType::QSymmS8>(numElements);
240 }
Jan Eilers45274902020-10-15 18:34:43 +0100241 else if (dataTypeStr.compare("qasymm8") == 0)
242 {
243 tensorData = readFromFile ?
244 ParseDataArray<armnn::DataType::QAsymmU8>(inputTensorFile) :
245 GenerateDummyTensorData<armnn::DataType::QAsymmU8>(numElements);
246 }
247 else
248 {
249 std::string errorMessage = "Unsupported tensor data type " + dataTypeStr;
250 ARMNN_LOG(fatal) << errorMessage;
251
252 inputTensorFile.close();
253 throw armnn::Exception(errorMessage);
254 }
255
256 inputTensorFile.close();
257}
258
259bool ValidatePath(const std::string& file, const bool expectFile)
260{
261 if (!fs::exists(file))
262 {
263 std::cerr << "Given file path '" << file << "' does not exist" << std::endl;
264 return false;
265 }
266 if (!fs::is_regular_file(file) && expectFile)
267 {
268 std::cerr << "Given file path '" << file << "' is not a regular file" << std::endl;
269 return false;
270 }
271 return true;
272}
273
274bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile)
275{
276 bool allPathsValid = true;
277 for (auto const& file : fileVec)
278 {
279 if(!ValidatePath(file, expectFile))
280 {
281 allPathsValid = false;
282 }
283 }
284 return allPathsValid;
285}
286
287
288