Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "RefLstmWorkload.hpp" |
| 7 | #include "Activation.hpp" |
| 8 | #include "Encoders.hpp" |
| 9 | #include "Decoders.hpp" |
| 10 | #include "LstmUtils.hpp" |
| 11 | #include "RefWorkloadUtils.hpp" |
| 12 | |
| 13 | namespace armnn |
| 14 | { |
| 15 | |
| 16 | RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) |
| 17 | : BaseWorkload<LstmQueueDescriptor>(descriptor, info) |
| 18 | , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights)) |
| 19 | , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights)) |
| 20 | , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights)) |
| 21 | , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights)) |
| 22 | , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights)) |
| 23 | , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights)) |
| 24 | , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights)) |
| 25 | , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights)) |
| 26 | , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights)) |
| 27 | , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights)) |
| 28 | , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights)) |
| 29 | , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias)) |
| 30 | , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias)) |
| 31 | , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias)) |
| 32 | , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias)) |
| 33 | , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights)) |
| 34 | , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias)) |
| 35 | {} |
| 36 | |
| 37 | void RefLstmWorkload::Execute() const |
| 38 | { |
| 39 | // This is a porting of the LSTM::Eval() method in the Android code base |
| 40 | // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp |
| 41 | |
| 42 | const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); |
| 43 | const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); |
| 44 | |
| 45 | const TensorShape& inputShape = inputInfo.GetShape(); |
| 46 | const DataType& outputType = outputInfo.GetDataType(); |
| 47 | |
| 48 | std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[1]->Map()); |
| 49 | std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[2]->Map()); |
| 50 | std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[3]->Map()); |
| 51 | |
| 52 | std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[2]->Map()); |
| 53 | std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[3]->Map()); |
| 54 | |
| 55 | std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map()); |
| 56 | std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[1]->Map()); |
| 57 | std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[2]->Map()); |
| 58 | |
| 59 | const uint32_t nBatch = inputShape[0]; |
| 60 | const uint32_t nInput = inputShape[1]; |
| 61 | |
| 62 | const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; |
| 63 | const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1]; |
| 64 | |
| 65 | const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; |
| 66 | const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; |
| 67 | |
| 68 | // Index the scratch buffers pointers to the global scratch buffer. |
| 69 | std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 70 | std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 71 | std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 72 | std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 73 | |
| 74 | std::unique_ptr<Decoder<float>> inputGateScratchDecoder = |
| 75 | MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 76 | std::unique_ptr<Decoder<float>> cellScratchDecoder = |
| 77 | MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 78 | std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = |
| 79 | MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 80 | std::unique_ptr<Decoder<float>> outputGateScratchDecoder = |
| 81 | MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()); |
| 82 | |
| 83 | if (useCifg) |
| 84 | { |
| 85 | *cellScratch += (0 * nCell * nBatch); |
| 86 | *forgetGateScratch += (1 * nCell * nBatch); |
| 87 | *outputGateScratch += (2 * nCell * nBatch); |
| 88 | |
| 89 | *cellScratchDecoder += (0 * nCell * nBatch); |
| 90 | *forgetGateScratchDecoder += (1 * nCell * nBatch); |
| 91 | *outputGateScratchDecoder += (2 * nCell * nBatch); |
| 92 | } |
| 93 | else |
| 94 | { |
| 95 | *inputGateScratch += (0 * nCell * nBatch); |
| 96 | *cellScratch += (1 * nCell * nBatch); |
| 97 | *forgetGateScratch += (2 * nCell * nBatch); |
| 98 | *outputGateScratch += (3 * nCell * nBatch); |
| 99 | |
| 100 | *inputGateScratchDecoder += (0 * nCell * nBatch); |
| 101 | *cellScratchDecoder += (1 * nCell * nBatch); |
| 102 | *forgetGateScratchDecoder += (2 * nCell * nBatch); |
| 103 | *outputGateScratchDecoder += (3 * nCell * nBatch); |
| 104 | } |
| 105 | |
| 106 | std::unique_ptr<Decoder<float>> inputToInputWeightsTensor; |
| 107 | std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>( |
| 108 | m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<void>()); |
| 109 | std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>( |
| 110 | m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<void>()); |
| 111 | std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>( |
| 112 | m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<void>()); |
| 113 | |
| 114 | std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor; |
| 115 | std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>( |
| 116 | m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<void>()); |
| 117 | std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>( |
| 118 | m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<void>()); |
| 119 | std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>( |
| 120 | m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<void>()); |
| 121 | |
| 122 | std::unique_ptr<Decoder<float>> inputGateBiasTensor; |
| 123 | std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>( |
| 124 | m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetTensor<void>()); |
| 125 | std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>( |
| 126 | m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetTensor<void>()); |
| 127 | std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>( |
| 128 | m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetTensor<void>()); |
| 129 | |
| 130 | std::unique_ptr<Decoder<float>> cellToInputWeightsTensor; |
| 131 | std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor; |
| 132 | std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor; |
| 133 | |
| 134 | std::unique_ptr<Decoder<float>> projectionWeightsTensor; |
| 135 | std::unique_ptr<Decoder<float>> projectionBiasTensor; |
| 136 | |
| 137 | if (!useCifg) |
| 138 | { |
| 139 | inputToInputWeightsTensor = MakeDecoder<float>( |
| 140 | m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<void>()); |
| 141 | inputGateBiasTensor = MakeDecoder<float>( |
| 142 | m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetTensor<void>()); |
| 143 | recurrentToInputWeightsTensor = MakeDecoder<float>( |
| 144 | m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<void>()); |
| 145 | } |
| 146 | |
| 147 | if (usePeephole) |
| 148 | { |
| 149 | cellToForgetWeightsTensor = MakeDecoder<float>( |
| 150 | m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<void>()); |
| 151 | cellToOutputWeightsTensor = MakeDecoder<float>( |
| 152 | m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<void>()); |
| 153 | } |
| 154 | |
| 155 | if (!useCifg && usePeephole) |
| 156 | { |
| 157 | cellToInputWeightsTensor = MakeDecoder<float>( |
| 158 | m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<void>()); |
| 159 | } |
| 160 | |
| 161 | if (m_Data.m_Parameters.m_ProjectionEnabled) |
| 162 | { |
| 163 | projectionWeightsTensor = MakeDecoder<float>( |
| 164 | m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<void>()); |
| 165 | if (m_ProjectionBiasTensor) |
| 166 | { |
| 167 | projectionBiasTensor = MakeDecoder<float>( |
| 168 | m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<void>()); |
| 169 | } |
| 170 | } |
| 171 | |
| 172 | // Initialize scratch buffers with bias. |
| 173 | if (!useCifg) |
| 174 | { |
| 175 | VectorBatchVectorAssign(*inputGateBiasTensor, |
| 176 | nCell, nBatch, *inputGateScratch); |
| 177 | } |
| 178 | VectorBatchVectorAssign(*forgetGateBiasTensor, |
| 179 | nCell, nBatch, *forgetGateScratch); |
| 180 | VectorBatchVectorAssign(*cellBiasTensor, |
| 181 | nCell, nBatch, *cellScratch); |
| 182 | VectorBatchVectorAssign(*outputGateBiasTensor, |
| 183 | nCell, nBatch, *outputGateScratch); |
| 184 | |
| 185 | // For each batch and cell: compute input_weight * input. |
| 186 | if (!useCifg) |
| 187 | { |
| 188 | MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor, |
| 189 | nCell, nInput, *inputData, nBatch, *inputGateScratch); |
| 190 | } |
| 191 | MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor, |
| 192 | nCell, nInput, *inputData, nBatch, *forgetGateScratch); |
| 193 | MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor, |
| 194 | nCell, nInput, *inputData, nBatch, *cellScratch); |
| 195 | MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor, |
| 196 | nCell, nInput, *inputData, nBatch, *outputGateScratch); |
| 197 | |
| 198 | // For each batch and cell: compute recurrent_weight * output_state. |
| 199 | if (!useCifg) |
| 200 | { |
| 201 | MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor, |
| 202 | nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch); |
| 203 | } |
| 204 | MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor, |
| 205 | nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch); |
| 206 | MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor, |
| 207 | nCell, nOutput, *outputStateIn, nBatch, *cellScratch); |
| 208 | MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor, |
| 209 | nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch); |
| 210 | |
| 211 | // For each batch and cell: update input gate. |
| 212 | if (!useCifg) |
| 213 | { |
| 214 | if (usePeephole) |
| 215 | { |
| 216 | VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor, |
| 217 | nCell, *cellStateIn, nBatch, *inputGateScratch); |
| 218 | } |
| 219 | Activation(*inputGateScratchDecoder, *inputGateScratch, |
| 220 | TensorInfo({nCell, nBatch}, outputType), |
| 221 | ActivationFunction::Sigmoid, 0, 0); |
| 222 | } |
| 223 | |
| 224 | // For each batch and cell: update forget gate. |
| 225 | if (usePeephole) |
| 226 | { |
| 227 | VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell, |
| 228 | *cellStateIn, nBatch, *forgetGateScratch); |
| 229 | } |
| 230 | Activation(*forgetGateScratchDecoder, *forgetGateScratch, |
| 231 | TensorInfo({nCell, nBatch}, outputType), |
| 232 | ActivationFunction::Sigmoid, 0, 0); |
| 233 | |
| 234 | // For each batch and cell: update the cell. |
| 235 | VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut); |
| 236 | |
| 237 | ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; |
| 238 | float a = 0; |
| 239 | float b = 0; |
| 240 | SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b); |
| 241 | |
| 242 | if (m_Data.m_Parameters.m_ActivationFunc > 0) |
| 243 | { |
| 244 | Activation(*cellScratchDecoder, *cellScratch, |
| 245 | TensorInfo({nCell, nBatch}, outputType), |
| 246 | armnnActivationFunc, a, b); |
| 247 | } |
| 248 | if (useCifg) |
| 249 | { |
| 250 | Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch); |
| 251 | VectorVectorCwiseProductAccumulate( |
| 252 | *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut); |
| 253 | } |
| 254 | else |
| 255 | { |
| 256 | VectorVectorCwiseProductAccumulate( |
| 257 | *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut); |
| 258 | } |
| 259 | if (m_Data.m_Parameters.m_ClippingThresCell > 0.0) |
| 260 | { |
| 261 | ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut); |
| 262 | } |
| 263 | |
| 264 | // For each batch and cell: update the output gate. |
| 265 | if (usePeephole) |
| 266 | { |
| 267 | VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, |
| 268 | nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); |
| 269 | } |
| 270 | Activation(*outputGateScratchDecoder, *outputGateScratch, |
| 271 | TensorInfo({nCell, nBatch}, outputType), |
| 272 | ActivationFunction::Sigmoid, 0, 0); |
| 273 | |
| 274 | if (m_Data.m_Parameters.m_ActivationFunc > 0) |
| 275 | { |
| 276 | Activation(*cellStateOutDecoder, *cellScratch, |
| 277 | TensorInfo({nCell, nBatch}, outputType), |
| 278 | armnnActivationFunc, a, b); |
| 279 | } |
| 280 | |
| 281 | VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch); |
| 282 | |
| 283 | // For each batch: update the projection and output_state. |
| 284 | if (m_Data.m_Parameters.m_ProjectionEnabled) |
| 285 | { |
| 286 | if (m_ProjectionBiasTensor) |
| 287 | { |
| 288 | VectorBatchVectorAssign(*projectionBiasTensor, |
| 289 | nOutput, nBatch, *output); |
| 290 | } |
| 291 | MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor, |
| 292 | nOutput, nCell, *outputGateScratchDecoder, nBatch, *output); |
| 293 | |
| 294 | if (m_Data.m_Parameters.m_ClippingThresProj > 0.0) |
| 295 | { |
| 296 | ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output); |
| 297 | } |
| 298 | } |
| 299 | else |
| 300 | { |
| 301 | CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output); |
| 302 | } |
| 303 | |
| 304 | CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut); |
| 305 | } |
| 306 | |
| 307 | } //namespace armnn |