blob: 8c8633b002dc6cb60e04d3fd31593aed6ece0558 [file] [log] [blame]
Ryan OShea49ed0df2022-09-21 16:09:41 +01001//
Mike Kelly0e3fe102023-01-23 19:32:06 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Ryan OShea49ed0df2022-09-21 16:09:41 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "BatchMatMulTestHelper.hpp"
7
Ryan OShea49ed0df2022-09-21 16:09:41 +01008#include <flatbuffers/flatbuffers.h>
Ryan OShea49ed0df2022-09-21 16:09:41 +01009
10#include <doctest/doctest.h>
11
12namespace armnnDelegate
13{
14
Colm Donelaneff204a2023-11-28 15:46:09 +000015TEST_SUITE("BATCH_MATMUL_Tests")
16{
17 TEST_CASE("BatchMatMul2DFp32SimpleTest")
Ryan OShea49ed0df2022-09-21 16:09:41 +010018 {
19 // Set input data
Colm Donelaneff204a2023-11-28 15:46:09 +000020 std::vector<int32_t> LHSInputShape{ 2, 2 };
21 std::vector<int32_t> RHSInputShape{ 2, 2 };
22 std::vector<int32_t> outputShape{ 2, 2 };
Ryan OShea49ed0df2022-09-21 16:09:41 +010023
Colm Donelaneff204a2023-11-28 15:46:09 +000024 std::vector<float> LHSInputValues = { 1, 2, 3, 4 };
Ryan OShea49ed0df2022-09-21 16:09:41 +010025
Colm Donelaneff204a2023-11-28 15:46:09 +000026 std::vector<float> RHSInputValues = { 5, 6, 7, 8 };
Ryan OShea49ed0df2022-09-21 16:09:41 +010027
Colm Donelaneff204a2023-11-28 15:46:09 +000028 std::vector<float> expectedOutputValues = { 19, 22, 43, 50 };
Ryan OShea49ed0df2022-09-21 16:09:41 +010029
Colm Donelaneff204a2023-11-28 15:46:09 +000030 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_FLOAT32, LHSInputShape,
31 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
Ryan OShea49ed0df2022-09-21 16:09:41 +010032 false);
33 }
34
Colm Donelaneff204a2023-11-28 15:46:09 +000035 TEST_CASE("BatchMatMul2DInt8SimpleTest")
Ryan OShea49ed0df2022-09-21 16:09:41 +010036 {
37 // Set input data
Colm Donelaneff204a2023-11-28 15:46:09 +000038 std::vector<int32_t> LHSInputShape{ 2, 2 };
39 std::vector<int32_t> RHSInputShape{ 2, 2 };
40 std::vector<int32_t> outputShape{ 2, 2 };
Ryan OShea49ed0df2022-09-21 16:09:41 +010041
Colm Donelaneff204a2023-11-28 15:46:09 +000042 std::vector<int8_t> LHSInputValues = { 1, 2, 3, 4 };
Ryan OShea49ed0df2022-09-21 16:09:41 +010043
Colm Donelaneff204a2023-11-28 15:46:09 +000044 std::vector<int8_t> RHSInputValues = { 5, 6, 7, 8 };
Ryan OShea49ed0df2022-09-21 16:09:41 +010045
Colm Donelaneff204a2023-11-28 15:46:09 +000046 std::vector<int8_t> expectedOutputValues = { 19, 22, 43, 50 };
Ryan OShea49ed0df2022-09-21 16:09:41 +010047
Colm Donelaneff204a2023-11-28 15:46:09 +000048 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_INT8, LHSInputShape,
49 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
50 false);
51 }
52
53 TEST_CASE("BatchMatMul3DFp32SimpleTest")
54 {
55 // Set input data
56 std::vector<int32_t> LHSInputShape{ 1, 2, 2 };
57 std::vector<int32_t> RHSInputShape{ 1, 2, 2 };
58 std::vector<int32_t> outputShape{ 1, 2, 2 };
59
60 std::vector<float> LHSInputValues = { 1, 2, 3, 4 };
61
62 std::vector<float> RHSInputValues = { 5, 6, 7, 8 };
63
64 std::vector<float> expectedOutputValues = { 19, 22, 43, 50 };
65
66 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_FLOAT32, LHSInputShape,
67 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
Ryan OShea49ed0df2022-09-21 16:09:41 +010068 false);
69 }
70
Colm Donelaneff204a2023-11-28 15:46:09 +000071 TEST_CASE("BatchMatMul3DInt8SimpleTest")
Ryan OShea49ed0df2022-09-21 16:09:41 +010072 {
73 // Set input data
Colm Donelaneff204a2023-11-28 15:46:09 +000074 std::vector<int32_t> LHSInputShape{ 1, 2, 2 };
75 std::vector<int32_t> RHSInputShape{ 1, 2, 2 };
76 std::vector<int32_t> outputShape{ 1, 2, 2 };
Ryan OShea49ed0df2022-09-21 16:09:41 +010077
Colm Donelaneff204a2023-11-28 15:46:09 +000078 std::vector<int8_t> LHSInputValues = { 1, 2, 3, 4 };
Ryan OShea49ed0df2022-09-21 16:09:41 +010079
Colm Donelaneff204a2023-11-28 15:46:09 +000080 std::vector<int8_t> RHSInputValues = { 5, 6, 7, 8 };
Ryan OShea49ed0df2022-09-21 16:09:41 +010081
Colm Donelaneff204a2023-11-28 15:46:09 +000082 std::vector<int8_t> expectedOutputValues = { 19, 22, 43, 50 };
Ryan OShea49ed0df2022-09-21 16:09:41 +010083
Colm Donelaneff204a2023-11-28 15:46:09 +000084 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_INT8, LHSInputShape,
85 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
86 false);
87 }
88
89 TEST_CASE("BatchMatMul4DFp32SimpleTest")
90 {
91 // Set input data
92 std::vector<int32_t> LHSInputShape{ 1, 1, 2, 2 };
93 std::vector<int32_t> RHSInputShape{ 1, 1, 2, 2 };
94 std::vector<int32_t> outputShape{ 1, 1, 2, 2 };
95
96 std::vector<float> LHSInputValues = { 1, 2, 3, 4 };
97
98 std::vector<float> RHSInputValues = { 5, 6, 7, 8 };
99
100 std::vector<float> expectedOutputValues = { 19, 22, 43, 50 };
101
102 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_FLOAT32, LHSInputShape,
103 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100104 false);
105 }
106
Colm Donelaneff204a2023-11-28 15:46:09 +0000107 TEST_CASE("BatchMatMul4DInt8SimpleTest")
Ryan OShea49ed0df2022-09-21 16:09:41 +0100108 {
109 // Set input data
Colm Donelaneff204a2023-11-28 15:46:09 +0000110 std::vector<int32_t> LHSInputShape{ 1, 1, 2, 2 };
111 std::vector<int32_t> RHSInputShape{ 1, 1, 2, 2 };
112 std::vector<int32_t> outputShape{ 1, 1, 2, 2 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100113
Colm Donelaneff204a2023-11-28 15:46:09 +0000114 std::vector<int8_t> LHSInputValues = { 1, 2, 3, 4 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100115
Colm Donelaneff204a2023-11-28 15:46:09 +0000116 std::vector<int8_t> RHSInputValues = { 5, 6, 7, 8 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100117
Colm Donelaneff204a2023-11-28 15:46:09 +0000118 std::vector<int8_t> expectedOutputValues = { 19, 22, 43, 50 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100119
Colm Donelaneff204a2023-11-28 15:46:09 +0000120 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_INT8, LHSInputShape,
121 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
122 false);
123 }
124
125 TEST_CASE("BatchMatMul3DFp32BatchTest")
126 {
127 // Set input data
128 std::vector<int32_t> LHSInputShape{ 2, 2, 2 };
129 std::vector<int32_t> RHSInputShape{ 2, 2, 2 };
130 std::vector<int32_t> outputShape{ 2, 2, 2 };
131
132 std::vector<float> LHSInputValues = { 1, 2, 3, 4,
133
134 9, 10, 11, 12 };
135
136 std::vector<float> RHSInputValues = { 5, 6, 7, 8,
137
138 13, 14, 15, 16 };
139
140 std::vector<float> expectedOutputValues = { 19, 22, 43, 50,
141
142 267, 286, 323, 346 };
143
144 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_FLOAT32, LHSInputShape,
145 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100146 false);
147 }
148
Colm Donelaneff204a2023-11-28 15:46:09 +0000149 TEST_CASE("BatchMatMul3DInt8BatchTest")
Ryan OShea49ed0df2022-09-21 16:09:41 +0100150 {
151 // Set input data
Colm Donelaneff204a2023-11-28 15:46:09 +0000152 std::vector<int32_t> LHSInputShape{ 2, 2, 2 };
153 std::vector<int32_t> RHSInputShape{ 2, 2, 2 };
154 std::vector<int32_t> outputShape{ 2, 2, 2 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100155
Colm Donelaneff204a2023-11-28 15:46:09 +0000156 std::vector<int8_t> LHSInputValues = { 1, 2, 3, 4,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100157
Colm Donelaneff204a2023-11-28 15:46:09 +0000158 9, 10, 11, 12 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100159
Colm Donelaneff204a2023-11-28 15:46:09 +0000160 std::vector<int8_t> RHSInputValues = { 5, 6, 7, 8,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100161
Colm Donelaneff204a2023-11-28 15:46:09 +0000162 1, 2, 3, 4 };
163
164 std::vector<int8_t> expectedOutputValues = { 19, 22, 43, 50,
165
166 39, 58, 47, 70 };
167
168 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_INT8, LHSInputShape,
169 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
170 false);
Ryan OShea49ed0df2022-09-21 16:09:41 +0100171 }
172
Colm Donelaneff204a2023-11-28 15:46:09 +0000173 TEST_CASE("BatchMatMul3DFp32BroadcastTest")
Ryan OShea49ed0df2022-09-21 16:09:41 +0100174 {
175 // Set input data
Colm Donelaneff204a2023-11-28 15:46:09 +0000176 std::vector<int32_t> LHSInputShape{ 2, 2, 2 };
177 std::vector<int32_t> RHSInputShape{ 2, 2 };
178 std::vector<int32_t> outputShape{ 2, 2, 2 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100179
Colm Donelaneff204a2023-11-28 15:46:09 +0000180 std::vector<float> LHSInputValues = { 1, 2, 3, 4,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100181
Colm Donelaneff204a2023-11-28 15:46:09 +0000182 9, 10, 11, 12 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100183
Colm Donelaneff204a2023-11-28 15:46:09 +0000184 std::vector<float> RHSInputValues = { 13, 14, 15, 16 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100185
Colm Donelaneff204a2023-11-28 15:46:09 +0000186 std::vector<float> expectedOutputValues = { 43, 46, 99, 106,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100187
Colm Donelaneff204a2023-11-28 15:46:09 +0000188 267, 286, 323, 346 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100189
Colm Donelaneff204a2023-11-28 15:46:09 +0000190 // We know that this is only supported on CpuRef. To enable on all backends just remoev the last parameter.
191 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_FLOAT32, LHSInputShape,
192 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
193 false, 1.0f, 0,{armnn::Compute::CpuRef});
Ryan OShea49ed0df2022-09-21 16:09:41 +0100194 }
195
Colm Donelaneff204a2023-11-28 15:46:09 +0000196 TEST_CASE("BatchMatMul3DInt8BroadcastTest")
Ryan OShea49ed0df2022-09-21 16:09:41 +0100197 {
198 // Set input data
Colm Donelaneff204a2023-11-28 15:46:09 +0000199 std::vector<int32_t> LHSInputShape{ 2, 2, 2 };
200 std::vector<int32_t> RHSInputShape{ 2, 2 };
201 std::vector<int32_t> outputShape{ 2, 2, 2 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100202
Colm Donelaneff204a2023-11-28 15:46:09 +0000203 std::vector<int8_t> LHSInputValues = { 1, 2, 3, 4,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100204
Colm Donelaneff204a2023-11-28 15:46:09 +0000205 9, 10, 11, 12 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100206
Colm Donelaneff204a2023-11-28 15:46:09 +0000207 std::vector<int8_t> RHSInputValues = { 1, 2, 3, 4 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100208
Colm Donelaneff204a2023-11-28 15:46:09 +0000209 std::vector<int8_t> expectedOutputValues = { 7, 10, 15, 22,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100210
Colm Donelaneff204a2023-11-28 15:46:09 +0000211 39, 58, 47, 70 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100212
Colm Donelaneff204a2023-11-28 15:46:09 +0000213 // We know that this is only supported on CpuRef. To enable on all backends just remoev the last parameter.
214 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_INT8, LHSInputShape,
215 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
216 false, 1.0f, 0,{armnn::Compute::CpuRef});
Ryan OShea49ed0df2022-09-21 16:09:41 +0100217 }
218
Colm Donelaneff204a2023-11-28 15:46:09 +0000219 TEST_CASE("BatchMatMul3D2DFp32BroadcastTest")
Ryan OShea49ed0df2022-09-21 16:09:41 +0100220 {
221 // Set input data
Colm Donelaneff204a2023-11-28 15:46:09 +0000222 std::vector<int32_t> LHSInputShape{ 2, 2, 2 };
223 std::vector<int32_t> RHSInputShape{ 2, 2 };
224 std::vector<int32_t> outputShape{ 2, 2, 2 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100225
Colm Donelaneff204a2023-11-28 15:46:09 +0000226 std::vector<float> LHSInputValues = { 1, 2, 3, 4,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100227
Colm Donelaneff204a2023-11-28 15:46:09 +0000228 9, 10, 11, 12 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100229
Colm Donelaneff204a2023-11-28 15:46:09 +0000230 std::vector<float> RHSInputValues = { 13, 14, 15, 16 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100231
Colm Donelaneff204a2023-11-28 15:46:09 +0000232 std::vector<float> expectedOutputValues = { 43, 46, 99, 106,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100233
Colm Donelaneff204a2023-11-28 15:46:09 +0000234 267, 286, 323, 346 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100235
Colm Donelaneff204a2023-11-28 15:46:09 +0000236 // We know that this is only supported on CpuRef. To enable on all backends just remoev the last parameter.
237 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_FLOAT32, LHSInputShape,
238 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
239 false, 1.0f, 0,{armnn::Compute::CpuRef});
Ryan OShea49ed0df2022-09-21 16:09:41 +0100240 }
241
Colm Donelaneff204a2023-11-28 15:46:09 +0000242 TEST_CASE("BatchMatMul3D2DInt8BroadcastTest")
Ryan OShea49ed0df2022-09-21 16:09:41 +0100243 {
244 // Set input data
Colm Donelaneff204a2023-11-28 15:46:09 +0000245 std::vector<int32_t> LHSInputShape{ 2, 2, 2 };
246 std::vector<int32_t> RHSInputShape{ 2, 2 };
247 std::vector<int32_t> outputShape{ 2, 2, 2 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100248
Colm Donelaneff204a2023-11-28 15:46:09 +0000249 std::vector<int8_t> LHSInputValues = { 1, 2, 3, 4,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100250
Colm Donelaneff204a2023-11-28 15:46:09 +0000251 9, 10, 11, 12 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100252
Colm Donelaneff204a2023-11-28 15:46:09 +0000253 std::vector<int8_t> RHSInputValues = { 1, 2, 3, 4 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100254
Colm Donelaneff204a2023-11-28 15:46:09 +0000255 std::vector<int8_t> expectedOutputValues = { 7, 10, 15, 22,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100256
Colm Donelaneff204a2023-11-28 15:46:09 +0000257 39, 58, 47, 70 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100258
Colm Donelaneff204a2023-11-28 15:46:09 +0000259 // We know that this is only supported on CpuRef. To enable on all backends just remoev the last parameter.
260 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_INT8, LHSInputShape,
261 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
262 false, 1.0f, 0,{armnn::Compute::CpuRef});
Ryan OShea49ed0df2022-09-21 16:09:41 +0100263 }
264
Colm Donelaneff204a2023-11-28 15:46:09 +0000265 TEST_CASE("BatchMatMul2DFp32TinyTest")
Ryan OShea49ed0df2022-09-21 16:09:41 +0100266 {
267 // Set input data
Colm Donelaneff204a2023-11-28 15:46:09 +0000268 std::vector<int32_t> LHSInputShape{ 1, 1 };
269 std::vector<int32_t> RHSInputShape{ 1, 1 };
270 std::vector<int32_t> outputShape{ 1, 1 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100271
272 std::vector<float> LHSInputValues = { 3 };
273
274 std::vector<float> RHSInputValues = { 5 };
275
276 std::vector<float> expectedOutputValues = { 15 };
277
Colm Donelaneff204a2023-11-28 15:46:09 +0000278 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_FLOAT32, LHSInputShape,
279 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100280 false);
281 }
Colm Donelaneff204a2023-11-28 15:46:09 +0000282
283 TEST_CASE("BatchMatMul2DInt8TinyTest")
Ryan OShea49ed0df2022-09-21 16:09:41 +0100284 {
285 // Set input data
Colm Donelaneff204a2023-11-28 15:46:09 +0000286 std::vector<int32_t> LHSInputShape{ 1, 1 };
287 std::vector<int32_t> RHSInputShape{ 1, 1 };
288 std::vector<int32_t> outputShape{ 1, 1 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100289
290 std::vector<int8_t> LHSInputValues = { 3 };
291
292 std::vector<int8_t> RHSInputValues = { 5 };
293
294 std::vector<int8_t> expectedOutputValues = { 15 };
295
Colm Donelaneff204a2023-11-28 15:46:09 +0000296 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_INT8, LHSInputShape,
297 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100298 false);
299 }
300
Colm Donelaneff204a2023-11-28 15:46:09 +0000301 TEST_CASE("BatchMatMulNonSquareFp32Test")
Ryan OShea49ed0df2022-09-21 16:09:41 +0100302 {
303 // Set input data
Colm Donelaneff204a2023-11-28 15:46:09 +0000304 std::vector<int32_t> LHSInputShape{ 2, 5, 3 };
305 std::vector<int32_t> RHSInputShape{ 2, 3, 4 };
306 std::vector<int32_t> outputShape{ 2, 5, 4 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100307
Colm Donelaneff204a2023-11-28 15:46:09 +0000308 std::vector<float> LHSInputValues = { 8, 8, 4, 6, 1, 3, 8, 8, 3, 8, 9, 8, 5, 4, 4,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100309
Colm Donelaneff204a2023-11-28 15:46:09 +0000310 1, 8, 5, 7, 1, 1, 8, 7, 9, 3, 2, 7, 8, 5, 3 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100311
Colm Donelaneff204a2023-11-28 15:46:09 +0000312 std::vector<float> RHSInputValues = { 6, 2, 3, 2, 6, 2, 2, 8, 3, 7, 8, 1,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100313
Colm Donelaneff204a2023-11-28 15:46:09 +0000314 7, 2, 9, 5, 2, 3, 1, 3, 2, 7, 7, 5 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100315
Colm Donelaneff204a2023-11-28 15:46:09 +0000316 std::vector<float> expectedOutputValues = { 108, 60, 72, 84, 51, 35, 44, 23, 105, 53,
317 64, 83, 126, 90, 106, 96, 66, 46, 55, 46,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100318
Colm Donelaneff204a2023-11-28 15:46:09 +0000319 33, 61, 52, 54, 53, 24, 71, 43, 88, 100,
320 142, 106, 39, 61, 78, 56, 72, 52, 98, 70 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100321
Colm Donelaneff204a2023-11-28 15:46:09 +0000322 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_FLOAT32, LHSInputShape,
323 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100324 false);
325 }
326
Colm Donelaneff204a2023-11-28 15:46:09 +0000327 TEST_CASE("BatchMatMulNonSquareInt8Test")
Ryan OShea49ed0df2022-09-21 16:09:41 +0100328 {
329 // Set input data
Colm Donelaneff204a2023-11-28 15:46:09 +0000330 std::vector<int32_t> LHSInputShape{ 2, 5, 3 };
331 std::vector<int32_t> RHSInputShape{ 2, 3, 4 };
332 std::vector<int32_t> outputShape{ 2, 5, 4 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100333
Colm Donelaneff204a2023-11-28 15:46:09 +0000334 std::vector<int8_t> LHSInputValues = { 8, 8, 4, 6, 1, 3, 8, 8, 3, 8, 9, 8, 5, 4, 4,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100335
Colm Donelaneff204a2023-11-28 15:46:09 +0000336 1, 8, 5, 7, 1, 1, 8, 7, 9, 3, 2, 7, 8, 5, 3 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100337
Colm Donelaneff204a2023-11-28 15:46:09 +0000338 std::vector<int8_t> RHSInputValues = { 6, 2, 3, 2, 6, 2, 2, 8, 3, 7, 8, 1,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100339
Colm Donelaneff204a2023-11-28 15:46:09 +0000340 7, 2, 3, 5, 2, 3, 1, 3, 2, 7, 7, 5 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100341
Colm Donelaneff204a2023-11-28 15:46:09 +0000342 std::vector<int8_t> expectedOutputValues = { 108, 60, 72, 84, 51, 35, 44, 23, 105, 53,
343 64, 83, 126, 90, 106, 96, 66, 46, 55, 46,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100344
Colm Donelaneff204a2023-11-28 15:46:09 +0000345 33, 61, 46, 54, 53, 24, 29, 43, 88, 100,
346 94, 106, 39, 61, 60, 56, 72, 52, 50, 70 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100347
Colm Donelaneff204a2023-11-28 15:46:09 +0000348 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_INT8, LHSInputShape,
349 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, false,
350 false);
351 }
352
353 TEST_CASE("BatchMatMul2DFp32SimpleAdjointTest")
354 {
355 // Set input data
356 std::vector<int32_t> LHSInputShape{ 3, 3 };
357 std::vector<int32_t> RHSInputShape{ 3, 3 };
358 std::vector<int32_t> outputShape{ 3, 3 };
359
360 std::vector<float> LHSInputValues = { 3, 1, 1, 1, 3, -1, 2, 4, 1 };
361
362 std::vector<float> RHSInputValues = { 1, 0, 0, 0, 1, 0, 0, 0, 1 };
363
364 std::vector<float> expectedOutputValues = { 3, 1, 2, 1, 3, 4, 1, -1, 1 };
365
366 BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_FLOAT32, LHSInputShape,
367 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, true,
Ryan OShea49ed0df2022-09-21 16:09:41 +0100368 false);
369 }
370
Colm Donelaneff204a2023-11-28 15:46:09 +0000371 TEST_CASE("BatchMatMul2DInt8SimpleAdjointTest")
Ryan OShea49ed0df2022-09-21 16:09:41 +0100372 {
373 // Set input data
Colm Donelaneff204a2023-11-28 15:46:09 +0000374 std::vector<int32_t> LHSInputShape{ 3, 3 };
375 std::vector<int32_t> RHSInputShape{ 3, 3 };
376 std::vector<int32_t> outputShape{ 3, 3 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100377
Colm Donelaneff204a2023-11-28 15:46:09 +0000378 std::vector<int8_t> LHSInputValues = { 3, 1, 1, 1, 3, -1, 2, 4, 1 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100379
Colm Donelaneff204a2023-11-28 15:46:09 +0000380 std::vector<int8_t> RHSInputValues = { 1, 0, 0, 0, 1, 0, 0, 0, 1 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100381
Colm Donelaneff204a2023-11-28 15:46:09 +0000382 std::vector<int8_t> expectedOutputValues = { 3, 1, 2, 1, 3, 4, 1, -1, 1 };
Ryan OShea49ed0df2022-09-21 16:09:41 +0100383
Colm Donelaneff204a2023-11-28 15:46:09 +0000384 BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, ::tflite::TensorType_INT8, LHSInputShape,
385 RHSInputShape, outputShape, LHSInputValues, RHSInputValues, expectedOutputValues, true,
386 false);
Teresa Charlin94916a52022-10-19 08:48:07 +0100387 }
Ryan OShea49ed0df2022-09-21 16:09:41 +0100388}
Colm Donelaneff204a2023-11-28 15:46:09 +0000389}