blob: 10ff04df894502ae3903ffe1e90975de6d1611d8 [file] [log] [blame]
Sadik Armagan2ad6cb42018-12-27 11:23:44 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12struct SplitFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13{
Saoirse Stewart91c0eff2019-02-27 11:07:57 +000014 SplitFixture(bool withDimZero=false) {
15 m_Prototext = R"(
16 node {
17 name: "graphInput"
18 op: "Placeholder"
19 attr {
20 key: "dtype"
21 value {
22 type: DT_FLOAT
23 }
24 }
25 attr {
26 key: "shape"
27 value {
28 shape {
29 }
30 }
31 }
32 }
33 node {
34 name: "graphInput2"
35 op: "Placeholder"
36 attr {
37 key: "dtype"
38 value {
39 type: DT_FLOAT
40 }
41 }
42 attr {
43 key: "shape"
44 value {
45 shape {
46 }
47 }
48 }
49 }
50 node {
51 name: "multiplication"
52 op : "Mul"
53 input: "graphInput"
54 input: "graphInput2"
55 attr {
56 key: "T"
57 value {
58 type: DT_FLOAT
59 }
60 }
61 }
62 node {
63 name: "SplitInput"
64 op: "Const"
65 attr {
66 key: "dtype"
67 value {
68 type: DT_INT32
69 }
70 }
71 attr {
72 key: "value"
73 value {
74 tensor {
75 dtype: DT_INT32
76 tensor_shape {
77 }
78 int_val: )";
Sadik Armagan2ad6cb42018-12-27 11:23:44 +000079
Saoirse Stewart91c0eff2019-02-27 11:07:57 +000080 if(withDimZero)
81 {
82 m_Prototext += std::to_string(3);
83 }
84 else
85 {
86 m_Prototext += std::to_string(1);
87 }
88
89 m_Prototext += R"(
90 }
91 }
92 }
93 }
94 node {
95 name: "Split"
96 op: "Split" )";
97 if(withDimZero)
98 {
99 m_Prototext += "input: \"SplitInput\"\n";
100 m_Prototext += "input: \"multiplication\"\n";
101 }
102 else
103 {
104 m_Prototext += "input: \"graphInput\"\n";
105 m_Prototext += "input: \"SplitInput\"\n";
106 }
107 m_Prototext += R"(
108 attr {
Saoirse Stewart315258e2019-02-28 11:32:41 +0000109 key: "num_split"
Saoirse Stewart91c0eff2019-02-27 11:07:57 +0000110 value {
111 i: 2
112 }
113 }
114 }
115 node {
116 name: "Relu_1"
117 op: "Relu"
118 input: "Split:0"
119 attr {
120 key: "T"
121 value {
122 type: DT_FLOAT
123 }
124 }
125 }
126 node {
127 name: "Relu_2"
128 op: "Relu"
129 input:"Split:1"
130 attr {
131 key: "T"
132 value {
133 type: DT_FLOAT
134 }
135 }
136 } )";
137
138 Setup( { { "graphInput", { 1, 2, 2 , 2} } , { "graphInput2", { 1, 2, 2 , 2} }},
Sadik Armagan2ad6cb42018-12-27 11:23:44 +0000139 { "Relu_1", "Relu_2" });
140 }
141};
142
Saoirse Stewart91c0eff2019-02-27 11:07:57 +0000143struct InputFirstSplitFixture : SplitFixture
144{
145 InputFirstSplitFixture() : SplitFixture(true) {}
146};
147
Sadik Armagan2ad6cb42018-12-27 11:23:44 +0000148BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture)
149{
150 BOOST_TEST(
151 (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
152
153 BOOST_TEST(
154 (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 1, 2, 2 })));
155
156 RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f, 1.75f } } },
157 { { "Relu_1", { 0.0f, 0.0f, 1.25f, 0.0f } },
158 { "Relu_2", { 0.0f, 0.5f, 0.0f, 1.75f } } });
159}
160
Saoirse Stewart91c0eff2019-02-27 11:07:57 +0000161BOOST_FIXTURE_TEST_CASE(ParseSplit, InputFirstSplitFixture)
162{
163
164 BOOST_TEST(
165 (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
166
167 BOOST_TEST(
168 (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
169
170 RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } ,
171 { "graphInput2", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } },
172 { { "Relu_1", { 1.0f, 1.5625f, 0, 0.5625f } },
173 { "Relu_2", { 0.25, 9.0f, 0.25f, 3.0625f } } });
174}
175
Narumol Prangnawarat15eb5832019-05-20 15:31:05 +0100176struct SplitLastDimFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
177{
178 SplitLastDimFixture(bool withDimZero=false) {
179 m_Prototext = R"(
180 node {
181 name: "Placeholder"
182 op: "Placeholder"
183 attr {
184 key: "dtype"
185 value {
186 type: DT_FLOAT
187 }
188 }
189 attr {
190 key: "shape"
191 value {
192 shape {
193 dim {
194 size: 1
195 }
196 dim {
197 size: 2
198 }
199 dim {
200 size: 2
201 }
202 dim {
203 size: 3
204 }
205 }
206 }
207 }
208 }
209 node {
210 name: "Const"
211 op: "Const"
212 attr {
213 key: "dtype"
214 value {
215 type: DT_INT32
216 }
217 }
218 attr {
219 key: "value"
220 value {
221 tensor {
222 dtype: DT_INT32
223 tensor_shape {
224 }
225 int_val: 3
226 }
227 }
228 }
229 }
230 node {
231 name: "split/split_dim"
232 op: "Const"
233 attr {
234 key: "dtype"
235 value {
236 type: DT_INT32
237 }
238 }
239 attr {
240 key: "value"
241 value {
242 tensor {
243 dtype: DT_INT32
244 tensor_shape {
245 }
246 int_val: 3
247 }
248 }
249 }
250 }
251 node {
252 name: "split"
253 op: "Split"
254 input: "split/split_dim"
255 input: "Placeholder"
256 attr {
257 key: "T"
258 value {
259 type: DT_FLOAT
260 }
261 }
262 attr {
263 key: "num_split"
264 value {
265 i: 3
266 }
267 }
268 }
269 node {
270 name: "sub0/y"
271 op: "Const"
272 attr {
273 key: "dtype"
274 value {
275 type: DT_FLOAT
276 }
277 }
278 attr {
279 key: "value"
280 value {
281 tensor {
282 dtype: DT_FLOAT
283 tensor_shape {
284 }
285 float_val: 3.0
286 }
287 }
288 }
289 }
290 node {
291 name: "sub0"
292 op: "Sub"
293 input: "split"
294 input: "sub0/y"
295 attr {
296 key: "T"
297 value {
298 type: DT_FLOAT
299 }
300 }
301 }
302 node {
303 name: "sub1/y"
304 op: "Const"
305 attr {
306 key: "dtype"
307 value {
308 type: DT_FLOAT
309 }
310 }
311 attr {
312 key: "value"
313 value {
314 tensor {
315 dtype: DT_FLOAT
316 tensor_shape {
317 }
318 float_val: 2.0
319 }
320 }
321 }
322 }
323 node {
324 name: "sub1"
325 op: "Sub"
326 input: "split:1"
327 input: "sub1/y"
328 attr {
329 key: "T"
330 value {
331 type: DT_FLOAT
332 }
333 }
334 }
335 node {
336 name: "sub2/y"
337 op: "Const"
338 attr {
339 key: "dtype"
340 value {
341 type: DT_FLOAT
342 }
343 }
344 attr {
345 key: "value"
346 value {
347 tensor {
348 dtype: DT_FLOAT
349 tensor_shape {
350 }
351 float_val: 1.0
352 }
353 }
354 }
355 }
356 node {
357 name: "sub2"
358 op: "Sub"
359 input: "split:2"
360 input: "sub2/y"
361 attr {
362 key: "T"
363 value {
364 type: DT_FLOAT
365 }
366 }
367 }
368 versions {
369 producer: 27
370 } )";
371
372 Setup( { { "Placeholder", { 1, 2, 2 , 3} } },
373 { "sub0", "sub1", "sub2" });
374 }
375};
376
377BOOST_FIXTURE_TEST_CASE(SplitLastDimTest, SplitLastDimFixture)
378{
379 BOOST_TEST(
380 (m_Parser->GetNetworkOutputBindingInfo("sub0").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
381
382 BOOST_TEST(
383 (m_Parser->GetNetworkOutputBindingInfo("sub1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
384
385 BOOST_TEST(
386 (m_Parser->GetNetworkOutputBindingInfo("sub2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
387
388 RunTest<4>({ { "Placeholder", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f } } },
389 { { "sub0", { -2.0f, 1.0f, 4.0f, 7.0f } },
390 { "sub1", { 0.0f, 3.0f, 6.0f, 9.0f } },
391 { "sub2", { 2.0f, 5.0f, 8.0f, 11.0f } } });
392}
393
Sadik Armagan2ad6cb42018-12-27 11:23:44 +0000394BOOST_AUTO_TEST_SUITE_END()