blob: d18c2774571e92668148923d0eb2416223cbda87 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "armnnOnnxParser/IOnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8
Sadik Armagan1625efc2021-06-10 18:24:34 +01009TEST_SUITE("OnnxParser_Addition")
10{
telsoa01c577f2c2018-08-31 09:22:23 +010011struct AddMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12{
13 AddMainFixture(const std::string& dataType)
14 {
15 m_Prototext = R"(
16 ir_version: 3
17 producer_name: "CNTK"
18 producer_version: "2.5.1"
19 domain: "ai.cntk"
20 model_version: 1
21 graph {
22 name: "CNTKGraph"
23 input {
24 name: "Input0"
25 type {
26 tensor_type {
27 elem_type: )" + dataType + R"(
28 shape {
29 dim {
30 dim_value: 1
31 }
32 dim {
33 dim_value: 1
34 }
35 dim {
36 dim_value: 2
37 }
38 dim {
39 dim_value: 2
40 }
41 }
42 }
43 }
44 }
45 input {
46 name: "Input1"
47 type {
48 tensor_type {
49 elem_type: )" + dataType + R"(
50 shape {
51 dim {
52 dim_value: 1
53 }
54 dim {
55 dim_value: 1
56 }
57 dim {
58 dim_value: 2
59 }
60 dim {
61 dim_value: 2
62 }
63 }
64 }
65 }
66 }
67 node {
68 input: "Input0"
69 input: "Input1"
70 output: "Output"
71 name: "addition"
72 op_type: "Add"
73 doc_string: ""
74 domain: ""
75 }
76 output {
77 name: "Output"
78 type {
79 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000080 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010081 shape {
82 dim {
83 dim_value: 1
84 }
85 dim {
86 dim_value: 1
87 }
88 dim {
89 dim_value: 2
90 }
91 dim {
92 dim_value: 2
93 }
94 }
95 }
96 }
97 }
98 }
99 opset_import {
100 version: 7
101 })";
102 }
103};
104
105struct AddValidFixture : AddMainFixture
106{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000107 AddValidFixture() : AddMainFixture("1") {
telsoa01c577f2c2018-08-31 09:22:23 +0100108 Setup();
109 }
110};
111
112struct AddInvalidFixture : AddMainFixture
113{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000114 AddInvalidFixture() : AddMainFixture("6") { }
telsoa01c577f2c2018-08-31 09:22:23 +0100115};
116
117struct AddValidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
118{
119 AddValidBroadcastFixture() {
120
121 m_Prototext = R"(
122 ir_version: 3
123 producer_name: "CNTK"
124 producer_version: "2.5.1"
125 domain: "ai.cntk"
126 model_version: 1
127 graph {
128 name: "CNTKGraph"
129 input {
130 name: "Input0"
131 type {
132 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000133 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100134 shape {
135 dim {
136 dim_value: 1
137 }
138 dim {
139 dim_value: 1
140 }
141 dim {
142 dim_value: 1
143 }
144 dim {
145 dim_value: 4
146 }
147 }
148 }
149 }
150 }
151 input {
152 name: "Input1"
153 type {
154 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000155 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100156 shape {
157 dim {
158 dim_value: 4
159 }
160 }
161 }
162 }
163 }
164 node {
165 input: "Input0"
166 input: "Input1"
167 output: "Output"
168 name: "addition"
169 op_type: "Add"
170 doc_string: ""
171 domain: ""
172 }
173 output {
174 name: "Output"
175 type {
176 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000177 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100178 shape {
179 dim {
180 dim_value: 1
181 }
182 dim {
183 dim_value: 1
184 }
185 dim {
186 dim_value: 1
187 }
188 dim {
189 dim_value: 4
190 }
191 }
192 }
193 }
194 }
195 }
196 opset_import {
197 version: 7
198 })";
199 Setup();
200 }
201};
202
203struct AddInvalidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
204{
205 AddInvalidBroadcastFixture() {
206
207 m_Prototext = R"(
208 ir_version: 3
209 producer_name: "CNTK"
210 producer_version: "2.5.1"
211 domain: "ai.cntk"
212 model_version: 1
213 graph {
214 name: "CNTKGraph"
215 input {
216 name: "Input0"
217 type {
218 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000219 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100220 shape {
221 dim {
222 dim_value: 1
223 }
224 dim {
225 dim_value: 1
226 }
227 dim {
228 dim_value: 1
229 }
230 dim {
231 dim_value: 3
232 }
233 }
234 }
235 }
236 }
237 input {
238 name: "Input1"
239 type {
240 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000241 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100242 shape {
243 dim {
244 dim_value: 4
245 }
246 }
247 }
248 }
249 }
250 node {
251 input: "Input0"
252 input: "Input1"
253 output: "Output"
254 name: "addition"
255 op_type: "Add"
256 doc_string: ""
257 domain: ""
258 }
259 output {
260 name: "Output"
261 type {
262 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000263 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100264 shape {
265 dim {
266 dim_value: 1
267 }
268 dim {
269 dim_value: 1
270 }
271 dim {
272 dim_value: 1
273 }
274 dim {
275 dim_value: 4
276 }
277 }
278 }
279 }
280 }
281 }
282 opset_import {
283 version: 7
284 })";
285 }
286};
287
Ryan OShea337c17f2020-02-21 12:33:17 +0000288struct AddScalarFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
289{
290 AddScalarFixture(const std::string& dataType)
291 {
292 m_Prototext = R"(
293 ir_version: 3
294 producer_name: "CNTK"
295 producer_version: "2.5.1"
296 domain: "ai.cntk"
297 model_version: 1
298 graph {
299 name: "CNTKGraph"
300 input {
301 name: "Input0"
302 type {
303 tensor_type {
304 elem_type: )" + dataType + R"(
305 shape {
306 dim {
307 dim_value: 1
308 }
309 dim {
310 dim_value: 1
311 }
312 dim {
313 dim_value: 2
314 }
315 dim {
316 dim_value: 2
317 }
318 }
319 }
320 }
321 }
322 input {
323 name: "Input1"
324 type {
325 tensor_type {
326 elem_type: )" + dataType + R"(
327 shape {
328 dim {
329 dim_value: 1
330 }
331 }
332 }
333 }
334 }
335 node {
336 input: "Input0"
337 input: "Input1"
338 output: "Output"
339 name: "addition"
340 op_type: "Add"
341 doc_string: ""
342 domain: ""
343 }
344 output {
345 name: "Output"
346 type {
347 tensor_type {
348 elem_type: 1
349 shape {
350 dim {
351 dim_value: 1
352 }
353 dim {
354 dim_value: 1
355 }
356 dim {
357 dim_value: 2
358 }
359 dim {
360 dim_value: 2
361 }
362 }
363 }
364 }
365 }
366 }
367 opset_import {
368 version: 7
369 })";
370 }
371};
372
373struct AddValidScalarFixture : AddScalarFixture
374{
375 AddValidScalarFixture() : AddScalarFixture("1") {
376 Setup();
377 }
378};
379
380struct AddInvalidScalarFixture : AddScalarFixture
381{
382 AddInvalidScalarFixture() : AddScalarFixture("6") { }
383};
384
Sadik Armagan1625efc2021-06-10 18:24:34 +0100385TEST_CASE_FIXTURE(AddValidFixture, "ValidAddTest")
telsoa01c577f2c2018-08-31 09:22:23 +0100386{
387 RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
388 {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}});
389}
390
Sadik Armagan1625efc2021-06-10 18:24:34 +0100391TEST_CASE_FIXTURE(AddInvalidFixture, "IncorrectDataTypeAdd")
telsoa01c577f2c2018-08-31 09:22:23 +0100392{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100393 CHECK_THROWS_AS(Setup(), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100394}
395
Sadik Armagan1625efc2021-06-10 18:24:34 +0100396TEST_CASE_FIXTURE(AddInvalidBroadcastFixture, "InvalidBroadcastAdd")
telsoa01c577f2c2018-08-31 09:22:23 +0100397{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100398 CHECK_THROWS_AS(Setup(), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100399}
400
Sadik Armagan1625efc2021-06-10 18:24:34 +0100401TEST_CASE_FIXTURE(AddValidBroadcastFixture, "ValidBroadcastAdd")
telsoa01c577f2c2018-08-31 09:22:23 +0100402{
403 RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
404 {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}});
405}
406
Sadik Armagan1625efc2021-06-10 18:24:34 +0100407TEST_CASE_FIXTURE(AddValidScalarFixture, "ValidAddScalarTest")
Ryan OShea337c17f2020-02-21 12:33:17 +0000408{
409 RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
410 {"Input1", {-8.0f}}}, {{"Output", {-7.0, -6.0, -11.0, -12.0}}});
411}
412
Sadik Armagan1625efc2021-06-10 18:24:34 +0100413TEST_CASE_FIXTURE(AddInvalidScalarFixture, "IncorrectDataTypeAddScalar")
Ryan OShea337c17f2020-02-21 12:33:17 +0000414{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100415 CHECK_THROWS_AS(Setup(), armnn::ParseException);
Ryan OShea337c17f2020-02-21 12:33:17 +0000416}
417
Sadik Armagan1625efc2021-06-10 18:24:34 +0100418}