blob: d53ae672eb628f9e28a2ca5dacf022d275552218 [file] [log] [blame]
Sadik Armagan2ad6cb42018-12-27 11:23:44 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12struct SplitFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13{
Saoirse Stewart91c0eff2019-02-27 11:07:57 +000014 SplitFixture(bool withDimZero=false) {
15 m_Prototext = R"(
16 node {
17 name: "graphInput"
18 op: "Placeholder"
19 attr {
20 key: "dtype"
21 value {
22 type: DT_FLOAT
23 }
24 }
25 attr {
26 key: "shape"
27 value {
28 shape {
29 }
30 }
31 }
32 }
33 node {
34 name: "graphInput2"
35 op: "Placeholder"
36 attr {
37 key: "dtype"
38 value {
39 type: DT_FLOAT
40 }
41 }
42 attr {
43 key: "shape"
44 value {
45 shape {
46 }
47 }
48 }
49 }
50 node {
51 name: "multiplication"
52 op : "Mul"
53 input: "graphInput"
54 input: "graphInput2"
55 attr {
56 key: "T"
57 value {
58 type: DT_FLOAT
59 }
60 }
61 }
62 node {
63 name: "SplitInput"
64 op: "Const"
65 attr {
66 key: "dtype"
67 value {
68 type: DT_INT32
69 }
70 }
71 attr {
72 key: "value"
73 value {
74 tensor {
75 dtype: DT_INT32
76 tensor_shape {
77 }
78 int_val: )";
Sadik Armagan2ad6cb42018-12-27 11:23:44 +000079
Saoirse Stewart91c0eff2019-02-27 11:07:57 +000080 if(withDimZero)
81 {
82 m_Prototext += std::to_string(3);
83 }
84 else
85 {
86 m_Prototext += std::to_string(1);
87 }
88
89 m_Prototext += R"(
90 }
91 }
92 }
93 }
94 node {
95 name: "Split"
96 op: "Split" )";
97 if(withDimZero)
98 {
99 m_Prototext += "input: \"SplitInput\"\n";
100 m_Prototext += "input: \"multiplication\"\n";
101 }
102 else
103 {
104 m_Prototext += "input: \"graphInput\"\n";
105 m_Prototext += "input: \"SplitInput\"\n";
106 }
107 m_Prototext += R"(
108 attr {
Saoirse Stewart315258e2019-02-28 11:32:41 +0000109 key: "num_split"
Saoirse Stewart91c0eff2019-02-27 11:07:57 +0000110 value {
111 i: 2
112 }
113 }
114 }
115 node {
116 name: "Relu_1"
117 op: "Relu"
118 input: "Split:0"
119 attr {
120 key: "T"
121 value {
122 type: DT_FLOAT
123 }
124 }
125 }
126 node {
127 name: "Relu_2"
128 op: "Relu"
129 input:"Split:1"
130 attr {
131 key: "T"
132 value {
133 type: DT_FLOAT
134 }
135 }
136 } )";
137
138 Setup( { { "graphInput", { 1, 2, 2 , 2} } , { "graphInput2", { 1, 2, 2 , 2} }},
Sadik Armagan2ad6cb42018-12-27 11:23:44 +0000139 { "Relu_1", "Relu_2" });
140 }
141};
142
Saoirse Stewart91c0eff2019-02-27 11:07:57 +0000143struct InputFirstSplitFixture : SplitFixture
144{
145 InputFirstSplitFixture() : SplitFixture(true) {}
146};
147
Sadik Armagan2ad6cb42018-12-27 11:23:44 +0000148BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture)
149{
150 BOOST_TEST(
151 (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
152
153 BOOST_TEST(
154 (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
155
156 RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f, 1.75f } } },
157 { { "Relu_1", { 0.0f, 0.0f, 1.25f, 0.0f } },
158 { "Relu_2", { 0.0f, 0.5f, 0.0f, 1.75f } } });
159}
160
Saoirse Stewart91c0eff2019-02-27 11:07:57 +0000161BOOST_FIXTURE_TEST_CASE(ParseSplit, InputFirstSplitFixture)
162{
163
164 BOOST_TEST(
165 (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
166
167 BOOST_TEST(
168 (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
169
170 RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } ,
171 { "graphInput2", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } },
172 { { "Relu_1", { 1.0f, 1.5625f, 0, 0.5625f } },
173 { "Relu_2", { 0.25, 9.0f, 0.25f, 3.0625f } } });
174}
175
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100176struct SplitLastDimFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
177{
178 SplitLastDimFixture(bool withDimZero=false) {
Derek Lambertiba25aee2019-12-10 22:20:54 +0000179 boost::ignore_unused(withDimZero);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100180 m_Prototext = R"(
181 node {
182 name: "Placeholder"
183 op: "Placeholder"
184 attr {
185 key: "dtype"
186 value {
187 type: DT_FLOAT
188 }
189 }
190 attr {
191 key: "shape"
192 value {
193 shape {
194 dim {
195 size: 1
196 }
197 dim {
198 size: 2
199 }
200 dim {
201 size: 2
202 }
203 dim {
204 size: 3
205 }
206 }
207 }
208 }
209 }
210 node {
211 name: "Const"
212 op: "Const"
213 attr {
214 key: "dtype"
215 value {
216 type: DT_INT32
217 }
218 }
219 attr {
220 key: "value"
221 value {
222 tensor {
223 dtype: DT_INT32
224 tensor_shape {
225 }
226 int_val: 3
227 }
228 }
229 }
230 }
231 node {
232 name: "split/split_dim"
233 op: "Const"
234 attr {
235 key: "dtype"
236 value {
237 type: DT_INT32
238 }
239 }
240 attr {
241 key: "value"
242 value {
243 tensor {
244 dtype: DT_INT32
245 tensor_shape {
246 }
247 int_val: 3
248 }
249 }
250 }
251 }
252 node {
253 name: "split"
254 op: "Split"
255 input: "split/split_dim"
256 input: "Placeholder"
257 attr {
258 key: "T"
259 value {
260 type: DT_FLOAT
261 }
262 }
263 attr {
264 key: "num_split"
265 value {
266 i: 3
267 }
268 }
269 }
270 node {
271 name: "sub0/y"
272 op: "Const"
273 attr {
274 key: "dtype"
275 value {
276 type: DT_FLOAT
277 }
278 }
279 attr {
280 key: "value"
281 value {
282 tensor {
283 dtype: DT_FLOAT
284 tensor_shape {
285 }
286 float_val: 3.0
287 }
288 }
289 }
290 }
291 node {
292 name: "sub0"
293 op: "Sub"
294 input: "split"
295 input: "sub0/y"
296 attr {
297 key: "T"
298 value {
299 type: DT_FLOAT
300 }
301 }
302 }
303 node {
304 name: "sub1/y"
305 op: "Const"
306 attr {
307 key: "dtype"
308 value {
309 type: DT_FLOAT
310 }
311 }
312 attr {
313 key: "value"
314 value {
315 tensor {
316 dtype: DT_FLOAT
317 tensor_shape {
318 }
319 float_val: 2.0
320 }
321 }
322 }
323 }
324 node {
325 name: "sub1"
326 op: "Sub"
327 input: "split:1"
328 input: "sub1/y"
329 attr {
330 key: "T"
331 value {
332 type: DT_FLOAT
333 }
334 }
335 }
336 node {
337 name: "sub2/y"
338 op: "Const"
339 attr {
340 key: "dtype"
341 value {
342 type: DT_FLOAT
343 }
344 }
345 attr {
346 key: "value"
347 value {
348 tensor {
349 dtype: DT_FLOAT
350 tensor_shape {
351 }
352 float_val: 1.0
353 }
354 }
355 }
356 }
357 node {
358 name: "sub2"
359 op: "Sub"
360 input: "split:2"
361 input: "sub2/y"
362 attr {
363 key: "T"
364 value {
365 type: DT_FLOAT
366 }
367 }
368 }
369 versions {
370 producer: 27
371 } )";
372
373 Setup( { { "Placeholder", { 1, 2, 2 , 3} } },
374 { "sub0", "sub1", "sub2" });
375 }
376};
377
378BOOST_FIXTURE_TEST_CASE(SplitLastDimTest, SplitLastDimFixture)
379{
380 BOOST_TEST(
381 (m_Parser->GetNetworkOutputBindingInfo("sub0").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
382
383 BOOST_TEST(
384 (m_Parser->GetNetworkOutputBindingInfo("sub1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
385
386 BOOST_TEST(
387 (m_Parser->GetNetworkOutputBindingInfo("sub2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
388
389 RunTest<4>({ { "Placeholder", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f } } },
390 { { "sub0", { -2.0f, 1.0f, 4.0f, 7.0f } },
391 { "sub1", { 0.0f, 3.0f, 6.0f, 9.0f } },
392 { "sub2", { 2.0f, 5.0f, 8.0f, 11.0f } } });
393}
394
Sadik Armagan2ad6cb42018-12-27 11:23:44 +0000395BOOST_AUTO_TEST_SUITE_END()