blob: f68758f42e10ebfb8997d31138daa390ee2ddfec [file] [log] [blame]
Narumol Prangnawarat1112b012021-09-30 12:10:50 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "armnnOnnxParser/IOnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8#include "OnnxParserTestUtils.hpp"
9
10TEST_SUITE("OnnxParser_Gemm")
11{
12
13struct GemmFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14{
15 GemmFixture(const std::string& alpha,
16 const std::string& beta,
17 const std::string& transA,
18 const std::string& transB,
19 const std::vector<int>& inputAShape,
20 const std::vector<int>& inputBShape,
21 const std::vector<int>& inputCShape,
22 const std::vector<int>& outputShape)
23 {
24 m_Prototext = R"(
25 ir_version: 8
26 producer_name: "onnx-example"
27 graph {
28 node {
29 input: "A"
30 input: "B"
31 input: "C"
32 output: "Output"
33 op_type: "Gemm"
34 attribute {
35 name: "alpha"
36 f: )" + alpha + R"(
37 type: FLOAT
38 }
39 attribute {
40 name: "beta"
41 f: )" + beta + R"(
42 type: FLOAT
43 }
44 attribute {
45 name: "transA"
46 i: )" + transA + R"(
47 type: INT
48 }
49 attribute {
50 name: "transB"
51 i: )" + transB + R"(
52 type: INT
53 }
54 }
55 name: "gem-model"
56 input {
57 name: "A"
58 type {
59 tensor_type {
60 elem_type: 1
61 shape {
62 )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"(
63 }
64 }
65 }
66 }
67 input {
68 name: "B"
69 type {
70 tensor_type {
71 elem_type: 1
72 shape {
73 )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"(
74 }
75 }
76 }
77 }
78 input {
79 name: "C"
80 type {
81 tensor_type {
82 elem_type: 1
83 shape {
84 )" + armnnUtils::ConstructTensorShapeString(inputCShape) + R"(
85 }
86 }
87 }
88 }
89 output {
90 name: "Output"
91 type {
92 tensor_type {
93 elem_type: 1
94 shape {
95 )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
96 }
97 }
98 }
99 }
100 })";
101 }
102};
103
104struct GemmAllAttributesFixture : GemmFixture
105{
106 GemmAllAttributesFixture() : GemmFixture("0.25", "0.35", "1", "1", { 4, 3 }, { 5, 4 }, { 5 }, { 3, 5 })
107 {
108 Setup();
109 }
110};
111
112struct GemmSimpleFixture : GemmFixture
113{
114 GemmSimpleFixture() : GemmFixture("1", "1", "0", "0", { 3, 4 }, { 4, 5 }, { 5 }, { 3, 5 })
115 {
116 Setup();
117 }
118};
119
120struct GemmTransAFixture : GemmFixture
121{
122 GemmTransAFixture() : GemmFixture("1", "1", "1", "0", { 4, 3 }, { 4, 5 }, { 5 }, { 3, 5 })
123 {
124 Setup();
125 }
126};
127
128struct GemmTransBFixture : GemmFixture
129{
130 GemmTransBFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 5 }, { 3, 5 })
131 {
132 Setup();
133 }
134};
135
136struct GemmParseExceptionFixture : GemmFixture
137{
138 GemmParseExceptionFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }, { 3, 5 }) {}
139};
140
141TEST_CASE_FIXTURE(GemmAllAttributesFixture, "GemmTest")
142{
143 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
144 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
145 {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
146 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
147 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
148 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
149 {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
150 {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f,
151 12.535f, 38.57f, 64.605f, 90.64f, 116.675f,
152 10.035f, 32.07f, 54.105f, 76.14f, 98.175f }}});
153}
154
155TEST_CASE_FIXTURE(GemmSimpleFixture, "GemmSimpleTest")
156{
157 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
158 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
159 {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
160 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
161 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
162 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
163 {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
164 {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f,
165 196.1f, 222.2f, 248.3f, 274.4f, 300.5f,
166 60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}});
167}
168
169TEST_CASE_FIXTURE(GemmTransAFixture, "GemmTransposeATest")
170{
171 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
172 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
173 {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
174 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
175 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
176 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
177 {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
178 {{"Output", { 180.1f, 210.2f, 240.3f, 270.4f, 300.5f,
179 146.1f, 172.2f, 198.3f, 224.4f, 250.5f,
180 112.1f, 134.2f, 156.3f, 178.4f, 200.5f }}});
181}
182
183TEST_CASE_FIXTURE(GemmTransBFixture, "GemmTransposeBTest")
184{
185 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
186 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
187 {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
188 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
189 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
190 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
191 {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
192 {{"Output", { 100.1f, 268.2f, 436.3f, 604.4f, 772.5f,
193 60.1f, 164.2f, 268.3f, 372.4f, 476.5f,
194 20.1f, 60.2f, 100.3f, 140.4f, 180.5f }}});
195}
196
197TEST_CASE_FIXTURE(GemmParseExceptionFixture, "GemmParseExceptionTest")
198{
199 // ParseException because Input C is non-constant and has 2 dimension (should be 1 dimension)
200 CHECK_THROWS_AS(Setup(), armnn::ParseException);
201}
202
203struct GemmConstantFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
204{
205 GemmConstantFixture()
206 {
207 m_Prototext = R"(
208 ir_version: 8
209 producer_name: "onnx-example"
210 graph {
211 node {
212 input: "A"
213 input: "B"
214 input: "C"
215 output: "Output"
216 op_type: "Gemm"
217 attribute {
218 name: "alpha"
219 f: 0.25
220 type: FLOAT
221 }
222 attribute {
223 name: "beta"
224 f: 0.35
225 type: FLOAT
226 }
227 attribute {
228 name: "transA"
229 i: 1
230 type: INT
231 }
232 attribute {
233 name: "transB"
234 i: 1
235 type: INT
236 }
237 }
238 name: "gem-model"
239 initializer {
240 dims: 5
241 dims: 4
242 data_type: 1
243 float_data: 1.0
244 float_data: 2.0
245 float_data: 3.0
246 float_data: 4.0
247 float_data: 5.0
248 float_data: 6.0
249 float_data: 7.0
250 float_data: 8.0
251 float_data: 9.0
252 float_data: 10.0
253 float_data: 11.0
254 float_data: 12.0
255 float_data: 13.0
256 float_data: 14.0
257 float_data: 15.0
258 float_data: 16.0
259 float_data: 17.0
260 float_data: 18.0
261 float_data: 19.0
262 float_data: 20.0
263 name: "B"
264 }
265 initializer {
266 dims: 1
267 dims: 5
268 data_type: 1
269 float_data: 0.1
270 float_data: 0.2
271 float_data: 0.3
272 float_data: 0.4
273 float_data: 0.5
274 name: "C"
275 }
276 input {
277 name: "A"
278 type {
279 tensor_type {
280 elem_type: 1
281 shape {
282 dim {
283 dim_value: 4
284 }
285 dim {
286 dim_value: 3
287 }
288 }
289 }
290 }
291 }
292 output {
293 name: "Output"
294 type {
295 tensor_type {
296 elem_type: 1
297 shape {
298 dim {
299 dim_value: 3
300 }
301 dim {
302 dim_value: 5
303 }
304 }
305 }
306 }
307 }
308 })";
309 Setup();
310 }
311};
312
313TEST_CASE_FIXTURE(GemmConstantFixture, "GemmConstantTest")
314{
315 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
316 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}},
317 {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f,
318 12.535f, 38.57f, 64.605f, 90.64f, 116.675f,
319 10.035f, 32.07f, 54.105f, 76.14f, 98.175f }}});
320}
321
322struct GemmConstantSimpleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
323{
324 GemmConstantSimpleFixture()
325 {
326 m_Prototext = R"(
327 ir_version: 8
328 producer_name: "onnx-example"
329 graph {
330 node {
331 input: "A"
332 input: "B"
333 input: "C"
334 output: "Output"
335 op_type: "Gemm"
336 attribute {
337 name: "alpha"
338 f: 1
339 type: FLOAT
340 }
341 attribute {
342 name: "beta"
343 f: 1
344 type: FLOAT
345 }
346 attribute {
347 name: "transA"
348 i: 0
349 type: INT
350 }
351 attribute {
352 name: "transB"
353 i: 0
354 type: INT
355 }
356 }
357 name: "gem-model"
358 initializer {
359 dims: 4
360 dims: 5
361 data_type: 1
362 float_data: 1.0
363 float_data: 2.0
364 float_data: 3.0
365 float_data: 4.0
366 float_data: 5.0
367 float_data: 6.0
368 float_data: 7.0
369 float_data: 8.0
370 float_data: 9.0
371 float_data: 10.0
372 float_data: 11.0
373 float_data: 12.0
374 float_data: 13.0
375 float_data: 14.0
376 float_data: 15.0
377 float_data: 16.0
378 float_data: 17.0
379 float_data: 18.0
380 float_data: 19.0
381 float_data: 20.0
382 name: "B"
383 }
384 initializer {
385 dims: 1
386 dims: 5
387 data_type: 1
388 float_data: 0.1
389 float_data: 0.2
390 float_data: 0.3
391 float_data: 0.4
392 float_data: 0.5
393 name: "C"
394 }
395 input {
396 name: "A"
397 type {
398 tensor_type {
399 elem_type: 1
400 shape {
401 dim {
402 dim_value: 3
403 }
404 dim {
405 dim_value: 4
406 }
407 }
408 }
409 }
410 }
411 output {
412 name: "Output"
413 type {
414 tensor_type {
415 elem_type: 1
416 shape {
417 dim {
418 dim_value: 3
419 }
420 dim {
421 dim_value: 5
422 }
423 }
424 }
425 }
426 }
427 })";
428 Setup();
429 }
430};
431
432TEST_CASE_FIXTURE(GemmConstantSimpleFixture, "GemmConstantSimpleTest")
433{
434 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
435 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}},
436 {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f,
437 196.1f, 222.2f, 248.3f, 274.4f, 300.5f,
438 60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}});
439}
440
441struct GemmABFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
442{
443 GemmABFixture(const std::string& alpha,
444 const std::string& beta,
445 const std::string& transA,
446 const std::string& transB,
447 const std::vector<int>& inputAShape,
448 const std::vector<int>& inputBShape,
449 const std::vector<int>& outputShape)
450 {
451 m_Prototext = R"(
452 ir_version: 8
453 producer_name: "onnx-example"
454 graph {
455 node {
456 input: "A"
457 input: "B"
458 output: "Output"
459 op_type: "Gemm"
460 attribute {
461 name: "alpha"
462 f: )" + alpha + R"(
463 type: FLOAT
464 }
465 attribute {
466 name: "beta"
467 f: )" + beta + R"(
468 type: FLOAT
469 }
470 attribute {
471 name: "transA"
472 i: )" + transA + R"(
473 type: INT
474 }
475 attribute {
476 name: "transB"
477 i: )" + transB + R"(
478 type: INT
479 }
480 }
481 name: "gem-model"
482 input {
483 name: "A"
484 type {
485 tensor_type {
486 elem_type: 1
487 shape {
488 )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"(
489 }
490 }
491 }
492 }
493 input {
494 name: "B"
495 type {
496 tensor_type {
497 elem_type: 1
498 shape {
499 )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"(
500 }
501 }
502 }
503 }
504 output {
505 name: "Output"
506 type {
507 tensor_type {
508 elem_type: 1
509 shape {
510 )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
511 }
512 }
513 }
514 }
515 })";
516 Setup();
517 }
518};
519
520struct GemmAlphaTransAFixture : GemmABFixture
521{
522 GemmAlphaTransAFixture() : GemmABFixture("0.25", "0.35", "1", "0", { 4, 3 }, { 4, 5 }, { 3, 5 }) {}
523};
524
525struct GemmAlphaTransBFixture : GemmABFixture
526{
527 GemmAlphaTransBFixture() : GemmABFixture("0.25", "0.35", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }) {}
528};
529
530TEST_CASE_FIXTURE(GemmAlphaTransAFixture, "GemmAlphaTransATest")
531{
532 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
533 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
534 {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
535 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
536 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
537 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}},
538 {{"Output", { 45.0f, 52.5f, 60.0f, 67.5f, 75.0f,
539 36.5f, 43.0f, 49.5f, 56.0f, 62.5f,
540 28.0f, 33.5f, 39.0f, 44.5f, 50.0f }}});
541}
542
543TEST_CASE_FIXTURE(GemmAlphaTransBFixture, "GemmAlphaTransBTest")
544{
545 RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
546 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
547 {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
548 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
549 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
550 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}},
551 {{"Output", { 25.0f, 67.0f, 109.0f, 151.0f, 193.0f,
552 15.0f, 41.0f, 67.0f, 93.0f, 119.0f,
553 5.0f, 15.0f, 25.0f, 35.0f, 45.0f }}});
554}
555
556}