blob: 7fabff6c1c65842b6c104ab84abc6811a0eca636 [file] [log] [blame]
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001//
Teresa Charlinfbf0e5b2020-08-17 01:01:06 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "GatherTestImpl.hpp"
7
8#include <ResolveType.hpp>
9
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010010
11#include <backendsCommon/test/TensorCopyUtils.hpp>
12#include <backendsCommon/test/WorkloadTestUtils.hpp>
13
14#include <test/TensorHelpers.hpp>
15
16namespace
17{
18
19template <armnn::DataType ArmnnType,
20 typename T = armnn::ResolveType<ArmnnType>,
21 size_t ParamsDim,
22 size_t IndicesDim,
23 size_t OutputDim>
24LayerTestResult<T, OutputDim> GatherTestImpl(
25 armnn::IWorkloadFactory& workloadFactory,
26 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +010027 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010028 const armnn::TensorInfo& paramsInfo,
29 const armnn::TensorInfo& indicesInfo,
30 const armnn::TensorInfo& outputInfo,
31 const std::vector<T>& paramsData,
32 const std::vector<int32_t>& indicesData,
33 const std::vector<T>& outputData)
34{
Jan Eilers8eb25602020-03-09 12:13:48 +000035 IgnoreUnused(memoryManager);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010036 auto params = MakeTensor<T, ParamsDim>(paramsInfo, paramsData);
37 auto indices = MakeTensor<int32_t, IndicesDim>(indicesInfo, indicesData);
38
39 LayerTestResult<T, OutputDim> result(outputInfo);
40 result.outputExpected = MakeTensor<T, OutputDim>(outputInfo, outputData);
41
Finn Williamsc43de6a2020-08-27 11:13:25 +010042 std::unique_ptr<armnn::ITensorHandle> paramsHandle = tensorHandleFactory.CreateTensorHandle(paramsInfo);
43 std::unique_ptr<armnn::ITensorHandle> indicesHandle = tensorHandleFactory.CreateTensorHandle(indicesInfo);
44 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010045
46 armnn::GatherQueueDescriptor data;
47 armnn::WorkloadInfo info;
48 AddInputToWorkload(data, info, paramsInfo, paramsHandle.get());
49 AddInputToWorkload(data, info, indicesInfo, indicesHandle.get());
50 AddOutputToWorkload(data, info, outputInfo, outputHandle.get());
51
52 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateGather(data, info);
53
54 paramsHandle->Allocate();
55 indicesHandle->Allocate();
56 outputHandle->Allocate();
57
58 CopyDataToITensorHandle(paramsHandle.get(), params.origin());
59 CopyDataToITensorHandle(indicesHandle.get(), indices.origin());
60
61 workload->Execute();
62
63 CopyDataFromITensorHandle(result.output.origin(), outputHandle.get());
64
65 return result;
66}
67
Matthew Jackson9bff1442019-09-12 09:08:23 +010068template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
69struct GatherTestHelper
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010070{
Matthew Jackson9bff1442019-09-12 09:08:23 +010071 static LayerTestResult<T, 1> Gather1dParamsTestImpl(
72 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +010073 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
74 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010075 {
Matthew Jackson9bff1442019-09-12 09:08:23 +010076 armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
77 armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
78 armnn::TensorInfo outputInfo({ 4 }, ArmnnType);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010079
Matthew Jackson9bff1442019-09-12 09:08:23 +010080 if (armnn::IsQuantizedType<T>())
81 {
82 paramsInfo.SetQuantizationScale(1.0f);
83 paramsInfo.SetQuantizationOffset(1);
84 outputInfo.SetQuantizationScale(1.0f);
85 outputInfo.SetQuantizationOffset(1);
86 }
87 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8 });
88 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
89 const std::vector<T> expectedOutput = std::vector<T>({ 1, 3, 2, 6 });
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010090
Matthew Jackson9bff1442019-09-12 09:08:23 +010091 return GatherTestImpl<ArmnnType, T, 1, 1, 1>(
92 workloadFactory,
93 memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +010094 tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +010095 paramsInfo,
96 indicesInfo,
97 outputInfo,
98 params,
99 indices,
100 expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100101 }
102
Matthew Jackson9bff1442019-09-12 09:08:23 +0100103 static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
104 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100105 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
106 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100107 {
Matthew Jackson9bff1442019-09-12 09:08:23 +0100108 armnn::TensorInfo paramsInfo({ 5, 2 }, ArmnnType);
109 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
110 armnn::TensorInfo outputInfo({ 3, 2 }, ArmnnType);
111
112 if (armnn::IsQuantizedType<T>())
113 {
114 paramsInfo.SetQuantizationScale(1.0f);
115 paramsInfo.SetQuantizationOffset(1);
116 outputInfo.SetQuantizationScale(1.0f);
117 outputInfo.SetQuantizationOffset(1);
118 }
119
120 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
121 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
122 const std::vector<T> expectedOutput = std::vector<T>({ 3, 4, 7, 8, 9, 10 });
123
124 return GatherTestImpl<ArmnnType, T, 2, 1, 2>(
125 workloadFactory,
126 memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100127 tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100128 paramsInfo,
129 indicesInfo,
130 outputInfo,
131 params,
132 indices,
133 expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100134 }
135
Matthew Jackson9bff1442019-09-12 09:08:23 +0100136 static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
137 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100138 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
139 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100140 {
Matthew Jackson9bff1442019-09-12 09:08:23 +0100141 armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
142 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
143 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100144
Matthew Jackson9bff1442019-09-12 09:08:23 +0100145 if (armnn::IsQuantizedType<T>())
146 {
147 paramsInfo.SetQuantizationScale(1.0f);
148 paramsInfo.SetQuantizationOffset(1);
149 outputInfo.SetQuantizationScale(1.0f);
150 outputInfo.SetQuantizationOffset(1);
151 }
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100152
Matthew Jackson9bff1442019-09-12 09:08:23 +0100153 const std::vector<T> params =
154 {
155 1, 2, 3,
156 4, 5, 6,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100157
Matthew Jackson9bff1442019-09-12 09:08:23 +0100158 7, 8, 9,
159 10, 11, 12,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100160
Matthew Jackson9bff1442019-09-12 09:08:23 +0100161 13, 14, 15,
162 16, 17, 18
163 };
164
165 const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
166
167 const std::vector<T> expectedOutput =
168 {
169 7, 8, 9,
170 10, 11, 12,
171 13, 14, 15,
172 16, 17, 18,
173 7, 8, 9,
174 10, 11, 12,
175
176 13, 14, 15,
177 16, 17, 18,
178 7, 8, 9,
179 10, 11, 12,
180 1, 2, 3,
181 4, 5, 6
182 };
183
184 return GatherTestImpl<ArmnnType, T, 3, 2, 4>(
185 workloadFactory,
186 memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100187 tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100188 paramsInfo,
189 indicesInfo,
190 outputInfo,
191 params,
192 indices,
193 expectedOutput);
194 }
195};
196
197template<typename T>
198struct GatherTestHelper<armnn::DataType::Float16, T>
199{
200 static LayerTestResult<T, 1> Gather1dParamsTestImpl(
201 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100202 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
203 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100204 {
Matthew Jackson9bff1442019-09-12 09:08:23 +0100205 using namespace half_float::literal;
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100206
Matthew Jackson9bff1442019-09-12 09:08:23 +0100207 armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::Float16);
208 armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
209 armnn::TensorInfo outputInfo({ 4 }, armnn::DataType::Float16);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100210
Matthew Jackson9bff1442019-09-12 09:08:23 +0100211 const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h });
212 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
213 const std::vector<T> expectedOutput = std::vector<T>({ 1._h, 3._h, 2._h, 6._h });
214
215 return GatherTestImpl<armnn::DataType::Float16, T, 1, 1, 1>(
216 workloadFactory,
217 memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100218 tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100219 paramsInfo,
220 indicesInfo,
221 outputInfo,
222 params,
223 indices,
224 expectedOutput);
225 }
226
227 static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
228 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100229 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
230 const armnn::ITensorHandleFactory& tensorHandleFactory)
Matthew Jackson9bff1442019-09-12 09:08:23 +0100231 {
232 using namespace half_float::literal;
233
234 armnn::TensorInfo paramsInfo({ 5, 2 }, armnn::DataType::Float16);
235 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
236 armnn::TensorInfo outputInfo({ 3, 2 }, armnn::DataType::Float16);
237
238 const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h, 9._h, 10._h });
239
240 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
241 const std::vector<T> expectedOutput = std::vector<T>({ 3._h, 4._h, 7._h, 8._h, 9._h, 10._h });
242
243 return GatherTestImpl<armnn::DataType::Float16, T, 2, 1, 2>(
244 workloadFactory,
245 memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100246 tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100247 paramsInfo,
248 indicesInfo,
249 outputInfo,
250 params,
251 indices,
252 expectedOutput);
253 }
254
255 static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
256 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100257 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
258 const armnn::ITensorHandleFactory& tensorHandleFactory)
Matthew Jackson9bff1442019-09-12 09:08:23 +0100259 {
260 using namespace half_float::literal;
261
262 armnn::TensorInfo paramsInfo({ 3, 2, 3 }, armnn::DataType::Float16);
263 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
264 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, armnn::DataType::Float16);
265
266 const std::vector<T> params =
267 {
268 1._h, 2._h, 3._h,
269 4._h, 5._h, 6._h,
270
271 7._h, 8._h, 9._h,
272 10._h, 11._h, 12._h,
273
274 13._h, 14._h, 15._h,
275 16._h, 17._h, 18._h
276 };
277
278 const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
279
280 const std::vector<T> expectedOutput =
281 {
282 7._h, 8._h, 9._h,
283 10._h, 11._h, 12._h,
284 13._h, 14._h, 15._h,
285 16._h, 17._h, 18._h,
286 7._h, 8._h, 9._h,
287 10._h, 11._h, 12._h,
288
289 13._h, 14._h, 15._h,
290 16._h, 17._h, 18._h,
291 7._h, 8._h, 9._h,
292 10._h, 11._h, 12._h,
293 1._h, 2._h, 3._h,
294 4._h, 5._h, 6._h
295 };
296
297 return GatherTestImpl<armnn::DataType::Float16, T, 3, 2, 4>(
298 workloadFactory,
299 memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100300 tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100301 paramsInfo,
302 indicesInfo,
303 outputInfo,
304 params,
305 indices,
306 expectedOutput);
307 }
308};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100309
310} // anonymous namespace
311
Matthew Jackson9bff1442019-09-12 09:08:23 +0100312LayerTestResult<float, 1> Gather1dParamsFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100313 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100314 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
315 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100316{
Finn Williamsc43de6a2020-08-27 11:13:25 +0100317 return GatherTestHelper<armnn::DataType::Float32>::Gather1dParamsTestImpl(
318 workloadFactory, memoryManager, tensorHandleFactory);
Matthew Jackson9bff1442019-09-12 09:08:23 +0100319}
320
321LayerTestResult<armnn::Half, 1> Gather1dParamsFloat16Test(
322 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100323 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
324 const armnn::ITensorHandleFactory& tensorHandleFactory)
Matthew Jackson9bff1442019-09-12 09:08:23 +0100325{
Finn Williamsc43de6a2020-08-27 11:13:25 +0100326 return GatherTestHelper<armnn::DataType::Float16>::Gather1dParamsTestImpl(
327 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100328}
329
330LayerTestResult<uint8_t, 1> Gather1dParamsUint8Test(
331 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100332 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
333 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100334{
Finn Williamsc43de6a2020-08-27 11:13:25 +0100335 return GatherTestHelper<armnn::DataType::QAsymmU8>::Gather1dParamsTestImpl(
336 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100337}
338
339LayerTestResult<int16_t, 1> Gather1dParamsInt16Test(
340 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100341 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
342 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100343{
Finn Williamsc43de6a2020-08-27 11:13:25 +0100344 return GatherTestHelper<armnn::DataType::QSymmS16>::Gather1dParamsTestImpl(
345 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100346}
347
Teresa Charlin93492462020-05-29 13:08:59 +0100348LayerTestResult<int32_t, 1> Gather1dParamsInt32Test(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100349 armnn::IWorkloadFactory& workloadFactory,
350 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
351 const armnn::ITensorHandleFactory& tensorHandleFactory)
Teresa Charlin93492462020-05-29 13:08:59 +0100352{
Finn Williamsc43de6a2020-08-27 11:13:25 +0100353 return GatherTestHelper<armnn::DataType::Signed32>::Gather1dParamsTestImpl(
354 workloadFactory, memoryManager, tensorHandleFactory);
Teresa Charlin93492462020-05-29 13:08:59 +0100355}
356
Matthew Jackson9bff1442019-09-12 09:08:23 +0100357LayerTestResult<float, 2> GatherMultiDimParamsFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100358 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100359 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
360 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100361{
Finn Williamsc43de6a2020-08-27 11:13:25 +0100362 return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsTestImpl(
363 workloadFactory, memoryManager, tensorHandleFactory);
Matthew Jackson9bff1442019-09-12 09:08:23 +0100364}
365
366LayerTestResult<armnn::Half, 2> GatherMultiDimParamsFloat16Test(
367 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100368 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
369 const armnn::ITensorHandleFactory& tensorHandleFactory)
Matthew Jackson9bff1442019-09-12 09:08:23 +0100370{
Finn Williamsc43de6a2020-08-27 11:13:25 +0100371 return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsTestImpl(
372 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100373}
374
375LayerTestResult<uint8_t, 2> GatherMultiDimParamsUint8Test(
376 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100377 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
378 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100379{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000380 return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100381 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100382}
383
384LayerTestResult<int16_t, 2> GatherMultiDimParamsInt16Test(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100385 armnn::IWorkloadFactory& workloadFactory,
386 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
387 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100388{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000389 return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100390 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100391}
392
Teresa Charlin93492462020-05-29 13:08:59 +0100393LayerTestResult<int32_t, 2> GatherMultiDimParamsInt32Test(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100394 armnn::IWorkloadFactory& workloadFactory,
395 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
396 const armnn::ITensorHandleFactory& tensorHandleFactory)
Teresa Charlin93492462020-05-29 13:08:59 +0100397{
398 return GatherTestHelper<armnn::DataType::Signed32>::GatherMultiDimParamsTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100399 workloadFactory, memoryManager, tensorHandleFactory);
Teresa Charlin93492462020-05-29 13:08:59 +0100400}
401
Matthew Jackson9bff1442019-09-12 09:08:23 +0100402LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesFloat32Test(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100403 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100404 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
405 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100406{
Matthew Jackson9bff1442019-09-12 09:08:23 +0100407 return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100408 workloadFactory, memoryManager, tensorHandleFactory);
Matthew Jackson9bff1442019-09-12 09:08:23 +0100409}
410
411LayerTestResult<armnn::Half, 4> GatherMultiDimParamsMultiDimIndicesFloat16Test(
412 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100413 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
414 const armnn::ITensorHandleFactory& tensorHandleFactory)
Matthew Jackson9bff1442019-09-12 09:08:23 +0100415{
416 return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100417 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100418}
419
420LayerTestResult<uint8_t, 4> GatherMultiDimParamsMultiDimIndicesUint8Test(
421 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100422 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
423 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100424{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000425 return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100426 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100427}
428
429LayerTestResult<int16_t, 4> GatherMultiDimParamsMultiDimIndicesInt16Test(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100430 armnn::IWorkloadFactory& workloadFactory,
431 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
432 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100433{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000434 return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100435 workloadFactory, memoryManager, tensorHandleFactory);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100436}
Teresa Charlin93492462020-05-29 13:08:59 +0100437
438LayerTestResult<int32_t, 4> GatherMultiDimParamsMultiDimIndicesInt32Test(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100439 armnn::IWorkloadFactory& workloadFactory,
440 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
441 const armnn::ITensorHandleFactory& tensorHandleFactory)
Teresa Charlin93492462020-05-29 13:08:59 +0100442{
443 return GatherTestHelper<armnn::DataType::Signed32>::GatherMultiDimParamsMultiDimIndicesTestImpl(
Finn Williamsc43de6a2020-08-27 11:13:25 +0100444 workloadFactory, memoryManager, tensorHandleFactory);
Teresa Charlin93492462020-05-29 13:08:59 +0100445}