blob: 11003a2e97bbf6e43ab6683aab528f976a78ee97 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
Teresa Charlinfbf0e5b2020-08-17 01:01:06 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
telsoa01c577f2c2018-08-31 09:22:23 +01005
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01006#include "LstmTestImpl.hpp"
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00007
Aron Virginas-Tar48623a02019-10-22 10:00:28 +01008#include <QuantizeHelper.hpp>
9
Matthew Sloyan171214c2020-09-09 09:07:37 +010010#include <armnn/utility/NumericCast.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010011
James Conroy1f58f032021-04-27 17:13:27 +010012#include <backendsCommon/TensorHandle.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010013
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010014#include <backendsCommon/test/TensorCopyUtils.hpp>
15#include <backendsCommon/test/WorkloadTestUtils.hpp>
16
17#include <reference/workloads/Decoders.hpp>
18#include <reference/workloads/Encoders.hpp>
19#include <reference/workloads/LstmUtils.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010020
David Beckac42efd2018-09-26 17:41:13 +010021#include <test/TensorHelpers.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010022
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010023namespace
24{
Jan Eilers38e05bd2019-06-26 13:10:09 +010025
26template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
27void LstmUtilsVectorBatchVectorAddTestImpl(
Sadik Armagan483c8112021-06-01 09:24:52 +010028 std::vector<float>& vec,
29 std::vector<float>& batchVec,
Jan Eilers38e05bd2019-06-26 13:10:09 +010030 uint32_t vSize,
31 uint32_t nBatch,
Sadik Armagan483c8112021-06-01 09:24:52 +010032 std::vector<float>& expectedOutput,
33 armnn::TensorShape& expectedShape)
Jan Eilers38e05bd2019-06-26 13:10:09 +010034{
35 float qScale = 0.0f;
36 int32_t qOffset = 0;
37 armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType, qScale, qOffset );
38
39 // Make encoder and decoder
40 std::unique_ptr<armnn::Decoder<float>> vecDecoder = armnn::MakeDecoder<float>(tensorInfo, vec.data());
41 std::unique_ptr<armnn::Decoder<float>> batchVecDecoder = armnn::MakeDecoder<float>(tensorInfo, batchVec.data());
42 std::unique_ptr<armnn::Encoder<float>> batchVecEncoder = armnn::MakeEncoder<float>(tensorInfo, batchVec.data());
43
44 VectorBatchVectorAdd(*vecDecoder, vSize, *batchVecDecoder, nBatch, *batchVecEncoder);
45
46 // check shape and compare values
Sadik Armagan483c8112021-06-01 09:24:52 +010047 auto result = CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape);
Colm Donelan25ab3a82021-05-17 13:01:52 +010048 BOOST_TEST(result.m_Result, result.m_Message.str());
Jan Eilers38e05bd2019-06-26 13:10:09 +010049
50 // check if iterator is back at start position
51 batchVecEncoder->Set(1.0f);
Sadik Armagan483c8112021-06-01 09:24:52 +010052 BOOST_TEST(batchVec[0] == 1.0f);
Jan Eilers38e05bd2019-06-26 13:10:09 +010053}
54
55template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
56void LstmUtilsZeroVectorTestImpl(
Sadik Armagan483c8112021-06-01 09:24:52 +010057 std::vector<float>& input,
Jan Eilers38e05bd2019-06-26 13:10:09 +010058 uint32_t vSize,
Sadik Armagan483c8112021-06-01 09:24:52 +010059 std::vector<float>& expectedOutput,
60 armnn::TensorShape& expectedShape)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010061{
Jan Eilers38e05bd2019-06-26 13:10:09 +010062 float qScale = 0.0f;
63 int32_t qOffset = 0;
64
65 armnn::TensorInfo tensorInfo({vSize}, ArmnnType, qScale, qOffset );
66
67 // Make encoder for input
68 std::unique_ptr<armnn::Encoder<float>> outputEncoder = armnn::MakeEncoder<float>(tensorInfo, input.data());
69
70 // call ZeroVector
71 ZeroVector(*outputEncoder, vSize);
72
73 // check shape and compare values
Sadik Armagan483c8112021-06-01 09:24:52 +010074 auto result = CompareTensors(input, expectedOutput, expectedShape, expectedShape);
Colm Donelan25ab3a82021-05-17 13:01:52 +010075 BOOST_TEST(result.m_Result, result.m_Message.str());
Jan Eilers38e05bd2019-06-26 13:10:09 +010076
77 // check if iterator is back at start position
78 outputEncoder->Set(1.0f);
79 BOOST_TEST(input[0] == 1.0f);
80
81}
82
Jan Eilers38e05bd2019-06-26 13:10:09 +010083template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
84void LstmUtilsMeanStddevNormalizationTestImpl(
Sadik Armagan483c8112021-06-01 09:24:52 +010085 std::vector<float>& input,
Jan Eilers38e05bd2019-06-26 13:10:09 +010086 uint32_t vSize,
87 uint32_t nBatch,
Sadik Armagan483c8112021-06-01 09:24:52 +010088 std::vector<float>& expectedOutput,
89 armnn::TensorShape& expectedShape)
Jan Eilers38e05bd2019-06-26 13:10:09 +010090{
91 float qScale = 0.0f;
92 int32_t qOffset = 0;
93 armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType, qScale, qOffset );
94
95 // Make encoder and decoder for input
96 std::unique_ptr<armnn::Decoder<float>> inputDecoder = armnn::MakeDecoder<float>(tensorInfo, input.data());
97 std::unique_ptr<armnn::Encoder<float>> outputEncoder = armnn::MakeEncoder<float>(tensorInfo, input.data());
98
99 MeanStddevNormalization(*inputDecoder, *outputEncoder, vSize, nBatch, 1e-8f);
100
101 // check shape and compare values
Sadik Armagan483c8112021-06-01 09:24:52 +0100102 auto result = CompareTensors(input, expectedOutput, expectedShape, expectedShape);
Colm Donelan25ab3a82021-05-17 13:01:52 +0100103 BOOST_TEST(result.m_Result, result.m_Message.str());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100104
105 // check if iterator is back at start position
106 outputEncoder->Set(1.0f);
Sadik Armagan483c8112021-06-01 09:24:52 +0100107 BOOST_TEST(input[0] == 1.0f);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100108}
109
110template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
111void LstmUtilsVectorBatchVectorCwiseProductTestImpl(
Sadik Armagan483c8112021-06-01 09:24:52 +0100112 std::vector<float>& vec,
113 std::vector<float>& batchVec,
Jan Eilers38e05bd2019-06-26 13:10:09 +0100114 uint32_t vSize,
115 uint32_t nBatch,
Sadik Armagan483c8112021-06-01 09:24:52 +0100116 std::vector<float>& expectedOutput,
117 armnn::TensorShape& expectedShape)
Jan Eilers38e05bd2019-06-26 13:10:09 +0100118{
119 float qScale = 0.0f;
120 int32_t qOffset = 0;
121 armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType, qScale, qOffset );
122
123 // Make encoder and decoder
124 std::unique_ptr<armnn::Decoder<float>> vecDecoder = armnn::MakeDecoder<float>(tensorInfo, vec.data());
125 std::unique_ptr<armnn::Decoder<float>> batchVecDecoder = armnn::MakeDecoder<float>(tensorInfo, batchVec.data());
126 std::unique_ptr<armnn::Encoder<float>> batchVecEncoder = armnn::MakeEncoder<float>(tensorInfo, batchVec.data());
127
128 VectorBatchVectorCwiseProduct(*vecDecoder, vSize, *batchVecDecoder, nBatch, *batchVecEncoder);
129
130 // check shape and compare values
Sadik Armagan483c8112021-06-01 09:24:52 +0100131 auto result = CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape);
Colm Donelan25ab3a82021-05-17 13:01:52 +0100132 BOOST_TEST(result.m_Result, result.m_Message.str());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100133
134 // check if iterator is back at start position
135 batchVecEncoder->Set(1.0f);
Sadik Armagan483c8112021-06-01 09:24:52 +0100136 BOOST_TEST(batchVec[0] == 1.0f);
Jan Eilers38e05bd2019-06-26 13:10:09 +0100137}
138
139// Lstm Layer tests:
James Conroy9c3cae82019-08-01 16:01:48 +0100140// *********************************** //
Conor Kennedyb9971c92019-05-07 07:14:23 +0100141template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
142LayerTestResult<T, 2>
143LstmNoCifgNoPeepholeNoProjectionTestImpl(
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000144 armnn::IWorkloadFactory& workloadFactory,
145 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100146 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +0100147 const std::vector<T>& input,
148 const std::vector<T>& outputExpected,
149 const armnn::TensorShape& inputShape,
150 const armnn::TensorShape& outputExpectedShape,
Conor Kennedyb9971c92019-05-07 07:14:23 +0100151 float qScale = 0.0f,
152 int32_t qOffset = 0,
153 armnn::DataType constantDataType = armnn::DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +0100154{
Jan Eilers8eb25602020-03-09 12:13:48 +0000155 IgnoreUnused(memoryManager);
Sadik Armagan483c8112021-06-01 09:24:52 +0100156 unsigned int batchSize = armnn::numeric_cast<unsigned int>(inputShape[0]);
157 unsigned int inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]);
158 unsigned int outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]);
telsoa01c577f2c2018-08-31 09:22:23 +0100159 // cellSize and outputSize have the same size when there is no projection.
160 unsigned numUnits = outputSize;
161
Conor Kennedyb9971c92019-05-07 07:14:23 +0100162 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset );
163 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
164 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100165
Conor Kennedyb9971c92019-05-07 07:14:23 +0100166 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
167 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
168 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
169 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100170
Rob Hughesbb46dde2020-05-20 15:27:37 +0100171 std::vector<T> inputVector;
telsoa01c577f2c2018-08-31 09:22:23 +0100172 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
telsoa01c577f2c2018-08-31 09:22:23 +0100173
Rob Hughesbb46dde2020-05-20 15:27:37 +0100174 std::vector<T> cellStateInVector(batchSize * numUnits, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100175 std::vector<T> outputStateInVector(batchSize * outputSize, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100176 std::vector<T> scratchBufferVector(batchSize * numUnits * 4, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100177 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100178 std::vector<T> cellStateOutVector(batchSize * numUnits, T());
Sadik Armagan483c8112021-06-01 09:24:52 +0100179
180 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
telsoa01c577f2c2018-08-31 09:22:23 +0100181
Rob Hughesbb46dde2020-05-20 15:27:37 +0100182 std::vector<T> outputVector;
telsoa01c577f2c2018-08-31 09:22:23 +0100183 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
telsoa01c577f2c2018-08-31 09:22:23 +0100184
Finn Williamsc43de6a2020-08-27 11:13:25 +0100185 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100186 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100187 tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100188 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100189 tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100190
Finn Williamsc43de6a2020-08-27 11:13:25 +0100191 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
192 tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100193 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100194 tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100195 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100196 tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
197 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100198
199 armnn::LstmQueueDescriptor data;
200 armnn::WorkloadInfo info;
201
202 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
203 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
204 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
205
206 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
207 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
208 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
209 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
210
Conor Kennedyb9971c92019-05-07 07:14:23 +0100211 armnn::TensorInfo tensorInfo4({numUnits}, constantDataType , qScale, qOffset);
212 armnn::TensorInfo tensorInfo8({numUnits, 2}, constantDataType, qScale, qOffset);
213 armnn::TensorInfo tensorInfo16({numUnits, 4}, constantDataType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100214
Sadik Armagan483c8112021-06-01 09:24:52 +0100215 std::vector<float> inputToInputWeights = {-0.45018822f, -0.02338299f, -0.0870589f,
216 -0.34550029f, 0.04266912f, -0.15680569f,
217 -0.34856534f, 0.43890524f};
telsoa01c577f2c2018-08-31 09:22:23 +0100218
Sadik Armagan483c8112021-06-01 09:24:52 +0100219 std::vector<float> inputToForgetWeights = { 0.09701663f, 0.20334584f, -0.50592935f,
220 -0.31343272f, -0.40032279f, 0.44781327f,
221 0.01387155f, -0.35593212f};
telsoa01c577f2c2018-08-31 09:22:23 +0100222
Sadik Armagan483c8112021-06-01 09:24:52 +0100223 std::vector<float> inputToCellWeights = { -0.50013041f, 0.1370284f, 0.11810488f, 0.2013163f,
224 -0.20583314f, 0.44344562f, 0.22077113f,
225 -0.29909778f};
telsoa01c577f2c2018-08-31 09:22:23 +0100226
Sadik Armagan483c8112021-06-01 09:24:52 +0100227 std::vector<float> inputToOutputWeights = { -0.25065863f, -0.28290087f, 0.04613829f,
228 0.40525138f, 0.44272184f, 0.03897077f,
229 -0.1556896f, 0.19487578f};
telsoa01c577f2c2018-08-31 09:22:23 +0100230
Sadik Armagan483c8112021-06-01 09:24:52 +0100231 std::vector<float> recurrentToInputWeights = {-0.0063535f, -0.2042388f, 0.31454784f,
232 -0.35746509f, 0.28902304f, 0.08183324f,
233 -0.16555229f, 0.02286911f, -0.13566875f,
234 0.03034258f, 0.48091322f, -0.12528998f,
235 0.24077177f, -0.51332325f, -0.33502164f,
236 0.10629296f};
telsoa01c577f2c2018-08-31 09:22:23 +0100237
Sadik Armagan483c8112021-06-01 09:24:52 +0100238 std::vector<float> recurrentToForgetWeights = { -0.48684245f, -0.06655136f, 0.42224967f,
239 0.2112639f, 0.27654213f, 0.20864892f,
240 -0.07646349f, 0.45877004f, 0.00141793f,
241 -0.14609534f, 0.36447752f, 0.09196436f,
242 0.28053468f, 0.01560611f, -0.20127171f,
243 -0.01140004f};
telsoa01c577f2c2018-08-31 09:22:23 +0100244
Sadik Armagan483c8112021-06-01 09:24:52 +0100245 std::vector<float> recurrentToCellWeights = { -0.3407414f, 0.24443203f, -0.2078532f,
246 0.26320225f, 0.05695659f, -0.00123841f,
247 -0.4744786f, -0.35869038f, -0.06418842f,
248 -0.13502428f, -0.501764f, 0.22830659f,
249 -0.46367589f, 0.26016325f, -0.03894562f,
250 -0.16368064f};
telsoa01c577f2c2018-08-31 09:22:23 +0100251
Sadik Armagan483c8112021-06-01 09:24:52 +0100252 std::vector<float> recurrentToOutputWeights = { 0.43385774f, -0.17194885f, 0.2718237f,
253 0.09215671f, 0.24107647f, -0.39835793f,
254 0.18212086f, 0.01301402f, 0.48572797f,
255 -0.50656658f, 0.20047462f, -0.20607421f,
256 -0.51818722f, -0.15390486f, 0.0468148f,
257 0.39922136f};
telsoa01c577f2c2018-08-31 09:22:23 +0100258
Sadik Armagan483c8112021-06-01 09:24:52 +0100259 std::vector<float> cellToInputWeights = {0., 0., 0., 0.};
telsoa01c577f2c2018-08-31 09:22:23 +0100260
Sadik Armagan483c8112021-06-01 09:24:52 +0100261 std::vector<float> inputGateBias = {0., 0., 0., 0.};
telsoa01c577f2c2018-08-31 09:22:23 +0100262
Sadik Armagan483c8112021-06-01 09:24:52 +0100263 std::vector<float> forgetGateBias = {1., 1., 1., 1.};
telsoa01c577f2c2018-08-31 09:22:23 +0100264
Sadik Armagan483c8112021-06-01 09:24:52 +0100265 std::vector<float> cellBias = {0., 0., 0., 0.};
telsoa01c577f2c2018-08-31 09:22:23 +0100266
Sadik Armagan483c8112021-06-01 09:24:52 +0100267 std::vector<float> outputGateBias = {0., 0., 0., 0.};
telsoa01c577f2c2018-08-31 09:22:23 +0100268
James Conroy1f58f032021-04-27 17:13:27 +0100269 armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo8);
270 armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo8);
271 armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo8);
272 armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo8);
273 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo16);
274 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo16);
275 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo16);
276 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo16);
277 armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo4);
278 armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo4);
279 armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo4);
280 armnn::ScopedTensorHandle cellBiasTensor(tensorInfo4);
281 armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo4);
telsoa01c577f2c2018-08-31 09:22:23 +0100282
Sadik Armagan483c8112021-06-01 09:24:52 +0100283 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
284 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
285 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
286 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
287 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
288 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
289 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
290 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
291 AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data());
292 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
293 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
294 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
295 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
telsoa01c577f2c2018-08-31 09:22:23 +0100296
297 data.m_InputToInputWeights = &inputToInputWeightsTensor;
298 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
299 data.m_InputToCellWeights = &inputToCellWeightsTensor;
300 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
301 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
302 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
303 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
304 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
telsoa01c577f2c2018-08-31 09:22:23 +0100305 data.m_InputGateBias = &inputGateBiasTensor;
306 data.m_ForgetGateBias = &forgetGateBiasTensor;
307 data.m_CellBias = &cellBiasTensor;
308 data.m_OutputGateBias = &outputGateBiasTensor;
309
telsoa01c577f2c2018-08-31 09:22:23 +0100310 // Flags to set test configuration
311 data.m_Parameters.m_ActivationFunc = 4;
312 data.m_Parameters.m_CifgEnabled = false;
313 data.m_Parameters.m_PeepholeEnabled = false;
314 data.m_Parameters.m_ProjectionEnabled = false;
315
telsoa01c577f2c2018-08-31 09:22:23 +0100316 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info);
317 inputHandle->Allocate();
318 outputStateInHandle->Allocate();
319 cellStateInHandle->Allocate();
320
321 scratchHandle->Allocate();
322 outputStateOutHandle->Allocate();
323 cellStateOutHandle->Allocate();
324 outputHandle->Allocate();
325
Sadik Armagan483c8112021-06-01 09:24:52 +0100326 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
327 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
328 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +0100329
telsoa01c577f2c2018-08-31 09:22:23 +0100330 workload->Execute();
331
Sadik Armagan483c8112021-06-01 09:24:52 +0100332 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
telsoa01c577f2c2018-08-31 09:22:23 +0100333
Sadik Armagan483c8112021-06-01 09:24:52 +0100334 return LayerTestResult<T, 2>(actualOutput,
335 outputVector,
336 outputHandle->GetShape(),
337 outputTensorInfo.GetShape());
telsoa01c577f2c2018-08-31 09:22:23 +0100338}
339
Conor Kennedyb9971c92019-05-07 07:14:23 +0100340template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
341LayerTestResult<T, 2>
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000342LstmLayerNoCifgWithPeepholeWithProjectionTestImpl(armnn::IWorkloadFactory& workloadFactory,
343 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +0100344 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +0100345 const std::vector<T>& input,
346 const std::vector<T>& outputExpected,
Conor Kennedyb9971c92019-05-07 07:14:23 +0100347 float qScale = 0.0f,
348 int32_t qOffset = 0,
349 armnn::DataType constantDataType = armnn::DataType::Float32)
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +0000350{
Jan Eilers8eb25602020-03-09 12:13:48 +0000351 IgnoreUnused(memoryManager);
telsoa01c577f2c2018-08-31 09:22:23 +0100352 unsigned int batchSize = 2;
353 unsigned int outputSize = 16;
354 unsigned int inputSize = 5;
355 unsigned numUnits = 20;
356
Conor Kennedyb9971c92019-05-07 07:14:23 +0100357 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset);
358 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
359 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100360
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000361 // Scratch buffer size without CIFG [batchSize, numUnits * 4]
Conor Kennedyb9971c92019-05-07 07:14:23 +0100362 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
363 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
364 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
365 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100366
Rob Hughesbb46dde2020-05-20 15:27:37 +0100367 std::vector<T> inputVector;
telsoa01c577f2c2018-08-31 09:22:23 +0100368 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
telsoa01c577f2c2018-08-31 09:22:23 +0100369
Rob Hughesbb46dde2020-05-20 15:27:37 +0100370 std::vector<T> cellStateInVector(batchSize * numUnits, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100371 std::vector<T> outputStateInVector(batchSize * outputSize, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100372 std::vector<T> scratchBufferVector(batchSize * numUnits * 4, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100373 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
Rob Hughesbb46dde2020-05-20 15:27:37 +0100374 std::vector<T> cellStateOutVector(batchSize * numUnits, T());
Sadik Armagan483c8112021-06-01 09:24:52 +0100375
376 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
telsoa01c577f2c2018-08-31 09:22:23 +0100377
Rob Hughesbb46dde2020-05-20 15:27:37 +0100378 std::vector<T> outputVector;
telsoa01c577f2c2018-08-31 09:22:23 +0100379 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
telsoa01c577f2c2018-08-31 09:22:23 +0100380
Finn Williamsc43de6a2020-08-27 11:13:25 +0100381 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100382 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100383 tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100384 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100385 tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100386
Finn Williamsc43de6a2020-08-27 11:13:25 +0100387 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
388 tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100389 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100390 tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100391 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +0100392 tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
393 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +0100394
395 armnn::LstmQueueDescriptor data;
396 armnn::WorkloadInfo info;
397
398 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
399 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
400 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
David Beckac42efd2018-09-26 17:41:13 +0100401
telsoa01c577f2c2018-08-31 09:22:23 +0100402 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
403 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
404 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
405 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
406
Conor Kennedyb9971c92019-05-07 07:14:23 +0100407 armnn::TensorInfo tensorInfo16({outputSize}, constantDataType, qScale, qOffset);
408 armnn::TensorInfo tensorInfo20({numUnits}, constantDataType, qScale, qOffset);
409 armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, constantDataType, qScale, qOffset);
410 armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, constantDataType, qScale, qOffset);
411 armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, constantDataType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +0100412
Sadik Armagan483c8112021-06-01 09:24:52 +0100413 std::vector<float> inputToInputWeights = {0.021393683f,0.06124551f, 0.046905167f,-0.014657677f,-0.03149463f,
414 0.09171803f, 0.14647801f,0.10797193f, -0.0057968358f,0.0019193048f,
415 -0.2726754f, 0.10154029f, -0.018539885f, 0.080349885f, -0.10262385f,
416 -0.022599787f,-0.09121155f, -0.008675967f, -0.045206103f,-0.0821282f,
417 -0.008045952f,0.015478081f, 0.055217247f, 0.038719587f, 0.044153627f,
418 -0.06453243f,0.05031825f, -0.046935108f, -0.008164439f, 0.014574226f,
419 -0.1671009f, -0.15519552f, -0.16819797f,-0.13971269f,-0.11953059f,
420 0.25005487f, -0.22790983f, 0.009855087f, -0.028140958f, -0.11200698f,
421 0.11295408f, -0.0035217577f, 0.054485075f, 0.05184695f, 0.064711206f,
422 0.10989193f, 0.11674786f, 0.03490607f, 0.07727357f, 0.11390585f,
423 -0.1863375f, -0.1034451f, -0.13945189f, -0.049401227f, -0.18767063f,
424 0.042483903f, 0.14233552f, 0.13832581f, 0.18350165f, 0.14545603f,
425 -0.028545704f,0.024939531f,0.050929718f,0.0076203286f,-0.0029723682f,
426 -0.042484224f, -0.11827596f, -0.09171104f, -0.10808628f,-0.16327988f,
427 -0.2273378f, -0.0993647f, -0.017155107f,0.0023917493f,0.049272764f,
428 0.0038534778f, 0.054764505f, 0.089753784f, 0.06947234f, 0.08014476f,
429 -0.04544234f, -0.0497073f,-0.07135631f, -0.048929106f,-0.004042012f,
430 -0.009284026f, 0.018042054f, 0.0036860977f,-0.07427302f, -0.11434604f,
431 -0.018995456f, 0.031487543f, 0.012834908f,0.019977754f,0.044256654f,
432 -0.39292613f, -0.18519334f, -0.11651281f,-0.06809892f, 0.011373677f };
telsoa01c577f2c2018-08-31 09:22:23 +0100433
Sadik Armagan483c8112021-06-01 09:24:52 +0100434 std::vector<float> inputToForgetWeights = {-0.0018401089f, -0.004852237f,0.03698424f, 0.014181704f,0.028273236f,
435 -0.016726194f, -0.05249759f,-0.10204261f, 0.00861066f,-0.040979505f,
436 -0.009899187f,0.01923892f,-0.028177269f, -0.08535103f,-0.14585495f,
437 0.10662567f,-0.01909731f,-0.017883534f,-0.0047269356f,-0.045103323f,
438 0.0030784295f,0.076784775f,0.07463696f, 0.094531395f,0.0814421f,
439 -0.12257899f, -0.033945758f,-0.031303465f, 0.045630626f,0.06843887f,
440 -0.13492945f, -0.012480007f,-0.0811829f, -0.07224499f,-0.09628791f,
441 0.045100946f,0.0012300825f, 0.013964662f, 0.099372394f,0.02543059f,
442 0.06958324f, 0.034257296f, 0.0482646f, 0.06267997f,0.052625068f,
443 0.12784666f, 0.07077897f, 0.025725935f, 0.04165009f,0.07241905f,
444 0.018668644f, -0.037377294f,-0.06277783f,-0.08833636f,-0.040120605f,
445 -0.011405586f,-0.007808335f,-0.010301386f,-0.005102167f,0.027717464f,
446 0.05483423f, 0.11449111f, 0.11289652f,0.10939839f, 0.13396506f,
447 -0.08402166f,-0.01901462f, -0.044678304f,-0.07720565f,0.014350063f,
448 -0.11757958f, -0.0652038f, -0.08185733f,-0.076754324f,-0.092614375f,
449 0.10405491f, 0.052960336f, 0.035755895f,0.035839386f,-0.012540553f,
450 0.036881298f, 0.02913376f, 0.03420159f,0.05448447f,-0.054523353f,
451 0.02582715f, 0.02327355f, -0.011857179f,-0.0011980024f,-0.034641717f,
452 -0.026125094f,-0.17582615f,-0.15923657f,-0.27486774f,-0.0006143371f,
453 0.0001771948f, -8.470171e-05f, 0.02651807f,0.045790765f,0.06956496f };
telsoa01c577f2c2018-08-31 09:22:23 +0100454
Sadik Armagan483c8112021-06-01 09:24:52 +0100455 std::vector<float> inputToCellWeights = { -0.04580283f, -0.09549462f, -0.032418985f, -0.06454633f,
456 -0.043528453f, 0.043018587f, -0.049152344f, -0.12418144f,
457 -0.078985475f, -0.07596889f, 0.019484362f, -0.11434962f,
458 -0.0074034138f, -0.06314844f, -0.092981495f, 0.0062155537f,
459 -0.025034338f, -0.0028890965f, 0.048929527f, 0.06235075f,
460 0.10665918f, -0.032036792f, -0.08505916f, -0.10843358f,
461 -0.13002433f, -0.036816437f, -0.02130134f, -0.016518239f,
462 0.0047691227f, -0.0025825808f, 0.066017866f, 0.029991534f,
463 -0.10652836f, -0.1037554f, -0.13056071f, -0.03266643f,
464 -0.033702414f, -0.006473424f, -0.04611692f, 0.014419339f,
465 -0.025174323f, 0.0396852f, 0.081777506f, 0.06157468f,
466 0.10210095f, -0.009658194f, 0.046511717f, 0.03603906f,
467 0.0069369148f, 0.015960095f, -0.06507666f, 0.09551598f,
468 0.053568836f, 0.06408714f, 0.12835667f, -0.008714329f,
469 -0.20211966f, -0.12093674f, 0.029450472f, 0.2849013f,
470 -0.029227901f, 0.1164364f, -0.08560263f, 0.09941786f,
471 -0.036999565f, -0.028842626f, -0.0033637602f, -0.017012902f,
472 -0.09720865f, -0.11193351f, -0.029155117f, -0.017936034f,
473 -0.009768936f, -0.04223324f, -0.036159635f, 0.06505112f,
474 -0.021742892f, -0.023377212f, -0.07221364f, -0.06430552f,
475 0.05453865f, 0.091149814f, 0.06387331f, 0.007518393f,
476 0.055960953f, 0.069779344f, 0.046411168f, 0.10509911f,
477 0.07463894f, 0.0075130584f, 0.012850982f, 0.04555431f,
478 0.056955688f, 0.06555285f, 0.050801456f, -0.009862683f,
479 0.00826772f, -0.026555609f, -0.0073611983f, -0.0014897042f };
telsoa01c577f2c2018-08-31 09:22:23 +0100480
Sadik Armagan483c8112021-06-01 09:24:52 +0100481 std::vector<float> inputToOutputWeights ={-0.0998932f, -0.07201956f, -0.052803773f,-0.15629593f,-0.15001918f,
482 -0.07650751f,0.02359855f, -0.075155355f, -0.08037709f, -0.15093534f,
483 0.029517552f, -0.04751393f, 0.010350531f,-0.02664851f, -0.016839722f,
484 -0.023121163f, 0.0077019283f, 0.012851257f, -0.05040649f,-0.0129761f,
485 -0.021737747f,-0.038305793f,-0.06870586f, -0.01481247f,-0.001285394f,
486 0.10124236f, 0.083122835f, 0.053313006f,-0.062235646f,-0.075637154f,
487 -0.027833903f, 0.029774971f, 0.1130802f, 0.09218906f, 0.09506135f,
488 -0.086665764f,-0.037162706f,-0.038880914f,-0.035832845f,-0.014481564f,
489 -0.09825003f,-0.12048569f,-0.097665586f,-0.05287633f, -0.0964047f,
490 -0.11366429f, 0.035777505f, 0.13568819f, 0.052451383f,0.050649304f,
491 0.05798951f, -0.021852335f,-0.099848844f,0.014740475f,-0.078897946f,
492 0.04974699f, 0.014160473f, 0.06973932f, 0.04964942f, 0.033364646f,
493 0.08190124f, 0.025535367f, 0.050893165f, 0.048514254f,0.06945813f,
494 -0.078907564f,-0.06707616f, -0.11844508f, -0.09986688f,-0.07509403f,
495 0.06263226f, 0.14925587f, 0.20188436f, 0.12098451f,0.14639415f,
496 0.0015017595f, -0.014267382f, -0.03417257f,0.012711468f,0.0028300495f,
497 -0.024758482f, -0.05098548f,-0.0821182f, 0.014225672f, 0.021544158f,
498 0.08949725f, 0.07505268f, -0.0020780868f, 0.04908258f,0.06476295f,
499 -0.022907063f,0.027562456f,0.040185735f, 0.019567577f,-0.015598739f,
500 -0.049097303f, -0.017121866f, -0.083368234f,-0.02332002f,-0.0840956f };
telsoa01c577f2c2018-08-31 09:22:23 +0100501
Sadik Armagan483c8112021-06-01 09:24:52 +0100502 std::vector<float> inputGateBias = {0.02234832f, 0.14757581f, 0.18176508f, 0.10380666f, 0.053110216f,
503 -0.06928846f, -0.13942584f, -0.11816189f, 0.19483899f, 0.03652339f,
504 -0.10250295f, 0.036714908f, -0.18426876f, 0.036065217f, 0.21810818f,
505 0.02383196f, -0.043370757f, 0.08690144f, -0.04444982f, 0.00030581196f };
telsoa01c577f2c2018-08-31 09:22:23 +0100506
Sadik Armagan483c8112021-06-01 09:24:52 +0100507 std::vector<float> forgetGateBias ={0.035185695f, -0.042891346f, -0.03032477f, 0.23027696f,
508 0.11098921f, 0.15378423f, 0.09263801f, 0.09790885f,
509 0.09508917f, 0.061199076f, 0.07665568f, -0.015443159f,
510 -0.03499149f, 0.046190713f, 0.08895977f, 0.10899629f,
511 0.40694186f, 0.06030037f, 0.012413437f, -0.06108739f };
telsoa01c577f2c2018-08-31 09:22:23 +0100512
Sadik Armagan483c8112021-06-01 09:24:52 +0100513 std::vector<float> cellBias = { -0.024379363f, 0.0055531194f, 0.23377132f, 0.033463873f,
514 -0.1483596f, -0.10639995f, -0.091433935f, 0.058573797f,
515 -0.06809782f, -0.07889636f, -0.043246906f, -0.09829136f,
516 -0.4279842f, 0.034901652f, 0.18797937f, 0.0075234566f,
517 0.016178843f, 0.1749513f, 0.13975595f, 0.92058027f };
telsoa01c577f2c2018-08-31 09:22:23 +0100518
Sadik Armagan483c8112021-06-01 09:24:52 +0100519 std::vector<float> outputGateBias ={0.046159424f, -0.0012809046f, 0.03563469f, 0.12648113f, 0.027195795f,
520 0.35373217f, -0.018957434f, 0.008907322f, -0.0762701f, 0.12018895f,
521 0.04216877f, 0.0022856654f, 0.040952638f, 0.3147856f, 0.08225149f,
522 -0.057416286f, -0.14995944f, -0.008040261f, 0.13208859f, 0.029760877f};
telsoa01c577f2c2018-08-31 09:22:23 +0100523
Sadik Armagan483c8112021-06-01 09:24:52 +0100524 std::vector<float> recurrentToInputWeights = { -0.001374326f, -0.078856036f, 0.10672688f, 0.029162422f,
telsoa01c577f2c2018-08-31 09:22:23 +0100525 -0.11585556f, 0.02557986f, -0.13446963f, -0.035785314f,
526 -0.01244275f, 0.025961924f, -0.02337298f, -0.044228926f,
527 -0.055839065f, -0.046598054f, -0.010546039f, -0.06900766f,
528 0.027239809f, 0.022582639f, -0.013296484f, -0.05459212f,
529 0.08981f, -0.045407712f, 0.08682226f, -0.06867011f,
530 -0.14390695f, -0.02916037f, 0.000996957f, 0.091420636f,
531 0.14283475f, -0.07390571f, -0.06402044f, 0.062524505f,
532 -0.093129106f, 0.04860203f, -0.08364217f, -0.08119002f,
533 0.009352075f, 0.22920375f, 0.0016303885f, 0.11583097f,
534 -0.13732095f, 0.012405723f, -0.07551853f, 0.06343048f,
535 0.12162708f, -0.031923793f, -0.014335606f, 0.01790974f,
536 -0.10650317f, -0.0724401f, 0.08554849f, -0.05727212f,
537 0.06556731f, -0.042729504f, -0.043227166f, 0.011683251f,
538 -0.013082158f, -0.029302018f, -0.010899579f, -0.062036745f,
539 -0.022509435f, -0.00964907f, -0.01567329f, 0.04260106f,
540 -0.07787477f, -0.11576462f, 0.017356863f, 0.048673786f,
541 -0.017577527f, -0.05527947f, -0.082487635f, -0.040137455f,
542 -0.10820036f, -0.04666372f, 0.022746278f, -0.07851417f,
543 0.01068115f, 0.032956902f, 0.022433773f, 0.0026891115f,
544 0.08944216f, -0.0685835f, 0.010513544f, 0.07228705f,
545 0.02032331f, -0.059686817f, -0.0005566496f, -0.086984694f,
546 0.040414046f, -0.1380399f, 0.094208956f, -0.05722982f,
547 0.012092817f, -0.04989123f, -0.086576f, -0.003399834f,
548 -0.04696032f, -0.045747425f, 0.10091314f, 0.048676282f,
549 -0.029037097f, 0.031399418f, -0.0040285117f, 0.047237843f,
550 0.09504992f, 0.041799378f, -0.049185462f, -0.031518843f,
551 -0.10516937f, 0.026374253f, 0.10058866f, -0.0033195973f,
552 -0.041975245f, 0.0073591834f, 0.0033782164f, -0.004325073f,
553 -0.10167381f, 0.042500053f, -0.01447153f, 0.06464186f,
554 -0.017142897f, 0.03312627f, 0.009205989f, 0.024138335f,
555 -0.011337001f, 0.035530265f, -0.010912711f, 0.0706555f,
556 -0.005894094f, 0.051841937f, -0.1401738f, -0.02351249f,
557 0.0365468f, 0.07590991f, 0.08838724f, 0.021681072f,
558 -0.10086113f, 0.019608743f, -0.06195883f, 0.077335775f,
559 0.023646897f, -0.095322326f, 0.02233014f, 0.09756986f,
560 -0.048691444f, -0.009579111f, 0.07595467f, 0.11480546f,
561 -0.09801813f, 0.019894179f, 0.08502348f, 0.004032281f,
562 0.037211012f, 0.068537936f, -0.048005626f, -0.091520436f,
563 -0.028379958f, -0.01556313f, 0.06554592f, -0.045599163f,
564 -0.01672207f, -0.020169014f, -0.011877351f, -0.20212261f,
565 0.010889619f, 0.0047078193f, 0.038385306f, 0.08540671f,
566 -0.017140968f, -0.0035865551f, 0.016678626f, 0.005633034f,
567 0.015963363f, 0.00871737f, 0.060130805f, 0.028611384f,
568 0.10109069f, -0.015060172f, -0.07894427f, 0.06401885f,
569 0.011584063f, -0.024466386f, 0.0047652307f, -0.09041358f,
570 0.030737216f, -0.0046374933f, 0.14215417f, -0.11823516f,
571 0.019899689f, 0.006106124f, -0.027092824f, 0.0786356f,
572 0.05052217f, -0.058925f, -0.011402121f, -0.024987547f,
573 -0.0013661642f, -0.06832946f, -0.015667673f, -0.1083353f,
574 -0.00096863037f, -0.06988685f, -0.053350925f, -0.027275559f,
575 -0.033664223f, -0.07978348f, -0.025200296f, -0.017207067f,
576 -0.058403496f, -0.055697463f, 0.005798788f, 0.12965427f,
577 -0.062582195f, 0.0013350133f, -0.10482091f, 0.0379771f,
578 0.072521195f, -0.0029455067f, -0.13797039f, -0.03628521f,
579 0.013806405f, -0.017858358f, -0.01008298f, -0.07700066f,
580 -0.017081132f, 0.019358726f, 0.0027079724f, 0.004635139f,
581 0.062634714f, -0.02338735f, -0.039547626f, -0.02050681f,
582 0.03385117f, -0.083611414f, 0.002862572f, -0.09421313f,
583 0.058618143f, -0.08598433f, 0.00972939f, 0.023867095f,
584 -0.053934585f, -0.023203006f, 0.07452513f, -0.048767887f,
585 -0.07314807f, -0.056307215f, -0.10433547f, -0.06440842f,
586 0.04328182f, 0.04389765f, -0.020006588f, -0.09076438f,
587 -0.11652589f, -0.021705797f, 0.03345259f, -0.010329105f,
588 -0.025767034f, 0.013057034f, -0.07316461f, -0.10145612f,
589 0.06358255f, 0.18531723f, 0.07759293f, 0.12006465f,
590 0.1305557f, 0.058638252f, -0.03393652f, 0.09622831f,
591 -0.16253184f, -2.4580743e-06f, 0.079869635f, -0.070196845f,
592 -0.005644518f, 0.06857898f, -0.12598175f, -0.035084512f,
593 0.03156317f, -0.12794146f, -0.031963028f, 0.04692781f,
594 0.030070418f, 0.0071660685f, -0.095516115f, -0.004643372f,
595 0.040170413f, -0.062104587f, -0.0037324072f, 0.0554317f,
596 0.08184801f, -0.019164372f, 0.06791302f, 0.034257166f,
597 -0.10307039f, 0.021943003f, 0.046745934f, 0.0790918f,
598 -0.0265588f, -0.007824208f, 0.042546265f, -0.00977924f,
599 -0.0002440307f, -0.017384544f, -0.017990116f, 0.12252321f,
600 -0.014512694f, -0.08251313f, 0.08861942f, 0.13589665f,
601 0.026351685f, 0.012641483f, 0.07466548f, 0.044301085f,
602 -0.045414884f, -0.051112458f, 0.03444247f, -0.08502782f,
Sadik Armagan483c8112021-06-01 09:24:52 +0100603 -0.04106223f, -0.028126027f, 0.028473156f, 0.10467447f };
telsoa01c577f2c2018-08-31 09:22:23 +0100604
Sadik Armagan483c8112021-06-01 09:24:52 +0100605 std::vector<float> recurrentToForgetWeights = {-0.057784554f, -0.026057621f, -0.068447545f, -0.022581743f,
telsoa01c577f2c2018-08-31 09:22:23 +0100606 0.14811787f, 0.10826372f, 0.09471067f, 0.03987225f,
607 -0.0039523416f, 0.00030638507f, 0.053185795f, 0.10572994f,
608 0.08414449f, -0.022036452f, -0.00066928595f, -0.09203576f,
609 0.032950465f, -0.10985798f, -0.023809856f, 0.0021431844f,
610 -0.02196096f, -0.00326074f, 0.00058621005f, -0.074678116f,
611 -0.06193199f, 0.055729095f, 0.03736828f, 0.020123724f,
612 0.061878487f, -0.04729229f, 0.034919553f, -0.07585433f,
613 -0.04421272f, -0.044019096f, 0.085488975f, 0.04058006f,
614 -0.06890133f, -0.030951202f, -0.024628663f, -0.07672815f,
615 0.034293607f, 0.08556707f, -0.05293577f, -0.033561368f,
616 -0.04899627f, 0.0241671f, 0.015736353f, -0.095442444f,
617 -0.029564252f, 0.016493602f, -0.035026584f, 0.022337519f,
618 -0.026871363f, 0.004780428f, 0.0077918363f, -0.03601621f,
619 0.016435321f, -0.03263031f, -0.09543275f, -0.047392778f,
620 0.013454138f, 0.028934088f, 0.01685226f, -0.086110644f,
621 -0.046250615f, -0.01847454f, 0.047608484f, 0.07339695f,
622 0.034546845f, -0.04881143f, 0.009128804f, -0.08802852f,
623 0.03761666f, 0.008096139f, -0.014454086f, 0.014361001f,
624 -0.023502491f, -0.0011840804f, -0.07607001f, 0.001856849f,
625 -0.06509276f, -0.006021153f, -0.08570962f, -0.1451793f,
626 0.060212336f, 0.055259194f, 0.06974018f, 0.049454916f,
627 -0.027794661f, -0.08077226f, -0.016179763f, 0.1169753f,
628 0.17213494f, -0.0056326236f, -0.053934924f, -0.0124349f,
629 -0.11520337f, 0.05409887f, 0.088759385f, 0.0019655675f,
630 0.0042065294f, 0.03881498f, 0.019844765f, 0.041858196f,
631 -0.05695512f, 0.047233116f, 0.038937137f, -0.06542224f,
632 0.014429736f, -0.09719407f, 0.13908425f, -0.05379757f,
633 0.012321099f, 0.082840554f, -0.029899208f, 0.044217527f,
634 0.059855383f, 0.07711018f, -0.045319796f, 0.0948846f,
635 -0.011724666f, -0.0033288454f, -0.033542685f, -0.04764985f,
636 -0.13873616f, 0.040668588f, 0.034832682f, -0.015319203f,
637 -0.018715994f, 0.046002675f, 0.0599172f, -0.043107376f,
638 0.0294216f, -0.002314414f, -0.022424703f, 0.0030315618f,
639 0.0014641669f, 0.0029166266f, -0.11878115f, 0.013738511f,
640 0.12375372f, -0.0006038222f, 0.029104086f, 0.087442465f,
641 0.052958444f, 0.07558703f, 0.04817258f, 0.044462286f,
642 -0.015213451f, -0.08783778f, -0.0561384f, -0.003008196f,
643 0.047060397f, -0.002058388f, 0.03429439f, -0.018839769f,
644 0.024734668f, 0.024614193f, -0.042046934f, 0.09597743f,
645 -0.0043254104f, 0.04320769f, 0.0064070094f, -0.0019131786f,
646 -0.02558259f, -0.022822596f, -0.023273505f, -0.02464396f,
647 -0.10991725f, -0.006240552f, 0.0074488563f, 0.024044557f,
648 0.04383914f, -0.046476185f, 0.028658995f, 0.060410924f,
649 0.050786525f, 0.009452605f, -0.0073054377f, -0.024810238f,
650 0.0052906186f, 0.0066939713f, -0.0020913032f, 0.014515517f,
651 0.015898481f, 0.021362653f, -0.030262267f, 0.016587038f,
652 -0.011442813f, 0.041154444f, -0.007631438f, -0.03423484f,
653 -0.010977775f, 0.036152758f, 0.0066366293f, 0.11915515f,
654 0.02318443f, -0.041350313f, 0.021485701f, -0.10906167f,
655 -0.028218046f, -0.00954771f, 0.020531068f, -0.11995105f,
656 -0.03672871f, 0.024019798f, 0.014255957f, -0.05221243f,
657 -0.00661567f, -0.04630967f, 0.033188973f, 0.10107534f,
658 -0.014027541f, 0.030796422f, -0.10270911f, -0.035999842f,
659 0.15443139f, 0.07684145f, 0.036571592f, -0.035900835f,
660 -0.0034699554f, 0.06209149f, 0.015920248f, -0.031122351f,
661 -0.03858649f, 0.01849943f, 0.13872518f, 0.01503974f,
662 0.069941424f, -0.06948533f, -0.0088794185f, 0.061282158f,
663 -0.047401894f, 0.03100163f, -0.041533746f, -0.10430945f,
664 0.044574402f, -0.01425562f, -0.024290353f, 0.034563623f,
665 0.05866852f, 0.023947537f, -0.09445152f, 0.035450947f,
666 0.02247216f, -0.0042998926f, 0.061146557f, -0.10250651f,
667 0.020881841f, -0.06747029f, 0.10062043f, -0.0023941975f,
668 0.03532124f, -0.016341697f, 0.09685456f, -0.016764693f,
669 0.051808182f, 0.05875331f, -0.04536488f, 0.001626336f,
670 -0.028892258f, -0.01048663f, -0.009793449f, -0.017093895f,
671 0.010987891f, 0.02357273f, -0.00010856845f, 0.0099760275f,
672 -0.001845119f, -0.03551521f, 0.0018358806f, 0.05763657f,
673 -0.01769146f, 0.040995963f, 0.02235177f, -0.060430344f,
674 0.11475477f, -0.023854522f, 0.10071741f, 0.0686208f,
675 -0.014250481f, 0.034261297f, 0.047418304f, 0.08562733f,
676 -0.030519066f, 0.0060542435f, 0.014653856f, -0.038836084f,
677 0.04096551f, 0.032249358f, -0.08355519f, -0.026823482f,
678 0.056386515f, -0.010401743f, -0.028396193f, 0.08507674f,
679 0.014410365f, 0.020995233f, 0.17040324f, 0.11511526f,
680 0.02459721f, 0.0066619175f, 0.025853224f, -0.023133837f,
681 -0.081302024f, 0.017264642f, -0.009585969f, 0.09491168f,
682 -0.051313367f, 0.054532815f, -0.014298593f, 0.10657464f,
683 0.007076659f, 0.10964551f, 0.0409152f, 0.008275321f,
Sadik Armagan483c8112021-06-01 09:24:52 +0100684 -0.07283536f, 0.07937492f, 0.04192024f, -0.1075027f };
telsoa01c577f2c2018-08-31 09:22:23 +0100685
Sadik Armagan483c8112021-06-01 09:24:52 +0100686 std::vector<float> recurrentToCellWeights = { -0.037322544f, 0.018592842f, 0.0056175636f, -0.06253426f,
telsoa01c577f2c2018-08-31 09:22:23 +0100687 0.055647098f, -0.05713207f, -0.05626563f, 0.005559383f,
688 0.03375411f, -0.025757805f, -0.088049285f, 0.06017052f,
689 -0.06570978f, 0.007384076f, 0.035123326f, -0.07920549f,
690 0.053676967f, 0.044480428f, -0.07663568f, 0.0071805613f,
691 0.08089997f, 0.05143358f, 0.038261272f, 0.03339287f,
692 -0.027673481f, 0.044746667f, 0.028349208f, 0.020090483f,
693 -0.019443132f, -0.030755889f, -0.0040000007f, 0.04465846f,
694 -0.021585021f, 0.0031670958f, 0.0053199246f, -0.056117613f,
695 -0.10893326f, 0.076739706f, -0.08509834f, -0.027997585f,
696 0.037871376f, 0.01449768f, -0.09002357f, -0.06111149f,
697 -0.046195522f, 0.0422062f, -0.005683705f, -0.1253618f,
698 -0.012925729f, -0.04890792f, 0.06985068f, 0.037654128f,
699 0.03398274f, -0.004781977f, 0.007032333f, -0.031787455f,
700 0.010868644f, -0.031489216f, 0.09525667f, 0.013939797f,
701 0.0058680447f, 0.0167067f, 0.02668468f, -0.04797466f,
702 -0.048885044f, -0.12722108f, 0.035304096f, 0.06554885f,
703 0.00972396f, -0.039238118f, -0.05159735f, -0.11329045f,
704 0.1613692f, -0.03750952f, 0.06529313f, -0.071974665f,
705 -0.11769596f, 0.015524369f, -0.0013754242f, -0.12446318f,
706 0.02786344f, -0.014179351f, 0.005264273f, 0.14376344f,
707 0.015983658f, 0.03406988f, -0.06939408f, 0.040699873f,
708 0.02111075f, 0.09669095f, 0.041345075f, -0.08316494f,
709 -0.07684199f, -0.045768797f, 0.032298047f, -0.041805092f,
710 0.0119405f, 0.0061010392f, 0.12652606f, 0.0064572375f,
711 -0.024950314f, 0.11574242f, 0.04508852f, -0.04335324f,
712 0.06760663f, -0.027437469f, 0.07216407f, 0.06977076f,
713 -0.05438599f, 0.034033038f, -0.028602652f, 0.05346137f,
714 0.043184172f, -0.037189785f, 0.10420091f, 0.00882477f,
715 -0.054019816f, -0.074273005f, -0.030617684f, -0.0028467078f,
716 0.024302477f, -0.0038869337f, 0.005332455f, 0.0013399826f,
717 0.04361412f, -0.007001822f, 0.09631092f, -0.06702025f,
718 -0.042049985f, -0.035070654f, -0.04103342f, -0.10273396f,
719 0.0544271f, 0.037184782f, -0.13150354f, -0.0058036847f,
720 -0.008264958f, 0.042035464f, 0.05891794f, 0.029673764f,
721 0.0063542654f, 0.044788733f, 0.054816857f, 0.062257513f,
722 -0.00093483756f, 0.048938446f, -0.004952862f, -0.007730018f,
723 -0.04043371f, -0.017094059f, 0.07229206f, -0.023670016f,
724 -0.052195564f, -0.025616996f, -0.01520939f, 0.045104615f,
725 -0.007376126f, 0.003533447f, 0.006570588f, 0.056037236f,
726 0.12436656f, 0.051817212f, 0.028532185f, -0.08686856f,
727 0.11868599f, 0.07663395f, -0.07323171f, 0.03463402f,
728 -0.050708205f, -0.04458982f, -0.11590894f, 0.021273347f,
729 0.1251325f, -0.15313013f, -0.12224372f, 0.17228661f,
730 0.023029093f, 0.086124025f, 0.006445803f, -0.03496501f,
731 0.028332196f, 0.04449512f, -0.042436164f, -0.026587414f,
732 -0.006041347f, -0.09292539f, -0.05678812f, 0.03897832f,
733 0.09465633f, 0.008115513f, -0.02171956f, 0.08304309f,
734 0.071401566f, 0.019622514f, 0.032163795f, -0.004167056f,
735 0.02295182f, 0.030739572f, 0.056506045f, 0.004612461f,
736 0.06524936f, 0.059999723f, 0.046395954f, -0.0045512207f,
737 -0.1335546f, -0.030136576f, 0.11584653f, -0.014678886f,
738 0.0020118146f, -0.09688814f, -0.0790206f, 0.039770417f,
739 -0.0329582f, 0.07922767f, 0.029322514f, 0.026405897f,
740 0.04207835f, -0.07073373f, 0.063781224f, 0.0859677f,
741 -0.10925287f, -0.07011058f, 0.048005477f, 0.03438226f,
742 -0.09606514f, -0.006669445f, -0.043381985f, 0.04240257f,
743 -0.06955775f, -0.06769346f, 0.043903265f, -0.026784198f,
744 -0.017840602f, 0.024307009f, -0.040079936f, -0.019946516f,
745 0.045318738f, -0.12233574f, 0.026170589f, 0.0074471775f,
746 0.15978073f, 0.10185836f, 0.10298046f, -0.015476589f,
747 -0.039390966f, -0.072174534f, 0.0739445f, -0.1211869f,
748 -0.0347889f, -0.07943156f, 0.014809798f, -0.12412325f,
749 -0.0030663363f, 0.039695457f, 0.0647603f, -0.08291318f,
750 -0.018529687f, -0.004423833f, 0.0037507233f, 0.084633216f,
751 -0.01514876f, -0.056505352f, -0.012800942f, -0.06994386f,
752 0.012962922f, -0.031234352f, 0.07029052f, 0.016418684f,
753 0.03618972f, 0.055686004f, -0.08663945f, -0.017404709f,
754 -0.054761406f, 0.029065743f, 0.052404847f, 0.020238016f,
755 0.0048197987f, -0.0214882f, 0.07078733f, 0.013016777f,
756 0.06262858f, 0.009184685f, 0.020785125f, -0.043904778f,
757 -0.0270329f, -0.03299152f, -0.060088247f, -0.015162964f,
758 -0.001828936f, 0.12642565f, -0.056757294f, 0.013586685f,
759 0.09232601f, -0.035886683f, 0.06000002f, 0.05229691f,
760 -0.052580316f, -0.082029596f, -0.010794592f, 0.012947712f,
761 -0.036429964f, -0.085508935f, -0.13127148f, -0.017744139f,
762 0.031502828f, 0.036232427f, -0.031581745f, 0.023051167f,
763 -0.05325106f, -0.03421577f, 0.028793324f, -0.034633752f,
764 -0.009881397f, -0.043551125f, -0.018609839f, 0.0019097115f,
Sadik Armagan483c8112021-06-01 09:24:52 +0100765 -0.008799762f, 0.056595087f, 0.0022273948f, 0.055752404f };
telsoa01c577f2c2018-08-31 09:22:23 +0100766
Sadik Armagan483c8112021-06-01 09:24:52 +0100767 std::vector<float> recurrentToOutputWeights = { 0.025825322f, -0.05813119f, 0.09495884f,-0.045984812f, -0.01255415f,
768 -0.0026479573f,-0.08196161f,-0.054914974f,-0.0046604523f,
telsoa01c577f2c2018-08-31 09:22:23 +0100769 -0.029587349f, -0.044576716f, -0.07480124f, -0.082868785f,
770 0.023254942f, 0.027502948f, -0.0039728214f, -0.08683098f,
771 -0.08116779f, -0.014675607f, -0.037924774f, -0.023314456f,
772 -0.007401714f, -0.09255757f, 0.029460307f, -0.08829125f,
773 -0.005139627f, -0.08989442f, -0.0555066f, 0.13596267f,
774 -0.025062224f, -0.048351806f, -0.03850004f, 0.07266485f,
775 -0.022414139f, 0.05940088f, 0.075114764f, 0.09597592f,
776 -0.010211725f, -0.0049794707f, -0.011523867f, -0.025980417f,
777 0.072999895f, 0.11091378f, -0.081685916f, 0.014416728f,
778 0.043229222f, 0.034178585f, -0.07530371f, 0.035837382f,
779 -0.085607f, -0.007721233f, -0.03287832f, -0.043848954f,
780 -0.06404588f, -0.06632928f, -0.073643476f, 0.008214239f,
781 -0.045984086f, 0.039764922f, 0.03474462f, 0.060612556f,
782 -0.080590084f, 0.049127717f, 0.04151091f, -0.030063879f,
783 0.008801774f, -0.023021035f, -0.019558564f, 0.05158114f,
784 -0.010947698f, -0.011825728f, 0.0075720972f, 0.0699727f,
785 -0.0039981045f, 0.069350146f, 0.08799282f, 0.016156472f,
786 0.035502106f, 0.11695009f, 0.006217345f, 0.13392477f,
787 -0.037875112f, 0.025745004f, 0.08940699f, -0.00924166f,
788 0.0046702605f, -0.036598757f, -0.08811812f, 0.10522024f,
789 -0.032441203f, 0.008176899f, -0.04454919f, 0.07058152f,
790 0.0067963637f, 0.039206743f, 0.03259838f, 0.03725492f,
791 -0.09515802f, 0.013326398f, -0.052055415f, -0.025676316f,
792 0.03198509f, -0.015951829f, -0.058556724f, 0.036879618f,
793 0.043357447f, 0.028362012f, -0.05908629f, 0.0059240665f,
794 -0.04995891f, -0.019187413f,0.0276265f, -0.01628143f, 0.0025863599f,
795 0.08800015f, 0.035250366f, -0.022165963f, -0.07328642f,
796 -0.009415526f, -0.07455109f, 0.11690406f, 0.0363299f,
797 0.07411125f, 0.042103454f, -0.009660886f, 0.019076364f,
798 0.018299393f, -0.046004917f, 0.08891175f,0.0431396f, -0.026327137f,
799 -0.051502608f, 0.08979574f, -0.051670972f, 0.04940282f,
800 -0.07491107f, -0.021240504f, 0.022596184f, -0.034280192f,
801 0.060163025f, -0.058211457f, -0.051837247f, -0.01349775f,
802 -0.04639988f, -0.035936575f, -0.011681591f, 0.064818054f,
803 0.0073146066f, -0.021745546f, -0.043124277f, -0.06471268f,
804 -0.07053354f, -0.029321948f, -0.05330136f, 0.016933719f,
805 -0.053782392f, 0.13747959f, -0.1361751f, -0.11569455f,
806 0.0033329215f, 0.05693899f, -0.053219706f, 0.063698f,
807 0.07977434f, -0.07924483f, 0.06936997f, 0.0034815092f,
808 -0.007305279f, -0.037325785f, -0.07251102f, -0.033633437f,
809 -0.08677009f, 0.091591336f, -0.14165086f, 0.021752775f,
810 0.019683983f, 0.0011612234f, -0.058154266f, 0.049996935f,
811 0.0288841f, -0.0024567875f, -0.14345716f, 0.010955264f,-0.10234828f,
812 0.1183656f, -0.0010731248f, -0.023590032f,-0.072285876f,-0.0724771f,
813 -0.026382286f, -0.0014920527f, 0.042667855f, 0.0018776858f,
814 0.02986552f, 0.009814309f, 0.0733756f, 0.12289186f,
815 0.018043943f, -0.0458958f, 0.049412545f, 0.033632483f,
816 0.05495232f, 0.036686596f, -0.013781798f, -0.010036754f,
817 0.02576849f, -0.08307328f, 0.010112348f, 0.042521734f,
818 -0.05869831f, -0.071689695f, 0.03876447f, -0.13275425f, -0.0352966f,
819 -0.023077697f, 0.10285965f, 0.084736146f, 0.15568255f,
820 -0.00040734606f, 0.027835453f, -0.10292561f, -0.032401145f,
821 0.10053256f, -0.026142767f, -0.08271222f, -0.0030240538f,
822 -0.016368777f, 0.1070414f, 0.042672627f, 0.013456989f,
823 -0.0437609f, -0.022309763f, 0.11576483f, 0.04108048f,
824 0.061026827f, -0.0190714f, -0.0869359f, 0.037901703f, 0.0610107f,
825 0.07202949f, 0.01675338f, 0.086139716f, -0.08795751f,
826 -0.014898893f, -0.023771819f, -0.01965048f, 0.007955471f,
827 -0.043740474f, 0.03346837f, -0.10549954f, 0.090567775f,
828 0.042013682f, -0.03176985f, 0.12569028f, -0.02421228f,
829 -0.029526481f, 0.023851605f, 0.031539805f, 0.05292009f,
830 -0.02344001f, -0.07811758f, -0.08834428f, 0.10094801f,
831 0.16594367f, -0.06861939f, -0.021256343f, -0.041093912f,
832 -0.06669611f, 0.035498552f, 0.021757556f, -0.09302526f,
833 -0.015403468f, -0.06614931f, -0.051798206f, -0.013874718f,
834 0.03630673f, 0.010412845f, -0.08077351f, 0.046185967f,
835 0.0035662893f, 0.03541868f, -0.094149634f, -0.034814864f,
836 0.003128424f, -0.020674974f, -0.03944324f, -0.008110165f,
837 -0.11113267f, 0.08484226f, 0.043586485f, 0.040582247f,
838 0.0968012f, -0.065249965f, -0.028036479f, 0.0050708856f,
839 0.0017462453f, 0.0326779f, 0.041296225f, 0.09164146f,
840 -0.047743853f, -0.015952192f, -0.034451712f, 0.084197424f,
841 -0.05347844f, -0.11768019f, 0.085926116f, -0.08251791f,
842 -0.045081906f, 0.0948852f, 0.068401024f, 0.024856757f,
843 0.06978981f, -0.057309967f, -0.012775832f, -0.0032452994f,
Sadik Armagan483c8112021-06-01 09:24:52 +0100844 0.01977615f, -0.041040014f, -0.024264973f,0.063464895f, 0.05431621f};
telsoa01c577f2c2018-08-31 09:22:23 +0100845
Sadik Armagan483c8112021-06-01 09:24:52 +0100846 std::vector<float> cellToInputWeights = {0.040369894f, 0.030746894f, 0.24704495f, 0.018586371f, -0.037586458f,
847 -0.15312155f, -0.11812848f, -0.11465643f, 0.20259799f, 0.11418174f,
848 -0.10116027f, -0.011334949f, 0.12411352f, -0.076769054f,-0.052169047f,
849 0.21198851f, -0.38871562f, -0.09061183f, -0.09683246f, -0.21929175f};
telsoa01c577f2c2018-08-31 09:22:23 +0100850
851
Sadik Armagan483c8112021-06-01 09:24:52 +0100852 std::vector<float> cellToForgetWeights = {-0.01998659f,-0.15568835f,-0.24248174f, -0.012770197f, 0.041331276f,
853 -0.072311886f, -0.052123554f,-0.0066330447f,-0.043891653f,0.036225766f,
854 -0.047248036f, 0.021479502f,0.033189066f, 0.11952997f, -0.020432774f,
855 0.64658105f, -0.06650122f, -0.03467612f, 0.095340036f, 0.23647355f};
telsoa01c577f2c2018-08-31 09:22:23 +0100856
Sadik Armagan483c8112021-06-01 09:24:52 +0100857 std::vector<float> cellToOutputWeights = { 0.08286371f, -0.08261836f, -0.51210177f, 0.002913762f, 0.17764764f,
858 -0.5495371f, -0.08460716f, -0.24552552f, 0.030037103f, 0.04123544f,
859 -0.11940523f, 0.007358328f, 0.1890978f, 0.4833202f, -0.34441817f,
860 0.36312827f, -0.26375428f, 0.1457655f, -0.19724406f, 0.15548733f};
telsoa01c577f2c2018-08-31 09:22:23 +0100861
Sadik Armagan483c8112021-06-01 09:24:52 +0100862 std::vector<float> projectionWeights={-0.009802181f, 0.09401916f, 0.0717386f, -0.13895074f, 0.09641832f,
863 0.060420845f, 0.08539281f, 0.054285463f, 0.061395317f, 0.034448683f,
864 -0.042991187f, 0.019801661f, -0.16840284f, -0.015726732f, -0.23041931f,
865 -0.024478018f, -0.10959692f, -0.013875541f, 0.18600968f, -0.061274476f,
866 0.0138165f, -0.08160894f, -0.07661644f, 0.032372914f, 0.16169067f,
867 0.22465782f, -0.03993472f, -0.004017731f, 0.08633481f, -0.28869787f,
868 0.08682067f, 0.17240396f, 0.014975425f, 0.056431185f, 0.031037588f,
869 0.16702051f, 0.0077946745f, 0.15140012f, 0.29405436f, 0.120285f,
870 -0.188994f, -0.027265169f, 0.043389652f, -0.022061434f, 0.014777949f,
871 -0.20203483f, 0.094781205f, 0.19100232f, 0.13987629f, -0.036132768f,
872 -0.06426278f, -0.05108664f, 0.13221376f, 0.009441198f, -0.16715929f,
873 0.15859416f, -0.040437475f, 0.050779544f, -0.022187516f, 0.012166504f,
874 0.027685808f, -0.07675938f, -0.0055694645f, -0.09444123f, 0.0046453946f,
875 0.050794356f, 0.10770313f, -0.20790008f, -0.07149004f, -0.11425117f,
876 0.008225835f, -0.035802525f, 0.14374903f, 0.15262283f, 0.048710253f,
877 0.1847461f, -0.007487823f, 0.11000021f, -0.09542012f, 0.22619456f,
878 -0.029149994f, 0.08527916f, 0.009043713f, 0.0042746216f, 0.016261552f,
879 0.022461696f, 0.12689082f, -0.043589946f, -0.12035478f, -0.08361797f,
880 -0.050666027f, -0.1248618f, -0.1275799f, -0.071875185f, 0.07377272f,
881 0.09944291f, -0.18897448f, -0.1593054f, -0.06526116f, -0.040107165f,
882 -0.004618631f, -0.067624845f, -0.007576253f, 0.10727444f, 0.041546922f,
883 -0.20424393f, 0.06907816f, 0.050412357f, 0.00724631f, 0.039827548f,
884 0.12449835f, 0.10747581f, 0.13708383f, 0.09134148f, -0.12617786f,
885 -0.06428341f, 0.09956831f, 0.1208086f, -0.14676677f, -0.0727722f,
886 0.1126304f, 0.010139365f, 0.015571211f, -0.038128063f, 0.022913318f,
887 -0.042050496f, 0.16842307f, -0.060597885f, 0.10531834f, -0.06411776f,
888 -0.07451711f, -0.03410368f, -0.13393489f, 0.06534304f, 0.003620307f,
889 0.04490757f, 0.05970546f, 0.05197996f, 0.02839995f, 0.10434969f,
890 -0.013699693f, -0.028353551f, -0.07260381f, 0.047201227f, -0.024575593f,
891 -0.036445823f, 0.07155557f, 0.009672501f, -0.02328883f, 0.009533515f,
892 -0.03606021f, -0.07421458f, -0.028082801f, -0.2678904f, -0.13221288f,
893 0.18419984f, -0.13012612f, -0.014588381f, -0.035059117f, -0.04824723f,
894 0.07830115f, -0.056184657f, 0.03277091f, 0.025466874f, 0.14494097f,
895 -0.12522776f, -0.098633975f, -0.10766018f, -0.08317623f, 0.08594209f,
896 0.07749552f, 0.039474737f, 0.1776665f, -0.07409566f, -0.0477268f,
897 0.29323658f, 0.10801441f, 0.1154011f, 0.013952499f, 0.10739139f,
898 0.10708251f, -0.051456142f, 0.0074137426f, -0.10430189f, 0.10034707f,
899 0.045594677f, 0.0635285f, -0.0715442f, -0.089667566f, -0.10811871f,
900 0.00026344223f, 0.08298446f, -0.009525053f, 0.006585689f, -0.24567553f,
901 -0.09450807f, 0.09648481f, 0.026996298f, -0.06419476f, -0.04752702f,
902 -0.11063944f, -0.23441927f, -0.17608605f, -0.052156363f, 0.067035615f,
903 0.19271925f, -0.0032889997f, -0.043264326f, 0.09663576f, -0.057112187f,
904 -0.10100678f, 0.0628376f, 0.04447668f, 0.017961001f, -0.10094388f,
905 -0.10190601f, 0.18335468f, 0.10494553f, -0.052095775f, -0.0026118709f,
906 0.10539724f, -0.04383912f, -0.042349473f, 0.08438151f, -0.1947263f,
907 0.02251204f, 0.11216432f, -0.10307853f, 0.17351969f, -0.039091777f,
908 0.08066188f, -0.00561982f, 0.12633002f, 0.11335965f, -0.0088127935f,
909 -0.019777594f, 0.06864014f, -0.059751723f, 0.016233567f, -0.06894641f,
910 -0.28651384f, -0.004228674f, 0.019708522f, -0.16305895f, -0.07468996f,
911 -0.0855457f, 0.099339016f, -0.07580735f, -0.13775392f, 0.08434318f,
912 0.08330512f, -0.12131499f, 0.031935584f, 0.09180414f, -0.08876437f,
913 -0.08049874f, 0.008753825f, 0.03498998f, 0.030215185f, 0.03907079f,
914 0.089751154f, 0.029194152f, -0.03337423f, -0.019092513f, 0.04331237f,
915 0.04299654f, -0.036394123f, -0.12915532f, 0.09793732f, 0.07512415f,
916 -0.11319543f, -0.032502122f, 0.15661901f, 0.07671967f, -0.005491124f,
917 -0.19379048f, -0.218606f, 0.21448623f, 0.017840758f, 0.1416943f,
918 -0.07051762f, 0.19488361f, 0.02664691f, -0.18104725f, -0.09334311f,
919 0.15026465f, -0.15493552f, -0.057762887f, -0.11604192f, -0.262013f,
920 -0.01391798f, 0.012185008f, 0.11156489f, -0.07483202f, 0.06693364f,
921 -0.26151478f, 0.046425626f, 0.036540434f, -0.16435726f, 0.17338543f,
922 -0.21401681f, -0.11385144f, -0.08283257f, -0.069031075f, 0.030635102f,
923 0.010969227f, 0.11109743f, 0.010919218f, 0.027526086f, 0.13519906f,
924 0.01891392f, -0.046839405f, -0.040167913f, 0.017953383f, -0.09700955f,
925 0.0061885654f, -0.07000971f, 0.026893595f, -0.038844477f, 0.14543656f};
telsoa01c577f2c2018-08-31 09:22:23 +0100926
927 std::vector<float> projectionBiasVector(outputSize, 0.f);
telsoa01c577f2c2018-08-31 09:22:23 +0100928
James Conroy1f58f032021-04-27 17:13:27 +0100929 armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo20x5);
930 armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo20x5);
931 armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo20x5);
932 armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo20x5);
933 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo20x16);
934 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo20x16);
935 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo20x16);
936 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo20x16);
937 armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo20);
938 armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo20);
939 armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo20);
940 armnn::ScopedTensorHandle cellBiasTensor(tensorInfo20);
941 armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo20);
942 armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfo20);
943 armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfo20);
944 armnn::ScopedTensorHandle projectionWeightsTensor(tensorInfo16x20);
945 armnn::ScopedTensorHandle projectionBiasTensor(tensorInfo16);
telsoa01c577f2c2018-08-31 09:22:23 +0100946
Sadik Armagan483c8112021-06-01 09:24:52 +0100947 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
948 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
949 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
950 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
951 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
952 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
953 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
954 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
955 AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data());
956 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
957 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
958 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
959 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
960 AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data());
961 AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data());
962 AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
963 AllocateAndCopyDataToITensorHandle(&projectionBiasTensor, projectionBiasVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +0100964
965 data.m_InputToInputWeights = &inputToInputWeightsTensor;
966 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
967 data.m_InputToCellWeights = &inputToCellWeightsTensor;
968 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
969 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
970 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
971 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
972 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
973 data.m_CellToInputWeights = &cellToInputWeightsTensor;
974 data.m_InputGateBias = &inputGateBiasTensor;
975 data.m_ForgetGateBias = &forgetGateBiasTensor;
976 data.m_CellBias = &cellBiasTensor;
977 data.m_OutputGateBias = &outputGateBiasTensor;
978 data.m_CellToForgetWeights = &cellToForgetWeightsTensor;
979 data.m_CellToOutputWeights = &cellToOutputWeightsTensor;
980 data.m_ProjectionWeights = &projectionWeightsTensor;
981 data.m_ProjectionBias = &projectionBiasTensor;
982
983 // Flags to set test configuration
984 data.m_Parameters.m_ActivationFunc = 4;
985 data.m_Parameters.m_CifgEnabled = false;
986 data.m_Parameters.m_PeepholeEnabled = true;
987 data.m_Parameters.m_ProjectionEnabled = true;
988
telsoa01c577f2c2018-08-31 09:22:23 +0100989 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info);
990 inputHandle->Allocate();
991 outputStateInHandle->Allocate();
992 cellStateInHandle->Allocate();
993
994 scratchHandle->Allocate();
995 outputStateOutHandle->Allocate();
996 cellStateOutHandle->Allocate();
997 outputHandle->Allocate();
998
Sadik Armagan483c8112021-06-01 09:24:52 +0100999 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1000 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1001 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001002
telsoa01c577f2c2018-08-31 09:22:23 +01001003 workload->Execute();
1004
Sadik Armagan483c8112021-06-01 09:24:52 +01001005 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
telsoa01c577f2c2018-08-31 09:22:23 +01001006
Sadik Armagan483c8112021-06-01 09:24:52 +01001007 return LayerTestResult<T, 2>(actualOutput,
1008 outputVector,
1009 outputHandle->GetShape(),
1010 outputTensorInfo.GetShape());
telsoa01c577f2c2018-08-31 09:22:23 +01001011}
1012
Conor Kennedyb9971c92019-05-07 07:14:23 +01001013template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
1014LayerTestResult<T, 2> LstmLayerWithCifgWithPeepholeNoProjectionTestImpl(
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +00001015 armnn::IWorkloadFactory& workloadFactory,
1016 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001017 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001018 const std::vector<T>& input,
1019 const std::vector<T>& outputExpected,
1020 const armnn::TensorShape& inputShape,
1021 const armnn::TensorShape& outputExpectedShape,
Conor Kennedyb9971c92019-05-07 07:14:23 +01001022 float qScale = 0.0f,
1023 int32_t qOffset = 0,
1024 armnn::DataType constantDataType = armnn::DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +01001025{
Jan Eilers8eb25602020-03-09 12:13:48 +00001026 IgnoreUnused(memoryManager);
telsoa01c577f2c2018-08-31 09:22:23 +01001027 bool cifgEnabled = true;
1028 bool peepholeEnabled = true;
1029 bool projectionEnabled = false;
1030 // These are not the input and the output of Lstm yet
Sadik Armagan483c8112021-06-01 09:24:52 +01001031 unsigned int batchSize = armnn::numeric_cast<unsigned int>(inputShape[0]);
1032 unsigned int inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]);
telsoa01c577f2c2018-08-31 09:22:23 +01001033
Sadik Armagan483c8112021-06-01 09:24:52 +01001034 unsigned int outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]);
telsoa01c577f2c2018-08-31 09:22:23 +01001035
1036 const unsigned int cellSize = outputSize;
1037
1038 // Decide the shape of all input tensors
Conor Kennedyb9971c92019-05-07 07:14:23 +01001039 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset); // change to ArmnnType
1040 armnn::TensorInfo outputStateInTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1041 armnn::TensorInfo cellStateInTensorInfo({batchSize, cellSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +01001042
Matteo Martincigha65b7ae2018-11-14 12:39:55 +00001043 unsigned int scratchBufferSize = cifgEnabled ? cellSize * 3 : cellSize * 4;
Conor Kennedyb9971c92019-05-07 07:14:23 +01001044 armnn::TensorInfo scratchBufferTensorInfo({batchSize, scratchBufferSize}, ArmnnType, qScale, qOffset);
1045 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1046 armnn::TensorInfo cellStateOutTensorInfo({batchSize, cellSize}, ArmnnType, qScale, qOffset);
1047 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +01001048
1049 // List of inputs
1050 std::vector<float> inputData;
1051 inputData.assign(input.data(), input.data() + batchSize*inputSize);
telsoa01c577f2c2018-08-31 09:22:23 +01001052
1053 std::vector<float> outputStateInVector(batchSize * outputSize, 0.f);
telsoa01c577f2c2018-08-31 09:22:23 +01001054
1055 std::vector<float> cellStateInVector(batchSize * cellSize, 0.f);
telsoa01c577f2c2018-08-31 09:22:23 +01001056
1057 // Prepare all the weights in the descriptor for LSTM
1058 armnn::LstmQueueDescriptor data;
Conor Kennedyb9971c92019-05-07 07:14:23 +01001059 armnn::TensorInfo tensorInfoInput({cellSize, inputSize}, constantDataType, qScale, qOffset);
1060 armnn::TensorInfo tensorInfoOutput({cellSize, outputSize}, constantDataType, qScale, qOffset);
1061 armnn::TensorInfo tensorInfoNumUnits({cellSize}, constantDataType, qScale, qOffset);
telsoa01c577f2c2018-08-31 09:22:23 +01001062
Sadik Armagan483c8112021-06-01 09:24:52 +01001063 std::vector<float> inputToCellWeights =
1064 {
1065 -0.49770179f, -0.27711356f, -0.09624726f, 0.05100781f,
1066 0.04717243f, 0.48944736f, -0.38535351f,
1067 -0.17212132f
1068 };
1069 std::vector<float> inputToForgetWeights =
1070 {
1071 -0.55291498f, -0.42866567f, 0.13056988f,
1072 -0.3633365f, -0.22755712f, 0.28253698f, 0.24407166f,
1073 0.33826375f
1074 };
1075 std::vector<float> inputToOutputWeights =
1076 {
1077 0.10725588f, -0.02335852f, -0.55932593f,
1078 -0.09426838f, -0.44257352f, 0.54939759f,
1079 0.01533556f, 0.42751634f
1080 };
1081 std::vector<float> cellBias = {0.f, 0.f, 0.f, 0.f};
1082 std::vector<float> forgetGateBias = {1.f, 1.f, 1.f, 1.f};
1083 std::vector<float> outputGateBias = {0.f, 0.f, 0.f, 0.f};
telsoa01c577f2c2018-08-31 09:22:23 +01001084
Sadik Armagan483c8112021-06-01 09:24:52 +01001085 std::vector<float> recurrentToCellWeights =
1086 {
1087 0.54066205f, -0.32668582f, -0.43562764f, -0.56094903f, 0.42957711f,
1088 0.01841056f, -0.32764608f, -0.33027974f, -0.10826075f, 0.20675004f,
1089 0.19069612f, -0.03026325f, -0.54532051f, 0.33003211f, 0.44901288f,
1090 0.21193194f
1091 };
1092 std::vector<float> recurrentToForgetWeights =
1093 {
1094 -0.13832897f, -0.0515101f, -0.2359007f, -0.16661474f, -0.14340827f,
1095 0.36986142f, 0.23414481f, 0.55899f, 0.10798943f, -0.41174671f, 0.17751795f,
1096 -0.34484994f, -0.35874045f, -0.11352962f, 0.27268326f, 0.54058349f
1097 };
telsoa01c577f2c2018-08-31 09:22:23 +01001098
Sadik Armagan483c8112021-06-01 09:24:52 +01001099 std::vector<float> recurrentToOutputWeights =
1100 {
1101 0.41613156f, 0.42610586f, -0.16495961f, -0.5663873f, 0.30579174f, -0.05115908f,
1102 -0.33941799f, 0.23364776f, 0.11178309f, 0.09481031f, -0.26424935f, 0.46261835f,
1103 0.50248802f, 0.26114327f, -0.43736315f, 0.33149987f
1104 };
telsoa01c577f2c2018-08-31 09:22:23 +01001105
Sadik Armagan483c8112021-06-01 09:24:52 +01001106 std::vector<float> cellToForgetWeights = {0.47485286f, -0.51955009f, -0.24458408f, 0.31544167f};
1107 std::vector<float> cellToOutputWeights = {-0.17135078f, 0.82760304f, 0.85573703f, -0.77109635f};
telsoa01c577f2c2018-08-31 09:22:23 +01001108
James Conroy1f58f032021-04-27 17:13:27 +01001109 armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfoInput);
1110 armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfoInput);
1111 armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfoInput);
telsoa01c577f2c2018-08-31 09:22:23 +01001112
James Conroy1f58f032021-04-27 17:13:27 +01001113 armnn::ScopedTensorHandle cellBiasTensor(tensorInfoNumUnits);
1114 armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfoNumUnits);
1115 armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfoNumUnits);
telsoa01c577f2c2018-08-31 09:22:23 +01001116
James Conroy1f58f032021-04-27 17:13:27 +01001117 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfoOutput);
1118 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfoOutput);
1119 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfoOutput);
telsoa01c577f2c2018-08-31 09:22:23 +01001120
James Conroy1f58f032021-04-27 17:13:27 +01001121 armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfoNumUnits);
1122 armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfoNumUnits);
telsoa01c577f2c2018-08-31 09:22:23 +01001123
Sadik Armagan483c8112021-06-01 09:24:52 +01001124 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1125 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1126 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001127
Sadik Armagan483c8112021-06-01 09:24:52 +01001128 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1129 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1130 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001131
Sadik Armagan483c8112021-06-01 09:24:52 +01001132 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1133 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1134 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001135
Sadik Armagan483c8112021-06-01 09:24:52 +01001136 AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data());
1137 AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001138
1139 data.m_InputToCellWeights = &inputToCellWeightsTensor;
1140 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1141 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1142
1143 data.m_CellBias = &cellBiasTensor;
1144 data.m_ForgetGateBias = &forgetGateBiasTensor;
1145 data.m_OutputGateBias = &outputGateBiasTensor;
1146
1147 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1148 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1149 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1150
1151 data.m_CellToForgetWeights = &cellToForgetWeightsTensor;
1152 data.m_CellToOutputWeights = &cellToOutputWeightsTensor;
1153
1154 // other parameters for the descriptor
1155 data.m_Parameters.m_CifgEnabled = cifgEnabled;
1156 data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
1157 data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
1158
1159 data.m_Parameters.m_ActivationFunc = 4;
1160 data.m_Parameters.m_ClippingThresProj = 0.0;
1161 data.m_Parameters.m_ClippingThresCell = 0.0;
1162
telsoa01c577f2c2018-08-31 09:22:23 +01001163 // List of outputs
Rob Hughesbb46dde2020-05-20 15:27:37 +01001164 std::vector<T> scratchBufferVector(batchSize * scratchBufferSize, T());
Conor Kennedyb9971c92019-05-07 07:14:23 +01001165 LayerTestResult<T, 2> ret0(scratchBufferTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001166
1167 // Output state for a certain time step
Rob Hughesbb46dde2020-05-20 15:27:37 +01001168 std::vector<T> outputStateOutVector(batchSize * outputSize, T());
Conor Kennedyb9971c92019-05-07 07:14:23 +01001169 LayerTestResult<T, 2> ret1(outputStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001170
1171 // Cell state for a certain time step
Rob Hughesbb46dde2020-05-20 15:27:37 +01001172 std::vector<T> cellStateOutVector(batchSize * cellSize, T());
Conor Kennedyb9971c92019-05-07 07:14:23 +01001173 LayerTestResult<T, 2> ret2(cellStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001174
1175 // Output for a certain time step
Rob Hughesbb46dde2020-05-20 15:27:37 +01001176 std::vector<T> outputData;
telsoa01c577f2c2018-08-31 09:22:23 +01001177 outputData.assign(outputExpected.data(), outputExpected.data() + batchSize*outputSize);
Conor Kennedyb9971c92019-05-07 07:14:23 +01001178 LayerTestResult<T, 2> ret3(outputTensorInfo);
Sadik Armagan483c8112021-06-01 09:24:52 +01001179 ret3.m_ExpectedData = outputData;
1180
1181 std::vector<T> actualScratchBufferOutput(scratchBufferTensorInfo.GetNumElements());
1182 std::vector<T> actualOutputStateOutput(outputStateOutTensorInfo.GetNumElements());
1183 std::vector<T> actualCellStateOutput(cellStateOutTensorInfo.GetNumElements());
1184 std::vector<T> actualOutput(outputTensorInfo.GetNumElements());
telsoa01c577f2c2018-08-31 09:22:23 +01001185
1186 // Prepare the inputs and outputs for the workload
1187 std::unique_ptr<armnn::ITensorHandle> inputHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001188 tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001189 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001190 tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001191 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001192 tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001193
1194 std::unique_ptr<armnn::ITensorHandle> scratchBufferHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001195 tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001196 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001197 tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001198 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001199 tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001200 std::unique_ptr<armnn::ITensorHandle> outputHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001201 tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
telsoa01c577f2c2018-08-31 09:22:23 +01001202
1203 armnn::WorkloadInfo info;
1204 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
1205 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
1206 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
1207
1208 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchBufferHandle.get());
1209 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
1210 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
1211 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
1212
1213 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info);
1214
telsoa01c577f2c2018-08-31 09:22:23 +01001215 inputHandle->Allocate();
1216 outputStateInHandle->Allocate();
1217 cellStateInHandle->Allocate();
1218
1219 scratchBufferHandle->Allocate();
1220 outputStateOutHandle->Allocate();
1221 cellStateOutHandle->Allocate();
1222 outputHandle->Allocate();
1223
Sadik Armagan483c8112021-06-01 09:24:52 +01001224 CopyDataToITensorHandle(inputHandle.get(), inputData.data());
1225 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1226 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001227
Sadik Armagan483c8112021-06-01 09:24:52 +01001228 CopyDataToITensorHandle(scratchBufferHandle.get(), scratchBufferVector.data());
1229 CopyDataToITensorHandle(outputStateOutHandle.get(), outputStateOutVector.data());
1230 CopyDataToITensorHandle(cellStateOutHandle.get(), cellStateOutVector.data());
telsoa01c577f2c2018-08-31 09:22:23 +01001231
telsoa01c577f2c2018-08-31 09:22:23 +01001232 workload->Execute();
1233
Sadik Armagan483c8112021-06-01 09:24:52 +01001234 CopyDataFromITensorHandle(actualScratchBufferOutput.data(), scratchBufferHandle.get());
1235 CopyDataFromITensorHandle(actualOutputStateOutput.data(), outputStateOutHandle.get());
1236 CopyDataFromITensorHandle(actualCellStateOutput.data(), cellStateOutHandle.get());
1237 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
1238
1239 ret0.m_ActualData = actualScratchBufferOutput;
1240 ret1.m_ActualData = actualOutputStateOutput;
1241 ret2.m_ActualData = actualCellStateOutput;
1242 ret3.m_ActualData = actualOutput;
telsoa01c577f2c2018-08-31 09:22:23 +01001243
1244 return ret3;
1245}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001246
Jan Eilers38e05bd2019-06-26 13:10:09 +01001247template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
1248LayerTestResult<T, 2>
1249LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl(armnn::IWorkloadFactory& workloadFactory,
1250 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001251 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001252 const std::vector<T>& input,
1253 const std::vector<T>& outputExpected,
Jan Eilers38e05bd2019-06-26 13:10:09 +01001254 float qScale = 0.0f,
1255 int32_t qOffset = 0,
1256 armnn::DataType constantDataType = armnn::DataType::Float32)
1257{
Jan Eilers8eb25602020-03-09 12:13:48 +00001258 IgnoreUnused(memoryManager);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001259 unsigned int batchSize = 2;
1260 unsigned int outputSize = 3;
1261 unsigned int inputSize = 5;
1262 unsigned numUnits = 4;
1263
1264 armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset);
1265 armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset);
1266 armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset);
1267
1268 // Scratch buffer size without CIFG [batchSize, numUnits * 4]
1269 armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset);
1270 armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset);
1271 armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1272 armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset);
1273
Jan Eilers38e05bd2019-06-26 13:10:09 +01001274 std::vector<float> inputVector;
1275 inputVector.assign(input.data(), input.data() + (batchSize * inputSize));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001276
1277 std::vector<float> cellStateInVector(batchSize * numUnits, 0.f);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001278 std::vector<float> outputStateInVector(batchSize * outputSize, 0.f);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001279 std::vector<float> scratchBufferVector(batchSize * numUnits * 4, 0.f);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001280 std::vector<float> outputStateOutVector(batchSize * outputSize, 0.f);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001281 std::vector<float> cellStateOutVector(batchSize * numUnits, 0.f);
Sadik Armagan483c8112021-06-01 09:24:52 +01001282
1283 std::vector<float> actualOutput(outputTensorInfo.GetNumElements());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001284
1285 std::vector<float> outputVector;
1286 outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize));
Jan Eilers38e05bd2019-06-26 13:10:09 +01001287
Finn Williamsc43de6a2020-08-27 11:13:25 +01001288 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001289 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001290 tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001291 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001292 tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001293
Finn Williamsc43de6a2020-08-27 11:13:25 +01001294 std::unique_ptr<armnn::ITensorHandle> scratchHandle =
1295 tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001296 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001297 tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001298 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001299 tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo);
1300 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001301
1302 armnn::LstmQueueDescriptor data;
1303 armnn::WorkloadInfo info;
1304
1305 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
1306 AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get());
1307 AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get());
1308
1309 AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get());
1310 AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get());
1311 AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get());
1312 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
1313
1314 armnn::TensorInfo tensorInfo3({outputSize}, constantDataType, qScale, qOffset);
1315 armnn::TensorInfo tensorInfo4({numUnits}, constantDataType, qScale, qOffset);
1316 armnn::TensorInfo tensorInfo4x5({numUnits, inputSize}, constantDataType, qScale, qOffset);
1317 armnn::TensorInfo tensorInfo4x3({numUnits, outputSize}, constantDataType, qScale, qOffset);
1318 armnn::TensorInfo tensorInfo3x4({outputSize, numUnits}, constantDataType, qScale, qOffset);
1319
Sadik Armagan483c8112021-06-01 09:24:52 +01001320 std::vector<float> inputToInputWeights = {0.5f, 0.6f, 0.7f, -0.8f, -0.9f,
1321 0.1f, 0.2f, 0.3f, -0.4f, 0.5f,
1322 -0.8f, 0.7f, -0.6f, 0.5f, -0.4f,
1323 -0.5f, -0.4f, -0.3f, -0.2f, -0.1f}; //{numUnits, inputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001324
Sadik Armagan483c8112021-06-01 09:24:52 +01001325 std::vector<float> inputToForgetWeights = { -0.6f, -0.1f, 0.3f, 0.2f, 0.9f,
1326 -0.5f, -0.2f, -0.4f, 0.3f, -0.8f,
1327 -0.4f, 0.3f, -0.5f, -0.4f, -0.6f,
1328 0.3f, -0.4f, -0.6f, -0.5f, -0.5f}; //{numUnits, inputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001329
Sadik Armagan483c8112021-06-01 09:24:52 +01001330 std::vector<float> inputToCellWeights = {-0.4f, -0.3f, -0.2f, -0.1f, -0.5f,
1331 0.5f, -0.2f, -0.3f, -0.2f, -0.6f,
1332 0.6f, -0.1f, -0.4f, -0.3f, -0.7f,
1333 0.7f, -0.9f, -0.5f, 0.8f, 0.6f}; //{numUnits, inputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001334
Sadik Armagan483c8112021-06-01 09:24:52 +01001335 std::vector<float> inputToOutputWeights = {-0.8f, -0.4f, -0.2f, -0.9f, -0.1f,
1336 -0.7f, 0.3f, -0.3f, -0.8f, -0.2f,
1337 0.6f, -0.2f, 0.4f, -0.7f, -0.3f,
1338 -0.5f, 0.1f, 0.5f, -0.6f, -0.4f}; //{numUnits, inputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001339
Sadik Armagan483c8112021-06-01 09:24:52 +01001340 std::vector<float> inputGateBias = {0.03f, 0.15f, 0.22f, 0.38f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001341
Sadik Armagan483c8112021-06-01 09:24:52 +01001342 std::vector<float> forgetGateBias = {0.1f, -0.3f, -0.2f, 0.1f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001343
Sadik Armagan483c8112021-06-01 09:24:52 +01001344 std::vector<float> cellBias = {-0.05f, 0.72f, 0.25f, 0.08f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001345
Sadik Armagan483c8112021-06-01 09:24:52 +01001346 std::vector<float> outputGateBias = {0.05f, -0.01f, 0.2f, 0.1f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001347
Sadik Armagan483c8112021-06-01 09:24:52 +01001348 std::vector<float> recurrentToInputWeights ={-0.2f, -0.3f, 0.4f,
Jan Eilers38e05bd2019-06-26 13:10:09 +01001349 0.1f, -0.5f, 0.9f,
1350 -0.2f, -0.3f, -0.7f,
Sadik Armagan483c8112021-06-01 09:24:52 +01001351 0.05f, -0.2f, -0.6f}; //{numUnits, outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001352
Sadik Armagan483c8112021-06-01 09:24:52 +01001353 std::vector<float> recurrentToCellWeights = {-0.3f, 0.2f, 0.1f,
Jan Eilers38e05bd2019-06-26 13:10:09 +01001354 -0.3f, 0.8f, -0.08f,
1355 -0.2f, 0.3f, 0.8f,
Sadik Armagan483c8112021-06-01 09:24:52 +01001356 -0.6f, -0.1f, 0.2f}; //{numUnits, outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001357
Sadik Armagan483c8112021-06-01 09:24:52 +01001358 std::vector<float> recurrentToForgetWeights = { -0.5f, -0.3f, -0.5f,
1359 -0.2f, 0.6f, 0.4f,
1360 0.9f, 0.3f, -0.1f,
1361 0.2f, 0.5f, 0.2f}; //{numUnits, outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001362
Sadik Armagan483c8112021-06-01 09:24:52 +01001363 std::vector<float> recurrentToOutputWeights = { 0.3f, -0.1f, 0.1f,
1364 -0.2f, -0.5f, -0.7f,
1365 -0.2f, -0.6f, -0.1f,
1366 -0.4f, -0.7f, -0.2f}; //{numUnits, outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001367
Sadik Armagan483c8112021-06-01 09:24:52 +01001368 std::vector<float> cellToInputWeights = {0.05f, 0.1f, 0.25f, 0.15f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001369
Sadik Armagan483c8112021-06-01 09:24:52 +01001370 std::vector<float> cellToForgetWeights = {-0.02f, -0.15f, -0.25f, -0.03f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001371
Sadik Armagan483c8112021-06-01 09:24:52 +01001372 std::vector<float> cellToOutputWeights = {0.1f, -0.1f, -0.5f, 0.05f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001373
Sadik Armagan483c8112021-06-01 09:24:52 +01001374 std::vector<float> projectionWeights = {-0.1f, 0.2f, 0.01f, -0.2f,
1375 0.1f, 0.5f, 0.3f, 0.08f,
1376 0.07f, 0.2f, -0.4f, 0.2f}; //{outputSize, numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001377
Sadik Armagan483c8112021-06-01 09:24:52 +01001378 std::vector<float> projectionBiasVector(outputSize, 0.f); //{outputSize}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001379
Sadik Armagan483c8112021-06-01 09:24:52 +01001380 std::vector<float> inputLayerNormWeights = {0.1f, 0.2f, 0.3f, 0.5f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001381
Sadik Armagan483c8112021-06-01 09:24:52 +01001382 std::vector<float> forgetLayerNormWeights = {0.2f, 0.2f, 0.4f, 0.3f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001383
Sadik Armagan483c8112021-06-01 09:24:52 +01001384 std::vector<float> cellLayerNormWeights = {0.7f, 0.2f, 0.3f, 0.8f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001385
Sadik Armagan483c8112021-06-01 09:24:52 +01001386 std::vector<float> outputLayerNormWeights = {0.6f, 0.2f, 0.2f, 0.5f}; //{numUnits}
Jan Eilers38e05bd2019-06-26 13:10:09 +01001387
1388
James Conroy1f58f032021-04-27 17:13:27 +01001389 armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo4x5);
1390 armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo4x5);
1391 armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo4x5);
1392 armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo4x5);
1393 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo4x3);
1394 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo4x3);
1395 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo4x3);
1396 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo4x3);
1397 armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo4);
1398 armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo4);
1399 armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo4);
1400 armnn::ScopedTensorHandle cellBiasTensor(tensorInfo4);
1401 armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo4);
1402 armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfo4);
1403 armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfo4);
1404 armnn::ScopedTensorHandle projectionWeightsTensor(tensorInfo3x4);
1405 armnn::ScopedTensorHandle projectionBiasTensor(tensorInfo3);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001406
James Conroy1f58f032021-04-27 17:13:27 +01001407 armnn::ScopedTensorHandle inputLayerNormWeightsTensor(tensorInfo4);
1408 armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(tensorInfo4);
1409 armnn::ScopedTensorHandle cellLayerNormWeightsTensor(tensorInfo4);
1410 armnn::ScopedTensorHandle outputLayerNormWeightsTensor(tensorInfo4);
Jan Eilers38e05bd2019-06-26 13:10:09 +01001411
Sadik Armagan483c8112021-06-01 09:24:52 +01001412 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
1413 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1414 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1415 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
1416 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
1417 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1418 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1419 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
1420 AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data());
1421 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
1422 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1423 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1424 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
1425 AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data());
1426 AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data());
1427 AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
1428 AllocateAndCopyDataToITensorHandle(&projectionBiasTensor, projectionBiasVector.data());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001429
Sadik Armagan483c8112021-06-01 09:24:52 +01001430 AllocateAndCopyDataToITensorHandle(&inputLayerNormWeightsTensor, inputLayerNormWeights.data());
1431 AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
1432 AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
1433 AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001434
1435 data.m_InputToInputWeights = &inputToInputWeightsTensor;
1436 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1437 data.m_InputToCellWeights = &inputToCellWeightsTensor;
1438 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1439 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
1440 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1441 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1442 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1443 data.m_CellToInputWeights = &cellToInputWeightsTensor;
1444 data.m_InputGateBias = &inputGateBiasTensor;
1445 data.m_ForgetGateBias = &forgetGateBiasTensor;
1446 data.m_CellBias = &cellBiasTensor;
1447 data.m_OutputGateBias = &outputGateBiasTensor;
1448 data.m_CellToForgetWeights = &cellToForgetWeightsTensor;
1449 data.m_CellToOutputWeights = &cellToOutputWeightsTensor;
1450 data.m_ProjectionWeights = &projectionWeightsTensor;
1451 data.m_ProjectionBias = &projectionBiasTensor;
1452
1453 data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor;
1454 data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
1455 data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
1456 data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
1457
1458 // Flags to set test configuration
1459 data.m_Parameters.m_ActivationFunc = 4;
1460 data.m_Parameters.m_CifgEnabled = false;
1461 data.m_Parameters.m_PeepholeEnabled = true;
1462 data.m_Parameters.m_ProjectionEnabled = true;
1463 data.m_Parameters.m_LayerNormEnabled = true;
1464
1465
1466 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info);
1467 inputHandle->Allocate();
1468 outputStateInHandle->Allocate();
1469 cellStateInHandle->Allocate();
1470
1471 scratchHandle->Allocate();
1472 outputStateOutHandle->Allocate();
1473 cellStateOutHandle->Allocate();
1474 outputHandle->Allocate();
1475
Sadik Armagan483c8112021-06-01 09:24:52 +01001476 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1477 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1478 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001479
1480 workload->Execute();
1481
Sadik Armagan483c8112021-06-01 09:24:52 +01001482 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
Jan Eilers38e05bd2019-06-26 13:10:09 +01001483
Sadik Armagan483c8112021-06-01 09:24:52 +01001484 return LayerTestResult<T, 2>(actualOutput,
1485 outputVector,
1486 outputHandle->GetShape(),
1487 outputTensorInfo.GetShape());
James Conroy9c3cae82019-08-01 16:01:48 +01001488}
1489
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001490LayerTestResult<uint8_t, 2> QuantizedLstmTestImpl(
1491 armnn::IWorkloadFactory& workloadFactory,
1492 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001493 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001494 const std::vector<uint8_t>& input,
1495 const std::vector<uint8_t>& outputExpected,
1496 const armnn::TensorShape& inputShape,
1497 const armnn::TensorShape& outputExpectedShape)
James Conroy9c3cae82019-08-01 16:01:48 +01001498{
Jan Eilers8eb25602020-03-09 12:13:48 +00001499 IgnoreUnused(memoryManager);
Sadik Armagan483c8112021-06-01 09:24:52 +01001500 auto numBatches = armnn::numeric_cast<unsigned int>(inputShape[0]);
1501 auto inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]);
1502 auto outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]);
James Conroy9c3cae82019-08-01 16:01:48 +01001503
1504 // Scale/Offset for input/output, cellState In/Out, weights, bias
1505 float inputOutputScale = 0.0078125f;
1506 int32_t inputOutputOffset = 128;
1507
1508 float cellStateScale = 0.00048828125f;
1509 int32_t cellStateOffset = 0;
1510
1511 float weightsScale = 0.00408021f;
1512 int32_t weightsOffset = 100;
1513
1514 float biasScale = 3.1876640625e-05f;
1515 int32_t biasOffset = 0;
1516
1517 // Input/Output tensor info
1518 armnn::TensorInfo inputInfo({numBatches , inputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001519 armnn::DataType::QAsymmU8,
James Conroy9c3cae82019-08-01 16:01:48 +01001520 inputOutputScale,
1521 inputOutputOffset);
1522
1523 armnn::TensorInfo cellStateInfo({numBatches , outputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001524 armnn::DataType::QSymmS16,
James Conroy9c3cae82019-08-01 16:01:48 +01001525 cellStateScale,
1526 cellStateOffset);
1527
1528 armnn::TensorInfo outputStateInfo({numBatches , outputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001529 armnn::DataType::QAsymmU8,
James Conroy9c3cae82019-08-01 16:01:48 +01001530 inputOutputScale,
1531 inputOutputOffset);
1532
James Conroy9c3cae82019-08-01 16:01:48 +01001533 // Input0
1534 std::vector<uint8_t> inputVector;
1535 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
James Conroy9c3cae82019-08-01 16:01:48 +01001536
1537 // Input1
1538 std::vector<int16_t> cellStateInVector = {876, 1034, 955, -909, 761, 1029, 796, -1036}; // 13
James Conroy9c3cae82019-08-01 16:01:48 +01001539 // Input2
1540 std::vector<uint8_t> outputStateInVector = {136, 150, 140, 115, 135, 152, 138, 112}; // 14
James Conroy9c3cae82019-08-01 16:01:48 +01001541
1542 // Output0
1543 std::vector<int16_t> cellStateOutVector = {1485, 1177, 1373, -1023, 1019, 1355, 1097, -1235}; // 0
James Conroy9c3cae82019-08-01 16:01:48 +01001544
1545 // Output1
1546 std::vector<uint8_t> outputVector; // 1
1547 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
Sadik Armagan483c8112021-06-01 09:24:52 +01001548
1549 std::vector<uint8_t> actualOutput(outputStateInfo.GetNumElements());
James Conroy9c3cae82019-08-01 16:01:48 +01001550
1551 // Create tensor handles
Finn Williamsc43de6a2020-08-27 11:13:25 +01001552 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001553 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001554 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001555 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001556 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001557
1558 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001559 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
1560 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001561
1562 armnn::QuantizedLstmQueueDescriptor data;
1563 armnn::WorkloadInfo info;
1564
1565 // Add inputs and outputs to workload
1566 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
1567 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
1568 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
1569
1570 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
1571 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
1572
1573 // Weights and bias tensor and quantization info
1574 armnn::TensorInfo inputWeightsInfo({outputSize, inputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001575 armnn::DataType::QAsymmU8,
James Conroy9c3cae82019-08-01 16:01:48 +01001576 weightsScale,
1577 weightsOffset);
1578
1579 armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize},
Derek Lambertif90c56d2020-01-10 17:14:08 +00001580 armnn::DataType::QAsymmU8,
James Conroy9c3cae82019-08-01 16:01:48 +01001581 weightsScale,
1582 weightsOffset);
1583
1584 armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset);
1585
1586 // Weights and bias tensor data
Sadik Armagan483c8112021-06-01 09:24:52 +01001587 std::vector<uint8_t> inputToInputWeights = {146, 250, 235, 171, 10, 218, 171, 108};
1588 std::vector<uint8_t> inputToForgetWeights = {24, 50, 132, 179, 158, 110, 3, 169};
1589 std::vector<uint8_t> inputToCellWeights = {133, 34, 29, 49, 206, 109, 54, 183};
1590 std::vector<uint8_t> inputToOutputWeights = {195, 187, 11, 99, 109, 10, 218, 48};
James Conroy9c3cae82019-08-01 16:01:48 +01001591
Sadik Armagan483c8112021-06-01 09:24:52 +01001592 std::vector<uint8_t> recurrentToInputWeights =
1593 {254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26};
1594 std::vector<uint8_t> recurrentToForgetWeights =
1595 {137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253};
1596 std::vector<uint8_t> recurrentToCellWeights =
1597 {172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216};
1598 std::vector<uint8_t> recurrentToOutputWeights =
1599 {106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98};
James Conroy9c3cae82019-08-01 16:01:48 +01001600
Sadik Armagan483c8112021-06-01 09:24:52 +01001601 std::vector<int32_t> inputGateBias = {-7876, 13488, -726, 32839};
1602 std::vector<int32_t> forgetGateBias = {9206, -46884, -11693, -38724};
1603 std::vector<int32_t> cellBias = {39481, 48624, 48976, -21419};
1604 std::vector<int32_t> outputGateBias = {-58999, -17050, -41852, -40538};
James Conroy9c3cae82019-08-01 16:01:48 +01001605
James Conroy1f58f032021-04-27 17:13:27 +01001606 // ScopedTensorHandles
1607 armnn::ScopedTensorHandle inputToInputWeightsTensor(inputWeightsInfo);
1608 armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
1609 armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
1610 armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001611
James Conroy1f58f032021-04-27 17:13:27 +01001612 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(recurrentWeightsInfo);
1613 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
1614 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
1615 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001616
James Conroy1f58f032021-04-27 17:13:27 +01001617 armnn::ScopedTensorHandle inputGateBiasTensor(biasInfo);
1618 armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
1619 armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
1620 armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
James Conroy9c3cae82019-08-01 16:01:48 +01001621
1622 // Allocate and copy data
Sadik Armagan483c8112021-06-01 09:24:52 +01001623 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
1624 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1625 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1626 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
James Conroy9c3cae82019-08-01 16:01:48 +01001627
Sadik Armagan483c8112021-06-01 09:24:52 +01001628 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
1629 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1630 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1631 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
James Conroy9c3cae82019-08-01 16:01:48 +01001632
Sadik Armagan483c8112021-06-01 09:24:52 +01001633 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
1634 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1635 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1636 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
James Conroy9c3cae82019-08-01 16:01:48 +01001637
1638 // Setup queue descriptor
1639 data.m_InputToInputWeights = &inputToInputWeightsTensor;
1640 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1641 data.m_InputToCellWeights = &inputToCellWeightsTensor;
1642 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1643
1644 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
1645 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1646 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1647 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1648
1649 data.m_InputGateBias = &inputGateBiasTensor;
1650 data.m_ForgetGateBias = &forgetGateBiasTensor;
1651 data.m_CellBias = &cellBiasTensor;
1652 data.m_OutputGateBias = &outputGateBiasTensor;
1653
1654 // Create workload and allocate tensor handles
1655 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQuantizedLstm(data, info);
1656 inputHandle->Allocate();
1657 outputStateInHandle->Allocate();
1658 cellStateInHandle->Allocate();
1659
1660 cellStateOutHandle->Allocate();
1661 outputHandle->Allocate();
1662
Sadik Armagan483c8112021-06-01 09:24:52 +01001663 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1664 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1665 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
James Conroy9c3cae82019-08-01 16:01:48 +01001666
1667 workload->Execute();
1668
Sadik Armagan483c8112021-06-01 09:24:52 +01001669 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
James Conroy9c3cae82019-08-01 16:01:48 +01001670
Sadik Armagan483c8112021-06-01 09:24:52 +01001671 return LayerTestResult<uint8_t, 2>(actualOutput,
1672 outputVector,
1673 outputHandle->GetShape(),
1674 outputStateInfo.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01001675}
1676
James Conroyb22a75e2020-06-08 14:53:10 +01001677// QLSTM: CIFG, LayerNorm
James Conroy4f1f8992020-04-29 20:01:10 +01001678LayerTestResult<int8_t, 2> QLstmTestImpl(
1679 armnn::IWorkloadFactory& workloadFactory,
1680 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001681 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001682 const std::vector<int8_t>& input,
1683 const std::vector<int8_t>& outputExpected)
James Conroy4f1f8992020-04-29 20:01:10 +01001684{
1685 IgnoreUnused(memoryManager);
1686 unsigned int numBatches = 2;
1687 unsigned int inputSize = 5;
1688 unsigned int outputSize = 4;
1689 unsigned int numUnits = 4;
1690
1691 bool cifgEnabled = true;
1692 bool peepholeEnabled = false;
1693 bool projectionEnabled = false;
1694 bool layerNormEnabled = true;
1695
1696 // Scale/Offset quantization info
1697 float inputScale = 0.0078125f;
1698 int32_t inputOffset = 0;
1699
1700 int32_t hiddenStateZeroPoint = 0;
1701 float hiddenStateScale = 0.007f;
1702
1703 // if (!projectionEnabled) outputScale == hiddenStateScale
1704 float outputScale = hiddenStateScale;
1705 int32_t outputOffset = hiddenStateZeroPoint;
1706
1707 float cellStateScale = 3.05176e-05f;
1708 int32_t cellStateOffset = 0;
1709
1710 float weightsScale = 0.00784314f;
1711 int32_t weightsOffset = 0;
1712
1713 float layerNormScale = 3.05182e-05f;
1714 int32_t layerNormOffset = 0;
1715
1716 float biasScale = layerNormScale / 1024;
1717 int32_t biasOffset = 0;
1718
1719 float inputIntermediateScale = 0.007059f;
1720 float forgetIntermediateScale = 0.007812f;
1721 float cellIntermediateScale = inputIntermediateScale;
1722 float outputIntermediateScale = forgetIntermediateScale;
1723
1724 float cellClip = 0.0f;
1725 float projectionClip = 0.0f;
1726
1727 // Input/Output tensor info
1728 armnn::TensorInfo inputInfo({numBatches , inputSize},
1729 armnn::DataType::QAsymmS8,
1730 inputScale,
1731 inputOffset);
1732
1733 armnn::TensorInfo cellStateInfo({numBatches , numUnits},
1734 armnn::DataType::QSymmS16,
1735 cellStateScale,
1736 cellStateOffset);
1737
1738 armnn::TensorInfo outputStateInfo({numBatches , outputSize},
1739 armnn::DataType::QAsymmS8,
1740 outputScale,
1741 outputOffset);
1742
1743 LayerTestResult<int8_t, 2> ret(outputStateInfo);
1744
1745 // Input tensors
1746 std::vector<int8_t> inputVector;
1747 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
James Conroy4f1f8992020-04-29 20:01:10 +01001748
1749 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
James Conroy4f1f8992020-04-29 20:01:10 +01001750
Teresa Charlinbe727be2020-09-25 15:08:21 +01001751 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
James Conroy4f1f8992020-04-29 20:01:10 +01001752
1753 // Output tensors
Sadik Armagan483c8112021-06-01 09:24:52 +01001754 std::vector<int16_t> cellStateOutVector = {-11692, 9960, 5491, 8861, -9422, 7726, 2056, 13149};
James Conroy4f1f8992020-04-29 20:01:10 +01001755
1756 std::vector<int8_t> outputVector;
1757 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
Sadik Armagan483c8112021-06-01 09:24:52 +01001758
1759 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
James Conroy4f1f8992020-04-29 20:01:10 +01001760
1761 // Create tensor handles
Finn Williamsc43de6a2020-08-27 11:13:25 +01001762 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001763 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001764 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001765 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001766 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001767
Finn Williamsc43de6a2020-08-27 11:13:25 +01001768 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
1769 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001770 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01001771 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
1772 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001773
1774 armnn::QLstmQueueDescriptor data;
1775 armnn::WorkloadInfo info;
1776
1777 // Add inputs and outputs to workload
1778 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
1779 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
1780 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
1781
1782 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
1783 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
1784 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
1785
1786 // Weights and bias tensor and quantization info
1787 armnn::TensorInfo inputWeightsInfo({outputSize, inputSize},
1788 armnn::DataType::QSymmS8,
1789 weightsScale,
1790 weightsOffset);
1791
1792 armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize},
1793 armnn::DataType::QSymmS8,
1794 weightsScale,
1795 weightsOffset);
1796
1797 armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset);
1798
1799 armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset);
1800
1801 // Weights and bias tensor data
Sadik Armagan483c8112021-06-01 09:24:52 +01001802 std::vector<int8_t> inputToForgetWeights =
1803 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
1804 std::vector<int8_t> inputToCellWeights =
1805 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
1806 std::vector<int8_t> inputToOutputWeights =
1807 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
James Conroy4f1f8992020-04-29 20:01:10 +01001808
Sadik Armagan483c8112021-06-01 09:24:52 +01001809 std::vector<int8_t> recurrentToForgetWeights =
1810 {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25, 25, 38, -13, 51};
1811 std::vector<int8_t> recurrentToCellWeights =
1812 {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25, 38, -13, 25, 64};
1813 std::vector<int8_t> recurrentToOutputWeights =
1814 {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25, 13, 64, 25, -38};
James Conroy4f1f8992020-04-29 20:01:10 +01001815
Sadik Armagan483c8112021-06-01 09:24:52 +01001816 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
1817 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
1818 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
James Conroy4f1f8992020-04-29 20:01:10 +01001819
Sadik Armagan483c8112021-06-01 09:24:52 +01001820 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
1821 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
1822 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
James Conroy4f1f8992020-04-29 20:01:10 +01001823
James Conroy1f58f032021-04-27 17:13:27 +01001824 // ScopedTensorHandles
1825 armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
1826 armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
1827 armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001828
James Conroy1f58f032021-04-27 17:13:27 +01001829 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
1830 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
1831 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001832
James Conroy1f58f032021-04-27 17:13:27 +01001833 armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
1834 armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
1835 armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001836
James Conroy1f58f032021-04-27 17:13:27 +01001837 armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo);
1838 armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo);
1839 armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo);
James Conroy4f1f8992020-04-29 20:01:10 +01001840
1841 // Allocate and copy data
Sadik Armagan483c8112021-06-01 09:24:52 +01001842 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
1843 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
1844 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001845
Sadik Armagan483c8112021-06-01 09:24:52 +01001846 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
1847 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
1848 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001849
Sadik Armagan483c8112021-06-01 09:24:52 +01001850 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
1851 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
1852 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001853
Sadik Armagan483c8112021-06-01 09:24:52 +01001854 AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
1855 AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
1856 AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001857
1858 // Setup queue descriptor
1859 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
1860 data.m_InputToCellWeights = &inputToCellWeightsTensor;
1861 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
1862
1863 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
1864 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
1865 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
1866
1867 data.m_ForgetGateBias = &forgetGateBiasTensor;
1868 data.m_CellBias = &cellBiasTensor;
1869 data.m_OutputGateBias = &outputGateBiasTensor;
1870
1871 data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
1872 data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
1873 data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
1874
1875 data.m_Parameters.m_CifgEnabled = cifgEnabled;
1876 data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
1877 data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
1878 data.m_Parameters.m_LayerNormEnabled = layerNormEnabled;
1879
1880 data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale;
1881 data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale;
1882 data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale;
1883 data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale;
1884
1885 data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint;
1886 data.m_Parameters.m_HiddenStateScale = hiddenStateScale;
1887
1888 data.m_Parameters.m_CellClip = cellClip;
1889 data.m_Parameters.m_ProjectionClip = projectionClip;
1890
1891 // Create workload and allocate tensor handles
1892 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info);
1893 inputHandle->Allocate();
1894 outputStateInHandle->Allocate();
1895 cellStateInHandle->Allocate();
1896
1897 outputStateOutHandle->Allocate();
1898 cellStateOutHandle->Allocate();
1899 outputHandle->Allocate();
1900
Sadik Armagan483c8112021-06-01 09:24:52 +01001901 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
1902 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
1903 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
James Conroy4f1f8992020-04-29 20:01:10 +01001904
1905 workload->Execute();
1906
Sadik Armagan483c8112021-06-01 09:24:52 +01001907 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
James Conroy4f1f8992020-04-29 20:01:10 +01001908
Sadik Armagan483c8112021-06-01 09:24:52 +01001909 return LayerTestResult<int8_t, 2>(actualOutput,
1910 outputVector,
1911 outputHandle->GetShape(),
1912 outputStateInfo.GetShape());
James Conroy4f1f8992020-04-29 20:01:10 +01001913}
1914
James Conroyb22a75e2020-06-08 14:53:10 +01001915// QLSTM: Projection, LayerNorm
1916LayerTestResult<int8_t, 2> QLstmTestImpl1(
1917 armnn::IWorkloadFactory& workloadFactory,
1918 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01001919 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01001920 const std::vector<int8_t>& input,
1921 const std::vector<int8_t>& outputExpected)
James Conroyb22a75e2020-06-08 14:53:10 +01001922{
1923 IgnoreUnused(memoryManager);
1924 unsigned int numBatches = 2;
1925 unsigned int inputSize = 5;
1926 unsigned int outputSize = 3;
1927 unsigned int numUnits = 4;
1928
1929 bool cifgEnabled = false;
1930 bool peepholeEnabled = false;
1931 bool projectionEnabled = true;
1932 bool layerNormEnabled = true;
1933
1934 // Scale/Offset quantization info
1935 float inputScale = 0.0078125f;
1936 int32_t inputOffset = 0;
1937
1938 int32_t hiddenStateZeroPoint = 0;
1939 float hiddenStateScale = 0.007f;
1940
1941 // if (!projectionEnabled) outputScale == hiddenStateScale
1942 float outputScale = 3.05176e-05f;
1943 int32_t outputOffset = 0;
1944
1945 float cellStateScale = 3.05176e-05f;
1946 int32_t cellStateOffset = 0;
1947
1948 float weightsScale = 0.00784314f;
1949 int32_t weightsOffset = 0;
1950
1951 float layerNormScale = 3.05182e-05f;
1952 int32_t layerNormOffset = 0;
1953
1954 float biasScale = layerNormScale / 1024;
1955 int32_t biasOffset = 0;
1956
1957 float projectionWeightsScale = 0.00392157f;
1958
1959 float inputIntermediateScale = 0.007059f;
1960 float forgetIntermediateScale = 0.007812f;
1961 float cellIntermediateScale = inputIntermediateScale;
1962 float outputIntermediateScale = forgetIntermediateScale;
1963
1964 float cellClip = 0.0f;
1965 float projectionClip = 0.0f;
1966
1967 // Input/Output tensor info
1968 armnn::TensorInfo inputInfo({numBatches , inputSize},
1969 armnn::DataType::QAsymmS8,
1970 inputScale,
1971 inputOffset);
1972
1973 armnn::TensorInfo cellStateInfo({numBatches , numUnits},
1974 armnn::DataType::QSymmS16,
1975 cellStateScale,
1976 cellStateOffset);
1977
1978 armnn::TensorInfo outputStateInfo({numBatches , outputSize},
1979 armnn::DataType::QAsymmS8,
1980 outputScale,
1981 outputOffset);
1982
James Conroyb22a75e2020-06-08 14:53:10 +01001983 // Input tensors
1984 std::vector<int8_t> inputVector;
1985 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
James Conroyb22a75e2020-06-08 14:53:10 +01001986
1987 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
James Conroyb22a75e2020-06-08 14:53:10 +01001988
1989 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0};
James Conroyb22a75e2020-06-08 14:53:10 +01001990
1991 // Output tensors
1992 std::vector<int16_t> cellStateOutVector = {-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939};
James Conroyb22a75e2020-06-08 14:53:10 +01001993
1994 std::vector<int8_t> outputVector;
1995 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
Sadik Armagan483c8112021-06-01 09:24:52 +01001996
1997 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
James Conroyb22a75e2020-06-08 14:53:10 +01001998
1999 // Create tensor handles
Finn Williamsc43de6a2020-08-27 11:13:25 +01002000 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002001 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002002 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002003 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002004 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002005
Finn Williamsc43de6a2020-08-27 11:13:25 +01002006 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
2007 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002008 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002009 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
2010 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002011
2012 armnn::QLstmQueueDescriptor data;
2013 armnn::WorkloadInfo info;
2014
2015 // Add inputs and outputs to workload
2016 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
2017 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
2018 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
2019
2020 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
2021 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
2022 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
2023
2024 // Weights and bias tensor and quantization info
2025 armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
2026 armnn::DataType::QSymmS8,
2027 weightsScale,
2028 weightsOffset);
2029
2030 armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
2031 armnn::DataType::QSymmS8,
2032 weightsScale,
2033 weightsOffset);
2034
2035 armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset);
2036
2037 armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset);
2038
2039 armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits},
2040 armnn::DataType::QSymmS8,
2041 projectionWeightsScale,
2042 0);
2043
2044 // Weights and bias tensor data
Sadik Armagan483c8112021-06-01 09:24:52 +01002045 std::vector<int8_t> inputToInputWeights =
2046 {64, 77, 89, -102, -115, 13, 25, 38, -51, 64, -102, 89, -77, 64, -51, -64, -51, -38, -25, -13};
2047 std::vector<int8_t> inputToForgetWeights =
2048 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
2049 std::vector<int8_t> inputToCellWeights =
2050 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
2051 std::vector<int8_t> inputToOutputWeights =
2052 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
James Conroyb22a75e2020-06-08 14:53:10 +01002053
Sadik Armagan483c8112021-06-01 09:24:52 +01002054 std::vector<int8_t> recurrentToInputWeights = {-25, -38, 51, 13, -64, 115, -25, -38, -89, 6, -25, -77};
2055 std::vector<int8_t> recurrentToForgetWeights = {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25};
2056 std::vector<int8_t> recurrentToCellWeights = {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25};
2057 std::vector<int8_t> recurrentToOutputWeights = {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25};
James Conroyb22a75e2020-06-08 14:53:10 +01002058
Sadik Armagan483c8112021-06-01 09:24:52 +01002059 std::vector<int32_t> inputGateBias = {644245, 3221226, 4724464, 8160438};
2060 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
2061 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
2062 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
James Conroyb22a75e2020-06-08 14:53:10 +01002063
Sadik Armagan483c8112021-06-01 09:24:52 +01002064 std::vector<int16_t> inputLayerNormWeights = {3277, 6553, 9830, 16384};
2065 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
2066 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
2067 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
James Conroyb22a75e2020-06-08 14:53:10 +01002068
Sadik Armagan483c8112021-06-01 09:24:52 +01002069 std::vector<int8_t> projectionWeights = {-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51};
James Conroyb22a75e2020-06-08 14:53:10 +01002070
James Conroy1f58f032021-04-27 17:13:27 +01002071 // ScopedTensorHandles
2072 armnn::ScopedTensorHandle inputToInputWeightsTensor(inputWeightsInfo);
2073 armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
2074 armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
2075 armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002076
James Conroy1f58f032021-04-27 17:13:27 +01002077 armnn::ScopedTensorHandle recurrentToInputWeightsTensor(recurrentWeightsInfo);
2078 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
2079 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
2080 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002081
James Conroy1f58f032021-04-27 17:13:27 +01002082 armnn::ScopedTensorHandle inputGateBiasTensor(biasInfo);
2083 armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
2084 armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
2085 armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002086
James Conroy1f58f032021-04-27 17:13:27 +01002087 armnn::ScopedTensorHandle inputLayerNormWeightsTensor(layerNormWeightsInfo);
2088 armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo);
2089 armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo);
2090 armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002091
James Conroy1f58f032021-04-27 17:13:27 +01002092 armnn::ScopedTensorHandle projectionWeightsTensor(projectionWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002093
2094 // Allocate and copy data
Sadik Armagan483c8112021-06-01 09:24:52 +01002095 AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data());
2096 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
2097 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
2098 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002099
Sadik Armagan483c8112021-06-01 09:24:52 +01002100 AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data());
2101 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
2102 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
2103 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002104
Sadik Armagan483c8112021-06-01 09:24:52 +01002105 AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data());
2106 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
2107 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
2108 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002109
Sadik Armagan483c8112021-06-01 09:24:52 +01002110 AllocateAndCopyDataToITensorHandle(&inputLayerNormWeightsTensor, inputLayerNormWeights.data());
2111 AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
2112 AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
2113 AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002114
Sadik Armagan483c8112021-06-01 09:24:52 +01002115 AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002116
2117 // Setup queue descriptor
2118 data.m_InputToInputWeights = &inputToInputWeightsTensor;
2119 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
2120 data.m_InputToCellWeights = &inputToCellWeightsTensor;
2121 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
2122
2123 data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
2124 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
2125 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
2126 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
2127
2128 data.m_InputGateBias = &inputGateBiasTensor;
2129 data.m_ForgetGateBias = &forgetGateBiasTensor;
2130 data.m_CellBias = &cellBiasTensor;
2131 data.m_OutputGateBias = &outputGateBiasTensor;
2132
2133 data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor;
2134 data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
2135 data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
2136 data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
2137
2138 data.m_ProjectionWeights = &projectionWeightsTensor;
2139
2140 data.m_Parameters.m_CifgEnabled = cifgEnabled;
2141 data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
2142 data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
2143 data.m_Parameters.m_LayerNormEnabled = layerNormEnabled;
2144
2145 data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale;
2146 data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale;
2147 data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale;
2148 data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale;
2149
2150 data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint;
2151 data.m_Parameters.m_HiddenStateScale = hiddenStateScale;
2152
2153 data.m_Parameters.m_CellClip = cellClip;
2154 data.m_Parameters.m_ProjectionClip = projectionClip;
2155
2156 // Create workload and allocate tensor handles
2157 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info);
2158 inputHandle->Allocate();
2159 outputStateInHandle->Allocate();
2160 cellStateInHandle->Allocate();
2161
2162 outputStateOutHandle->Allocate();
2163 cellStateOutHandle->Allocate();
2164 outputHandle->Allocate();
2165
Sadik Armagan483c8112021-06-01 09:24:52 +01002166 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
2167 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
2168 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002169
2170 workload->Execute();
2171
Sadik Armagan483c8112021-06-01 09:24:52 +01002172 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
James Conroyb22a75e2020-06-08 14:53:10 +01002173
Sadik Armagan483c8112021-06-01 09:24:52 +01002174 return LayerTestResult<int8_t, 2>(actualOutput,
2175 outputVector,
2176 outputHandle->GetShape(),
2177 outputStateInfo.GetShape());
James Conroyb22a75e2020-06-08 14:53:10 +01002178}
2179
2180// QLSTM: Projection, CIFG, LayerNorm
2181LayerTestResult<int8_t, 2> QLstmTestImpl2(
2182 armnn::IWorkloadFactory& workloadFactory,
2183 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002184 const armnn::ITensorHandleFactory& tensorHandleFactory,
Sadik Armagan483c8112021-06-01 09:24:52 +01002185 const std::vector<int8_t>& input,
2186 const std::vector<int8_t>& outputExpected)
James Conroyb22a75e2020-06-08 14:53:10 +01002187{
2188 IgnoreUnused(memoryManager);
2189 unsigned int numBatches = 2;
2190 unsigned int inputSize = 5;
2191 unsigned int outputSize = 3;
2192 unsigned int numUnits = 4;
2193
2194 bool cifgEnabled = true;
2195 bool peepholeEnabled = false;
2196 bool projectionEnabled = true;
2197 bool layerNormEnabled = true;
2198
2199 // Scale/Offset quantization info
2200 float inputScale = 0.0078125f;
2201 int32_t inputOffset = 0;
2202
2203 int32_t hiddenStateZeroPoint = 0;
2204 float hiddenStateScale = 0.007f;
2205
2206 // if (!projectionEnabled) outputScale == hiddenStateScale
2207 float outputScale = 3.05176e-05f;
2208 int32_t outputOffset = 0;
2209
2210 float cellStateScale = 3.05176e-05f;
2211 int32_t cellStateOffset = 0;
2212
2213 float weightsScale = 0.00784314f;
2214 int32_t weightsOffset = 0;
2215
2216 float layerNormScale = 3.05182e-05f;
2217 int32_t layerNormOffset = 0;
2218
2219 float biasScale = layerNormScale / 1024;
2220 int32_t biasOffset = 0;
2221
2222 float projectionWeightsScale = 0.00392157f;
2223
2224 float inputIntermediateScale = 0.007059f;
2225 float forgetIntermediateScale = 0.007812f;
2226 float cellIntermediateScale = inputIntermediateScale;
2227 float outputIntermediateScale = forgetIntermediateScale;
2228
2229 float cellClip = 0.0f;
2230 float projectionClip = 0.0f;
2231
2232 // Input/Output tensor info
2233 armnn::TensorInfo inputInfo({numBatches , inputSize},
2234 armnn::DataType::QAsymmS8,
2235 inputScale,
2236 inputOffset);
2237
2238 armnn::TensorInfo cellStateInfo({numBatches , numUnits},
2239 armnn::DataType::QSymmS16,
2240 cellStateScale,
2241 cellStateOffset);
2242
2243 armnn::TensorInfo outputStateInfo({numBatches , outputSize},
2244 armnn::DataType::QAsymmS8,
2245 outputScale,
2246 outputOffset);
2247
James Conroyb22a75e2020-06-08 14:53:10 +01002248 // Input tensors
2249 std::vector<int8_t> inputVector;
2250 inputVector.assign(input.data(), input.data() + (numBatches * inputSize));
James Conroyb22a75e2020-06-08 14:53:10 +01002251
2252 std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0};
James Conroyb22a75e2020-06-08 14:53:10 +01002253
2254 std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0};
James Conroyb22a75e2020-06-08 14:53:10 +01002255
2256 // Output tensors
Sadik Armagan483c8112021-06-01 09:24:52 +01002257 std::vector<int16_t> cellStateOutVector = {-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939};
James Conroyb22a75e2020-06-08 14:53:10 +01002258
2259 std::vector<int8_t> outputVector;
2260 outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize));
Sadik Armagan483c8112021-06-01 09:24:52 +01002261
2262 std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements());
James Conroyb22a75e2020-06-08 14:53:10 +01002263
2264 // Create tensor handles
Finn Williamsc43de6a2020-08-27 11:13:25 +01002265 std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002266 std::unique_ptr<armnn::ITensorHandle> cellStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002267 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002268 std::unique_ptr<armnn::ITensorHandle> outputStateInHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002269 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002270
Finn Williamsc43de6a2020-08-27 11:13:25 +01002271 std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle =
2272 tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002273 std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle =
Finn Williamsc43de6a2020-08-27 11:13:25 +01002274 tensorHandleFactory.CreateTensorHandle(cellStateInfo);
2275 std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002276
2277 armnn::QLstmQueueDescriptor data;
2278 armnn::WorkloadInfo info;
2279
2280 // Add inputs and outputs to workload
2281 AddInputToWorkload(data, info, inputInfo, inputHandle.get());
2282 AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get());
2283 AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get());
2284
2285 AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get());
2286 AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get());
2287 AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get());
2288
2289 // Weights and bias tensor and quantization info
2290 armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
2291 armnn::DataType::QSymmS8,
2292 weightsScale,
2293 weightsOffset);
2294
2295 armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
2296 armnn::DataType::QSymmS8,
2297 weightsScale,
2298 weightsOffset);
2299
2300 armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset);
2301
2302 armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset);
2303
2304 armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits},
2305 armnn::DataType::QSymmS8,
2306 projectionWeightsScale,
2307 0);
2308
2309 // Weights and bias tensor data
Sadik Armagan483c8112021-06-01 09:24:52 +01002310 std::vector<int8_t> inputToForgetWeights =
2311 {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64};
2312 std::vector<int8_t> inputToCellWeights =
2313 {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77};
2314 std::vector<int8_t> inputToOutputWeights =
2315 {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51};
James Conroyb22a75e2020-06-08 14:53:10 +01002316
Sadik Armagan483c8112021-06-01 09:24:52 +01002317 std::vector<int8_t> recurrentToForgetWeights =
2318 {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25};
2319 std::vector<int8_t> recurrentToCellWeights =
2320 {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25};
2321 std::vector<int8_t> recurrentToOutputWeights =
2322 {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25};
James Conroyb22a75e2020-06-08 14:53:10 +01002323
Sadik Armagan483c8112021-06-01 09:24:52 +01002324 std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484};
2325 std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987};
2326 std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484};
James Conroyb22a75e2020-06-08 14:53:10 +01002327
Sadik Armagan483c8112021-06-01 09:24:52 +01002328 std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830};
2329 std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214};
2330 std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384};
James Conroyb22a75e2020-06-08 14:53:10 +01002331
Sadik Armagan483c8112021-06-01 09:24:52 +01002332 std::vector<int8_t> projectionWeights = {-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51};
James Conroyb22a75e2020-06-08 14:53:10 +01002333
James Conroy1f58f032021-04-27 17:13:27 +01002334 // ScopedTensorHandles
2335 armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo);
2336 armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo);
2337 armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002338
James Conroy1f58f032021-04-27 17:13:27 +01002339 armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo);
2340 armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo);
2341 armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002342
James Conroy1f58f032021-04-27 17:13:27 +01002343 armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo);
2344 armnn::ScopedTensorHandle cellBiasTensor(biasInfo);
2345 armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002346
James Conroy1f58f032021-04-27 17:13:27 +01002347 armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo);
2348 armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo);
2349 armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002350
James Conroy1f58f032021-04-27 17:13:27 +01002351 armnn::ScopedTensorHandle projectionWeightsTensor(projectionWeightsInfo);
James Conroyb22a75e2020-06-08 14:53:10 +01002352
2353 // Allocate and copy data
Sadik Armagan483c8112021-06-01 09:24:52 +01002354 AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data());
2355 AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data());
2356 AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002357
Sadik Armagan483c8112021-06-01 09:24:52 +01002358 AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data());
2359 AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data());
2360 AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002361
Sadik Armagan483c8112021-06-01 09:24:52 +01002362 AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data());
2363 AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data());
2364 AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002365
Sadik Armagan483c8112021-06-01 09:24:52 +01002366 AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data());
2367 AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data());
2368 AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002369
Sadik Armagan483c8112021-06-01 09:24:52 +01002370 AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002371
2372 // Setup queue descriptor
2373 data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
2374 data.m_InputToCellWeights = &inputToCellWeightsTensor;
2375 data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
2376
2377 data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
2378 data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
2379 data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
2380
2381 data.m_ForgetGateBias = &forgetGateBiasTensor;
2382 data.m_CellBias = &cellBiasTensor;
2383 data.m_OutputGateBias = &outputGateBiasTensor;
2384
2385 data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
2386 data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
2387 data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
2388
2389 data.m_ProjectionWeights = &projectionWeightsTensor;
2390
2391 data.m_Parameters.m_CifgEnabled = cifgEnabled;
2392 data.m_Parameters.m_PeepholeEnabled = peepholeEnabled;
2393 data.m_Parameters.m_ProjectionEnabled = projectionEnabled;
2394 data.m_Parameters.m_LayerNormEnabled = layerNormEnabled;
2395
2396 data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale;
2397 data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale;
2398 data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale;
2399 data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale;
2400
2401 data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint;
2402 data.m_Parameters.m_HiddenStateScale = hiddenStateScale;
2403
2404 data.m_Parameters.m_CellClip = cellClip;
2405 data.m_Parameters.m_ProjectionClip = projectionClip;
2406
2407 // Create workload and allocate tensor handles
2408 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info);
2409 inputHandle->Allocate();
2410 outputStateInHandle->Allocate();
2411 cellStateInHandle->Allocate();
2412
2413 outputStateOutHandle->Allocate();
2414 cellStateOutHandle->Allocate();
2415 outputHandle->Allocate();
2416
Sadik Armagan483c8112021-06-01 09:24:52 +01002417 CopyDataToITensorHandle(inputHandle.get(), inputVector.data());
2418 CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data());
2419 CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data());
James Conroyb22a75e2020-06-08 14:53:10 +01002420
2421 workload->Execute();
2422
Sadik Armagan483c8112021-06-01 09:24:52 +01002423 CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
James Conroyb22a75e2020-06-08 14:53:10 +01002424
Sadik Armagan483c8112021-06-01 09:24:52 +01002425 return LayerTestResult<int8_t, 2>(actualOutput,
2426 outputVector,
2427 outputHandle->GetShape(),
2428 outputStateInfo.GetShape());
James Conroyb22a75e2020-06-08 14:53:10 +01002429}
2430
James Conroy4f1f8992020-04-29 20:01:10 +01002431
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002432} // anonymous namespace
2433
2434#if defined(ARMNNREF_ENABLED)
2435
2436// The LSTM test units are run only for the reference backend at the moment
2437
2438void LstmUtilsZeroVectorTest()
2439{
2440 armnn::TensorInfo inputDesc({4}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002441 std::vector<float> input = {2., 3., 3., 4.};
2442 std::vector<float> expectedOutput = {0., 0., 0., 0.};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002443
Sadik Armagan483c8112021-06-01 09:24:52 +01002444 return LstmUtilsZeroVectorTestImpl<armnn::DataType::Float32>(input, 4, expectedOutput, inputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002445}
2446
2447void LstmUtilsMeanStddevNormalizationNoneZeroInputTest()
2448{
2449 uint32_t batchSize = 2;
2450 uint32_t vecSize = 4;
2451 armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002452 std::vector<float> input =
2453 { 0.1f, 0.2f, 0.3f, 0.4f, //batch 0
2454 0.9f, 1.0f, 1.1f, 1.2f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002455
Sadik Armagan483c8112021-06-01 09:24:52 +01002456 std::vector<float> expectedOutput =
2457 { -1.34164071f, -0.447213531f, 0.44721365f, 1.34164071f, //batch 0
2458 -1.34163153f, -0.447210163f, 0.447211236f, 1.3416326f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002459
2460 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
Sadik Armagan483c8112021-06-01 09:24:52 +01002461 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002462}
2463
2464void LstmUtilsMeanStddevNormalizationAllZeroInputTest()
2465{
2466 uint32_t batchSize = 2;
2467 uint32_t vecSize = 4;
2468 armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002469 std::vector<float> input =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002470 { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0
Sadik Armagan483c8112021-06-01 09:24:52 +01002471 0.0f, 0.0f, 0.0f, 0.0f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002472
Sadik Armagan483c8112021-06-01 09:24:52 +01002473 std::vector<float> expectedOutput =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002474 { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0
Sadik Armagan483c8112021-06-01 09:24:52 +01002475 0.0f, 0.0f, 0.0f, 0.0f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002476
2477 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
Sadik Armagan483c8112021-06-01 09:24:52 +01002478 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002479}
2480
2481void LstmUtilsMeanStddevNormalizationMixedZeroInputTest()
2482{
2483 uint32_t batchSize = 2;
2484 uint32_t vecSize = 4;
2485 armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002486 std::vector<float> input =
2487 { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0
2488 0.1f, 0.2f, 0.3f, 0.4f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002489
Sadik Armagan483c8112021-06-01 09:24:52 +01002490 std::vector<float> expectedOutput =
2491 { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0
2492 -1.34164071f, -0.447213531f, 0.44721365f, 1.34164071f }; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002493
2494 return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input,
Sadik Armagan483c8112021-06-01 09:24:52 +01002495 vecSize, batchSize, expectedOutput, inputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002496}
2497
2498void LstmUtilsVectorBatchVectorCwiseProductTest()
2499{
2500 uint32_t batchSize = 4;
2501 uint32_t vecSize = 29;
2502 armnn::TensorInfo vecDesc({vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002503 std::vector<float> vector =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002504 { 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.1f,
2505 11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.2f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002506 21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f, 0.0f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002507
2508 armnn::TensorInfo batchVecDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002509 std::vector<float> batchVector =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002510 { /* batch 0 */
2511 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.1f,
2512 11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.2f,
2513 21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f, 0.0f,
2514 /* batch 1 */
2515 -1.1f, -2.2f, -3.3f, -4.4f, -5.5f, -6.6f, -7.7f, -8.8f, -9.9f, -10.1f,
2516 -11.11f, -12.12f, -13.13f, -14.14f, -15.15f, -16.16f, -17.17f, -18.18f, -19.19f, -20.2f,
2517 -21.21f, -22.22f, -23.23f, -24.24f, -25.25f, -26.26f, -27.27f, -28.28f, 0.0f,
2518 /* batch 2 */
2519 1.1f, -2.2f, 3.3f, -4.4f, 5.5f, -6.6f, 7.7f, -8.8f, 9.9f, -10.1f,
2520 11.11f, -12.12f, 13.13f, -14.14f, 15.15f, -16.16f, 17.17f, -18.18f, 19.19f, -20.2f,
2521 21.21f, -22.22f, 23.23f, -24.24f, 25.25f, -26.26f, 27.27f, -28.28f, 0.0f,
2522 /* batch 3 */
2523 -1.1f, 2.2f, -3.3f, 4.4f, -5.5f, 6.6f, -7.7f, 8.8f, -9.9f, 10.1f,
2524 -11.11f, 12.12f, -13.13f, 14.14f, -15.15f, 16.16f, -17.17f, 18.18f, -19.19f, 20.2f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002525 -21.21f, 22.22f, -23.23f, 24.24f, -25.25f, 26.26f, -27.27f, 28.28f, 0.0f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002526
2527 // Expect output = input * output + output.
Sadik Armagan483c8112021-06-01 09:24:52 +01002528 std::vector<float> expectedOutput =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002529 { /* batch 0 */
2530 1.210000f, 4.840000f, 10.889999f, 19.360001f, 30.250000f, 43.559998f,
2531 59.289997f, 77.440002f, 98.009995f, 102.010010f, 123.432091f, 146.894394f,
2532 172.396896f, 199.939606f, 229.522491f, 261.145599f, 294.808899f, 330.512421f,
2533 368.256134f, 408.040039f, 449.864075f, 493.728363f, 539.632874f, 587.577576f,
2534 637.562500f, 689.587585f, 743.652954f, 799.758423f, 0.000000f,
2535 /* batch 1 */
2536 -1.210000f, -4.840000f, -10.889999f, -19.360001f, -30.250000f, -43.559998f,
2537 -59.289997f, -77.440002f, -98.009995f, -102.010010f, -123.432091f, -146.894394f,
2538 -172.396896f, -199.939606f, -229.522491f, -261.145599f, -294.808899f, -330.512421f,
2539 -368.256134f, -408.040039f, -449.864075f, -493.728363f, -539.632874f, -587.577576f,
2540 -637.562500f, -689.587585f, -743.652954f, -799.758423f, 0.000000f,
2541 /* batch 2 */
2542 1.210000f, -4.840000f, 10.889999f, -19.360001f, 30.250000f, -43.559998f,
2543 59.289997f, -77.440002f, 98.009995f, -102.010010f, 123.432091f, -146.894394f,
2544 172.396896f, -199.939606f, 229.522491f, -261.145599f, 294.808899f, -330.512421f,
2545 368.256134f, -408.040039f, 449.864075f, -493.728363f, 539.632874f, -587.577576f,
2546 637.562500f, -689.587585f, 743.652954f, -799.758423f, 0.000000f,
2547 /* batch 3 */
2548 -1.210000f, 4.840000f, -10.889999f, 19.360001f, -30.250000f, 43.559998f,
2549 -59.289997f, 77.440002f, -98.009995f, 102.010010f, -123.432091f, 146.894394f,
2550 -172.396896f, 199.939606f, -229.522491f, 261.145599f, -294.808899f, 330.512421f,
2551 -368.256134f, 408.040039f, -449.864075f, 493.728363f, -539.632874f, 587.577576f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002552 -637.562500f, 689.587585f, -743.652954f, 799.758423f, 0.000000f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002553
2554 return LstmUtilsVectorBatchVectorCwiseProductTestImpl<armnn::DataType::Float32>(vector, batchVector,
Sadik Armagan483c8112021-06-01 09:24:52 +01002555 vecSize, batchSize, expectedOutput, vecDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002556}
2557
2558void LstmUtilsVectorBatchVectorAddTest()
2559{
2560 uint32_t batchSize = 2;
2561 uint32_t vecSize = 3;
2562 armnn::TensorInfo vecDesc({vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002563 std::vector<float> vector = { 0.0f, -0.5f, 1.0f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002564
2565 armnn::TensorInfo batchVecDesc({batchSize, vecSize}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002566 std::vector<float> batchVector =
2567 {
2568 1.0f, 2.0f, 3.0f, //batch 0
2569 4.0f, 5.0f, 6.0f //batch 1
2570 };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002571
Sadik Armagan483c8112021-06-01 09:24:52 +01002572 std::vector<float> expectedOutput =
2573 {
2574 1.0f, 1.5f, 4.0f,
2575 4.0f, 4.5f, 7.0f
2576 };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002577
2578 return LstmUtilsVectorBatchVectorAddTestImpl<armnn::DataType::Float32>(vector, batchVector,
Sadik Armagan483c8112021-06-01 09:24:52 +01002579 vecSize, batchSize, expectedOutput, batchVecDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002580}
2581
2582#endif
2583
2584LayerTestResult<float, 2> LstmLayerFloat32WithCifgWithPeepholeNoProjectionTest(
2585 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002586 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2587 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002588{
2589 armnn::TensorInfo inputDesc({ 2, 2 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002590 std::vector<float> input = { 2., 3., 3., 4. };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002591
2592 armnn::TensorInfo outputDesc({ 2, 4 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002593 std::vector<float> expectedOutput =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002594 {-0.36444446f, -0.00352185f, 0.12886585f, -0.05163646f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002595 -0.42734814f, -0.00478661f, 0.13455015f, -0.03560682f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002596 return LstmLayerWithCifgWithPeepholeNoProjectionTestImpl<armnn::DataType::Float32>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002597 workloadFactory, memoryManager, tensorHandleFactory,
2598 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002599}
2600
2601LayerTestResult<float, 2> LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest(
2602 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002603 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2604 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002605{
2606 armnn::TensorInfo inputDesc({ 2, 5 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002607 std::vector<float> input =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002608 {0.787926f, 0.151646f, 0.071352f, 0.118426f, 0.458058f,
Sadik Armagan483c8112021-06-01 09:24:52 +01002609 0.295743f, 0.544053f, 0.690064f, 0.858138f, 0.497181f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002610
2611 armnn::TensorInfo outputDesc({ 2, 16 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002612 std::vector<float> expectedOutput =
2613 {-0.00396806f, 0.029352f, -0.00279226f, 0.0159977f, -0.00835576f,
2614 -0.0211779f, 0.0283512f, -0.0114597f, 0.00907307f, -0.0244004f,
2615 -0.0152191f, -0.0259063f, 0.00914318f, 0.00415118f, 0.017147f,
2616 0.0134203f, -0.013869f, 0.0287268f, -0.00334693f, 0.00733398f, -0.0287926f,
2617 -0.0186926f, 0.0193662f, -0.0115437f, 0.00422612f, -0.0345232f,
2618 0.00223253f, -0.00957321f, 0.0210624f, 0.013331f, 0.0150954f, 0.02168f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002619 return LstmLayerNoCifgWithPeepholeWithProjectionTestImpl<armnn::DataType::Float32>(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002620 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002621}
2622
2623LayerTestResult<float, 2> LstmLayerFloat32NoCifgNoPeepholeNoProjectionTest(
2624 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002625 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2626 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002627{
2628 armnn::TensorInfo inputDesc({2, 2}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002629 std::vector<float> input = {2., 3., 3., 4.};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002630
2631 armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002632 std::vector<float> expectedOutput =
2633 {-0.02973187f, 0.1229473f, 0.20885126f, -0.15358765f,
2634 -0.0185422f, 0.11281417f, 0.24466537f, -0.1826292f};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002635
2636 return LstmNoCifgNoPeepholeNoProjectionTestImpl<armnn::DataType::Float32>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002637 workloadFactory, memoryManager, tensorHandleFactory,
2638 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002639}
2640
2641LayerTestResult<float, 2> LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002642 armnn::IWorkloadFactory& workloadFactory,
2643 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2644 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002645{
2646 armnn::TensorInfo inputDesc({ 2, 5 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002647 std::vector<float> input =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002648 {0.7f, 0.8f, 0.1f, 0.2f, 0.3f, //batch 0
Sadik Armagan483c8112021-06-01 09:24:52 +01002649 0.3f, 0.2f, 0.9f, 0.8f, 0.1f}; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002650
2651 armnn::TensorInfo outputDesc({ 2, 3 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +01002652 std::vector<float> expectedOutput =
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002653 { 0.0244077f, 0.128027f, -0.00170918f, //batch 0
Sadik Armagan483c8112021-06-01 09:24:52 +01002654 -0.00692428f, 0.0848741f, 0.063445f}; //batch 1
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002655 return LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl<armnn::DataType::Float32>(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002656 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002657}
2658
2659LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgNoPeepholeNoProjectionTest(
2660 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002661 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2662 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002663{
2664 const float qScale = 1.0f;
2665 const int32_t qOffset = 0;
2666
Derek Lambertif90c56d2020-01-10 17:14:08 +00002667 const armnn::DataType datatype = armnn::DataType::QSymmS16;
2668 const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8;
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002669
2670 armnn::TensorInfo inputDesc({2, 2}, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002671 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002672
2673 armnn::TensorInfo outputDesc({2, 4}, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002674 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2675 {
2676 -0.02973187f, 0.12294730f, 0.20885126f, -0.15358765f,
2677 -0.01854220f, 0.11281417f, 0.24466537f, -0.18262920f
2678 },
2679 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002680
2681 return LstmNoCifgNoPeepholeNoProjectionTestImpl<datatype>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002682 workloadFactory, memoryManager, tensorHandleFactory,
2683 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2684 qScale, qOffset, constantDatatype);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002685
2686}
2687
2688LayerTestResult<int16_t, 2> LstmLayerInt16WithCifgWithPeepholeNoProjectionTest(
2689 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002690 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2691 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002692{
2693 const float qScale = 1.0f;
2694 const int32_t qOffset = 0;
2695
Derek Lambertif90c56d2020-01-10 17:14:08 +00002696 const armnn::DataType datatype = armnn::DataType::QSymmS16;
2697 const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8;
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002698
2699 armnn::TensorInfo inputDesc({ 2, 2 }, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002700 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002701
2702 armnn::TensorInfo outputDesc({ 2, 4 }, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002703 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2704 {
2705 -0.36444446f, -0.00352185f, 0.12886585f, -0.05163646f,
2706 -0.42734814f, -0.00478661f, 0.13455015f, -0.03560682f
2707 },
2708 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002709
2710 return LstmLayerWithCifgWithPeepholeNoProjectionTestImpl<datatype>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002711 workloadFactory, memoryManager, tensorHandleFactory,
2712 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2713 qScale, qOffset, constantDatatype);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002714}
2715
2716LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgWithPeepholeWithProjectionTest(
2717 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002718 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2719 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002720{
2721 const float qScale = 2.0f;
2722 const int32_t qOffset = 0;
2723
Derek Lambertif90c56d2020-01-10 17:14:08 +00002724 const armnn::DataType datatype = armnn::DataType::QSymmS16;
2725 const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8;
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002726
2727 armnn::TensorInfo inputDesc({ 2, 5 }, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002728 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>(
2729 {
2730 0.787926f, 0.151646f, 0.071352f, 0.118426f, 0.458058f,
2731 0.295743f, 0.544053f, 0.690064f, 0.858138f, 0.497181f
2732 },
2733 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002734
2735 armnn::TensorInfo outputDesc({ 2, 16 }, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002736 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2737 {
2738 -0.00396806f, 0.02935200f, -0.00279226f, 0.01599770f,
2739 -0.00835576f, -0.02117790f, 0.02835120f, -0.01145970f,
2740 0.00907307f, -0.02440040f, -0.01521910f, -0.02590630f,
2741 0.00914318f, 0.00415118f, 0.01714700f, 0.01342030f,
2742 -0.01386900f, 0.02872680f, -0.00334693f, 0.00733398f,
2743 -0.02879260f, -0.01869260f, 0.01936620f, -0.01154370f,
2744 0.00422612f, -0.03452320f, 0.00223253f, -0.00957321f,
2745 0.02106240f, 0.01333100f, 0.01509540f, 0.02168000f
2746 },
2747 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002748
2749 return LstmLayerNoCifgWithPeepholeWithProjectionTestImpl<datatype>(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002750 workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput, qScale, qOffset, constantDatatype);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002751}
2752
2753LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgNoPeepholeNoProjectionInt16ConstantTest(
2754 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002755 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2756 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002757{
2758 const float qScale = 1.0f;
2759 const int32_t qOffset = 0;
2760
Derek Lambertif90c56d2020-01-10 17:14:08 +00002761 const armnn::DataType datatype = armnn::DataType::QSymmS16; // datatype & constants set to QSymm16
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002762
2763 armnn::TensorInfo inputDesc({2, 2}, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002764 std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002765
2766 armnn::TensorInfo outputDesc({2, 4}, datatype);
Sadik Armagan483c8112021-06-01 09:24:52 +01002767 std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>(
2768 {
2769 -0.02973187f, 0.12294730f, 0.20885126f, -0.15358765f,
2770 -0.01854220f, 0.11281417f, 0.24466537f, -0.18262920f
2771 },
2772 qScale, qOffset);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002773
2774 return LstmNoCifgNoPeepholeNoProjectionTestImpl<datatype>(
Sadik Armagan483c8112021-06-01 09:24:52 +01002775 workloadFactory, memoryManager, tensorHandleFactory,
2776 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(),
2777 qScale, qOffset, datatype);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002778}
2779
2780//
2781// QuantizedLstm
2782//
2783
2784LayerTestResult<uint8_t, 2> QuantizedLstmTest(
2785 armnn::IWorkloadFactory& workloadFactory,
Finn Williamsc43de6a2020-08-27 11:13:25 +01002786 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2787 const armnn::ITensorHandleFactory& tensorHandleFactory)
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002788{
Derek Lambertif90c56d2020-01-10 17:14:08 +00002789 armnn::TensorInfo inputDesc({2, 2}, armnn::DataType::QAsymmU8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002790 std::vector<uint8_t> input = {166, 179, 50, 150};
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002791
Derek Lambertif90c56d2020-01-10 17:14:08 +00002792 armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::QAsymmU8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002793 std::vector<uint8_t> expectedOutput = {140, 151, 146, 112, 136, 156, 142, 112 };
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002794
Sadik Armagan483c8112021-06-01 09:24:52 +01002795 return QuantizedLstmTestImpl(workloadFactory, memoryManager, tensorHandleFactory,
2796 input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape());
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +01002797}
James Conroy4f1f8992020-04-29 20:01:10 +01002798
2799// QLSTM
2800LayerTestResult<int8_t, 2> QLstmTest(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002801 armnn::IWorkloadFactory& workloadFactory,
2802 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2803 const armnn::ITensorHandleFactory& tensorHandleFactory)
James Conroy4f1f8992020-04-29 20:01:10 +01002804{
2805 armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002806 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
James Conroy4f1f8992020-04-29 20:01:10 +01002807
2808 armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002809 std::vector<int8_t> expectedOutput = {-15, 21, 14, 20, -15, 15, 5, 27};
James Conroy4f1f8992020-04-29 20:01:10 +01002810
Finn Williamsc43de6a2020-08-27 11:13:25 +01002811 return QLstmTestImpl(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
James Conroy4f1f8992020-04-29 20:01:10 +01002812}
James Conroyb22a75e2020-06-08 14:53:10 +01002813
2814LayerTestResult<int8_t, 2> QLstmTest1(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002815 armnn::IWorkloadFactory& workloadFactory,
2816 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2817 const armnn::ITensorHandleFactory& tensorHandleFactory)
James Conroyb22a75e2020-06-08 14:53:10 +01002818{
2819 armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002820 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
James Conroyb22a75e2020-06-08 14:53:10 +01002821
2822 armnn::TensorInfo outputDesc({2, 3}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002823 std::vector<int8_t> expectedOutput = {127, 127, -108, -67, 127, 127};
James Conroyb22a75e2020-06-08 14:53:10 +01002824
Finn Williamsc43de6a2020-08-27 11:13:25 +01002825 return QLstmTestImpl1(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
James Conroyb22a75e2020-06-08 14:53:10 +01002826}
2827
2828LayerTestResult<int8_t, 2> QLstmTest2(
Finn Williamsc43de6a2020-08-27 11:13:25 +01002829 armnn::IWorkloadFactory& workloadFactory,
2830 const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
2831 const armnn::ITensorHandleFactory& tensorHandleFactory)
James Conroyb22a75e2020-06-08 14:53:10 +01002832{
2833 armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002834 std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64};
James Conroyb22a75e2020-06-08 14:53:10 +01002835
2836 armnn::TensorInfo outputDesc({2, 3}, armnn::DataType::QAsymmS8);
Sadik Armagan483c8112021-06-01 09:24:52 +01002837 std::vector<int8_t> expectedOutput = {127, 127, 127, -128, 127, 127};
James Conroyb22a75e2020-06-08 14:53:10 +01002838
Finn Williamsc43de6a2020-08-27 11:13:25 +01002839 return QLstmTestImpl2(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput);
James Conroyb22a75e2020-06-08 14:53:10 +01002840}