blob: 6fc8eb115107c45e24f6d4a9ed7030936b44276a [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnOnnxParser/IOnnxParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12struct AddMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13{
14 AddMainFixture(const std::string& dataType)
15 {
16 m_Prototext = R"(
17 ir_version: 3
18 producer_name: "CNTK"
19 producer_version: "2.5.1"
20 domain: "ai.cntk"
21 model_version: 1
22 graph {
23 name: "CNTKGraph"
24 input {
25 name: "Input0"
26 type {
27 tensor_type {
28 elem_type: )" + dataType + R"(
29 shape {
30 dim {
31 dim_value: 1
32 }
33 dim {
34 dim_value: 1
35 }
36 dim {
37 dim_value: 2
38 }
39 dim {
40 dim_value: 2
41 }
42 }
43 }
44 }
45 }
46 input {
47 name: "Input1"
48 type {
49 tensor_type {
50 elem_type: )" + dataType + R"(
51 shape {
52 dim {
53 dim_value: 1
54 }
55 dim {
56 dim_value: 1
57 }
58 dim {
59 dim_value: 2
60 }
61 dim {
62 dim_value: 2
63 }
64 }
65 }
66 }
67 }
68 node {
69 input: "Input0"
70 input: "Input1"
71 output: "Output"
72 name: "addition"
73 op_type: "Add"
74 doc_string: ""
75 domain: ""
76 }
77 output {
78 name: "Output"
79 type {
80 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000081 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010082 shape {
83 dim {
84 dim_value: 1
85 }
86 dim {
87 dim_value: 1
88 }
89 dim {
90 dim_value: 2
91 }
92 dim {
93 dim_value: 2
94 }
95 }
96 }
97 }
98 }
99 }
100 opset_import {
101 version: 7
102 })";
103 }
104};
105
106struct AddValidFixture : AddMainFixture
107{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000108 AddValidFixture() : AddMainFixture("1") {
telsoa01c577f2c2018-08-31 09:22:23 +0100109 Setup();
110 }
111};
112
113struct AddInvalidFixture : AddMainFixture
114{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000115 AddInvalidFixture() : AddMainFixture("6") { }
telsoa01c577f2c2018-08-31 09:22:23 +0100116};
117
118struct AddValidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
119{
120 AddValidBroadcastFixture() {
121
122 m_Prototext = R"(
123 ir_version: 3
124 producer_name: "CNTK"
125 producer_version: "2.5.1"
126 domain: "ai.cntk"
127 model_version: 1
128 graph {
129 name: "CNTKGraph"
130 input {
131 name: "Input0"
132 type {
133 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000134 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100135 shape {
136 dim {
137 dim_value: 1
138 }
139 dim {
140 dim_value: 1
141 }
142 dim {
143 dim_value: 1
144 }
145 dim {
146 dim_value: 4
147 }
148 }
149 }
150 }
151 }
152 input {
153 name: "Input1"
154 type {
155 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000156 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100157 shape {
158 dim {
159 dim_value: 4
160 }
161 }
162 }
163 }
164 }
165 node {
166 input: "Input0"
167 input: "Input1"
168 output: "Output"
169 name: "addition"
170 op_type: "Add"
171 doc_string: ""
172 domain: ""
173 }
174 output {
175 name: "Output"
176 type {
177 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000178 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100179 shape {
180 dim {
181 dim_value: 1
182 }
183 dim {
184 dim_value: 1
185 }
186 dim {
187 dim_value: 1
188 }
189 dim {
190 dim_value: 4
191 }
192 }
193 }
194 }
195 }
196 }
197 opset_import {
198 version: 7
199 })";
200 Setup();
201 }
202};
203
204struct AddInvalidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
205{
206 AddInvalidBroadcastFixture() {
207
208 m_Prototext = R"(
209 ir_version: 3
210 producer_name: "CNTK"
211 producer_version: "2.5.1"
212 domain: "ai.cntk"
213 model_version: 1
214 graph {
215 name: "CNTKGraph"
216 input {
217 name: "Input0"
218 type {
219 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000220 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100221 shape {
222 dim {
223 dim_value: 1
224 }
225 dim {
226 dim_value: 1
227 }
228 dim {
229 dim_value: 1
230 }
231 dim {
232 dim_value: 3
233 }
234 }
235 }
236 }
237 }
238 input {
239 name: "Input1"
240 type {
241 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000242 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100243 shape {
244 dim {
245 dim_value: 4
246 }
247 }
248 }
249 }
250 }
251 node {
252 input: "Input0"
253 input: "Input1"
254 output: "Output"
255 name: "addition"
256 op_type: "Add"
257 doc_string: ""
258 domain: ""
259 }
260 output {
261 name: "Output"
262 type {
263 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000264 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100265 shape {
266 dim {
267 dim_value: 1
268 }
269 dim {
270 dim_value: 1
271 }
272 dim {
273 dim_value: 1
274 }
275 dim {
276 dim_value: 4
277 }
278 }
279 }
280 }
281 }
282 }
283 opset_import {
284 version: 7
285 })";
286 }
287};
288
Ryan OShea337c17f2020-02-21 12:33:17 +0000289struct AddScalarFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
290{
291 AddScalarFixture(const std::string& dataType)
292 {
293 m_Prototext = R"(
294 ir_version: 3
295 producer_name: "CNTK"
296 producer_version: "2.5.1"
297 domain: "ai.cntk"
298 model_version: 1
299 graph {
300 name: "CNTKGraph"
301 input {
302 name: "Input0"
303 type {
304 tensor_type {
305 elem_type: )" + dataType + R"(
306 shape {
307 dim {
308 dim_value: 1
309 }
310 dim {
311 dim_value: 1
312 }
313 dim {
314 dim_value: 2
315 }
316 dim {
317 dim_value: 2
318 }
319 }
320 }
321 }
322 }
323 input {
324 name: "Input1"
325 type {
326 tensor_type {
327 elem_type: )" + dataType + R"(
328 shape {
329 dim {
330 dim_value: 1
331 }
332 }
333 }
334 }
335 }
336 node {
337 input: "Input0"
338 input: "Input1"
339 output: "Output"
340 name: "addition"
341 op_type: "Add"
342 doc_string: ""
343 domain: ""
344 }
345 output {
346 name: "Output"
347 type {
348 tensor_type {
349 elem_type: 1
350 shape {
351 dim {
352 dim_value: 1
353 }
354 dim {
355 dim_value: 1
356 }
357 dim {
358 dim_value: 2
359 }
360 dim {
361 dim_value: 2
362 }
363 }
364 }
365 }
366 }
367 }
368 opset_import {
369 version: 7
370 })";
371 }
372};
373
374struct AddValidScalarFixture : AddScalarFixture
375{
376 AddValidScalarFixture() : AddScalarFixture("1") {
377 Setup();
378 }
379};
380
381struct AddInvalidScalarFixture : AddScalarFixture
382{
383 AddInvalidScalarFixture() : AddScalarFixture("6") { }
384};
385
telsoa01c577f2c2018-08-31 09:22:23 +0100386BOOST_FIXTURE_TEST_CASE(ValidAddTest, AddValidFixture)
387{
388 RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
389 {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}});
390}
391
392BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeAdd, AddInvalidFixture)
393{
394 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
395}
396
397BOOST_FIXTURE_TEST_CASE(InvalidBroadcastAdd, AddInvalidBroadcastFixture)
398{
399 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
400}
401
402BOOST_FIXTURE_TEST_CASE(ValidBroadcastAdd, AddValidBroadcastFixture)
403{
404 RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
405 {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}});
406}
407
Ryan OShea337c17f2020-02-21 12:33:17 +0000408BOOST_FIXTURE_TEST_CASE(ValidAddScalarTest, AddValidScalarFixture)
409{
410 RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
411 {"Input1", {-8.0f}}}, {{"Output", {-7.0, -6.0, -11.0, -12.0}}});
412}
413
414BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeAddScalar, AddInvalidScalarFixture)
415{
416 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
417}
418
419BOOST_AUTO_TEST_SUITE_END()