blob: 2eaaeb5c9b41c508708a89b3c169f393ee460fce [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
Teresa Charlinfbf0e5b2020-08-17 01:01:06 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
telsoa01c577f2c2018-08-31 09:22:23 +01005
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01006#include "LstmTestImpl.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007
Colm Donelanc42a9872022-02-02 16:35:09 +00008#include <armnnUtils/QuantizeHelper.hpp>
Aron Virginas-Tar48623a02019-10-22 10:00:28 +01009
Matthew Sloyan171214c2020-09-09 09:07:37 +010010#include <armnn/utility/NumericCast.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010011
Colm Donelan0c479742021-12-10 12:43:54 +000012#include <armnn/backends/TensorHandle.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010013
Sadik Armagana097d2a2021-11-24 15:47:28 +000014#include <armnnTestUtils/TensorCopyUtils.hpp>
Colm Donelan0c479742021-12-10 12:43:54 +000015#include <armnnTestUtils/WorkloadTestUtils.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010016
17#include <reference/workloads/Decoders.hpp>
18#include <reference/workloads/Encoders.hpp>
19#include <reference/workloads/LstmUtils.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010020
Colm Donelanc42a9872022-02-02 16:35:09 +000021#include <armnnTestUtils/TensorHelpers.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010022
Sadik Armagan1625efc2021-06-10 18:24:34 +010023#include <doctest/doctest.h>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010024namespace
25{
Jan Eilers38e05bd2019-06-26 13:10:09 +010026
27template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
28void LstmUtilsVectorBatchVectorAddTestImpl(
Sadik Armagan483c8112021-06-01 09:24:52 +010029 std::vector<float>& vec,
30 std::vector<float>& batchVec,
Jan Eilers38e05bd2019-06-26 13:10:09 +010031 uint32_t vSize,
32 uint32_t nBatch,
Sadik Armagan483c8112021-06-01 09:24:52 +010033 std::vector<float>& expectedOutput,
34 armnn::TensorShape& expectedShape)
Jan Eilers38e05bd2019-06-26 13:10:09 +010035{
36 float qScale = 0.0f;
37 int32_t qOffset = 0;
38 armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType, qScale, qOffset );
39
40 // Make encoder and decoder
41 std::unique_ptr<armnn::Decoder<float>> vecDecoder = armnn::MakeDecoder<float>(tensorInfo, vec.data());
42 std::unique_ptr<armnn::Decoder<float>> batchVecDecoder = armnn::MakeDecoder<float>(tensorInfo, batchVec.data());
43 std::unique_ptr<armnn::Encoder<float>> batchVecEncoder = armnn::MakeEncoder<float>(tensorInfo, batchVec.data());
44
45 VectorBatchVectorAdd(*vecDecoder, vSize, *batchVecDecoder, nBatch, *batchVecEncoder);
46
47 // check shape and compare values
Sadik Armagan483c8112021-06-01 09:24:52 +010048 auto result = CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape);
Sadik Armagan1625efc2021-06-10 18:24:34 +010049 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
Jan Eilers38e05bd2019-06-26 13:10:09 +010050
51 // check if iterator is back at start position
52 batchVecEncoder->Set(1.0f);
Sadik Armagan1625efc2021-06-10 18:24:34 +010053 CHECK(batchVec[0] == 1.0f);
Jan Eilers38e05bd2019-06-26 13:10:09 +010054}
55
56template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
57void LstmUtilsZeroVectorTestImpl(
Sadik Armagan483c8112021-06-01 09:24:52 +010058 std::vector<float>& input,
Jan Eilers38e05bd2019-06-26 13:10:09 +010059 uint32_t vSize,
Sadik Armagan483c8112021-06-01 09:24:52 +010060 std::vector<float>& expectedOutput,
61 armnn::TensorShape& expectedShape)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010062{
Jan Eilers38e05bd2019-06-26 13:10:09 +010063 float qScale = 0.0f;
64 int32_t qOffset = 0;
65
66 armnn::TensorInfo tensorInfo({vSize}, ArmnnType, qScale, qOffset );
67
68 // Make encoder for input
69 std::unique_ptr<armnn::Encoder<float>> outputEncoder = armnn::MakeEncoder<float>(tensorInfo, input.data());
70
71 // call ZeroVector
72 ZeroVector(*outputEncoder, vSize);
73
74 // check shape and compare values
Sadik Armagan483c8112021-06-01 09:24:52 +010075 auto result = CompareTensors(input, expectedOutput, expectedShape, expectedShape);
Sadik Armagan1625efc2021-06-10 18:24:34 +010076 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
Jan Eilers38e05bd2019-06-26 13:10:09 +010077
78 // check if iterator is back at start position
79 outputEncoder->Set(1.0f);
Sadik Armagan1625efc2021-06-10 18:24:34 +010080 CHECK(input[0] == 1.0f);
Jan Eilers38e05bd2019-06-26 13:10:09 +010081
82}
83
Jan Eilers38e05bd2019-06-26 13:10:09 +010084template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
85void LstmUtilsMeanStddevNormalizationTestImpl(
Sadik Armagan483c8112021-06-01 09:24:52 +010086 std::vector<float>& input,
Jan Eilers38e05bd2019-06-26 13:10:09 +010087 uint32_t vSize,
88 uint32_t nBatch,
Sadik Armagan483c8112021-06-01 09:24:52 +010089 std::vector<float>& expectedOutput,
90 armnn::TensorShape& expectedShape)
Jan Eilers38e05bd2019-06-26 13:10:09 +010091{
92 float qScale = 0.0f;
93 int32_t qOffset = 0;
94 armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType, qScale, qOffset );
95
96 // Make encoder and decoder for input
97 std::unique_ptr<armnn::Decoder<float>> inputDecoder = armnn::MakeDecoder<float>(tensorInfo, input.data());
98 std::unique_ptr<armnn::Encoder<float>> outputEncoder = armnn::MakeEncoder<float>(tensorInfo, input.data());
99
100 MeanStddevNormalization(*inputDecoder, *outputEncoder, vSize, nBatch, 1e-8f);
101
102 // check shape and compare values
Sadik Armagan483c8112021-06-01 09:24:52 +0100103 auto result = CompareTensors(input, expectedOutput, expectedShape, expectedShape);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100104 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100105
106 // check if iterator is back at start position
107 outputEncoder->Set(1.0f);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100108 CHECK(input[0] == 1.0f);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100109}
110
111template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
112void LstmUtilsVectorBatchVectorCwiseProductTestImpl(
Sadik Armagan483c8112021-06-01 09:24:52 +0100113 std::vector<float>& vec,
114 std::vector<float>& batchVec,
Jan Eilers38e05bd2019-06-26 13:10:09 +0100115 uint32_t vSize,
116 uint32_t nBatch,
Sadik Armagan483c8112021-06-01 09:24:52 +0100117 std::vector<float>& expectedOutput,
118 armnn::TensorShape& expectedShape)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100119{
120 float qScale = 0.0f;
121 int32_t qOffset = 0;
122 armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType, qScale, qOffset );
123
124 // Make encoder and decoder
125 std::unique_ptr<armnn::Decoder<float>> vecDecoder = armnn::MakeDecoder<float>(tensorInfo, vec.data());
126 std::unique_ptr<armnn::Decoder<float>> batchVecDecoder = armnn::MakeDecoder<float>(tensorInfo, batchVec.data());
127 std::unique_ptr<armnn::Encoder<float>> batchVecEncoder = armnn::MakeEncoder<float>(tensorInfo, batchVec.data());
128
129 VectorBatchVectorCwiseProduct(*vecDecoder, vSize, *batchVecDecoder, nBatch, *batchVecEncoder);
130
131 // check shape and compare values
Sadik Armagan483c8112021-06-01 09:24:52 +0100132 auto result = CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100133 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100134
135 // check if iterator is back at start position
136 batchVecEncoder->Set(1.0f);
Sadik Armagan1625efc2021-06-10 18:24:34 +0100137 CHECK(batchVec[0] == 1.0f);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100138}
139
140// Lstm Layer tests:
James Conroy9c3cae82019-08-01 16:01:48 +0100141// *********************************** //
Conor Kennedyb9971c92019-05-07 07:14:23 +0100142template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
143LayerTestResult<T, 2>
144LstmNoCifgNoPeepholeNoProjectionTestImpl(
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000145 armnn::IWorkloadFactory& workloadFactory,
146 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100147 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +0100148 const std::vector<T>& input,
149 const std::vector<T>& outputExpected,
150 const armnn::TensorShape& inputShape,
151 const armnn::TensorShape& outputExpectedShape,
Conor Kennedyb9971c92019-05-07 07:14:23 +0100152 float qScale = 0.0f,
153 int32_t qOffset = 0,
154 armnn::DataType constantDataType = armnn::DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +0100155{
Jan Eilers8eb25602020-03-09 12:13:48 +0000156 IgnoreUnused(memoryManager);
Sadik Armagan483c8112021-06-01 09:24:52 +0100157 unsigned int batchSize = armnn::numeric_cast<unsigned int>(inputShape[0]);
158 unsigned int inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]);
159 unsigned int outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]);
telsoa01c577f2c2018-08-31 09:22:23 +0100160 // cellSize and outputSize have the same size when there is no projection.
161 unsigned numUnits = outputSize;
162
Conor Kennedyb9971c92019-05-07 07:14:23 +0100163 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset );
164 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
165 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100166
Conor Kennedyb9971c92019-05-07 07:14:23 +0100167 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
168 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
169 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
170 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100171
Rob Hughesbb46dde2020-05-20 15:27:37 +0100172 std::vector<T> inputVector;
telsoa01c577f2c2018-08-31 09:22:23 +0100173 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
telsoa01c577f2c2018-08-31 09:22:23 +0100174
Rob Hughesbb46dde2020-05-20 15:27:37 +0100175 std::vector<T> cellStateInVector(batchSize * numUnits, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100176 std::vector<T> outputStateInVector(batchSize * outputSize, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100177 std::vector<T> scratchBufferVector(batchSize * numUnits * 4, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100178 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100179 std::vector<T> cellStateOutVector(batchSize * numUnits, T());
Sadik Armagan483c8112021-06-01 09:24:52 +0100180
181 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
telsoa01c577f2c2018-08-31 09:22:23 +0100182
Rob Hughesbb46dde2020-05-20 15:27:37 +0100183 std::vector<T> outputVector;
telsoa01c577f2c2018-08-31 09:22:23 +0100184 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
telsoa01c577f2c2018-08-31 09:22:23 +0100185
Finn Williamsc43de6a2020-08-27 11:13:25 +0100186 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100187 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100188 tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100189 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100190 tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100191
Finn Williamsc43de6a2020-08-27 11:13:25 +0100192 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
193 tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100194 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100195 tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100196 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100197 tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
198 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100199
200 armnn::LstmQueueDescriptor data;
201 armnn::WorkloadInfo info;
202
203 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
204 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
205 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
206
207 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
208 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
209 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
210 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
211
Conor Kennedyb9971c92019-05-07 07:14:23 +0100212 armnn::TensorInfo tensorInfo4({numUnits}, constantDataType , qScale, qOffset);
213 armnn::TensorInfo tensorInfo8({numUnits, 2}, constantDataType, qScale, qOffset);
214 armnn::TensorInfo tensorInfo16({numUnits, 4}, constantDataType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100215
Sadik Armagan483c8112021-06-01 09:24:52 +0100216 std::vector<float> inputToInputWeights = {-0.45018822f, -0.02338299f, -0.0870589f,
217 -0.34550029f, 0.04266912f, -0.15680569f,
218 -0.34856534f, 0.43890524f};
telsoa01c577f2c2018-08-31 09:22:23 +0100219
Sadik Armagan483c8112021-06-01 09:24:52 +0100220 std::vector<float> inputToForgetWeights = { 0.09701663f, 0.20334584f, -0.50592935f,
221 -0.31343272f, -0.40032279f, 0.44781327f,
222 0.01387155f, -0.35593212f};
telsoa01c577f2c2018-08-31 09:22:23 +0100223
Sadik Armagan483c8112021-06-01 09:24:52 +0100224 std::vector<float> inputToCellWeights = { -0.50013041f, 0.1370284f, 0.11810488f, 0.2013163f,
225 -0.20583314f, 0.44344562f, 0.22077113f,
226 -0.29909778f};
telsoa01c577f2c2018-08-31 09:22:23 +0100227
Sadik Armagan483c8112021-06-01 09:24:52 +0100228 std::vector<float> inputToOutputWeights = { -0.25065863f, -0.28290087f, 0.04613829f,
229 0.40525138f, 0.44272184f, 0.03897077f,
230 -0.1556896f, 0.19487578f};
telsoa01c577f2c2018-08-31 09:22:23 +0100231
Sadik Armagan483c8112021-06-01 09:24:52 +0100232 std::vector<float> recurrentToInputWeights = {-0.0063535f, -0.2042388f, 0.31454784f,
233 -0.35746509f, 0.28902304f, 0.08183324f,
234 -0.16555229f, 0.02286911f, -0.13566875f,
235 0.03034258f, 0.48091322f, -0.12528998f,
236 0.24077177f, -0.51332325f, -0.33502164f,
237 0.10629296f};
telsoa01c577f2c2018-08-31 09:22:23 +0100238
Sadik Armagan483c8112021-06-01 09:24:52 +0100239 std::vector<float> recurrentToForgetWeights = { -0.48684245f, -0.06655136f, 0.42224967f,
240 0.2112639f, 0.27654213f, 0.20864892f,
241 -0.07646349f, 0.45877004f, 0.00141793f,
242 -0.14609534f, 0.36447752f, 0.09196436f,
243 0.28053468f, 0.01560611f, -0.20127171f,
244 -0.01140004f};
telsoa01c577f2c2018-08-31 09:22:23 +0100245
Sadik Armagan483c8112021-06-01 09:24:52 +0100246 std::vector<float> recurrentToCellWeights = { -0.3407414f, 0.24443203f, -0.2078532f,
247 0.26320225f, 0.05695659f, -0.00123841f,
248 -0.4744786f, -0.35869038f, -0.06418842f,
249 -0.13502428f, -0.501764f, 0.22830659f,
250 -0.46367589f, 0.26016325f, -0.03894562f,
251 -0.16368064f};
telsoa01c577f2c2018-08-31 09:22:23 +0100252
Sadik Armagan483c8112021-06-01 09:24:52 +0100253 std::vector<float> recurrentToOutputWeights = { 0.43385774f, -0.17194885f, 0.2718237f,
254 0.09215671f, 0.24107647f, -0.39835793f,
255 0.18212086f, 0.01301402f, 0.48572797f,
256 -0.50656658f, 0.20047462f, -0.20607421f,
257 -0.51818722f, -0.15390486f, 0.0468148f,
258 0.39922136f};
telsoa01c577f2c2018-08-31 09:22:23 +0100259
Sadik Armagan483c8112021-06-01 09:24:52 +0100260 std::vector<float> cellToInputWeights = {0., 0., 0., 0.};
telsoa01c577f2c2018-08-31 09:22:23 +0100261
Sadik Armagan483c8112021-06-01 09:24:52 +0100262 std::vector<float> inputGateBias = {0., 0., 0., 0.};
telsoa01c577f2c2018-08-31 09:22:23 +0100263
Sadik Armagan483c8112021-06-01 09:24:52 +0100264 std::vector<float> forgetGateBias = {1., 1., 1., 1.};
telsoa01c577f2c2018-08-31 09:22:23 +0100265
Sadik Armagan483c8112021-06-01 09:24:52 +0100266 std::vector<float> cellBias = {0., 0., 0., 0.};
telsoa01c577f2c2018-08-31 09:22:23 +0100267
Sadik Armagan483c8112021-06-01 09:24:52 +0100268 std::vector<float> outputGateBias = {0., 0., 0., 0.};
telsoa01c577f2c2018-08-31 09:22:23 +0100269
James Conroy1f58f032021-04-27 17:13:27 +0100270 armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo8);
271 armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo8);
272 armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo8);
273 armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo8);
274 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo16);
275 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo16);
276 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo16);
277 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo16);
278 armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo4);
279 armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo4);
280 armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo4);
281 armnn::ScopedTensorHandle cellBiasTensor(tensorInfo4);
282 armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo4);
telsoa01c577f2c2018-08-31 09:22:23 +0100283
Sadik Armagan483c8112021-06-01 09:24:52 +0100284 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
285 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
286 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
287 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
288 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
289 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
290 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
291 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
292 AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data());
293 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
294 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
295 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
296 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
telsoa01c577f2c2018-08-31 09:22:23 +0100297
298 data.m_InputToInputWeights = &inputToInputWeightsTensor;
299 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
300 data.m_InputToCellWeights = &inputToCellWeightsTensor;
301 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
302 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
303 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
304 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
305 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100306 data.m_InputGateBias = &inputGateBiasTensor;
307 data.m_ForgetGateBias = &forgetGateBiasTensor;
308 data.m_CellBias = &cellBiasTensor;
309 data.m_OutputGateBias = &outputGateBiasTensor;
310
telsoa01c577f2c2018-08-31 09:22:23 +0100311 // Flags to set test configuration
312 data.m_Parameters.m_ActivationFunc = 4;
313 data.m_Parameters.m_CifgEnabled = false;
314 data.m_Parameters.m_PeepholeEnabled = false;
315 data.m_Parameters.m_ProjectionEnabled = false;
316
Teresa Charlin611c7fb2022-01-07 09:47:29 +0000317 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info);
telsoa01c577f2c2018-08-31 09:22:23 +0100318 inputHandle->Allocate();
319 outputStateInHandle->Allocate();
320 cellStateInHandle->Allocate();
321
322 scratchHandle->Allocate();
323 outputStateOutHandle->Allocate();
324 cellStateOutHandle->Allocate();
325 outputHandle->Allocate();
326
Sadik Armagan483c8112021-06-01 09:24:52 +0100327 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
328 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
329 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +0100330
telsoa01c577f2c2018-08-31 09:22:23 +0100331 workload->Execute();
332
Sadik Armagan483c8112021-06-01 09:24:52 +0100333 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
telsoa01c577f2c2018-08-31 09:22:23 +0100334
Sadik Armagan483c8112021-06-01 09:24:52 +0100335 return LayerTestResult<T, 2>(actualOutput,
336 outputVector,
337 outputHandle->GetShape(),
338 outputTensorInfo.GetShape());
telsoa01c577f2c2018-08-31 09:22:23 +0100339}
340
Conor Kennedyb9971c92019-05-07 07:14:23 +0100341template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
342LayerTestResult<T, 2>
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000343LstmLayerNoCifgWithPeepholeWithProjectionTestImpl(armnn::IWorkloadFactory& workloadFactory,
344 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100345 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +0100346 const std::vector<T>& input,
347 const std::vector<T>& outputExpected,
Conor Kennedyb9971c92019-05-07 07:14:23 +0100348 float qScale = 0.0f,
349 int32_t qOffset = 0,
350 armnn::DataType constantDataType = armnn::DataType::Float32)
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000351{
Jan Eilers8eb25602020-03-09 12:13:48 +0000352 IgnoreUnused(memoryManager);
telsoa01c577f2c2018-08-31 09:22:23 +0100353 unsigned int batchSize = 2;
354 unsigned int outputSize = 16;
355 unsigned int inputSize = 5;
356 unsigned numUnits = 20;
357
Conor Kennedyb9971c92019-05-07 07:14:23 +0100358 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset);
359 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
360 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100361
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000362 // Scratch buffer size without CIFG [batchSize, numUnits * 4]
Conor Kennedyb9971c92019-05-07 07:14:23 +0100363 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
364 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
365 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
366 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100367
Rob Hughesbb46dde2020-05-20 15:27:37 +0100368 std::vector<T> inputVector;
telsoa01c577f2c2018-08-31 09:22:23 +0100369 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
telsoa01c577f2c2018-08-31 09:22:23 +0100370
Rob Hughesbb46dde2020-05-20 15:27:37 +0100371 std::vector<T> cellStateInVector(batchSize * numUnits, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100372 std::vector<T> outputStateInVector(batchSize * outputSize, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100373 std::vector<T> scratchBufferVector(batchSize * numUnits * 4, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100374 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100375 std::vector<T> cellStateOutVector(batchSize * numUnits, T());
Sadik Armagan483c8112021-06-01 09:24:52 +0100376
377 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
telsoa01c577f2c2018-08-31 09:22:23 +0100378
Rob Hughesbb46dde2020-05-20 15:27:37 +0100379 std::vector<T> outputVector;
telsoa01c577f2c2018-08-31 09:22:23 +0100380 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
telsoa01c577f2c2018-08-31 09:22:23 +0100381
Finn Williamsc43de6a2020-08-27 11:13:25 +0100382 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100383 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100384 tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100385 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100386 tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100387
Finn Williamsc43de6a2020-08-27 11:13:25 +0100388 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
389 tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100390 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100391 tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100392 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100393 tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
394 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100395
396 armnn::LstmQueueDescriptor data;
397 armnn::WorkloadInfo info;
398
399 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
400 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
401 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
David Beckac42efd2018-09-26 17:41:13 +0100402
telsoa01c577f2c2018-08-31 09:22:23 +0100403 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
404 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
405 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
406 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
407
Conor Kennedyb9971c92019-05-07 07:14:23 +0100408 armnn::TensorInfo tensorInfo16({outputSize}, constantDataType, qScale, qOffset);
409 armnn::TensorInfo tensorInfo20({numUnits}, constantDataType, qScale, qOffset);
410 armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, constantDataType, qScale, qOffset);
411 armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, constantDataType, qScale, qOffset);
412 armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, constantDataType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100413
Sadik Armagan483c8112021-06-01 09:24:52 +0100414 std::vector<float> inputToInputWeights = {0.021393683f,0.06124551f, 0.046905167f,-0.014657677f,-0.03149463f,
415 0.09171803f, 0.14647801f,0.10797193f, -0.0057968358f,0.0019193048f,
416 -0.2726754f, 0.10154029f, -0.018539885f, 0.080349885f, -0.10262385f,
417 -0.022599787f,-0.09121155f, -0.008675967f, -0.045206103f,-0.0821282f,
418 -0.008045952f,0.015478081f, 0.055217247f, 0.038719587f, 0.044153627f,
419 -0.06453243f,0.05031825f, -0.046935108f, -0.008164439f, 0.014574226f,
420 -0.1671009f, -0.15519552f, -0.16819797f,-0.13971269f,-0.11953059f,
421 0.25005487f, -0.22790983f, 0.009855087f, -0.028140958f, -0.11200698f,
422 0.11295408f, -0.0035217577f, 0.054485075f, 0.05184695f, 0.064711206f,
423 0.10989193f, 0.11674786f, 0.03490607f, 0.07727357f, 0.11390585f,
424 -0.1863375f, -0.1034451f, -0.13945189f, -0.049401227f, -0.18767063f,
425 0.042483903f, 0.14233552f, 0.13832581f, 0.18350165f, 0.14545603f,
426 -0.028545704f,0.024939531f,0.050929718f,0.0076203286f,-0.0029723682f,
427 -0.042484224f, -0.11827596f, -0.09171104f, -0.10808628f,-0.16327988f,
428 -0.2273378f, -0.0993647f, -0.017155107f,0.0023917493f,0.049272764f,
429 0.0038534778f, 0.054764505f, 0.089753784f, 0.06947234f, 0.08014476f,
430 -0.04544234f, -0.0497073f,-0.07135631f, -0.048929106f,-0.004042012f,
431 -0.009284026f, 0.018042054f, 0.0036860977f,-0.07427302f, -0.11434604f,
432 -0.018995456f, 0.031487543f, 0.012834908f,0.019977754f,0.044256654f,
433 -0.39292613f, -0.18519334f, -0.11651281f,-0.06809892f, 0.011373677f };
telsoa01c577f2c2018-08-31 09:22:23 +0100434
Sadik Armagan483c8112021-06-01 09:24:52 +0100435 std::vector<float> inputToForgetWeights = {-0.0018401089f, -0.004852237f,0.03698424f, 0.014181704f,0.028273236f,
436 -0.016726194f, -0.05249759f,-0.10204261f, 0.00861066f,-0.040979505f,
437 -0.009899187f,0.01923892f,-0.028177269f, -0.08535103f,-0.14585495f,
438 0.10662567f,-0.01909731f,-0.017883534f,-0.0047269356f,-0.045103323f,
439 0.0030784295f,0.076784775f,0.07463696f, 0.094531395f,0.0814421f,
440 -0.12257899f, -0.033945758f,-0.031303465f, 0.045630626f,0.06843887f,
441 -0.13492945f, -0.012480007f,-0.0811829f, -0.07224499f,-0.09628791f,
442 0.045100946f,0.0012300825f, 0.013964662f, 0.099372394f,0.02543059f,
443 0.06958324f, 0.034257296f, 0.0482646f, 0.06267997f,0.052625068f,
444 0.12784666f, 0.07077897f, 0.025725935f, 0.04165009f,0.07241905f,
445 0.018668644f, -0.037377294f,-0.06277783f,-0.08833636f,-0.040120605f,
446 -0.011405586f,-0.007808335f,-0.010301386f,-0.005102167f,0.027717464f,
447 0.05483423f, 0.11449111f, 0.11289652f,0.10939839f, 0.13396506f,
448 -0.08402166f,-0.01901462f, -0.044678304f,-0.07720565f,0.014350063f,
449 -0.11757958f, -0.0652038f, -0.08185733f,-0.076754324f,-0.092614375f,
450 0.10405491f, 0.052960336f, 0.035755895f,0.035839386f,-0.012540553f,
451 0.036881298f, 0.02913376f, 0.03420159f,0.05448447f,-0.054523353f,
452 0.02582715f, 0.02327355f, -0.011857179f,-0.0011980024f,-0.034641717f,
453 -0.026125094f,-0.17582615f,-0.15923657f,-0.27486774f,-0.0006143371f,
454 0.0001771948f, -8.470171e-05f, 0.02651807f,0.045790765f,0.06956496f };
telsoa01c577f2c2018-08-31 09:22:23 +0100455
Sadik Armagan483c8112021-06-01 09:24:52 +0100456 std::vector<float> inputToCellWeights = { -0.04580283f, -0.09549462f, -0.032418985f, -0.06454633f,
457 -0.043528453f, 0.043018587f, -0.049152344f, -0.12418144f,
458 -0.078985475f, -0.07596889f, 0.019484362f, -0.11434962f,
459 -0.0074034138f, -0.06314844f, -0.092981495f, 0.0062155537f,
460 -0.025034338f, -0.0028890965f, 0.048929527f, 0.06235075f,
461 0.10665918f, -0.032036792f, -0.08505916f, -0.10843358f,
462 -0.13002433f, -0.036816437f, -0.02130134f, -0.016518239f,
463 0.0047691227f, -0.0025825808f, 0.066017866f, 0.029991534f,
464 -0.10652836f, -0.1037554f, -0.13056071f, -0.03266643f,
465 -0.033702414f, -0.006473424f, -0.04611692f, 0.014419339f,
466 -0.025174323f, 0.0396852f, 0.081777506f, 0.06157468f,
467 0.10210095f, -0.009658194f, 0.046511717f, 0.03603906f,
468 0.0069369148f, 0.015960095f, -0.06507666f, 0.09551598f,
469 0.053568836f, 0.06408714f, 0.12835667f, -0.008714329f,
470 -0.20211966f, -0.12093674f, 0.029450472f, 0.2849013f,
471 -0.029227901f, 0.1164364f, -0.08560263f, 0.09941786f,
472 -0.036999565f, -0.028842626f, -0.0033637602f, -0.017012902f,
473 -0.09720865f, -0.11193351f, -0.029155117f, -0.017936034f,
474 -0.009768936f, -0.04223324f, -0.036159635f, 0.06505112f,
475 -0.021742892f, -0.023377212f, -0.07221364f, -0.06430552f,
476 0.05453865f, 0.091149814f, 0.06387331f, 0.007518393f,
477 0.055960953f, 0.069779344f, 0.046411168f, 0.10509911f,
478 0.07463894f, 0.0075130584f, 0.012850982f, 0.04555431f,
479 0.056955688f, 0.06555285f, 0.050801456f, -0.009862683f,
480 0.00826772f, -0.026555609f, -0.0073611983f, -0.0014897042f };
telsoa01c577f2c2018-08-31 09:22:23 +0100481
Sadik Armagan483c8112021-06-01 09:24:52 +0100482 std::vector<float> inputToOutputWeights ={-0.0998932f, -0.07201956f, -0.052803773f,-0.15629593f,-0.15001918f,
483 -0.07650751f,0.02359855f, -0.075155355f, -0.08037709f, -0.15093534f,
484 0.029517552f, -0.04751393f, 0.010350531f,-0.02664851f, -0.016839722f,
485 -0.023121163f, 0.0077019283f, 0.012851257f, -0.05040649f,-0.0129761f,
486 -0.021737747f,-0.038305793f,-0.06870586f, -0.01481247f,-0.001285394f,
487 0.10124236f, 0.083122835f, 0.053313006f,-0.062235646f,-0.075637154f,
488 -0.027833903f, 0.029774971f, 0.1130802f, 0.09218906f, 0.09506135f,
489 -0.086665764f,-0.037162706f,-0.038880914f,-0.035832845f,-0.014481564f,
490 -0.09825003f,-0.12048569f,-0.097665586f,-0.05287633f, -0.0964047f,
491 -0.11366429f, 0.035777505f, 0.13568819f, 0.052451383f,0.050649304f,
492 0.05798951f, -0.021852335f,-0.099848844f,0.014740475f,-0.078897946f,
493 0.04974699f, 0.014160473f, 0.06973932f, 0.04964942f, 0.033364646f,
494 0.08190124f, 0.025535367f, 0.050893165f, 0.048514254f,0.06945813f,
495 -0.078907564f,-0.06707616f, -0.11844508f, -0.09986688f,-0.07509403f,
496 0.06263226f, 0.14925587f, 0.20188436f, 0.12098451f,0.14639415f,
497 0.0015017595f, -0.014267382f, -0.03417257f,0.012711468f,0.0028300495f,
498 -0.024758482f, -0.05098548f,-0.0821182f, 0.014225672f, 0.021544158f,
499 0.08949725f, 0.07505268f, -0.0020780868f, 0.04908258f,0.06476295f,
500 -0.022907063f,0.027562456f,0.040185735f, 0.019567577f,-0.015598739f,
501 -0.049097303f, -0.017121866f, -0.083368234f,-0.02332002f,-0.0840956f };
telsoa01c577f2c2018-08-31 09:22:23 +0100502
Sadik Armagan483c8112021-06-01 09:24:52 +0100503 std::vector<float> inputGateBias = {0.02234832f, 0.14757581f, 0.18176508f, 0.10380666f, 0.053110216f,
504 -0.06928846f, -0.13942584f, -0.11816189f, 0.19483899f, 0.03652339f,
505 -0.10250295f, 0.036714908f, -0.18426876f, 0.036065217f, 0.21810818f,
506 0.02383196f, -0.043370757f, 0.08690144f, -0.04444982f, 0.00030581196f };
telsoa01c577f2c2018-08-31 09:22:23 +0100507
Sadik Armagan483c8112021-06-01 09:24:52 +0100508 std::vector<float> forgetGateBias ={0.035185695f, -0.042891346f, -0.03032477f, 0.23027696f,
509 0.11098921f, 0.15378423f, 0.09263801f, 0.09790885f,
510 0.09508917f, 0.061199076f, 0.07665568f, -0.015443159f,
511 -0.03499149f, 0.046190713f, 0.08895977f, 0.10899629f,
512 0.40694186f, 0.06030037f, 0.012413437f, -0.06108739f };
telsoa01c577f2c2018-08-31 09:22:23 +0100513
Sadik Armagan483c8112021-06-01 09:24:52 +0100514 std::vector<float> cellBias = { -0.024379363f, 0.0055531194f, 0.23377132f, 0.033463873f,
515 -0.1483596f, -0.10639995f, -0.091433935f, 0.058573797f,
516 -0.06809782f, -0.07889636f, -0.043246906f, -0.09829136f,
517 -0.4279842f, 0.034901652f, 0.18797937f, 0.0075234566f,
518 0.016178843f, 0.1749513f, 0.13975595f, 0.92058027f };
telsoa01c577f2c2018-08-31 09:22:23 +0100519
Sadik Armagan483c8112021-06-01 09:24:52 +0100520 std::vector<float> outputGateBias ={0.046159424f, -0.0012809046f, 0.03563469f, 0.12648113f, 0.027195795f,
521 0.35373217f, -0.018957434f, 0.008907322f, -0.0762701f, 0.12018895f,
522 0.04216877f, 0.0022856654f, 0.040952638f, 0.3147856f, 0.08225149f,
523 -0.057416286f, -0.14995944f, -0.008040261f, 0.13208859f, 0.029760877f};
telsoa01c577f2c2018-08-31 09:22:23 +0100524
Sadik Armagan483c8112021-06-01 09:24:52 +0100525 std::vector<float> recurrentToInputWeights = { -0.001374326f, -0.078856036f, 0.10672688f, 0.029162422f,
telsoa01c577f2c2018-08-31 09:22:23 +0100526 -0.11585556f, 0.02557986f, -0.13446963f, -0.035785314f,
527 -0.01244275f, 0.025961924f, -0.02337298f, -0.044228926f,
528 -0.055839065f, -0.046598054f, -0.010546039f, -0.06900766f,
529 0.027239809f, 0.022582639f, -0.013296484f, -0.05459212f,
530 0.08981f, -0.045407712f, 0.08682226f, -0.06867011f,
531 -0.14390695f, -0.02916037f, 0.000996957f, 0.091420636f,
532 0.14283475f, -0.07390571f, -0.06402044f, 0.062524505f,
533 -0.093129106f, 0.04860203f, -0.08364217f, -0.08119002f,
534 0.009352075f, 0.22920375f, 0.0016303885f, 0.11583097f,
535 -0.13732095f, 0.012405723f, -0.07551853f, 0.06343048f,
536 0.12162708f, -0.031923793f, -0.014335606f, 0.01790974f,
537 -0.10650317f, -0.0724401f, 0.08554849f, -0.05727212f,
538 0.06556731f, -0.042729504f, -0.043227166f, 0.011683251f,
539 -0.013082158f, -0.029302018f, -0.010899579f, -0.062036745f,
540 -0.022509435f, -0.00964907f, -0.01567329f, 0.04260106f,
541 -0.07787477f, -0.11576462f, 0.017356863f, 0.048673786f,
542 -0.017577527f, -0.05527947f, -0.082487635f, -0.040137455f,
543 -0.10820036f, -0.04666372f, 0.022746278f, -0.07851417f,
544 0.01068115f, 0.032956902f, 0.022433773f, 0.0026891115f,
545 0.08944216f, -0.0685835f, 0.010513544f, 0.07228705f,
546 0.02032331f, -0.059686817f, -0.0005566496f, -0.086984694f,
547 0.040414046f, -0.1380399f, 0.094208956f, -0.05722982f,
548 0.012092817f, -0.04989123f, -0.086576f, -0.003399834f,
549 -0.04696032f, -0.045747425f, 0.10091314f, 0.048676282f,
550 -0.029037097f, 0.031399418f, -0.0040285117f, 0.047237843f,
551 0.09504992f, 0.041799378f, -0.049185462f, -0.031518843f,
552 -0.10516937f, 0.026374253f, 0.10058866f, -0.0033195973f,
553 -0.041975245f, 0.0073591834f, 0.0033782164f, -0.004325073f,
554 -0.10167381f, 0.042500053f, -0.01447153f, 0.06464186f,
555 -0.017142897f, 0.03312627f, 0.009205989f, 0.024138335f,
556 -0.011337001f, 0.035530265f, -0.010912711f, 0.0706555f,
557 -0.005894094f, 0.051841937f, -0.1401738f, -0.02351249f,
558 0.0365468f, 0.07590991f, 0.08838724f, 0.021681072f,
559 -0.10086113f, 0.019608743f, -0.06195883f, 0.077335775f,
560 0.023646897f, -0.095322326f, 0.02233014f, 0.09756986f,
561 -0.048691444f, -0.009579111f, 0.07595467f, 0.11480546f,
562 -0.09801813f, 0.019894179f, 0.08502348f, 0.004032281f,
563 0.037211012f, 0.068537936f, -0.048005626f, -0.091520436f,
564 -0.028379958f, -0.01556313f, 0.06554592f, -0.045599163f,
565 -0.01672207f, -0.020169014f, -0.011877351f, -0.20212261f,
566 0.010889619f, 0.0047078193f, 0.038385306f, 0.08540671f,
567 -0.017140968f, -0.0035865551f, 0.016678626f, 0.005633034f,
568 0.015963363f, 0.00871737f, 0.060130805f, 0.028611384f,
569 0.10109069f, -0.015060172f, -0.07894427f, 0.06401885f,
570 0.011584063f, -0.024466386f, 0.0047652307f, -0.09041358f,
571 0.030737216f, -0.0046374933f, 0.14215417f, -0.11823516f,
572 0.019899689f, 0.006106124f, -0.027092824f, 0.0786356f,
573 0.05052217f, -0.058925f, -0.011402121f, -0.024987547f,
574 -0.0013661642f, -0.06832946f, -0.015667673f, -0.1083353f,
575 -0.00096863037f, -0.06988685f, -0.053350925f, -0.027275559f,
576 -0.033664223f, -0.07978348f, -0.025200296f, -0.017207067f,
577 -0.058403496f, -0.055697463f, 0.005798788f, 0.12965427f,
578 -0.062582195f, 0.0013350133f, -0.10482091f, 0.0379771f,
579 0.072521195f, -0.0029455067f, -0.13797039f, -0.03628521f,
580 0.013806405f, -0.017858358f, -0.01008298f, -0.07700066f,
581 -0.017081132f, 0.019358726f, 0.0027079724f, 0.004635139f,
582 0.062634714f, -0.02338735f, -0.039547626f, -0.02050681f,
583 0.03385117f, -0.083611414f, 0.002862572f, -0.09421313f,
584 0.058618143f, -0.08598433f, 0.00972939f, 0.023867095f,
585 -0.053934585f, -0.023203006f, 0.07452513f, -0.048767887f,
586 -0.07314807f, -0.056307215f, -0.10433547f, -0.06440842f,
587 0.04328182f, 0.04389765f, -0.020006588f, -0.09076438f,
588 -0.11652589f, -0.021705797f, 0.03345259f, -0.010329105f,
589 -0.025767034f, 0.013057034f, -0.07316461f, -0.10145612f,
590 0.06358255f, 0.18531723f, 0.07759293f, 0.12006465f,
591 0.1305557f, 0.058638252f, -0.03393652f, 0.09622831f,
592 -0.16253184f, -2.4580743e-06f, 0.079869635f, -0.070196845f,
593 -0.005644518f, 0.06857898f, -0.12598175f, -0.035084512f,
594 0.03156317f, -0.12794146f, -0.031963028f, 0.04692781f,
595 0.030070418f, 0.0071660685f, -0.095516115f, -0.004643372f,
596 0.040170413f, -0.062104587f, -0.0037324072f, 0.0554317f,
597 0.08184801f, -0.019164372f, 0.06791302f, 0.034257166f,
598 -0.10307039f, 0.021943003f, 0.046745934f, 0.0790918f,
599 -0.0265588f, -0.007824208f, 0.042546265f, -0.00977924f,
600 -0.0002440307f, -0.017384544f, -0.017990116f, 0.12252321f,
601 -0.014512694f, -0.08251313f, 0.08861942f, 0.13589665f,
602 0.026351685f, 0.012641483f, 0.07466548f, 0.044301085f,
603 -0.045414884f, -0.051112458f, 0.03444247f, -0.08502782f,
Sadik Armagan483c8112021-06-01 09:24:52 +0100604 -0.04106223f, -0.028126027f, 0.028473156f, 0.10467447f };
telsoa01c577f2c2018-08-31 09:22:23 +0100605
Sadik Armagan483c8112021-06-01 09:24:52 +0100606 std::vector<float> recurrentToForgetWeights = {-0.057784554f, -0.026057621f, -0.068447545f, -0.022581743f,
telsoa01c577f2c2018-08-31 09:22:23 +0100607 0.14811787f, 0.10826372f, 0.09471067f, 0.03987225f,
608 -0.0039523416f, 0.00030638507f, 0.053185795f, 0.10572994f,
609 0.08414449f, -0.022036452f, -0.00066928595f, -0.09203576f,
610 0.032950465f, -0.10985798f, -0.023809856f, 0.0021431844f,
611 -0.02196096f, -0.00326074f, 0.00058621005f, -0.074678116f,
612 -0.06193199f, 0.055729095f, 0.03736828f, 0.020123724f,
613 0.061878487f, -0.04729229f, 0.034919553f, -0.07585433f,
614 -0.04421272f, -0.044019096f, 0.085488975f, 0.04058006f,
615 -0.06890133f, -0.030951202f, -0.024628663f, -0.07672815f,
616 0.034293607f, 0.08556707f, -0.05293577f, -0.033561368f,
617 -0.04899627f, 0.0241671f, 0.015736353f, -0.095442444f,
618 -0.029564252f, 0.016493602f, -0.035026584f, 0.022337519f,
619 -0.026871363f, 0.004780428f, 0.0077918363f, -0.03601621f,
620 0.016435321f, -0.03263031f, -0.09543275f, -0.047392778f,
621 0.013454138f, 0.028934088f, 0.01685226f, -0.086110644f,
622 -0.046250615f, -0.01847454f, 0.047608484f, 0.07339695f,
623 0.034546845f, -0.04881143f, 0.009128804f, -0.08802852f,
624 0.03761666f, 0.008096139f, -0.014454086f, 0.014361001f,
625 -0.023502491f, -0.0011840804f, -0.07607001f, 0.001856849f,
626 -0.06509276f, -0.006021153f, -0.08570962f, -0.1451793f,
627 0.060212336f, 0.055259194f, 0.06974018f, 0.049454916f,
628 -0.027794661f, -0.08077226f, -0.016179763f, 0.1169753f,
629 0.17213494f, -0.0056326236f, -0.053934924f, -0.0124349f,
630 -0.11520337f, 0.05409887f, 0.088759385f, 0.0019655675f,
631 0.0042065294f, 0.03881498f, 0.019844765f, 0.041858196f,
632 -0.05695512f, 0.047233116f, 0.038937137f, -0.06542224f,
633 0.014429736f, -0.09719407f, 0.13908425f, -0.05379757f,
634 0.012321099f, 0.082840554f, -0.029899208f, 0.044217527f,
635 0.059855383f, 0.07711018f, -0.045319796f, 0.0948846f,
636 -0.011724666f, -0.0033288454f, -0.033542685f, -0.04764985f,
637 -0.13873616f, 0.040668588f, 0.034832682f, -0.015319203f,
638 -0.018715994f, 0.046002675f, 0.0599172f, -0.043107376f,
639 0.0294216f, -0.002314414f, -0.022424703f, 0.0030315618f,
640 0.0014641669f, 0.0029166266f, -0.11878115f, 0.013738511f,
641 0.12375372f, -0.0006038222f, 0.029104086f, 0.087442465f,
642 0.052958444f, 0.07558703f, 0.04817258f, 0.044462286f,
643 -0.015213451f, -0.08783778f, -0.0561384f, -0.003008196f,
644 0.047060397f, -0.002058388f, 0.03429439f, -0.018839769f,
645 0.024734668f, 0.024614193f, -0.042046934f, 0.09597743f,
646 -0.0043254104f, 0.04320769f, 0.0064070094f, -0.0019131786f,
647 -0.02558259f, -0.022822596f, -0.023273505f, -0.02464396f,
648 -0.10991725f, -0.006240552f, 0.0074488563f, 0.024044557f,
649 0.04383914f, -0.046476185f, 0.028658995f, 0.060410924f,
650 0.050786525f, 0.009452605f, -0.0073054377f, -0.024810238f,
651 0.0052906186f, 0.0066939713f, -0.0020913032f, 0.014515517f,
652 0.015898481f, 0.021362653f, -0.030262267f, 0.016587038f,
653 -0.011442813f, 0.041154444f, -0.007631438f, -0.03423484f,
654 -0.010977775f, 0.036152758f, 0.0066366293f, 0.11915515f,
655 0.02318443f, -0.041350313f, 0.021485701f, -0.10906167f,
656 -0.028218046f, -0.00954771f, 0.020531068f, -0.11995105f,
657 -0.03672871f, 0.024019798f, 0.014255957f, -0.05221243f,
658 -0.00661567f, -0.04630967f, 0.033188973f, 0.10107534f,
659 -0.014027541f, 0.030796422f, -0.10270911f, -0.035999842f,
660 0.15443139f, 0.07684145f, 0.036571592f, -0.035900835f,
661 -0.0034699554f, 0.06209149f, 0.015920248f, -0.031122351f,
662 -0.03858649f, 0.01849943f, 0.13872518f, 0.01503974f,
663 0.069941424f, -0.06948533f, -0.0088794185f, 0.061282158f,
664 -0.047401894f, 0.03100163f, -0.041533746f, -0.10430945f,
665 0.044574402f, -0.01425562f, -0.024290353f, 0.034563623f,
666 0.05866852f, 0.023947537f, -0.09445152f, 0.035450947f,
667 0.02247216f, -0.0042998926f, 0.061146557f, -0.10250651f,
668 0.020881841f, -0.06747029f, 0.10062043f, -0.0023941975f,
669 0.03532124f, -0.016341697f, 0.09685456f, -0.016764693f,
670 0.051808182f, 0.05875331f, -0.04536488f, 0.001626336f,
671 -0.028892258f, -0.01048663f, -0.009793449f, -0.017093895f,
672 0.010987891f, 0.02357273f, -0.00010856845f, 0.0099760275f,
673 -0.001845119f, -0.03551521f, 0.0018358806f, 0.05763657f,
674 -0.01769146f, 0.040995963f, 0.02235177f, -0.060430344f,
675 0.11475477f, -0.023854522f, 0.10071741f, 0.0686208f,
676 -0.014250481f, 0.034261297f, 0.047418304f, 0.08562733f,
677 -0.030519066f, 0.0060542435f, 0.014653856f, -0.038836084f,
678 0.04096551f, 0.032249358f, -0.08355519f, -0.026823482f,
679 0.056386515f, -0.010401743f, -0.028396193f, 0.08507674f,
680 0.014410365f, 0.020995233f, 0.17040324f, 0.11511526f,
681 0.02459721f, 0.0066619175f, 0.025853224f, -0.023133837f,
682 -0.081302024f, 0.017264642f, -0.009585969f, 0.09491168f,
683 -0.051313367f, 0.054532815f, -0.014298593f, 0.10657464f,
684 0.007076659f, 0.10964551f, 0.0409152f, 0.008275321f,
Sadik Armagan483c8112021-06-01 09:24:52 +0100685 -0.07283536f, 0.07937492f, 0.04192024f, -0.1075027f };
telsoa01c577f2c2018-08-31 09:22:23 +0100686
Sadik Armagan483c8112021-06-01 09:24:52 +0100687 std::vector<float> recurrentToCellWeights = { -0.037322544f, 0.018592842f, 0.0056175636f, -0.06253426f,
telsoa01c577f2c2018-08-31 09:22:23 +0100688 0.055647098f, -0.05713207f, -0.05626563f, 0.005559383f,
689 0.03375411f, -0.025757805f, -0.088049285f, 0.06017052f,
690 -0.06570978f, 0.007384076f, 0.035123326f, -0.07920549f,
691 0.053676967f, 0.044480428f, -0.07663568f, 0.0071805613f,
692 0.08089997f, 0.05143358f, 0.038261272f, 0.03339287f,
693 -0.027673481f, 0.044746667f, 0.028349208f, 0.020090483f,
694 -0.019443132f, -0.030755889f, -0.0040000007f, 0.04465846f,
695 -0.021585021f, 0.0031670958f, 0.0053199246f, -0.056117613f,
696 -0.10893326f, 0.076739706f, -0.08509834f, -0.027997585f,
697 0.037871376f, 0.01449768f, -0.09002357f, -0.06111149f,
698 -0.046195522f, 0.0422062f, -0.005683705f, -0.1253618f,
699 -0.012925729f, -0.04890792f, 0.06985068f, 0.037654128f,
700 0.03398274f, -0.004781977f, 0.007032333f, -0.031787455f,
701 0.010868644f, -0.031489216f, 0.09525667f, 0.013939797f,
702 0.0058680447f, 0.0167067f, 0.02668468f, -0.04797466f,
703 -0.048885044f, -0.12722108f, 0.035304096f, 0.06554885f,
704 0.00972396f, -0.039238118f, -0.05159735f, -0.11329045f,
705 0.1613692f, -0.03750952f, 0.06529313f, -0.071974665f,
706 -0.11769596f, 0.015524369f, -0.0013754242f, -0.12446318f,
707 0.02786344f, -0.014179351f, 0.005264273f, 0.14376344f,
708 0.015983658f, 0.03406988f, -0.06939408f, 0.040699873f,
709 0.02111075f, 0.09669095f, 0.041345075f, -0.08316494f,
710 -0.07684199f, -0.045768797f, 0.032298047f, -0.041805092f,
711 0.0119405f, 0.0061010392f, 0.12652606f, 0.0064572375f,
712 -0.024950314f, 0.11574242f, 0.04508852f, -0.04335324f,
713 0.06760663f, -0.027437469f, 0.07216407f, 0.06977076f,
714 -0.05438599f, 0.034033038f, -0.028602652f, 0.05346137f,
715 0.043184172f, -0.037189785f, 0.10420091f, 0.00882477f,
716 -0.054019816f, -0.074273005f, -0.030617684f, -0.0028467078f,
717 0.024302477f, -0.0038869337f, 0.005332455f, 0.0013399826f,
718 0.04361412f, -0.007001822f, 0.09631092f, -0.06702025f,
719 -0.042049985f, -0.035070654f, -0.04103342f, -0.10273396f,
720 0.0544271f, 0.037184782f, -0.13150354f, -0.0058036847f,
721 -0.008264958f, 0.042035464f, 0.05891794f, 0.029673764f,
722 0.0063542654f, 0.044788733f, 0.054816857f, 0.062257513f,
723 -0.00093483756f, 0.048938446f, -0.004952862f, -0.007730018f,
724 -0.04043371f, -0.017094059f, 0.07229206f, -0.023670016f,
725 -0.052195564f, -0.025616996f, -0.01520939f, 0.045104615f,
726 -0.007376126f, 0.003533447f, 0.006570588f, 0.056037236f,
727 0.12436656f, 0.051817212f, 0.028532185f, -0.08686856f,
728 0.11868599f, 0.07663395f, -0.07323171f, 0.03463402f,
729 -0.050708205f, -0.04458982f, -0.11590894f, 0.021273347f,
730 0.1251325f, -0.15313013f, -0.12224372f, 0.17228661f,
731 0.023029093f, 0.086124025f, 0.006445803f, -0.03496501f,
732 0.028332196f, 0.04449512f, -0.042436164f, -0.026587414f,
733 -0.006041347f, -0.09292539f, -0.05678812f, 0.03897832f,
734 0.09465633f, 0.008115513f, -0.02171956f, 0.08304309f,
735 0.071401566f, 0.019622514f, 0.032163795f, -0.004167056f,
736 0.02295182f, 0.030739572f, 0.056506045f, 0.004612461f,
737 0.06524936f, 0.059999723f, 0.046395954f, -0.0045512207f,
738 -0.1335546f, -0.030136576f, 0.11584653f, -0.014678886f,
739 0.0020118146f, -0.09688814f, -0.0790206f, 0.039770417f,
740 -0.0329582f, 0.07922767f, 0.029322514f, 0.026405897f,
741 0.04207835f, -0.07073373f, 0.063781224f, 0.0859677f,
742 -0.10925287f, -0.07011058f, 0.048005477f, 0.03438226f,
743 -0.09606514f, -0.006669445f, -0.043381985f, 0.04240257f,
744 -0.06955775f, -0.06769346f, 0.043903265f, -0.026784198f,
745 -0.017840602f, 0.024307009f, -0.040079936f, -0.019946516f,
746 0.045318738f, -0.12233574f, 0.026170589f, 0.0074471775f,
747 0.15978073f, 0.10185836f, 0.10298046f, -0.015476589f,
748 -0.039390966f, -0.072174534f, 0.0739445f, -0.1211869f,
749 -0.0347889f, -0.07943156f, 0.014809798f, -0.12412325f,
750 -0.0030663363f, 0.039695457f, 0.0647603f, -0.08291318f,
751 -0.018529687f, -0.004423833f, 0.0037507233f, 0.084633216f,
752 -0.01514876f, -0.056505352f, -0.012800942f, -0.06994386f,
753 0.012962922f, -0.031234352f, 0.07029052f, 0.016418684f,
754 0.03618972f, 0.055686004f, -0.08663945f, -0.017404709f,
755 -0.054761406f, 0.029065743f, 0.052404847f, 0.020238016f,
756 0.0048197987f, -0.0214882f, 0.07078733f, 0.013016777f,
757 0.06262858f, 0.009184685f, 0.020785125f, -0.043904778f,
758 -0.0270329f, -0.03299152f, -0.060088247f, -0.015162964f,
759 -0.001828936f, 0.12642565f, -0.056757294f, 0.013586685f,
760 0.09232601f, -0.035886683f, 0.06000002f, 0.05229691f,
761 -0.052580316f, -0.082029596f, -0.010794592f, 0.012947712f,
762 -0.036429964f, -0.085508935f, -0.13127148f, -0.017744139f,
763 0.031502828f, 0.036232427f, -0.031581745f, 0.023051167f,
764 -0.05325106f, -0.03421577f, 0.028793324f, -0.034633752f,
765 -0.009881397f, -0.043551125f, -0.018609839f, 0.0019097115f,
Sadik Armagan483c8112021-06-01 09:24:52 +0100766 -0.008799762f, 0.056595087f, 0.0022273948f, 0.055752404f };
telsoa01c577f2c2018-08-31 09:22:23 +0100767
Sadik Armagan483c8112021-06-01 09:24:52 +0100768 std::vector<float> recurrentToOutputWeights = { 0.025825322f, -0.05813119f, 0.09495884f,-0.045984812f, -0.01255415f,
769 -0.0026479573f,-0.08196161f,-0.054914974f,-0.0046604523f,
telsoa01c577f2c2018-08-31 09:22:23 +0100770 -0.029587349f, -0.044576716f, -0.07480124f, -0.082868785f,
771 0.023254942f, 0.027502948f, -0.0039728214f, -0.08683098f,
772 -0.08116779f, -0.014675607f, -0.037924774f, -0.023314456f,
773 -0.007401714f, -0.09255757f, 0.029460307f, -0.08829125f,
774 -0.005139627f, -0.08989442f, -0.0555066f, 0.13596267f,
775 -0.025062224f, -0.048351806f, -0.03850004f, 0.07266485f,
776 -0.022414139f, 0.05940088f, 0.075114764f, 0.09597592f,
777 -0.010211725f, -0.0049794707f, -0.011523867f, -0.025980417f,
778 0.072999895f, 0.11091378f, -0.081685916f, 0.014416728f,
779 0.043229222f, 0.034178585f, -0.07530371f, 0.035837382f,
780 -0.085607f, -0.007721233f, -0.03287832f, -0.043848954f,
781 -0.06404588f, -0.06632928f, -0.073643476f, 0.008214239f,
782 -0.045984086f, 0.039764922f, 0.03474462f, 0.060612556f,
783 -0.080590084f, 0.049127717f, 0.04151091f, -0.030063879f,
784 0.008801774f, -0.023021035f, -0.019558564f, 0.05158114f,
785 -0.010947698f, -0.011825728f, 0.0075720972f, 0.0699727f,
786 -0.0039981045f, 0.069350146f, 0.08799282f, 0.016156472f,
787 0.035502106f, 0.11695009f, 0.006217345f, 0.13392477f,
788 -0.037875112f, 0.025745004f, 0.08940699f, -0.00924166f,
789 0.0046702605f, -0.036598757f, -0.08811812f, 0.10522024f,
790 -0.032441203f, 0.008176899f, -0.04454919f, 0.07058152f,
791 0.0067963637f, 0.039206743f, 0.03259838f, 0.03725492f,
792 -0.09515802f, 0.013326398f, -0.052055415f, -0.025676316f,
793 0.03198509f, -0.015951829f, -0.058556724f, 0.036879618f,
794 0.043357447f, 0.028362012f, -0.05908629f, 0.0059240665f,
795 -0.04995891f, -0.019187413f,0.0276265f, -0.01628143f, 0.0025863599f,
796 0.08800015f, 0.035250366f, -0.022165963f, -0.07328642f,
797 -0.009415526f, -0.07455109f, 0.11690406f, 0.0363299f,
798 0.07411125f, 0.042103454f, -0.009660886f, 0.019076364f,
799 0.018299393f, -0.046004917f, 0.08891175f,0.0431396f, -0.026327137f,
800 -0.051502608f, 0.08979574f, -0.051670972f, 0.04940282f,
801 -0.07491107f, -0.021240504f, 0.022596184f, -0.034280192f,
802 0.060163025f, -0.058211457f, -0.051837247f, -0.01349775f,
803 -0.04639988f, -0.035936575f, -0.011681591f, 0.064818054f,
804 0.0073146066f, -0.021745546f, -0.043124277f, -0.06471268f,
805 -0.07053354f, -0.029321948f, -0.05330136f, 0.016933719f,
806 -0.053782392f, 0.13747959f, -0.1361751f, -0.11569455f,
807 0.0033329215f, 0.05693899f, -0.053219706f, 0.063698f,
808 0.07977434f, -0.07924483f, 0.06936997f, 0.0034815092f,
809 -0.007305279f, -0.037325785f, -0.07251102f, -0.033633437f,
810 -0.08677009f, 0.091591336f, -0.14165086f, 0.021752775f,
811 0.019683983f, 0.0011612234f, -0.058154266f, 0.049996935f,
812 0.0288841f, -0.0024567875f, -0.14345716f, 0.010955264f,-0.10234828f,
813 0.1183656f, -0.0010731248f, -0.023590032f,-0.072285876f,-0.0724771f,
814 -0.026382286f, -0.0014920527f, 0.042667855f, 0.0018776858f,
815 0.02986552f, 0.009814309f, 0.0733756f, 0.12289186f,
816 0.018043943f, -0.0458958f, 0.049412545f, 0.033632483f,
817 0.05495232f, 0.036686596f, -0.013781798f, -0.010036754f,
818 0.02576849f, -0.08307328f, 0.010112348f, 0.042521734f,
819 -0.05869831f, -0.071689695f, 0.03876447f, -0.13275425f, -0.0352966f,
820 -0.023077697f, 0.10285965f, 0.084736146f, 0.15568255f,
821 -0.00040734606f, 0.027835453f, -0.10292561f, -0.032401145f,
822 0.10053256f, -0.026142767f, -0.08271222f, -0.0030240538f,
823 -0.016368777f, 0.1070414f, 0.042672627f, 0.013456989f,
824 -0.0437609f, -0.022309763f, 0.11576483f, 0.04108048f,
825 0.061026827f, -0.0190714f, -0.0869359f, 0.037901703f, 0.0610107f,
826 0.07202949f, 0.01675338f, 0.086139716f, -0.08795751f,
827 -0.014898893f, -0.023771819f, -0.01965048f, 0.007955471f,
828 -0.043740474f, 0.03346837f, -0.10549954f, 0.090567775f,
829 0.042013682f, -0.03176985f, 0.12569028f, -0.02421228f,
830 -0.029526481f, 0.023851605f, 0.031539805f, 0.05292009f,
831 -0.02344001f, -0.07811758f, -0.08834428f, 0.10094801f,
832 0.16594367f, -0.06861939f, -0.021256343f, -0.041093912f,
833 -0.06669611f, 0.035498552f, 0.021757556f, -0.09302526f,
834 -0.015403468f, -0.06614931f, -0.051798206f, -0.013874718f,
835 0.03630673f, 0.010412845f, -0.08077351f, 0.046185967f,
836 0.0035662893f, 0.03541868f, -0.094149634f, -0.034814864f,
837 0.003128424f, -0.020674974f, -0.03944324f, -0.008110165f,
838 -0.11113267f, 0.08484226f, 0.043586485f, 0.040582247f,
839 0.0968012f, -0.065249965f, -0.028036479f, 0.0050708856f,
840 0.0017462453f, 0.0326779f, 0.041296225f, 0.09164146f,
841 -0.047743853f, -0.015952192f, -0.034451712f, 0.084197424f,
842 -0.05347844f, -0.11768019f, 0.085926116f, -0.08251791f,
843 -0.045081906f, 0.0948852f, 0.068401024f, 0.024856757f,
844 0.06978981f, -0.057309967f, -0.012775832f, -0.0032452994f,
Sadik Armagan483c8112021-06-01 09:24:52 +0100845 0.01977615f, -0.041040014f, -0.024264973f,0.063464895f, 0.05431621f};
telsoa01c577f2c2018-08-31 09:22:23 +0100846
Sadik Armagan483c8112021-06-01 09:24:52 +0100847 std::vector<float> cellToInputWeights = {0.040369894f, 0.030746894f, 0.24704495f, 0.018586371f, -0.037586458f,
848 -0.15312155f, -0.11812848f, -0.11465643f, 0.20259799f, 0.11418174f,
849 -0.10116027f, -0.011334949f, 0.12411352f, -0.076769054f,-0.052169047f,
850 0.21198851f, -0.38871562f, -0.09061183f, -0.09683246f, -0.21929175f};
telsoa01c577f2c2018-08-31 09:22:23 +0100851
852
Sadik Armagan483c8112021-06-01 09:24:52 +0100853 std::vector<float> cellToForgetWeights = {-0.01998659f,-0.15568835f,-0.24248174f, -0.012770197f, 0.041331276f,
854 -0.072311886f, -0.052123554f,-0.0066330447f,-0.043891653f,0.036225766f,
855 -0.047248036f, 0.021479502f,0.033189066f, 0.11952997f, -0.020432774f,
856 0.64658105f, -0.06650122f, -0.03467612f, 0.095340036f, 0.23647355f};
telsoa01c577f2c2018-08-31 09:22:23 +0100857
Sadik Armagan483c8112021-06-01 09:24:52 +0100858 std::vector<float> cellToOutputWeights = { 0.08286371f, -0.08261836f, -0.51210177f, 0.002913762f, 0.17764764f,
859 -0.5495371f, -0.08460716f, -0.24552552f, 0.030037103f, 0.04123544f,
860 -0.11940523f, 0.007358328f, 0.1890978f, 0.4833202f, -0.34441817f,
861 0.36312827f, -0.26375428f, 0.1457655f, -0.19724406f, 0.15548733f};
telsoa01c577f2c2018-08-31 09:22:23 +0100862
Sadik Armagan483c8112021-06-01 09:24:52 +0100863 std::vector<float> projectionWeights={-0.009802181f, 0.09401916f, 0.0717386f, -0.13895074f, 0.09641832f,
864 0.060420845f, 0.08539281f, 0.054285463f, 0.061395317f, 0.034448683f,
865 -0.042991187f, 0.019801661f, -0.16840284f, -0.015726732f, -0.23041931f,
866 -0.024478018f, -0.10959692f, -0.013875541f, 0.18600968f, -0.061274476f,
867 0.0138165f, -0.08160894f, -0.07661644f, 0.032372914f, 0.16169067f,
868 0.22465782f, -0.03993472f, -0.004017731f, 0.08633481f, -0.28869787f,
869 0.08682067f, 0.17240396f, 0.014975425f, 0.056431185f, 0.031037588f,
870 0.16702051f, 0.0077946745f, 0.15140012f, 0.29405436f, 0.120285f,
871 -0.188994f, -0.027265169f, 0.043389652f, -0.022061434f, 0.014777949f,
872 -0.20203483f, 0.094781205f, 0.19100232f, 0.13987629f, -0.036132768f,
873 -0.06426278f, -0.05108664f, 0.13221376f, 0.009441198f, -0.16715929f,
874 0.15859416f, -0.040437475f, 0.050779544f, -0.022187516f, 0.012166504f,
875 0.027685808f, -0.07675938f, -0.0055694645f, -0.09444123f, 0.0046453946f,
876 0.050794356f, 0.10770313f, -0.20790008f, -0.07149004f, -0.11425117f,
877 0.008225835f, -0.035802525f, 0.14374903f, 0.15262283f, 0.048710253f,
878 0.1847461f, -0.007487823f, 0.11000021f, -0.09542012f, 0.22619456f,
879 -0.029149994f, 0.08527916f, 0.009043713f, 0.0042746216f, 0.016261552f,
880 0.022461696f, 0.12689082f, -0.043589946f, -0.12035478f, -0.08361797f,
881 -0.050666027f, -0.1248618f, -0.1275799f, -0.071875185f, 0.07377272f,
882 0.09944291f, -0.18897448f, -0.1593054f, -0.06526116f, -0.040107165f,
883 -0.004618631f, -0.067624845f, -0.007576253f, 0.10727444f, 0.041546922f,
884 -0.20424393f, 0.06907816f, 0.050412357f, 0.00724631f, 0.039827548f,
885 0.12449835f, 0.10747581f, 0.13708383f, 0.09134148f, -0.12617786f,
886 -0.06428341f, 0.09956831f, 0.1208086f, -0.14676677f, -0.0727722f,
887 0.1126304f, 0.010139365f, 0.015571211f, -0.038128063f, 0.022913318f,
888 -0.042050496f, 0.16842307f, -0.060597885f, 0.10531834f, -0.06411776f,
889 -0.07451711f, -0.03410368f, -0.13393489f, 0.06534304f, 0.003620307f,
890 0.04490757f, 0.05970546f, 0.05197996f, 0.02839995f, 0.10434969f,
891 -0.013699693f, -0.028353551f, -0.07260381f, 0.047201227f, -0.024575593f,
892 -0.036445823f, 0.07155557f, 0.009672501f, -0.02328883f, 0.009533515f,
893 -0.03606021f, -0.07421458f, -0.028082801f, -0.2678904f, -0.13221288f,
894 0.18419984f, -0.13012612f, -0.014588381f, -0.035059117f, -0.04824723f,
895 0.07830115f, -0.056184657f, 0.03277091f, 0.025466874f, 0.14494097f,
896 -0.12522776f, -0.098633975f, -0.10766018f, -0.08317623f, 0.08594209f,
897 0.07749552f, 0.039474737f, 0.1776665f, -0.07409566f, -0.0477268f,
898 0.29323658f, 0.10801441f, 0.1154011f, 0.013952499f, 0.10739139f,
899 0.10708251f, -0.051456142f, 0.0074137426f, -0.10430189f, 0.10034707f,
900 0.045594677f, 0.0635285f, -0.0715442f, -0.089667566f, -0.10811871f,
901 0.00026344223f, 0.08298446f, -0.009525053f, 0.006585689f, -0.24567553f,
902 -0.09450807f, 0.09648481f, 0.026996298f, -0.06419476f, -0.04752702f,
903 -0.11063944f, -0.23441927f, -0.17608605f, -0.052156363f, 0.067035615f,
904 0.19271925f, -0.0032889997f, -0.043264326f, 0.09663576f, -0.057112187f,
905 -0.10100678f, 0.0628376f, 0.04447668f, 0.017961001f, -0.10094388f,
906 -0.10190601f, 0.18335468f, 0.10494553f, -0.052095775f, -0.0026118709f,
907 0.10539724f, -0.04383912f, -0.042349473f, 0.08438151f, -0.1947263f,
908 0.02251204f, 0.11216432f, -0.10307853f, 0.17351969f, -0.039091777f,
909 0.08066188f, -0.00561982f, 0.12633002f, 0.11335965f, -0.0088127935f,
910 -0.019777594f, 0.06864014f, -0.059751723f, 0.016233567f, -0.06894641f,
911 -0.28651384f, -0.004228674f, 0.019708522f, -0.16305895f, -0.07468996f,
912 -0.0855457f, 0.099339016f, -0.07580735f, -0.13775392f, 0.08434318f,
913 0.08330512f, -0.12131499f, 0.031935584f, 0.09180414f, -0.08876437f,
914 -0.08049874f, 0.008753825f, 0.03498998f, 0.030215185f, 0.03907079f,
915 0.089751154f, 0.029194152f, -0.03337423f, -0.019092513f, 0.04331237f,
916 0.04299654f, -0.036394123f, -0.12915532f, 0.09793732f, 0.07512415f,
917 -0.11319543f, -0.032502122f, 0.15661901f, 0.07671967f, -0.005491124f,
918 -0.19379048f, -0.218606f, 0.21448623f, 0.017840758f, 0.1416943f,
919 -0.07051762f, 0.19488361f, 0.02664691f, -0.18104725f, -0.09334311f,
920 0.15026465f, -0.15493552f, -0.057762887f, -0.11604192f, -0.262013f,
921 -0.01391798f, 0.012185008f, 0.11156489f, -0.07483202f, 0.06693364f,
922 -0.26151478f, 0.046425626f, 0.036540434f, -0.16435726f, 0.17338543f,
923 -0.21401681f, -0.11385144f, -0.08283257f, -0.069031075f, 0.030635102f,
924 0.010969227f, 0.11109743f, 0.010919218f, 0.027526086f, 0.13519906f,
925 0.01891392f, -0.046839405f, -0.040167913f, 0.017953383f, -0.09700955f,
926 0.0061885654f, -0.07000971f, 0.026893595f, -0.038844477f, 0.14543656f};
telsoa01c577f2c2018-08-31 09:22:23 +0100927
928 std::vector<float> projectionBiasVector(outputSize, 0.f);
telsoa01c577f2c2018-08-31 09:22:23 +0100929
James Conroy1f58f032021-04-27 17:13:27 +0100930 armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo20x5);
931 armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo20x5);
932 armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo20x5);
933 armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo20x5);
934 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo20x16);
935 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo20x16);
936 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo20x16);
937 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo20x16);
938 armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo20);
939 armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo20);
940 armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo20);
941 armnn::ScopedTensorHandle cellBiasTensor(tensorInfo20);
942 armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo20);
943 armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfo20);
944 armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfo20);
945 armnn::ScopedTensorHandle projectionWeightsTensor(tensorInfo16x20);
946 armnn::ScopedTensorHandle projectionBiasTensor(tensorInfo16);
telsoa01c577f2c2018-08-31 09:22:23 +0100947
Sadik Armagan483c8112021-06-01 09:24:52 +0100948 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
949 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
950 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
951 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
952 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
953 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
954 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
955 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
956 AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data());
957 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
958 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
959 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
960 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
961 AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data());
962 AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data());
963 AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
964 AllocateAndCopyDataToITensorHandle(&projectionBiasTensor, projectionBiasVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +0100965
966 data.m_InputToInputWeights = &inputToInputWeightsTensor;
967 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
968 data.m_InputToCellWeights = &inputToCellWeightsTensor;
969 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
970 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
971 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
972 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
973 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
974 data.m_CellToInputWeights = &cellToInputWeightsTensor;
975 data.m_InputGateBias = &inputGateBiasTensor;
976 data.m_ForgetGateBias = &forgetGateBiasTensor;
977 data.m_CellBias = &cellBiasTensor;
978 data.m_OutputGateBias = &outputGateBiasTensor;
979 data.m_CellToForgetWeights = &cellToForgetWeightsTensor;
980 data.m_CellToOutputWeights = &cellToOutputWeightsTensor;
981 data.m_ProjectionWeights = &projectionWeightsTensor;
982 data.m_ProjectionBias = &projectionBiasTensor;
983
984 // Flags to set test configuration
985 data.m_Parameters.m_ActivationFunc = 4;
986 data.m_Parameters.m_CifgEnabled = false;
987 data.m_Parameters.m_PeepholeEnabled = true;
988 data.m_Parameters.m_ProjectionEnabled = true;
989
Teresa Charlin611c7fb2022-01-07 09:47:29 +0000990 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info);
telsoa01c577f2c2018-08-31 09:22:23 +0100991 inputHandle->Allocate();
992 outputStateInHandle->Allocate();
993 cellStateInHandle->Allocate();
994
995 scratchHandle->Allocate();
996 outputStateOutHandle->Allocate();
997 cellStateOutHandle->Allocate();
998 outputHandle->Allocate();
999
Sadik Armagan483c8112021-06-01 09:24:52 +01001000 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1001 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1002 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001003
telsoa01c577f2c2018-08-31 09:22:23 +01001004 workload->Execute();
1005
Sadik Armagan483c8112021-06-01 09:24:52 +01001006 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
telsoa01c577f2c2018-08-31 09:22:23 +01001007
Sadik Armagan483c8112021-06-01 09:24:52 +01001008 return LayerTestResult<T, 2>(actualOutput,
1009 outputVector,
1010 outputHandle->GetShape(),
1011 outputTensorInfo.GetShape());
telsoa01c577f2c2018-08-31 09:22:23 +01001012}
1013
Conor Kennedyb9971c92019-05-07 07:14:23 +01001014template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
1015LayerTestResult<T, 2> LstmLayerWithCifgWithPeepholeNoProjectionTestImpl(
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +00001016 armnn::IWorkloadFactory& workloadFactory,
1017 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001018 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001019 const std::vector<T>& input,
1020 const std::vector<T>& outputExpected,
1021 const armnn::TensorShape& inputShape,
1022 const armnn::TensorShape& outputExpectedShape,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001023 float qScale = 0.0f,
1024 int32_t qOffset = 0,
1025 armnn::DataType constantDataType = armnn::DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01001026{
Jan Eilers8eb25602020-03-09 12:13:48 +00001027 IgnoreUnused(memoryManager);
telsoa01c577f2c2018-08-31 09:22:23 +01001028 bool cifgEnabled = true;
1029 bool peepholeEnabled = true;
1030 bool projectionEnabled = false;
1031 // These are not the input and the output of Lstm yet
Sadik Armagan483c8112021-06-01 09:24:52 +01001032 unsigned int batchSize = armnn::numeric_cast<unsigned int>(inputShape[0]);
1033 unsigned int inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]);
telsoa01c577f2c2018-08-31 09:22:23 +01001034
Sadik Armagan483c8112021-06-01 09:24:52 +01001035 unsigned int outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]);
telsoa01c577f2c2018-08-31 09:22:23 +01001036
1037 const unsigned int cellSize = outputSize;
1038
1039 // Decide the shape of all input tensors
Conor Kennedyb9971c92019-05-07 07:14:23 +01001040 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset); // change to ArmnnType
1041 armnn::TensorInfo outputStateInTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1042 armnn::TensorInfo cellStateInTensorInfo({batchSize, cellSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +01001043
Matteo Martincigha65b7ae2018-11-14 12:39:55 +00001044 unsigned int scratchBufferSize = cifgEnabled ? cellSize * 3 : cellSize * 4;
Conor Kennedyb9971c92019-05-07 07:14:23 +01001045 armnn::TensorInfo scratchBufferTensorInfo({batchSize, scratchBufferSize}, ArmnnType, qScale, qOffset);
1046 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1047 armnn::TensorInfo cellStateOutTensorInfo({batchSize, cellSize}, ArmnnType, qScale, qOffset);
1048 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +01001049
1050 // List of inputs
1051 std::vector<float> inputData;
1052 inputData.assign(input.data(), input.data() + batchSize*inputSize);
telsoa01c577f2c2018-08-31 09:22:23 +01001053
1054 std::vector<float> outputStateInVector(batchSize * outputSize, 0.f);
telsoa01c577f2c2018-08-31 09:22:23 +01001055
1056 std::vector<float> cellStateInVector(batchSize * cellSize, 0.f);
telsoa01c577f2c2018-08-31 09:22:23 +01001057
1058 // Prepare all the weights in the descriptor for LSTM
1059 armnn::LstmQueueDescriptor data;
Conor Kennedyb9971c92019-05-07 07:14:23 +01001060 armnn::TensorInfo tensorInfoInput({cellSize, inputSize}, constantDataType, qScale, qOffset);
1061 armnn::TensorInfo tensorInfoOutput({cellSize, outputSize}, constantDataType, qScale, qOffset);
1062 armnn::TensorInfo tensorInfoNumUnits({cellSize}, constantDataType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +01001063
Sadik Armagan483c8112021-06-01 09:24:52 +01001064 std::vector<float> inputToCellWeights =
1065 {
1066 -0.49770179f, -0.27711356f, -0.09624726f, 0.05100781f,
1067 0.04717243f, 0.48944736f, -0.38535351f,
1068 -0.17212132f
1069 };
1070 std::vector<float> inputToForgetWeights =
1071 {
1072 -0.55291498f, -0.42866567f, 0.13056988f,
1073 -0.3633365f, -0.22755712f, 0.28253698f, 0.24407166f,
1074 0.33826375f
1075 };
1076 std::vector<float> inputToOutputWeights =
1077 {
1078 0.10725588f, -0.02335852f, -0.55932593f,
1079 -0.09426838f, -0.44257352f, 0.54939759f,
1080 0.01533556f, 0.42751634f
1081 };
1082 std::vector<float> cellBias = {0.f, 0.f, 0.f, 0.f};
1083 std::vector<float> forgetGateBias = {1.f, 1.f, 1.f, 1.f};
1084 std::vector<float> outputGateBias = {0.f, 0.f, 0.f, 0.f};
telsoa01c577f2c2018-08-31 09:22:23 +01001085
Sadik Armagan483c8112021-06-01 09:24:52 +01001086 std::vector<float> recurrentToCellWeights =
1087 {
1088 0.54066205f, -0.32668582f, -0.43562764f, -0.56094903f, 0.42957711f,
1089 0.01841056f, -0.32764608f, -0.33027974f, -0.10826075f, 0.20675004f,
1090 0.19069612f, -0.03026325f, -0.54532051f, 0.33003211f, 0.44901288f,
1091 0.21193194f
1092 };
1093 std::vector<float> recurrentToForgetWeights =
1094 {
1095 -0.13832897f, -0.0515101f, -0.2359007f, -0.16661474f, -0.14340827f,
1096 0.36986142f, 0.23414481f, 0.55899f, 0.10798943f, -0.41174671f, 0.17751795f,
1097 -0.34484994f, -0.35874045f, -0.11352962f, 0.27268326f, 0.54058349f
1098 };
telsoa01c577f2c2018-08-31 09:22:23 +01001099
Sadik Armagan483c8112021-06-01 09:24:52 +01001100 std::vector<float> recurrentToOutputWeights =
1101 {
1102 0.41613156f, 0.42610586f, -0.16495961f, -0.5663873f, 0.30579174f, -0.05115908f,
1103 -0.33941799f, 0.23364776f, 0.11178309f, 0.09481031f, -0.26424935f, 0.46261835f,
1104 0.50248802f, 0.26114327f, -0.43736315f, 0.33149987f
1105 };
telsoa01c577f2c2018-08-31 09:22:23 +01001106
Sadik Armagan483c8112021-06-01 09:24:52 +01001107 std::vector<float> cellToForgetWeights = {0.47485286f, -0.51955009f, -0.24458408f, 0.31544167f};
1108 std::vector<float> cellToOutputWeights = {-0.17135078f, 0.82760304f, 0.85573703f, -0.77109635f};
telsoa01c577f2c2018-08-31 09:22:23 +01001109
James Conroy1f58f032021-04-27 17:13:27 +01001110 armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfoInput);
1111 armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfoInput);
1112 armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfoInput);
telsoa01c577f2c2018-08-31 09:22:23 +01001113
James Conroy1f58f032021-04-27 17:13:27 +01001114 armnn::ScopedTensorHandle cellBiasTensor(tensorInfoNumUnits);
1115 armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfoNumUnits);
1116 armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfoNumUnits);
telsoa01c577f2c2018-08-31 09:22:23 +01001117
James Conroy1f58f032021-04-27 17:13:27 +01001118 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfoOutput);
1119 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfoOutput);
1120 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfoOutput);
telsoa01c577f2c2018-08-31 09:22:23 +01001121
James Conroy1f58f032021-04-27 17:13:27 +01001122 armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfoNumUnits);
1123 armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfoNumUnits);
telsoa01c577f2c2018-08-31 09:22:23 +01001124
Sadik Armagan483c8112021-06-01 09:24:52 +01001125 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1126 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1127 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001128
Sadik Armagan483c8112021-06-01 09:24:52 +01001129 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1130 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1131 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001132
Sadik Armagan483c8112021-06-01 09:24:52 +01001133 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1134 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1135 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001136
Sadik Armagan483c8112021-06-01 09:24:52 +01001137 AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data());
1138 AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001139
1140 data.m_InputToCellWeights = &inputToCellWeightsTensor;
1141 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1142 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1143
1144 data.m_CellBias = &cellBiasTensor;
1145 data.m_ForgetGateBias = &forgetGateBiasTensor;
1146 data.m_OutputGateBias = &outputGateBiasTensor;
1147
1148 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1149 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1150 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1151
1152 data.m_CellToForgetWeights = &cellToForgetWeightsTensor;
1153 data.m_CellToOutputWeights = &cellToOutputWeightsTensor;
1154
1155 // other parameters for the descriptor
1156 data.m_Parameters.m_CifgEnabled = cifgEnabled;
1157 data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
1158 data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
1159
1160 data.m_Parameters.m_ActivationFunc = 4;
1161 data.m_Parameters.m_ClippingThresProj = 0.0;
1162 data.m_Parameters.m_ClippingThresCell = 0.0;
1163
telsoa01c577f2c2018-08-31 09:22:23 +01001164 // List of outputs
Rob Hughesbb46dde2020-05-20 15:27:37 +01001165 std::vector<T> scratchBufferVector(batchSize * scratchBufferSize, T());
Conor Kennedyb9971c92019-05-07 07:14:23 +01001166 LayerTestResult<T, 2> ret0(scratchBufferTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001167
1168 // Output state for a certain time step
Rob Hughesbb46dde2020-05-20 15:27:37 +01001169 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
Conor Kennedyb9971c92019-05-07 07:14:23 +01001170 LayerTestResult<T, 2> ret1(outputStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001171
1172 // Cell state for a certain time step
Rob Hughesbb46dde2020-05-20 15:27:37 +01001173 std::vector<T> cellStateOutVector(batchSize * cellSize, T());
Conor Kennedyb9971c92019-05-07 07:14:23 +01001174 LayerTestResult<T, 2> ret2(cellStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001175
1176 // Output for a certain time step
Rob Hughesbb46dde2020-05-20 15:27:37 +01001177 std::vector<T> outputData;
telsoa01c577f2c2018-08-31 09:22:23 +01001178 outputData.assign(outputExpected.data(), outputExpected.data() + batchSize*outputSize);
Conor Kennedyb9971c92019-05-07 07:14:23 +01001179 LayerTestResult<T, 2> ret3(outputTensorInfo);
Sadik Armagan483c8112021-06-01 09:24:52 +01001180 ret3.m_ExpectedData = outputData;
1181
1182 std::vector<T> actualScratchBufferOutput(scratchBufferTensorInfo.GetNumElements());
1183 std::vector<T> actualOutputStateOutput(outputStateOutTensorInfo.GetNumElements());
1184 std::vector<T> actualCellStateOutput(cellStateOutTensorInfo.GetNumElements());
1185 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
telsoa01c577f2c2018-08-31 09:22:23 +01001186
1187 // Prepare the inputs and outputs for the workload
1188 std::unique_ptr<armnn::ITensorHandle> inputHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001189 tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001190 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001191 tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001192 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001193 tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001194
1195 std::unique_ptr<armnn::ITensorHandle> scratchBufferHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001196 tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001197 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001198 tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001199 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001200 tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001201 std::unique_ptr<armnn::ITensorHandle> outputHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001202 tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001203
1204 armnn::WorkloadInfo info;
1205 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
1206 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
1207 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
1208
1209 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchBufferHandle.get());
1210 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
1211 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
1212 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
1213
Teresa Charlin611c7fb2022-01-07 09:47:29 +00001214 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info);
telsoa01c577f2c2018-08-31 09:22:23 +01001215
telsoa01c577f2c2018-08-31 09:22:23 +01001216 inputHandle->Allocate();
1217 outputStateInHandle->Allocate();
1218 cellStateInHandle->Allocate();
1219
1220 scratchBufferHandle->Allocate();
1221 outputStateOutHandle->Allocate();
1222 cellStateOutHandle->Allocate();
1223 outputHandle->Allocate();
1224
Sadik Armagan483c8112021-06-01 09:24:52 +01001225 CopyDataToITensorHandle(inputHandle.get(), inputData.data());
1226 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1227 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001228
Sadik Armagan483c8112021-06-01 09:24:52 +01001229 CopyDataToITensorHandle(scratchBufferHandle.get(), scratchBufferVector.data());
1230 CopyDataToITensorHandle(outputStateOutHandle.get(), outputStateOutVector.data());
1231 CopyDataToITensorHandle(cellStateOutHandle.get(), cellStateOutVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001232
telsoa01c577f2c2018-08-31 09:22:23 +01001233 workload->Execute();
1234
Sadik Armagan483c8112021-06-01 09:24:52 +01001235 CopyDataFromITensorHandle(actualScratchBufferOutput.data(), scratchBufferHandle.get());
1236 CopyDataFromITensorHandle(actualOutputStateOutput.data(), outputStateOutHandle.get());
1237 CopyDataFromITensorHandle(actualCellStateOutput.data(), cellStateOutHandle.get());
1238 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
1239
1240 ret0.m_ActualData = actualScratchBufferOutput;
1241 ret1.m_ActualData = actualOutputStateOutput;
1242 ret2.m_ActualData = actualCellStateOutput;
1243 ret3.m_ActualData = actualOutput;
telsoa01c577f2c2018-08-31 09:22:23 +01001244
1245 return ret3;
1246}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001247
Jan Eilers38e05bd2019-06-26 13:10:09 +01001248template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
1249LayerTestResult<T, 2>
1250LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl(armnn::IWorkloadFactory& workloadFactory,
1251 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001252 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001253 const std::vector<T>& input,
1254 const std::vector<T>& outputExpected,
Jan Eilers38e05bd2019-06-26 13:10:09 +01001255 float qScale = 0.0f,
1256 int32_t qOffset = 0,
1257 armnn::DataType constantDataType = armnn::DataType::Float32)
1258{
Jan Eilers8eb25602020-03-09 12:13:48 +00001259 IgnoreUnused(memoryManager);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001260 unsigned int batchSize = 2;
1261 unsigned int outputSize = 3;
1262 unsigned int inputSize = 5;
1263 unsigned numUnits = 4;
1264
1265 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset);
1266 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
1267 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
1268
1269 // Scratch buffer size without CIFG [batchSize, numUnits * 4]
1270 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
1271 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
1272 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1273 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1274
Jan Eilers38e05bd2019-06-26 13:10:09 +01001275 std::vector<float> inputVector;
1276 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001277
1278 std::vector<float> cellStateInVector(batchSize * numUnits, 0.f);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001279 std::vector<float> outputStateInVector(batchSize * outputSize, 0.f);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001280 std::vector<float> scratchBufferVector(batchSize * numUnits * 4, 0.f);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001281 std::vector<float> outputStateOutVector(batchSize * outputSize, 0.f);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001282 std::vector<float> cellStateOutVector(batchSize * numUnits, 0.f);
Sadik Armagan483c8112021-06-01 09:24:52 +01001283
1284 std::vector<float> actualOutput(outputTensorInfo.GetNumElements());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001285
1286 std::vector<float> outputVector;
1287 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001288
Finn Williamsc43de6a2020-08-27 11:13:25 +01001289 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001290 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001291 tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001292 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001293 tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001294
Finn Williamsc43de6a2020-08-27 11:13:25 +01001295 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
1296 tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001297 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001298 tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001299 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001300 tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
1301 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001302
1303 armnn::LstmQueueDescriptor data;
1304 armnn::WorkloadInfo info;
1305
1306 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
1307 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
1308 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
1309
1310 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
1311 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
1312 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
1313 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
1314
1315 armnn::TensorInfo tensorInfo3({outputSize}, constantDataType, qScale, qOffset);
1316 armnn::TensorInfo tensorInfo4({numUnits}, constantDataType, qScale, qOffset);
1317 armnn::TensorInfo tensorInfo4x5({numUnits, inputSize}, constantDataType, qScale, qOffset);
1318 armnn::TensorInfo tensorInfo4x3({numUnits, outputSize}, constantDataType, qScale, qOffset);
1319 armnn::TensorInfo tensorInfo3x4({outputSize, numUnits}, constantDataType, qScale, qOffset);
1320
Sadik Armagan483c8112021-06-01 09:24:52 +01001321 std::vector<float> inputToInputWeights = {0.5f, 0.6f, 0.7f, -0.8f, -0.9f,
1322 0.1f, 0.2f, 0.3f, -0.4f, 0.5f,
1323 -0.8f, 0.7f, -0.6f, 0.5f, -0.4f,
1324 -0.5f, -0.4f, -0.3f, -0.2f, -0.1f}; //{numUnits, inputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001325
Sadik Armagan483c8112021-06-01 09:24:52 +01001326 std::vector<float> inputToForgetWeights = { -0.6f, -0.1f, 0.3f, 0.2f, 0.9f,
1327 -0.5f, -0.2f, -0.4f, 0.3f, -0.8f,
1328 -0.4f, 0.3f, -0.5f, -0.4f, -0.6f,
1329 0.3f, -0.4f, -0.6f, -0.5f, -0.5f}; //{numUnits, inputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001330
Sadik Armagan483c8112021-06-01 09:24:52 +01001331 std::vector<float> inputToCellWeights = {-0.4f, -0.3f, -0.2f, -0.1f, -0.5f,
1332 0.5f, -0.2f, -0.3f, -0.2f, -0.6f,
1333 0.6f, -0.1f, -0.4f, -0.3f, -0.7f,
1334 0.7f, -0.9f, -0.5f, 0.8f, 0.6f}; //{numUnits, inputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001335
Sadik Armagan483c8112021-06-01 09:24:52 +01001336 std::vector<float> inputToOutputWeights = {-0.8f, -0.4f, -0.2f, -0.9f, -0.1f,
1337 -0.7f, 0.3f, -0.3f, -0.8f, -0.2f,
1338 0.6f, -0.2f, 0.4f, -0.7f, -0.3f,
1339 -0.5f, 0.1f, 0.5f, -0.6f, -0.4f}; //{numUnits, inputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001340
Sadik Armagan483c8112021-06-01 09:24:52 +01001341 std::vector<float> inputGateBias = {0.03f, 0.15f, 0.22f, 0.38f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001342
Sadik Armagan483c8112021-06-01 09:24:52 +01001343 std::vector<float> forgetGateBias = {0.1f, -0.3f, -0.2f, 0.1f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001344
Sadik Armagan483c8112021-06-01 09:24:52 +01001345 std::vector<float> cellBias = {-0.05f, 0.72f, 0.25f, 0.08f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001346
Sadik Armagan483c8112021-06-01 09:24:52 +01001347 std::vector<float> outputGateBias = {0.05f, -0.01f, 0.2f, 0.1f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001348
Sadik Armagan483c8112021-06-01 09:24:52 +01001349 std::vector<float> recurrentToInputWeights ={-0.2f, -0.3f, 0.4f,
Jan Eilers38e05bd2019-06-26 13:10:09 +01001350 0.1f, -0.5f, 0.9f,
1351 -0.2f, -0.3f, -0.7f,
Sadik Armagan483c8112021-06-01 09:24:52 +01001352 0.05f, -0.2f, -0.6f}; //{numUnits, outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001353
Sadik Armagan483c8112021-06-01 09:24:52 +01001354 std::vector<float> recurrentToCellWeights = {-0.3f, 0.2f, 0.1f,
Jan Eilers38e05bd2019-06-26 13:10:09 +01001355 -0.3f, 0.8f, -0.08f,
1356 -0.2f, 0.3f, 0.8f,
Sadik Armagan483c8112021-06-01 09:24:52 +01001357 -0.6f, -0.1f, 0.2f}; //{numUnits, outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001358
Sadik Armagan483c8112021-06-01 09:24:52 +01001359 std::vector<float> recurrentToForgetWeights = { -0.5f, -0.3f, -0.5f,
1360 -0.2f, 0.6f, 0.4f,
1361 0.9f, 0.3f, -0.1f,
1362 0.2f, 0.5f, 0.2f}; //{numUnits, outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001363
Sadik Armagan483c8112021-06-01 09:24:52 +01001364 std::vector<float> recurrentToOutputWeights = { 0.3f, -0.1f, 0.1f,
1365 -0.2f, -0.5f, -0.7f,
1366 -0.2f, -0.6f, -0.1f,
1367 -0.4f, -0.7f, -0.2f}; //{numUnits, outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001368
Sadik Armagan483c8112021-06-01 09:24:52 +01001369 std::vector<float> cellToInputWeights = {0.05f, 0.1f, 0.25f, 0.15f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001370
Sadik Armagan483c8112021-06-01 09:24:52 +01001371 std::vector<float> cellToForgetWeights = {-0.02f, -0.15f, -0.25f, -0.03f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001372
Sadik Armagan483c8112021-06-01 09:24:52 +01001373 std::vector<float> cellToOutputWeights = {0.1f, -0.1f, -0.5f, 0.05f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001374
Sadik Armagan483c8112021-06-01 09:24:52 +01001375 std::vector<float> projectionWeights = {-0.1f, 0.2f, 0.01f, -0.2f,
1376 0.1f, 0.5f, 0.3f, 0.08f,
1377 0.07f, 0.2f, -0.4f, 0.2f}; //{outputSize, numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001378
Sadik Armagan483c8112021-06-01 09:24:52 +01001379 std::vector<float> projectionBiasVector(outputSize, 0.f); //{outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001380
Sadik Armagan483c8112021-06-01 09:24:52 +01001381 std::vector<float> inputLayerNormWeights = {0.1f, 0.2f, 0.3f, 0.5f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001382
Sadik Armagan483c8112021-06-01 09:24:52 +01001383 std::vector<float> forgetLayerNormWeights = {0.2f, 0.2f, 0.4f, 0.3f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001384
Sadik Armagan483c8112021-06-01 09:24:52 +01001385 std::vector<float> cellLayerNormWeights = {0.7f, 0.2f, 0.3f, 0.8f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001386
Sadik Armagan483c8112021-06-01 09:24:52 +01001387 std::vector<float> outputLayerNormWeights = {0.6f, 0.2f, 0.2f, 0.5f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001388
1389
James Conroy1f58f032021-04-27 17:13:27 +01001390 armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo4x5);
1391 armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo4x5);
1392 armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo4x5);
1393 armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo4x5);
1394 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo4x3);
1395 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo4x3);
1396 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo4x3);
1397 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo4x3);
1398 armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo4);
1399 armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo4);
1400 armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo4);
1401 armnn::ScopedTensorHandle cellBiasTensor(tensorInfo4);
1402 armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo4);
1403 armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfo4);
1404 armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfo4);
1405 armnn::ScopedTensorHandle projectionWeightsTensor(tensorInfo3x4);
1406 armnn::ScopedTensorHandle projectionBiasTensor(tensorInfo3);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001407
James Conroy1f58f032021-04-27 17:13:27 +01001408 armnn::ScopedTensorHandle inputLayerNormWeightsTensor(tensorInfo4);
1409 armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(tensorInfo4);
1410 armnn::ScopedTensorHandle cellLayerNormWeightsTensor(tensorInfo4);
1411 armnn::ScopedTensorHandle outputLayerNormWeightsTensor(tensorInfo4);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001412
Sadik Armagan483c8112021-06-01 09:24:52 +01001413 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
1414 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1415 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1416 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
1417 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
1418 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1419 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1420 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
1421 AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data());
1422 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
1423 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1424 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1425 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
1426 AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data());
1427 AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data());
1428 AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
1429 AllocateAndCopyDataToITensorHandle(&projectionBiasTensor, projectionBiasVector.data());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001430
Sadik Armagan483c8112021-06-01 09:24:52 +01001431 AllocateAndCopyDataToITensorHandle(&inputLayerNormWeightsTensor, inputLayerNormWeights.data());
1432 AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
1433 AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
1434 AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001435
1436 data.m_InputToInputWeights = &inputToInputWeightsTensor;
1437 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1438 data.m_InputToCellWeights = &inputToCellWeightsTensor;
1439 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1440 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
1441 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1442 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1443 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1444 data.m_CellToInputWeights = &cellToInputWeightsTensor;
1445 data.m_InputGateBias = &inputGateBiasTensor;
1446 data.m_ForgetGateBias = &forgetGateBiasTensor;
1447 data.m_CellBias = &cellBiasTensor;
1448 data.m_OutputGateBias = &outputGateBiasTensor;
1449 data.m_CellToForgetWeights = &cellToForgetWeightsTensor;
1450 data.m_CellToOutputWeights = &cellToOutputWeightsTensor;
1451 data.m_ProjectionWeights = &projectionWeightsTensor;
1452 data.m_ProjectionBias = &projectionBiasTensor;
1453
1454 data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor;
1455 data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
1456 data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
1457 data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
1458
1459 // Flags to set test configuration
1460 data.m_Parameters.m_ActivationFunc = 4;
1461 data.m_Parameters.m_CifgEnabled = false;
1462 data.m_Parameters.m_PeepholeEnabled = true;
1463 data.m_Parameters.m_ProjectionEnabled = true;
1464 data.m_Parameters.m_LayerNormEnabled = true;
1465
1466
Teresa Charlin611c7fb2022-01-07 09:47:29 +00001467 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Lstm, data, info);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001468 inputHandle->Allocate();
1469 outputStateInHandle->Allocate();
1470 cellStateInHandle->Allocate();
1471
1472 scratchHandle->Allocate();
1473 outputStateOutHandle->Allocate();
1474 cellStateOutHandle->Allocate();
1475 outputHandle->Allocate();
1476
Sadik Armagan483c8112021-06-01 09:24:52 +01001477 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1478 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1479 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001480
1481 workload->Execute();
1482
Sadik Armagan483c8112021-06-01 09:24:52 +01001483 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001484
Sadik Armagan483c8112021-06-01 09:24:52 +01001485 return LayerTestResult<T, 2>(actualOutput,
1486 outputVector,
1487 outputHandle->GetShape(),
1488 outputTensorInfo.GetShape());
James Conroy9c3cae82019-08-01 16:01:48 +01001489}
1490
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001491LayerTestResult<uint8_t, 2> QuantizedLstmTestImpl(
1492 armnn::IWorkloadFactory& workloadFactory,
1493 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001494 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001495 const std::vector<uint8_t>& input,
1496 const std::vector<uint8_t>& outputExpected,
1497 const armnn::TensorShape& inputShape,
1498 const armnn::TensorShape& outputExpectedShape)
James Conroy9c3cae82019-08-01 16:01:48 +01001499{
Jan Eilers8eb25602020-03-09 12:13:48 +00001500 IgnoreUnused(memoryManager);
Sadik Armagan483c8112021-06-01 09:24:52 +01001501 auto numBatches = armnn::numeric_cast<unsigned int>(inputShape[0]);
1502 auto inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]);
1503 auto outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]);
James Conroy9c3cae82019-08-01 16:01:48 +01001504
1505 // Scale/Offset for input/output, cellState In/Out, weights, bias
1506 float inputOutputScale = 0.0078125f;
1507 int32_t inputOutputOffset = 128;
1508
1509 float cellStateScale = 0.00048828125f;
1510 int32_t cellStateOffset = 0;
1511
1512 float weightsScale = 0.00408021f;
1513 int32_t weightsOffset = 100;
1514
1515 float biasScale = 3.1876640625e-05f;
1516 int32_t biasOffset = 0;
1517
1518 // Input/Output tensor info
1519 armnn::TensorInfo inputInfo({numBatches , inputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001520 armnn::DataType::QAsymmU8,
James Conroy9c3cae82019-08-01 16:01:48 +01001521 inputOutputScale,
1522 inputOutputOffset);
1523
1524 armnn::TensorInfo cellStateInfo({numBatches , outputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001525 armnn::DataType::QSymmS16,
James Conroy9c3cae82019-08-01 16:01:48 +01001526 cellStateScale,
1527 cellStateOffset);
1528
1529 armnn::TensorInfo outputStateInfo({numBatches , outputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001530 armnn::DataType::QAsymmU8,
James Conroy9c3cae82019-08-01 16:01:48 +01001531 inputOutputScale,
1532 inputOutputOffset);
1533
James Conroy9c3cae82019-08-01 16:01:48 +01001534 // Input0
1535 std::vector<uint8_t> inputVector;
1536 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
James Conroy9c3cae82019-08-01 16:01:48 +01001537
1538 // Input1
1539 std::vector<int16_t> cellStateInVector = {876, 1034, 955, -909, 761, 1029, 796, -1036}; // 13
James Conroy9c3cae82019-08-01 16:01:48 +01001540 // Input2
1541 std::vector<uint8_t> outputStateInVector = {136, 150, 140, 115, 135, 152, 138, 112}; // 14
James Conroy9c3cae82019-08-01 16:01:48 +01001542
1543 // Output0
1544 std::vector<int16_t> cellStateOutVector = {1485, 1177, 1373, -1023, 1019, 1355, 1097, -1235}; // 0
James Conroy9c3cae82019-08-01 16:01:48 +01001545
1546 // Output1
1547 std::vector<uint8_t> outputVector; // 1
1548 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
Sadik Armagan483c8112021-06-01 09:24:52 +01001549
1550 std::vector<uint8_t> actualOutput(outputStateInfo.GetNumElements());
James Conroy9c3cae82019-08-01 16:01:48 +01001551
1552 // Create tensor handles
Finn Williamsc43de6a2020-08-27 11:13:25 +01001553 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001554 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001555 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001556 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001557 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001558
1559 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001560 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
1561 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001562
1563 armnn::QuantizedLstmQueueDescriptor data;
1564 armnn::WorkloadInfo info;
1565
1566 // Add inputs and outputs to workload
1567 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
1568 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
1569 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
1570
1571 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
1572 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
1573
1574 // Weights and bias tensor and quantization info
1575 armnn::TensorInfo inputWeightsInfo({outputSize, inputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001576 armnn::DataType::QAsymmU8,
James Conroy9c3cae82019-08-01 16:01:48 +01001577 weightsScale,
1578 weightsOffset);
1579
1580 armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001581 armnn::DataType::QAsymmU8,
James Conroy9c3cae82019-08-01 16:01:48 +01001582 weightsScale,
1583 weightsOffset);
1584
1585 armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset);
1586
1587 // Weights and bias tensor data
Sadik Armagan483c8112021-06-01 09:24:52 +01001588 std::vector<uint8_t> inputToInputWeights = {146, 250, 235, 171, 10, 218, 171, 108};
1589 std::vector<uint8_t> inputToForgetWeights = {24, 50, 132, 179, 158, 110, 3, 169};
1590 std::vector<uint8_t> inputToCellWeights = {133, 34, 29, 49, 206, 109, 54, 183};
1591 std::vector<uint8_t> inputToOutputWeights = {195, 187, 11, 99, 109, 10, 218, 48};
James Conroy9c3cae82019-08-01 16:01:48 +01001592
Sadik Armagan483c8112021-06-01 09:24:52 +01001593 std::vector<uint8_t> recurrentToInputWeights =
1594 {254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26};
1595 std::vector<uint8_t> recurrentToForgetWeights =
1596 {137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253};
1597 std::vector<uint8_t> recurrentToCellWeights =
1598 {172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216};
1599 std::vector<uint8_t> recurrentToOutputWeights =
1600 {106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98};
James Conroy9c3cae82019-08-01 16:01:48 +01001601
Sadik Armagan483c8112021-06-01 09:24:52 +01001602 std::vector<int32_t> inputGateBias = {-7876, 13488, -726, 32839};
1603 std::vector<int32_t> forgetGateBias = {9206, -46884, -11693, -38724};
1604 std::vector<int32_t> cellBias = {39481, 48624, 48976, -21419};
1605 std::vector<int32_t> outputGateBias = {-58999, -17050, -41852, -40538};
James Conroy9c3cae82019-08-01 16:01:48 +01001606
James Conroy1f58f032021-04-27 17:13:27 +01001607 // ScopedTensorHandles
1608 armnn::ScopedTensorHandle inputToInputWeightsTensor(inputWeightsInfo);
1609 armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
1610 armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
1611 armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001612
James Conroy1f58f032021-04-27 17:13:27 +01001613 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(recurrentWeightsInfo);
1614 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
1615 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
1616 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001617
James Conroy1f58f032021-04-27 17:13:27 +01001618 armnn::ScopedTensorHandle inputGateBiasTensor(biasInfo);
1619 armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
1620 armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
1621 armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001622
1623 // Allocate and copy data
Sadik Armagan483c8112021-06-01 09:24:52 +01001624 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
1625 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1626 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1627 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
James Conroy9c3cae82019-08-01 16:01:48 +01001628
Sadik Armagan483c8112021-06-01 09:24:52 +01001629 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
1630 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1631 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1632 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
James Conroy9c3cae82019-08-01 16:01:48 +01001633
Sadik Armagan483c8112021-06-01 09:24:52 +01001634 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
1635 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1636 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1637 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
James Conroy9c3cae82019-08-01 16:01:48 +01001638
1639 // Setup queue descriptor
1640 data.m_InputToInputWeights = &inputToInputWeightsTensor;
1641 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1642 data.m_InputToCellWeights = &inputToCellWeightsTensor;
1643 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1644
1645 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
1646 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1647 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1648 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1649
1650 data.m_InputGateBias = &inputGateBiasTensor;
1651 data.m_ForgetGateBias = &forgetGateBiasTensor;
1652 data.m_CellBias = &cellBiasTensor;
1653 data.m_OutputGateBias = &outputGateBiasTensor;
1654
1655 // Create workload and allocate tensor handles
Teresa Charlin611c7fb2022-01-07 09:47:29 +00001656 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QuantizedLstm,
1657 data,
1658 info);
James Conroy9c3cae82019-08-01 16:01:48 +01001659 inputHandle->Allocate();
1660 outputStateInHandle->Allocate();
1661 cellStateInHandle->Allocate();
1662
1663 cellStateOutHandle->Allocate();
1664 outputHandle->Allocate();
1665
Sadik Armagan483c8112021-06-01 09:24:52 +01001666 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1667 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1668 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
James Conroy9c3cae82019-08-01 16:01:48 +01001669
1670 workload->Execute();
1671
Sadik Armagan483c8112021-06-01 09:24:52 +01001672 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
James Conroy9c3cae82019-08-01 16:01:48 +01001673
Sadik Armagan483c8112021-06-01 09:24:52 +01001674 return LayerTestResult<uint8_t, 2>(actualOutput,
1675 outputVector,
1676 outputHandle->GetShape(),
1677 outputStateInfo.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001678}
1679
James Conroyb22a75e2020-06-08 14:53:10 +01001680// QLSTM: CIFG, LayerNorm
James Conroy4f1f8992020-04-29 20:01:10 +01001681LayerTestResult<int8_t, 2> QLstmTestImpl(
1682 armnn::IWorkloadFactory& workloadFactory,
1683 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001684 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001685 const std::vector<int8_t>& input,
1686 const std::vector<int8_t>& outputExpected)
James Conroy4f1f8992020-04-29 20:01:10 +01001687{
1688 IgnoreUnused(memoryManager);
1689 unsigned int numBatches = 2;
1690 unsigned int inputSize = 5;
1691 unsigned int outputSize = 4;
1692 unsigned int numUnits = 4;
1693
1694 bool cifgEnabled = true;
1695 bool peepholeEnabled = false;
1696 bool projectionEnabled = false;
1697 bool layerNormEnabled = true;
1698
1699 // Scale/Offset quantization info
1700 float inputScale = 0.0078125f;
1701 int32_t inputOffset = 0;
1702
1703 int32_t hiddenStateZeroPoint = 0;
1704 float hiddenStateScale = 0.007f;
1705
1706 // if (!projectionEnabled) outputScale == hiddenStateScale
1707 float outputScale = hiddenStateScale;
1708 int32_t outputOffset = hiddenStateZeroPoint;
1709
1710 float cellStateScale = 3.05176e-05f;
1711 int32_t cellStateOffset = 0;
1712
1713 float weightsScale = 0.00784314f;
1714 int32_t weightsOffset = 0;
1715
1716 float layerNormScale = 3.05182e-05f;
1717 int32_t layerNormOffset = 0;
1718
1719 float biasScale = layerNormScale / 1024;
1720 int32_t biasOffset = 0;
1721
1722 float inputIntermediateScale = 0.007059f;
1723 float forgetIntermediateScale = 0.007812f;
1724 float cellIntermediateScale = inputIntermediateScale;
1725 float outputIntermediateScale = forgetIntermediateScale;
1726
1727 float cellClip = 0.0f;
1728 float projectionClip = 0.0f;
1729
1730 // Input/Output tensor info
1731 armnn::TensorInfo inputInfo({numBatches , inputSize},
1732 armnn::DataType::QAsymmS8,
1733 inputScale,
1734 inputOffset);
1735
1736 armnn::TensorInfo cellStateInfo({numBatches , numUnits},
1737 armnn::DataType::QSymmS16,
1738 cellStateScale,
1739 cellStateOffset);
1740
1741 armnn::TensorInfo outputStateInfo({numBatches , outputSize},
1742 armnn::DataType::QAsymmS8,
1743 outputScale,
1744 outputOffset);
1745
1746 LayerTestResult<int8_t, 2> ret(outputStateInfo);
1747
1748 // Input tensors
1749 std::vector<int8_t> inputVector;
1750 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
James Conroy4f1f8992020-04-29 20:01:10 +01001751
1752 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
James Conroy4f1f8992020-04-29 20:01:10 +01001753
Teresa Charlinbe727be2020-09-25 15:08:21 +01001754 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
James Conroy4f1f8992020-04-29 20:01:10 +01001755
1756 // Output tensors
Sadik Armagan483c8112021-06-01 09:24:52 +01001757 std::vector<int16_t> cellStateOutVector = {-11692, 9960, 5491, 8861, -9422, 7726, 2056, 13149};
James Conroy4f1f8992020-04-29 20:01:10 +01001758
1759 std::vector<int8_t> outputVector;
1760 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
Sadik Armagan483c8112021-06-01 09:24:52 +01001761
1762 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
James Conroy4f1f8992020-04-29 20:01:10 +01001763
1764 // Create tensor handles
Finn Williamsc43de6a2020-08-27 11:13:25 +01001765 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001766 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001767 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001768 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001769 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001770
Finn Williamsc43de6a2020-08-27 11:13:25 +01001771 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
1772 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001773 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001774 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
1775 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001776
1777 armnn::QLstmQueueDescriptor data;
1778 armnn::WorkloadInfo info;
1779
1780 // Add inputs and outputs to workload
1781 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
1782 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
1783 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
1784
1785 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
1786 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
1787 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
1788
1789 // Weights and bias tensor and quantization info
1790 armnn::TensorInfo inputWeightsInfo({outputSize, inputSize},
1791 armnn::DataType::QSymmS8,
1792 weightsScale,
1793 weightsOffset);
1794
1795 armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize},
1796 armnn::DataType::QSymmS8,
1797 weightsScale,
1798 weightsOffset);
1799
1800 armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset);
1801
1802 armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset);
1803
1804 // Weights and bias tensor data
Sadik Armagan483c8112021-06-01 09:24:52 +01001805 std::vector<int8_t> inputToForgetWeights =
1806 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
1807 std::vector<int8_t> inputToCellWeights =
1808 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
1809 std::vector<int8_t> inputToOutputWeights =
1810 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
James Conroy4f1f8992020-04-29 20:01:10 +01001811
Sadik Armagan483c8112021-06-01 09:24:52 +01001812 std::vector<int8_t> recurrentToForgetWeights =
1813 {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25, 25, 38, -13, 51};
1814 std::vector<int8_t> recurrentToCellWeights =
1815 {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25, 38, -13, 25, 64};
1816 std::vector<int8_t> recurrentToOutputWeights =
1817 {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25, 13, 64, 25, -38};
James Conroy4f1f8992020-04-29 20:01:10 +01001818
Sadik Armagan483c8112021-06-01 09:24:52 +01001819 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
1820 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
1821 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
James Conroy4f1f8992020-04-29 20:01:10 +01001822
Sadik Armagan483c8112021-06-01 09:24:52 +01001823 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
1824 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
1825 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
James Conroy4f1f8992020-04-29 20:01:10 +01001826
James Conroy1f58f032021-04-27 17:13:27 +01001827 // ScopedTensorHandles
1828 armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
1829 armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
1830 armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001831
James Conroy1f58f032021-04-27 17:13:27 +01001832 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
1833 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
1834 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001835
James Conroy1f58f032021-04-27 17:13:27 +01001836 armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
1837 armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
1838 armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001839
James Conroy1f58f032021-04-27 17:13:27 +01001840 armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo);
1841 armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo);
1842 armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001843
1844 // Allocate and copy data
Sadik Armagan483c8112021-06-01 09:24:52 +01001845 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1846 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1847 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001848
Sadik Armagan483c8112021-06-01 09:24:52 +01001849 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1850 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1851 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001852
Sadik Armagan483c8112021-06-01 09:24:52 +01001853 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1854 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1855 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001856
Sadik Armagan483c8112021-06-01 09:24:52 +01001857 AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
1858 AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
1859 AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001860
1861 // Setup queue descriptor
1862 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1863 data.m_InputToCellWeights = &inputToCellWeightsTensor;
1864 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1865
1866 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1867 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1868 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1869
1870 data.m_ForgetGateBias = &forgetGateBiasTensor;
1871 data.m_CellBias = &cellBiasTensor;
1872 data.m_OutputGateBias = &outputGateBiasTensor;
1873
1874 data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
1875 data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
1876 data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
1877
1878 data.m_Parameters.m_CifgEnabled = cifgEnabled;
1879 data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
1880 data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
1881 data.m_Parameters.m_LayerNormEnabled = layerNormEnabled;
1882
1883 data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale;
1884 data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale;
1885 data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale;
1886 data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale;
1887
1888 data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint;
1889 data.m_Parameters.m_HiddenStateScale = hiddenStateScale;
1890
1891 data.m_Parameters.m_CellClip = cellClip;
1892 data.m_Parameters.m_ProjectionClip = projectionClip;
1893
1894 // Create workload and allocate tensor handles
Teresa Charlin611c7fb2022-01-07 09:47:29 +00001895 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QLstm, data, info);
James Conroy4f1f8992020-04-29 20:01:10 +01001896 inputHandle->Allocate();
1897 outputStateInHandle->Allocate();
1898 cellStateInHandle->Allocate();
1899
1900 outputStateOutHandle->Allocate();
1901 cellStateOutHandle->Allocate();
1902 outputHandle->Allocate();
1903
Sadik Armagan483c8112021-06-01 09:24:52 +01001904 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1905 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1906 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001907
1908 workload->Execute();
1909
Sadik Armagan483c8112021-06-01 09:24:52 +01001910 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
James Conroy4f1f8992020-04-29 20:01:10 +01001911
Sadik Armagan483c8112021-06-01 09:24:52 +01001912 return LayerTestResult<int8_t, 2>(actualOutput,
1913 outputVector,
1914 outputHandle->GetShape(),
1915 outputStateInfo.GetShape());
James Conroy4f1f8992020-04-29 20:01:10 +01001916}
1917
James Conroyb22a75e2020-06-08 14:53:10 +01001918// QLSTM: Projection, LayerNorm
1919LayerTestResult<int8_t, 2> QLstmTestImpl1(
1920 armnn::IWorkloadFactory& workloadFactory,
1921 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001922 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001923 const std::vector<int8_t>& input,
1924 const std::vector<int8_t>& outputExpected)
James Conroyb22a75e2020-06-08 14:53:10 +01001925{
1926 IgnoreUnused(memoryManager);
1927 unsigned int numBatches = 2;
1928 unsigned int inputSize = 5;
1929 unsigned int outputSize = 3;
1930 unsigned int numUnits = 4;
1931
1932 bool cifgEnabled = false;
1933 bool peepholeEnabled = false;
1934 bool projectionEnabled = true;
1935 bool layerNormEnabled = true;
1936
1937 // Scale/Offset quantization info
1938 float inputScale = 0.0078125f;
1939 int32_t inputOffset = 0;
1940
1941 int32_t hiddenStateZeroPoint = 0;
1942 float hiddenStateScale = 0.007f;
1943
1944 // if (!projectionEnabled) outputScale == hiddenStateScale
1945 float outputScale = 3.05176e-05f;
1946 int32_t outputOffset = 0;
1947
1948 float cellStateScale = 3.05176e-05f;
1949 int32_t cellStateOffset = 0;
1950
1951 float weightsScale = 0.00784314f;
1952 int32_t weightsOffset = 0;
1953
1954 float layerNormScale = 3.05182e-05f;
1955 int32_t layerNormOffset = 0;
1956
1957 float biasScale = layerNormScale / 1024;
1958 int32_t biasOffset = 0;
1959
1960 float projectionWeightsScale = 0.00392157f;
1961
1962 float inputIntermediateScale = 0.007059f;
1963 float forgetIntermediateScale = 0.007812f;
1964 float cellIntermediateScale = inputIntermediateScale;
1965 float outputIntermediateScale = forgetIntermediateScale;
1966
1967 float cellClip = 0.0f;
1968 float projectionClip = 0.0f;
1969
1970 // Input/Output tensor info
1971 armnn::TensorInfo inputInfo({numBatches , inputSize},
1972 armnn::DataType::QAsymmS8,
1973 inputScale,
1974 inputOffset);
1975
1976 armnn::TensorInfo cellStateInfo({numBatches , numUnits},
1977 armnn::DataType::QSymmS16,
1978 cellStateScale,
1979 cellStateOffset);
1980
1981 armnn::TensorInfo outputStateInfo({numBatches , outputSize},
1982 armnn::DataType::QAsymmS8,
1983 outputScale,
1984 outputOffset);
1985
James Conroyb22a75e2020-06-08 14:53:10 +01001986 // Input tensors
1987 std::vector<int8_t> inputVector;
1988 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
James Conroyb22a75e2020-06-08 14:53:10 +01001989
1990 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
James Conroyb22a75e2020-06-08 14:53:10 +01001991
1992 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0};
James Conroyb22a75e2020-06-08 14:53:10 +01001993
1994 // Output tensors
1995 std::vector<int16_t> cellStateOutVector = {-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939};
James Conroyb22a75e2020-06-08 14:53:10 +01001996
1997 std::vector<int8_t> outputVector;
1998 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
Sadik Armagan483c8112021-06-01 09:24:52 +01001999
2000 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
James Conroyb22a75e2020-06-08 14:53:10 +01002001
2002 // Create tensor handles
Finn Williamsc43de6a2020-08-27 11:13:25 +01002003 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002004 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002005 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002006 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002007 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002008
Finn Williamsc43de6a2020-08-27 11:13:25 +01002009 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
2010 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002011 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002012 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
2013 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002014
2015 armnn::QLstmQueueDescriptor data;
2016 armnn::WorkloadInfo info;
2017
2018 // Add inputs and outputs to workload
2019 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
2020 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
2021 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
2022
2023 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
2024 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
2025 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
2026
2027 // Weights and bias tensor and quantization info
2028 armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
2029 armnn::DataType::QSymmS8,
2030 weightsScale,
2031 weightsOffset);
2032
2033 armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
2034 armnn::DataType::QSymmS8,
2035 weightsScale,
2036 weightsOffset);
2037
2038 armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset);
2039
2040 armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset);
2041
2042 armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits},
2043 armnn::DataType::QSymmS8,
2044 projectionWeightsScale,
2045 0);
2046
2047 // Weights and bias tensor data
Sadik Armagan483c8112021-06-01 09:24:52 +01002048 std::vector<int8_t> inputToInputWeights =
2049 {64, 77, 89, -102, -115, 13, 25, 38, -51, 64, -102, 89, -77, 64, -51, -64, -51, -38, -25, -13};
2050 std::vector<int8_t> inputToForgetWeights =
2051 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
2052 std::vector<int8_t> inputToCellWeights =
2053 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
2054 std::vector<int8_t> inputToOutputWeights =
2055 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
James Conroyb22a75e2020-06-08 14:53:10 +01002056
Sadik Armagan483c8112021-06-01 09:24:52 +01002057 std::vector<int8_t> recurrentToInputWeights = {-25, -38, 51, 13, -64, 115, -25, -38, -89, 6, -25, -77};
2058 std::vector<int8_t> recurrentToForgetWeights = {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25};
2059 std::vector<int8_t> recurrentToCellWeights = {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25};
2060 std::vector<int8_t> recurrentToOutputWeights = {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25};
James Conroyb22a75e2020-06-08 14:53:10 +01002061
Sadik Armagan483c8112021-06-01 09:24:52 +01002062 std::vector<int32_t> inputGateBias = {644245, 3221226, 4724464, 8160438};
2063 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
2064 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
2065 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
James Conroyb22a75e2020-06-08 14:53:10 +01002066
Sadik Armagan483c8112021-06-01 09:24:52 +01002067 std::vector<int16_t> inputLayerNormWeights = {3277, 6553, 9830, 16384};
2068 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
2069 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
2070 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
James Conroyb22a75e2020-06-08 14:53:10 +01002071
Sadik Armagan483c8112021-06-01 09:24:52 +01002072 std::vector<int8_t> projectionWeights = {-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51};
James Conroyb22a75e2020-06-08 14:53:10 +01002073
James Conroy1f58f032021-04-27 17:13:27 +01002074 // ScopedTensorHandles
2075 armnn::ScopedTensorHandle inputToInputWeightsTensor(inputWeightsInfo);
2076 armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
2077 armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
2078 armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002079
James Conroy1f58f032021-04-27 17:13:27 +01002080 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(recurrentWeightsInfo);
2081 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
2082 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
2083 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002084
James Conroy1f58f032021-04-27 17:13:27 +01002085 armnn::ScopedTensorHandle inputGateBiasTensor(biasInfo);
2086 armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
2087 armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
2088 armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002089
James Conroy1f58f032021-04-27 17:13:27 +01002090 armnn::ScopedTensorHandle inputLayerNormWeightsTensor(layerNormWeightsInfo);
2091 armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo);
2092 armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo);
2093 armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002094
James Conroy1f58f032021-04-27 17:13:27 +01002095 armnn::ScopedTensorHandle projectionWeightsTensor(projectionWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002096
2097 // Allocate and copy data
Sadik Armagan483c8112021-06-01 09:24:52 +01002098 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
2099 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
2100 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
2101 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002102
Sadik Armagan483c8112021-06-01 09:24:52 +01002103 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
2104 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
2105 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
2106 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002107
Sadik Armagan483c8112021-06-01 09:24:52 +01002108 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
2109 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
2110 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
2111 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002112
Sadik Armagan483c8112021-06-01 09:24:52 +01002113 AllocateAndCopyDataToITensorHandle(&inputLayerNormWeightsTensor, inputLayerNormWeights.data());
2114 AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
2115 AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
2116 AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002117
Sadik Armagan483c8112021-06-01 09:24:52 +01002118 AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002119
2120 // Setup queue descriptor
2121 data.m_InputToInputWeights = &inputToInputWeightsTensor;
2122 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
2123 data.m_InputToCellWeights = &inputToCellWeightsTensor;
2124 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
2125
2126 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
2127 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
2128 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
2129 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
2130
2131 data.m_InputGateBias = &inputGateBiasTensor;
2132 data.m_ForgetGateBias = &forgetGateBiasTensor;
2133 data.m_CellBias = &cellBiasTensor;
2134 data.m_OutputGateBias = &outputGateBiasTensor;
2135
2136 data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor;
2137 data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
2138 data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
2139 data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
2140
2141 data.m_ProjectionWeights = &projectionWeightsTensor;
2142
2143 data.m_Parameters.m_CifgEnabled = cifgEnabled;
2144 data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
2145 data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
2146 data.m_Parameters.m_LayerNormEnabled = layerNormEnabled;
2147
2148 data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale;
2149 data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale;
2150 data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale;
2151 data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale;
2152
2153 data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint;
2154 data.m_Parameters.m_HiddenStateScale = hiddenStateScale;
2155
2156 data.m_Parameters.m_CellClip = cellClip;
2157 data.m_Parameters.m_ProjectionClip = projectionClip;
2158
2159 // Create workload and allocate tensor handles
Teresa Charlin611c7fb2022-01-07 09:47:29 +00002160 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QLstm, data, info);
James Conroyb22a75e2020-06-08 14:53:10 +01002161 inputHandle->Allocate();
2162 outputStateInHandle->Allocate();
2163 cellStateInHandle->Allocate();
2164
2165 outputStateOutHandle->Allocate();
2166 cellStateOutHandle->Allocate();
2167 outputHandle->Allocate();
2168
Sadik Armagan483c8112021-06-01 09:24:52 +01002169 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
2170 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
2171 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002172
2173 workload->Execute();
2174
Sadik Armagan483c8112021-06-01 09:24:52 +01002175 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
James Conroyb22a75e2020-06-08 14:53:10 +01002176
Sadik Armagan483c8112021-06-01 09:24:52 +01002177 return LayerTestResult<int8_t, 2>(actualOutput,
2178 outputVector,
2179 outputHandle->GetShape(),
2180 outputStateInfo.GetShape());
James Conroyb22a75e2020-06-08 14:53:10 +01002181}
2182
2183// QLSTM: Projection, CIFG, LayerNorm
2184LayerTestResult<int8_t, 2> QLstmTestImpl2(
2185 armnn::IWorkloadFactory& workloadFactory,
2186 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002187 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01002188 const std::vector<int8_t>& input,
2189 const std::vector<int8_t>& outputExpected)
James Conroyb22a75e2020-06-08 14:53:10 +01002190{
2191 IgnoreUnused(memoryManager);
2192 unsigned int numBatches = 2;
2193 unsigned int inputSize = 5;
2194 unsigned int outputSize = 3;
2195 unsigned int numUnits = 4;
2196
2197 bool cifgEnabled = true;
2198 bool peepholeEnabled = false;
2199 bool projectionEnabled = true;
2200 bool layerNormEnabled = true;
2201
2202 // Scale/Offset quantization info
2203 float inputScale = 0.0078125f;
2204 int32_t inputOffset = 0;
2205
2206 int32_t hiddenStateZeroPoint = 0;
2207 float hiddenStateScale = 0.007f;
2208
2209 // if (!projectionEnabled) outputScale == hiddenStateScale
2210 float outputScale = 3.05176e-05f;
2211 int32_t outputOffset = 0;
2212
2213 float cellStateScale = 3.05176e-05f;
2214 int32_t cellStateOffset = 0;
2215
2216 float weightsScale = 0.00784314f;
2217 int32_t weightsOffset = 0;
2218
2219 float layerNormScale = 3.05182e-05f;
2220 int32_t layerNormOffset = 0;
2221
2222 float biasScale = layerNormScale / 1024;
2223 int32_t biasOffset = 0;
2224
2225 float projectionWeightsScale = 0.00392157f;
2226
2227 float inputIntermediateScale = 0.007059f;
2228 float forgetIntermediateScale = 0.007812f;
2229 float cellIntermediateScale = inputIntermediateScale;
2230 float outputIntermediateScale = forgetIntermediateScale;
2231
2232 float cellClip = 0.0f;
2233 float projectionClip = 0.0f;
2234
2235 // Input/Output tensor info
2236 armnn::TensorInfo inputInfo({numBatches , inputSize},
2237 armnn::DataType::QAsymmS8,
2238 inputScale,
2239 inputOffset);
2240
2241 armnn::TensorInfo cellStateInfo({numBatches , numUnits},
2242 armnn::DataType::QSymmS16,
2243 cellStateScale,
2244 cellStateOffset);
2245
2246 armnn::TensorInfo outputStateInfo({numBatches , outputSize},
2247 armnn::DataType::QAsymmS8,
2248 outputScale,
2249 outputOffset);
2250
James Conroyb22a75e2020-06-08 14:53:10 +01002251 // Input tensors
2252 std::vector<int8_t> inputVector;
2253 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
James Conroyb22a75e2020-06-08 14:53:10 +01002254
2255 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
James Conroyb22a75e2020-06-08 14:53:10 +01002256
2257 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0};
James Conroyb22a75e2020-06-08 14:53:10 +01002258
2259 // Output tensors
Sadik Armagan483c8112021-06-01 09:24:52 +01002260 std::vector<int16_t> cellStateOutVector = {-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939};
James Conroyb22a75e2020-06-08 14:53:10 +01002261
2262 std::vector<int8_t> outputVector;
2263 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
Sadik Armagan483c8112021-06-01 09:24:52 +01002264
2265 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
James Conroyb22a75e2020-06-08 14:53:10 +01002266
2267 // Create tensor handles
Finn Williamsc43de6a2020-08-27 11:13:25 +01002268 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002269 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002270 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002271 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002272 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002273
Finn Williamsc43de6a2020-08-27 11:13:25 +01002274 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
2275 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002276 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002277 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
2278 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002279
2280 armnn::QLstmQueueDescriptor data;
2281 armnn::WorkloadInfo info;
2282
2283 // Add inputs and outputs to workload
2284 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
2285 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
2286 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
2287
2288 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
2289 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
2290 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
2291
2292 // Weights and bias tensor and quantization info
2293 armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
2294 armnn::DataType::QSymmS8,
2295 weightsScale,
2296 weightsOffset);
2297
2298 armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
2299 armnn::DataType::QSymmS8,
2300 weightsScale,
2301 weightsOffset);
2302
2303 armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset);
2304
2305 armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset);
2306
2307 armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits},
2308 armnn::DataType::QSymmS8,
2309 projectionWeightsScale,
2310 0);
2311
2312 // Weights and bias tensor data
Sadik Armagan483c8112021-06-01 09:24:52 +01002313 std::vector<int8_t> inputToForgetWeights =
2314 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
2315 std::vector<int8_t> inputToCellWeights =
2316 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
2317 std::vector<int8_t> inputToOutputWeights =
2318 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
James Conroyb22a75e2020-06-08 14:53:10 +01002319
Sadik Armagan483c8112021-06-01 09:24:52 +01002320 std::vector<int8_t> recurrentToForgetWeights =
2321 {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25};
2322 std::vector<int8_t> recurrentToCellWeights =
2323 {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25};
2324 std::vector<int8_t> recurrentToOutputWeights =
2325 {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25};
James Conroyb22a75e2020-06-08 14:53:10 +01002326
Sadik Armagan483c8112021-06-01 09:24:52 +01002327 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
2328 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
2329 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
James Conroyb22a75e2020-06-08 14:53:10 +01002330
Sadik Armagan483c8112021-06-01 09:24:52 +01002331 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
2332 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
2333 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
James Conroyb22a75e2020-06-08 14:53:10 +01002334
Sadik Armagan483c8112021-06-01 09:24:52 +01002335 std::vector<int8_t> projectionWeights = {-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51};
James Conroyb22a75e2020-06-08 14:53:10 +01002336
James Conroy1f58f032021-04-27 17:13:27 +01002337 // ScopedTensorHandles
2338 armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
2339 armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
2340 armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002341
James Conroy1f58f032021-04-27 17:13:27 +01002342 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
2343 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
2344 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002345
James Conroy1f58f032021-04-27 17:13:27 +01002346 armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
2347 armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
2348 armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002349
James Conroy1f58f032021-04-27 17:13:27 +01002350 armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo);
2351 armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo);
2352 armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002353
James Conroy1f58f032021-04-27 17:13:27 +01002354 armnn::ScopedTensorHandle projectionWeightsTensor(projectionWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002355
2356 // Allocate and copy data
Sadik Armagan483c8112021-06-01 09:24:52 +01002357 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
2358 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
2359 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002360
Sadik Armagan483c8112021-06-01 09:24:52 +01002361 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
2362 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
2363 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002364
Sadik Armagan483c8112021-06-01 09:24:52 +01002365 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
2366 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
2367 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002368
Sadik Armagan483c8112021-06-01 09:24:52 +01002369 AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
2370 AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
2371 AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002372
Sadik Armagan483c8112021-06-01 09:24:52 +01002373 AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002374
2375 // Setup queue descriptor
2376 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
2377 data.m_InputToCellWeights = &inputToCellWeightsTensor;
2378 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
2379
2380 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
2381 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
2382 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
2383
2384 data.m_ForgetGateBias = &forgetGateBiasTensor;
2385 data.m_CellBias = &cellBiasTensor;
2386 data.m_OutputGateBias = &outputGateBiasTensor;
2387
2388 data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
2389 data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
2390 data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
2391
2392 data.m_ProjectionWeights = &projectionWeightsTensor;
2393
2394 data.m_Parameters.m_CifgEnabled = cifgEnabled;
2395 data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
2396 data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
2397 data.m_Parameters.m_LayerNormEnabled = layerNormEnabled;
2398
2399 data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale;
2400 data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale;
2401 data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale;
2402 data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale;
2403
2404 data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint;
2405 data.m_Parameters.m_HiddenStateScale = hiddenStateScale;
2406
2407 data.m_Parameters.m_CellClip = cellClip;
2408 data.m_Parameters.m_ProjectionClip = projectionClip;
2409
2410 // Create workload and allocate tensor handles
Teresa Charlin611c7fb2022-01-07 09:47:29 +00002411 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::QLstm, data, info);
James Conroyb22a75e2020-06-08 14:53:10 +01002412 inputHandle->Allocate();
2413 outputStateInHandle->Allocate();
2414 cellStateInHandle->Allocate();
2415
2416 outputStateOutHandle->Allocate();
2417 cellStateOutHandle->Allocate();
2418 outputHandle->Allocate();
2419
Sadik Armagan483c8112021-06-01 09:24:52 +01002420 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
2421 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
2422 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002423
2424 workload->Execute();
2425
Sadik Armagan483c8112021-06-01 09:24:52 +01002426 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
James Conroyb22a75e2020-06-08 14:53:10 +01002427
Sadik Armagan483c8112021-06-01 09:24:52 +01002428 return LayerTestResult<int8_t, 2>(actualOutput,
2429 outputVector,
2430 outputHandle->GetShape(),
2431 outputStateInfo.GetShape());
James Conroyb22a75e2020-06-08 14:53:10 +01002432}
2433
James Conroy4f1f8992020-04-29 20:01:10 +01002434
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002435} // anonymous namespace
2436
2437#if defined(ARMNNREF_ENABLED)
2438
2439// The LSTM test units are run only for the reference backend at the moment
2440
2441void LstmUtilsZeroVectorTest()
2442{
2443 armnn::TensorInfo inputDesc({4}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002444 std::vector<float> input = {2., 3., 3., 4.};
2445 std::vector<float> expectedOutput = {0., 0., 0., 0.};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002446
Sadik Armagan483c8112021-06-01 09:24:52 +01002447 return LstmUtilsZeroVectorTestImpl<armnn::DataType::Float32>(input, 4, expectedOutput, inputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002448}
2449
2450void LstmUtilsMeanStddevNormalizationNoneZeroInputTest()
2451{
2452 uint32_t batchSize = 2;
2453 uint32_t vecSize = 4;
2454 armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002455 std::vector<float> input =
2456 { 0.1f, 0.2f, 0.3f, 0.4f, //batch 0
2457 0.9f, 1.0f, 1.1f, 1.2f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002458
Sadik Armagan483c8112021-06-01 09:24:52 +01002459 std::vector<float> expectedOutput =
2460 { -1.34164071f, -0.447213531f, 0.44721365f, 1.34164071f, //batch 0
2461 -1.34163153f, -0.447210163f, 0.447211236f, 1.3416326f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002462
2463 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
Sadik Armagan483c8112021-06-01 09:24:52 +01002464 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002465}
2466
2467void LstmUtilsMeanStddevNormalizationAllZeroInputTest()
2468{
2469 uint32_t batchSize = 2;
2470 uint32_t vecSize = 4;
2471 armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002472 std::vector<float> input =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002473 { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0
Sadik Armagan483c8112021-06-01 09:24:52 +01002474 0.0f, 0.0f, 0.0f, 0.0f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002475
Sadik Armagan483c8112021-06-01 09:24:52 +01002476 std::vector<float> expectedOutput =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002477 { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0
Sadik Armagan483c8112021-06-01 09:24:52 +01002478 0.0f, 0.0f, 0.0f, 0.0f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002479
2480 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
Sadik Armagan483c8112021-06-01 09:24:52 +01002481 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002482}
2483
2484void LstmUtilsMeanStddevNormalizationMixedZeroInputTest()
2485{
2486 uint32_t batchSize = 2;
2487 uint32_t vecSize = 4;
2488 armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002489 std::vector<float> input =
2490 { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0
2491 0.1f, 0.2f, 0.3f, 0.4f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002492
Sadik Armagan483c8112021-06-01 09:24:52 +01002493 std::vector<float> expectedOutput =
2494 { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0
2495 -1.34164071f, -0.447213531f, 0.44721365f, 1.34164071f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002496
2497 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
Sadik Armagan483c8112021-06-01 09:24:52 +01002498 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002499}
2500
2501void LstmUtilsVectorBatchVectorCwiseProductTest()
2502{
2503 uint32_t batchSize = 4;
2504 uint32_t vecSize = 29;
2505 armnn::TensorInfo vecDesc({vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002506 std::vector<float> vector =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002507 { 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.1f,
2508 11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.2f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002509 21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f, 0.0f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002510
2511 armnn::TensorInfo batchVecDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002512 std::vector<float> batchVector =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002513 { /* batch 0 */
2514 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.1f,
2515 11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.2f,
2516 21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f, 0.0f,
2517 /* batch 1 */
2518 -1.1f, -2.2f, -3.3f, -4.4f, -5.5f, -6.6f, -7.7f, -8.8f, -9.9f, -10.1f,
2519 -11.11f, -12.12f, -13.13f, -14.14f, -15.15f, -16.16f, -17.17f, -18.18f, -19.19f, -20.2f,
2520 -21.21f, -22.22f, -23.23f, -24.24f, -25.25f, -26.26f, -27.27f, -28.28f, 0.0f,
2521 /* batch 2 */
2522 1.1f, -2.2f, 3.3f, -4.4f, 5.5f, -6.6f, 7.7f, -8.8f, 9.9f, -10.1f,
2523 11.11f, -12.12f, 13.13f, -14.14f, 15.15f, -16.16f, 17.17f, -18.18f, 19.19f, -20.2f,
2524 21.21f, -22.22f, 23.23f, -24.24f, 25.25f, -26.26f, 27.27f, -28.28f, 0.0f,
2525 /* batch 3 */
2526 -1.1f, 2.2f, -3.3f, 4.4f, -5.5f, 6.6f, -7.7f, 8.8f, -9.9f, 10.1f,
2527 -11.11f, 12.12f, -13.13f, 14.14f, -15.15f, 16.16f, -17.17f, 18.18f, -19.19f, 20.2f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002528 -21.21f, 22.22f, -23.23f, 24.24f, -25.25f, 26.26f, -27.27f, 28.28f, 0.0f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002529
2530 // Expect output = input * output + output.
Sadik Armagan483c8112021-06-01 09:24:52 +01002531 std::vector<float> expectedOutput =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002532 { /* batch 0 */
2533 1.210000f, 4.840000f, 10.889999f, 19.360001f, 30.250000f, 43.559998f,
2534 59.289997f, 77.440002f, 98.009995f, 102.010010f, 123.432091f, 146.894394f,
2535 172.396896f, 199.939606f, 229.522491f, 261.145599f, 294.808899f, 330.512421f,
2536 368.256134f, 408.040039f, 449.864075f, 493.728363f, 539.632874f, 587.577576f,
2537 637.562500f, 689.587585f, 743.652954f, 799.758423f, 0.000000f,
2538 /* batch 1 */
2539 -1.210000f, -4.840000f, -10.889999f, -19.360001f, -30.250000f, -43.559998f,
2540 -59.289997f, -77.440002f, -98.009995f, -102.010010f, -123.432091f, -146.894394f,
2541 -172.396896f, -199.939606f, -229.522491f, -261.145599f, -294.808899f, -330.512421f,
2542 -368.256134f, -408.040039f, -449.864075f, -493.728363f, -539.632874f, -587.577576f,
2543 -637.562500f, -689.587585f, -743.652954f, -799.758423f, 0.000000f,
2544 /* batch 2 */
2545 1.210000f, -4.840000f, 10.889999f, -19.360001f, 30.250000f, -43.559998f,
2546 59.289997f, -77.440002f, 98.009995f, -102.010010f, 123.432091f, -146.894394f,
2547 172.396896f, -199.939606f, 229.522491f, -261.145599f, 294.808899f, -330.512421f,
2548 368.256134f, -408.040039f, 449.864075f, -493.728363f, 539.632874f, -587.577576f,
2549 637.562500f, -689.587585f, 743.652954f, -799.758423f, 0.000000f,
2550 /* batch 3 */
2551 -1.210000f, 4.840000f, -10.889999f, 19.360001f, -30.250000f, 43.559998f,
2552 -59.289997f, 77.440002f, -98.009995f, 102.010010f, -123.432091f, 146.894394f,
2553 -172.396896f, 199.939606f, -229.522491f, 261.145599f, -294.808899f, 330.512421f,
2554 -368.256134f, 408.040039f, -449.864075f, 493.728363f, -539.632874f, 587.577576f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002555 -637.562500f, 689.587585f, -743.652954f, 799.758423f, 0.000000f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002556
2557 return LstmUtilsVectorBatchVectorCwiseProductTestImpl<armnn::DataType::Float32>(vector, batchVector,
Sadik Armagan483c8112021-06-01 09:24:52 +01002558 vecSize, batchSize, expectedOutput, vecDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002559}
2560
2561void LstmUtilsVectorBatchVectorAddTest()
2562{
2563 uint32_t batchSize = 2;
2564 uint32_t vecSize = 3;
2565 armnn::TensorInfo vecDesc({vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002566 std::vector<float> vector = { 0.0f, -0.5f, 1.0f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002567
2568 armnn::TensorInfo batchVecDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002569 std::vector<float> batchVector =
2570 {
2571 1.0f, 2.0f, 3.0f, //batch 0
2572 4.0f, 5.0f, 6.0f //batch 1
2573 };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002574
Sadik Armagan483c8112021-06-01 09:24:52 +01002575 std::vector<float> expectedOutput =
2576 {
2577 1.0f, 1.5f, 4.0f,
2578 4.0f, 4.5f, 7.0f
2579 };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002580
2581 return LstmUtilsVectorBatchVectorAddTestImpl<armnn::DataType::Float32>(vector, batchVector,
Sadik Armagan483c8112021-06-01 09:24:52 +01002582 vecSize, batchSize, expectedOutput, batchVecDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002583}
2584
2585#endif
2586
2587LayerTestResult<float, 2> LstmLayerFloat32WithCifgWithPeepholeNoProjectionTest(
2588 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002589 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2590 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002591{
2592 armnn::TensorInfo inputDesc({ 2, 2 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002593 std::vector<float> input = { 2., 3., 3., 4. };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002594
2595 armnn::TensorInfo outputDesc({ 2, 4 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002596 std::vector<float> expectedOutput =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002597 {-0.36444446f, -0.00352185f, 0.12886585f, -0.05163646f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002598 -0.42734814f, -0.00478661f, 0.13455015f, -0.03560682f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002599 return LstmLayerWithCifgWithPeepholeNoProjectionTestImpl<armnn::DataType::Float32>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002600 workloadFactory, memoryManager, tensorHandleFactory,
2601 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002602}
2603
2604LayerTestResult<float, 2> LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest(
2605 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002606 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2607 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002608{
2609 armnn::TensorInfo inputDesc({ 2, 5 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002610 std::vector<float> input =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002611 {0.787926f, 0.151646f, 0.071352f, 0.118426f, 0.458058f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002612 0.295743f, 0.544053f, 0.690064f, 0.858138f, 0.497181f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002613
2614 armnn::TensorInfo outputDesc({ 2, 16 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002615 std::vector<float> expectedOutput =
2616 {-0.00396806f, 0.029352f, -0.00279226f, 0.0159977f, -0.00835576f,
2617 -0.0211779f, 0.0283512f, -0.0114597f, 0.00907307f, -0.0244004f,
2618 -0.0152191f, -0.0259063f, 0.00914318f, 0.00415118f, 0.017147f,
2619 0.0134203f, -0.013869f, 0.0287268f, -0.00334693f, 0.00733398f, -0.0287926f,
2620 -0.0186926f, 0.0193662f, -0.0115437f, 0.00422612f, -0.0345232f,
2621 0.00223253f, -0.00957321f, 0.0210624f, 0.013331f, 0.0150954f, 0.02168f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002622 return LstmLayerNoCifgWithPeepholeWithProjectionTestImpl<armnn::DataType::Float32>(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002623 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002624}
2625
2626LayerTestResult<float, 2> LstmLayerFloat32NoCifgNoPeepholeNoProjectionTest(
2627 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002628 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2629 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002630{
2631 armnn::TensorInfo inputDesc({2, 2}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002632 std::vector<float> input = {2., 3., 3., 4.};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002633
2634 armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002635 std::vector<float> expectedOutput =
2636 {-0.02973187f, 0.1229473f, 0.20885126f, -0.15358765f,
2637 -0.0185422f, 0.11281417f, 0.24466537f, -0.1826292f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002638
2639 return LstmNoCifgNoPeepholeNoProjectionTestImpl<armnn::DataType::Float32>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002640 workloadFactory, memoryManager, tensorHandleFactory,
2641 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002642}
2643
2644LayerTestResult<float, 2> LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002645 armnn::IWorkloadFactory& workloadFactory,
2646 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2647 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002648{
2649 armnn::TensorInfo inputDesc({ 2, 5 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002650 std::vector<float> input =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002651 {0.7f, 0.8f, 0.1f, 0.2f, 0.3f, //batch 0
Sadik Armagan483c8112021-06-01 09:24:52 +01002652 0.3f, 0.2f, 0.9f, 0.8f, 0.1f}; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002653
2654 armnn::TensorInfo outputDesc({ 2, 3 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002655 std::vector<float> expectedOutput =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002656 { 0.0244077f, 0.128027f, -0.00170918f, //batch 0
Sadik Armagan483c8112021-06-01 09:24:52 +01002657 -0.00692428f, 0.0848741f, 0.063445f}; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002658 return LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl<armnn::DataType::Float32>(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002659 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002660}
2661
2662LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgNoPeepholeNoProjectionTest(
2663 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002664 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2665 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002666{
2667 const float qScale = 1.0f;
2668 const int32_t qOffset = 0;
2669
Derek Lambertif90c56d2020-01-10 17:14:08 +00002670 const armnn::DataType datatype = armnn::DataType::QSymmS16;
2671 const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8;
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002672
2673 armnn::TensorInfo inputDesc({2, 2}, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002674 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002675
2676 armnn::TensorInfo outputDesc({2, 4}, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002677 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2678 {
2679 -0.02973187f, 0.12294730f, 0.20885126f, -0.15358765f,
2680 -0.01854220f, 0.11281417f, 0.24466537f, -0.18262920f
2681 },
2682 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002683
2684 return LstmNoCifgNoPeepholeNoProjectionTestImpl<datatype>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002685 workloadFactory, memoryManager, tensorHandleFactory,
2686 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2687 qScale, qOffset, constantDatatype);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002688
2689}
2690
2691LayerTestResult<int16_t, 2> LstmLayerInt16WithCifgWithPeepholeNoProjectionTest(
2692 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002693 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2694 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002695{
2696 const float qScale = 1.0f;
2697 const int32_t qOffset = 0;
2698
Derek Lambertif90c56d2020-01-10 17:14:08 +00002699 const armnn::DataType datatype = armnn::DataType::QSymmS16;
2700 const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8;
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002701
2702 armnn::TensorInfo inputDesc({ 2, 2 }, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002703 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002704
2705 armnn::TensorInfo outputDesc({ 2, 4 }, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002706 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2707 {
2708 -0.36444446f, -0.00352185f, 0.12886585f, -0.05163646f,
2709 -0.42734814f, -0.00478661f, 0.13455015f, -0.03560682f
2710 },
2711 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002712
2713 return LstmLayerWithCifgWithPeepholeNoProjectionTestImpl<datatype>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002714 workloadFactory, memoryManager, tensorHandleFactory,
2715 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2716 qScale, qOffset, constantDatatype);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002717}
2718
2719LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgWithPeepholeWithProjectionTest(
2720 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002721 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2722 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002723{
2724 const float qScale = 2.0f;
2725 const int32_t qOffset = 0;
2726
Derek Lambertif90c56d2020-01-10 17:14:08 +00002727 const armnn::DataType datatype = armnn::DataType::QSymmS16;
2728 const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8;
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002729
2730 armnn::TensorInfo inputDesc({ 2, 5 }, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002731 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>(
2732 {
2733 0.787926f, 0.151646f, 0.071352f, 0.118426f, 0.458058f,
2734 0.295743f, 0.544053f, 0.690064f, 0.858138f, 0.497181f
2735 },
2736 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002737
2738 armnn::TensorInfo outputDesc({ 2, 16 }, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002739 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2740 {
2741 -0.00396806f, 0.02935200f, -0.00279226f, 0.01599770f,
2742 -0.00835576f, -0.02117790f, 0.02835120f, -0.01145970f,
2743 0.00907307f, -0.02440040f, -0.01521910f, -0.02590630f,
2744 0.00914318f, 0.00415118f, 0.01714700f, 0.01342030f,
2745 -0.01386900f, 0.02872680f, -0.00334693f, 0.00733398f,
2746 -0.02879260f, -0.01869260f, 0.01936620f, -0.01154370f,
2747 0.00422612f, -0.03452320f, 0.00223253f, -0.00957321f,
2748 0.02106240f, 0.01333100f, 0.01509540f, 0.02168000f
2749 },
2750 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002751
2752 return LstmLayerNoCifgWithPeepholeWithProjectionTestImpl<datatype>(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002753 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput, qScale, qOffset, constantDatatype);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002754}
2755
2756LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgNoPeepholeNoProjectionInt16ConstantTest(
2757 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002758 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2759 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002760{
2761 const float qScale = 1.0f;
2762 const int32_t qOffset = 0;
2763
Derek Lambertif90c56d2020-01-10 17:14:08 +00002764 const armnn::DataType datatype = armnn::DataType::QSymmS16; // datatype & constants set to QSymm16
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002765
2766 armnn::TensorInfo inputDesc({2, 2}, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002767 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002768
2769 armnn::TensorInfo outputDesc({2, 4}, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002770 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2771 {
2772 -0.02973187f, 0.12294730f, 0.20885126f, -0.15358765f,
2773 -0.01854220f, 0.11281417f, 0.24466537f, -0.18262920f
2774 },
2775 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002776
2777 return LstmNoCifgNoPeepholeNoProjectionTestImpl<datatype>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002778 workloadFactory, memoryManager, tensorHandleFactory,
2779 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2780 qScale, qOffset, datatype);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002781}
2782
2783//
2784// QuantizedLstm
2785//
2786
2787LayerTestResult<uint8_t, 2> QuantizedLstmTest(
2788 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002789 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2790 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002791{
Derek Lambertif90c56d2020-01-10 17:14:08 +00002792 armnn::TensorInfo inputDesc({2, 2}, armnn::DataType::QAsymmU8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002793 std::vector<uint8_t> input = {166, 179, 50, 150};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002794
Derek Lambertif90c56d2020-01-10 17:14:08 +00002795 armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::QAsymmU8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002796 std::vector<uint8_t> expectedOutput = {140, 151, 146, 112, 136, 156, 142, 112 };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002797
Sadik Armagan483c8112021-06-01 09:24:52 +01002798 return QuantizedLstmTestImpl(workloadFactory, memoryManager, tensorHandleFactory,
2799 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002800}
James Conroy4f1f8992020-04-29 20:01:10 +01002801
2802// QLSTM
2803LayerTestResult<int8_t, 2> QLstmTest(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002804 armnn::IWorkloadFactory& workloadFactory,
2805 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2806 const armnn::ITensorHandleFactory& tensorHandleFactory)
James Conroy4f1f8992020-04-29 20:01:10 +01002807{
2808 armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002809 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
James Conroy4f1f8992020-04-29 20:01:10 +01002810
2811 armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002812 std::vector<int8_t> expectedOutput = {-15, 21, 14, 20, -15, 15, 5, 27};
James Conroy4f1f8992020-04-29 20:01:10 +01002813
Finn Williamsc43de6a2020-08-27 11:13:25 +01002814 return QLstmTestImpl(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
James Conroy4f1f8992020-04-29 20:01:10 +01002815}
James Conroyb22a75e2020-06-08 14:53:10 +01002816
2817LayerTestResult<int8_t, 2> QLstmTest1(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002818 armnn::IWorkloadFactory& workloadFactory,
2819 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2820 const armnn::ITensorHandleFactory& tensorHandleFactory)
James Conroyb22a75e2020-06-08 14:53:10 +01002821{
2822 armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002823 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
James Conroyb22a75e2020-06-08 14:53:10 +01002824
2825 armnn::TensorInfo outputDesc({2, 3}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002826 std::vector<int8_t> expectedOutput = {127, 127, -108, -67, 127, 127};
James Conroyb22a75e2020-06-08 14:53:10 +01002827
Finn Williamsc43de6a2020-08-27 11:13:25 +01002828 return QLstmTestImpl1(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
James Conroyb22a75e2020-06-08 14:53:10 +01002829}
2830
2831LayerTestResult<int8_t, 2> QLstmTest2(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002832 armnn::IWorkloadFactory& workloadFactory,
2833 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2834 const armnn::ITensorHandleFactory& tensorHandleFactory)
James Conroyb22a75e2020-06-08 14:53:10 +01002835{
2836 armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002837 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
James Conroyb22a75e2020-06-08 14:53:10 +01002838
2839 armnn::TensorInfo outputDesc({2, 3}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002840 std::vector<int8_t> expectedOutput = {127, 127, 127, -128, 127, 127};
James Conroyb22a75e2020-06-08 14:53:10 +01002841
Finn Williamsc43de6a2020-08-27 11:13:25 +01002842 return QLstmTestImpl2(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
James Conroyb22a75e2020-06-08 14:53:10 +01002843}