blob: 73d113d8e122526e9f23c909afd93365a81428bc [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
telsoa01c577f2c2018-08-31 09:22:23 +01006#include "armnnOnnxParser/IOnnxParser.hpp"
7#include "ParserPrototxtFixture.hpp"
8
Sadik Armagan1625efc2021-06-10 18:24:34 +01009TEST_SUITE("OnnxParser_Pooling")
10{
telsoa01c577f2c2018-08-31 09:22:23 +010011struct PoolingMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12{
13 PoolingMainFixture(const std::string& dataType, const std::string& op)
14 {
15 m_Prototext = R"(
16 ir_version: 3
17 producer_name: "CNTK"
18 producer_version: "2.5.1"
19 domain: "ai.cntk"
20 model_version: 1
21 graph {
22 name: "CNTKGraph"
23 input {
24 name: "Input"
25 type {
26 tensor_type {
27 elem_type: )" + dataType + R"(
28 shape {
29 dim {
30 dim_value: 1
31 }
32 dim {
33 dim_value: 1
34 }
35 dim {
36 dim_value: 2
37 }
38 dim {
39 dim_value: 2
40 }
41 }
42 }
43 }
44 }
45 node {
46 input: "Input"
47 output: "Output"
48 name: "Pooling"
49 op_type: )" + op + R"(
50 attribute {
51 name: "kernel_shape"
52 ints: 2
53 ints: 2
54 type: INTS
55 }
56 attribute {
57 name: "strides"
58 ints: 1
59 ints: 1
60 type: INTS
61 }
62 attribute {
63 name: "pads"
64 ints: 0
65 ints: 0
66 ints: 0
67 ints: 0
68 type: INTS
69 }
70 }
71 output {
72 name: "Output"
73 type {
74 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +000075 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +010076 shape {
77 dim {
78 dim_value: 1
79 }
80 dim {
81 dim_value: 1
82 }
83 dim {
84 dim_value: 1
85 }
86 dim {
87 dim_value: 1
88 }
89 }
90 }
91 }
92 }
93 }
94 opset_import {
95 version: 7
96 })";
97 }
98};
99
100struct MaxPoolValidFixture : PoolingMainFixture
101{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000102 MaxPoolValidFixture() : PoolingMainFixture("1", "\"MaxPool\"") {
telsoa01c577f2c2018-08-31 09:22:23 +0100103 Setup();
104 }
105};
106
107struct MaxPoolInvalidFixture : PoolingMainFixture
108{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000109 MaxPoolInvalidFixture() : PoolingMainFixture("10", "\"MaxPool\"") { }
telsoa01c577f2c2018-08-31 09:22:23 +0100110};
111
Sadik Armagan1625efc2021-06-10 18:24:34 +0100112TEST_CASE_FIXTURE(MaxPoolValidFixture, "ValidMaxPoolTest")
telsoa01c577f2c2018-08-31 09:22:23 +0100113{
114 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {3.0f}}});
115}
116
117struct AvgPoolValidFixture : PoolingMainFixture
118{
Matteo Martincigh44a71672018-12-11 13:46:52 +0000119 AvgPoolValidFixture() : PoolingMainFixture("1", "\"AveragePool\"") {
telsoa01c577f2c2018-08-31 09:22:23 +0100120 Setup();
121 }
122};
123
124struct PoolingWithPadFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
125{
126 PoolingWithPadFixture()
127 {
128 m_Prototext = R"(
129 ir_version: 3
130 producer_name: "CNTK"
131 producer_version: "2.5.1"
132 domain: "ai.cntk"
133 model_version: 1
134 graph {
135 name: "CNTKGraph"
136 input {
137 name: "Input"
138 type {
139 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000140 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100141 shape {
142 dim {
143 dim_value: 1
144 }
145 dim {
146 dim_value: 1
147 }
148 dim {
149 dim_value: 2
150 }
151 dim {
152 dim_value: 2
153 }
154 }
155 }
156 }
157 }
158 node {
159 input: "Input"
160 output: "Output"
161 name: "Pooling"
162 op_type: "AveragePool"
163 attribute {
164 name: "kernel_shape"
165 ints: 4
166 ints: 4
167 type: INTS
168 }
169 attribute {
170 name: "strides"
171 ints: 1
172 ints: 1
173 type: INTS
174 }
175 attribute {
176 name: "pads"
177 ints: 1
178 ints: 1
179 ints: 1
180 ints: 1
181 type: INTS
182 }
183 attribute {
184 name: "count_include_pad"
185 i: 1
186 type: INT
187 }
188 }
189 output {
190 name: "Output"
191 type {
192 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000193 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100194 shape {
195 dim {
196 dim_value: 1
197 }
198 dim {
199 dim_value: 1
200 }
201 dim {
202 dim_value: 1
203 }
204 dim {
205 dim_value: 1
206 }
207 }
208 }
209 }
210 }
211 }
212 opset_import {
213 version: 7
214 })";
215 Setup();
216 }
217};
218
Sadik Armagan1625efc2021-06-10 18:24:34 +0100219TEST_CASE_FIXTURE(AvgPoolValidFixture, "AveragePoolValid")
telsoa01c577f2c2018-08-31 09:22:23 +0100220{
221 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {0.5}}});
222}
223
Sadik Armagan1625efc2021-06-10 18:24:34 +0100224TEST_CASE_FIXTURE(PoolingWithPadFixture, "ValidAvgWithPadTest")
telsoa01c577f2c2018-08-31 09:22:23 +0100225{
226 RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {1.0/8.0}}});
227}
228
229struct GlobalAvgFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
230{
231 GlobalAvgFixture()
232 {
233 m_Prototext = R"(
234 ir_version: 3
235 producer_name: "CNTK"
236 producer_version: "2.5.1"
237 domain: "ai.cntk"
238 model_version: 1
239 graph {
240 name: "CNTKGraph"
241 input {
242 name: "Input"
243 type {
244 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000245 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100246 shape {
247 dim {
248 dim_value: 1
249 }
250 dim {
251 dim_value: 2
252 }
253 dim {
254 dim_value: 2
255 }
256 dim {
257 dim_value: 2
258 }
259 }
260 }
261 }
262 }
263 node {
264 input: "Input"
265 output: "Output"
266 name: "Pooling"
267 op_type: "GlobalAveragePool"
268 }
269 output {
270 name: "Output"
271 type {
272 tensor_type {
Matteo Martincigh44a71672018-12-11 13:46:52 +0000273 elem_type: 1
telsoa01c577f2c2018-08-31 09:22:23 +0100274 shape {
275 dim {
276 dim_value: 1
277 }
278 dim {
279 dim_value: 2
280 }
281 dim {
282 dim_value: 1
283 }
284 dim {
285 dim_value: 1
286 }
287 }
288 }
289 }
290 }
291 }
292 opset_import {
293 version: 7
294 })";
295 Setup();
296 }
297};
298
Sadik Armagan1625efc2021-06-10 18:24:34 +0100299TEST_CASE_FIXTURE(GlobalAvgFixture, "GlobalAvgTest")
telsoa01c577f2c2018-08-31 09:22:23 +0100300{
301 RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}}}, {{"Output", {10/4.0, 26/4.0}}});
302}
303
Sadik Armagan1625efc2021-06-10 18:24:34 +0100304TEST_CASE_FIXTURE(MaxPoolInvalidFixture, "IncorrectDataTypeMaxPool")
telsoa01c577f2c2018-08-31 09:22:23 +0100305{
Sadik Armagan1625efc2021-06-10 18:24:34 +0100306 CHECK_THROWS_AS(Setup(), armnn::ParseException);
telsoa01c577f2c2018-08-31 09:22:23 +0100307}
308
Sadik Armagan1625efc2021-06-10 18:24:34 +0100309}