blob: 63375f0f2f5e41684c497c79e5157f9776e9c97e [file] [log] [blame]
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001//
Teresa Charlinfbf0e5b2020-08-17 01:01:06 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "SoftmaxTestImpl.hpp"
7
Colm Donelanc42a9872022-02-02 16:35:09 +00008#include <armnnUtils/QuantizeHelper.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01009#include <ResolveType.hpp>
10
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010011
Colm Donelan0c479742021-12-10 12:43:54 +000012#include <armnn/backends/TensorHandle.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010013
Sadik Armagana097d2a2021-11-24 15:47:28 +000014#include <armnnTestUtils/TensorCopyUtils.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +000015#include <armnnTestUtils/WorkloadTestUtils.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010016
Colm Donelanc42a9872022-02-02 16:35:09 +000017#include <armnnTestUtils/TensorHelpers.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010018
19#include <algorithm>
20
21namespace
22{
23
24struct Simple3dSoftmaxOutputData
25{
26 const std::vector<float> outputData =
27 {
28 0.0964599f, 0.26220518f, 0.0964599f, 0.0964599f,
29 0.15903549f, 0.0964599f, 0.0964599f, 0.0964599f
30 };
31
32 const armnn::TensorShape inputShape{ 1, 8, 1 };
33
34 const std::vector<float> inputData =
35 {
36 0.0f, 1.0f, 0.0f, 0.0f,
37 0.5f, 0.0f, 0.0f, 0.0f,
38 };
39};
40
41struct Simple4dSoftmaxData
42{
43 const armnn::TensorShape inputShape{ 1, 8, 1, 1 };
44
45 const std::vector<float> outputData =
46 {
47 0.0964599f, 0.26220518f, 0.0964599f, 0.0964599f,
48 0.15903549f, 0.0964599f, 0.0964599f, 0.0964599f
49 };
50
51 const std::vector<float> inputData =
52 {
53 0.0f, 1.0f, 0.0f, 0.0f,
54 0.5f, 0.0f, 0.0f, 0.0f
55 };
56};
57
58template<armnn::DataType ArmnnType, std::size_t n, typename T = armnn::ResolveType<ArmnnType>>
59LayerTestResult<T, n> SimpleSoftmaxBaseTestImpl(
60 armnn::IWorkloadFactory& workloadFactory,
61 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +010062 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010063 float beta,
64 const armnn::TensorShape& inputShape,
65 const std::vector<float>& outputData,
66 const std::vector<float>& inputData,
David Monahan9b14bfc2020-06-30 15:57:56 +010067 int axis = -1)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010068{
69 using std::exp;
70
71 const float qScale = 1.f / 256.f;
72 const int qOffset = 0;
73
74 armnn::TensorInfo inputTensorInfo;
75 armnn::TensorInfo outputTensorInfo;
76
77 inputTensorInfo = armnn::TensorInfo(inputShape, ArmnnType);
78 inputTensorInfo.SetQuantizationScale(qScale);
79 inputTensorInfo.SetQuantizationOffset(qOffset);
80
81 outputTensorInfo = armnn::TensorInfo(inputShape, ArmnnType);
82 outputTensorInfo.SetQuantizationScale(qScale);
83 outputTensorInfo.SetQuantizationOffset(qOffset);
84
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010085 // Each row is independently softmax'd.
Sadik Armagan483c8112021-06-01 09:24:52 +010086 std::vector<T> input = armnnUtils::QuantizedVector<T>(inputData, qScale, qOffset);
87 std::vector<T> expectedOutput = armnnUtils::QuantizedVector<T>(outputData, qScale, qOffset);
88 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010089
Sadik Armagan89de3b42020-08-28 10:38:53 +010090 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
91 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010092
93 armnn::SoftmaxQueueDescriptor data;
94 data.m_Parameters.m_Beta = beta;
95 data.m_Parameters.m_Axis = axis;
96
97 armnn::WorkloadInfo info;
98 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
99 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
100
Teresa Charlin611c7fb2022-01-07 09:47:29 +0000101 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Softmax, data, info);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100102
103 inputHandle->Allocate();
104 outputHandle->Allocate();
Sadik Armagan483c8112021-06-01 09:24:52 +0100105 CopyDataToITensorHandle(inputHandle.get(), input.data());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100106
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100107 ARMNN_ASSERT(workload);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100108
109 ExecuteWorkload(*workload, memoryManager);
110
Sadik Armagan483c8112021-06-01 09:24:52 +0100111 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100112
Sadik Armagan483c8112021-06-01 09:24:52 +0100113 return LayerTestResult<T, n>(actualOutput,
114 expectedOutput,
115 outputHandle->GetShape(),
116 outputTensorInfo.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100117}
118
119template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
120LayerTestResult<T, 2> SimpleSoftmaxTestImpl(
121 armnn::IWorkloadFactory& workloadFactory,
122 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100123 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100124 float beta)
125{
126 using std::exp;
127 const armnn::TensorShape inputShape{ 2, 4 };
128
129 float x0[4] = { exp((0.f - 1.0f) * beta), exp((1.0f - 1.0f) * beta),
130 exp((0.0f - 1.0f) * beta), exp((0.0f - 1.0f) * beta) };
131 float sum0 = x0[0] + x0[1] + x0[2] + x0[3];
132 float x1[4] = { exp((0.5f - 0.5f) * beta), exp((0.0f - 0.5f) * beta),
133 exp((0.0f - 0.5f) * beta), exp((0.0f - 0.5f) * beta) };
134 float sum1 = x1[0] + x1[1] + x1[2] + x1[3];
135
136 const std::vector<float> outputData = { x0[0] / sum0, x0[1] / sum0, x0[2] / sum0, x0[3] / sum0,
137 x1[0] / sum1, x1[1] / sum1, x1[2] / sum1, x1[3] / sum1 };
138
139 const std::vector<float> inputData =
140 {
141 0.f, 1.f, 0.f, 0.f,
142 .5f, 0.f, 0.f, 0.f,
143 };
144
Sadik Armagan56785c72020-08-27 12:57:20 +0100145 return SimpleSoftmaxBaseTestImpl<ArmnnType, 2>(workloadFactory, memoryManager, tensorHandleFactory, beta,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100146 inputShape, outputData, inputData);
147}
148
149template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
150LayerTestResult<T, 2> SimpleSoftmaxTestImpl(
151 armnn::IWorkloadFactory& workloadFactory,
152 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100153 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100154 float beta,
155 int axis)
156{
157 armnn::TensorShape inputShape;
158 std::vector<float> inputData;
159 std::vector<float> outputData;
160 switch (axis)
161 {
162 case -2:
163 case 0:
164 {
165 inputShape = {5, 2};
166
167 inputData =
168 {
169 17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f
170 };
171
172 outputData =
173 {
174 0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
175 0.087144312427294f,
176 0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
177 7.246299848982885e-08f
178 };
179 break;
180 }
181 case -1:
182 case 1:
183 {
184 inputShape = {2, 5};
185
186 inputData =
187 {
188 17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f
189 };
190
191 outputData =
192 {
193 0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
194 7.246299848982885e-08f,
195 0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
196 7.246299848982885e-08f
197 };
198 break;
199 }
200 }
Sadik Armagan56785c72020-08-27 12:57:20 +0100201 return SimpleSoftmaxBaseTestImpl<ArmnnType, 2>(workloadFactory, memoryManager, tensorHandleFactory, beta,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100202 inputShape, outputData, inputData, axis);
203}
204
205template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
206LayerTestResult<T, 3> Simple3dSoftmaxTestImpl(
207 armnn::IWorkloadFactory& workloadFactory,
208 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100209 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100210 float beta,
211 const armnn::TensorShape& inputShape,
212 const std::vector<float>& outputData,
213 const std::vector<float>& inputData,
214 int axis = 1)
215{
Sadik Armagan56785c72020-08-27 12:57:20 +0100216 return SimpleSoftmaxBaseTestImpl<ArmnnType, 3>(workloadFactory, memoryManager, tensorHandleFactory, beta,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100217 inputShape, outputData, inputData, axis);
218}
219
220template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
221LayerTestResult<T, 4> Simple4dSoftmaxTestImpl(
222 armnn::IWorkloadFactory& workloadFactory,
223 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100224 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100225 float beta,
226 const armnn::TensorShape& inputShape,
227 const std::vector<float>& outputData,
228 const std::vector<float>& inputData,
229 int axis = 1)
230{
231
Sadik Armagan56785c72020-08-27 12:57:20 +0100232 return SimpleSoftmaxBaseTestImpl<ArmnnType, 4>(workloadFactory, memoryManager, tensorHandleFactory, beta,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100233 inputShape, outputData, inputData, axis);
234}
235
236template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
237LayerTestResult<T, 2> CompareSoftmaxTestImpl(
238 armnn::IWorkloadFactory& workloadFactory,
239 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
240 armnn::IWorkloadFactory& refWorkloadFactory,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100241 const armnn::ITensorHandleFactory& tensorHandleFactory,
242 const armnn::ITensorHandleFactory& refTensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100243 float beta)
244{
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100245 const int batchSize = 20;
246 const int channels = 30;
247
248 armnn::TensorInfo inputTensorInfo;
249 armnn::TensorInfo outputTensorInfo;
250
251 unsigned int inputShape[] = { batchSize, channels };
252
253 inputTensorInfo = armnn::TensorInfo(2, inputShape, ArmnnType);
254 outputTensorInfo = armnn::TensorInfo(2, inputShape, ArmnnType);
255 float qScale = 1.f / 256.f;
256 int qOffset = 0;
257 inputTensorInfo.SetQuantizationScale(qScale);
258 inputTensorInfo.SetQuantizationOffset(qOffset);
259 outputTensorInfo.SetQuantizationScale(qScale);
260 outputTensorInfo.SetQuantizationOffset(qOffset);
261
Sadik Armagan483c8112021-06-01 09:24:52 +0100262 auto input = MakeRandomTensor<T>(inputTensorInfo, 0xF00D, 0.0f, 1.0f);
263 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
264 std::vector<T> expectedOutput(outputTensorInfo.GetNumElements());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100265
Sadik Armagan89de3b42020-08-28 10:38:53 +0100266 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
267 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100268
269 armnn::SoftmaxQueueDescriptor data;
270 data.m_Parameters.m_Beta = beta;
271
272 armnn::WorkloadInfo info;
273 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
274 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
275
Sadik Armagan56785c72020-08-27 12:57:20 +0100276 std::unique_ptr<armnn::ITensorHandle> outputHandleRef =
Sadik Armagan89de3b42020-08-28 10:38:53 +0100277 refTensorHandleFactory.CreateTensorHandle(outputTensorInfo);
Sadik Armagan56785c72020-08-27 12:57:20 +0100278 std::unique_ptr<armnn::ITensorHandle> inputHandleRef =
Sadik Armagan89de3b42020-08-28 10:38:53 +0100279 refTensorHandleFactory.CreateTensorHandle(inputTensorInfo);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100280
281 armnn::SoftmaxQueueDescriptor refData = data;
282 armnn::WorkloadInfo refInfo = info;
283 SetWorkloadInput(refData, refInfo, 0, inputTensorInfo, inputHandleRef.get());
284 SetWorkloadOutput(refData, refInfo, 0, outputTensorInfo, outputHandleRef.get());
285
Teresa Charlin611c7fb2022-01-07 09:47:29 +0000286 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Softmax, data, info);
287 std::unique_ptr<armnn::IWorkload> workloadRef = refWorkloadFactory.CreateWorkload(armnn::LayerType::Softmax,
288 refData,
289 refInfo);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100290
291 outputHandleRef->Allocate();
292 inputHandleRef->Allocate();
293
294 inputHandle->Allocate();
295 outputHandle->Allocate();
296
Sadik Armagan483c8112021-06-01 09:24:52 +0100297 CopyDataToITensorHandle(inputHandle.get(), input.data());
298 CopyDataToITensorHandle(inputHandleRef.get(), input.data());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100299
300 ExecuteWorkload(*workload, memoryManager);
301
302 workloadRef->Execute();
303
Sadik Armagan483c8112021-06-01 09:24:52 +0100304 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
305 CopyDataFromITensorHandle(expectedOutput.data(), outputHandleRef.get());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100306
Sadik Armagan483c8112021-06-01 09:24:52 +0100307 return LayerTestResult<T, 2>(actualOutput,
308 expectedOutput,
309 outputHandle->GetShape(),
310 outputTensorInfo.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100311}
312
313} // anonymous namespace
314
315LayerTestResult<float,2> SimpleSoftmaxTest(
316 armnn::IWorkloadFactory& workloadFactory,
317 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100318 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100319 float beta)
320{
Sadik Armagan56785c72020-08-27 12:57:20 +0100321 return SimpleSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, tensorHandleFactory, beta);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100322}
323
324LayerTestResult<float,2> SimpleAxisSoftmaxTest(
325 armnn::IWorkloadFactory& workloadFactory,
326 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100327 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100328 float beta,
329 int axis)
330{
Sadik Armagan56785c72020-08-27 12:57:20 +0100331 return SimpleSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager,
332 tensorHandleFactory, beta, axis);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100333}
334
335LayerTestResult<float,3> Simple3dSoftmaxTest(
336 armnn::IWorkloadFactory& workloadFactory,
337 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100338 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100339 float beta)
340{
341 Simple3dSoftmaxOutputData data;
Sadik Armagan56785c72020-08-27 12:57:20 +0100342 return Simple3dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, tensorHandleFactory, beta,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100343 data.inputShape, data.outputData, data.inputData);
344}
345
346LayerTestResult<float,3> Simple3dAxisSoftmaxTest(
347 armnn::IWorkloadFactory& workloadFactory,
348 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100349 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100350 float beta,
351 int axis)
352{
353 armnn::TensorShape inputShape;
354 std::vector<float> inputData;
355 std::vector<float> outputData;
356 switch (axis)
357 {
358 case -3:
359 case 0:
360 {
361 inputShape = {5, 2, 2};
362
363 inputData =
364 {
365 17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
366
367 15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f
368 };
369
370 outputData =
371 {
372 0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
373 0.236882800924671f,
374 0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
375 0.087144312427294f,
376
377 0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
378 0.032058600957022f,
379 0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
380 7.246299848982885e-08f
381 };
382 break;
383 }
384 case -2:
385 case 1:
386 {
387 inputShape = {2, 5, 2};
388
389 inputData =
390 {
391 17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
392
393 17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f
394 };
395
396 outputData =
397 {
398 0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
399 0.087144312427294f,
400 0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
401 7.246299848982885e-08f,
402
403 0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
404 0.087144312427294f,
405 0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
406 7.246299848982885e-08f
407 };
408 break;
409 }
410 case -1:
411 case 2:
412 {
413 inputShape = {2, 2, 5};
414
415 inputData =
416 {
417 17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
418 17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f
419 };
420
421 outputData =
422 {
423 0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
424 7.246299848982885e-08f,
425 0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
426 7.246299848982885e-08f,
427
428 0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
429 7.246299848982885e-08f,
430 0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
431 7.246299848982885e-08f
432 };
433 break;
434 }
435 }
436
Sadik Armagan56785c72020-08-27 12:57:20 +0100437 return Simple3dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, tensorHandleFactory, beta,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100438 inputShape, outputData, inputData, axis);
439}
440
441LayerTestResult<float,4> Simple4dSoftmaxTest(
442 armnn::IWorkloadFactory& workloadFactory,
443 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100444 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100445 float beta)
446{
447 Simple4dSoftmaxData data;
Sadik Armagan56785c72020-08-27 12:57:20 +0100448 return Simple4dSoftmaxTestImpl<armnn::DataType::Float32>(workloadFactory, memoryManager, tensorHandleFactory,
449 beta, data.inputShape, data.outputData, data.inputData);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100450}
451
452LayerTestResult<float,4> Simple4dAxisSoftmaxTest(
453 armnn::IWorkloadFactory& workloadFactory,
454 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100455 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100456 float beta,
457 int axis)
458{
459 armnn::TensorShape inputShape;
460 std::vector<float> inputData;
461 std::vector<float> outputData;
462 switch (axis)
463 {
464 case -4:
465 case 0:
466 {
467 inputShape = {5, 2, 2, 2};
468
469 inputData =
470 {
471 17.0f, -1.0f, 17.0f, -1.0f, 17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f,
472 16.0f, -2.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f, 15.0f, -3.0f,
473 15.0f, -3.0f, 15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 14.0f, -4.0f,
474 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f, 1.0f, -17.0f, 1.0f, -17.0f
475 };
476
477 outputData =
478 {
479 0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
480 0.643914213228014f,
481 0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.236882800924671f,
482 0.236882800924671f,
483 0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.236882800924671f,
484 0.236882800924671f,
485 0.236882800924671f, 0.087144312427294f, 0.087144312427294f, 0.087144312427294f,
486 0.087144312427294f,
487
488 0.087144312427294f, 0.087144312427294f, 0.087144312427294f, 0.087144312427294f,
489 0.032058600957022f,
490 0.032058600957022f, 0.032058600957022f, 0.032058600957022f, 0.032058600957022f,
491 0.032058600957022f,
492 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f,
493 7.246299848982885e-08f,
494 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
495 7.246299848982885e-08f, 7.246299848982885e-08f
496 };
497 break;
498 }
499 case -3:
500 case 1:
501 {
502 inputShape = {2, 5, 2, 2};
503
504 inputData =
505 {
506 17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
507 15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f,
508 17.0f, -1.0f, 17.0f, -1.0f, 16.0f, -2.0f, 16.0f, -2.0f, 15.0f, -3.0f,
509 15.0f, -3.0f, 14.0f, -4.0f, 14.0f, -4.0f, 1.0f, -17.0f, 1.0f, -17.0f
510 };
511
512 outputData =
513 {
514 0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
515 0.236882800924671f,
516 0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
517 0.087144312427294f,
518 0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
519 0.032058600957022f,
520 0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
521 7.246299848982885e-08f,
522
523
524 0.643914213228014f, 0.643914213228014f, 0.643914213228014f, 0.643914213228014f,
525 0.236882800924671f,
526 0.236882800924671f, 0.236882800924671f, 0.236882800924671f, 0.087144312427294f,
527 0.087144312427294f,
528 0.087144312427294f, 0.087144312427294f, 0.032058600957022f, 0.032058600957022f,
529 0.032058600957022f,
530 0.032058600957022f, 7.246299848982885e-08f, 7.246299848982885e-08f, 7.246299848982885e-08f,
531 7.246299848982885e-08f
532 };
533 break;
534 }
535 case -2:
536 case 2:
537 {
538 inputShape = {2, 2, 5, 2};
539
540 inputData =
541 {
542 17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
543 17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
544 17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f,
545 17.0f, -1.0f, 16.0f, -2.0f, 15.0f, -3.0f, 14.0f, -4.0f, 1.0f, -17.0f
546 };
547
548 outputData =
549 {
550 0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
551 0.087144312427294f,
552 0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
553 7.246299848982885e-08f,
554 0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
555 0.087144312427294f,
556 0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
557 7.246299848982885e-08f,
558
559 0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
560 0.087144312427294f,
561 0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
562 7.246299848982885e-08f,
563 0.643914213228014f, 0.643914213228014f, 0.236882800924671f, 0.236882800924671f,
564 0.087144312427294f,
565 0.087144312427294f, 0.032058600957022f, 0.032058600957022f, 7.246299848982885e-08f,
566 7.246299848982885e-08f
567 };
568 break;
569 }
570 case -1:
571 case 3:
572 {
573 inputShape = {2, 2, 2, 5};
574
575 inputData =
576 {
577 17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
578 17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
579 17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f,
580 17.0f, 16.0f, 15.0f, 14.0f, 1.0f, -1.0f, -2.0f, -3.0f, -4.0f, -17.0f
581 };
582
583 outputData =
584 {
585 0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
586 7.246299848982885e-08f,
587 0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
588 7.246299848982885e-08f,
589 0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
590 7.246299848982885e-08f,
591 0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
592 7.246299848982885e-08f,
593
594 0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
595 7.246299848982885e-08f,
596 0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
597 7.246299848982885e-08f,
598 0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
599 7.246299848982885e-08f,
600 0.643914213228014f, 0.236882800924671f, 0.087144312427294f, 0.032058600957022f,
601 7.246299848982885e-08f
602 };
603 break;
604 }
605 }
606
607 return Simple4dSoftmaxTestImpl<armnn::DataType::Float32>(
608 workloadFactory,
609 memoryManager,
Sadik Armagan56785c72020-08-27 12:57:20 +0100610 tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100611 beta,
612 inputShape,
613 outputData,
614 inputData,
615 axis);
616}
617
618LayerTestResult<uint8_t,2> SimpleSoftmaxUint8Test(
619 armnn::IWorkloadFactory& workloadFactory,
620 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100621 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100622 float beta)
623{
Sadik Armagan56785c72020-08-27 12:57:20 +0100624 return SimpleSoftmaxTestImpl<armnn::DataType::QAsymmU8>(workloadFactory, memoryManager, tensorHandleFactory, beta);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100625}
626
627LayerTestResult<uint8_t,3> Simple3dSoftmaxUint8Test(
628 armnn::IWorkloadFactory& workloadFactory,
629 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100630 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100631 float beta)
632{
633 Simple3dSoftmaxOutputData data;
Derek Lambertif90c56d2020-01-10 17:14:08 +0000634 return Simple3dSoftmaxTestImpl<armnn::DataType::QAsymmU8>(
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100635 workloadFactory,
636 memoryManager,
Sadik Armagan56785c72020-08-27 12:57:20 +0100637 tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100638 beta,
639 data.inputShape,
640 data.outputData,
641 data.inputData);
642}
643
644LayerTestResult<uint8_t,4> Simple4dSoftmaxUint8Test(
645 armnn::IWorkloadFactory& workloadFactory,
646 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100647 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100648 float beta)
649{
650 Simple4dSoftmaxData data;
651
Sadik Armagan56785c72020-08-27 12:57:20 +0100652 return Simple4dSoftmaxTestImpl<armnn::DataType::QAsymmU8>(workloadFactory, memoryManager, tensorHandleFactory, beta,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100653 data.inputShape, data.outputData, data.inputData);
654}
655
Matthew Jackson9bff1442019-09-12 09:08:23 +0100656LayerTestResult<armnn::Half,2> SimpleSoftmaxFloat16Test(
657 armnn::IWorkloadFactory& workloadFactory,
658 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100659 const armnn::ITensorHandleFactory& tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100660 float beta)
661{
Sadik Armagan56785c72020-08-27 12:57:20 +0100662 return SimpleSoftmaxTestImpl<armnn::DataType::Float16>(workloadFactory, memoryManager, tensorHandleFactory, beta);
Matthew Jackson9bff1442019-09-12 09:08:23 +0100663}
664
665LayerTestResult<armnn::Half,3> Simple3dSoftmaxFloat16Test(
666 armnn::IWorkloadFactory& workloadFactory,
667 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100668 const armnn::ITensorHandleFactory& tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100669 float beta)
670{
671 Simple3dSoftmaxOutputData data;
Sadik Armagan56785c72020-08-27 12:57:20 +0100672 return Simple3dSoftmaxTestImpl<armnn::DataType::Float16>(workloadFactory, memoryManager, tensorHandleFactory, beta,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100673 data.inputShape, data.outputData, data.inputData);
674}
675
676LayerTestResult<armnn::Half,4> Simple4dSoftmaxFloat16Test(
677 armnn::IWorkloadFactory& workloadFactory,
678 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100679 const armnn::ITensorHandleFactory& tensorHandleFactory,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100680 float beta)
681{
682 Simple4dSoftmaxData data;
Sadik Armagan56785c72020-08-27 12:57:20 +0100683 return Simple4dSoftmaxTestImpl<armnn::DataType::Float16>(workloadFactory, memoryManager, tensorHandleFactory, beta,
Matthew Jackson9bff1442019-09-12 09:08:23 +0100684 data.inputShape, data.outputData, data.inputData);
685}
686
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100687LayerTestResult<int16_t,2> SimpleSoftmaxUint16Test(
688 armnn::IWorkloadFactory& workloadFactory,
689 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100690 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100691 float beta)
692{
Sadik Armagan56785c72020-08-27 12:57:20 +0100693 return SimpleSoftmaxTestImpl<armnn::DataType::QSymmS16>(workloadFactory, memoryManager, tensorHandleFactory, beta);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100694}
695
696LayerTestResult<int16_t,3> Simple3dSoftmaxUint16Test(
697 armnn::IWorkloadFactory& workloadFactory,
698 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100699 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100700 float beta)
701{
702 Simple3dSoftmaxOutputData data;
Sadik Armagan56785c72020-08-27 12:57:20 +0100703 return Simple3dSoftmaxTestImpl<armnn::DataType::QSymmS16>(workloadFactory, memoryManager, tensorHandleFactory, beta,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100704 data.inputShape, data.outputData, data.inputData);
705}
706
707LayerTestResult<int16_t,4> Simple4dSoftmaxUint16Test(
708 armnn::IWorkloadFactory& workloadFactory,
709 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100710 const armnn::ITensorHandleFactory& tensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100711 float beta)
712{
713 Simple4dSoftmaxData data;
714
Sadik Armagan56785c72020-08-27 12:57:20 +0100715 return Simple4dSoftmaxTestImpl<armnn::DataType::QSymmS16>(workloadFactory, memoryManager, tensorHandleFactory, beta,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100716 data.inputShape, data.outputData, data.inputData);
717}
718
719LayerTestResult<float,2> CompareSoftmaxTest(
720 armnn::IWorkloadFactory& workloadFactory,
721 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
722 armnn::IWorkloadFactory& refWorkloadFactory,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100723 const armnn::ITensorHandleFactory& tensorHandleFactory,
724 const armnn::ITensorHandleFactory& refTensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100725 float beta)
726{
727 return CompareSoftmaxTestImpl<armnn::DataType::Float32>(
Sadik Armagan56785c72020-08-27 12:57:20 +0100728 workloadFactory, memoryManager, refWorkloadFactory, tensorHandleFactory, refTensorHandleFactory, beta);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100729}
730
731LayerTestResult<uint8_t,2> CompareSoftmaxUint8Test(
732 armnn::IWorkloadFactory& workloadFactory,
733 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
734 armnn::IWorkloadFactory& refWorkloadFactory,
Sadik Armagan89de3b42020-08-28 10:38:53 +0100735 const armnn::ITensorHandleFactory& tensorHandleFactory,
736 const armnn::ITensorHandleFactory& refTensorHandleFactory,
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100737 float beta)
738{
Derek Lambertif90c56d2020-01-10 17:14:08 +0000739 return CompareSoftmaxTestImpl<armnn::DataType::QAsymmU8>(
Sadik Armagan56785c72020-08-27 12:57:20 +0100740 workloadFactory, memoryManager, refWorkloadFactory, tensorHandleFactory, refTensorHandleFactory, beta);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100741}