blob: 74c878304d089c05c40eaf2e435eb7eb663c7bba [file] [log] [blame]
Jan Eilers45274902020-10-15 18:34:43 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "NetworkExecutionUtils.hpp"
7
8#include <Filesystem.hpp>
9#include <InferenceTest.hpp>
10#include <ResolveType.hpp>
11
12#if defined(ARMNN_SERIALIZER)
13#include "armnnDeserializer/IDeserializer.hpp"
14#endif
Jan Eilers45274902020-10-15 18:34:43 +010015#if defined(ARMNN_TF_PARSER)
16#include "armnnTfParser/ITfParser.hpp"
17#endif
18#if defined(ARMNN_TF_LITE_PARSER)
19#include "armnnTfLiteParser/ITfLiteParser.hpp"
20#endif
21#if defined(ARMNN_ONNX_PARSER)
22#include "armnnOnnxParser/IOnnxParser.hpp"
23#endif
24
Jan Eilers45274902020-10-15 18:34:43 +010025template<armnn::DataType NonQuantizedType>
26auto ParseDataArray(std::istream& stream);
27
28template<armnn::DataType QuantizedType>
29auto ParseDataArray(std::istream& stream,
30 const float& quantizationScale,
31 const int32_t& quantizationOffset);
32
33template<>
34auto ParseDataArray<armnn::DataType::Float32>(std::istream& stream)
35{
36 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
37}
38
39template<>
40auto ParseDataArray<armnn::DataType::Signed32>(std::istream& stream)
41{
42 return ParseArrayImpl<int>(stream, [](const std::string& s) { return std::stoi(s); });
43}
44
45template<>
46auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream)
47{
48 return ParseArrayImpl<uint8_t>(stream,
49 [](const std::string& s) { return armnn::numeric_cast<uint8_t>(std::stoi(s)); });
50}
51
Finn Williamsf806c4d2021-02-22 15:13:12 +000052
53template<>
54auto ParseDataArray<armnn::DataType::QSymmS8>(std::istream& stream)
55{
56 return ParseArrayImpl<int8_t>(stream,
57 [](const std::string& s) { return armnn::numeric_cast<int8_t>(std::stoi(s)); });
58}
59
60
61
Jan Eilers45274902020-10-15 18:34:43 +010062template<>
63auto ParseDataArray<armnn::DataType::QAsymmU8>(std::istream& stream,
64 const float& quantizationScale,
65 const int32_t& quantizationOffset)
66{
67 return ParseArrayImpl<uint8_t>(stream,
68 [&quantizationScale, &quantizationOffset](const std::string& s)
69 {
70 return armnn::numeric_cast<uint8_t>(
71 armnn::Quantize<uint8_t>(std::stof(s),
72 quantizationScale,
73 quantizationOffset));
74 });
75}
76
77template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
78std::vector<T> GenerateDummyTensorData(unsigned int numElements)
79{
80 return std::vector<T>(numElements, static_cast<T>(0));
81}
82
83
84std::vector<unsigned int> ParseArray(std::istream& stream)
85{
86 return ParseArrayImpl<unsigned int>(
87 stream,
88 [](const std::string& s) { return armnn::numeric_cast<unsigned int>(std::stoi(s)); });
89}
90
91std::vector<std::string> ParseStringList(const std::string& inputString, const char* delimiter)
92{
93 std::stringstream stream(inputString);
94 return ParseArrayImpl<std::string>(stream, [](const std::string& s) {
95 return armnn::stringUtils::StringTrimCopy(s); }, delimiter);
96}
97
98
99TensorPrinter::TensorPrinter(const std::string& binding,
100 const armnn::TensorInfo& info,
101 const std::string& outputTensorFile,
102 bool dequantizeOutput)
103 : m_OutputBinding(binding)
104 , m_Scale(info.GetQuantizationScale())
105 , m_Offset(info.GetQuantizationOffset())
106 , m_OutputTensorFile(outputTensorFile)
107 , m_DequantizeOutput(dequantizeOutput) {}
108
109void TensorPrinter::operator()(const std::vector<float>& values)
110{
111 ForEachValue(values, [](float value)
112 {
113 printf("%f ", value);
114 });
115 WriteToFile(values);
116}
117
118void TensorPrinter::operator()(const std::vector<uint8_t>& values)
119{
120 if(m_DequantizeOutput)
121 {
122 auto& scale = m_Scale;
123 auto& offset = m_Offset;
124 std::vector<float> dequantizedValues;
125 ForEachValue(values, [&scale, &offset, &dequantizedValues](uint8_t value)
126 {
127 auto dequantizedValue = armnn::Dequantize(value, scale, offset);
128 printf("%f ", dequantizedValue);
129 dequantizedValues.push_back(dequantizedValue);
130 });
131 WriteToFile(dequantizedValues);
132 }
133 else
134 {
135 const std::vector<int> intValues(values.begin(), values.end());
136 operator()(intValues);
137 }
138}
139
Finn Williamsf806c4d2021-02-22 15:13:12 +0000140void TensorPrinter::operator()(const std::vector<int8_t>& values)
141{
142 ForEachValue(values, [](int8_t value)
143 {
144 printf("%d ", value);
145 });
146 WriteToFile(values);
147}
148
Jan Eilers45274902020-10-15 18:34:43 +0100149void TensorPrinter::operator()(const std::vector<int>& values)
150{
151 ForEachValue(values, [](int value)
152 {
153 printf("%d ", value);
154 });
155 WriteToFile(values);
156}
157
158template<typename Container, typename Delegate>
159void TensorPrinter::ForEachValue(const Container& c, Delegate delegate)
160{
161 std::cout << m_OutputBinding << ": ";
162 for (const auto& value : c)
163 {
164 delegate(value);
165 }
166 printf("\n");
167}
168
169template<typename T>
170void TensorPrinter::WriteToFile(const std::vector<T>& values)
171{
172 if (!m_OutputTensorFile.empty())
173 {
174 std::ofstream outputTensorFile;
175 outputTensorFile.open(m_OutputTensorFile, std::ofstream::out | std::ofstream::trunc);
176 if (outputTensorFile.is_open())
177 {
178 outputTensorFile << m_OutputBinding << ": ";
179 std::copy(values.begin(), values.end(), std::ostream_iterator<T>(outputTensorFile, " "));
180 }
181 else
182 {
183 ARMNN_LOG(info) << "Output Tensor File: " << m_OutputTensorFile << " could not be opened!";
184 }
185 outputTensorFile.close();
186 }
187}
188
Finn Williamsf806c4d2021-02-22 15:13:12 +0000189using TContainer =
190 mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>, std::vector<int8_t>>;
Jan Eilers45274902020-10-15 18:34:43 +0100191using QuantizationParams = std::pair<float, int32_t>;
192
193void PopulateTensorWithData(TContainer& tensorData,
194 unsigned int numElements,
195 const std::string& dataTypeStr,
196 const armnn::Optional<QuantizationParams>& qParams,
197 const armnn::Optional<std::string>& dataFile)
198{
199 const bool readFromFile = dataFile.has_value() && !dataFile.value().empty();
200 const bool quantizeData = qParams.has_value();
201
202 std::ifstream inputTensorFile;
203 if (readFromFile)
204 {
205 inputTensorFile = std::ifstream(dataFile.value());
206 }
207
208 if (dataTypeStr.compare("float") == 0)
209 {
210 if (quantizeData)
211 {
212 const float qScale = qParams.value().first;
213 const int qOffset = qParams.value().second;
214
215 tensorData = readFromFile ?
216 ParseDataArray<armnn::DataType::QAsymmU8>(inputTensorFile, qScale, qOffset) :
217 GenerateDummyTensorData<armnn::DataType::QAsymmU8>(numElements);
218 }
219 else
220 {
221 tensorData = readFromFile ?
222 ParseDataArray<armnn::DataType::Float32>(inputTensorFile) :
223 GenerateDummyTensorData<armnn::DataType::Float32>(numElements);
224 }
225 }
226 else if (dataTypeStr.compare("int") == 0)
227 {
228 tensorData = readFromFile ?
229 ParseDataArray<armnn::DataType::Signed32>(inputTensorFile) :
230 GenerateDummyTensorData<armnn::DataType::Signed32>(numElements);
231 }
Finn Williamsf806c4d2021-02-22 15:13:12 +0000232 else if (dataTypeStr.compare("qsymms8") == 0)
233 {
234 tensorData = readFromFile ?
235 ParseDataArray<armnn::DataType::QSymmS8>(inputTensorFile) :
236 GenerateDummyTensorData<armnn::DataType::QSymmS8>(numElements);
237 }
Jan Eilers45274902020-10-15 18:34:43 +0100238 else if (dataTypeStr.compare("qasymm8") == 0)
239 {
240 tensorData = readFromFile ?
241 ParseDataArray<armnn::DataType::QAsymmU8>(inputTensorFile) :
242 GenerateDummyTensorData<armnn::DataType::QAsymmU8>(numElements);
243 }
244 else
245 {
246 std::string errorMessage = "Unsupported tensor data type " + dataTypeStr;
247 ARMNN_LOG(fatal) << errorMessage;
248
249 inputTensorFile.close();
250 throw armnn::Exception(errorMessage);
251 }
252
253 inputTensorFile.close();
254}
255
256bool ValidatePath(const std::string& file, const bool expectFile)
257{
258 if (!fs::exists(file))
259 {
260 std::cerr << "Given file path '" << file << "' does not exist" << std::endl;
261 return false;
262 }
263 if (!fs::is_regular_file(file) && expectFile)
264 {
265 std::cerr << "Given file path '" << file << "' is not a regular file" << std::endl;
266 return false;
267 }
268 return true;
269}
270
271bool ValidatePaths(const std::vector<std::string>& fileVec, const bool expectFile)
272{
273 bool allPathsValid = true;
274 for (auto const& file : fileVec)
275 {
276 if(!ValidatePath(file, expectFile))
277 {
278 allPathsValid = false;
279 }
280 }
281 return allPathsValid;
282}
283
284
285