blob: 82de11ae210142b8c98c1709c7eef036dd20c4f5 [file] [log] [blame]
Finn Williams2605b232020-06-10 15:53:46 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RankTestImpl.hpp"
7
8#include <backendsCommon/test/DataTypeUtils.hpp>
9#include <backendsCommon/test/TensorCopyUtils.hpp>
10#include <backendsCommon/test/WorkloadTestUtils.hpp>
11
12#include <test/TensorHelpers.hpp>
13
14template<typename T, std::size_t n>
15LayerTestResult<int32_t, 1> RankTest(
16 armnn::TensorInfo inputTensorInfo,
17 boost::multi_array<T, n> input,
18 armnn::IWorkloadFactory& workloadFactory,
19 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
20{
21 IgnoreUnused(memoryManager);
22
23 const armnn::TensorShape outputShape{armnn::Dimensionality::Scalar};
24 armnn::TensorInfo outputTensorInfo(outputShape, armnn::DataType::Signed32);
25
26 LayerTestResult<int32_t , 1> ret(outputTensorInfo);
27 ret.outputExpected = MakeTensor<uint32_t, 1>(outputTensorInfo, { n });
28
29 std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
30 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
31
32 armnn::RankQueueDescriptor data;
33 armnn::WorkloadInfo info;
34 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
35 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
36
37 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateRank(data, info);
38
39 inputHandle->Allocate();
40 outputHandle->Allocate();
41
42 CopyDataToITensorHandle(inputHandle.get(), input.origin());
43
44 workload->Execute();
45
46 CopyDataFromITensorHandle(&ret.output[0], outputHandle.get());
47
48 return ret;
49}
50
51template<armnn::DataType ArmnnType, typename T>
52LayerTestResult<int32_t, 1> RankDimSize1Test(
53 armnn::IWorkloadFactory& workloadFactory,
54 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
55{
56 armnn::TensorInfo inputTensorInfo({6}, ArmnnType, 1.0f, 0);
57 auto input = MakeTensor<T, 1>(inputTensorInfo, ConvertToDataType<ArmnnType>(
58 { -37.5f, -15.2f, -8.76f, -2.0f, -1.3f, -0.5f },
59 inputTensorInfo));
60
61 return RankTest<T, 1>(inputTensorInfo, input, workloadFactory, memoryManager);
62}
63
64template<armnn::DataType ArmnnType, typename T>
65LayerTestResult<int32_t, 1> RankDimSize2Test(
66 armnn::IWorkloadFactory& workloadFactory,
67 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
68{
69 armnn::TensorInfo inputTensorInfo({1, 3}, ArmnnType, 1.0f, 0);
70 auto input = MakeTensor<T, 2>(inputTensorInfo, ConvertToDataType<ArmnnType>(
71 { -37.5f, -15.2f, -8.76f },
72 inputTensorInfo));
73
74 return RankTest<T, 2>(inputTensorInfo, input, workloadFactory, memoryManager);
75}
76
77template<armnn::DataType ArmnnType, typename T>
78LayerTestResult<int32_t, 1> RankDimSize3Test(
79 armnn::IWorkloadFactory& workloadFactory,
80 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
81{
82 armnn::TensorInfo inputTensorInfo({1, 3, 2}, ArmnnType, 1.0f, 0);
83 auto input = MakeTensor<T, 3>(inputTensorInfo, ConvertToDataType<ArmnnType>(
84 { -37.5f, -15.2f, -8.76f, -2.0f, -1.5f, -1.3f},
85 inputTensorInfo));
86
87 return RankTest<T, 3>(inputTensorInfo, input, workloadFactory, memoryManager);
88}
89
90template<armnn::DataType ArmnnType, typename T>
91LayerTestResult<int32_t, 1> RankDimSize4Test(
92 armnn::IWorkloadFactory& workloadFactory,
93 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
94{
95 armnn::TensorInfo inputTensorInfo({1, 3, 2, 3}, ArmnnType, 1.0f, 0);
96 auto input = MakeTensor<T, 4>(inputTensorInfo, ConvertToDataType<ArmnnType>(
97 { -37.5f, -15.2f, -8.76f, -2.0f, -1.5f, -1.3f, -0.5f, -0.4f, 0.0f,
98 1.0f, 0.4f, 0.5f, 1.3f, 1.5f, 2.0f, 8.76f, 15.2f, 37.5f },
99 inputTensorInfo));
100
101 return RankTest<T, 4>(inputTensorInfo, input, workloadFactory, memoryManager);
102}
103
104template LayerTestResult<int32_t, 1>
105RankDimSize4Test<armnn::DataType::Float16>(
106 armnn::IWorkloadFactory& workloadFactory,
107 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
108
109template LayerTestResult<int32_t, 1>
110RankDimSize4Test<armnn::DataType::Float32>(
111 armnn::IWorkloadFactory& workloadFactory,
112 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
113
114template LayerTestResult<int32_t, 1>
115RankDimSize4Test<armnn::DataType::QAsymmU8>(
116 armnn::IWorkloadFactory& workloadFactory,
117 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
118
119template LayerTestResult<int32_t, 1>
120RankDimSize4Test<armnn::DataType::Signed32>(
121 armnn::IWorkloadFactory& workloadFactory,
122 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
123
124template LayerTestResult<int32_t, 1>
125RankDimSize4Test<armnn::DataType::QSymmS16>(
126 armnn::IWorkloadFactory& workloadFactory,
127 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
128
129template LayerTestResult<int32_t, 1>
130RankDimSize4Test<armnn::DataType::QSymmS8>(
131 armnn::IWorkloadFactory& workloadFactory,
132 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
133
134template LayerTestResult<int32_t, 1>
135RankDimSize4Test<armnn::DataType::QAsymmS8>(
136 armnn::IWorkloadFactory& workloadFactory,
137 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
138
139template LayerTestResult<int32_t, 1>
140RankDimSize4Test<armnn::DataType::BFloat16>(
141 armnn::IWorkloadFactory& workloadFactory,
142 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
143
144template LayerTestResult<int32_t, 1>
145RankDimSize3Test<armnn::DataType::Float16>(
146 armnn::IWorkloadFactory& workloadFactory,
147 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
148
149template LayerTestResult<int32_t, 1>
150RankDimSize3Test<armnn::DataType::Float32>(
151 armnn::IWorkloadFactory& workloadFactory,
152 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
153
154template LayerTestResult<int32_t, 1>
155RankDimSize3Test<armnn::DataType::QAsymmU8>(
156 armnn::IWorkloadFactory& workloadFactory,
157 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
158
159template LayerTestResult<int32_t, 1>
160RankDimSize3Test<armnn::DataType::Signed32>(
161 armnn::IWorkloadFactory& workloadFactory,
162 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
163
164template LayerTestResult<int32_t, 1>
165RankDimSize3Test<armnn::DataType::QSymmS16>(
166 armnn::IWorkloadFactory& workloadFactory,
167 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
168
169template LayerTestResult<int32_t, 1>
170RankDimSize3Test<armnn::DataType::QSymmS8>(
171 armnn::IWorkloadFactory& workloadFactory,
172 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
173
174template LayerTestResult<int32_t, 1>
175RankDimSize3Test<armnn::DataType::QAsymmS8>(
176 armnn::IWorkloadFactory& workloadFactory,
177 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
178
179template LayerTestResult<int32_t, 1>
180RankDimSize3Test<armnn::DataType::BFloat16>(
181 armnn::IWorkloadFactory& workloadFactory,
182 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
183
184template LayerTestResult<int32_t, 1>
185RankDimSize2Test<armnn::DataType::Float16>(
186 armnn::IWorkloadFactory& workloadFactory,
187 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
188
189template LayerTestResult<int32_t, 1>
190RankDimSize2Test<armnn::DataType::Float32>(
191 armnn::IWorkloadFactory& workloadFactory,
192 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
193
194template LayerTestResult<int32_t, 1>
195RankDimSize2Test<armnn::DataType::QAsymmU8>(
196 armnn::IWorkloadFactory& workloadFactory,
197 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
198
199template LayerTestResult<int32_t, 1>
200RankDimSize2Test<armnn::DataType::Signed32>(
201 armnn::IWorkloadFactory& workloadFactory,
202 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
203
204template LayerTestResult<int32_t, 1>
205RankDimSize2Test<armnn::DataType::QSymmS16>(
206 armnn::IWorkloadFactory& workloadFactory,
207 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
208
209template LayerTestResult<int32_t, 1>
210RankDimSize2Test<armnn::DataType::QSymmS8>(
211 armnn::IWorkloadFactory& workloadFactory,
212 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
213
214template LayerTestResult<int32_t, 1>
215RankDimSize2Test<armnn::DataType::QAsymmS8>(
216 armnn::IWorkloadFactory& workloadFactory,
217 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
218
219template LayerTestResult<int32_t, 1>
220RankDimSize2Test<armnn::DataType::BFloat16>(
221 armnn::IWorkloadFactory& workloadFactory,
222 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
223
224template LayerTestResult<int32_t, 1>
225RankDimSize1Test<armnn::DataType::Float16>(
226 armnn::IWorkloadFactory& workloadFactory,
227 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
228
229template LayerTestResult<int32_t, 1>
230RankDimSize1Test<armnn::DataType::Float32>(
231 armnn::IWorkloadFactory& workloadFactory,
232 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
233
234template LayerTestResult<int32_t, 1>
235RankDimSize1Test<armnn::DataType::QAsymmU8>(
236 armnn::IWorkloadFactory& workloadFactory,
237 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
238
239template LayerTestResult<int32_t, 1>
240RankDimSize1Test<armnn::DataType::Signed32>(
241 armnn::IWorkloadFactory& workloadFactory,
242 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
243
244template LayerTestResult<int32_t, 1>
245RankDimSize1Test<armnn::DataType::QSymmS16>(
246 armnn::IWorkloadFactory& workloadFactory,
247 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
248
249template LayerTestResult<int32_t, 1>
250RankDimSize1Test<armnn::DataType::QSymmS8>(
251 armnn::IWorkloadFactory& workloadFactory,
252 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
253
254template LayerTestResult<int32_t, 1>
255RankDimSize1Test<armnn::DataType::QAsymmS8>(
256 armnn::IWorkloadFactory& workloadFactory,
257 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
258
259template LayerTestResult<int32_t, 1>
260RankDimSize1Test<armnn::DataType::BFloat16>(
261 armnn::IWorkloadFactory& workloadFactory,
262 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);