Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "RefLstmWorkload.hpp" |
| 7 | #include "Activation.hpp" |
| 8 | #include "Encoders.hpp" |
| 9 | #include "Decoders.hpp" |
Narumol Prangnawarat | e5339e7 | 2021-07-28 17:33:28 +0100 | [diff] [blame] | 10 | #include "Lstm.hpp" |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 11 | #include "LstmUtils.hpp" |
| 12 | #include "RefWorkloadUtils.hpp" |
| 13 | |
| 14 | namespace armnn |
| 15 | { |
| 16 | |
| 17 | RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) |
| 18 | : BaseWorkload<LstmQueueDescriptor>(descriptor, info) |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 19 | , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) |
| 20 | , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) |
| 21 | , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) |
| 22 | , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights)) |
| 23 | , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights)) |
| 24 | , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights)) |
| 25 | , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights)) |
| 26 | , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights)) |
| 27 | , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights)) |
| 28 | , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights)) |
| 29 | , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights)) |
| 30 | , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias)) |
| 31 | , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias)) |
| 32 | , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias)) |
| 33 | , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias)) |
| 34 | , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights)) |
| 35 | , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias)) |
| 36 | , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights)) |
| 37 | , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights)) |
| 38 | , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights)) |
| 39 | , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights)) |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 40 | {} |
| 41 | |
| 42 | void RefLstmWorkload::Execute() const |
| 43 | { |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 44 | Execute(m_Data.m_Inputs, m_Data.m_Outputs); |
| 45 | } |
| 46 | |
| 47 | void RefLstmWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) |
| 48 | { |
| 49 | Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); |
| 50 | } |
| 51 | |
| 52 | void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const |
| 53 | { |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 54 | // This is a porting of the LSTM::Eval() method in the Android code base |
| 55 | // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp |
| 56 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 57 | const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); |
| 58 | const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 59 | |
| 60 | const TensorShape& inputShape = inputInfo.GetShape(); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 61 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 62 | std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map()); |
| 63 | std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map()); |
| 64 | std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 65 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 66 | std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map()); |
| 67 | std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 68 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 69 | std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map()); |
| 70 | std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map()); |
| 71 | std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 72 | |
| 73 | const uint32_t nBatch = inputShape[0]; |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 74 | const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 75 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 76 | const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; |
| 77 | const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; |
| 78 | const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled; |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 79 | |
| 80 | // Index the scratch buffers pointers to the global scratch buffer. |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 81 | std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map()); |
| 82 | std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map()); |
| 83 | std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map()); |
| 84 | std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 85 | |
| 86 | std::unique_ptr<Decoder<float>> inputGateScratchDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 87 | MakeDecoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 88 | std::unique_ptr<Decoder<float>> cellScratchDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 89 | MakeDecoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 90 | std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 91 | MakeDecoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 92 | std::unique_ptr<Decoder<float>> outputGateScratchDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 93 | MakeDecoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 94 | |
| 95 | if (useCifg) |
| 96 | { |
| 97 | *cellScratch += (0 * nCell * nBatch); |
| 98 | *forgetGateScratch += (1 * nCell * nBatch); |
| 99 | *outputGateScratch += (2 * nCell * nBatch); |
| 100 | |
| 101 | *cellScratchDecoder += (0 * nCell * nBatch); |
| 102 | *forgetGateScratchDecoder += (1 * nCell * nBatch); |
| 103 | *outputGateScratchDecoder += (2 * nCell * nBatch); |
| 104 | } |
| 105 | else |
| 106 | { |
| 107 | *inputGateScratch += (0 * nCell * nBatch); |
| 108 | *cellScratch += (1 * nCell * nBatch); |
| 109 | *forgetGateScratch += (2 * nCell * nBatch); |
| 110 | *outputGateScratch += (3 * nCell * nBatch); |
| 111 | |
| 112 | *inputGateScratchDecoder += (0 * nCell * nBatch); |
| 113 | *cellScratchDecoder += (1 * nCell * nBatch); |
| 114 | *forgetGateScratchDecoder += (2 * nCell * nBatch); |
| 115 | *outputGateScratchDecoder += (3 * nCell * nBatch); |
| 116 | } |
| 117 | |
| 118 | std::unique_ptr<Decoder<float>> inputToInputWeightsTensor; |
| 119 | std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 120 | m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 121 | std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 122 | m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 123 | std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 124 | m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 125 | |
| 126 | std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor; |
| 127 | std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 128 | m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 129 | std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 130 | m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 131 | std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 132 | m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 133 | |
| 134 | std::unique_ptr<Decoder<float>> inputGateBiasTensor; |
| 135 | std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 136 | m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 137 | std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 138 | m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 139 | std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 140 | m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 141 | |
| 142 | std::unique_ptr<Decoder<float>> cellToInputWeightsTensor; |
| 143 | std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor; |
| 144 | std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor; |
| 145 | |
| 146 | std::unique_ptr<Decoder<float>> projectionWeightsTensor; |
| 147 | std::unique_ptr<Decoder<float>> projectionBiasTensor; |
| 148 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 149 | std::unique_ptr<Decoder<float>> inputLayerNormWeights; |
| 150 | std::unique_ptr<Decoder<float>> forgetLayerNormWeights; |
| 151 | std::unique_ptr<Decoder<float>> cellLayerNormWeights; |
| 152 | std::unique_ptr<Decoder<float>> outputLayerNormWeights; |
| 153 | |
Narumol Prangnawarat | e5339e7 | 2021-07-28 17:33:28 +0100 | [diff] [blame] | 154 | const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape(); |
| 155 | const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape(); |
| 156 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 157 | if (useLayerNorm) |
| 158 | { |
| 159 | if (!useCifg) |
| 160 | { |
| 161 | inputLayerNormWeights = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 162 | m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 163 | } |
| 164 | forgetLayerNormWeights = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 165 | m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 166 | cellLayerNormWeights = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 167 | m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 168 | outputLayerNormWeights = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 169 | m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 170 | } |
| 171 | |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 172 | if (!useCifg) |
| 173 | { |
| 174 | inputToInputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 175 | m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 176 | inputGateBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 177 | m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 178 | recurrentToInputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 179 | m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 180 | } |
| 181 | |
| 182 | if (usePeephole) |
| 183 | { |
| 184 | cellToForgetWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 185 | m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 186 | cellToOutputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 187 | m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 188 | } |
| 189 | |
| 190 | if (!useCifg && usePeephole) |
| 191 | { |
| 192 | cellToInputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 193 | m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 194 | } |
| 195 | |
| 196 | if (m_Data.m_Parameters.m_ProjectionEnabled) |
| 197 | { |
| 198 | projectionWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 199 | m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 200 | if (m_ProjectionBiasTensor) |
| 201 | { |
| 202 | projectionBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 203 | m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 204 | } |
| 205 | } |
| 206 | |
Narumol Prangnawarat | e5339e7 | 2021-07-28 17:33:28 +0100 | [diff] [blame] | 207 | LstmImpl(m_Data.m_Parameters, |
| 208 | inputInfo, |
| 209 | outputInfo, |
| 210 | inputToOutputWeightsShape, |
| 211 | recurrentToOutputWeightsShape, |
| 212 | inputData, |
| 213 | outputStateIn, |
| 214 | cellStateIn, |
| 215 | outputStateOut, |
| 216 | cellStateOut, |
| 217 | output, |
| 218 | cellStateOutDecoder, |
| 219 | outputDecoder, |
| 220 | inputToInputWeightsTensor, |
| 221 | inputToForgetWeightsTensor, |
| 222 | inputToCellWeightsTensor, |
| 223 | inputToOutputWeightsTensor, |
| 224 | recurrentToInputWeightsTensor, |
| 225 | recurrentToForgetWeightsTensor, |
| 226 | recurrentToCellWeightsTensor, |
| 227 | recurrentToOutputWeightsTensor, |
| 228 | cellToInputWeightsTensor, |
| 229 | cellToForgetWeightsTensor, |
| 230 | cellToOutputWeightsTensor, |
| 231 | inputGateBiasTensor, |
| 232 | forgetGateBiasTensor, |
| 233 | cellBiasTensor, |
| 234 | outputGateBiasTensor, |
| 235 | projectionWeightsTensor, |
| 236 | projectionBiasTensor, |
| 237 | inputLayerNormWeights, |
| 238 | forgetLayerNormWeights, |
| 239 | cellLayerNormWeights, |
| 240 | outputLayerNormWeights, |
| 241 | inputGateScratch, |
| 242 | cellScratch, |
| 243 | forgetGateScratch, |
| 244 | outputGateScratch, |
| 245 | inputGateScratchDecoder, |
| 246 | cellScratchDecoder, |
| 247 | forgetGateScratchDecoder, |
| 248 | outputGateScratchDecoder, |
| 249 | m_LayerNormEpsilon); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 250 | } |
| 251 | |
| 252 | } //namespace armnn |