blob: 57a30c6f333e6157ed7a415e58996b5ff9e3ed76 [file] [log] [blame]
Teresa Charlinb2d3ec52022-04-12 22:07:09 +01001//
2// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "GatherNdTestImpl.hpp"
7
8#include <DataTypeUtils.hpp>
9#include <armnnTestUtils/TensorCopyUtils.hpp>
10#include <armnnTestUtils/WorkloadTestUtils.hpp>
11
12namespace
13{
14
15template<armnn::DataType ArmnnType,
16 typename T = armnn::ResolveType<ArmnnType>,
17 size_t ParamsDim,
18 size_t IndicesDim,
19 size_t OutputDim>
20LayerTestResult<T, OutputDim> GatherNdTestImpl(
21 armnn::IWorkloadFactory &workloadFactory,
22 const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager,
23 const armnn::ITensorHandleFactory &tensorHandleFactory,
24 const armnn::TensorInfo &paramsInfo,
25 const armnn::TensorInfo &indicesInfo,
26 const armnn::TensorInfo &outputInfo,
27 const std::vector<T> &paramsData,
28 const std::vector<int32_t> &indicesData,
29 const std::vector<T> &outputData)
30{
31 IgnoreUnused(memoryManager);
32
33 std::vector<T> actualOutput(outputInfo.GetNumElements());
34
35 std::unique_ptr<armnn::ITensorHandle> paramsHandle = tensorHandleFactory.CreateTensorHandle(paramsInfo);
36 std::unique_ptr<armnn::ITensorHandle> indicesHandle = tensorHandleFactory.CreateTensorHandle(indicesInfo);
37 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo);
38
39 armnn::GatherNdQueueDescriptor data;
40 armnn::WorkloadInfo info;
41 AddInputToWorkload(data, info, paramsInfo, paramsHandle.get());
42 AddInputToWorkload(data, info, indicesInfo, indicesHandle.get());
43 AddOutputToWorkload(data, info, outputInfo, outputHandle.get());
44
45 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::GatherNd,
46 data,
47 info);
48
49 paramsHandle->Allocate();
50 indicesHandle->Allocate();
51 outputHandle->Allocate();
52
53 CopyDataToITensorHandle(paramsHandle.get(), paramsData.data());
54 CopyDataToITensorHandle(indicesHandle.get(), indicesData.data());
55
56 workload->Execute();
57
58 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
59
60 return LayerTestResult<T, OutputDim>(actualOutput,
61 outputData,
62 outputHandle->GetShape(),
63 outputInfo.GetShape());
64}
65} // anonymous namespace
66
67template<armnn::DataType ArmnnType, typename T>
68LayerTestResult<T, 2> SimpleGatherNd2dTest(
69 armnn::IWorkloadFactory& workloadFactory,
70 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
71 const armnn::ITensorHandleFactory& tensorHandleFactory)
72{
73 armnn::TensorInfo paramsInfo({ 5, 2 }, ArmnnType);
74 armnn::TensorInfo indicesInfo({ 3, 1 }, armnn::DataType::Signed32);
75 armnn::TensorInfo outputInfo({ 3, 2 }, ArmnnType);
76 if (armnn::IsQuantizedType<T>())
77 {
78 paramsInfo.SetQuantizationScale(1.0f);
79 paramsInfo.SetQuantizationOffset(1);
80 outputInfo.SetQuantizationScale(1.0f);
81 outputInfo.SetQuantizationOffset(1);
82 }
83 const std::vector<T> params = ConvertToDataType<ArmnnType>(
84 { 1, 2,
85 3, 4,
86 5, 6,
87 7, 8,
88 9, 10},
89 paramsInfo);
90 const std::vector<int32_t> indices = ConvertToDataType<armnn::DataType::Signed32>(
91 { 1, 0, 4},
92 indicesInfo);
93 const std::vector<T> expectedOutput = ConvertToDataType<ArmnnType>(
94 { 3, 4,
95 1, 2,
96 9, 10},
97 outputInfo);
98 return GatherNdTestImpl<ArmnnType, T, 2, 2, 2>(
99 workloadFactory,
100 memoryManager,
101 tensorHandleFactory,
102 paramsInfo,
103 indicesInfo,
104 outputInfo,
105 params,
106 indices,
107 expectedOutput);
108}
109
110template<armnn::DataType ArmnnType, typename T>
111LayerTestResult<T, 3> SimpleGatherNd3dTest(
112 armnn::IWorkloadFactory& workloadFactory,
113 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
114 const armnn::ITensorHandleFactory& tensorHandleFactory)
115{
116 armnn::TensorInfo paramsInfo({ 2, 3, 8, 4 }, ArmnnType);
117 armnn::TensorInfo indicesInfo({ 2, 2 }, armnn::DataType::Signed32);
118 armnn::TensorInfo outputInfo({ 2, 8, 4 }, ArmnnType);
119
120 if (armnn::IsQuantizedType<T>())
121 {
122 paramsInfo.SetQuantizationScale(1.0f);
123 paramsInfo.SetQuantizationOffset(0);
124 outputInfo.SetQuantizationScale(1.0f);
125 outputInfo.SetQuantizationOffset(0);
126 }
127 const std::vector<T> params = ConvertToDataType<ArmnnType>(
128 { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
129 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
130
131 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
132 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
133
134 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
135 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
136
137 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
138 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
139
140 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
141 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,
142
143 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
144 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191 },
145 paramsInfo);
146
147 const std::vector<int32_t> indices = ConvertToDataType<armnn::DataType::Signed32>(
148 { 1, 2, 1, 1},
149 indicesInfo);
150
151 const std::vector<T> expectedOutput = ConvertToDataType<ArmnnType>(
152 { 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
153 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191,
154
155 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
156 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159},
157 outputInfo);
158
159 return GatherNdTestImpl<ArmnnType, T, 4, 2, 3>(
160 workloadFactory,
161 memoryManager,
162 tensorHandleFactory,
163 paramsInfo,
164 indicesInfo,
165 outputInfo,
166 params,
167 indices,
168 expectedOutput);
169}
170
171template<armnn::DataType ArmnnType, typename T>
172LayerTestResult<T, 4> SimpleGatherNd4dTest(
173 armnn::IWorkloadFactory& workloadFactory,
174 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
175 const armnn::ITensorHandleFactory& tensorHandleFactory)
176{
177 armnn::TensorInfo paramsInfo({ 5, 5, 2 }, ArmnnType);
178 armnn::TensorInfo indicesInfo({ 2, 2, 3, 2 }, armnn::DataType::Signed32);
179 armnn::TensorInfo outputInfo({ 2, 2, 3, 2 }, ArmnnType);
180
181 if (armnn::IsQuantizedType<T>())
182 {
183 paramsInfo.SetQuantizationScale(1.0f);
184 paramsInfo.SetQuantizationOffset(0);
185 outputInfo.SetQuantizationScale(1.0f);
186 outputInfo.SetQuantizationOffset(0);
187 }
188 const std::vector<T> params = ConvertToDataType<ArmnnType>(
189 { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
190 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
191 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
192 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
193 40, 41, 42, 43, 44, 45, 46, 47, 48, 49 },
194 paramsInfo);
195
196 const std::vector<int32_t> indices = ConvertToDataType<armnn::DataType::Signed32>(
197 { 0, 0,
198 3, 3,
199 4, 4,
200
201 0, 0,
202 1, 1,
203 2, 2,
204
205 4, 4,
206 3, 3,
207 0, 0,
208
209 2, 2,
210 1, 1,
211 0, 0 },
212 indicesInfo);
213
214 const std::vector<T> expectedOutput = ConvertToDataType<ArmnnType>(
215 { 0, 1,
216 36, 37,
217 48, 49,
218
219 0, 1,
220 12, 13,
221 24, 25,
222
223 48, 49,
224 36, 37,
225 0, 1,
226
227 24, 25,
228 12, 13,
229 0, 1 },
230 outputInfo);
231
232 return GatherNdTestImpl<ArmnnType, T, 3, 4, 4>(
233 workloadFactory,
234 memoryManager,
235 tensorHandleFactory,
236 paramsInfo,
237 indicesInfo,
238 outputInfo,
239 params,
240 indices,
241 expectedOutput);
242}
243
244//
245// Explicit template specializations
246//
247
248template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
249SimpleGatherNd2dTest<armnn::DataType::Float32>(
250 armnn::IWorkloadFactory& workloadFactory,
251 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
252 const armnn::ITensorHandleFactory& tensorHandleFactory);
253
254template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 3>
255SimpleGatherNd3dTest<armnn::DataType::Float32>(
256 armnn::IWorkloadFactory& workloadFactory,
257 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
258 const armnn::ITensorHandleFactory& tensorHandleFactory);
259
260template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
261SimpleGatherNd4dTest<armnn::DataType::Float32>(
262 armnn::IWorkloadFactory& workloadFactory,
263 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
264 const armnn::ITensorHandleFactory& tensorHandleFactory);
265
266template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
267SimpleGatherNd2dTest<armnn::DataType::QAsymmS8>(
268 armnn::IWorkloadFactory& workloadFactory,
269 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
270 const armnn::ITensorHandleFactory& tensorHandleFactory);
271
272template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 3>
273SimpleGatherNd3dTest<armnn::DataType::QAsymmS8>(
274 armnn::IWorkloadFactory& workloadFactory,
275 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
276 const armnn::ITensorHandleFactory& tensorHandleFactory);
277
278template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 4>
279SimpleGatherNd4dTest<armnn::DataType::QAsymmS8>(
280 armnn::IWorkloadFactory& workloadFactory,
281 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
282 const armnn::ITensorHandleFactory& tensorHandleFactory);
283
284template LayerTestResult<armnn::ResolveType<armnn::DataType::Signed32>, 2>
285SimpleGatherNd2dTest<armnn::DataType::Signed32>(
286 armnn::IWorkloadFactory& workloadFactory,
287 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
288 const armnn::ITensorHandleFactory& tensorHandleFactory);
289
290template LayerTestResult<armnn::ResolveType<armnn::DataType::Signed32>, 3>
291SimpleGatherNd3dTest<armnn::DataType::Signed32>(
292 armnn::IWorkloadFactory& workloadFactory,
293 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
294 const armnn::ITensorHandleFactory& tensorHandleFactory);
295
296template LayerTestResult<armnn::ResolveType<armnn::DataType::Signed32>, 4>
297SimpleGatherNd4dTest<armnn::DataType::Signed32>(
298 armnn::IWorkloadFactory& workloadFactory,
299 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
300 const armnn::ITensorHandleFactory& tensorHandleFactory);