blob: bd0584936e3dd3683b3dc16a694014e1cd002cab [file] [log] [blame]
Jan Eilers2ffddda2021-02-03 09:14:30 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "SliceTestHelper.hpp"
7
8#include <armnn_delegate.hpp>
9
10#include <flatbuffers/flatbuffers.h>
11#include <tensorflow/lite/schema/schema_generated.h>
12
13#include <doctest/doctest.h>
14
15namespace armnnDelegate
16{
17
18void StridedSlice4DTest(std::vector<armnn::BackendId>& backends)
19{
20 std::vector<int32_t> inputShape { 3, 2, 3, 1 };
21 std::vector<int32_t> outputShape { 1, 2, 3, 1 };
22 std::vector<int32_t> beginShape { 4 };
23 std::vector<int32_t> endShape { 4 };
24 std::vector<int32_t> strideShape { 4 };
25
26 std::vector<int32_t> beginData { 1, 0, 0, 0 };
27 std::vector<int32_t> endData { 2, 2, 3, 1 };
28 std::vector<int32_t> strideData { 1, 1, 1, 1 };
29 std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
30 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
31 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
32 std::vector<float> outputData { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f };
33
34 StridedSliceTestImpl<float>(
35 backends,
36 inputData,
37 outputData,
38 beginData,
39 endData,
40 strideData,
41 inputShape,
42 beginShape,
43 endShape,
44 strideShape,
45 outputShape
46 );
47}
48
49void StridedSlice4DReverseTest(std::vector<armnn::BackendId>& backends)
50{
51 std::vector<int32_t> inputShape { 3, 2, 3, 1 };
52 std::vector<int32_t> outputShape { 1, 2, 3, 1 };
53 std::vector<int32_t> beginShape { 4 };
54 std::vector<int32_t> endShape { 4 };
55 std::vector<int32_t> strideShape { 4 };
56
57 std::vector<int32_t> beginData { 1, -1, 0, 0 };
58 std::vector<int32_t> endData { 2, -3, 3, 1 };
59 std::vector<int32_t> strideData { 1, -1, 1, 1 };
60 std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
61 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
62 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
63 std::vector<float> outputData { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f };
64
65 StridedSliceTestImpl<float>(
66 backends,
67 inputData,
68 outputData,
69 beginData,
70 endData,
71 strideData,
72 inputShape,
73 beginShape,
74 endShape,
75 strideShape,
76 outputShape
77 );
78}
79
80void StridedSliceSimpleStrideTest(std::vector<armnn::BackendId>& backends)
81{
82 std::vector<int32_t> inputShape { 3, 2, 3, 1 };
83 std::vector<int32_t> outputShape { 2, 1, 2, 1 };
84 std::vector<int32_t> beginShape { 4 };
85 std::vector<int32_t> endShape { 4 };
86 std::vector<int32_t> strideShape { 4 };
87
88 std::vector<int32_t> beginData { 0, 0, 0, 0 };
89 std::vector<int32_t> endData { 3, 2, 3, 1 };
90 std::vector<int32_t> strideData { 2, 2, 2, 1 };
91 std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
92 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
93 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
94 std::vector<float> outputData { 1.0f, 1.0f,
95 5.0f, 5.0f };
96
97 StridedSliceTestImpl<float>(
98 backends,
99 inputData,
100 outputData,
101 beginData,
102 endData,
103 strideData,
104 inputShape,
105 beginShape,
106 endShape,
107 strideShape,
108 outputShape
109 );
110}
111
112void StridedSliceSimpleRangeMaskTest(std::vector<armnn::BackendId>& backends)
113{
114 std::vector<int32_t> inputShape { 3, 2, 3, 1 };
115 std::vector<int32_t> outputShape { 3, 2, 3, 1 };
116 std::vector<int32_t> beginShape { 4 };
117 std::vector<int32_t> endShape { 4 };
118 std::vector<int32_t> strideShape { 4 };
119
120 std::vector<int32_t> beginData { 1, 1, 1, 1 };
121 std::vector<int32_t> endData { 1, 1, 1, 1 };
122 std::vector<int32_t> strideData { 1, 1, 1, 1 };
123
124 int beginMask = -1;
125 int endMask = -1;
126
127 std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
128 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
129 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
130 std::vector<float> outputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
131 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
132 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
133
134 StridedSliceTestImpl<float>(
135 backends,
136 inputData,
137 outputData,
138 beginData,
139 endData,
140 strideData,
141 inputShape,
142 beginShape,
143 endShape,
144 strideShape,
145 outputShape,
146 beginMask,
147 endMask
148 );
149}
150
151
152TEST_SUITE("StridedSlice_CpuRefTests")
153{
154
155TEST_CASE ("StridedSlice_4D_CpuRef_Test")
156{
157 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
158 StridedSlice4DTest(backends);
159}
160
161TEST_CASE ("StridedSlice_4D_Reverse_CpuRef_Test")
162{
163 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
164 StridedSlice4DReverseTest(backends);
165}
166
167TEST_CASE ("StridedSlice_SimpleStride_CpuRef_Test")
168{
169 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
170 StridedSliceSimpleStrideTest(backends);
171}
172
173TEST_CASE ("StridedSlice_SimpleRange_CpuRef_Test")
174{
175 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
176 StridedSliceSimpleRangeMaskTest(backends);
177}
178
179} // StridedSlice_CpuRefTests TestSuite
180
181
182
183TEST_SUITE("StridedSlice_CpuAccTests")
184{
185
186TEST_CASE ("StridedSlice_4D_CpuAcc_Test")
187{
188 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
189 StridedSlice4DTest(backends);
190}
191
192TEST_CASE ("StridedSlice_4D_Reverse_CpuAcc_Test")
193{
194 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
195 StridedSlice4DReverseTest(backends);
196}
197
198TEST_CASE ("StridedSlice_SimpleStride_CpuAcc_Test")
199{
200 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
201 StridedSliceSimpleStrideTest(backends);
202}
203
204TEST_CASE ("StridedSlice_SimpleRange_CpuAcc_Test")
205{
206 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
207 StridedSliceSimpleRangeMaskTest(backends);
208}
209
210} // StridedSlice_CpuAccTests TestSuite
211
212
213
214TEST_SUITE("StridedSlice_GpuAccTests")
215{
216
217TEST_CASE ("StridedSlice_4D_GpuAcc_Test")
218{
219 std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
220 StridedSlice4DTest(backends);
221}
222
223TEST_CASE ("StridedSlice_4D_Reverse_GpuAcc_Test")
224{
225 std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
226 StridedSlice4DReverseTest(backends);
227}
228
229TEST_CASE ("StridedSlice_SimpleStride_GpuAcc_Test")
230{
231 std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
232 StridedSliceSimpleStrideTest(backends);
233}
234
235TEST_CASE ("StridedSlice_SimpleRange_GpuAcc_Test")
236{
237 std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
238 StridedSliceSimpleRangeMaskTest(backends);
239}
240
241} // StridedSlice_GpuAccTests TestSuite
242
243} // namespace armnnDelegate