blob: a69a0098ce91e7eb6ffa23b08f3e7edb444b506b [file] [log] [blame]
Narumol Prangnawarat02807852019-09-11 16:43:09 +01001//
Mike Kelly0e3fe102023-01-23 19:32:06 +00002// Copyright © 2019,2021-2023 Arm Ltd and Contributors. All rights reserved.
Narumol Prangnawarat02807852019-09-11 16:43:09 +01003// SPDX-License-Identifier: MIT
4//
5
Narumol Prangnawarat02807852019-09-11 16:43:09 +01006#include <armnn/Types.hpp>
7
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/TensorUtils.hpp>
9
Sadik Armagan1625efc2021-06-10 18:24:34 +010010#include <doctest/doctest.h>
Narumol Prangnawarat02807852019-09-11 16:43:09 +010011
12using namespace armnn;
13using namespace armnnUtils;
14
Sadik Armagan1625efc2021-06-10 18:24:34 +010015TEST_SUITE("TensorUtilsSuite")
16{
17TEST_CASE("ExpandDimsAxis0Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010018{
19 armnn::TensorShape inputShape({ 2, 3, 4 });
20
21 // Expand dimension 0
22 armnn::TensorShape outputShape = ExpandDims(inputShape, 0);
Sadik Armagan1625efc2021-06-10 18:24:34 +010023 CHECK(outputShape.GetNumDimensions() == 4);
24 CHECK(outputShape[0] == 1);
25 CHECK(outputShape[1] == 2);
26 CHECK(outputShape[2] == 3);
27 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010028}
29
Sadik Armagan1625efc2021-06-10 18:24:34 +010030TEST_CASE("ExpandDimsAxis1Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010031{
32 armnn::TensorShape inputShape({ 2, 3, 4 });
33
34 // Expand dimension 1
35 armnn::TensorShape outputShape = ExpandDims(inputShape, 1);
Sadik Armagan1625efc2021-06-10 18:24:34 +010036 CHECK(outputShape.GetNumDimensions() == 4);
37 CHECK(outputShape[0] == 2);
38 CHECK(outputShape[1] == 1);
39 CHECK(outputShape[2] == 3);
40 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010041}
42
Sadik Armagan1625efc2021-06-10 18:24:34 +010043TEST_CASE("ExpandDimsAxis2Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010044{
45 armnn::TensorShape inputShape({ 2, 3, 4 });
46
47 // Expand dimension 2
48 armnn::TensorShape outputShape = ExpandDims(inputShape, 2);
Sadik Armagan1625efc2021-06-10 18:24:34 +010049 CHECK(outputShape.GetNumDimensions() == 4);
50 CHECK(outputShape[0] == 2);
51 CHECK(outputShape[1] == 3);
52 CHECK(outputShape[2] == 1);
53 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010054}
55
Sadik Armagan1625efc2021-06-10 18:24:34 +010056TEST_CASE("ExpandDimsAxis3Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010057{
58 armnn::TensorShape inputShape({ 2, 3, 4 });
59
60 // Expand dimension 3
61 armnn::TensorShape outputShape = ExpandDims(inputShape, 3);
Sadik Armagan1625efc2021-06-10 18:24:34 +010062 CHECK(outputShape.GetNumDimensions() == 4);
63 CHECK(outputShape[0] == 2);
64 CHECK(outputShape[1] == 3);
65 CHECK(outputShape[2] == 4);
66 CHECK(outputShape[3] == 1);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010067}
68
Sadik Armagan1625efc2021-06-10 18:24:34 +010069TEST_CASE("ExpandDimsNegativeAxis1Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010070{
71 armnn::TensorShape inputShape({ 2, 3, 4 });
72
73 // Expand dimension -1
74 armnn::TensorShape outputShape = ExpandDims(inputShape, -1);
Sadik Armagan1625efc2021-06-10 18:24:34 +010075 CHECK(outputShape.GetNumDimensions() == 4);
76 CHECK(outputShape[0] == 2);
77 CHECK(outputShape[1] == 3);
78 CHECK(outputShape[2] == 4);
79 CHECK(outputShape[3] == 1);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010080}
81
Sadik Armagan1625efc2021-06-10 18:24:34 +010082TEST_CASE("ExpandDimsNegativeAxis2Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010083{
84 armnn::TensorShape inputShape({ 2, 3, 4 });
85
86 // Expand dimension -2
87 armnn::TensorShape outputShape = ExpandDims(inputShape, -2);
Sadik Armagan1625efc2021-06-10 18:24:34 +010088 CHECK(outputShape.GetNumDimensions() == 4);
89 CHECK(outputShape[0] == 2);
90 CHECK(outputShape[1] == 3);
91 CHECK(outputShape[2] == 1);
92 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010093}
94
Sadik Armagan1625efc2021-06-10 18:24:34 +010095TEST_CASE("ExpandDimsNegativeAxis3Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010096{
97 armnn::TensorShape inputShape({ 2, 3, 4 });
98
99 // Expand dimension -3
100 armnn::TensorShape outputShape = ExpandDims(inputShape, -3);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100101 CHECK(outputShape.GetNumDimensions() == 4);
102 CHECK(outputShape[0] == 2);
103 CHECK(outputShape[1] == 1);
104 CHECK(outputShape[2] == 3);
105 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100106}
107
Sadik Armagan1625efc2021-06-10 18:24:34 +0100108TEST_CASE("ExpandDimsNegativeAxis4Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100109{
110 armnn::TensorShape inputShape({ 2, 3, 4 });
111
112 // Expand dimension -4
113 armnn::TensorShape outputShape = ExpandDims(inputShape, -4);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100114 CHECK(outputShape.GetNumDimensions() == 4);
115 CHECK(outputShape[0] == 1);
116 CHECK(outputShape[1] == 2);
117 CHECK(outputShape[2] == 3);
118 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100119}
120
Sadik Armagan1625efc2021-06-10 18:24:34 +0100121TEST_CASE("ExpandDimsInvalidAxisTest")
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100122{
123 armnn::TensorShape inputShape({ 2, 3, 4 });
124
125 // Invalid expand dimension 4
Sadik Armagan1625efc2021-06-10 18:24:34 +0100126 CHECK_THROWS_AS(ExpandDims(inputShape, 4), armnn::InvalidArgumentException);
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100127}
128
Mike Kelly0e3fe102023-01-23 19:32:06 +0000129TEST_CASE("ReduceDimsShapeAll1s")
130{
131 armnn::TensorShape inputShape({ 1, 1, 1 });
132
133 // Invalid expand dimension 4
134 armnn::TensorShape outputShape = ReduceDims(inputShape, 2);
135 CHECK(outputShape.GetNumDimensions() == 2);
136 CHECK(outputShape[0] == 1);
137 CHECK(outputShape[1] == 1);
138}
139
140TEST_CASE("ReduceDimsShapeNotEnough1s")
141{
142 armnn::TensorShape inputShape({ 1, 2, 1 });
143
144 // Invalid expand dimension 4
145 armnn::TensorShape outputShape = ReduceDims(inputShape, 1);
146 CHECK(outputShape.GetNumDimensions() == 2);
147 CHECK(outputShape[0] == 2);
148 CHECK(outputShape[1] == 1);
149}
150
151TEST_CASE("ReduceDimsInfoAll1s")
152{
153 armnn::TensorInfo inputInfo({ 1, 1, 1 }, DataType::Float32);
154
155 // Invalid expand dimension 4
156 armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 2);
157 CHECK(outputInfo.GetShape().GetNumDimensions() == 2);
158 CHECK(outputInfo.GetShape()[0] == 1);
159 CHECK(outputInfo.GetShape()[1] == 1);
160}
161
162TEST_CASE("ReduceDimsInfoNotEnough1s")
163{
164 armnn::TensorInfo inputInfo({ 1, 2, 1 }, DataType::Float32);
165
166 // Invalid expand dimension 4
167 armnn::TensorInfo outputInfo = ReduceDims(inputInfo, 1);
168 CHECK(outputInfo.GetNumDimensions() == 2);
169 CHECK(outputInfo.GetShape()[0] == 2);
170 CHECK(outputInfo.GetShape()[1] == 1);
171}
172
173TEST_CASE("ReduceDimsShapeDimensionGreaterThanSize")
174{
175 armnn::TensorShape inputShape({ 1, 1, 1 });
176
177 // Invalid expand dimension 4
178 armnn::TensorShape outputShape = ReduceDims(inputShape, 4);
179 CHECK(outputShape.GetNumDimensions() == 3);
180 CHECK(outputShape[0] == 1);
181 CHECK(outputShape[1] == 1);
182 CHECK(outputShape[2] == 1);
183}
184
Sadik Armagan1625efc2021-06-10 18:24:34 +0100185TEST_CASE("ExpandDimsInvalidNegativeAxisTest")
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100186{
187 armnn::TensorShape inputShape({ 2, 3, 4 });
188
189 // Invalid expand dimension -5
Sadik Armagan1625efc2021-06-10 18:24:34 +0100190 CHECK_THROWS_AS(ExpandDims(inputShape, -5), armnn::InvalidArgumentException);
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100191}
192
Mike Kelly0506ef02023-01-03 16:29:44 +0000193TEST_CASE("ToFloatArrayInvalidDataType")
194{
195 armnn::TensorInfo info({ 2, 3, 4 }, armnn::DataType::BFloat16);
196 std::vector<uint8_t> data {1,2,3,4,5,6,7,8,9,10};
197
198 // Invalid argument
199 CHECK_THROWS_AS(ToFloatArray(data, info), armnn::InvalidArgumentException);
200}
201
202TEST_CASE("ToFloatArrayQSymmS8PerAxis")
203{
204 std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
205 unsigned int quantizationDim = 1;
206
207 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QSymmS8, quantizationScales, quantizationDim);
208 std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 };
209 float expected[] { 10.0f, 24.0f, -37.8f, -46.4f, -10.6f, -19.2f, -25.8f, -30.4f, -6.6f, -11.2f, -13.8f, -14.4f };
210
211 std::unique_ptr<float[]> result = ToFloatArray(data, info);
212
213 for (uint i = 0; i < info.GetNumElements(); ++i)
214 {
215 CHECK_EQ(result[i], doctest::Approx(expected[i]));
216 }
217}
218
219TEST_CASE("ToFloatArrayQSymmS8")
220{
221 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QSymmS8, 0.1f);
222 std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 };
223 float expected[] { 10.0f, 12.0f, -12.6f, -11.6f, -10.6f, -9.6f, -8.6f, -7.6f, -6.6f, -5.6f, -4.6f, -3.6f };
224
225 std::unique_ptr<float[]> result = ToFloatArray(data, info);
226
227 for (uint i = 0; i < info.GetNumElements(); ++i)
228 {
229 CHECK_EQ(result[i], doctest::Approx(expected[i]));
230 }
231}
232
233TEST_CASE("ToFloatArrayQAsymmS8PerAxis")
234{
235 std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
236 unsigned int quantizationDim = 1;
237
238 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmS8, quantizationScales, quantizationDim);
239 std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 };
240 float expected[] { 10.0f, 24.0f, -37.8f, -46.4f, -10.6f, -19.2f, -25.8f, -30.4f, -6.6f, -11.2f, -13.8f, -14.4f };
241
242 std::unique_ptr<float[]> result = ToFloatArray(data, info);
243
244 for (uint i = 0; i < info.GetNumElements(); ++i)
245 {
246 CHECK_EQ(result[i], doctest::Approx(expected[i]));
247 }
248}
249
250TEST_CASE("ToFloatArrayQAsymmS8")
251{
252 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmS8, 0.1f);
253 std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170 ,180, 190, 200, 210, 220 };
254 float expected[] { 10.0f, 12.0f, -12.6f, -11.6f, -10.6f, -9.6f, -8.6f, -7.6f, -6.6f, -5.6f, -4.6f, -3.6f };
255
256 std::unique_ptr<float[]> result = ToFloatArray(data, info);
257
258 for (uint i = 0; i < info.GetNumElements(); ++i)
259 {
260 CHECK_EQ(result[i], doctest::Approx(expected[i]));
261 }
262}
263
264TEST_CASE("ToFloatArrayQASymmU8PerAxis")
265{
266 std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
267 unsigned int quantizationDim = 1;
268
269 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmU8, quantizationScales, quantizationDim);
270 std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220 };
271 float expected[] { 10.0f, 24.0f, 39.0f, 56.0f, 15.0f, 32.0f, 51.0f, 72.0f, 19.0f, 40.0f, 63.0f, 88.0f };
272
273 std::unique_ptr<float[]> result = ToFloatArray(data, info);
274
275 for (uint i = 0; i < info.GetNumElements(); ++i)
276 {
277 CHECK_EQ(result[i], doctest::Approx(expected[i]));
278 }
279}
280
281TEST_CASE("ToFloatArrayQAsymmU8")
282{
283 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::QAsymmU8, 0.1f);
284 std::vector<uint8_t> data { 100, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220 };
285 float expected[] { 10.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f };
286
287 std::unique_ptr<float[]> result = ToFloatArray(data, info);
288
289 for (uint i = 0; i < info.GetNumElements(); ++i)
290 {
291 CHECK_EQ(result[i], doctest::Approx(expected[i]));
292 }
293}
294
295TEST_CASE("ToFloatArraySigned32PerAxis")
296{
297 std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
298 unsigned int quantizationDim = 1;
299
300 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed32, quantizationScales, quantizationDim);
301 std::vector<uint8_t> data { 100, 0, 0, 0, 120, 0, 0, 0, 130, 0, 0, 0, 140, 0, 0, 0, 150, 0, 0, 0, 160, 0, 0, 0,
302 170, 0, 0, 0, 180, 0, 0, 0, 190, 0, 0, 0, 200, 0, 0, 0, 210, 0, 0, 0, 220, 0, 0, 0 };
303 float expected[] { 10.0f, 24.0f, 39.0f, 56.0f, 15.0f, 32.0f, 51.0f, 72.0f, 19.0f, 40.0f, 63.0f, 88.0f };
304
305 std::unique_ptr<float[]> result = ToFloatArray(data, info);
306
307 for (uint i = 0; i < info.GetNumElements(); ++i)
308 {
309 CHECK_EQ(result[i], doctest::Approx(expected[i]));
310 }
311}
312
313TEST_CASE("ToFloatArraySigned32")
314{
315 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed32, 0.1f);
316 std::vector<uint8_t> data { 100, 0, 0, 0, 120, 0, 0, 0, 130, 0, 0, 0, 140, 0, 0, 0, 150, 0, 0, 0, 160, 0, 0, 0,
317 170, 0, 0, 0, 180, 0, 0, 0, 190, 0, 0, 0, 200, 0, 0, 0, 210, 0, 0, 0, 220, 0, 0, 0 };
318 float expected[] { 10.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f };
319
320 std::unique_ptr<float[]> result = ToFloatArray(data, info);
321
322 for (uint i = 0; i < info.GetNumElements(); ++i)
323 {
324 CHECK_EQ(result[i], doctest::Approx(expected[i]));
325 }
326}
327
328TEST_CASE("ToFloatArraySigned64PerAxis")
329{
330 std::vector<float> quantizationScales { 0.1f, 0.2f, 0.3f, 0.4f };
331 unsigned int quantizationDim = 1;
332
333 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed64, quantizationScales, quantizationDim);
334 std::vector<uint8_t> data { 100, 0, 0, 0, 0, 0, 0, 0, 120, 0, 0, 0, 0, 0, 0, 0, 130, 0, 0, 0, 0, 0, 0, 0,
335 140, 0, 0, 0, 0, 0, 0, 0, 150, 0, 0, 0, 0, 0, 0, 0, 160, 0, 0, 0, 0, 0, 0, 0,
336 170, 0, 0, 0, 0, 0, 0, 0, 180, 0, 0, 0, 0, 0, 0, 0, 190, 0, 0, 0, 0, 0, 0, 0,
337 200, 0, 0, 0, 0, 0, 0, 0, 210, 0, 0, 0, 0, 0, 0, 0, 220, 0, 0, 0, 0, 0, 0, 0 };
338 float expected[] { 10.0f, 24.0f, 39.0f, 56.0f, 15.0f, 32.0f, 51.0f, 72.0f, 19.0f, 40.0f, 63.0f, 88.0f };
339
340 std::unique_ptr<float[]> result = ToFloatArray(data, info);
341
342 for (uint i = 0; i < info.GetNumElements(); ++i)
343 {
344 CHECK_EQ(result[i], doctest::Approx(expected[i]));
345 }
346}
347
348TEST_CASE("ToFloatArraySigned64")
349{
350 armnn::TensorInfo info({ 3, 4 }, armnn::DataType::Signed64, 0.1f);
351 std::vector<uint8_t> data { 100, 0, 0, 0, 0, 0, 0, 0, 120, 0, 0, 0, 0, 0, 0, 0, 130, 0, 0, 0, 0, 0, 0, 0,
352 140, 0, 0, 0, 0, 0, 0, 0, 150, 0, 0, 0, 0, 0, 0, 0, 160, 0, 0, 0, 0, 0, 0, 0,
353 170, 0, 0, 0, 0, 0, 0, 0, 180, 0, 0, 0, 0, 0, 0, 0, 190, 0, 0, 0, 0, 0, 0, 0,
354 200, 0, 0, 0, 0, 0, 0, 0, 210, 0, 0, 0, 0, 0, 0, 0, 220, 0, 0, 0, 0, 0, 0, 0 };
355 float expected[] { 10.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f };
356
357 std::unique_ptr<float[]> result = ToFloatArray(data, info);
358
359 for (uint i = 0; i < info.GetNumElements(); ++i)
360 {
361 CHECK_EQ(result[i], doctest::Approx(expected[i]));
362 }
363}
Sadik Armagan1625efc2021-06-10 18:24:34 +0100364}