blob: 81e3057c8067890b7554f637317985ea7c425c1f [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#pragma once
7
8#include "armnn/IRuntime.hpp"
9#include "test/TensorHelpers.hpp"
10#include <string>
11
surmeh013537c2c2018-05-18 16:31:43 +010012
13// TODO davbec01 (14/05/18) : put these into armnnUtils namespace
14
telsoa014fcda012018-03-09 14:13:49 +000015template<typename TParser>
16struct ParserPrototxtFixture
17{
18 ParserPrototxtFixture()
19 : m_Parser(TParser::Create())
telsoa014fcda012018-03-09 14:13:49 +000020 , m_NetworkIdentifier(-1)
surmeh013537c2c2018-05-18 16:31:43 +010021 {
22 m_Runtimes.push_back(armnn::IRuntime::Create(armnn::Compute::CpuRef));
23
24#if ARMCOMPUTENEON_ENABLED
25 m_Runtimes.push_back(armnn::IRuntime::Create(armnn::Compute::CpuAcc));
26#endif
27
28#if ARMCOMPUTECL_ENABLED
29 m_Runtimes.push_back(armnn::IRuntime::Create(armnn::Compute::GpuAcc));
30#endif
31 }
telsoa014fcda012018-03-09 14:13:49 +000032
33 /// Parses and loads the network defined by the m_Prototext string.
34 /// @{
35 void SetupSingleInputSingleOutput(const std::string& inputName, const std::string& outputName);
36 void SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape,
37 const std::string& inputName,
38 const std::string& outputName);
39 void Setup(const std::map<std::string, armnn::TensorShape>& inputShapes,
40 const std::vector<std::string>& requestedOutputs);
41 /// @}
42
43 /// Executes the network with the given input tensor and checks the result against the given output tensor.
44 /// This overload assumes the network has a single input and a single output.
45 template <std::size_t NumOutputDimensions>
46 void RunTest(const std::vector<float>& inputData, const std::vector<float>& expectedOutputData);
47
48 /// Executes the network with the given input tensors and checks the results against the given output tensors.
49 /// This overload supports multiple inputs and multiple outputs, identified by name.
50 template <std::size_t NumOutputDimensions>
51 void RunTest(const std::map<std::string, std::vector<float>>& inputData,
52 const std::map<std::string, std::vector<float>>& expectedOutputData);
53
surmeh013537c2c2018-05-18 16:31:43 +010054 std::string m_Prototext;
55 std::unique_ptr<TParser, void(*)(TParser* parser)> m_Parser;
56 std::vector<armnn::IRuntimePtr> m_Runtimes;
57 armnn::NetworkId m_NetworkIdentifier;
telsoa014fcda012018-03-09 14:13:49 +000058
59 /// If the single-input-single-output overload of Setup() is called, these will store the input and output name
60 /// so they don't need to be passed to the single-input-single-output overload of RunTest().
61 /// @{
62 std::string m_SingleInputName;
63 std::string m_SingleOutputName;
64 /// @}
65};
66
67template<typename TParser>
68void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const std::string& inputName,
69 const std::string& outputName)
70{
71 // Store the input and output name so they don't need to be passed to the single-input-single-output RunTest().
72 m_SingleInputName = inputName;
73 m_SingleOutputName = outputName;
74 Setup({ }, { outputName });
75}
76
77template<typename TParser>
78void ParserPrototxtFixture<TParser>::SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape,
79 const std::string& inputName,
80 const std::string& outputName)
81{
82 // Store the input and output name so they don't need to be passed to the single-input-single-output RunTest().
83 m_SingleInputName = inputName;
84 m_SingleOutputName = outputName;
85 Setup({ { inputName, inputTensorShape } }, { outputName });
86}
87
88template<typename TParser>
89void ParserPrototxtFixture<TParser>::Setup(const std::map<std::string, armnn::TensorShape>& inputShapes,
90 const std::vector<std::string>& requestedOutputs)
91{
surmeh013537c2c2018-05-18 16:31:43 +010092 for (auto&& runtime : m_Runtimes)
telsoa014fcda012018-03-09 14:13:49 +000093 {
surmeh013537c2c2018-05-18 16:31:43 +010094 armnn::INetworkPtr network =
95 m_Parser->CreateNetworkFromString(m_Prototext.c_str(), inputShapes, requestedOutputs);
96
97 auto optimized = Optimize(*network, runtime->GetDeviceSpec());
98
99 armnn::Status ret = runtime->LoadNetwork(m_NetworkIdentifier, move(optimized));
100
101 if (ret != armnn::Status::Success)
102 {
103 throw armnn::Exception("LoadNetwork failed");
104 }
telsoa014fcda012018-03-09 14:13:49 +0000105 }
106}
107
108template<typename TParser>
109template <std::size_t NumOutputDimensions>
110void ParserPrototxtFixture<TParser>::RunTest(const std::vector<float>& inputData,
111 const std::vector<float>& expectedOutputData)
112{
113 RunTest<NumOutputDimensions>({ { m_SingleInputName, inputData } }, { { m_SingleOutputName, expectedOutputData } });
114}
115
116template<typename TParser>
117template <std::size_t NumOutputDimensions>
118void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::vector<float>>& inputData,
119 const std::map<std::string, std::vector<float>>& expectedOutputData)
120{
surmeh013537c2c2018-05-18 16:31:43 +0100121 for (auto&& runtime : m_Runtimes)
telsoa014fcda012018-03-09 14:13:49 +0000122 {
surmeh013537c2c2018-05-18 16:31:43 +0100123 using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
telsoa014fcda012018-03-09 14:13:49 +0000124
surmeh013537c2c2018-05-18 16:31:43 +0100125 // Setup the armnn input tensors from the given vectors.
126 armnn::InputTensors inputTensors;
127 for (auto&& it : inputData)
128 {
129 BindingPointInfo bindingInfo = m_Parser->GetNetworkInputBindingInfo(it.first);
130 inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
131 }
telsoa014fcda012018-03-09 14:13:49 +0000132
surmeh013537c2c2018-05-18 16:31:43 +0100133 // Allocate storage for the output tensors to be written to and setup the armnn output tensors.
134 std::map<std::string, boost::multi_array<float, NumOutputDimensions>> outputStorage;
135 armnn::OutputTensors outputTensors;
136 for (auto&& it : expectedOutputData)
137 {
138 BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(it.first);
139 outputStorage.emplace(it.first, MakeTensor<float, NumOutputDimensions>(bindingInfo.second));
140 outputTensors.push_back(
141 { bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) });
142 }
telsoa014fcda012018-03-09 14:13:49 +0000143
surmeh013537c2c2018-05-18 16:31:43 +0100144 runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
145
146 // Compare each output tensor to the expected values
147 for (auto&& it : expectedOutputData)
148 {
149 BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(it.first);
150 auto outputExpected = MakeTensor<float, NumOutputDimensions>(bindingInfo.second, it.second);
151 BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first]));
152 }
telsoa014fcda012018-03-09 14:13:49 +0000153 }
154}