blob: 7c37301d1daff90063b5c8b5a80a8a886571a1a8 [file] [log] [blame]
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10#include "LstmUtils.hpp"
11#include "RefWorkloadUtils.hpp"
12
13namespace armnn
14{
15
16RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
17 : BaseWorkload<LstmQueueDescriptor>(descriptor, info)
18 , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights))
19 , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights))
20 , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights))
21 , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights))
22 , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights))
23 , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights))
24 , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights))
25 , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights))
26 , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights))
27 , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights))
28 , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights))
29 , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias))
30 , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias))
31 , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias))
32 , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias))
33 , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights))
34 , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias))
Jan Eilers38e05bd2019-06-26 13:10:09 +010035 , m_InputLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_InputLayerNormWeights))
36 , m_ForgetLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_ForgetLayerNormWeights))
37 , m_CellLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_CellLayerNormWeights))
38 , m_OutputLayerNormWeights (AssignScopedCpuTensorHandle(descriptor.m_OutputLayerNormWeights))
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010039{}
40
41void RefLstmWorkload::Execute() const
42{
43 // This is a porting of the LSTM::Eval() method in the Android code base
44 // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
45
46 const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
47 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
48
49 const TensorShape& inputShape = inputInfo.GetShape();
50 const DataType& outputType = outputInfo.GetDataType();
51
52 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[1]->Map());
53 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
54 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[3]->Map());
55
56 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
57 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[3]->Map());
58
59 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map());
60 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[1]->Map());
61 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[2]->Map());
62
63 const uint32_t nBatch = inputShape[0];
64 const uint32_t nInput = inputShape[1];
65
66 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
67 const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1];
68
Jan Eilers38e05bd2019-06-26 13:10:09 +010069 const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
70 const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
71 const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010072
73 // Index the scratch buffers pointers to the global scratch buffer.
74 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
75 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
76 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
77 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
78
79 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
80 MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
81 std::unique_ptr<Decoder<float>> cellScratchDecoder =
82 MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
83 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
84 MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
85 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
86 MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
87
88 if (useCifg)
89 {
90 *cellScratch += (0 * nCell * nBatch);
91 *forgetGateScratch += (1 * nCell * nBatch);
92 *outputGateScratch += (2 * nCell * nBatch);
93
94 *cellScratchDecoder += (0 * nCell * nBatch);
95 *forgetGateScratchDecoder += (1 * nCell * nBatch);
96 *outputGateScratchDecoder += (2 * nCell * nBatch);
97 }
98 else
99 {
100 *inputGateScratch += (0 * nCell * nBatch);
101 *cellScratch += (1 * nCell * nBatch);
102 *forgetGateScratch += (2 * nCell * nBatch);
103 *outputGateScratch += (3 * nCell * nBatch);
104
105 *inputGateScratchDecoder += (0 * nCell * nBatch);
106 *cellScratchDecoder += (1 * nCell * nBatch);
107 *forgetGateScratchDecoder += (2 * nCell * nBatch);
108 *outputGateScratchDecoder += (3 * nCell * nBatch);
109 }
110
111 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
112 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000113 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100114 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000115 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100116 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000117 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100118
119 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
120 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000121 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100122 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000123 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100124 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000125 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100126
127 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
128 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000129 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100130 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000131 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100132 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000133 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100134
135 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
136 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
137 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
138
139 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
140 std::unique_ptr<Decoder<float>> projectionBiasTensor;
141
Jan Eilers38e05bd2019-06-26 13:10:09 +0100142 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
143 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
144 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
145 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
146
147 if (useLayerNorm)
148 {
149 if (!useCifg)
150 {
151 inputLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000152 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100153 }
154 forgetLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000155 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100156 cellLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000157 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100158 outputLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000159 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100160 }
161
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100162 if (!useCifg)
163 {
164 inputToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000165 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100166 inputGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000167 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100168 recurrentToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000169 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100170 }
171
172 if (usePeephole)
173 {
174 cellToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000175 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100176 cellToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000177 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100178 }
179
180 if (!useCifg && usePeephole)
181 {
182 cellToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000183 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100184 }
185
186 if (m_Data.m_Parameters.m_ProjectionEnabled)
187 {
188 projectionWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000189 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100190 if (m_ProjectionBiasTensor)
191 {
192 projectionBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000193 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100194 }
195 }
196
Jan Eilers38e05bd2019-06-26 13:10:09 +0100197 if (!useLayerNorm)
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100198 {
Jan Eilers38e05bd2019-06-26 13:10:09 +0100199 // Initialize scratch buffers with bias.
200 if (!useCifg)
201 {
202 VectorBatchVectorAssign(*inputGateBiasTensor,
203 nCell, nBatch, *inputGateScratch);
204 }
205 VectorBatchVectorAssign(*forgetGateBiasTensor,
206 nCell, nBatch, *forgetGateScratch);
207 VectorBatchVectorAssign(*cellBiasTensor,
208 nCell, nBatch, *cellScratch);
209 VectorBatchVectorAssign(*outputGateBiasTensor,
210 nCell, nBatch, *outputGateScratch);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100211 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100212 else
213 {
214 // Initialize scratch buffers with zeroes.
215 if (!useCifg)
216 {
217 ZeroVector(*inputGateScratch, nCell * nBatch);
218 }
219 ZeroVector(*forgetGateScratch, nCell * nBatch);
220 ZeroVector(*cellScratch , nCell * nBatch);
221 ZeroVector(*outputGateScratch, nCell * nBatch);
222 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100223
224 // For each batch and cell: compute input_weight * input.
225 if (!useCifg)
226 {
227 MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor,
228 nCell, nInput, *inputData, nBatch, *inputGateScratch);
229 }
230 MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor,
231 nCell, nInput, *inputData, nBatch, *forgetGateScratch);
232 MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor,
233 nCell, nInput, *inputData, nBatch, *cellScratch);
234 MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor,
235 nCell, nInput, *inputData, nBatch, *outputGateScratch);
236
237 // For each batch and cell: compute recurrent_weight * output_state.
238 if (!useCifg)
239 {
240 MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor,
241 nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
242 }
243 MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor,
244 nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
245 MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor,
246 nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
247 MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor,
248 nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
249
250 // For each batch and cell: update input gate.
251 if (!useCifg)
252 {
253 if (usePeephole)
254 {
255 VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor,
256 nCell, *cellStateIn, nBatch, *inputGateScratch);
257 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100258 if (useLayerNorm)
259 {
260 MeanStddevNormalization(*inputGateScratchDecoder,
261 *inputGateScratch, nCell, nBatch, m_LayerNormEpsilon);
262 VectorBatchVectorCwiseProduct(*inputLayerNormWeights,
263 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
264 VectorBatchVectorAdd(*inputGateBiasTensor,
265 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
266 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100267 Activation(*inputGateScratchDecoder, *inputGateScratch,
268 TensorInfo({nCell, nBatch}, outputType),
269 ActivationFunction::Sigmoid, 0, 0);
270 }
271
272 // For each batch and cell: update forget gate.
273 if (usePeephole)
274 {
275 VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell,
276 *cellStateIn, nBatch, *forgetGateScratch);
277 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100278 if (useLayerNorm)
279 {
280 MeanStddevNormalization(*forgetGateScratchDecoder,
281 *forgetGateScratch, nCell, nBatch, m_LayerNormEpsilon);
282 VectorBatchVectorCwiseProduct(*forgetLayerNormWeights,
283 nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
284 VectorBatchVectorAdd(*forgetGateBiasTensor,
285 nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
286 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100287 Activation(*forgetGateScratchDecoder, *forgetGateScratch,
288 TensorInfo({nCell, nBatch}, outputType),
289 ActivationFunction::Sigmoid, 0, 0);
290
291 // For each batch and cell: update the cell.
Jan Eilers38e05bd2019-06-26 13:10:09 +0100292 if (useLayerNorm)
293 {
294 MeanStddevNormalization(*cellScratchDecoder,
295 *cellScratch, nCell, nBatch, m_LayerNormEpsilon);
296 VectorBatchVectorCwiseProduct(*cellLayerNormWeights,
297 nCell, *cellScratchDecoder, nBatch, *cellScratch);
298 VectorBatchVectorAdd(*cellBiasTensor,
299 nCell, *cellScratchDecoder, nBatch, *cellScratch);
300 }
301
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100302 VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut);
303
304 ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
305 float a = 0;
306 float b = 0;
307 SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b);
308
309 if (m_Data.m_Parameters.m_ActivationFunc > 0)
310 {
311 Activation(*cellScratchDecoder, *cellScratch,
312 TensorInfo({nCell, nBatch}, outputType),
313 armnnActivationFunc, a, b);
314 }
315 if (useCifg)
316 {
317 Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
318 VectorVectorCwiseProductAccumulate(
319 *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
320 }
321 else
322 {
323 VectorVectorCwiseProductAccumulate(
324 *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
325 }
326 if (m_Data.m_Parameters.m_ClippingThresCell > 0.0)
327 {
328 ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut);
329 }
330
331 // For each batch and cell: update the output gate.
332 if (usePeephole)
333 {
334 VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor,
335 nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
336 }
Jan Eilers38e05bd2019-06-26 13:10:09 +0100337 if (useLayerNorm)
338 {
339 MeanStddevNormalization(*outputGateScratchDecoder,
340 *outputGateScratch, nCell, nBatch, m_LayerNormEpsilon);
341 VectorBatchVectorCwiseProduct(*outputLayerNormWeights,
342 nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
343 VectorBatchVectorAdd(*outputGateBiasTensor,
344 nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
345 }
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100346 Activation(*outputGateScratchDecoder, *outputGateScratch,
347 TensorInfo({nCell, nBatch}, outputType),
348 ActivationFunction::Sigmoid, 0, 0);
349
350 if (m_Data.m_Parameters.m_ActivationFunc > 0)
351 {
352 Activation(*cellStateOutDecoder, *cellScratch,
353 TensorInfo({nCell, nBatch}, outputType),
354 armnnActivationFunc, a, b);
355 }
356
357 VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch);
358
359 // For each batch: update the projection and output_state.
360 if (m_Data.m_Parameters.m_ProjectionEnabled)
361 {
362 if (m_ProjectionBiasTensor)
363 {
364 VectorBatchVectorAssign(*projectionBiasTensor,
365 nOutput, nBatch, *output);
366 }
367 MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor,
368 nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
369
370 if (m_Data.m_Parameters.m_ClippingThresProj > 0.0)
371 {
372 ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output);
373 }
374 }
375 else
376 {
377 CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
378 }
379
380 CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);
381}
382
383} //namespace armnn