blob: eeef90a6259ed83b675c3f39ddd4a226a568d9ff [file] [log] [blame]
Sadik Armagan2ad6cb42018-12-27 11:23:44 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Sadik Armagan2ad6cb42018-12-27 11:23:44 +00006#include "armnnTfParser/ITfParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8
Jan Eilers8eb25602020-03-09 12:13:48 +00009#include <armnn/utility/IgnoreUnused.hpp>
10
11#include <boost/test/unit_test.hpp>
12
Sadik Armagan2ad6cb42018-12-27 11:23:44 +000013BOOST_AUTO_TEST_SUITE(TensorflowParser)
14
15struct SplitFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
16{
Saoirse Stewart91c0eff2019-02-27 11:07:57 +000017 SplitFixture(bool withDimZero=false) {
18 m_Prototext = R"(
19 node {
20 name: "graphInput"
21 op: "Placeholder"
22 attr {
23 key: "dtype"
24 value {
25 type: DT_FLOAT
26 }
27 }
28 attr {
29 key: "shape"
30 value {
31 shape {
32 }
33 }
34 }
35 }
36 node {
37 name: "graphInput2"
38 op: "Placeholder"
39 attr {
40 key: "dtype"
41 value {
42 type: DT_FLOAT
43 }
44 }
45 attr {
46 key: "shape"
47 value {
48 shape {
49 }
50 }
51 }
52 }
53 node {
54 name: "multiplication"
55 op : "Mul"
56 input: "graphInput"
57 input: "graphInput2"
58 attr {
59 key: "T"
60 value {
61 type: DT_FLOAT
62 }
63 }
64 }
65 node {
66 name: "SplitInput"
67 op: "Const"
68 attr {
69 key: "dtype"
70 value {
71 type: DT_INT32
72 }
73 }
74 attr {
75 key: "value"
76 value {
77 tensor {
78 dtype: DT_INT32
79 tensor_shape {
80 }
81 int_val: )";
Sadik Armagan2ad6cb42018-12-27 11:23:44 +000082
Saoirse Stewart91c0eff2019-02-27 11:07:57 +000083 if(withDimZero)
84 {
85 m_Prototext += std::to_string(3);
86 }
87 else
88 {
89 m_Prototext += std::to_string(1);
90 }
91
92 m_Prototext += R"(
93 }
94 }
95 }
96 }
97 node {
98 name: "Split"
99 op: "Split" )";
100 if(withDimZero)
101 {
102 m_Prototext += "input: \"SplitInput\"\n";
103 m_Prototext += "input: \"multiplication\"\n";
104 }
105 else
106 {
107 m_Prototext += "input: \"graphInput\"\n";
108 m_Prototext += "input: \"SplitInput\"\n";
109 }
110 m_Prototext += R"(
111 attr {
Saoirse Stewart315258e2019-02-28 11:32:41 +0000112 key: "num_split"
Saoirse Stewart91c0eff2019-02-27 11:07:57 +0000113 value {
114 i: 2
115 }
116 }
117 }
118 node {
119 name: "Relu_1"
120 op: "Relu"
121 input: "Split:0"
122 attr {
123 key: "T"
124 value {
125 type: DT_FLOAT
126 }
127 }
128 }
129 node {
130 name: "Relu_2"
131 op: "Relu"
132 input:"Split:1"
133 attr {
134 key: "T"
135 value {
136 type: DT_FLOAT
137 }
138 }
139 } )";
140
141 Setup( { { "graphInput", { 1, 2, 2 , 2} } , { "graphInput2", { 1, 2, 2 , 2} }},
Sadik Armagan2ad6cb42018-12-27 11:23:44 +0000142 { "Relu_1", "Relu_2" });
143 }
144};
145
Saoirse Stewart91c0eff2019-02-27 11:07:57 +0000146struct InputFirstSplitFixture : SplitFixture
147{
148 InputFirstSplitFixture() : SplitFixture(true) {}
149};
150
Sadik Armagan2ad6cb42018-12-27 11:23:44 +0000151BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture)
152{
153 BOOST_TEST(
154 (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
155
156 BOOST_TEST(
157 (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
158
159 RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f, 1.75f } } },
160 { { "Relu_1", { 0.0f, 0.0f, 1.25f, 0.0f } },
161 { "Relu_2", { 0.0f, 0.5f, 0.0f, 1.75f } } });
162}
163
Saoirse Stewart91c0eff2019-02-27 11:07:57 +0000164BOOST_FIXTURE_TEST_CASE(ParseSplit, InputFirstSplitFixture)
165{
166
167 BOOST_TEST(
168 (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
169
170 BOOST_TEST(
171 (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
172
173 RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } ,
174 { "graphInput2", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } },
175 { { "Relu_1", { 1.0f, 1.5625f, 0, 0.5625f } },
176 { "Relu_2", { 0.25, 9.0f, 0.25f, 3.0625f } } });
177}
178
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100179struct SplitLastDimFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
180{
181 SplitLastDimFixture(bool withDimZero=false) {
Jan Eilers8eb25602020-03-09 12:13:48 +0000182 armnn::IgnoreUnused(withDimZero);
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100183 m_Prototext = R"(
184 node {
185 name: "Placeholder"
186 op: "Placeholder"
187 attr {
188 key: "dtype"
189 value {
190 type: DT_FLOAT
191 }
192 }
193 attr {
194 key: "shape"
195 value {
196 shape {
197 dim {
198 size: 1
199 }
200 dim {
201 size: 2
202 }
203 dim {
204 size: 2
205 }
206 dim {
207 size: 3
208 }
209 }
210 }
211 }
212 }
213 node {
214 name: "Const"
215 op: "Const"
216 attr {
217 key: "dtype"
218 value {
219 type: DT_INT32
220 }
221 }
222 attr {
223 key: "value"
224 value {
225 tensor {
226 dtype: DT_INT32
227 tensor_shape {
228 }
229 int_val: 3
230 }
231 }
232 }
233 }
234 node {
235 name: "split/split_dim"
236 op: "Const"
237 attr {
238 key: "dtype"
239 value {
240 type: DT_INT32
241 }
242 }
243 attr {
244 key: "value"
245 value {
246 tensor {
247 dtype: DT_INT32
248 tensor_shape {
249 }
250 int_val: 3
251 }
252 }
253 }
254 }
255 node {
256 name: "split"
257 op: "Split"
258 input: "split/split_dim"
259 input: "Placeholder"
260 attr {
261 key: "T"
262 value {
263 type: DT_FLOAT
264 }
265 }
266 attr {
267 key: "num_split"
268 value {
269 i: 3
270 }
271 }
272 }
273 node {
274 name: "sub0/y"
275 op: "Const"
276 attr {
277 key: "dtype"
278 value {
279 type: DT_FLOAT
280 }
281 }
282 attr {
283 key: "value"
284 value {
285 tensor {
286 dtype: DT_FLOAT
287 tensor_shape {
288 }
289 float_val: 3.0
290 }
291 }
292 }
293 }
294 node {
295 name: "sub0"
296 op: "Sub"
297 input: "split"
298 input: "sub0/y"
299 attr {
300 key: "T"
301 value {
302 type: DT_FLOAT
303 }
304 }
305 }
306 node {
307 name: "sub1/y"
308 op: "Const"
309 attr {
310 key: "dtype"
311 value {
312 type: DT_FLOAT
313 }
314 }
315 attr {
316 key: "value"
317 value {
318 tensor {
319 dtype: DT_FLOAT
320 tensor_shape {
321 }
322 float_val: 2.0
323 }
324 }
325 }
326 }
327 node {
328 name: "sub1"
329 op: "Sub"
330 input: "split:1"
331 input: "sub1/y"
332 attr {
333 key: "T"
334 value {
335 type: DT_FLOAT
336 }
337 }
338 }
339 node {
340 name: "sub2/y"
341 op: "Const"
342 attr {
343 key: "dtype"
344 value {
345 type: DT_FLOAT
346 }
347 }
348 attr {
349 key: "value"
350 value {
351 tensor {
352 dtype: DT_FLOAT
353 tensor_shape {
354 }
355 float_val: 1.0
356 }
357 }
358 }
359 }
360 node {
361 name: "sub2"
362 op: "Sub"
363 input: "split:2"
364 input: "sub2/y"
365 attr {
366 key: "T"
367 value {
368 type: DT_FLOAT
369 }
370 }
371 }
372 versions {
373 producer: 27
374 } )";
375
376 Setup( { { "Placeholder", { 1, 2, 2 , 3} } },
377 { "sub0", "sub1", "sub2" });
378 }
379};
380
381BOOST_FIXTURE_TEST_CASE(SplitLastDimTest, SplitLastDimFixture)
382{
383 BOOST_TEST(
384 (m_Parser->GetNetworkOutputBindingInfo("sub0").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
385
386 BOOST_TEST(
387 (m_Parser->GetNetworkOutputBindingInfo("sub1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
388
389 BOOST_TEST(
390 (m_Parser->GetNetworkOutputBindingInfo("sub2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
391
392 RunTest<4>({ { "Placeholder", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f } } },
393 { { "sub0", { -2.0f, 1.0f, 4.0f, 7.0f } },
394 { "sub1", { 0.0f, 3.0f, 6.0f, 9.0f } },
395 { "sub2", { 2.0f, 5.0f, 8.0f, 11.0f } } });
396}
397
Sadik Armagan2ad6cb42018-12-27 11:23:44 +0000398BOOST_AUTO_TEST_SUITE_END()