blob: f582dbd7135cac4ca3e23c0bc4d79b8da2ceaee4 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "armnnOnnxParser/IOnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8
Sadik Armagan1625efc2021-06-10 18:24:34 +01009TEST_SUITE("OnnxParser_BatchNorm")
10{
telsoa01c577f2c2018-08-31 09:22:23 +010011struct BatchNormalizationMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12{
13 BatchNormalizationMainFixture()
14 {
15 m_Prototext = R"(
16 ir_version: 3
17 producer_name: "CNTK"
18 producer_version: "2.5.1"
19 domain: "ai.cntk"
20 model_version: 1
21 graph {
22 name: "CNTKGraph"
23 input {
24 name: "Input"
25 type {
26 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000027 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010028 shape {
29 dim {
30 dim_value: 1
31 }
32 dim {
33 dim_value: 1
34 }
35 dim {
36 dim_value: 3
37 }
38 dim {
39 dim_value: 3
40 }
41 }
42 }
43 }
44 }
45 input {
46 name: "mean"
47 type {
48 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000049 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010050 shape {
51 dim {
52 dim_value: 1
53 }
54 }
55 }
56 }
57 }
58 input {
59 name: "var"
60 type {
61 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000062 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010063 shape {
64 dim {
65 dim_value: 1
66 }
67 }
68 }
69 }
70 }
71 input {
72 name: "scale"
73 type {
74 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000075 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010076 shape {
77 dim {
78 dim_value: 1
79 }
80 }
81 }
82 }
83 }
84 input {
85 name: "bias"
86 type {
87 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000088 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010089 shape {
90 dim {
91 dim_value: 1
92 }
93 }
94 }
95 }
96 }
97 node {
98 input: "Input"
99 input: "scale"
100 input: "bias"
101 input: "mean"
102 input: "var"
103 output: "Output"
104 name: "batchNorm"
105 op_type: "BatchNormalization"
106 attribute {
107 name: "epsilon"
108 f: 0.0010000000475
Matteo Martincigh44a71672018-12-11 13:46:52 +0000109 type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100110 }
111 }
112 initializer {
113 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000114 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100115 float_data: 5.0
116 name: "mean"
117 }
118 initializer {
119 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000120 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100121 float_data: 2.0
122 name: "var"
123 }
124 initializer {
125 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000126 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100127 float_data: 0.0
128 name: "bias"
129 }
130 initializer {
131 dims: 1
Matteo Martincigh44a71672018-12-11 13:46:52 +0000132 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100133 float_data: 1.0
134 name: "scale"
135 }
136 output {
137 name: "Output"
138 type {
139 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000140 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100141 shape {
142 dim {
143 dim_value: 1
144 }
145 dim {
146 dim_value: 1
147 }
148 dim {
149 dim_value: 3
150 }
151 dim {
152 dim_value: 3
153 }
154 }
155 }
156 }
157 }
158 }
159 opset_import {
160 version: 7
161 })";
162 Setup();
163 }
164};
165
Sadik Armagan1625efc2021-06-10 18:24:34 +0100166TEST_CASE_FIXTURE(BatchNormalizationMainFixture, "ValidBatchNormalizationTest")
telsoa01c577f2c2018-08-31 09:22:23 +0100167{
168 RunTest<4>({{"Input", {1, 2, 3, 4, 5, 6, 7, 8, 9}}}, // Input data.
169 {{"Output", {-2.8277204f, -2.12079024f, -1.4138602f,
170 -0.7069301f, 0.0f, 0.7069301f,
171 1.4138602f, 2.12079024f, 2.8277204f}}}); // Expected output data.
172}
173
174
175struct BatchNormalizationBisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
176{
177 BatchNormalizationBisFixture()
178 {
179 m_Prototext = R"(
180 ir_version: 3
181 producer_name: "CNTK"
182 producer_version: "2.5.1"
183 domain: "ai.cntk"
184 model_version: 1
185 graph {
186 name: "CNTKGraph"
187 input {
188 name: "Input"
189 type {
190 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000191 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100192 shape {
193 dim {
194 dim_value: 1
195 }
196 dim {
197 dim_value: 2
198 }
199 dim {
200 dim_value: 1
201 }
202 dim {
203 dim_value: 3
204 }
205 }
206 }
207 }
208 }
209 input {
210 name: "mean"
211 type {
212 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000213 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100214 shape {
215 dim {
216 dim_value: 2
217 }
218 }
219 }
220 }
221 }
222 input {
223 name: "var"
224 type {
225 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000226 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100227 shape {
228 dim {
229 dim_value: 2
230 }
231 }
232 }
233 }
234 }
235 input {
236 name: "scale"
237 type {
238 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000239 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100240 shape {
241 dim {
242 dim_value: 2
243 }
244 }
245 }
246 }
247 }
248 input {
249 name: "bias"
250 type {
251 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000252 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100253 shape {
254 dim {
255 dim_value: 2
256 }
257 }
258 }
259 }
260 }
261 node {
262 input: "Input"
263 input: "scale"
264 input: "bias"
265 input: "mean"
266 input: "var"
267 output: "Output"
268 name: "batchNorm"
269 op_type: "BatchNormalization"
270 attribute {
271 name: "epsilon"
272 f: 0.00001
Matteo Martincigh44a71672018-12-11 13:46:52 +0000273 type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100274 }
275 }
276 initializer {
277 dims: 2
Matteo Martincigh44a71672018-12-11 13:46:52 +0000278 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100279 float_data: 0.0
280 float_data: 3.0
281 name: "mean"
282 }
283 initializer {
284 dims: 2
Matteo Martincigh44a71672018-12-11 13:46:52 +0000285 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100286 float_data: 1.0
287 float_data: 1.5
288 name: "var"
289 }
290 initializer {
291 dims: 2
Matteo Martincigh44a71672018-12-11 13:46:52 +0000292 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100293 float_data: 0.0
294 float_data: 1.0
295 name: "bias"
296 }
297 initializer {
298 dims: 2
Matteo Martincigh44a71672018-12-11 13:46:52 +0000299 data_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100300 float_data: 1.0
301 float_data: 1.5
302 name: "scale"
303 }
304 output {
305 name: "Output"
306 type {
307 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000308 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100309 shape {
310 dim {
311 dim_value: 1
312 }
313 dim {
314 dim_value: 2
315 }
316 dim {
317 dim_value: 1
318 }
319 dim {
320 dim_value: 3
321 }
322 }
323 }
324 }
325 }
326 }
327 opset_import {
328 version: 7
329 })";
330 Setup();
331 }
332};
333
Sadik Armagan1625efc2021-06-10 18:24:34 +0100334TEST_CASE_FIXTURE(BatchNormalizationBisFixture, "ValidBatchNormalizationBisTest")
telsoa01c577f2c2018-08-31 09:22:23 +0100335{
Matthew Bentham7c1603a2019-06-21 17:22:23 +0100336 RunTest<4>({{"Input", {-1, 0.0, 1, 2, 3.0, 4.0}}}, // Input data.
telsoa01c577f2c2018-08-31 09:22:23 +0100337 {{"Output", {-0.999995f, 0.0, 0.999995f,
338 -0.22474074f, 1.0f, 2.2247407f}}}); // Expected output data.
339}
340
Sadik Armagan1625efc2021-06-10 18:24:34 +0100341}