blob: 89faf756791c65d02489c952d450786635265dbe [file] [log] [blame]
Georgios Pinitas5e90aab2020-02-14 14:46:51 +00001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "armnnTfParser/ITfParser.hpp"
7
8#include "ParserPrototxtFixture.hpp"
9#include <PrototxtConversions.hpp>
10
11#include <boost/test/unit_test.hpp>
12
13BOOST_AUTO_TEST_SUITE(TensorflowParser)
14
15namespace {
16// helper for setting the dimensions in prototxt
17void shapeHelper(const armnn::TensorShape& shape, std::string& text){
18 for(u_int i = 0; i < shape.GetNumDimensions(); ++i) {
19 text.append(R"(dim {
20 size: )");
21 text.append(std::to_string(shape[i]));
22 text.append(R"(
23 })");
24 }
25}
26
27// helper for converting from integer to octal representation
28void octalHelper(const std::vector<int>& content, std::string& text){
29 for (unsigned int i = 0; i < content.size(); ++i)
30 {
31 text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(content[i])));
32 }
33}
34} // namespace
35
36struct StridedSliceFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
37{
38 StridedSliceFixture(const armnn::TensorShape& inputShape,
39 const std::vector<int>& beginData,
40 const std::vector<int>& endData,
41 const std::vector<int>& stridesData,
42 int beginMask = 0,
43 int endMask = 0,
44 int ellipsisMask = 0,
45 int newAxisMask = 0,
46 int shrinkAxisMask = 0)
47 {
48 m_Prototext = R"(
49 node {
50 name: "input"
51 op: "Placeholder"
52 attr {
53 key: "dtype"
54 value {
55 type: DT_FLOAT
56 }
57 }
58 attr {
59 key: "shape"
60 value {
61 shape {)";
62 shapeHelper(inputShape, m_Prototext);
63 m_Prototext.append(R"(
64 }
65 }
66 }
67 }
68 node {
69 name: "begin"
70 op: "Const"
71 attr {
72 key: "dtype"
73 value {
74 type: DT_INT32
75 }
76 }
77 attr {
78 key: "value"
79 value {
80 tensor {
81 dtype: DT_INT32
82 tensor_shape {
83 dim {
84 size: )");
85 m_Prototext += std::to_string(beginData.size());
86 m_Prototext.append(R"(
87 }
88 }
89 tensor_content: ")");
90 octalHelper(beginData, m_Prototext);
91 m_Prototext.append(R"("
92 }
93 }
94 }
95 }
96 node {
97 name: "end"
98 op: "Const"
99 attr {
100 key: "dtype"
101 value {
102 type: DT_INT32
103 }
104 }
105 attr {
106 key: "value"
107 value {
108 tensor {
109 dtype: DT_INT32
110 tensor_shape {
111 dim {
112 size: )");
113 m_Prototext += std::to_string(endData.size());
114 m_Prototext.append(R"(
115 }
116 }
117 tensor_content: ")");
118 octalHelper(endData, m_Prototext);
119 m_Prototext.append(R"("
120 }
121 }
122 }
123 }
124 node {
125 name: "strides"
126 op: "Const"
127 attr {
128 key: "dtype"
129 value {
130 type: DT_INT32
131 }
132 }
133 attr {
134 key: "value"
135 value {
136 tensor {
137 dtype: DT_INT32
138 tensor_shape {
139 dim {
140 size: )");
141 m_Prototext += std::to_string(stridesData.size());
142 m_Prototext.append(R"(
143 }
144 }
145 tensor_content: ")");
146 octalHelper(stridesData, m_Prototext);
147 m_Prototext.append(R"("
148 }
149 }
150 }
151 }
152 node {
153 name: "output"
154 op: "StridedSlice"
155 input: "input"
156 input: "begin"
157 input: "end"
158 input: "strides"
159 attr {
160 key: "begin_mask"
161 value {
162 i: )");
163 m_Prototext += std::to_string(beginMask);
164 m_Prototext.append(R"(
165 }
166 }
167 attr {
168 key: "end_mask"
169 value {
170 i: )");
171 m_Prototext += std::to_string(endMask);
172 m_Prototext.append(R"(
173 }
174 }
175 attr {
176 key: "ellipsis_mask"
177 value {
178 i: )");
179 m_Prototext += std::to_string(ellipsisMask);
180 m_Prototext.append(R"(
181 }
182 }
183 attr {
184 key: "new_axis_mask"
185 value {
186 i: )");
187 m_Prototext += std::to_string(newAxisMask);
188 m_Prototext.append(R"(
189 }
190 }
191 attr {
192 key: "shrink_axis_mask"
193 value {
194 i: )");
195 m_Prototext += std::to_string(shrinkAxisMask);
196 m_Prototext.append(R"(
197 }
198 }
199 })");
200
201 Setup({ { "input", inputShape } }, { "output" });
202 }
203};
204
205struct StridedSlice4DFixture : StridedSliceFixture
206{
207 StridedSlice4DFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
208 { 1, 0, 0, 0 }, // beginData
209 { 2, 2, 3, 1 }, // endData
210 { 1, 1, 1, 1 } // stridesData
211 ) {}
212};
213
214BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture)
215{
216 RunTest<4>(
217 {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
218 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
219 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
220 {{"output", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
221}
222
223struct StridedSlice4DReverseFixture : StridedSliceFixture
224{
225
226 StridedSlice4DReverseFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
227 { 1, -1, 0, 0 }, // beginData
228 { 2, -3, 3, 1 }, // endData
229 { 1, -1, 1, 1 } // stridesData
230 ) {}
231};
232
233BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse, StridedSlice4DReverseFixture)
234{
235 RunTest<4>(
236 {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
237 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
238 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
239 {{"output", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
240}
241
242struct StridedSliceSimpleStrideFixture : StridedSliceFixture
243{
244 StridedSliceSimpleStrideFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
245 { 0, 0, 0, 0 }, // beginData
246 { 3, 2, 3, 1 }, // endData
247 { 2, 2, 2, 1 } // stridesData
248 ) {}
249};
250
251BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride, StridedSliceSimpleStrideFixture)
252{
253 RunTest<4>(
254 {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
255 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
256 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
257 {{"output", { 1.0f, 1.0f,
258 5.0f, 5.0f }}});
259}
260
261struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
262{
263 StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture({ 3, 2, 3, 1 }, // inputShape
264 { 1, 1, 1, 1 }, // beginData
265 { 1, 1, 1, 1 }, // endData
266 { 1, 1, 1, 1 }, // stridesData
267 (1 << 4) - 1, // beginMask
268 (1 << 4) - 1 // endMask
269 ) {}
270};
271
272BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask, StridedSliceSimpleRangeMaskFixture)
273{
274 RunTest<4>(
275 {{"input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
276 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
277 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
278 {{"output", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
279 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
280 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
281}
282
283BOOST_AUTO_TEST_SUITE_END()