blob: 669b1fd0cab54f378dceb77eed7b22efbe6432dc [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5
6#pragma once
7
Aron Virginas-Tar9c5db112018-10-25 11:10:49 +01008#include <armnn/IRuntime.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <test/TensorHelpers.hpp>
Aron Virginas-Tar9c5db112018-10-25 11:10:49 +010010
11#include <armnnOnnxParser/IOnnxParser.hpp>
12
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000013#include <VerificationHelpers.hpp>
Aron Virginas-Tar9c5db112018-10-25 11:10:49 +010014
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000015#include <backendsCommon/BackendRegistry.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010016
17#include <boost/format.hpp>
Aron Virginas-Tar9c5db112018-10-25 11:10:49 +010018
telsoa014fcda012018-03-09 14:13:49 +000019#include <string>
20
telsoa01c577f2c2018-08-31 09:22:23 +010021namespace armnnUtils
22{
surmeh013537c2c2018-05-18 16:31:43 +010023
telsoa014fcda012018-03-09 14:13:49 +000024template<typename TParser>
25struct ParserPrototxtFixture
26{
27 ParserPrototxtFixture()
28 : m_Parser(TParser::Create())
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +000029 , m_Runtime(armnn::IRuntime::Create(armnn::IRuntime::CreationOptions()))
telsoa014fcda012018-03-09 14:13:49 +000030 , m_NetworkIdentifier(-1)
surmeh013537c2c2018-05-18 16:31:43 +010031 {
surmeh013537c2c2018-05-18 16:31:43 +010032 }
telsoa014fcda012018-03-09 14:13:49 +000033
34 /// Parses and loads the network defined by the m_Prototext string.
35 /// @{
36 void SetupSingleInputSingleOutput(const std::string& inputName, const std::string& outputName);
37 void SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape,
38 const std::string& inputName,
39 const std::string& outputName);
40 void Setup(const std::map<std::string, armnn::TensorShape>& inputShapes,
41 const std::vector<std::string>& requestedOutputs);
telsoa01c577f2c2018-08-31 09:22:23 +010042 void Setup();
telsoa014fcda012018-03-09 14:13:49 +000043 /// @}
44
45 /// Executes the network with the given input tensor and checks the result against the given output tensor.
telsoa01c577f2c2018-08-31 09:22:23 +010046 /// This overload assumes that the network has a single input and a single output.
telsoa014fcda012018-03-09 14:13:49 +000047 template <std::size_t NumOutputDimensions>
48 void RunTest(const std::vector<float>& inputData, const std::vector<float>& expectedOutputData);
49
50 /// Executes the network with the given input tensors and checks the results against the given output tensors.
51 /// This overload supports multiple inputs and multiple outputs, identified by name.
52 template <std::size_t NumOutputDimensions>
53 void RunTest(const std::map<std::string, std::vector<float>>& inputData,
54 const std::map<std::string, std::vector<float>>& expectedOutputData);
55
surmeh013537c2c2018-05-18 16:31:43 +010056 std::string m_Prototext;
57 std::unique_ptr<TParser, void(*)(TParser* parser)> m_Parser;
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +000058 armnn::IRuntimePtr m_Runtime;
surmeh013537c2c2018-05-18 16:31:43 +010059 armnn::NetworkId m_NetworkIdentifier;
telsoa014fcda012018-03-09 14:13:49 +000060
61 /// If the single-input-single-output overload of Setup() is called, these will store the input and output name
62 /// so they don't need to be passed to the single-input-single-output overload of RunTest().
63 /// @{
64 std::string m_SingleInputName;
65 std::string m_SingleOutputName;
66 /// @}
67};
68
69template<typename TParser>
70void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const std::string& inputName,
71 const std::string& outputName)
72{
telsoa01c577f2c2018-08-31 09:22:23 +010073 // Stores the input and output name so they don't need to be passed to the single-input-single-output RunTest().
telsoa014fcda012018-03-09 14:13:49 +000074 m_SingleInputName = inputName;
75 m_SingleOutputName = outputName;
76 Setup({ }, { outputName });
77}
78
79template<typename TParser>
80void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape,
81 const std::string& inputName,
82 const std::string& outputName)
83{
telsoa01c577f2c2018-08-31 09:22:23 +010084 // Stores the input and output name so they don't need to be passed to the single-input-single-output RunTest().
telsoa014fcda012018-03-09 14:13:49 +000085 m_SingleInputName = inputName;
86 m_SingleOutputName = outputName;
87 Setup({ { inputName, inputTensorShape } }, { outputName });
88}
89
90template<typename TParser>
91void ParserPrototxtFixture<TParser>::Setup(const std::map<std::string, armnn::TensorShape>& inputShapes,
92 const std::vector<std::string>& requestedOutputs)
93{
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +000094 std::string errorMessage;
telsoa01c577f2c2018-08-31 09:22:23 +010095
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +000096 armnn::INetworkPtr network =
97 m_Parser->CreateNetworkFromString(m_Prototext.c_str(), inputShapes, requestedOutputs);
98 auto optimized = Optimize(*network, { armnn::Compute::CpuRef }, m_Runtime->GetDeviceSpec());
99 armnn::Status ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, move(optimized), errorMessage);
100 if (ret != armnn::Status::Success)
101 {
102 throw armnn::Exception(boost::str(
103 boost::format("LoadNetwork failed with error: '%1%' %2%")
104 % errorMessage
105 % CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100106 }
107}
108
109template<typename TParser>
110void ParserPrototxtFixture<TParser>::Setup()
111{
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000112 std::string errorMessage;
telsoa01c577f2c2018-08-31 09:22:23 +0100113
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000114 armnn::INetworkPtr network =
115 m_Parser->CreateNetworkFromString(m_Prototext.c_str());
116 auto optimized = Optimize(*network, { armnn::Compute::CpuRef }, m_Runtime->GetDeviceSpec());
117 armnn::Status ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, move(optimized), errorMessage);
118 if (ret != armnn::Status::Success)
119 {
120 throw armnn::Exception(boost::str(
121 boost::format("LoadNetwork failed with error: '%1%' %2%")
122 % errorMessage
123 % CHECK_LOCATION().AsString()));
telsoa014fcda012018-03-09 14:13:49 +0000124 }
125}
126
127template<typename TParser>
128template <std::size_t NumOutputDimensions>
129void ParserPrototxtFixture<TParser>::RunTest(const std::vector<float>& inputData,
130 const std::vector<float>& expectedOutputData)
131{
132 RunTest<NumOutputDimensions>({ { m_SingleInputName, inputData } }, { { m_SingleOutputName, expectedOutputData } });
133}
134
135template<typename TParser>
136template <std::size_t NumOutputDimensions>
137void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::vector<float>>& inputData,
138 const std::map<std::string, std::vector<float>>& expectedOutputData)
139{
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000140 using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
141
142 // Sets up the armnn input tensors from the given vectors.
143 armnn::InputTensors inputTensors;
144 for (auto&& it : inputData)
telsoa014fcda012018-03-09 14:13:49 +0000145 {
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000146 BindingPointInfo bindingInfo = m_Parser->GetNetworkInputBindingInfo(it.first);
147 inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
148 }
telsoa014fcda012018-03-09 14:13:49 +0000149
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000150 // Allocates storage for the output tensors to be written to and sets up the armnn output tensors.
151 std::map<std::string, boost::multi_array<float, NumOutputDimensions>> outputStorage;
152 armnn::OutputTensors outputTensors;
153 for (auto&& it : expectedOutputData)
154 {
155 BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(it.first);
156 outputStorage.emplace(it.first, MakeTensor<float, NumOutputDimensions>(bindingInfo.second));
157 outputTensors.push_back(
158 { bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) });
159 }
160
161 m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
162
163 // Compares each output tensor to the expected values.
164 for (auto&& it : expectedOutputData)
165 {
166 BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(it.first);
167 if (bindingInfo.second.GetNumElements() != it.second.size())
surmeh013537c2c2018-05-18 16:31:43 +0100168 {
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000169 throw armnn::Exception(
170 boost::str(
171 boost::format("Output tensor %1% is expected to have %2% elements. "
172 "%3% elements supplied. %4%") %
173 it.first %
174 bindingInfo.second.GetNumElements() %
175 it.second.size() %
176 CHECK_LOCATION().AsString()));
surmeh013537c2c2018-05-18 16:31:43 +0100177 }
Aron Virginas-Tar1d67a6902018-11-19 10:58:30 +0000178 auto outputExpected = MakeTensor<float, NumOutputDimensions>(bindingInfo.second, it.second);
179 BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first]));
telsoa014fcda012018-03-09 14:13:49 +0000180 }
181}
telsoa01c577f2c2018-08-31 09:22:23 +0100182
183} // namespace armnnUtils