blob: 7316b9f1ac31832cce593d0a58706b008bc4ac5f [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12struct ConcatOfConcatsFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
13{
14 explicit ConcatOfConcatsFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1,
15 const armnn::TensorShape& inputShape2, const armnn::TensorShape& inputShape3,
16 unsigned int concatDim)
17 {
18 m_Prototext = R"(
19 node {
20 name: "graphInput0"
21 op: "Placeholder"
22 attr {
23 key: "dtype"
24 value {
25 type: DT_FLOAT
26 }
27 }
28 attr {
29 key: "shape"
30 value {
31 shape {
32 }
33 }
34 }
35 }
36 node {
37 name: "graphInput1"
38 op: "Placeholder"
39 attr {
40 key: "dtype"
41 value {
42 type: DT_FLOAT
43 }
44 }
45 attr {
46 key: "shape"
47 value {
48 shape {
49 }
50 }
51 }
52 }
53 node {
54 name: "graphInput2"
55 op: "Placeholder"
56 attr {
57 key: "dtype"
58 value {
59 type: DT_FLOAT
60 }
61 }
62 attr {
63 key: "shape"
64 value {
65 shape {
66 }
67 }
68 }
69 }
70 node {
71 name: "graphInput3"
72 op: "Placeholder"
73 attr {
74 key: "dtype"
75 value {
76 type: DT_FLOAT
77 }
78 }
79 attr {
80 key: "shape"
81 value {
82 shape {
83 }
84 }
85 }
86 }
87 node {
88 name: "Relu"
89 op: "Relu"
90 input: "graphInput0"
91 attr {
92 key: "T"
93 value {
94 type: DT_FLOAT
95 }
96 }
97 }
98 node {
99 name: "Relu_1"
100 op: "Relu"
101 input: "graphInput1"
102 attr {
103 key: "T"
104 value {
105 type: DT_FLOAT
106 }
107 }
108 }
109 node {
110 name: "Relu_2"
111 op: "Relu"
112 input: "graphInput2"
113 attr {
114 key: "T"
115 value {
116 type: DT_FLOAT
117 }
118 }
119 }
120 node {
121 name: "Relu_3"
122 op: "Relu"
123 input: "graphInput3"
124 attr {
125 key: "T"
126 value {
127 type: DT_FLOAT
128 }
129 }
130 }
131 node {
132 name: "concat/axis"
133 op: "Const"
134 attr {
135 key: "dtype"
136 value {
137 type: DT_INT32
138 }
139 }
140 attr {
141 key: "value"
142 value {
143 tensor {
144 dtype: DT_INT32
145 tensor_shape {
146 }
147 int_val: )";
148 m_Prototext += std::to_string(concatDim);
149 m_Prototext += R"(
150 }
151 }
152 }
153 }
154 node {
155 name: "concat"
156 op: "ConcatV2"
157 input: "Relu"
158 input: "Relu_1"
159 input: "concat/axis"
160 attr {
161 key: "N"
162 value {
163 i: 2
164 }
165 }
166 attr {
167 key: "T"
168 value {
169 type: DT_FLOAT
170 }
171 }
172 attr {
173 key: "Tidx"
174 value {
175 type: DT_INT32
176 }
177 }
178 }
179 node {
180 name: "concat_1/axis"
181 op: "Const"
182 attr {
183 key: "dtype"
184 value {
185 type: DT_INT32
186 }
187 }
188 attr {
189 key: "value"
190 value {
191 tensor {
192 dtype: DT_INT32
193 tensor_shape {
194 }
195 int_val: )";
196 m_Prototext += std::to_string(concatDim);
197 m_Prototext += R"(
198 }
199 }
200 }
201 }
202 node {
203 name: "concat_1"
204 op: "ConcatV2"
205 input: "Relu_2"
206 input: "Relu_3"
207 input: "concat_1/axis"
208 attr {
209 key: "N"
210 value {
211 i: 2
212 }
213 }
214 attr {
215 key: "T"
216 value {
217 type: DT_FLOAT
218 }
219 }
220 attr {
221 key: "Tidx"
222 value {
223 type: DT_INT32
224 }
225 }
226 }
227 node {
228 name: "concat_2/axis"
229 op: "Const"
230 attr {
231 key: "dtype"
232 value {
233 type: DT_INT32
234 }
235 }
236 attr {
237 key: "value"
238 value {
239 tensor {
240 dtype: DT_INT32
241 tensor_shape {
242 }
243 int_val: )";
244 m_Prototext += std::to_string(concatDim);
245 m_Prototext += R"(
246 }
247 }
248 }
249 }
250 node {
251 name: "concat_2"
252 op: "ConcatV2"
253 input: "concat"
254 input: "concat_1"
255 input: "concat_2/axis"
256 attr {
257 key: "N"
258 value {
259 i: 2
260 }
261 }
262 attr {
263 key: "T"
264 value {
265 type: DT_FLOAT
266 }
267 }
268 attr {
269 key: "Tidx"
270 value {
271 type: DT_INT32
272 }
273 }
274 }
275 )";
276
277 Setup({{ "graphInput0", inputShape0 },
278 { "graphInput1", inputShape1 },
279 { "graphInput2", inputShape2 },
280 { "graphInput3", inputShape3}}, {"concat_2"});
281 }
282};
283
284struct ConcatOfConcatsFixtureNCHW : ConcatOfConcatsFixture
285{
286 ConcatOfConcatsFixtureNCHW() : ConcatOfConcatsFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 },
287 { 1, 1, 2, 2 }, 1 ) {}
288};
289
290struct ConcatOfConcatsFixtureNHWC : ConcatOfConcatsFixture
291{
292 ConcatOfConcatsFixtureNHWC() : ConcatOfConcatsFixture({ 1, 1, 2, 2 }, { 1, 1, 2, 2 }, { 1, 1, 2, 2 },
293 { 1, 1, 2, 2 }, 3 ) {}
294};
295
296BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNCHW, ConcatOfConcatsFixtureNCHW)
297{
298 RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
299 {"graphInput1", {4.0, 5.0, 6.0, 7.0}},
300 {"graphInput2", {8.0, 9.0, 10.0, 11.0}},
301 {"graphInput3", {12.0, 13.0, 14.0, 15.0}}},
302 {{"concat_2", { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
303 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0 }}});
304}
305
306BOOST_FIXTURE_TEST_CASE(ParseConcatOfConcatsNHWC, ConcatOfConcatsFixtureNHWC)
307{
308 RunTest<4>({{"graphInput0", {0.0, 1.0, 2.0, 3.0}},
309 {"graphInput1", {4.0, 5.0, 6.0, 7.0}},
310 {"graphInput2", {8.0, 9.0, 10.0, 11.0}},
311 {"graphInput3", {12.0, 13.0, 14.0, 15.0}}},
312 {{"concat_2", { 0.0, 1.0, 4.0, 5.0, 8.0, 9.0, 12.0, 13.0,
313 2.0, 3.0, 6.0, 7.0, 10.0, 11.0, 14.0, 15.0 }}});
314}
315
316BOOST_AUTO_TEST_SUITE_END()