Narumol Prangnawarat | e5339e7 | 2021-07-28 17:33:28 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "Activation.hpp" |
| 7 | #include "Lstm.hpp" |
| 8 | #include "LstmUtils.hpp" |
| 9 | |
| 10 | namespace armnn |
| 11 | { |
| 12 | |
| 13 | void LstmImpl(const LstmDescriptor& descriptor, |
| 14 | const TensorInfo& inputInfo, |
| 15 | const TensorInfo& outputInfo, |
| 16 | const TensorShape& inputToOutputWeightsShape, |
| 17 | const TensorShape& recurrentToOutputWeightsShape, |
| 18 | std::unique_ptr<Decoder<float>>& inputData, |
| 19 | std::unique_ptr<Decoder<float>>& outputStateIn, |
| 20 | std::unique_ptr<Decoder<float>>& cellStateIn, |
| 21 | std::unique_ptr<Encoder<float>>& outputStateOut, |
| 22 | std::unique_ptr<Encoder<float>>& cellStateOut, |
| 23 | std::unique_ptr<Encoder<float>>& output, |
| 24 | std::unique_ptr<Decoder<float>>& cellStateOutDecoder, |
| 25 | std::unique_ptr<Decoder<float>>& outputDecoder, |
| 26 | std::unique_ptr<Decoder<float>>& inputToInputWeightsTensor, |
| 27 | std::unique_ptr<Decoder<float>>& inputToForgetWeightsTensor, |
| 28 | std::unique_ptr<Decoder<float>>& inputToCellWeightsTensor, |
| 29 | std::unique_ptr<Decoder<float>>& inputToOutputWeightsTensor, |
| 30 | std::unique_ptr<Decoder<float>>& recurrentToInputWeightsTensor, |
| 31 | std::unique_ptr<Decoder<float>>& recurrentToForgetWeightsTensor, |
| 32 | std::unique_ptr<Decoder<float>>& recurrentToCellWeightsTensor, |
| 33 | std::unique_ptr<Decoder<float>>& recurrentToOutputWeightsTensor, |
| 34 | std::unique_ptr<Decoder<float>>& cellToInputWeightsTensor, |
| 35 | std::unique_ptr<Decoder<float>>& cellToForgetWeightsTensor, |
| 36 | std::unique_ptr<Decoder<float>>& cellToOutputWeightsTensor, |
| 37 | std::unique_ptr<Decoder<float>>& inputGateBiasTensor, |
| 38 | std::unique_ptr<Decoder<float>>& forgetGateBiasTensor, |
| 39 | std::unique_ptr<Decoder<float>>& cellBiasTensor, |
| 40 | std::unique_ptr<Decoder<float>>& outputGateBiasTensor, |
| 41 | std::unique_ptr<Decoder<float>>& projectionWeightsTensor, |
| 42 | std::unique_ptr<Decoder<float>>& projectionBiasTensor, |
| 43 | std::unique_ptr<Decoder<float>>& inputLayerNormWeights, |
| 44 | std::unique_ptr<Decoder<float>>& forgetLayerNormWeights, |
| 45 | std::unique_ptr<Decoder<float>>& cellLayerNormWeights, |
| 46 | std::unique_ptr<Decoder<float>>& outputLayerNormWeights, |
| 47 | std::unique_ptr<Encoder<float>>& inputGateScratch, |
| 48 | std::unique_ptr<Encoder<float>>& cellScratch, |
| 49 | std::unique_ptr<Encoder<float>>& forgetGateScratch, |
| 50 | std::unique_ptr<Encoder<float>>& outputGateScratch, |
| 51 | std::unique_ptr<Decoder<float>>& inputGateScratchDecoder, |
| 52 | std::unique_ptr<Decoder<float>>& cellScratchDecoder, |
| 53 | std::unique_ptr<Decoder<float>>& forgetGateScratchDecoder, |
| 54 | std::unique_ptr<Decoder<float>>& outputGateScratchDecoder, |
| 55 | float layerNormEpsilon) |
| 56 | { |
| 57 | // This is a porting of the LSTM::Eval() method in the Android code base |
| 58 | // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp |
| 59 | |
| 60 | const TensorShape& inputShape = inputInfo.GetShape(); |
| 61 | const DataType& outputType = outputInfo.GetDataType(); |
| 62 | |
| 63 | const uint32_t nBatch = inputShape[0]; |
| 64 | const uint32_t nInput = inputShape[1]; |
| 65 | |
| 66 | const uint32_t nCell = inputToOutputWeightsShape[0]; |
| 67 | const uint32_t nOutput = recurrentToOutputWeightsShape[1]; |
| 68 | |
| 69 | const bool useCifg = descriptor.m_CifgEnabled; |
| 70 | const bool usePeephole = descriptor.m_PeepholeEnabled; |
| 71 | const bool useLayerNorm = descriptor.m_LayerNormEnabled; |
| 72 | |
| 73 | if (!useLayerNorm) |
| 74 | { |
| 75 | // Initialize scratch buffers with bias. |
| 76 | if (!useCifg) |
| 77 | { |
| 78 | VectorBatchVectorAssign(*inputGateBiasTensor, |
| 79 | nCell, nBatch, *inputGateScratch); |
| 80 | } |
| 81 | VectorBatchVectorAssign(*forgetGateBiasTensor, |
| 82 | nCell, nBatch, *forgetGateScratch); |
| 83 | VectorBatchVectorAssign(*cellBiasTensor, |
| 84 | nCell, nBatch, *cellScratch); |
| 85 | VectorBatchVectorAssign(*outputGateBiasTensor, |
| 86 | nCell, nBatch, *outputGateScratch); |
| 87 | } |
| 88 | else |
| 89 | { |
| 90 | // Initialize scratch buffers with zeroes. |
| 91 | if (!useCifg) |
| 92 | { |
| 93 | ZeroVector(*inputGateScratch, nCell * nBatch); |
| 94 | } |
| 95 | ZeroVector(*forgetGateScratch, nCell * nBatch); |
| 96 | ZeroVector(*cellScratch , nCell * nBatch); |
| 97 | ZeroVector(*outputGateScratch, nCell * nBatch); |
| 98 | } |
| 99 | |
| 100 | // For each batch and cell: compute input_weight * input. |
| 101 | if (!useCifg) |
| 102 | { |
| 103 | MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor, |
| 104 | nCell, nInput, *inputData, nBatch, *inputGateScratch); |
| 105 | } |
| 106 | MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor, |
| 107 | nCell, nInput, *inputData, nBatch, *forgetGateScratch); |
| 108 | MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor, |
| 109 | nCell, nInput, *inputData, nBatch, *cellScratch); |
| 110 | MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor, |
| 111 | nCell, nInput, *inputData, nBatch, *outputGateScratch); |
| 112 | |
| 113 | // For each batch and cell: compute recurrent_weight * output_state. |
| 114 | if (!useCifg) |
| 115 | { |
| 116 | MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor, |
| 117 | nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch); |
| 118 | } |
| 119 | MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor, |
| 120 | nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch); |
| 121 | MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor, |
| 122 | nCell, nOutput, *outputStateIn, nBatch, *cellScratch); |
| 123 | MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor, |
| 124 | nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch); |
| 125 | |
| 126 | // For each batch and cell: update input gate. |
| 127 | if (!useCifg) |
| 128 | { |
| 129 | if (usePeephole) |
| 130 | { |
| 131 | VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor, |
| 132 | nCell, *cellStateIn, nBatch, *inputGateScratch); |
| 133 | } |
| 134 | if (useLayerNorm) |
| 135 | { |
| 136 | MeanStddevNormalization(*inputGateScratchDecoder, |
| 137 | *inputGateScratch, nCell, nBatch, layerNormEpsilon); |
| 138 | VectorBatchVectorCwiseProduct(*inputLayerNormWeights, |
| 139 | nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); |
| 140 | VectorBatchVectorAdd(*inputGateBiasTensor, |
| 141 | nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); |
| 142 | } |
| 143 | Activation(*inputGateScratchDecoder, *inputGateScratch, |
| 144 | TensorInfo({nCell, nBatch}, outputType), |
| 145 | ActivationFunction::Sigmoid, 0, 0); |
| 146 | } |
| 147 | |
| 148 | // For each batch and cell: update forget gate. |
| 149 | if (usePeephole) |
| 150 | { |
| 151 | VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell, |
| 152 | *cellStateIn, nBatch, *forgetGateScratch); |
| 153 | } |
| 154 | if (useLayerNorm) |
| 155 | { |
| 156 | MeanStddevNormalization(*forgetGateScratchDecoder, |
| 157 | *forgetGateScratch, nCell, nBatch, layerNormEpsilon); |
| 158 | VectorBatchVectorCwiseProduct(*forgetLayerNormWeights, |
| 159 | nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); |
| 160 | VectorBatchVectorAdd(*forgetGateBiasTensor, |
| 161 | nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); |
| 162 | } |
| 163 | Activation(*forgetGateScratchDecoder, *forgetGateScratch, |
| 164 | TensorInfo({nCell, nBatch}, outputType), |
| 165 | ActivationFunction::Sigmoid, 0, 0); |
| 166 | |
| 167 | // For each batch and cell: update the cell. |
| 168 | if (useLayerNorm) |
| 169 | { |
| 170 | MeanStddevNormalization(*cellScratchDecoder, |
| 171 | *cellScratch, nCell, nBatch, layerNormEpsilon); |
| 172 | VectorBatchVectorCwiseProduct(*cellLayerNormWeights, |
| 173 | nCell, *cellScratchDecoder, nBatch, *cellScratch); |
| 174 | VectorBatchVectorAdd(*cellBiasTensor, |
| 175 | nCell, *cellScratchDecoder, nBatch, *cellScratch); |
| 176 | } |
| 177 | |
| 178 | VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut); |
| 179 | |
| 180 | ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; |
| 181 | float a = 0; |
| 182 | float b = 0; |
| 183 | SetActivationParameters(descriptor.m_ActivationFunc, armnnActivationFunc, a, b); |
| 184 | |
| 185 | if (descriptor.m_ActivationFunc > 0) |
| 186 | { |
| 187 | Activation(*cellScratchDecoder, *cellScratch, |
| 188 | TensorInfo({nCell, nBatch}, outputType), |
| 189 | armnnActivationFunc, a, b); |
| 190 | } |
| 191 | if (useCifg) |
| 192 | { |
| 193 | Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch); |
| 194 | VectorVectorCwiseProductAccumulate( |
| 195 | *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut); |
| 196 | } |
| 197 | else |
| 198 | { |
| 199 | VectorVectorCwiseProductAccumulate( |
| 200 | *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut); |
| 201 | } |
| 202 | if (descriptor.m_ClippingThresCell > 0.0) |
| 203 | { |
| 204 | ClipVector(*cellStateOutDecoder, nBatch * nCell, descriptor.m_ClippingThresCell, *cellStateOut); |
| 205 | } |
| 206 | |
| 207 | // For each batch and cell: update the output gate. |
| 208 | if (usePeephole) |
| 209 | { |
| 210 | VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, |
| 211 | nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); |
| 212 | } |
| 213 | if (useLayerNorm) |
| 214 | { |
| 215 | MeanStddevNormalization(*outputGateScratchDecoder, |
| 216 | *outputGateScratch, nCell, nBatch, layerNormEpsilon); |
| 217 | VectorBatchVectorCwiseProduct(*outputLayerNormWeights, |
| 218 | nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); |
| 219 | VectorBatchVectorAdd(*outputGateBiasTensor, |
| 220 | nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); |
| 221 | } |
| 222 | Activation(*outputGateScratchDecoder, *outputGateScratch, |
| 223 | TensorInfo({nCell, nBatch}, outputType), |
| 224 | ActivationFunction::Sigmoid, 0, 0); |
| 225 | |
| 226 | if (descriptor.m_ActivationFunc > 0) |
| 227 | { |
| 228 | Activation(*cellStateOutDecoder, *cellScratch, |
| 229 | TensorInfo({nCell, nBatch}, outputType), |
| 230 | armnnActivationFunc, a, b); |
| 231 | } |
| 232 | |
| 233 | VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch); |
| 234 | |
| 235 | // For each batch: update the projection and output_state. |
| 236 | if (descriptor.m_ProjectionEnabled) |
| 237 | { |
| 238 | if (projectionBiasTensor) |
| 239 | { |
| 240 | VectorBatchVectorAssign(*projectionBiasTensor, |
| 241 | nOutput, nBatch, *output); |
| 242 | } |
| 243 | MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor, |
| 244 | nOutput, nCell, *outputGateScratchDecoder, nBatch, *output); |
| 245 | |
| 246 | if (descriptor.m_ClippingThresProj > 0.0) |
| 247 | { |
| 248 | ClipVector(*outputDecoder, nBatch * nOutput, descriptor.m_ClippingThresProj, *output); |
| 249 | } |
| 250 | } |
| 251 | else |
| 252 | { |
| 253 | CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output); |
| 254 | } |
| 255 | |
| 256 | CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut); |
| 257 | } |
| 258 | |
| 259 | } //namespace armnn |