blob: 0bca8be49d93edf9d89063cdf4c4d892bb823f70 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include <armnn/ArmNN.hpp>
8#include <armnn/Tensor.hpp>
9#include <armnn/TypesUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010
David Beckac42efd2018-09-26 17:41:13 +010011#include <test/TensorHelpers.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012#include "QuantizeHelper.hpp"
13
David Beckac42efd2018-09-26 17:41:13 +010014#include <backends/CpuTensorHandle.hpp>
15#include <backends/WorkloadFactory.hpp>
telsoa014fcda012018-03-09 14:13:49 +000016
17#include <algorithm>
18
19template<typename T>
20LayerTestResult<T, 2> SimpleSoftmaxTestImpl(armnn::IWorkloadFactory& workloadFactory, float beta)
21{
22 using std::exp;
23
24 armnn::TensorInfo inputTensorInfo;
25 armnn::TensorInfo outputTensorInfo;
26
27 unsigned int inputShape[] = { 2, 4 };
28
29 inputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::GetDataType<T>());
30 float qScale = 1.f / 256.f;
31 int qOffset = 0;
32 inputTensorInfo.SetQuantizationScale(qScale);
33 inputTensorInfo.SetQuantizationOffset(qOffset);
34
35 outputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::GetDataType<T>());
36 outputTensorInfo.SetQuantizationScale(qScale);
37 outputTensorInfo.SetQuantizationOffset(qOffset);
38
39 LayerTestResult<T, 2> ret(outputTensorInfo);
40
telsoa01c577f2c2018-08-31 09:22:23 +010041 // Each row is independently softmax'd.
telsoa014fcda012018-03-09 14:13:49 +000042 auto input = MakeTensor<T, 2>(inputTensorInfo, std::vector<T>(
43 QuantizedVector<T>(qScale, 0, {
44 0.f, 1.f, 0.f, 0.f,
45 .5f, 0.f, 0.f, 0.f,
46 })));
47
48 std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
49 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
50
51 armnn::SoftmaxQueueDescriptor data;
52 data.m_Parameters.m_Beta = beta;
53
54 armnn::WorkloadInfo info;
55 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
56 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
57
58 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateSoftmax(data, info);
59
60 inputHandle->Allocate();
61 outputHandle->Allocate();
62 CopyDataToITensorHandle(inputHandle.get(), &input[0][0]);
63
surmeh013537c2c2018-05-18 16:31:43 +010064 workloadFactory.Finalize();
telsoa014fcda012018-03-09 14:13:49 +000065 workload->Execute();
66
67 CopyDataFromITensorHandle(&ret.output[0][0], outputHandle.get());
68
69 float x0[4] = { exp((0.f - 1.0f) * beta), exp((1.0f - 1.0f) * beta),
70 exp((0.0f - 1.0f) * beta), exp((0.0f - 1.0f) * beta) };
71 float sum0 = x0[0] + x0[1] + x0[2] + x0[3];
72 float x1[4] = { exp((0.5f - 0.5f) * beta), exp((0.0f - 0.5f) * beta),
73 exp((0.0f - 0.5f) * beta), exp((0.0f - 0.5f) * beta) };
74 float sum1 = x1[0] + x1[1] + x1[2] + x1[3];
75
76 ret.outputExpected = MakeTensor<T, 2>(outputTensorInfo, std::vector<T>(
77 QuantizedVector<T>(qScale, qOffset, {
78 x0[0] / sum0, x0[1] / sum0, x0[2] / sum0, x0[3] / sum0,
79 x1[0] / sum1, x1[1] / sum1, x1[2] / sum1, x1[3] / sum1
80 })));
81
82 return ret;
83}
84
85template<typename T>
86LayerTestResult<T, 2> CompareSoftmaxTestImpl(armnn::IWorkloadFactory& workloadFactory,
87 armnn::IWorkloadFactory& refWorkloadFactory,
88 float beta)
89{
90
91 const int batchSize = 20;
92 const int channels = 30;
93
94 armnn::TensorInfo inputTensorInfo;
95 armnn::TensorInfo outputTensorInfo;
96
97 unsigned int inputShape[] = { batchSize, channels };
98
99 inputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::GetDataType<T>());
100 outputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::GetDataType<T>());
101 float qScale = 1.f / 256.f;
102 int qOffset = 0;
103 inputTensorInfo.SetQuantizationScale(qScale);
104 inputTensorInfo.SetQuantizationOffset(qOffset);
105 outputTensorInfo.SetQuantizationScale(qScale);
106 outputTensorInfo.SetQuantizationOffset(qOffset);
107
108
109 LayerTestResult<T, 2> ret(outputTensorInfo);
110 auto input = MakeRandomTensor<T, 2>(inputTensorInfo, 0xF00D, 0.0f, 1.0f);
111
112 std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
113 std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
114
115 armnn::SoftmaxQueueDescriptor data;
116 data.m_Parameters.m_Beta = beta;
117
118 armnn::WorkloadInfo info;
119 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
120 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
121
122 std::unique_ptr<armnn::ITensorHandle> outputHandleRef = refWorkloadFactory.CreateTensorHandle(outputTensorInfo);
123 std::unique_ptr<armnn::ITensorHandle> inputHandleRef = refWorkloadFactory.CreateTensorHandle(inputTensorInfo);
124
125
126 armnn::SoftmaxQueueDescriptor refData = data;
127 armnn::WorkloadInfo refInfo = info;
128 SetWorkloadInput(refData, refInfo, 0, inputTensorInfo, inputHandleRef.get());
129 SetWorkloadOutput(refData, refInfo, 0, outputTensorInfo, outputHandleRef.get());
130
131 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateSoftmax(data, info);
132 std::unique_ptr<armnn::IWorkload> workloadRef = refWorkloadFactory.CreateSoftmax(refData, refInfo);
133
134 outputHandleRef->Allocate();
135 inputHandleRef->Allocate();
136
137 inputHandle->Allocate();
138 outputHandle->Allocate();
139
140 CopyDataToITensorHandle(inputHandle.get(), &input[0][0]);
141 CopyDataToITensorHandle(inputHandleRef.get(), &input[0][0]);
142
surmeh013537c2c2018-05-18 16:31:43 +0100143 workloadFactory.Finalize();
telsoa014fcda012018-03-09 14:13:49 +0000144 workload->Execute();
surmeh013537c2c2018-05-18 16:31:43 +0100145 refWorkloadFactory.Finalize();
telsoa014fcda012018-03-09 14:13:49 +0000146 workloadRef->Execute();
147
148 CopyDataFromITensorHandle(&ret.output[0][0], outputHandle.get());
149 CopyDataFromITensorHandle(&ret.outputExpected[0][0], outputHandleRef.get());
150
151 return ret;
surmeh013537c2c2018-05-18 16:31:43 +0100152}