blob: 56bc23cf9cd655782612f734ec1ea59e9076a416 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
Teresa Charlinfbf0e5b2020-08-17 01:01:06 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
telsoa01c577f2c2018-08-31 09:22:23 +01005
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01006#include "LstmTestImpl.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007
Aron Virginas-Tar48623a02019-10-22 10:00:28 +01008#include <QuantizeHelper.hpp>
9
Matthew Sloyan171214c2020-09-09 09:07:37 +010010#include <armnn/utility/NumericCast.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010011
James Conroy1f58f032021-04-27 17:13:27 +010012#include <backendsCommon/TensorHandle.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010013
Sadik Armagana097d2a2021-11-24 15:47:28 +000014#include <armnnTestUtils/TensorCopyUtils.hpp>
15#include <WorkloadTestUtils.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010016
17#include <reference/workloads/Decoders.hpp>
18#include <reference/workloads/Encoders.hpp>
19#include <reference/workloads/LstmUtils.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010020
Sadik Armagana097d2a2021-11-24 15:47:28 +000021#include <TensorHelpers.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010022
Sadik Armagan1625efc2021-06-10 18:24:34 +010023#include <doctest/doctest.h>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010024namespace
25{
Jan Eilers38e05bd2019-06-26 13:10:09 +010026
27template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
28void LstmUtilsVectorBatchVectorAddTestImpl(
Sadik Armagan483c8112021-06-01 09:24:52 +010029 std::vector<float>& vec,
30 std::vector<float>& batchVec,
Jan Eilers38e05bd2019-06-26 13:10:09 +010031 uint32_t vSize,
32 uint32_t nBatch,
Sadik Armagan483c8112021-06-01 09:24:52 +010033 std::vector<float>& expectedOutput,
34 armnn::TensorShape& expectedShape)
Jan Eilers38e05bd2019-06-26 13:10:09 +010035{
36 float qScale = 0.0f;
37 int32_t qOffset = 0;
38 armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType, qScale, qOffset );
39
40 // Make encoder and decoder
41 std::unique_ptr<armnn::Decoder<float>> vecDecoder = armnn::MakeDecoder<float>(tensorInfo, vec.data());
42 std::unique_ptr<armnn::Decoder<float>> batchVecDecoder = armnn::MakeDecoder<float>(tensorInfo, batchVec.data());
43 std::unique_ptr<armnn::Encoder<float>> batchVecEncoder = armnn::MakeEncoder<float>(tensorInfo, batchVec.data());
44
45 VectorBatchVectorAdd(*vecDecoder, vSize, *batchVecDecoder, nBatch, *batchVecEncoder);
46
47 // check shape and compare values
Sadik Armagan483c8112021-06-01 09:24:52 +010048 auto result = CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape);
Sadik Armagan1625efc2021-06-10 18:24:34 +010049 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
Jan Eilers38e05bd2019-06-26 13:10:09 +010050
51 // check if iterator is back at start position
52 batchVecEncoder->Set(1.0f);
Sadik Armagan1625efc2021-06-10 18:24:34 +010053 CHECK(batchVec[0] == 1.0f);
Jan Eilers38e05bd2019-06-26 13:10:09 +010054}
55
56template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
57void LstmUtilsZeroVectorTestImpl(
Sadik Armagan483c8112021-06-01 09:24:52 +010058 std::vector<float>& input,
Jan Eilers38e05bd2019-06-26 13:10:09 +010059 uint32_t vSize,
Sadik Armagan483c8112021-06-01 09:24:52 +010060 std::vector<float>& expectedOutput,
61 armnn::TensorShape& expectedShape)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010062{
Jan Eilers38e05bd2019-06-26 13:10:09 +010063 float qScale = 0.0f;
64 int32_t qOffset = 0;
65
66 armnn::TensorInfo tensorInfo({vSize}, ArmnnType, qScale, qOffset );
67
68 // Make encoder for input
69 std::unique_ptr<armnn::Encoder<float>> outputEncoder = armnn::MakeEncoder<float>(tensorInfo, input.data());
70
71 // call ZeroVector
72 ZeroVector(*outputEncoder, vSize);
73
74 // check shape and compare values
Sadik Armagan483c8112021-06-01 09:24:52 +010075 auto result = CompareTensors(input, expectedOutput, expectedShape, expectedShape);
Sadik Armagan1625efc2021-06-10 18:24:34 +010076 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
Jan Eilers38e05bd2019-06-26 13:10:09 +010077
78 // check if iterator is back at start position
79 outputEncoder->Set(1.0f);
Sadik Armagan1625efc2021-06-10 18:24:34 +010080 CHECK(input[0] == 1.0f);
Jan Eilers38e05bd2019-06-26 13:10:09 +010081
82}
83
Jan Eilers38e05bd2019-06-26 13:10:09 +010084template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
85void LstmUtilsMeanStddevNormalizationTestImpl(
Sadik Armagan483c8112021-06-01 09:24:52 +010086 std::vector<float>& input,
Jan Eilers38e05bd2019-06-26 13:10:09 +010087 uint32_t vSize,
88 uint32_t nBatch,
Sadik Armagan483c8112021-06-01 09:24:52 +010089 std::vector<float>& expectedOutput,
90 armnn::TensorShape& expectedShape)
Jan Eilers38e05bd2019-06-26 13:10:09 +010091{
92 float qScale = 0.0f;
93 int32_t qOffset = 0;
94 armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType, qScale, qOffset );
95
96 // Make encoder and decoder for input
97 std::unique_ptr<armnn::Decoder<float>> inputDecoder = armnn::MakeDecoder<float>(tensorInfo, input.data());
98 std::unique_ptr<armnn::Encoder<float>> outputEncoder = armnn::MakeEncoder<float>(tensorInfo, input.data());
99
100 MeanStddevNormalization(*inputDecoder, *outputEncoder, vSize, nBatch, 1e-8f);
101
102 // check shape and compare values
Sadik Armagan483c8112021-06-01 09:24:52 +0100103 auto result = CompareTensors(input, expectedOutput, expectedShape, expectedShape);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100104 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100105
106 // check if iterator is back at start position
107 outputEncoder->Set(1.0f);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100108 CHECK(input[0] == 1.0f);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100109}
110
111template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
112void LstmUtilsVectorBatchVectorCwiseProductTestImpl(
Sadik Armagan483c8112021-06-01 09:24:52 +0100113 std::vector<float>& vec,
114 std::vector<float>& batchVec,
Jan Eilers38e05bd2019-06-26 13:10:09 +0100115 uint32_t vSize,
116 uint32_t nBatch,
Sadik Armagan483c8112021-06-01 09:24:52 +0100117 std::vector<float>& expectedOutput,
118 armnn::TensorShape& expectedShape)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100119{
120 float qScale = 0.0f;
121 int32_t qOffset = 0;
122 armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType, qScale, qOffset );
123
124 // Make encoder and decoder
125 std::unique_ptr<armnn::Decoder<float>> vecDecoder = armnn::MakeDecoder<float>(tensorInfo, vec.data());
126 std::unique_ptr<armnn::Decoder<float>> batchVecDecoder = armnn::MakeDecoder<float>(tensorInfo, batchVec.data());
127 std::unique_ptr<armnn::Encoder<float>> batchVecEncoder = armnn::MakeEncoder<float>(tensorInfo, batchVec.data());
128
129 VectorBatchVectorCwiseProduct(*vecDecoder, vSize, *batchVecDecoder, nBatch, *batchVecEncoder);
130
131 // check shape and compare values
Sadik Armagan483c8112021-06-01 09:24:52 +0100132 auto result = CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100133 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100134
135 // check if iterator is back at start position
136 batchVecEncoder->Set(1.0f);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100137 CHECK(batchVec[0] == 1.0f);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100138}
139
140// Lstm Layer tests:
James Conroy9c3cae82019-08-01 16:01:48 +0100141// *********************************** //
Conor Kennedyb9971c92019-05-07 07:14:23 +0100142template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
143LayerTestResult<T, 2>
144LstmNoCifgNoPeepholeNoProjectionTestImpl(
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000145 armnn::IWorkloadFactory& workloadFactory,
146 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100147 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +0100148 const std::vector<T>& input,
149 const std::vector<T>& outputExpected,
150 const armnn::TensorShape& inputShape,
151 const armnn::TensorShape& outputExpectedShape,
Conor Kennedyb9971c92019-05-07 07:14:23 +0100152 float qScale = 0.0f,
153 int32_t qOffset = 0,
154 armnn::DataType constantDataType = armnn::DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +0100155{
Jan Eilers8eb25602020-03-09 12:13:48 +0000156 IgnoreUnused(memoryManager);
Sadik Armagan483c8112021-06-01 09:24:52 +0100157 unsigned int batchSize = armnn::numeric_cast<unsigned int>(inputShape[0]);
158 unsigned int inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]);
159 unsigned int outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]);
telsoa01c577f2c2018-08-31 09:22:23 +0100160 // cellSize and outputSize have the same size when there is no projection.
161 unsigned numUnits = outputSize;
162
Conor Kennedyb9971c92019-05-07 07:14:23 +0100163 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset );
164 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
165 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100166
Conor Kennedyb9971c92019-05-07 07:14:23 +0100167 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
168 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
169 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
170 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100171
Rob Hughesbb46dde2020-05-20 15:27:37 +0100172 std::vector<T> inputVector;
telsoa01c577f2c2018-08-31 09:22:23 +0100173 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
telsoa01c577f2c2018-08-31 09:22:23 +0100174
Rob Hughesbb46dde2020-05-20 15:27:37 +0100175 std::vector<T> cellStateInVector(batchSize * numUnits, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100176 std::vector<T> outputStateInVector(batchSize * outputSize, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100177 std::vector<T> scratchBufferVector(batchSize * numUnits * 4, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100178 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100179 std::vector<T> cellStateOutVector(batchSize * numUnits, T());
Sadik Armagan483c8112021-06-01 09:24:52 +0100180
181 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
telsoa01c577f2c2018-08-31 09:22:23 +0100182
Rob Hughesbb46dde2020-05-20 15:27:37 +0100183 std::vector<T> outputVector;
telsoa01c577f2c2018-08-31 09:22:23 +0100184 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
telsoa01c577f2c2018-08-31 09:22:23 +0100185
Finn Williamsc43de6a2020-08-27 11:13:25 +0100186 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100187 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100188 tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100189 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100190 tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100191
Finn Williamsc43de6a2020-08-27 11:13:25 +0100192 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
193 tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100194 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100195 tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100196 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100197 tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
198 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100199
200 armnn::LstmQueueDescriptor data;
201 armnn::WorkloadInfo info;
202
203 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
204 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
205 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
206
207 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
208 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
209 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
210 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
211
Conor Kennedyb9971c92019-05-07 07:14:23 +0100212 armnn::TensorInfo tensorInfo4({numUnits}, constantDataType , qScale, qOffset);
213 armnn::TensorInfo tensorInfo8({numUnits, 2}, constantDataType, qScale, qOffset);
214 armnn::TensorInfo tensorInfo16({numUnits, 4}, constantDataType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100215
Sadik Armagan483c8112021-06-01 09:24:52 +0100216 std::vector<float> inputToInputWeights = {-0.45018822f, -0.02338299f, -0.0870589f,
217 -0.34550029f, 0.04266912f, -0.15680569f,
218 -0.34856534f, 0.43890524f};
telsoa01c577f2c2018-08-31 09:22:23 +0100219
Sadik Armagan483c8112021-06-01 09:24:52 +0100220 std::vector<float> inputToForgetWeights = { 0.09701663f, 0.20334584f, -0.50592935f,
221 -0.31343272f, -0.40032279f, 0.44781327f,
222 0.01387155f, -0.35593212f};
telsoa01c577f2c2018-08-31 09:22:23 +0100223
Sadik Armagan483c8112021-06-01 09:24:52 +0100224 std::vector<float> inputToCellWeights = { -0.50013041f, 0.1370284f, 0.11810488f, 0.2013163f,
225 -0.20583314f, 0.44344562f, 0.22077113f,
226 -0.29909778f};
telsoa01c577f2c2018-08-31 09:22:23 +0100227
Sadik Armagan483c8112021-06-01 09:24:52 +0100228 std::vector<float> inputToOutputWeights = { -0.25065863f, -0.28290087f, 0.04613829f,
229 0.40525138f, 0.44272184f, 0.03897077f,
230 -0.1556896f, 0.19487578f};
telsoa01c577f2c2018-08-31 09:22:23 +0100231
Sadik Armagan483c8112021-06-01 09:24:52 +0100232 std::vector<float> recurrentToInputWeights = {-0.0063535f, -0.2042388f, 0.31454784f,
233 -0.35746509f, 0.28902304f, 0.08183324f,
234 -0.16555229f, 0.02286911f, -0.13566875f,
235 0.03034258f, 0.48091322f, -0.12528998f,
236 0.24077177f, -0.51332325f, -0.33502164f,
237 0.10629296f};
telsoa01c577f2c2018-08-31 09:22:23 +0100238
Sadik Armagan483c8112021-06-01 09:24:52 +0100239 std::vector<float> recurrentToForgetWeights = { -0.48684245f, -0.06655136f, 0.42224967f,
240 0.2112639f, 0.27654213f, 0.20864892f,
241 -0.07646349f, 0.45877004f, 0.00141793f,
242 -0.14609534f, 0.36447752f, 0.09196436f,
243 0.28053468f, 0.01560611f, -0.20127171f,
244 -0.01140004f};
telsoa01c577f2c2018-08-31 09:22:23 +0100245
Sadik Armagan483c8112021-06-01 09:24:52 +0100246 std::vector<float> recurrentToCellWeights = { -0.3407414f, 0.24443203f, -0.2078532f,
247 0.26320225f, 0.05695659f, -0.00123841f,
248 -0.4744786f, -0.35869038f, -0.06418842f,
249 -0.13502428f, -0.501764f, 0.22830659f,
250 -0.46367589f, 0.26016325f, -0.03894562f,
251 -0.16368064f};
telsoa01c577f2c2018-08-31 09:22:23 +0100252
Sadik Armagan483c8112021-06-01 09:24:52 +0100253 std::vector<float> recurrentToOutputWeights = { 0.43385774f, -0.17194885f, 0.2718237f,
254 0.09215671f, 0.24107647f, -0.39835793f,
255 0.18212086f, 0.01301402f, 0.48572797f,
256 -0.50656658f, 0.20047462f, -0.20607421f,
257 -0.51818722f, -0.15390486f, 0.0468148f,
258 0.39922136f};
telsoa01c577f2c2018-08-31 09:22:23 +0100259
Sadik Armagan483c8112021-06-01 09:24:52 +0100260 std::vector<float> cellToInputWeights = {0., 0., 0., 0.};
telsoa01c577f2c2018-08-31 09:22:23 +0100261
Sadik Armagan483c8112021-06-01 09:24:52 +0100262 std::vector<float> inputGateBias = {0., 0., 0., 0.};
telsoa01c577f2c2018-08-31 09:22:23 +0100263
Sadik Armagan483c8112021-06-01 09:24:52 +0100264 std::vector<float> forgetGateBias = {1., 1., 1., 1.};
telsoa01c577f2c2018-08-31 09:22:23 +0100265
Sadik Armagan483c8112021-06-01 09:24:52 +0100266 std::vector<float> cellBias = {0., 0., 0., 0.};
telsoa01c577f2c2018-08-31 09:22:23 +0100267
Sadik Armagan483c8112021-06-01 09:24:52 +0100268 std::vector<float> outputGateBias = {0., 0., 0., 0.};
telsoa01c577f2c2018-08-31 09:22:23 +0100269
James Conroy1f58f032021-04-27 17:13:27 +0100270 armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo8);
271 armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo8);
272 armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo8);
273 armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo8);
274 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo16);
275 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo16);
276 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo16);
277 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo16);
278 armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo4);
279 armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo4);
280 armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo4);
281 armnn::ScopedTensorHandle cellBiasTensor(tensorInfo4);
282 armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo4);
telsoa01c577f2c2018-08-31 09:22:23 +0100283
Sadik Armagan483c8112021-06-01 09:24:52 +0100284 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
285 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
286 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
287 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
288 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
289 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
290 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
291 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
292 AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data());
293 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
294 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
295 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
296 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
telsoa01c577f2c2018-08-31 09:22:23 +0100297
298 data.m_InputToInputWeights = &inputToInputWeightsTensor;
299 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
300 data.m_InputToCellWeights = &inputToCellWeightsTensor;
301 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
302 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
303 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
304 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
305 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100306 data.m_InputGateBias = &inputGateBiasTensor;
307 data.m_ForgetGateBias = &forgetGateBiasTensor;
308 data.m_CellBias = &cellBiasTensor;
309 data.m_OutputGateBias = &outputGateBiasTensor;
310
telsoa01c577f2c2018-08-31 09:22:23 +0100311 // Flags to set test configuration
312 data.m_Parameters.m_ActivationFunc = 4;
313 data.m_Parameters.m_CifgEnabled = false;
314 data.m_Parameters.m_PeepholeEnabled = false;
315 data.m_Parameters.m_ProjectionEnabled = false;
316
telsoa01c577f2c2018-08-31 09:22:23 +0100317 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info);
318 inputHandle->Allocate();
319 outputStateInHandle->Allocate();
320 cellStateInHandle->Allocate();
321
322 scratchHandle->Allocate();
323 outputStateOutHandle->Allocate();
324 cellStateOutHandle->Allocate();
325 outputHandle->Allocate();
326
Sadik Armagan483c8112021-06-01 09:24:52 +0100327 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
328 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
329 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +0100330
telsoa01c577f2c2018-08-31 09:22:23 +0100331 workload->Execute();
332
Sadik Armagan483c8112021-06-01 09:24:52 +0100333 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
telsoa01c577f2c2018-08-31 09:22:23 +0100334
Sadik Armagan483c8112021-06-01 09:24:52 +0100335 return LayerTestResult<T, 2>(actualOutput,
336 outputVector,
337 outputHandle->GetShape(),
338 outputTensorInfo.GetShape());
telsoa01c577f2c2018-08-31 09:22:23 +0100339}
340
Conor Kennedyb9971c92019-05-07 07:14:23 +0100341template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
342LayerTestResult<T, 2>
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000343LstmLayerNoCifgWithPeepholeWithProjectionTestImpl(armnn::IWorkloadFactory& workloadFactory,
344 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100345 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +0100346 const std::vector<T>& input,
347 const std::vector<T>& outputExpected,
Conor Kennedyb9971c92019-05-07 07:14:23 +0100348 float qScale = 0.0f,
349 int32_t qOffset = 0,
350 armnn::DataType constantDataType = armnn::DataType::Float32)
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000351{
Jan Eilers8eb25602020-03-09 12:13:48 +0000352 IgnoreUnused(memoryManager);
telsoa01c577f2c2018-08-31 09:22:23 +0100353 unsigned int batchSize = 2;
354 unsigned int outputSize = 16;
355 unsigned int inputSize = 5;
356 unsigned numUnits = 20;
357
Conor Kennedyb9971c92019-05-07 07:14:23 +0100358 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset);
359 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
360 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100361
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000362 // Scratch buffer size without CIFG [batchSize, numUnits * 4]
Conor Kennedyb9971c92019-05-07 07:14:23 +0100363 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
364 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
365 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
366 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100367
Rob Hughesbb46dde2020-05-20 15:27:37 +0100368 std::vector<T> inputVector;
telsoa01c577f2c2018-08-31 09:22:23 +0100369 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
telsoa01c577f2c2018-08-31 09:22:23 +0100370
Rob Hughesbb46dde2020-05-20 15:27:37 +0100371 std::vector<T> cellStateInVector(batchSize * numUnits, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100372 std::vector<T> outputStateInVector(batchSize * outputSize, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100373 std::vector<T> scratchBufferVector(batchSize * numUnits * 4, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100374 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100375 std::vector<T> cellStateOutVector(batchSize * numUnits, T());
Sadik Armagan483c8112021-06-01 09:24:52 +0100376
377 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
telsoa01c577f2c2018-08-31 09:22:23 +0100378
Rob Hughesbb46dde2020-05-20 15:27:37 +0100379 std::vector<T> outputVector;
telsoa01c577f2c2018-08-31 09:22:23 +0100380 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
telsoa01c577f2c2018-08-31 09:22:23 +0100381
Finn Williamsc43de6a2020-08-27 11:13:25 +0100382 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100383 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100384 tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100385 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100386 tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100387
Finn Williamsc43de6a2020-08-27 11:13:25 +0100388 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
389 tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100390 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100391 tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100392 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100393 tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
394 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100395
396 armnn::LstmQueueDescriptor data;
397 armnn::WorkloadInfo info;
398
399 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
400 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
401 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
David Beckac42efd2018-09-26 17:41:13 +0100402
telsoa01c577f2c2018-08-31 09:22:23 +0100403 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
404 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
405 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
406 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
407
Conor Kennedyb9971c92019-05-07 07:14:23 +0100408 armnn::TensorInfo tensorInfo16({outputSize}, constantDataType, qScale, qOffset);
409 armnn::TensorInfo tensorInfo20({numUnits}, constantDataType, qScale, qOffset);
410 armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, constantDataType, qScale, qOffset);
411 armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, constantDataType, qScale, qOffset);
412 armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, constantDataType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100413
Sadik Armagan483c8112021-06-01 09:24:52 +0100414 std::vector<float> inputToInputWeights = {0.021393683f,0.06124551f, 0.046905167f,-0.014657677f,-0.03149463f,
415 0.09171803f, 0.14647801f,0.10797193f, -0.0057968358f,0.0019193048f,
416 -0.2726754f, 0.10154029f, -0.018539885f, 0.080349885f, -0.10262385f,
417 -0.022599787f,-0.09121155f, -0.008675967f, -0.045206103f,-0.0821282f,
418 -0.008045952f,0.015478081f, 0.055217247f, 0.038719587f, 0.044153627f,
419 -0.06453243f,0.05031825f, -0.046935108f, -0.008164439f, 0.014574226f,
420 -0.1671009f, -0.15519552f, -0.16819797f,-0.13971269f,-0.11953059f,
421 0.25005487f, -0.22790983f, 0.009855087f, -0.028140958f, -0.11200698f,
422 0.11295408f, -0.0035217577f, 0.054485075f, 0.05184695f, 0.064711206f,
423 0.10989193f, 0.11674786f, 0.03490607f, 0.07727357f, 0.11390585f,
424 -0.1863375f, -0.1034451f, -0.13945189f, -0.049401227f, -0.18767063f,
425 0.042483903f, 0.14233552f, 0.13832581f, 0.18350165f, 0.14545603f,
426 -0.028545704f,0.024939531f,0.050929718f,0.0076203286f,-0.0029723682f,
427 -0.042484224f, -0.11827596f, -0.09171104f, -0.10808628f,-0.16327988f,
428 -0.2273378f, -0.0993647f, -0.017155107f,0.0023917493f,0.049272764f,
429 0.0038534778f, 0.054764505f, 0.089753784f, 0.06947234f, 0.08014476f,
430 -0.04544234f, -0.0497073f,-0.07135631f, -0.048929106f,-0.004042012f,
431 -0.009284026f, 0.018042054f, 0.0036860977f,-0.07427302f, -0.11434604f,
432 -0.018995456f, 0.031487543f, 0.012834908f,0.019977754f,0.044256654f,
433 -0.39292613f, -0.18519334f, -0.11651281f,-0.06809892f, 0.011373677f };
telsoa01c577f2c2018-08-31 09:22:23 +0100434
Sadik Armagan483c8112021-06-01 09:24:52 +0100435 std::vector<float> inputToForgetWeights = {-0.0018401089f, -0.004852237f,0.03698424f, 0.014181704f,0.028273236f,
436 -0.016726194f, -0.05249759f,-0.10204261f, 0.00861066f,-0.040979505f,
437 -0.009899187f,0.01923892f,-0.028177269f, -0.08535103f,-0.14585495f,
438 0.10662567f,-0.01909731f,-0.017883534f,-0.0047269356f,-0.045103323f,
439 0.0030784295f,0.076784775f,0.07463696f, 0.094531395f,0.0814421f,
440 -0.12257899f, -0.033945758f,-0.031303465f, 0.045630626f,0.06843887f,
441 -0.13492945f, -0.012480007f,-0.0811829f, -0.07224499f,-0.09628791f,
442 0.045100946f,0.0012300825f, 0.013964662f, 0.099372394f,0.02543059f,
443 0.06958324f, 0.034257296f, 0.0482646f, 0.06267997f,0.052625068f,
444 0.12784666f, 0.07077897f, 0.025725935f, 0.04165009f,0.07241905f,
445 0.018668644f, -0.037377294f,-0.06277783f,-0.08833636f,-0.040120605f,
446 -0.011405586f,-0.007808335f,-0.010301386f,-0.005102167f,0.027717464f,
447 0.05483423f, 0.11449111f, 0.11289652f,0.10939839f, 0.13396506f,
448 -0.08402166f,-0.01901462f, -0.044678304f,-0.07720565f,0.014350063f,
449 -0.11757958f, -0.0652038f, -0.08185733f,-0.076754324f,-0.092614375f,
450 0.10405491f, 0.052960336f, 0.035755895f,0.035839386f,-0.012540553f,
451 0.036881298f, 0.02913376f, 0.03420159f,0.05448447f,-0.054523353f,
452 0.02582715f, 0.02327355f, -0.011857179f,-0.0011980024f,-0.034641717f,
453 -0.026125094f,-0.17582615f,-0.15923657f,-0.27486774f,-0.0006143371f,
454 0.0001771948f, -8.470171e-05f, 0.02651807f,0.045790765f,0.06956496f };
telsoa01c577f2c2018-08-31 09:22:23 +0100455
Sadik Armagan483c8112021-06-01 09:24:52 +0100456 std::vector<float> inputToCellWeights = { -0.04580283f, -0.09549462f, -0.032418985f, -0.06454633f,
457 -0.043528453f, 0.043018587f, -0.049152344f, -0.12418144f,
458 -0.078985475f, -0.07596889f, 0.019484362f, -0.11434962f,
459 -0.0074034138f, -0.06314844f, -0.092981495f, 0.0062155537f,
460 -0.025034338f, -0.0028890965f, 0.048929527f, 0.06235075f,
461 0.10665918f, -0.032036792f, -0.08505916f, -0.10843358f,
462 -0.13002433f, -0.036816437f, -0.02130134f, -0.016518239f,
463 0.0047691227f, -0.0025825808f, 0.066017866f, 0.029991534f,
464 -0.10652836f, -0.1037554f, -0.13056071f, -0.03266643f,
465 -0.033702414f, -0.006473424f, -0.04611692f, 0.014419339f,
466 -0.025174323f, 0.0396852f, 0.081777506f, 0.06157468f,
467 0.10210095f, -0.009658194f, 0.046511717f, 0.03603906f,
468 0.0069369148f, 0.015960095f, -0.06507666f, 0.09551598f,
469 0.053568836f, 0.06408714f, 0.12835667f, -0.008714329f,
470 -0.20211966f, -0.12093674f, 0.029450472f, 0.2849013f,
471 -0.029227901f, 0.1164364f, -0.08560263f, 0.09941786f,
472 -0.036999565f, -0.028842626f, -0.0033637602f, -0.017012902f,
473 -0.09720865f, -0.11193351f, -0.029155117f, -0.017936034f,
474 -0.009768936f, -0.04223324f, -0.036159635f, 0.06505112f,
475 -0.021742892f, -0.023377212f, -0.07221364f, -0.06430552f,
476 0.05453865f, 0.091149814f, 0.06387331f, 0.007518393f,
477 0.055960953f, 0.069779344f, 0.046411168f, 0.10509911f,
478 0.07463894f, 0.0075130584f, 0.012850982f, 0.04555431f,
479 0.056955688f, 0.06555285f, 0.050801456f, -0.009862683f,
480 0.00826772f, -0.026555609f, -0.0073611983f, -0.0014897042f };
telsoa01c577f2c2018-08-31 09:22:23 +0100481
Sadik Armagan483c8112021-06-01 09:24:52 +0100482 std::vector<float> inputToOutputWeights ={-0.0998932f, -0.07201956f, -0.052803773f,-0.15629593f,-0.15001918f,
483 -0.07650751f,0.02359855f, -0.075155355f, -0.08037709f, -0.15093534f,
484 0.029517552f, -0.04751393f, 0.010350531f,-0.02664851f, -0.016839722f,
485 -0.023121163f, 0.0077019283f, 0.012851257f, -0.05040649f,-0.0129761f,
486 -0.021737747f,-0.038305793f,-0.06870586f, -0.01481247f,-0.001285394f,
487 0.10124236f, 0.083122835f, 0.053313006f,-0.062235646f,-0.075637154f,
488 -0.027833903f, 0.029774971f, 0.1130802f, 0.09218906f, 0.09506135f,
489 -0.086665764f,-0.037162706f,-0.038880914f,-0.035832845f,-0.014481564f,
490 -0.09825003f,-0.12048569f,-0.097665586f,-0.05287633f, -0.0964047f,
491 -0.11366429f, 0.035777505f, 0.13568819f, 0.052451383f,0.050649304f,
492 0.05798951f, -0.021852335f,-0.099848844f,0.014740475f,-0.078897946f,
493 0.04974699f, 0.014160473f, 0.06973932f, 0.04964942f, 0.033364646f,
494 0.08190124f, 0.025535367f, 0.050893165f, 0.048514254f,0.06945813f,
495 -0.078907564f,-0.06707616f, -0.11844508f, -0.09986688f,-0.07509403f,
496 0.06263226f, 0.14925587f, 0.20188436f, 0.12098451f,0.14639415f,
497 0.0015017595f, -0.014267382f, -0.03417257f,0.012711468f,0.0028300495f,
498 -0.024758482f, -0.05098548f,-0.0821182f, 0.014225672f, 0.021544158f,
499 0.08949725f, 0.07505268f, -0.0020780868f, 0.04908258f,0.06476295f,
500 -0.022907063f,0.027562456f,0.040185735f, 0.019567577f,-0.015598739f,
501 -0.049097303f, -0.017121866f, -0.083368234f,-0.02332002f,-0.0840956f };
telsoa01c577f2c2018-08-31 09:22:23 +0100502
Sadik Armagan483c8112021-06-01 09:24:52 +0100503 std::vector<float> inputGateBias = {0.02234832f, 0.14757581f, 0.18176508f, 0.10380666f, 0.053110216f,
504 -0.06928846f, -0.13942584f, -0.11816189f, 0.19483899f, 0.03652339f,
505 -0.10250295f, 0.036714908f, -0.18426876f, 0.036065217f, 0.21810818f,
506 0.02383196f, -0.043370757f, 0.08690144f, -0.04444982f, 0.00030581196f };
telsoa01c577f2c2018-08-31 09:22:23 +0100507
Sadik Armagan483c8112021-06-01 09:24:52 +0100508 std::vector<float> forgetGateBias ={0.035185695f, -0.042891346f, -0.03032477f, 0.23027696f,
509 0.11098921f, 0.15378423f, 0.09263801f, 0.09790885f,
510 0.09508917f, 0.061199076f, 0.07665568f, -0.015443159f,
511 -0.03499149f, 0.046190713f, 0.08895977f, 0.10899629f,
512 0.40694186f, 0.06030037f, 0.012413437f, -0.06108739f };
telsoa01c577f2c2018-08-31 09:22:23 +0100513
Sadik Armagan483c8112021-06-01 09:24:52 +0100514 std::vector<float> cellBias = { -0.024379363f, 0.0055531194f, 0.23377132f, 0.033463873f,
515 -0.1483596f, -0.10639995f, -0.091433935f, 0.058573797f,
516 -0.06809782f, -0.07889636f, -0.043246906f, -0.09829136f,
517 -0.4279842f, 0.034901652f, 0.18797937f, 0.0075234566f,
518 0.016178843f, 0.1749513f, 0.13975595f, 0.92058027f };
telsoa01c577f2c2018-08-31 09:22:23 +0100519
Sadik Armagan483c8112021-06-01 09:24:52 +0100520 std::vector<float> outputGateBias ={0.046159424f, -0.0012809046f, 0.03563469f, 0.12648113f, 0.027195795f,
521 0.35373217f, -0.018957434f, 0.008907322f, -0.0762701f, 0.12018895f,
522 0.04216877f, 0.0022856654f, 0.040952638f, 0.3147856f, 0.08225149f,
523 -0.057416286f, -0.14995944f, -0.008040261f, 0.13208859f, 0.029760877f};
telsoa01c577f2c2018-08-31 09:22:23 +0100524
Sadik Armagan483c8112021-06-01 09:24:52 +0100525 std::vector<float> recurrentToInputWeights = { -0.001374326f, -0.078856036f, 0.10672688f, 0.029162422f,
telsoa01c577f2c2018-08-31 09:22:23 +0100526 -0.11585556f, 0.02557986f, -0.13446963f, -0.035785314f,
527 -0.01244275f, 0.025961924f, -0.02337298f, -0.044228926f,
528 -0.055839065f, -0.046598054f, -0.010546039f, -0.06900766f,
529 0.027239809f, 0.022582639f, -0.013296484f, -0.05459212f,
530 0.08981f, -0.045407712f, 0.08682226f, -0.06867011f,
531 -0.14390695f, -0.02916037f, 0.000996957f, 0.091420636f,
532 0.14283475f, -0.07390571f, -0.06402044f, 0.062524505f,
533 -0.093129106f, 0.04860203f, -0.08364217f, -0.08119002f,
534 0.009352075f, 0.22920375f, 0.0016303885f, 0.11583097f,
535 -0.13732095f, 0.012405723f, -0.07551853f, 0.06343048f,
536 0.12162708f, -0.031923793f, -0.014335606f, 0.01790974f,
537 -0.10650317f, -0.0724401f, 0.08554849f, -0.05727212f,
538 0.06556731f, -0.042729504f, -0.043227166f, 0.011683251f,
539 -0.013082158f, -0.029302018f, -0.010899579f, -0.062036745f,
540 -0.022509435f, -0.00964907f, -0.01567329f, 0.04260106f,
541 -0.07787477f, -0.11576462f, 0.017356863f, 0.048673786f,
542 -0.017577527f, -0.05527947f, -0.082487635f, -0.040137455f,
543 -0.10820036f, -0.04666372f, 0.022746278f, -0.07851417f,
544 0.01068115f, 0.032956902f, 0.022433773f, 0.0026891115f,
545 0.08944216f, -0.0685835f, 0.010513544f, 0.07228705f,
546 0.02032331f, -0.059686817f, -0.0005566496f, -0.086984694f,
547 0.040414046f, -0.1380399f, 0.094208956f, -0.05722982f,
548 0.012092817f, -0.04989123f, -0.086576f, -0.003399834f,
549 -0.04696032f, -0.045747425f, 0.10091314f, 0.048676282f,
550 -0.029037097f, 0.031399418f, -0.0040285117f, 0.047237843f,
551 0.09504992f, 0.041799378f, -0.049185462f, -0.031518843f,
552 -0.10516937f, 0.026374253f, 0.10058866f, -0.0033195973f,
553 -0.041975245f, 0.0073591834f, 0.0033782164f, -0.004325073f,
554 -0.10167381f, 0.042500053f, -0.01447153f, 0.06464186f,
555 -0.017142897f, 0.03312627f, 0.009205989f, 0.024138335f,
556 -0.011337001f, 0.035530265f, -0.010912711f, 0.0706555f,
557 -0.005894094f, 0.051841937f, -0.1401738f, -0.02351249f,
558 0.0365468f, 0.07590991f, 0.08838724f, 0.021681072f,
559 -0.10086113f, 0.019608743f, -0.06195883f, 0.077335775f,
560 0.023646897f, -0.095322326f, 0.02233014f, 0.09756986f,
561 -0.048691444f, -0.009579111f, 0.07595467f, 0.11480546f,
562 -0.09801813f, 0.019894179f, 0.08502348f, 0.004032281f,
563 0.037211012f, 0.068537936f, -0.048005626f, -0.091520436f,
564 -0.028379958f, -0.01556313f, 0.06554592f, -0.045599163f,
565 -0.01672207f, -0.020169014f, -0.011877351f, -0.20212261f,
566 0.010889619f, 0.0047078193f, 0.038385306f, 0.08540671f,
567 -0.017140968f, -0.0035865551f, 0.016678626f, 0.005633034f,
568 0.015963363f, 0.00871737f, 0.060130805f, 0.028611384f,
569 0.10109069f, -0.015060172f, -0.07894427f, 0.06401885f,
570 0.011584063f, -0.024466386f, 0.0047652307f, -0.09041358f,
571 0.030737216f, -0.0046374933f, 0.14215417f, -0.11823516f,
572 0.019899689f, 0.006106124f, -0.027092824f, 0.0786356f,
573 0.05052217f, -0.058925f, -0.011402121f, -0.024987547f,
574 -0.0013661642f, -0.06832946f, -0.015667673f, -0.1083353f,
575 -0.00096863037f, -0.06988685f, -0.053350925f, -0.027275559f,
576 -0.033664223f, -0.07978348f, -0.025200296f, -0.017207067f,
577 -0.058403496f, -0.055697463f, 0.005798788f, 0.12965427f,
578 -0.062582195f, 0.0013350133f, -0.10482091f, 0.0379771f,
579 0.072521195f, -0.0029455067f, -0.13797039f, -0.03628521f,
580 0.013806405f, -0.017858358f, -0.01008298f, -0.07700066f,
581 -0.017081132f, 0.019358726f, 0.0027079724f, 0.004635139f,
582 0.062634714f, -0.02338735f, -0.039547626f, -0.02050681f,
583 0.03385117f, -0.083611414f, 0.002862572f, -0.09421313f,
584 0.058618143f, -0.08598433f, 0.00972939f, 0.023867095f,
585 -0.053934585f, -0.023203006f, 0.07452513f, -0.048767887f,
586 -0.07314807f, -0.056307215f, -0.10433547f, -0.06440842f,
587 0.04328182f, 0.04389765f, -0.020006588f, -0.09076438f,
588 -0.11652589f, -0.021705797f, 0.03345259f, -0.010329105f,
589 -0.025767034f, 0.013057034f, -0.07316461f, -0.10145612f,
590 0.06358255f, 0.18531723f, 0.07759293f, 0.12006465f,
591 0.1305557f, 0.058638252f, -0.03393652f, 0.09622831f,
592 -0.16253184f, -2.4580743e-06f, 0.079869635f, -0.070196845f,
593 -0.005644518f, 0.06857898f, -0.12598175f, -0.035084512f,
594 0.03156317f, -0.12794146f, -0.031963028f, 0.04692781f,
595 0.030070418f, 0.0071660685f, -0.095516115f, -0.004643372f,
596 0.040170413f, -0.062104587f, -0.0037324072f, 0.0554317f,
597 0.08184801f, -0.019164372f, 0.06791302f, 0.034257166f,
598 -0.10307039f, 0.021943003f, 0.046745934f, 0.0790918f,
599 -0.0265588f, -0.007824208f, 0.042546265f, -0.00977924f,
600 -0.0002440307f, -0.017384544f, -0.017990116f, 0.12252321f,
601 -0.014512694f, -0.08251313f, 0.08861942f, 0.13589665f,
602 0.026351685f, 0.012641483f, 0.07466548f, 0.044301085f,
603 -0.045414884f, -0.051112458f, 0.03444247f, -0.08502782f,
Sadik Armagan483c8112021-06-01 09:24:52 +0100604 -0.04106223f, -0.028126027f, 0.028473156f, 0.10467447f };
telsoa01c577f2c2018-08-31 09:22:23 +0100605
Sadik Armagan483c8112021-06-01 09:24:52 +0100606 std::vector<float> recurrentToForgetWeights = {-0.057784554f, -0.026057621f, -0.068447545f, -0.022581743f,
telsoa01c577f2c2018-08-31 09:22:23 +0100607 0.14811787f, 0.10826372f, 0.09471067f, 0.03987225f,
608 -0.0039523416f, 0.00030638507f, 0.053185795f, 0.10572994f,
609 0.08414449f, -0.022036452f, -0.00066928595f, -0.09203576f,
610 0.032950465f, -0.10985798f, -0.023809856f, 0.0021431844f,
611 -0.02196096f, -0.00326074f, 0.00058621005f, -0.074678116f,
612 -0.06193199f, 0.055729095f, 0.03736828f, 0.020123724f,
613 0.061878487f, -0.04729229f, 0.034919553f, -0.07585433f,
614 -0.04421272f, -0.044019096f, 0.085488975f, 0.04058006f,
615 -0.06890133f, -0.030951202f, -0.024628663f, -0.07672815f,
616 0.034293607f, 0.08556707f, -0.05293577f, -0.033561368f,
617 -0.04899627f, 0.0241671f, 0.015736353f, -0.095442444f,
618 -0.029564252f, 0.016493602f, -0.035026584f, 0.022337519f,
619 -0.026871363f, 0.004780428f, 0.0077918363f, -0.03601621f,
620 0.016435321f, -0.03263031f, -0.09543275f, -0.047392778f,
621 0.013454138f, 0.028934088f, 0.01685226f, -0.086110644f,
622 -0.046250615f, -0.01847454f, 0.047608484f, 0.07339695f,
623 0.034546845f, -0.04881143f, 0.009128804f, -0.08802852f,
624 0.03761666f, 0.008096139f, -0.014454086f, 0.014361001f,
625 -0.023502491f, -0.0011840804f, -0.07607001f, 0.001856849f,
626 -0.06509276f, -0.006021153f, -0.08570962f, -0.1451793f,
627 0.060212336f, 0.055259194f, 0.06974018f, 0.049454916f,
628 -0.027794661f, -0.08077226f, -0.016179763f, 0.1169753f,
629 0.17213494f, -0.0056326236f, -0.053934924f, -0.0124349f,
630 -0.11520337f, 0.05409887f, 0.088759385f, 0.0019655675f,
631 0.0042065294f, 0.03881498f, 0.019844765f, 0.041858196f,
632 -0.05695512f, 0.047233116f, 0.038937137f, -0.06542224f,
633 0.014429736f, -0.09719407f, 0.13908425f, -0.05379757f,
634 0.012321099f, 0.082840554f, -0.029899208f, 0.044217527f,
635 0.059855383f, 0.07711018f, -0.045319796f, 0.0948846f,
636 -0.011724666f, -0.0033288454f, -0.033542685f, -0.04764985f,
637 -0.13873616f, 0.040668588f, 0.034832682f, -0.015319203f,
638 -0.018715994f, 0.046002675f, 0.0599172f, -0.043107376f,
639 0.0294216f, -0.002314414f, -0.022424703f, 0.0030315618f,
640 0.0014641669f, 0.0029166266f, -0.11878115f, 0.013738511f,
641 0.12375372f, -0.0006038222f, 0.029104086f, 0.087442465f,
642 0.052958444f, 0.07558703f, 0.04817258f, 0.044462286f,
643 -0.015213451f, -0.08783778f, -0.0561384f, -0.003008196f,
644 0.047060397f, -0.002058388f, 0.03429439f, -0.018839769f,
645 0.024734668f, 0.024614193f, -0.042046934f, 0.09597743f,
646 -0.0043254104f, 0.04320769f, 0.0064070094f, -0.0019131786f,
647 -0.02558259f, -0.022822596f, -0.023273505f, -0.02464396f,
648 -0.10991725f, -0.006240552f, 0.0074488563f, 0.024044557f,
649 0.04383914f, -0.046476185f, 0.028658995f, 0.060410924f,
650 0.050786525f, 0.009452605f, -0.0073054377f, -0.024810238f,
651 0.0052906186f, 0.0066939713f, -0.0020913032f, 0.014515517f,
652 0.015898481f, 0.021362653f, -0.030262267f, 0.016587038f,
653 -0.011442813f, 0.041154444f, -0.007631438f, -0.03423484f,
654 -0.010977775f, 0.036152758f, 0.0066366293f, 0.11915515f,
655 0.02318443f, -0.041350313f, 0.021485701f, -0.10906167f,
656 -0.028218046f, -0.00954771f, 0.020531068f, -0.11995105f,
657 -0.03672871f, 0.024019798f, 0.014255957f, -0.05221243f,
658 -0.00661567f, -0.04630967f, 0.033188973f, 0.10107534f,
659 -0.014027541f, 0.030796422f, -0.10270911f, -0.035999842f,
660 0.15443139f, 0.07684145f, 0.036571592f, -0.035900835f,
661 -0.0034699554f, 0.06209149f, 0.015920248f, -0.031122351f,
662 -0.03858649f, 0.01849943f, 0.13872518f, 0.01503974f,
663 0.069941424f, -0.06948533f, -0.0088794185f, 0.061282158f,
664 -0.047401894f, 0.03100163f, -0.041533746f, -0.10430945f,
665 0.044574402f, -0.01425562f, -0.024290353f, 0.034563623f,
666 0.05866852f, 0.023947537f, -0.09445152f, 0.035450947f,
667 0.02247216f, -0.0042998926f, 0.061146557f, -0.10250651f,
668 0.020881841f, -0.06747029f, 0.10062043f, -0.0023941975f,
669 0.03532124f, -0.016341697f, 0.09685456f, -0.016764693f,
670 0.051808182f, 0.05875331f, -0.04536488f, 0.001626336f,
671 -0.028892258f, -0.01048663f, -0.009793449f, -0.017093895f,
672 0.010987891f, 0.02357273f, -0.00010856845f, 0.0099760275f,
673 -0.001845119f, -0.03551521f, 0.0018358806f, 0.05763657f,
674 -0.01769146f, 0.040995963f, 0.02235177f, -0.060430344f,
675 0.11475477f, -0.023854522f, 0.10071741f, 0.0686208f,
676 -0.014250481f, 0.034261297f, 0.047418304f, 0.08562733f,
677 -0.030519066f, 0.0060542435f, 0.014653856f, -0.038836084f,
678 0.04096551f, 0.032249358f, -0.08355519f, -0.026823482f,
679 0.056386515f, -0.010401743f, -0.028396193f, 0.08507674f,
680 0.014410365f, 0.020995233f, 0.17040324f, 0.11511526f,
681 0.02459721f, 0.0066619175f, 0.025853224f, -0.023133837f,
682 -0.081302024f, 0.017264642f, -0.009585969f, 0.09491168f,
683 -0.051313367f, 0.054532815f, -0.014298593f, 0.10657464f,
684 0.007076659f, 0.10964551f, 0.0409152f, 0.008275321f,
Sadik Armagan483c8112021-06-01 09:24:52 +0100685 -0.07283536f, 0.07937492f, 0.04192024f, -0.1075027f };
telsoa01c577f2c2018-08-31 09:22:23 +0100686
Sadik Armagan483c8112021-06-01 09:24:52 +0100687 std::vector<float> recurrentToCellWeights = { -0.037322544f, 0.018592842f, 0.0056175636f, -0.06253426f,
telsoa01c577f2c2018-08-31 09:22:23 +0100688 0.055647098f, -0.05713207f, -0.05626563f, 0.005559383f,
689 0.03375411f, -0.025757805f, -0.088049285f, 0.06017052f,
690 -0.06570978f, 0.007384076f, 0.035123326f, -0.07920549f,
691 0.053676967f, 0.044480428f, -0.07663568f, 0.0071805613f,
692 0.08089997f, 0.05143358f, 0.038261272f, 0.03339287f,
693 -0.027673481f, 0.044746667f, 0.028349208f, 0.020090483f,
694 -0.019443132f, -0.030755889f, -0.0040000007f, 0.04465846f,
695 -0.021585021f, 0.0031670958f, 0.0053199246f, -0.056117613f,
696 -0.10893326f, 0.076739706f, -0.08509834f, -0.027997585f,
697 0.037871376f, 0.01449768f, -0.09002357f, -0.06111149f,
698 -0.046195522f, 0.0422062f, -0.005683705f, -0.1253618f,
699 -0.012925729f, -0.04890792f, 0.06985068f, 0.037654128f,
700 0.03398274f, -0.004781977f, 0.007032333f, -0.031787455f,
701 0.010868644f, -0.031489216f, 0.09525667f, 0.013939797f,
702 0.0058680447f, 0.0167067f, 0.02668468f, -0.04797466f,
703 -0.048885044f, -0.12722108f, 0.035304096f, 0.06554885f,
704 0.00972396f, -0.039238118f, -0.05159735f, -0.11329045f,
705 0.1613692f, -0.03750952f, 0.06529313f, -0.071974665f,
706 -0.11769596f, 0.015524369f, -0.0013754242f, -0.12446318f,
707 0.02786344f, -0.014179351f, 0.005264273f, 0.14376344f,
708 0.015983658f, 0.03406988f, -0.06939408f, 0.040699873f,
709 0.02111075f, 0.09669095f, 0.041345075f, -0.08316494f,
710 -0.07684199f, -0.045768797f, 0.032298047f, -0.041805092f,
711 0.0119405f, 0.0061010392f, 0.12652606f, 0.0064572375f,
712 -0.024950314f, 0.11574242f, 0.04508852f, -0.04335324f,
713 0.06760663f, -0.027437469f, 0.07216407f, 0.06977076f,
714 -0.05438599f, 0.034033038f, -0.028602652f, 0.05346137f,
715 0.043184172f, -0.037189785f, 0.10420091f, 0.00882477f,
716 -0.054019816f, -0.074273005f, -0.030617684f, -0.0028467078f,
717 0.024302477f, -0.0038869337f, 0.005332455f, 0.0013399826f,
718 0.04361412f, -0.007001822f, 0.09631092f, -0.06702025f,
719 -0.042049985f, -0.035070654f, -0.04103342f, -0.10273396f,
720 0.0544271f, 0.037184782f, -0.13150354f, -0.0058036847f,
721 -0.008264958f, 0.042035464f, 0.05891794f, 0.029673764f,
722 0.0063542654f, 0.044788733f, 0.054816857f, 0.062257513f,
723 -0.00093483756f, 0.048938446f, -0.004952862f, -0.007730018f,
724 -0.04043371f, -0.017094059f, 0.07229206f, -0.023670016f,
725 -0.052195564f, -0.025616996f, -0.01520939f, 0.045104615f,
726 -0.007376126f, 0.003533447f, 0.006570588f, 0.056037236f,
727 0.12436656f, 0.051817212f, 0.028532185f, -0.08686856f,
728 0.11868599f, 0.07663395f, -0.07323171f, 0.03463402f,
729 -0.050708205f, -0.04458982f, -0.11590894f, 0.021273347f,
730 0.1251325f, -0.15313013f, -0.12224372f, 0.17228661f,
731 0.023029093f, 0.086124025f, 0.006445803f, -0.03496501f,
732 0.028332196f, 0.04449512f, -0.042436164f, -0.026587414f,
733 -0.006041347f, -0.09292539f, -0.05678812f, 0.03897832f,
734 0.09465633f, 0.008115513f, -0.02171956f, 0.08304309f,
735 0.071401566f, 0.019622514f, 0.032163795f, -0.004167056f,
736 0.02295182f, 0.030739572f, 0.056506045f, 0.004612461f,
737 0.06524936f, 0.059999723f, 0.046395954f, -0.0045512207f,
738 -0.1335546f, -0.030136576f, 0.11584653f, -0.014678886f,
739 0.0020118146f, -0.09688814f, -0.0790206f, 0.039770417f,
740 -0.0329582f, 0.07922767f, 0.029322514f, 0.026405897f,
741 0.04207835f, -0.07073373f, 0.063781224f, 0.0859677f,
742 -0.10925287f, -0.07011058f, 0.048005477f, 0.03438226f,
743 -0.09606514f, -0.006669445f, -0.043381985f, 0.04240257f,
744 -0.06955775f, -0.06769346f, 0.043903265f, -0.026784198f,
745 -0.017840602f, 0.024307009f, -0.040079936f, -0.019946516f,
746 0.045318738f, -0.12233574f, 0.026170589f, 0.0074471775f,
747 0.15978073f, 0.10185836f, 0.10298046f, -0.015476589f,
748 -0.039390966f, -0.072174534f, 0.0739445f, -0.1211869f,
749 -0.0347889f, -0.07943156f, 0.014809798f, -0.12412325f,
750 -0.0030663363f, 0.039695457f, 0.0647603f, -0.08291318f,
751 -0.018529687f, -0.004423833f, 0.0037507233f, 0.084633216f,
752 -0.01514876f, -0.056505352f, -0.012800942f, -0.06994386f,
753 0.012962922f, -0.031234352f, 0.07029052f, 0.016418684f,
754 0.03618972f, 0.055686004f, -0.08663945f, -0.017404709f,
755 -0.054761406f, 0.029065743f, 0.052404847f, 0.020238016f,
756 0.0048197987f, -0.0214882f, 0.07078733f, 0.013016777f,
757 0.06262858f, 0.009184685f, 0.020785125f, -0.043904778f,
758 -0.0270329f, -0.03299152f, -0.060088247f, -0.015162964f,
759 -0.001828936f, 0.12642565f, -0.056757294f, 0.013586685f,
760 0.09232601f, -0.035886683f, 0.06000002f, 0.05229691f,
761 -0.052580316f, -0.082029596f, -0.010794592f, 0.012947712f,
762 -0.036429964f, -0.085508935f, -0.13127148f, -0.017744139f,
763 0.031502828f, 0.036232427f, -0.031581745f, 0.023051167f,
764 -0.05325106f, -0.03421577f, 0.028793324f, -0.034633752f,
765 -0.009881397f, -0.043551125f, -0.018609839f, 0.0019097115f,
Sadik Armagan483c8112021-06-01 09:24:52 +0100766 -0.008799762f, 0.056595087f, 0.0022273948f, 0.055752404f };
telsoa01c577f2c2018-08-31 09:22:23 +0100767
Sadik Armagan483c8112021-06-01 09:24:52 +0100768 std::vector<float> recurrentToOutputWeights = { 0.025825322f, -0.05813119f, 0.09495884f,-0.045984812f, -0.01255415f,
769 -0.0026479573f,-0.08196161f,-0.054914974f,-0.0046604523f,
telsoa01c577f2c2018-08-31 09:22:23 +0100770 -0.029587349f, -0.044576716f, -0.07480124f, -0.082868785f,
771 0.023254942f, 0.027502948f, -0.0039728214f, -0.08683098f,
772 -0.08116779f, -0.014675607f, -0.037924774f, -0.023314456f,
773 -0.007401714f, -0.09255757f, 0.029460307f, -0.08829125f,
774 -0.005139627f, -0.08989442f, -0.0555066f, 0.13596267f,
775 -0.025062224f, -0.048351806f, -0.03850004f, 0.07266485f,
776 -0.022414139f, 0.05940088f, 0.075114764f, 0.09597592f,
777 -0.010211725f, -0.0049794707f, -0.011523867f, -0.025980417f,
778 0.072999895f, 0.11091378f, -0.081685916f, 0.014416728f,
779 0.043229222f, 0.034178585f, -0.07530371f, 0.035837382f,
780 -0.085607f, -0.007721233f, -0.03287832f, -0.043848954f,
781 -0.06404588f, -0.06632928f, -0.073643476f, 0.008214239f,
782 -0.045984086f, 0.039764922f, 0.03474462f, 0.060612556f,
783 -0.080590084f, 0.049127717f, 0.04151091f, -0.030063879f,
784 0.008801774f, -0.023021035f, -0.019558564f, 0.05158114f,
785 -0.010947698f, -0.011825728f, 0.0075720972f, 0.0699727f,
786 -0.0039981045f, 0.069350146f, 0.08799282f, 0.016156472f,
787 0.035502106f, 0.11695009f, 0.006217345f, 0.13392477f,
788 -0.037875112f, 0.025745004f, 0.08940699f, -0.00924166f,
789 0.0046702605f, -0.036598757f, -0.08811812f, 0.10522024f,
790 -0.032441203f, 0.008176899f, -0.04454919f, 0.07058152f,
791 0.0067963637f, 0.039206743f, 0.03259838f, 0.03725492f,
792 -0.09515802f, 0.013326398f, -0.052055415f, -0.025676316f,
793 0.03198509f, -0.015951829f, -0.058556724f, 0.036879618f,
794 0.043357447f, 0.028362012f, -0.05908629f, 0.0059240665f,
795 -0.04995891f, -0.019187413f,0.0276265f, -0.01628143f, 0.0025863599f,
796 0.08800015f, 0.035250366f, -0.022165963f, -0.07328642f,
797 -0.009415526f, -0.07455109f, 0.11690406f, 0.0363299f,
798 0.07411125f, 0.042103454f, -0.009660886f, 0.019076364f,
799 0.018299393f, -0.046004917f, 0.08891175f,0.0431396f, -0.026327137f,
800 -0.051502608f, 0.08979574f, -0.051670972f, 0.04940282f,
801 -0.07491107f, -0.021240504f, 0.022596184f, -0.034280192f,
802 0.060163025f, -0.058211457f, -0.051837247f, -0.01349775f,
803 -0.04639988f, -0.035936575f, -0.011681591f, 0.064818054f,
804 0.0073146066f, -0.021745546f, -0.043124277f, -0.06471268f,
805 -0.07053354f, -0.029321948f, -0.05330136f, 0.016933719f,
806 -0.053782392f, 0.13747959f, -0.1361751f, -0.11569455f,
807 0.0033329215f, 0.05693899f, -0.053219706f, 0.063698f,
808 0.07977434f, -0.07924483f, 0.06936997f, 0.0034815092f,
809 -0.007305279f, -0.037325785f, -0.07251102f, -0.033633437f,
810 -0.08677009f, 0.091591336f, -0.14165086f, 0.021752775f,
811 0.019683983f, 0.0011612234f, -0.058154266f, 0.049996935f,
812 0.0288841f, -0.0024567875f, -0.14345716f, 0.010955264f,-0.10234828f,
813 0.1183656f, -0.0010731248f, -0.023590032f,-0.072285876f,-0.0724771f,
814 -0.026382286f, -0.0014920527f, 0.042667855f, 0.0018776858f,
815 0.02986552f, 0.009814309f, 0.0733756f, 0.12289186f,
816 0.018043943f, -0.0458958f, 0.049412545f, 0.033632483f,
817 0.05495232f, 0.036686596f, -0.013781798f, -0.010036754f,
818 0.02576849f, -0.08307328f, 0.010112348f, 0.042521734f,
819 -0.05869831f, -0.071689695f, 0.03876447f, -0.13275425f, -0.0352966f,
820 -0.023077697f, 0.10285965f, 0.084736146f, 0.15568255f,
821 -0.00040734606f, 0.027835453f, -0.10292561f, -0.032401145f,
822 0.10053256f, -0.026142767f, -0.08271222f, -0.0030240538f,
823 -0.016368777f, 0.1070414f, 0.042672627f, 0.013456989f,
824 -0.0437609f, -0.022309763f, 0.11576483f, 0.04108048f,
825 0.061026827f, -0.0190714f, -0.0869359f, 0.037901703f, 0.0610107f,
826 0.07202949f, 0.01675338f, 0.086139716f, -0.08795751f,
827 -0.014898893f, -0.023771819f, -0.01965048f, 0.007955471f,
828 -0.043740474f, 0.03346837f, -0.10549954f, 0.090567775f,
829 0.042013682f, -0.03176985f, 0.12569028f, -0.02421228f,
830 -0.029526481f, 0.023851605f, 0.031539805f, 0.05292009f,
831 -0.02344001f, -0.07811758f, -0.08834428f, 0.10094801f,
832 0.16594367f, -0.06861939f, -0.021256343f, -0.041093912f,
833 -0.06669611f, 0.035498552f, 0.021757556f, -0.09302526f,
834 -0.015403468f, -0.06614931f, -0.051798206f, -0.013874718f,
835 0.03630673f, 0.010412845f, -0.08077351f, 0.046185967f,
836 0.0035662893f, 0.03541868f, -0.094149634f, -0.034814864f,
837 0.003128424f, -0.020674974f, -0.03944324f, -0.008110165f,
838 -0.11113267f, 0.08484226f, 0.043586485f, 0.040582247f,
839 0.0968012f, -0.065249965f, -0.028036479f, 0.0050708856f,
840 0.0017462453f, 0.0326779f, 0.041296225f, 0.09164146f,
841 -0.047743853f, -0.015952192f, -0.034451712f, 0.084197424f,
842 -0.05347844f, -0.11768019f, 0.085926116f, -0.08251791f,
843 -0.045081906f, 0.0948852f, 0.068401024f, 0.024856757f,
844 0.06978981f, -0.057309967f, -0.012775832f, -0.0032452994f,
Sadik Armagan483c8112021-06-01 09:24:52 +0100845 0.01977615f, -0.041040014f, -0.024264973f,0.063464895f, 0.05431621f};
telsoa01c577f2c2018-08-31 09:22:23 +0100846
Sadik Armagan483c8112021-06-01 09:24:52 +0100847 std::vector<float> cellToInputWeights = {0.040369894f, 0.030746894f, 0.24704495f, 0.018586371f, -0.037586458f,
848 -0.15312155f, -0.11812848f, -0.11465643f, 0.20259799f, 0.11418174f,
849 -0.10116027f, -0.011334949f, 0.12411352f, -0.076769054f,-0.052169047f,
850 0.21198851f, -0.38871562f, -0.09061183f, -0.09683246f, -0.21929175f};
telsoa01c577f2c2018-08-31 09:22:23 +0100851
852
Sadik Armagan483c8112021-06-01 09:24:52 +0100853 std::vector<float> cellToForgetWeights = {-0.01998659f,-0.15568835f,-0.24248174f, -0.012770197f, 0.041331276f,
854 -0.072311886f, -0.052123554f,-0.0066330447f,-0.043891653f,0.036225766f,
855 -0.047248036f, 0.021479502f,0.033189066f, 0.11952997f, -0.020432774f,
856 0.64658105f, -0.06650122f, -0.03467612f, 0.095340036f, 0.23647355f};
telsoa01c577f2c2018-08-31 09:22:23 +0100857
Sadik Armagan483c8112021-06-01 09:24:52 +0100858 std::vector<float> cellToOutputWeights = { 0.08286371f, -0.08261836f, -0.51210177f, 0.002913762f, 0.17764764f,
859 -0.5495371f, -0.08460716f, -0.24552552f, 0.030037103f, 0.04123544f,
860 -0.11940523f, 0.007358328f, 0.1890978f, 0.4833202f, -0.34441817f,
861 0.36312827f, -0.26375428f, 0.1457655f, -0.19724406f, 0.15548733f};
telsoa01c577f2c2018-08-31 09:22:23 +0100862
Sadik Armagan483c8112021-06-01 09:24:52 +0100863 std::vector<float> projectionWeights={-0.009802181f, 0.09401916f, 0.0717386f, -0.13895074f, 0.09641832f,
864 0.060420845f, 0.08539281f, 0.054285463f, 0.061395317f, 0.034448683f,
865 -0.042991187f, 0.019801661f, -0.16840284f, -0.015726732f, -0.23041931f,
866 -0.024478018f, -0.10959692f, -0.013875541f, 0.18600968f, -0.061274476f,
867 0.0138165f, -0.08160894f, -0.07661644f, 0.032372914f, 0.16169067f,
868 0.22465782f, -0.03993472f, -0.004017731f, 0.08633481f, -0.28869787f,
869 0.08682067f, 0.17240396f, 0.014975425f, 0.056431185f, 0.031037588f,
870 0.16702051f, 0.0077946745f, 0.15140012f, 0.29405436f, 0.120285f,
871 -0.188994f, -0.027265169f, 0.043389652f, -0.022061434f, 0.014777949f,
872 -0.20203483f, 0.094781205f, 0.19100232f, 0.13987629f, -0.036132768f,
873 -0.06426278f, -0.05108664f, 0.13221376f, 0.009441198f, -0.16715929f,
874 0.15859416f, -0.040437475f, 0.050779544f, -0.022187516f, 0.012166504f,
875 0.027685808f, -0.07675938f, -0.0055694645f, -0.09444123f, 0.0046453946f,
876 0.050794356f, 0.10770313f, -0.20790008f, -0.07149004f, -0.11425117f,
877 0.008225835f, -0.035802525f, 0.14374903f, 0.15262283f, 0.048710253f,
878 0.1847461f, -0.007487823f, 0.11000021f, -0.09542012f, 0.22619456f,
879 -0.029149994f, 0.08527916f, 0.009043713f, 0.0042746216f, 0.016261552f,
880 0.022461696f, 0.12689082f, -0.043589946f, -0.12035478f, -0.08361797f,
881 -0.050666027f, -0.1248618f, -0.1275799f, -0.071875185f, 0.07377272f,
882 0.09944291f, -0.18897448f, -0.1593054f, -0.06526116f, -0.040107165f,
883 -0.004618631f, -0.067624845f, -0.007576253f, 0.10727444f, 0.041546922f,
884 -0.20424393f, 0.06907816f, 0.050412357f, 0.00724631f, 0.039827548f,
885 0.12449835f, 0.10747581f, 0.13708383f, 0.09134148f, -0.12617786f,
886 -0.06428341f, 0.09956831f, 0.1208086f, -0.14676677f, -0.0727722f,
887 0.1126304f, 0.010139365f, 0.015571211f, -0.038128063f, 0.022913318f,
888 -0.042050496f, 0.16842307f, -0.060597885f, 0.10531834f, -0.06411776f,
889 -0.07451711f, -0.03410368f, -0.13393489f, 0.06534304f, 0.003620307f,
890 0.04490757f, 0.05970546f, 0.05197996f, 0.02839995f, 0.10434969f,
891 -0.013699693f, -0.028353551f, -0.07260381f, 0.047201227f, -0.024575593f,
892 -0.036445823f, 0.07155557f, 0.009672501f, -0.02328883f, 0.009533515f,
893 -0.03606021f, -0.07421458f, -0.028082801f, -0.2678904f, -0.13221288f,
894 0.18419984f, -0.13012612f, -0.014588381f, -0.035059117f, -0.04824723f,
895 0.07830115f, -0.056184657f, 0.03277091f, 0.025466874f, 0.14494097f,
896 -0.12522776f, -0.098633975f, -0.10766018f, -0.08317623f, 0.08594209f,
897 0.07749552f, 0.039474737f, 0.1776665f, -0.07409566f, -0.0477268f,
898 0.29323658f, 0.10801441f, 0.1154011f, 0.013952499f, 0.10739139f,
899 0.10708251f, -0.051456142f, 0.0074137426f, -0.10430189f, 0.10034707f,
900 0.045594677f, 0.0635285f, -0.0715442f, -0.089667566f, -0.10811871f,
901 0.00026344223f, 0.08298446f, -0.009525053f, 0.006585689f, -0.24567553f,
902 -0.09450807f, 0.09648481f, 0.026996298f, -0.06419476f, -0.04752702f,
903 -0.11063944f, -0.23441927f, -0.17608605f, -0.052156363f, 0.067035615f,
904 0.19271925f, -0.0032889997f, -0.043264326f, 0.09663576f, -0.057112187f,
905 -0.10100678f, 0.0628376f, 0.04447668f, 0.017961001f, -0.10094388f,
906 -0.10190601f, 0.18335468f, 0.10494553f, -0.052095775f, -0.0026118709f,
907 0.10539724f, -0.04383912f, -0.042349473f, 0.08438151f, -0.1947263f,
908 0.02251204f, 0.11216432f, -0.10307853f, 0.17351969f, -0.039091777f,
909 0.08066188f, -0.00561982f, 0.12633002f, 0.11335965f, -0.0088127935f,
910 -0.019777594f, 0.06864014f, -0.059751723f, 0.016233567f, -0.06894641f,
911 -0.28651384f, -0.004228674f, 0.019708522f, -0.16305895f, -0.07468996f,
912 -0.0855457f, 0.099339016f, -0.07580735f, -0.13775392f, 0.08434318f,
913 0.08330512f, -0.12131499f, 0.031935584f, 0.09180414f, -0.08876437f,
914 -0.08049874f, 0.008753825f, 0.03498998f, 0.030215185f, 0.03907079f,
915 0.089751154f, 0.029194152f, -0.03337423f, -0.019092513f, 0.04331237f,
916 0.04299654f, -0.036394123f, -0.12915532f, 0.09793732f, 0.07512415f,
917 -0.11319543f, -0.032502122f, 0.15661901f, 0.07671967f, -0.005491124f,
918 -0.19379048f, -0.218606f, 0.21448623f, 0.017840758f, 0.1416943f,
919 -0.07051762f, 0.19488361f, 0.02664691f, -0.18104725f, -0.09334311f,
920 0.15026465f, -0.15493552f, -0.057762887f, -0.11604192f, -0.262013f,
921 -0.01391798f, 0.012185008f, 0.11156489f, -0.07483202f, 0.06693364f,
922 -0.26151478f, 0.046425626f, 0.036540434f, -0.16435726f, 0.17338543f,
923 -0.21401681f, -0.11385144f, -0.08283257f, -0.069031075f, 0.030635102f,
924 0.010969227f, 0.11109743f, 0.010919218f, 0.027526086f, 0.13519906f,
925 0.01891392f, -0.046839405f, -0.040167913f, 0.017953383f, -0.09700955f,
926 0.0061885654f, -0.07000971f, 0.026893595f, -0.038844477f, 0.14543656f};
telsoa01c577f2c2018-08-31 09:22:23 +0100927
928 std::vector<float> projectionBiasVector(outputSize, 0.f);
telsoa01c577f2c2018-08-31 09:22:23 +0100929
James Conroy1f58f032021-04-27 17:13:27 +0100930 armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo20x5);
931 armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo20x5);
932 armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo20x5);
933 armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo20x5);
934 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo20x16);
935 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo20x16);
936 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo20x16);
937 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo20x16);
938 armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo20);
939 armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo20);
940 armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo20);
941 armnn::ScopedTensorHandle cellBiasTensor(tensorInfo20);
942 armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo20);
943 armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfo20);
944 armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfo20);
945 armnn::ScopedTensorHandle projectionWeightsTensor(tensorInfo16x20);
946 armnn::ScopedTensorHandle projectionBiasTensor(tensorInfo16);
telsoa01c577f2c2018-08-31 09:22:23 +0100947
Sadik Armagan483c8112021-06-01 09:24:52 +0100948 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
949 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
950 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
951 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
952 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
953 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
954 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
955 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
956 AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data());
957 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
958 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
959 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
960 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
961 AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data());
962 AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data());
963 AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
964 AllocateAndCopyDataToITensorHandle(&projectionBiasTensor, projectionBiasVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +0100965
966 data.m_InputToInputWeights = &inputToInputWeightsTensor;
967 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
968 data.m_InputToCellWeights = &inputToCellWeightsTensor;
969 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
970 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
971 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
972 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
973 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
974 data.m_CellToInputWeights = &cellToInputWeightsTensor;
975 data.m_InputGateBias = &inputGateBiasTensor;
976 data.m_ForgetGateBias = &forgetGateBiasTensor;
977 data.m_CellBias = &cellBiasTensor;
978 data.m_OutputGateBias = &outputGateBiasTensor;
979 data.m_CellToForgetWeights = &cellToForgetWeightsTensor;
980 data.m_CellToOutputWeights = &cellToOutputWeightsTensor;
981 data.m_ProjectionWeights = &projectionWeightsTensor;
982 data.m_ProjectionBias = &projectionBiasTensor;
983
984 // Flags to set test configuration
985 data.m_Parameters.m_ActivationFunc = 4;
986 data.m_Parameters.m_CifgEnabled = false;
987 data.m_Parameters.m_PeepholeEnabled = true;
988 data.m_Parameters.m_ProjectionEnabled = true;
989
telsoa01c577f2c2018-08-31 09:22:23 +0100990 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info);
991 inputHandle->Allocate();
992 outputStateInHandle->Allocate();
993 cellStateInHandle->Allocate();
994
995 scratchHandle->Allocate();
996 outputStateOutHandle->Allocate();
997 cellStateOutHandle->Allocate();
998 outputHandle->Allocate();
999
Sadik Armagan483c8112021-06-01 09:24:52 +01001000 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1001 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1002 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001003
telsoa01c577f2c2018-08-31 09:22:23 +01001004 workload->Execute();
1005
Sadik Armagan483c8112021-06-01 09:24:52 +01001006 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
telsoa01c577f2c2018-08-31 09:22:23 +01001007
Sadik Armagan483c8112021-06-01 09:24:52 +01001008 return LayerTestResult<T, 2>(actualOutput,
1009 outputVector,
1010 outputHandle->GetShape(),
1011 outputTensorInfo.GetShape());
telsoa01c577f2c2018-08-31 09:22:23 +01001012}
1013
Conor Kennedyb9971c92019-05-07 07:14:23 +01001014template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
1015LayerTestResult<T, 2> LstmLayerWithCifgWithPeepholeNoProjectionTestImpl(
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +00001016 armnn::IWorkloadFactory& workloadFactory,
1017 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001018 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001019 const std::vector<T>& input,
1020 const std::vector<T>& outputExpected,
1021 const armnn::TensorShape& inputShape,
1022 const armnn::TensorShape& outputExpectedShape,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001023 float qScale = 0.0f,
1024 int32_t qOffset = 0,
1025 armnn::DataType constantDataType = armnn::DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01001026{
Jan Eilers8eb25602020-03-09 12:13:48 +00001027 IgnoreUnused(memoryManager);
telsoa01c577f2c2018-08-31 09:22:23 +01001028 bool cifgEnabled = true;
1029 bool peepholeEnabled = true;
1030 bool projectionEnabled = false;
1031 // These are not the input and the output of Lstm yet
Sadik Armagan483c8112021-06-01 09:24:52 +01001032 unsigned int batchSize = armnn::numeric_cast<unsigned int>(inputShape[0]);
1033 unsigned int inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]);
telsoa01c577f2c2018-08-31 09:22:23 +01001034
Sadik Armagan483c8112021-06-01 09:24:52 +01001035 unsigned int outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]);
telsoa01c577f2c2018-08-31 09:22:23 +01001036
1037 const unsigned int cellSize = outputSize;
1038
1039 // Decide the shape of all input tensors
Conor Kennedyb9971c92019-05-07 07:14:23 +01001040 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset); // change to ArmnnType
1041 armnn::TensorInfo outputStateInTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1042 armnn::TensorInfo cellStateInTensorInfo({batchSize, cellSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +01001043
Matteo Martincigha65b7ae2018-11-14 12:39:55 +00001044 unsigned int scratchBufferSize = cifgEnabled ? cellSize * 3 : cellSize * 4;
Conor Kennedyb9971c92019-05-07 07:14:23 +01001045 armnn::TensorInfo scratchBufferTensorInfo({batchSize, scratchBufferSize}, ArmnnType, qScale, qOffset);
1046 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1047 armnn::TensorInfo cellStateOutTensorInfo({batchSize, cellSize}, ArmnnType, qScale, qOffset);
1048 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +01001049
1050 // List of inputs
1051 std::vector<float> inputData;
1052 inputData.assign(input.data(), input.data() + batchSize*inputSize);
telsoa01c577f2c2018-08-31 09:22:23 +01001053
1054 std::vector<float> outputStateInVector(batchSize * outputSize, 0.f);
telsoa01c577f2c2018-08-31 09:22:23 +01001055
1056 std::vector<float> cellStateInVector(batchSize * cellSize, 0.f);
telsoa01c577f2c2018-08-31 09:22:23 +01001057
1058 // Prepare all the weights in the descriptor for LSTM
1059 armnn::LstmQueueDescriptor data;
Conor Kennedyb9971c92019-05-07 07:14:23 +01001060 armnn::TensorInfo tensorInfoInput({cellSize, inputSize}, constantDataType, qScale, qOffset);
1061 armnn::TensorInfo tensorInfoOutput({cellSize, outputSize}, constantDataType, qScale, qOffset);
1062 armnn::TensorInfo tensorInfoNumUnits({cellSize}, constantDataType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +01001063
Sadik Armagan483c8112021-06-01 09:24:52 +01001064 std::vector<float> inputToCellWeights =
1065 {
1066 -0.49770179f, -0.27711356f, -0.09624726f, 0.05100781f,
1067 0.04717243f, 0.48944736f, -0.38535351f,
1068 -0.17212132f
1069 };
1070 std::vector<float> inputToForgetWeights =
1071 {
1072 -0.55291498f, -0.42866567f, 0.13056988f,
1073 -0.3633365f, -0.22755712f, 0.28253698f, 0.24407166f,
1074 0.33826375f
1075 };
1076 std::vector<float> inputToOutputWeights =
1077 {
1078 0.10725588f, -0.02335852f, -0.55932593f,
1079 -0.09426838f, -0.44257352f, 0.54939759f,
1080 0.01533556f, 0.42751634f
1081 };
1082 std::vector<float> cellBias = {0.f, 0.f, 0.f, 0.f};
1083 std::vector<float> forgetGateBias = {1.f, 1.f, 1.f, 1.f};
1084 std::vector<float> outputGateBias = {0.f, 0.f, 0.f, 0.f};
telsoa01c577f2c2018-08-31 09:22:23 +01001085
Sadik Armagan483c8112021-06-01 09:24:52 +01001086 std::vector<float> recurrentToCellWeights =
1087 {
1088 0.54066205f, -0.32668582f, -0.43562764f, -0.56094903f, 0.42957711f,
1089 0.01841056f, -0.32764608f, -0.33027974f, -0.10826075f, 0.20675004f,
1090 0.19069612f, -0.03026325f, -0.54532051f, 0.33003211f, 0.44901288f,
1091 0.21193194f
1092 };
1093 std::vector<float> recurrentToForgetWeights =
1094 {
1095 -0.13832897f, -0.0515101f, -0.2359007f, -0.16661474f, -0.14340827f,
1096 0.36986142f, 0.23414481f, 0.55899f, 0.10798943f, -0.41174671f, 0.17751795f,
1097 -0.34484994f, -0.35874045f, -0.11352962f, 0.27268326f, 0.54058349f
1098 };
telsoa01c577f2c2018-08-31 09:22:23 +01001099
Sadik Armagan483c8112021-06-01 09:24:52 +01001100 std::vector<float> recurrentToOutputWeights =
1101 {
1102 0.41613156f, 0.42610586f, -0.16495961f, -0.5663873f, 0.30579174f, -0.05115908f,
1103 -0.33941799f, 0.23364776f, 0.11178309f, 0.09481031f, -0.26424935f, 0.46261835f,
1104 0.50248802f, 0.26114327f, -0.43736315f, 0.33149987f
1105 };
telsoa01c577f2c2018-08-31 09:22:23 +01001106
Sadik Armagan483c8112021-06-01 09:24:52 +01001107 std::vector<float> cellToForgetWeights = {0.47485286f, -0.51955009f, -0.24458408f, 0.31544167f};
1108 std::vector<float> cellToOutputWeights = {-0.17135078f, 0.82760304f, 0.85573703f, -0.77109635f};
telsoa01c577f2c2018-08-31 09:22:23 +01001109
James Conroy1f58f032021-04-27 17:13:27 +01001110 armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfoInput);
1111 armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfoInput);
1112 armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfoInput);
telsoa01c577f2c2018-08-31 09:22:23 +01001113
James Conroy1f58f032021-04-27 17:13:27 +01001114 armnn::ScopedTensorHandle cellBiasTensor(tensorInfoNumUnits);
1115 armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfoNumUnits);
1116 armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfoNumUnits);
telsoa01c577f2c2018-08-31 09:22:23 +01001117
James Conroy1f58f032021-04-27 17:13:27 +01001118 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfoOutput);
1119 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfoOutput);
1120 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfoOutput);
telsoa01c577f2c2018-08-31 09:22:23 +01001121
James Conroy1f58f032021-04-27 17:13:27 +01001122 armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfoNumUnits);
1123 armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfoNumUnits);
telsoa01c577f2c2018-08-31 09:22:23 +01001124
Sadik Armagan483c8112021-06-01 09:24:52 +01001125 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1126 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1127 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001128
Sadik Armagan483c8112021-06-01 09:24:52 +01001129 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1130 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1131 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001132
Sadik Armagan483c8112021-06-01 09:24:52 +01001133 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1134 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1135 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001136
Sadik Armagan483c8112021-06-01 09:24:52 +01001137 AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data());
1138 AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001139
1140 data.m_InputToCellWeights = &inputToCellWeightsTensor;
1141 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1142 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1143
1144 data.m_CellBias = &cellBiasTensor;
1145 data.m_ForgetGateBias = &forgetGateBiasTensor;
1146 data.m_OutputGateBias = &outputGateBiasTensor;
1147
1148 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1149 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1150 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1151
1152 data.m_CellToForgetWeights = &cellToForgetWeightsTensor;
1153 data.m_CellToOutputWeights = &cellToOutputWeightsTensor;
1154
1155 // other parameters for the descriptor
1156 data.m_Parameters.m_CifgEnabled = cifgEnabled;
1157 data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
1158 data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
1159
1160 data.m_Parameters.m_ActivationFunc = 4;
1161 data.m_Parameters.m_ClippingThresProj = 0.0;
1162 data.m_Parameters.m_ClippingThresCell = 0.0;
1163
telsoa01c577f2c2018-08-31 09:22:23 +01001164 // List of outputs
Rob Hughesbb46dde2020-05-20 15:27:37 +01001165 std::vector<T> scratchBufferVector(batchSize * scratchBufferSize, T());
Conor Kennedyb9971c92019-05-07 07:14:23 +01001166 LayerTestResult<T, 2> ret0(scratchBufferTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001167
1168 // Output state for a certain time step
Rob Hughesbb46dde2020-05-20 15:27:37 +01001169 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
Conor Kennedyb9971c92019-05-07 07:14:23 +01001170 LayerTestResult<T, 2> ret1(outputStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001171
1172 // Cell state for a certain time step
Rob Hughesbb46dde2020-05-20 15:27:37 +01001173 std::vector<T> cellStateOutVector(batchSize * cellSize, T());
Conor Kennedyb9971c92019-05-07 07:14:23 +01001174 LayerTestResult<T, 2> ret2(cellStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001175
1176 // Output for a certain time step
Rob Hughesbb46dde2020-05-20 15:27:37 +01001177 std::vector<T> outputData;
telsoa01c577f2c2018-08-31 09:22:23 +01001178 outputData.assign(outputExpected.data(), outputExpected.data() + batchSize*outputSize);
Conor Kennedyb9971c92019-05-07 07:14:23 +01001179 LayerTestResult<T, 2> ret3(outputTensorInfo);
Sadik Armagan483c8112021-06-01 09:24:52 +01001180 ret3.m_ExpectedData = outputData;
1181
1182 std::vector<T> actualScratchBufferOutput(scratchBufferTensorInfo.GetNumElements());
1183 std::vector<T> actualOutputStateOutput(outputStateOutTensorInfo.GetNumElements());
1184 std::vector<T> actualCellStateOutput(cellStateOutTensorInfo.GetNumElements());
1185 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
telsoa01c577f2c2018-08-31 09:22:23 +01001186
1187 // Prepare the inputs and outputs for the workload
1188 std::unique_ptr<armnn::ITensorHandle> inputHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001189 tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001190 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001191 tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001192 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001193 tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001194
1195 std::unique_ptr<armnn::ITensorHandle> scratchBufferHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001196 tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001197 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001198 tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001199 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001200 tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001201 std::unique_ptr<armnn::ITensorHandle> outputHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001202 tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001203
1204 armnn::WorkloadInfo info;
1205 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
1206 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
1207 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
1208
1209 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchBufferHandle.get());
1210 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
1211 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
1212 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
1213
1214 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info);
1215
telsoa01c577f2c2018-08-31 09:22:23 +01001216 inputHandle->Allocate();
1217 outputStateInHandle->Allocate();
1218 cellStateInHandle->Allocate();
1219
1220 scratchBufferHandle->Allocate();
1221 outputStateOutHandle->Allocate();
1222 cellStateOutHandle->Allocate();
1223 outputHandle->Allocate();
1224
Sadik Armagan483c8112021-06-01 09:24:52 +01001225 CopyDataToITensorHandle(inputHandle.get(), inputData.data());
1226 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1227 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001228
Sadik Armagan483c8112021-06-01 09:24:52 +01001229 CopyDataToITensorHandle(scratchBufferHandle.get(), scratchBufferVector.data());
1230 CopyDataToITensorHandle(outputStateOutHandle.get(), outputStateOutVector.data());
1231 CopyDataToITensorHandle(cellStateOutHandle.get(), cellStateOutVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001232
telsoa01c577f2c2018-08-31 09:22:23 +01001233 workload->Execute();
1234
Sadik Armagan483c8112021-06-01 09:24:52 +01001235 CopyDataFromITensorHandle(actualScratchBufferOutput.data(), scratchBufferHandle.get());
1236 CopyDataFromITensorHandle(actualOutputStateOutput.data(), outputStateOutHandle.get());
1237 CopyDataFromITensorHandle(actualCellStateOutput.data(), cellStateOutHandle.get());
1238 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
1239
1240 ret0.m_ActualData = actualScratchBufferOutput;
1241 ret1.m_ActualData = actualOutputStateOutput;
1242 ret2.m_ActualData = actualCellStateOutput;
1243 ret3.m_ActualData = actualOutput;
telsoa01c577f2c2018-08-31 09:22:23 +01001244
1245 return ret3;
1246}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001247
Jan Eilers38e05bd2019-06-26 13:10:09 +01001248template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
1249LayerTestResult<T, 2>
1250LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl(armnn::IWorkloadFactory& workloadFactory,
1251 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001252 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001253 const std::vector<T>& input,
1254 const std::vector<T>& outputExpected,
Jan Eilers38e05bd2019-06-26 13:10:09 +01001255 float qScale = 0.0f,
1256 int32_t qOffset = 0,
1257 armnn::DataType constantDataType = armnn::DataType::Float32)
1258{
Jan Eilers8eb25602020-03-09 12:13:48 +00001259 IgnoreUnused(memoryManager);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001260 unsigned int batchSize = 2;
1261 unsigned int outputSize = 3;
1262 unsigned int inputSize = 5;
1263 unsigned numUnits = 4;
1264
1265 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset);
1266 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
1267 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
1268
1269 // Scratch buffer size without CIFG [batchSize, numUnits * 4]
1270 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
1271 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
1272 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1273 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1274
Jan Eilers38e05bd2019-06-26 13:10:09 +01001275 std::vector<float> inputVector;
1276 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001277
1278 std::vector<float> cellStateInVector(batchSize * numUnits, 0.f);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001279 std::vector<float> outputStateInVector(batchSize * outputSize, 0.f);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001280 std::vector<float> scratchBufferVector(batchSize * numUnits * 4, 0.f);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001281 std::vector<float> outputStateOutVector(batchSize * outputSize, 0.f);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001282 std::vector<float> cellStateOutVector(batchSize * numUnits, 0.f);
Sadik Armagan483c8112021-06-01 09:24:52 +01001283
1284 std::vector<float> actualOutput(outputTensorInfo.GetNumElements());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001285
1286 std::vector<float> outputVector;
1287 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001288
Finn Williamsc43de6a2020-08-27 11:13:25 +01001289 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001290 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001291 tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001292 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001293 tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001294
Finn Williamsc43de6a2020-08-27 11:13:25 +01001295 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
1296 tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001297 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001298 tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001299 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001300 tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
1301 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001302
1303 armnn::LstmQueueDescriptor data;
1304 armnn::WorkloadInfo info;
1305
1306 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
1307 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
1308 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
1309
1310 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
1311 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
1312 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
1313 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
1314
1315 armnn::TensorInfo tensorInfo3({outputSize}, constantDataType, qScale, qOffset);
1316 armnn::TensorInfo tensorInfo4({numUnits}, constantDataType, qScale, qOffset);
1317 armnn::TensorInfo tensorInfo4x5({numUnits, inputSize}, constantDataType, qScale, qOffset);
1318 armnn::TensorInfo tensorInfo4x3({numUnits, outputSize}, constantDataType, qScale, qOffset);
1319 armnn::TensorInfo tensorInfo3x4({outputSize, numUnits}, constantDataType, qScale, qOffset);
1320
Sadik Armagan483c8112021-06-01 09:24:52 +01001321 std::vector<float> inputToInputWeights = {0.5f, 0.6f, 0.7f, -0.8f, -0.9f,
1322 0.1f, 0.2f, 0.3f, -0.4f, 0.5f,
1323 -0.8f, 0.7f, -0.6f, 0.5f, -0.4f,
1324 -0.5f, -0.4f, -0.3f, -0.2f, -0.1f}; //{numUnits, inputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001325
Sadik Armagan483c8112021-06-01 09:24:52 +01001326 std::vector<float> inputToForgetWeights = { -0.6f, -0.1f, 0.3f, 0.2f, 0.9f,
1327 -0.5f, -0.2f, -0.4f, 0.3f, -0.8f,
1328 -0.4f, 0.3f, -0.5f, -0.4f, -0.6f,
1329 0.3f, -0.4f, -0.6f, -0.5f, -0.5f}; //{numUnits, inputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001330
Sadik Armagan483c8112021-06-01 09:24:52 +01001331 std::vector<float> inputToCellWeights = {-0.4f, -0.3f, -0.2f, -0.1f, -0.5f,
1332 0.5f, -0.2f, -0.3f, -0.2f, -0.6f,
1333 0.6f, -0.1f, -0.4f, -0.3f, -0.7f,
1334 0.7f, -0.9f, -0.5f, 0.8f, 0.6f}; //{numUnits, inputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001335
Sadik Armagan483c8112021-06-01 09:24:52 +01001336 std::vector<float> inputToOutputWeights = {-0.8f, -0.4f, -0.2f, -0.9f, -0.1f,
1337 -0.7f, 0.3f, -0.3f, -0.8f, -0.2f,
1338 0.6f, -0.2f, 0.4f, -0.7f, -0.3f,
1339 -0.5f, 0.1f, 0.5f, -0.6f, -0.4f}; //{numUnits, inputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001340
Sadik Armagan483c8112021-06-01 09:24:52 +01001341 std::vector<float> inputGateBias = {0.03f, 0.15f, 0.22f, 0.38f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001342
Sadik Armagan483c8112021-06-01 09:24:52 +01001343 std::vector<float> forgetGateBias = {0.1f, -0.3f, -0.2f, 0.1f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001344
Sadik Armagan483c8112021-06-01 09:24:52 +01001345 std::vector<float> cellBias = {-0.05f, 0.72f, 0.25f, 0.08f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001346
Sadik Armagan483c8112021-06-01 09:24:52 +01001347 std::vector<float> outputGateBias = {0.05f, -0.01f, 0.2f, 0.1f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001348
Sadik Armagan483c8112021-06-01 09:24:52 +01001349 std::vector<float> recurrentToInputWeights ={-0.2f, -0.3f, 0.4f,
Jan Eilers38e05bd2019-06-26 13:10:09 +01001350 0.1f, -0.5f, 0.9f,
1351 -0.2f, -0.3f, -0.7f,
Sadik Armagan483c8112021-06-01 09:24:52 +01001352 0.05f, -0.2f, -0.6f}; //{numUnits, outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001353
Sadik Armagan483c8112021-06-01 09:24:52 +01001354 std::vector<float> recurrentToCellWeights = {-0.3f, 0.2f, 0.1f,
Jan Eilers38e05bd2019-06-26 13:10:09 +01001355 -0.3f, 0.8f, -0.08f,
1356 -0.2f, 0.3f, 0.8f,
Sadik Armagan483c8112021-06-01 09:24:52 +01001357 -0.6f, -0.1f, 0.2f}; //{numUnits, outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001358
Sadik Armagan483c8112021-06-01 09:24:52 +01001359 std::vector<float> recurrentToForgetWeights = { -0.5f, -0.3f, -0.5f,
1360 -0.2f, 0.6f, 0.4f,
1361 0.9f, 0.3f, -0.1f,
1362 0.2f, 0.5f, 0.2f}; //{numUnits, outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001363
Sadik Armagan483c8112021-06-01 09:24:52 +01001364 std::vector<float> recurrentToOutputWeights = { 0.3f, -0.1f, 0.1f,
1365 -0.2f, -0.5f, -0.7f,
1366 -0.2f, -0.6f, -0.1f,
1367 -0.4f, -0.7f, -0.2f}; //{numUnits, outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001368
Sadik Armagan483c8112021-06-01 09:24:52 +01001369 std::vector<float> cellToInputWeights = {0.05f, 0.1f, 0.25f, 0.15f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001370
Sadik Armagan483c8112021-06-01 09:24:52 +01001371 std::vector<float> cellToForgetWeights = {-0.02f, -0.15f, -0.25f, -0.03f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001372
Sadik Armagan483c8112021-06-01 09:24:52 +01001373 std::vector<float> cellToOutputWeights = {0.1f, -0.1f, -0.5f, 0.05f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001374
Sadik Armagan483c8112021-06-01 09:24:52 +01001375 std::vector<float> projectionWeights = {-0.1f, 0.2f, 0.01f, -0.2f,
1376 0.1f, 0.5f, 0.3f, 0.08f,
1377 0.07f, 0.2f, -0.4f, 0.2f}; //{outputSize, numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001378
Sadik Armagan483c8112021-06-01 09:24:52 +01001379 std::vector<float> projectionBiasVector(outputSize, 0.f); //{outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001380
Sadik Armagan483c8112021-06-01 09:24:52 +01001381 std::vector<float> inputLayerNormWeights = {0.1f, 0.2f, 0.3f, 0.5f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001382
Sadik Armagan483c8112021-06-01 09:24:52 +01001383 std::vector<float> forgetLayerNormWeights = {0.2f, 0.2f, 0.4f, 0.3f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001384
Sadik Armagan483c8112021-06-01 09:24:52 +01001385 std::vector<float> cellLayerNormWeights = {0.7f, 0.2f, 0.3f, 0.8f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001386
Sadik Armagan483c8112021-06-01 09:24:52 +01001387 std::vector<float> outputLayerNormWeights = {0.6f, 0.2f, 0.2f, 0.5f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001388
1389
James Conroy1f58f032021-04-27 17:13:27 +01001390 armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo4x5);
1391 armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo4x5);
1392 armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo4x5);
1393 armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo4x5);
1394 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo4x3);
1395 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo4x3);
1396 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo4x3);
1397 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo4x3);
1398 armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo4);
1399 armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo4);
1400 armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo4);
1401 armnn::ScopedTensorHandle cellBiasTensor(tensorInfo4);
1402 armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo4);
1403 armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfo4);
1404 armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfo4);
1405 armnn::ScopedTensorHandle projectionWeightsTensor(tensorInfo3x4);
1406 armnn::ScopedTensorHandle projectionBiasTensor(tensorInfo3);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001407
James Conroy1f58f032021-04-27 17:13:27 +01001408 armnn::ScopedTensorHandle inputLayerNormWeightsTensor(tensorInfo4);
1409 armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(tensorInfo4);
1410 armnn::ScopedTensorHandle cellLayerNormWeightsTensor(tensorInfo4);
1411 armnn::ScopedTensorHandle outputLayerNormWeightsTensor(tensorInfo4);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001412
Sadik Armagan483c8112021-06-01 09:24:52 +01001413 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
1414 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1415 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1416 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
1417 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
1418 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1419 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1420 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
1421 AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data());
1422 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
1423 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1424 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1425 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
1426 AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data());
1427 AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data());
1428 AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
1429 AllocateAndCopyDataToITensorHandle(&projectionBiasTensor, projectionBiasVector.data());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001430
Sadik Armagan483c8112021-06-01 09:24:52 +01001431 AllocateAndCopyDataToITensorHandle(&inputLayerNormWeightsTensor, inputLayerNormWeights.data());
1432 AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
1433 AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
1434 AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001435
1436 data.m_InputToInputWeights = &inputToInputWeightsTensor;
1437 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1438 data.m_InputToCellWeights = &inputToCellWeightsTensor;
1439 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1440 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
1441 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1442 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1443 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1444 data.m_CellToInputWeights = &cellToInputWeightsTensor;
1445 data.m_InputGateBias = &inputGateBiasTensor;
1446 data.m_ForgetGateBias = &forgetGateBiasTensor;
1447 data.m_CellBias = &cellBiasTensor;
1448 data.m_OutputGateBias = &outputGateBiasTensor;
1449 data.m_CellToForgetWeights = &cellToForgetWeightsTensor;
1450 data.m_CellToOutputWeights = &cellToOutputWeightsTensor;
1451 data.m_ProjectionWeights = &projectionWeightsTensor;
1452 data.m_ProjectionBias = &projectionBiasTensor;
1453
1454 data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor;
1455 data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
1456 data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
1457 data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
1458
1459 // Flags to set test configuration
1460 data.m_Parameters.m_ActivationFunc = 4;
1461 data.m_Parameters.m_CifgEnabled = false;
1462 data.m_Parameters.m_PeepholeEnabled = true;
1463 data.m_Parameters.m_ProjectionEnabled = true;
1464 data.m_Parameters.m_LayerNormEnabled = true;
1465
1466
1467 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info);
1468 inputHandle->Allocate();
1469 outputStateInHandle->Allocate();
1470 cellStateInHandle->Allocate();
1471
1472 scratchHandle->Allocate();
1473 outputStateOutHandle->Allocate();
1474 cellStateOutHandle->Allocate();
1475 outputHandle->Allocate();
1476
Sadik Armagan483c8112021-06-01 09:24:52 +01001477 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1478 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1479 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001480
1481 workload->Execute();
1482
Sadik Armagan483c8112021-06-01 09:24:52 +01001483 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001484
Sadik Armagan483c8112021-06-01 09:24:52 +01001485 return LayerTestResult<T, 2>(actualOutput,
1486 outputVector,
1487 outputHandle->GetShape(),
1488 outputTensorInfo.GetShape());
James Conroy9c3cae82019-08-01 16:01:48 +01001489}
1490
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001491LayerTestResult<uint8_t, 2> QuantizedLstmTestImpl(
1492 armnn::IWorkloadFactory& workloadFactory,
1493 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001494 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001495 const std::vector<uint8_t>& input,
1496 const std::vector<uint8_t>& outputExpected,
1497 const armnn::TensorShape& inputShape,
1498 const armnn::TensorShape& outputExpectedShape)
James Conroy9c3cae82019-08-01 16:01:48 +01001499{
Jan Eilers8eb25602020-03-09 12:13:48 +00001500 IgnoreUnused(memoryManager);
Sadik Armagan483c8112021-06-01 09:24:52 +01001501 auto numBatches = armnn::numeric_cast<unsigned int>(inputShape[0]);
1502 auto inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]);
1503 auto outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]);
James Conroy9c3cae82019-08-01 16:01:48 +01001504
1505 // Scale/Offset for input/output, cellState In/Out, weights, bias
1506 float inputOutputScale = 0.0078125f;
1507 int32_t inputOutputOffset = 128;
1508
1509 float cellStateScale = 0.00048828125f;
1510 int32_t cellStateOffset = 0;
1511
1512 float weightsScale = 0.00408021f;
1513 int32_t weightsOffset = 100;
1514
1515 float biasScale = 3.1876640625e-05f;
1516 int32_t biasOffset = 0;
1517
1518 // Input/Output tensor info
1519 armnn::TensorInfo inputInfo({numBatches , inputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001520 armnn::DataType::QAsymmU8,
James Conroy9c3cae82019-08-01 16:01:48 +01001521 inputOutputScale,
1522 inputOutputOffset);
1523
1524 armnn::TensorInfo cellStateInfo({numBatches , outputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001525 armnn::DataType::QSymmS16,
James Conroy9c3cae82019-08-01 16:01:48 +01001526 cellStateScale,
1527 cellStateOffset);
1528
1529 armnn::TensorInfo outputStateInfo({numBatches , outputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001530 armnn::DataType::QAsymmU8,
James Conroy9c3cae82019-08-01 16:01:48 +01001531 inputOutputScale,
1532 inputOutputOffset);
1533
James Conroy9c3cae82019-08-01 16:01:48 +01001534 // Input0
1535 std::vector<uint8_t> inputVector;
1536 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
James Conroy9c3cae82019-08-01 16:01:48 +01001537
1538 // Input1
1539 std::vector<int16_t> cellStateInVector = {876, 1034, 955, -909, 761, 1029, 796, -1036}; // 13
James Conroy9c3cae82019-08-01 16:01:48 +01001540 // Input2
1541 std::vector<uint8_t> outputStateInVector = {136, 150, 140, 115, 135, 152, 138, 112}; // 14
James Conroy9c3cae82019-08-01 16:01:48 +01001542
1543 // Output0
1544 std::vector<int16_t> cellStateOutVector = {1485, 1177, 1373, -1023, 1019, 1355, 1097, -1235}; // 0
James Conroy9c3cae82019-08-01 16:01:48 +01001545
1546 // Output1
1547 std::vector<uint8_t> outputVector; // 1
1548 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
Sadik Armagan483c8112021-06-01 09:24:52 +01001549
1550 std::vector<uint8_t> actualOutput(outputStateInfo.GetNumElements());
James Conroy9c3cae82019-08-01 16:01:48 +01001551
1552 // Create tensor handles
Finn Williamsc43de6a2020-08-27 11:13:25 +01001553 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001554 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001555 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001556 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001557 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001558
1559 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001560 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
1561 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001562
1563 armnn::QuantizedLstmQueueDescriptor data;
1564 armnn::WorkloadInfo info;
1565
1566 // Add inputs and outputs to workload
1567 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
1568 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
1569 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
1570
1571 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
1572 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
1573
1574 // Weights and bias tensor and quantization info
1575 armnn::TensorInfo inputWeightsInfo({outputSize, inputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001576 armnn::DataType::QAsymmU8,
James Conroy9c3cae82019-08-01 16:01:48 +01001577 weightsScale,
1578 weightsOffset);
1579
1580 armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001581 armnn::DataType::QAsymmU8,
James Conroy9c3cae82019-08-01 16:01:48 +01001582 weightsScale,
1583 weightsOffset);
1584
1585 armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset);
1586
1587 // Weights and bias tensor data
Sadik Armagan483c8112021-06-01 09:24:52 +01001588 std::vector<uint8_t> inputToInputWeights = {146, 250, 235, 171, 10, 218, 171, 108};
1589 std::vector<uint8_t> inputToForgetWeights = {24, 50, 132, 179, 158, 110, 3, 169};
1590 std::vector<uint8_t> inputToCellWeights = {133, 34, 29, 49, 206, 109, 54, 183};
1591 std::vector<uint8_t> inputToOutputWeights = {195, 187, 11, 99, 109, 10, 218, 48};
James Conroy9c3cae82019-08-01 16:01:48 +01001592
Sadik Armagan483c8112021-06-01 09:24:52 +01001593 std::vector<uint8_t> recurrentToInputWeights =
1594 {254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26};
1595 std::vector<uint8_t> recurrentToForgetWeights =
1596 {137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253};
1597 std::vector<uint8_t> recurrentToCellWeights =
1598 {172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216};
1599 std::vector<uint8_t> recurrentToOutputWeights =
1600 {106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98};
James Conroy9c3cae82019-08-01 16:01:48 +01001601
Sadik Armagan483c8112021-06-01 09:24:52 +01001602 std::vector<int32_t> inputGateBias = {-7876, 13488, -726, 32839};
1603 std::vector<int32_t> forgetGateBias = {9206, -46884, -11693, -38724};
1604 std::vector<int32_t> cellBias = {39481, 48624, 48976, -21419};
1605 std::vector<int32_t> outputGateBias = {-58999, -17050, -41852, -40538};
James Conroy9c3cae82019-08-01 16:01:48 +01001606
James Conroy1f58f032021-04-27 17:13:27 +01001607 // ScopedTensorHandles
1608 armnn::ScopedTensorHandle inputToInputWeightsTensor(inputWeightsInfo);
1609 armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
1610 armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
1611 armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001612
James Conroy1f58f032021-04-27 17:13:27 +01001613 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(recurrentWeightsInfo);
1614 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
1615 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
1616 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001617
James Conroy1f58f032021-04-27 17:13:27 +01001618 armnn::ScopedTensorHandle inputGateBiasTensor(biasInfo);
1619 armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
1620 armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
1621 armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001622
1623 // Allocate and copy data
Sadik Armagan483c8112021-06-01 09:24:52 +01001624 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
1625 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1626 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1627 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
James Conroy9c3cae82019-08-01 16:01:48 +01001628
Sadik Armagan483c8112021-06-01 09:24:52 +01001629 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
1630 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1631 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1632 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
James Conroy9c3cae82019-08-01 16:01:48 +01001633
Sadik Armagan483c8112021-06-01 09:24:52 +01001634 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
1635 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1636 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1637 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
James Conroy9c3cae82019-08-01 16:01:48 +01001638
1639 // Setup queue descriptor
1640 data.m_InputToInputWeights = &inputToInputWeightsTensor;
1641 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1642 data.m_InputToCellWeights = &inputToCellWeightsTensor;
1643 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1644
1645 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
1646 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1647 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1648 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1649
1650 data.m_InputGateBias = &inputGateBiasTensor;
1651 data.m_ForgetGateBias = &forgetGateBiasTensor;
1652 data.m_CellBias = &cellBiasTensor;
1653 data.m_OutputGateBias = &outputGateBiasTensor;
1654
1655 // Create workload and allocate tensor handles
1656 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQuantizedLstm(data, info);
1657 inputHandle->Allocate();
1658 outputStateInHandle->Allocate();
1659 cellStateInHandle->Allocate();
1660
1661 cellStateOutHandle->Allocate();
1662 outputHandle->Allocate();
1663
Sadik Armagan483c8112021-06-01 09:24:52 +01001664 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1665 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1666 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
James Conroy9c3cae82019-08-01 16:01:48 +01001667
1668 workload->Execute();
1669
Sadik Armagan483c8112021-06-01 09:24:52 +01001670 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
James Conroy9c3cae82019-08-01 16:01:48 +01001671
Sadik Armagan483c8112021-06-01 09:24:52 +01001672 return LayerTestResult<uint8_t, 2>(actualOutput,
1673 outputVector,
1674 outputHandle->GetShape(),
1675 outputStateInfo.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001676}
1677
James Conroyb22a75e2020-06-08 14:53:10 +01001678// QLSTM: CIFG, LayerNorm
James Conroy4f1f8992020-04-29 20:01:10 +01001679LayerTestResult<int8_t, 2> QLstmTestImpl(
1680 armnn::IWorkloadFactory& workloadFactory,
1681 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001682 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001683 const std::vector<int8_t>& input,
1684 const std::vector<int8_t>& outputExpected)
James Conroy4f1f8992020-04-29 20:01:10 +01001685{
1686 IgnoreUnused(memoryManager);
1687 unsigned int numBatches = 2;
1688 unsigned int inputSize = 5;
1689 unsigned int outputSize = 4;
1690 unsigned int numUnits = 4;
1691
1692 bool cifgEnabled = true;
1693 bool peepholeEnabled = false;
1694 bool projectionEnabled = false;
1695 bool layerNormEnabled = true;
1696
1697 // Scale/Offset quantization info
1698 float inputScale = 0.0078125f;
1699 int32_t inputOffset = 0;
1700
1701 int32_t hiddenStateZeroPoint = 0;
1702 float hiddenStateScale = 0.007f;
1703
1704 // if (!projectionEnabled) outputScale == hiddenStateScale
1705 float outputScale = hiddenStateScale;
1706 int32_t outputOffset = hiddenStateZeroPoint;
1707
1708 float cellStateScale = 3.05176e-05f;
1709 int32_t cellStateOffset = 0;
1710
1711 float weightsScale = 0.00784314f;
1712 int32_t weightsOffset = 0;
1713
1714 float layerNormScale = 3.05182e-05f;
1715 int32_t layerNormOffset = 0;
1716
1717 float biasScale = layerNormScale / 1024;
1718 int32_t biasOffset = 0;
1719
1720 float inputIntermediateScale = 0.007059f;
1721 float forgetIntermediateScale = 0.007812f;
1722 float cellIntermediateScale = inputIntermediateScale;
1723 float outputIntermediateScale = forgetIntermediateScale;
1724
1725 float cellClip = 0.0f;
1726 float projectionClip = 0.0f;
1727
1728 // Input/Output tensor info
1729 armnn::TensorInfo inputInfo({numBatches , inputSize},
1730 armnn::DataType::QAsymmS8,
1731 inputScale,
1732 inputOffset);
1733
1734 armnn::TensorInfo cellStateInfo({numBatches , numUnits},
1735 armnn::DataType::QSymmS16,
1736 cellStateScale,
1737 cellStateOffset);
1738
1739 armnn::TensorInfo outputStateInfo({numBatches , outputSize},
1740 armnn::DataType::QAsymmS8,
1741 outputScale,
1742 outputOffset);
1743
1744 LayerTestResult<int8_t, 2> ret(outputStateInfo);
1745
1746 // Input tensors
1747 std::vector<int8_t> inputVector;
1748 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
James Conroy4f1f8992020-04-29 20:01:10 +01001749
1750 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
James Conroy4f1f8992020-04-29 20:01:10 +01001751
Teresa Charlinbe727be2020-09-25 15:08:21 +01001752 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
James Conroy4f1f8992020-04-29 20:01:10 +01001753
1754 // Output tensors
Sadik Armagan483c8112021-06-01 09:24:52 +01001755 std::vector<int16_t> cellStateOutVector = {-11692, 9960, 5491, 8861, -9422, 7726, 2056, 13149};
James Conroy4f1f8992020-04-29 20:01:10 +01001756
1757 std::vector<int8_t> outputVector;
1758 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
Sadik Armagan483c8112021-06-01 09:24:52 +01001759
1760 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
James Conroy4f1f8992020-04-29 20:01:10 +01001761
1762 // Create tensor handles
Finn Williamsc43de6a2020-08-27 11:13:25 +01001763 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001764 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001765 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001766 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001767 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001768
Finn Williamsc43de6a2020-08-27 11:13:25 +01001769 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
1770 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001771 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001772 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
1773 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001774
1775 armnn::QLstmQueueDescriptor data;
1776 armnn::WorkloadInfo info;
1777
1778 // Add inputs and outputs to workload
1779 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
1780 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
1781 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
1782
1783 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
1784 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
1785 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
1786
1787 // Weights and bias tensor and quantization info
1788 armnn::TensorInfo inputWeightsInfo({outputSize, inputSize},
1789 armnn::DataType::QSymmS8,
1790 weightsScale,
1791 weightsOffset);
1792
1793 armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize},
1794 armnn::DataType::QSymmS8,
1795 weightsScale,
1796 weightsOffset);
1797
1798 armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset);
1799
1800 armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset);
1801
1802 // Weights and bias tensor data
Sadik Armagan483c8112021-06-01 09:24:52 +01001803 std::vector<int8_t> inputToForgetWeights =
1804 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
1805 std::vector<int8_t> inputToCellWeights =
1806 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
1807 std::vector<int8_t> inputToOutputWeights =
1808 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
James Conroy4f1f8992020-04-29 20:01:10 +01001809
Sadik Armagan483c8112021-06-01 09:24:52 +01001810 std::vector<int8_t> recurrentToForgetWeights =
1811 {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25, 25, 38, -13, 51};
1812 std::vector<int8_t> recurrentToCellWeights =
1813 {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25, 38, -13, 25, 64};
1814 std::vector<int8_t> recurrentToOutputWeights =
1815 {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25, 13, 64, 25, -38};
James Conroy4f1f8992020-04-29 20:01:10 +01001816
Sadik Armagan483c8112021-06-01 09:24:52 +01001817 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
1818 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
1819 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
James Conroy4f1f8992020-04-29 20:01:10 +01001820
Sadik Armagan483c8112021-06-01 09:24:52 +01001821 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
1822 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
1823 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
James Conroy4f1f8992020-04-29 20:01:10 +01001824
James Conroy1f58f032021-04-27 17:13:27 +01001825 // ScopedTensorHandles
1826 armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
1827 armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
1828 armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001829
James Conroy1f58f032021-04-27 17:13:27 +01001830 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
1831 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
1832 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001833
James Conroy1f58f032021-04-27 17:13:27 +01001834 armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
1835 armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
1836 armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001837
James Conroy1f58f032021-04-27 17:13:27 +01001838 armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo);
1839 armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo);
1840 armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001841
1842 // Allocate and copy data
Sadik Armagan483c8112021-06-01 09:24:52 +01001843 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1844 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1845 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001846
Sadik Armagan483c8112021-06-01 09:24:52 +01001847 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1848 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1849 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001850
Sadik Armagan483c8112021-06-01 09:24:52 +01001851 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1852 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1853 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001854
Sadik Armagan483c8112021-06-01 09:24:52 +01001855 AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
1856 AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
1857 AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001858
1859 // Setup queue descriptor
1860 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1861 data.m_InputToCellWeights = &inputToCellWeightsTensor;
1862 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1863
1864 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1865 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1866 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1867
1868 data.m_ForgetGateBias = &forgetGateBiasTensor;
1869 data.m_CellBias = &cellBiasTensor;
1870 data.m_OutputGateBias = &outputGateBiasTensor;
1871
1872 data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
1873 data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
1874 data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
1875
1876 data.m_Parameters.m_CifgEnabled = cifgEnabled;
1877 data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
1878 data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
1879 data.m_Parameters.m_LayerNormEnabled = layerNormEnabled;
1880
1881 data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale;
1882 data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale;
1883 data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale;
1884 data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale;
1885
1886 data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint;
1887 data.m_Parameters.m_HiddenStateScale = hiddenStateScale;
1888
1889 data.m_Parameters.m_CellClip = cellClip;
1890 data.m_Parameters.m_ProjectionClip = projectionClip;
1891
1892 // Create workload and allocate tensor handles
1893 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info);
1894 inputHandle->Allocate();
1895 outputStateInHandle->Allocate();
1896 cellStateInHandle->Allocate();
1897
1898 outputStateOutHandle->Allocate();
1899 cellStateOutHandle->Allocate();
1900 outputHandle->Allocate();
1901
Sadik Armagan483c8112021-06-01 09:24:52 +01001902 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1903 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1904 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001905
1906 workload->Execute();
1907
Sadik Armagan483c8112021-06-01 09:24:52 +01001908 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
James Conroy4f1f8992020-04-29 20:01:10 +01001909
Sadik Armagan483c8112021-06-01 09:24:52 +01001910 return LayerTestResult<int8_t, 2>(actualOutput,
1911 outputVector,
1912 outputHandle->GetShape(),
1913 outputStateInfo.GetShape());
James Conroy4f1f8992020-04-29 20:01:10 +01001914}
1915
James Conroyb22a75e2020-06-08 14:53:10 +01001916// QLSTM: Projection, LayerNorm
1917LayerTestResult<int8_t, 2> QLstmTestImpl1(
1918 armnn::IWorkloadFactory& workloadFactory,
1919 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001920 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001921 const std::vector<int8_t>& input,
1922 const std::vector<int8_t>& outputExpected)
James Conroyb22a75e2020-06-08 14:53:10 +01001923{
1924 IgnoreUnused(memoryManager);
1925 unsigned int numBatches = 2;
1926 unsigned int inputSize = 5;
1927 unsigned int outputSize = 3;
1928 unsigned int numUnits = 4;
1929
1930 bool cifgEnabled = false;
1931 bool peepholeEnabled = false;
1932 bool projectionEnabled = true;
1933 bool layerNormEnabled = true;
1934
1935 // Scale/Offset quantization info
1936 float inputScale = 0.0078125f;
1937 int32_t inputOffset = 0;
1938
1939 int32_t hiddenStateZeroPoint = 0;
1940 float hiddenStateScale = 0.007f;
1941
1942 // if (!projectionEnabled) outputScale == hiddenStateScale
1943 float outputScale = 3.05176e-05f;
1944 int32_t outputOffset = 0;
1945
1946 float cellStateScale = 3.05176e-05f;
1947 int32_t cellStateOffset = 0;
1948
1949 float weightsScale = 0.00784314f;
1950 int32_t weightsOffset = 0;
1951
1952 float layerNormScale = 3.05182e-05f;
1953 int32_t layerNormOffset = 0;
1954
1955 float biasScale = layerNormScale / 1024;
1956 int32_t biasOffset = 0;
1957
1958 float projectionWeightsScale = 0.00392157f;
1959
1960 float inputIntermediateScale = 0.007059f;
1961 float forgetIntermediateScale = 0.007812f;
1962 float cellIntermediateScale = inputIntermediateScale;
1963 float outputIntermediateScale = forgetIntermediateScale;
1964
1965 float cellClip = 0.0f;
1966 float projectionClip = 0.0f;
1967
1968 // Input/Output tensor info
1969 armnn::TensorInfo inputInfo({numBatches , inputSize},
1970 armnn::DataType::QAsymmS8,
1971 inputScale,
1972 inputOffset);
1973
1974 armnn::TensorInfo cellStateInfo({numBatches , numUnits},
1975 armnn::DataType::QSymmS16,
1976 cellStateScale,
1977 cellStateOffset);
1978
1979 armnn::TensorInfo outputStateInfo({numBatches , outputSize},
1980 armnn::DataType::QAsymmS8,
1981 outputScale,
1982 outputOffset);
1983
James Conroyb22a75e2020-06-08 14:53:10 +01001984 // Input tensors
1985 std::vector<int8_t> inputVector;
1986 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
James Conroyb22a75e2020-06-08 14:53:10 +01001987
1988 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
James Conroyb22a75e2020-06-08 14:53:10 +01001989
1990 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0};
James Conroyb22a75e2020-06-08 14:53:10 +01001991
1992 // Output tensors
1993 std::vector<int16_t> cellStateOutVector = {-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939};
James Conroyb22a75e2020-06-08 14:53:10 +01001994
1995 std::vector<int8_t> outputVector;
1996 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
Sadik Armagan483c8112021-06-01 09:24:52 +01001997
1998 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
James Conroyb22a75e2020-06-08 14:53:10 +01001999
2000 // Create tensor handles
Finn Williamsc43de6a2020-08-27 11:13:25 +01002001 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002002 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002003 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002004 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002005 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002006
Finn Williamsc43de6a2020-08-27 11:13:25 +01002007 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
2008 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002009 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002010 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
2011 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002012
2013 armnn::QLstmQueueDescriptor data;
2014 armnn::WorkloadInfo info;
2015
2016 // Add inputs and outputs to workload
2017 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
2018 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
2019 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
2020
2021 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
2022 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
2023 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
2024
2025 // Weights and bias tensor and quantization info
2026 armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
2027 armnn::DataType::QSymmS8,
2028 weightsScale,
2029 weightsOffset);
2030
2031 armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
2032 armnn::DataType::QSymmS8,
2033 weightsScale,
2034 weightsOffset);
2035
2036 armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset);
2037
2038 armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset);
2039
2040 armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits},
2041 armnn::DataType::QSymmS8,
2042 projectionWeightsScale,
2043 0);
2044
2045 // Weights and bias tensor data
Sadik Armagan483c8112021-06-01 09:24:52 +01002046 std::vector<int8_t> inputToInputWeights =
2047 {64, 77, 89, -102, -115, 13, 25, 38, -51, 64, -102, 89, -77, 64, -51, -64, -51, -38, -25, -13};
2048 std::vector<int8_t> inputToForgetWeights =
2049 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
2050 std::vector<int8_t> inputToCellWeights =
2051 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
2052 std::vector<int8_t> inputToOutputWeights =
2053 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
James Conroyb22a75e2020-06-08 14:53:10 +01002054
Sadik Armagan483c8112021-06-01 09:24:52 +01002055 std::vector<int8_t> recurrentToInputWeights = {-25, -38, 51, 13, -64, 115, -25, -38, -89, 6, -25, -77};
2056 std::vector<int8_t> recurrentToForgetWeights = {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25};
2057 std::vector<int8_t> recurrentToCellWeights = {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25};
2058 std::vector<int8_t> recurrentToOutputWeights = {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25};
James Conroyb22a75e2020-06-08 14:53:10 +01002059
Sadik Armagan483c8112021-06-01 09:24:52 +01002060 std::vector<int32_t> inputGateBias = {644245, 3221226, 4724464, 8160438};
2061 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
2062 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
2063 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
James Conroyb22a75e2020-06-08 14:53:10 +01002064
Sadik Armagan483c8112021-06-01 09:24:52 +01002065 std::vector<int16_t> inputLayerNormWeights = {3277, 6553, 9830, 16384};
2066 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
2067 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
2068 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
James Conroyb22a75e2020-06-08 14:53:10 +01002069
Sadik Armagan483c8112021-06-01 09:24:52 +01002070 std::vector<int8_t> projectionWeights = {-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51};
James Conroyb22a75e2020-06-08 14:53:10 +01002071
James Conroy1f58f032021-04-27 17:13:27 +01002072 // ScopedTensorHandles
2073 armnn::ScopedTensorHandle inputToInputWeightsTensor(inputWeightsInfo);
2074 armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
2075 armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
2076 armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002077
James Conroy1f58f032021-04-27 17:13:27 +01002078 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(recurrentWeightsInfo);
2079 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
2080 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
2081 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002082
James Conroy1f58f032021-04-27 17:13:27 +01002083 armnn::ScopedTensorHandle inputGateBiasTensor(biasInfo);
2084 armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
2085 armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
2086 armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002087
James Conroy1f58f032021-04-27 17:13:27 +01002088 armnn::ScopedTensorHandle inputLayerNormWeightsTensor(layerNormWeightsInfo);
2089 armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo);
2090 armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo);
2091 armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002092
James Conroy1f58f032021-04-27 17:13:27 +01002093 armnn::ScopedTensorHandle projectionWeightsTensor(projectionWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002094
2095 // Allocate and copy data
Sadik Armagan483c8112021-06-01 09:24:52 +01002096 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
2097 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
2098 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
2099 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002100
Sadik Armagan483c8112021-06-01 09:24:52 +01002101 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
2102 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
2103 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
2104 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002105
Sadik Armagan483c8112021-06-01 09:24:52 +01002106 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
2107 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
2108 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
2109 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002110
Sadik Armagan483c8112021-06-01 09:24:52 +01002111 AllocateAndCopyDataToITensorHandle(&inputLayerNormWeightsTensor, inputLayerNormWeights.data());
2112 AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
2113 AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
2114 AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002115
Sadik Armagan483c8112021-06-01 09:24:52 +01002116 AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002117
2118 // Setup queue descriptor
2119 data.m_InputToInputWeights = &inputToInputWeightsTensor;
2120 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
2121 data.m_InputToCellWeights = &inputToCellWeightsTensor;
2122 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
2123
2124 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
2125 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
2126 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
2127 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
2128
2129 data.m_InputGateBias = &inputGateBiasTensor;
2130 data.m_ForgetGateBias = &forgetGateBiasTensor;
2131 data.m_CellBias = &cellBiasTensor;
2132 data.m_OutputGateBias = &outputGateBiasTensor;
2133
2134 data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor;
2135 data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
2136 data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
2137 data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
2138
2139 data.m_ProjectionWeights = &projectionWeightsTensor;
2140
2141 data.m_Parameters.m_CifgEnabled = cifgEnabled;
2142 data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
2143 data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
2144 data.m_Parameters.m_LayerNormEnabled = layerNormEnabled;
2145
2146 data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale;
2147 data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale;
2148 data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale;
2149 data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale;
2150
2151 data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint;
2152 data.m_Parameters.m_HiddenStateScale = hiddenStateScale;
2153
2154 data.m_Parameters.m_CellClip = cellClip;
2155 data.m_Parameters.m_ProjectionClip = projectionClip;
2156
2157 // Create workload and allocate tensor handles
2158 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info);
2159 inputHandle->Allocate();
2160 outputStateInHandle->Allocate();
2161 cellStateInHandle->Allocate();
2162
2163 outputStateOutHandle->Allocate();
2164 cellStateOutHandle->Allocate();
2165 outputHandle->Allocate();
2166
Sadik Armagan483c8112021-06-01 09:24:52 +01002167 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
2168 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
2169 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002170
2171 workload->Execute();
2172
Sadik Armagan483c8112021-06-01 09:24:52 +01002173 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
James Conroyb22a75e2020-06-08 14:53:10 +01002174
Sadik Armagan483c8112021-06-01 09:24:52 +01002175 return LayerTestResult<int8_t, 2>(actualOutput,
2176 outputVector,
2177 outputHandle->GetShape(),
2178 outputStateInfo.GetShape());
James Conroyb22a75e2020-06-08 14:53:10 +01002179}
2180
2181// QLSTM: Projection, CIFG, LayerNorm
2182LayerTestResult<int8_t, 2> QLstmTestImpl2(
2183 armnn::IWorkloadFactory& workloadFactory,
2184 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002185 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01002186 const std::vector<int8_t>& input,
2187 const std::vector<int8_t>& outputExpected)
James Conroyb22a75e2020-06-08 14:53:10 +01002188{
2189 IgnoreUnused(memoryManager);
2190 unsigned int numBatches = 2;
2191 unsigned int inputSize = 5;
2192 unsigned int outputSize = 3;
2193 unsigned int numUnits = 4;
2194
2195 bool cifgEnabled = true;
2196 bool peepholeEnabled = false;
2197 bool projectionEnabled = true;
2198 bool layerNormEnabled = true;
2199
2200 // Scale/Offset quantization info
2201 float inputScale = 0.0078125f;
2202 int32_t inputOffset = 0;
2203
2204 int32_t hiddenStateZeroPoint = 0;
2205 float hiddenStateScale = 0.007f;
2206
2207 // if (!projectionEnabled) outputScale == hiddenStateScale
2208 float outputScale = 3.05176e-05f;
2209 int32_t outputOffset = 0;
2210
2211 float cellStateScale = 3.05176e-05f;
2212 int32_t cellStateOffset = 0;
2213
2214 float weightsScale = 0.00784314f;
2215 int32_t weightsOffset = 0;
2216
2217 float layerNormScale = 3.05182e-05f;
2218 int32_t layerNormOffset = 0;
2219
2220 float biasScale = layerNormScale / 1024;
2221 int32_t biasOffset = 0;
2222
2223 float projectionWeightsScale = 0.00392157f;
2224
2225 float inputIntermediateScale = 0.007059f;
2226 float forgetIntermediateScale = 0.007812f;
2227 float cellIntermediateScale = inputIntermediateScale;
2228 float outputIntermediateScale = forgetIntermediateScale;
2229
2230 float cellClip = 0.0f;
2231 float projectionClip = 0.0f;
2232
2233 // Input/Output tensor info
2234 armnn::TensorInfo inputInfo({numBatches , inputSize},
2235 armnn::DataType::QAsymmS8,
2236 inputScale,
2237 inputOffset);
2238
2239 armnn::TensorInfo cellStateInfo({numBatches , numUnits},
2240 armnn::DataType::QSymmS16,
2241 cellStateScale,
2242 cellStateOffset);
2243
2244 armnn::TensorInfo outputStateInfo({numBatches , outputSize},
2245 armnn::DataType::QAsymmS8,
2246 outputScale,
2247 outputOffset);
2248
James Conroyb22a75e2020-06-08 14:53:10 +01002249 // Input tensors
2250 std::vector<int8_t> inputVector;
2251 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
James Conroyb22a75e2020-06-08 14:53:10 +01002252
2253 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
James Conroyb22a75e2020-06-08 14:53:10 +01002254
2255 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0};
James Conroyb22a75e2020-06-08 14:53:10 +01002256
2257 // Output tensors
Sadik Armagan483c8112021-06-01 09:24:52 +01002258 std::vector<int16_t> cellStateOutVector = {-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939};
James Conroyb22a75e2020-06-08 14:53:10 +01002259
2260 std::vector<int8_t> outputVector;
2261 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
Sadik Armagan483c8112021-06-01 09:24:52 +01002262
2263 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
James Conroyb22a75e2020-06-08 14:53:10 +01002264
2265 // Create tensor handles
Finn Williamsc43de6a2020-08-27 11:13:25 +01002266 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002267 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002268 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002269 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002270 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002271
Finn Williamsc43de6a2020-08-27 11:13:25 +01002272 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
2273 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002274 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002275 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
2276 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002277
2278 armnn::QLstmQueueDescriptor data;
2279 armnn::WorkloadInfo info;
2280
2281 // Add inputs and outputs to workload
2282 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
2283 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
2284 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
2285
2286 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
2287 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
2288 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
2289
2290 // Weights and bias tensor and quantization info
2291 armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
2292 armnn::DataType::QSymmS8,
2293 weightsScale,
2294 weightsOffset);
2295
2296 armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
2297 armnn::DataType::QSymmS8,
2298 weightsScale,
2299 weightsOffset);
2300
2301 armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset);
2302
2303 armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset);
2304
2305 armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits},
2306 armnn::DataType::QSymmS8,
2307 projectionWeightsScale,
2308 0);
2309
2310 // Weights and bias tensor data
Sadik Armagan483c8112021-06-01 09:24:52 +01002311 std::vector<int8_t> inputToForgetWeights =
2312 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
2313 std::vector<int8_t> inputToCellWeights =
2314 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
2315 std::vector<int8_t> inputToOutputWeights =
2316 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
James Conroyb22a75e2020-06-08 14:53:10 +01002317
Sadik Armagan483c8112021-06-01 09:24:52 +01002318 std::vector<int8_t> recurrentToForgetWeights =
2319 {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25};
2320 std::vector<int8_t> recurrentToCellWeights =
2321 {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25};
2322 std::vector<int8_t> recurrentToOutputWeights =
2323 {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25};
James Conroyb22a75e2020-06-08 14:53:10 +01002324
Sadik Armagan483c8112021-06-01 09:24:52 +01002325 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
2326 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
2327 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
James Conroyb22a75e2020-06-08 14:53:10 +01002328
Sadik Armagan483c8112021-06-01 09:24:52 +01002329 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
2330 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
2331 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
James Conroyb22a75e2020-06-08 14:53:10 +01002332
Sadik Armagan483c8112021-06-01 09:24:52 +01002333 std::vector<int8_t> projectionWeights = {-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51};
James Conroyb22a75e2020-06-08 14:53:10 +01002334
James Conroy1f58f032021-04-27 17:13:27 +01002335 // ScopedTensorHandles
2336 armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
2337 armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
2338 armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002339
James Conroy1f58f032021-04-27 17:13:27 +01002340 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
2341 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
2342 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002343
James Conroy1f58f032021-04-27 17:13:27 +01002344 armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
2345 armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
2346 armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002347
James Conroy1f58f032021-04-27 17:13:27 +01002348 armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo);
2349 armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo);
2350 armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002351
James Conroy1f58f032021-04-27 17:13:27 +01002352 armnn::ScopedTensorHandle projectionWeightsTensor(projectionWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002353
2354 // Allocate and copy data
Sadik Armagan483c8112021-06-01 09:24:52 +01002355 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
2356 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
2357 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002358
Sadik Armagan483c8112021-06-01 09:24:52 +01002359 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
2360 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
2361 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002362
Sadik Armagan483c8112021-06-01 09:24:52 +01002363 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
2364 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
2365 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002366
Sadik Armagan483c8112021-06-01 09:24:52 +01002367 AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
2368 AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
2369 AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002370
Sadik Armagan483c8112021-06-01 09:24:52 +01002371 AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002372
2373 // Setup queue descriptor
2374 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
2375 data.m_InputToCellWeights = &inputToCellWeightsTensor;
2376 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
2377
2378 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
2379 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
2380 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
2381
2382 data.m_ForgetGateBias = &forgetGateBiasTensor;
2383 data.m_CellBias = &cellBiasTensor;
2384 data.m_OutputGateBias = &outputGateBiasTensor;
2385
2386 data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
2387 data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
2388 data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
2389
2390 data.m_ProjectionWeights = &projectionWeightsTensor;
2391
2392 data.m_Parameters.m_CifgEnabled = cifgEnabled;
2393 data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
2394 data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
2395 data.m_Parameters.m_LayerNormEnabled = layerNormEnabled;
2396
2397 data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale;
2398 data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale;
2399 data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale;
2400 data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale;
2401
2402 data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint;
2403 data.m_Parameters.m_HiddenStateScale = hiddenStateScale;
2404
2405 data.m_Parameters.m_CellClip = cellClip;
2406 data.m_Parameters.m_ProjectionClip = projectionClip;
2407
2408 // Create workload and allocate tensor handles
2409 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info);
2410 inputHandle->Allocate();
2411 outputStateInHandle->Allocate();
2412 cellStateInHandle->Allocate();
2413
2414 outputStateOutHandle->Allocate();
2415 cellStateOutHandle->Allocate();
2416 outputHandle->Allocate();
2417
Sadik Armagan483c8112021-06-01 09:24:52 +01002418 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
2419 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
2420 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002421
2422 workload->Execute();
2423
Sadik Armagan483c8112021-06-01 09:24:52 +01002424 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
James Conroyb22a75e2020-06-08 14:53:10 +01002425
Sadik Armagan483c8112021-06-01 09:24:52 +01002426 return LayerTestResult<int8_t, 2>(actualOutput,
2427 outputVector,
2428 outputHandle->GetShape(),
2429 outputStateInfo.GetShape());
James Conroyb22a75e2020-06-08 14:53:10 +01002430}
2431
James Conroy4f1f8992020-04-29 20:01:10 +01002432
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002433} // anonymous namespace
2434
2435#if defined(ARMNNREF_ENABLED)
2436
2437// The LSTM test units are run only for the reference backend at the moment
2438
2439void LstmUtilsZeroVectorTest()
2440{
2441 armnn::TensorInfo inputDesc({4}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002442 std::vector<float> input = {2., 3., 3., 4.};
2443 std::vector<float> expectedOutput = {0., 0., 0., 0.};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002444
Sadik Armagan483c8112021-06-01 09:24:52 +01002445 return LstmUtilsZeroVectorTestImpl<armnn::DataType::Float32>(input, 4, expectedOutput, inputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002446}
2447
2448void LstmUtilsMeanStddevNormalizationNoneZeroInputTest()
2449{
2450 uint32_t batchSize = 2;
2451 uint32_t vecSize = 4;
2452 armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002453 std::vector<float> input =
2454 { 0.1f, 0.2f, 0.3f, 0.4f, //batch 0
2455 0.9f, 1.0f, 1.1f, 1.2f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002456
Sadik Armagan483c8112021-06-01 09:24:52 +01002457 std::vector<float> expectedOutput =
2458 { -1.34164071f, -0.447213531f, 0.44721365f, 1.34164071f, //batch 0
2459 -1.34163153f, -0.447210163f, 0.447211236f, 1.3416326f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002460
2461 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
Sadik Armagan483c8112021-06-01 09:24:52 +01002462 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002463}
2464
2465void LstmUtilsMeanStddevNormalizationAllZeroInputTest()
2466{
2467 uint32_t batchSize = 2;
2468 uint32_t vecSize = 4;
2469 armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002470 std::vector<float> input =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002471 { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0
Sadik Armagan483c8112021-06-01 09:24:52 +01002472 0.0f, 0.0f, 0.0f, 0.0f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002473
Sadik Armagan483c8112021-06-01 09:24:52 +01002474 std::vector<float> expectedOutput =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002475 { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0
Sadik Armagan483c8112021-06-01 09:24:52 +01002476 0.0f, 0.0f, 0.0f, 0.0f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002477
2478 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
Sadik Armagan483c8112021-06-01 09:24:52 +01002479 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002480}
2481
2482void LstmUtilsMeanStddevNormalizationMixedZeroInputTest()
2483{
2484 uint32_t batchSize = 2;
2485 uint32_t vecSize = 4;
2486 armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002487 std::vector<float> input =
2488 { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0
2489 0.1f, 0.2f, 0.3f, 0.4f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002490
Sadik Armagan483c8112021-06-01 09:24:52 +01002491 std::vector<float> expectedOutput =
2492 { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0
2493 -1.34164071f, -0.447213531f, 0.44721365f, 1.34164071f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002494
2495 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
Sadik Armagan483c8112021-06-01 09:24:52 +01002496 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002497}
2498
2499void LstmUtilsVectorBatchVectorCwiseProductTest()
2500{
2501 uint32_t batchSize = 4;
2502 uint32_t vecSize = 29;
2503 armnn::TensorInfo vecDesc({vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002504 std::vector<float> vector =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002505 { 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.1f,
2506 11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.2f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002507 21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f, 0.0f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002508
2509 armnn::TensorInfo batchVecDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002510 std::vector<float> batchVector =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002511 { /* batch 0 */
2512 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.1f,
2513 11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.2f,
2514 21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f, 0.0f,
2515 /* batch 1 */
2516 -1.1f, -2.2f, -3.3f, -4.4f, -5.5f, -6.6f, -7.7f, -8.8f, -9.9f, -10.1f,
2517 -11.11f, -12.12f, -13.13f, -14.14f, -15.15f, -16.16f, -17.17f, -18.18f, -19.19f, -20.2f,
2518 -21.21f, -22.22f, -23.23f, -24.24f, -25.25f, -26.26f, -27.27f, -28.28f, 0.0f,
2519 /* batch 2 */
2520 1.1f, -2.2f, 3.3f, -4.4f, 5.5f, -6.6f, 7.7f, -8.8f, 9.9f, -10.1f,
2521 11.11f, -12.12f, 13.13f, -14.14f, 15.15f, -16.16f, 17.17f, -18.18f, 19.19f, -20.2f,
2522 21.21f, -22.22f, 23.23f, -24.24f, 25.25f, -26.26f, 27.27f, -28.28f, 0.0f,
2523 /* batch 3 */
2524 -1.1f, 2.2f, -3.3f, 4.4f, -5.5f, 6.6f, -7.7f, 8.8f, -9.9f, 10.1f,
2525 -11.11f, 12.12f, -13.13f, 14.14f, -15.15f, 16.16f, -17.17f, 18.18f, -19.19f, 20.2f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002526 -21.21f, 22.22f, -23.23f, 24.24f, -25.25f, 26.26f, -27.27f, 28.28f, 0.0f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002527
2528 // Expect output = input * output + output.
Sadik Armagan483c8112021-06-01 09:24:52 +01002529 std::vector<float> expectedOutput =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002530 { /* batch 0 */
2531 1.210000f, 4.840000f, 10.889999f, 19.360001f, 30.250000f, 43.559998f,
2532 59.289997f, 77.440002f, 98.009995f, 102.010010f, 123.432091f, 146.894394f,
2533 172.396896f, 199.939606f, 229.522491f, 261.145599f, 294.808899f, 330.512421f,
2534 368.256134f, 408.040039f, 449.864075f, 493.728363f, 539.632874f, 587.577576f,
2535 637.562500f, 689.587585f, 743.652954f, 799.758423f, 0.000000f,
2536 /* batch 1 */
2537 -1.210000f, -4.840000f, -10.889999f, -19.360001f, -30.250000f, -43.559998f,
2538 -59.289997f, -77.440002f, -98.009995f, -102.010010f, -123.432091f, -146.894394f,
2539 -172.396896f, -199.939606f, -229.522491f, -261.145599f, -294.808899f, -330.512421f,
2540 -368.256134f, -408.040039f, -449.864075f, -493.728363f, -539.632874f, -587.577576f,
2541 -637.562500f, -689.587585f, -743.652954f, -799.758423f, 0.000000f,
2542 /* batch 2 */
2543 1.210000f, -4.840000f, 10.889999f, -19.360001f, 30.250000f, -43.559998f,
2544 59.289997f, -77.440002f, 98.009995f, -102.010010f, 123.432091f, -146.894394f,
2545 172.396896f, -199.939606f, 229.522491f, -261.145599f, 294.808899f, -330.512421f,
2546 368.256134f, -408.040039f, 449.864075f, -493.728363f, 539.632874f, -587.577576f,
2547 637.562500f, -689.587585f, 743.652954f, -799.758423f, 0.000000f,
2548 /* batch 3 */
2549 -1.210000f, 4.840000f, -10.889999f, 19.360001f, -30.250000f, 43.559998f,
2550 -59.289997f, 77.440002f, -98.009995f, 102.010010f, -123.432091f, 146.894394f,
2551 -172.396896f, 199.939606f, -229.522491f, 261.145599f, -294.808899f, 330.512421f,
2552 -368.256134f, 408.040039f, -449.864075f, 493.728363f, -539.632874f, 587.577576f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002553 -637.562500f, 689.587585f, -743.652954f, 799.758423f, 0.000000f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002554
2555 return LstmUtilsVectorBatchVectorCwiseProductTestImpl<armnn::DataType::Float32>(vector, batchVector,
Sadik Armagan483c8112021-06-01 09:24:52 +01002556 vecSize, batchSize, expectedOutput, vecDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002557}
2558
2559void LstmUtilsVectorBatchVectorAddTest()
2560{
2561 uint32_t batchSize = 2;
2562 uint32_t vecSize = 3;
2563 armnn::TensorInfo vecDesc({vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002564 std::vector<float> vector = { 0.0f, -0.5f, 1.0f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002565
2566 armnn::TensorInfo batchVecDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002567 std::vector<float> batchVector =
2568 {
2569 1.0f, 2.0f, 3.0f, //batch 0
2570 4.0f, 5.0f, 6.0f //batch 1
2571 };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002572
Sadik Armagan483c8112021-06-01 09:24:52 +01002573 std::vector<float> expectedOutput =
2574 {
2575 1.0f, 1.5f, 4.0f,
2576 4.0f, 4.5f, 7.0f
2577 };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002578
2579 return LstmUtilsVectorBatchVectorAddTestImpl<armnn::DataType::Float32>(vector, batchVector,
Sadik Armagan483c8112021-06-01 09:24:52 +01002580 vecSize, batchSize, expectedOutput, batchVecDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002581}
2582
2583#endif
2584
2585LayerTestResult<float, 2> LstmLayerFloat32WithCifgWithPeepholeNoProjectionTest(
2586 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002587 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2588 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002589{
2590 armnn::TensorInfo inputDesc({ 2, 2 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002591 std::vector<float> input = { 2., 3., 3., 4. };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002592
2593 armnn::TensorInfo outputDesc({ 2, 4 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002594 std::vector<float> expectedOutput =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002595 {-0.36444446f, -0.00352185f, 0.12886585f, -0.05163646f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002596 -0.42734814f, -0.00478661f, 0.13455015f, -0.03560682f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002597 return LstmLayerWithCifgWithPeepholeNoProjectionTestImpl<armnn::DataType::Float32>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002598 workloadFactory, memoryManager, tensorHandleFactory,
2599 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002600}
2601
2602LayerTestResult<float, 2> LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest(
2603 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002604 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2605 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002606{
2607 armnn::TensorInfo inputDesc({ 2, 5 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002608 std::vector<float> input =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002609 {0.787926f, 0.151646f, 0.071352f, 0.118426f, 0.458058f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002610 0.295743f, 0.544053f, 0.690064f, 0.858138f, 0.497181f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002611
2612 armnn::TensorInfo outputDesc({ 2, 16 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002613 std::vector<float> expectedOutput =
2614 {-0.00396806f, 0.029352f, -0.00279226f, 0.0159977f, -0.00835576f,
2615 -0.0211779f, 0.0283512f, -0.0114597f, 0.00907307f, -0.0244004f,
2616 -0.0152191f, -0.0259063f, 0.00914318f, 0.00415118f, 0.017147f,
2617 0.0134203f, -0.013869f, 0.0287268f, -0.00334693f, 0.00733398f, -0.0287926f,
2618 -0.0186926f, 0.0193662f, -0.0115437f, 0.00422612f, -0.0345232f,
2619 0.00223253f, -0.00957321f, 0.0210624f, 0.013331f, 0.0150954f, 0.02168f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002620 return LstmLayerNoCifgWithPeepholeWithProjectionTestImpl<armnn::DataType::Float32>(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002621 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002622}
2623
2624LayerTestResult<float, 2> LstmLayerFloat32NoCifgNoPeepholeNoProjectionTest(
2625 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002626 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2627 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002628{
2629 armnn::TensorInfo inputDesc({2, 2}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002630 std::vector<float> input = {2., 3., 3., 4.};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002631
2632 armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002633 std::vector<float> expectedOutput =
2634 {-0.02973187f, 0.1229473f, 0.20885126f, -0.15358765f,
2635 -0.0185422f, 0.11281417f, 0.24466537f, -0.1826292f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002636
2637 return LstmNoCifgNoPeepholeNoProjectionTestImpl<armnn::DataType::Float32>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002638 workloadFactory, memoryManager, tensorHandleFactory,
2639 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002640}
2641
2642LayerTestResult<float, 2> LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002643 armnn::IWorkloadFactory& workloadFactory,
2644 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2645 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002646{
2647 armnn::TensorInfo inputDesc({ 2, 5 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002648 std::vector<float> input =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002649 {0.7f, 0.8f, 0.1f, 0.2f, 0.3f, //batch 0
Sadik Armagan483c8112021-06-01 09:24:52 +01002650 0.3f, 0.2f, 0.9f, 0.8f, 0.1f}; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002651
2652 armnn::TensorInfo outputDesc({ 2, 3 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002653 std::vector<float> expectedOutput =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002654 { 0.0244077f, 0.128027f, -0.00170918f, //batch 0
Sadik Armagan483c8112021-06-01 09:24:52 +01002655 -0.00692428f, 0.0848741f, 0.063445f}; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002656 return LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl<armnn::DataType::Float32>(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002657 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002658}
2659
2660LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgNoPeepholeNoProjectionTest(
2661 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002662 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2663 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002664{
2665 const float qScale = 1.0f;
2666 const int32_t qOffset = 0;
2667
Derek Lambertif90c56d2020-01-10 17:14:08 +00002668 const armnn::DataType datatype = armnn::DataType::QSymmS16;
2669 const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8;
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002670
2671 armnn::TensorInfo inputDesc({2, 2}, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002672 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002673
2674 armnn::TensorInfo outputDesc({2, 4}, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002675 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2676 {
2677 -0.02973187f, 0.12294730f, 0.20885126f, -0.15358765f,
2678 -0.01854220f, 0.11281417f, 0.24466537f, -0.18262920f
2679 },
2680 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002681
2682 return LstmNoCifgNoPeepholeNoProjectionTestImpl<datatype>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002683 workloadFactory, memoryManager, tensorHandleFactory,
2684 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2685 qScale, qOffset, constantDatatype);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002686
2687}
2688
2689LayerTestResult<int16_t, 2> LstmLayerInt16WithCifgWithPeepholeNoProjectionTest(
2690 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002691 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2692 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002693{
2694 const float qScale = 1.0f;
2695 const int32_t qOffset = 0;
2696
Derek Lambertif90c56d2020-01-10 17:14:08 +00002697 const armnn::DataType datatype = armnn::DataType::QSymmS16;
2698 const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8;
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002699
2700 armnn::TensorInfo inputDesc({ 2, 2 }, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002701 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002702
2703 armnn::TensorInfo outputDesc({ 2, 4 }, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002704 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2705 {
2706 -0.36444446f, -0.00352185f, 0.12886585f, -0.05163646f,
2707 -0.42734814f, -0.00478661f, 0.13455015f, -0.03560682f
2708 },
2709 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002710
2711 return LstmLayerWithCifgWithPeepholeNoProjectionTestImpl<datatype>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002712 workloadFactory, memoryManager, tensorHandleFactory,
2713 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2714 qScale, qOffset, constantDatatype);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002715}
2716
2717LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgWithPeepholeWithProjectionTest(
2718 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002719 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2720 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002721{
2722 const float qScale = 2.0f;
2723 const int32_t qOffset = 0;
2724
Derek Lambertif90c56d2020-01-10 17:14:08 +00002725 const armnn::DataType datatype = armnn::DataType::QSymmS16;
2726 const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8;
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002727
2728 armnn::TensorInfo inputDesc({ 2, 5 }, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002729 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>(
2730 {
2731 0.787926f, 0.151646f, 0.071352f, 0.118426f, 0.458058f,
2732 0.295743f, 0.544053f, 0.690064f, 0.858138f, 0.497181f
2733 },
2734 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002735
2736 armnn::TensorInfo outputDesc({ 2, 16 }, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002737 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2738 {
2739 -0.00396806f, 0.02935200f, -0.00279226f, 0.01599770f,
2740 -0.00835576f, -0.02117790f, 0.02835120f, -0.01145970f,
2741 0.00907307f, -0.02440040f, -0.01521910f, -0.02590630f,
2742 0.00914318f, 0.00415118f, 0.01714700f, 0.01342030f,
2743 -0.01386900f, 0.02872680f, -0.00334693f, 0.00733398f,
2744 -0.02879260f, -0.01869260f, 0.01936620f, -0.01154370f,
2745 0.00422612f, -0.03452320f, 0.00223253f, -0.00957321f,
2746 0.02106240f, 0.01333100f, 0.01509540f, 0.02168000f
2747 },
2748 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002749
2750 return LstmLayerNoCifgWithPeepholeWithProjectionTestImpl<datatype>(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002751 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput, qScale, qOffset, constantDatatype);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002752}
2753
2754LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgNoPeepholeNoProjectionInt16ConstantTest(
2755 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002756 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2757 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002758{
2759 const float qScale = 1.0f;
2760 const int32_t qOffset = 0;
2761
Derek Lambertif90c56d2020-01-10 17:14:08 +00002762 const armnn::DataType datatype = armnn::DataType::QSymmS16; // datatype & constants set to QSymm16
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002763
2764 armnn::TensorInfo inputDesc({2, 2}, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002765 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002766
2767 armnn::TensorInfo outputDesc({2, 4}, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002768 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2769 {
2770 -0.02973187f, 0.12294730f, 0.20885126f, -0.15358765f,
2771 -0.01854220f, 0.11281417f, 0.24466537f, -0.18262920f
2772 },
2773 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002774
2775 return LstmNoCifgNoPeepholeNoProjectionTestImpl<datatype>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002776 workloadFactory, memoryManager, tensorHandleFactory,
2777 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2778 qScale, qOffset, datatype);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002779}
2780
2781//
2782// QuantizedLstm
2783//
2784
2785LayerTestResult<uint8_t, 2> QuantizedLstmTest(
2786 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002787 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2788 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002789{
Derek Lambertif90c56d2020-01-10 17:14:08 +00002790 armnn::TensorInfo inputDesc({2, 2}, armnn::DataType::QAsymmU8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002791 std::vector<uint8_t> input = {166, 179, 50, 150};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002792
Derek Lambertif90c56d2020-01-10 17:14:08 +00002793 armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::QAsymmU8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002794 std::vector<uint8_t> expectedOutput = {140, 151, 146, 112, 136, 156, 142, 112 };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002795
Sadik Armagan483c8112021-06-01 09:24:52 +01002796 return QuantizedLstmTestImpl(workloadFactory, memoryManager, tensorHandleFactory,
2797 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002798}
James Conroy4f1f8992020-04-29 20:01:10 +01002799
2800// QLSTM
2801LayerTestResult<int8_t, 2> QLstmTest(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002802 armnn::IWorkloadFactory& workloadFactory,
2803 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2804 const armnn::ITensorHandleFactory& tensorHandleFactory)
James Conroy4f1f8992020-04-29 20:01:10 +01002805{
2806 armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002807 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
James Conroy4f1f8992020-04-29 20:01:10 +01002808
2809 armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002810 std::vector<int8_t> expectedOutput = {-15, 21, 14, 20, -15, 15, 5, 27};
James Conroy4f1f8992020-04-29 20:01:10 +01002811
Finn Williamsc43de6a2020-08-27 11:13:25 +01002812 return QLstmTestImpl(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
James Conroy4f1f8992020-04-29 20:01:10 +01002813}
James Conroyb22a75e2020-06-08 14:53:10 +01002814
2815LayerTestResult<int8_t, 2> QLstmTest1(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002816 armnn::IWorkloadFactory& workloadFactory,
2817 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2818 const armnn::ITensorHandleFactory& tensorHandleFactory)
James Conroyb22a75e2020-06-08 14:53:10 +01002819{
2820 armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002821 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
James Conroyb22a75e2020-06-08 14:53:10 +01002822
2823 armnn::TensorInfo outputDesc({2, 3}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002824 std::vector<int8_t> expectedOutput = {127, 127, -108, -67, 127, 127};
James Conroyb22a75e2020-06-08 14:53:10 +01002825
Finn Williamsc43de6a2020-08-27 11:13:25 +01002826 return QLstmTestImpl1(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
James Conroyb22a75e2020-06-08 14:53:10 +01002827}
2828
2829LayerTestResult<int8_t, 2> QLstmTest2(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002830 armnn::IWorkloadFactory& workloadFactory,
2831 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2832 const armnn::ITensorHandleFactory& tensorHandleFactory)
James Conroyb22a75e2020-06-08 14:53:10 +01002833{
2834 armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002835 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
James Conroyb22a75e2020-06-08 14:53:10 +01002836
2837 armnn::TensorInfo outputDesc({2, 3}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002838 std::vector<int8_t> expectedOutput = {127, 127, 127, -128, 127, 127};
James Conroyb22a75e2020-06-08 14:53:10 +01002839
Finn Williamsc43de6a2020-08-27 11:13:25 +01002840 return QLstmTestImpl2(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
James Conroyb22a75e2020-06-08 14:53:10 +01002841}