blob: 993a620dccb3c4eef80c70a2d425be29bfdc2914 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnOnnxParser/IOnnxParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12struct AddMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13{
14 AddMainFixture(const std::string& dataType)
15 {
16 m_Prototext = R"(
17 ir_version: 3
18 producer_name: "CNTK"
19 producer_version: "2.5.1"
20 domain: "ai.cntk"
21 model_version: 1
22 graph {
23 name: "CNTKGraph"
24 input {
25 name: "Input0"
26 type {
27 tensor_type {
28 elem_type: )" + dataType + R"(
29 shape {
30 dim {
31 dim_value: 1
32 }
33 dim {
34 dim_value: 1
35 }
36 dim {
37 dim_value: 2
38 }
39 dim {
40 dim_value: 2
41 }
42 }
43 }
44 }
45 }
46 input {
47 name: "Input1"
48 type {
49 tensor_type {
50 elem_type: )" + dataType + R"(
51 shape {
52 dim {
53 dim_value: 1
54 }
55 dim {
56 dim_value: 1
57 }
58 dim {
59 dim_value: 2
60 }
61 dim {
62 dim_value: 2
63 }
64 }
65 }
66 }
67 }
68 node {
69 input: "Input0"
70 input: "Input1"
71 output: "Output"
72 name: "addition"
73 op_type: "Add"
74 doc_string: ""
75 domain: ""
76 }
77 output {
78 name: "Output"
79 type {
80 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000081 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010082 shape {
83 dim {
84 dim_value: 1
85 }
86 dim {
87 dim_value: 1
88 }
89 dim {
90 dim_value: 2
91 }
92 dim {
93 dim_value: 2
94 }
95 }
96 }
97 }
98 }
99 }
100 opset_import {
101 version: 7
102 })";
103 }
104};
105
106struct AddValidFixture : AddMainFixture
107{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000108 AddValidFixture() : AddMainFixture("1") {
telsoa01c577f2c2018-08-31 09:22:23 +0100109 Setup();
110 }
111};
112
113struct AddInvalidFixture : AddMainFixture
114{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000115 AddInvalidFixture() : AddMainFixture("6") { }
telsoa01c577f2c2018-08-31 09:22:23 +0100116};
117
118struct AddValidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
119{
120 AddValidBroadcastFixture() {
121
122 m_Prototext = R"(
123 ir_version: 3
124 producer_name: "CNTK"
125 producer_version: "2.5.1"
126 domain: "ai.cntk"
127 model_version: 1
128 graph {
129 name: "CNTKGraph"
130 input {
131 name: "Input0"
132 type {
133 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000134 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100135 shape {
136 dim {
137 dim_value: 1
138 }
139 dim {
140 dim_value: 1
141 }
142 dim {
143 dim_value: 1
144 }
145 dim {
146 dim_value: 4
147 }
148 }
149 }
150 }
151 }
152 input {
153 name: "Input1"
154 type {
155 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000156 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100157 shape {
158 dim {
159 dim_value: 4
160 }
161 }
162 }
163 }
164 }
165 node {
166 input: "Input0"
167 input: "Input1"
168 output: "Output"
169 name: "addition"
170 op_type: "Add"
171 doc_string: ""
172 domain: ""
173 }
174 output {
175 name: "Output"
176 type {
177 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000178 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100179 shape {
180 dim {
181 dim_value: 1
182 }
183 dim {
184 dim_value: 1
185 }
186 dim {
187 dim_value: 1
188 }
189 dim {
190 dim_value: 4
191 }
192 }
193 }
194 }
195 }
196 }
197 opset_import {
198 version: 7
199 })";
200 Setup();
201 }
202};
203
204struct AddInvalidBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
205{
206 AddInvalidBroadcastFixture() {
207
208 m_Prototext = R"(
209 ir_version: 3
210 producer_name: "CNTK"
211 producer_version: "2.5.1"
212 domain: "ai.cntk"
213 model_version: 1
214 graph {
215 name: "CNTKGraph"
216 input {
217 name: "Input0"
218 type {
219 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000220 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100221 shape {
222 dim {
223 dim_value: 1
224 }
225 dim {
226 dim_value: 1
227 }
228 dim {
229 dim_value: 1
230 }
231 dim {
232 dim_value: 3
233 }
234 }
235 }
236 }
237 }
238 input {
239 name: "Input1"
240 type {
241 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000242 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100243 shape {
244 dim {
245 dim_value: 4
246 }
247 }
248 }
249 }
250 }
251 node {
252 input: "Input0"
253 input: "Input1"
254 output: "Output"
255 name: "addition"
256 op_type: "Add"
257 doc_string: ""
258 domain: ""
259 }
260 output {
261 name: "Output"
262 type {
263 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000264 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100265 shape {
266 dim {
267 dim_value: 1
268 }
269 dim {
270 dim_value: 1
271 }
272 dim {
273 dim_value: 1
274 }
275 dim {
276 dim_value: 4
277 }
278 }
279 }
280 }
281 }
282 }
283 opset_import {
284 version: 7
285 })";
286 }
287};
288
289BOOST_FIXTURE_TEST_CASE(ValidAddTest, AddValidFixture)
290{
291 RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
292 {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}});
293}
294
295BOOST_FIXTURE_TEST_CASE(IncorrectDataTypeAdd, AddInvalidFixture)
296{
297 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
298}
299
300BOOST_FIXTURE_TEST_CASE(InvalidBroadcastAdd, AddInvalidBroadcastFixture)
301{
302 BOOST_CHECK_THROW(Setup(), armnn::ParseException);
303}
304
305BOOST_FIXTURE_TEST_CASE(ValidBroadcastAdd, AddValidBroadcastFixture)
306{
307 RunTest<4>({{"Input0", {1.0f, 2.0f, -3.0f, -4.0f}},
308 {"Input1", {1.0f, 2.0f, 3.0, 4.0f}}}, {{"Output", {2.0, 4.0, 0, 0.0}}});
309}
310
311BOOST_AUTO_TEST_SUITE_END()