blob: ad95641cd14b65904c26270adfc0e529b3ad47ec [file] [log] [blame]
Conor Kennedyc2130a02018-12-05 11:05:54 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "armnnTfParser/ITfParser.hpp"
8#include "ParserPrototxtFixture.hpp"
9
10BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12struct ExpandDimsFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
13{
14 ExpandDimsFixture(const std::string& expandDim)
15 {
16 m_Prototext =
17 "node { \n"
18 " name: \"graphInput\" \n"
19 " op: \"Placeholder\" \n"
20 " attr { \n"
21 " key: \"dtype\" \n"
22 " value { \n"
23 " type: DT_FLOAT \n"
24 " } \n"
25 " } \n"
26 " attr { \n"
27 " key: \"shape\" \n"
28 " value { \n"
29 " shape { \n"
30 " } \n"
31 " } \n"
32 " } \n"
33 " } \n"
34 "node { \n"
35 " name: \"ExpandDims\" \n"
36 " op: \"ExpandDims\" \n"
37 " input: \"graphInput\" \n"
38 " attr { \n"
39 " key: \"T\" \n"
40 " value { \n"
41 " type: DT_FLOAT \n"
42 " } \n"
43 " } \n"
44 " attr { \n"
45 " key: \"Tdim\" \n"
46 " value { \n";
47 m_Prototext += "i:" + expandDim;
48 m_Prototext +=
49 " } \n"
50 " } \n"
51 "} \n";
52
53 SetupSingleInputSingleOutput({ 2, 3, 5 }, "graphInput", "ExpandDims");
54 }
55};
56
57struct ExpandZeroDim : ExpandDimsFixture
58{
59 ExpandZeroDim() : ExpandDimsFixture("0") {}
60};
61
62struct ExpandTwoDim : ExpandDimsFixture
63{
64 ExpandTwoDim() : ExpandDimsFixture("2") {}
65};
66
67struct ExpandThreeDim : ExpandDimsFixture
68{
69 ExpandThreeDim() : ExpandDimsFixture("3") {}
70};
71
72struct ExpandMinusOneDim : ExpandDimsFixture
73{
74 ExpandMinusOneDim() : ExpandDimsFixture("-1") {}
75};
76
77struct ExpandMinusThreeDim : ExpandDimsFixture
78{
79 ExpandMinusThreeDim() : ExpandDimsFixture("-3") {}
80};
81
82BOOST_FIXTURE_TEST_CASE(ParseExpandZeroDim, ExpandZeroDim)
83{
84 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
85 armnn::TensorShape({1, 2, 3, 5})));
86}
87
88BOOST_FIXTURE_TEST_CASE(ParseExpandTwoDim, ExpandTwoDim)
89{
90 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
91 armnn::TensorShape({2, 3, 1, 5})));
92}
93
94BOOST_FIXTURE_TEST_CASE(ParseExpandThreeDim, ExpandThreeDim)
95{
96 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
97 armnn::TensorShape({2, 3, 5, 1})));
98}
99
100BOOST_FIXTURE_TEST_CASE(ParseExpandMinusOneDim, ExpandMinusOneDim)
101{
102 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
103 armnn::TensorShape({2, 3, 5, 1})));
104}
105
106BOOST_FIXTURE_TEST_CASE(ParseExpandMinusThreeDim, ExpandMinusThreeDim)
107{
108 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
109 armnn::TensorShape({2, 1, 3, 5})));
110}
111
Jan Eilers1f3b49b2020-09-08 08:57:40 +0100112struct ExpandDimsAsInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
113{
114 ExpandDimsAsInputFixture(const std::string& expandDim,
115 const bool wrongDataType = false,
116 const std::string& numElements = "1")
117 {
118 std::string dataType = (wrongDataType) ? "DT_FLOAT" : "DT_INT32";
119 std::string val = (wrongDataType) ? ("float_val: " + expandDim + ".0") : ("int_val: "+ expandDim);
120
121 m_Prototext = R"(
122 node {
123 name: "a"
124 op: "Placeholder"
125 attr {
126 key: "dtype"
127 value {
128 type: DT_FLOAT
129 }
130 }
131 attr {
132 key: "shape"
133 value {
134 shape {
135 dim {
136 size: 1
137 }
138 dim {
139 size: 4
140 }
141 }
142 }
143 }
144 }
145 node {
146 name: "b"
147 op: "Const"
148 attr {
149 key: "dtype"
150 value {
151 type: )" + dataType + R"(
152 }
153 }
154 attr {
155 key: "value"
156 value {
157 tensor {
158 dtype: )" + dataType + R"(
159 tensor_shape {
160 dim {
161 size: )" + numElements + R"(
162 }
163 }
164 )" + val + R"(
165 }
166 }
167 }
168 }
169 node {
170 name: "ExpandDims"
171 op: "ExpandDims"
172 input: "a"
173 input: "b"
174 attr {
175 key: "T"
176 value {
177 type: DT_FLOAT
178 }
179 }
180 attr {
181 key: "Tdim"
182 value {
183 type: DT_INT32
184 }
185 }
186 }
187 versions {
188 producer: 134
189 })";
190 }
191};
192
193struct ExpandDimAsInput : ExpandDimsAsInputFixture
194{
195 ExpandDimAsInput() : ExpandDimsAsInputFixture("0")
196 {
197 Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" });
198 }
199};
200
201
202BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInput, ExpandDimAsInput)
203{
204 // Axis parameter that describes which axis/dim should be expanded is passed as a second input
205 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("ExpandDims").second.GetShape() ==
206 armnn::TensorShape({1, 1, 4})));
207}
208
209struct ExpandDimAsInputWrongDataType : ExpandDimsAsInputFixture
210{
211 ExpandDimAsInputWrongDataType() : ExpandDimsAsInputFixture("0", true, "1") {}
212};
213
214BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongDataType, ExpandDimAsInputWrongDataType)
215{
216 // Axis parameter that describes which axis/dim should be expanded is passed as a second input
217 // Axis parameter is of wrong data type (float instead of int32)
218 BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException);
219}
220
221struct ExpandDimAsInputWrongShape : ExpandDimsAsInputFixture
222{
223 ExpandDimAsInputWrongShape() : ExpandDimsAsInputFixture("0", false, "2") {}
224};
225
226BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsInputWrongShape, ExpandDimAsInputWrongShape)
227{
228 // Axis parameter that describes which axis/dim should be expanded is passed as a second input
229 // Axis parameter is of wrong shape
230 BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }), armnn::ParseException);
231}
232
233struct ExpandDimsAsNotConstInputFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
234{
235 ExpandDimsAsNotConstInputFixture()
236 {
237 m_Prototext = R"(
238 node {
239 name: "a"
240 op: "Placeholder"
241 attr {
242 key: "dtype"
243 value {
244 type: DT_FLOAT
245 }
246 }
247 attr {
248 key: "shape"
249 value {
250 shape {
251 dim {
252 size: 1
253 }
254 dim {
255 size: 4
256 }
257 }
258 }
259 }
260 }
261 node {
262 name: "b"
263 op: "Placeholder"
264 attr {
265 key: "dtype"
266 value {
267 type: DT_INT32
268 }
269 }
270 attr {
271 key: "shape"
272 value {
273 shape {
274 dim {
275 size: 1
276 }
277 }
278 }
279 }
280 }
281 node {
282 name: "ExpandDims"
283 op: "ExpandDims"
284 input: "a"
285 input: "b"
286 attr {
287 key: "T"
288 value {
289 type: DT_FLOAT
290 }
291 }
292 attr {
293 key: "Tdim"
294 value {
295 type: DT_INT32
296 }
297 }
298 }
299 versions {
300 producer: 134
301 })";
302 }
303};
304
305BOOST_FIXTURE_TEST_CASE(ParseExpandDimAsNotConstInput, ExpandDimsAsNotConstInputFixture)
306{
307 // Axis parameter that describes which axis/dim should be expanded is passed as a second input.
308 // But is not a constant tensor --> not supported
309 BOOST_REQUIRE_THROW(Setup({{"a", {1,4}} ,{"b",{1,1}}}, { "ExpandDims" }),
310 armnn::ParseException);
311}
312
Conor Kennedyc2130a02018-12-05 11:05:54 +0000313BOOST_AUTO_TEST_SUITE_END()