blob: 246cb509c3f73f48d9745a1ce8d1139b411b9134 [file] [log] [blame]
David Monahan005288d2019-05-14 10:42:38 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
David Monahan005288d2019-05-14 10:42:38 +01006
7#include "CommonTestUtils.hpp"
David Monahan41f00f12019-05-27 09:44:52 +01008#include "MockBackend.hpp"
David Monahan005288d2019-05-14 10:42:38 +01009
Jan Eilersbb446e52020-04-02 13:56:54 +010010#include <armnn/backends/OptimizationViews.hpp>
11#include <armnn/utility/PolymorphicDowncast.hpp>
12#include <Graph.hpp>
13#include <Network.hpp>
14#include <SubgraphView.hpp>
15#include <SubgraphViewSelector.hpp>
16
Sadik Armagan1625efc2021-06-10 18:24:34 +010017#include <doctest/doctest.h>
Jan Eilersbb446e52020-04-02 13:56:54 +010018
David Monahan005288d2019-05-14 10:42:38 +010019using namespace armnn;
20
David Monahan41f00f12019-05-27 09:44:52 +010021void CheckLayers(Graph& graph)
22{
23 unsigned int m_inputLayerCount = 0, m_outputLayerCount = 0, m_addLayerCount = 0;
24 for(auto layer : graph)
25 {
26 switch(layer->GetType())
27 {
28 case LayerType::Input:
29 ++m_inputLayerCount;
Sadik Armagan1625efc2021-06-10 18:24:34 +010030 CHECK((layer->GetName() == std::string("inLayer0") ||
Narumol Prangnawarat60a20fb2019-12-09 17:24:41 +000031 layer->GetName() == std::string("inLayer1")));
David Monahan41f00f12019-05-27 09:44:52 +010032 break;
33 // The Addition layer should become a PreCompiled Layer after Optimisation
34 case LayerType::PreCompiled:
35 ++m_addLayerCount;
Sadik Armagan1625efc2021-06-10 18:24:34 +010036 CHECK(std::string(layer->GetName()) == "pre-compiled");
David Monahan41f00f12019-05-27 09:44:52 +010037 break;
38 case LayerType::Output:
39 ++m_outputLayerCount;
Sadik Armagan1625efc2021-06-10 18:24:34 +010040 CHECK(std::string(layer->GetName()) == "outLayer");
David Monahan41f00f12019-05-27 09:44:52 +010041 break;
42 default:
43 //Fail for anything else
Sadik Armagan1625efc2021-06-10 18:24:34 +010044 CHECK(false);
David Monahan41f00f12019-05-27 09:44:52 +010045 }
46 }
Sadik Armagan1625efc2021-06-10 18:24:34 +010047 CHECK(m_inputLayerCount == 2);
48 CHECK(m_outputLayerCount == 1);
49 CHECK(m_addLayerCount == 1);
David Monahan41f00f12019-05-27 09:44:52 +010050}
51
Sadik Armagan1625efc2021-06-10 18:24:34 +010052TEST_SUITE("OptimizationViewsTestSuite")
53{
54TEST_CASE("OptimizedViewsSubgraphLayerCount")
David Monahan005288d2019-05-14 10:42:38 +010055{
56 OptimizationViews view;
57 // Construct a graph with 3 layers
58 Graph& baseGraph = view.GetGraph();
59
60 Layer* const inputLayer = baseGraph.AddLayer<InputLayer>(0, "input");
61
62 Convolution2dDescriptor convDescriptor;
63 PreCompiledDescriptor substitutionLayerDescriptor(1, 1);
64 Layer* const convLayer1 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv1");
65 Layer* const convLayer2 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv2");
66 Layer* const substitutableCompiledLayer =
67 baseGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
68
69 Layer* const outputLayer = baseGraph.AddLayer<OutputLayer>(0, "output");
70
71 inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0));
72 convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0));
73 convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
74
75 // Subgraph for a failed layer
76 SubgraphViewSelector::SubgraphViewPtr failedSubgraph =
77 CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
78 CreateOutputsFrom({convLayer1}),
79 {convLayer1});
80 // Subgraph for an untouched layer
81 SubgraphViewSelector::SubgraphViewPtr untouchedSubgraph =
82 CreateSubgraphViewFrom(CreateInputsFrom({convLayer2}),
83 CreateOutputsFrom({convLayer2}),
84 {convLayer2});
85 // Subgraph for a substitutable layer
86 SubgraphViewSelector::SubgraphViewPtr substitutableSubgraph =
87 CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
88 CreateOutputsFrom({convLayer2}),
89 {substitutableCompiledLayer});
90 // Create a Graph containing a layer to substitute in
91 Graph substitutableGraph;
92 Layer* const substitutionpreCompiledLayer =
93 substitutableGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
94
95 // Subgraph for a substitution layer
96 SubgraphViewSelector::SubgraphViewPtr substitutionSubgraph =
97 CreateSubgraphViewFrom(CreateInputsFrom({substitutionpreCompiledLayer}),
98 CreateOutputsFrom({substitutionpreCompiledLayer}),
99 {substitutionpreCompiledLayer});
100
101 // Sub in the graph
102 baseGraph.SubstituteSubgraph(*substitutableSubgraph, *substitutionSubgraph);
103
104 view.AddFailedSubgraph(SubgraphView(*failedSubgraph));
105 view.AddUntouchedSubgraph(SubgraphView(*untouchedSubgraph));
106
107 SubgraphViewSelector::SubgraphViewPtr baseSubgraph =
108 CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
109 CreateOutputsFrom({convLayer2}),
110 {substitutionpreCompiledLayer});
111 view.AddSubstitution({*baseSubgraph, *substitutionSubgraph});
112
113 // Construct original subgraph to compare against
114 SubgraphViewSelector::SubgraphViewPtr originalSubgraph =
115 CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
116 CreateOutputsFrom({convLayer2}),
117 {convLayer1, convLayer2, substitutionpreCompiledLayer});
118
Sadik Armagan1625efc2021-06-10 18:24:34 +0100119 CHECK(view.Validate(*originalSubgraph));
David Monahan005288d2019-05-14 10:42:38 +0100120}
121
Sadik Armagan1625efc2021-06-10 18:24:34 +0100122TEST_CASE("OptimizedViewsSubgraphLayerCountFailValidate")
David Monahan005288d2019-05-14 10:42:38 +0100123{
124 OptimizationViews view;
125 // Construct a graph with 3 layers
126 Graph& baseGraph = view.GetGraph();
127
128 Layer* const inputLayer = baseGraph.AddLayer<InputLayer>(0, "input");
129
130 Convolution2dDescriptor convDescriptor;
131 PreCompiledDescriptor substitutionLayerDescriptor(1, 1);
132 Layer* const convLayer1 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv1");
133 Layer* const convLayer2 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv2");
134 Layer* const substitutableCompiledLayer =
135 baseGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
136
137 Layer* const outputLayer = baseGraph.AddLayer<OutputLayer>(0, "output");
138
139 inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0));
140 convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0));
141 convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
142
143 // Subgraph for an untouched layer
144 SubgraphViewSelector::SubgraphViewPtr untouchedSubgraph =
145 CreateSubgraphViewFrom(CreateInputsFrom({convLayer2}),
146 CreateOutputsFrom({convLayer2}),
147 {convLayer2});
148 // Subgraph for a substitutable layer
149 SubgraphViewSelector::SubgraphViewPtr substitutableSubgraph =
150 CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
151 CreateOutputsFrom({convLayer2}),
152 {substitutableCompiledLayer});
153 // Create a Graph containing a layer to substitute in
154 Graph substitutableGraph;
155 Layer* const substitutionpreCompiledLayer =
156 substitutableGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
157
158 // Subgraph for a substitution layer
159 SubgraphViewSelector::SubgraphViewPtr substitutionSubgraph =
160 CreateSubgraphViewFrom(CreateInputsFrom({substitutionpreCompiledLayer}),
161 CreateOutputsFrom({substitutionpreCompiledLayer}),
162 {substitutionpreCompiledLayer});
163
164 // Sub in the graph
165 baseGraph.SubstituteSubgraph(*substitutableSubgraph, *substitutionSubgraph);
166
167 view.AddUntouchedSubgraph(SubgraphView(*untouchedSubgraph));
168
169 SubgraphViewSelector::SubgraphViewPtr baseSubgraph =
170 CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
171 CreateOutputsFrom({convLayer2}),
172 {substitutionpreCompiledLayer});
173 view.AddSubstitution({*baseSubgraph, *substitutionSubgraph});
174
175 // Construct original subgraph to compare against
176 SubgraphViewSelector::SubgraphViewPtr originalSubgraph =
177 CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
178 CreateOutputsFrom({convLayer2}),
179 {convLayer1, convLayer2, substitutionpreCompiledLayer});
180
181 // Validate should fail as convLayer1 is not counted
Sadik Armagan1625efc2021-06-10 18:24:34 +0100182 CHECK(!view.Validate(*originalSubgraph));
David Monahan005288d2019-05-14 10:42:38 +0100183}
184
Sadik Armagan1625efc2021-06-10 18:24:34 +0100185TEST_CASE("OptimizeViewsValidateDeviceMockBackend")
David Monahan41f00f12019-05-27 09:44:52 +0100186{
187 // build up the structure of the network
188 armnn::INetworkPtr net(armnn::INetwork::Create());
189
190 armnn::IConnectableLayer* input = net->AddInputLayer(0, "inLayer0");
191 armnn::IConnectableLayer* input1 = net->AddInputLayer(1, "inLayer1");
192
193 armnn::IConnectableLayer* addition = net->AddAdditionLayer("addLayer");
194
195 armnn::IConnectableLayer* output = net->AddOutputLayer(0, "outLayer");
196
197 input->GetOutputSlot(0).Connect(addition->GetInputSlot(0));
198 input1->GetOutputSlot(0).Connect(addition->GetInputSlot(1));
199 addition->GetOutputSlot(0).Connect(output->GetInputSlot(0));
200
201 input->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({ 1, 1, 4, 4 }, armnn::DataType::Float32));
202 input1->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({ 1, 1, 4, 4 }, armnn::DataType::Float32));
203 addition->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({ 1, 1, 4, 4 }, armnn::DataType::Float32));
204
David Monahanc1536d62020-02-12 15:52:35 +0000205 armnn::MockBackendInitialiser initialiser;
David Monahan41f00f12019-05-27 09:44:52 +0100206 armnn::IRuntime::CreationOptions options;
207 armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
208
209 std::vector<armnn::BackendId> backends = { MockBackend().GetIdStatic() };
210 armnn::IOptimizedNetworkPtr optNet = armnn::Optimize(*net, backends, runtime->GetDeviceSpec());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100211 CHECK(optNet);
David Monahan41f00f12019-05-27 09:44:52 +0100212
213 // Check the optimised graph
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000214 armnn::Graph& graph = GetGraphForTesting(optNet.get());
215 CheckLayers(graph);
David Monahan41f00f12019-05-27 09:44:52 +0100216}
217
Sadik Armagan1625efc2021-06-10 18:24:34 +0100218}