blob: 8e2f0fee0000f1cfe722f81b80bcfbc13e452707 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnOnnxParser/IOnnxParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12struct PoolingMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13{
14 PoolingMainFixture(const std::string& dataType, const std::string& op)
15 {
16 m_Prototext = R"(
17 ir_version: 3
18 producer_name: "CNTK"
19 producer_version: "2.5.1"
20 domain: "ai.cntk"
21 model_version: 1
22 graph {
23 name: "CNTKGraph"
24 input {
25 name: "Input"
26 type {
27 tensor_type {
28 elem_type: )" + dataType + R"(
29 shape {
30 dim {
31 dim_value: 1
32 }
33 dim {
34 dim_value: 1
35 }
36 dim {
37 dim_value: 2
38 }
39 dim {
40 dim_value: 2
41 }
42 }
43 }
44 }
45 }
46 node {
47 input: "Input"
48 output: "Output"
49 name: "Pooling"
50 op_type: )" + op + R"(
51 attribute {
52 name: "kernel_shape"
53 ints: 2
54 ints: 2
55 type: INTS
56 }
57 attribute {
58 name: "strides"
59 ints: 1
60 ints: 1
61 type: INTS
62 }
63 attribute {
64 name: "pads"
65 ints: 0
66 ints: 0
67 ints: 0
68 ints: 0
69 type: INTS
70 }
71 }
72 output {
73 name: "Output"
74 type {
75 tensor_type {
76 elem_type: FLOAT
77 shape {
78 dim {
79 dim_value: 1
80 }
81 dim {
82 dim_value: 1
83 }
84 dim {
85 dim_value: 1
86 }
87 dim {
88 dim_value: 1
89 }
90 }
91 }
92 }
93 }
94 }
95 opset_import {
96 version: 7
97 })";
98 }
99};
100
101struct MaxPoolValidFixture : PoolingMainFixture
102{
103 MaxPoolValidFixture() : PoolingMainFixture("FLOAT", "\"MaxPool\"") {
104 Setup();
105 }
106};
107
108struct MaxPoolInvalidFixture : PoolingMainFixture
109{
110 MaxPoolInvalidFixture() : PoolingMainFixture("FLOAT16", "\"MaxPool\"") { }
111};
112
113BOOST_FIXTURE_TEST_CASE(ValidMaxPoolTest, MaxPoolValidFixture)
114{
115 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {3.0f}}});
116}
117
118struct AvgPoolValidFixture : PoolingMainFixture
119{
120 AvgPoolValidFixture() : PoolingMainFixture("FLOAT", "\"AveragePool\"") {
121 Setup();
122 }
123};
124
125struct PoolingWithPadFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
126{
127 PoolingWithPadFixture()
128 {
129 m_Prototext = R"(
130 ir_version: 3
131 producer_name: "CNTK"
132 producer_version: "2.5.1"
133 domain: "ai.cntk"
134 model_version: 1
135 graph {
136 name: "CNTKGraph"
137 input {
138 name: "Input"
139 type {
140 tensor_type {
141 elem_type: FLOAT
142 shape {
143 dim {
144 dim_value: 1
145 }
146 dim {
147 dim_value: 1
148 }
149 dim {
150 dim_value: 2
151 }
152 dim {
153 dim_value: 2
154 }
155 }
156 }
157 }
158 }
159 node {
160 input: "Input"
161 output: "Output"
162 name: "Pooling"
163 op_type: "AveragePool"
164 attribute {
165 name: "kernel_shape"
166 ints: 4
167 ints: 4
168 type: INTS
169 }
170 attribute {
171 name: "strides"
172 ints: 1
173 ints: 1
174 type: INTS
175 }
176 attribute {
177 name: "pads"
178 ints: 1
179 ints: 1
180 ints: 1
181 ints: 1
182 type: INTS
183 }
184 attribute {
185 name: "count_include_pad"
186 i: 1
187 type: INT
188 }
189 }
190 output {
191 name: "Output"
192 type {
193 tensor_type {
194 elem_type: FLOAT
195 shape {
196 dim {
197 dim_value: 1
198 }
199 dim {
200 dim_value: 1
201 }
202 dim {
203 dim_value: 1
204 }
205 dim {
206 dim_value: 1
207 }
208 }
209 }
210 }
211 }
212 }
213 opset_import {
214 version: 7
215 })";
216 Setup();
217 }
218};
219
220BOOST_FIXTURE_TEST_CASE(AveragePoolValid, AvgPoolValidFixture)
221{
222 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {0.5}}});
223}
224
225BOOST_FIXTURE_TEST_CASE(ValidAvgWithPadTest, PoolingWithPadFixture)
226{
227 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {1.0/8.0}}});
228}
229
230struct GlobalAvgFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
231{
232 GlobalAvgFixture()
233 {
234 m_Prototext = R"(
235 ir_version: 3
236 producer_name: "CNTK"
237 producer_version: "2.5.1"
238 domain: "ai.cntk"
239 model_version: 1
240 graph {
241 name: "CNTKGraph"
242 input {
243 name: "Input"
244 type {
245 tensor_type {
246 elem_type: FLOAT
247 shape {
248 dim {
249 dim_value: 1
250 }
251 dim {
252 dim_value: 2
253 }
254 dim {
255 dim_value: 2
256 }
257 dim {
258 dim_value: 2
259 }
260 }
261 }
262 }
263 }
264 node {
265 input: "Input"
266 output: "Output"
267 name: "Pooling"
268 op_type: "GlobalAveragePool"
269 }
270 output {
271 name: "Output"
272 type {
273 tensor_type {
274 elem_type: FLOAT
275 shape {
276 dim {
277 dim_value: 1
278 }
279 dim {
280 dim_value: 2
281 }
282 dim {
283 dim_value: 1
284 }
285 dim {
286 dim_value: 1
287 }
288 }
289 }
290 }
291 }
292 }
293 opset_import {
294 version: 7
295 })";
296 Setup();
297 }
298};
299
300BOOST_FIXTURE_TEST_CASE(GlobalAvgTest, GlobalAvgFixture)
301{
302 RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}}}, {{"Output", {10/4.0, 26/4.0}}});
303}
304
305BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeMaxPool, MaxPoolInvalidFixture)
306{
307 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
308}
309
310BOOST_AUTO_TEST_SUITE_END()