blob: 89b823a16960e5a78ec1ddd69a7c58acc7fd43ec [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#pragma once
7
8#include "armnn/IRuntime.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01009#include "armnnOnnxParser/IOnnxParser.hpp"
telsoa014fcda012018-03-09 14:13:49 +000010#include "test/TensorHelpers.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010011#include "VerificationHelpers.hpp"
12
13#include <boost/format.hpp>
telsoa014fcda012018-03-09 14:13:49 +000014#include <string>
15
telsoa01c577f2c2018-08-31 09:22:23 +010016namespace armnnUtils
17{
surmeh013537c2c2018-05-18 16:31:43 +010018
telsoa014fcda012018-03-09 14:13:49 +000019template<typename TParser>
20struct ParserPrototxtFixture
21{
22 ParserPrototxtFixture()
23 : m_Parser(TParser::Create())
telsoa014fcda012018-03-09 14:13:49 +000024 , m_NetworkIdentifier(-1)
surmeh013537c2c2018-05-18 16:31:43 +010025 {
telsoa01c577f2c2018-08-31 09:22:23 +010026 armnn::IRuntime::CreationOptions options;
27 m_Runtimes.push_back(std::make_pair(armnn::IRuntime::Create(options), armnn::Compute::CpuRef));
surmeh013537c2c2018-05-18 16:31:43 +010028
29#if ARMCOMPUTENEON_ENABLED
telsoa01c577f2c2018-08-31 09:22:23 +010030 m_Runtimes.push_back(std::make_pair(armnn::IRuntime::Create(options), armnn::Compute::CpuAcc));
surmeh013537c2c2018-05-18 16:31:43 +010031#endif
32
33#if ARMCOMPUTECL_ENABLED
telsoa01c577f2c2018-08-31 09:22:23 +010034 m_Runtimes.push_back(std::make_pair(armnn::IRuntime::Create(options), armnn::Compute::GpuAcc));
surmeh013537c2c2018-05-18 16:31:43 +010035#endif
36 }
telsoa014fcda012018-03-09 14:13:49 +000037
38 /// Parses and loads the network defined by the m_Prototext string.
39 /// @{
40 void SetupSingleInputSingleOutput(const std::string& inputName, const std::string& outputName);
41 void SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape,
42 const std::string& inputName,
43 const std::string& outputName);
44 void Setup(const std::map<std::string, armnn::TensorShape>& inputShapes,
45 const std::vector<std::string>& requestedOutputs);
telsoa01c577f2c2018-08-31 09:22:23 +010046 void Setup();
telsoa014fcda012018-03-09 14:13:49 +000047 /// @}
48
49 /// Executes the network with the given input tensor and checks the result against the given output tensor.
telsoa01c577f2c2018-08-31 09:22:23 +010050 /// This overload assumes that the network has a single input and a single output.
telsoa014fcda012018-03-09 14:13:49 +000051 template <std::size_t NumOutputDimensions>
52 void RunTest(const std::vector<float>& inputData, const std::vector<float>& expectedOutputData);
53
54 /// Executes the network with the given input tensors and checks the results against the given output tensors.
55 /// This overload supports multiple inputs and multiple outputs, identified by name.
56 template <std::size_t NumOutputDimensions>
57 void RunTest(const std::map<std::string, std::vector<float>>& inputData,
58 const std::map<std::string, std::vector<float>>& expectedOutputData);
59
surmeh013537c2c2018-05-18 16:31:43 +010060 std::string m_Prototext;
61 std::unique_ptr<TParser, void(*)(TParser* parser)> m_Parser;
telsoa01c577f2c2018-08-31 09:22:23 +010062 std::vector<std::pair<armnn::IRuntimePtr, armnn::Compute>> m_Runtimes;
surmeh013537c2c2018-05-18 16:31:43 +010063 armnn::NetworkId m_NetworkIdentifier;
telsoa014fcda012018-03-09 14:13:49 +000064
65 /// If the single-input-single-output overload of Setup() is called, these will store the input and output name
66 /// so they don't need to be passed to the single-input-single-output overload of RunTest().
67 /// @{
68 std::string m_SingleInputName;
69 std::string m_SingleOutputName;
70 /// @}
71};
72
73template<typename TParser>
74void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const std::string& inputName,
75 const std::string& outputName)
76{
telsoa01c577f2c2018-08-31 09:22:23 +010077 // Stores the input and output name so they don't need to be passed to the single-input-single-output RunTest().
telsoa014fcda012018-03-09 14:13:49 +000078 m_SingleInputName = inputName;
79 m_SingleOutputName = outputName;
80 Setup({ }, { outputName });
81}
82
83template<typename TParser>
84void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape,
85 const std::string& inputName,
86 const std::string& outputName)
87{
telsoa01c577f2c2018-08-31 09:22:23 +010088 // Stores the input and output name so they don't need to be passed to the single-input-single-output RunTest().
telsoa014fcda012018-03-09 14:13:49 +000089 m_SingleInputName = inputName;
90 m_SingleOutputName = outputName;
91 Setup({ { inputName, inputTensorShape } }, { outputName });
92}
93
94template<typename TParser>
95void ParserPrototxtFixture<TParser>::Setup(const std::map<std::string, armnn::TensorShape>& inputShapes,
96 const std::vector<std::string>& requestedOutputs)
97{
surmeh013537c2c2018-05-18 16:31:43 +010098 for (auto&& runtime : m_Runtimes)
telsoa014fcda012018-03-09 14:13:49 +000099 {
telsoa01c577f2c2018-08-31 09:22:23 +0100100 std::string errorMessage;
101
surmeh013537c2c2018-05-18 16:31:43 +0100102 armnn::INetworkPtr network =
103 m_Parser->CreateNetworkFromString(m_Prototext.c_str(), inputShapes, requestedOutputs);
telsoa01c577f2c2018-08-31 09:22:23 +0100104 auto optimized = Optimize(*network, { runtime.second, armnn::Compute::CpuRef }, runtime.first->GetDeviceSpec());
105 armnn::Status ret = runtime.first->LoadNetwork(m_NetworkIdentifier, move(optimized), errorMessage);
surmeh013537c2c2018-05-18 16:31:43 +0100106 if (ret != armnn::Status::Success)
107 {
telsoa01c577f2c2018-08-31 09:22:23 +0100108 throw armnn::Exception(boost::str(
109 boost::format("LoadNetwork failed with error: '%1%' %2%")
110 % errorMessage
111 % CHECK_LOCATION().AsString()));
112 }
113 }
114}
115
116template<typename TParser>
117void ParserPrototxtFixture<TParser>::Setup()
118{
119 for (auto&& runtime : m_Runtimes)
120 {
121 std::string errorMessage;
122
123 armnn::INetworkPtr network =
124 m_Parser->CreateNetworkFromString(m_Prototext.c_str());
125 auto optimized = Optimize(*network, { runtime.second, armnn::Compute::CpuRef }, runtime.first->GetDeviceSpec());
126 armnn::Status ret = runtime.first->LoadNetwork(m_NetworkIdentifier, move(optimized), errorMessage);
127 if (ret != armnn::Status::Success)
128 {
129 throw armnn::Exception(boost::str(
130 boost::format("LoadNetwork failed with error: '%1%' %2%")
131 % errorMessage
132 % CHECK_LOCATION().AsString()));
surmeh013537c2c2018-05-18 16:31:43 +0100133 }
telsoa014fcda012018-03-09 14:13:49 +0000134 }
135}
136
137template<typename TParser>
138template <std::size_t NumOutputDimensions>
139void ParserPrototxtFixture<TParser>::RunTest(const std::vector<float>& inputData,
140 const std::vector<float>& expectedOutputData)
141{
142 RunTest<NumOutputDimensions>({ { m_SingleInputName, inputData } }, { { m_SingleOutputName, expectedOutputData } });
143}
144
145template<typename TParser>
146template <std::size_t NumOutputDimensions>
147void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::vector<float>>& inputData,
148 const std::map<std::string, std::vector<float>>& expectedOutputData)
149{
surmeh013537c2c2018-05-18 16:31:43 +0100150 for (auto&& runtime : m_Runtimes)
telsoa014fcda012018-03-09 14:13:49 +0000151 {
surmeh013537c2c2018-05-18 16:31:43 +0100152 using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
telsoa014fcda012018-03-09 14:13:49 +0000153
telsoa01c577f2c2018-08-31 09:22:23 +0100154 // Sets up the armnn input tensors from the given vectors.
surmeh013537c2c2018-05-18 16:31:43 +0100155 armnn::InputTensors inputTensors;
156 for (auto&& it : inputData)
157 {
158 BindingPointInfo bindingInfo = m_Parser->GetNetworkInputBindingInfo(it.first);
159 inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
160 }
telsoa014fcda012018-03-09 14:13:49 +0000161
telsoa01c577f2c2018-08-31 09:22:23 +0100162 // Allocates storage for the output tensors to be written to and sets up the armnn output tensors.
surmeh013537c2c2018-05-18 16:31:43 +0100163 std::map<std::string, boost::multi_array<float, NumOutputDimensions>> outputStorage;
164 armnn::OutputTensors outputTensors;
165 for (auto&& it : expectedOutputData)
166 {
167 BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(it.first);
168 outputStorage.emplace(it.first, MakeTensor<float, NumOutputDimensions>(bindingInfo.second));
169 outputTensors.push_back(
170 { bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) });
171 }
telsoa014fcda012018-03-09 14:13:49 +0000172
telsoa01c577f2c2018-08-31 09:22:23 +0100173 runtime.first->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
surmeh013537c2c2018-05-18 16:31:43 +0100174
telsoa01c577f2c2018-08-31 09:22:23 +0100175 // Compares each output tensor to the expected values.
surmeh013537c2c2018-05-18 16:31:43 +0100176 for (auto&& it : expectedOutputData)
177 {
178 BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(it.first);
telsoa01c577f2c2018-08-31 09:22:23 +0100179 if (bindingInfo.second.GetNumElements() != it.second.size())
180 {
181 throw armnn::Exception(
182 boost::str(
183 boost::format("Output tensor %1% is expected to have %2% elements. "
184 "%3% elements supplied. %4%") %
185 it.first %
186 bindingInfo.second.GetNumElements() %
187 it.second.size() %
188 CHECK_LOCATION().AsString()));
189 }
surmeh013537c2c2018-05-18 16:31:43 +0100190 auto outputExpected = MakeTensor<float, NumOutputDimensions>(bindingInfo.second, it.second);
191 BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first]));
192 }
telsoa014fcda012018-03-09 14:13:49 +0000193 }
194}
telsoa01c577f2c2018-08-31 09:22:23 +0100195
196} // namespace armnnUtils