telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1 | // |
Teresa Charlin | fbf0e5b | 2020-08-17 01:01:06 +0100 | [diff] [blame] | 2 | // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. |
David Beck | ecb56cd | 2018-09-05 12:52:57 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 4 | // |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 5 | |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 6 | #include "LstmTestImpl.hpp" |
Aron Virginas-Tar | c9cc804 | 2018-11-01 16:15:57 +0000 | [diff] [blame] | 7 | |
Aron Virginas-Tar | 48623a0 | 2019-10-22 10:00:28 +0100 | [diff] [blame] | 8 | #include <QuantizeHelper.hpp> |
| 9 | |
Matthew Sloyan | 171214c | 2020-09-09 09:07:37 +0100 | [diff] [blame] | 10 | #include <armnn/utility/NumericCast.hpp> |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 11 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 12 | #include <backendsCommon/TensorHandle.hpp> |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 13 | |
Sadik Armagan | a097d2a | 2021-11-24 15:47:28 +0000 | [diff] [blame] | 14 | #include <armnnTestUtils/TensorCopyUtils.hpp> |
| 15 | #include <WorkloadTestUtils.hpp> |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 16 | |
| 17 | #include <reference/workloads/Decoders.hpp> |
| 18 | #include <reference/workloads/Encoders.hpp> |
| 19 | #include <reference/workloads/LstmUtils.hpp> |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 20 | |
Sadik Armagan | a097d2a | 2021-11-24 15:47:28 +0000 | [diff] [blame] | 21 | #include <TensorHelpers.hpp> |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 22 | |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 23 | #include <doctest/doctest.h> |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 24 | namespace |
| 25 | { |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 26 | |
| 27 | template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> |
| 28 | void LstmUtilsVectorBatchVectorAddTestImpl( |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 29 | std::vector<float>& vec, |
| 30 | std::vector<float>& batchVec, |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 31 | uint32_t vSize, |
| 32 | uint32_t nBatch, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 33 | std::vector<float>& expectedOutput, |
| 34 | armnn::TensorShape& expectedShape) |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 35 | { |
| 36 | float qScale = 0.0f; |
| 37 | int32_t qOffset = 0; |
| 38 | armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType, qScale, qOffset ); |
| 39 | |
| 40 | // Make encoder and decoder |
| 41 | std::unique_ptr<armnn::Decoder<float>> vecDecoder = armnn::MakeDecoder<float>(tensorInfo, vec.data()); |
| 42 | std::unique_ptr<armnn::Decoder<float>> batchVecDecoder = armnn::MakeDecoder<float>(tensorInfo, batchVec.data()); |
| 43 | std::unique_ptr<armnn::Encoder<float>> batchVecEncoder = armnn::MakeEncoder<float>(tensorInfo, batchVec.data()); |
| 44 | |
| 45 | VectorBatchVectorAdd(*vecDecoder, vSize, *batchVecDecoder, nBatch, *batchVecEncoder); |
| 46 | |
| 47 | // check shape and compare values |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 48 | auto result = CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape); |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 49 | CHECK_MESSAGE(result.m_Result, result.m_Message.str()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 50 | |
| 51 | // check if iterator is back at start position |
| 52 | batchVecEncoder->Set(1.0f); |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 53 | CHECK(batchVec[0] == 1.0f); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 54 | } |
| 55 | |
| 56 | template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> |
| 57 | void LstmUtilsZeroVectorTestImpl( |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 58 | std::vector<float>& input, |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 59 | uint32_t vSize, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 60 | std::vector<float>& expectedOutput, |
| 61 | armnn::TensorShape& expectedShape) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 62 | { |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 63 | float qScale = 0.0f; |
| 64 | int32_t qOffset = 0; |
| 65 | |
| 66 | armnn::TensorInfo tensorInfo({vSize}, ArmnnType, qScale, qOffset ); |
| 67 | |
| 68 | // Make encoder for input |
| 69 | std::unique_ptr<armnn::Encoder<float>> outputEncoder = armnn::MakeEncoder<float>(tensorInfo, input.data()); |
| 70 | |
| 71 | // call ZeroVector |
| 72 | ZeroVector(*outputEncoder, vSize); |
| 73 | |
| 74 | // check shape and compare values |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 75 | auto result = CompareTensors(input, expectedOutput, expectedShape, expectedShape); |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 76 | CHECK_MESSAGE(result.m_Result, result.m_Message.str()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 77 | |
| 78 | // check if iterator is back at start position |
| 79 | outputEncoder->Set(1.0f); |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 80 | CHECK(input[0] == 1.0f); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 81 | |
| 82 | } |
| 83 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 84 | template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> |
| 85 | void LstmUtilsMeanStddevNormalizationTestImpl( |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 86 | std::vector<float>& input, |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 87 | uint32_t vSize, |
| 88 | uint32_t nBatch, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 89 | std::vector<float>& expectedOutput, |
| 90 | armnn::TensorShape& expectedShape) |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 91 | { |
| 92 | float qScale = 0.0f; |
| 93 | int32_t qOffset = 0; |
| 94 | armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType, qScale, qOffset ); |
| 95 | |
| 96 | // Make encoder and decoder for input |
| 97 | std::unique_ptr<armnn::Decoder<float>> inputDecoder = armnn::MakeDecoder<float>(tensorInfo, input.data()); |
| 98 | std::unique_ptr<armnn::Encoder<float>> outputEncoder = armnn::MakeEncoder<float>(tensorInfo, input.data()); |
| 99 | |
| 100 | MeanStddevNormalization(*inputDecoder, *outputEncoder, vSize, nBatch, 1e-8f); |
| 101 | |
| 102 | // check shape and compare values |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 103 | auto result = CompareTensors(input, expectedOutput, expectedShape, expectedShape); |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 104 | CHECK_MESSAGE(result.m_Result, result.m_Message.str()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 105 | |
| 106 | // check if iterator is back at start position |
| 107 | outputEncoder->Set(1.0f); |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 108 | CHECK(input[0] == 1.0f); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 109 | } |
| 110 | |
| 111 | template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> |
| 112 | void LstmUtilsVectorBatchVectorCwiseProductTestImpl( |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 113 | std::vector<float>& vec, |
| 114 | std::vector<float>& batchVec, |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 115 | uint32_t vSize, |
| 116 | uint32_t nBatch, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 117 | std::vector<float>& expectedOutput, |
| 118 | armnn::TensorShape& expectedShape) |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 119 | { |
| 120 | float qScale = 0.0f; |
| 121 | int32_t qOffset = 0; |
| 122 | armnn::TensorInfo tensorInfo({nBatch, vSize}, ArmnnType, qScale, qOffset ); |
| 123 | |
| 124 | // Make encoder and decoder |
| 125 | std::unique_ptr<armnn::Decoder<float>> vecDecoder = armnn::MakeDecoder<float>(tensorInfo, vec.data()); |
| 126 | std::unique_ptr<armnn::Decoder<float>> batchVecDecoder = armnn::MakeDecoder<float>(tensorInfo, batchVec.data()); |
| 127 | std::unique_ptr<armnn::Encoder<float>> batchVecEncoder = armnn::MakeEncoder<float>(tensorInfo, batchVec.data()); |
| 128 | |
| 129 | VectorBatchVectorCwiseProduct(*vecDecoder, vSize, *batchVecDecoder, nBatch, *batchVecEncoder); |
| 130 | |
| 131 | // check shape and compare values |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 132 | auto result = CompareTensors(batchVec, expectedOutput, expectedShape, expectedShape); |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 133 | CHECK_MESSAGE(result.m_Result, result.m_Message.str()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 134 | |
| 135 | // check if iterator is back at start position |
| 136 | batchVecEncoder->Set(1.0f); |
Sadik Armagan | 1625efc | 2021-06-10 18:24:34 +0100 | [diff] [blame] | 137 | CHECK(batchVec[0] == 1.0f); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 138 | } |
| 139 | |
| 140 | // Lstm Layer tests: |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 141 | // *********************************** // |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 142 | template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> |
| 143 | LayerTestResult<T, 2> |
| 144 | LstmNoCifgNoPeepholeNoProjectionTestImpl( |
Aron Virginas-Tar | 5caf907 | 2018-11-14 18:35:18 +0000 | [diff] [blame] | 145 | armnn::IWorkloadFactory& workloadFactory, |
| 146 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 147 | const armnn::ITensorHandleFactory& tensorHandleFactory, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 148 | const std::vector<T>& input, |
| 149 | const std::vector<T>& outputExpected, |
| 150 | const armnn::TensorShape& inputShape, |
| 151 | const armnn::TensorShape& outputExpectedShape, |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 152 | float qScale = 0.0f, |
| 153 | int32_t qOffset = 0, |
| 154 | armnn::DataType constantDataType = armnn::DataType::Float32) |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 155 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 156 | IgnoreUnused(memoryManager); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 157 | unsigned int batchSize = armnn::numeric_cast<unsigned int>(inputShape[0]); |
| 158 | unsigned int inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]); |
| 159 | unsigned int outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 160 | // cellSize and outputSize have the same size when there is no projection. |
| 161 | unsigned numUnits = outputSize; |
| 162 | |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 163 | armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset ); |
| 164 | armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset); |
| 165 | armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 166 | |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 167 | armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset); |
| 168 | armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset); |
| 169 | armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset); |
| 170 | armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 171 | |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 172 | std::vector<T> inputVector; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 173 | inputVector.assign(input.data(), input.data() + (batchSize * inputSize)); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 174 | |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 175 | std::vector<T> cellStateInVector(batchSize * numUnits, T()); |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 176 | std::vector<T> outputStateInVector(batchSize * outputSize, T()); |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 177 | std::vector<T> scratchBufferVector(batchSize * numUnits * 4, T()); |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 178 | std::vector<T> outputStateOutVector(batchSize * outputSize, T()); |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 179 | std::vector<T> cellStateOutVector(batchSize * numUnits, T()); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 180 | |
| 181 | std::vector<T> actualOutput(outputTensorInfo.GetNumElements()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 182 | |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 183 | std::vector<T> outputVector; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 184 | outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize)); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 185 | |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 186 | std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 187 | std::unique_ptr<armnn::ITensorHandle> cellStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 188 | tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 189 | std::unique_ptr<armnn::ITensorHandle> outputStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 190 | tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 191 | |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 192 | std::unique_ptr<armnn::ITensorHandle> scratchHandle = |
| 193 | tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 194 | std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 195 | tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 196 | std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 197 | tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo); |
| 198 | std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 199 | |
| 200 | armnn::LstmQueueDescriptor data; |
| 201 | armnn::WorkloadInfo info; |
| 202 | |
| 203 | AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get()); |
| 204 | AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get()); |
| 205 | AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get()); |
| 206 | |
| 207 | AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get()); |
| 208 | AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get()); |
| 209 | AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get()); |
| 210 | AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get()); |
| 211 | |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 212 | armnn::TensorInfo tensorInfo4({numUnits}, constantDataType , qScale, qOffset); |
| 213 | armnn::TensorInfo tensorInfo8({numUnits, 2}, constantDataType, qScale, qOffset); |
| 214 | armnn::TensorInfo tensorInfo16({numUnits, 4}, constantDataType, qScale, qOffset); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 215 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 216 | std::vector<float> inputToInputWeights = {-0.45018822f, -0.02338299f, -0.0870589f, |
| 217 | -0.34550029f, 0.04266912f, -0.15680569f, |
| 218 | -0.34856534f, 0.43890524f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 219 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 220 | std::vector<float> inputToForgetWeights = { 0.09701663f, 0.20334584f, -0.50592935f, |
| 221 | -0.31343272f, -0.40032279f, 0.44781327f, |
| 222 | 0.01387155f, -0.35593212f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 223 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 224 | std::vector<float> inputToCellWeights = { -0.50013041f, 0.1370284f, 0.11810488f, 0.2013163f, |
| 225 | -0.20583314f, 0.44344562f, 0.22077113f, |
| 226 | -0.29909778f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 227 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 228 | std::vector<float> inputToOutputWeights = { -0.25065863f, -0.28290087f, 0.04613829f, |
| 229 | 0.40525138f, 0.44272184f, 0.03897077f, |
| 230 | -0.1556896f, 0.19487578f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 231 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 232 | std::vector<float> recurrentToInputWeights = {-0.0063535f, -0.2042388f, 0.31454784f, |
| 233 | -0.35746509f, 0.28902304f, 0.08183324f, |
| 234 | -0.16555229f, 0.02286911f, -0.13566875f, |
| 235 | 0.03034258f, 0.48091322f, -0.12528998f, |
| 236 | 0.24077177f, -0.51332325f, -0.33502164f, |
| 237 | 0.10629296f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 238 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 239 | std::vector<float> recurrentToForgetWeights = { -0.48684245f, -0.06655136f, 0.42224967f, |
| 240 | 0.2112639f, 0.27654213f, 0.20864892f, |
| 241 | -0.07646349f, 0.45877004f, 0.00141793f, |
| 242 | -0.14609534f, 0.36447752f, 0.09196436f, |
| 243 | 0.28053468f, 0.01560611f, -0.20127171f, |
| 244 | -0.01140004f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 245 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 246 | std::vector<float> recurrentToCellWeights = { -0.3407414f, 0.24443203f, -0.2078532f, |
| 247 | 0.26320225f, 0.05695659f, -0.00123841f, |
| 248 | -0.4744786f, -0.35869038f, -0.06418842f, |
| 249 | -0.13502428f, -0.501764f, 0.22830659f, |
| 250 | -0.46367589f, 0.26016325f, -0.03894562f, |
| 251 | -0.16368064f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 252 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 253 | std::vector<float> recurrentToOutputWeights = { 0.43385774f, -0.17194885f, 0.2718237f, |
| 254 | 0.09215671f, 0.24107647f, -0.39835793f, |
| 255 | 0.18212086f, 0.01301402f, 0.48572797f, |
| 256 | -0.50656658f, 0.20047462f, -0.20607421f, |
| 257 | -0.51818722f, -0.15390486f, 0.0468148f, |
| 258 | 0.39922136f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 259 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 260 | std::vector<float> cellToInputWeights = {0., 0., 0., 0.}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 261 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 262 | std::vector<float> inputGateBias = {0., 0., 0., 0.}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 263 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 264 | std::vector<float> forgetGateBias = {1., 1., 1., 1.}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 265 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 266 | std::vector<float> cellBias = {0., 0., 0., 0.}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 267 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 268 | std::vector<float> outputGateBias = {0., 0., 0., 0.}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 269 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 270 | armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo8); |
| 271 | armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo8); |
| 272 | armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo8); |
| 273 | armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo8); |
| 274 | armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo16); |
| 275 | armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo16); |
| 276 | armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo16); |
| 277 | armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo16); |
| 278 | armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo4); |
| 279 | armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo4); |
| 280 | armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo4); |
| 281 | armnn::ScopedTensorHandle cellBiasTensor(tensorInfo4); |
| 282 | armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo4); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 283 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 284 | AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data()); |
| 285 | AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data()); |
| 286 | AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data()); |
| 287 | AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data()); |
| 288 | AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data()); |
| 289 | AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data()); |
| 290 | AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data()); |
| 291 | AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data()); |
| 292 | AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data()); |
| 293 | AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data()); |
| 294 | AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data()); |
| 295 | AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data()); |
| 296 | AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 297 | |
| 298 | data.m_InputToInputWeights = &inputToInputWeightsTensor; |
| 299 | data.m_InputToForgetWeights = &inputToForgetWeightsTensor; |
| 300 | data.m_InputToCellWeights = &inputToCellWeightsTensor; |
| 301 | data.m_InputToOutputWeights = &inputToOutputWeightsTensor; |
| 302 | data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; |
| 303 | data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; |
| 304 | data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; |
| 305 | data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 306 | data.m_InputGateBias = &inputGateBiasTensor; |
| 307 | data.m_ForgetGateBias = &forgetGateBiasTensor; |
| 308 | data.m_CellBias = &cellBiasTensor; |
| 309 | data.m_OutputGateBias = &outputGateBiasTensor; |
| 310 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 311 | // Flags to set test configuration |
| 312 | data.m_Parameters.m_ActivationFunc = 4; |
| 313 | data.m_Parameters.m_CifgEnabled = false; |
| 314 | data.m_Parameters.m_PeepholeEnabled = false; |
| 315 | data.m_Parameters.m_ProjectionEnabled = false; |
| 316 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 317 | std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info); |
| 318 | inputHandle->Allocate(); |
| 319 | outputStateInHandle->Allocate(); |
| 320 | cellStateInHandle->Allocate(); |
| 321 | |
| 322 | scratchHandle->Allocate(); |
| 323 | outputStateOutHandle->Allocate(); |
| 324 | cellStateOutHandle->Allocate(); |
| 325 | outputHandle->Allocate(); |
| 326 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 327 | CopyDataToITensorHandle(inputHandle.get(), inputVector.data()); |
| 328 | CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data()); |
| 329 | CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 330 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 331 | workload->Execute(); |
| 332 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 333 | CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 334 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 335 | return LayerTestResult<T, 2>(actualOutput, |
| 336 | outputVector, |
| 337 | outputHandle->GetShape(), |
| 338 | outputTensorInfo.GetShape()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 339 | } |
| 340 | |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 341 | template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> |
| 342 | LayerTestResult<T, 2> |
Matteo Martincigh | a65b7ae | 2018-11-14 12:39:55 +0000 | [diff] [blame] | 343 | LstmLayerNoCifgWithPeepholeWithProjectionTestImpl(armnn::IWorkloadFactory& workloadFactory, |
| 344 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 345 | const armnn::ITensorHandleFactory& tensorHandleFactory, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 346 | const std::vector<T>& input, |
| 347 | const std::vector<T>& outputExpected, |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 348 | float qScale = 0.0f, |
| 349 | int32_t qOffset = 0, |
| 350 | armnn::DataType constantDataType = armnn::DataType::Float32) |
Aron Virginas-Tar | 5caf907 | 2018-11-14 18:35:18 +0000 | [diff] [blame] | 351 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 352 | IgnoreUnused(memoryManager); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 353 | unsigned int batchSize = 2; |
| 354 | unsigned int outputSize = 16; |
| 355 | unsigned int inputSize = 5; |
| 356 | unsigned numUnits = 20; |
| 357 | |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 358 | armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset); |
| 359 | armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset); |
| 360 | armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 361 | |
Matteo Martincigh | a65b7ae | 2018-11-14 12:39:55 +0000 | [diff] [blame] | 362 | // Scratch buffer size without CIFG [batchSize, numUnits * 4] |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 363 | armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset); |
| 364 | armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset); |
| 365 | armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset); |
| 366 | armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 367 | |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 368 | std::vector<T> inputVector; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 369 | inputVector.assign(input.data(), input.data() + (batchSize * inputSize)); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 370 | |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 371 | std::vector<T> cellStateInVector(batchSize * numUnits, T()); |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 372 | std::vector<T> outputStateInVector(batchSize * outputSize, T()); |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 373 | std::vector<T> scratchBufferVector(batchSize * numUnits * 4, T()); |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 374 | std::vector<T> outputStateOutVector(batchSize * outputSize, T()); |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 375 | std::vector<T> cellStateOutVector(batchSize * numUnits, T()); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 376 | |
| 377 | std::vector<T> actualOutput(outputTensorInfo.GetNumElements()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 378 | |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 379 | std::vector<T> outputVector; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 380 | outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize)); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 381 | |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 382 | std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 383 | std::unique_ptr<armnn::ITensorHandle> cellStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 384 | tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 385 | std::unique_ptr<armnn::ITensorHandle> outputStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 386 | tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 387 | |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 388 | std::unique_ptr<armnn::ITensorHandle> scratchHandle = |
| 389 | tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 390 | std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 391 | tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 392 | std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 393 | tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo); |
| 394 | std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 395 | |
| 396 | armnn::LstmQueueDescriptor data; |
| 397 | armnn::WorkloadInfo info; |
| 398 | |
| 399 | AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get()); |
| 400 | AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get()); |
| 401 | AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get()); |
David Beck | ac42efd | 2018-09-26 17:41:13 +0100 | [diff] [blame] | 402 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 403 | AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get()); |
| 404 | AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get()); |
| 405 | AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get()); |
| 406 | AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get()); |
| 407 | |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 408 | armnn::TensorInfo tensorInfo16({outputSize}, constantDataType, qScale, qOffset); |
| 409 | armnn::TensorInfo tensorInfo20({numUnits}, constantDataType, qScale, qOffset); |
| 410 | armnn::TensorInfo tensorInfo20x5({numUnits, inputSize}, constantDataType, qScale, qOffset); |
| 411 | armnn::TensorInfo tensorInfo20x16({numUnits, outputSize}, constantDataType, qScale, qOffset); |
| 412 | armnn::TensorInfo tensorInfo16x20({outputSize, numUnits}, constantDataType, qScale, qOffset); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 413 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 414 | std::vector<float> inputToInputWeights = {0.021393683f,0.06124551f, 0.046905167f,-0.014657677f,-0.03149463f, |
| 415 | 0.09171803f, 0.14647801f,0.10797193f, -0.0057968358f,0.0019193048f, |
| 416 | -0.2726754f, 0.10154029f, -0.018539885f, 0.080349885f, -0.10262385f, |
| 417 | -0.022599787f,-0.09121155f, -0.008675967f, -0.045206103f,-0.0821282f, |
| 418 | -0.008045952f,0.015478081f, 0.055217247f, 0.038719587f, 0.044153627f, |
| 419 | -0.06453243f,0.05031825f, -0.046935108f, -0.008164439f, 0.014574226f, |
| 420 | -0.1671009f, -0.15519552f, -0.16819797f,-0.13971269f,-0.11953059f, |
| 421 | 0.25005487f, -0.22790983f, 0.009855087f, -0.028140958f, -0.11200698f, |
| 422 | 0.11295408f, -0.0035217577f, 0.054485075f, 0.05184695f, 0.064711206f, |
| 423 | 0.10989193f, 0.11674786f, 0.03490607f, 0.07727357f, 0.11390585f, |
| 424 | -0.1863375f, -0.1034451f, -0.13945189f, -0.049401227f, -0.18767063f, |
| 425 | 0.042483903f, 0.14233552f, 0.13832581f, 0.18350165f, 0.14545603f, |
| 426 | -0.028545704f,0.024939531f,0.050929718f,0.0076203286f,-0.0029723682f, |
| 427 | -0.042484224f, -0.11827596f, -0.09171104f, -0.10808628f,-0.16327988f, |
| 428 | -0.2273378f, -0.0993647f, -0.017155107f,0.0023917493f,0.049272764f, |
| 429 | 0.0038534778f, 0.054764505f, 0.089753784f, 0.06947234f, 0.08014476f, |
| 430 | -0.04544234f, -0.0497073f,-0.07135631f, -0.048929106f,-0.004042012f, |
| 431 | -0.009284026f, 0.018042054f, 0.0036860977f,-0.07427302f, -0.11434604f, |
| 432 | -0.018995456f, 0.031487543f, 0.012834908f,0.019977754f,0.044256654f, |
| 433 | -0.39292613f, -0.18519334f, -0.11651281f,-0.06809892f, 0.011373677f }; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 434 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 435 | std::vector<float> inputToForgetWeights = {-0.0018401089f, -0.004852237f,0.03698424f, 0.014181704f,0.028273236f, |
| 436 | -0.016726194f, -0.05249759f,-0.10204261f, 0.00861066f,-0.040979505f, |
| 437 | -0.009899187f,0.01923892f,-0.028177269f, -0.08535103f,-0.14585495f, |
| 438 | 0.10662567f,-0.01909731f,-0.017883534f,-0.0047269356f,-0.045103323f, |
| 439 | 0.0030784295f,0.076784775f,0.07463696f, 0.094531395f,0.0814421f, |
| 440 | -0.12257899f, -0.033945758f,-0.031303465f, 0.045630626f,0.06843887f, |
| 441 | -0.13492945f, -0.012480007f,-0.0811829f, -0.07224499f,-0.09628791f, |
| 442 | 0.045100946f,0.0012300825f, 0.013964662f, 0.099372394f,0.02543059f, |
| 443 | 0.06958324f, 0.034257296f, 0.0482646f, 0.06267997f,0.052625068f, |
| 444 | 0.12784666f, 0.07077897f, 0.025725935f, 0.04165009f,0.07241905f, |
| 445 | 0.018668644f, -0.037377294f,-0.06277783f,-0.08833636f,-0.040120605f, |
| 446 | -0.011405586f,-0.007808335f,-0.010301386f,-0.005102167f,0.027717464f, |
| 447 | 0.05483423f, 0.11449111f, 0.11289652f,0.10939839f, 0.13396506f, |
| 448 | -0.08402166f,-0.01901462f, -0.044678304f,-0.07720565f,0.014350063f, |
| 449 | -0.11757958f, -0.0652038f, -0.08185733f,-0.076754324f,-0.092614375f, |
| 450 | 0.10405491f, 0.052960336f, 0.035755895f,0.035839386f,-0.012540553f, |
| 451 | 0.036881298f, 0.02913376f, 0.03420159f,0.05448447f,-0.054523353f, |
| 452 | 0.02582715f, 0.02327355f, -0.011857179f,-0.0011980024f,-0.034641717f, |
| 453 | -0.026125094f,-0.17582615f,-0.15923657f,-0.27486774f,-0.0006143371f, |
| 454 | 0.0001771948f, -8.470171e-05f, 0.02651807f,0.045790765f,0.06956496f }; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 455 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 456 | std::vector<float> inputToCellWeights = { -0.04580283f, -0.09549462f, -0.032418985f, -0.06454633f, |
| 457 | -0.043528453f, 0.043018587f, -0.049152344f, -0.12418144f, |
| 458 | -0.078985475f, -0.07596889f, 0.019484362f, -0.11434962f, |
| 459 | -0.0074034138f, -0.06314844f, -0.092981495f, 0.0062155537f, |
| 460 | -0.025034338f, -0.0028890965f, 0.048929527f, 0.06235075f, |
| 461 | 0.10665918f, -0.032036792f, -0.08505916f, -0.10843358f, |
| 462 | -0.13002433f, -0.036816437f, -0.02130134f, -0.016518239f, |
| 463 | 0.0047691227f, -0.0025825808f, 0.066017866f, 0.029991534f, |
| 464 | -0.10652836f, -0.1037554f, -0.13056071f, -0.03266643f, |
| 465 | -0.033702414f, -0.006473424f, -0.04611692f, 0.014419339f, |
| 466 | -0.025174323f, 0.0396852f, 0.081777506f, 0.06157468f, |
| 467 | 0.10210095f, -0.009658194f, 0.046511717f, 0.03603906f, |
| 468 | 0.0069369148f, 0.015960095f, -0.06507666f, 0.09551598f, |
| 469 | 0.053568836f, 0.06408714f, 0.12835667f, -0.008714329f, |
| 470 | -0.20211966f, -0.12093674f, 0.029450472f, 0.2849013f, |
| 471 | -0.029227901f, 0.1164364f, -0.08560263f, 0.09941786f, |
| 472 | -0.036999565f, -0.028842626f, -0.0033637602f, -0.017012902f, |
| 473 | -0.09720865f, -0.11193351f, -0.029155117f, -0.017936034f, |
| 474 | -0.009768936f, -0.04223324f, -0.036159635f, 0.06505112f, |
| 475 | -0.021742892f, -0.023377212f, -0.07221364f, -0.06430552f, |
| 476 | 0.05453865f, 0.091149814f, 0.06387331f, 0.007518393f, |
| 477 | 0.055960953f, 0.069779344f, 0.046411168f, 0.10509911f, |
| 478 | 0.07463894f, 0.0075130584f, 0.012850982f, 0.04555431f, |
| 479 | 0.056955688f, 0.06555285f, 0.050801456f, -0.009862683f, |
| 480 | 0.00826772f, -0.026555609f, -0.0073611983f, -0.0014897042f }; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 481 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 482 | std::vector<float> inputToOutputWeights ={-0.0998932f, -0.07201956f, -0.052803773f,-0.15629593f,-0.15001918f, |
| 483 | -0.07650751f,0.02359855f, -0.075155355f, -0.08037709f, -0.15093534f, |
| 484 | 0.029517552f, -0.04751393f, 0.010350531f,-0.02664851f, -0.016839722f, |
| 485 | -0.023121163f, 0.0077019283f, 0.012851257f, -0.05040649f,-0.0129761f, |
| 486 | -0.021737747f,-0.038305793f,-0.06870586f, -0.01481247f,-0.001285394f, |
| 487 | 0.10124236f, 0.083122835f, 0.053313006f,-0.062235646f,-0.075637154f, |
| 488 | -0.027833903f, 0.029774971f, 0.1130802f, 0.09218906f, 0.09506135f, |
| 489 | -0.086665764f,-0.037162706f,-0.038880914f,-0.035832845f,-0.014481564f, |
| 490 | -0.09825003f,-0.12048569f,-0.097665586f,-0.05287633f, -0.0964047f, |
| 491 | -0.11366429f, 0.035777505f, 0.13568819f, 0.052451383f,0.050649304f, |
| 492 | 0.05798951f, -0.021852335f,-0.099848844f,0.014740475f,-0.078897946f, |
| 493 | 0.04974699f, 0.014160473f, 0.06973932f, 0.04964942f, 0.033364646f, |
| 494 | 0.08190124f, 0.025535367f, 0.050893165f, 0.048514254f,0.06945813f, |
| 495 | -0.078907564f,-0.06707616f, -0.11844508f, -0.09986688f,-0.07509403f, |
| 496 | 0.06263226f, 0.14925587f, 0.20188436f, 0.12098451f,0.14639415f, |
| 497 | 0.0015017595f, -0.014267382f, -0.03417257f,0.012711468f,0.0028300495f, |
| 498 | -0.024758482f, -0.05098548f,-0.0821182f, 0.014225672f, 0.021544158f, |
| 499 | 0.08949725f, 0.07505268f, -0.0020780868f, 0.04908258f,0.06476295f, |
| 500 | -0.022907063f,0.027562456f,0.040185735f, 0.019567577f,-0.015598739f, |
| 501 | -0.049097303f, -0.017121866f, -0.083368234f,-0.02332002f,-0.0840956f }; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 502 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 503 | std::vector<float> inputGateBias = {0.02234832f, 0.14757581f, 0.18176508f, 0.10380666f, 0.053110216f, |
| 504 | -0.06928846f, -0.13942584f, -0.11816189f, 0.19483899f, 0.03652339f, |
| 505 | -0.10250295f, 0.036714908f, -0.18426876f, 0.036065217f, 0.21810818f, |
| 506 | 0.02383196f, -0.043370757f, 0.08690144f, -0.04444982f, 0.00030581196f }; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 507 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 508 | std::vector<float> forgetGateBias ={0.035185695f, -0.042891346f, -0.03032477f, 0.23027696f, |
| 509 | 0.11098921f, 0.15378423f, 0.09263801f, 0.09790885f, |
| 510 | 0.09508917f, 0.061199076f, 0.07665568f, -0.015443159f, |
| 511 | -0.03499149f, 0.046190713f, 0.08895977f, 0.10899629f, |
| 512 | 0.40694186f, 0.06030037f, 0.012413437f, -0.06108739f }; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 513 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 514 | std::vector<float> cellBias = { -0.024379363f, 0.0055531194f, 0.23377132f, 0.033463873f, |
| 515 | -0.1483596f, -0.10639995f, -0.091433935f, 0.058573797f, |
| 516 | -0.06809782f, -0.07889636f, -0.043246906f, -0.09829136f, |
| 517 | -0.4279842f, 0.034901652f, 0.18797937f, 0.0075234566f, |
| 518 | 0.016178843f, 0.1749513f, 0.13975595f, 0.92058027f }; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 519 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 520 | std::vector<float> outputGateBias ={0.046159424f, -0.0012809046f, 0.03563469f, 0.12648113f, 0.027195795f, |
| 521 | 0.35373217f, -0.018957434f, 0.008907322f, -0.0762701f, 0.12018895f, |
| 522 | 0.04216877f, 0.0022856654f, 0.040952638f, 0.3147856f, 0.08225149f, |
| 523 | -0.057416286f, -0.14995944f, -0.008040261f, 0.13208859f, 0.029760877f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 524 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 525 | std::vector<float> recurrentToInputWeights = { -0.001374326f, -0.078856036f, 0.10672688f, 0.029162422f, |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 526 | -0.11585556f, 0.02557986f, -0.13446963f, -0.035785314f, |
| 527 | -0.01244275f, 0.025961924f, -0.02337298f, -0.044228926f, |
| 528 | -0.055839065f, -0.046598054f, -0.010546039f, -0.06900766f, |
| 529 | 0.027239809f, 0.022582639f, -0.013296484f, -0.05459212f, |
| 530 | 0.08981f, -0.045407712f, 0.08682226f, -0.06867011f, |
| 531 | -0.14390695f, -0.02916037f, 0.000996957f, 0.091420636f, |
| 532 | 0.14283475f, -0.07390571f, -0.06402044f, 0.062524505f, |
| 533 | -0.093129106f, 0.04860203f, -0.08364217f, -0.08119002f, |
| 534 | 0.009352075f, 0.22920375f, 0.0016303885f, 0.11583097f, |
| 535 | -0.13732095f, 0.012405723f, -0.07551853f, 0.06343048f, |
| 536 | 0.12162708f, -0.031923793f, -0.014335606f, 0.01790974f, |
| 537 | -0.10650317f, -0.0724401f, 0.08554849f, -0.05727212f, |
| 538 | 0.06556731f, -0.042729504f, -0.043227166f, 0.011683251f, |
| 539 | -0.013082158f, -0.029302018f, -0.010899579f, -0.062036745f, |
| 540 | -0.022509435f, -0.00964907f, -0.01567329f, 0.04260106f, |
| 541 | -0.07787477f, -0.11576462f, 0.017356863f, 0.048673786f, |
| 542 | -0.017577527f, -0.05527947f, -0.082487635f, -0.040137455f, |
| 543 | -0.10820036f, -0.04666372f, 0.022746278f, -0.07851417f, |
| 544 | 0.01068115f, 0.032956902f, 0.022433773f, 0.0026891115f, |
| 545 | 0.08944216f, -0.0685835f, 0.010513544f, 0.07228705f, |
| 546 | 0.02032331f, -0.059686817f, -0.0005566496f, -0.086984694f, |
| 547 | 0.040414046f, -0.1380399f, 0.094208956f, -0.05722982f, |
| 548 | 0.012092817f, -0.04989123f, -0.086576f, -0.003399834f, |
| 549 | -0.04696032f, -0.045747425f, 0.10091314f, 0.048676282f, |
| 550 | -0.029037097f, 0.031399418f, -0.0040285117f, 0.047237843f, |
| 551 | 0.09504992f, 0.041799378f, -0.049185462f, -0.031518843f, |
| 552 | -0.10516937f, 0.026374253f, 0.10058866f, -0.0033195973f, |
| 553 | -0.041975245f, 0.0073591834f, 0.0033782164f, -0.004325073f, |
| 554 | -0.10167381f, 0.042500053f, -0.01447153f, 0.06464186f, |
| 555 | -0.017142897f, 0.03312627f, 0.009205989f, 0.024138335f, |
| 556 | -0.011337001f, 0.035530265f, -0.010912711f, 0.0706555f, |
| 557 | -0.005894094f, 0.051841937f, -0.1401738f, -0.02351249f, |
| 558 | 0.0365468f, 0.07590991f, 0.08838724f, 0.021681072f, |
| 559 | -0.10086113f, 0.019608743f, -0.06195883f, 0.077335775f, |
| 560 | 0.023646897f, -0.095322326f, 0.02233014f, 0.09756986f, |
| 561 | -0.048691444f, -0.009579111f, 0.07595467f, 0.11480546f, |
| 562 | -0.09801813f, 0.019894179f, 0.08502348f, 0.004032281f, |
| 563 | 0.037211012f, 0.068537936f, -0.048005626f, -0.091520436f, |
| 564 | -0.028379958f, -0.01556313f, 0.06554592f, -0.045599163f, |
| 565 | -0.01672207f, -0.020169014f, -0.011877351f, -0.20212261f, |
| 566 | 0.010889619f, 0.0047078193f, 0.038385306f, 0.08540671f, |
| 567 | -0.017140968f, -0.0035865551f, 0.016678626f, 0.005633034f, |
| 568 | 0.015963363f, 0.00871737f, 0.060130805f, 0.028611384f, |
| 569 | 0.10109069f, -0.015060172f, -0.07894427f, 0.06401885f, |
| 570 | 0.011584063f, -0.024466386f, 0.0047652307f, -0.09041358f, |
| 571 | 0.030737216f, -0.0046374933f, 0.14215417f, -0.11823516f, |
| 572 | 0.019899689f, 0.006106124f, -0.027092824f, 0.0786356f, |
| 573 | 0.05052217f, -0.058925f, -0.011402121f, -0.024987547f, |
| 574 | -0.0013661642f, -0.06832946f, -0.015667673f, -0.1083353f, |
| 575 | -0.00096863037f, -0.06988685f, -0.053350925f, -0.027275559f, |
| 576 | -0.033664223f, -0.07978348f, -0.025200296f, -0.017207067f, |
| 577 | -0.058403496f, -0.055697463f, 0.005798788f, 0.12965427f, |
| 578 | -0.062582195f, 0.0013350133f, -0.10482091f, 0.0379771f, |
| 579 | 0.072521195f, -0.0029455067f, -0.13797039f, -0.03628521f, |
| 580 | 0.013806405f, -0.017858358f, -0.01008298f, -0.07700066f, |
| 581 | -0.017081132f, 0.019358726f, 0.0027079724f, 0.004635139f, |
| 582 | 0.062634714f, -0.02338735f, -0.039547626f, -0.02050681f, |
| 583 | 0.03385117f, -0.083611414f, 0.002862572f, -0.09421313f, |
| 584 | 0.058618143f, -0.08598433f, 0.00972939f, 0.023867095f, |
| 585 | -0.053934585f, -0.023203006f, 0.07452513f, -0.048767887f, |
| 586 | -0.07314807f, -0.056307215f, -0.10433547f, -0.06440842f, |
| 587 | 0.04328182f, 0.04389765f, -0.020006588f, -0.09076438f, |
| 588 | -0.11652589f, -0.021705797f, 0.03345259f, -0.010329105f, |
| 589 | -0.025767034f, 0.013057034f, -0.07316461f, -0.10145612f, |
| 590 | 0.06358255f, 0.18531723f, 0.07759293f, 0.12006465f, |
| 591 | 0.1305557f, 0.058638252f, -0.03393652f, 0.09622831f, |
| 592 | -0.16253184f, -2.4580743e-06f, 0.079869635f, -0.070196845f, |
| 593 | -0.005644518f, 0.06857898f, -0.12598175f, -0.035084512f, |
| 594 | 0.03156317f, -0.12794146f, -0.031963028f, 0.04692781f, |
| 595 | 0.030070418f, 0.0071660685f, -0.095516115f, -0.004643372f, |
| 596 | 0.040170413f, -0.062104587f, -0.0037324072f, 0.0554317f, |
| 597 | 0.08184801f, -0.019164372f, 0.06791302f, 0.034257166f, |
| 598 | -0.10307039f, 0.021943003f, 0.046745934f, 0.0790918f, |
| 599 | -0.0265588f, -0.007824208f, 0.042546265f, -0.00977924f, |
| 600 | -0.0002440307f, -0.017384544f, -0.017990116f, 0.12252321f, |
| 601 | -0.014512694f, -0.08251313f, 0.08861942f, 0.13589665f, |
| 602 | 0.026351685f, 0.012641483f, 0.07466548f, 0.044301085f, |
| 603 | -0.045414884f, -0.051112458f, 0.03444247f, -0.08502782f, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 604 | -0.04106223f, -0.028126027f, 0.028473156f, 0.10467447f }; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 605 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 606 | std::vector<float> recurrentToForgetWeights = {-0.057784554f, -0.026057621f, -0.068447545f, -0.022581743f, |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 607 | 0.14811787f, 0.10826372f, 0.09471067f, 0.03987225f, |
| 608 | -0.0039523416f, 0.00030638507f, 0.053185795f, 0.10572994f, |
| 609 | 0.08414449f, -0.022036452f, -0.00066928595f, -0.09203576f, |
| 610 | 0.032950465f, -0.10985798f, -0.023809856f, 0.0021431844f, |
| 611 | -0.02196096f, -0.00326074f, 0.00058621005f, -0.074678116f, |
| 612 | -0.06193199f, 0.055729095f, 0.03736828f, 0.020123724f, |
| 613 | 0.061878487f, -0.04729229f, 0.034919553f, -0.07585433f, |
| 614 | -0.04421272f, -0.044019096f, 0.085488975f, 0.04058006f, |
| 615 | -0.06890133f, -0.030951202f, -0.024628663f, -0.07672815f, |
| 616 | 0.034293607f, 0.08556707f, -0.05293577f, -0.033561368f, |
| 617 | -0.04899627f, 0.0241671f, 0.015736353f, -0.095442444f, |
| 618 | -0.029564252f, 0.016493602f, -0.035026584f, 0.022337519f, |
| 619 | -0.026871363f, 0.004780428f, 0.0077918363f, -0.03601621f, |
| 620 | 0.016435321f, -0.03263031f, -0.09543275f, -0.047392778f, |
| 621 | 0.013454138f, 0.028934088f, 0.01685226f, -0.086110644f, |
| 622 | -0.046250615f, -0.01847454f, 0.047608484f, 0.07339695f, |
| 623 | 0.034546845f, -0.04881143f, 0.009128804f, -0.08802852f, |
| 624 | 0.03761666f, 0.008096139f, -0.014454086f, 0.014361001f, |
| 625 | -0.023502491f, -0.0011840804f, -0.07607001f, 0.001856849f, |
| 626 | -0.06509276f, -0.006021153f, -0.08570962f, -0.1451793f, |
| 627 | 0.060212336f, 0.055259194f, 0.06974018f, 0.049454916f, |
| 628 | -0.027794661f, -0.08077226f, -0.016179763f, 0.1169753f, |
| 629 | 0.17213494f, -0.0056326236f, -0.053934924f, -0.0124349f, |
| 630 | -0.11520337f, 0.05409887f, 0.088759385f, 0.0019655675f, |
| 631 | 0.0042065294f, 0.03881498f, 0.019844765f, 0.041858196f, |
| 632 | -0.05695512f, 0.047233116f, 0.038937137f, -0.06542224f, |
| 633 | 0.014429736f, -0.09719407f, 0.13908425f, -0.05379757f, |
| 634 | 0.012321099f, 0.082840554f, -0.029899208f, 0.044217527f, |
| 635 | 0.059855383f, 0.07711018f, -0.045319796f, 0.0948846f, |
| 636 | -0.011724666f, -0.0033288454f, -0.033542685f, -0.04764985f, |
| 637 | -0.13873616f, 0.040668588f, 0.034832682f, -0.015319203f, |
| 638 | -0.018715994f, 0.046002675f, 0.0599172f, -0.043107376f, |
| 639 | 0.0294216f, -0.002314414f, -0.022424703f, 0.0030315618f, |
| 640 | 0.0014641669f, 0.0029166266f, -0.11878115f, 0.013738511f, |
| 641 | 0.12375372f, -0.0006038222f, 0.029104086f, 0.087442465f, |
| 642 | 0.052958444f, 0.07558703f, 0.04817258f, 0.044462286f, |
| 643 | -0.015213451f, -0.08783778f, -0.0561384f, -0.003008196f, |
| 644 | 0.047060397f, -0.002058388f, 0.03429439f, -0.018839769f, |
| 645 | 0.024734668f, 0.024614193f, -0.042046934f, 0.09597743f, |
| 646 | -0.0043254104f, 0.04320769f, 0.0064070094f, -0.0019131786f, |
| 647 | -0.02558259f, -0.022822596f, -0.023273505f, -0.02464396f, |
| 648 | -0.10991725f, -0.006240552f, 0.0074488563f, 0.024044557f, |
| 649 | 0.04383914f, -0.046476185f, 0.028658995f, 0.060410924f, |
| 650 | 0.050786525f, 0.009452605f, -0.0073054377f, -0.024810238f, |
| 651 | 0.0052906186f, 0.0066939713f, -0.0020913032f, 0.014515517f, |
| 652 | 0.015898481f, 0.021362653f, -0.030262267f, 0.016587038f, |
| 653 | -0.011442813f, 0.041154444f, -0.007631438f, -0.03423484f, |
| 654 | -0.010977775f, 0.036152758f, 0.0066366293f, 0.11915515f, |
| 655 | 0.02318443f, -0.041350313f, 0.021485701f, -0.10906167f, |
| 656 | -0.028218046f, -0.00954771f, 0.020531068f, -0.11995105f, |
| 657 | -0.03672871f, 0.024019798f, 0.014255957f, -0.05221243f, |
| 658 | -0.00661567f, -0.04630967f, 0.033188973f, 0.10107534f, |
| 659 | -0.014027541f, 0.030796422f, -0.10270911f, -0.035999842f, |
| 660 | 0.15443139f, 0.07684145f, 0.036571592f, -0.035900835f, |
| 661 | -0.0034699554f, 0.06209149f, 0.015920248f, -0.031122351f, |
| 662 | -0.03858649f, 0.01849943f, 0.13872518f, 0.01503974f, |
| 663 | 0.069941424f, -0.06948533f, -0.0088794185f, 0.061282158f, |
| 664 | -0.047401894f, 0.03100163f, -0.041533746f, -0.10430945f, |
| 665 | 0.044574402f, -0.01425562f, -0.024290353f, 0.034563623f, |
| 666 | 0.05866852f, 0.023947537f, -0.09445152f, 0.035450947f, |
| 667 | 0.02247216f, -0.0042998926f, 0.061146557f, -0.10250651f, |
| 668 | 0.020881841f, -0.06747029f, 0.10062043f, -0.0023941975f, |
| 669 | 0.03532124f, -0.016341697f, 0.09685456f, -0.016764693f, |
| 670 | 0.051808182f, 0.05875331f, -0.04536488f, 0.001626336f, |
| 671 | -0.028892258f, -0.01048663f, -0.009793449f, -0.017093895f, |
| 672 | 0.010987891f, 0.02357273f, -0.00010856845f, 0.0099760275f, |
| 673 | -0.001845119f, -0.03551521f, 0.0018358806f, 0.05763657f, |
| 674 | -0.01769146f, 0.040995963f, 0.02235177f, -0.060430344f, |
| 675 | 0.11475477f, -0.023854522f, 0.10071741f, 0.0686208f, |
| 676 | -0.014250481f, 0.034261297f, 0.047418304f, 0.08562733f, |
| 677 | -0.030519066f, 0.0060542435f, 0.014653856f, -0.038836084f, |
| 678 | 0.04096551f, 0.032249358f, -0.08355519f, -0.026823482f, |
| 679 | 0.056386515f, -0.010401743f, -0.028396193f, 0.08507674f, |
| 680 | 0.014410365f, 0.020995233f, 0.17040324f, 0.11511526f, |
| 681 | 0.02459721f, 0.0066619175f, 0.025853224f, -0.023133837f, |
| 682 | -0.081302024f, 0.017264642f, -0.009585969f, 0.09491168f, |
| 683 | -0.051313367f, 0.054532815f, -0.014298593f, 0.10657464f, |
| 684 | 0.007076659f, 0.10964551f, 0.0409152f, 0.008275321f, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 685 | -0.07283536f, 0.07937492f, 0.04192024f, -0.1075027f }; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 686 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 687 | std::vector<float> recurrentToCellWeights = { -0.037322544f, 0.018592842f, 0.0056175636f, -0.06253426f, |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 688 | 0.055647098f, -0.05713207f, -0.05626563f, 0.005559383f, |
| 689 | 0.03375411f, -0.025757805f, -0.088049285f, 0.06017052f, |
| 690 | -0.06570978f, 0.007384076f, 0.035123326f, -0.07920549f, |
| 691 | 0.053676967f, 0.044480428f, -0.07663568f, 0.0071805613f, |
| 692 | 0.08089997f, 0.05143358f, 0.038261272f, 0.03339287f, |
| 693 | -0.027673481f, 0.044746667f, 0.028349208f, 0.020090483f, |
| 694 | -0.019443132f, -0.030755889f, -0.0040000007f, 0.04465846f, |
| 695 | -0.021585021f, 0.0031670958f, 0.0053199246f, -0.056117613f, |
| 696 | -0.10893326f, 0.076739706f, -0.08509834f, -0.027997585f, |
| 697 | 0.037871376f, 0.01449768f, -0.09002357f, -0.06111149f, |
| 698 | -0.046195522f, 0.0422062f, -0.005683705f, -0.1253618f, |
| 699 | -0.012925729f, -0.04890792f, 0.06985068f, 0.037654128f, |
| 700 | 0.03398274f, -0.004781977f, 0.007032333f, -0.031787455f, |
| 701 | 0.010868644f, -0.031489216f, 0.09525667f, 0.013939797f, |
| 702 | 0.0058680447f, 0.0167067f, 0.02668468f, -0.04797466f, |
| 703 | -0.048885044f, -0.12722108f, 0.035304096f, 0.06554885f, |
| 704 | 0.00972396f, -0.039238118f, -0.05159735f, -0.11329045f, |
| 705 | 0.1613692f, -0.03750952f, 0.06529313f, -0.071974665f, |
| 706 | -0.11769596f, 0.015524369f, -0.0013754242f, -0.12446318f, |
| 707 | 0.02786344f, -0.014179351f, 0.005264273f, 0.14376344f, |
| 708 | 0.015983658f, 0.03406988f, -0.06939408f, 0.040699873f, |
| 709 | 0.02111075f, 0.09669095f, 0.041345075f, -0.08316494f, |
| 710 | -0.07684199f, -0.045768797f, 0.032298047f, -0.041805092f, |
| 711 | 0.0119405f, 0.0061010392f, 0.12652606f, 0.0064572375f, |
| 712 | -0.024950314f, 0.11574242f, 0.04508852f, -0.04335324f, |
| 713 | 0.06760663f, -0.027437469f, 0.07216407f, 0.06977076f, |
| 714 | -0.05438599f, 0.034033038f, -0.028602652f, 0.05346137f, |
| 715 | 0.043184172f, -0.037189785f, 0.10420091f, 0.00882477f, |
| 716 | -0.054019816f, -0.074273005f, -0.030617684f, -0.0028467078f, |
| 717 | 0.024302477f, -0.0038869337f, 0.005332455f, 0.0013399826f, |
| 718 | 0.04361412f, -0.007001822f, 0.09631092f, -0.06702025f, |
| 719 | -0.042049985f, -0.035070654f, -0.04103342f, -0.10273396f, |
| 720 | 0.0544271f, 0.037184782f, -0.13150354f, -0.0058036847f, |
| 721 | -0.008264958f, 0.042035464f, 0.05891794f, 0.029673764f, |
| 722 | 0.0063542654f, 0.044788733f, 0.054816857f, 0.062257513f, |
| 723 | -0.00093483756f, 0.048938446f, -0.004952862f, -0.007730018f, |
| 724 | -0.04043371f, -0.017094059f, 0.07229206f, -0.023670016f, |
| 725 | -0.052195564f, -0.025616996f, -0.01520939f, 0.045104615f, |
| 726 | -0.007376126f, 0.003533447f, 0.006570588f, 0.056037236f, |
| 727 | 0.12436656f, 0.051817212f, 0.028532185f, -0.08686856f, |
| 728 | 0.11868599f, 0.07663395f, -0.07323171f, 0.03463402f, |
| 729 | -0.050708205f, -0.04458982f, -0.11590894f, 0.021273347f, |
| 730 | 0.1251325f, -0.15313013f, -0.12224372f, 0.17228661f, |
| 731 | 0.023029093f, 0.086124025f, 0.006445803f, -0.03496501f, |
| 732 | 0.028332196f, 0.04449512f, -0.042436164f, -0.026587414f, |
| 733 | -0.006041347f, -0.09292539f, -0.05678812f, 0.03897832f, |
| 734 | 0.09465633f, 0.008115513f, -0.02171956f, 0.08304309f, |
| 735 | 0.071401566f, 0.019622514f, 0.032163795f, -0.004167056f, |
| 736 | 0.02295182f, 0.030739572f, 0.056506045f, 0.004612461f, |
| 737 | 0.06524936f, 0.059999723f, 0.046395954f, -0.0045512207f, |
| 738 | -0.1335546f, -0.030136576f, 0.11584653f, -0.014678886f, |
| 739 | 0.0020118146f, -0.09688814f, -0.0790206f, 0.039770417f, |
| 740 | -0.0329582f, 0.07922767f, 0.029322514f, 0.026405897f, |
| 741 | 0.04207835f, -0.07073373f, 0.063781224f, 0.0859677f, |
| 742 | -0.10925287f, -0.07011058f, 0.048005477f, 0.03438226f, |
| 743 | -0.09606514f, -0.006669445f, -0.043381985f, 0.04240257f, |
| 744 | -0.06955775f, -0.06769346f, 0.043903265f, -0.026784198f, |
| 745 | -0.017840602f, 0.024307009f, -0.040079936f, -0.019946516f, |
| 746 | 0.045318738f, -0.12233574f, 0.026170589f, 0.0074471775f, |
| 747 | 0.15978073f, 0.10185836f, 0.10298046f, -0.015476589f, |
| 748 | -0.039390966f, -0.072174534f, 0.0739445f, -0.1211869f, |
| 749 | -0.0347889f, -0.07943156f, 0.014809798f, -0.12412325f, |
| 750 | -0.0030663363f, 0.039695457f, 0.0647603f, -0.08291318f, |
| 751 | -0.018529687f, -0.004423833f, 0.0037507233f, 0.084633216f, |
| 752 | -0.01514876f, -0.056505352f, -0.012800942f, -0.06994386f, |
| 753 | 0.012962922f, -0.031234352f, 0.07029052f, 0.016418684f, |
| 754 | 0.03618972f, 0.055686004f, -0.08663945f, -0.017404709f, |
| 755 | -0.054761406f, 0.029065743f, 0.052404847f, 0.020238016f, |
| 756 | 0.0048197987f, -0.0214882f, 0.07078733f, 0.013016777f, |
| 757 | 0.06262858f, 0.009184685f, 0.020785125f, -0.043904778f, |
| 758 | -0.0270329f, -0.03299152f, -0.060088247f, -0.015162964f, |
| 759 | -0.001828936f, 0.12642565f, -0.056757294f, 0.013586685f, |
| 760 | 0.09232601f, -0.035886683f, 0.06000002f, 0.05229691f, |
| 761 | -0.052580316f, -0.082029596f, -0.010794592f, 0.012947712f, |
| 762 | -0.036429964f, -0.085508935f, -0.13127148f, -0.017744139f, |
| 763 | 0.031502828f, 0.036232427f, -0.031581745f, 0.023051167f, |
| 764 | -0.05325106f, -0.03421577f, 0.028793324f, -0.034633752f, |
| 765 | -0.009881397f, -0.043551125f, -0.018609839f, 0.0019097115f, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 766 | -0.008799762f, 0.056595087f, 0.0022273948f, 0.055752404f }; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 767 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 768 | std::vector<float> recurrentToOutputWeights = { 0.025825322f, -0.05813119f, 0.09495884f,-0.045984812f, -0.01255415f, |
| 769 | -0.0026479573f,-0.08196161f,-0.054914974f,-0.0046604523f, |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 770 | -0.029587349f, -0.044576716f, -0.07480124f, -0.082868785f, |
| 771 | 0.023254942f, 0.027502948f, -0.0039728214f, -0.08683098f, |
| 772 | -0.08116779f, -0.014675607f, -0.037924774f, -0.023314456f, |
| 773 | -0.007401714f, -0.09255757f, 0.029460307f, -0.08829125f, |
| 774 | -0.005139627f, -0.08989442f, -0.0555066f, 0.13596267f, |
| 775 | -0.025062224f, -0.048351806f, -0.03850004f, 0.07266485f, |
| 776 | -0.022414139f, 0.05940088f, 0.075114764f, 0.09597592f, |
| 777 | -0.010211725f, -0.0049794707f, -0.011523867f, -0.025980417f, |
| 778 | 0.072999895f, 0.11091378f, -0.081685916f, 0.014416728f, |
| 779 | 0.043229222f, 0.034178585f, -0.07530371f, 0.035837382f, |
| 780 | -0.085607f, -0.007721233f, -0.03287832f, -0.043848954f, |
| 781 | -0.06404588f, -0.06632928f, -0.073643476f, 0.008214239f, |
| 782 | -0.045984086f, 0.039764922f, 0.03474462f, 0.060612556f, |
| 783 | -0.080590084f, 0.049127717f, 0.04151091f, -0.030063879f, |
| 784 | 0.008801774f, -0.023021035f, -0.019558564f, 0.05158114f, |
| 785 | -0.010947698f, -0.011825728f, 0.0075720972f, 0.0699727f, |
| 786 | -0.0039981045f, 0.069350146f, 0.08799282f, 0.016156472f, |
| 787 | 0.035502106f, 0.11695009f, 0.006217345f, 0.13392477f, |
| 788 | -0.037875112f, 0.025745004f, 0.08940699f, -0.00924166f, |
| 789 | 0.0046702605f, -0.036598757f, -0.08811812f, 0.10522024f, |
| 790 | -0.032441203f, 0.008176899f, -0.04454919f, 0.07058152f, |
| 791 | 0.0067963637f, 0.039206743f, 0.03259838f, 0.03725492f, |
| 792 | -0.09515802f, 0.013326398f, -0.052055415f, -0.025676316f, |
| 793 | 0.03198509f, -0.015951829f, -0.058556724f, 0.036879618f, |
| 794 | 0.043357447f, 0.028362012f, -0.05908629f, 0.0059240665f, |
| 795 | -0.04995891f, -0.019187413f,0.0276265f, -0.01628143f, 0.0025863599f, |
| 796 | 0.08800015f, 0.035250366f, -0.022165963f, -0.07328642f, |
| 797 | -0.009415526f, -0.07455109f, 0.11690406f, 0.0363299f, |
| 798 | 0.07411125f, 0.042103454f, -0.009660886f, 0.019076364f, |
| 799 | 0.018299393f, -0.046004917f, 0.08891175f,0.0431396f, -0.026327137f, |
| 800 | -0.051502608f, 0.08979574f, -0.051670972f, 0.04940282f, |
| 801 | -0.07491107f, -0.021240504f, 0.022596184f, -0.034280192f, |
| 802 | 0.060163025f, -0.058211457f, -0.051837247f, -0.01349775f, |
| 803 | -0.04639988f, -0.035936575f, -0.011681591f, 0.064818054f, |
| 804 | 0.0073146066f, -0.021745546f, -0.043124277f, -0.06471268f, |
| 805 | -0.07053354f, -0.029321948f, -0.05330136f, 0.016933719f, |
| 806 | -0.053782392f, 0.13747959f, -0.1361751f, -0.11569455f, |
| 807 | 0.0033329215f, 0.05693899f, -0.053219706f, 0.063698f, |
| 808 | 0.07977434f, -0.07924483f, 0.06936997f, 0.0034815092f, |
| 809 | -0.007305279f, -0.037325785f, -0.07251102f, -0.033633437f, |
| 810 | -0.08677009f, 0.091591336f, -0.14165086f, 0.021752775f, |
| 811 | 0.019683983f, 0.0011612234f, -0.058154266f, 0.049996935f, |
| 812 | 0.0288841f, -0.0024567875f, -0.14345716f, 0.010955264f,-0.10234828f, |
| 813 | 0.1183656f, -0.0010731248f, -0.023590032f,-0.072285876f,-0.0724771f, |
| 814 | -0.026382286f, -0.0014920527f, 0.042667855f, 0.0018776858f, |
| 815 | 0.02986552f, 0.009814309f, 0.0733756f, 0.12289186f, |
| 816 | 0.018043943f, -0.0458958f, 0.049412545f, 0.033632483f, |
| 817 | 0.05495232f, 0.036686596f, -0.013781798f, -0.010036754f, |
| 818 | 0.02576849f, -0.08307328f, 0.010112348f, 0.042521734f, |
| 819 | -0.05869831f, -0.071689695f, 0.03876447f, -0.13275425f, -0.0352966f, |
| 820 | -0.023077697f, 0.10285965f, 0.084736146f, 0.15568255f, |
| 821 | -0.00040734606f, 0.027835453f, -0.10292561f, -0.032401145f, |
| 822 | 0.10053256f, -0.026142767f, -0.08271222f, -0.0030240538f, |
| 823 | -0.016368777f, 0.1070414f, 0.042672627f, 0.013456989f, |
| 824 | -0.0437609f, -0.022309763f, 0.11576483f, 0.04108048f, |
| 825 | 0.061026827f, -0.0190714f, -0.0869359f, 0.037901703f, 0.0610107f, |
| 826 | 0.07202949f, 0.01675338f, 0.086139716f, -0.08795751f, |
| 827 | -0.014898893f, -0.023771819f, -0.01965048f, 0.007955471f, |
| 828 | -0.043740474f, 0.03346837f, -0.10549954f, 0.090567775f, |
| 829 | 0.042013682f, -0.03176985f, 0.12569028f, -0.02421228f, |
| 830 | -0.029526481f, 0.023851605f, 0.031539805f, 0.05292009f, |
| 831 | -0.02344001f, -0.07811758f, -0.08834428f, 0.10094801f, |
| 832 | 0.16594367f, -0.06861939f, -0.021256343f, -0.041093912f, |
| 833 | -0.06669611f, 0.035498552f, 0.021757556f, -0.09302526f, |
| 834 | -0.015403468f, -0.06614931f, -0.051798206f, -0.013874718f, |
| 835 | 0.03630673f, 0.010412845f, -0.08077351f, 0.046185967f, |
| 836 | 0.0035662893f, 0.03541868f, -0.094149634f, -0.034814864f, |
| 837 | 0.003128424f, -0.020674974f, -0.03944324f, -0.008110165f, |
| 838 | -0.11113267f, 0.08484226f, 0.043586485f, 0.040582247f, |
| 839 | 0.0968012f, -0.065249965f, -0.028036479f, 0.0050708856f, |
| 840 | 0.0017462453f, 0.0326779f, 0.041296225f, 0.09164146f, |
| 841 | -0.047743853f, -0.015952192f, -0.034451712f, 0.084197424f, |
| 842 | -0.05347844f, -0.11768019f, 0.085926116f, -0.08251791f, |
| 843 | -0.045081906f, 0.0948852f, 0.068401024f, 0.024856757f, |
| 844 | 0.06978981f, -0.057309967f, -0.012775832f, -0.0032452994f, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 845 | 0.01977615f, -0.041040014f, -0.024264973f,0.063464895f, 0.05431621f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 846 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 847 | std::vector<float> cellToInputWeights = {0.040369894f, 0.030746894f, 0.24704495f, 0.018586371f, -0.037586458f, |
| 848 | -0.15312155f, -0.11812848f, -0.11465643f, 0.20259799f, 0.11418174f, |
| 849 | -0.10116027f, -0.011334949f, 0.12411352f, -0.076769054f,-0.052169047f, |
| 850 | 0.21198851f, -0.38871562f, -0.09061183f, -0.09683246f, -0.21929175f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 851 | |
| 852 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 853 | std::vector<float> cellToForgetWeights = {-0.01998659f,-0.15568835f,-0.24248174f, -0.012770197f, 0.041331276f, |
| 854 | -0.072311886f, -0.052123554f,-0.0066330447f,-0.043891653f,0.036225766f, |
| 855 | -0.047248036f, 0.021479502f,0.033189066f, 0.11952997f, -0.020432774f, |
| 856 | 0.64658105f, -0.06650122f, -0.03467612f, 0.095340036f, 0.23647355f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 857 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 858 | std::vector<float> cellToOutputWeights = { 0.08286371f, -0.08261836f, -0.51210177f, 0.002913762f, 0.17764764f, |
| 859 | -0.5495371f, -0.08460716f, -0.24552552f, 0.030037103f, 0.04123544f, |
| 860 | -0.11940523f, 0.007358328f, 0.1890978f, 0.4833202f, -0.34441817f, |
| 861 | 0.36312827f, -0.26375428f, 0.1457655f, -0.19724406f, 0.15548733f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 862 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 863 | std::vector<float> projectionWeights={-0.009802181f, 0.09401916f, 0.0717386f, -0.13895074f, 0.09641832f, |
| 864 | 0.060420845f, 0.08539281f, 0.054285463f, 0.061395317f, 0.034448683f, |
| 865 | -0.042991187f, 0.019801661f, -0.16840284f, -0.015726732f, -0.23041931f, |
| 866 | -0.024478018f, -0.10959692f, -0.013875541f, 0.18600968f, -0.061274476f, |
| 867 | 0.0138165f, -0.08160894f, -0.07661644f, 0.032372914f, 0.16169067f, |
| 868 | 0.22465782f, -0.03993472f, -0.004017731f, 0.08633481f, -0.28869787f, |
| 869 | 0.08682067f, 0.17240396f, 0.014975425f, 0.056431185f, 0.031037588f, |
| 870 | 0.16702051f, 0.0077946745f, 0.15140012f, 0.29405436f, 0.120285f, |
| 871 | -0.188994f, -0.027265169f, 0.043389652f, -0.022061434f, 0.014777949f, |
| 872 | -0.20203483f, 0.094781205f, 0.19100232f, 0.13987629f, -0.036132768f, |
| 873 | -0.06426278f, -0.05108664f, 0.13221376f, 0.009441198f, -0.16715929f, |
| 874 | 0.15859416f, -0.040437475f, 0.050779544f, -0.022187516f, 0.012166504f, |
| 875 | 0.027685808f, -0.07675938f, -0.0055694645f, -0.09444123f, 0.0046453946f, |
| 876 | 0.050794356f, 0.10770313f, -0.20790008f, -0.07149004f, -0.11425117f, |
| 877 | 0.008225835f, -0.035802525f, 0.14374903f, 0.15262283f, 0.048710253f, |
| 878 | 0.1847461f, -0.007487823f, 0.11000021f, -0.09542012f, 0.22619456f, |
| 879 | -0.029149994f, 0.08527916f, 0.009043713f, 0.0042746216f, 0.016261552f, |
| 880 | 0.022461696f, 0.12689082f, -0.043589946f, -0.12035478f, -0.08361797f, |
| 881 | -0.050666027f, -0.1248618f, -0.1275799f, -0.071875185f, 0.07377272f, |
| 882 | 0.09944291f, -0.18897448f, -0.1593054f, -0.06526116f, -0.040107165f, |
| 883 | -0.004618631f, -0.067624845f, -0.007576253f, 0.10727444f, 0.041546922f, |
| 884 | -0.20424393f, 0.06907816f, 0.050412357f, 0.00724631f, 0.039827548f, |
| 885 | 0.12449835f, 0.10747581f, 0.13708383f, 0.09134148f, -0.12617786f, |
| 886 | -0.06428341f, 0.09956831f, 0.1208086f, -0.14676677f, -0.0727722f, |
| 887 | 0.1126304f, 0.010139365f, 0.015571211f, -0.038128063f, 0.022913318f, |
| 888 | -0.042050496f, 0.16842307f, -0.060597885f, 0.10531834f, -0.06411776f, |
| 889 | -0.07451711f, -0.03410368f, -0.13393489f, 0.06534304f, 0.003620307f, |
| 890 | 0.04490757f, 0.05970546f, 0.05197996f, 0.02839995f, 0.10434969f, |
| 891 | -0.013699693f, -0.028353551f, -0.07260381f, 0.047201227f, -0.024575593f, |
| 892 | -0.036445823f, 0.07155557f, 0.009672501f, -0.02328883f, 0.009533515f, |
| 893 | -0.03606021f, -0.07421458f, -0.028082801f, -0.2678904f, -0.13221288f, |
| 894 | 0.18419984f, -0.13012612f, -0.014588381f, -0.035059117f, -0.04824723f, |
| 895 | 0.07830115f, -0.056184657f, 0.03277091f, 0.025466874f, 0.14494097f, |
| 896 | -0.12522776f, -0.098633975f, -0.10766018f, -0.08317623f, 0.08594209f, |
| 897 | 0.07749552f, 0.039474737f, 0.1776665f, -0.07409566f, -0.0477268f, |
| 898 | 0.29323658f, 0.10801441f, 0.1154011f, 0.013952499f, 0.10739139f, |
| 899 | 0.10708251f, -0.051456142f, 0.0074137426f, -0.10430189f, 0.10034707f, |
| 900 | 0.045594677f, 0.0635285f, -0.0715442f, -0.089667566f, -0.10811871f, |
| 901 | 0.00026344223f, 0.08298446f, -0.009525053f, 0.006585689f, -0.24567553f, |
| 902 | -0.09450807f, 0.09648481f, 0.026996298f, -0.06419476f, -0.04752702f, |
| 903 | -0.11063944f, -0.23441927f, -0.17608605f, -0.052156363f, 0.067035615f, |
| 904 | 0.19271925f, -0.0032889997f, -0.043264326f, 0.09663576f, -0.057112187f, |
| 905 | -0.10100678f, 0.0628376f, 0.04447668f, 0.017961001f, -0.10094388f, |
| 906 | -0.10190601f, 0.18335468f, 0.10494553f, -0.052095775f, -0.0026118709f, |
| 907 | 0.10539724f, -0.04383912f, -0.042349473f, 0.08438151f, -0.1947263f, |
| 908 | 0.02251204f, 0.11216432f, -0.10307853f, 0.17351969f, -0.039091777f, |
| 909 | 0.08066188f, -0.00561982f, 0.12633002f, 0.11335965f, -0.0088127935f, |
| 910 | -0.019777594f, 0.06864014f, -0.059751723f, 0.016233567f, -0.06894641f, |
| 911 | -0.28651384f, -0.004228674f, 0.019708522f, -0.16305895f, -0.07468996f, |
| 912 | -0.0855457f, 0.099339016f, -0.07580735f, -0.13775392f, 0.08434318f, |
| 913 | 0.08330512f, -0.12131499f, 0.031935584f, 0.09180414f, -0.08876437f, |
| 914 | -0.08049874f, 0.008753825f, 0.03498998f, 0.030215185f, 0.03907079f, |
| 915 | 0.089751154f, 0.029194152f, -0.03337423f, -0.019092513f, 0.04331237f, |
| 916 | 0.04299654f, -0.036394123f, -0.12915532f, 0.09793732f, 0.07512415f, |
| 917 | -0.11319543f, -0.032502122f, 0.15661901f, 0.07671967f, -0.005491124f, |
| 918 | -0.19379048f, -0.218606f, 0.21448623f, 0.017840758f, 0.1416943f, |
| 919 | -0.07051762f, 0.19488361f, 0.02664691f, -0.18104725f, -0.09334311f, |
| 920 | 0.15026465f, -0.15493552f, -0.057762887f, -0.11604192f, -0.262013f, |
| 921 | -0.01391798f, 0.012185008f, 0.11156489f, -0.07483202f, 0.06693364f, |
| 922 | -0.26151478f, 0.046425626f, 0.036540434f, -0.16435726f, 0.17338543f, |
| 923 | -0.21401681f, -0.11385144f, -0.08283257f, -0.069031075f, 0.030635102f, |
| 924 | 0.010969227f, 0.11109743f, 0.010919218f, 0.027526086f, 0.13519906f, |
| 925 | 0.01891392f, -0.046839405f, -0.040167913f, 0.017953383f, -0.09700955f, |
| 926 | 0.0061885654f, -0.07000971f, 0.026893595f, -0.038844477f, 0.14543656f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 927 | |
| 928 | std::vector<float> projectionBiasVector(outputSize, 0.f); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 929 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 930 | armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo20x5); |
| 931 | armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo20x5); |
| 932 | armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo20x5); |
| 933 | armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo20x5); |
| 934 | armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo20x16); |
| 935 | armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo20x16); |
| 936 | armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo20x16); |
| 937 | armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo20x16); |
| 938 | armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo20); |
| 939 | armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo20); |
| 940 | armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo20); |
| 941 | armnn::ScopedTensorHandle cellBiasTensor(tensorInfo20); |
| 942 | armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo20); |
| 943 | armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfo20); |
| 944 | armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfo20); |
| 945 | armnn::ScopedTensorHandle projectionWeightsTensor(tensorInfo16x20); |
| 946 | armnn::ScopedTensorHandle projectionBiasTensor(tensorInfo16); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 947 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 948 | AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data()); |
| 949 | AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data()); |
| 950 | AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data()); |
| 951 | AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data()); |
| 952 | AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data()); |
| 953 | AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data()); |
| 954 | AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data()); |
| 955 | AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data()); |
| 956 | AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data()); |
| 957 | AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data()); |
| 958 | AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data()); |
| 959 | AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data()); |
| 960 | AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data()); |
| 961 | AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data()); |
| 962 | AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data()); |
| 963 | AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data()); |
| 964 | AllocateAndCopyDataToITensorHandle(&projectionBiasTensor, projectionBiasVector.data()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 965 | |
| 966 | data.m_InputToInputWeights = &inputToInputWeightsTensor; |
| 967 | data.m_InputToForgetWeights = &inputToForgetWeightsTensor; |
| 968 | data.m_InputToCellWeights = &inputToCellWeightsTensor; |
| 969 | data.m_InputToOutputWeights = &inputToOutputWeightsTensor; |
| 970 | data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; |
| 971 | data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; |
| 972 | data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; |
| 973 | data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; |
| 974 | data.m_CellToInputWeights = &cellToInputWeightsTensor; |
| 975 | data.m_InputGateBias = &inputGateBiasTensor; |
| 976 | data.m_ForgetGateBias = &forgetGateBiasTensor; |
| 977 | data.m_CellBias = &cellBiasTensor; |
| 978 | data.m_OutputGateBias = &outputGateBiasTensor; |
| 979 | data.m_CellToForgetWeights = &cellToForgetWeightsTensor; |
| 980 | data.m_CellToOutputWeights = &cellToOutputWeightsTensor; |
| 981 | data.m_ProjectionWeights = &projectionWeightsTensor; |
| 982 | data.m_ProjectionBias = &projectionBiasTensor; |
| 983 | |
| 984 | // Flags to set test configuration |
| 985 | data.m_Parameters.m_ActivationFunc = 4; |
| 986 | data.m_Parameters.m_CifgEnabled = false; |
| 987 | data.m_Parameters.m_PeepholeEnabled = true; |
| 988 | data.m_Parameters.m_ProjectionEnabled = true; |
| 989 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 990 | std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info); |
| 991 | inputHandle->Allocate(); |
| 992 | outputStateInHandle->Allocate(); |
| 993 | cellStateInHandle->Allocate(); |
| 994 | |
| 995 | scratchHandle->Allocate(); |
| 996 | outputStateOutHandle->Allocate(); |
| 997 | cellStateOutHandle->Allocate(); |
| 998 | outputHandle->Allocate(); |
| 999 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1000 | CopyDataToITensorHandle(inputHandle.get(), inputVector.data()); |
| 1001 | CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data()); |
| 1002 | CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1003 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1004 | workload->Execute(); |
| 1005 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1006 | CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1007 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1008 | return LayerTestResult<T, 2>(actualOutput, |
| 1009 | outputVector, |
| 1010 | outputHandle->GetShape(), |
| 1011 | outputTensorInfo.GetShape()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1012 | } |
| 1013 | |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 1014 | template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> |
| 1015 | LayerTestResult<T, 2> LstmLayerWithCifgWithPeepholeNoProjectionTestImpl( |
Aron Virginas-Tar | 5caf907 | 2018-11-14 18:35:18 +0000 | [diff] [blame] | 1016 | armnn::IWorkloadFactory& workloadFactory, |
| 1017 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1018 | const armnn::ITensorHandleFactory& tensorHandleFactory, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1019 | const std::vector<T>& input, |
| 1020 | const std::vector<T>& outputExpected, |
| 1021 | const armnn::TensorShape& inputShape, |
| 1022 | const armnn::TensorShape& outputExpectedShape, |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 1023 | float qScale = 0.0f, |
| 1024 | int32_t qOffset = 0, |
| 1025 | armnn::DataType constantDataType = armnn::DataType::Float32) |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1026 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 1027 | IgnoreUnused(memoryManager); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1028 | bool cifgEnabled = true; |
| 1029 | bool peepholeEnabled = true; |
| 1030 | bool projectionEnabled = false; |
| 1031 | // These are not the input and the output of Lstm yet |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1032 | unsigned int batchSize = armnn::numeric_cast<unsigned int>(inputShape[0]); |
| 1033 | unsigned int inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1034 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1035 | unsigned int outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1036 | |
| 1037 | const unsigned int cellSize = outputSize; |
| 1038 | |
| 1039 | // Decide the shape of all input tensors |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 1040 | armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset); // change to ArmnnType |
| 1041 | armnn::TensorInfo outputStateInTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset); |
| 1042 | armnn::TensorInfo cellStateInTensorInfo({batchSize, cellSize}, ArmnnType, qScale, qOffset); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1043 | |
Matteo Martincigh | a65b7ae | 2018-11-14 12:39:55 +0000 | [diff] [blame] | 1044 | unsigned int scratchBufferSize = cifgEnabled ? cellSize * 3 : cellSize * 4; |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 1045 | armnn::TensorInfo scratchBufferTensorInfo({batchSize, scratchBufferSize}, ArmnnType, qScale, qOffset); |
| 1046 | armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset); |
| 1047 | armnn::TensorInfo cellStateOutTensorInfo({batchSize, cellSize}, ArmnnType, qScale, qOffset); |
| 1048 | armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1049 | |
| 1050 | // List of inputs |
| 1051 | std::vector<float> inputData; |
| 1052 | inputData.assign(input.data(), input.data() + batchSize*inputSize); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1053 | |
| 1054 | std::vector<float> outputStateInVector(batchSize * outputSize, 0.f); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1055 | |
| 1056 | std::vector<float> cellStateInVector(batchSize * cellSize, 0.f); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1057 | |
| 1058 | // Prepare all the weights in the descriptor for LSTM |
| 1059 | armnn::LstmQueueDescriptor data; |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 1060 | armnn::TensorInfo tensorInfoInput({cellSize, inputSize}, constantDataType, qScale, qOffset); |
| 1061 | armnn::TensorInfo tensorInfoOutput({cellSize, outputSize}, constantDataType, qScale, qOffset); |
| 1062 | armnn::TensorInfo tensorInfoNumUnits({cellSize}, constantDataType, qScale, qOffset); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1063 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1064 | std::vector<float> inputToCellWeights = |
| 1065 | { |
| 1066 | -0.49770179f, -0.27711356f, -0.09624726f, 0.05100781f, |
| 1067 | 0.04717243f, 0.48944736f, -0.38535351f, |
| 1068 | -0.17212132f |
| 1069 | }; |
| 1070 | std::vector<float> inputToForgetWeights = |
| 1071 | { |
| 1072 | -0.55291498f, -0.42866567f, 0.13056988f, |
| 1073 | -0.3633365f, -0.22755712f, 0.28253698f, 0.24407166f, |
| 1074 | 0.33826375f |
| 1075 | }; |
| 1076 | std::vector<float> inputToOutputWeights = |
| 1077 | { |
| 1078 | 0.10725588f, -0.02335852f, -0.55932593f, |
| 1079 | -0.09426838f, -0.44257352f, 0.54939759f, |
| 1080 | 0.01533556f, 0.42751634f |
| 1081 | }; |
| 1082 | std::vector<float> cellBias = {0.f, 0.f, 0.f, 0.f}; |
| 1083 | std::vector<float> forgetGateBias = {1.f, 1.f, 1.f, 1.f}; |
| 1084 | std::vector<float> outputGateBias = {0.f, 0.f, 0.f, 0.f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1085 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1086 | std::vector<float> recurrentToCellWeights = |
| 1087 | { |
| 1088 | 0.54066205f, -0.32668582f, -0.43562764f, -0.56094903f, 0.42957711f, |
| 1089 | 0.01841056f, -0.32764608f, -0.33027974f, -0.10826075f, 0.20675004f, |
| 1090 | 0.19069612f, -0.03026325f, -0.54532051f, 0.33003211f, 0.44901288f, |
| 1091 | 0.21193194f |
| 1092 | }; |
| 1093 | std::vector<float> recurrentToForgetWeights = |
| 1094 | { |
| 1095 | -0.13832897f, -0.0515101f, -0.2359007f, -0.16661474f, -0.14340827f, |
| 1096 | 0.36986142f, 0.23414481f, 0.55899f, 0.10798943f, -0.41174671f, 0.17751795f, |
| 1097 | -0.34484994f, -0.35874045f, -0.11352962f, 0.27268326f, 0.54058349f |
| 1098 | }; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1099 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1100 | std::vector<float> recurrentToOutputWeights = |
| 1101 | { |
| 1102 | 0.41613156f, 0.42610586f, -0.16495961f, -0.5663873f, 0.30579174f, -0.05115908f, |
| 1103 | -0.33941799f, 0.23364776f, 0.11178309f, 0.09481031f, -0.26424935f, 0.46261835f, |
| 1104 | 0.50248802f, 0.26114327f, -0.43736315f, 0.33149987f |
| 1105 | }; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1106 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1107 | std::vector<float> cellToForgetWeights = {0.47485286f, -0.51955009f, -0.24458408f, 0.31544167f}; |
| 1108 | std::vector<float> cellToOutputWeights = {-0.17135078f, 0.82760304f, 0.85573703f, -0.77109635f}; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1109 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 1110 | armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfoInput); |
| 1111 | armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfoInput); |
| 1112 | armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfoInput); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1113 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 1114 | armnn::ScopedTensorHandle cellBiasTensor(tensorInfoNumUnits); |
| 1115 | armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfoNumUnits); |
| 1116 | armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfoNumUnits); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1117 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 1118 | armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfoOutput); |
| 1119 | armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfoOutput); |
| 1120 | armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfoOutput); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1121 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 1122 | armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfoNumUnits); |
| 1123 | armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfoNumUnits); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1124 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1125 | AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data()); |
| 1126 | AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data()); |
| 1127 | AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1128 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1129 | AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data()); |
| 1130 | AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data()); |
| 1131 | AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1132 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1133 | AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data()); |
| 1134 | AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data()); |
| 1135 | AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1136 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1137 | AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data()); |
| 1138 | AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1139 | |
| 1140 | data.m_InputToCellWeights = &inputToCellWeightsTensor; |
| 1141 | data.m_InputToForgetWeights = &inputToForgetWeightsTensor; |
| 1142 | data.m_InputToOutputWeights = &inputToOutputWeightsTensor; |
| 1143 | |
| 1144 | data.m_CellBias = &cellBiasTensor; |
| 1145 | data.m_ForgetGateBias = &forgetGateBiasTensor; |
| 1146 | data.m_OutputGateBias = &outputGateBiasTensor; |
| 1147 | |
| 1148 | data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; |
| 1149 | data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; |
| 1150 | data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; |
| 1151 | |
| 1152 | data.m_CellToForgetWeights = &cellToForgetWeightsTensor; |
| 1153 | data.m_CellToOutputWeights = &cellToOutputWeightsTensor; |
| 1154 | |
| 1155 | // other parameters for the descriptor |
| 1156 | data.m_Parameters.m_CifgEnabled = cifgEnabled; |
| 1157 | data.m_Parameters.m_ProjectionEnabled = projectionEnabled; |
| 1158 | data.m_Parameters.m_PeepholeEnabled = peepholeEnabled; |
| 1159 | |
| 1160 | data.m_Parameters.m_ActivationFunc = 4; |
| 1161 | data.m_Parameters.m_ClippingThresProj = 0.0; |
| 1162 | data.m_Parameters.m_ClippingThresCell = 0.0; |
| 1163 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1164 | // List of outputs |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 1165 | std::vector<T> scratchBufferVector(batchSize * scratchBufferSize, T()); |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 1166 | LayerTestResult<T, 2> ret0(scratchBufferTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1167 | |
| 1168 | // Output state for a certain time step |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 1169 | std::vector<T> outputStateOutVector(batchSize * outputSize, T()); |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 1170 | LayerTestResult<T, 2> ret1(outputStateOutTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1171 | |
| 1172 | // Cell state for a certain time step |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 1173 | std::vector<T> cellStateOutVector(batchSize * cellSize, T()); |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 1174 | LayerTestResult<T, 2> ret2(cellStateOutTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1175 | |
| 1176 | // Output for a certain time step |
Rob Hughes | bb46dde | 2020-05-20 15:27:37 +0100 | [diff] [blame] | 1177 | std::vector<T> outputData; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1178 | outputData.assign(outputExpected.data(), outputExpected.data() + batchSize*outputSize); |
Conor Kennedy | b9971c9 | 2019-05-07 07:14:23 +0100 | [diff] [blame] | 1179 | LayerTestResult<T, 2> ret3(outputTensorInfo); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1180 | ret3.m_ExpectedData = outputData; |
| 1181 | |
| 1182 | std::vector<T> actualScratchBufferOutput(scratchBufferTensorInfo.GetNumElements()); |
| 1183 | std::vector<T> actualOutputStateOutput(outputStateOutTensorInfo.GetNumElements()); |
| 1184 | std::vector<T> actualCellStateOutput(cellStateOutTensorInfo.GetNumElements()); |
| 1185 | std::vector<T> actualOutput(outputTensorInfo.GetNumElements()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1186 | |
| 1187 | // Prepare the inputs and outputs for the workload |
| 1188 | std::unique_ptr<armnn::ITensorHandle> inputHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1189 | tensorHandleFactory.CreateTensorHandle(inputTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1190 | std::unique_ptr<armnn::ITensorHandle> outputStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1191 | tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1192 | std::unique_ptr<armnn::ITensorHandle> cellStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1193 | tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1194 | |
| 1195 | std::unique_ptr<armnn::ITensorHandle> scratchBufferHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1196 | tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1197 | std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1198 | tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1199 | std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1200 | tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1201 | std::unique_ptr<armnn::ITensorHandle> outputHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1202 | tensorHandleFactory.CreateTensorHandle(outputTensorInfo); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1203 | |
| 1204 | armnn::WorkloadInfo info; |
| 1205 | AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get()); |
| 1206 | AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get()); |
| 1207 | AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get()); |
| 1208 | |
| 1209 | AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchBufferHandle.get()); |
| 1210 | AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get()); |
| 1211 | AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get()); |
| 1212 | AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get()); |
| 1213 | |
| 1214 | std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info); |
| 1215 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1216 | inputHandle->Allocate(); |
| 1217 | outputStateInHandle->Allocate(); |
| 1218 | cellStateInHandle->Allocate(); |
| 1219 | |
| 1220 | scratchBufferHandle->Allocate(); |
| 1221 | outputStateOutHandle->Allocate(); |
| 1222 | cellStateOutHandle->Allocate(); |
| 1223 | outputHandle->Allocate(); |
| 1224 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1225 | CopyDataToITensorHandle(inputHandle.get(), inputData.data()); |
| 1226 | CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data()); |
| 1227 | CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1228 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1229 | CopyDataToITensorHandle(scratchBufferHandle.get(), scratchBufferVector.data()); |
| 1230 | CopyDataToITensorHandle(outputStateOutHandle.get(), outputStateOutVector.data()); |
| 1231 | CopyDataToITensorHandle(cellStateOutHandle.get(), cellStateOutVector.data()); |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1232 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1233 | workload->Execute(); |
| 1234 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1235 | CopyDataFromITensorHandle(actualScratchBufferOutput.data(), scratchBufferHandle.get()); |
| 1236 | CopyDataFromITensorHandle(actualOutputStateOutput.data(), outputStateOutHandle.get()); |
| 1237 | CopyDataFromITensorHandle(actualCellStateOutput.data(), cellStateOutHandle.get()); |
| 1238 | CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get()); |
| 1239 | |
| 1240 | ret0.m_ActualData = actualScratchBufferOutput; |
| 1241 | ret1.m_ActualData = actualOutputStateOutput; |
| 1242 | ret2.m_ActualData = actualCellStateOutput; |
| 1243 | ret3.m_ActualData = actualOutput; |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 1244 | |
| 1245 | return ret3; |
| 1246 | } |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1247 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1248 | template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> |
| 1249 | LayerTestResult<T, 2> |
| 1250 | LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl(armnn::IWorkloadFactory& workloadFactory, |
| 1251 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1252 | const armnn::ITensorHandleFactory& tensorHandleFactory, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1253 | const std::vector<T>& input, |
| 1254 | const std::vector<T>& outputExpected, |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1255 | float qScale = 0.0f, |
| 1256 | int32_t qOffset = 0, |
| 1257 | armnn::DataType constantDataType = armnn::DataType::Float32) |
| 1258 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 1259 | IgnoreUnused(memoryManager); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1260 | unsigned int batchSize = 2; |
| 1261 | unsigned int outputSize = 3; |
| 1262 | unsigned int inputSize = 5; |
| 1263 | unsigned numUnits = 4; |
| 1264 | |
| 1265 | armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, ArmnnType, qScale, qOffset); |
| 1266 | armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, ArmnnType, qScale, qOffset); |
| 1267 | armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, ArmnnType, qScale, qOffset); |
| 1268 | |
| 1269 | // Scratch buffer size without CIFG [batchSize, numUnits * 4] |
| 1270 | armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, ArmnnType, qScale, qOffset); |
| 1271 | armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, ArmnnType, qScale, qOffset); |
| 1272 | armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset); |
| 1273 | armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, ArmnnType, qScale, qOffset); |
| 1274 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1275 | std::vector<float> inputVector; |
| 1276 | inputVector.assign(input.data(), input.data() + (batchSize * inputSize)); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1277 | |
| 1278 | std::vector<float> cellStateInVector(batchSize * numUnits, 0.f); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1279 | std::vector<float> outputStateInVector(batchSize * outputSize, 0.f); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1280 | std::vector<float> scratchBufferVector(batchSize * numUnits * 4, 0.f); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1281 | std::vector<float> outputStateOutVector(batchSize * outputSize, 0.f); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1282 | std::vector<float> cellStateOutVector(batchSize * numUnits, 0.f); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1283 | |
| 1284 | std::vector<float> actualOutput(outputTensorInfo.GetNumElements()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1285 | |
| 1286 | std::vector<float> outputVector; |
| 1287 | outputVector.assign(outputExpected.data(), outputExpected.data() + (batchSize * outputSize)); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1288 | |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1289 | std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputTensorInfo); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1290 | std::unique_ptr<armnn::ITensorHandle> cellStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1291 | tensorHandleFactory.CreateTensorHandle(cellStateInTensorInfo); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1292 | std::unique_ptr<armnn::ITensorHandle> outputStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1293 | tensorHandleFactory.CreateTensorHandle(outputStateInTensorInfo); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1294 | |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1295 | std::unique_ptr<armnn::ITensorHandle> scratchHandle = |
| 1296 | tensorHandleFactory.CreateTensorHandle(scratchBufferTensorInfo); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1297 | std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1298 | tensorHandleFactory.CreateTensorHandle(outputStateOutTensorInfo); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1299 | std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1300 | tensorHandleFactory.CreateTensorHandle(cellStateOutTensorInfo); |
| 1301 | std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputTensorInfo); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1302 | |
| 1303 | armnn::LstmQueueDescriptor data; |
| 1304 | armnn::WorkloadInfo info; |
| 1305 | |
| 1306 | AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get()); |
| 1307 | AddInputToWorkload(data, info, outputStateInTensorInfo, outputStateInHandle.get()); |
| 1308 | AddInputToWorkload(data, info, cellStateInTensorInfo, cellStateInHandle.get()); |
| 1309 | |
| 1310 | AddOutputToWorkload(data, info, scratchBufferTensorInfo, scratchHandle.get()); |
| 1311 | AddOutputToWorkload(data, info, outputStateOutTensorInfo, outputStateOutHandle.get()); |
| 1312 | AddOutputToWorkload(data, info, cellStateOutTensorInfo, cellStateOutHandle.get()); |
| 1313 | AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get()); |
| 1314 | |
| 1315 | armnn::TensorInfo tensorInfo3({outputSize}, constantDataType, qScale, qOffset); |
| 1316 | armnn::TensorInfo tensorInfo4({numUnits}, constantDataType, qScale, qOffset); |
| 1317 | armnn::TensorInfo tensorInfo4x5({numUnits, inputSize}, constantDataType, qScale, qOffset); |
| 1318 | armnn::TensorInfo tensorInfo4x3({numUnits, outputSize}, constantDataType, qScale, qOffset); |
| 1319 | armnn::TensorInfo tensorInfo3x4({outputSize, numUnits}, constantDataType, qScale, qOffset); |
| 1320 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1321 | std::vector<float> inputToInputWeights = {0.5f, 0.6f, 0.7f, -0.8f, -0.9f, |
| 1322 | 0.1f, 0.2f, 0.3f, -0.4f, 0.5f, |
| 1323 | -0.8f, 0.7f, -0.6f, 0.5f, -0.4f, |
| 1324 | -0.5f, -0.4f, -0.3f, -0.2f, -0.1f}; //{numUnits, inputSize} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1325 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1326 | std::vector<float> inputToForgetWeights = { -0.6f, -0.1f, 0.3f, 0.2f, 0.9f, |
| 1327 | -0.5f, -0.2f, -0.4f, 0.3f, -0.8f, |
| 1328 | -0.4f, 0.3f, -0.5f, -0.4f, -0.6f, |
| 1329 | 0.3f, -0.4f, -0.6f, -0.5f, -0.5f}; //{numUnits, inputSize} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1330 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1331 | std::vector<float> inputToCellWeights = {-0.4f, -0.3f, -0.2f, -0.1f, -0.5f, |
| 1332 | 0.5f, -0.2f, -0.3f, -0.2f, -0.6f, |
| 1333 | 0.6f, -0.1f, -0.4f, -0.3f, -0.7f, |
| 1334 | 0.7f, -0.9f, -0.5f, 0.8f, 0.6f}; //{numUnits, inputSize} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1335 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1336 | std::vector<float> inputToOutputWeights = {-0.8f, -0.4f, -0.2f, -0.9f, -0.1f, |
| 1337 | -0.7f, 0.3f, -0.3f, -0.8f, -0.2f, |
| 1338 | 0.6f, -0.2f, 0.4f, -0.7f, -0.3f, |
| 1339 | -0.5f, 0.1f, 0.5f, -0.6f, -0.4f}; //{numUnits, inputSize} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1340 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1341 | std::vector<float> inputGateBias = {0.03f, 0.15f, 0.22f, 0.38f}; //{numUnits} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1342 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1343 | std::vector<float> forgetGateBias = {0.1f, -0.3f, -0.2f, 0.1f}; //{numUnits} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1344 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1345 | std::vector<float> cellBias = {-0.05f, 0.72f, 0.25f, 0.08f}; //{numUnits} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1346 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1347 | std::vector<float> outputGateBias = {0.05f, -0.01f, 0.2f, 0.1f}; //{numUnits} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1348 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1349 | std::vector<float> recurrentToInputWeights ={-0.2f, -0.3f, 0.4f, |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1350 | 0.1f, -0.5f, 0.9f, |
| 1351 | -0.2f, -0.3f, -0.7f, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1352 | 0.05f, -0.2f, -0.6f}; //{numUnits, outputSize} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1353 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1354 | std::vector<float> recurrentToCellWeights = {-0.3f, 0.2f, 0.1f, |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1355 | -0.3f, 0.8f, -0.08f, |
| 1356 | -0.2f, 0.3f, 0.8f, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1357 | -0.6f, -0.1f, 0.2f}; //{numUnits, outputSize} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1358 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1359 | std::vector<float> recurrentToForgetWeights = { -0.5f, -0.3f, -0.5f, |
| 1360 | -0.2f, 0.6f, 0.4f, |
| 1361 | 0.9f, 0.3f, -0.1f, |
| 1362 | 0.2f, 0.5f, 0.2f}; //{numUnits, outputSize} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1363 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1364 | std::vector<float> recurrentToOutputWeights = { 0.3f, -0.1f, 0.1f, |
| 1365 | -0.2f, -0.5f, -0.7f, |
| 1366 | -0.2f, -0.6f, -0.1f, |
| 1367 | -0.4f, -0.7f, -0.2f}; //{numUnits, outputSize} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1368 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1369 | std::vector<float> cellToInputWeights = {0.05f, 0.1f, 0.25f, 0.15f}; //{numUnits} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1370 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1371 | std::vector<float> cellToForgetWeights = {-0.02f, -0.15f, -0.25f, -0.03f}; //{numUnits} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1372 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1373 | std::vector<float> cellToOutputWeights = {0.1f, -0.1f, -0.5f, 0.05f}; //{numUnits} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1374 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1375 | std::vector<float> projectionWeights = {-0.1f, 0.2f, 0.01f, -0.2f, |
| 1376 | 0.1f, 0.5f, 0.3f, 0.08f, |
| 1377 | 0.07f, 0.2f, -0.4f, 0.2f}; //{outputSize, numUnits} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1378 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1379 | std::vector<float> projectionBiasVector(outputSize, 0.f); //{outputSize} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1380 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1381 | std::vector<float> inputLayerNormWeights = {0.1f, 0.2f, 0.3f, 0.5f}; //{numUnits} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1382 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1383 | std::vector<float> forgetLayerNormWeights = {0.2f, 0.2f, 0.4f, 0.3f}; //{numUnits} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1384 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1385 | std::vector<float> cellLayerNormWeights = {0.7f, 0.2f, 0.3f, 0.8f}; //{numUnits} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1386 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1387 | std::vector<float> outputLayerNormWeights = {0.6f, 0.2f, 0.2f, 0.5f}; //{numUnits} |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1388 | |
| 1389 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 1390 | armnn::ScopedTensorHandle inputToInputWeightsTensor(tensorInfo4x5); |
| 1391 | armnn::ScopedTensorHandle inputToForgetWeightsTensor(tensorInfo4x5); |
| 1392 | armnn::ScopedTensorHandle inputToCellWeightsTensor(tensorInfo4x5); |
| 1393 | armnn::ScopedTensorHandle inputToOutputWeightsTensor(tensorInfo4x5); |
| 1394 | armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(tensorInfo4x3); |
| 1395 | armnn::ScopedTensorHandle recurrentToInputWeightsTensor(tensorInfo4x3); |
| 1396 | armnn::ScopedTensorHandle recurrentToCellWeightsTensor(tensorInfo4x3); |
| 1397 | armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(tensorInfo4x3); |
| 1398 | armnn::ScopedTensorHandle cellToInputWeightsTensor(tensorInfo4); |
| 1399 | armnn::ScopedTensorHandle inputGateBiasTensor(tensorInfo4); |
| 1400 | armnn::ScopedTensorHandle forgetGateBiasTensor(tensorInfo4); |
| 1401 | armnn::ScopedTensorHandle cellBiasTensor(tensorInfo4); |
| 1402 | armnn::ScopedTensorHandle outputGateBiasTensor(tensorInfo4); |
| 1403 | armnn::ScopedTensorHandle cellToForgetWeightsTensor(tensorInfo4); |
| 1404 | armnn::ScopedTensorHandle cellToOutputWeightsTensor(tensorInfo4); |
| 1405 | armnn::ScopedTensorHandle projectionWeightsTensor(tensorInfo3x4); |
| 1406 | armnn::ScopedTensorHandle projectionBiasTensor(tensorInfo3); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1407 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 1408 | armnn::ScopedTensorHandle inputLayerNormWeightsTensor(tensorInfo4); |
| 1409 | armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(tensorInfo4); |
| 1410 | armnn::ScopedTensorHandle cellLayerNormWeightsTensor(tensorInfo4); |
| 1411 | armnn::ScopedTensorHandle outputLayerNormWeightsTensor(tensorInfo4); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1412 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1413 | AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data()); |
| 1414 | AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data()); |
| 1415 | AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data()); |
| 1416 | AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data()); |
| 1417 | AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data()); |
| 1418 | AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data()); |
| 1419 | AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data()); |
| 1420 | AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data()); |
| 1421 | AllocateAndCopyDataToITensorHandle(&cellToInputWeightsTensor, cellToInputWeights.data()); |
| 1422 | AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data()); |
| 1423 | AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data()); |
| 1424 | AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data()); |
| 1425 | AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data()); |
| 1426 | AllocateAndCopyDataToITensorHandle(&cellToForgetWeightsTensor, cellToForgetWeights.data()); |
| 1427 | AllocateAndCopyDataToITensorHandle(&cellToOutputWeightsTensor, cellToOutputWeights.data()); |
| 1428 | AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data()); |
| 1429 | AllocateAndCopyDataToITensorHandle(&projectionBiasTensor, projectionBiasVector.data()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1430 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1431 | AllocateAndCopyDataToITensorHandle(&inputLayerNormWeightsTensor, inputLayerNormWeights.data()); |
| 1432 | AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data()); |
| 1433 | AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data()); |
| 1434 | AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1435 | |
| 1436 | data.m_InputToInputWeights = &inputToInputWeightsTensor; |
| 1437 | data.m_InputToForgetWeights = &inputToForgetWeightsTensor; |
| 1438 | data.m_InputToCellWeights = &inputToCellWeightsTensor; |
| 1439 | data.m_InputToOutputWeights = &inputToOutputWeightsTensor; |
| 1440 | data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; |
| 1441 | data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; |
| 1442 | data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; |
| 1443 | data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; |
| 1444 | data.m_CellToInputWeights = &cellToInputWeightsTensor; |
| 1445 | data.m_InputGateBias = &inputGateBiasTensor; |
| 1446 | data.m_ForgetGateBias = &forgetGateBiasTensor; |
| 1447 | data.m_CellBias = &cellBiasTensor; |
| 1448 | data.m_OutputGateBias = &outputGateBiasTensor; |
| 1449 | data.m_CellToForgetWeights = &cellToForgetWeightsTensor; |
| 1450 | data.m_CellToOutputWeights = &cellToOutputWeightsTensor; |
| 1451 | data.m_ProjectionWeights = &projectionWeightsTensor; |
| 1452 | data.m_ProjectionBias = &projectionBiasTensor; |
| 1453 | |
| 1454 | data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor; |
| 1455 | data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor; |
| 1456 | data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor; |
| 1457 | data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor; |
| 1458 | |
| 1459 | // Flags to set test configuration |
| 1460 | data.m_Parameters.m_ActivationFunc = 4; |
| 1461 | data.m_Parameters.m_CifgEnabled = false; |
| 1462 | data.m_Parameters.m_PeepholeEnabled = true; |
| 1463 | data.m_Parameters.m_ProjectionEnabled = true; |
| 1464 | data.m_Parameters.m_LayerNormEnabled = true; |
| 1465 | |
| 1466 | |
| 1467 | std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateLstm(data, info); |
| 1468 | inputHandle->Allocate(); |
| 1469 | outputStateInHandle->Allocate(); |
| 1470 | cellStateInHandle->Allocate(); |
| 1471 | |
| 1472 | scratchHandle->Allocate(); |
| 1473 | outputStateOutHandle->Allocate(); |
| 1474 | cellStateOutHandle->Allocate(); |
| 1475 | outputHandle->Allocate(); |
| 1476 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1477 | CopyDataToITensorHandle(inputHandle.get(), inputVector.data()); |
| 1478 | CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data()); |
| 1479 | CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1480 | |
| 1481 | workload->Execute(); |
| 1482 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1483 | CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 1484 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1485 | return LayerTestResult<T, 2>(actualOutput, |
| 1486 | outputVector, |
| 1487 | outputHandle->GetShape(), |
| 1488 | outputTensorInfo.GetShape()); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1489 | } |
| 1490 | |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 1491 | LayerTestResult<uint8_t, 2> QuantizedLstmTestImpl( |
| 1492 | armnn::IWorkloadFactory& workloadFactory, |
| 1493 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1494 | const armnn::ITensorHandleFactory& tensorHandleFactory, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1495 | const std::vector<uint8_t>& input, |
| 1496 | const std::vector<uint8_t>& outputExpected, |
| 1497 | const armnn::TensorShape& inputShape, |
| 1498 | const armnn::TensorShape& outputExpectedShape) |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1499 | { |
Jan Eilers | 8eb2560 | 2020-03-09 12:13:48 +0000 | [diff] [blame] | 1500 | IgnoreUnused(memoryManager); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1501 | auto numBatches = armnn::numeric_cast<unsigned int>(inputShape[0]); |
| 1502 | auto inputSize = armnn::numeric_cast<unsigned int>(inputShape[1]); |
| 1503 | auto outputSize = armnn::numeric_cast<unsigned int>(outputExpectedShape[1]); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1504 | |
| 1505 | // Scale/Offset for input/output, cellState In/Out, weights, bias |
| 1506 | float inputOutputScale = 0.0078125f; |
| 1507 | int32_t inputOutputOffset = 128; |
| 1508 | |
| 1509 | float cellStateScale = 0.00048828125f; |
| 1510 | int32_t cellStateOffset = 0; |
| 1511 | |
| 1512 | float weightsScale = 0.00408021f; |
| 1513 | int32_t weightsOffset = 100; |
| 1514 | |
| 1515 | float biasScale = 3.1876640625e-05f; |
| 1516 | int32_t biasOffset = 0; |
| 1517 | |
| 1518 | // Input/Output tensor info |
| 1519 | armnn::TensorInfo inputInfo({numBatches , inputSize}, |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 1520 | armnn::DataType::QAsymmU8, |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1521 | inputOutputScale, |
| 1522 | inputOutputOffset); |
| 1523 | |
| 1524 | armnn::TensorInfo cellStateInfo({numBatches , outputSize}, |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 1525 | armnn::DataType::QSymmS16, |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1526 | cellStateScale, |
| 1527 | cellStateOffset); |
| 1528 | |
| 1529 | armnn::TensorInfo outputStateInfo({numBatches , outputSize}, |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 1530 | armnn::DataType::QAsymmU8, |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1531 | inputOutputScale, |
| 1532 | inputOutputOffset); |
| 1533 | |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1534 | // Input0 |
| 1535 | std::vector<uint8_t> inputVector; |
| 1536 | inputVector.assign(input.data(), input.data() + (numBatches * inputSize)); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1537 | |
| 1538 | // Input1 |
| 1539 | std::vector<int16_t> cellStateInVector = {876, 1034, 955, -909, 761, 1029, 796, -1036}; // 13 |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1540 | // Input2 |
| 1541 | std::vector<uint8_t> outputStateInVector = {136, 150, 140, 115, 135, 152, 138, 112}; // 14 |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1542 | |
| 1543 | // Output0 |
| 1544 | std::vector<int16_t> cellStateOutVector = {1485, 1177, 1373, -1023, 1019, 1355, 1097, -1235}; // 0 |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1545 | |
| 1546 | // Output1 |
| 1547 | std::vector<uint8_t> outputVector; // 1 |
| 1548 | outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize)); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1549 | |
| 1550 | std::vector<uint8_t> actualOutput(outputStateInfo.GetNumElements()); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1551 | |
| 1552 | // Create tensor handles |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1553 | std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1554 | std::unique_ptr<armnn::ITensorHandle> cellStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1555 | tensorHandleFactory.CreateTensorHandle(cellStateInfo); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1556 | std::unique_ptr<armnn::ITensorHandle> outputStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1557 | tensorHandleFactory.CreateTensorHandle(outputStateInfo); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1558 | |
| 1559 | std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1560 | tensorHandleFactory.CreateTensorHandle(cellStateInfo); |
| 1561 | std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1562 | |
| 1563 | armnn::QuantizedLstmQueueDescriptor data; |
| 1564 | armnn::WorkloadInfo info; |
| 1565 | |
| 1566 | // Add inputs and outputs to workload |
| 1567 | AddInputToWorkload(data, info, inputInfo, inputHandle.get()); |
| 1568 | AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get()); |
| 1569 | AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get()); |
| 1570 | |
| 1571 | AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get()); |
| 1572 | AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get()); |
| 1573 | |
| 1574 | // Weights and bias tensor and quantization info |
| 1575 | armnn::TensorInfo inputWeightsInfo({outputSize, inputSize}, |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 1576 | armnn::DataType::QAsymmU8, |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1577 | weightsScale, |
| 1578 | weightsOffset); |
| 1579 | |
| 1580 | armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize}, |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 1581 | armnn::DataType::QAsymmU8, |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1582 | weightsScale, |
| 1583 | weightsOffset); |
| 1584 | |
| 1585 | armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset); |
| 1586 | |
| 1587 | // Weights and bias tensor data |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1588 | std::vector<uint8_t> inputToInputWeights = {146, 250, 235, 171, 10, 218, 171, 108}; |
| 1589 | std::vector<uint8_t> inputToForgetWeights = {24, 50, 132, 179, 158, 110, 3, 169}; |
| 1590 | std::vector<uint8_t> inputToCellWeights = {133, 34, 29, 49, 206, 109, 54, 183}; |
| 1591 | std::vector<uint8_t> inputToOutputWeights = {195, 187, 11, 99, 109, 10, 218, 48}; |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1592 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1593 | std::vector<uint8_t> recurrentToInputWeights = |
| 1594 | {254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26}; |
| 1595 | std::vector<uint8_t> recurrentToForgetWeights = |
| 1596 | {137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253}; |
| 1597 | std::vector<uint8_t> recurrentToCellWeights = |
| 1598 | {172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216}; |
| 1599 | std::vector<uint8_t> recurrentToOutputWeights = |
| 1600 | {106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98}; |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1601 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1602 | std::vector<int32_t> inputGateBias = {-7876, 13488, -726, 32839}; |
| 1603 | std::vector<int32_t> forgetGateBias = {9206, -46884, -11693, -38724}; |
| 1604 | std::vector<int32_t> cellBias = {39481, 48624, 48976, -21419}; |
| 1605 | std::vector<int32_t> outputGateBias = {-58999, -17050, -41852, -40538}; |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1606 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 1607 | // ScopedTensorHandles |
| 1608 | armnn::ScopedTensorHandle inputToInputWeightsTensor(inputWeightsInfo); |
| 1609 | armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo); |
| 1610 | armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo); |
| 1611 | armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1612 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 1613 | armnn::ScopedTensorHandle recurrentToInputWeightsTensor(recurrentWeightsInfo); |
| 1614 | armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo); |
| 1615 | armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo); |
| 1616 | armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1617 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 1618 | armnn::ScopedTensorHandle inputGateBiasTensor(biasInfo); |
| 1619 | armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo); |
| 1620 | armnn::ScopedTensorHandle cellBiasTensor(biasInfo); |
| 1621 | armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1622 | |
| 1623 | // Allocate and copy data |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1624 | AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data()); |
| 1625 | AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data()); |
| 1626 | AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data()); |
| 1627 | AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data()); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1628 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1629 | AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data()); |
| 1630 | AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data()); |
| 1631 | AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data()); |
| 1632 | AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data()); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1633 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1634 | AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data()); |
| 1635 | AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data()); |
| 1636 | AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data()); |
| 1637 | AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data()); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1638 | |
| 1639 | // Setup queue descriptor |
| 1640 | data.m_InputToInputWeights = &inputToInputWeightsTensor; |
| 1641 | data.m_InputToForgetWeights = &inputToForgetWeightsTensor; |
| 1642 | data.m_InputToCellWeights = &inputToCellWeightsTensor; |
| 1643 | data.m_InputToOutputWeights = &inputToOutputWeightsTensor; |
| 1644 | |
| 1645 | data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; |
| 1646 | data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; |
| 1647 | data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; |
| 1648 | data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; |
| 1649 | |
| 1650 | data.m_InputGateBias = &inputGateBiasTensor; |
| 1651 | data.m_ForgetGateBias = &forgetGateBiasTensor; |
| 1652 | data.m_CellBias = &cellBiasTensor; |
| 1653 | data.m_OutputGateBias = &outputGateBiasTensor; |
| 1654 | |
| 1655 | // Create workload and allocate tensor handles |
| 1656 | std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQuantizedLstm(data, info); |
| 1657 | inputHandle->Allocate(); |
| 1658 | outputStateInHandle->Allocate(); |
| 1659 | cellStateInHandle->Allocate(); |
| 1660 | |
| 1661 | cellStateOutHandle->Allocate(); |
| 1662 | outputHandle->Allocate(); |
| 1663 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1664 | CopyDataToITensorHandle(inputHandle.get(), inputVector.data()); |
| 1665 | CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data()); |
| 1666 | CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data()); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1667 | |
| 1668 | workload->Execute(); |
| 1669 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1670 | CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get()); |
James Conroy | 9c3cae8 | 2019-08-01 16:01:48 +0100 | [diff] [blame] | 1671 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1672 | return LayerTestResult<uint8_t, 2>(actualOutput, |
| 1673 | outputVector, |
| 1674 | outputHandle->GetShape(), |
| 1675 | outputStateInfo.GetShape()); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 1676 | } |
| 1677 | |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 1678 | // QLSTM: CIFG, LayerNorm |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1679 | LayerTestResult<int8_t, 2> QLstmTestImpl( |
| 1680 | armnn::IWorkloadFactory& workloadFactory, |
| 1681 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1682 | const armnn::ITensorHandleFactory& tensorHandleFactory, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1683 | const std::vector<int8_t>& input, |
| 1684 | const std::vector<int8_t>& outputExpected) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1685 | { |
| 1686 | IgnoreUnused(memoryManager); |
| 1687 | unsigned int numBatches = 2; |
| 1688 | unsigned int inputSize = 5; |
| 1689 | unsigned int outputSize = 4; |
| 1690 | unsigned int numUnits = 4; |
| 1691 | |
| 1692 | bool cifgEnabled = true; |
| 1693 | bool peepholeEnabled = false; |
| 1694 | bool projectionEnabled = false; |
| 1695 | bool layerNormEnabled = true; |
| 1696 | |
| 1697 | // Scale/Offset quantization info |
| 1698 | float inputScale = 0.0078125f; |
| 1699 | int32_t inputOffset = 0; |
| 1700 | |
| 1701 | int32_t hiddenStateZeroPoint = 0; |
| 1702 | float hiddenStateScale = 0.007f; |
| 1703 | |
| 1704 | // if (!projectionEnabled) outputScale == hiddenStateScale |
| 1705 | float outputScale = hiddenStateScale; |
| 1706 | int32_t outputOffset = hiddenStateZeroPoint; |
| 1707 | |
| 1708 | float cellStateScale = 3.05176e-05f; |
| 1709 | int32_t cellStateOffset = 0; |
| 1710 | |
| 1711 | float weightsScale = 0.00784314f; |
| 1712 | int32_t weightsOffset = 0; |
| 1713 | |
| 1714 | float layerNormScale = 3.05182e-05f; |
| 1715 | int32_t layerNormOffset = 0; |
| 1716 | |
| 1717 | float biasScale = layerNormScale / 1024; |
| 1718 | int32_t biasOffset = 0; |
| 1719 | |
| 1720 | float inputIntermediateScale = 0.007059f; |
| 1721 | float forgetIntermediateScale = 0.007812f; |
| 1722 | float cellIntermediateScale = inputIntermediateScale; |
| 1723 | float outputIntermediateScale = forgetIntermediateScale; |
| 1724 | |
| 1725 | float cellClip = 0.0f; |
| 1726 | float projectionClip = 0.0f; |
| 1727 | |
| 1728 | // Input/Output tensor info |
| 1729 | armnn::TensorInfo inputInfo({numBatches , inputSize}, |
| 1730 | armnn::DataType::QAsymmS8, |
| 1731 | inputScale, |
| 1732 | inputOffset); |
| 1733 | |
| 1734 | armnn::TensorInfo cellStateInfo({numBatches , numUnits}, |
| 1735 | armnn::DataType::QSymmS16, |
| 1736 | cellStateScale, |
| 1737 | cellStateOffset); |
| 1738 | |
| 1739 | armnn::TensorInfo outputStateInfo({numBatches , outputSize}, |
| 1740 | armnn::DataType::QAsymmS8, |
| 1741 | outputScale, |
| 1742 | outputOffset); |
| 1743 | |
| 1744 | LayerTestResult<int8_t, 2> ret(outputStateInfo); |
| 1745 | |
| 1746 | // Input tensors |
| 1747 | std::vector<int8_t> inputVector; |
| 1748 | inputVector.assign(input.data(), input.data() + (numBatches * inputSize)); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1749 | |
| 1750 | std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0}; |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1751 | |
Teresa Charlin | be727be | 2020-09-25 15:08:21 +0100 | [diff] [blame] | 1752 | std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0, 0, 0}; |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1753 | |
| 1754 | // Output tensors |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1755 | std::vector<int16_t> cellStateOutVector = {-11692, 9960, 5491, 8861, -9422, 7726, 2056, 13149}; |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1756 | |
| 1757 | std::vector<int8_t> outputVector; |
| 1758 | outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize)); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1759 | |
| 1760 | std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1761 | |
| 1762 | // Create tensor handles |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1763 | std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1764 | std::unique_ptr<armnn::ITensorHandle> cellStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1765 | tensorHandleFactory.CreateTensorHandle(cellStateInfo); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1766 | std::unique_ptr<armnn::ITensorHandle> outputStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1767 | tensorHandleFactory.CreateTensorHandle(outputStateInfo); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1768 | |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1769 | std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle = |
| 1770 | tensorHandleFactory.CreateTensorHandle(outputStateInfo); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1771 | std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1772 | tensorHandleFactory.CreateTensorHandle(cellStateInfo); |
| 1773 | std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1774 | |
| 1775 | armnn::QLstmQueueDescriptor data; |
| 1776 | armnn::WorkloadInfo info; |
| 1777 | |
| 1778 | // Add inputs and outputs to workload |
| 1779 | AddInputToWorkload(data, info, inputInfo, inputHandle.get()); |
| 1780 | AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get()); |
| 1781 | AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get()); |
| 1782 | |
| 1783 | AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get()); |
| 1784 | AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get()); |
| 1785 | AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get()); |
| 1786 | |
| 1787 | // Weights and bias tensor and quantization info |
| 1788 | armnn::TensorInfo inputWeightsInfo({outputSize, inputSize}, |
| 1789 | armnn::DataType::QSymmS8, |
| 1790 | weightsScale, |
| 1791 | weightsOffset); |
| 1792 | |
| 1793 | armnn::TensorInfo recurrentWeightsInfo({outputSize, outputSize}, |
| 1794 | armnn::DataType::QSymmS8, |
| 1795 | weightsScale, |
| 1796 | weightsOffset); |
| 1797 | |
| 1798 | armnn::TensorInfo biasInfo({outputSize}, armnn::DataType::Signed32, biasScale, biasOffset); |
| 1799 | |
| 1800 | armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset); |
| 1801 | |
| 1802 | // Weights and bias tensor data |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1803 | std::vector<int8_t> inputToForgetWeights = |
| 1804 | {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64}; |
| 1805 | std::vector<int8_t> inputToCellWeights = |
| 1806 | {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77}; |
| 1807 | std::vector<int8_t> inputToOutputWeights = |
| 1808 | {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51}; |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1809 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1810 | std::vector<int8_t> recurrentToForgetWeights = |
| 1811 | {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25, 25, 38, -13, 51}; |
| 1812 | std::vector<int8_t> recurrentToCellWeights = |
| 1813 | {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25, 38, -13, 25, 64}; |
| 1814 | std::vector<int8_t> recurrentToOutputWeights = |
| 1815 | {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25, 13, 64, 25, -38}; |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1816 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1817 | std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484}; |
| 1818 | std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987}; |
| 1819 | std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484}; |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1820 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1821 | std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830}; |
| 1822 | std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214}; |
| 1823 | std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384}; |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1824 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 1825 | // ScopedTensorHandles |
| 1826 | armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo); |
| 1827 | armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo); |
| 1828 | armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1829 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 1830 | armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo); |
| 1831 | armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo); |
| 1832 | armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1833 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 1834 | armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo); |
| 1835 | armnn::ScopedTensorHandle cellBiasTensor(biasInfo); |
| 1836 | armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1837 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 1838 | armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo); |
| 1839 | armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo); |
| 1840 | armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1841 | |
| 1842 | // Allocate and copy data |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1843 | AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data()); |
| 1844 | AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data()); |
| 1845 | AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1846 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1847 | AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data()); |
| 1848 | AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data()); |
| 1849 | AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1850 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1851 | AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data()); |
| 1852 | AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data()); |
| 1853 | AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1854 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1855 | AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data()); |
| 1856 | AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data()); |
| 1857 | AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1858 | |
| 1859 | // Setup queue descriptor |
| 1860 | data.m_InputToForgetWeights = &inputToForgetWeightsTensor; |
| 1861 | data.m_InputToCellWeights = &inputToCellWeightsTensor; |
| 1862 | data.m_InputToOutputWeights = &inputToOutputWeightsTensor; |
| 1863 | |
| 1864 | data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; |
| 1865 | data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; |
| 1866 | data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; |
| 1867 | |
| 1868 | data.m_ForgetGateBias = &forgetGateBiasTensor; |
| 1869 | data.m_CellBias = &cellBiasTensor; |
| 1870 | data.m_OutputGateBias = &outputGateBiasTensor; |
| 1871 | |
| 1872 | data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor; |
| 1873 | data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor; |
| 1874 | data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor; |
| 1875 | |
| 1876 | data.m_Parameters.m_CifgEnabled = cifgEnabled; |
| 1877 | data.m_Parameters.m_PeepholeEnabled = peepholeEnabled; |
| 1878 | data.m_Parameters.m_ProjectionEnabled = projectionEnabled; |
| 1879 | data.m_Parameters.m_LayerNormEnabled = layerNormEnabled; |
| 1880 | |
| 1881 | data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale; |
| 1882 | data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale; |
| 1883 | data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale; |
| 1884 | data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale; |
| 1885 | |
| 1886 | data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint; |
| 1887 | data.m_Parameters.m_HiddenStateScale = hiddenStateScale; |
| 1888 | |
| 1889 | data.m_Parameters.m_CellClip = cellClip; |
| 1890 | data.m_Parameters.m_ProjectionClip = projectionClip; |
| 1891 | |
| 1892 | // Create workload and allocate tensor handles |
| 1893 | std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info); |
| 1894 | inputHandle->Allocate(); |
| 1895 | outputStateInHandle->Allocate(); |
| 1896 | cellStateInHandle->Allocate(); |
| 1897 | |
| 1898 | outputStateOutHandle->Allocate(); |
| 1899 | cellStateOutHandle->Allocate(); |
| 1900 | outputHandle->Allocate(); |
| 1901 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1902 | CopyDataToITensorHandle(inputHandle.get(), inputVector.data()); |
| 1903 | CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data()); |
| 1904 | CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1905 | |
| 1906 | workload->Execute(); |
| 1907 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1908 | CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1909 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1910 | return LayerTestResult<int8_t, 2>(actualOutput, |
| 1911 | outputVector, |
| 1912 | outputHandle->GetShape(), |
| 1913 | outputStateInfo.GetShape()); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 1914 | } |
| 1915 | |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 1916 | // QLSTM: Projection, LayerNorm |
| 1917 | LayerTestResult<int8_t, 2> QLstmTestImpl1( |
| 1918 | armnn::IWorkloadFactory& workloadFactory, |
| 1919 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 1920 | const armnn::ITensorHandleFactory& tensorHandleFactory, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1921 | const std::vector<int8_t>& input, |
| 1922 | const std::vector<int8_t>& outputExpected) |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 1923 | { |
| 1924 | IgnoreUnused(memoryManager); |
| 1925 | unsigned int numBatches = 2; |
| 1926 | unsigned int inputSize = 5; |
| 1927 | unsigned int outputSize = 3; |
| 1928 | unsigned int numUnits = 4; |
| 1929 | |
| 1930 | bool cifgEnabled = false; |
| 1931 | bool peepholeEnabled = false; |
| 1932 | bool projectionEnabled = true; |
| 1933 | bool layerNormEnabled = true; |
| 1934 | |
| 1935 | // Scale/Offset quantization info |
| 1936 | float inputScale = 0.0078125f; |
| 1937 | int32_t inputOffset = 0; |
| 1938 | |
| 1939 | int32_t hiddenStateZeroPoint = 0; |
| 1940 | float hiddenStateScale = 0.007f; |
| 1941 | |
| 1942 | // if (!projectionEnabled) outputScale == hiddenStateScale |
| 1943 | float outputScale = 3.05176e-05f; |
| 1944 | int32_t outputOffset = 0; |
| 1945 | |
| 1946 | float cellStateScale = 3.05176e-05f; |
| 1947 | int32_t cellStateOffset = 0; |
| 1948 | |
| 1949 | float weightsScale = 0.00784314f; |
| 1950 | int32_t weightsOffset = 0; |
| 1951 | |
| 1952 | float layerNormScale = 3.05182e-05f; |
| 1953 | int32_t layerNormOffset = 0; |
| 1954 | |
| 1955 | float biasScale = layerNormScale / 1024; |
| 1956 | int32_t biasOffset = 0; |
| 1957 | |
| 1958 | float projectionWeightsScale = 0.00392157f; |
| 1959 | |
| 1960 | float inputIntermediateScale = 0.007059f; |
| 1961 | float forgetIntermediateScale = 0.007812f; |
| 1962 | float cellIntermediateScale = inputIntermediateScale; |
| 1963 | float outputIntermediateScale = forgetIntermediateScale; |
| 1964 | |
| 1965 | float cellClip = 0.0f; |
| 1966 | float projectionClip = 0.0f; |
| 1967 | |
| 1968 | // Input/Output tensor info |
| 1969 | armnn::TensorInfo inputInfo({numBatches , inputSize}, |
| 1970 | armnn::DataType::QAsymmS8, |
| 1971 | inputScale, |
| 1972 | inputOffset); |
| 1973 | |
| 1974 | armnn::TensorInfo cellStateInfo({numBatches , numUnits}, |
| 1975 | armnn::DataType::QSymmS16, |
| 1976 | cellStateScale, |
| 1977 | cellStateOffset); |
| 1978 | |
| 1979 | armnn::TensorInfo outputStateInfo({numBatches , outputSize}, |
| 1980 | armnn::DataType::QAsymmS8, |
| 1981 | outputScale, |
| 1982 | outputOffset); |
| 1983 | |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 1984 | // Input tensors |
| 1985 | std::vector<int8_t> inputVector; |
| 1986 | inputVector.assign(input.data(), input.data() + (numBatches * inputSize)); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 1987 | |
| 1988 | std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 1989 | |
| 1990 | std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 1991 | |
| 1992 | // Output tensors |
| 1993 | std::vector<int16_t> cellStateOutVector = {-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 1994 | |
| 1995 | std::vector<int8_t> outputVector; |
| 1996 | outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize)); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 1997 | |
| 1998 | std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 1999 | |
| 2000 | // Create tensor handles |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2001 | std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2002 | std::unique_ptr<armnn::ITensorHandle> cellStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2003 | tensorHandleFactory.CreateTensorHandle(cellStateInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2004 | std::unique_ptr<armnn::ITensorHandle> outputStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2005 | tensorHandleFactory.CreateTensorHandle(outputStateInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2006 | |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2007 | std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle = |
| 2008 | tensorHandleFactory.CreateTensorHandle(outputStateInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2009 | std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2010 | tensorHandleFactory.CreateTensorHandle(cellStateInfo); |
| 2011 | std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2012 | |
| 2013 | armnn::QLstmQueueDescriptor data; |
| 2014 | armnn::WorkloadInfo info; |
| 2015 | |
| 2016 | // Add inputs and outputs to workload |
| 2017 | AddInputToWorkload(data, info, inputInfo, inputHandle.get()); |
| 2018 | AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get()); |
| 2019 | AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get()); |
| 2020 | |
| 2021 | AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get()); |
| 2022 | AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get()); |
| 2023 | AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get()); |
| 2024 | |
| 2025 | // Weights and bias tensor and quantization info |
| 2026 | armnn::TensorInfo inputWeightsInfo({numUnits, inputSize}, |
| 2027 | armnn::DataType::QSymmS8, |
| 2028 | weightsScale, |
| 2029 | weightsOffset); |
| 2030 | |
| 2031 | armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize}, |
| 2032 | armnn::DataType::QSymmS8, |
| 2033 | weightsScale, |
| 2034 | weightsOffset); |
| 2035 | |
| 2036 | armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset); |
| 2037 | |
| 2038 | armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset); |
| 2039 | |
| 2040 | armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits}, |
| 2041 | armnn::DataType::QSymmS8, |
| 2042 | projectionWeightsScale, |
| 2043 | 0); |
| 2044 | |
| 2045 | // Weights and bias tensor data |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2046 | std::vector<int8_t> inputToInputWeights = |
| 2047 | {64, 77, 89, -102, -115, 13, 25, 38, -51, 64, -102, 89, -77, 64, -51, -64, -51, -38, -25, -13}; |
| 2048 | std::vector<int8_t> inputToForgetWeights = |
| 2049 | {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64}; |
| 2050 | std::vector<int8_t> inputToCellWeights = |
| 2051 | {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77}; |
| 2052 | std::vector<int8_t> inputToOutputWeights = |
| 2053 | {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2054 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2055 | std::vector<int8_t> recurrentToInputWeights = {-25, -38, 51, 13, -64, 115, -25, -38, -89, 6, -25, -77}; |
| 2056 | std::vector<int8_t> recurrentToForgetWeights = {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25}; |
| 2057 | std::vector<int8_t> recurrentToCellWeights = {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25}; |
| 2058 | std::vector<int8_t> recurrentToOutputWeights = {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2059 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2060 | std::vector<int32_t> inputGateBias = {644245, 3221226, 4724464, 8160438}; |
| 2061 | std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484}; |
| 2062 | std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987}; |
| 2063 | std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2064 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2065 | std::vector<int16_t> inputLayerNormWeights = {3277, 6553, 9830, 16384}; |
| 2066 | std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830}; |
| 2067 | std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214}; |
| 2068 | std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2069 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2070 | std::vector<int8_t> projectionWeights = {-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2071 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 2072 | // ScopedTensorHandles |
| 2073 | armnn::ScopedTensorHandle inputToInputWeightsTensor(inputWeightsInfo); |
| 2074 | armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo); |
| 2075 | armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo); |
| 2076 | armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2077 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 2078 | armnn::ScopedTensorHandle recurrentToInputWeightsTensor(recurrentWeightsInfo); |
| 2079 | armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo); |
| 2080 | armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo); |
| 2081 | armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2082 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 2083 | armnn::ScopedTensorHandle inputGateBiasTensor(biasInfo); |
| 2084 | armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo); |
| 2085 | armnn::ScopedTensorHandle cellBiasTensor(biasInfo); |
| 2086 | armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2087 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 2088 | armnn::ScopedTensorHandle inputLayerNormWeightsTensor(layerNormWeightsInfo); |
| 2089 | armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo); |
| 2090 | armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo); |
| 2091 | armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2092 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 2093 | armnn::ScopedTensorHandle projectionWeightsTensor(projectionWeightsInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2094 | |
| 2095 | // Allocate and copy data |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2096 | AllocateAndCopyDataToITensorHandle(&inputToInputWeightsTensor, inputToInputWeights.data()); |
| 2097 | AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data()); |
| 2098 | AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data()); |
| 2099 | AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2100 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2101 | AllocateAndCopyDataToITensorHandle(&recurrentToInputWeightsTensor, recurrentToInputWeights.data()); |
| 2102 | AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data()); |
| 2103 | AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data()); |
| 2104 | AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2105 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2106 | AllocateAndCopyDataToITensorHandle(&inputGateBiasTensor, inputGateBias.data()); |
| 2107 | AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data()); |
| 2108 | AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data()); |
| 2109 | AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2110 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2111 | AllocateAndCopyDataToITensorHandle(&inputLayerNormWeightsTensor, inputLayerNormWeights.data()); |
| 2112 | AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data()); |
| 2113 | AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data()); |
| 2114 | AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2115 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2116 | AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2117 | |
| 2118 | // Setup queue descriptor |
| 2119 | data.m_InputToInputWeights = &inputToInputWeightsTensor; |
| 2120 | data.m_InputToForgetWeights = &inputToForgetWeightsTensor; |
| 2121 | data.m_InputToCellWeights = &inputToCellWeightsTensor; |
| 2122 | data.m_InputToOutputWeights = &inputToOutputWeightsTensor; |
| 2123 | |
| 2124 | data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; |
| 2125 | data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; |
| 2126 | data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; |
| 2127 | data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; |
| 2128 | |
| 2129 | data.m_InputGateBias = &inputGateBiasTensor; |
| 2130 | data.m_ForgetGateBias = &forgetGateBiasTensor; |
| 2131 | data.m_CellBias = &cellBiasTensor; |
| 2132 | data.m_OutputGateBias = &outputGateBiasTensor; |
| 2133 | |
| 2134 | data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor; |
| 2135 | data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor; |
| 2136 | data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor; |
| 2137 | data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor; |
| 2138 | |
| 2139 | data.m_ProjectionWeights = &projectionWeightsTensor; |
| 2140 | |
| 2141 | data.m_Parameters.m_CifgEnabled = cifgEnabled; |
| 2142 | data.m_Parameters.m_PeepholeEnabled = peepholeEnabled; |
| 2143 | data.m_Parameters.m_ProjectionEnabled = projectionEnabled; |
| 2144 | data.m_Parameters.m_LayerNormEnabled = layerNormEnabled; |
| 2145 | |
| 2146 | data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale; |
| 2147 | data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale; |
| 2148 | data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale; |
| 2149 | data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale; |
| 2150 | |
| 2151 | data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint; |
| 2152 | data.m_Parameters.m_HiddenStateScale = hiddenStateScale; |
| 2153 | |
| 2154 | data.m_Parameters.m_CellClip = cellClip; |
| 2155 | data.m_Parameters.m_ProjectionClip = projectionClip; |
| 2156 | |
| 2157 | // Create workload and allocate tensor handles |
| 2158 | std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info); |
| 2159 | inputHandle->Allocate(); |
| 2160 | outputStateInHandle->Allocate(); |
| 2161 | cellStateInHandle->Allocate(); |
| 2162 | |
| 2163 | outputStateOutHandle->Allocate(); |
| 2164 | cellStateOutHandle->Allocate(); |
| 2165 | outputHandle->Allocate(); |
| 2166 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2167 | CopyDataToITensorHandle(inputHandle.get(), inputVector.data()); |
| 2168 | CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data()); |
| 2169 | CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2170 | |
| 2171 | workload->Execute(); |
| 2172 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2173 | CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2174 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2175 | return LayerTestResult<int8_t, 2>(actualOutput, |
| 2176 | outputVector, |
| 2177 | outputHandle->GetShape(), |
| 2178 | outputStateInfo.GetShape()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2179 | } |
| 2180 | |
| 2181 | // QLSTM: Projection, CIFG, LayerNorm |
| 2182 | LayerTestResult<int8_t, 2> QLstmTestImpl2( |
| 2183 | armnn::IWorkloadFactory& workloadFactory, |
| 2184 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2185 | const armnn::ITensorHandleFactory& tensorHandleFactory, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2186 | const std::vector<int8_t>& input, |
| 2187 | const std::vector<int8_t>& outputExpected) |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2188 | { |
| 2189 | IgnoreUnused(memoryManager); |
| 2190 | unsigned int numBatches = 2; |
| 2191 | unsigned int inputSize = 5; |
| 2192 | unsigned int outputSize = 3; |
| 2193 | unsigned int numUnits = 4; |
| 2194 | |
| 2195 | bool cifgEnabled = true; |
| 2196 | bool peepholeEnabled = false; |
| 2197 | bool projectionEnabled = true; |
| 2198 | bool layerNormEnabled = true; |
| 2199 | |
| 2200 | // Scale/Offset quantization info |
| 2201 | float inputScale = 0.0078125f; |
| 2202 | int32_t inputOffset = 0; |
| 2203 | |
| 2204 | int32_t hiddenStateZeroPoint = 0; |
| 2205 | float hiddenStateScale = 0.007f; |
| 2206 | |
| 2207 | // if (!projectionEnabled) outputScale == hiddenStateScale |
| 2208 | float outputScale = 3.05176e-05f; |
| 2209 | int32_t outputOffset = 0; |
| 2210 | |
| 2211 | float cellStateScale = 3.05176e-05f; |
| 2212 | int32_t cellStateOffset = 0; |
| 2213 | |
| 2214 | float weightsScale = 0.00784314f; |
| 2215 | int32_t weightsOffset = 0; |
| 2216 | |
| 2217 | float layerNormScale = 3.05182e-05f; |
| 2218 | int32_t layerNormOffset = 0; |
| 2219 | |
| 2220 | float biasScale = layerNormScale / 1024; |
| 2221 | int32_t biasOffset = 0; |
| 2222 | |
| 2223 | float projectionWeightsScale = 0.00392157f; |
| 2224 | |
| 2225 | float inputIntermediateScale = 0.007059f; |
| 2226 | float forgetIntermediateScale = 0.007812f; |
| 2227 | float cellIntermediateScale = inputIntermediateScale; |
| 2228 | float outputIntermediateScale = forgetIntermediateScale; |
| 2229 | |
| 2230 | float cellClip = 0.0f; |
| 2231 | float projectionClip = 0.0f; |
| 2232 | |
| 2233 | // Input/Output tensor info |
| 2234 | armnn::TensorInfo inputInfo({numBatches , inputSize}, |
| 2235 | armnn::DataType::QAsymmS8, |
| 2236 | inputScale, |
| 2237 | inputOffset); |
| 2238 | |
| 2239 | armnn::TensorInfo cellStateInfo({numBatches , numUnits}, |
| 2240 | armnn::DataType::QSymmS16, |
| 2241 | cellStateScale, |
| 2242 | cellStateOffset); |
| 2243 | |
| 2244 | armnn::TensorInfo outputStateInfo({numBatches , outputSize}, |
| 2245 | armnn::DataType::QAsymmS8, |
| 2246 | outputScale, |
| 2247 | outputOffset); |
| 2248 | |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2249 | // Input tensors |
| 2250 | std::vector<int8_t> inputVector; |
| 2251 | inputVector.assign(input.data(), input.data() + (numBatches * inputSize)); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2252 | |
| 2253 | std::vector<int16_t> cellStateInVector = {0, 0, 0, 0, 0, 0, 0, 0}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2254 | |
| 2255 | std::vector<int8_t> outputStateInVector = {0, 0, 0, 0, 0, 0}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2256 | |
| 2257 | // Output tensors |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2258 | std::vector<int16_t> cellStateOutVector = {-14650, 8939, 5771, 6715, -11843, 7847, 1508, 12939}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2259 | |
| 2260 | std::vector<int8_t> outputVector; |
| 2261 | outputVector.assign(outputExpected.data(), outputExpected.data() + (numBatches * outputSize)); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2262 | |
| 2263 | std::vector<int8_t> actualOutput(outputStateInfo.GetNumElements()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2264 | |
| 2265 | // Create tensor handles |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2266 | std::unique_ptr<armnn::ITensorHandle> inputHandle = tensorHandleFactory.CreateTensorHandle(inputInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2267 | std::unique_ptr<armnn::ITensorHandle> cellStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2268 | tensorHandleFactory.CreateTensorHandle(cellStateInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2269 | std::unique_ptr<armnn::ITensorHandle> outputStateInHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2270 | tensorHandleFactory.CreateTensorHandle(outputStateInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2271 | |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2272 | std::unique_ptr<armnn::ITensorHandle> outputStateOutHandle = |
| 2273 | tensorHandleFactory.CreateTensorHandle(outputStateInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2274 | std::unique_ptr<armnn::ITensorHandle> cellStateOutHandle = |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2275 | tensorHandleFactory.CreateTensorHandle(cellStateInfo); |
| 2276 | std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputStateInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2277 | |
| 2278 | armnn::QLstmQueueDescriptor data; |
| 2279 | armnn::WorkloadInfo info; |
| 2280 | |
| 2281 | // Add inputs and outputs to workload |
| 2282 | AddInputToWorkload(data, info, inputInfo, inputHandle.get()); |
| 2283 | AddInputToWorkload(data, info, outputStateInfo, outputStateInHandle.get()); |
| 2284 | AddInputToWorkload(data, info, cellStateInfo, cellStateInHandle.get()); |
| 2285 | |
| 2286 | AddOutputToWorkload(data, info, outputStateInfo, outputStateOutHandle.get()); |
| 2287 | AddOutputToWorkload(data, info, cellStateInfo, cellStateOutHandle.get()); |
| 2288 | AddOutputToWorkload(data, info, outputStateInfo, outputHandle.get()); |
| 2289 | |
| 2290 | // Weights and bias tensor and quantization info |
| 2291 | armnn::TensorInfo inputWeightsInfo({numUnits, inputSize}, |
| 2292 | armnn::DataType::QSymmS8, |
| 2293 | weightsScale, |
| 2294 | weightsOffset); |
| 2295 | |
| 2296 | armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize}, |
| 2297 | armnn::DataType::QSymmS8, |
| 2298 | weightsScale, |
| 2299 | weightsOffset); |
| 2300 | |
| 2301 | armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset); |
| 2302 | |
| 2303 | armnn::TensorInfo layerNormWeightsInfo({numUnits}, armnn::DataType::QSymmS16, layerNormScale, layerNormOffset); |
| 2304 | |
| 2305 | armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits}, |
| 2306 | armnn::DataType::QSymmS8, |
| 2307 | projectionWeightsScale, |
| 2308 | 0); |
| 2309 | |
| 2310 | // Weights and bias tensor data |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2311 | std::vector<int8_t> inputToForgetWeights = |
| 2312 | {-77, -13, 38, 25, 115, -64, -25, -51, 38, -102, -51, 38, -64, -51, -77, 38, -51, -77, -64, -64}; |
| 2313 | std::vector<int8_t> inputToCellWeights = |
| 2314 | {-51, -38, -25, -13, -64, 64, -25, -38, -25, -77, 77, -13, -51, -38, -89, 89, -115, -64, 102, 77}; |
| 2315 | std::vector<int8_t> inputToOutputWeights = |
| 2316 | {-102, -51, -25, -115, -13, -89, 38, -38, -102, -25, 77, -25, 51, -89, -38, -64, 13, 64, -77, -51}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2317 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2318 | std::vector<int8_t> recurrentToForgetWeights = |
| 2319 | {-64, -38, -64, -25, 77, 51, 115, 38, -13, 25, 64, 25}; |
| 2320 | std::vector<int8_t> recurrentToCellWeights = |
| 2321 | {-38, 25, 13, -38, 102, -10, -25, 38, 102, -77, -13, 25}; |
| 2322 | std::vector<int8_t> recurrentToOutputWeights = |
| 2323 | {38, -13, 13, -25, -64, -89, -25, -77, -13, -51, -89, -25}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2324 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2325 | std::vector<int32_t> forgetGateBias = {2147484, -6442451, -4294968, 2147484}; |
| 2326 | std::vector<int32_t> cellBias = {-1073742, 15461883, 5368709, 1717987}; |
| 2327 | std::vector<int32_t> outputGateBias = {1073742, -214748, 4294968, 2147484}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2328 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2329 | std::vector<int16_t> forgetLayerNormWeights = {6553, 6553, 13107, 9830}; |
| 2330 | std::vector<int16_t> cellLayerNormWeights = {22937, 6553, 9830, 26214}; |
| 2331 | std::vector<int16_t> outputLayerNormWeights = {19660, 6553, 6553, 16384}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2332 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2333 | std::vector<int8_t> projectionWeights = {-25, 51, 3, -51, 25, 127, 77, 20, 18, 51, -102, 51}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2334 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 2335 | // ScopedTensorHandles |
| 2336 | armnn::ScopedTensorHandle inputToForgetWeightsTensor(inputWeightsInfo); |
| 2337 | armnn::ScopedTensorHandle inputToCellWeightsTensor(inputWeightsInfo); |
| 2338 | armnn::ScopedTensorHandle inputToOutputWeightsTensor(inputWeightsInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2339 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 2340 | armnn::ScopedTensorHandle recurrentToForgetWeightsTensor(recurrentWeightsInfo); |
| 2341 | armnn::ScopedTensorHandle recurrentToCellWeightsTensor(recurrentWeightsInfo); |
| 2342 | armnn::ScopedTensorHandle recurrentToOutputWeightsTensor(recurrentWeightsInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2343 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 2344 | armnn::ScopedTensorHandle forgetGateBiasTensor(biasInfo); |
| 2345 | armnn::ScopedTensorHandle cellBiasTensor(biasInfo); |
| 2346 | armnn::ScopedTensorHandle outputGateBiasTensor(biasInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2347 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 2348 | armnn::ScopedTensorHandle forgetLayerNormWeightsTensor(layerNormWeightsInfo); |
| 2349 | armnn::ScopedTensorHandle cellLayerNormWeightsTensor(layerNormWeightsInfo); |
| 2350 | armnn::ScopedTensorHandle outputLayerNormWeightsTensor(layerNormWeightsInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2351 | |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 2352 | armnn::ScopedTensorHandle projectionWeightsTensor(projectionWeightsInfo); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2353 | |
| 2354 | // Allocate and copy data |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2355 | AllocateAndCopyDataToITensorHandle(&inputToForgetWeightsTensor, inputToForgetWeights.data()); |
| 2356 | AllocateAndCopyDataToITensorHandle(&inputToCellWeightsTensor, inputToCellWeights.data()); |
| 2357 | AllocateAndCopyDataToITensorHandle(&inputToOutputWeightsTensor, inputToOutputWeights.data()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2358 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2359 | AllocateAndCopyDataToITensorHandle(&recurrentToForgetWeightsTensor, recurrentToForgetWeights.data()); |
| 2360 | AllocateAndCopyDataToITensorHandle(&recurrentToCellWeightsTensor, recurrentToCellWeights.data()); |
| 2361 | AllocateAndCopyDataToITensorHandle(&recurrentToOutputWeightsTensor, recurrentToOutputWeights.data()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2362 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2363 | AllocateAndCopyDataToITensorHandle(&forgetGateBiasTensor, forgetGateBias.data()); |
| 2364 | AllocateAndCopyDataToITensorHandle(&cellBiasTensor, cellBias.data()); |
| 2365 | AllocateAndCopyDataToITensorHandle(&outputGateBiasTensor, outputGateBias.data()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2366 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2367 | AllocateAndCopyDataToITensorHandle(&forgetLayerNormWeightsTensor, forgetLayerNormWeights.data()); |
| 2368 | AllocateAndCopyDataToITensorHandle(&cellLayerNormWeightsTensor, cellLayerNormWeights.data()); |
| 2369 | AllocateAndCopyDataToITensorHandle(&outputLayerNormWeightsTensor, outputLayerNormWeights.data()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2370 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2371 | AllocateAndCopyDataToITensorHandle(&projectionWeightsTensor, projectionWeights.data()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2372 | |
| 2373 | // Setup queue descriptor |
| 2374 | data.m_InputToForgetWeights = &inputToForgetWeightsTensor; |
| 2375 | data.m_InputToCellWeights = &inputToCellWeightsTensor; |
| 2376 | data.m_InputToOutputWeights = &inputToOutputWeightsTensor; |
| 2377 | |
| 2378 | data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; |
| 2379 | data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; |
| 2380 | data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; |
| 2381 | |
| 2382 | data.m_ForgetGateBias = &forgetGateBiasTensor; |
| 2383 | data.m_CellBias = &cellBiasTensor; |
| 2384 | data.m_OutputGateBias = &outputGateBiasTensor; |
| 2385 | |
| 2386 | data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor; |
| 2387 | data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor; |
| 2388 | data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor; |
| 2389 | |
| 2390 | data.m_ProjectionWeights = &projectionWeightsTensor; |
| 2391 | |
| 2392 | data.m_Parameters.m_CifgEnabled = cifgEnabled; |
| 2393 | data.m_Parameters.m_PeepholeEnabled = peepholeEnabled; |
| 2394 | data.m_Parameters.m_ProjectionEnabled = projectionEnabled; |
| 2395 | data.m_Parameters.m_LayerNormEnabled = layerNormEnabled; |
| 2396 | |
| 2397 | data.m_Parameters.m_InputIntermediateScale = inputIntermediateScale; |
| 2398 | data.m_Parameters.m_ForgetIntermediateScale = forgetIntermediateScale; |
| 2399 | data.m_Parameters.m_CellIntermediateScale = cellIntermediateScale; |
| 2400 | data.m_Parameters.m_OutputIntermediateScale = outputIntermediateScale; |
| 2401 | |
| 2402 | data.m_Parameters.m_HiddenStateZeroPoint = hiddenStateZeroPoint; |
| 2403 | data.m_Parameters.m_HiddenStateScale = hiddenStateScale; |
| 2404 | |
| 2405 | data.m_Parameters.m_CellClip = cellClip; |
| 2406 | data.m_Parameters.m_ProjectionClip = projectionClip; |
| 2407 | |
| 2408 | // Create workload and allocate tensor handles |
| 2409 | std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateQLstm(data, info); |
| 2410 | inputHandle->Allocate(); |
| 2411 | outputStateInHandle->Allocate(); |
| 2412 | cellStateInHandle->Allocate(); |
| 2413 | |
| 2414 | outputStateOutHandle->Allocate(); |
| 2415 | cellStateOutHandle->Allocate(); |
| 2416 | outputHandle->Allocate(); |
| 2417 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2418 | CopyDataToITensorHandle(inputHandle.get(), inputVector.data()); |
| 2419 | CopyDataToITensorHandle(outputStateInHandle.get(), outputStateInVector.data()); |
| 2420 | CopyDataToITensorHandle(cellStateInHandle.get(), cellStateInVector.data()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2421 | |
| 2422 | workload->Execute(); |
| 2423 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2424 | CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2425 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2426 | return LayerTestResult<int8_t, 2>(actualOutput, |
| 2427 | outputVector, |
| 2428 | outputHandle->GetShape(), |
| 2429 | outputStateInfo.GetShape()); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2430 | } |
| 2431 | |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 2432 | |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2433 | } // anonymous namespace |
| 2434 | |
| 2435 | #if defined(ARMNNREF_ENABLED) |
| 2436 | |
| 2437 | // The LSTM test units are run only for the reference backend at the moment |
| 2438 | |
| 2439 | void LstmUtilsZeroVectorTest() |
| 2440 | { |
| 2441 | armnn::TensorInfo inputDesc({4}, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2442 | std::vector<float> input = {2., 3., 3., 4.}; |
| 2443 | std::vector<float> expectedOutput = {0., 0., 0., 0.}; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2444 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2445 | return LstmUtilsZeroVectorTestImpl<armnn::DataType::Float32>(input, 4, expectedOutput, inputDesc.GetShape()); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2446 | } |
| 2447 | |
| 2448 | void LstmUtilsMeanStddevNormalizationNoneZeroInputTest() |
| 2449 | { |
| 2450 | uint32_t batchSize = 2; |
| 2451 | uint32_t vecSize = 4; |
| 2452 | armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2453 | std::vector<float> input = |
| 2454 | { 0.1f, 0.2f, 0.3f, 0.4f, //batch 0 |
| 2455 | 0.9f, 1.0f, 1.1f, 1.2f }; //batch 1 |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2456 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2457 | std::vector<float> expectedOutput = |
| 2458 | { -1.34164071f, -0.447213531f, 0.44721365f, 1.34164071f, //batch 0 |
| 2459 | -1.34163153f, -0.447210163f, 0.447211236f, 1.3416326f }; //batch 1 |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2460 | |
| 2461 | return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2462 | vecSize, batchSize, expectedOutput, inputDesc.GetShape()); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2463 | } |
| 2464 | |
| 2465 | void LstmUtilsMeanStddevNormalizationAllZeroInputTest() |
| 2466 | { |
| 2467 | uint32_t batchSize = 2; |
| 2468 | uint32_t vecSize = 4; |
| 2469 | armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2470 | std::vector<float> input = |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2471 | { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0 |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2472 | 0.0f, 0.0f, 0.0f, 0.0f }; //batch 1 |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2473 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2474 | std::vector<float> expectedOutput = |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2475 | { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0 |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2476 | 0.0f, 0.0f, 0.0f, 0.0f }; //batch 1 |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2477 | |
| 2478 | return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2479 | vecSize, batchSize, expectedOutput, inputDesc.GetShape()); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2480 | } |
| 2481 | |
| 2482 | void LstmUtilsMeanStddevNormalizationMixedZeroInputTest() |
| 2483 | { |
| 2484 | uint32_t batchSize = 2; |
| 2485 | uint32_t vecSize = 4; |
| 2486 | armnn::TensorInfo inputDesc({batchSize, vecSize}, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2487 | std::vector<float> input = |
| 2488 | { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0 |
| 2489 | 0.1f, 0.2f, 0.3f, 0.4f }; //batch 1 |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2490 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2491 | std::vector<float> expectedOutput = |
| 2492 | { 0.0f, 0.0f, 0.0f, 0.0f, //batch 0 |
| 2493 | -1.34164071f, -0.447213531f, 0.44721365f, 1.34164071f }; //batch 1 |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2494 | |
| 2495 | return LstmUtilsMeanStddevNormalizationTestImpl<armnn::DataType::Float32>(input, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2496 | vecSize, batchSize, expectedOutput, inputDesc.GetShape()); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2497 | } |
| 2498 | |
| 2499 | void LstmUtilsVectorBatchVectorCwiseProductTest() |
| 2500 | { |
| 2501 | uint32_t batchSize = 4; |
| 2502 | uint32_t vecSize = 29; |
| 2503 | armnn::TensorInfo vecDesc({vecSize}, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2504 | std::vector<float> vector = |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2505 | { 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.1f, |
| 2506 | 11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.2f, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2507 | 21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f, 0.0f}; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2508 | |
| 2509 | armnn::TensorInfo batchVecDesc({batchSize, vecSize}, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2510 | std::vector<float> batchVector = |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2511 | { /* batch 0 */ |
| 2512 | 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.1f, |
| 2513 | 11.11f, 12.12f, 13.13f, 14.14f, 15.15f, 16.16f, 17.17f, 18.18f, 19.19f, 20.2f, |
| 2514 | 21.21f, 22.22f, 23.23f, 24.24f, 25.25f, 26.26f, 27.27f, 28.28f, 0.0f, |
| 2515 | /* batch 1 */ |
| 2516 | -1.1f, -2.2f, -3.3f, -4.4f, -5.5f, -6.6f, -7.7f, -8.8f, -9.9f, -10.1f, |
| 2517 | -11.11f, -12.12f, -13.13f, -14.14f, -15.15f, -16.16f, -17.17f, -18.18f, -19.19f, -20.2f, |
| 2518 | -21.21f, -22.22f, -23.23f, -24.24f, -25.25f, -26.26f, -27.27f, -28.28f, 0.0f, |
| 2519 | /* batch 2 */ |
| 2520 | 1.1f, -2.2f, 3.3f, -4.4f, 5.5f, -6.6f, 7.7f, -8.8f, 9.9f, -10.1f, |
| 2521 | 11.11f, -12.12f, 13.13f, -14.14f, 15.15f, -16.16f, 17.17f, -18.18f, 19.19f, -20.2f, |
| 2522 | 21.21f, -22.22f, 23.23f, -24.24f, 25.25f, -26.26f, 27.27f, -28.28f, 0.0f, |
| 2523 | /* batch 3 */ |
| 2524 | -1.1f, 2.2f, -3.3f, 4.4f, -5.5f, 6.6f, -7.7f, 8.8f, -9.9f, 10.1f, |
| 2525 | -11.11f, 12.12f, -13.13f, 14.14f, -15.15f, 16.16f, -17.17f, 18.18f, -19.19f, 20.2f, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2526 | -21.21f, 22.22f, -23.23f, 24.24f, -25.25f, 26.26f, -27.27f, 28.28f, 0.0f}; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2527 | |
| 2528 | // Expect output = input * output + output. |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2529 | std::vector<float> expectedOutput = |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2530 | { /* batch 0 */ |
| 2531 | 1.210000f, 4.840000f, 10.889999f, 19.360001f, 30.250000f, 43.559998f, |
| 2532 | 59.289997f, 77.440002f, 98.009995f, 102.010010f, 123.432091f, 146.894394f, |
| 2533 | 172.396896f, 199.939606f, 229.522491f, 261.145599f, 294.808899f, 330.512421f, |
| 2534 | 368.256134f, 408.040039f, 449.864075f, 493.728363f, 539.632874f, 587.577576f, |
| 2535 | 637.562500f, 689.587585f, 743.652954f, 799.758423f, 0.000000f, |
| 2536 | /* batch 1 */ |
| 2537 | -1.210000f, -4.840000f, -10.889999f, -19.360001f, -30.250000f, -43.559998f, |
| 2538 | -59.289997f, -77.440002f, -98.009995f, -102.010010f, -123.432091f, -146.894394f, |
| 2539 | -172.396896f, -199.939606f, -229.522491f, -261.145599f, -294.808899f, -330.512421f, |
| 2540 | -368.256134f, -408.040039f, -449.864075f, -493.728363f, -539.632874f, -587.577576f, |
| 2541 | -637.562500f, -689.587585f, -743.652954f, -799.758423f, 0.000000f, |
| 2542 | /* batch 2 */ |
| 2543 | 1.210000f, -4.840000f, 10.889999f, -19.360001f, 30.250000f, -43.559998f, |
| 2544 | 59.289997f, -77.440002f, 98.009995f, -102.010010f, 123.432091f, -146.894394f, |
| 2545 | 172.396896f, -199.939606f, 229.522491f, -261.145599f, 294.808899f, -330.512421f, |
| 2546 | 368.256134f, -408.040039f, 449.864075f, -493.728363f, 539.632874f, -587.577576f, |
| 2547 | 637.562500f, -689.587585f, 743.652954f, -799.758423f, 0.000000f, |
| 2548 | /* batch 3 */ |
| 2549 | -1.210000f, 4.840000f, -10.889999f, 19.360001f, -30.250000f, 43.559998f, |
| 2550 | -59.289997f, 77.440002f, -98.009995f, 102.010010f, -123.432091f, 146.894394f, |
| 2551 | -172.396896f, 199.939606f, -229.522491f, 261.145599f, -294.808899f, 330.512421f, |
| 2552 | -368.256134f, 408.040039f, -449.864075f, 493.728363f, -539.632874f, 587.577576f, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2553 | -637.562500f, 689.587585f, -743.652954f, 799.758423f, 0.000000f}; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2554 | |
| 2555 | return LstmUtilsVectorBatchVectorCwiseProductTestImpl<armnn::DataType::Float32>(vector, batchVector, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2556 | vecSize, batchSize, expectedOutput, vecDesc.GetShape()); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2557 | } |
| 2558 | |
| 2559 | void LstmUtilsVectorBatchVectorAddTest() |
| 2560 | { |
| 2561 | uint32_t batchSize = 2; |
| 2562 | uint32_t vecSize = 3; |
| 2563 | armnn::TensorInfo vecDesc({vecSize}, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2564 | std::vector<float> vector = { 0.0f, -0.5f, 1.0f}; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2565 | |
| 2566 | armnn::TensorInfo batchVecDesc({batchSize, vecSize}, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2567 | std::vector<float> batchVector = |
| 2568 | { |
| 2569 | 1.0f, 2.0f, 3.0f, //batch 0 |
| 2570 | 4.0f, 5.0f, 6.0f //batch 1 |
| 2571 | }; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2572 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2573 | std::vector<float> expectedOutput = |
| 2574 | { |
| 2575 | 1.0f, 1.5f, 4.0f, |
| 2576 | 4.0f, 4.5f, 7.0f |
| 2577 | }; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2578 | |
| 2579 | return LstmUtilsVectorBatchVectorAddTestImpl<armnn::DataType::Float32>(vector, batchVector, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2580 | vecSize, batchSize, expectedOutput, batchVecDesc.GetShape()); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2581 | } |
| 2582 | |
| 2583 | #endif |
| 2584 | |
| 2585 | LayerTestResult<float, 2> LstmLayerFloat32WithCifgWithPeepholeNoProjectionTest( |
| 2586 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2587 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 2588 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2589 | { |
| 2590 | armnn::TensorInfo inputDesc({ 2, 2 }, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2591 | std::vector<float> input = { 2., 3., 3., 4. }; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2592 | |
| 2593 | armnn::TensorInfo outputDesc({ 2, 4 }, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2594 | std::vector<float> expectedOutput = |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2595 | {-0.36444446f, -0.00352185f, 0.12886585f, -0.05163646f, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2596 | -0.42734814f, -0.00478661f, 0.13455015f, -0.03560682f}; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2597 | return LstmLayerWithCifgWithPeepholeNoProjectionTestImpl<armnn::DataType::Float32>( |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2598 | workloadFactory, memoryManager, tensorHandleFactory, |
| 2599 | input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape()); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2600 | } |
| 2601 | |
| 2602 | LayerTestResult<float, 2> LstmLayerFloat32NoCifgWithPeepholeWithProjectionTest( |
| 2603 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2604 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 2605 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2606 | { |
| 2607 | armnn::TensorInfo inputDesc({ 2, 5 }, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2608 | std::vector<float> input = |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2609 | {0.787926f, 0.151646f, 0.071352f, 0.118426f, 0.458058f, |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2610 | 0.295743f, 0.544053f, 0.690064f, 0.858138f, 0.497181f}; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2611 | |
| 2612 | armnn::TensorInfo outputDesc({ 2, 16 }, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2613 | std::vector<float> expectedOutput = |
| 2614 | {-0.00396806f, 0.029352f, -0.00279226f, 0.0159977f, -0.00835576f, |
| 2615 | -0.0211779f, 0.0283512f, -0.0114597f, 0.00907307f, -0.0244004f, |
| 2616 | -0.0152191f, -0.0259063f, 0.00914318f, 0.00415118f, 0.017147f, |
| 2617 | 0.0134203f, -0.013869f, 0.0287268f, -0.00334693f, 0.00733398f, -0.0287926f, |
| 2618 | -0.0186926f, 0.0193662f, -0.0115437f, 0.00422612f, -0.0345232f, |
| 2619 | 0.00223253f, -0.00957321f, 0.0210624f, 0.013331f, 0.0150954f, 0.02168f}; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2620 | return LstmLayerNoCifgWithPeepholeWithProjectionTestImpl<armnn::DataType::Float32>( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2621 | workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2622 | } |
| 2623 | |
| 2624 | LayerTestResult<float, 2> LstmLayerFloat32NoCifgNoPeepholeNoProjectionTest( |
| 2625 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2626 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 2627 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2628 | { |
| 2629 | armnn::TensorInfo inputDesc({2, 2}, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2630 | std::vector<float> input = {2., 3., 3., 4.}; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2631 | |
| 2632 | armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2633 | std::vector<float> expectedOutput = |
| 2634 | {-0.02973187f, 0.1229473f, 0.20885126f, -0.15358765f, |
| 2635 | -0.0185422f, 0.11281417f, 0.24466537f, -0.1826292f}; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2636 | |
| 2637 | return LstmNoCifgNoPeepholeNoProjectionTestImpl<armnn::DataType::Float32>( |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2638 | workloadFactory, memoryManager, tensorHandleFactory, |
| 2639 | input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape()); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2640 | } |
| 2641 | |
| 2642 | LayerTestResult<float, 2> LstmLayerFloat32NoCifgWithPeepholeWithProjectionWithLayerNormTest( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2643 | armnn::IWorkloadFactory& workloadFactory, |
| 2644 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 2645 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2646 | { |
| 2647 | armnn::TensorInfo inputDesc({ 2, 5 }, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2648 | std::vector<float> input = |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2649 | {0.7f, 0.8f, 0.1f, 0.2f, 0.3f, //batch 0 |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2650 | 0.3f, 0.2f, 0.9f, 0.8f, 0.1f}; //batch 1 |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2651 | |
| 2652 | armnn::TensorInfo outputDesc({ 2, 3 }, armnn::DataType::Float32); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2653 | std::vector<float> expectedOutput = |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2654 | { 0.0244077f, 0.128027f, -0.00170918f, //batch 0 |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2655 | -0.00692428f, 0.0848741f, 0.063445f}; //batch 1 |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2656 | return LstmLayerNoCifgWithPeepholeWithProjectionWithLayerNormTestImpl<armnn::DataType::Float32>( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2657 | workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2658 | } |
| 2659 | |
| 2660 | LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgNoPeepholeNoProjectionTest( |
| 2661 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2662 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 2663 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2664 | { |
| 2665 | const float qScale = 1.0f; |
| 2666 | const int32_t qOffset = 0; |
| 2667 | |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 2668 | const armnn::DataType datatype = armnn::DataType::QSymmS16; |
| 2669 | const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2670 | |
| 2671 | armnn::TensorInfo inputDesc({2, 2}, datatype); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2672 | std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2673 | |
| 2674 | armnn::TensorInfo outputDesc({2, 4}, datatype); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2675 | std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>( |
| 2676 | { |
| 2677 | -0.02973187f, 0.12294730f, 0.20885126f, -0.15358765f, |
| 2678 | -0.01854220f, 0.11281417f, 0.24466537f, -0.18262920f |
| 2679 | }, |
| 2680 | qScale, qOffset); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2681 | |
| 2682 | return LstmNoCifgNoPeepholeNoProjectionTestImpl<datatype>( |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2683 | workloadFactory, memoryManager, tensorHandleFactory, |
| 2684 | input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(), |
| 2685 | qScale, qOffset, constantDatatype); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2686 | |
| 2687 | } |
| 2688 | |
| 2689 | LayerTestResult<int16_t, 2> LstmLayerInt16WithCifgWithPeepholeNoProjectionTest( |
| 2690 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2691 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 2692 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2693 | { |
| 2694 | const float qScale = 1.0f; |
| 2695 | const int32_t qOffset = 0; |
| 2696 | |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 2697 | const armnn::DataType datatype = armnn::DataType::QSymmS16; |
| 2698 | const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2699 | |
| 2700 | armnn::TensorInfo inputDesc({ 2, 2 }, datatype); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2701 | std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2702 | |
| 2703 | armnn::TensorInfo outputDesc({ 2, 4 }, datatype); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2704 | std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>( |
| 2705 | { |
| 2706 | -0.36444446f, -0.00352185f, 0.12886585f, -0.05163646f, |
| 2707 | -0.42734814f, -0.00478661f, 0.13455015f, -0.03560682f |
| 2708 | }, |
| 2709 | qScale, qOffset); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2710 | |
| 2711 | return LstmLayerWithCifgWithPeepholeNoProjectionTestImpl<datatype>( |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2712 | workloadFactory, memoryManager, tensorHandleFactory, |
| 2713 | input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(), |
| 2714 | qScale, qOffset, constantDatatype); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2715 | } |
| 2716 | |
| 2717 | LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgWithPeepholeWithProjectionTest( |
| 2718 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2719 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 2720 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2721 | { |
| 2722 | const float qScale = 2.0f; |
| 2723 | const int32_t qOffset = 0; |
| 2724 | |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 2725 | const armnn::DataType datatype = armnn::DataType::QSymmS16; |
| 2726 | const armnn::DataType constantDatatype = armnn::DataType::QAsymmU8; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2727 | |
| 2728 | armnn::TensorInfo inputDesc({ 2, 5 }, datatype); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2729 | std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>( |
| 2730 | { |
| 2731 | 0.787926f, 0.151646f, 0.071352f, 0.118426f, 0.458058f, |
| 2732 | 0.295743f, 0.544053f, 0.690064f, 0.858138f, 0.497181f |
| 2733 | }, |
| 2734 | qScale, qOffset); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2735 | |
| 2736 | armnn::TensorInfo outputDesc({ 2, 16 }, datatype); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2737 | std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>( |
| 2738 | { |
| 2739 | -0.00396806f, 0.02935200f, -0.00279226f, 0.01599770f, |
| 2740 | -0.00835576f, -0.02117790f, 0.02835120f, -0.01145970f, |
| 2741 | 0.00907307f, -0.02440040f, -0.01521910f, -0.02590630f, |
| 2742 | 0.00914318f, 0.00415118f, 0.01714700f, 0.01342030f, |
| 2743 | -0.01386900f, 0.02872680f, -0.00334693f, 0.00733398f, |
| 2744 | -0.02879260f, -0.01869260f, 0.01936620f, -0.01154370f, |
| 2745 | 0.00422612f, -0.03452320f, 0.00223253f, -0.00957321f, |
| 2746 | 0.02106240f, 0.01333100f, 0.01509540f, 0.02168000f |
| 2747 | }, |
| 2748 | qScale, qOffset); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2749 | |
| 2750 | return LstmLayerNoCifgWithPeepholeWithProjectionTestImpl<datatype>( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2751 | workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput, qScale, qOffset, constantDatatype); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2752 | } |
| 2753 | |
| 2754 | LayerTestResult<int16_t, 2> LstmLayerInt16NoCifgNoPeepholeNoProjectionInt16ConstantTest( |
| 2755 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2756 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 2757 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2758 | { |
| 2759 | const float qScale = 1.0f; |
| 2760 | const int32_t qOffset = 0; |
| 2761 | |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 2762 | const armnn::DataType datatype = armnn::DataType::QSymmS16; // datatype & constants set to QSymm16 |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2763 | |
| 2764 | armnn::TensorInfo inputDesc({2, 2}, datatype); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2765 | std::vector<int16_t> input = armnnUtils::QuantizedVector<int16_t>({ 2.f, 3.f, 3.f, 4.f }, qScale, qOffset); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2766 | |
| 2767 | armnn::TensorInfo outputDesc({2, 4}, datatype); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2768 | std::vector<int16_t> expectedOutput = armnnUtils::QuantizedVector<int16_t>( |
| 2769 | { |
| 2770 | -0.02973187f, 0.12294730f, 0.20885126f, -0.15358765f, |
| 2771 | -0.01854220f, 0.11281417f, 0.24466537f, -0.18262920f |
| 2772 | }, |
| 2773 | qScale, qOffset); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2774 | |
| 2775 | return LstmNoCifgNoPeepholeNoProjectionTestImpl<datatype>( |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2776 | workloadFactory, memoryManager, tensorHandleFactory, |
| 2777 | input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape(), |
| 2778 | qScale, qOffset, datatype); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2779 | } |
| 2780 | |
| 2781 | // |
| 2782 | // QuantizedLstm |
| 2783 | // |
| 2784 | |
| 2785 | LayerTestResult<uint8_t, 2> QuantizedLstmTest( |
| 2786 | armnn::IWorkloadFactory& workloadFactory, |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2787 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 2788 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2789 | { |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 2790 | armnn::TensorInfo inputDesc({2, 2}, armnn::DataType::QAsymmU8); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2791 | std::vector<uint8_t> input = {166, 179, 50, 150}; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2792 | |
Derek Lamberti | f90c56d | 2020-01-10 17:14:08 +0000 | [diff] [blame] | 2793 | armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::QAsymmU8); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2794 | std::vector<uint8_t> expectedOutput = {140, 151, 146, 112, 136, 156, 142, 112 }; |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2795 | |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2796 | return QuantizedLstmTestImpl(workloadFactory, memoryManager, tensorHandleFactory, |
| 2797 | input, expectedOutput, inputDesc.GetShape(), outputDesc.GetShape()); |
Aron Virginas-Tar | 00d306e | 2019-08-28 18:08:46 +0100 | [diff] [blame] | 2798 | } |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 2799 | |
| 2800 | // QLSTM |
| 2801 | LayerTestResult<int8_t, 2> QLstmTest( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2802 | armnn::IWorkloadFactory& workloadFactory, |
| 2803 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 2804 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 2805 | { |
| 2806 | armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2807 | std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64}; |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 2808 | |
| 2809 | armnn::TensorInfo outputDesc({2, 4}, armnn::DataType::QAsymmS8); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2810 | std::vector<int8_t> expectedOutput = {-15, 21, 14, 20, -15, 15, 5, 27}; |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 2811 | |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2812 | return QLstmTestImpl(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput); |
James Conroy | 4f1f899 | 2020-04-29 20:01:10 +0100 | [diff] [blame] | 2813 | } |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2814 | |
| 2815 | LayerTestResult<int8_t, 2> QLstmTest1( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2816 | armnn::IWorkloadFactory& workloadFactory, |
| 2817 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 2818 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2819 | { |
| 2820 | armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2821 | std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2822 | |
| 2823 | armnn::TensorInfo outputDesc({2, 3}, armnn::DataType::QAsymmS8); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2824 | std::vector<int8_t> expectedOutput = {127, 127, -108, -67, 127, 127}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2825 | |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2826 | return QLstmTestImpl1(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2827 | } |
| 2828 | |
| 2829 | LayerTestResult<int8_t, 2> QLstmTest2( |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2830 | armnn::IWorkloadFactory& workloadFactory, |
| 2831 | const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, |
| 2832 | const armnn::ITensorHandleFactory& tensorHandleFactory) |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2833 | { |
| 2834 | armnn::TensorInfo inputDesc({2, 5}, armnn::DataType::QAsymmS8); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2835 | std::vector<int8_t> input = {90, 102, 13, 26, 38, 102, 13, 26, 51, 64}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2836 | |
| 2837 | armnn::TensorInfo outputDesc({2, 3}, armnn::DataType::QAsymmS8); |
Sadik Armagan | 483c811 | 2021-06-01 09:24:52 +0100 | [diff] [blame] | 2838 | std::vector<int8_t> expectedOutput = {127, 127, 127, -128, 127, 127}; |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2839 | |
Finn Williams | c43de6a | 2020-08-27 11:13:25 +0100 | [diff] [blame] | 2840 | return QLstmTestImpl2(workloadFactory, memoryManager, tensorHandleFactory, input, expectedOutput); |
James Conroy | b22a75e | 2020-06-08 14:53:10 +0100 | [diff] [blame] | 2841 | } |