blob: 13ab17c5b6d5026f3425016b878d00c8e11bfea7 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12// Graph which tests that nodes are re-ordered in the queue when they are encountered a second time.
13// In this case R0 will be encountered first via R1 and then via R2. At that time
14// we need to make sure that R0 (and the I on which it is dependent) is moved to the front again
15// so that it is before both R1 and R2.
16// I
17// |
18// R0
19// / \'
20// R1 R2
21// \ |
22// \ R3
23// \|
24// O
25struct RediscoveredDependenciesFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
26{
27 RediscoveredDependenciesFixture()
28 {
29 // input = tf.placeholder(tf.float32, 1, "input")
30 // relu0 = tf.nn.relu(input, "relu0")
31 // relu1 = tf.nn.relu(relu0, "relu1")
32 // relu2 = tf.nn.relu(relu0, "relu2")
33 // relu3 = tf.nn.relu(relu2, "relu3")
34 // output = tf.add(relu1, relu3, "output")
35 m_Prototext = R"(
36 node {
37 name: "input"
38 op: "Placeholder"
39 attr {
40 key: "dtype"
41 value {
42 type: DT_FLOAT
43 }
44 }
45 attr {
46 key: "shape"
47 value {
48 shape {
49 dim {
50 size: 1
51 }
52 }
53 }
54 }
55 }
56 node {
57 name: "relu0"
58 op: "Relu"
59 input: "input"
60 attr {
61 key: "T"
62 value {
63 type: DT_FLOAT
64 }
65 }
66 }
67 node {
68 name: "relu1"
69 op: "Relu"
70 input: "relu0"
71 attr {
72 key: "T"
73 value {
74 type: DT_FLOAT
75 }
76 }
77 }
78 node {
79 name: "relu2"
80 op: "Relu"
81 input: "relu0"
82 attr {
83 key: "T"
84 value {
85 type: DT_FLOAT
86 }
87 }
88 }
89 node {
90 name: "relu3"
91 op: "Relu"
92 input: "relu2"
93 attr {
94 key: "T"
95 value {
96 type: DT_FLOAT
97 }
98 }
99 }
100 node {
101 name: "output"
102 op: "Add"
103 input: "relu1"
104 input: "relu3"
105 attr {
106 key: "T"
107 value {
108 type: DT_FLOAT
109 }
110 }
111 }
112 )";
113 SetupSingleInputSingleOutput({ 1 }, "input", "output");
114 }
115};
116
117BOOST_FIXTURE_TEST_CASE(RediscoveredDependencies, RediscoveredDependenciesFixture)
118{
119 RunTest<1>({1}, {2});
120}
121
122// Tests that a simple cycle in the tensorflow graph will be detected and an exception thrown, rather than the TfParser
123// getting stuck in an infinite loop.
124BOOST_AUTO_TEST_CASE(SimpleCycle)
125{
126 const char* prototext = R"(
127node {
128 name: "r1"
129 op: "Relu"
130 input: "r2"
131 attr {
132 key: "T"
133 value {
134 type: DT_FLOAT
135 }
136 }
137}
138node {
139 name: "r2"
140 op: "Relu"
141 input: "r1"
142 attr {
143 key: "T"
144 value {
145 type: DT_FLOAT
146 }
147 }
148}
149 )";
150 armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
151 BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r2" }), armnn::ParseException);
152}
153
154// Similar to the above SimpleCycle test, but has a single node which connects to itself.
155BOOST_AUTO_TEST_CASE(SingleNodeCycle)
156{
157 const char* prototext = R"(
158node {
159 name: "r1"
160 op: "Relu"
161 input: "r1"
162 attr {
163 key: "T"
164 value {
165 type: DT_FLOAT
166 }
167 }
168}
169 )";
170 armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
171 BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException);
172}
173
174// Similar to the above SimpleCycle test, but with a more complicated graph.
175// I
176// |
177// A2---<---<-
178// / \' |
179// R1 R2 |
180// \ | |
181// \ R3 |
182// \| |
183// A1-->--->|
184//
185BOOST_AUTO_TEST_CASE(ComplexCycle)
186{
187 // input = tf.placeholder(tf.float32, 1, "input")
188 // add2 = tf.nn.relu(input, add1, "add2") // This line won't actually run in TF, because add1 is not yet defined
189 // relu1 = tf.nn.relu(relu0, "relu1")
190 // relu2 = tf.nn.relu(relu0, "relu2")
191 // relu3 = tf.nn.relu(relu2, "relu3")
192 // add1 = tf.add(relu1, relu3, "add1")
193 const char* prototext = R"(
194 node {
195 name: "input"
196 op: "Placeholder"
197 attr {
198 key: "dtype"
199 value {
200 type: DT_FLOAT
201 }
202 }
203 attr {
204 key: "shape"
205 value {
206 shape {
207 dim {
208 size: 1
209 }
210 }
211 }
212 }
213 }
214 node {
215 name: "add2"
216 op: "Add"
217 input: "input"
218 input: "add1"
219 attr {
220 key: "T"
221 value {
222 type: DT_FLOAT
223 }
224 }
225 }
226 node {
227 name: "relu1"
228 op: "Relu"
229 input: "add2"
230 attr {
231 key: "T"
232 value {
233 type: DT_FLOAT
234 }
235 }
236 }
237 node {
238 name: "relu2"
239 op: "Relu"
240 input: "add2"
241 attr {
242 key: "T"
243 value {
244 type: DT_FLOAT
245 }
246 }
247 }
248 node {
249 name: "relu3"
250 op: "Relu"
251 input: "relu2"
252 attr {
253 key: "T"
254 value {
255 type: DT_FLOAT
256 }
257 }
258 }
259 node {
260 name: "add1"
261 op: "Add"
262 input: "relu1"
263 input: "relu3"
264 attr {
265 key: "T"
266 value {
267 type: DT_FLOAT
268 }
269 }
270 }
271 )";
272 armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
273 BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "add1" }), armnn::ParseException);
274}
275
276// Tests that a graph with an input that is not present throws a ParseException.
277BOOST_AUTO_TEST_CASE(InvalidInput)
278{
279 const char* prototext = R"(
280node {
281 name: "r1"
282 op: "Relu"
283 input: "a-node-that-does-not-exist"
284 attr {
285 key: "T"
286 value {
287 type: DT_FLOAT
288 }
289 }
290}
291 )";
292 armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
293 BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException);
294}
295
296BOOST_AUTO_TEST_SUITE_END()