Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "RefLstmWorkload.hpp" |
| 7 | #include "Activation.hpp" |
| 8 | #include "Encoders.hpp" |
| 9 | #include "Decoders.hpp" |
| 10 | #include "LstmUtils.hpp" |
| 11 | #include "RefWorkloadUtils.hpp" |
| 12 | |
| 13 | namespace armnn |
| 14 | { |
| 15 | |
| 16 | RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) |
| 17 | : BaseWorkload<LstmQueueDescriptor>(descriptor, info) |
| 18 | , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights)) |
| 19 | , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights)) |
| 20 | , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights)) |
| 21 | , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights)) |
| 22 | , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights)) |
| 23 | , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights)) |
| 24 | , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights)) |
| 25 | , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights)) |
| 26 | , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights)) |
| 27 | , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights)) |
| 28 | , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights)) |
| 29 | , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias)) |
| 30 | , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias)) |
| 31 | , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias)) |
| 32 | , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias)) |
| 33 | , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights)) |
| 34 | , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias)) |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 35 | , m_InputLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_InputLayerNormWeights)) |
| 36 | , m_ForgetLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_ForgetLayerNormWeights)) |
| 37 | , m_CellLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_CellLayerNormWeights)) |
| 38 | , m_OutputLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_OutputLayerNormWeights)) |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 39 | {} |
| 40 | |
| 41 | void RefLstmWorkload::Execute() const |
| 42 | { |
| 43 | // This is a porting of the LSTM::Eval() method in the Android code base |
| 44 | // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp |
| 45 | |
| 46 | const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); |
| 47 | const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); |
| 48 | |
| 49 | const TensorShape& inputShape = inputInfo.GetShape(); |
| 50 | const DataType& outputType = outputInfo.GetDataType(); |
| 51 | |
| 52 | std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[1]->Map()); |
| 53 | std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[2]->Map()); |
| 54 | std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[3]->Map()); |
| 55 | |
| 56 | std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[2]->Map()); |
| 57 | std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[3]->Map()); |
| 58 | |
| 59 | std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map()); |
| 60 | std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[1]->Map()); |
| 61 | std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[2]->Map()); |
| 62 | |
| 63 | const uint32_t nBatch = inputShape[0]; |
| 64 | const uint32_t nInput = inputShape[1]; |
| 65 | |
| 66 | const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; |
| 67 | const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1]; |
| 68 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 69 | const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; |
| 70 | const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; |
| 71 | const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled; |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 72 | |
| 73 | // Index the scratch buffers pointers to the global scratch buffer. |
| 74 | std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 75 | std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 76 | std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 77 | std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 78 | |
| 79 | std::unique_ptr<Decoder<float>> inputGateScratchDecoder = |
| 80 | MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 81 | std::unique_ptr<Decoder<float>> cellScratchDecoder = |
| 82 | MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 83 | std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = |
| 84 | MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 85 | std::unique_ptr<Decoder<float>> outputGateScratchDecoder = |
| 86 | MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 87 | |
| 88 | if (useCifg) |
| 89 | { |
| 90 | *cellScratch += (0 * nCell * nBatch); |
| 91 | *forgetGateScratch += (1 * nCell * nBatch); |
| 92 | *outputGateScratch += (2 * nCell * nBatch); |
| 93 | |
| 94 | *cellScratchDecoder += (0 * nCell * nBatch); |
| 95 | *forgetGateScratchDecoder += (1 * nCell * nBatch); |
| 96 | *outputGateScratchDecoder += (2 * nCell * nBatch); |
| 97 | } |
| 98 | else |
| 99 | { |
| 100 | *inputGateScratch += (0 * nCell * nBatch); |
| 101 | *cellScratch += (1 * nCell * nBatch); |
| 102 | *forgetGateScratch += (2 * nCell * nBatch); |
| 103 | *outputGateScratch += (3 * nCell * nBatch); |
| 104 | |
| 105 | *inputGateScratchDecoder += (0 * nCell * nBatch); |
| 106 | *cellScratchDecoder += (1 * nCell * nBatch); |
| 107 | *forgetGateScratchDecoder += (2 * nCell * nBatch); |
| 108 | *outputGateScratchDecoder += (3 * nCell * nBatch); |
| 109 | } |
| 110 | |
| 111 | std::unique_ptr<Decoder<float>> inputToInputWeightsTensor; |
| 112 | std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>( |
| 113 | m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<void>()); |
| 114 | std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>( |
| 115 | m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<void>()); |
| 116 | std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>( |
| 117 | m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<void>()); |
| 118 | |
| 119 | std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor; |
| 120 | std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>( |
| 121 | m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<void>()); |
| 122 | std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>( |
| 123 | m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<void>()); |
| 124 | std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>( |
| 125 | m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<void>()); |
| 126 | |
| 127 | std::unique_ptr<Decoder<float>> inputGateBiasTensor; |
| 128 | std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>( |
| 129 | m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetTensor<void>()); |
| 130 | std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>( |
| 131 | m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetTensor<void>()); |
| 132 | std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>( |
| 133 | m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetTensor<void>()); |
| 134 | |
| 135 | std::unique_ptr<Decoder<float>> cellToInputWeightsTensor; |
| 136 | std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor; |
| 137 | std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor; |
| 138 | |
| 139 | std::unique_ptr<Decoder<float>> projectionWeightsTensor; |
| 140 | std::unique_ptr<Decoder<float>> projectionBiasTensor; |
| 141 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 142 | std::unique_ptr<Decoder<float>> inputLayerNormWeights; |
| 143 | std::unique_ptr<Decoder<float>> forgetLayerNormWeights; |
| 144 | std::unique_ptr<Decoder<float>> cellLayerNormWeights; |
| 145 | std::unique_ptr<Decoder<float>> outputLayerNormWeights; |
| 146 | |
| 147 | if (useLayerNorm) |
| 148 | { |
| 149 | if (!useCifg) |
| 150 | { |
| 151 | inputLayerNormWeights = MakeDecoder<float>( |
| 152 | m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetTensor<void>()); |
| 153 | } |
| 154 | forgetLayerNormWeights = MakeDecoder<float>( |
| 155 | m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetTensor<void>()); |
| 156 | cellLayerNormWeights = MakeDecoder<float>( |
| 157 | m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetTensor<void>()); |
| 158 | outputLayerNormWeights = MakeDecoder<float>( |
| 159 | m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetTensor<void>()); |
| 160 | } |
| 161 | |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 162 | if (!useCifg) |
| 163 | { |
| 164 | inputToInputWeightsTensor = MakeDecoder<float>( |
| 165 | m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<void>()); |
| 166 | inputGateBiasTensor = MakeDecoder<float>( |
| 167 | m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetTensor<void>()); |
| 168 | recurrentToInputWeightsTensor = MakeDecoder<float>( |
| 169 | m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<void>()); |
| 170 | } |
| 171 | |
| 172 | if (usePeephole) |
| 173 | { |
| 174 | cellToForgetWeightsTensor = MakeDecoder<float>( |
| 175 | m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<void>()); |
| 176 | cellToOutputWeightsTensor = MakeDecoder<float>( |
| 177 | m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<void>()); |
| 178 | } |
| 179 | |
| 180 | if (!useCifg && usePeephole) |
| 181 | { |
| 182 | cellToInputWeightsTensor = MakeDecoder<float>( |
| 183 | m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<void>()); |
| 184 | } |
| 185 | |
| 186 | if (m_Data.m_Parameters.m_ProjectionEnabled) |
| 187 | { |
| 188 | projectionWeightsTensor = MakeDecoder<float>( |
| 189 | m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<void>()); |
| 190 | if (m_ProjectionBiasTensor) |
| 191 | { |
| 192 | projectionBiasTensor = MakeDecoder<float>( |
| 193 | m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<void>()); |
| 194 | } |
| 195 | } |
| 196 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 197 | if (!useLayerNorm) |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 198 | { |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 199 | // Initialize scratch buffers with bias. |
| 200 | if (!useCifg) |
| 201 | { |
| 202 | VectorBatchVectorAssign(*inputGateBiasTensor, |
| 203 | nCell, nBatch, *inputGateScratch); |
| 204 | } |
| 205 | VectorBatchVectorAssign(*forgetGateBiasTensor, |
| 206 | nCell, nBatch, *forgetGateScratch); |
| 207 | VectorBatchVectorAssign(*cellBiasTensor, |
| 208 | nCell, nBatch, *cellScratch); |
| 209 | VectorBatchVectorAssign(*outputGateBiasTensor, |
| 210 | nCell, nBatch, *outputGateScratch); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 211 | } |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 212 | else |
| 213 | { |
| 214 | // Initialize scratch buffers with zeroes. |
| 215 | if (!useCifg) |
| 216 | { |
| 217 | ZeroVector(*inputGateScratch, nCell * nBatch); |
| 218 | } |
| 219 | ZeroVector(*forgetGateScratch, nCell * nBatch); |
| 220 | ZeroVector(*cellScratch , nCell * nBatch); |
| 221 | ZeroVector(*outputGateScratch, nCell * nBatch); |
| 222 | } |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 223 | |
| 224 | // For each batch and cell: compute input_weight * input. |
| 225 | if (!useCifg) |
| 226 | { |
| 227 | MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor, |
| 228 | nCell, nInput, *inputData, nBatch, *inputGateScratch); |
| 229 | } |
| 230 | MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor, |
| 231 | nCell, nInput, *inputData, nBatch, *forgetGateScratch); |
| 232 | MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor, |
| 233 | nCell, nInput, *inputData, nBatch, *cellScratch); |
| 234 | MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor, |
| 235 | nCell, nInput, *inputData, nBatch, *outputGateScratch); |
| 236 | |
| 237 | // For each batch and cell: compute recurrent_weight * output_state. |
| 238 | if (!useCifg) |
| 239 | { |
| 240 | MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor, |
| 241 | nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch); |
| 242 | } |
| 243 | MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor, |
| 244 | nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch); |
| 245 | MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor, |
| 246 | nCell, nOutput, *outputStateIn, nBatch, *cellScratch); |
| 247 | MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor, |
| 248 | nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch); |
| 249 | |
| 250 | // For each batch and cell: update input gate. |
| 251 | if (!useCifg) |
| 252 | { |
| 253 | if (usePeephole) |
| 254 | { |
| 255 | VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor, |
| 256 | nCell, *cellStateIn, nBatch, *inputGateScratch); |
| 257 | } |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 258 | if (useLayerNorm) |
| 259 | { |
| 260 | MeanStddevNormalization(*inputGateScratchDecoder, |
| 261 | *inputGateScratch, nCell, nBatch, m_LayerNormEpsilon); |
| 262 | VectorBatchVectorCwiseProduct(*inputLayerNormWeights, |
| 263 | nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); |
| 264 | VectorBatchVectorAdd(*inputGateBiasTensor, |
| 265 | nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); |
| 266 | } |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 267 | Activation(*inputGateScratchDecoder, *inputGateScratch, |
| 268 | TensorInfo({nCell, nBatch}, outputType), |
| 269 | ActivationFunction::Sigmoid, 0, 0); |
| 270 | } |
| 271 | |
| 272 | // For each batch and cell: update forget gate. |
| 273 | if (usePeephole) |
| 274 | { |
| 275 | VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell, |
| 276 | *cellStateIn, nBatch, *forgetGateScratch); |
| 277 | } |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 278 | if (useLayerNorm) |
| 279 | { |
| 280 | MeanStddevNormalization(*forgetGateScratchDecoder, |
| 281 | *forgetGateScratch, nCell, nBatch, m_LayerNormEpsilon); |
| 282 | VectorBatchVectorCwiseProduct(*forgetLayerNormWeights, |
| 283 | nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); |
| 284 | VectorBatchVectorAdd(*forgetGateBiasTensor, |
| 285 | nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); |
| 286 | } |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 287 | Activation(*forgetGateScratchDecoder, *forgetGateScratch, |
| 288 | TensorInfo({nCell, nBatch}, outputType), |
| 289 | ActivationFunction::Sigmoid, 0, 0); |
| 290 | |
| 291 | // For each batch and cell: update the cell. |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 292 | if (useLayerNorm) |
| 293 | { |
| 294 | MeanStddevNormalization(*cellScratchDecoder, |
| 295 | *cellScratch, nCell, nBatch, m_LayerNormEpsilon); |
| 296 | VectorBatchVectorCwiseProduct(*cellLayerNormWeights, |
| 297 | nCell, *cellScratchDecoder, nBatch, *cellScratch); |
| 298 | VectorBatchVectorAdd(*cellBiasTensor, |
| 299 | nCell, *cellScratchDecoder, nBatch, *cellScratch); |
| 300 | } |
| 301 | |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 302 | VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut); |
| 303 | |
| 304 | ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; |
| 305 | float a = 0; |
| 306 | float b = 0; |
| 307 | SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b); |
| 308 | |
| 309 | if (m_Data.m_Parameters.m_ActivationFunc > 0) |
| 310 | { |
| 311 | Activation(*cellScratchDecoder, *cellScratch, |
| 312 | TensorInfo({nCell, nBatch}, outputType), |
| 313 | armnnActivationFunc, a, b); |
| 314 | } |
| 315 | if (useCifg) |
| 316 | { |
| 317 | Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch); |
| 318 | VectorVectorCwiseProductAccumulate( |
| 319 | *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut); |
| 320 | } |
| 321 | else |
| 322 | { |
| 323 | VectorVectorCwiseProductAccumulate( |
| 324 | *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut); |
| 325 | } |
| 326 | if (m_Data.m_Parameters.m_ClippingThresCell > 0.0) |
| 327 | { |
| 328 | ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut); |
| 329 | } |
| 330 | |
| 331 | // For each batch and cell: update the output gate. |
| 332 | if (usePeephole) |
| 333 | { |
| 334 | VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, |
| 335 | nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); |
| 336 | } |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 337 | if (useLayerNorm) |
| 338 | { |
| 339 | MeanStddevNormalization(*outputGateScratchDecoder, |
| 340 | *outputGateScratch, nCell, nBatch, m_LayerNormEpsilon); |
| 341 | VectorBatchVectorCwiseProduct(*outputLayerNormWeights, |
| 342 | nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); |
| 343 | VectorBatchVectorAdd(*outputGateBiasTensor, |
| 344 | nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); |
| 345 | } |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 346 | Activation(*outputGateScratchDecoder, *outputGateScratch, |
| 347 | TensorInfo({nCell, nBatch}, outputType), |
| 348 | ActivationFunction::Sigmoid, 0, 0); |
| 349 | |
| 350 | if (m_Data.m_Parameters.m_ActivationFunc > 0) |
| 351 | { |
| 352 | Activation(*cellStateOutDecoder, *cellScratch, |
| 353 | TensorInfo({nCell, nBatch}, outputType), |
| 354 | armnnActivationFunc, a, b); |
| 355 | } |
| 356 | |
| 357 | VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch); |
| 358 | |
| 359 | // For each batch: update the projection and output_state. |
| 360 | if (m_Data.m_Parameters.m_ProjectionEnabled) |
| 361 | { |
| 362 | if (m_ProjectionBiasTensor) |
| 363 | { |
| 364 | VectorBatchVectorAssign(*projectionBiasTensor, |
| 365 | nOutput, nBatch, *output); |
| 366 | } |
| 367 | MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor, |
| 368 | nOutput, nCell, *outputGateScratchDecoder, nBatch, *output); |
| 369 | |
| 370 | if (m_Data.m_Parameters.m_ClippingThresProj > 0.0) |
| 371 | { |
| 372 | ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output); |
| 373 | } |
| 374 | } |
| 375 | else |
| 376 | { |
| 377 | CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output); |
| 378 | } |
| 379 | |
| 380 | CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut); |
| 381 | } |
| 382 | |
| 383 | } //namespace armnn |