blob: 11a5d1eb87e4c8de40478e66d614194ee5195f59 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnOnnxParser/IOnnxParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12struct SimpleConv2DFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13{
14 SimpleConv2DFixture()
15 {
16 m_Prototext = R"(
17 ir_version: 3
18 producer_name: "CNTK"
19 producer_version: "2.5.1"
20 domain: "ai.cntk"
21 model_version: 1
22 graph {
23 name: "CNTKGraph"
24 input {
25 name: "Input"
26 type {
27 tensor_type {
28 elem_type: FLOAT
29 shape {
30 dim {
31 dim_value: 1
32 }
33 dim {
34 dim_value: 1
35 }
36 dim {
37 dim_value: 3
38 }
39 dim {
40 dim_value: 3
41 }
42 }
43 }
44 }
45 }
46 input {
47 name: "Weight"
48 type {
49 tensor_type {
50 elem_type: FLOAT
51 shape {
52 dim {
53 dim_value: 1
54 }
55 dim {
56 dim_value: 1
57 }
58 dim {
59 dim_value: 3
60 }
61 dim {
62 dim_value: 3
63 }
64 }
65 }
66 }
67 }
68 initializer {
69 dims: 1
70 dims: 1
71 dims: 3
72 dims: 3
73 data_type: FLOAT
74 float_data: 2
75 float_data: 1
76 float_data: 0
77 float_data: 6
78 float_data: 2
79 float_data: 1
80 float_data: 4
81 float_data: 1
82 float_data: 2
83 name: "Weight"
84 }
85 node {
86 input: "Input"
87 input: "Weight"
88 output: "Output"
89 name: "Convolution"
90 op_type: "Conv"
91 attribute {
92 name: "kernel_shape"
93 ints: 3
94 ints: 3
95 type: INTS
96 }
97 attribute {
98 name: "strides"
99 ints: 1
100 ints: 1
101 type: INTS
102 }
103 attribute {
104 name: "auto_pad"
105 s: "VALID"
106 type: STRING
107 }
108 attribute {
109 name: "group"
110 i: 1
111 type: INT
112 }
113 attribute {
114 name: "dilations"
115 ints: 1
116 ints: 1
117 type: INTS
118 }
119 doc_string: ""
120 domain: ""
121 }
122 output {
123 name: "Output"
124 type {
125 tensor_type {
126 elem_type: FLOAT
127 shape {
128 dim {
129 dim_value: 1
130 }
131 dim {
132 dim_value: 1
133 }
134 dim {
135 dim_value: 1
136 }
137 dim {
138 dim_value: 1
139 }
140 }
141 }
142 }
143 }
144 }
145 opset_import {
146 version: 7
147 })";
148 Setup();
149 }
150};
151
152struct Conv2DWithBiasesFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
153{
154 Conv2DWithBiasesFixture() {
155 m_Prototext = R"(
156 ir_version: 3
157 producer_name: "CNTK"
158 producer_version: "2.5.1"
159 domain: "ai.cntk"
160 model_version: 1
161 graph {
162 name: "CNTKGraph"
163 input {
164 name: "Input"
165 type {
166 tensor_type {
167 elem_type: FLOAT
168 shape {
169 dim {
170 dim_value: 1
171 }
172 dim {
173 dim_value: 1
174 }
175 dim {
176 dim_value: 2
177 }
178 dim {
179 dim_value: 2
180 }
181 }
182 }
183 }
184 }
185 input {
186 name: "Weight"
187 type {
188 tensor_type {
189 elem_type: FLOAT
190 shape {
191 dim {
192 dim_value: 1
193 }
194 dim {
195 dim_value: 1
196 }
197 dim {
198 dim_value: 2
199 }
200 dim {
201 dim_value: 2
202 }
203 }
204 }
205 }
206 }
207 initializer {
208 dims: 1
209 dims: 1
210 dims: 2
211 dims: 2
212 data_type: FLOAT
213 float_data: 2
214 float_data: 1
215 float_data: 0
216 float_data: 6
217 name: "Weight"
218 }
219 input {
220 name: "Bias"
221 type {
222 tensor_type {
223 elem_type: FLOAT
224 shape {
225 dim {
226 dim_value: 4
227 }
228 }
229 }
230 }
231 }
232 initializer {
233 dims: 4
234 data_type: FLOAT
235 float_data: 10
236 float_data: 0
237 float_data: 0
238 float_data: 0
239 name: "Bias"
240 }
241 node {
242 input: "Input"
243 input: "Weight"
244 input: "Bias"
245 output: "Output"
246 name: "Convolution"
247 op_type: "Conv"
248 attribute {
249 name: "kernel_shape"
250 ints: 2
251 ints: 2
252 type: INTS
253 }
254 attribute {
255 name: "strides"
256 ints: 1
257 ints: 1
258 type: INTS
259 }
260 attribute {
261 name: "auto_pad"
262 s: "SAME_UPPER"
263 type: STRING
264 }
265 attribute {
266 name: "group"
267 i: 1
268 type: INT
269 }
270 attribute {
271 name: "dilations"
272 ints: 1
273 ints: 1
274 type: INTS
275 }
276 doc_string: ""
277 domain: ""
278 }
279 output {
280 name: "Output"
281 type {
282 tensor_type {
283 elem_type: FLOAT
284 shape {
285 dim {
286 dim_value: 1
287 }
288 dim {
289 dim_value: 1
290 }
291 dim {
292 dim_value: 2
293 }
294 dim {
295 dim_value: 2
296 }
297 }
298 }
299 }
300 }
301 }
302 opset_import {
303 version: 7
304 })";
305 Setup();
306 }
307};
308
309
310struct Conv2DDimReducingFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
311{
312 Conv2DDimReducingFixture() {
313 m_Prototext = R"(
314 ir_version: 3
315 producer_name: "CNTK"
316 producer_version: "2.5.1"
317 domain: "ai.cntk"
318 model_version: 1
319 graph {
320 name: "CNTKGraph"
321 input {
322 name: "Input"
323 type {
324 tensor_type {
325 elem_type: FLOAT
326 shape {
327 dim {
328 dim_value: 1
329 }
330 dim {
331 dim_value: 3
332 }
333 dim {
334 dim_value: 2
335 }
336 dim {
337 dim_value: 2
338 }
339 }
340 }
341 }
342 }
343 input {
344 name: "Weight"
345 type {
346 tensor_type {
347 elem_type: FLOAT
348 shape {
349 dim {
350 dim_value: 2
351 }
352 dim {
353 dim_value: 3
354 }
355 dim {
356 dim_value: 1
357 }
358 dim {
359 dim_value: 1
360 }
361 }
362 }
363 }
364 }
365 initializer {
366 dims: 2
367 dims: 3
368 dims: 1
369 dims: 1
370 data_type: FLOAT
371 float_data: -1
372 float_data: 2
373 float_data: 0
374 float_data: 1
375 float_data: 0
376 float_data: 0
377 name: "Weight"
378 }
379 node {
380 input: "Input"
381 input: "Weight"
382 output: "Output"
383 name: "Convolution"
384 op_type: "Conv"
385 attribute {
386 name: "kernel_shape"
387 ints: 1
388 ints: 1
389 type: INTS
390 }
391 attribute {
392 name: "strides"
393 ints: 1
394 ints: 1
395 type: INTS
396 }
397 attribute {
398 name: "group"
399 i: 1
400 type: INT
401 }
402 attribute {
403 name: "dilations"
404 ints: 1
405 ints: 1
406 type: INTS
407 }
408 doc_string: ""
409 domain: ""
410 }
411 output {
412 name: "Output"
413 type {
414 tensor_type {
415 elem_type: FLOAT
416 shape {
417 dim {
418 dim_value: 1
419 }
420 dim {
421 dim_value: 2
422 }
423 dim {
424 dim_value: 2
425 }
426 dim {
427 dim_value: 2
428 }
429 }
430 }
431 }
432 }
433 }
434 opset_import {
435 version: 7
436 })";
437 Setup();
438 }
439};
440
441BOOST_FIXTURE_TEST_CASE(ValidConvTest, SimpleConv2DFixture)
442{
443 RunTest<4>({{"Input", {1.0, 2.0, 3.0,
444 4.0, 5.0, 6.0,
445 7.0, 8.0, 9.0}}},
446 {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 +
447 4.0 * 6 + 5.0 * 2 + 6.0 * 1 +
448 7.0 * 4 + 8.0 * 1 + 9.0 * 2}}});
449}
450
451BOOST_FIXTURE_TEST_CASE(ValidConvWithBiasTest, Conv2DWithBiasesFixture)
452{
453 RunTest<4>({{"Input", {1.0, 2.0,
454 3.0, 4.0}}},
455 {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 + 4 * 6 + 10,
456 2.0 * 2 + 0 * 1 + 4.0 * 0 + 0 * 6 + 10,
457 3.0 * 2 + 4.0 * 1 + 0 * 0 + 0 * 6 + 10,
458 4.0 * 2 + 0 * 1 + 0 * 0 + 0 * 6 + 10}}});
459}
460
461BOOST_FIXTURE_TEST_CASE(ValidConvDimReducTest, Conv2DDimReducingFixture)
462{
463 RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, -1, -2, 3, 4, 1 , 1, 1, 1 }}},
464 {{"Output", {-1 * 1 + 2 * -1, -1 * 2 + 2 * -2,
465 -1 * 3 + 2 * 3, -1 * 4 + 2 * 4,
466 1, 2, 3, 4}}});
467}
468
469BOOST_AUTO_TEST_SUITE_END()