blob: 0665be7c7ec4e9d41e55d7599ed6f0014bf38162 [file] [log] [blame]
narpra016f37f832018-12-21 18:30:00 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "armnnTfParser/ITfParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8#include "test/GraphUtils.hpp"
9
Jan Eilersbb446e52020-04-02 13:56:54 +010010#include <armnn/utility/PolymorphicDowncast.hpp>
11
narpra016f37f832018-12-21 18:30:00 +000012#include <boost/test/unit_test.hpp>
13
14BOOST_AUTO_TEST_SUITE(TensorflowParser)
15
16struct AssertSimpleFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
17{
18 AssertSimpleFixture()
19 {
20 // Placeholder AssertInput
21 // | \ /
22 // Add ------ Assert
23
24 m_Prototext = R"(
25 node {
26 name: "Placeholder"
27 op: "Placeholder"
28 attr {
29 key: "dtype"
30 value {
31 type: DT_FLOAT
32 }
33 }
34 attr {
35 key: "shape"
36 value {
37 shape {
38 unknown_rank: true
39 }
40 }
41 }
42 }
43 node {
44 name: "AssertInput"
45 op: "Const"
46 attr {
47 key: "dtype"
48 value {
49 type: DT_FLOAT
50 }
51 }
52 attr {
53 key: "value"
54 value {
55 tensor {
56 dtype: DT_FLOAT
57 tensor_shape {
58 dim {
59 size: 1
60 }
61 }
62 float_val: 17.0
63 }
64 }
65 }
66 }
67 node {
68 name: "Assert"
69 op: "Assert"
70 input: "Placeholder"
71 input: "AssertInput"
72 attr {
73 key: "T"
74 value {
75 type: DT_FLOAT
76 }
77 }
78 }
79 node {
80 name: "Add"
81 op: "Add"
82 input: "Placeholder"
83 input: "Placeholder"
84 input: "^Assert"
85 attr {
86 key: "T"
87 value {
88 type: DT_FLOAT
89 }
90 }
91 })";
92 }
93};
94
95BOOST_FIXTURE_TEST_CASE(AssertSimpleTest, AssertSimpleFixture)
96{
97 SetupSingleInputSingleOutput({ 1, 1, 1, 4 }, "Placeholder", "Add");
98 RunTest<4>({ 1.0f, 2.0f, 3.0f, 4.0f }, { 2.0f, 4.0f, 6.0f, 8.0f });
99}
100
101BOOST_FIXTURE_TEST_CASE(AssertSimpleGraphStructureTest, AssertSimpleFixture)
102{
103 auto optimized = SetupOptimizedNetwork({ { "Placeholder", { 1, 1, 1, 4 } } }, { "Add" });
104
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000105 armnn::Graph& graph = GetGraphForTesting(optimized.get());
narpra016f37f832018-12-21 18:30:00 +0000106
107 BOOST_TEST((graph.GetNumInputs() == 1));
108 BOOST_TEST((graph.GetNumOutputs() == 1));
109 BOOST_TEST((graph.GetNumLayers() == 3));
110
111 armnn::Layer* inputLayer = GetFirstLayerWithName(graph, "Placeholder");
112 BOOST_TEST((inputLayer->GetType() == armnn::LayerType::Input));
113 BOOST_TEST(CheckNumberOfInputSlot(inputLayer, 0));
114 BOOST_TEST(CheckNumberOfOutputSlot(inputLayer, 1));
115
116 armnn::Layer* addLayer = GetFirstLayerWithName(graph, "Add");
117 BOOST_TEST((addLayer->GetType() == armnn::LayerType::Addition));
118 BOOST_TEST(CheckNumberOfInputSlot(addLayer, 2));
119 BOOST_TEST(CheckNumberOfOutputSlot(addLayer, 1));
120
121 armnn::TensorInfo tensorInfo(armnn::TensorShape({1, 1, 1, 4}), armnn::DataType::Float32);
122 BOOST_TEST(IsConnected(inputLayer, addLayer, 0, 0, tensorInfo));
123 BOOST_TEST(IsConnected(inputLayer, addLayer, 0, 1, tensorInfo));
124
125 for (auto&& outputLayer : graph.GetOutputLayers())
126 {
127 BOOST_TEST(IsConnected(addLayer, const_cast<armnn::OutputLayer*>(outputLayer), 0, 0, tensorInfo));
128 }
129}
130
131struct AssertFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
132{
133 AssertFixture()
134 {
135 // Input0 Input1 Input2
136 // | \ / |
137 // | Sub ------ Assert
138 // \ / /
139 // Output -------
140
141 m_Prototext = R"(
142 node {
143 name: "Input0"
144 op: "Placeholder"
145 attr {
146 key: "dtype"
147 value {
148 type: DT_FLOAT
149 }
150 }
151 attr {
152 key: "shape"
153 value {
154 shape {
155 unknown_rank: true
156 }
157 }
158 }
159 }
160 node {
161 name: "Input1"
162 op: "Placeholder"
163 attr {
164 key: "dtype"
165 value {
166 type: DT_FLOAT
167 }
168 }
169 attr {
170 key: "shape"
171 value {
172 shape {
173 unknown_rank: true
174 }
175 }
176 }
177 }
178 node {
179 name: "Sub"
180 op: "Sub"
181 input: "Input0"
182 input: "Input1"
183 attr {
184 key: "T"
185 value {
186 type: DT_FLOAT
187 }
188 }
189 }
190 node {
191 name: "Input2"
192 op: "Placeholder"
193 attr {
194 key: "dtype"
195 value {
196 type: DT_FLOAT
197 }
198 }
199 attr {
200 key: "shape"
201 value {
202 shape {
203 unknown_rank: true
204 }
205 }
206 }
207 }
208 node {
209 name: "Assert"
210 op: "Assert"
211 input: "Input2"
212 input: "Sub"
213 attr {
214 key: "T"
215 value {
216 type: DT_FLOAT
217 }
218 }
219 }
220 node {
221 name: "Output"
222 op: "Add"
223 input: "Input0"
224 input: "Sub"
225 input: "^Assert"
226 attr {
227 key: "T"
228 value {
229 type: DT_FLOAT
230 }
231 }
232 })";
233
234
235 }
236};
237
238BOOST_FIXTURE_TEST_CASE(AssertTest, AssertFixture)
239{
240 Setup({ { "Input0", { 1, 1, 2, 2 } },
241 { "Input1", { 1, 1, 2, 2 } } },
242 { "Output" });
243
244 RunTest<4>({ { "Input0", { 4.0f, 3.0f,
245 2.0f, 1.0f } },
246
247 { "Input1", { 1.0f, 2.0f,
248 3.0f, 4.0f } } },
249
250 { { "Output", { 7.0f, 4.0f,
251 1.0f, -2.0f } } });
252}
253
254BOOST_FIXTURE_TEST_CASE(AssertGraphStructureTest, AssertFixture)
255{
256 auto optimized = SetupOptimizedNetwork({ { "Input0", { 1, 1, 2, 2 } },
257 { "Input1", { 1, 1, 2, 2 } } },
258 { "Output" });
259
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000260 armnn::Graph& graph = GetGraphForTesting(optimized.get());
narpra016f37f832018-12-21 18:30:00 +0000261
262 BOOST_TEST((graph.GetNumInputs() == 2));
263 BOOST_TEST((graph.GetNumOutputs() == 1));
264 BOOST_TEST((graph.GetNumLayers() == 5));
265
266 armnn::Layer* inputLayer0 = GetFirstLayerWithName(graph, "Input0");
267 BOOST_TEST((inputLayer0->GetType() == armnn::LayerType::Input));
268 BOOST_TEST(CheckNumberOfInputSlot(inputLayer0, 0));
269 BOOST_TEST(CheckNumberOfOutputSlot(inputLayer0, 1));
270
271 armnn::Layer* inputLayer1 = GetFirstLayerWithName(graph, "Input1");
272 BOOST_TEST((inputLayer1->GetType() == armnn::LayerType::Input));
273 BOOST_TEST(CheckNumberOfInputSlot(inputLayer1, 0));
274 BOOST_TEST(CheckNumberOfOutputSlot(inputLayer1, 1));
275
276 armnn::Layer* subLayer = GetFirstLayerWithName(graph, "Sub");
277 BOOST_TEST((subLayer->GetType() == armnn::LayerType::Subtraction));
278 BOOST_TEST(CheckNumberOfInputSlot(subLayer, 2));
279 BOOST_TEST(CheckNumberOfOutputSlot(subLayer, 1));
280
281 armnn::Layer* addLayer = GetFirstLayerWithName(graph, "Output");
282 BOOST_TEST((addLayer->GetType() == armnn::LayerType::Addition));
283 BOOST_TEST(CheckNumberOfInputSlot(addLayer, 2));
284 BOOST_TEST(CheckNumberOfOutputSlot(addLayer, 1));
285
286 armnn::TensorInfo tensorInfo(armnn::TensorShape({1, 1, 2, 2}), armnn::DataType::Float32);
287 BOOST_TEST(IsConnected(inputLayer0, subLayer, 0, 0, tensorInfo));
288 BOOST_TEST(IsConnected(inputLayer1, subLayer, 0, 1, tensorInfo));
289 BOOST_TEST(IsConnected(inputLayer0, addLayer, 0, 0, tensorInfo));
290 BOOST_TEST(IsConnected(subLayer, addLayer, 0, 1, tensorInfo));
291
292 for (auto&& outputLayer : graph.GetOutputLayers())
293 {
294 BOOST_TEST(IsConnected(addLayer, const_cast<armnn::OutputLayer*>(outputLayer), 0, 0, tensorInfo));
295 }
296}
297
298
299BOOST_AUTO_TEST_SUITE_END()