blob: 3ddfd334b846525afadd69e229e2b97df2ae7a6b [file] [log] [blame]
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10#include "LstmUtils.hpp"
11#include "RefWorkloadUtils.hpp"
12
13namespace armnn
14{
15
16RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
17 : BaseWorkload<LstmQueueDescriptor>(descriptor, info)
James Conroy1f58f032021-04-27 17:13:27 +010018 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
19 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
20 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
21 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
22 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
23 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
24 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
25 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
26 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
27 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
28 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
29 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
30 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
31 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
32 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
33 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
34 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
35 , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
36 , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
37 , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
38 , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010039{}
40
41void RefLstmWorkload::Execute() const
42{
Finn Williamsb8181f72021-04-07 10:23:21 +010043 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
44}
45
46void RefLstmWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor)
47{
48 Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
49}
50
51void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
52{
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010053 // This is a porting of the LSTM::Eval() method in the Android code base
54 // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
55
Finn Williamsb8181f72021-04-07 10:23:21 +010056 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
57 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010058
59 const TensorShape& inputShape = inputInfo.GetShape();
60 const DataType& outputType = outputInfo.GetDataType();
61
Finn Williamsb8181f72021-04-07 10:23:21 +010062 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
63 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map());
64 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010065
Finn Williamsb8181f72021-04-07 10:23:21 +010066 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
67 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010068
Finn Williamsb8181f72021-04-07 10:23:21 +010069 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map());
70 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
71 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010072
73 const uint32_t nBatch = inputShape[0];
74 const uint32_t nInput = inputShape[1];
75
76 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
77 const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1];
78
Jan Eilers38e05bd2019-06-26 13:10:09 +010079 const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
80 const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
81 const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010082
83 // Index the scratch buffers pointers to the global scratch buffer.
Finn Williamsb8181f72021-04-07 10:23:21 +010084 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
85 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
86 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
87 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010088
89 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010090 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010091 std::unique_ptr<Decoder<float>> cellScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010092 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010093 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010094 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010095 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010096 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010097
98 if (useCifg)
99 {
100 *cellScratch += (0 * nCell * nBatch);
101 *forgetGateScratch += (1 * nCell * nBatch);
102 *outputGateScratch += (2 * nCell * nBatch);
103
104 *cellScratchDecoder += (0 * nCell * nBatch);
105 *forgetGateScratchDecoder += (1 * nCell * nBatch);
106 *outputGateScratchDecoder += (2 * nCell * nBatch);
107 }
108 else
109 {
110 *inputGateScratch += (0 * nCell * nBatch);
111 *cellScratch += (1 * nCell * nBatch);
112 *forgetGateScratch += (2 * nCell * nBatch);
113 *outputGateScratch += (3 * nCell * nBatch);
114
115 *inputGateScratchDecoder += (0 * nCell * nBatch);
116 *cellScratchDecoder += (1 * nCell * nBatch);
117 *forgetGateScratchDecoder += (2 * nCell * nBatch);
118 *outputGateScratchDecoder += (3 * nCell * nBatch);
119 }
120
121 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
122 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000123 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100124 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000125 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100126 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000127 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100128
129 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
130 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000131 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100132 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000133 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100134 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000135 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100136
137 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
138 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000139 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100140 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000141 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100142 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000143 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100144
145 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
146 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
147 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
148
149 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
150 std::unique_ptr<Decoder<float>> projectionBiasTensor;
151
Jan Eilers38e05bd2019-06-26 13:10:09 +0100152 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
153 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
154 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
155 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
156
157 if (useLayerNorm)
158 {
159 if (!useCifg)
160 {
161 inputLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000162 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100163 }
164 forgetLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000165 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100166 cellLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000167 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100168 outputLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000169 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100170 }
171
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100172 if (!useCifg)
173 {
174 inputToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000175 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100176 inputGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000177 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100178 recurrentToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000179 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100180 }
181
182 if (usePeephole)
183 {
184 cellToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000185 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100186 cellToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000187 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100188 }
189
190 if (!useCifg && usePeephole)
191 {
192 cellToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000193 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100194 }
195
196 if (m_Data.m_Parameters.m_ProjectionEnabled)
197 {
198 projectionWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000199 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100200 if (m_ProjectionBiasTensor)
201 {
202 projectionBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000203 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100204 }
205 }
206
Jan Eilers38e05bd2019-06-26 13:10:09 +0100207 if (!useLayerNorm)
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100208 {
Jan Eilers38e05bd2019-06-26 13:10:09 +0100209 // Initialize scratch buffers with bias.
210 if (!useCifg)
211 {
212 VectorBatchVectorAssign(*inputGateBiasTensor,
213 nCell, nBatch, *inputGateScratch);
214 }
215 VectorBatchVectorAssign(*forgetGateBiasTensor,
216 nCell, nBatch, *forgetGateScratch);
217 VectorBatchVectorAssign(*cellBiasTensor,
218 nCell, nBatch, *cellScratch);
219 VectorBatchVectorAssign(*outputGateBiasTensor,
220 nCell, nBatch, *outputGateScratch);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100221 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100222 else
223 {
224 // Initialize scratch buffers with zeroes.
225 if (!useCifg)
226 {
227 ZeroVector(*inputGateScratch, nCell * nBatch);
228 }
229 ZeroVector(*forgetGateScratch, nCell * nBatch);
230 ZeroVector(*cellScratch , nCell * nBatch);
231 ZeroVector(*outputGateScratch, nCell * nBatch);
232 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100233
234 // For each batch and cell: compute input_weight * input.
235 if (!useCifg)
236 {
237 MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor,
238 nCell, nInput, *inputData, nBatch, *inputGateScratch);
239 }
240 MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor,
241 nCell, nInput, *inputData, nBatch, *forgetGateScratch);
242 MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor,
243 nCell, nInput, *inputData, nBatch, *cellScratch);
244 MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor,
245 nCell, nInput, *inputData, nBatch, *outputGateScratch);
246
247 // For each batch and cell: compute recurrent_weight * output_state.
248 if (!useCifg)
249 {
250 MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor,
251 nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
252 }
253 MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor,
254 nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
255 MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor,
256 nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
257 MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor,
258 nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
259
260 // For each batch and cell: update input gate.
261 if (!useCifg)
262 {
263 if (usePeephole)
264 {
265 VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor,
266 nCell, *cellStateIn, nBatch, *inputGateScratch);
267 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100268 if (useLayerNorm)
269 {
270 MeanStddevNormalization(*inputGateScratchDecoder,
271 *inputGateScratch, nCell, nBatch, m_LayerNormEpsilon);
272 VectorBatchVectorCwiseProduct(*inputLayerNormWeights,
273 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
274 VectorBatchVectorAdd(*inputGateBiasTensor,
275 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
276 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100277 Activation(*inputGateScratchDecoder, *inputGateScratch,
278 TensorInfo({nCell, nBatch}, outputType),
279 ActivationFunction::Sigmoid, 0, 0);
280 }
281
282 // For each batch and cell: update forget gate.
283 if (usePeephole)
284 {
285 VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell,
286 *cellStateIn, nBatch, *forgetGateScratch);
287 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100288 if (useLayerNorm)
289 {
290 MeanStddevNormalization(*forgetGateScratchDecoder,
291 *forgetGateScratch, nCell, nBatch, m_LayerNormEpsilon);
292 VectorBatchVectorCwiseProduct(*forgetLayerNormWeights,
293 nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
294 VectorBatchVectorAdd(*forgetGateBiasTensor,
295 nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
296 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100297 Activation(*forgetGateScratchDecoder, *forgetGateScratch,
298 TensorInfo({nCell, nBatch}, outputType),
299 ActivationFunction::Sigmoid, 0, 0);
300
301 // For each batch and cell: update the cell.
Jan Eilers38e05bd2019-06-26 13:10:09 +0100302 if (useLayerNorm)
303 {
304 MeanStddevNormalization(*cellScratchDecoder,
305 *cellScratch, nCell, nBatch, m_LayerNormEpsilon);
306 VectorBatchVectorCwiseProduct(*cellLayerNormWeights,
307 nCell, *cellScratchDecoder, nBatch, *cellScratch);
308 VectorBatchVectorAdd(*cellBiasTensor,
309 nCell, *cellScratchDecoder, nBatch, *cellScratch);
310 }
311
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100312 VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut);
313
314 ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
315 float a = 0;
316 float b = 0;
317 SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b);
318
319 if (m_Data.m_Parameters.m_ActivationFunc > 0)
320 {
321 Activation(*cellScratchDecoder, *cellScratch,
322 TensorInfo({nCell, nBatch}, outputType),
323 armnnActivationFunc, a, b);
324 }
325 if (useCifg)
326 {
327 Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
328 VectorVectorCwiseProductAccumulate(
329 *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
330 }
331 else
332 {
333 VectorVectorCwiseProductAccumulate(
334 *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
335 }
336 if (m_Data.m_Parameters.m_ClippingThresCell > 0.0)
337 {
338 ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut);
339 }
340
341 // For each batch and cell: update the output gate.
342 if (usePeephole)
343 {
344 VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor,
345 nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
346 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100347 if (useLayerNorm)
348 {
349 MeanStddevNormalization(*outputGateScratchDecoder,
350 *outputGateScratch, nCell, nBatch, m_LayerNormEpsilon);
351 VectorBatchVectorCwiseProduct(*outputLayerNormWeights,
352 nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
353 VectorBatchVectorAdd(*outputGateBiasTensor,
354 nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
355 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100356 Activation(*outputGateScratchDecoder, *outputGateScratch,
357 TensorInfo({nCell, nBatch}, outputType),
358 ActivationFunction::Sigmoid, 0, 0);
359
360 if (m_Data.m_Parameters.m_ActivationFunc > 0)
361 {
362 Activation(*cellStateOutDecoder, *cellScratch,
363 TensorInfo({nCell, nBatch}, outputType),
364 armnnActivationFunc, a, b);
365 }
366
367 VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch);
368
369 // For each batch: update the projection and output_state.
370 if (m_Data.m_Parameters.m_ProjectionEnabled)
371 {
372 if (m_ProjectionBiasTensor)
373 {
374 VectorBatchVectorAssign(*projectionBiasTensor,
375 nOutput, nBatch, *output);
376 }
377 MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor,
378 nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
379
380 if (m_Data.m_Parameters.m_ClippingThresProj > 0.0)
381 {
382 ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output);
383 }
384 }
385 else
386 {
387 CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
388 }
389
390 CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);
391}
392
393} //namespace armnn