blob: 199134d70b79680cb8c703bb9ca4bf7b87db3105 [file] [log] [blame]
Bruno Goncalves451d95b2019-02-12 22:59:22 -02001//
Finn Williamsb49ed182021-06-29 15:50:08 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
Bruno Goncalves451d95b2019-02-12 22:59:22 -02003// SPDX-License-Identifier: MIT
4//
5
Bruno Goncalves451d95b2019-02-12 22:59:22 -02006#include "ParserFlatbuffersFixture.hpp"
Bruno Goncalves451d95b2019-02-12 22:59:22 -02007
Bruno Goncalves451d95b2019-02-12 22:59:22 -02008
Sadik Armagan1625efc2021-06-10 18:24:34 +01009TEST_SUITE("TensorflowLiteParser_StridedSlice")
10{
Bruno Goncalves451d95b2019-02-12 22:59:22 -020011struct StridedSliceFixture : public ParserFlatbuffersFixture
12{
13 explicit StridedSliceFixture(const std::string & inputShape,
14 const std::string & outputShape,
15 const std::string & beginData,
16 const std::string & endData,
17 const std::string & stridesData,
18 int beginMask = 0,
19 int endMask = 0)
20 {
21 m_JsonString = R"(
22 {
23 "version": 3,
24 "operator_codes": [ { "builtin_code": "STRIDED_SLICE" } ],
25 "subgraphs": [ {
26 "tensors": [
27 {
28 "shape": )" + inputShape + R"(,
29 "type": "FLOAT32",
30 "buffer": 0,
31 "name": "inputTensor",
32 "quantization": {
33 "min": [ 0.0 ],
34 "max": [ 255.0 ],
35 "scale": [ 1.0 ],
36 "zero_point": [ 0 ],
37 }
38 },
39 {
40 "shape": [ 4 ],
41 "type": "INT32",
42 "buffer": 1,
43 "name": "beginTensor",
44 "quantization": {
45 }
46 },
47 {
48 "shape": [ 4 ],
49 "type": "INT32",
50 "buffer": 2,
51 "name": "endTensor",
52 "quantization": {
53 }
54 },
55 {
56 "shape": [ 4 ],
57 "type": "INT32",
58 "buffer": 3,
59 "name": "stridesTensor",
60 "quantization": {
61 }
62 },
63 {
64 "shape": )" + outputShape + R"( ,
65 "type": "FLOAT32",
66 "buffer": 4,
67 "name": "outputTensor",
68 "quantization": {
69 "min": [ 0.0 ],
70 "max": [ 255.0 ],
71 "scale": [ 1.0 ],
72 "zero_point": [ 0 ],
73 }
74 }
75 ],
76 "inputs": [ 0, 1, 2, 3 ],
77 "outputs": [ 4 ],
78 "operators": [
79 {
80 "opcode_index": 0,
81 "inputs": [ 0, 1, 2, 3 ],
82 "outputs": [ 4 ],
83 "builtin_options_type": "StridedSliceOptions",
84 "builtin_options": {
85 "begin_mask": )" + std::to_string(beginMask) + R"(,
86 "end_mask": )" + std::to_string(endMask) + R"(
87 },
88 "custom_options_format": "FLEXBUFFERS"
89 }
90 ],
91 } ],
92 "buffers" : [
93 { },
94 { "data": )" + beginData + R"(, },
95 { "data": )" + endData + R"(, },
96 { "data": )" + stridesData + R"(, },
97 { }
98 ]
99 }
100 )";
101 Setup();
102 }
103};
104
105struct StridedSlice4DFixture : StridedSliceFixture
106{
107 StridedSlice4DFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
108 "[ 1, 2, 3, 1 ]", // outputShape
109 "[ 1,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]", // beginData
110 "[ 2,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]", // endData
111 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]" // stridesData
112 ) {}
113};
114
Sadik Armagan1625efc2021-06-10 18:24:34 +0100115TEST_CASE_FIXTURE(StridedSlice4DFixture, "StridedSlice4D")
Bruno Goncalves451d95b2019-02-12 22:59:22 -0200116{
117 RunTest<4, armnn::DataType::Float32>(
118 0,
119 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
120
121 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
122
123 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
124
125 {{"outputTensor", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}});
126}
127
128struct StridedSlice4DReverseFixture : StridedSliceFixture
129{
130 StridedSlice4DReverseFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
131 "[ 1, 2, 3, 1 ]", // outputShape
132 "[ 1,0,0,0, "
133 "255,255,255,255, "
134 "0,0,0,0, "
135 "0,0,0,0 ]", // beginData [ 1 -1 0 0 ]
136 "[ 2,0,0,0, "
137 "253,255,255,255, "
138 "3,0,0,0, "
139 "1,0,0,0 ]", // endData [ 2 -3 3 1 ]
140 "[ 1,0,0,0, "
141 "255,255,255,255, "
142 "1,0,0,0, "
143 "1,0,0,0 ]" // stridesData [ 1 -1 1 1 ]
144 ) {}
145};
146
Sadik Armagan1625efc2021-06-10 18:24:34 +0100147TEST_CASE_FIXTURE(StridedSlice4DReverseFixture, "StridedSlice4DReverse")
Bruno Goncalves451d95b2019-02-12 22:59:22 -0200148{
149 RunTest<4, armnn::DataType::Float32>(
150 0,
151 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
152
153 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
154
155 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
156
157 {{"outputTensor", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}});
158}
159
160struct StridedSliceSimpleStrideFixture : StridedSliceFixture
161{
162 StridedSliceSimpleStrideFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
163 "[ 2, 1, 2, 1 ]", // outputShape
164 "[ 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]", // beginData
165 "[ 3,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]", // endData
166 "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 1,0,0,0 ]" // stridesData
167 ) {}
168};
169
Sadik Armagan1625efc2021-06-10 18:24:34 +0100170TEST_CASE_FIXTURE(StridedSliceSimpleStrideFixture, "StridedSliceSimpleStride")
Bruno Goncalves451d95b2019-02-12 22:59:22 -0200171{
172 RunTest<4, armnn::DataType::Float32>(
173 0,
174 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
175
176 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
177
178 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
179
180 {{"outputTensor", { 1.0f, 1.0f,
181
182 5.0f, 5.0f }}});
183}
184
185struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture
186{
187 StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape
188 "[ 3, 2, 3, 1 ]", // outputShape
189 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // beginData
190 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // endData
191 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // stridesData
192 (1 << 4) - 1, // beginMask
193 (1 << 4) - 1 // endMask
194 ) {}
195};
196
Sadik Armagan1625efc2021-06-10 18:24:34 +0100197TEST_CASE_FIXTURE(StridedSliceSimpleRangeMaskFixture, "StridedSliceSimpleRangeMask")
Bruno Goncalves451d95b2019-02-12 22:59:22 -0200198{
199 RunTest<4, armnn::DataType::Float32>(
200 0,
201 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
202
203 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
204
205 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}},
206
207 {{"outputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
208
209 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
210
211 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}});
212}
213
Sadik Armagan1625efc2021-06-10 18:24:34 +0100214}