blob: ed21bbe93c5188830349a798d3d42dc103002f2e [file] [log] [blame]
Narumol Prangnawarat02807852019-09-11 16:43:09 +01001//
Mike Kelly0e3fe102023-01-23 19:32:06 +00002// Copyright © 2019,2021-2023 Arm Ltd and Contributors. All rights reserved.
Narumol Prangnawarat02807852019-09-11 16:43:09 +01003// SPDX-License-Identifier: MIT
4//
5
Narumol Prangnawarat02807852019-09-11 16:43:09 +01006#include <armnn/Types.hpp>
7
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/TensorUtils.hpp>
9
Sadik Armagan1625efc2021-06-10 18:24:34 +010010#include <doctest/doctest.h>
Narumol Prangnawarat02807852019-09-11 16:43:09 +010011
12using namespace armnn;
13using namespace armnnUtils;
14
Sadik Armagan1625efc2021-06-10 18:24:34 +010015TEST_SUITE("TensorUtilsSuite")
16{
17TEST_CASE("ExpandDimsAxis0Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010018{
19 armnn::TensorShape inputShape({ 2, 3, 4 });
20
21 // Expand dimension 0
22 armnn::TensorShape outputShape = ExpandDims(inputShape, 0);
Sadik Armagan1625efc2021-06-10 18:24:34 +010023 CHECK(outputShape.GetNumDimensions() == 4);
24 CHECK(outputShape[0] == 1);
25 CHECK(outputShape[1] == 2);
26 CHECK(outputShape[2] == 3);
27 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010028}
29
Sadik Armagan1625efc2021-06-10 18:24:34 +010030TEST_CASE("ExpandDimsAxis1Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010031{
32 armnn::TensorShape inputShape({ 2, 3, 4 });
33
34 // Expand dimension 1
35 armnn::TensorShape outputShape = ExpandDims(inputShape, 1);
Sadik Armagan1625efc2021-06-10 18:24:34 +010036 CHECK(outputShape.GetNumDimensions() == 4);
37 CHECK(outputShape[0] == 2);
38 CHECK(outputShape[1] == 1);
39 CHECK(outputShape[2] == 3);
40 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010041}
42
Sadik Armagan1625efc2021-06-10 18:24:34 +010043TEST_CASE("ExpandDimsAxis2Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010044{
45 armnn::TensorShape inputShape({ 2, 3, 4 });
46
47 // Expand dimension 2
48 armnn::TensorShape outputShape = ExpandDims(inputShape, 2);
Sadik Armagan1625efc2021-06-10 18:24:34 +010049 CHECK(outputShape.GetNumDimensions() == 4);
50 CHECK(outputShape[0] == 2);
51 CHECK(outputShape[1] == 3);
52 CHECK(outputShape[2] == 1);
53 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010054}
55
Sadik Armagan1625efc2021-06-10 18:24:34 +010056TEST_CASE("ExpandDimsAxis3Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010057{
58 armnn::TensorShape inputShape({ 2, 3, 4 });
59
60 // Expand dimension 3
61 armnn::TensorShape outputShape = ExpandDims(inputShape, 3);
Sadik Armagan1625efc2021-06-10 18:24:34 +010062 CHECK(outputShape.GetNumDimensions() == 4);
63 CHECK(outputShape[0] == 2);
64 CHECK(outputShape[1] == 3);
65 CHECK(outputShape[2] == 4);
66 CHECK(outputShape[3] == 1);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010067}
68
Sadik Armagan1625efc2021-06-10 18:24:34 +010069TEST_CASE("ExpandDimsNegativeAxis1Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010070{
71 armnn::TensorShape inputShape({ 2, 3, 4 });
72
73 // Expand dimension -1
74 armnn::TensorShape outputShape = ExpandDims(inputShape, -1);
Sadik Armagan1625efc2021-06-10 18:24:34 +010075 CHECK(outputShape.GetNumDimensions() == 4);
76 CHECK(outputShape[0] == 2);
77 CHECK(outputShape[1] == 3);
78 CHECK(outputShape[2] == 4);
79 CHECK(outputShape[3] == 1);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010080}
81
Sadik Armagan1625efc2021-06-10 18:24:34 +010082TEST_CASE("ExpandDimsNegativeAxis2Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010083{
84 armnn::TensorShape inputShape({ 2, 3, 4 });
85
86 // Expand dimension -2
87 armnn::TensorShape outputShape = ExpandDims(inputShape, -2);
Sadik Armagan1625efc2021-06-10 18:24:34 +010088 CHECK(outputShape.GetNumDimensions() == 4);
89 CHECK(outputShape[0] == 2);
90 CHECK(outputShape[1] == 3);
91 CHECK(outputShape[2] == 1);
92 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010093}
94
Sadik Armagan1625efc2021-06-10 18:24:34 +010095TEST_CASE("ExpandDimsNegativeAxis3Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010096{
97 armnn::TensorShape inputShape({ 2, 3, 4 });
98
99 // Expand dimension -3
100 armnn::TensorShape outputShape = ExpandDims(inputShape, -3);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100101 CHECK(outputShape.GetNumDimensions() == 4);
102 CHECK(outputShape[0] == 2);
103 CHECK(outputShape[1] == 1);
104 CHECK(outputShape[2] == 3);
105 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100106}
107
Sadik Armagan1625efc2021-06-10 18:24:34 +0100108TEST_CASE("ExpandDimsNegativeAxis4Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100109{
110 armnn::TensorShape inputShape({ 2, 3, 4 });
111
112 // Expand dimension -4
113 armnn::TensorShape outputShape = ExpandDims(inputShape, -4);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100114 CHECK(outputShape.GetNumDimensions() == 4);
115 CHECK(outputShape[0] == 1);
116 CHECK(outputShape[1] == 2);
117 CHECK(outputShape[2] == 3);
118 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100119}
120
Sadik Armagan1625efc2021-06-10 18:24:34 +0100121TEST_CASE("ExpandDimsInvalidAxisTest")
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100122{
123 armnn::TensorShape inputShape({ 2, 3, 4 });
124
125 // Invalid expand dimension 4
Sadik Armagan1625efc2021-06-10 18:24:34 +0100126 CHECK_THROWS_AS(ExpandDims(inputShape, 4), armnn::InvalidArgumentException);
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100127}
128
Ryan OSheaa544f0f2023-01-25 18:10:20 +0000129TEST_CASE("ExpandDimsInvalidNegativeAxisTest")
130{
131 armnn::TensorShape inputShape({ 2, 3, 4 });
132
133 // Invalid expand dimension -5
134 CHECK_THROWS_AS(ExpandDims(inputShape, -5), armnn::InvalidArgumentException);
135}
136
137TEST_CASE("ExpandDimsBy1Rank")
138{
139 armnn::TensorShape inputShape({ 2, 3, 4 });
140
141 // Expand by 1 dimension
142 armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4);
143 CHECK(outputShape.GetNumDimensions() == 4);
144 CHECK(outputShape[0] == 1);
145 CHECK(outputShape[1] == 2);
146 CHECK(outputShape[2] == 3);
147 CHECK(outputShape[3] == 4);
148}
149
150TEST_CASE("ExpandDimsBy2Ranks")
151{
152 armnn::TensorShape inputShape({ 3, 4 });
153
154 // Expand 2 dimensions
155 armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4);
156 CHECK(outputShape.GetNumDimensions() == 4);
157 CHECK(outputShape[0] == 1);
158 CHECK(outputShape[1] == 1);
159 CHECK(outputShape[2] == 3);
160 CHECK(outputShape[3] == 4);
161}
162
163TEST_CASE("ExpandDimsBy3Ranks")
164{
165 armnn::TensorShape inputShape({ 4 });
166
167 // Expand 3 dimensions
168 armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 4);
169 CHECK(outputShape.GetNumDimensions() == 4);
170 CHECK(outputShape[0] == 1);
171 CHECK(outputShape[1] == 1);
172 CHECK(outputShape[2] == 1);
173 CHECK(outputShape[3] == 4);
174}
175
176TEST_CASE("ExpandDimsInvalidRankAmount")
177{
178 armnn::TensorShape inputShape({ 2, 3, 4 });
179
180 // Don't expand because target rank is smaller than current rank
181 armnn::TensorShape outputShape = ExpandDimsToRank(inputShape, 2);
182 CHECK(outputShape.GetNumDimensions() == 3);
183 CHECK(outputShape[0] == 2);
184 CHECK(outputShape[1] == 3);
185 CHECK(outputShape[2] == 4);
186}
187
188TEST_CASE("ExpandDimsToRankInvalidTensorShape")
189{
190 armnn::TensorShape inputShape({ 2, 3, 4 });
191
192 // Throw exception because rank 6 tensors are unsupported by armnn
193 CHECK_THROWS_AS(ExpandDimsToRank(inputShape, 6), armnn::InvalidArgumentException);
194}
195
196
Mike Kelly0e3fe102023-01-23 19:32:06 +0000197TEST_CASE("ReduceDimsShapeAll1s")
198{
199 armnn::TensorShape inputShape({ 1, 1, 1 });
200
Ryan OSheaa544f0f2023-01-25 18:10:20 +0000201 // Reduce dimension 2
Mike Kelly0e3fe102023-01-23 19:32:06 +0000202 armnn::TensorShape outputShape = ReduceDims(inputShape, 2);
203 CHECK(outputShape.GetNumDimensions() == 2);
204 CHECK(outputShape[0] == 1);
205 CHECK(outputShape[1] == 1);
206}
207
208TEST_CASE("ReduceDimsShapeNotEnough1s")
209{
210 armnn::TensorShape inputShape({ 1, 2, 1 });
211
Ryan OSheaa544f0f2023-01-25 18:10:20 +0000212 // Reduce dimension 1
Mike Kelly0e3fe102023-01-23 19:32:06 +0000213 armnn::TensorShape outputShape = ReduceDims(inputShape, 1);
214 CHECK(outputShape.GetNumDimensions() == 2);
215 CHECK(outputShape[0] == 2);
216 CHECK(outputShape[1] == 1);
217}
218
219TEST_CASE("ReduceDimsInfoAll1s")
220{
221 armnn::TensorInfo inputInfo({ 1, 1, 1 }, DataType::Float32);
222
Ryan OSheaa544f0f2023-01-25 18:10:20 +0000223 // Reduce dimension 2
Mike Kelly0e3fe102023-01-23 19:32:06 +0000224 armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 2);
225 CHECK(outputInfo.GetShape().GetNumDimensions() == 2);
226 CHECK(outputInfo.GetShape()[0] == 1);
227 CHECK(outputInfo.GetShape()[1] == 1);
228}
229
230TEST_CASE("ReduceDimsInfoNotEnough1s")
231{
232 armnn::TensorInfo inputInfo({ 1, 2, 1 }, DataType::Float32);
233
Ryan OSheaa544f0f2023-01-25 18:10:20 +0000234 // Reduce dimension 1
Mike Kelly0e3fe102023-01-23 19:32:06 +0000235 armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 1);
236 CHECK(outputInfo.GetNumDimensions() == 2);
237 CHECK(outputInfo.GetShape()[0] == 2);
238 CHECK(outputInfo.GetShape()[1] == 1);
239}
240
241TEST_CASE("ReduceDimsShapeDimensionGreaterThanSize")
242{
243 armnn::TensorShape inputShape({ 1, 1, 1 });
244
Ryan OSheaa544f0f2023-01-25 18:10:20 +0000245 // Do not reduce because dimension does not exist
Mike Kelly0e3fe102023-01-23 19:32:06 +0000246 armnn::TensorShape outputShape = ReduceDims(inputShape, 4);
247 CHECK(outputShape.GetNumDimensions() == 3);
248 CHECK(outputShape[0] == 1);
249 CHECK(outputShape[1] == 1);
250 CHECK(outputShape[2] == 1);
251}
252
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100253
Mike Kelly0506ef02023-01-03 16:29:44 +0000254TEST_CASE("ToFloatArrayInvalidDataType")
255{
256 armnn::TensorInfo info({ 2, 3, 4 }, armnn::DataType::BFloat16);
257 std::vector<uint8_t> data {1,2,3,4,5,6,7,8,9,10};
258
259 // Invalid argument
260 CHECK_THROWS_AS(ToFloatArray(data, info), armnn::InvalidArgumentException);
261}
262
263TEST_CASE("ToFloatArrayQSymmS8PerAxis")
264{
265 std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
266 unsigned int quantizationDim = 1;
267
268 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QSymmS8, quantizationScales, quantizationDim);
269 std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 };
270 float expected[] { 10.0f, 24.0f, -37.8f, -46.4f, -10.6f, -19.2f, -25.8f, -30.4f, -6.6f, -11.2f, -13.8f, -14.4f };
271
272 std::unique_ptr<float[]> result = ToFloatArray(data, info);
273
274 for (uint i = 0; i < info.GetNumElements(); ++i)
275 {
276 CHECK_EQ(result[i], doctest::Approx(expected[i]));
277 }
278}
279
280TEST_CASE("ToFloatArrayQSymmS8")
281{
282 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QSymmS8, 0.1f);
283 std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 };
284 float expected[] { 10.0f, 12.0f, -12.6f, -11.6f, -10.6f, -9.6f, -8.6f, -7.6f, -6.6f, -5.6f, -4.6f, -3.6f };
285
286 std::unique_ptr<float[]> result = ToFloatArray(data, info);
287
288 for (uint i = 0; i < info.GetNumElements(); ++i)
289 {
290 CHECK_EQ(result[i], doctest::Approx(expected[i]));
291 }
292}
293
294TEST_CASE("ToFloatArrayQAsymmS8PerAxis")
295{
296 std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
297 unsigned int quantizationDim = 1;
298
299 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmS8, quantizationScales, quantizationDim);
300 std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 };
301 float expected[] { 10.0f, 24.0f, -37.8f, -46.4f, -10.6f, -19.2f, -25.8f, -30.4f, -6.6f, -11.2f, -13.8f, -14.4f };
302
303 std::unique_ptr<float[]> result = ToFloatArray(data, info);
304
305 for (uint i = 0; i < info.GetNumElements(); ++i)
306 {
307 CHECK_EQ(result[i], doctest::Approx(expected[i]));
308 }
309}
310
311TEST_CASE("ToFloatArrayQAsymmS8")
312{
313 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmS8, 0.1f);
314 std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 };
315 float expected[] { 10.0f, 12.0f, -12.6f, -11.6f, -10.6f, -9.6f, -8.6f, -7.6f, -6.6f, -5.6f, -4.6f, -3.6f };
316
317 std::unique_ptr<float[]> result = ToFloatArray(data, info);
318
319 for (uint i = 0; i < info.GetNumElements(); ++i)
320 {
321 CHECK_EQ(result[i], doctest::Approx(expected[i]));
322 }
323}
324
325TEST_CASE("ToFloatArrayQASymmU8PerAxis")
326{
327 std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
328 unsigned int quantizationDim = 1;
329
330 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmU8, quantizationScales, quantizationDim);
331 std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220 };
332 float expected[] { 10.0f, 24.0f, 39.0f, 56.0f, 15.0f, 32.0f, 51.0f, 72.0f, 19.0f, 40.0f, 63.0f, 88.0f };
333
334 std::unique_ptr<float[]> result = ToFloatArray(data, info);
335
336 for (uint i = 0; i < info.GetNumElements(); ++i)
337 {
338 CHECK_EQ(result[i], doctest::Approx(expected[i]));
339 }
340}
341
342TEST_CASE("ToFloatArrayQAsymmU8")
343{
344 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmU8, 0.1f);
345 std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220 };
346 float expected[] { 10.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f };
347
348 std::unique_ptr<float[]> result = ToFloatArray(data, info);
349
350 for (uint i = 0; i < info.GetNumElements(); ++i)
351 {
352 CHECK_EQ(result[i], doctest::Approx(expected[i]));
353 }
354}
355
356TEST_CASE("ToFloatArraySigned32PerAxis")
357{
358 std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
359 unsigned int quantizationDim = 1;
360
361 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed32, quantizationScales, quantizationDim);
362 std::vector<uint8_t> data { 100, 0, 0, 0, 120, 0, 0, 0, 130, 0, 0, 0, 140, 0, 0, 0, 150, 0, 0, 0, 160, 0, 0, 0,
363 170, 0, 0, 0, 180, 0, 0, 0, 190, 0, 0, 0, 200, 0, 0, 0, 210, 0, 0, 0, 220, 0, 0, 0 };
364 float expected[] { 10.0f, 24.0f, 39.0f, 56.0f, 15.0f, 32.0f, 51.0f, 72.0f, 19.0f, 40.0f, 63.0f, 88.0f };
365
366 std::unique_ptr<float[]> result = ToFloatArray(data, info);
367
368 for (uint i = 0; i < info.GetNumElements(); ++i)
369 {
370 CHECK_EQ(result[i], doctest::Approx(expected[i]));
371 }
372}
373
374TEST_CASE("ToFloatArraySigned32")
375{
376 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed32, 0.1f);
377 std::vector<uint8_t> data { 100, 0, 0, 0, 120, 0, 0, 0, 130, 0, 0, 0, 140, 0, 0, 0, 150, 0, 0, 0, 160, 0, 0, 0,
378 170, 0, 0, 0, 180, 0, 0, 0, 190, 0, 0, 0, 200, 0, 0, 0, 210, 0, 0, 0, 220, 0, 0, 0 };
379 float expected[] { 10.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f };
380
381 std::unique_ptr<float[]> result = ToFloatArray(data, info);
382
383 for (uint i = 0; i < info.GetNumElements(); ++i)
384 {
385 CHECK_EQ(result[i], doctest::Approx(expected[i]));
386 }
387}
388
389TEST_CASE("ToFloatArraySigned64PerAxis")
390{
391 std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
392 unsigned int quantizationDim = 1;
393
394 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed64, quantizationScales, quantizationDim);
395 std::vector<uint8_t> data { 100, 0, 0, 0, 0, 0, 0, 0, 120, 0, 0, 0, 0, 0, 0, 0, 130, 0, 0, 0, 0, 0, 0, 0,
396 140, 0, 0, 0, 0, 0, 0, 0, 150, 0, 0, 0, 0, 0, 0, 0, 160, 0, 0, 0, 0, 0, 0, 0,
397 170, 0, 0, 0, 0, 0, 0, 0, 180, 0, 0, 0, 0, 0, 0, 0, 190, 0, 0, 0, 0, 0, 0, 0,
398 200, 0, 0, 0, 0, 0, 0, 0, 210, 0, 0, 0, 0, 0, 0, 0, 220, 0, 0, 0, 0, 0, 0, 0 };
399 float expected[] { 10.0f, 24.0f, 39.0f, 56.0f, 15.0f, 32.0f, 51.0f, 72.0f, 19.0f, 40.0f, 63.0f, 88.0f };
400
401 std::unique_ptr<float[]> result = ToFloatArray(data, info);
402
403 for (uint i = 0; i < info.GetNumElements(); ++i)
404 {
405 CHECK_EQ(result[i], doctest::Approx(expected[i]));
406 }
407}
408
409TEST_CASE("ToFloatArraySigned64")
410{
411 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed64, 0.1f);
412 std::vector<uint8_t> data { 100, 0, 0, 0, 0, 0, 0, 0, 120, 0, 0, 0, 0, 0, 0, 0, 130, 0, 0, 0, 0, 0, 0, 0,
413 140, 0, 0, 0, 0, 0, 0, 0, 150, 0, 0, 0, 0, 0, 0, 0, 160, 0, 0, 0, 0, 0, 0, 0,
414 170, 0, 0, 0, 0, 0, 0, 0, 180, 0, 0, 0, 0, 0, 0, 0, 190, 0, 0, 0, 0, 0, 0, 0,
415 200, 0, 0, 0, 0, 0, 0, 0, 210, 0, 0, 0, 0, 0, 0, 0, 220, 0, 0, 0, 0, 0, 0, 0 };
416 float expected[] { 10.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f };
417
418 std::unique_ptr<float[]> result = ToFloatArray(data, info);
419
420 for (uint i = 0; i < info.GetNumElements(); ++i)
421 {
422 CHECK_EQ(result[i], doctest::Approx(expected[i]));
423 }
424}
Sadik Armagan1625efc2021-06-10 18:24:34 +0100425}