blob: 9b86784dce0991da460fe20df73a5decab48e805 [file] [log] [blame]
David Monahan005288d2019-05-14 10:42:38 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
David Monahan005288d2019-05-14 10:42:38 +01006
Sadik Armagana097d2a2021-11-24 15:47:28 +00007#include <CommonTestUtils.hpp>
David Monahan005288d2019-05-14 10:42:38 +01008
Jan Eilersbb446e52020-04-02 13:56:54 +01009#include <Graph.hpp>
10#include <Network.hpp>
Jan Eilersbb446e52020-04-02 13:56:54 +010011#include <SubgraphViewSelector.hpp>
12
Francis Murtagha49ff082022-01-17 17:08:01 +000013#include <armnn/backends/OptimizationViews.hpp>
14#include <armnn/backends/SubgraphView.hpp>
15#include <armnn/utility/PolymorphicDowncast.hpp>
Cathal Corbett3464ba12022-03-04 11:36:39 +000016#include <armnnTestUtils/MockBackend.hpp>
Francis Murtagha49ff082022-01-17 17:08:01 +000017
Sadik Armagan1625efc2021-06-10 18:24:34 +010018#include <doctest/doctest.h>
Jan Eilersbb446e52020-04-02 13:56:54 +010019
David Monahan005288d2019-05-14 10:42:38 +010020using namespace armnn;
21
David Monahan41f00f12019-05-27 09:44:52 +010022void CheckLayers(Graph& graph)
23{
24 unsigned int m_inputLayerCount = 0, m_outputLayerCount = 0, m_addLayerCount = 0;
25 for(auto layer : graph)
26 {
27 switch(layer->GetType())
28 {
29 case LayerType::Input:
30 ++m_inputLayerCount;
Sadik Armagan1625efc2021-06-10 18:24:34 +010031 CHECK((layer->GetName() == std::string("inLayer0") ||
Narumol Prangnawarat60a20fb2019-12-09 17:24:41 +000032 layer->GetName() == std::string("inLayer1")));
David Monahan41f00f12019-05-27 09:44:52 +010033 break;
34 // The Addition layer should become a PreCompiled Layer after Optimisation
35 case LayerType::PreCompiled:
36 ++m_addLayerCount;
Sadik Armagan1625efc2021-06-10 18:24:34 +010037 CHECK(std::string(layer->GetName()) == "pre-compiled");
David Monahan41f00f12019-05-27 09:44:52 +010038 break;
39 case LayerType::Output:
40 ++m_outputLayerCount;
Sadik Armagan1625efc2021-06-10 18:24:34 +010041 CHECK(std::string(layer->GetName()) == "outLayer");
David Monahan41f00f12019-05-27 09:44:52 +010042 break;
43 default:
44 //Fail for anything else
Sadik Armagan1625efc2021-06-10 18:24:34 +010045 CHECK(false);
David Monahan41f00f12019-05-27 09:44:52 +010046 }
47 }
Sadik Armagan1625efc2021-06-10 18:24:34 +010048 CHECK(m_inputLayerCount == 2);
49 CHECK(m_outputLayerCount == 1);
50 CHECK(m_addLayerCount == 1);
David Monahan41f00f12019-05-27 09:44:52 +010051}
52
Sadik Armagan1625efc2021-06-10 18:24:34 +010053TEST_SUITE("OptimizationViewsTestSuite")
54{
55TEST_CASE("OptimizedViewsSubgraphLayerCount")
David Monahan005288d2019-05-14 10:42:38 +010056{
57 OptimizationViews view;
58 // Construct a graph with 3 layers
Cathal Corbettcbfd7182021-12-15 17:12:59 +000059 Graph baseGraph;
David Monahan005288d2019-05-14 10:42:38 +010060
61 Layer* const inputLayer = baseGraph.AddLayer<InputLayer>(0, "input");
62
63 Convolution2dDescriptor convDescriptor;
Keith Davisb4dd5cc2022-04-07 11:32:00 +010064 PreCompiledDescriptor substitutionLayerDescriptor(2, 1);
David Monahan005288d2019-05-14 10:42:38 +010065 Layer* const convLayer1 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv1");
66 Layer* const convLayer2 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv2");
Keith Davisb4dd5cc2022-04-07 11:32:00 +010067 Layer* const weightsLayer1 = baseGraph.AddLayer<ConstantLayer>("weights1");
68 Layer* const weightsLayer2 = baseGraph.AddLayer<ConstantLayer>("weights2");
David Monahan005288d2019-05-14 10:42:38 +010069 Layer* const substitutableCompiledLayer =
70 baseGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
71
72 Layer* const outputLayer = baseGraph.AddLayer<OutputLayer>(0, "output");
73
74 inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0));
Keith Davisb4dd5cc2022-04-07 11:32:00 +010075 weightsLayer1->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(1));
David Monahan005288d2019-05-14 10:42:38 +010076 convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0));
Keith Davisb4dd5cc2022-04-07 11:32:00 +010077 weightsLayer2->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(1));
David Monahan005288d2019-05-14 10:42:38 +010078 convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
79
80 // Subgraph for a failed layer
81 SubgraphViewSelector::SubgraphViewPtr failedSubgraph =
Keith Davisb4dd5cc2022-04-07 11:32:00 +010082 CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
David Monahan005288d2019-05-14 10:42:38 +010083 CreateOutputsFrom({convLayer1}),
84 {convLayer1});
85 // Subgraph for an untouched layer
86 SubgraphViewSelector::SubgraphViewPtr untouchedSubgraph =
Keith Davisb4dd5cc2022-04-07 11:32:00 +010087 CreateSubgraphViewFrom(CreateInputsFrom(convLayer2),
David Monahan005288d2019-05-14 10:42:38 +010088 CreateOutputsFrom({convLayer2}),
89 {convLayer2});
90 // Subgraph for a substitutable layer
91 SubgraphViewSelector::SubgraphViewPtr substitutableSubgraph =
Keith Davisb4dd5cc2022-04-07 11:32:00 +010092 CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
David Monahan005288d2019-05-14 10:42:38 +010093 CreateOutputsFrom({convLayer2}),
94 {substitutableCompiledLayer});
95 // Create a Graph containing a layer to substitute in
96 Graph substitutableGraph;
97 Layer* const substitutionpreCompiledLayer =
98 substitutableGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
99
100 // Subgraph for a substitution layer
101 SubgraphViewSelector::SubgraphViewPtr substitutionSubgraph =
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100102 CreateSubgraphViewFrom(CreateInputsFrom(substitutionpreCompiledLayer),
David Monahan005288d2019-05-14 10:42:38 +0100103 CreateOutputsFrom({substitutionpreCompiledLayer}),
104 {substitutionpreCompiledLayer});
105
106 // Sub in the graph
107 baseGraph.SubstituteSubgraph(*substitutableSubgraph, *substitutionSubgraph);
108
109 view.AddFailedSubgraph(SubgraphView(*failedSubgraph));
110 view.AddUntouchedSubgraph(SubgraphView(*untouchedSubgraph));
111
112 SubgraphViewSelector::SubgraphViewPtr baseSubgraph =
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100113 CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
David Monahan005288d2019-05-14 10:42:38 +0100114 CreateOutputsFrom({convLayer2}),
115 {substitutionpreCompiledLayer});
116 view.AddSubstitution({*baseSubgraph, *substitutionSubgraph});
117
118 // Construct original subgraph to compare against
119 SubgraphViewSelector::SubgraphViewPtr originalSubgraph =
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100120 CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
David Monahan005288d2019-05-14 10:42:38 +0100121 CreateOutputsFrom({convLayer2}),
122 {convLayer1, convLayer2, substitutionpreCompiledLayer});
123
Sadik Armagan1625efc2021-06-10 18:24:34 +0100124 CHECK(view.Validate(*originalSubgraph));
David Monahan005288d2019-05-14 10:42:38 +0100125}
126
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000127
128TEST_CASE("OptimizedViewsSubgraphLayerCountUsingGetINetwork")
129{
130 OptimizationViews view;
131
132 IConnectableLayer* const inputLayer = view.GetINetwork()->AddInputLayer(0, "input");
133
134 DepthwiseConvolution2dDescriptor convDescriptor;
Cathal Corbett06902652022-04-14 17:55:11 +0100135 PreCompiledDescriptor substitutionLayerDescriptor(2, 1);
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000136 CompiledBlobPtr blobPtr;
137 BackendId backend = Compute::CpuRef;
138
139 Layer* convLayer1 = PolymorphicDowncast<Layer*>(
140 view.GetINetwork()->AddDepthwiseConvolution2dLayer(convDescriptor,
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000141 "conv1"));
142
143 Layer* convLayer2 = PolymorphicDowncast<Layer*>(
144 view.GetINetwork()->AddDepthwiseConvolution2dLayer(convDescriptor,
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000145 "conv2"));
146
147 IConnectableLayer* const outputLayer = view.GetINetwork()->AddOutputLayer(0, "output");
148
149 inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0));
150 convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0));
151 convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
152
153 // Subgraph for a failed layer
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100154 SubgraphViewSelector::SubgraphViewPtr failedSubgraph = CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000155 CreateOutputsFrom({convLayer1}),
156 {convLayer1});
157 // Subgraph for an untouched layer
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100158 SubgraphViewSelector::SubgraphViewPtr untouchedSubgraph = CreateSubgraphViewFrom(CreateInputsFrom(convLayer2),
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000159 CreateOutputsFrom({convLayer2}),
160 {convLayer2});
161
162 // Create a Network containing a layer to substitute in
163 NetworkImpl net;
164 Layer* substitutionpreCompiledLayer = PolymorphicDowncast<Layer*>(
Cathal Corbett3ea01072022-01-06 10:29:43 +0000165 net.AddPrecompiledLayer(substitutionLayerDescriptor, std::move(blobPtr), backend));
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000166
167 // Subgraph for a substitution layer
168 SubgraphViewSelector::SubgraphViewPtr substitutionSubgraph =
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100169 CreateSubgraphViewFrom(CreateInputsFrom(substitutionpreCompiledLayer),
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000170 CreateOutputsFrom({substitutionpreCompiledLayer}),
171 {substitutionpreCompiledLayer});
172
173 view.AddFailedSubgraph(SubgraphView(*failedSubgraph));
174 view.AddUntouchedSubgraph(SubgraphView(*untouchedSubgraph));
175
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100176 SubgraphViewSelector::SubgraphViewPtr baseSubgraph = CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000177 CreateOutputsFrom({convLayer2}),
178 {substitutionpreCompiledLayer});
179 view.AddSubstitution({*baseSubgraph, *substitutionSubgraph});
180
181 // Construct original subgraph to compare against
182 SubgraphViewSelector::SubgraphViewPtr originalSubgraph =
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100183 CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000184 CreateOutputsFrom({convLayer2}),
185 {convLayer1, convLayer2, substitutionpreCompiledLayer});
186
187 CHECK(view.Validate(*originalSubgraph));
188}
189
Sadik Armagan1625efc2021-06-10 18:24:34 +0100190TEST_CASE("OptimizedViewsSubgraphLayerCountFailValidate")
David Monahan005288d2019-05-14 10:42:38 +0100191{
192 OptimizationViews view;
193 // Construct a graph with 3 layers
Cathal Corbettcbfd7182021-12-15 17:12:59 +0000194 Graph baseGraph;
David Monahan005288d2019-05-14 10:42:38 +0100195
196 Layer* const inputLayer = baseGraph.AddLayer<InputLayer>(0, "input");
197
198 Convolution2dDescriptor convDescriptor;
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100199 PreCompiledDescriptor substitutionLayerDescriptor(2, 1);
David Monahan005288d2019-05-14 10:42:38 +0100200 Layer* const convLayer1 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv1");
201 Layer* const convLayer2 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv2");
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100202 Layer* const weightsLayer1 = baseGraph.AddLayer<ConstantLayer>("weights1");
203 Layer* const weightsLayer2 = baseGraph.AddLayer<ConstantLayer>("weights2");
David Monahan005288d2019-05-14 10:42:38 +0100204 Layer* const substitutableCompiledLayer =
205 baseGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
206
207 Layer* const outputLayer = baseGraph.AddLayer<OutputLayer>(0, "output");
208
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100209
David Monahan005288d2019-05-14 10:42:38 +0100210 inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0));
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100211 weightsLayer1->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(1));
David Monahan005288d2019-05-14 10:42:38 +0100212 convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0));
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100213 weightsLayer2->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(1));
David Monahan005288d2019-05-14 10:42:38 +0100214 convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
215
216 // Subgraph for an untouched layer
217 SubgraphViewSelector::SubgraphViewPtr untouchedSubgraph =
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100218 CreateSubgraphViewFrom(CreateInputsFrom(convLayer2),
David Monahan005288d2019-05-14 10:42:38 +0100219 CreateOutputsFrom({convLayer2}),
220 {convLayer2});
221 // Subgraph for a substitutable layer
222 SubgraphViewSelector::SubgraphViewPtr substitutableSubgraph =
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100223 CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
David Monahan005288d2019-05-14 10:42:38 +0100224 CreateOutputsFrom({convLayer2}),
225 {substitutableCompiledLayer});
226 // Create a Graph containing a layer to substitute in
227 Graph substitutableGraph;
228 Layer* const substitutionpreCompiledLayer =
229 substitutableGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
230
231 // Subgraph for a substitution layer
232 SubgraphViewSelector::SubgraphViewPtr substitutionSubgraph =
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100233 CreateSubgraphViewFrom(CreateInputsFrom(substitutionpreCompiledLayer),
David Monahan005288d2019-05-14 10:42:38 +0100234 CreateOutputsFrom({substitutionpreCompiledLayer}),
235 {substitutionpreCompiledLayer});
236
237 // Sub in the graph
238 baseGraph.SubstituteSubgraph(*substitutableSubgraph, *substitutionSubgraph);
239
240 view.AddUntouchedSubgraph(SubgraphView(*untouchedSubgraph));
241
242 SubgraphViewSelector::SubgraphViewPtr baseSubgraph =
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100243 CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
David Monahan005288d2019-05-14 10:42:38 +0100244 CreateOutputsFrom({convLayer2}),
245 {substitutionpreCompiledLayer});
246 view.AddSubstitution({*baseSubgraph, *substitutionSubgraph});
247
248 // Construct original subgraph to compare against
249 SubgraphViewSelector::SubgraphViewPtr originalSubgraph =
Keith Davisb4dd5cc2022-04-07 11:32:00 +0100250 CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
David Monahan005288d2019-05-14 10:42:38 +0100251 CreateOutputsFrom({convLayer2}),
252 {convLayer1, convLayer2, substitutionpreCompiledLayer});
253
254 // Validate should fail as convLayer1 is not counted
Sadik Armagan1625efc2021-06-10 18:24:34 +0100255 CHECK(!view.Validate(*originalSubgraph));
David Monahan005288d2019-05-14 10:42:38 +0100256}
257
Sadik Armagan1625efc2021-06-10 18:24:34 +0100258TEST_CASE("OptimizeViewsValidateDeviceMockBackend")
David Monahan41f00f12019-05-27 09:44:52 +0100259{
260 // build up the structure of the network
261 armnn::INetworkPtr net(armnn::INetwork::Create());
262
263 armnn::IConnectableLayer* input = net->AddInputLayer(0, "inLayer0");
264 armnn::IConnectableLayer* input1 = net->AddInputLayer(1, "inLayer1");
265
266 armnn::IConnectableLayer* addition = net->AddAdditionLayer("addLayer");
267
268 armnn::IConnectableLayer* output = net->AddOutputLayer(0, "outLayer");
269
270 input->GetOutputSlot(0).Connect(addition->GetInputSlot(0));
271 input1->GetOutputSlot(0).Connect(addition->GetInputSlot(1));
272 addition->GetOutputSlot(0).Connect(output->GetInputSlot(0));
273
274 input->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({ 1, 1, 4, 4 }, armnn::DataType::Float32));
275 input1->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({ 1, 1, 4, 4 }, armnn::DataType::Float32));
276 addition->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({ 1, 1, 4, 4 }, armnn::DataType::Float32));
277
David Monahanc1536d62020-02-12 15:52:35 +0000278 armnn::MockBackendInitialiser initialiser;
David Monahan41f00f12019-05-27 09:44:52 +0100279 armnn::IRuntime::CreationOptions options;
280 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
281
282 std::vector<armnn::BackendId> backends = { MockBackend().GetIdStatic() };
283 armnn::IOptimizedNetworkPtr optNet = armnn::Optimize(*net, backends, runtime->GetDeviceSpec());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100284 CHECK(optNet);
David Monahan41f00f12019-05-27 09:44:52 +0100285
286 // Check the optimised graph
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000287 armnn::Graph& graph = GetGraphForTesting(optNet.get());
288 CheckLayers(graph);
David Monahan41f00f12019-05-27 09:44:52 +0100289}
290
Sadik Armagan1625efc2021-06-10 18:24:34 +0100291}