blob: 8ad1036ef167c976de947033e8a0d491202e7242 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9#include <string>
10#include <iostream>
11
12BOOST_AUTO_TEST_SUITE(TensorflowParser)
13
telsoa01c577f2c2018-08-31 09:22:23 +010014struct Convolution2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010015{
16 explicit Convolution2dFixture(const char* paddingType)
17 : Convolution2dFixture(paddingType, 1)
18 {}
19
telsoa01c577f2c2018-08-31 09:22:23 +010020 // Dilation: 0 - dilations attribute is not included;
21 // Dilation: >0 - dilations attribute set to [1,v,v,1], where v is the value of the dilation arg
surmeh01bceff2f2018-03-29 16:29:27 +010022 explicit Convolution2dFixture(const char* paddingType, int stride, int dilation = 0)
23 {
24 std::string strideString = std::to_string(stride);
25 std::string dilationString = std::to_string(dilation);
26 m_Prototext = "node { \n"
27 " name: \"graphInput\" \n"
28 " op: \"Placeholder\" \n"
29 " attr { \n"
30 " key: \"dtype\" \n"
31 " value { \n"
32 " type: DT_FLOAT \n"
33 " } \n"
34 " } \n"
35 " attr { \n"
36 " key: \"shape\" \n"
37 " value { \n"
38 " shape { \n"
39 " } \n"
40 " } \n"
41 " } \n"
42 " } \n"
43 " node { \n"
44 " name: \"Const_1\" \n"
45 " op: \"Const\" \n"
46 " attr { \n"
47 " key: \"dtype\" \n"
48 " value { \n"
49 " type: DT_FLOAT \n"
50 " } \n"
51 " } \n"
52 " attr { \n"
53 " key: \"value\" \n"
54 " value { \n"
55 " tensor { \n"
56 " dtype: DT_FLOAT \n"
57 " tensor_shape { \n"
58 " dim { \n"
59 " size: 1 \n"
60 " } \n"
61 " dim { \n"
62 " size: 3 \n"
63 " } \n"
64 " dim { \n"
65 " size: 1 \n"
66 " } \n"
67 " dim { \n"
68 " size: 1 \n"
69 " } \n"
70 " } \n"
71 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
72 " } \n"
73 " } \n"
74 " } \n"
75 "} \n"
76 "node { \n"
77 " name: \"potato\" \n"
78 " op: \"Conv2D\" \n"
79 " input: \"graphInput\" \n"
80 " input: \"Const_1\" \n"
81 " attr { \n"
82 " key: \"T\" \n"
83 " value { \n"
84 " type: DT_FLOAT \n"
85 " } \n"
86 " } \n"
87 " attr { \n"
88 " key: \"data_format\" \n"
89 " value { \n"
90 " s: \"NHWC\" \n"
91 " } \n"
92 " } \n"
93 " attr { \n"
94 " key: \"padding\" \n"
95 " value { \n"
96 " s: \"";
97 m_Prototext.append(paddingType);
98 m_Prototext.append("\"\n"
99 " } \n"
100 " } \n"
101 " attr { \n"
102 " key: \"strides\" \n"
103 " value { \n"
104 " list { \n"
105 " i: 1 \n"
106 " i: 1 \n"
107 " i: ");
108 m_Prototext.append(strideString);
109 m_Prototext.append(" \n"
110 " i: 1 \n"
111 " } \n"
112 " } \n"
113 " } \n");
114
115 if (dilation > 0)
116 {
117 m_Prototext.append(" attr { \n"
118 " key: \"dilations\" \n"
119 " value { \n"
120 " list { \n"
121 " i: 1 \n"
122 " i: ");
123 m_Prototext.append(dilationString);
124 m_Prototext.append(" \n"
125 " i: ");
126 m_Prototext.append(dilationString);
127 m_Prototext.append(" \n"
128 " i: 1 \n"
129 " } \n"
130 " } \n"
131 " } \n");
132 }
133 m_Prototext.append(" attr { \n"
134 " key: \"use_cudnn_on_gpu\" \n"
135 " value { \n"
136 " b: false \n"
137 " } \n"
138 " } \n"
139 "} \n");
140
141 // Manual height computation based on stride parameter.
142 BOOST_ASSERT_MSG(stride == 1 || stride==2, "Add support for strides other than 1 or 2.");
143 unsigned int dims[] = {1,2,3,1};
144 if (stride == 2)
145 {
146 dims[1]=3;
147 }
148
149 SetupSingleInputSingleOutput(armnn::TensorShape(4, dims), "graphInput", "potato");
150 }
151};
152
153
154struct Convolution2dSameFixture : Convolution2dFixture
155{
156 Convolution2dSameFixture() : Convolution2dFixture("SAME", 1){}
157};
158BOOST_FIXTURE_TEST_CASE(ParseConv2DSame, Convolution2dSameFixture)
159{
160 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
161}
162
163struct Convolution2dValidFixture : Convolution2dFixture
164{
165 Convolution2dValidFixture() : Convolution2dFixture("VALID", 1){}
166};
167BOOST_FIXTURE_TEST_CASE(ParseConv2DValid, Convolution2dValidFixture)
168{
169 RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
170}
171
172
173struct Convolution2dStride2SameFixture : Convolution2dFixture
174{
175 Convolution2dStride2SameFixture() : Convolution2dFixture("SAME", 2){}
176};
177BOOST_FIXTURE_TEST_CASE(ParseConv2DStride2Same, Convolution2dStride2SameFixture)
178{
179 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
180}
181
182
183struct Convolution2dStride2ValidFixture : Convolution2dFixture
184{
185 Convolution2dStride2ValidFixture() : Convolution2dFixture("VALID", 2){}
186};
187BOOST_FIXTURE_TEST_CASE(ParseConv2DStride2Valid, Convolution2dStride2ValidFixture)
188{
189 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
190}
191
192
193struct Convolution2dDilation1Fixture : Convolution2dFixture
194{
195 Convolution2dDilation1Fixture() : Convolution2dFixture("SAME", 1, 1){}
196};
197BOOST_FIXTURE_TEST_CASE(ParseConv2DDilation1, Convolution2dDilation1Fixture)
198{
199 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
200}
201
202BOOST_AUTO_TEST_CASE(ParseConv2DDilation2)
203{
204 const char* prototext = ""
205 "node {\n"
206 " name: \"graphInput\"\n"
207 " op: \"Placeholder\"\n"
208 " attr {\n"
209 " key: \"dtype\"\n"
210 " value {\n"
211 " type: DT_FLOAT\n"
212 " }\n"
213 " }\n"
214 " attr {\n"
215 " key: \"shape\"\n"
216 " value {\n"
217 " shape {\n"
218 " }\n"
219 " }\n"
220 " }\n"
221 "}\n"
222 "node {\n"
223 " name: \"Const_1\"\n"
224 " op: \"Const\"\n"
225 " attr {\n"
226 " key: \"dtype\"\n"
227 " value {\n"
228 " type: DT_FLOAT\n"
229 " }\n"
230 " }\n"
231 " attr {\n"
232 " key: \"value\"\n"
233 " value {\n"
234 " tensor {\n"
235 " dtype: DT_FLOAT\n"
236 " tensor_shape {\n"
237 " dim {\n"
238 " size: 1\n"
239 " }\n"
240 " dim {\n"
241 " size: 3\n"
242 " }\n"
243 " dim {\n"
244 " size: 1\n"
245 " }\n"
246 " dim {\n"
247 " size: 1\n"
248 " }\n"
249 " }\n"
250 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\"\n"
251 " }\n"
252 " }\n"
253 " }\n"
254 "}\n"
255 "node {\n"
256 " name: \"potato\"\n"
257 " op: \"Conv2D\"\n"
258 " input: \"graphInput\"\n"
259 " input: \"Const_1\"\n"
260 " attr {\n"
261 " key: \"T\"\n"
262 " value {\n"
263 " type: DT_FLOAT\n"
264 " }\n"
265 " }\n"
266 " attr {\n"
267 " key: \"data_format\"\n"
268 " value {\n"
269 " s: \"NHWC\"\n"
270 " }\n"
271 " }\n"
272 " attr {\n"
273 " key: \"padding\"\n"
274 " value {\n"
275 " s: \"SAME\"\n"
276 " }\n"
277 " }\n"
278 " attr {\n"
279 " key: \"strides\"\n"
280 " value {\n"
281 " list {\n"
282 " i: 1\n"
283 " i: 1\n"
284 " i: 1\n"
285 " i: 1\n"
286 " }\n"
287 " }\n"
288 " }\n"
289 " attr {\n"
290 " key: \"dilations\"\n"
291 " value {\n"
292 " list {\n"
293 " i: 1\n"
294 " i: 2\n"
295 " i: 2\n"
296 " i: 1\n"
297 " }\n"
298 " }\n"
299 " }\n"
300 " attr {\n"
301 " key: \"use_cudnn_on_gpu\"\n"
302 " value {\n"
303 " b: false\n"
304 " }\n"
305 " }\n"
306 "}\n";
307
308 std::map<std::string, armnn::TensorShape> inputShapes;
309 armnn::TensorShape tensorShape = { 1, 3, 3, 1 };
310 inputShapes["graphInput"] = tensorShape;
311 armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
telsoa01c577f2c2018-08-31 09:22:23 +0100312 BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, inputShapes, { "potato" }),
313 armnn::ParseException);
surmeh01bceff2f2018-03-29 16:29:27 +0100314}
315
316
317BOOST_AUTO_TEST_SUITE_END()