blob: 68410559f7ddbae604e490705134e7cbffdd605e [file] [log] [blame]
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "GatherTestImpl.hpp"
7
8#include <ResolveType.hpp>
9
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010010
11#include <backendsCommon/test/TensorCopyUtils.hpp>
12#include <backendsCommon/test/WorkloadTestUtils.hpp>
13
14#include <test/TensorHelpers.hpp>
15
16namespace
17{
18
19template <armnn::DataType ArmnnType,
20 typename T = armnn::ResolveType<ArmnnType>,
21 size_t ParamsDim,
22 size_t IndicesDim,
23 size_t OutputDim>
24LayerTestResult<T, OutputDim> GatherTestImpl(
25 armnn::IWorkloadFactory& workloadFactory,
26 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
27 const armnn::TensorInfo& paramsInfo,
28 const armnn::TensorInfo& indicesInfo,
29 const armnn::TensorInfo& outputInfo,
30 const std::vector<T>& paramsData,
31 const std::vector<int32_t>& indicesData,
32 const std::vector<T>& outputData)
33{
Derek Lambertic374ff02019-12-10 21:57:35 +000034 boost::ignore_unused(memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010035 auto params = MakeTensor<T, ParamsDim>(paramsInfo, paramsData);
36 auto indices = MakeTensor<int32_t, IndicesDim>(indicesInfo, indicesData);
37
38 LayerTestResult<T, OutputDim> result(outputInfo);
39 result.outputExpected = MakeTensor<T, OutputDim>(outputInfo, outputData);
40
41 std::unique_ptr<armnn::ITensorHandle> paramsHandle = workloadFactory.CreateTensorHandle(paramsInfo);
42 std::unique_ptr<armnn::ITensorHandle> indicesHandle = workloadFactory.CreateTensorHandle(indicesInfo);
43 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputInfo);
44
45 armnn::GatherQueueDescriptor data;
46 armnn::WorkloadInfo info;
47 AddInputToWorkload(data, info, paramsInfo, paramsHandle.get());
48 AddInputToWorkload(data, info, indicesInfo, indicesHandle.get());
49 AddOutputToWorkload(data, info, outputInfo, outputHandle.get());
50
51 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateGather(data, info);
52
53 paramsHandle->Allocate();
54 indicesHandle->Allocate();
55 outputHandle->Allocate();
56
57 CopyDataToITensorHandle(paramsHandle.get(), params.origin());
58 CopyDataToITensorHandle(indicesHandle.get(), indices.origin());
59
60 workload->Execute();
61
62 CopyDataFromITensorHandle(result.output.origin(), outputHandle.get());
63
64 return result;
65}
66
Matthew Jackson9bff1442019-09-12 09:08:23 +010067template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
68struct GatherTestHelper
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010069{
Matthew Jackson9bff1442019-09-12 09:08:23 +010070 static LayerTestResult<T, 1> Gather1dParamsTestImpl(
71 armnn::IWorkloadFactory& workloadFactory,
72 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010073 {
Matthew Jackson9bff1442019-09-12 09:08:23 +010074 armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
75 armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
76 armnn::TensorInfo outputInfo({ 4 }, ArmnnType);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010077
Matthew Jackson9bff1442019-09-12 09:08:23 +010078 if (armnn::IsQuantizedType<T>())
79 {
80 paramsInfo.SetQuantizationScale(1.0f);
81 paramsInfo.SetQuantizationOffset(1);
82 outputInfo.SetQuantizationScale(1.0f);
83 outputInfo.SetQuantizationOffset(1);
84 }
85 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8 });
86 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
87 const std::vector<T> expectedOutput = std::vector<T>({ 1, 3, 2, 6 });
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010088
Matthew Jackson9bff1442019-09-12 09:08:23 +010089 return GatherTestImpl<ArmnnType, T, 1, 1, 1>(
90 workloadFactory,
91 memoryManager,
92 paramsInfo,
93 indicesInfo,
94 outputInfo,
95 params,
96 indices,
97 expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010098 }
99
Matthew Jackson9bff1442019-09-12 09:08:23 +0100100 static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
101 armnn::IWorkloadFactory& workloadFactory,
102 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100103 {
Matthew Jackson9bff1442019-09-12 09:08:23 +0100104 armnn::TensorInfo paramsInfo({ 5, 2 }, ArmnnType);
105 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
106 armnn::TensorInfo outputInfo({ 3, 2 }, ArmnnType);
107
108 if (armnn::IsQuantizedType<T>())
109 {
110 paramsInfo.SetQuantizationScale(1.0f);
111 paramsInfo.SetQuantizationOffset(1);
112 outputInfo.SetQuantizationScale(1.0f);
113 outputInfo.SetQuantizationOffset(1);
114 }
115
116 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
117 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
118 const std::vector<T> expectedOutput = std::vector<T>({ 3, 4, 7, 8, 9, 10 });
119
120 return GatherTestImpl<ArmnnType, T, 2, 1, 2>(
121 workloadFactory,
122 memoryManager,
123 paramsInfo,
124 indicesInfo,
125 outputInfo,
126 params,
127 indices,
128 expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100129 }
130
Matthew Jackson9bff1442019-09-12 09:08:23 +0100131 static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
132 armnn::IWorkloadFactory& workloadFactory,
133 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100134 {
Matthew Jackson9bff1442019-09-12 09:08:23 +0100135 armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
136 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
137 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100138
Matthew Jackson9bff1442019-09-12 09:08:23 +0100139 if (armnn::IsQuantizedType<T>())
140 {
141 paramsInfo.SetQuantizationScale(1.0f);
142 paramsInfo.SetQuantizationOffset(1);
143 outputInfo.SetQuantizationScale(1.0f);
144 outputInfo.SetQuantizationOffset(1);
145 }
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100146
Matthew Jackson9bff1442019-09-12 09:08:23 +0100147 const std::vector<T> params =
148 {
149 1, 2, 3,
150 4, 5, 6,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100151
Matthew Jackson9bff1442019-09-12 09:08:23 +0100152 7, 8, 9,
153 10, 11, 12,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100154
Matthew Jackson9bff1442019-09-12 09:08:23 +0100155 13, 14, 15,
156 16, 17, 18
157 };
158
159 const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
160
161 const std::vector<T> expectedOutput =
162 {
163 7, 8, 9,
164 10, 11, 12,
165 13, 14, 15,
166 16, 17, 18,
167 7, 8, 9,
168 10, 11, 12,
169
170 13, 14, 15,
171 16, 17, 18,
172 7, 8, 9,
173 10, 11, 12,
174 1, 2, 3,
175 4, 5, 6
176 };
177
178 return GatherTestImpl<ArmnnType, T, 3, 2, 4>(
179 workloadFactory,
180 memoryManager,
181 paramsInfo,
182 indicesInfo,
183 outputInfo,
184 params,
185 indices,
186 expectedOutput);
187 }
188};
189
190template<typename T>
191struct GatherTestHelper<armnn::DataType::Float16, T>
192{
193 static LayerTestResult<T, 1> Gather1dParamsTestImpl(
194 armnn::IWorkloadFactory& workloadFactory,
195 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100196 {
Matthew Jackson9bff1442019-09-12 09:08:23 +0100197 using namespace half_float::literal;
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100198
Matthew Jackson9bff1442019-09-12 09:08:23 +0100199 armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::Float16);
200 armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
201 armnn::TensorInfo outputInfo({ 4 }, armnn::DataType::Float16);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100202
Matthew Jackson9bff1442019-09-12 09:08:23 +0100203 const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h });
204 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
205 const std::vector<T> expectedOutput = std::vector<T>({ 1._h, 3._h, 2._h, 6._h });
206
207 return GatherTestImpl<armnn::DataType::Float16, T, 1, 1, 1>(
208 workloadFactory,
209 memoryManager,
210 paramsInfo,
211 indicesInfo,
212 outputInfo,
213 params,
214 indices,
215 expectedOutput);
216 }
217
218 static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
219 armnn::IWorkloadFactory& workloadFactory,
220 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
221 {
222 using namespace half_float::literal;
223
224 armnn::TensorInfo paramsInfo({ 5, 2 }, armnn::DataType::Float16);
225 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
226 armnn::TensorInfo outputInfo({ 3, 2 }, armnn::DataType::Float16);
227
228 const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h, 9._h, 10._h });
229
230 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
231 const std::vector<T> expectedOutput = std::vector<T>({ 3._h, 4._h, 7._h, 8._h, 9._h, 10._h });
232
233 return GatherTestImpl<armnn::DataType::Float16, T, 2, 1, 2>(
234 workloadFactory,
235 memoryManager,
236 paramsInfo,
237 indicesInfo,
238 outputInfo,
239 params,
240 indices,
241 expectedOutput);
242 }
243
244 static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
245 armnn::IWorkloadFactory& workloadFactory,
246 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
247 {
248 using namespace half_float::literal;
249
250 armnn::TensorInfo paramsInfo({ 3, 2, 3 }, armnn::DataType::Float16);
251 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
252 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, armnn::DataType::Float16);
253
254 const std::vector<T> params =
255 {
256 1._h, 2._h, 3._h,
257 4._h, 5._h, 6._h,
258
259 7._h, 8._h, 9._h,
260 10._h, 11._h, 12._h,
261
262 13._h, 14._h, 15._h,
263 16._h, 17._h, 18._h
264 };
265
266 const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
267
268 const std::vector<T> expectedOutput =
269 {
270 7._h, 8._h, 9._h,
271 10._h, 11._h, 12._h,
272 13._h, 14._h, 15._h,
273 16._h, 17._h, 18._h,
274 7._h, 8._h, 9._h,
275 10._h, 11._h, 12._h,
276
277 13._h, 14._h, 15._h,
278 16._h, 17._h, 18._h,
279 7._h, 8._h, 9._h,
280 10._h, 11._h, 12._h,
281 1._h, 2._h, 3._h,
282 4._h, 5._h, 6._h
283 };
284
285 return GatherTestImpl<armnn::DataType::Float16, T, 3, 2, 4>(
286 workloadFactory,
287 memoryManager,
288 paramsInfo,
289 indicesInfo,
290 outputInfo,
291 params,
292 indices,
293 expectedOutput);
294 }
295};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100296
297} // anonymous namespace
298
Matthew Jackson9bff1442019-09-12 09:08:23 +0100299LayerTestResult<float, 1> Gather1dParamsFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100300 armnn::IWorkloadFactory& workloadFactory,
301 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
302{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100303 return GatherTestHelper<armnn::DataType::Float32>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
304}
305
306LayerTestResult<armnn::Half, 1> Gather1dParamsFloat16Test(
307 armnn::IWorkloadFactory& workloadFactory,
308 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
309{
310 return GatherTestHelper<armnn::DataType::Float16>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100311}
312
313LayerTestResult<uint8_t, 1> Gather1dParamsUint8Test(
314 armnn::IWorkloadFactory& workloadFactory,
315 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
316{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000317 return GatherTestHelper<armnn::DataType::QAsymmU8>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100318}
319
320LayerTestResult<int16_t, 1> Gather1dParamsInt16Test(
321 armnn::IWorkloadFactory& workloadFactory,
322 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
323{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000324 return GatherTestHelper<armnn::DataType::QSymmS16>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100325}
326
Matthew Jackson9bff1442019-09-12 09:08:23 +0100327LayerTestResult<float, 2> GatherMultiDimParamsFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100328 armnn::IWorkloadFactory& workloadFactory,
329 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
330{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100331 return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsTestImpl(workloadFactory, memoryManager);
332}
333
334LayerTestResult<armnn::Half, 2> GatherMultiDimParamsFloat16Test(
335 armnn::IWorkloadFactory& workloadFactory,
336 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
337{
338 return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsTestImpl(workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100339}
340
341LayerTestResult<uint8_t, 2> GatherMultiDimParamsUint8Test(
342 armnn::IWorkloadFactory& workloadFactory,
343 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
344{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000345 return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsTestImpl(
Matthew Jackson9bff1442019-09-12 09:08:23 +0100346 workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100347}
348
349LayerTestResult<int16_t, 2> GatherMultiDimParamsInt16Test(
350 armnn::IWorkloadFactory& workloadFactory,
351 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
352{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000353 return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsTestImpl(
Matthew Jackson9bff1442019-09-12 09:08:23 +0100354 workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100355}
356
Matthew Jackson9bff1442019-09-12 09:08:23 +0100357LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100358 armnn::IWorkloadFactory& workloadFactory,
359 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
360{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100361 return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesTestImpl(
362 workloadFactory, memoryManager);
363}
364
365LayerTestResult<armnn::Half, 4> GatherMultiDimParamsMultiDimIndicesFloat16Test(
366 armnn::IWorkloadFactory& workloadFactory,
367 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
368{
369 return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsMultiDimIndicesTestImpl(
370 workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100371}
372
373LayerTestResult<uint8_t, 4> GatherMultiDimParamsMultiDimIndicesUint8Test(
374 armnn::IWorkloadFactory& workloadFactory,
375 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
376{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000377 return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100378 workloadFactory, memoryManager);
379}
380
381LayerTestResult<int16_t, 4> GatherMultiDimParamsMultiDimIndicesInt16Test(
382 armnn::IWorkloadFactory& workloadFactory,
383 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
384{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000385 return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100386 workloadFactory, memoryManager);
387}