blob: 74c77c049ff8423711227d2c5dfa5ed6365c9901 [file] [log] [blame]
Bruno Goncalves451d95b2019-02-12 22:59:22 -02001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <boost/test/unit_test.hpp>
7#include "ParserFlatbuffersFixture.hpp"
8#include "../TfLiteParser.hpp"
9
10#include <string>
11#include <iostream>
12
13BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14
15struct StridedSliceFixture : public ParserFlatbuffersFixture
16{
17 explicit StridedSliceFixture(const std::string & inputShape,
18 const std::string & outputShape,
19 const std::string & beginData,
20 const std::string & endData,
21 const std::string & stridesData,
22 int beginMask = 0,
23 int endMask = 0)
24 {
25 m_JsonString = R"(
26 {
27 "version": 3,
28 "operator_codes": [ { "builtin_code": "STRIDED_SLICE" } ],
29 "subgraphs": [ {
30 "tensors": [
31 {
32 "shape": )" + inputShape + R"(,
33 "type": "FLOAT32",
34 "buffer": 0,
35 "name": "inputTensor",
36 "quantization": {
37 "min": [ 0.0 ],
38 "max": [ 255.0 ],
39 "scale": [ 1.0 ],
40 "zero_point": [ 0 ],
41 }
42 },
43 {
44 "shape": [ 4 ],
45 "type": "INT32",
46 "buffer": 1,
47 "name": "beginTensor",
48 "quantization": {
49 }
50 },
51 {
52 "shape": [ 4 ],
53 "type": "INT32",
54 "buffer": 2,
55 "name": "endTensor",
56 "quantization": {
57 }
58 },
59 {
60 "shape": [ 4 ],
61 "type": "INT32",
62 "buffer": 3,
63 "name": "stridesTensor",
64 "quantization": {
65 }
66 },
67 {
68 "shape": )" + outputShape + R"( ,
69 "type": "FLOAT32",
70 "buffer": 4,
71 "name": "outputTensor",
72 "quantization": {
73 "min": [ 0.0 ],
74 "max": [ 255.0 ],
75 "scale": [ 1.0 ],
76 "zero_point": [ 0 ],
77 }
78 }
79 ],
80 "inputs": [ 0, 1, 2, 3 ],
81 "outputs": [ 4 ],
82 "operators": [
83 {
84 "opcode_index": 0,
85 "inputs": [ 0, 1, 2, 3 ],
86 "outputs": [ 4 ],
87 "builtin_options_type": "StridedSliceOptions",
88 "builtin_options": {
89 "begin_mask": )" + std::to_string(beginMask) + R"(,
90 "end_mask": )" + std::to_string(endMask) + R"(
91 },
92 "custom_options_format": "FLEXBUFFERS"
93 }
94 ],
95 } ],
96 "buffers" : [
97 { },
98 { "data": )" + beginData + R"(, },
99 { "data": )" + endData + R"(, },
100 { "data": )" + stridesData + R"(, },
101 { }
102 ]
103 }
104 )";
105 Setup();
106 }
107};
108
109struct StridedSlice4DFixture : StridedSliceFixture
110{
111 StridedSlice4DFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
112 "[ 1, 2, 3, 1 ]", // outputShape
113 "[ 1,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]", // beginData
114 "[ 2,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]", // endData
115 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]" // stridesData
116 ) {}
117};
118
119BOOST_FIXTURE_TEST_CASE(StridedSlice4D, StridedSlice4DFixture)
120{
121 RunTest<4, armnn::DataType::Float32>(
122 0,
123 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
124
125 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
126
127 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
128
129 {{"outputTensor", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
130}
131
132struct StridedSlice4DReverseFixture : StridedSliceFixture
133{
134 StridedSlice4DReverseFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
135 "[ 1, 2, 3, 1 ]", // outputShape
136 "[ 1,0,0,0, "
137 "255,255,255,255, "
138 "0,0,0,0, "
139 "0,0,0,0 ]", // beginData [ 1 -1 0 0 ]
140 "[ 2,0,0,0, "
141 "253,255,255,255, "
142 "3,0,0,0, "
143 "1,0,0,0 ]", // endData [ 2 -3 3 1 ]
144 "[ 1,0,0,0, "
145 "255,255,255,255, "
146 "1,0,0,0, "
147 "1,0,0,0 ]" // stridesData [ 1 -1 1 1 ]
148 ) {}
149};
150
151BOOST_FIXTURE_TEST_CASE(StridedSlice4DReverse, StridedSlice4DReverseFixture)
152{
153 RunTest<4, armnn::DataType::Float32>(
154 0,
155 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
156
157 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
158
159 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
160
161 {{"outputTensor", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
162}
163
164struct StridedSliceSimpleStrideFixture : StridedSliceFixture
165{
166 StridedSliceSimpleStrideFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
167 "[ 2, 1, 2, 1 ]", // outputShape
168 "[ 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]", // beginData
169 "[ 3,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]", // endData
170 "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 1,0,0,0 ]" // stridesData
171 ) {}
172};
173
174BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleStride, StridedSliceSimpleStrideFixture)
175{
176 RunTest<4, armnn::DataType::Float32>(
177 0,
178 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
179
180 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
181
182 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
183
184 {{"outputTensor", { 1.0f, 1.0f,
185
186 5.0f, 5.0f }}});
187}
188
189struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
190{
191 StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
192 "[ 3, 2, 3, 1 ]", // outputShape
193 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // beginData
194 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // endData
195 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // stridesData
196 (1 << 4) - 1, // beginMask
197 (1 << 4) - 1 // endMask
198 ) {}
199};
200
201BOOST_FIXTURE_TEST_CASE(StridedSliceSimpleRangeMask, StridedSliceSimpleRangeMaskFixture)
202{
203 RunTest<4, armnn::DataType::Float32>(
204 0,
205 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
206
207 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
208
209 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
210
211 {{"outputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
212
213 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
214
215 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
216}
217
218BOOST_AUTO_TEST_SUITE_END()