blob: b10590342e95b4ca6607aebf04eeb91935d21b33 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#pragma once
7
Aron Virginas-Tar9c5db112018-10-25 11:10:49 +01008#include <armnn/IRuntime.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <test/TensorHelpers.hpp>
Aron Virginas-Tar9c5db112018-10-25 11:10:49 +010010
11#include <armnnOnnxParser/IOnnxParser.hpp>
12
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <VerificationHelpers.hpp>
Aron Virginas-Tar9c5db112018-10-25 11:10:49 +010014
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000015#include <backendsCommon/BackendRegistry.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010016
17#include <boost/format.hpp>
Aron Virginas-Tar9c5db112018-10-25 11:10:49 +010018
telsoa014fcda012018-03-09 14:13:49 +000019#include <string>
20
telsoa01c577f2c2018-08-31 09:22:23 +010021namespace armnnUtils
22{
surmeh013537c2c2018-05-18 16:31:43 +010023
telsoa014fcda012018-03-09 14:13:49 +000024template<typename TParser>
25struct ParserPrototxtFixture
26{
27 ParserPrototxtFixture()
28 : m_Parser(TParser::Create())
telsoa014fcda012018-03-09 14:13:49 +000029 , m_NetworkIdentifier(-1)
surmeh013537c2c2018-05-18 16:31:43 +010030 {
telsoa01c577f2c2018-08-31 09:22:23 +010031 armnn::IRuntime::CreationOptions options;
surmeh013537c2c2018-05-18 16:31:43 +010032
Aron Virginas-Tar9c5db112018-10-25 11:10:49 +010033 // Create runtimes for each available backend
34 const armnn::BackendIdSet availableBackendIds = armnn::BackendRegistryInstance().GetBackendIds();
35 for (auto& backendId : availableBackendIds)
36 {
37 m_Runtimes.push_back(std::make_pair(armnn::IRuntime::Create(options), backendId));
38 }
surmeh013537c2c2018-05-18 16:31:43 +010039 }
telsoa014fcda012018-03-09 14:13:49 +000040
41 /// Parses and loads the network defined by the m_Prototext string.
42 /// @{
43 void SetupSingleInputSingleOutput(const std::string& inputName, const std::string& outputName);
44 void SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape,
45 const std::string& inputName,
46 const std::string& outputName);
47 void Setup(const std::map<std::string, armnn::TensorShape>& inputShapes,
48 const std::vector<std::string>& requestedOutputs);
telsoa01c577f2c2018-08-31 09:22:23 +010049 void Setup();
telsoa014fcda012018-03-09 14:13:49 +000050 /// @}
51
52 /// Executes the network with the given input tensor and checks the result against the given output tensor.
telsoa01c577f2c2018-08-31 09:22:23 +010053 /// This overload assumes that the network has a single input and a single output.
telsoa014fcda012018-03-09 14:13:49 +000054 template <std::size_t NumOutputDimensions>
55 void RunTest(const std::vector<float>& inputData, const std::vector<float>& expectedOutputData);
56
57 /// Executes the network with the given input tensors and checks the results against the given output tensors.
58 /// This overload supports multiple inputs and multiple outputs, identified by name.
59 template <std::size_t NumOutputDimensions>
60 void RunTest(const std::map<std::string, std::vector<float>>& inputData,
61 const std::map<std::string, std::vector<float>>& expectedOutputData);
62
surmeh013537c2c2018-05-18 16:31:43 +010063 std::string m_Prototext;
64 std::unique_ptr<TParser, void(*)(TParser* parser)> m_Parser;
David Beckf0b48452018-10-19 15:20:56 +010065 std::vector<std::pair<armnn::IRuntimePtr, armnn::BackendId>> m_Runtimes;
surmeh013537c2c2018-05-18 16:31:43 +010066 armnn::NetworkId m_NetworkIdentifier;
telsoa014fcda012018-03-09 14:13:49 +000067
68 /// If the single-input-single-output overload of Setup() is called, these will store the input and output name
69 /// so they don't need to be passed to the single-input-single-output overload of RunTest().
70 /// @{
71 std::string m_SingleInputName;
72 std::string m_SingleOutputName;
73 /// @}
74};
75
76template<typename TParser>
77void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const std::string& inputName,
78 const std::string& outputName)
79{
telsoa01c577f2c2018-08-31 09:22:23 +010080 // Stores the input and output name so they don't need to be passed to the single-input-single-output RunTest().
telsoa014fcda012018-03-09 14:13:49 +000081 m_SingleInputName = inputName;
82 m_SingleOutputName = outputName;
83 Setup({ }, { outputName });
84}
85
86template<typename TParser>
87void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape,
88 const std::string& inputName,
89 const std::string& outputName)
90{
telsoa01c577f2c2018-08-31 09:22:23 +010091 // Stores the input and output name so they don't need to be passed to the single-input-single-output RunTest().
telsoa014fcda012018-03-09 14:13:49 +000092 m_SingleInputName = inputName;
93 m_SingleOutputName = outputName;
94 Setup({ { inputName, inputTensorShape } }, { outputName });
95}
96
97template<typename TParser>
98void ParserPrototxtFixture<TParser>::Setup(const std::map<std::string, armnn::TensorShape>& inputShapes,
99 const std::vector<std::string>& requestedOutputs)
100{
surmeh013537c2c2018-05-18 16:31:43 +0100101 for (auto&& runtime : m_Runtimes)
telsoa014fcda012018-03-09 14:13:49 +0000102 {
telsoa01c577f2c2018-08-31 09:22:23 +0100103 std::string errorMessage;
104
surmeh013537c2c2018-05-18 16:31:43 +0100105 armnn::INetworkPtr network =
106 m_Parser->CreateNetworkFromString(m_Prototext.c_str(), inputShapes, requestedOutputs);
jimfly016b0b53d2018-10-08 14:43:01 +0100107 auto optimized = Optimize(*network,
108 { runtime.second, armnn::Compute::CpuRef }, runtime.first->GetDeviceSpec());
telsoa01c577f2c2018-08-31 09:22:23 +0100109 armnn::Status ret = runtime.first->LoadNetwork(m_NetworkIdentifier, move(optimized), errorMessage);
surmeh013537c2c2018-05-18 16:31:43 +0100110 if (ret != armnn::Status::Success)
111 {
telsoa01c577f2c2018-08-31 09:22:23 +0100112 throw armnn::Exception(boost::str(
113 boost::format("LoadNetwork failed with error: '%1%' %2%")
114 % errorMessage
115 % CHECK_LOCATION().AsString()));
116 }
117 }
118}
119
120template<typename TParser>
121void ParserPrototxtFixture<TParser>::Setup()
122{
123 for (auto&& runtime : m_Runtimes)
124 {
125 std::string errorMessage;
126
127 armnn::INetworkPtr network =
128 m_Parser->CreateNetworkFromString(m_Prototext.c_str());
jimfly016b0b53d2018-10-08 14:43:01 +0100129 auto optimized = Optimize(*network,
130 { runtime.second, armnn::Compute::CpuRef }, runtime.first->GetDeviceSpec());
telsoa01c577f2c2018-08-31 09:22:23 +0100131 armnn::Status ret = runtime.first->LoadNetwork(m_NetworkIdentifier, move(optimized), errorMessage);
132 if (ret != armnn::Status::Success)
133 {
134 throw armnn::Exception(boost::str(
135 boost::format("LoadNetwork failed with error: '%1%' %2%")
136 % errorMessage
137 % CHECK_LOCATION().AsString()));
surmeh013537c2c2018-05-18 16:31:43 +0100138 }
telsoa014fcda012018-03-09 14:13:49 +0000139 }
140}
141
142template<typename TParser>
143template <std::size_t NumOutputDimensions>
144void ParserPrototxtFixture<TParser>::RunTest(const std::vector<float>& inputData,
145 const std::vector<float>& expectedOutputData)
146{
147 RunTest<NumOutputDimensions>({ { m_SingleInputName, inputData } }, { { m_SingleOutputName, expectedOutputData } });
148}
149
150template<typename TParser>
151template <std::size_t NumOutputDimensions>
152void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::vector<float>>& inputData,
153 const std::map<std::string, std::vector<float>>& expectedOutputData)
154{
surmeh013537c2c2018-05-18 16:31:43 +0100155 for (auto&& runtime : m_Runtimes)
telsoa014fcda012018-03-09 14:13:49 +0000156 {
surmeh013537c2c2018-05-18 16:31:43 +0100157 using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
telsoa014fcda012018-03-09 14:13:49 +0000158
telsoa01c577f2c2018-08-31 09:22:23 +0100159 // Sets up the armnn input tensors from the given vectors.
surmeh013537c2c2018-05-18 16:31:43 +0100160 armnn::InputTensors inputTensors;
161 for (auto&& it : inputData)
162 {
163 BindingPointInfo bindingInfo = m_Parser->GetNetworkInputBindingInfo(it.first);
164 inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
165 }
telsoa014fcda012018-03-09 14:13:49 +0000166
telsoa01c577f2c2018-08-31 09:22:23 +0100167 // Allocates storage for the output tensors to be written to and sets up the armnn output tensors.
surmeh013537c2c2018-05-18 16:31:43 +0100168 std::map<std::string, boost::multi_array<float, NumOutputDimensions>> outputStorage;
169 armnn::OutputTensors outputTensors;
170 for (auto&& it : expectedOutputData)
171 {
172 BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(it.first);
173 outputStorage.emplace(it.first, MakeTensor<float, NumOutputDimensions>(bindingInfo.second));
174 outputTensors.push_back(
175 { bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) });
176 }
telsoa014fcda012018-03-09 14:13:49 +0000177
telsoa01c577f2c2018-08-31 09:22:23 +0100178 runtime.first->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
surmeh013537c2c2018-05-18 16:31:43 +0100179
telsoa01c577f2c2018-08-31 09:22:23 +0100180 // Compares each output tensor to the expected values.
surmeh013537c2c2018-05-18 16:31:43 +0100181 for (auto&& it : expectedOutputData)
182 {
183 BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(it.first);
telsoa01c577f2c2018-08-31 09:22:23 +0100184 if (bindingInfo.second.GetNumElements() != it.second.size())
185 {
186 throw armnn::Exception(
187 boost::str(
188 boost::format("Output tensor %1% is expected to have %2% elements. "
189 "%3% elements supplied. %4%") %
190 it.first %
191 bindingInfo.second.GetNumElements() %
192 it.second.size() %
193 CHECK_LOCATION().AsString()));
194 }
surmeh013537c2c2018-05-18 16:31:43 +0100195 auto outputExpected = MakeTensor<float, NumOutputDimensions>(bindingInfo.second, it.second);
196 BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first]));
197 }
telsoa014fcda012018-03-09 14:13:49 +0000198 }
199}
telsoa01c577f2c2018-08-31 09:22:23 +0100200
201} // namespace armnnUtils