blob: bee5f6aaf9fa3d922f77e00526180ac7503fca01 [file] [log] [blame]
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "GatherTestImpl.hpp"
7
8#include <ResolveType.hpp>
9
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010010
11#include <backendsCommon/test/TensorCopyUtils.hpp>
12#include <backendsCommon/test/WorkloadTestUtils.hpp>
13
14#include <test/TensorHelpers.hpp>
15
16namespace
17{
18
19template <armnn::DataType ArmnnType,
20 typename T = armnn::ResolveType<ArmnnType>,
21 size_t ParamsDim,
22 size_t IndicesDim,
23 size_t OutputDim>
24LayerTestResult<T, OutputDim> GatherTestImpl(
25 armnn::IWorkloadFactory& workloadFactory,
26 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
27 const armnn::TensorInfo& paramsInfo,
28 const armnn::TensorInfo& indicesInfo,
29 const armnn::TensorInfo& outputInfo,
30 const std::vector<T>& paramsData,
31 const std::vector<int32_t>& indicesData,
32 const std::vector<T>& outputData)
33{
Jan Eilers8eb25602020-03-09 12:13:48 +000034 IgnoreUnused(memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010035 auto params = MakeTensor<T, ParamsDim>(paramsInfo, paramsData);
36 auto indices = MakeTensor<int32_t, IndicesDim>(indicesInfo, indicesData);
37
38 LayerTestResult<T, OutputDim> result(outputInfo);
39 result.outputExpected = MakeTensor<T, OutputDim>(outputInfo, outputData);
40
41 std::unique_ptr<armnn::ITensorHandle> paramsHandle = workloadFactory.CreateTensorHandle(paramsInfo);
42 std::unique_ptr<armnn::ITensorHandle> indicesHandle = workloadFactory.CreateTensorHandle(indicesInfo);
43 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputInfo);
44
45 armnn::GatherQueueDescriptor data;
46 armnn::WorkloadInfo info;
47 AddInputToWorkload(data, info, paramsInfo, paramsHandle.get());
48 AddInputToWorkload(data, info, indicesInfo, indicesHandle.get());
49 AddOutputToWorkload(data, info, outputInfo, outputHandle.get());
50
51 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateGather(data, info);
52
53 paramsHandle->Allocate();
54 indicesHandle->Allocate();
55 outputHandle->Allocate();
56
57 CopyDataToITensorHandle(paramsHandle.get(), params.origin());
58 CopyDataToITensorHandle(indicesHandle.get(), indices.origin());
59
60 workload->Execute();
61
62 CopyDataFromITensorHandle(result.output.origin(), outputHandle.get());
63
64 return result;
65}
66
Matthew Jackson9bff1442019-09-12 09:08:23 +010067template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
68struct GatherTestHelper
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010069{
Matthew Jackson9bff1442019-09-12 09:08:23 +010070 static LayerTestResult<T, 1> Gather1dParamsTestImpl(
71 armnn::IWorkloadFactory& workloadFactory,
72 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010073 {
Matthew Jackson9bff1442019-09-12 09:08:23 +010074 armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
75 armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
76 armnn::TensorInfo outputInfo({ 4 }, ArmnnType);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010077
Matthew Jackson9bff1442019-09-12 09:08:23 +010078 if (armnn::IsQuantizedType<T>())
79 {
80 paramsInfo.SetQuantizationScale(1.0f);
81 paramsInfo.SetQuantizationOffset(1);
82 outputInfo.SetQuantizationScale(1.0f);
83 outputInfo.SetQuantizationOffset(1);
84 }
85 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8 });
86 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
87 const std::vector<T> expectedOutput = std::vector<T>({ 1, 3, 2, 6 });
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010088
Matthew Jackson9bff1442019-09-12 09:08:23 +010089 return GatherTestImpl<ArmnnType, T, 1, 1, 1>(
90 workloadFactory,
91 memoryManager,
92 paramsInfo,
93 indicesInfo,
94 outputInfo,
95 params,
96 indices,
97 expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010098 }
99
Matthew Jackson9bff1442019-09-12 09:08:23 +0100100 static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
101 armnn::IWorkloadFactory& workloadFactory,
102 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100103 {
Matthew Jackson9bff1442019-09-12 09:08:23 +0100104 armnn::TensorInfo paramsInfo({ 5, 2 }, ArmnnType);
105 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
106 armnn::TensorInfo outputInfo({ 3, 2 }, ArmnnType);
107
108 if (armnn::IsQuantizedType<T>())
109 {
110 paramsInfo.SetQuantizationScale(1.0f);
111 paramsInfo.SetQuantizationOffset(1);
112 outputInfo.SetQuantizationScale(1.0f);
113 outputInfo.SetQuantizationOffset(1);
114 }
115
116 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
117 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
118 const std::vector<T> expectedOutput = std::vector<T>({ 3, 4, 7, 8, 9, 10 });
119
120 return GatherTestImpl<ArmnnType, T, 2, 1, 2>(
121 workloadFactory,
122 memoryManager,
123 paramsInfo,
124 indicesInfo,
125 outputInfo,
126 params,
127 indices,
128 expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100129 }
130
Matthew Jackson9bff1442019-09-12 09:08:23 +0100131 static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
132 armnn::IWorkloadFactory& workloadFactory,
133 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100134 {
Matthew Jackson9bff1442019-09-12 09:08:23 +0100135 armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
136 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
137 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100138
Matthew Jackson9bff1442019-09-12 09:08:23 +0100139 if (armnn::IsQuantizedType<T>())
140 {
141 paramsInfo.SetQuantizationScale(1.0f);
142 paramsInfo.SetQuantizationOffset(1);
143 outputInfo.SetQuantizationScale(1.0f);
144 outputInfo.SetQuantizationOffset(1);
145 }
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100146
Matthew Jackson9bff1442019-09-12 09:08:23 +0100147 const std::vector<T> params =
148 {
149 1, 2, 3,
150 4, 5, 6,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100151
Matthew Jackson9bff1442019-09-12 09:08:23 +0100152 7, 8, 9,
153 10, 11, 12,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100154
Matthew Jackson9bff1442019-09-12 09:08:23 +0100155 13, 14, 15,
156 16, 17, 18
157 };
158
159 const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
160
161 const std::vector<T> expectedOutput =
162 {
163 7, 8, 9,
164 10, 11, 12,
165 13, 14, 15,
166 16, 17, 18,
167 7, 8, 9,
168 10, 11, 12,
169
170 13, 14, 15,
171 16, 17, 18,
172 7, 8, 9,
173 10, 11, 12,
174 1, 2, 3,
175 4, 5, 6
176 };
177
178 return GatherTestImpl<ArmnnType, T, 3, 2, 4>(
179 workloadFactory,
180 memoryManager,
181 paramsInfo,
182 indicesInfo,
183 outputInfo,
184 params,
185 indices,
186 expectedOutput);
187 }
188};
189
190template<typename T>
191struct GatherTestHelper<armnn::DataType::Float16, T>
192{
193 static LayerTestResult<T, 1> Gather1dParamsTestImpl(
194 armnn::IWorkloadFactory& workloadFactory,
195 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100196 {
Matthew Jackson9bff1442019-09-12 09:08:23 +0100197 using namespace half_float::literal;
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100198
Matthew Jackson9bff1442019-09-12 09:08:23 +0100199 armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::Float16);
200 armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
201 armnn::TensorInfo outputInfo({ 4 }, armnn::DataType::Float16);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100202
Matthew Jackson9bff1442019-09-12 09:08:23 +0100203 const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h });
204 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
205 const std::vector<T> expectedOutput = std::vector<T>({ 1._h, 3._h, 2._h, 6._h });
206
207 return GatherTestImpl<armnn::DataType::Float16, T, 1, 1, 1>(
208 workloadFactory,
209 memoryManager,
210 paramsInfo,
211 indicesInfo,
212 outputInfo,
213 params,
214 indices,
215 expectedOutput);
216 }
217
218 static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
219 armnn::IWorkloadFactory& workloadFactory,
220 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
221 {
222 using namespace half_float::literal;
223
224 armnn::TensorInfo paramsInfo({ 5, 2 }, armnn::DataType::Float16);
225 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
226 armnn::TensorInfo outputInfo({ 3, 2 }, armnn::DataType::Float16);
227
228 const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h, 9._h, 10._h });
229
230 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
231 const std::vector<T> expectedOutput = std::vector<T>({ 3._h, 4._h, 7._h, 8._h, 9._h, 10._h });
232
233 return GatherTestImpl<armnn::DataType::Float16, T, 2, 1, 2>(
234 workloadFactory,
235 memoryManager,
236 paramsInfo,
237 indicesInfo,
238 outputInfo,
239 params,
240 indices,
241 expectedOutput);
242 }
243
244 static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
245 armnn::IWorkloadFactory& workloadFactory,
246 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
247 {
248 using namespace half_float::literal;
249
250 armnn::TensorInfo paramsInfo({ 3, 2, 3 }, armnn::DataType::Float16);
251 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
252 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, armnn::DataType::Float16);
253
254 const std::vector<T> params =
255 {
256 1._h, 2._h, 3._h,
257 4._h, 5._h, 6._h,
258
259 7._h, 8._h, 9._h,
260 10._h, 11._h, 12._h,
261
262 13._h, 14._h, 15._h,
263 16._h, 17._h, 18._h
264 };
265
266 const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
267
268 const std::vector<T> expectedOutput =
269 {
270 7._h, 8._h, 9._h,
271 10._h, 11._h, 12._h,
272 13._h, 14._h, 15._h,
273 16._h, 17._h, 18._h,
274 7._h, 8._h, 9._h,
275 10._h, 11._h, 12._h,
276
277 13._h, 14._h, 15._h,
278 16._h, 17._h, 18._h,
279 7._h, 8._h, 9._h,
280 10._h, 11._h, 12._h,
281 1._h, 2._h, 3._h,
282 4._h, 5._h, 6._h
283 };
284
285 return GatherTestImpl<armnn::DataType::Float16, T, 3, 2, 4>(
286 workloadFactory,
287 memoryManager,
288 paramsInfo,
289 indicesInfo,
290 outputInfo,
291 params,
292 indices,
293 expectedOutput);
294 }
295};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100296
297} // anonymous namespace
298
Matthew Jackson9bff1442019-09-12 09:08:23 +0100299LayerTestResult<float, 1> Gather1dParamsFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100300 armnn::IWorkloadFactory& workloadFactory,
301 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
302{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100303 return GatherTestHelper<armnn::DataType::Float32>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
304}
305
306LayerTestResult<armnn::Half, 1> Gather1dParamsFloat16Test(
307 armnn::IWorkloadFactory& workloadFactory,
308 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
309{
310 return GatherTestHelper<armnn::DataType::Float16>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100311}
312
313LayerTestResult<uint8_t, 1> Gather1dParamsUint8Test(
314 armnn::IWorkloadFactory& workloadFactory,
315 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
316{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000317 return GatherTestHelper<armnn::DataType::QAsymmU8>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100318}
319
320LayerTestResult<int16_t, 1> Gather1dParamsInt16Test(
321 armnn::IWorkloadFactory& workloadFactory,
322 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
323{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000324 return GatherTestHelper<armnn::DataType::QSymmS16>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100325}
326
Teresa Charlin93492462020-05-29 13:08:59 +0100327LayerTestResult<int32_t, 1> Gather1dParamsInt32Test(
328 armnn::IWorkloadFactory& workloadFactory,
329 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
330{
331 return GatherTestHelper<armnn::DataType::Signed32>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
332}
333
Matthew Jackson9bff1442019-09-12 09:08:23 +0100334LayerTestResult<float, 2> GatherMultiDimParamsFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100335 armnn::IWorkloadFactory& workloadFactory,
336 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
337{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100338 return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsTestImpl(workloadFactory, memoryManager);
339}
340
341LayerTestResult<armnn::Half, 2> GatherMultiDimParamsFloat16Test(
342 armnn::IWorkloadFactory& workloadFactory,
343 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
344{
345 return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsTestImpl(workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100346}
347
348LayerTestResult<uint8_t, 2> GatherMultiDimParamsUint8Test(
349 armnn::IWorkloadFactory& workloadFactory,
350 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
351{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000352 return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsTestImpl(
Matthew Jackson9bff1442019-09-12 09:08:23 +0100353 workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100354}
355
356LayerTestResult<int16_t, 2> GatherMultiDimParamsInt16Test(
357 armnn::IWorkloadFactory& workloadFactory,
358 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
359{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000360 return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsTestImpl(
Matthew Jackson9bff1442019-09-12 09:08:23 +0100361 workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100362}
363
Teresa Charlin93492462020-05-29 13:08:59 +0100364LayerTestResult<int32_t, 2> GatherMultiDimParamsInt32Test(
365 armnn::IWorkloadFactory& workloadFactory,
366 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
367{
368 return GatherTestHelper<armnn::DataType::Signed32>::GatherMultiDimParamsTestImpl(
369 workloadFactory, memoryManager);
370}
371
Matthew Jackson9bff1442019-09-12 09:08:23 +0100372LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100373 armnn::IWorkloadFactory& workloadFactory,
374 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
375{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100376 return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesTestImpl(
377 workloadFactory, memoryManager);
378}
379
380LayerTestResult<armnn::Half, 4> GatherMultiDimParamsMultiDimIndicesFloat16Test(
381 armnn::IWorkloadFactory& workloadFactory,
382 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
383{
384 return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsMultiDimIndicesTestImpl(
385 workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100386}
387
388LayerTestResult<uint8_t, 4> GatherMultiDimParamsMultiDimIndicesUint8Test(
389 armnn::IWorkloadFactory& workloadFactory,
390 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
391{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000392 return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100393 workloadFactory, memoryManager);
394}
395
396LayerTestResult<int16_t, 4> GatherMultiDimParamsMultiDimIndicesInt16Test(
397 armnn::IWorkloadFactory& workloadFactory,
398 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
399{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000400 return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100401 workloadFactory, memoryManager);
402}
Teresa Charlin93492462020-05-29 13:08:59 +0100403
404LayerTestResult<int32_t, 4> GatherMultiDimParamsMultiDimIndicesInt32Test(
405 armnn::IWorkloadFactory& workloadFactory,
406 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
407{
408 return GatherTestHelper<armnn::DataType::Signed32>::GatherMultiDimParamsMultiDimIndicesTestImpl(
409 workloadFactory, memoryManager);
410}