blob: a12a66ea256b4549a5517247979aaadd9a6aae3f [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Mike Kellya9c32672023-12-04 17:23:09 +00002// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#pragma once
7
Aron Virginas-Tar9c5db112018-10-25 11:10:49 +01008#include <armnn/IRuntime.hpp>
Colm Donelanc42a9872022-02-02 16:35:09 +00009#include <armnnTestUtils/TensorHelpers.hpp>
Aron Virginas-Tar9c5db112018-10-25 11:10:49 +010010
narpra016f37f832018-12-21 18:30:00 +000011#include <Network.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000012#include <VerificationHelpers.hpp>
Aron Virginas-Tar9c5db112018-10-25 11:10:49 +010013
Sadik Armagan1625efc2021-06-10 18:24:34 +010014#include <doctest/doctest.h>
Colm Donelan5b5c2222020-09-09 12:48:16 +010015#include <fmt/format.h>
Aron Virginas-Tar9c5db112018-10-25 11:10:49 +010016
Ferran Balaguer51dd62f2019-01-11 19:29:18 +000017#include <iomanip>
telsoa014fcda012018-03-09 14:13:49 +000018#include <string>
19
telsoa01c577f2c2018-08-31 09:22:23 +010020namespace armnnUtils
21{
surmeh013537c2c2018-05-18 16:31:43 +010022
telsoa014fcda012018-03-09 14:13:49 +000023template<typename TParser>
24struct ParserPrototxtFixture
25{
26 ParserPrototxtFixture()
27 : m_Parser(TParser::Create())
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +000028 , m_Runtime(armnn::IRuntime::Create(armnn::IRuntime::CreationOptions()))
telsoa014fcda012018-03-09 14:13:49 +000029 , m_NetworkIdentifier(-1)
surmeh013537c2c2018-05-18 16:31:43 +010030 {
surmeh013537c2c2018-05-18 16:31:43 +010031 }
telsoa014fcda012018-03-09 14:13:49 +000032
33 /// Parses and loads the network defined by the m_Prototext string.
34 /// @{
35 void SetupSingleInputSingleOutput(const std::string& inputName, const std::string& outputName);
36 void SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape,
37 const std::string& inputName,
38 const std::string& outputName);
Ferran Balaguer51dd62f2019-01-11 19:29:18 +000039 void SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape,
40 const armnn::TensorShape& outputTensorShape,
41 const std::string& inputName,
42 const std::string& outputName);
telsoa014fcda012018-03-09 14:13:49 +000043 void Setup(const std::map<std::string, armnn::TensorShape>& inputShapes,
44 const std::vector<std::string>& requestedOutputs);
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +010045 void Setup(const std::map<std::string, armnn::TensorShape>& inputShapes);
telsoa01c577f2c2018-08-31 09:22:23 +010046 void Setup();
narpra016f37f832018-12-21 18:30:00 +000047 armnn::IOptimizedNetworkPtr SetupOptimizedNetwork(
48 const std::map<std::string,armnn::TensorShape>& inputShapes,
49 const std::vector<std::string>& requestedOutputs);
telsoa014fcda012018-03-09 14:13:49 +000050 /// @}
51
52 /// Executes the network with the given input tensor and checks the result against the given output tensor.
telsoa01c577f2c2018-08-31 09:22:23 +010053 /// This overload assumes that the network has a single input and a single output.
telsoa014fcda012018-03-09 14:13:49 +000054 template <std::size_t NumOutputDimensions>
55 void RunTest(const std::vector<float>& inputData, const std::vector<float>& expectedOutputData);
56
kevmay012b4d88e2019-01-24 14:05:09 +000057 /// Executes the network with the given input tensor and checks the result against the given output tensor.
58 /// Calls RunTest with output type of uint8_t for checking comparison operators.
59 template <std::size_t NumOutputDimensions>
60 void RunComparisonTest(const std::map<std::string, std::vector<float>>& inputData,
61 const std::map<std::string, std::vector<uint8_t>>& expectedOutputData);
62
telsoa014fcda012018-03-09 14:13:49 +000063 /// Executes the network with the given input tensors and checks the results against the given output tensors.
64 /// This overload supports multiple inputs and multiple outputs, identified by name.
kevmay012b4d88e2019-01-24 14:05:09 +000065 template <std::size_t NumOutputDimensions, typename T = float>
telsoa014fcda012018-03-09 14:13:49 +000066 void RunTest(const std::map<std::string, std::vector<float>>& inputData,
kevmay012b4d88e2019-01-24 14:05:09 +000067 const std::map<std::string, std::vector<T>>& expectedOutputData);
telsoa014fcda012018-03-09 14:13:49 +000068
surmeh013537c2c2018-05-18 16:31:43 +010069 std::string m_Prototext;
70 std::unique_ptr<TParser, void(*)(TParser* parser)> m_Parser;
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +000071 armnn::IRuntimePtr m_Runtime;
surmeh013537c2c2018-05-18 16:31:43 +010072 armnn::NetworkId m_NetworkIdentifier;
telsoa014fcda012018-03-09 14:13:49 +000073
74 /// If the single-input-single-output overload of Setup() is called, these will store the input and output name
75 /// so they don't need to be passed to the single-input-single-output overload of RunTest().
76 /// @{
77 std::string m_SingleInputName;
78 std::string m_SingleOutputName;
79 /// @}
Ferran Balaguer51dd62f2019-01-11 19:29:18 +000080
81 /// This will store the output shape so it don't need to be passed to the single-input-single-output overload
82 /// of RunTest().
83 armnn::TensorShape m_SingleOutputShape;
telsoa014fcda012018-03-09 14:13:49 +000084};
85
86template<typename TParser>
87void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const std::string& inputName,
88 const std::string& outputName)
89{
telsoa01c577f2c2018-08-31 09:22:23 +010090 // Stores the input and output name so they don't need to be passed to the single-input-single-output RunTest().
telsoa014fcda012018-03-09 14:13:49 +000091 m_SingleInputName = inputName;
92 m_SingleOutputName = outputName;
93 Setup({ }, { outputName });
94}
95
96template<typename TParser>
97void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape,
98 const std::string& inputName,
99 const std::string& outputName)
100{
telsoa01c577f2c2018-08-31 09:22:23 +0100101 // Stores the input and output name so they don't need to be passed to the single-input-single-output RunTest().
telsoa014fcda012018-03-09 14:13:49 +0000102 m_SingleInputName = inputName;
103 m_SingleOutputName = outputName;
104 Setup({ { inputName, inputTensorShape } }, { outputName });
105}
106
107template<typename TParser>
Ferran Balaguer51dd62f2019-01-11 19:29:18 +0000108void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape,
109 const armnn::TensorShape& outputTensorShape,
110 const std::string& inputName,
111 const std::string& outputName)
112{
113 // Stores the input name, the output name and the output tensor shape
114 // so they don't need to be passed to the single-input-single-output RunTest().
115 m_SingleInputName = inputName;
116 m_SingleOutputName = outputName;
117 m_SingleOutputShape = outputTensorShape;
118 Setup({ { inputName, inputTensorShape } }, { outputName });
119}
120
121template<typename TParser>
telsoa014fcda012018-03-09 14:13:49 +0000122void ParserPrototxtFixture<TParser>::Setup(const std::map<std::string, armnn::TensorShape>& inputShapes,
123 const std::vector<std::string>& requestedOutputs)
124{
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000125 std::string errorMessage;
telsoa01c577f2c2018-08-31 09:22:23 +0100126
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000127 armnn::INetworkPtr network =
128 m_Parser->CreateNetworkFromString(m_Prototext.c_str(), inputShapes, requestedOutputs);
129 auto optimized = Optimize(*network, { armnn::Compute::CpuRef }, m_Runtime->GetDeviceSpec());
Mike Kellya9c32672023-12-04 17:23:09 +0000130 armnn::Status ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optimized), errorMessage);
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000131 if (ret != armnn::Status::Success)
132 {
Colm Donelan5b5c2222020-09-09 12:48:16 +0100133 throw armnn::Exception(fmt::format("LoadNetwork failed with error: '{0}' {1}",
134 errorMessage,
135 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100136 }
137}
138
139template<typename TParser>
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100140void ParserPrototxtFixture<TParser>::Setup(const std::map<std::string, armnn::TensorShape>& inputShapes)
141{
142 std::string errorMessage;
143
144 armnn::INetworkPtr network =
145 m_Parser->CreateNetworkFromString(m_Prototext.c_str(), inputShapes);
146 auto optimized = Optimize(*network, { armnn::Compute::CpuRef }, m_Runtime->GetDeviceSpec());
Mike Kellya9c32672023-12-04 17:23:09 +0000147 armnn::Status ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optimized), errorMessage);
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100148 if (ret != armnn::Status::Success)
149 {
150 throw armnn::Exception(fmt::format("LoadNetwork failed with error: '{0}' {1}",
151 errorMessage,
152 CHECK_LOCATION().AsString()));
153 }
154}
155
156template<typename TParser>
telsoa01c577f2c2018-08-31 09:22:23 +0100157void ParserPrototxtFixture<TParser>::Setup()
158{
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000159 std::string errorMessage;
telsoa01c577f2c2018-08-31 09:22:23 +0100160
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000161 armnn::INetworkPtr network =
162 m_Parser->CreateNetworkFromString(m_Prototext.c_str());
163 auto optimized = Optimize(*network, { armnn::Compute::CpuRef }, m_Runtime->GetDeviceSpec());
Mike Kellya9c32672023-12-04 17:23:09 +0000164 armnn::Status ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optimized), errorMessage);
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000165 if (ret != armnn::Status::Success)
166 {
Colm Donelan5b5c2222020-09-09 12:48:16 +0100167 throw armnn::Exception(fmt::format("LoadNetwork failed with error: '{0}' {1}",
168 errorMessage,
169 CHECK_LOCATION().AsString()));
telsoa014fcda012018-03-09 14:13:49 +0000170 }
171}
172
173template<typename TParser>
narpra016f37f832018-12-21 18:30:00 +0000174armnn::IOptimizedNetworkPtr ParserPrototxtFixture<TParser>::SetupOptimizedNetwork(
175 const std::map<std::string,armnn::TensorShape>& inputShapes,
176 const std::vector<std::string>& requestedOutputs)
177{
178 armnn::INetworkPtr network =
179 m_Parser->CreateNetworkFromString(m_Prototext.c_str(), inputShapes, requestedOutputs);
180 auto optimized = Optimize(*network, { armnn::Compute::CpuRef }, m_Runtime->GetDeviceSpec());
181 return optimized;
182}
183
184template<typename TParser>
telsoa014fcda012018-03-09 14:13:49 +0000185template <std::size_t NumOutputDimensions>
186void ParserPrototxtFixture<TParser>::RunTest(const std::vector<float>& inputData,
kevmay012b4d88e2019-01-24 14:05:09 +0000187 const std::vector<float>& expectedOutputData)
telsoa014fcda012018-03-09 14:13:49 +0000188{
189 RunTest<NumOutputDimensions>({ { m_SingleInputName, inputData } }, { { m_SingleOutputName, expectedOutputData } });
190}
191
192template<typename TParser>
193template <std::size_t NumOutputDimensions>
kevmay012b4d88e2019-01-24 14:05:09 +0000194void ParserPrototxtFixture<TParser>::RunComparisonTest(const std::map<std::string, std::vector<float>>& inputData,
195 const std::map<std::string, std::vector<uint8_t>>&
196 expectedOutputData)
197{
198 RunTest<NumOutputDimensions, uint8_t>(inputData, expectedOutputData);
199}
200
201template<typename TParser>
202template <std::size_t NumOutputDimensions, typename T>
telsoa014fcda012018-03-09 14:13:49 +0000203void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::vector<float>>& inputData,
kevmay012b4d88e2019-01-24 14:05:09 +0000204 const std::map<std::string, std::vector<T>>& expectedOutputData)
telsoa014fcda012018-03-09 14:13:49 +0000205{
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000206 // Sets up the armnn input tensors from the given vectors.
207 armnn::InputTensors inputTensors;
208 for (auto&& it : inputData)
telsoa014fcda012018-03-09 14:13:49 +0000209 {
Jim Flynnb4d7eae2019-05-01 14:44:27 +0100210 armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkInputBindingInfo(it.first);
Cathal Corbett5b8093c2021-10-22 11:12:07 +0100211 bindingInfo.second.SetConstant(true);
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000212 inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100213 if (bindingInfo.second.GetNumElements() != it.second.size())
214 {
215 throw armnn::Exception(fmt::format("Input tensor {0} is expected to have {1} elements. "
216 "{2} elements supplied. {3}",
217 it.first,
218 bindingInfo.second.GetNumElements(),
219 it.second.size(),
220 CHECK_LOCATION().AsString()));
221 }
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000222 }
telsoa014fcda012018-03-09 14:13:49 +0000223
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000224 // Allocates storage for the output tensors to be written to and sets up the armnn output tensors.
Sadik Armagan483c8112021-06-01 09:24:52 +0100225 std::map<std::string, std::vector<T>> outputStorage;
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000226 armnn::OutputTensors outputTensors;
227 for (auto&& it : expectedOutputData)
228 {
Jim Flynnb4d7eae2019-05-01 14:44:27 +0100229 armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(it.first);
Sadik Armagan483c8112021-06-01 09:24:52 +0100230 outputStorage.emplace(it.first, std::vector<T>(bindingInfo.second.GetNumElements()));
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000231 outputTensors.push_back(
232 { bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) });
233 }
234
235 m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
236
237 // Compares each output tensor to the expected values.
238 for (auto&& it : expectedOutputData)
239 {
Jim Flynnb4d7eae2019-05-01 14:44:27 +0100240 armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(it.first);
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000241 if (bindingInfo.second.GetNumElements() != it.second.size())
surmeh013537c2c2018-05-18 16:31:43 +0100242 {
Colm Donelan5b5c2222020-09-09 12:48:16 +0100243 throw armnn::Exception(fmt::format("Output tensor {0} is expected to have {1} elements. "
244 "{2} elements supplied. {3}",
245 it.first,
246 bindingInfo.second.GetNumElements(),
247 it.second.size(),
248 CHECK_LOCATION().AsString()));
surmeh013537c2c2018-05-18 16:31:43 +0100249 }
Ferran Balaguer51dd62f2019-01-11 19:29:18 +0000250
251 // If the expected output shape is set, the output tensor checks will be carried out.
252 if (m_SingleOutputShape.GetNumDimensions() != 0)
253 {
254
255 if (bindingInfo.second.GetShape().GetNumDimensions() == NumOutputDimensions &&
256 bindingInfo.second.GetShape().GetNumDimensions() == m_SingleOutputShape.GetNumDimensions())
257 {
258 for (unsigned int i = 0; i < m_SingleOutputShape.GetNumDimensions(); ++i)
259 {
260 if (m_SingleOutputShape[i] != bindingInfo.second.GetShape()[i])
261 {
Colm Donelan5b5c2222020-09-09 12:48:16 +0100262 // This exception message could not be created by fmt:format because of an oddity in
263 // the operator << of TensorShape.
264 std::stringstream message;
265 message << "Output tensor " << it.first << " is expected to have "
266 << bindingInfo.second.GetShape() << "shape. "
267 << m_SingleOutputShape << " shape supplied. "
268 << CHECK_LOCATION().AsString();
269 throw armnn::Exception(message.str());
Ferran Balaguer51dd62f2019-01-11 19:29:18 +0000270 }
271 }
272 }
273 else
274 {
Colm Donelan5b5c2222020-09-09 12:48:16 +0100275 throw armnn::Exception(fmt::format("Output tensor {0} is expected to have {1} dimensions. "
276 "{2} dimensions supplied. {3}",
277 it.first,
278 bindingInfo.second.GetShape().GetNumDimensions(),
279 NumOutputDimensions,
280 CHECK_LOCATION().AsString()));
Ferran Balaguer51dd62f2019-01-11 19:29:18 +0000281 }
282 }
283
Sadik Armagan483c8112021-06-01 09:24:52 +0100284 auto outputExpected = it.second;
285 auto shape = bindingInfo.second.GetShape();
kevmay012b4d88e2019-01-24 14:05:09 +0000286 if (std::is_same<T, uint8_t>::value)
287 {
Sadik Armagan483c8112021-06-01 09:24:52 +0100288 auto result = CompareTensors(outputExpected, outputStorage[it.first], shape, shape, true);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100289 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
kevmay012b4d88e2019-01-24 14:05:09 +0000290 }
291 else
292 {
Sadik Armagan483c8112021-06-01 09:24:52 +0100293 auto result = CompareTensors(outputExpected, outputStorage[it.first], shape, shape);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100294 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
kevmay012b4d88e2019-01-24 14:05:09 +0000295 }
telsoa014fcda012018-03-09 14:13:49 +0000296 }
297}
telsoa01c577f2c2018-08-31 09:22:23 +0100298
299} // namespace armnnUtils