blob: 386d6cbbcb79b939de05f9c5cad184a0ed98c309 [file] [log] [blame]
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +00001//
2// Copyright © 2019 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armagana097d2a2021-11-24 15:47:28 +00008#include <CommonTestUtils.hpp>
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +00009
Colm Donelanc42a9872022-02-02 16:35:09 +000010#include <armnnUtils/QuantizeHelper.hpp>
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +000011#include <ResolveType.hpp>
12
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +000013
14namespace
15{
16
17armnn::INetworkPtr CreateArgMinMaxNetwork(const armnn::TensorInfo& inputTensorInfo,
18 const armnn::TensorInfo& outputTensorInfo,
19 armnn::ArgMinMaxFunction function,
20 int axis)
21{
22 armnn::INetworkPtr network(armnn::INetwork::Create());
23
24 armnn::ArgMinMaxDescriptor descriptor;
25 descriptor.m_Function = function;
26 descriptor.m_Axis = axis;
27
28 armnn::IConnectableLayer* inputLayer = network->AddInputLayer(0, "Input");
29 armnn::IConnectableLayer* argMinMaxLayer = network->AddArgMinMaxLayer(descriptor, "ArgMinMax");
30 armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
31
32 Connect(inputLayer, argMinMaxLayer, inputTensorInfo, 0, 0);
33 Connect(argMinMaxLayer, outputLayer, outputTensorInfo, 0, 0);
34
35 return network;
36}
37
38template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
39void ArgMinMaxEndToEndImpl(const armnn::TensorShape& inputShape,
40 const armnn::TensorShape& outputShape,
41 const std::vector<float>& inputData,
42 const std::vector<int32_t>& expectedOutputData,
43 armnn::ArgMinMaxFunction function,
44 int axis,
45 const std::vector<armnn::BackendId>& backends)
46{
47 const float qScale = armnn::IsQuantizedType<T>() ? 2.0f : 1.0f;
48 const int32_t qOffset = armnn::IsQuantizedType<T>() ? 2 : 0;
49
Cathal Corbett5b8093c2021-10-22 11:12:07 +010050 armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType, qScale, qOffset, true);
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +000051 armnn::TensorInfo outputTensorInfo(outputShape, armnn::DataType::Signed32);
52
53 // quantize data
54 std::vector<T> qInputData = armnnUtils::QuantizedVector<T>(inputData, qScale, qOffset);
55
56 armnn::INetworkPtr network = CreateArgMinMaxNetwork(inputTensorInfo,
57 outputTensorInfo,
58 function,
59 axis);
60
61 EndToEndLayerTestImpl<ArmnnType, armnn::DataType::Signed32>(std::move(network),
62 { { 0, qInputData } },
63 { { 0, expectedOutputData } },
64 backends);
65}
66
67template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
68void ArgMaxEndToEndSimple(const std::vector<armnn::BackendId>& backends)
69{
70 const armnn::TensorShape inputShape{ 1, 1, 1, 5 };
71 const armnn::TensorShape outputShape{ 1, 1, 1 };
72
Francis Murtagh62cdb082019-11-11 16:53:13 +000073 std::vector<float> inputData({ 6.0f, 2.0f, 8.0f, 10.0f, 9.0f });
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +000074 std::vector<int32_t> expectedOutputData({ 3 });
75
76 ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
77 outputShape,
78 inputData,
79 expectedOutputData,
80 armnn::ArgMinMaxFunction::Max,
81 -1,
82 backends);
83}
84
85template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
86void ArgMinEndToEndSimple(const std::vector<armnn::BackendId>& backends)
87{
88 const armnn::TensorShape inputShape{ 1, 1, 1, 5 };
89 const armnn::TensorShape outputShape{ 1, 1, 1 };
90
Francis Murtagh62cdb082019-11-11 16:53:13 +000091 std::vector<float> inputData({ 6.0f, 2.0f, 8.0f, 10.0f, 9.0f });
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +000092 std::vector<int32_t> expectedOutputData({ 1 });
93
94 ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
95 outputShape,
96 inputData,
97 expectedOutputData,
98 armnn::ArgMinMaxFunction::Min,
99 3,
100 backends);
101}
102
103template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
104void ArgMaxAxis0EndToEnd(const std::vector<armnn::BackendId>& backends)
105{
106 const armnn::TensorShape inputShape{ 3, 2, 1, 4 };
107 const armnn::TensorShape outputShape{ 2, 1, 4 };
108
109 std::vector<float> inputData({ 1.0f, 2.0f, 3.0f, 4.0f,
Francis Murtagh62cdb082019-11-11 16:53:13 +0000110 8.0f, 7.0f, 6.0f, 6.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000111 100.0f, 20.0f, 300.0f, 40.0f,
Francis Murtagh62cdb082019-11-11 16:53:13 +0000112 500.0f, 476.0f, 450.0f, 426.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000113 50.0f, 60.0f, 70.0f, 80.0f,
114 10.0f, 200.0f, 30.0f, 400.0f });
115
116 std::vector<int32_t> expectedOutputData({ 1, 2, 1, 2,
117 1, 1, 1, 1 });
118
119 ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
120 outputShape,
121 inputData,
122 expectedOutputData,
123 armnn::ArgMinMaxFunction::Max,
124 0,
125 backends);
126}
127
128template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
129void ArgMinAxis0EndToEnd(const std::vector<armnn::BackendId>& backends)
130{
131 const armnn::TensorShape inputShape{ 3, 2, 1, 4 };
132 const armnn::TensorShape outputShape{ 2, 1, 4 };
133
134 std::vector<float> inputData({ 1.0f, 2.0f, 3.0f, 4.0f,
Francis Murtagh62cdb082019-11-11 16:53:13 +0000135 8.0f, 7.0f, 6.0f, 6.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000136 100.0f, 20.0f, 300.0f, 40.0f,
Francis Murtagh62cdb082019-11-11 16:53:13 +0000137 500.0f, 476.0f, 450.0f, 426.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000138 50.0f, 60.0f, 70.0f, 80.0f,
139 10.0f, 200.0f, 30.0f, 400.0f });
140
141 std::vector<int32_t> expectedOutputData({ 0, 0, 0, 0,
142 0, 0, 0, 0 });
143
144 ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
145 outputShape,
146 inputData,
147 expectedOutputData,
148 armnn::ArgMinMaxFunction::Min,
149 0,
150 backends);
151}
152
153template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
154void ArgMaxAxis1EndToEnd(const std::vector<armnn::BackendId>& backends)
155{
156 const armnn::TensorShape inputShape{ 1, 3, 2, 4 };
157 const armnn::TensorShape outputShape{ 1, 2, 4 };
158
159 std::vector<float> inputData({ 1.0f, 2.0f, 3.0f, 4.0f,
Francis Murtagh62cdb082019-11-11 16:53:13 +0000160 8.0f, 7.0f, 6.0f, 6.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000161 100.0f, 20.0f, 300.0f, 40.0f,
Francis Murtagh62cdb082019-11-11 16:53:13 +0000162 500.0f, 476.0f, 450.0f, 426.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000163 50.0f, 60.0f, 70.0f, 80.0f,
164 10.0f, 200.0f, 30.0f, 400.0f });
165
166 std::vector<int32_t> expectedOutputData({ 1, 2, 1, 2,
167 1, 1, 1, 1 });
168
169 ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
170 outputShape,
171 inputData,
172 expectedOutputData,
173 armnn::ArgMinMaxFunction::Max,
174 1,
175 backends);
176}
177
178template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
179void ArgMinAxis1EndToEnd(const std::vector<armnn::BackendId>& backends)
180{
181 const armnn::TensorShape inputShape{ 1, 3, 2, 4 };
182 const armnn::TensorShape outputShape{ 1, 2, 4 };
183
184 std::vector<float> inputData({ 1.0f, 2.0f, 3.0f, 4.0f,
Francis Murtagh62cdb082019-11-11 16:53:13 +0000185 8.0f, 7.0f, 6.0f, 6.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000186 100.0f, 20.0f, 300.0f, 40.0f,
Francis Murtagh62cdb082019-11-11 16:53:13 +0000187 500.0f, 476.0f, 450.0f, 426.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000188 50.0f, 60.0f, 70.0f, 80.0f,
189 10.0f, 200.0f, 30.0f, 400.0f });
190
191 std::vector<int32_t> expectedOutputData({ 0, 0, 0, 0,
192 0, 0, 0, 0 });
193
194 ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
195 outputShape,
196 inputData,
197 expectedOutputData,
198 armnn::ArgMinMaxFunction::Min,
199 1,
200 backends);
201}
202
203template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
204void ArgMaxAxis2EndToEnd(const std::vector<armnn::BackendId>& backends)
205{
206 const armnn::TensorShape inputShape{ 1, 3, 2, 4 };
207 const armnn::TensorShape outputShape{ 1, 3, 4 };
208
209 std::vector<float> inputData({ 1.0f, 2.0f, 3.0f, 4.0f,
Francis Murtagh62cdb082019-11-11 16:53:13 +0000210 8.0f, 7.0f, 6.0f, 6.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000211 100.0f, 20.0f, 300.0f, 40.0f,
Francis Murtagh62cdb082019-11-11 16:53:13 +0000212 500.0f, 476.0f, 450.0f, 426.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000213 10.0f, 200.0f, 30.0f, 400.0f,
214 50.0f, 60.0f, 70.0f, 80.0f });
215
216 std::vector<int32_t> expectedOutputData({ 1, 1, 1, 1,
217 1, 1, 1, 1,
218 1, 0, 1, 0});
219
220 ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
221 outputShape,
222 inputData,
223 expectedOutputData,
224 armnn::ArgMinMaxFunction::Max,
225 2,
226 backends);
227}
228
229template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
230void ArgMinAxis2EndToEnd(const std::vector<armnn::BackendId>& backends)
231{
232 const armnn::TensorShape inputShape{ 1, 3, 2, 4 };
233 const armnn::TensorShape outputShape{ 1, 3, 4 };
234
235 std::vector<float> inputData({ 1.0f, 2.0f, 3.0f, 4.0f,
Francis Murtagh62cdb082019-11-11 16:53:13 +0000236 8.0f, 7.0f, 6.0f, 6.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000237 100.0f, 20.0f, 300.0f, 40.0f,
Francis Murtagh62cdb082019-11-11 16:53:13 +0000238 500.0f, 476.0f, 450.0f, 426.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000239 10.0f, 200.0f, 30.0f, 400.0f,
240 50.0f, 60.0f, 70.0f, 80.0f });
241
242 std::vector<int32_t> expectedOutputData({ 0, 0, 0, 0,
243 0, 0, 0, 0,
244 0, 1, 0, 1 });
245
246 ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
247 outputShape,
248 inputData,
249 expectedOutputData,
250 armnn::ArgMinMaxFunction::Min,
251 2,
252 backends);
253}
254
255template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
256void ArgMaxAxis3EndToEnd(const std::vector<armnn::BackendId>& backends)
257{
258 const armnn::TensorShape inputShape{ 1, 3, 2, 4 };
259 const armnn::TensorShape outputShape{ 1, 3, 2 };
260
Francis Murtagh62cdb082019-11-11 16:53:13 +0000261 std::vector<float> inputData({ 1.0f, 3.0f, 6.0f, 7.0f,
262 8.0f, 7.0f, 6.0f, 6.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000263 100.0f, 20.0f, 300.0f, 40.0f,
Francis Murtagh62cdb082019-11-11 16:53:13 +0000264 500.0f, 476.0f, 450.0f, 426.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000265 10.0f, 200.0f, 30.0f, 400.0f,
266 50.0f, 60.0f, 70.0f, 80.0f });
267
268 std::vector<int32_t> expectedOutputData({ 3, 0,
269 2, 0,
270 3, 3});
271
272 ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
273 outputShape,
274 inputData,
275 expectedOutputData,
276 armnn::ArgMinMaxFunction::Max,
277 3,
278 backends);
279}
280
281template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
282void ArgMinAxis3EndToEnd(const std::vector<armnn::BackendId>& backends)
283{
284 const armnn::TensorShape inputShape{ 1, 3, 2, 4 };
285 const armnn::TensorShape outputShape{ 1, 3, 2 };
286
Francis Murtagh62cdb082019-11-11 16:53:13 +0000287 std::vector<float> inputData({ 1.0f, 3.0f, 6.0f, 7.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000288 18.0f, 16.0f, 14.0f, 12.0f,
289 100.0f, 20.0f, 300.0f, 40.0f,
Francis Murtagh62cdb082019-11-11 16:53:13 +0000290 500.0f, 476.0f, 450.0f, 426.0f,
Narumol Prangnawaratd1f57732019-10-31 14:24:02 +0000291 10.0f, 200.0f, 30.0f, 400.0f,
292 50.0f, 60.0f, 70.0f, 80.0f });
293
294 std::vector<int32_t> expectedOutputData({ 0, 3,
295 1, 3,
296 0, 0 });
297
298 ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
299 outputShape,
300 inputData,
301 expectedOutputData,
302 armnn::ArgMinMaxFunction::Min,
303 3,
304 backends);
305}
306
307} // anonymous namespace