blob: c58615f990eb89523240ee61cb5dd377545bd79b [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh01bceff2f2018-03-29 16:29:27 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
Matteo Martincigh46315822018-11-28 16:22:36 +00009
10#include <array>
surmeh01bceff2f2018-03-29 16:29:27 +010011#include <string>
12#include <iostream>
13
14BOOST_AUTO_TEST_SUITE(TensorflowParser)
15
telsoa01c577f2c2018-08-31 09:22:23 +010016struct Convolution2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010017{
Matteo Martincigh46315822018-11-28 16:22:36 +000018 explicit Convolution2dFixture(const std::string& dataLayout, const std::string& paddingType)
19 : Convolution2dFixture(dataLayout, paddingType, 1)
surmeh01bceff2f2018-03-29 16:29:27 +010020 {}
21
telsoa01c577f2c2018-08-31 09:22:23 +010022 // Dilation: 0 - dilations attribute is not included;
23 // Dilation: >0 - dilations attribute set to [1,v,v,1], where v is the value of the dilation arg
Matteo Martincigh46315822018-11-28 16:22:36 +000024 explicit Convolution2dFixture(const std::string& dataLayout, const std::string& paddingType,
25 int stride, int dilation = 0)
surmeh01bceff2f2018-03-29 16:29:27 +010026 {
Matteo Martincigh46315822018-11-28 16:22:36 +000027 std::string strideString (" i: 1 \n"
28 " i: 1 \n");
29 if (dataLayout == "NHWC")
30 {
31 strideString.append(" i: " + std::to_string(stride) + " \n"
32 " i: 1 \n");
33 }
34 else // dataLayout == "NCHW"
35 {
36 strideString.append(" i: 1 \n"
37 " i: " + std::to_string(stride) + " \n");
38 }
39
Sadik Armagan60bb9d82021-01-11 15:15:01 +000040 std::string dilationString;
41 if (dataLayout == "NHWC")
42 {
43 dilationString.append(" i: 1 \n"
44 " i: " + std::to_string(dilation) + " \n"
45 " i: " + std::to_string(dilation) + " \n"
46 " i: 1 \n");
47 }
48 else // dataLayout == "NCHW"
49 {
50 dilationString.append(" i: 1 \n"
51 " i: 1 \n"
52 " i: " + std::to_string(dilation) + " \n"
53 " i: " + std::to_string(dilation) + " \n");
54 }
55
surmeh01bceff2f2018-03-29 16:29:27 +010056 m_Prototext = "node { \n"
57 " name: \"graphInput\" \n"
58 " op: \"Placeholder\" \n"
59 " attr { \n"
60 " key: \"dtype\" \n"
61 " value { \n"
62 " type: DT_FLOAT \n"
63 " } \n"
64 " } \n"
65 " attr { \n"
66 " key: \"shape\" \n"
67 " value { \n"
68 " shape { \n"
69 " } \n"
70 " } \n"
71 " } \n"
72 " } \n"
73 " node { \n"
74 " name: \"Const_1\" \n"
75 " op: \"Const\" \n"
76 " attr { \n"
77 " key: \"dtype\" \n"
78 " value { \n"
79 " type: DT_FLOAT \n"
80 " } \n"
81 " } \n"
82 " attr { \n"
83 " key: \"value\" \n"
84 " value { \n"
85 " tensor { \n"
86 " dtype: DT_FLOAT \n"
87 " tensor_shape { \n"
88 " dim { \n"
89 " size: 1 \n"
90 " } \n"
91 " dim { \n"
92 " size: 3 \n"
93 " } \n"
94 " dim { \n"
95 " size: 1 \n"
96 " } \n"
97 " dim { \n"
98 " size: 1 \n"
99 " } \n"
100 " } \n"
101 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
102 " } \n"
103 " } \n"
104 " } \n"
105 "} \n"
106 "node { \n"
107 " name: \"potato\" \n"
108 " op: \"Conv2D\" \n"
109 " input: \"graphInput\" \n"
110 " input: \"Const_1\" \n"
111 " attr { \n"
112 " key: \"T\" \n"
113 " value { \n"
114 " type: DT_FLOAT \n"
115 " } \n"
116 " } \n"
117 " attr { \n"
118 " key: \"data_format\" \n"
119 " value { \n"
surmeh01bceff2f2018-03-29 16:29:27 +0100120 " s: \"";
Matteo Martincigh46315822018-11-28 16:22:36 +0000121 m_Prototext.append(dataLayout);
122 m_Prototext.append("\"\n"
123 " } \n"
124 " } \n"
125 " attr { \n"
126 " key: \"padding\" \n"
127 " value { \n"
128 " s: \"");
surmeh01bceff2f2018-03-29 16:29:27 +0100129 m_Prototext.append(paddingType);
130 m_Prototext.append("\"\n"
131 " } \n"
132 " } \n"
133 " attr { \n"
134 " key: \"strides\" \n"
135 " value { \n"
Matteo Martincigh46315822018-11-28 16:22:36 +0000136 " list { \n");
surmeh01bceff2f2018-03-29 16:29:27 +0100137 m_Prototext.append(strideString);
Matteo Martincigh46315822018-11-28 16:22:36 +0000138
139 m_Prototext.append(" } \n"
surmeh01bceff2f2018-03-29 16:29:27 +0100140 " } \n"
141 " } \n");
142
143 if (dilation > 0)
144 {
145 m_Prototext.append(" attr { \n"
146 " key: \"dilations\" \n"
147 " value { \n"
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000148 " list { \n");
surmeh01bceff2f2018-03-29 16:29:27 +0100149 m_Prototext.append(dilationString);
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000150
151 m_Prototext.append(" } \n"
surmeh01bceff2f2018-03-29 16:29:27 +0100152 " } \n"
153 " } \n");
154 }
155 m_Prototext.append(" attr { \n"
156 " key: \"use_cudnn_on_gpu\" \n"
157 " value { \n"
158 " b: false \n"
159 " } \n"
160 " } \n"
161 "} \n");
162
163 // Manual height computation based on stride parameter.
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100164 ARMNN_ASSERT_MSG(stride == 1 || stride == 2, "Add support for strides other than 1 or 2.");
Matteo Martincigh46315822018-11-28 16:22:36 +0000165 std::array<unsigned int, 4> dims;
166 if (dataLayout == "NHWC")
surmeh01bceff2f2018-03-29 16:29:27 +0100167 {
Matteo Martincigh46315822018-11-28 16:22:36 +0000168 dims = { 1u, (stride == 2 ? 3u : 2u), 3u, 1u };
169 }
170 else // dataLayout == "NCHW"
171 {
172 dims = { 1u, 1u, (stride == 2 ? 3u : 2u), 3u };
surmeh01bceff2f2018-03-29 16:29:27 +0100173 }
174
Matteo Martincigh46315822018-11-28 16:22:36 +0000175 SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "potato");
surmeh01bceff2f2018-03-29 16:29:27 +0100176 }
177};
178
Matteo Martincigh46315822018-11-28 16:22:36 +0000179struct Convolution2dNhwcSameFixture : Convolution2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100180{
Matteo Martincigh46315822018-11-28 16:22:36 +0000181 Convolution2dNhwcSameFixture() : Convolution2dFixture("NHWC", "SAME", 1){}
surmeh01bceff2f2018-03-29 16:29:27 +0100182};
Matteo Martincigh46315822018-11-28 16:22:36 +0000183BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcSame, Convolution2dNhwcSameFixture)
surmeh01bceff2f2018-03-29 16:29:27 +0100184{
185 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
186}
187
Matteo Martincigh46315822018-11-28 16:22:36 +0000188struct Convolution2dNchwSameFixture : Convolution2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100189{
Matteo Martincigh46315822018-11-28 16:22:36 +0000190 Convolution2dNchwSameFixture() : Convolution2dFixture("NCHW", "SAME", 1){}
surmeh01bceff2f2018-03-29 16:29:27 +0100191};
Matteo Martincigh46315822018-11-28 16:22:36 +0000192BOOST_FIXTURE_TEST_CASE(ParseConv2dNchwSame, Convolution2dNchwSameFixture)
193{
194 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
195}
196
197
198struct Convolution2dNhwcValidFixture : Convolution2dFixture
199{
200 Convolution2dNhwcValidFixture() : Convolution2dFixture("NHWC", "VALID", 1){}
201};
202BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcValid, Convolution2dNhwcValidFixture)
203{
204 RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
205}
206
207struct Convolution2dNchwValidFixture : Convolution2dFixture
208{
209 Convolution2dNchwValidFixture() : Convolution2dFixture("NCHW", "VALID", 1){}
210};
211BOOST_FIXTURE_TEST_CASE(ParseConv2dNchwValid, Convolution2dNchwValidFixture)
surmeh01bceff2f2018-03-29 16:29:27 +0100212{
213 RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
214}
215
216
Matteo Martincigh46315822018-11-28 16:22:36 +0000217struct Convolution2dStride2NhwcSameFixture : Convolution2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100218{
Matteo Martincigh46315822018-11-28 16:22:36 +0000219 Convolution2dStride2NhwcSameFixture() : Convolution2dFixture("NHWC", "SAME", 2){}
surmeh01bceff2f2018-03-29 16:29:27 +0100220};
Matteo Martincigh46315822018-11-28 16:22:36 +0000221BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NhwcSame, Convolution2dStride2NhwcSameFixture)
222{
223 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
224}
225
226struct Convolution2dStride2NchwSameFixture : Convolution2dFixture
227{
228 Convolution2dStride2NchwSameFixture() : Convolution2dFixture("NCHW", "SAME", 2){}
229};
230BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NchwSame, Convolution2dStride2NchwSameFixture)
surmeh01bceff2f2018-03-29 16:29:27 +0100231{
232 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
233}
234
235
Matteo Martincigh46315822018-11-28 16:22:36 +0000236struct Convolution2dStride2NhwcValidFixture : Convolution2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100237{
Matteo Martincigh46315822018-11-28 16:22:36 +0000238 Convolution2dStride2NhwcValidFixture() : Convolution2dFixture("NHWC", "VALID", 2){}
surmeh01bceff2f2018-03-29 16:29:27 +0100239};
Matteo Martincigh46315822018-11-28 16:22:36 +0000240BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NhwcValid, Convolution2dStride2NhwcValidFixture)
241{
242 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
243}
244
245struct Convolution2dStride2NchwValidFixture : Convolution2dFixture
246{
247 Convolution2dStride2NchwValidFixture() : Convolution2dFixture("NCHW", "VALID", 2){}
248};
249BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NchwValid, Convolution2dStride2NchwValidFixture)
surmeh01bceff2f2018-03-29 16:29:27 +0100250{
251 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
252}
253
254
Matteo Martincigh46315822018-11-28 16:22:36 +0000255struct Convolution2dDilation1NhwcFixture : Convolution2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100256{
Matteo Martincigh46315822018-11-28 16:22:36 +0000257 Convolution2dDilation1NhwcFixture() : Convolution2dFixture("NHWC", "SAME", 1, 1){}
surmeh01bceff2f2018-03-29 16:29:27 +0100258};
Matteo Martincigh46315822018-11-28 16:22:36 +0000259BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nhwc, Convolution2dDilation1NhwcFixture)
surmeh01bceff2f2018-03-29 16:29:27 +0100260{
261 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
262}
263
Matteo Martincigh46315822018-11-28 16:22:36 +0000264struct Convolution2dDilation1NchwFixture : Convolution2dFixture
265{
266 Convolution2dDilation1NchwFixture() : Convolution2dFixture("NCHW", "SAME", 1, 1){}
267};
268BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nchw, Convolution2dDilation1NchwFixture)
269{
270 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
271}
272
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000273struct Convolution2dDilationFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +0100274{
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000275 explicit Convolution2dDilationFixture(const std::string& dataLayout, const std::string& paddingType)
276 : Convolution2dDilationFixture(dataLayout, paddingType, 1)
277 {}
surmeh01bceff2f2018-03-29 16:29:27 +0100278
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000279 explicit Convolution2dDilationFixture(const std::string& dataLayout, const std::string& paddingType,
280 int stride, int dilation = 0)
281 {
282 std::string strideString;
283 if (dataLayout == "NHWC")
284 {
285 strideString.append(" i: 1 \n"
286 " i: " + std::to_string(stride) + " \n"
287 " i: " + std::to_string(stride) + " \n"
288 " i: 1 \n");
289 }
290 else // dataLayout == "NCHW"
291 {
292 strideString.append(" i: 1 \n"
293 " i: 1 \n"
294 " i: " + std::to_string(stride) + " \n"
295 " i: " + std::to_string(stride) + " \n");
296 }
297
298 std::string dilationString;
299 if (dataLayout == "NHWC")
300 {
301 dilationString.append(" i: 1 \n"
302 " i: " + std::to_string(dilation) + " \n"
303 " i: " + std::to_string(dilation) + " \n"
304 " i: 1 \n");
305 }
306 else // dataLayout == "NCHW"
307 {
308 dilationString.append(" i: 1 \n"
309 " i: 1 \n"
310 " i: " + std::to_string(dilation) + " \n"
311 " i: " + std::to_string(dilation) + " \n");
312 }
313
314 m_Prototext = "node { \n"
315 " name: \"graphInput\" \n"
316 " op: \"Placeholder\" \n"
317 " attr { \n"
318 " key: \"dtype\" \n"
319 " value { \n"
320 " type: DT_FLOAT \n"
321 " } \n"
322 " } \n"
323 " attr { \n"
324 " key: \"shape\" \n"
325 " value { \n"
326 " shape { \n"
327 " } \n"
328 " } \n"
329 " } \n"
330 " } \n"
331 " node { \n"
332 " name: \"Const_1\" \n"
333 " op: \"Const\" \n"
334 " attr { \n"
335 " key: \"dtype\" \n"
336 " value { \n"
337 " type: DT_FLOAT \n"
338 " } \n"
339 " } \n"
340 " attr { \n"
341 " key: \"value\" \n"
342 " value { \n"
343 " tensor { \n"
344 " dtype: DT_FLOAT \n"
345 " tensor_shape { \n"
346 " dim { \n"
347 " size: 3 \n"
348 " } \n"
349 " dim { \n"
350 " size: 1 \n"
351 " } \n"
352 " dim { \n"
353 " size: 1 \n"
354 " } \n"
355 " dim { \n"
356 " size: 1 \n"
357 " } \n"
358 " } \n"
359 " tensor_content: \"\\001\\000\\000?\\000\\000\\000?\\001\\000\\000?\" \n"
360 " } \n"
361 " } \n"
362 " } \n"
363 "} \n"
364 "node { \n"
365 " name: \"potato\" \n"
366 " op: \"Conv2D\" \n"
367 " input: \"graphInput\" \n"
368 " input: \"Const_1\" \n"
369 " attr { \n"
370 " key: \"T\" \n"
371 " value { \n"
372 " type: DT_FLOAT \n"
373 " } \n"
374 " } \n"
375 " attr { \n"
376 " key: \"data_format\" \n"
377 " value { \n"
378 " s: \"";
379 m_Prototext.append(dataLayout);
380 m_Prototext.append("\"\n"
381 " } \n"
382 " } \n"
383 " attr { \n"
384 " key: \"padding\" \n"
385 " value { \n"
386 " s: \"");
387 m_Prototext.append(paddingType);
388 m_Prototext.append("\"\n"
389 " } \n"
390 " } \n"
391 " attr { \n"
392 " key: \"strides\" \n"
393 " value { \n"
394 " list { \n");
395 m_Prototext.append(strideString);
396
397 m_Prototext.append(" } \n"
398 " } \n"
399 " } \n");
400
401 if (dilation > 0)
402 {
403 m_Prototext.append(" attr { \n"
404 " key: \"dilations\" \n"
405 " value { \n"
406 " list { \n");
407 m_Prototext.append(dilationString);
408
409 m_Prototext.append(" } \n"
410 " } \n"
411 " } \n");
412 }
413 m_Prototext.append(" attr { \n"
414 " key: \"use_cudnn_on_gpu\" \n"
415 " value { \n"
416 " b: false \n"
417 " } \n"
418 " } \n"
419 "} \n");
420
421 // Manual height computation based on stride parameter.
422 std::array<unsigned int, 4> dims = { 1u, 1u, 6u, 6u };;
423
424 SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "potato");
425 }
426};
427
428struct Convolution2dDilation2NchwValidFixture : Convolution2dDilationFixture
429{
430 Convolution2dDilation2NchwValidFixture() : Convolution2dDilationFixture("NCHW", "VALID", 1, 2){}
431};
432BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation2NchwValid, Convolution2dDilation2NchwValidFixture)
433{
434 RunTest<4>({1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
435 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
436 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
437 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
438 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
439 7.0, 8.0, 9.0, 10.0, 11.0, 12.0},
440 {1.5f, 3.0f, 4.5f, 6.0f, 7.5f, 9.0f, 10.5f, 12.f, 13.5f, 15.0f, 16.5f, 18.0f});
surmeh01bceff2f2018-03-29 16:29:27 +0100441}
442
443
444BOOST_AUTO_TEST_SUITE_END()