blob: 97198761e5b9f950f104d8f1e473a3b8fa1c5cdc [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "armnnOnnxParser/IOnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
Narumol Prangnawarat4b536e32021-10-18 12:35:19 +01008#include "OnnxParserTestUtils.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01009
Sadik Armagan1625efc2021-06-10 18:24:34 +010010TEST_SUITE("OnnxParser_Reshape")
11{
telsoa01c577f2c2018-08-31 09:22:23 +010012struct ReshapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13{
14 ReshapeMainFixture(const std::string& dataType)
15 {
16 m_Prototext = R"(
17 ir_version: 3
18 producer_name: "CNTK"
19 producer_version: "2.5.1"
20 domain: "ai.cntk"
21 model_version: 1
22 graph {
23 name: "CNTKGraph"
24 input {
25 name: "Input"
26 type {
27 tensor_type {
28 elem_type: )" + dataType + R"(
29 shape {
30 dim {
31 dim_value: 4
32 }
33 }
34 }
35 }
36 }
37 input {
38 name: "Shape"
39 type {
40 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000041 elem_type: 7
telsoa01c577f2c2018-08-31 09:22:23 +010042 shape {
43 dim {
44 dim_value: 2
45 }
46 }
47 }
48 }
49 }
50 node {
51 input: "Input"
52 input: "Shape"
53 output: "Output"
54 name: "reshape"
55 op_type: "Reshape"
56
57 }
58 initializer {
59 dims: 2
Matteo Martincigh44a71672018-12-11 13:46:52 +000060 data_type: 7
telsoa01c577f2c2018-08-31 09:22:23 +010061 int64_data: 2
62 int64_data: 2
63 name: "Shape"
64 }
65 output {
66 name: "Output"
67 type {
68 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000069 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010070 shape {
71 dim {
72 dim_value: 2
73 }
74 dim {
75 dim_value: 2
76 }
77 }
78 }
79 }
80 }
81 }
82 opset_import {
83 version: 7
84 })";
85 }
86};
87
Ryan OSheaed27ee72020-04-22 16:37:29 +010088struct ReshapeRank4Fixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
89{
90 ReshapeRank4Fixture(const std::string& dataType)
91 {
92 m_Prototext = R"(
93 ir_version: 3
94 producer_name: "CNTK"
95 producer_version: "2.5.1"
96 domain: "ai.cntk"
97 model_version: 1
98 graph {
99 name: "CNTKGraph"
100 input {
101 name: "Input"
102 type {
103 tensor_type {
104 elem_type: )" + dataType + R"(
105 shape {
106 dim {
107 dim_value: 2
108 }
109 dim {
110 dim_value: 2
111 }
112 dim {
113 dim_value: 3
114 }
115 dim {
116 dim_value: 3
117 }
118 }
119 }
120 }
121 }
122 input {
123 name: "Shape"
124 type {
125 tensor_type {
126 elem_type: 7
127 shape {
128 dim {
129 dim_value: 2
130 }
131 }
132 }
133 }
134 }
135 node {
136 input: "Input"
137 input: "Shape"
138 output: "Output"
139 name: "reshape"
140 op_type: "Reshape"
141
142 }
143 initializer {
144 dims: 2
145 data_type: 7
146 int64_data: 2
147 int64_data: 2
148 name: "Shape"
149 }
150 output {
151 name: "Output"
152 type {
153 tensor_type {
154 elem_type: 1
155 shape {
156 dim {
157 dim_value: 6
158 }
159 dim {
160 dim_value: 6
161 }
162 }
163 }
164 }
165 }
166 }
167 opset_import {
168 version: 7
169 })";
170 }
171};
172
telsoa01c577f2c2018-08-31 09:22:23 +0100173struct ReshapeValidFixture : ReshapeMainFixture
174{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000175 ReshapeValidFixture() : ReshapeMainFixture("1") {
telsoa01c577f2c2018-08-31 09:22:23 +0100176 Setup();
177 }
178};
179
Ryan OSheaed27ee72020-04-22 16:37:29 +0100180struct ReshapeValidRank4Fixture : ReshapeRank4Fixture
181{
182 ReshapeValidRank4Fixture() : ReshapeRank4Fixture("1") {
183 Setup();
184 }
185};
186
telsoa01c577f2c2018-08-31 09:22:23 +0100187struct ReshapeInvalidFixture : ReshapeMainFixture
188{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000189 ReshapeInvalidFixture() : ReshapeMainFixture("10") { }
telsoa01c577f2c2018-08-31 09:22:23 +0100190};
191
Sadik Armagan1625efc2021-06-10 18:24:34 +0100192TEST_CASE_FIXTURE(ReshapeValidFixture, "ValidReshapeTest")
telsoa01c577f2c2018-08-31 09:22:23 +0100193{
194 RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}});
195}
196
Sadik Armagan1625efc2021-06-10 18:24:34 +0100197TEST_CASE_FIXTURE(ReshapeValidRank4Fixture, "ValidRank4ReshapeTest")
Ryan OSheaed27ee72020-04-22 16:37:29 +0100198{
199 RunTest<2>(
200 {{"Input",
201 {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
202 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
203 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}},
204 {{"Output",
205 {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
206 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
207 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}});
208}
209
Sadik Armagan1625efc2021-06-10 18:24:34 +0100210TEST_CASE_FIXTURE(ReshapeInvalidFixture, "IncorrectDataTypeReshape")
telsoa01c577f2c2018-08-31 09:22:23 +0100211{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100212 CHECK_THROWS_AS(Setup(), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100213}
214
Narumol Prangnawarat4b536e32021-10-18 12:35:19 +0100215struct ReshapeNegativeReshapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
216{
217 ReshapeNegativeReshapeFixture(const std::vector<int>& inputShape,
218 const std::vector<int>& shapeInputShape,
219 const std::vector<int>& outputShape,
220 const std::string& shape)
221 {
222 m_Prototext = R"(
223 ir_version: 3
224 producer_name: "onnx-example"
225 graph {
226 name: "ReshapeGrapn"
227 input {
228 name: "Input"
229 type {
230 tensor_type {
231 elem_type: 1
232 shape {
233 )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
234 }
235 }
236 }
237 }
238 input {
239 name: "Shape"
240 type {
241 tensor_type {
242 elem_type: 7
243 shape {
244 )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"(
245 }
246 }
247 }
248 }
249 node {
250 input: "Input"
251 input: "Shape"
252 output: "Output"
253 name: "reshape"
254 op_type: "Reshape"
255 }
256 initializer {
257 dims: 2
258 data_type: 7
259 )" + shape + R"(
260 name: "Shape"
261 }
262 output {
263 name: "Output"
264 type {
265 tensor_type {
266 elem_type: 1
267 shape {
268 )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
269 }
270 }
271 }
272 }
273 }
274 opset_import {
275 version: 7
276 })";
277 }
278};
279
280struct ReshapeNegativeReshape1DFixture : ReshapeNegativeReshapeFixture
281{
282 ReshapeNegativeReshape1DFixture() : ReshapeNegativeReshapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 }, "int64_data: -1")
283 {
284 Setup();
285 }
286};
287
288struct ReshapeNegativeReshape2DFixture : ReshapeNegativeReshapeFixture
289{
290 ReshapeNegativeReshape2DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 },
291 { 2 },
292 { 2, 6 },
293 "int64_data: -1 int64_data: 6")
294 {
295 Setup();
296 }
297};
298
299struct ReshapeNegativeReshape3DFixture : ReshapeNegativeReshapeFixture
300{
301 ReshapeNegativeReshape3DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 },
302 { 3 },
303 { 3, 1, 4 },
304 "int64_data: 3 int64_data: -1 int64_data: 4")
305 {
306 Setup();
307 }
308};
309
310struct ReshapeNegativeReshape4DFixture : ReshapeNegativeReshapeFixture
311{
312 ReshapeNegativeReshape4DFixture() : ReshapeNegativeReshapeFixture(
313 { 2, 3, 1, 2 },
314 { 4 },
315 { 3, 1, 2, 2 },
316 "int64_data: 3 int64_data: 1 int64_data: 2 int64_data: -1")
317 {
318 Setup();
319 }
320};
321
322TEST_CASE_FIXTURE(ReshapeNegativeReshape1DFixture, "ReshapeNegativeReshape1DTest")
323{
324 RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
325 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
326}
327
328TEST_CASE_FIXTURE(ReshapeNegativeReshape2DFixture, "ReshapeNegativeReshape2DTest")
329{
330 RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
331 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
332 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
333 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
334}
335
336TEST_CASE_FIXTURE(ReshapeNegativeReshape3DFixture, "ReshapeNegativeReshape3DTest")
337{
338 RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
339 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
340 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
341 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
342}
343
344TEST_CASE_FIXTURE(ReshapeNegativeReshape4DFixture, "ReshapeNegativeReshape4DTest")
345{
346 RunTest<4, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
347 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
348 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
349 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
350}
351
352struct ReshapeNonConstShapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
353{
354 ReshapeNonConstShapeFixture(const std::vector<int>& inputShape,
355 const std::vector<int>& shapeInputShape,
356 const std::vector<int>& outputShape)
357 {
358 m_Prototext = R"(
359 ir_version: 3
360 producer_name: "onnx-example"
361 graph {
362 name: "ReshapeGrapn"
363 input {
364 name: "Input"
365 type {
366 tensor_type {
367 elem_type: 1
368 shape {
369 )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
370 }
371 }
372 }
373 }
374 input {
375 name: "Shape"
376 type {
377 tensor_type {
378 elem_type: 7
379 shape {
380 )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"(
381 }
382 }
383 }
384 }
385 node {
386 input: "Input"
387 input: "Shape"
388 output: "Output"
389 name: "reshape"
390 op_type: "Reshape"
391 }
392 output {
393 name: "Output"
394 type {
395 tensor_type {
396 elem_type: 1
397 shape {
398 )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
399 }
400 }
401 }
402 }
403 }
404 opset_import {
405 version: 7
406 })";
407 }
408};
409
410struct ReshapeNonConst1DShapeFixture : ReshapeNonConstShapeFixture
411{
412 ReshapeNonConst1DShapeFixture() : ReshapeNonConstShapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 })
413 {
414 Setup();
415 }
416};
417
418struct ReshapeNonConst2DShapeFixture : ReshapeNonConstShapeFixture
419{
420 ReshapeNonConst2DShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 2 }, { 2, 12 })
421 {
422 Setup();
423 }
424};
425
426struct ReshapeInvalidNonConstShapeFixture : ReshapeNonConstShapeFixture
427{
428 ReshapeInvalidNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 3 }, { 2, 3, 4 })
429 {
430 }
431};
432
433struct ReshapeInvalidDimNonConstShapeFixture : ReshapeNonConstShapeFixture
434{
435 ReshapeInvalidDimNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 1, 2 }, { 2, 3, 4 })
436 {
437 }
438};
439
440TEST_CASE_FIXTURE(ReshapeNonConst1DShapeFixture, "ReshapeNonConst1DShapeTest")
441{
442 RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
443 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
444}
445
446TEST_CASE_FIXTURE(ReshapeNonConst2DShapeFixture, "ReshapeNonConst2DShapeTest")
447{
448 RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
449 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
450 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
451 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}},
452 {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
453 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
454 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
455 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}});
456}
457
458TEST_CASE_FIXTURE(ReshapeInvalidNonConstShapeFixture, "ReshapeInvalidNonConstShapeTest")
459{
460 CHECK_THROWS_AS(Setup(), armnn::ParseException);
461}
462
463TEST_CASE_FIXTURE(ReshapeInvalidDimNonConstShapeFixture, "ReshapeInvalidDimNonConstShapeTest")
464{
465 CHECK_THROWS_AS(Setup(), armnn::ParseException);
466}
467
Sadik Armagan1625efc2021-06-10 18:24:34 +0100468}