blob: 5c06d8c876c582d0f2562d636eb11d75328756a3 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh01bceff2f2018-03-29 16:29:27 +01004//
5
6#include <boost/test/unit_test.hpp>
7
8#include "armnnTfParser/ITfParser.hpp"
9
10#include "ParserPrototxtFixture.hpp"
11
12BOOST_AUTO_TEST_SUITE(TensorflowParser)
13
14// Tests that a Const node in Tensorflow can be converted to a ConstLayer in armnn (as opposed to most
15// Const nodes which are used as weight inputs for convolutions etc. and are therefore not converted to
16// armnn ConstLayers).
telsoa01c577f2c2018-08-31 09:22:23 +010017struct ConstantFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010018{
19 ConstantFixture()
20 {
telsoa01c577f2c2018-08-31 09:22:23 +010021 // Input = tf.placeholder(tf.float32, name = "input")
22 // Const = tf.constant([17], tf.float32, [1])
23 // Output = tf.add(input, const, name = "output")
surmeh01bceff2f2018-03-29 16:29:27 +010024 m_Prototext =
25 R"(
26node {
27 name: "input"
28 op: "Placeholder"
29 attr {
30 key: "dtype"
31 value {
32 type: DT_FLOAT
33 }
34 }
35 attr {
36 key: "shape"
37 value {
38 shape {
39 unknown_rank: true
40 }
41 }
42 }
43}
44node {
45 name: "Const"
46 op: "Const"
47 attr {
48 key: "dtype"
49 value {
50 type: DT_FLOAT
51 }
52 }
53 attr {
54 key: "value"
55 value {
56 tensor {
57 dtype: DT_FLOAT
58 tensor_shape {
59 dim {
60 size: 1
61 }
62 }
63 float_val: 17.0
64 }
65 }
66 }
67}
68node {
69 name: "output"
70 op: "Add"
71 input: "input"
72 input: "Const"
73 attr {
74 key: "T"
75 value {
76 type: DT_FLOAT
77 }
78 }
79}
80 )";
81 SetupSingleInputSingleOutput({ 1 }, "input", "output");
82 }
83};
84
85BOOST_FIXTURE_TEST_CASE(Constant, ConstantFixture)
86{
87 RunTest<1>({1}, {18});
88}
89
90
91// Tests that a single Const node in Tensorflow can be used twice by a dependant node. This should result in only
92// a single armnn ConstLayer being created.
telsoa01c577f2c2018-08-31 09:22:23 +010093struct ConstantReusedFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +010094{
95 ConstantReusedFixture()
96 {
telsoa01c577f2c2018-08-31 09:22:23 +010097 // Const = tf.constant([17], tf.float32, [1])
98 // Output = tf.add(const, const, name = "output")
surmeh01bceff2f2018-03-29 16:29:27 +010099 m_Prototext =
100 R"(
101node {
102 name: "Const"
103 op: "Const"
104 attr {
105 key: "dtype"
106 value {
107 type: DT_FLOAT
108 }
109 }
110 attr {
111 key: "value"
112 value {
113 tensor {
114 dtype: DT_FLOAT
115 tensor_shape {
116 dim {
117 size: 1
118 }
119 }
120 float_val: 17.0
121 }
122 }
123 }
124}
125node {
126 name: "output"
127 op: "Add"
128 input: "Const"
129 input: "Const"
130 attr {
131 key: "T"
132 value {
133 type: DT_FLOAT
134 }
135 }
136}
137 )";
138 Setup({}, { "output" });
139 }
140};
141
142BOOST_FIXTURE_TEST_CASE(ConstantReused, ConstantReusedFixture)
143{
144 RunTest<1>({}, { { "output", { 34 } } });
145}
146
147template <int ListSize>
telsoa01c577f2c2018-08-31 09:22:23 +0100148struct ConstantValueListFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +0100149{
150 ConstantValueListFixture()
151 {
152 m_Prototext =
153 R"(
154node {
155 name: "output"
156 op: "Const"
157 attr {
158 key: "dtype"
159 value {
160 type: DT_FLOAT
161 }
162 }
163 attr {
164 key: "value"
165 value {
166 tensor {
167 dtype: DT_FLOAT
168 tensor_shape {
169 dim {
170 size: 2
171 }
172 dim {
173 size: 3
174 }
175 })";
176
177 double value = 0.75;
178 for (int i = 0; i < ListSize; i++, value += 0.25)
179 {
180 m_Prototext += std::string("float_val : ") + std::to_string(value) + "\n";
181 }
182
telsoa01c577f2c2018-08-31 09:22:23 +0100183 m_Prototext +=
surmeh01bceff2f2018-03-29 16:29:27 +0100184 R"(
185 }
186 }
187 }
188}
189 )";
190 Setup({}, { "output" });
191 }
192};
193
194using ConstantSingleValueListFixture = ConstantValueListFixture<1>;
195using ConstantMultipleValueListFixture = ConstantValueListFixture<4>;
196using ConstantMaxValueListFixture = ConstantValueListFixture<6>;
197
198BOOST_FIXTURE_TEST_CASE(ConstantSingleValueList, ConstantSingleValueListFixture)
199{
200 RunTest<2>({}, { { "output", { 0.75f, 0.75f, 0.75f, 0.75f, 0.75f, 0.75f } } });
201}
202BOOST_FIXTURE_TEST_CASE(ConstantMultipleValueList, ConstantMultipleValueListFixture)
203{
204 RunTest<2>({}, { { "output", { 0.75f, 1.f, 1.25f, 1.5f, 1.5f, 1.5f } } });
205}
206BOOST_FIXTURE_TEST_CASE(ConstantMaxValueList, ConstantMaxValueListFixture)
207{
208 RunTest<2>({}, { { "output", { 0.75f, 1.f, 1.25f, 1.50f, 1.75f, 2.f } } });
209}
210
211template <bool WithShape, bool WithContent, bool WithValueList>
telsoa01c577f2c2018-08-31 09:22:23 +0100212struct ConstantCreateFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
surmeh01bceff2f2018-03-29 16:29:27 +0100213{
214 ConstantCreateFixture()
215 {
216 m_Prototext =
217 R"(
218node {
219 name: "output"
220 op: "Const"
221 attr {
222 key: "dtype"
223 value {
224 type: DT_FLOAT
225 }
226 }
227 attr {
228 key: "value"
229 value {
230 tensor {
231 dtype: DT_FLOAT
232 )";
233
234 if (WithShape)
235 {
236 m_Prototext +=
237 R"(
238tensor_shape {
239 dim {
240 size: 2
241 }
242 dim {
243 size: 2
244 }
245}
246 )";
247 }
248 else
249 {
250 m_Prototext +=
251 R"(
252tensor_shape {
253}
254 )";
255 }
256
257 if (WithContent)
258 {
259 m_Prototext +=
260 R"(
261tensor_content: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?"
262 )";
263 }
264
265 if (WithValueList)
266 {
267 m_Prototext +=
268 R"(
269float_val: 1.0
270float_val: 1.0
271float_val: 1.0
272float_val: 1.0
273float_val: 1.0
274 )";
275 }
276
277 m_Prototext +=
278 R"(
279 }
280 }
281 }
282}
283 )";
284 }
285};
286
287using ConstantCreateNoValueListFixture = ConstantCreateFixture<true, false, true>;
288using ConstantCreateNoValueList2Fixture = ConstantCreateFixture<true, false, false>;
289using ConstantCreateNoContentFixture = ConstantCreateFixture<true, true, false>;
290using ConstantCreateNoContent2Fixture = ConstantCreateFixture<true, false, false>;
291using ConstantCreateNoShapeFixture = ConstantCreateFixture<false, false, false>;
292using ConstantCreateNoShape2Fixture = ConstantCreateFixture<false, true, false>;
293using ConstantCreateNoShape3Fixture = ConstantCreateFixture<false, false, true>;
294
295BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidValueList, ConstantCreateNoValueListFixture)
296{
297 BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
298}
299BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidValueList2, ConstantCreateNoValueList2Fixture)
300{
301 BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
302}
303BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidContent, ConstantCreateNoContentFixture)
304{
305 BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
306}
307BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidShape, ConstantCreateNoShapeFixture)
308{
309 BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
310}
311BOOST_FIXTURE_TEST_CASE(ConstantCreateNoShape2, ConstantCreateNoShape2Fixture)
312{
313 BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
314}
315BOOST_FIXTURE_TEST_CASE(ConstantCreateNoShape3, ConstantCreateNoShape3Fixture)
316{
317 Setup({}, { "output" });
318 RunTest<1>({}, { { "output", { 1.f, 1.f, 1.f, 1.f, 1.f } } });
319}
320
321BOOST_AUTO_TEST_SUITE_END()