blob: 0118f5425731d49c214347a8aaabb9d66248f840 [file] [log] [blame]
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "GatherTestImpl.hpp"
7
8#include <ResolveType.hpp>
9
10#include <armnn/ArmNN.hpp>
11
12#include <backendsCommon/test/TensorCopyUtils.hpp>
13#include <backendsCommon/test/WorkloadTestUtils.hpp>
14
15#include <test/TensorHelpers.hpp>
16
17namespace
18{
19
20template <armnn::DataType ArmnnType,
21 typename T = armnn::ResolveType<ArmnnType>,
22 size_t ParamsDim,
23 size_t IndicesDim,
24 size_t OutputDim>
25LayerTestResult<T, OutputDim> GatherTestImpl(
26 armnn::IWorkloadFactory& workloadFactory,
27 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
28 const armnn::TensorInfo& paramsInfo,
29 const armnn::TensorInfo& indicesInfo,
30 const armnn::TensorInfo& outputInfo,
31 const std::vector<T>& paramsData,
32 const std::vector<int32_t>& indicesData,
33 const std::vector<T>& outputData)
34{
35 auto params = MakeTensor<T, ParamsDim>(paramsInfo, paramsData);
36 auto indices = MakeTensor<int32_t, IndicesDim>(indicesInfo, indicesData);
37
38 LayerTestResult<T, OutputDim> result(outputInfo);
39 result.outputExpected = MakeTensor<T, OutputDim>(outputInfo, outputData);
40
41 std::unique_ptr<armnn::ITensorHandle> paramsHandle = workloadFactory.CreateTensorHandle(paramsInfo);
42 std::unique_ptr<armnn::ITensorHandle> indicesHandle = workloadFactory.CreateTensorHandle(indicesInfo);
43 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputInfo);
44
45 armnn::GatherQueueDescriptor data;
46 armnn::WorkloadInfo info;
47 AddInputToWorkload(data, info, paramsInfo, paramsHandle.get());
48 AddInputToWorkload(data, info, indicesInfo, indicesHandle.get());
49 AddOutputToWorkload(data, info, outputInfo, outputHandle.get());
50
51 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateGather(data, info);
52
53 paramsHandle->Allocate();
54 indicesHandle->Allocate();
55 outputHandle->Allocate();
56
57 CopyDataToITensorHandle(paramsHandle.get(), params.origin());
58 CopyDataToITensorHandle(indicesHandle.get(), indices.origin());
59
60 workload->Execute();
61
62 CopyDataFromITensorHandle(result.output.origin(), outputHandle.get());
63
64 return result;
65}
66
67template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
68LayerTestResult<T, 1> Gather1dParamsTestImpl(armnn::IWorkloadFactory& workloadFactory,
69 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
70{
71 armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
72 armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
73 armnn::TensorInfo outputInfo({ 4 }, ArmnnType);
74
75 if (armnn::IsQuantizedType<T>())
76 {
77 paramsInfo.SetQuantizationScale(1.0f);
78 paramsInfo.SetQuantizationOffset(1);
79 outputInfo.SetQuantizationScale(1.0f);
80 outputInfo.SetQuantizationOffset(1);
81 }
82 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8 });
83 const std::vector<int32_t> indices = std::vector<int32_t>({ 0, 2, 1, 5 });
84 const std::vector<T> expectedOutput = std::vector<T>({ 1, 3, 2, 6 });
85
86 return GatherTestImpl<ArmnnType, T, 1, 1, 1>(
87 workloadFactory,
88 memoryManager,
89 paramsInfo,
90 indicesInfo,
91 outputInfo,
92 params,
93 indices,
94 expectedOutput);
95}
96
97template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
98LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
99 armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
100{
101 armnn::TensorInfo paramsInfo({ 5, 2 }, ArmnnType);
102 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
103 armnn::TensorInfo outputInfo({ 3, 2 }, ArmnnType);
104
105 if (armnn::IsQuantizedType<T>())
106 {
107 paramsInfo.SetQuantizationScale(1.0f);
108 paramsInfo.SetQuantizationOffset(1);
109 outputInfo.SetQuantizationScale(1.0f);
110 outputInfo.SetQuantizationOffset(1);
111 }
112
113 const std::vector<T> params = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
114 const std::vector<int32_t> indices = std::vector<int32_t>({ 1, 3, 4 });
115 const std::vector<T> expectedOutput = std::vector<T>({ 3, 4, 7, 8, 9, 10 });
116
117 return GatherTestImpl<ArmnnType, T, 2, 1, 2>(
118 workloadFactory,
119 memoryManager,
120 paramsInfo,
121 indicesInfo,
122 outputInfo,
123 params,
124 indices,
125 expectedOutput);
126}
127
128template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
129LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
130 armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
131{
132 armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
133 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
134 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
135
136 if (armnn::IsQuantizedType<T>())
137 {
138 paramsInfo.SetQuantizationScale(1.0f);
139 paramsInfo.SetQuantizationOffset(1);
140 outputInfo.SetQuantizationScale(1.0f);
141 outputInfo.SetQuantizationOffset(1);
142 }
143
144 const std::vector<T> params =
145 {
146 1, 2, 3,
147 4, 5, 6,
148
149 7, 8, 9,
150 10, 11, 12,
151
152 13, 14, 15,
153 16, 17, 18
154 };
155
156 const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
157
158 const std::vector<T> expectedOutput =
159 {
160 7, 8, 9,
161 10, 11, 12,
162 13, 14, 15,
163 16, 17, 18,
164 7, 8, 9,
165 10, 11, 12,
166
167 13, 14, 15,
168 16, 17, 18,
169 7, 8, 9,
170 10, 11, 12,
171 1, 2, 3,
172 4, 5, 6
173 };
174
175 return GatherTestImpl<ArmnnType, T, 3, 2, 4>(
176 workloadFactory,
177 memoryManager,
178 paramsInfo,
179 indicesInfo,
180 outputInfo,
181 params,
182 indices,
183 expectedOutput);
184}
185
186} // anonymous namespace
187
188LayerTestResult<float, 1> Gather1dParamsFloatTest(
189 armnn::IWorkloadFactory& workloadFactory,
190 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
191{
192 return Gather1dParamsTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager);
193}
194
195LayerTestResult<uint8_t, 1> Gather1dParamsUint8Test(
196 armnn::IWorkloadFactory& workloadFactory,
197 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
198{
199 return Gather1dParamsTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager);
200}
201
202LayerTestResult<int16_t, 1> Gather1dParamsInt16Test(
203 armnn::IWorkloadFactory& workloadFactory,
204 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
205{
206 return Gather1dParamsTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager);
207}
208
209LayerTestResult<float, 2> GatherMultiDimParamsFloatTest(
210 armnn::IWorkloadFactory& workloadFactory,
211 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
212{
213 return GatherMultiDimParamsTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager);
214}
215
216LayerTestResult<uint8_t, 2> GatherMultiDimParamsUint8Test(
217 armnn::IWorkloadFactory& workloadFactory,
218 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
219{
220 return GatherMultiDimParamsTestImpl<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager);
221}
222
223LayerTestResult<int16_t, 2> GatherMultiDimParamsInt16Test(
224 armnn::IWorkloadFactory& workloadFactory,
225 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
226{
227 return GatherMultiDimParamsTestImpl<armnn::DataType::QuantisedSymm16>(workloadFactory, memoryManager);
228}
229
230LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesFloatTest(
231 armnn::IWorkloadFactory& workloadFactory,
232 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
233{
234 return GatherMultiDimParamsMultiDimIndicesTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager);
235}
236
237LayerTestResult<uint8_t, 4> GatherMultiDimParamsMultiDimIndicesUint8Test(
238 armnn::IWorkloadFactory& workloadFactory,
239 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
240{
241 return GatherMultiDimParamsMultiDimIndicesTestImpl<armnn::DataType::QuantisedAsymm8>(
242 workloadFactory, memoryManager);
243}
244
245LayerTestResult<int16_t, 4> GatherMultiDimParamsMultiDimIndicesInt16Test(
246 armnn::IWorkloadFactory& workloadFactory,
247 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
248{
249 return GatherMultiDimParamsMultiDimIndicesTestImpl<armnn::DataType::QuantisedSymm16>(
250 workloadFactory, memoryManager);
251}