blob: 91427a6420c20b9eb2542917c4f2869524c89bb6 [file] [log] [blame]
Bruno Goncalves451d95b2019-02-12 22:59:22 -02001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersFixture.hpp"
8#include "../TfLiteParser.hpp"
9
10#include <string>
Bruno Goncalves451d95b2019-02-12 22:59:22 -020011
12BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
13
14struct StridedSliceFixture : public ParserFlatbuffersFixture
15{
16 explicit StridedSliceFixture(const std::string & inputShape,
17 const std::string & outputShape,
18 const std::string & beginData,
19 const std::string & endData,
20 const std::string & stridesData,
21 int beginMask = 0,
22 int endMask = 0)
23 {
24 m_JsonString = R"(
25 {
26 "version": 3,
27 "operator_codes": [ { "builtin_code": "STRIDED_SLICE" } ],
28 "subgraphs": [ {
29 "tensors": [
30 {
31 "shape": )" + inputShape + R"(,
32 "type": "FLOAT32",
33 "buffer": 0,
34 "name": "inputTensor",
35 "quantization": {
36 "min": [ 0.0 ],
37 "max": [ 255.0 ],
38 "scale": [ 1.0 ],
39 "zero_point": [ 0 ],
40 }
41 },
42 {
43 "shape": [ 4 ],
44 "type": "INT32",
45 "buffer": 1,
46 "name": "beginTensor",
47 "quantization": {
48 }
49 },
50 {
51 "shape": [ 4 ],
52 "type": "INT32",
53 "buffer": 2,
54 "name": "endTensor",
55 "quantization": {
56 }
57 },
58 {
59 "shape": [ 4 ],
60 "type": "INT32",
61 "buffer": 3,
62 "name": "stridesTensor",
63 "quantization": {
64 }
65 },
66 {
67 "shape": )" + outputShape + R"( ,
68 "type": "FLOAT32",
69 "buffer": 4,
70 "name": "outputTensor",
71 "quantization": {
72 "min": [ 0.0 ],
73 "max": [ 255.0 ],
74 "scale": [ 1.0 ],
75 "zero_point": [ 0 ],
76 }
77 }
78 ],
79 "inputs": [ 0, 1, 2, 3 ],
80 "outputs": [ 4 ],
81 "operators": [
82 {
83 "opcode_index": 0,
84 "inputs": [ 0, 1, 2, 3 ],
85 "outputs": [ 4 ],
86 "builtin_options_type": "StridedSliceOptions",
87 "builtin_options": {
88 "begin_mask": )" + std::to_string(beginMask) + R"(,
89 "end_mask": )" + std::to_string(endMask) + R"(
90 },
91 "custom_options_format": "FLEXBUFFERS"
92 }
93 ],
94 } ],
95 "buffers" : [
96 { },
97 { "data": )" + beginData + R"(, },
98 { "data": )" + endData + R"(, },
99 { "data": )" + stridesData + R"(, },
100 { }
101 ]
102 }
103 )";
104 Setup();
105 }
106};
107
108struct StridedSlice4DFixture : StridedSliceFixture
109{
110 StridedSlice4DFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
111 "[ 1, 2, 3, 1 ]", // outputShape
112 "[ 1,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]", // beginData
113 "[ 2,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]", // endData
114 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]" // stridesData
115 ) {}
116};
117
118BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture)
119{
120 RunTest<4, armnn::DataType::Float32>(
121 0,
122 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
123
124 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
125
126 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
127
128 {{"outputTensor", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
129}
130
131struct StridedSlice4DReverseFixture : StridedSliceFixture
132{
133 StridedSlice4DReverseFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
134 "[ 1, 2, 3, 1 ]", // outputShape
135 "[ 1,0,0,0, "
136 "255,255,255,255, "
137 "0,0,0,0, "
138 "0,0,0,0 ]", // beginData [ 1 -1 0 0 ]
139 "[ 2,0,0,0, "
140 "253,255,255,255, "
141 "3,0,0,0, "
142 "1,0,0,0 ]", // endData [ 2 -3 3 1 ]
143 "[ 1,0,0,0, "
144 "255,255,255,255, "
145 "1,0,0,0, "
146 "1,0,0,0 ]" // stridesData [ 1 -1 1 1 ]
147 ) {}
148};
149
150BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse, StridedSlice4DReverseFixture)
151{
152 RunTest<4, armnn::DataType::Float32>(
153 0,
154 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
155
156 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
157
158 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
159
160 {{"outputTensor", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
161}
162
163struct StridedSliceSimpleStrideFixture : StridedSliceFixture
164{
165 StridedSliceSimpleStrideFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
166 "[ 2, 1, 2, 1 ]", // outputShape
167 "[ 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]", // beginData
168 "[ 3,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]", // endData
169 "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 1,0,0,0 ]" // stridesData
170 ) {}
171};
172
173BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride, StridedSliceSimpleStrideFixture)
174{
175 RunTest<4, armnn::DataType::Float32>(
176 0,
177 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
178
179 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
180
181 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
182
183 {{"outputTensor", { 1.0f, 1.0f,
184
185 5.0f, 5.0f }}});
186}
187
188struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
189{
190 StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
191 "[ 3, 2, 3, 1 ]", // outputShape
192 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // beginData
193 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // endData
194 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // stridesData
195 (1 << 4) - 1, // beginMask
196 (1 << 4) - 1 // endMask
197 ) {}
198};
199
200BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask, StridedSliceSimpleRangeMaskFixture)
201{
202 RunTest<4, armnn::DataType::Float32>(
203 0,
204 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
205
206 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
207
208 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
209
210 {{"outputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
211
212 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
213
214 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
215}
216
217BOOST_AUTO_TEST_SUITE_END()