blob: 7ba87bc68081553c8203aac6ae74446875c15cbd [file] [log] [blame]
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "armnnOnnxParser/IOnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8#include "OnnxParserTestUtils.hpp"
9
10TEST_SUITE("OnnxParser_Unsqueeze")
11{
12
13struct UnsqueezeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14{
15 UnsqueezeFixture(const std::vector<int>& axes,
16 const std::vector<int>& inputShape,
17 const std::vector<int>& outputShape)
18 {
19 m_Prototext = R"(
20 ir_version: 8
21 producer_name: "onnx-example"
22 graph {
23 node {
24 input: "Input"
25 output: "Output"
26 op_type: "Unsqueeze"
27 )" + armnnUtils::ConstructIntsAttribute("axes", axes) + R"(
28 }
29 name: "test-model"
30 input {
31 name: "Input"
32 type {
33 tensor_type {
34 elem_type: 1
35 shape {
36 )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
37 }
38 }
39 }
40 }
41 output {
42 name: "Output"
43 type {
44 tensor_type {
45 elem_type: 1
46 shape {
47 )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
48 }
49 }
50 }
51 }
52 })";
53 }
54};
55
56struct UnsqueezeSingleAxesFixture : UnsqueezeFixture
57{
58 UnsqueezeSingleAxesFixture() : UnsqueezeFixture({ 0 }, { 2, 3 }, { 1, 2, 3 })
59 {
60 Setup();
61 }
62};
63
64struct UnsqueezeMultiAxesFixture : UnsqueezeFixture
65{
66 UnsqueezeMultiAxesFixture() : UnsqueezeFixture({ 1, 3 }, { 3, 2, 5 }, { 3, 1, 2, 1, 5 })
67 {
68 Setup();
69 }
70};
71
72struct UnsqueezeUnsortedAxesFixture : UnsqueezeFixture
73{
74 UnsqueezeUnsortedAxesFixture() : UnsqueezeFixture({ 3, 0, 1 }, { 2, 5 }, { 1, 1, 2, 1, 5 })
75 {
76 Setup();
77 }
78};
79
Narumol Prangnawarat452274c2021-09-23 16:12:19 +010080struct UnsqueezeScalarFixture : UnsqueezeFixture
81{
82 UnsqueezeScalarFixture() : UnsqueezeFixture({ 0 }, { }, { 1 })
83 {
84 Setup();
85 }
86};
87
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +010088TEST_CASE_FIXTURE(UnsqueezeSingleAxesFixture, "UnsqueezeSingleAxesTest")
89{
90 RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
91 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
92}
93
94TEST_CASE_FIXTURE(UnsqueezeMultiAxesFixture, "UnsqueezeMultiAxesTest")
95{
96 RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
97 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
98 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
99 16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
100 21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
101 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}},
102 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
103 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
104 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
105 16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
106 21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
107 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}});
108}
109
110TEST_CASE_FIXTURE(UnsqueezeUnsortedAxesFixture, "UnsqueezeUnsortedAxesTest")
111{
112 RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
113 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}},
114 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
115 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
116}
117
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100118TEST_CASE_FIXTURE(UnsqueezeScalarFixture, "UnsqueezeScalarTest")
119{
120 RunTest<1, float>({{"Input", { 1.0f }}},
121 {{"Output", { 1.0f }}});
122}
123
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +0100124struct UnsqueezeInputAxesFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
125{
126 UnsqueezeInputAxesFixture()
127 {
128 m_Prototext = R"(
129 ir_version: 8
130 producer_name: "onnx-example"
131 graph {
132 node {
133 input: "Input"
134 input: "Axes"
135 output: "Output"
136 op_type: "Unsqueeze"
137 }
138 initializer {
139 dims: 2
140 data_type: 7
141 int64_data: 0
142 int64_data: 3
143 name: "Axes"
144 }
145 name: "test-model"
146 input {
147 name: "Input"
148 type {
149 tensor_type {
150 elem_type: 1
151 shape {
152 dim {
153 dim_value: 3
154 }
155 dim {
156 dim_value: 2
157 }
158 dim {
159 dim_value: 5
160 }
161 }
162 }
163 }
164 }
165 output {
166 name: "Output"
167 type {
168 tensor_type {
169 elem_type: 1
170 shape {
171 dim {
172 dim_value: 1
173 }
174 dim {
175 dim_value: 3
176 }
177 dim {
178 dim_value: 2
179 }
180 dim {
181 dim_value: 1
182 }
183 dim {
184 dim_value: 5
185 }
186 }
187 }
188 }
189 }
190 })";
191 Setup();
192 }
193};
194
195TEST_CASE_FIXTURE(UnsqueezeInputAxesFixture, "UnsqueezeInputAxesTest")
196{
197 RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
198 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
199 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
200 16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
201 21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
202 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}},
203 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
204 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
205 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
206 16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
207 21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
208 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}});
209}
210
211}