blob: 85ebcc307b052b7826dd239fe4f28e68e5519fb3 [file] [log] [blame]
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "armnnOnnxParser/IOnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8#include "OnnxParserTestUtils.hpp"
9
10TEST_SUITE("OnnxParser_Concat")
11{
12
13struct ConcatFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14{
15 ConcatFixture(const std::string& axis,
16 const std::vector<int>& input0Shape,
17 const std::vector<int>& input1Shape,
18 const std::vector<int>& outputShape)
19 {
20 m_Prototext = R"(
21 ir_version: 8
22 producer_name: "onnx-example"
23 graph {
24 node {
25 input: "Input0"
26 input: "Input1"
27 output: "Output"
28 op_type: "Concat"
29 attribute {
30 name: "axis"
31 i: )" + axis + R"(
32 type: INT
33 }
34 }
35 name: "concat-model"
36 input {
37 name: "Input0"
38 type {
39 tensor_type {
40 elem_type: 1
41 shape {
42 )" + armnnUtils::ConstructTensorShapeString(input0Shape) + R"(
43 }
44 }
45 }
46 }
47 input {
48 name: "Input1"
49 type {
50 tensor_type {
51 elem_type: 1
52 shape {
53 )" + armnnUtils::ConstructTensorShapeString(input1Shape) + R"(
54 }
55 }
56 }
57 }
58 output {
59 name: "Output"
60 type {
61 tensor_type {
62 elem_type: 1
63 shape {
64 )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
65 }
66 }
67 }
68 }
69 })";
70 Setup();
71 }
72};
73
74struct ConcatAxis0Fixture : ConcatFixture
75{
76 ConcatAxis0Fixture() : ConcatFixture("0", { 1, 3, 2, 5 }, { 1, 3, 2, 5 }, { 2, 3, 2, 5 }) {}
77};
78
79struct ConcatAxis1Fixture : ConcatFixture
80{
81 ConcatAxis1Fixture() : ConcatFixture("1", { 2, 2, 1, 3 }, { 2, 1, 1, 3 }, { 2, 3, 1, 3 }) {}
82};
83
84struct ConcatAxis2Fixture : ConcatFixture
85{
86 ConcatAxis2Fixture() : ConcatFixture("2", { 2, 3, 1, 1 }, { 2, 3, 2, 1 }, { 2, 3, 3, 1 }) {}
87};
88
89struct ConcatAxis3Fixture : ConcatFixture
90{
91 ConcatAxis3Fixture() : ConcatFixture("3", { 1, 3, 2, 2 }, { 1, 3, 2, 2 }, { 1, 3, 2, 4 }) {}
92};
93
94struct ConcatNegativeAxisFixture : ConcatFixture
95{
96 ConcatNegativeAxisFixture() : ConcatFixture("-1", { 1, 2, 5 }, { 1, 2, 3 }, { 1, 2, 8 }) {}
97};
98
99TEST_CASE_FIXTURE(ConcatAxis0Fixture, "ConcatAxis0Test")
100{
101 RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
102 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
103 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
104 16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
105 21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
106 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }},
107 {"Input1", { 31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
108 36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
109 41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
110 46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
111 51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
112 56.0f, 57.0f, 58.0f, 59.0f, 60.0f }}},
113 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
114 6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
115 11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
116 16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
117 21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
118 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
119 31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
120 36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
121 41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
122 46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
123 51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
124 56.0f, 57.0f, 58.0f, 59.0f, 60.0f }}});
125}
126
127TEST_CASE_FIXTURE(ConcatAxis1Fixture, "ConcatAxis1est")
128{
129 RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }},
130 {"Input1", { 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}},
131 {{"Output", { 1.0f, 2.0f, 3.0f,
132 4.0f, 5.0f, 6.0f,
133 13.0f, 14.0f, 15.0f,
134 7.0f, 8.0f, 9.0f,
135 10.0f, 11.0f, 12.0f,
136 16.0f, 17.0f, 18.0f }}});
137}
138
139TEST_CASE_FIXTURE(ConcatAxis2Fixture, "ConcatAxis2Test")
140{
141 RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }},
142 {"Input1", { 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}},
143 {{"Output", { 1.0f, 7.0f, 8.0f,
144 2.0f, 9.0f, 10.0f,
145 3.0f, 11.0f, 12.0f,
146 4.0f, 13.0f, 14.0f,
147 5.0f, 15.0f, 16.0f,
148 6.0f, 17.0f, 18.0f }}});
149}
150
151TEST_CASE_FIXTURE(ConcatAxis3Fixture, "ConcatAxis3Test")
152{
153 RunTest<4, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
154 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }},
155 {"Input1", { 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
156 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}},
157 {{"Output", { 1.0f, 2.0f, 13.0f, 14.0f,
158 3.0f, 4.0f, 15.0f, 16.0f,
159 5.0f, 6.0f, 17.0f, 18.0f,
160 7.0f, 8.0f, 19.0f, 20.0f,
161 9.0f, 10.0f, 21.0f, 22.0f,
162 11.0f, 12.0f, 23.0f, 24.0f }}});
163}
164
165TEST_CASE_FIXTURE(ConcatNegativeAxisFixture, "ConcatNegativeAxisTest")
166{
167 RunTest<3, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
168 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }},
169 {"Input1", { 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f }}},
170 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f, 12.0f, 13.0f,
171 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 14.0f, 15.0f, 16.0f }}});
172}
173
174struct ConcatMultipleInputsFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
175{
176 ConcatMultipleInputsFixture()
177 {
178 m_Prototext = R"(
179 ir_version: 8
180 producer_name: "onnx-example"
181 graph {
182 node {
183 input: "Input0"
184 input: "Input1"
185 input: "Input2"
186 output: "Output"
187 op_type: "Concat"
188 attribute {
189 name: "axis"
190 i: 1
191 type: INT
192 }
193 }
194 name: "concat-model"
195 input {
196 name: "Input0"
197 type {
198 tensor_type {
199 elem_type: 1
200 shape {
201 dim {
202 dim_value: 3
203 }
204 dim {
205 dim_value: 2
206 }
207 }
208 }
209 }
210 }
211 input {
212 name: "Input1"
213 type {
214 tensor_type {
215 elem_type: 1
216 shape {
217 dim {
218 dim_value: 3
219 }
220 dim {
221 dim_value: 3
222 }
223 }
224 }
225 }
226 }
227 input {
228 name: "Input2"
229 type {
230 tensor_type {
231 elem_type: 1
232 shape {
233 dim {
234 dim_value: 3
235 }
236 dim {
237 dim_value: 1
238 }
239 }
240 }
241 }
242 }
243 output {
244 name: "Output"
245 type {
246 tensor_type {
247 elem_type: 1
248 shape {
249 dim {
250 dim_value: 3
251 }
252 dim {
253 dim_value: 6
254 }
255 }
256 }
257 }
258 }
259 })";
260 Setup();
261 }
262};
263
264TEST_CASE_FIXTURE(ConcatMultipleInputsFixture, "ConcatMultipleInputsTest")
265{
266 RunTest<2, float>({{"Input0", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }},
267 {"Input1", { 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f }},
268 {"Input2", { 16.0f, 17.0f, 18.0f }}},
269 {{"Output", { 1.0f, 2.0f, 7.0f, 8.0f, 9.0f, 16.0f,
270 3.0f, 4.0f, 10.0f, 11.0f, 12.0f, 17.0f,
271 5.0f, 6.0f, 13.0f, 14.0f, 15.0f, 18.0f }}});
272}
273
274}