blob: 5b6d7efca002bbe7039bda0633abebd32840380a [file] [log] [blame]
Cathal Corbett839b9322022-11-18 08:52:18 +00001//
Teresa Charlinad1b3d72023-03-14 12:10:28 +00002// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
Cathal Corbett839b9322022-11-18 08:52:18 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "StridedSliceTestHelper.hpp"
7
8#include <armnn_delegate.hpp>
9
10#include <flatbuffers/flatbuffers.h>
11
12#include <doctest/doctest.h>
13
14namespace armnnDelegate
15{
16
17void StridedSlice4DTest(std::vector<armnn::BackendId>& backends)
18{
19 std::vector<int32_t> inputShape { 3, 2, 3, 1 };
20 std::vector<int32_t> outputShape { 1, 2, 3, 1 };
21 std::vector<int32_t> beginShape { 4 };
22 std::vector<int32_t> endShape { 4 };
23 std::vector<int32_t> strideShape { 4 };
24
25 std::vector<int32_t> beginData { 1, 0, 0, 0 };
26 std::vector<int32_t> endData { 2, 2, 3, 1 };
27 std::vector<int32_t> strideData { 1, 1, 1, 1 };
28 std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
29 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
30 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
31 std::vector<float> outputData { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f };
32
33 StridedSliceTestImpl<float>(
34 backends,
35 inputData,
36 outputData,
37 beginData,
38 endData,
39 strideData,
40 inputShape,
41 beginShape,
42 endShape,
43 strideShape,
44 outputShape
45 );
46}
47
48void StridedSlice4DReverseTest(std::vector<armnn::BackendId>& backends)
49{
50 std::vector<int32_t> inputShape { 3, 2, 3, 1 };
51 std::vector<int32_t> outputShape { 1, 2, 3, 1 };
52 std::vector<int32_t> beginShape { 4 };
53 std::vector<int32_t> endShape { 4 };
54 std::vector<int32_t> strideShape { 4 };
55
56 std::vector<int32_t> beginData { 1, -1, 0, 0 };
57 std::vector<int32_t> endData { 2, -3, 3, 1 };
58 std::vector<int32_t> strideData { 1, -1, 1, 1 };
59 std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
60 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
61 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
62 std::vector<float> outputData { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f };
63
64 StridedSliceTestImpl<float>(
65 backends,
66 inputData,
67 outputData,
68 beginData,
69 endData,
70 strideData,
71 inputShape,
72 beginShape,
73 endShape,
74 strideShape,
75 outputShape
76 );
77}
78
79void StridedSliceSimpleStrideTest(std::vector<armnn::BackendId>& backends)
80{
81 std::vector<int32_t> inputShape { 3, 2, 3, 1 };
82 std::vector<int32_t> outputShape { 2, 1, 2, 1 };
83 std::vector<int32_t> beginShape { 4 };
84 std::vector<int32_t> endShape { 4 };
85 std::vector<int32_t> strideShape { 4 };
86
87 std::vector<int32_t> beginData { 0, 0, 0, 0 };
88 std::vector<int32_t> endData { 3, 2, 3, 1 };
89 std::vector<int32_t> strideData { 2, 2, 2, 1 };
90 std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
91 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
92 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
93 std::vector<float> outputData { 1.0f, 1.0f,
94 5.0f, 5.0f };
95
96 StridedSliceTestImpl<float>(
97 backends,
98 inputData,
99 outputData,
100 beginData,
101 endData,
102 strideData,
103 inputShape,
104 beginShape,
105 endShape,
106 strideShape,
107 outputShape
108 );
109}
110
111void StridedSliceSimpleRangeMaskTest(std::vector<armnn::BackendId>& backends)
112{
113 std::vector<int32_t> inputShape { 3, 2, 3, 1 };
114 std::vector<int32_t> outputShape { 3, 2, 3, 1 };
115 std::vector<int32_t> beginShape { 4 };
116 std::vector<int32_t> endShape { 4 };
117 std::vector<int32_t> strideShape { 4 };
118
119 std::vector<int32_t> beginData { 1, 1, 1, 1 };
120 std::vector<int32_t> endData { 1, 1, 1, 1 };
121 std::vector<int32_t> strideData { 1, 1, 1, 1 };
122
123 int beginMask = -1;
124 int endMask = -1;
125
126 std::vector<float> inputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
127 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
128 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
129 std::vector<float> outputData { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
130 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
131 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f };
132
133 StridedSliceTestImpl<float>(
134 backends,
135 inputData,
136 outputData,
137 beginData,
138 endData,
139 strideData,
140 inputShape,
141 beginShape,
142 endShape,
143 strideShape,
144 outputShape,
145 beginMask,
146 endMask
147 );
148}
149
150TEST_SUITE("StridedSlice_CpuRefTests")
151{
152
153TEST_CASE ("StridedSlice_4D_CpuRef_Test")
154{
155 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
156 StridedSlice4DTest(backends);
157}
158
159TEST_CASE ("StridedSlice_4D_Reverse_CpuRef_Test")
160{
161 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
162 StridedSlice4DReverseTest(backends);
163}
164
165TEST_CASE ("StridedSlice_SimpleStride_CpuRef_Test")
166{
167 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
168 StridedSliceSimpleStrideTest(backends);
169}
170
171TEST_CASE ("StridedSlice_SimpleRange_CpuRef_Test")
172{
173 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
174 StridedSliceSimpleRangeMaskTest(backends);
175}
176
177} // StridedSlice_CpuRefTests TestSuite
178
179
180
181TEST_SUITE("StridedSlice_CpuAccTests")
182{
183
184TEST_CASE ("StridedSlice_4D_CpuAcc_Test")
185{
186 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
187 StridedSlice4DTest(backends);
188}
189
190TEST_CASE ("StridedSlice_4D_Reverse_CpuAcc_Test")
191{
192 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
193 StridedSlice4DReverseTest(backends);
194}
195
196TEST_CASE ("StridedSlice_SimpleStride_CpuAcc_Test")
197{
198 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
199 StridedSliceSimpleStrideTest(backends);
200}
201
202TEST_CASE ("StridedSlice_SimpleRange_CpuAcc_Test")
203{
204 std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
205 StridedSliceSimpleRangeMaskTest(backends);
206}
207
208} // StridedSlice_CpuAccTests TestSuite
209
210
211
212TEST_SUITE("StridedSlice_GpuAccTests")
213{
214
215TEST_CASE ("StridedSlice_4D_GpuAcc_Test")
216{
217 std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
218 StridedSlice4DTest(backends);
219}
220
221TEST_CASE ("StridedSlice_4D_Reverse_GpuAcc_Test")
222{
223 std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
224 StridedSlice4DReverseTest(backends);
225}
226
227TEST_CASE ("StridedSlice_SimpleStride_GpuAcc_Test")
228{
229 std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
230 StridedSliceSimpleStrideTest(backends);
231}
232
233TEST_CASE ("StridedSlice_SimpleRange_GpuAcc_Test")
234{
235 std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
236 StridedSliceSimpleRangeMaskTest(backends);
237}
238
239} // StridedSlice_GpuAccTests TestSuite
240
241} // namespace armnnDelegate