blob: 1ccf51c7d2cfc50f60ecd927e2cddc9cac1b4993 [file] [log] [blame]
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "GatherTestImpl.hpp"
7
8#include <ResolveType.hpp>
9
10#include <armnn/ArmNN.hpp>
11
12#include <backendsCommon/test/TensorCopyUtils.hpp>
13#include <backendsCommon/test/WorkloadTestUtils.hpp>
14
15#include <test/TensorHelpers.hpp>
16
17namespace
18{
19
20template <armnn::DataType ArmnnType,
21 typename T = armnn::ResolveType<ArmnnType>,
22 size_t ParamsDim,
23 size_t IndicesDim,
24 size_t OutputDim>
25LayerTestResult<T, OutputDim> GatherTestImpl(
26 armnn::IWorkloadFactory& workloadFactory,
27 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
28 const armnn::TensorInfo& paramsInfo,
29 const armnn::TensorInfo& indicesInfo,
30 const armnn::TensorInfo& outputInfo,
31 const std::vector<T>& paramsData,
32 const std::vector<int32_t>& indicesData,
33 const std::vector<T>& outputData)
34{
Derek Lambertic374ff02019-12-10 21:57:35 +000035 boost::ignore_unused(memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010036 auto params = MakeTensor<T, ParamsDim>(paramsInfo, paramsData);
37 auto indices = MakeTensor<int32_t, IndicesDim>(indicesInfo, indicesData);
38
39 LayerTestResult<T, OutputDim> result(outputInfo);
40 result.outputExpected = MakeTensor<T, OutputDim>(outputInfo, outputData);
41
42 std::unique_ptr<armnn::ITensorHandle> paramsHandle = workloadFactory.CreateTensorHandle(paramsInfo);
43 std::unique_ptr<armnn::ITensorHandle> indicesHandle = workloadFactory.CreateTensorHandle(indicesInfo);
44 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputInfo);
45
46 armnn::GatherQueueDescriptor data;
47 armnn::WorkloadInfo info;
48 AddInputToWorkload(data, info, paramsInfo, paramsHandle.get());
49 AddInputToWorkload(data, info, indicesInfo, indicesHandle.get());
50 AddOutputToWorkload(data, info, outputInfo, outputHandle.get());
51
52 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateGather(data, info);
53
54 paramsHandle->Allocate();
55 indicesHandle->Allocate();
56 outputHandle->Allocate();
57
58 CopyDataToITensorHandle(paramsHandle.get(), params.origin());
59 CopyDataToITensorHandle(indicesHandle.get(), indices.origin());
60
61 workload->Execute();
62
63 CopyDataFromITensorHandle(result.output.origin(), outputHandle.get());
64
65 return result;
66}
67
Matthew Jackson9bff1442019-09-12 09:08:23 +010068template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
69struct GatherTestHelper
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010070{
Matthew Jackson9bff1442019-09-12 09:08:23 +010071 static LayerTestResult<T, 1> Gather1dParamsTestImpl(
72 armnn::IWorkloadFactory& workloadFactory,
73 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010074 {
Matthew Jackson9bff1442019-09-12 09:08:23 +010075 armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
76 armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
77 armnn::TensorInfo outputInfo({ 4 }, ArmnnType);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010078
Matthew Jackson9bff1442019-09-12 09:08:23 +010079 if (armnn::IsQuantizedType<T>())
80 {
81 paramsInfo.SetQuantizationScale(1.0f);
82 paramsInfo.SetQuantizationOffset(1);
83 outputInfo.SetQuantizationScale(1.0f);
84 outputInfo.SetQuantizationOffset(1);
85 }
86 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8 });
87 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
88 const std::vector<T> expectedOutput = std::vector<T>({ 1, 3, 2, 6 });
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010089
Matthew Jackson9bff1442019-09-12 09:08:23 +010090 return GatherTestImpl<ArmnnType, T, 1, 1, 1>(
91 workloadFactory,
92 memoryManager,
93 paramsInfo,
94 indicesInfo,
95 outputInfo,
96 params,
97 indices,
98 expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010099 }
100
Matthew Jackson9bff1442019-09-12 09:08:23 +0100101 static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
102 armnn::IWorkloadFactory& workloadFactory,
103 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100104 {
Matthew Jackson9bff1442019-09-12 09:08:23 +0100105 armnn::TensorInfo paramsInfo({ 5, 2 }, ArmnnType);
106 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
107 armnn::TensorInfo outputInfo({ 3, 2 }, ArmnnType);
108
109 if (armnn::IsQuantizedType<T>())
110 {
111 paramsInfo.SetQuantizationScale(1.0f);
112 paramsInfo.SetQuantizationOffset(1);
113 outputInfo.SetQuantizationScale(1.0f);
114 outputInfo.SetQuantizationOffset(1);
115 }
116
117 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
118 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
119 const std::vector<T> expectedOutput = std::vector<T>({ 3, 4, 7, 8, 9, 10 });
120
121 return GatherTestImpl<ArmnnType, T, 2, 1, 2>(
122 workloadFactory,
123 memoryManager,
124 paramsInfo,
125 indicesInfo,
126 outputInfo,
127 params,
128 indices,
129 expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100130 }
131
Matthew Jackson9bff1442019-09-12 09:08:23 +0100132 static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
133 armnn::IWorkloadFactory& workloadFactory,
134 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100135 {
Matthew Jackson9bff1442019-09-12 09:08:23 +0100136 armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
137 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
138 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100139
Matthew Jackson9bff1442019-09-12 09:08:23 +0100140 if (armnn::IsQuantizedType<T>())
141 {
142 paramsInfo.SetQuantizationScale(1.0f);
143 paramsInfo.SetQuantizationOffset(1);
144 outputInfo.SetQuantizationScale(1.0f);
145 outputInfo.SetQuantizationOffset(1);
146 }
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100147
Matthew Jackson9bff1442019-09-12 09:08:23 +0100148 const std::vector<T> params =
149 {
150 1, 2, 3,
151 4, 5, 6,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100152
Matthew Jackson9bff1442019-09-12 09:08:23 +0100153 7, 8, 9,
154 10, 11, 12,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100155
Matthew Jackson9bff1442019-09-12 09:08:23 +0100156 13, 14, 15,
157 16, 17, 18
158 };
159
160 const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
161
162 const std::vector<T> expectedOutput =
163 {
164 7, 8, 9,
165 10, 11, 12,
166 13, 14, 15,
167 16, 17, 18,
168 7, 8, 9,
169 10, 11, 12,
170
171 13, 14, 15,
172 16, 17, 18,
173 7, 8, 9,
174 10, 11, 12,
175 1, 2, 3,
176 4, 5, 6
177 };
178
179 return GatherTestImpl<ArmnnType, T, 3, 2, 4>(
180 workloadFactory,
181 memoryManager,
182 paramsInfo,
183 indicesInfo,
184 outputInfo,
185 params,
186 indices,
187 expectedOutput);
188 }
189};
190
191template<typename T>
192struct GatherTestHelper<armnn::DataType::Float16, T>
193{
194 static LayerTestResult<T, 1> Gather1dParamsTestImpl(
195 armnn::IWorkloadFactory& workloadFactory,
196 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100197 {
Matthew Jackson9bff1442019-09-12 09:08:23 +0100198 using namespace half_float::literal;
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100199
Matthew Jackson9bff1442019-09-12 09:08:23 +0100200 armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::Float16);
201 armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
202 armnn::TensorInfo outputInfo({ 4 }, armnn::DataType::Float16);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100203
Matthew Jackson9bff1442019-09-12 09:08:23 +0100204 const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h });
205 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
206 const std::vector<T> expectedOutput = std::vector<T>({ 1._h, 3._h, 2._h, 6._h });
207
208 return GatherTestImpl<armnn::DataType::Float16, T, 1, 1, 1>(
209 workloadFactory,
210 memoryManager,
211 paramsInfo,
212 indicesInfo,
213 outputInfo,
214 params,
215 indices,
216 expectedOutput);
217 }
218
219 static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
220 armnn::IWorkloadFactory& workloadFactory,
221 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
222 {
223 using namespace half_float::literal;
224
225 armnn::TensorInfo paramsInfo({ 5, 2 }, armnn::DataType::Float16);
226 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
227 armnn::TensorInfo outputInfo({ 3, 2 }, armnn::DataType::Float16);
228
229 const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h, 9._h, 10._h });
230
231 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
232 const std::vector<T> expectedOutput = std::vector<T>({ 3._h, 4._h, 7._h, 8._h, 9._h, 10._h });
233
234 return GatherTestImpl<armnn::DataType::Float16, T, 2, 1, 2>(
235 workloadFactory,
236 memoryManager,
237 paramsInfo,
238 indicesInfo,
239 outputInfo,
240 params,
241 indices,
242 expectedOutput);
243 }
244
245 static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
246 armnn::IWorkloadFactory& workloadFactory,
247 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
248 {
249 using namespace half_float::literal;
250
251 armnn::TensorInfo paramsInfo({ 3, 2, 3 }, armnn::DataType::Float16);
252 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
253 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, armnn::DataType::Float16);
254
255 const std::vector<T> params =
256 {
257 1._h, 2._h, 3._h,
258 4._h, 5._h, 6._h,
259
260 7._h, 8._h, 9._h,
261 10._h, 11._h, 12._h,
262
263 13._h, 14._h, 15._h,
264 16._h, 17._h, 18._h
265 };
266
267 const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
268
269 const std::vector<T> expectedOutput =
270 {
271 7._h, 8._h, 9._h,
272 10._h, 11._h, 12._h,
273 13._h, 14._h, 15._h,
274 16._h, 17._h, 18._h,
275 7._h, 8._h, 9._h,
276 10._h, 11._h, 12._h,
277
278 13._h, 14._h, 15._h,
279 16._h, 17._h, 18._h,
280 7._h, 8._h, 9._h,
281 10._h, 11._h, 12._h,
282 1._h, 2._h, 3._h,
283 4._h, 5._h, 6._h
284 };
285
286 return GatherTestImpl<armnn::DataType::Float16, T, 3, 2, 4>(
287 workloadFactory,
288 memoryManager,
289 paramsInfo,
290 indicesInfo,
291 outputInfo,
292 params,
293 indices,
294 expectedOutput);
295 }
296};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100297
298} // anonymous namespace
299
Matthew Jackson9bff1442019-09-12 09:08:23 +0100300LayerTestResult<float, 1> Gather1dParamsFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100301 armnn::IWorkloadFactory& workloadFactory,
302 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
303{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100304 return GatherTestHelper<armnn::DataType::Float32>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
305}
306
307LayerTestResult<armnn::Half, 1> Gather1dParamsFloat16Test(
308 armnn::IWorkloadFactory& workloadFactory,
309 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
310{
311 return GatherTestHelper<armnn::DataType::Float16>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100312}
313
314LayerTestResult<uint8_t, 1> Gather1dParamsUint8Test(
315 armnn::IWorkloadFactory& workloadFactory,
316 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
317{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100318 return GatherTestHelper<armnn::DataType::QuantisedAsymm8>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100319}
320
321LayerTestResult<int16_t, 1> Gather1dParamsInt16Test(
322 armnn::IWorkloadFactory& workloadFactory,
323 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
324{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100325 return GatherTestHelper<armnn::DataType::QuantisedSymm16>::Gather1dParamsTestImpl(workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100326}
327
Matthew Jackson9bff1442019-09-12 09:08:23 +0100328LayerTestResult<float, 2> GatherMultiDimParamsFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100329 armnn::IWorkloadFactory& workloadFactory,
330 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
331{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100332 return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsTestImpl(workloadFactory, memoryManager);
333}
334
335LayerTestResult<armnn::Half, 2> GatherMultiDimParamsFloat16Test(
336 armnn::IWorkloadFactory& workloadFactory,
337 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
338{
339 return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsTestImpl(workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100340}
341
342LayerTestResult<uint8_t, 2> GatherMultiDimParamsUint8Test(
343 armnn::IWorkloadFactory& workloadFactory,
344 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
345{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100346 return GatherTestHelper<armnn::DataType::QuantisedAsymm8>::GatherMultiDimParamsTestImpl(
347 workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100348}
349
350LayerTestResult<int16_t, 2> GatherMultiDimParamsInt16Test(
351 armnn::IWorkloadFactory& workloadFactory,
352 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
353{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100354 return GatherTestHelper<armnn::DataType::QuantisedSymm16>::GatherMultiDimParamsTestImpl(
355 workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100356}
357
Matthew Jackson9bff1442019-09-12 09:08:23 +0100358LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100359 armnn::IWorkloadFactory& workloadFactory,
360 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
361{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100362 return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesTestImpl(
363 workloadFactory, memoryManager);
364}
365
366LayerTestResult<armnn::Half, 4> GatherMultiDimParamsMultiDimIndicesFloat16Test(
367 armnn::IWorkloadFactory& workloadFactory,
368 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
369{
370 return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsMultiDimIndicesTestImpl(
371 workloadFactory, memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100372}
373
374LayerTestResult<uint8_t, 4> GatherMultiDimParamsMultiDimIndicesUint8Test(
375 armnn::IWorkloadFactory& workloadFactory,
376 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
377{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100378 return GatherTestHelper<armnn::DataType::QuantisedAsymm8>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100379 workloadFactory, memoryManager);
380}
381
382LayerTestResult<int16_t, 4> GatherMultiDimParamsMultiDimIndicesInt16Test(
383 armnn::IWorkloadFactory& workloadFactory,
384 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
385{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100386 return GatherTestHelper<armnn::DataType::QuantisedSymm16>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100387 workloadFactory, memoryManager);
388}