blob: e5cb976c45457213fd9cd8d1e1d8a88afa6061eb [file] [log] [blame]
Ryan OShea49ed0df2022-09-21 16:09:41 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "BatchMatMulTestHelper.hpp"
7
8#include <armnn_delegate.hpp>
9
10#include <flatbuffers/flatbuffers.h>
11#include <tensorflow/lite/schema/schema_generated.h>
12
13#include <doctest/doctest.h>
14
15namespace armnnDelegate
16{
17
18 void BatchMatMul2DFp32SimpleTest(std::vector<armnn::BackendId>& backends)
19 {
20 // Set input data
21 std::vector<int32_t> LHSInputShape { 2, 2 };
22 std::vector<int32_t> RHSInputShape { 2, 2 };
23 std::vector<int32_t> outputShape { 2, 2 };
24
25 std::vector<float> LHSInputValues = { 1, 2,
26 3, 4 };
27
28 std::vector<float> RHSInputValues = { 5, 6,
29 7, 8 };
30
31 std::vector<float> expectedOutputValues = { 19, 22,
32 43, 50 };
33
34 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
35 ::tflite::TensorType_FLOAT32,
36 backends,
37 LHSInputShape,
38 RHSInputShape,
39 outputShape,
40 LHSInputValues,
41 RHSInputValues,
42 expectedOutputValues,
43 false,
44 false);
45 }
46 void BatchMatMul2DInt8SimpleTest(std::vector<armnn::BackendId>& backends)
47 {
48 // Set input data
49 std::vector<int32_t> LHSInputShape { 2, 2 };
50 std::vector<int32_t> RHSInputShape { 2, 2 };
51 std::vector<int32_t> outputShape { 2, 2 };
52
53 std::vector<int8_t> LHSInputValues = { 1, 2,
54 3, 4 };
55
56 std::vector<int8_t> RHSInputValues = { 5, 6,
57 7, 8 };
58
59 std::vector<int8_t> expectedOutputValues = { 19, 22,
60 43, 50 };
61
62 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
63 ::tflite::TensorType_INT8,
64 backends,
65 LHSInputShape,
66 RHSInputShape,
67 outputShape,
68 LHSInputValues,
69 RHSInputValues,
70 expectedOutputValues,
71 false,
72 false);
73 }
74
75 void BatchMatMul3DFp32SimpleTest(std::vector<armnn::BackendId>& backends)
76 {
77 // Set input data
78 std::vector<int32_t> LHSInputShape { 1,2,2 };
79 std::vector<int32_t> RHSInputShape { 1,2,2 };
80 std::vector<int32_t> outputShape { 1,2,2 };
81
82 std::vector<float> LHSInputValues = { 1, 2,
83 3, 4 };
84
85 std::vector<float> RHSInputValues = { 5, 6,
86 7, 8 };
87
88 std::vector<float> expectedOutputValues = { 19, 22,
89 43, 50 };
90
91 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
92 ::tflite::TensorType_FLOAT32,
93 backends,
94 LHSInputShape,
95 RHSInputShape,
96 outputShape,
97 LHSInputValues,
98 RHSInputValues,
99 expectedOutputValues,
100 false,
101 false);
102 }
103
104 void BatchMatMul3DInt8SimpleTest(std::vector<armnn::BackendId>& backends)
105 {
106 // Set input data
107 std::vector<int32_t> LHSInputShape { 1,2,2 };
108 std::vector<int32_t> RHSInputShape { 1,2,2 };
109 std::vector<int32_t> outputShape { 1,2,2 };
110
111 std::vector<int8_t> LHSInputValues = { 1, 2,
112 3, 4 };
113
114 std::vector<int8_t> RHSInputValues = { 5, 6,
115 7, 8 };
116
117 std::vector<int8_t> expectedOutputValues = { 19, 22,
118 43, 50 };
119
120 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
121 ::tflite::TensorType_INT8,
122 backends,
123 LHSInputShape,
124 RHSInputShape,
125 outputShape,
126 LHSInputValues,
127 RHSInputValues,
128 expectedOutputValues,
129 false,
130 false);
131 }
132
133 void BatchMatMul4DFp32SimpleTest(std::vector<armnn::BackendId>& backends)
134 {
135 // Set input data
136 std::vector<int32_t> LHSInputShape { 1,1,2,2 };
137 std::vector<int32_t> RHSInputShape { 1,1,2,2 };
138 std::vector<int32_t> outputShape { 1,1,2,2 };
139
140 std::vector<float> LHSInputValues = { 1, 2,
141 3, 4 };
142
143 std::vector<float> RHSInputValues = { 5, 6,
144 7, 8 };
145
146 std::vector<float> expectedOutputValues = { 19, 22,
147 43, 50 };
148
149 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
150 ::tflite::TensorType_FLOAT32,
151 backends,
152 LHSInputShape,
153 RHSInputShape,
154 outputShape,
155 LHSInputValues,
156 RHSInputValues,
157 expectedOutputValues,
158 false,
159 false);
160 }
161
162 void BatchMatMul4DInt8SimpleTest(std::vector<armnn::BackendId>& backends)
163 {
164 // Set input data
165 std::vector<int32_t> LHSInputShape { 1,1,2,2};
166 std::vector<int32_t> RHSInputShape { 1,1,2,2 };
167 std::vector<int32_t> outputShape { 1,1,2,2 };
168
169 std::vector<int8_t> LHSInputValues = { 1, 2,
170 3, 4 };
171
172 std::vector<int8_t> RHSInputValues = { 5, 6,
173 7, 8 };
174
175 std::vector<int8_t> expectedOutputValues = { 19, 22,
176 43, 50 };
177
178 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
179 ::tflite::TensorType_INT8,
180 backends,
181 LHSInputShape,
182 RHSInputShape,
183 outputShape,
184 LHSInputValues,
185 RHSInputValues,
186 expectedOutputValues,
187 false,
188 false);
189 }
190
191 void BatchMatMul3DFp32BatchTest(std::vector<armnn::BackendId>& backends)
192 {
193 // Set input data
194 std::vector<int32_t> LHSInputShape { 2,2,2 };
195 std::vector<int32_t> RHSInputShape { 2,2,2 };
196 std::vector<int32_t> outputShape { 2,2,2 };
197
198 std::vector<float> LHSInputValues = { 1, 2,
199 3, 4,
200
201 9, 10,
202 11, 12 };
203
204 std::vector<float> RHSInputValues = { 5, 6,
205 7, 8,
206
207 13, 14,
208 15, 16 };
209
210 std::vector<float> expectedOutputValues = { 19, 22,
211 43, 50,
212
213 267, 286,
214 323, 346 };
215
216 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
217 ::tflite::TensorType_FLOAT32,
218 backends,
219 LHSInputShape,
220 RHSInputShape,
221 outputShape,
222 LHSInputValues,
223 RHSInputValues,
224 expectedOutputValues,
225 false,
226 false);
227 }
228
229 void BatchMatMul3DInt8BatchTest(std::vector<armnn::BackendId>& backends)
230 {
231 // Set input data
232 std::vector<int32_t> LHSInputShape { 2,2,2 };
233 std::vector<int32_t> RHSInputShape { 2,2,2 };
234 std::vector<int32_t> outputShape { 2,2,2 };
235
236 std::vector<int8_t> LHSInputValues = { 1, 2,
237 3, 4,
238
239 9, 10,
240 11, 12 };
241
242 std::vector<int8_t> RHSInputValues = { 5, 6,
243 7, 8,
244
245 1, 2,
246 3, 4 };
247
248 std::vector<int8_t> expectedOutputValues = { 19, 22,
249 43, 50,
250
251 39, 58,
252 47, 70 };
253
254 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
255 ::tflite::TensorType_INT8,
256 backends,
257 LHSInputShape,
258 RHSInputShape,
259 outputShape,
260 LHSInputValues,
261 RHSInputValues,
262 expectedOutputValues,
263 false,
264 false);
265 }
266
267 void BatchMatMul3DFp32BroadcastTest(std::vector<armnn::BackendId>& backends)
268 {
269 // Set input data
270 std::vector<int32_t> LHSInputShape { 2,2,2 };
271 std::vector<int32_t> RHSInputShape { 1,2,2 };
272 std::vector<int32_t> outputShape { 2,2,2 };
273
274 std::vector<float> LHSInputValues = { 1, 2,
275 3, 4,
276
277 9, 10,
278 11, 12 };
279
280 std::vector<float> RHSInputValues = { 13, 14,
281 15, 16 };
282
283 std::vector<float> expectedOutputValues = { 43, 46,
284 99, 106,
285
286 267, 286,
287 323, 346 };
288
289 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
290 ::tflite::TensorType_FLOAT32,
291 backends,
292 LHSInputShape,
293 RHSInputShape,
294 outputShape,
295 LHSInputValues,
296 RHSInputValues,
297 expectedOutputValues,
298 false,
299 false);
300 }
301
302 void BatchMatMul3DInt8BroadcastTest(std::vector<armnn::BackendId>& backends)
303 {
304 // Set input data
305 std::vector<int32_t> LHSInputShape { 2,2,2 };
306 std::vector<int32_t> RHSInputShape { 1,2,2 };
307 std::vector<int32_t> outputShape { 2,2,2 };
308
309 std::vector<int8_t> LHSInputValues = { 1, 2,
310 3, 4,
311
312 9, 10,
313 11, 12 };
314
315 std::vector<int8_t> RHSInputValues = { 1, 2,
316 3, 4 };
317
318 std::vector<int8_t> expectedOutputValues = { 7, 10,
319 15, 22,
320
321 39, 58,
322 47, 70 };
323
324 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
325 ::tflite::TensorType_INT8,
326 backends,
327 LHSInputShape,
328 RHSInputShape,
329 outputShape,
330 LHSInputValues,
331 RHSInputValues,
332 expectedOutputValues,
333 false,
334 false);
335 }
336
337 void BatchMatMul3D2DFp32BroadcastTest(std::vector<armnn::BackendId>& backends)
338 {
339 // Set input data
340 std::vector<int32_t> LHSInputShape { 2,2,2 };
341 std::vector<int32_t> RHSInputShape { 2,2 };
342 std::vector<int32_t> outputShape { 2,2,2 };
343
344 std::vector<float> LHSInputValues = { 1, 2,
345 3, 4,
346
347 9, 10,
348 11, 12 };
349
350 std::vector<float> RHSInputValues = { 13, 14,
351 15, 16 };
352
353 std::vector<float> expectedOutputValues = { 43, 46,
354 99, 106,
355
356 267, 286,
357 323, 346 };
358
359 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
360 ::tflite::TensorType_FLOAT32,
361 backends,
362 LHSInputShape,
363 RHSInputShape,
364 outputShape,
365 LHSInputValues,
366 RHSInputValues,
367 expectedOutputValues,
368 false,
369 false);
370 }
371
372 void BatchMatMul3D2DInt8BroadcastTest(std::vector<armnn::BackendId>& backends)
373 {
374 // Set input data
375 std::vector<int32_t> LHSInputShape { 2,2,2 };
376 std::vector<int32_t> RHSInputShape { 2,2 };
377 std::vector<int32_t> outputShape { 2,2,2 };
378
379 std::vector<int8_t> LHSInputValues = { 1, 2,
380 3, 4,
381
382 9, 10,
383 11, 12 };
384
385 std::vector<int8_t> RHSInputValues = { 1, 2,
386 3, 4 };
387
388 std::vector<int8_t> expectedOutputValues = { 7, 10,
389 15, 22,
390
391 39, 58,
392 47, 70 };
393
394 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
395 ::tflite::TensorType_INT8,
396 backends,
397 LHSInputShape,
398 RHSInputShape,
399 outputShape,
400 LHSInputValues,
401 RHSInputValues,
402 expectedOutputValues,
403 false,
404 false);
405 }
406
407 void BatchMatMul2DFp32TinyTest(std::vector<armnn::BackendId>& backends)
408 {
409 // Set input data
410 std::vector<int32_t> LHSInputShape { 1,1 };
411 std::vector<int32_t> RHSInputShape { 1,1 };
412 std::vector<int32_t> outputShape { 1,1 };
413
414 std::vector<float> LHSInputValues = { 3 };
415
416 std::vector<float> RHSInputValues = { 5 };
417
418 std::vector<float> expectedOutputValues = { 15 };
419
420 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
421 ::tflite::TensorType_FLOAT32,
422 backends,
423 LHSInputShape,
424 RHSInputShape,
425 outputShape,
426 LHSInputValues,
427 RHSInputValues,
428 expectedOutputValues,
429 false,
430 false);
431 }
432 void BatchMatMul2DInt8TinyTest(std::vector<armnn::BackendId>& backends)
433 {
434 // Set input data
435 std::vector<int32_t> LHSInputShape { 1,1 };
436 std::vector<int32_t> RHSInputShape { 1,1 };
437 std::vector<int32_t> outputShape { 1,1 };
438
439 std::vector<int8_t> LHSInputValues = { 3 };
440
441 std::vector<int8_t> RHSInputValues = { 5 };
442
443 std::vector<int8_t> expectedOutputValues = { 15 };
444
445 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
446 ::tflite::TensorType_INT8,
447 backends,
448 LHSInputShape,
449 RHSInputShape,
450 outputShape,
451 LHSInputValues,
452 RHSInputValues,
453 expectedOutputValues,
454 false,
455 false);
456 }
457
458 void BatchMatMulNonSquareFp32Test(std::vector<armnn::BackendId>& backends)
459 {
460 // Set input data
461 std::vector<int32_t> LHSInputShape { 2,5,3 };
462 std::vector<int32_t> RHSInputShape { 2,3,4 };
463 std::vector<int32_t> outputShape { 2,5,4 };
464
465 std::vector<float> LHSInputValues = { 8, 8, 4,
466 6, 1, 3,
467 8, 8, 3,
468 8, 9, 8,
469 5, 4, 4,
470
471 1, 8, 5,
472 7, 1, 1,
473 8, 7, 9,
474 3, 2, 7,
475 8, 5, 3 };
476
477 std::vector<float> RHSInputValues = { 6, 2, 3, 2,
478 6, 2, 2, 8,
479 3, 7, 8, 1,
480
481 7, 2, 9, 5,
482 2, 3, 1, 3,
483 2, 7, 7, 5 };
484
485 std::vector<float> expectedOutputValues = { 108, 60, 72, 84,
486 51, 35, 44, 23,
487 105, 53, 64, 83,
488 126, 90, 106, 96,
489 66, 46, 55, 46,
490
491 33, 61, 52, 54,
492 53, 24, 71, 43,
493 88, 100, 142, 106,
494 39, 61, 78, 56,
495 72, 52, 98, 70 };
496
497 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
498 ::tflite::TensorType_FLOAT32,
499 backends,
500 LHSInputShape,
501 RHSInputShape,
502 outputShape,
503 LHSInputValues,
504 RHSInputValues,
505 expectedOutputValues,
506 false,
507 false);
508 }
509
510 void BatchMatMulNonSquareInt8Test(std::vector<armnn::BackendId>& backends)
511 {
512 // Set input data
513 std::vector<int32_t> LHSInputShape { 2,5,3 };
514 std::vector<int32_t> RHSInputShape { 2,3,4 };
515 std::vector<int32_t> outputShape { 2,5,4 };
516
517 std::vector<int8_t> LHSInputValues = { 8, 8, 4,
518 6, 1, 3,
519 8, 8, 3,
520 8, 9, 8,
521 5, 4, 4,
522
523 1, 8, 5,
524 7, 1, 1,
525 8, 7, 9,
526 3, 2, 7,
527 8, 5, 3 };
528
529 std::vector<int8_t> RHSInputValues = { 6, 2, 3, 2,
530 6, 2, 2, 8,
531 3, 7, 8, 1,
532
533 7, 2, 3, 5,
534 2, 3, 1, 3,
535 2, 7, 7, 5 };
536
537 std::vector<int8_t> expectedOutputValues = { 108, 60, 72, 84,
538 51, 35, 44, 23,
539 105, 53, 64, 83,
540 126, 90, 106, 96,
541 66, 46, 55, 46,
542
543 33, 61, 46, 54,
544 53, 24, 29, 43,
545 88, 100, 94, 106,
546 39, 61, 60, 56,
547 72, 52, 50, 70 };
548
549 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
550 ::tflite::TensorType_INT8,
551 backends,
552 LHSInputShape,
553 RHSInputShape,
554 outputShape,
555 LHSInputValues,
556 RHSInputValues,
557 expectedOutputValues,
558 false,
559 false);
560 }
561
562 void BatchMatMul2DFp32SimpleAdjointTest(std::vector<armnn::BackendId>& backends)
563 {
564 // Set input data
565 std::vector<int32_t> LHSInputShape { 3,3 };
566 std::vector<int32_t> RHSInputShape { 3,3 };
567 std::vector<int32_t> outputShape { 3,3 };
568
569 std::vector<float> LHSInputValues = { 3, 1, 1,
570 1, 3, -1,
571 2, 4, 1 };
572
573 std::vector<float> RHSInputValues = { 1, 0, 0,
574 0, 1, 0,
575 0, 0, 1 };
576
577 std::vector<float> expectedOutputValues = { 3, 1, 2,
578 1, 3, 4,
579 1, -1, 1 };
580
581 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
582 ::tflite::TensorType_FLOAT32,
583 backends,
584 LHSInputShape,
585 RHSInputShape,
586 outputShape,
587 LHSInputValues,
588 RHSInputValues,
589 expectedOutputValues,
590 true,
591 false);
592 }
593
594 void BatchMatMul2DInt8SimpleAdjointTest(std::vector<armnn::BackendId>& backends)
595 {
596 // Set input data
597 std::vector<int32_t> LHSInputShape { 3,3 };
598 std::vector<int32_t> RHSInputShape { 3,3 };
599 std::vector<int32_t> outputShape { 3,3 };
600
601 std::vector<int8_t> LHSInputValues = { 3, 1, 1,
602 1, 3, -1,
603 2, 4, 1 };
604
605 std::vector<int8_t> RHSInputValues = { 1, 0, 0,
606 0, 1, 0,
607 0, 0, 1 };
608
609 std::vector<int8_t> expectedOutputValues = { 3, 1, 2,
610 1, 3, 4,
611 1, -1, 1 };
612
613 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
614 ::tflite::TensorType_INT8,
615 backends,
616 LHSInputShape,
617 RHSInputShape,
618 outputShape,
619 LHSInputValues,
620 RHSInputValues,
621 expectedOutputValues,
622 true,
623 false);
624 }
625
626 TEST_SUITE("BATCH_MATMUL_CpuRefTests")
627 {
628 TEST_CASE("BATCH_MATMUL_Fp32_CpuRefTests")
629 {
630 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
631 BatchMatMul2DFp32SimpleTest (backends);
632 BatchMatMul3DFp32SimpleTest (backends);
633 BatchMatMul4DFp32SimpleTest (backends);
634 BatchMatMul3DFp32BatchTest (backends);
635 BatchMatMul3DFp32BroadcastTest (backends);
636 BatchMatMul3D2DFp32BroadcastTest (backends);
637 BatchMatMul2DFp32TinyTest (backends);
638 BatchMatMulNonSquareFp32Test (backends);
639 BatchMatMul2DFp32SimpleAdjointTest(backends);
640 }
641
642 TEST_CASE("BATCH_MATMUL_Int8_CpuRefTests")
643 {
644 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
645 BatchMatMul2DInt8SimpleTest (backends);
646 BatchMatMul3DInt8SimpleTest (backends);
647 BatchMatMul4DInt8SimpleTest (backends);
648 BatchMatMul3DInt8BatchTest (backends);
649 BatchMatMul3DInt8BroadcastTest (backends);
650 BatchMatMul3D2DInt8BroadcastTest (backends);
651 BatchMatMul2DInt8TinyTest (backends);
652 BatchMatMulNonSquareInt8Test (backends);
653 BatchMatMul2DInt8SimpleAdjointTest(backends);
654 }
655 }
656
Teresa Charlin0f86ecf2022-10-13 15:47:08 +0100657 TEST_SUITE("BATCH_MATMUL_CpuAccTests")
658 {
659 TEST_CASE("BATCH_MATMUL_Fp32_CpuAccTests")
660 {
661 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
662 BatchMatMul2DFp32SimpleTest (backends);
663 BatchMatMul3DFp32SimpleTest (backends);
664 BatchMatMul4DFp32SimpleTest (backends);
665 BatchMatMul3DFp32BatchTest (backends);
666 BatchMatMul3DFp32BroadcastTest (backends);
667 BatchMatMul3D2DFp32BroadcastTest (backends);
668 BatchMatMul2DFp32TinyTest (backends);
669 BatchMatMulNonSquareFp32Test (backends);
670 BatchMatMul2DFp32SimpleAdjointTest(backends);
671 }
672 }
Ryan OShea49ed0df2022-09-21 16:09:41 +0100673}