blob: a38cc192ed10197d54eea1c175f1bb51be56703f [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnOnnxParser/IOnnxParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12struct SimpleConv2DFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13{
14 SimpleConv2DFixture()
15 {
16 m_Prototext = R"(
17 ir_version: 3
18 producer_name: "CNTK"
19 producer_version: "2.5.1"
20 domain: "ai.cntk"
21 model_version: 1
22 graph {
23 name: "CNTKGraph"
24 input {
25 name: "Input"
26 type {
27 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000028 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010029 shape {
30 dim {
31 dim_value: 1
32 }
33 dim {
34 dim_value: 1
35 }
36 dim {
37 dim_value: 3
38 }
39 dim {
40 dim_value: 3
41 }
42 }
43 }
44 }
45 }
46 input {
47 name: "Weight"
48 type {
49 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000050 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010051 shape {
52 dim {
53 dim_value: 1
54 }
55 dim {
56 dim_value: 1
57 }
58 dim {
59 dim_value: 3
60 }
61 dim {
62 dim_value: 3
63 }
64 }
65 }
66 }
67 }
68 initializer {
69 dims: 1
70 dims: 1
71 dims: 3
72 dims: 3
Matteo Martincigh44a71672018-12-11 13:46:52 +000073 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010074 float_data: 2
75 float_data: 1
76 float_data: 0
77 float_data: 6
78 float_data: 2
79 float_data: 1
80 float_data: 4
81 float_data: 1
82 float_data: 2
83 name: "Weight"
84 }
85 node {
86 input: "Input"
87 input: "Weight"
88 output: "Output"
89 name: "Convolution"
90 op_type: "Conv"
91 attribute {
92 name: "kernel_shape"
93 ints: 3
94 ints: 3
95 type: INTS
96 }
97 attribute {
98 name: "strides"
99 ints: 1
100 ints: 1
101 type: INTS
102 }
103 attribute {
104 name: "auto_pad"
105 s: "VALID"
106 type: STRING
107 }
108 attribute {
109 name: "group"
110 i: 1
111 type: INT
112 }
113 attribute {
114 name: "dilations"
115 ints: 1
116 ints: 1
117 type: INTS
118 }
119 doc_string: ""
120 domain: ""
121 }
122 output {
123 name: "Output"
124 type {
125 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000126 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100127 shape {
128 dim {
129 dim_value: 1
130 }
131 dim {
132 dim_value: 1
133 }
134 dim {
135 dim_value: 1
136 }
137 dim {
138 dim_value: 1
139 }
140 }
141 }
142 }
143 }
144 }
145 opset_import {
146 version: 7
147 })";
148 Setup();
149 }
150};
151
152struct Conv2DWithBiasesFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
153{
154 Conv2DWithBiasesFixture() {
155 m_Prototext = R"(
156 ir_version: 3
157 producer_name: "CNTK"
158 producer_version: "2.5.1"
159 domain: "ai.cntk"
160 model_version: 1
161 graph {
162 name: "CNTKGraph"
163 input {
164 name: "Input"
165 type {
166 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000167 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100168 shape {
169 dim {
170 dim_value: 1
171 }
172 dim {
173 dim_value: 1
174 }
175 dim {
176 dim_value: 2
177 }
178 dim {
179 dim_value: 2
180 }
181 }
182 }
183 }
184 }
185 input {
186 name: "Weight"
187 type {
188 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000189 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100190 shape {
191 dim {
192 dim_value: 1
193 }
194 dim {
195 dim_value: 1
196 }
197 dim {
198 dim_value: 2
199 }
200 dim {
201 dim_value: 2
202 }
203 }
204 }
205 }
206 }
207 initializer {
208 dims: 1
209 dims: 1
210 dims: 2
211 dims: 2
Matteo Martincigh44a71672018-12-11 13:46:52 +0000212 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100213 float_data: 2
214 float_data: 1
215 float_data: 0
216 float_data: 6
217 name: "Weight"
218 }
219 input {
220 name: "Bias"
221 type {
222 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000223 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100224 shape {
225 dim {
226 dim_value: 4
227 }
228 }
229 }
230 }
231 }
232 initializer {
233 dims: 4
Matteo Martincigh44a71672018-12-11 13:46:52 +0000234 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100235 float_data: 10
236 float_data: 0
237 float_data: 0
238 float_data: 0
239 name: "Bias"
240 }
241 node {
242 input: "Input"
243 input: "Weight"
244 input: "Bias"
245 output: "Output"
246 name: "Convolution"
247 op_type: "Conv"
248 attribute {
249 name: "kernel_shape"
250 ints: 2
251 ints: 2
252 type: INTS
253 }
254 attribute {
255 name: "strides"
256 ints: 1
257 ints: 1
258 type: INTS
259 }
260 attribute {
261 name: "auto_pad"
262 s: "SAME_UPPER"
263 type: STRING
264 }
265 attribute {
266 name: "group"
267 i: 1
268 type: INT
269 }
270 attribute {
271 name: "dilations"
272 ints: 1
273 ints: 1
274 type: INTS
275 }
276 doc_string: ""
277 domain: ""
278 }
279 output {
280 name: "Output"
281 type {
282 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000283 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100284 shape {
285 dim {
286 dim_value: 1
287 }
288 dim {
289 dim_value: 1
290 }
291 dim {
292 dim_value: 2
293 }
294 dim {
295 dim_value: 2
296 }
297 }
298 }
299 }
300 }
301 }
302 opset_import {
303 version: 7
304 })";
305 Setup();
306 }
307};
308
309
310struct Conv2DDimReducingFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
311{
312 Conv2DDimReducingFixture() {
313 m_Prototext = R"(
314 ir_version: 3
315 producer_name: "CNTK"
316 producer_version: "2.5.1"
317 domain: "ai.cntk"
318 model_version: 1
319 graph {
320 name: "CNTKGraph"
321 input {
322 name: "Input"
323 type {
324 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000325 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100326 shape {
327 dim {
328 dim_value: 1
329 }
330 dim {
331 dim_value: 3
332 }
333 dim {
334 dim_value: 2
335 }
336 dim {
337 dim_value: 2
338 }
339 }
340 }
341 }
342 }
343 input {
344 name: "Weight"
345 type {
346 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000347 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100348 shape {
349 dim {
350 dim_value: 2
351 }
352 dim {
353 dim_value: 3
354 }
355 dim {
356 dim_value: 1
357 }
358 dim {
359 dim_value: 1
360 }
361 }
362 }
363 }
364 }
365 initializer {
366 dims: 2
367 dims: 3
368 dims: 1
369 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000370 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100371 float_data: -1
372 float_data: 2
373 float_data: 0
374 float_data: 1
375 float_data: 0
376 float_data: 0
377 name: "Weight"
378 }
379 node {
380 input: "Input"
381 input: "Weight"
382 output: "Output"
383 name: "Convolution"
384 op_type: "Conv"
385 attribute {
386 name: "kernel_shape"
387 ints: 1
388 ints: 1
389 type: INTS
390 }
391 attribute {
392 name: "strides"
393 ints: 1
394 ints: 1
395 type: INTS
396 }
397 attribute {
398 name: "group"
399 i: 1
400 type: INT
401 }
402 attribute {
403 name: "dilations"
404 ints: 1
405 ints: 1
406 type: INTS
407 }
408 doc_string: ""
409 domain: ""
410 }
411 output {
412 name: "Output"
413 type {
414 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000415 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100416 shape {
417 dim {
418 dim_value: 1
419 }
420 dim {
421 dim_value: 2
422 }
423 dim {
424 dim_value: 2
425 }
426 dim {
427 dim_value: 2
428 }
429 }
430 }
431 }
432 }
433 }
434 opset_import {
435 version: 7
436 })";
437 Setup();
438 }
439};
440
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000441struct Conv2DwithDilationFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
442{
443 Conv2DwithDilationFixture()
444 {
445 m_Prototext = R"(
446 ir_version: 3
447 producer_name: "CNTK"
448 producer_version: "2.5.1"
449 domain: "ai.cntk"
450 model_version: 1
451 graph {
452 name: "CNTKGraph"
453 input {
454 name: "Input"
455 type {
456 tensor_type {
457 elem_type: 1
458 shape {
459 dim {
460 dim_value: 1
461 }
462 dim {
463 dim_value: 1
464 }
465 dim {
466 dim_value: 6
467 }
468 dim {
469 dim_value: 6
470 }
471 }
472 }
473 }
474 }
475 input {
476 name: "Weight"
477 type {
478 tensor_type {
479 elem_type: 1
480 shape {
481 dim {
482 dim_value: 1
483 }
484 dim {
485 dim_value: 1
486 }
487 dim {
488 dim_value: 3
489 }
490 dim {
491 dim_value: 3
492 }
493 }
494 }
495 }
496 }
497 initializer {
498 dims: 1
499 dims: 1
500 dims: 3
501 dims: 3
502 data_type: 1
503 float_data: 2
504 float_data: 1
505 float_data: 0
506 float_data: 6
507 float_data: 2
508 float_data: 1
509 float_data: 4
510 float_data: 1
511 float_data: 2
512 name: "Weight"
513 }
514 node {
515 input: "Input"
516 input: "Weight"
517 output: "Output"
518 name: "Convolution"
519 op_type: "Conv"
520 attribute {
521 name: "kernel_shape"
522 ints: 3
523 ints: 3
524 type: INTS
525 }
526 attribute {
527 name: "strides"
528 ints: 1
529 ints: 1
530 type: INTS
531 }
532 attribute {
533 name: "auto_pad"
534 s: "VALID"
535 type: STRING
536 }
537 attribute {
538 name: "group"
539 i: 1
540 type: INT
541 }
542 attribute {
543 name: "dilations"
544 ints: 2
545 ints: 2
546 type: INTS
547 }
548 doc_string: ""
549 domain: ""
550 }
551 output {
552 name: "Output"
553 type {
554 tensor_type {
555 elem_type: 1
556 shape {
557 dim {
558 dim_value: 1
559 }
560 dim {
561 dim_value: 1
562 }
563 dim {
564 dim_value: 2
565 }
566 dim {
567 dim_value: 2
568 }
569 }
570 }
571 }
572 }
573 }
574 opset_import {
575 version: 7
576 })";
577 Setup();
578 }
579};
580
telsoa01c577f2c2018-08-31 09:22:23 +0100581BOOST_FIXTURE_TEST_CASE(ValidConvTest, SimpleConv2DFixture)
582{
583 RunTest<4>({{"Input", {1.0, 2.0, 3.0,
584 4.0, 5.0, 6.0,
585 7.0, 8.0, 9.0}}},
586 {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 +
587 4.0 * 6 + 5.0 * 2 + 6.0 * 1 +
588 7.0 * 4 + 8.0 * 1 + 9.0 * 2}}});
589}
590
591BOOST_FIXTURE_TEST_CASE(ValidConvWithBiasTest, Conv2DWithBiasesFixture)
592{
593 RunTest<4>({{"Input", {1.0, 2.0,
594 3.0, 4.0}}},
595 {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 + 4 * 6 + 10,
596 2.0 * 2 + 0 * 1 + 4.0 * 0 + 0 * 6 + 10,
597 3.0 * 2 + 4.0 * 1 + 0 * 0 + 0 * 6 + 10,
598 4.0 * 2 + 0 * 1 + 0 * 0 + 0 * 6 + 10}}});
599}
600
601BOOST_FIXTURE_TEST_CASE(ValidConvDimReducTest, Conv2DDimReducingFixture)
602{
603 RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, -1, -2, 3, 4, 1 , 1, 1, 1 }}},
604 {{"Output", {-1 * 1 + 2 * -1, -1 * 2 + 2 * -2,
605 -1 * 3 + 2 * 3, -1 * 4 + 2 * 4,
606 1, 2, 3, 4}}});
607}
608
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000609BOOST_FIXTURE_TEST_CASE(ValidConvWithDilationTest, Conv2DwithDilationFixture)
610{
611 RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
612 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
613 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
614 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
615 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
616 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}}},
617 {{"Output", {39.0, 58.0, 153.0, 172.0 }}});
618}
619
telsoa01c577f2c2018-08-31 09:22:23 +0100620BOOST_AUTO_TEST_SUITE_END()