blob: 46ac0dfeee8482135f9b714711df86e3ffeba9ef [file] [log] [blame]
Ryan OSheaed27ee72020-04-22 16:37:29 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Ryan OSheaed27ee72020-04-22 16:37:29 +01006#include "armnnOnnxParser/IOnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8
Sadik Armagan1625efc2021-06-10 18:24:34 +01009TEST_SUITE("OnnxParser_Flatter")
10{
Ryan OSheaed27ee72020-04-22 16:37:29 +010011struct FlattenMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12{
13 FlattenMainFixture(const std::string& dataType)
14 {
15 m_Prototext = R"(
16 ir_version: 3
17 producer_name: "CNTK"
18 producer_version: "2.5.1"
19 domain: "ai.cntk"
20 model_version: 1
21 graph {
22 name: "CNTKGraph"
23 input {
24 name: "Input"
25 type {
26 tensor_type {
27 elem_type: )" + dataType + R"(
28 shape {
29 dim {
30 dim_value: 2
31 }
32 dim {
33 dim_value: 2
34 }
35 dim {
36 dim_value: 3
37 }
38 dim {
39 dim_value: 3
40 }
41 }
42 }
43 }
44 }
45 node {
46 input: "Input"
47 output: "Output"
48 name: "flatten"
49 op_type: "Flatten"
50 attribute {
51 name: "axis"
52 i: 2
53 type: INT
54 }
55 }
56 output {
57 name: "Output"
58 type {
59 tensor_type {
60 elem_type: 1
61 shape {
62 dim {
63 dim_value: 4
64 }
65 dim {
66 dim_value: 9
67 }
68 }
69 }
70 }
71 }
72 }
73 opset_import {
74 version: 7
75 })";
76 }
77};
78
79struct FlattenDefaultAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
80{
81 FlattenDefaultAxisFixture(const std::string& dataType)
82 {
83 m_Prototext = R"(
84 ir_version: 3
85 producer_name: "CNTK"
86 producer_version: "2.5.1"
87 domain: "ai.cntk"
88 model_version: 1
89 graph {
90 name: "CNTKGraph"
91 input {
92 name: "Input"
93 type {
94 tensor_type {
95 elem_type: )" + dataType + R"(
96 shape {
97 dim {
98 dim_value: 2
99 }
100 dim {
101 dim_value: 2
102 }
103 dim {
104 dim_value: 3
105 }
106 dim {
107 dim_value: 3
108 }
109 }
110 }
111 }
112 }
113 node {
114 input: "Input"
115 output: "Output"
116 name: "flatten"
117 op_type: "Flatten"
118 }
119 output {
120 name: "Output"
121 type {
122 tensor_type {
123 elem_type: 1
124 shape {
125 dim {
126 dim_value: 2
127 }
128 dim {
129 dim_value: 18
130 }
131 }
132 }
133 }
134 }
135 }
136 opset_import {
137 version: 7
138 })";
139 }
140};
141
142struct FlattenAxisZeroFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
143{
144 FlattenAxisZeroFixture(const std::string& dataType)
145 {
146 m_Prototext = R"(
147 ir_version: 3
148 producer_name: "CNTK"
149 producer_version: "2.5.1"
150 domain: "ai.cntk"
151 model_version: 1
152 graph {
153 name: "CNTKGraph"
154 input {
155 name: "Input"
156 type {
157 tensor_type {
158 elem_type: )" + dataType + R"(
159 shape {
160 dim {
161 dim_value: 2
162 }
163 dim {
164 dim_value: 2
165 }
166 dim {
167 dim_value: 3
168 }
169 dim {
170 dim_value: 3
171 }
172 }
173 }
174 }
175 }
176 node {
177 input: "Input"
178 output: "Output"
179 name: "flatten"
180 op_type: "Flatten"
181 attribute {
182 name: "axis"
183 i: 0
184 type: INT
185 }
186 }
187 output {
188 name: "Output"
189 type {
190 tensor_type {
191 elem_type: 1
192 shape {
193 dim {
194 dim_value: 1
195 }
196 dim {
197 dim_value: 36
198 }
199 }
200 }
201 }
202 }
203 }
204 opset_import {
205 version: 7
206 })";
207 }
208};
209
210struct FlattenNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
211{
212 FlattenNegativeAxisFixture(const std::string& dataType)
213 {
214 m_Prototext = R"(
215 ir_version: 3
216 producer_name: "CNTK"
217 producer_version: "2.5.1"
218 domain: "ai.cntk"
219 model_version: 1
220 graph {
221 name: "CNTKGraph"
222 input {
223 name: "Input"
224 type {
225 tensor_type {
226 elem_type: )" + dataType + R"(
227 shape {
228 dim {
229 dim_value: 2
230 }
231 dim {
232 dim_value: 2
233 }
234 dim {
235 dim_value: 3
236 }
237 dim {
238 dim_value: 3
239 }
240 }
241 }
242 }
243 }
244 node {
245 input: "Input"
246 output: "Output"
247 name: "flatten"
248 op_type: "Flatten"
249 attribute {
250 name: "axis"
251 i: -1
252 type: INT
253 }
254 }
255 output {
256 name: "Output"
257 type {
258 tensor_type {
259 elem_type: 1
260 shape {
261 dim {
262 dim_value: 12
263 }
264 dim {
265 dim_value: 3
266 }
267 }
268 }
269 }
270 }
271 }
272 opset_import {
273 version: 7
274 })";
275 }
276};
277
278struct FlattenInvalidNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
279{
280 FlattenInvalidNegativeAxisFixture(const std::string& dataType)
281 {
282 m_Prototext = R"(
283 ir_version: 3
284 producer_name: "CNTK"
285 producer_version: "2.5.1"
286 domain: "ai.cntk"
287 model_version: 1
288 graph {
289 name: "CNTKGraph"
290 input {
291 name: "Input"
292 type {
293 tensor_type {
294 elem_type: )" + dataType + R"(
295 shape {
296 dim {
297 dim_value: 2
298 }
299 dim {
300 dim_value: 2
301 }
302 dim {
303 dim_value: 3
304 }
305 dim {
306 dim_value: 3
307 }
308 }
309 }
310 }
311 }
312 node {
313 input: "Input"
314 output: "Output"
315 name: "flatten"
316 op_type: "Flatten"
317 attribute {
318 name: "axis"
319 i: -5
320 type: INT
321 }
322 }
323 output {
324 name: "Output"
325 type {
326 tensor_type {
327 elem_type: 1
328 shape {
329 dim {
330 dim_value: 12
331 }
332 dim {
333 dim_value: 3
334 }
335 }
336 }
337 }
338 }
339 }
340 opset_import {
341 version: 7
342 })";
343 }
344};
345
346struct FlattenValidFixture : FlattenMainFixture
347{
348 FlattenValidFixture() : FlattenMainFixture("1") {
349 Setup();
350 }
351};
352
353struct FlattenDefaultValidFixture : FlattenDefaultAxisFixture
354{
355 FlattenDefaultValidFixture() : FlattenDefaultAxisFixture("1") {
356 Setup();
357 }
358};
359
360struct FlattenAxisZeroValidFixture : FlattenAxisZeroFixture
361{
362 FlattenAxisZeroValidFixture() : FlattenAxisZeroFixture("1") {
363 Setup();
364 }
365};
366
367struct FlattenNegativeAxisValidFixture : FlattenNegativeAxisFixture
368{
369 FlattenNegativeAxisValidFixture() : FlattenNegativeAxisFixture("1") {
370 Setup();
371 }
372};
373
374struct FlattenInvalidFixture : FlattenMainFixture
375{
376 FlattenInvalidFixture() : FlattenMainFixture("10") { }
377};
378
379struct FlattenInvalidAxisFixture : FlattenInvalidNegativeAxisFixture
380{
381 FlattenInvalidAxisFixture() : FlattenInvalidNegativeAxisFixture("1") { }
382};
383
Sadik Armagan1625efc2021-06-10 18:24:34 +0100384TEST_CASE_FIXTURE(FlattenValidFixture, "ValidFlattenTest")
Ryan OSheaed27ee72020-04-22 16:37:29 +0100385{
386 RunTest<2>({{"Input",
387 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
388 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
389 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
390 {{"Output",
391 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
392 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
393 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
394}
395
Sadik Armagan1625efc2021-06-10 18:24:34 +0100396TEST_CASE_FIXTURE(FlattenDefaultValidFixture, "ValidFlattenDefaultTest")
Ryan OSheaed27ee72020-04-22 16:37:29 +0100397{
398 RunTest<2>({{"Input",
399 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
400 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
401 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
402 {{"Output",
403 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
404 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
405 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
406}
407
Sadik Armagan1625efc2021-06-10 18:24:34 +0100408TEST_CASE_FIXTURE(FlattenAxisZeroValidFixture, "ValidFlattenAxisZeroTest")
Ryan OSheaed27ee72020-04-22 16:37:29 +0100409{
410 RunTest<2>({{"Input",
411 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
412 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
413 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
414 {{"Output",
415 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
416 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
417 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
418}
419
Sadik Armagan1625efc2021-06-10 18:24:34 +0100420TEST_CASE_FIXTURE(FlattenNegativeAxisValidFixture, "ValidFlattenNegativeAxisTest")
Ryan OSheaed27ee72020-04-22 16:37:29 +0100421{
422 RunTest<2>({{"Input",
423 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
424 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
425 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
426 {{"Output",
427 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
428 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
429 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
430}
431
Sadik Armagan1625efc2021-06-10 18:24:34 +0100432TEST_CASE_FIXTURE(FlattenInvalidFixture, "IncorrectDataTypeFlatten")
Ryan OSheaed27ee72020-04-22 16:37:29 +0100433{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100434 CHECK_THROWS_AS(Setup(), armnn::ParseException);
Ryan OSheaed27ee72020-04-22 16:37:29 +0100435}
436
Sadik Armagan1625efc2021-06-10 18:24:34 +0100437TEST_CASE_FIXTURE(FlattenInvalidAxisFixture, "IncorrectAxisFlatten")
Ryan OSheaed27ee72020-04-22 16:37:29 +0100438{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100439 CHECK_THROWS_AS(Setup(), armnn::ParseException);
Ryan OSheaed27ee72020-04-22 16:37:29 +0100440}
441
Sadik Armagan1625efc2021-06-10 18:24:34 +0100442}