blob: aead1fe96510c9addb5385365e40b202d3394b97 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh01bceff2f2018-03-29 16:29:27 +01004//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
Matteo Martincigh46315822018-11-28 16:22:36 +00009
10#include <array>
surmeh01bceff2f2018-03-29 16:29:27 +010011#include <string>
12#include <iostream>
13
14BOOST_AUTO_TEST_SUITE(TensorflowParser)
15
telsoa01c577f2c2018-08-31 09:22:23 +010016struct Convolution2dFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010017{
Matteo Martincigh46315822018-11-28 16:22:36 +000018 explicit Convolution2dFixture(const std::string& dataLayout, const std::string& paddingType)
19 : Convolution2dFixture(dataLayout, paddingType, 1)
surmeh01bceff2f2018-03-29 16:29:27 +010020 {}
21
telsoa01c577f2c2018-08-31 09:22:23 +010022 // Dilation: 0 - dilations attribute is not included;
23 // Dilation: >0 - dilations attribute set to [1,v,v,1], where v is the value of the dilation arg
Matteo Martincigh46315822018-11-28 16:22:36 +000024 explicit Convolution2dFixture(const std::string& dataLayout, const std::string& paddingType,
25 int stride, int dilation = 0)
surmeh01bceff2f2018-03-29 16:29:27 +010026 {
Matteo Martincigh46315822018-11-28 16:22:36 +000027 std::string strideString (" i: 1 \n"
28 " i: 1 \n");
29 if (dataLayout == "NHWC")
30 {
31 strideString.append(" i: " + std::to_string(stride) + " \n"
32 " i: 1 \n");
33 }
34 else // dataLayout == "NCHW"
35 {
36 strideString.append(" i: 1 \n"
37 " i: " + std::to_string(stride) + " \n");
38 }
39
surmeh01bceff2f2018-03-29 16:29:27 +010040 std::string dilationString = std::to_string(dilation);
41 m_Prototext = "node { \n"
42 " name: \"graphInput\" \n"
43 " op: \"Placeholder\" \n"
44 " attr { \n"
45 " key: \"dtype\" \n"
46 " value { \n"
47 " type: DT_FLOAT \n"
48 " } \n"
49 " } \n"
50 " attr { \n"
51 " key: \"shape\" \n"
52 " value { \n"
53 " shape { \n"
54 " } \n"
55 " } \n"
56 " } \n"
57 " } \n"
58 " node { \n"
59 " name: \"Const_1\" \n"
60 " op: \"Const\" \n"
61 " attr { \n"
62 " key: \"dtype\" \n"
63 " value { \n"
64 " type: DT_FLOAT \n"
65 " } \n"
66 " } \n"
67 " attr { \n"
68 " key: \"value\" \n"
69 " value { \n"
70 " tensor { \n"
71 " dtype: DT_FLOAT \n"
72 " tensor_shape { \n"
73 " dim { \n"
74 " size: 1 \n"
75 " } \n"
76 " dim { \n"
77 " size: 3 \n"
78 " } \n"
79 " dim { \n"
80 " size: 1 \n"
81 " } \n"
82 " dim { \n"
83 " size: 1 \n"
84 " } \n"
85 " } \n"
86 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\" \n"
87 " } \n"
88 " } \n"
89 " } \n"
90 "} \n"
91 "node { \n"
92 " name: \"potato\" \n"
93 " op: \"Conv2D\" \n"
94 " input: \"graphInput\" \n"
95 " input: \"Const_1\" \n"
96 " attr { \n"
97 " key: \"T\" \n"
98 " value { \n"
99 " type: DT_FLOAT \n"
100 " } \n"
101 " } \n"
102 " attr { \n"
103 " key: \"data_format\" \n"
104 " value { \n"
surmeh01bceff2f2018-03-29 16:29:27 +0100105 " s: \"";
Matteo Martincigh46315822018-11-28 16:22:36 +0000106 m_Prototext.append(dataLayout);
107 m_Prototext.append("\"\n"
108 " } \n"
109 " } \n"
110 " attr { \n"
111 " key: \"padding\" \n"
112 " value { \n"
113 " s: \"");
surmeh01bceff2f2018-03-29 16:29:27 +0100114 m_Prototext.append(paddingType);
115 m_Prototext.append("\"\n"
116 " } \n"
117 " } \n"
118 " attr { \n"
119 " key: \"strides\" \n"
120 " value { \n"
Matteo Martincigh46315822018-11-28 16:22:36 +0000121 " list { \n");
surmeh01bceff2f2018-03-29 16:29:27 +0100122 m_Prototext.append(strideString);
Matteo Martincigh46315822018-11-28 16:22:36 +0000123
124 m_Prototext.append(" } \n"
surmeh01bceff2f2018-03-29 16:29:27 +0100125 " } \n"
126 " } \n");
127
128 if (dilation > 0)
129 {
130 m_Prototext.append(" attr { \n"
131 " key: \"dilations\" \n"
132 " value { \n"
133 " list { \n"
134 " i: 1 \n"
135 " i: ");
136 m_Prototext.append(dilationString);
137 m_Prototext.append(" \n"
138 " i: ");
139 m_Prototext.append(dilationString);
140 m_Prototext.append(" \n"
141 " i: 1 \n"
142 " } \n"
143 " } \n"
144 " } \n");
145 }
146 m_Prototext.append(" attr { \n"
147 " key: \"use_cudnn_on_gpu\" \n"
148 " value { \n"
149 " b: false \n"
150 " } \n"
151 " } \n"
152 "} \n");
153
154 // Manual height computation based on stride parameter.
Matteo Martincigh46315822018-11-28 16:22:36 +0000155 BOOST_ASSERT_MSG(stride == 1 || stride == 2, "Add support for strides other than 1 or 2.");
156 std::array<unsigned int, 4> dims;
157 if (dataLayout == "NHWC")
surmeh01bceff2f2018-03-29 16:29:27 +0100158 {
Matteo Martincigh46315822018-11-28 16:22:36 +0000159 dims = { 1u, (stride == 2 ? 3u : 2u), 3u, 1u };
160 }
161 else // dataLayout == "NCHW"
162 {
163 dims = { 1u, 1u, (stride == 2 ? 3u : 2u), 3u };
surmeh01bceff2f2018-03-29 16:29:27 +0100164 }
165
Matteo Martincigh46315822018-11-28 16:22:36 +0000166 SetupSingleInputSingleOutput(armnn::TensorShape(4, dims.data()), "graphInput", "potato");
surmeh01bceff2f2018-03-29 16:29:27 +0100167 }
168};
169
170
Matteo Martincigh46315822018-11-28 16:22:36 +0000171struct Convolution2dNhwcSameFixture : Convolution2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100172{
Matteo Martincigh46315822018-11-28 16:22:36 +0000173 Convolution2dNhwcSameFixture() : Convolution2dFixture("NHWC", "SAME", 1){}
surmeh01bceff2f2018-03-29 16:29:27 +0100174};
Matteo Martincigh46315822018-11-28 16:22:36 +0000175BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcSame, Convolution2dNhwcSameFixture)
surmeh01bceff2f2018-03-29 16:29:27 +0100176{
177 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
178}
179
Matteo Martincigh46315822018-11-28 16:22:36 +0000180struct Convolution2dNchwSameFixture : Convolution2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100181{
Matteo Martincigh46315822018-11-28 16:22:36 +0000182 Convolution2dNchwSameFixture() : Convolution2dFixture("NCHW", "SAME", 1){}
surmeh01bceff2f2018-03-29 16:29:27 +0100183};
Matteo Martincigh46315822018-11-28 16:22:36 +0000184BOOST_FIXTURE_TEST_CASE(ParseConv2dNchwSame, Convolution2dNchwSameFixture)
185{
186 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
187}
188
189
190struct Convolution2dNhwcValidFixture : Convolution2dFixture
191{
192 Convolution2dNhwcValidFixture() : Convolution2dFixture("NHWC", "VALID", 1){}
193};
194BOOST_FIXTURE_TEST_CASE(ParseConv2dNhwcValid, Convolution2dNhwcValidFixture)
195{
196 RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
197}
198
199struct Convolution2dNchwValidFixture : Convolution2dFixture
200{
201 Convolution2dNchwValidFixture() : Convolution2dFixture("NCHW", "VALID", 1){}
202};
203BOOST_FIXTURE_TEST_CASE(ParseConv2dNchwValid, Convolution2dNchwValidFixture)
surmeh01bceff2f2018-03-29 16:29:27 +0100204{
205 RunTest<4>({1, 2, 3, 4, 5, 6}, {4, 10});
206}
207
208
Matteo Martincigh46315822018-11-28 16:22:36 +0000209struct Convolution2dStride2NhwcSameFixture : Convolution2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100210{
Matteo Martincigh46315822018-11-28 16:22:36 +0000211 Convolution2dStride2NhwcSameFixture() : Convolution2dFixture("NHWC", "SAME", 2){}
surmeh01bceff2f2018-03-29 16:29:27 +0100212};
Matteo Martincigh46315822018-11-28 16:22:36 +0000213BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NhwcSame, Convolution2dStride2NhwcSameFixture)
214{
215 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
216}
217
218struct Convolution2dStride2NchwSameFixture : Convolution2dFixture
219{
220 Convolution2dStride2NchwSameFixture() : Convolution2dFixture("NCHW", "SAME", 2){}
221};
222BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NchwSame, Convolution2dStride2NchwSameFixture)
surmeh01bceff2f2018-03-29 16:29:27 +0100223{
224 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {2, 4, 6.5, 8.5, 11, 13});
225}
226
227
Matteo Martincigh46315822018-11-28 16:22:36 +0000228struct Convolution2dStride2NhwcValidFixture : Convolution2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100229{
Matteo Martincigh46315822018-11-28 16:22:36 +0000230 Convolution2dStride2NhwcValidFixture() : Convolution2dFixture("NHWC", "VALID", 2){}
surmeh01bceff2f2018-03-29 16:29:27 +0100231};
Matteo Martincigh46315822018-11-28 16:22:36 +0000232BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NhwcValid, Convolution2dStride2NhwcValidFixture)
233{
234 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
235}
236
237struct Convolution2dStride2NchwValidFixture : Convolution2dFixture
238{
239 Convolution2dStride2NchwValidFixture() : Convolution2dFixture("NCHW", "VALID", 2){}
240};
241BOOST_FIXTURE_TEST_CASE(ParseConv2dStride2NchwValid, Convolution2dStride2NchwValidFixture)
surmeh01bceff2f2018-03-29 16:29:27 +0100242{
243 RunTest<4>({1, 2, 3, 4, 5, 6, 7, 8, 9}, {4, 10, 16});
244}
245
246
Matteo Martincigh46315822018-11-28 16:22:36 +0000247struct Convolution2dDilation1NhwcFixture : Convolution2dFixture
surmeh01bceff2f2018-03-29 16:29:27 +0100248{
Matteo Martincigh46315822018-11-28 16:22:36 +0000249 Convolution2dDilation1NhwcFixture() : Convolution2dFixture("NHWC", "SAME", 1, 1){}
surmeh01bceff2f2018-03-29 16:29:27 +0100250};
Matteo Martincigh46315822018-11-28 16:22:36 +0000251BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nhwc, Convolution2dDilation1NhwcFixture)
surmeh01bceff2f2018-03-29 16:29:27 +0100252{
253 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
254}
255
Matteo Martincigh46315822018-11-28 16:22:36 +0000256struct Convolution2dDilation1NchwFixture : Convolution2dFixture
257{
258 Convolution2dDilation1NchwFixture() : Convolution2dFixture("NCHW", "SAME", 1, 1){}
259};
260BOOST_FIXTURE_TEST_CASE(ParseConv2dDilation1Nchw, Convolution2dDilation1NchwFixture)
261{
262 RunTest<4>({1, 2, 3, 4, 5, 6}, {2, 4, 4, 6.5f, 10 , 8.5f});
263}
264
265
266BOOST_AUTO_TEST_CASE(ParseConv2dDilation2)
surmeh01bceff2f2018-03-29 16:29:27 +0100267{
268 const char* prototext = ""
269 "node {\n"
270 " name: \"graphInput\"\n"
271 " op: \"Placeholder\"\n"
272 " attr {\n"
273 " key: \"dtype\"\n"
274 " value {\n"
275 " type: DT_FLOAT\n"
276 " }\n"
277 " }\n"
278 " attr {\n"
279 " key: \"shape\"\n"
280 " value {\n"
281 " shape {\n"
282 " }\n"
283 " }\n"
284 " }\n"
285 "}\n"
286 "node {\n"
287 " name: \"Const_1\"\n"
288 " op: \"Const\"\n"
289 " attr {\n"
290 " key: \"dtype\"\n"
291 " value {\n"
292 " type: DT_FLOAT\n"
293 " }\n"
294 " }\n"
295 " attr {\n"
296 " key: \"value\"\n"
297 " value {\n"
298 " tensor {\n"
299 " dtype: DT_FLOAT\n"
300 " tensor_shape {\n"
301 " dim {\n"
302 " size: 1\n"
303 " }\n"
304 " dim {\n"
305 " size: 3\n"
306 " }\n"
307 " dim {\n"
308 " size: 1\n"
309 " }\n"
310 " dim {\n"
311 " size: 1\n"
312 " }\n"
313 " }\n"
314 " tensor_content: \"\\000\\000\\000?\\000\\000\\200?\\000\\000\\000?\"\n"
315 " }\n"
316 " }\n"
317 " }\n"
318 "}\n"
319 "node {\n"
320 " name: \"potato\"\n"
321 " op: \"Conv2D\"\n"
322 " input: \"graphInput\"\n"
323 " input: \"Const_1\"\n"
324 " attr {\n"
325 " key: \"T\"\n"
326 " value {\n"
327 " type: DT_FLOAT\n"
328 " }\n"
329 " }\n"
330 " attr {\n"
331 " key: \"data_format\"\n"
332 " value {\n"
333 " s: \"NHWC\"\n"
334 " }\n"
335 " }\n"
336 " attr {\n"
337 " key: \"padding\"\n"
338 " value {\n"
339 " s: \"SAME\"\n"
340 " }\n"
341 " }\n"
342 " attr {\n"
343 " key: \"strides\"\n"
344 " value {\n"
345 " list {\n"
346 " i: 1\n"
347 " i: 1\n"
348 " i: 1\n"
349 " i: 1\n"
350 " }\n"
351 " }\n"
352 " }\n"
353 " attr {\n"
354 " key: \"dilations\"\n"
355 " value {\n"
356 " list {\n"
357 " i: 1\n"
358 " i: 2\n"
359 " i: 2\n"
360 " i: 1\n"
361 " }\n"
362 " }\n"
363 " }\n"
364 " attr {\n"
365 " key: \"use_cudnn_on_gpu\"\n"
366 " value {\n"
367 " b: false\n"
368 " }\n"
369 " }\n"
370 "}\n";
371
372 std::map<std::string, armnn::TensorShape> inputShapes;
373 armnn::TensorShape tensorShape = { 1, 3, 3, 1 };
374 inputShapes["graphInput"] = tensorShape;
375 armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
Matteo Martincigh46315822018-11-28 16:22:36 +0000376 BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, inputShapes, { "potato" }), armnn::ParseException);
surmeh01bceff2f2018-03-29 16:29:27 +0100377}
378
379
380BOOST_AUTO_TEST_SUITE_END()