blob: 119a406d7ec9bab620b6e253bd59988bce8a2e86 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnOnnxParser/IOnnxParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12struct ReshapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13{
14 ReshapeMainFixture(const std::string& dataType)
15 {
16 m_Prototext = R"(
17 ir_version: 3
18 producer_name: "CNTK"
19 producer_version: "2.5.1"
20 domain: "ai.cntk"
21 model_version: 1
22 graph {
23 name: "CNTKGraph"
24 input {
25 name: "Input"
26 type {
27 tensor_type {
28 elem_type: )" + dataType + R"(
29 shape {
30 dim {
31 dim_value: 4
32 }
33 }
34 }
35 }
36 }
37 input {
38 name: "Shape"
39 type {
40 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000041 elem_type: 7
telsoa01c577f2c2018-08-31 09:22:23 +010042 shape {
43 dim {
44 dim_value: 2
45 }
46 }
47 }
48 }
49 }
50 node {
51 input: "Input"
52 input: "Shape"
53 output: "Output"
54 name: "reshape"
55 op_type: "Reshape"
56
57 }
58 initializer {
59 dims: 2
Matteo Martincigh44a71672018-12-11 13:46:52 +000060 data_type: 7
telsoa01c577f2c2018-08-31 09:22:23 +010061 int64_data: 2
62 int64_data: 2
63 name: "Shape"
64 }
65 output {
66 name: "Output"
67 type {
68 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000069 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010070 shape {
71 dim {
72 dim_value: 2
73 }
74 dim {
75 dim_value: 2
76 }
77 }
78 }
79 }
80 }
81 }
82 opset_import {
83 version: 7
84 })";
85 }
86};
87
Ryan OSheaed27ee72020-04-22 16:37:29 +010088struct ReshapeRank4Fixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
89{
90 ReshapeRank4Fixture(const std::string& dataType)
91 {
92 m_Prototext = R"(
93 ir_version: 3
94 producer_name: "CNTK"
95 producer_version: "2.5.1"
96 domain: "ai.cntk"
97 model_version: 1
98 graph {
99 name: "CNTKGraph"
100 input {
101 name: "Input"
102 type {
103 tensor_type {
104 elem_type: )" + dataType + R"(
105 shape {
106 dim {
107 dim_value: 2
108 }
109 dim {
110 dim_value: 2
111 }
112 dim {
113 dim_value: 3
114 }
115 dim {
116 dim_value: 3
117 }
118 }
119 }
120 }
121 }
122 input {
123 name: "Shape"
124 type {
125 tensor_type {
126 elem_type: 7
127 shape {
128 dim {
129 dim_value: 2
130 }
131 }
132 }
133 }
134 }
135 node {
136 input: "Input"
137 input: "Shape"
138 output: "Output"
139 name: "reshape"
140 op_type: "Reshape"
141
142 }
143 initializer {
144 dims: 2
145 data_type: 7
146 int64_data: 2
147 int64_data: 2
148 name: "Shape"
149 }
150 output {
151 name: "Output"
152 type {
153 tensor_type {
154 elem_type: 1
155 shape {
156 dim {
157 dim_value: 6
158 }
159 dim {
160 dim_value: 6
161 }
162 }
163 }
164 }
165 }
166 }
167 opset_import {
168 version: 7
169 })";
170 }
171};
172
telsoa01c577f2c2018-08-31 09:22:23 +0100173struct ReshapeValidFixture : ReshapeMainFixture
174{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000175 ReshapeValidFixture() : ReshapeMainFixture("1") {
telsoa01c577f2c2018-08-31 09:22:23 +0100176 Setup();
177 }
178};
179
Ryan OSheaed27ee72020-04-22 16:37:29 +0100180struct ReshapeValidRank4Fixture : ReshapeRank4Fixture
181{
182 ReshapeValidRank4Fixture() : ReshapeRank4Fixture("1") {
183 Setup();
184 }
185};
186
telsoa01c577f2c2018-08-31 09:22:23 +0100187struct ReshapeInvalidFixture : ReshapeMainFixture
188{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000189 ReshapeInvalidFixture() : ReshapeMainFixture("10") { }
telsoa01c577f2c2018-08-31 09:22:23 +0100190};
191
192BOOST_FIXTURE_TEST_CASE(ValidReshapeTest, ReshapeValidFixture)
193{
194 RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}});
195}
196
Ryan OSheaed27ee72020-04-22 16:37:29 +0100197BOOST_FIXTURE_TEST_CASE(ValidRank4ReshapeTest, ReshapeValidRank4Fixture)
198{
199 RunTest<2>(
200 {{"Input",
201 {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
202 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
203 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}},
204 {{"Output",
205 {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
206 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
207 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}});
208}
209
telsoa01c577f2c2018-08-31 09:22:23 +0100210BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeReshape, ReshapeInvalidFixture)
211{
212 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
213}
214
215BOOST_AUTO_TEST_SUITE_END()