blob: bbe961604cae846f777fc7482c094f033f30b155 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnOnnxParser/IOnnxParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(OnnxParser)
11
12struct BatchNormalizationMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13{
14 BatchNormalizationMainFixture()
15 {
16 m_Prototext = R"(
17 ir_version: 3
18 producer_name: "CNTK"
19 producer_version: "2.5.1"
20 domain: "ai.cntk"
21 model_version: 1
22 graph {
23 name: "CNTKGraph"
24 input {
25 name: "Input"
26 type {
27 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000028 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010029 shape {
30 dim {
31 dim_value: 1
32 }
33 dim {
34 dim_value: 1
35 }
36 dim {
37 dim_value: 3
38 }
39 dim {
40 dim_value: 3
41 }
42 }
43 }
44 }
45 }
46 input {
47 name: "mean"
48 type {
49 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000050 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010051 shape {
52 dim {
53 dim_value: 1
54 }
55 }
56 }
57 }
58 }
59 input {
60 name: "var"
61 type {
62 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000063 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010064 shape {
65 dim {
66 dim_value: 1
67 }
68 }
69 }
70 }
71 }
72 input {
73 name: "scale"
74 type {
75 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000076 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010077 shape {
78 dim {
79 dim_value: 1
80 }
81 }
82 }
83 }
84 }
85 input {
86 name: "bias"
87 type {
88 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000089 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010090 shape {
91 dim {
92 dim_value: 1
93 }
94 }
95 }
96 }
97 }
98 node {
99 input: "Input"
100 input: "scale"
101 input: "bias"
102 input: "mean"
103 input: "var"
104 output: "Output"
105 name: "batchNorm"
106 op_type: "BatchNormalization"
107 attribute {
108 name: "epsilon"
109 f: 0.0010000000475
Matteo Martincigh44a71672018-12-11 13:46:52 +0000110 type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100111 }
112 }
113 initializer {
114 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000115 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100116 float_data: 5.0
117 name: "mean"
118 }
119 initializer {
120 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000121 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100122 float_data: 2.0
123 name: "var"
124 }
125 initializer {
126 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000127 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100128 float_data: 0.0
129 name: "bias"
130 }
131 initializer {
132 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000133 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100134 float_data: 1.0
135 name: "scale"
136 }
137 output {
138 name: "Output"
139 type {
140 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000141 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100142 shape {
143 dim {
144 dim_value: 1
145 }
146 dim {
147 dim_value: 1
148 }
149 dim {
150 dim_value: 3
151 }
152 dim {
153 dim_value: 3
154 }
155 }
156 }
157 }
158 }
159 }
160 opset_import {
161 version: 7
162 })";
163 Setup();
164 }
165};
166
167BOOST_FIXTURE_TEST_CASE(ValidBatchNormalizationTest, BatchNormalizationMainFixture)
168{
169 RunTest<4>({{"Input", {1, 2, 3, 4, 5, 6, 7, 8, 9}}}, // Input data.
170 {{"Output", {-2.8277204f, -2.12079024f, -1.4138602f,
171 -0.7069301f, 0.0f, 0.7069301f,
172 1.4138602f, 2.12079024f, 2.8277204f}}}); // Expected output data.
173}
174
175
176struct BatchNormalizationBisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
177{
178 BatchNormalizationBisFixture()
179 {
180 m_Prototext = R"(
181 ir_version: 3
182 producer_name: "CNTK"
183 producer_version: "2.5.1"
184 domain: "ai.cntk"
185 model_version: 1
186 graph {
187 name: "CNTKGraph"
188 input {
189 name: "Input"
190 type {
191 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000192 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100193 shape {
194 dim {
195 dim_value: 1
196 }
197 dim {
198 dim_value: 2
199 }
200 dim {
201 dim_value: 1
202 }
203 dim {
204 dim_value: 3
205 }
206 }
207 }
208 }
209 }
210 input {
211 name: "mean"
212 type {
213 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000214 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100215 shape {
216 dim {
217 dim_value: 2
218 }
219 }
220 }
221 }
222 }
223 input {
224 name: "var"
225 type {
226 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000227 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100228 shape {
229 dim {
230 dim_value: 2
231 }
232 }
233 }
234 }
235 }
236 input {
237 name: "scale"
238 type {
239 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000240 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100241 shape {
242 dim {
243 dim_value: 2
244 }
245 }
246 }
247 }
248 }
249 input {
250 name: "bias"
251 type {
252 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000253 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100254 shape {
255 dim {
256 dim_value: 2
257 }
258 }
259 }
260 }
261 }
262 node {
263 input: "Input"
264 input: "scale"
265 input: "bias"
266 input: "mean"
267 input: "var"
268 output: "Output"
269 name: "batchNorm"
270 op_type: "BatchNormalization"
271 attribute {
272 name: "epsilon"
273 f: 0.00001
Matteo Martincigh44a71672018-12-11 13:46:52 +0000274 type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100275 }
276 }
277 initializer {
278 dims: 2
Matteo Martincigh44a71672018-12-11 13:46:52 +0000279 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100280 float_data: 0.0
281 float_data: 3.0
282 name: "mean"
283 }
284 initializer {
285 dims: 2
Matteo Martincigh44a71672018-12-11 13:46:52 +0000286 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100287 float_data: 1.0
288 float_data: 1.5
289 name: "var"
290 }
291 initializer {
292 dims: 2
Matteo Martincigh44a71672018-12-11 13:46:52 +0000293 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100294 float_data: 0.0
295 float_data: 1.0
296 name: "bias"
297 }
298 initializer {
299 dims: 2
Matteo Martincigh44a71672018-12-11 13:46:52 +0000300 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100301 float_data: 1.0
302 float_data: 1.5
303 name: "scale"
304 }
305 output {
306 name: "Output"
307 type {
308 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000309 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100310 shape {
311 dim {
312 dim_value: 1
313 }
314 dim {
315 dim_value: 2
316 }
317 dim {
318 dim_value: 1
319 }
320 dim {
321 dim_value: 3
322 }
323 }
324 }
325 }
326 }
327 }
328 opset_import {
329 version: 7
330 })";
331 Setup();
332 }
333};
334
335BOOST_FIXTURE_TEST_CASE(ValidBatchNormalizationBisTest, BatchNormalizationBisFixture)
336{
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100337 RunTest<4>({{"Input", {-1, 0.0, 1, 2, 3.0, 4.0}}}, // Input data.
telsoa01c577f2c2018-08-31 09:22:23 +0100338 {{"Output", {-0.999995f, 0.0, 0.999995f,
339 -0.22474074f, 1.0f, 2.2247407f}}}); // Expected output data.
340}
341
342BOOST_AUTO_TEST_SUITE_END()