blob: 2fdda5e7595c9d2fc1561116a03527331dc353b3 [file] [log] [blame]
David Monahan005288d2019-05-14 10:42:38 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include <armnn/ArmNN.hpp>
8#include <Graph.hpp>
9#include <SubgraphView.hpp>
10#include <SubgraphViewSelector.hpp>
11#include <backendsCommon/OptimizationViews.hpp>
12#include <Network.hpp>
13
14#include "CommonTestUtils.hpp"
15
16using namespace armnn;
17
18BOOST_AUTO_TEST_SUITE(OptimizationViewsTestSuite)
19
20BOOST_AUTO_TEST_CASE(OptimizedViewsSubgraphLayerCount)
21{
22 OptimizationViews view;
23 // Construct a graph with 3 layers
24 Graph& baseGraph = view.GetGraph();
25
26 Layer* const inputLayer = baseGraph.AddLayer<InputLayer>(0, "input");
27
28 Convolution2dDescriptor convDescriptor;
29 PreCompiledDescriptor substitutionLayerDescriptor(1, 1);
30 Layer* const convLayer1 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv1");
31 Layer* const convLayer2 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv2");
32 Layer* const substitutableCompiledLayer =
33 baseGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
34
35 Layer* const outputLayer = baseGraph.AddLayer<OutputLayer>(0, "output");
36
37 inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0));
38 convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0));
39 convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
40
41 // Subgraph for a failed layer
42 SubgraphViewSelector::SubgraphViewPtr failedSubgraph =
43 CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
44 CreateOutputsFrom({convLayer1}),
45 {convLayer1});
46 // Subgraph for an untouched layer
47 SubgraphViewSelector::SubgraphViewPtr untouchedSubgraph =
48 CreateSubgraphViewFrom(CreateInputsFrom({convLayer2}),
49 CreateOutputsFrom({convLayer2}),
50 {convLayer2});
51 // Subgraph for a substitutable layer
52 SubgraphViewSelector::SubgraphViewPtr substitutableSubgraph =
53 CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
54 CreateOutputsFrom({convLayer2}),
55 {substitutableCompiledLayer});
56 // Create a Graph containing a layer to substitute in
57 Graph substitutableGraph;
58 Layer* const substitutionpreCompiledLayer =
59 substitutableGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
60
61 // Subgraph for a substitution layer
62 SubgraphViewSelector::SubgraphViewPtr substitutionSubgraph =
63 CreateSubgraphViewFrom(CreateInputsFrom({substitutionpreCompiledLayer}),
64 CreateOutputsFrom({substitutionpreCompiledLayer}),
65 {substitutionpreCompiledLayer});
66
67 // Sub in the graph
68 baseGraph.SubstituteSubgraph(*substitutableSubgraph, *substitutionSubgraph);
69
70 view.AddFailedSubgraph(SubgraphView(*failedSubgraph));
71 view.AddUntouchedSubgraph(SubgraphView(*untouchedSubgraph));
72
73 SubgraphViewSelector::SubgraphViewPtr baseSubgraph =
74 CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
75 CreateOutputsFrom({convLayer2}),
76 {substitutionpreCompiledLayer});
77 view.AddSubstitution({*baseSubgraph, *substitutionSubgraph});
78
79 // Construct original subgraph to compare against
80 SubgraphViewSelector::SubgraphViewPtr originalSubgraph =
81 CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
82 CreateOutputsFrom({convLayer2}),
83 {convLayer1, convLayer2, substitutionpreCompiledLayer});
84
85 BOOST_CHECK(view.Validate(*originalSubgraph));
86}
87
88BOOST_AUTO_TEST_CASE(OptimizedViewsSubgraphLayerCountFailValidate)
89{
90 OptimizationViews view;
91 // Construct a graph with 3 layers
92 Graph& baseGraph = view.GetGraph();
93
94 Layer* const inputLayer = baseGraph.AddLayer<InputLayer>(0, "input");
95
96 Convolution2dDescriptor convDescriptor;
97 PreCompiledDescriptor substitutionLayerDescriptor(1, 1);
98 Layer* const convLayer1 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv1");
99 Layer* const convLayer2 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv2");
100 Layer* const substitutableCompiledLayer =
101 baseGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
102
103 Layer* const outputLayer = baseGraph.AddLayer<OutputLayer>(0, "output");
104
105 inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0));
106 convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0));
107 convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
108
109 // Subgraph for an untouched layer
110 SubgraphViewSelector::SubgraphViewPtr untouchedSubgraph =
111 CreateSubgraphViewFrom(CreateInputsFrom({convLayer2}),
112 CreateOutputsFrom({convLayer2}),
113 {convLayer2});
114 // Subgraph for a substitutable layer
115 SubgraphViewSelector::SubgraphViewPtr substitutableSubgraph =
116 CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
117 CreateOutputsFrom({convLayer2}),
118 {substitutableCompiledLayer});
119 // Create a Graph containing a layer to substitute in
120 Graph substitutableGraph;
121 Layer* const substitutionpreCompiledLayer =
122 substitutableGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
123
124 // Subgraph for a substitution layer
125 SubgraphViewSelector::SubgraphViewPtr substitutionSubgraph =
126 CreateSubgraphViewFrom(CreateInputsFrom({substitutionpreCompiledLayer}),
127 CreateOutputsFrom({substitutionpreCompiledLayer}),
128 {substitutionpreCompiledLayer});
129
130 // Sub in the graph
131 baseGraph.SubstituteSubgraph(*substitutableSubgraph, *substitutionSubgraph);
132
133 view.AddUntouchedSubgraph(SubgraphView(*untouchedSubgraph));
134
135 SubgraphViewSelector::SubgraphViewPtr baseSubgraph =
136 CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
137 CreateOutputsFrom({convLayer2}),
138 {substitutionpreCompiledLayer});
139 view.AddSubstitution({*baseSubgraph, *substitutionSubgraph});
140
141 // Construct original subgraph to compare against
142 SubgraphViewSelector::SubgraphViewPtr originalSubgraph =
143 CreateSubgraphViewFrom(CreateInputsFrom({convLayer1}),
144 CreateOutputsFrom({convLayer2}),
145 {convLayer1, convLayer2, substitutionpreCompiledLayer});
146
147 // Validate should fail as convLayer1 is not counted
148 BOOST_CHECK(!view.Validate(*originalSubgraph));
149}
150
151BOOST_AUTO_TEST_SUITE_END()