blob: 6d5f719eb1e3057e9b2c4831fcf79998dbdff6de [file] [log] [blame]
Narumol Prangnawarat02807852019-09-11 16:43:09 +01001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Narumol Prangnawarat02807852019-09-11 16:43:09 +01006#include <armnn/Types.hpp>
7
Matteo Martincighe011d202019-11-28 11:35:47 +00008#include <armnnUtils/TensorUtils.hpp>
9
Sadik Armagan1625efc2021-06-10 18:24:34 +010010#include <doctest/doctest.h>
Narumol Prangnawarat02807852019-09-11 16:43:09 +010011
12using namespace armnn;
13using namespace armnnUtils;
14
Sadik Armagan1625efc2021-06-10 18:24:34 +010015TEST_SUITE("TensorUtilsSuite")
16{
17TEST_CASE("ExpandDimsAxis0Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010018{
19 armnn::TensorShape inputShape({ 2, 3, 4 });
20
21 // Expand dimension 0
22 armnn::TensorShape outputShape = ExpandDims(inputShape, 0);
Sadik Armagan1625efc2021-06-10 18:24:34 +010023 CHECK(outputShape.GetNumDimensions() == 4);
24 CHECK(outputShape[0] == 1);
25 CHECK(outputShape[1] == 2);
26 CHECK(outputShape[2] == 3);
27 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010028}
29
Sadik Armagan1625efc2021-06-10 18:24:34 +010030TEST_CASE("ExpandDimsAxis1Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010031{
32 armnn::TensorShape inputShape({ 2, 3, 4 });
33
34 // Expand dimension 1
35 armnn::TensorShape outputShape = ExpandDims(inputShape, 1);
Sadik Armagan1625efc2021-06-10 18:24:34 +010036 CHECK(outputShape.GetNumDimensions() == 4);
37 CHECK(outputShape[0] == 2);
38 CHECK(outputShape[1] == 1);
39 CHECK(outputShape[2] == 3);
40 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010041}
42
Sadik Armagan1625efc2021-06-10 18:24:34 +010043TEST_CASE("ExpandDimsAxis2Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010044{
45 armnn::TensorShape inputShape({ 2, 3, 4 });
46
47 // Expand dimension 2
48 armnn::TensorShape outputShape = ExpandDims(inputShape, 2);
Sadik Armagan1625efc2021-06-10 18:24:34 +010049 CHECK(outputShape.GetNumDimensions() == 4);
50 CHECK(outputShape[0] == 2);
51 CHECK(outputShape[1] == 3);
52 CHECK(outputShape[2] == 1);
53 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010054}
55
Sadik Armagan1625efc2021-06-10 18:24:34 +010056TEST_CASE("ExpandDimsAxis3Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010057{
58 armnn::TensorShape inputShape({ 2, 3, 4 });
59
60 // Expand dimension 3
61 armnn::TensorShape outputShape = ExpandDims(inputShape, 3);
Sadik Armagan1625efc2021-06-10 18:24:34 +010062 CHECK(outputShape.GetNumDimensions() == 4);
63 CHECK(outputShape[0] == 2);
64 CHECK(outputShape[1] == 3);
65 CHECK(outputShape[2] == 4);
66 CHECK(outputShape[3] == 1);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010067}
68
Sadik Armagan1625efc2021-06-10 18:24:34 +010069TEST_CASE("ExpandDimsNegativeAxis1Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010070{
71 armnn::TensorShape inputShape({ 2, 3, 4 });
72
73 // Expand dimension -1
74 armnn::TensorShape outputShape = ExpandDims(inputShape, -1);
Sadik Armagan1625efc2021-06-10 18:24:34 +010075 CHECK(outputShape.GetNumDimensions() == 4);
76 CHECK(outputShape[0] == 2);
77 CHECK(outputShape[1] == 3);
78 CHECK(outputShape[2] == 4);
79 CHECK(outputShape[3] == 1);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010080}
81
Sadik Armagan1625efc2021-06-10 18:24:34 +010082TEST_CASE("ExpandDimsNegativeAxis2Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010083{
84 armnn::TensorShape inputShape({ 2, 3, 4 });
85
86 // Expand dimension -2
87 armnn::TensorShape outputShape = ExpandDims(inputShape, -2);
Sadik Armagan1625efc2021-06-10 18:24:34 +010088 CHECK(outputShape.GetNumDimensions() == 4);
89 CHECK(outputShape[0] == 2);
90 CHECK(outputShape[1] == 3);
91 CHECK(outputShape[2] == 1);
92 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +010093}
94
Sadik Armagan1625efc2021-06-10 18:24:34 +010095TEST_CASE("ExpandDimsNegativeAxis3Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +010096{
97 armnn::TensorShape inputShape({ 2, 3, 4 });
98
99 // Expand dimension -3
100 armnn::TensorShape outputShape = ExpandDims(inputShape, -3);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100101 CHECK(outputShape.GetNumDimensions() == 4);
102 CHECK(outputShape[0] == 2);
103 CHECK(outputShape[1] == 1);
104 CHECK(outputShape[2] == 3);
105 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100106}
107
Sadik Armagan1625efc2021-06-10 18:24:34 +0100108TEST_CASE("ExpandDimsNegativeAxis4Test")
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100109{
110 armnn::TensorShape inputShape({ 2, 3, 4 });
111
112 // Expand dimension -4
113 armnn::TensorShape outputShape = ExpandDims(inputShape, -4);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100114 CHECK(outputShape.GetNumDimensions() == 4);
115 CHECK(outputShape[0] == 1);
116 CHECK(outputShape[1] == 2);
117 CHECK(outputShape[2] == 3);
118 CHECK(outputShape[3] == 4);
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100119}
120
Sadik Armagan1625efc2021-06-10 18:24:34 +0100121TEST_CASE("ExpandDimsInvalidAxisTest")
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100122{
123 armnn::TensorShape inputShape({ 2, 3, 4 });
124
125 // Invalid expand dimension 4
Sadik Armagan1625efc2021-06-10 18:24:34 +0100126 CHECK_THROWS_AS(ExpandDims(inputShape, 4), armnn::InvalidArgumentException);
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100127}
128
Sadik Armagan1625efc2021-06-10 18:24:34 +0100129TEST_CASE("ExpandDimsInvalidNegativeAxisTest")
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100130{
131 armnn::TensorShape inputShape({ 2, 3, 4 });
132
133 // Invalid expand dimension -5
Sadik Armagan1625efc2021-06-10 18:24:34 +0100134 CHECK_THROWS_AS(ExpandDims(inputShape, -5), armnn::InvalidArgumentException);
Narumol Prangnawarat02807852019-09-11 16:43:09 +0100135}
136
Sadik Armagan1625efc2021-06-10 18:24:34 +0100137}