Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include "RefLstmWorkload.hpp" |
| 7 | #include "Activation.hpp" |
| 8 | #include "Encoders.hpp" |
| 9 | #include "Decoders.hpp" |
| 10 | #include "LstmUtils.hpp" |
| 11 | #include "RefWorkloadUtils.hpp" |
| 12 | |
| 13 | namespace armnn |
| 14 | { |
| 15 | |
| 16 | RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) |
| 17 | : BaseWorkload<LstmQueueDescriptor>(descriptor, info) |
James Conroy | 1f58f03 | 2021-04-27 17:13:27 +0100 | [diff] [blame] | 18 | , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) |
| 19 | , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) |
| 20 | , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) |
| 21 | , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights)) |
| 22 | , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights)) |
| 23 | , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights)) |
| 24 | , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights)) |
| 25 | , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights)) |
| 26 | , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights)) |
| 27 | , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights)) |
| 28 | , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights)) |
| 29 | , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias)) |
| 30 | , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias)) |
| 31 | , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias)) |
| 32 | , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias)) |
| 33 | , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights)) |
| 34 | , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias)) |
| 35 | , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights)) |
| 36 | , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights)) |
| 37 | , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights)) |
| 38 | , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights)) |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 39 | {} |
| 40 | |
| 41 | void RefLstmWorkload::Execute() const |
| 42 | { |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 43 | Execute(m_Data.m_Inputs, m_Data.m_Outputs); |
| 44 | } |
| 45 | |
| 46 | void RefLstmWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) |
| 47 | { |
| 48 | Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); |
| 49 | } |
| 50 | |
| 51 | void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const |
| 52 | { |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 53 | // This is a porting of the LSTM::Eval() method in the Android code base |
| 54 | // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp |
| 55 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 56 | const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); |
| 57 | const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 58 | |
| 59 | const TensorShape& inputShape = inputInfo.GetShape(); |
| 60 | const DataType& outputType = outputInfo.GetDataType(); |
| 61 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 62 | std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map()); |
| 63 | std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map()); |
| 64 | std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 65 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 66 | std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map()); |
| 67 | std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 68 | |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 69 | std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map()); |
| 70 | std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map()); |
| 71 | std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 72 | |
| 73 | const uint32_t nBatch = inputShape[0]; |
| 74 | const uint32_t nInput = inputShape[1]; |
| 75 | |
| 76 | const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0]; |
| 77 | const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1]; |
| 78 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 79 | const bool useCifg = m_Data.m_Parameters.m_CifgEnabled; |
| 80 | const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled; |
| 81 | const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled; |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 82 | |
| 83 | // Index the scratch buffers pointers to the global scratch buffer. |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 84 | std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map()); |
| 85 | std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map()); |
| 86 | std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map()); |
| 87 | std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 88 | |
| 89 | std::unique_ptr<Decoder<float>> inputGateScratchDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 90 | MakeDecoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 91 | std::unique_ptr<Decoder<float>> cellScratchDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 92 | MakeDecoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 93 | std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 94 | MakeDecoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 95 | std::unique_ptr<Decoder<float>> outputGateScratchDecoder = |
Finn Williams | b8181f7 | 2021-04-07 10:23:21 +0100 | [diff] [blame] | 96 | MakeDecoder<float>(outputInfo, outputs[0]->Map()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 97 | |
| 98 | if (useCifg) |
| 99 | { |
| 100 | *cellScratch += (0 * nCell * nBatch); |
| 101 | *forgetGateScratch += (1 * nCell * nBatch); |
| 102 | *outputGateScratch += (2 * nCell * nBatch); |
| 103 | |
| 104 | *cellScratchDecoder += (0 * nCell * nBatch); |
| 105 | *forgetGateScratchDecoder += (1 * nCell * nBatch); |
| 106 | *outputGateScratchDecoder += (2 * nCell * nBatch); |
| 107 | } |
| 108 | else |
| 109 | { |
| 110 | *inputGateScratch += (0 * nCell * nBatch); |
| 111 | *cellScratch += (1 * nCell * nBatch); |
| 112 | *forgetGateScratch += (2 * nCell * nBatch); |
| 113 | *outputGateScratch += (3 * nCell * nBatch); |
| 114 | |
| 115 | *inputGateScratchDecoder += (0 * nCell * nBatch); |
| 116 | *cellScratchDecoder += (1 * nCell * nBatch); |
| 117 | *forgetGateScratchDecoder += (2 * nCell * nBatch); |
| 118 | *outputGateScratchDecoder += (3 * nCell * nBatch); |
| 119 | } |
| 120 | |
| 121 | std::unique_ptr<Decoder<float>> inputToInputWeightsTensor; |
| 122 | std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 123 | m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 124 | std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 125 | m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 126 | std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 127 | m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 128 | |
| 129 | std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor; |
| 130 | std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 131 | m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 132 | std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 133 | m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 134 | std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 135 | m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 136 | |
| 137 | std::unique_ptr<Decoder<float>> inputGateBiasTensor; |
| 138 | std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 139 | m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 140 | std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 141 | m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 142 | std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 143 | m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 144 | |
| 145 | std::unique_ptr<Decoder<float>> cellToInputWeightsTensor; |
| 146 | std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor; |
| 147 | std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor; |
| 148 | |
| 149 | std::unique_ptr<Decoder<float>> projectionWeightsTensor; |
| 150 | std::unique_ptr<Decoder<float>> projectionBiasTensor; |
| 151 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 152 | std::unique_ptr<Decoder<float>> inputLayerNormWeights; |
| 153 | std::unique_ptr<Decoder<float>> forgetLayerNormWeights; |
| 154 | std::unique_ptr<Decoder<float>> cellLayerNormWeights; |
| 155 | std::unique_ptr<Decoder<float>> outputLayerNormWeights; |
| 156 | |
| 157 | if (useLayerNorm) |
| 158 | { |
| 159 | if (!useCifg) |
| 160 | { |
| 161 | inputLayerNormWeights = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 162 | m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 163 | } |
| 164 | forgetLayerNormWeights = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 165 | m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 166 | cellLayerNormWeights = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 167 | m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 168 | outputLayerNormWeights = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 169 | m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>()); |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 170 | } |
| 171 | |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 172 | if (!useCifg) |
| 173 | { |
| 174 | inputToInputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 175 | m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 176 | inputGateBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 177 | m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 178 | recurrentToInputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 179 | m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 180 | } |
| 181 | |
| 182 | if (usePeephole) |
| 183 | { |
| 184 | cellToForgetWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 185 | m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 186 | cellToOutputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 187 | m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 188 | } |
| 189 | |
| 190 | if (!useCifg && usePeephole) |
| 191 | { |
| 192 | cellToInputWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 193 | m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 194 | } |
| 195 | |
| 196 | if (m_Data.m_Parameters.m_ProjectionEnabled) |
| 197 | { |
| 198 | projectionWeightsTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 199 | m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 200 | if (m_ProjectionBiasTensor) |
| 201 | { |
| 202 | projectionBiasTensor = MakeDecoder<float>( |
Finn Williams | 4422cec | 2021-03-22 17:51:06 +0000 | [diff] [blame] | 203 | m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>()); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 204 | } |
| 205 | } |
| 206 | |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 207 | if (!useLayerNorm) |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 208 | { |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 209 | // Initialize scratch buffers with bias. |
| 210 | if (!useCifg) |
| 211 | { |
| 212 | VectorBatchVectorAssign(*inputGateBiasTensor, |
| 213 | nCell, nBatch, *inputGateScratch); |
| 214 | } |
| 215 | VectorBatchVectorAssign(*forgetGateBiasTensor, |
| 216 | nCell, nBatch, *forgetGateScratch); |
| 217 | VectorBatchVectorAssign(*cellBiasTensor, |
| 218 | nCell, nBatch, *cellScratch); |
| 219 | VectorBatchVectorAssign(*outputGateBiasTensor, |
| 220 | nCell, nBatch, *outputGateScratch); |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 221 | } |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 222 | else |
| 223 | { |
| 224 | // Initialize scratch buffers with zeroes. |
| 225 | if (!useCifg) |
| 226 | { |
| 227 | ZeroVector(*inputGateScratch, nCell * nBatch); |
| 228 | } |
| 229 | ZeroVector(*forgetGateScratch, nCell * nBatch); |
| 230 | ZeroVector(*cellScratch , nCell * nBatch); |
| 231 | ZeroVector(*outputGateScratch, nCell * nBatch); |
| 232 | } |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 233 | |
| 234 | // For each batch and cell: compute input_weight * input. |
| 235 | if (!useCifg) |
| 236 | { |
| 237 | MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor, |
| 238 | nCell, nInput, *inputData, nBatch, *inputGateScratch); |
| 239 | } |
| 240 | MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor, |
| 241 | nCell, nInput, *inputData, nBatch, *forgetGateScratch); |
| 242 | MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor, |
| 243 | nCell, nInput, *inputData, nBatch, *cellScratch); |
| 244 | MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor, |
| 245 | nCell, nInput, *inputData, nBatch, *outputGateScratch); |
| 246 | |
| 247 | // For each batch and cell: compute recurrent_weight * output_state. |
| 248 | if (!useCifg) |
| 249 | { |
| 250 | MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor, |
| 251 | nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch); |
| 252 | } |
| 253 | MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor, |
| 254 | nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch); |
| 255 | MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor, |
| 256 | nCell, nOutput, *outputStateIn, nBatch, *cellScratch); |
| 257 | MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor, |
| 258 | nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch); |
| 259 | |
| 260 | // For each batch and cell: update input gate. |
| 261 | if (!useCifg) |
| 262 | { |
| 263 | if (usePeephole) |
| 264 | { |
| 265 | VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor, |
| 266 | nCell, *cellStateIn, nBatch, *inputGateScratch); |
| 267 | } |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 268 | if (useLayerNorm) |
| 269 | { |
| 270 | MeanStddevNormalization(*inputGateScratchDecoder, |
| 271 | *inputGateScratch, nCell, nBatch, m_LayerNormEpsilon); |
| 272 | VectorBatchVectorCwiseProduct(*inputLayerNormWeights, |
| 273 | nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); |
| 274 | VectorBatchVectorAdd(*inputGateBiasTensor, |
| 275 | nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch); |
| 276 | } |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 277 | Activation(*inputGateScratchDecoder, *inputGateScratch, |
| 278 | TensorInfo({nCell, nBatch}, outputType), |
| 279 | ActivationFunction::Sigmoid, 0, 0); |
| 280 | } |
| 281 | |
| 282 | // For each batch and cell: update forget gate. |
| 283 | if (usePeephole) |
| 284 | { |
| 285 | VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell, |
| 286 | *cellStateIn, nBatch, *forgetGateScratch); |
| 287 | } |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 288 | if (useLayerNorm) |
| 289 | { |
| 290 | MeanStddevNormalization(*forgetGateScratchDecoder, |
| 291 | *forgetGateScratch, nCell, nBatch, m_LayerNormEpsilon); |
| 292 | VectorBatchVectorCwiseProduct(*forgetLayerNormWeights, |
| 293 | nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); |
| 294 | VectorBatchVectorAdd(*forgetGateBiasTensor, |
| 295 | nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch); |
| 296 | } |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 297 | Activation(*forgetGateScratchDecoder, *forgetGateScratch, |
| 298 | TensorInfo({nCell, nBatch}, outputType), |
| 299 | ActivationFunction::Sigmoid, 0, 0); |
| 300 | |
| 301 | // For each batch and cell: update the cell. |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 302 | if (useLayerNorm) |
| 303 | { |
| 304 | MeanStddevNormalization(*cellScratchDecoder, |
| 305 | *cellScratch, nCell, nBatch, m_LayerNormEpsilon); |
| 306 | VectorBatchVectorCwiseProduct(*cellLayerNormWeights, |
| 307 | nCell, *cellScratchDecoder, nBatch, *cellScratch); |
| 308 | VectorBatchVectorAdd(*cellBiasTensor, |
| 309 | nCell, *cellScratchDecoder, nBatch, *cellScratch); |
| 310 | } |
| 311 | |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 312 | VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut); |
| 313 | |
| 314 | ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid; |
| 315 | float a = 0; |
| 316 | float b = 0; |
| 317 | SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b); |
| 318 | |
| 319 | if (m_Data.m_Parameters.m_ActivationFunc > 0) |
| 320 | { |
| 321 | Activation(*cellScratchDecoder, *cellScratch, |
| 322 | TensorInfo({nCell, nBatch}, outputType), |
| 323 | armnnActivationFunc, a, b); |
| 324 | } |
| 325 | if (useCifg) |
| 326 | { |
| 327 | Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch); |
| 328 | VectorVectorCwiseProductAccumulate( |
| 329 | *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut); |
| 330 | } |
| 331 | else |
| 332 | { |
| 333 | VectorVectorCwiseProductAccumulate( |
| 334 | *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut); |
| 335 | } |
| 336 | if (m_Data.m_Parameters.m_ClippingThresCell > 0.0) |
| 337 | { |
| 338 | ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut); |
| 339 | } |
| 340 | |
| 341 | // For each batch and cell: update the output gate. |
| 342 | if (usePeephole) |
| 343 | { |
| 344 | VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor, |
| 345 | nCell, *cellStateOutDecoder, nBatch, *outputGateScratch); |
| 346 | } |
Jan Eilers | 38e05bd | 2019-06-26 13:10:09 +0100 | [diff] [blame] | 347 | if (useLayerNorm) |
| 348 | { |
| 349 | MeanStddevNormalization(*outputGateScratchDecoder, |
| 350 | *outputGateScratch, nCell, nBatch, m_LayerNormEpsilon); |
| 351 | VectorBatchVectorCwiseProduct(*outputLayerNormWeights, |
| 352 | nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); |
| 353 | VectorBatchVectorAdd(*outputGateBiasTensor, |
| 354 | nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch); |
| 355 | } |
Nattapat Chaimanowong | eb2b329 | 2019-05-07 12:02:30 +0100 | [diff] [blame] | 356 | Activation(*outputGateScratchDecoder, *outputGateScratch, |
| 357 | TensorInfo({nCell, nBatch}, outputType), |
| 358 | ActivationFunction::Sigmoid, 0, 0); |
| 359 | |
| 360 | if (m_Data.m_Parameters.m_ActivationFunc > 0) |
| 361 | { |
| 362 | Activation(*cellStateOutDecoder, *cellScratch, |
| 363 | TensorInfo({nCell, nBatch}, outputType), |
| 364 | armnnActivationFunc, a, b); |
| 365 | } |
| 366 | |
| 367 | VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch); |
| 368 | |
| 369 | // For each batch: update the projection and output_state. |
| 370 | if (m_Data.m_Parameters.m_ProjectionEnabled) |
| 371 | { |
| 372 | if (m_ProjectionBiasTensor) |
| 373 | { |
| 374 | VectorBatchVectorAssign(*projectionBiasTensor, |
| 375 | nOutput, nBatch, *output); |
| 376 | } |
| 377 | MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor, |
| 378 | nOutput, nCell, *outputGateScratchDecoder, nBatch, *output); |
| 379 | |
| 380 | if (m_Data.m_Parameters.m_ClippingThresProj > 0.0) |
| 381 | { |
| 382 | ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output); |
| 383 | } |
| 384 | } |
| 385 | else |
| 386 | { |
| 387 | CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output); |
| 388 | } |
| 389 | |
| 390 | CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut); |
| 391 | } |
| 392 | |
| 393 | } //namespace armnn |