blob: 060388167f7cd3b9bcf770191aa714cb5459d23c [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "armnnOnnxParser/IOnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8
Sadik Armagan1625efc2021-06-10 18:24:34 +01009TEST_SUITE("OnnxParser_Conv2D")
10{
telsoa01c577f2c2018-08-31 09:22:23 +010011struct SimpleConv2DFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12{
13 SimpleConv2DFixture()
14 {
15 m_Prototext = R"(
16 ir_version: 3
17 producer_name: "CNTK"
18 producer_version: "2.5.1"
19 domain: "ai.cntk"
20 model_version: 1
21 graph {
22 name: "CNTKGraph"
23 input {
24 name: "Input"
25 type {
26 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000027 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010028 shape {
29 dim {
30 dim_value: 1
31 }
32 dim {
33 dim_value: 1
34 }
35 dim {
36 dim_value: 3
37 }
38 dim {
39 dim_value: 3
40 }
41 }
42 }
43 }
44 }
45 input {
46 name: "Weight"
47 type {
48 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000049 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010050 shape {
51 dim {
52 dim_value: 1
53 }
54 dim {
55 dim_value: 1
56 }
57 dim {
58 dim_value: 3
59 }
60 dim {
61 dim_value: 3
62 }
63 }
64 }
65 }
66 }
67 initializer {
68 dims: 1
69 dims: 1
70 dims: 3
71 dims: 3
Matteo Martincigh44a71672018-12-11 13:46:52 +000072 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010073 float_data: 2
74 float_data: 1
75 float_data: 0
76 float_data: 6
77 float_data: 2
78 float_data: 1
79 float_data: 4
80 float_data: 1
81 float_data: 2
82 name: "Weight"
83 }
84 node {
85 input: "Input"
86 input: "Weight"
87 output: "Output"
88 name: "Convolution"
89 op_type: "Conv"
90 attribute {
91 name: "kernel_shape"
92 ints: 3
93 ints: 3
94 type: INTS
95 }
96 attribute {
97 name: "strides"
98 ints: 1
99 ints: 1
100 type: INTS
101 }
102 attribute {
103 name: "auto_pad"
104 s: "VALID"
105 type: STRING
106 }
107 attribute {
108 name: "group"
109 i: 1
110 type: INT
111 }
112 attribute {
113 name: "dilations"
114 ints: 1
115 ints: 1
116 type: INTS
117 }
118 doc_string: ""
119 domain: ""
120 }
121 output {
122 name: "Output"
123 type {
124 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000125 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100126 shape {
127 dim {
128 dim_value: 1
129 }
130 dim {
131 dim_value: 1
132 }
133 dim {
134 dim_value: 1
135 }
136 dim {
137 dim_value: 1
138 }
139 }
140 }
141 }
142 }
143 }
144 opset_import {
145 version: 7
146 })";
147 Setup();
148 }
149};
150
151struct Conv2DWithBiasesFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
152{
153 Conv2DWithBiasesFixture() {
154 m_Prototext = R"(
155 ir_version: 3
156 producer_name: "CNTK"
157 producer_version: "2.5.1"
158 domain: "ai.cntk"
159 model_version: 1
160 graph {
161 name: "CNTKGraph"
162 input {
163 name: "Input"
164 type {
165 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000166 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100167 shape {
168 dim {
169 dim_value: 1
170 }
171 dim {
172 dim_value: 1
173 }
174 dim {
175 dim_value: 2
176 }
177 dim {
178 dim_value: 2
179 }
180 }
181 }
182 }
183 }
184 input {
185 name: "Weight"
186 type {
187 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000188 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100189 shape {
190 dim {
191 dim_value: 1
192 }
193 dim {
194 dim_value: 1
195 }
196 dim {
197 dim_value: 2
198 }
199 dim {
200 dim_value: 2
201 }
202 }
203 }
204 }
205 }
206 initializer {
207 dims: 1
208 dims: 1
209 dims: 2
210 dims: 2
Matteo Martincigh44a71672018-12-11 13:46:52 +0000211 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100212 float_data: 2
213 float_data: 1
214 float_data: 0
215 float_data: 6
216 name: "Weight"
217 }
218 input {
219 name: "Bias"
220 type {
221 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000222 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100223 shape {
224 dim {
225 dim_value: 4
226 }
227 }
228 }
229 }
230 }
231 initializer {
232 dims: 4
Matteo Martincigh44a71672018-12-11 13:46:52 +0000233 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100234 float_data: 10
235 float_data: 0
236 float_data: 0
237 float_data: 0
238 name: "Bias"
239 }
240 node {
241 input: "Input"
242 input: "Weight"
243 input: "Bias"
244 output: "Output"
245 name: "Convolution"
246 op_type: "Conv"
247 attribute {
248 name: "kernel_shape"
249 ints: 2
250 ints: 2
251 type: INTS
252 }
253 attribute {
254 name: "strides"
255 ints: 1
256 ints: 1
257 type: INTS
258 }
259 attribute {
260 name: "auto_pad"
261 s: "SAME_UPPER"
262 type: STRING
263 }
264 attribute {
265 name: "group"
266 i: 1
267 type: INT
268 }
269 attribute {
270 name: "dilations"
271 ints: 1
272 ints: 1
273 type: INTS
274 }
275 doc_string: ""
276 domain: ""
277 }
278 output {
279 name: "Output"
280 type {
281 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000282 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100283 shape {
284 dim {
285 dim_value: 1
286 }
287 dim {
288 dim_value: 1
289 }
290 dim {
291 dim_value: 2
292 }
293 dim {
294 dim_value: 2
295 }
296 }
297 }
298 }
299 }
300 }
301 opset_import {
302 version: 7
303 })";
304 Setup();
305 }
306};
307
308
309struct Conv2DDimReducingFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
310{
311 Conv2DDimReducingFixture() {
312 m_Prototext = R"(
313 ir_version: 3
314 producer_name: "CNTK"
315 producer_version: "2.5.1"
316 domain: "ai.cntk"
317 model_version: 1
318 graph {
319 name: "CNTKGraph"
320 input {
321 name: "Input"
322 type {
323 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000324 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100325 shape {
326 dim {
327 dim_value: 1
328 }
329 dim {
330 dim_value: 3
331 }
332 dim {
333 dim_value: 2
334 }
335 dim {
336 dim_value: 2
337 }
338 }
339 }
340 }
341 }
342 input {
343 name: "Weight"
344 type {
345 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000346 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100347 shape {
348 dim {
349 dim_value: 2
350 }
351 dim {
352 dim_value: 3
353 }
354 dim {
355 dim_value: 1
356 }
357 dim {
358 dim_value: 1
359 }
360 }
361 }
362 }
363 }
364 initializer {
365 dims: 2
366 dims: 3
367 dims: 1
368 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000369 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100370 float_data: -1
371 float_data: 2
372 float_data: 0
373 float_data: 1
374 float_data: 0
375 float_data: 0
376 name: "Weight"
377 }
378 node {
379 input: "Input"
380 input: "Weight"
381 output: "Output"
382 name: "Convolution"
383 op_type: "Conv"
384 attribute {
385 name: "kernel_shape"
386 ints: 1
387 ints: 1
388 type: INTS
389 }
390 attribute {
391 name: "strides"
392 ints: 1
393 ints: 1
394 type: INTS
395 }
396 attribute {
397 name: "group"
398 i: 1
399 type: INT
400 }
401 attribute {
402 name: "dilations"
403 ints: 1
404 ints: 1
405 type: INTS
406 }
407 doc_string: ""
408 domain: ""
409 }
410 output {
411 name: "Output"
412 type {
413 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000414 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100415 shape {
416 dim {
417 dim_value: 1
418 }
419 dim {
420 dim_value: 2
421 }
422 dim {
423 dim_value: 2
424 }
425 dim {
426 dim_value: 2
427 }
428 }
429 }
430 }
431 }
432 }
433 opset_import {
434 version: 7
435 })";
436 Setup();
437 }
438};
439
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000440struct Conv2DwithDilationFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
441{
442 Conv2DwithDilationFixture()
443 {
444 m_Prototext = R"(
445 ir_version: 3
446 producer_name: "CNTK"
447 producer_version: "2.5.1"
448 domain: "ai.cntk"
449 model_version: 1
450 graph {
451 name: "CNTKGraph"
452 input {
453 name: "Input"
454 type {
455 tensor_type {
456 elem_type: 1
457 shape {
458 dim {
459 dim_value: 1
460 }
461 dim {
462 dim_value: 1
463 }
464 dim {
465 dim_value: 6
466 }
467 dim {
468 dim_value: 6
469 }
470 }
471 }
472 }
473 }
474 input {
475 name: "Weight"
476 type {
477 tensor_type {
478 elem_type: 1
479 shape {
480 dim {
481 dim_value: 1
482 }
483 dim {
484 dim_value: 1
485 }
486 dim {
487 dim_value: 3
488 }
489 dim {
490 dim_value: 3
491 }
492 }
493 }
494 }
495 }
496 initializer {
497 dims: 1
498 dims: 1
499 dims: 3
500 dims: 3
501 data_type: 1
502 float_data: 2
503 float_data: 1
504 float_data: 0
505 float_data: 6
506 float_data: 2
507 float_data: 1
508 float_data: 4
509 float_data: 1
510 float_data: 2
511 name: "Weight"
512 }
513 node {
514 input: "Input"
515 input: "Weight"
516 output: "Output"
517 name: "Convolution"
518 op_type: "Conv"
519 attribute {
520 name: "kernel_shape"
521 ints: 3
522 ints: 3
523 type: INTS
524 }
525 attribute {
526 name: "strides"
527 ints: 1
528 ints: 1
529 type: INTS
530 }
531 attribute {
532 name: "auto_pad"
533 s: "VALID"
534 type: STRING
535 }
536 attribute {
537 name: "group"
538 i: 1
539 type: INT
540 }
541 attribute {
542 name: "dilations"
543 ints: 2
544 ints: 2
545 type: INTS
546 }
547 doc_string: ""
548 domain: ""
549 }
550 output {
551 name: "Output"
552 type {
553 tensor_type {
554 elem_type: 1
555 shape {
556 dim {
557 dim_value: 1
558 }
559 dim {
560 dim_value: 1
561 }
562 dim {
563 dim_value: 2
564 }
565 dim {
566 dim_value: 2
567 }
568 }
569 }
570 }
571 }
572 }
573 opset_import {
574 version: 7
575 })";
576 Setup();
577 }
578};
579
Sadik Armagan1625efc2021-06-10 18:24:34 +0100580TEST_CASE_FIXTURE(SimpleConv2DFixture, "ValidConvTest")
telsoa01c577f2c2018-08-31 09:22:23 +0100581{
582 RunTest<4>({{"Input", {1.0, 2.0, 3.0,
583 4.0, 5.0, 6.0,
584 7.0, 8.0, 9.0}}},
585 {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 +
586 4.0 * 6 + 5.0 * 2 + 6.0 * 1 +
587 7.0 * 4 + 8.0 * 1 + 9.0 * 2}}});
588}
589
Sadik Armagan1625efc2021-06-10 18:24:34 +0100590TEST_CASE_FIXTURE(Conv2DWithBiasesFixture, "ValidConvWithBiasTest")
telsoa01c577f2c2018-08-31 09:22:23 +0100591{
592 RunTest<4>({{"Input", {1.0, 2.0,
593 3.0, 4.0}}},
594 {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 + 4 * 6 + 10,
595 2.0 * 2 + 0 * 1 + 4.0 * 0 + 0 * 6 + 10,
596 3.0 * 2 + 4.0 * 1 + 0 * 0 + 0 * 6 + 10,
597 4.0 * 2 + 0 * 1 + 0 * 0 + 0 * 6 + 10}}});
598}
599
Sadik Armagan1625efc2021-06-10 18:24:34 +0100600TEST_CASE_FIXTURE(Conv2DDimReducingFixture, "ValidConvDimReducTest")
telsoa01c577f2c2018-08-31 09:22:23 +0100601{
602 RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, -1, -2, 3, 4, 1 , 1, 1, 1 }}},
603 {{"Output", {-1 * 1 + 2 * -1, -1 * 2 + 2 * -2,
604 -1 * 3 + 2 * 3, -1 * 4 + 2 * 4,
605 1, 2, 3, 4}}});
606}
607
Sadik Armagan1625efc2021-06-10 18:24:34 +0100608TEST_CASE_FIXTURE(Conv2DwithDilationFixture, "ValidConvWithDilationTest")
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000609{
610 RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
611 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
612 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
613 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
614 1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
615 7.0, 8.0, 9.0, 10.0, 11.0, 12.0}}},
616 {{"Output", {39.0, 58.0, 153.0, 172.0 }}});
617}
618
Sadik Armagan1625efc2021-06-10 18:24:34 +0100619}