blob: 4434b0f0c701a1aebfe1190a28ccf4fea51789f5 [file] [log] [blame]
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001//
Nikhil Raj369d8fc2022-11-24 13:12:36 +00002// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "GatherTestImpl.hpp"
7
8#include <ResolveType.hpp>
9
Sadik Armagana097d2a2021-11-24 15:47:28 +000010#include <armnnTestUtils/TensorCopyUtils.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +000011#include <armnnTestUtils/WorkloadTestUtils.hpp>
Colm Donelanc42a9872022-02-02 16:35:09 +000012#include <armnnTestUtils/TensorHelpers.hpp>
Nikhil Raj369d8fc2022-11-24 13:12:36 +000013#include <utility>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010014
15namespace
16{
17
18template <armnn::DataType ArmnnType,
19 typename T = armnn::ResolveType<ArmnnType>,
20 size_t ParamsDim,
21 size_t IndicesDim,
22 size_t OutputDim>
23LayerTestResult<T, OutputDim> GatherTestImpl(
24 armnn::IWorkloadFactory& workloadFactory,
25 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +010026 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010027 const armnn::TensorInfo& paramsInfo,
28 const armnn::TensorInfo& indicesInfo,
29 const armnn::TensorInfo& outputInfo,
30 const std::vector<T>& paramsData,
31 const std::vector<int32_t>& indicesData,
Nikhil Raj369d8fc2022-11-24 13:12:36 +000032 const std::vector<T>& outputData,
33 armnn::GatherDescriptor descriptor= armnn::GatherDescriptor())
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010034{
Jan Eilers8eb25602020-03-09 12:13:48 +000035 IgnoreUnused(memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010036
Sadik Armagan483c8112021-06-01 09:24:52 +010037 std::vector<T> actualOutput(outputInfo.GetNumElements());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010038
Finn Williamsc43de6a2020-08-27 11:13:25 +010039 std::unique_ptr<armnn::ITensorHandle> paramsHandle = tensorHandleFactory.CreateTensorHandle(paramsInfo);
40 std::unique_ptr<armnn::ITensorHandle> indicesHandle = tensorHandleFactory.CreateTensorHandle(indicesInfo);
41 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010042
43 armnn::GatherQueueDescriptor data;
Nikhil Raj369d8fc2022-11-24 13:12:36 +000044 data.m_Parameters = std::move(descriptor);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010045 armnn::WorkloadInfo info;
46 AddInputToWorkload(data, info, paramsInfo, paramsHandle.get());
47 AddInputToWorkload(data, info, indicesInfo, indicesHandle.get());
48 AddOutputToWorkload(data, info, outputInfo, outputHandle.get());
49
Teresa Charlin611c7fb2022-01-07 09:47:29 +000050 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Gather, data, info);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010051
52 paramsHandle->Allocate();
53 indicesHandle->Allocate();
54 outputHandle->Allocate();
55
Sadik Armagan483c8112021-06-01 09:24:52 +010056 CopyDataToITensorHandle(paramsHandle.get(), paramsData.data());
57 CopyDataToITensorHandle(indicesHandle.get(), indicesData.data());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010058
59 workload->Execute();
60
Sadik Armagan483c8112021-06-01 09:24:52 +010061 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010062
Sadik Armagan483c8112021-06-01 09:24:52 +010063 return LayerTestResult<T, OutputDim>(actualOutput,
64 outputData,
65 outputHandle->GetShape(),
66 outputInfo.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010067}
68
Matthew Jackson9bff1442019-09-12 09:08:23 +010069template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
70struct GatherTestHelper
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010071{
Matthew Jackson9bff1442019-09-12 09:08:23 +010072 static LayerTestResult<T, 1> Gather1dParamsTestImpl(
73 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +010074 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
75 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010076 {
Matthew Jackson9bff1442019-09-12 09:08:23 +010077 armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
78 armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
79 armnn::TensorInfo outputInfo({ 4 }, ArmnnType);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010080
Matthew Jackson9bff1442019-09-12 09:08:23 +010081 if (armnn::IsQuantizedType<T>())
82 {
83 paramsInfo.SetQuantizationScale(1.0f);
84 paramsInfo.SetQuantizationOffset(1);
85 outputInfo.SetQuantizationScale(1.0f);
86 outputInfo.SetQuantizationOffset(1);
87 }
88 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8 });
89 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
90 const std::vector<T> expectedOutput = std::vector<T>({ 1, 3, 2, 6 });
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010091
Matthew Jackson9bff1442019-09-12 09:08:23 +010092 return GatherTestImpl<ArmnnType, T, 1, 1, 1>(
93 workloadFactory,
94 memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +010095 tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +010096 paramsInfo,
97 indicesInfo,
98 outputInfo,
99 params,
100 indices,
101 expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100102 }
103
Nikhil Raj369d8fc2022-11-24 13:12:36 +0000104 static LayerTestResult<T, 1> Gather1dParamsAxisTestImpl(
105 armnn::IWorkloadFactory& workloadFactory,
106 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
107 const armnn::ITensorHandleFactory& tensorHandleFactory)
108 {
109 armnn::GatherDescriptor descriptor;
110 descriptor.m_Axis=1;
111 armnn::TensorInfo paramsInfo({ 4, 3 }, ArmnnType);
112 armnn::TensorInfo indicesInfo({ 2 }, armnn::DataType::Signed32);
113 armnn::TensorInfo outputInfo({ 4, 2 }, ArmnnType);
114
115 if (armnn::IsQuantizedType<T>())
116 {
117 paramsInfo.SetQuantizationScale(1.0f);
118 paramsInfo.SetQuantizationOffset(1);
119 outputInfo.SetQuantizationScale(1.0f);
120 outputInfo.SetQuantizationOffset(1);
121 }
122 const std::vector<T> params ={ 10, 11, 12,
123 110, 111, 112,
124 120, 121, 122,
125 130, 131, 132 };
126 const std::vector<int32_t> indices = std::vector<int32_t>({ 2, 1 });
127 const std::vector<T> expectedOutput = { 12, 11,
128 112, 111,
129 122, 121,
130 132, 131 } ;
131
132 return GatherTestImpl<ArmnnType, T, 1, 1, 1>(
133 workloadFactory,
134 memoryManager,
135 tensorHandleFactory,
136 paramsInfo,
137 indicesInfo,
138 outputInfo,
139 params,
140 indices,
141 expectedOutput,
142 descriptor);
143 }
144
Matthew Jackson9bff1442019-09-12 09:08:23 +0100145 static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
146 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100147 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
148 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100149 {
Matthew Jackson9bff1442019-09-12 09:08:23 +0100150 armnn::TensorInfo paramsInfo({ 5, 2 }, ArmnnType);
151 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
152 armnn::TensorInfo outputInfo({ 3, 2 }, ArmnnType);
153
154 if (armnn::IsQuantizedType<T>())
155 {
156 paramsInfo.SetQuantizationScale(1.0f);
157 paramsInfo.SetQuantizationOffset(1);
158 outputInfo.SetQuantizationScale(1.0f);
159 outputInfo.SetQuantizationOffset(1);
160 }
161
162 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
163 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
164 const std::vector<T> expectedOutput = std::vector<T>({ 3, 4, 7, 8, 9, 10 });
165
166 return GatherTestImpl<ArmnnType, T, 2, 1, 2>(
167 workloadFactory,
168 memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100169 tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100170 paramsInfo,
171 indicesInfo,
172 outputInfo,
173 params,
174 indices,
175 expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100176 }
177
Matthew Jackson9bff1442019-09-12 09:08:23 +0100178 static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
179 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100180 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
181 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100182 {
Nikhil Raj369d8fc2022-11-24 13:12:36 +0000183 armnn::TensorInfo paramsInfo({ 3, 2, 3 }, ArmnnType);
Matthew Jackson9bff1442019-09-12 09:08:23 +0100184 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
185 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100186
Matthew Jackson9bff1442019-09-12 09:08:23 +0100187 if (armnn::IsQuantizedType<T>())
188 {
189 paramsInfo.SetQuantizationScale(1.0f);
190 paramsInfo.SetQuantizationOffset(1);
191 outputInfo.SetQuantizationScale(1.0f);
192 outputInfo.SetQuantizationOffset(1);
193 }
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100194
Matthew Jackson9bff1442019-09-12 09:08:23 +0100195 const std::vector<T> params =
196 {
197 1, 2, 3,
198 4, 5, 6,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100199
Matthew Jackson9bff1442019-09-12 09:08:23 +0100200 7, 8, 9,
201 10, 11, 12,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100202
Matthew Jackson9bff1442019-09-12 09:08:23 +0100203 13, 14, 15,
204 16, 17, 18
205 };
206
207 const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
208
209 const std::vector<T> expectedOutput =
210 {
211 7, 8, 9,
212 10, 11, 12,
213 13, 14, 15,
214 16, 17, 18,
215 7, 8, 9,
216 10, 11, 12,
217
218 13, 14, 15,
219 16, 17, 18,
220 7, 8, 9,
221 10, 11, 12,
222 1, 2, 3,
223 4, 5, 6
224 };
225
226 return GatherTestImpl<ArmnnType, T, 3, 2, 4>(
227 workloadFactory,
228 memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100229 tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100230 paramsInfo,
231 indicesInfo,
232 outputInfo,
233 params,
234 indices,
235 expectedOutput);
236 }
Nikhil Raj369d8fc2022-11-24 13:12:36 +0000237
238 static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesAxis1TestImpl(
239 armnn::IWorkloadFactory& workloadFactory,
240 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
241 const armnn::ITensorHandleFactory& tensorHandleFactory)
242 {
243 armnn::GatherDescriptor descriptor;
244 descriptor.m_Axis=1;
245 armnn::TensorInfo paramsInfo({ 3, 2, 3 }, ArmnnType);
246 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
247 armnn::TensorInfo outputInfo({ 3, 2, 3, 3 }, ArmnnType);
248
249 if (armnn::IsQuantizedType<T>())
250 {
251 paramsInfo.SetQuantizationScale(1.0f);
252 paramsInfo.SetQuantizationOffset(1);
253 outputInfo.SetQuantizationScale(1.0f);
254 outputInfo.SetQuantizationOffset(1);
255 }
256
257 const std::vector<T> params =
258 {
259 1, 2, 3,
260 4, 5, 6,
261
262 7, 8, 9,
263 10, 11, 12,
264
265 13, 14, 15,
266 16, 17, 18
267 };
268
269 const std::vector<int32_t> indices = { 1, 0, 1, 0, 1, 0 };
270
271 const std::vector<T> expectedOutput =
272 {
273 4, 5, 6,
274 1, 2, 3,
275 4, 5, 6,
276
277 1, 2, 3,
278 4, 5, 6,
279 1, 2, 3,
280
281 10, 11, 12,
282 7, 8, 9,
283 10, 11, 12,
284
285 7, 8, 9,
286 10, 11, 12,
287 7, 8, 9,
288
289 16, 17, 18,
290 13, 14, 15,
291 16, 17, 18,
292
293 13, 14, 15,
294 16, 17, 18,
295 13, 14, 15
296 };
297
298 return GatherTestImpl<ArmnnType, T, 3, 2, 4>(
299 workloadFactory,
300 memoryManager,
301 tensorHandleFactory,
302 paramsInfo,
303 indicesInfo,
304 outputInfo,
305 params,
306 indices,
307 expectedOutput,
308 descriptor);
309 }
310
311 static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesAxis2TestImpl(
312 armnn::IWorkloadFactory& workloadFactory,
313 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
314 const armnn::ITensorHandleFactory& tensorHandleFactory)
315 {
316 armnn::GatherDescriptor descriptor;
317 descriptor.m_Axis=2;
318 armnn::TensorInfo paramsInfo({ 3, 2, 3 }, ArmnnType);
319 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
320 armnn::TensorInfo outputInfo({ 3, 2, 2, 3 }, ArmnnType);
321
322 if (armnn::IsQuantizedType<T>())
323 {
324 paramsInfo.SetQuantizationScale(1.0f);
325 paramsInfo.SetQuantizationOffset(1);
326 outputInfo.SetQuantizationScale(1.0f);
327 outputInfo.SetQuantizationOffset(1);
328 }
329
330 const std::vector<T> params =
331 {
332 1, 2, 3,
333 4, 5, 6,
334
335 7, 8, 9,
336 10, 11, 12,
337
338 13, 14, 15,
339 16, 17, 18
340 };
341
342 const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
343
344 const std::vector<T> expectedOutput =
345 {
346 2, 3, 2,
347 3, 2, 1,
348
349 5, 6, 5,
350 6, 5, 4,
351
352 8, 9, 8,
353 9, 8, 7,
354
355 11, 12, 11,
356 12, 11, 10,
357
358 14, 15, 14,
359 15, 14, 13,
360
361 17, 18, 17,
362 18, 17, 16
363 };
364
365 return GatherTestImpl<ArmnnType, T, 3, 2, 4>(
366 workloadFactory,
367 memoryManager,
368 tensorHandleFactory,
369 paramsInfo,
370 indicesInfo,
371 outputInfo,
372 params,
373 indices,
374 expectedOutput,
375 descriptor);
376 }
Matthew Jackson9bff1442019-09-12 09:08:23 +0100377};
378
379template<typename T>
380struct GatherTestHelper<armnn::DataType::Float16, T>
381{
382 static LayerTestResult<T, 1> Gather1dParamsTestImpl(
383 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100384 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
385 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100386 {
Matthew Jackson9bff1442019-09-12 09:08:23 +0100387 using namespace half_float::literal;
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100388
Matthew Jackson9bff1442019-09-12 09:08:23 +0100389 armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::Float16);
390 armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
391 armnn::TensorInfo outputInfo({ 4 }, armnn::DataType::Float16);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100392
Matthew Jackson9bff1442019-09-12 09:08:23 +0100393 const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h });
394 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
395 const std::vector<T> expectedOutput = std::vector<T>({ 1._h, 3._h, 2._h, 6._h });
396
397 return GatherTestImpl<armnn::DataType::Float16, T, 1, 1, 1>(
398 workloadFactory,
399 memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100400 tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100401 paramsInfo,
402 indicesInfo,
403 outputInfo,
404 params,
405 indices,
406 expectedOutput);
407 }
408
409 static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
410 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100411 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
412 const armnn::ITensorHandleFactory& tensorHandleFactory)
Matthew Jackson9bff1442019-09-12 09:08:23 +0100413 {
414 using namespace half_float::literal;
415
416 armnn::TensorInfo paramsInfo({ 5, 2 }, armnn::DataType::Float16);
417 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
418 armnn::TensorInfo outputInfo({ 3, 2 }, armnn::DataType::Float16);
419
420 const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h, 9._h, 10._h });
421
422 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
423 const std::vector<T> expectedOutput = std::vector<T>({ 3._h, 4._h, 7._h, 8._h, 9._h, 10._h });
424
425 return GatherTestImpl<armnn::DataType::Float16, T, 2, 1, 2>(
426 workloadFactory,
427 memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100428 tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100429 paramsInfo,
430 indicesInfo,
431 outputInfo,
432 params,
433 indices,
434 expectedOutput);
435 }
436
437 static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
438 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100439 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
440 const armnn::ITensorHandleFactory& tensorHandleFactory)
Matthew Jackson9bff1442019-09-12 09:08:23 +0100441 {
442 using namespace half_float::literal;
443
444 armnn::TensorInfo paramsInfo({ 3, 2, 3 }, armnn::DataType::Float16);
445 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
446 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, armnn::DataType::Float16);
447
448 const std::vector<T> params =
449 {
450 1._h, 2._h, 3._h,
451 4._h, 5._h, 6._h,
452
453 7._h, 8._h, 9._h,
454 10._h, 11._h, 12._h,
455
456 13._h, 14._h, 15._h,
457 16._h, 17._h, 18._h
458 };
459
460 const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
461
462 const std::vector<T> expectedOutput =
463 {
464 7._h, 8._h, 9._h,
465 10._h, 11._h, 12._h,
466 13._h, 14._h, 15._h,
467 16._h, 17._h, 18._h,
468 7._h, 8._h, 9._h,
469 10._h, 11._h, 12._h,
470
471 13._h, 14._h, 15._h,
472 16._h, 17._h, 18._h,
473 7._h, 8._h, 9._h,
474 10._h, 11._h, 12._h,
475 1._h, 2._h, 3._h,
476 4._h, 5._h, 6._h
477 };
478
479 return GatherTestImpl<armnn::DataType::Float16, T, 3, 2, 4>(
480 workloadFactory,
481 memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100482 tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100483 paramsInfo,
484 indicesInfo,
485 outputInfo,
486 params,
487 indices,
488 expectedOutput);
489 }
490};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100491
492} // anonymous namespace
493
Matthew Jackson9bff1442019-09-12 09:08:23 +0100494LayerTestResult<float, 1> Gather1dParamsFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100495 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100496 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
497 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100498{
Finn Williamsc43de6a2020-08-27 11:13:25 +0100499 return GatherTestHelper<armnn::DataType::Float32>::Gather1dParamsTestImpl(
500 workloadFactory, memoryManager, tensorHandleFactory);
Matthew Jackson9bff1442019-09-12 09:08:23 +0100501}
502
Nikhil Raj369d8fc2022-11-24 13:12:36 +0000503LayerTestResult<float, 1> Gather1dParamsAxisTest(
504 armnn::IWorkloadFactory& workloadFactory,
505 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
506 const armnn::ITensorHandleFactory& tensorHandleFactory)
507{
508 return GatherTestHelper<armnn::DataType::Float32>::Gather1dParamsAxisTestImpl(
509 workloadFactory, memoryManager, tensorHandleFactory);
510}
511
Matthew Jackson9bff1442019-09-12 09:08:23 +0100512LayerTestResult<armnn::Half, 1> Gather1dParamsFloat16Test(
513 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100514 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
515 const armnn::ITensorHandleFactory& tensorHandleFactory)
Matthew Jackson9bff1442019-09-12 09:08:23 +0100516{
Finn Williamsc43de6a2020-08-27 11:13:25 +0100517 return GatherTestHelper<armnn::DataType::Float16>::Gather1dParamsTestImpl(
518 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100519}
520
521LayerTestResult<uint8_t, 1> Gather1dParamsUint8Test(
522 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100523 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
524 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100525{
Finn Williamsc43de6a2020-08-27 11:13:25 +0100526 return GatherTestHelper<armnn::DataType::QAsymmU8>::Gather1dParamsTestImpl(
527 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100528}
529
530LayerTestResult<int16_t, 1> Gather1dParamsInt16Test(
531 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100532 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
533 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100534{
Finn Williamsc43de6a2020-08-27 11:13:25 +0100535 return GatherTestHelper<armnn::DataType::QSymmS16>::Gather1dParamsTestImpl(
536 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100537}
538
Teresa Charlin93492462020-05-29 13:08:59 +0100539LayerTestResult<int32_t, 1> Gather1dParamsInt32Test(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100540 armnn::IWorkloadFactory& workloadFactory,
541 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
542 const armnn::ITensorHandleFactory& tensorHandleFactory)
Teresa Charlin93492462020-05-29 13:08:59 +0100543{
Finn Williamsc43de6a2020-08-27 11:13:25 +0100544 return GatherTestHelper<armnn::DataType::Signed32>::Gather1dParamsTestImpl(
545 workloadFactory, memoryManager, tensorHandleFactory);
Teresa Charlin93492462020-05-29 13:08:59 +0100546}
547
Matthew Jackson9bff1442019-09-12 09:08:23 +0100548LayerTestResult<float, 2> GatherMultiDimParamsFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100549 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100550 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
551 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100552{
Finn Williamsc43de6a2020-08-27 11:13:25 +0100553 return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsTestImpl(
554 workloadFactory, memoryManager, tensorHandleFactory);
Matthew Jackson9bff1442019-09-12 09:08:23 +0100555}
556
557LayerTestResult<armnn::Half, 2> GatherMultiDimParamsFloat16Test(
558 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100559 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
560 const armnn::ITensorHandleFactory& tensorHandleFactory)
Matthew Jackson9bff1442019-09-12 09:08:23 +0100561{
Finn Williamsc43de6a2020-08-27 11:13:25 +0100562 return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsTestImpl(
563 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100564}
565
566LayerTestResult<uint8_t, 2> GatherMultiDimParamsUint8Test(
567 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100568 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
569 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100570{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000571 return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100572 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100573}
574
575LayerTestResult<int16_t, 2> GatherMultiDimParamsInt16Test(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100576 armnn::IWorkloadFactory& workloadFactory,
577 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
578 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100579{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000580 return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100581 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100582}
583
Teresa Charlin93492462020-05-29 13:08:59 +0100584LayerTestResult<int32_t, 2> GatherMultiDimParamsInt32Test(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100585 armnn::IWorkloadFactory& workloadFactory,
586 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
587 const armnn::ITensorHandleFactory& tensorHandleFactory)
Teresa Charlin93492462020-05-29 13:08:59 +0100588{
589 return GatherTestHelper<armnn::DataType::Signed32>::GatherMultiDimParamsTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100590 workloadFactory, memoryManager, tensorHandleFactory);
Teresa Charlin93492462020-05-29 13:08:59 +0100591}
592
Matthew Jackson9bff1442019-09-12 09:08:23 +0100593LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100594 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100595 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
596 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100597{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100598 return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100599 workloadFactory, memoryManager, tensorHandleFactory);
Matthew Jackson9bff1442019-09-12 09:08:23 +0100600}
601
Nikhil Raj369d8fc2022-11-24 13:12:36 +0000602LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesAxis1Test(
603 armnn::IWorkloadFactory& workloadFactory,
604 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
605 const armnn::ITensorHandleFactory& tensorHandleFactory)
606{
607 return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesAxis1TestImpl(
608 workloadFactory, memoryManager, tensorHandleFactory);
609}
610
611LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesAxis2Test(
612 armnn::IWorkloadFactory& workloadFactory,
613 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
614 const armnn::ITensorHandleFactory& tensorHandleFactory)
615{
616 return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesAxis2TestImpl(
617 workloadFactory, memoryManager, tensorHandleFactory);
618}
619
Matthew Jackson9bff1442019-09-12 09:08:23 +0100620LayerTestResult<armnn::Half, 4> GatherMultiDimParamsMultiDimIndicesFloat16Test(
621 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100622 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
623 const armnn::ITensorHandleFactory& tensorHandleFactory)
Matthew Jackson9bff1442019-09-12 09:08:23 +0100624{
625 return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100626 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100627}
628
629LayerTestResult<uint8_t, 4> GatherMultiDimParamsMultiDimIndicesUint8Test(
630 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100631 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
632 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100633{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000634 return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100635 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100636}
637
638LayerTestResult<int16_t, 4> GatherMultiDimParamsMultiDimIndicesInt16Test(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100639 armnn::IWorkloadFactory& workloadFactory,
640 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
641 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100642{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000643 return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100644 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100645}
Teresa Charlin93492462020-05-29 13:08:59 +0100646
647LayerTestResult<int32_t, 4> GatherMultiDimParamsMultiDimIndicesInt32Test(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100648 armnn::IWorkloadFactory& workloadFactory,
649 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
650 const armnn::ITensorHandleFactory& tensorHandleFactory)
Teresa Charlin93492462020-05-29 13:08:59 +0100651{
652 return GatherTestHelper<armnn::DataType::Signed32>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100653 workloadFactory, memoryManager, tensorHandleFactory);
Teresa Charlin93492462020-05-29 13:08:59 +0100654}