Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 1 | // |
Matthew Sloyan | 2d213a7 | 2022-06-30 17:13:04 +0100 | [diff] [blame] | 2 | // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "RefLstmWorkload.hpp" |
| 7 | #include "Activation.hpp" |
| 8 | #include "Encoders.hpp" |
| 9 | #include "Decoders.hpp" |
Narumol Prangnawarat | e5339e7 | 2021-07-28 17:33:28 +0100 | [diff] [blame] | 10 | #include "Lstm.hpp" |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 11 | #include "LstmUtils.hpp" |
| 12 | #include "RefWorkloadUtils.hpp" |
| 13 | |
| 14 | namespace armnn |
| 15 | { |
| 16 | |
| 17 | RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) |
Finn Williams | 73c547d | 2022-02-15 20:47:34 +0000 | [diff] [blame] | 18 | : RefBaseWorkload<LstmQueueDescriptor>(descriptor, info) |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 19 | , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) |
| 20 | , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) |
| 21 | , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) |
| 22 | , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights)) |
| 23 | , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights)) |
| 24 | , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights)) |
| 25 | , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights)) |
| 26 | , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights)) |
| 27 | , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights)) |
| 28 | , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights)) |
| 29 | , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights)) |
| 30 | , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias)) |
| 31 | , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias)) |
| 32 | , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias)) |
| 33 | , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias)) |
| 34 | , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights)) |
| 35 | , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias)) |
| 36 | , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights)) |
| 37 | , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights)) |
| 38 | , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights)) |
| 39 | , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights)) |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 40 | {} |
| 41 | |
| 42 | void RefLstmWorkload::Execute() const |
| 43 | { |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 44 | Execute(m_Data.m_Inputs, m_Data.m_Outputs); |
| 45 | } |
| 46 | |
Matthew Sloyan | 2d213a7 | 2022-06-30 17:13:04 +0100 | [diff] [blame] | 47 | void RefLstmWorkload::ExecuteAsync(ExecutionData& executionData) |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 48 | { |
Matthew Sloyan | 2d213a7 | 2022-06-30 17:13:04 +0100 | [diff] [blame] | 49 | WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data); |
| 50 | Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs); |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 51 | } |
| 52 | |
| 53 | void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const |
| 54 | { |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 55 | // This is a porting of the LSTM::Eval() method in the Android code base |
| 56 | // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp |
| 57 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 58 | const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); |
| 59 | const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 60 | |
| 61 | const TensorShape& inputShape = inputInfo.GetShape(); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 62 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 63 | std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map()); |
| 64 | std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map()); |
| 65 | std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 66 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 67 | std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map()); |
| 68 | std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 69 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 70 | std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map()); |
| 71 | std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map()); |
| 72 | std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 73 | |
| 74 | const uint32_t nBatch = inputShape[0]; |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 75 | const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 76 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 77 | const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; |
| 78 | const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; |
| 79 | const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled; |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 80 | |
| 81 | // Index the scratch buffers pointers to the global scratch buffer. |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 82 | std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map()); |
| 83 | std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map()); |
| 84 | std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map()); |
| 85 | std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 86 | |
| 87 | std::unique_ptr<Decoder<float>> inputGateScratchDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 88 | MakeDecoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 89 | std::unique_ptr<Decoder<float>> cellScratchDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 90 | MakeDecoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 91 | std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 92 | MakeDecoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 93 | std::unique_ptr<Decoder<float>> outputGateScratchDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 94 | MakeDecoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 95 | |
| 96 | if (useCifg) |
| 97 | { |
| 98 | *cellScratch += (0 * nCell * nBatch); |
| 99 | *forgetGateScratch += (1 * nCell * nBatch); |
| 100 | *outputGateScratch += (2 * nCell * nBatch); |
| 101 | |
| 102 | *cellScratchDecoder += (0 * nCell * nBatch); |
| 103 | *forgetGateScratchDecoder += (1 * nCell * nBatch); |
| 104 | *outputGateScratchDecoder += (2 * nCell * nBatch); |
| 105 | } |
| 106 | else |
| 107 | { |
| 108 | *inputGateScratch += (0 * nCell * nBatch); |
| 109 | *cellScratch += (1 * nCell * nBatch); |
| 110 | *forgetGateScratch += (2 * nCell * nBatch); |
| 111 | *outputGateScratch += (3 * nCell * nBatch); |
| 112 | |
| 113 | *inputGateScratchDecoder += (0 * nCell * nBatch); |
| 114 | *cellScratchDecoder += (1 * nCell * nBatch); |
| 115 | *forgetGateScratchDecoder += (2 * nCell * nBatch); |
| 116 | *outputGateScratchDecoder += (3 * nCell * nBatch); |
| 117 | } |
| 118 | |
| 119 | std::unique_ptr<Decoder<float>> inputToInputWeightsTensor; |
| 120 | std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 121 | m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 122 | std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 123 | m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 124 | std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 125 | m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 126 | |
| 127 | std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor; |
| 128 | std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 129 | m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 130 | std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 131 | m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 132 | std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 133 | m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 134 | |
| 135 | std::unique_ptr<Decoder<float>> inputGateBiasTensor; |
| 136 | std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 137 | m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 138 | std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 139 | m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 140 | std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 141 | m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 142 | |
| 143 | std::unique_ptr<Decoder<float>> cellToInputWeightsTensor; |
| 144 | std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor; |
| 145 | std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor; |
| 146 | |
| 147 | std::unique_ptr<Decoder<float>> projectionWeightsTensor; |
| 148 | std::unique_ptr<Decoder<float>> projectionBiasTensor; |
| 149 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 150 | std::unique_ptr<Decoder<float>> inputLayerNormWeights; |
| 151 | std::unique_ptr<Decoder<float>> forgetLayerNormWeights; |
| 152 | std::unique_ptr<Decoder<float>> cellLayerNormWeights; |
| 153 | std::unique_ptr<Decoder<float>> outputLayerNormWeights; |
| 154 | |
Narumol Prangnawarat | e5339e7 | 2021-07-28 17:33:28 +0100 | [diff] [blame] | 155 | const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape(); |
| 156 | const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape(); |
| 157 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 158 | if (useLayerNorm) |
| 159 | { |
| 160 | if (!useCifg) |
| 161 | { |
| 162 | inputLayerNormWeights = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 163 | m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 164 | } |
| 165 | forgetLayerNormWeights = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 166 | m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 167 | cellLayerNormWeights = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 168 | m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 169 | outputLayerNormWeights = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 170 | m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 171 | } |
| 172 | |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 173 | if (!useCifg) |
| 174 | { |
| 175 | inputToInputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 176 | m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 177 | inputGateBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 178 | m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 179 | recurrentToInputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 180 | m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 181 | } |
| 182 | |
| 183 | if (usePeephole) |
| 184 | { |
| 185 | cellToForgetWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 186 | m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 187 | cellToOutputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 188 | m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 189 | } |
| 190 | |
| 191 | if (!useCifg && usePeephole) |
| 192 | { |
| 193 | cellToInputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 194 | m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 195 | } |
| 196 | |
| 197 | if (m_Data.m_Parameters.m_ProjectionEnabled) |
| 198 | { |
| 199 | projectionWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 200 | m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 201 | if (m_ProjectionBiasTensor) |
| 202 | { |
| 203 | projectionBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 204 | m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 205 | } |
| 206 | } |
| 207 | |
Narumol Prangnawarat | e5339e7 | 2021-07-28 17:33:28 +0100 | [diff] [blame] | 208 | LstmImpl(m_Data.m_Parameters, |
| 209 | inputInfo, |
| 210 | outputInfo, |
| 211 | inputToOutputWeightsShape, |
| 212 | recurrentToOutputWeightsShape, |
| 213 | inputData, |
| 214 | outputStateIn, |
| 215 | cellStateIn, |
| 216 | outputStateOut, |
| 217 | cellStateOut, |
| 218 | output, |
| 219 | cellStateOutDecoder, |
| 220 | outputDecoder, |
| 221 | inputToInputWeightsTensor, |
| 222 | inputToForgetWeightsTensor, |
| 223 | inputToCellWeightsTensor, |
| 224 | inputToOutputWeightsTensor, |
| 225 | recurrentToInputWeightsTensor, |
| 226 | recurrentToForgetWeightsTensor, |
| 227 | recurrentToCellWeightsTensor, |
| 228 | recurrentToOutputWeightsTensor, |
| 229 | cellToInputWeightsTensor, |
| 230 | cellToForgetWeightsTensor, |
| 231 | cellToOutputWeightsTensor, |
| 232 | inputGateBiasTensor, |
| 233 | forgetGateBiasTensor, |
| 234 | cellBiasTensor, |
| 235 | outputGateBiasTensor, |
| 236 | projectionWeightsTensor, |
| 237 | projectionBiasTensor, |
| 238 | inputLayerNormWeights, |
| 239 | forgetLayerNormWeights, |
| 240 | cellLayerNormWeights, |
| 241 | outputLayerNormWeights, |
| 242 | inputGateScratch, |
| 243 | cellScratch, |
| 244 | forgetGateScratch, |
| 245 | outputGateScratch, |
| 246 | inputGateScratchDecoder, |
| 247 | cellScratchDecoder, |
| 248 | forgetGateScratchDecoder, |
| 249 | outputGateScratchDecoder, |
| 250 | m_LayerNormEpsilon); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 251 | } |
| 252 | |
| 253 | } //namespace armnn |