blob: 1ba509e5e6b436cc18f21f0093b6c1881363df55 [file] [log] [blame]
Ryan OSheaed27ee72020-04-22 16:37:29 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnOnnxParser/IOnnxParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12struct FlattenMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13{
14 FlattenMainFixture(const std::string& dataType)
15 {
16 m_Prototext = R"(
17 ir_version: 3
18 producer_name: "CNTK"
19 producer_version: "2.5.1"
20 domain: "ai.cntk"
21 model_version: 1
22 graph {
23 name: "CNTKGraph"
24 input {
25 name: "Input"
26 type {
27 tensor_type {
28 elem_type: )" + dataType + R"(
29 shape {
30 dim {
31 dim_value: 2
32 }
33 dim {
34 dim_value: 2
35 }
36 dim {
37 dim_value: 3
38 }
39 dim {
40 dim_value: 3
41 }
42 }
43 }
44 }
45 }
46 node {
47 input: "Input"
48 output: "Output"
49 name: "flatten"
50 op_type: "Flatten"
51 attribute {
52 name: "axis"
53 i: 2
54 type: INT
55 }
56 }
57 output {
58 name: "Output"
59 type {
60 tensor_type {
61 elem_type: 1
62 shape {
63 dim {
64 dim_value: 4
65 }
66 dim {
67 dim_value: 9
68 }
69 }
70 }
71 }
72 }
73 }
74 opset_import {
75 version: 7
76 })";
77 }
78};
79
80struct FlattenDefaultAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
81{
82 FlattenDefaultAxisFixture(const std::string& dataType)
83 {
84 m_Prototext = R"(
85 ir_version: 3
86 producer_name: "CNTK"
87 producer_version: "2.5.1"
88 domain: "ai.cntk"
89 model_version: 1
90 graph {
91 name: "CNTKGraph"
92 input {
93 name: "Input"
94 type {
95 tensor_type {
96 elem_type: )" + dataType + R"(
97 shape {
98 dim {
99 dim_value: 2
100 }
101 dim {
102 dim_value: 2
103 }
104 dim {
105 dim_value: 3
106 }
107 dim {
108 dim_value: 3
109 }
110 }
111 }
112 }
113 }
114 node {
115 input: "Input"
116 output: "Output"
117 name: "flatten"
118 op_type: "Flatten"
119 }
120 output {
121 name: "Output"
122 type {
123 tensor_type {
124 elem_type: 1
125 shape {
126 dim {
127 dim_value: 2
128 }
129 dim {
130 dim_value: 18
131 }
132 }
133 }
134 }
135 }
136 }
137 opset_import {
138 version: 7
139 })";
140 }
141};
142
143struct FlattenAxisZeroFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
144{
145 FlattenAxisZeroFixture(const std::string& dataType)
146 {
147 m_Prototext = R"(
148 ir_version: 3
149 producer_name: "CNTK"
150 producer_version: "2.5.1"
151 domain: "ai.cntk"
152 model_version: 1
153 graph {
154 name: "CNTKGraph"
155 input {
156 name: "Input"
157 type {
158 tensor_type {
159 elem_type: )" + dataType + R"(
160 shape {
161 dim {
162 dim_value: 2
163 }
164 dim {
165 dim_value: 2
166 }
167 dim {
168 dim_value: 3
169 }
170 dim {
171 dim_value: 3
172 }
173 }
174 }
175 }
176 }
177 node {
178 input: "Input"
179 output: "Output"
180 name: "flatten"
181 op_type: "Flatten"
182 attribute {
183 name: "axis"
184 i: 0
185 type: INT
186 }
187 }
188 output {
189 name: "Output"
190 type {
191 tensor_type {
192 elem_type: 1
193 shape {
194 dim {
195 dim_value: 1
196 }
197 dim {
198 dim_value: 36
199 }
200 }
201 }
202 }
203 }
204 }
205 opset_import {
206 version: 7
207 })";
208 }
209};
210
211struct FlattenNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
212{
213 FlattenNegativeAxisFixture(const std::string& dataType)
214 {
215 m_Prototext = R"(
216 ir_version: 3
217 producer_name: "CNTK"
218 producer_version: "2.5.1"
219 domain: "ai.cntk"
220 model_version: 1
221 graph {
222 name: "CNTKGraph"
223 input {
224 name: "Input"
225 type {
226 tensor_type {
227 elem_type: )" + dataType + R"(
228 shape {
229 dim {
230 dim_value: 2
231 }
232 dim {
233 dim_value: 2
234 }
235 dim {
236 dim_value: 3
237 }
238 dim {
239 dim_value: 3
240 }
241 }
242 }
243 }
244 }
245 node {
246 input: "Input"
247 output: "Output"
248 name: "flatten"
249 op_type: "Flatten"
250 attribute {
251 name: "axis"
252 i: -1
253 type: INT
254 }
255 }
256 output {
257 name: "Output"
258 type {
259 tensor_type {
260 elem_type: 1
261 shape {
262 dim {
263 dim_value: 12
264 }
265 dim {
266 dim_value: 3
267 }
268 }
269 }
270 }
271 }
272 }
273 opset_import {
274 version: 7
275 })";
276 }
277};
278
279struct FlattenInvalidNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
280{
281 FlattenInvalidNegativeAxisFixture(const std::string& dataType)
282 {
283 m_Prototext = R"(
284 ir_version: 3
285 producer_name: "CNTK"
286 producer_version: "2.5.1"
287 domain: "ai.cntk"
288 model_version: 1
289 graph {
290 name: "CNTKGraph"
291 input {
292 name: "Input"
293 type {
294 tensor_type {
295 elem_type: )" + dataType + R"(
296 shape {
297 dim {
298 dim_value: 2
299 }
300 dim {
301 dim_value: 2
302 }
303 dim {
304 dim_value: 3
305 }
306 dim {
307 dim_value: 3
308 }
309 }
310 }
311 }
312 }
313 node {
314 input: "Input"
315 output: "Output"
316 name: "flatten"
317 op_type: "Flatten"
318 attribute {
319 name: "axis"
320 i: -5
321 type: INT
322 }
323 }
324 output {
325 name: "Output"
326 type {
327 tensor_type {
328 elem_type: 1
329 shape {
330 dim {
331 dim_value: 12
332 }
333 dim {
334 dim_value: 3
335 }
336 }
337 }
338 }
339 }
340 }
341 opset_import {
342 version: 7
343 })";
344 }
345};
346
347struct FlattenValidFixture : FlattenMainFixture
348{
349 FlattenValidFixture() : FlattenMainFixture("1") {
350 Setup();
351 }
352};
353
354struct FlattenDefaultValidFixture : FlattenDefaultAxisFixture
355{
356 FlattenDefaultValidFixture() : FlattenDefaultAxisFixture("1") {
357 Setup();
358 }
359};
360
361struct FlattenAxisZeroValidFixture : FlattenAxisZeroFixture
362{
363 FlattenAxisZeroValidFixture() : FlattenAxisZeroFixture("1") {
364 Setup();
365 }
366};
367
368struct FlattenNegativeAxisValidFixture : FlattenNegativeAxisFixture
369{
370 FlattenNegativeAxisValidFixture() : FlattenNegativeAxisFixture("1") {
371 Setup();
372 }
373};
374
375struct FlattenInvalidFixture : FlattenMainFixture
376{
377 FlattenInvalidFixture() : FlattenMainFixture("10") { }
378};
379
380struct FlattenInvalidAxisFixture : FlattenInvalidNegativeAxisFixture
381{
382 FlattenInvalidAxisFixture() : FlattenInvalidNegativeAxisFixture("1") { }
383};
384
385BOOST_FIXTURE_TEST_CASE(ValidFlattenTest, FlattenValidFixture)
386{
387 RunTest<2>({{"Input",
388 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
389 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
390 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
391 {{"Output",
392 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
393 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
394 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
395}
396
397BOOST_FIXTURE_TEST_CASE(ValidFlattenDefaultTest, FlattenDefaultValidFixture)
398{
399 RunTest<2>({{"Input",
400 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
401 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
402 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
403 {{"Output",
404 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
405 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
406 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
407}
408
409BOOST_FIXTURE_TEST_CASE(ValidFlattenAxisZeroTest, FlattenAxisZeroValidFixture)
410{
411 RunTest<2>({{"Input",
412 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
413 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
414 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
415 {{"Output",
416 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
417 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
418 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
419}
420
421BOOST_FIXTURE_TEST_CASE(ValidFlattenNegativeAxisTest, FlattenNegativeAxisValidFixture)
422{
423 RunTest<2>({{"Input",
424 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
425 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
426 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
427 {{"Output",
428 { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
429 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
430 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
431}
432
433BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeFlatten, FlattenInvalidFixture)
434{
435 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
436}
437
438BOOST_FIXTURE_TEST_CASE(IncorrectAxisFlatten, FlattenInvalidAxisFixture)
439{
440 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
441}
442
443BOOST_AUTO_TEST_SUITE_END()