blob: 53094c1efe1e7353e7101f9f2dd630d68cd483e5 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "armnnOnnxParser/IOnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8
Sadik Armagan1625efc2021-06-10 18:24:34 +01009TEST_SUITE("OnnxParser_FullyConnected")
10{
telsoa01c577f2c2018-08-31 09:22:23 +010011// A MatMul in isolation, not connected to an add. Should result in a non-biased FullyConnected layer.
12struct MatMulFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13{
14 MatMulFixture()
15 {
16 m_Prototext = R"(
17 ir_version: 3
18 producer_name: "CNTK "
19 producer_version: "2.5.1 "
20 domain: "ai.cntk "
21 model_version: 1
22 graph {
23 name: "CNTKGraph "
24 input {
25 name: "Input"
26 type {
27 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000028 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010029 shape {
30 dim {
31 dim_value: 1
32 }
33 dim {
34 dim_value: 1
35 }
36 }
37 }
38 }
39 }
40 input {
41 name: "Const"
42 type {
43 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000044 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010045 shape {
46 dim {
47 dim_value: 1
48 }
49 dim {
50 dim_value: 1
51 }
52 }
53 }
54 }
55 }
56 initializer {
57 dims: 1
Tee Jungd94efa82019-11-01 11:55:21 +000058 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +000059 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010060 float_data: 17.0
61 name: "Const"
62 }
63 node {
64 input: "Input"
65 input: "Const"
66 output: "Output"
67 name: "SimpleMatmul"
68 op_type: "MatMul"
69 }
70 output {
71 name: "Output"
72 type {
73 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000074 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010075 shape {
76 dim {
77 dim_value: 1
78 }
79 dim {
80 dim_value: 1
81 }
82 }
83 }
84 }
85 }
86 }
87 opset_import {
88 version: 7
89 })";
90
91 Setup();
92 }
93};
94
Sadik Armagan1625efc2021-06-10 18:24:34 +010095TEST_CASE_FIXTURE(MatMulFixture, "MatMul")
telsoa01c577f2c2018-08-31 09:22:23 +010096{
97 RunTest<1>({{"Input", { 2 }}}, {{"Output", { 34 }}});
98}
99
100// In Onnx fully connected layers are expressed as a MatMul followed by an Add.
101// The OnnxParser must detect this case and convert them to a FullyConnected layer.
102struct FullyConnectedFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
103{
104 FullyConnectedFixture()
105 {
106 m_Prototext = R"(
107 ir_version: 3
108 producer_name: "CNTK "
109 producer_version: "2.5.1 "
110 domain: "ai.cntk "
111 model_version: 1
112 graph {
113 name: "CNTKGraph "
114 input {
115 name: "Input"
116 type {
117 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000118 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100119 shape {
120 dim {
121 dim_value: 1
122 }
123 dim {
124 dim_value: 1
125 }
126 }
127 }
128 }
129 }
130 input {
131 name: "Weight"
132 type {
133 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000134 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100135 shape {
136 dim {
137 dim_value: 1
138 }
139 dim {
140 dim_value: 1
141 }
142 }
143 }
144 }
145 }
146 initializer {
147 dims: 1
Tee Jungd94efa82019-11-01 11:55:21 +0000148 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000149 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100150 float_data: 2
151 name: "Weight"
152 }
153 input {
154 name: "Bias"
155 type {
156 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000157 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100158 shape {
159 dim {
160 dim_value: 1
161 }
162 }
163 }
164 }
165 }
166 initializer {
167 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000168 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100169 float_data: 1
170 name: "Bias"
171 }
172 node {
173 input: "Input"
174 input: "Weight"
175 output: "AddInput"
176 name: "FCMatmul"
177 op_type: "MatMul"
178 }
179 node {
180 input: "AddInput"
181 input: "Bias"
182 output: "Output"
183 name: "FCAdd"
184 op_type: "Add"
185 }
186 value_info {
187 name: "AddInput"
188 type {
189 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000190 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100191 shape {
192 dim {
193 dim_value: 1
194 }
195 dim {
196 dim_value: 1
197 }
198 }
199 }
200 }
201 }
202 output {
203 name: "Output"
204 type {
205 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000206 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100207 shape {
208 dim {
209 dim_value: 1
210 }
211 dim {
212 dim_value: 1
213 }
214 }
215 }
216 }
217 }
218 }
219 opset_import {
220 version: 7
221 })";
222
223 Setup();
224 }
225};
226
Sadik Armagan1625efc2021-06-10 18:24:34 +0100227TEST_CASE_FIXTURE(FullyConnectedFixture, "FullyConnected")
telsoa01c577f2c2018-08-31 09:22:23 +0100228{
229 RunTest<1>({{"Input", { 3 }}}, {{"Output", { 7 }}});
230}
231
232
233// Similar to FullyConnectedFixture, but this time the MatMul's output is used by two Adds. This should result
234// in two FullyConnected layers being created.
235// I
236// |
237// M -- C
238// / \'
239// C-- A A -- C
240// \ /
241// A
242struct MatMulUsedInTwoFcFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
243{
244 MatMulUsedInTwoFcFixture()
245 {
246 m_Prototext = R"(
247 ir_version: 3
248 producer_name: "CNTK "
249 producer_version: "2.5.1 "
250 domain: "ai.cntk "
251 model_version: 1
252 graph {
253 name: "CNTKGraph "
254 input {
255 name: "Input"
256 type {
257 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000258 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100259 shape {
260 dim {
261 dim_value: 1
262 }
263 dim {
264 dim_value: 1
265 }
266 }
267 }
268 }
269 }
270 input {
271 name: "Weight"
272 type {
273 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000274 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100275 shape {
276 dim {
277 dim_value: 1
278 }
279 dim {
280 dim_value: 1
281 }
282 }
283 }
284 }
285 }
286 initializer {
287 dims: 1
Tee Jungd94efa82019-11-01 11:55:21 +0000288 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000289 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100290 float_data: 2
291 name: "Weight"
292 }
293 input {
294 name: "Bias"
295 type {
296 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000297 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100298 shape {
299 dim {
300 dim_value: 1
301 }
302 }
303 }
304 }
305 }
306 initializer {
307 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000308 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100309 float_data: 1
310 name: "Bias"
311 }
312 input {
313 name: "Bias_1"
314 type {
315 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000316 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100317 shape {
318 dim {
319 dim_value: 1
320 }
321 }
322 }
323 }
324 }
325 initializer {
326 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000327 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100328 float_data: 10.0
329 name: "Bias_1"
330 }
331 node {
332 input: "Input"
333 input: "Weight"
334 output: "AddInput"
335 name: "FCMatmul"
336 op_type: "MatMul"
337 }
338 node {
339 input: "AddInput"
340 input: "Bias"
341 output: "AddOutput"
342 name: "FCAdd"
343 op_type: "Add"
344 }
345 node {
346 input: "AddInput"
347 input: "Bias_1"
348 output: "AddOutput_1"
349 name: "FCAdd_1"
350 op_type: "Add"
351 }
352 node {
353 input: "AddOutput"
354 input: "AddOutput_1"
355 output: "Output"
356 name: "FinalAdd"
357 op_type: "Add"
358 }
359 value_info {
360 name: "AddInput"
361 type {
362 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000363 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100364 shape {
365 dim {
366 dim_value: 1
367 }
368 dim {
369 dim_value: 1
370 }
371 }
372 }
373 }
374 }
375 value_info {
376 name: "AddOutput"
377 type {
378 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000379 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100380 shape {
381 dim {
382 dim_value: 1
383 }
384 dim {
385 dim_value: 1
386 }
387 }
388 }
389 }
390 }
391 value_info {
392 name: "AddOutput_1"
393 type {
394 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000395 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100396 shape {
397 dim {
398 dim_value: 1
399 }
400 dim {
401 dim_value: 1
402 }
403 }
404 }
405 }
406 }
407 output {
408 name: "Output"
409 type {
410 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000411 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100412 shape {
413 dim {
414 dim_value: 1
415 }
416 dim {
417 dim_value: 1
418 }
419 }
420 }
421 }
422 }
423 }
424 opset_import {
425 version: 7
426 })";
427
428 Setup();
429 }
430};
431
Sadik Armagan1625efc2021-06-10 18:24:34 +0100432TEST_CASE_FIXTURE(MatMulUsedInTwoFcFixture, "MatMulUsedInTwoFc")
telsoa01c577f2c2018-08-31 09:22:23 +0100433{
434 RunTest<1>({{"Input", { 3 }}}, {{"Output", { 23 }}});
435}
436
437
438// Similar to MatMulUsedInTwoFc, but this time the Adds are 'staggered' (see diagram), which means that only one
439// FullyConnected layer can be created (the other should just be an Add).
440// I
441// |
442// M -- C1
443// / \'
444// C2 -- A |
445// \ /
446// A
447struct MatMulUsedInTwoFcStaggeredFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
448{
449 MatMulUsedInTwoFcStaggeredFixture()
450 {
451 m_Prototext = R"(
452 ir_version: 3
453 producer_name: "CNTK "
454 producer_version: "2.5.1 "
455 domain: "ai.cntk "
456 model_version: 1
457 graph {
458 name: "CNTKGraph "
459 input {
460 name: "Input"
461 type {
462 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000463 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100464 shape {
465 dim {
466 dim_value: 1
467 }
468 dim {
469 dim_value: 1
470 }
471 }
472 }
473 }
474 }
475 input {
476 name: "Weight"
477 type {
478 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000479 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100480 shape {
481 dim {
482 dim_value: 1
483 }
484 dim {
485 dim_value: 1
486 }
487 }
488 }
489 }
490 }
491 initializer {
492 dims: 1
Tee Jungd94efa82019-11-01 11:55:21 +0000493 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000494 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100495 float_data: 2
496 name: "Weight"
497 }
498 input {
499 name: "Bias"
500 type {
501 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000502 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100503 shape {
504 dim {
505 dim_value: 1
506 }
507 }
508 }
509 }
510 }
511 initializer {
512 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000513 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100514 float_data: 1
515 name: "Bias"
516 }
517 node {
518 input: "Input"
519 input: "Weight"
520 output: "AddInput"
521 name: "MatmulFC&NFC"
522 op_type: "MatMul"
523 }
524 node {
525 input: "AddInput"
526 input: "Bias"
527 output: "AddOutput"
528 name: "FCAdd"
529 op_type: "Add"
530 }
531
532 node {
533 input: "AddInput"
534 input: "AddOutput"
535 output: "Output"
536 name: "FinalAdd"
537 op_type: "Add"
538 }
539 value_info {
540 name: "AddInput"
541 type {
542 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000543 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100544 shape {
545 dim {
546 dim_value: 1
547 }
548 dim {
549 dim_value: 1
550 }
551 }
552 }
553 }
554 }
555 value_info {
556 name: "AddOutput"
557 type {
558 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000559 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100560 shape {
561 dim {
562 dim_value: 1
563 }
564 dim {
565 dim_value: 1
566 }
567 }
568 }
569 }
570 }
571 output {
572 name: "Output"
573 type {
574 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000575 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100576 shape {
577 dim {
578 dim_value: 1
579 }
580 dim {
581 dim_value: 1
582 }
583 }
584 }
585 }
586 }
587 }
588 opset_import {
589 version: 7
590 })";
591 Setup();
592 }
593};
594
Sadik Armagan1625efc2021-06-10 18:24:34 +0100595TEST_CASE_FIXTURE(MatMulUsedInTwoFcStaggeredFixture, "MatMulUsedInTwoFcStaggered")
telsoa01c577f2c2018-08-31 09:22:23 +0100596{
597 RunTest<1>({{"Input", { 3 }}}, {{"Output", { 13 }}});
598}
599
Sadik Armagan1625efc2021-06-10 18:24:34 +0100600}