blob: 34d048b0cb4f6c50650f3c74ade34d231d482019 [file] [log] [blame]
James Conroy4f1f8992020-04-29 20:01:10 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefQLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10#include "LstmUtils.hpp"
11#include "RefWorkloadUtils.hpp"
12
13namespace armnn
14{
15
16RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
17 : BaseWorkload<QLstmQueueDescriptor>(descriptor, info)
18 , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights))
19 , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights))
20 , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights))
21 , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights))
22
23 , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights))
24 , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights))
25 , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights))
26 , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights))
27
28 , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights))
29 , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights))
30 , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights))
31
32 , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias))
33 , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias))
34 , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias))
35 , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias))
36
37 , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights))
38 , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias))
39
40 , m_InputLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputLayerNormWeights))
41 , m_ForgetLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetLayerNormWeights))
42 , m_CellLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellLayerNormWeights))
43 , m_OutputLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputLayerNormWeights))
44{}
45
46void RefQLstmWorkload::Execute() const
47{
48 // This is a porting of the QLSTM::Execute() method in the Android code base
49 // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all
50 // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp.
51 // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp
52 const DataType& internalType = armnn::DataType::QSymmS16;
53
54 const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
55 const TensorInfo& outputStateInInfo = GetTensorInfo(m_Data.m_Inputs[1]);
56 const TensorInfo& cellStateInInfo = GetTensorInfo(m_Data.m_Inputs[2]);
57
58 const TensorInfo& outputStateOutInfo = GetTensorInfo(m_Data.m_Outputs[0]);
59 const TensorInfo& cellStateOutInfo = GetTensorInfo(m_Data.m_Outputs[1]);
60 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[2]);
61
62 const TensorShape& inputShape = inputInfo.GetShape();
63 const TensorShape& outputStateInShape = outputStateInInfo.GetShape();
64 const TensorShape& cellStateInShape = cellStateInInfo.GetShape();
65
66 // Infer numBatches, inputSize, outputSize and numUnits
67 const uint32_t numBatches = inputShape[0];
68 const uint32_t inputSize = inputShape[1];
69 const uint32_t outputSize = outputStateInShape[1];
70 const uint32_t numUnits = cellStateInShape[1];
71
72 // Optional param settings
73 const bool cifgEnabled = m_Data.m_Parameters.m_CifgEnabled;
74 const bool peepholeEnabled = m_Data.m_Parameters.m_PeepholeEnabled;
75 const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled;
76 const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled;
77
78 // Input decoders
79 std::unique_ptr<Decoder<float>> inputDecoder =
80 MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map());
81 std::unique_ptr<Decoder<float>> outputStateInDecoder =
82 MakeDecoder<float>(outputStateInInfo, m_Data.m_Inputs[1]->Map());
83 std::unique_ptr<Decoder<float>> cellStateInDecoder =
84 MakeDecoder<float>(cellStateInInfo, m_Data.m_Inputs[2]->Map());
85
86 // Output decoders
87 std::unique_ptr<Decoder<float>> outputStateOutDecoder =
88 MakeDecoder<float>(outputStateOutInfo, m_Data.m_Outputs[0]->Map());
89 std::unique_ptr<Decoder<float>> cellStateOutDecoder =
90 MakeDecoder<float>(cellStateOutInfo, m_Data.m_Outputs[1]->Map());
91 std::unique_ptr<Decoder<float>> outputDecoder =
92 MakeDecoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
93
94 // Output encoders
95 std::unique_ptr<Encoder<float>> outputStateOutEncoder =
96 MakeEncoder<float>(outputStateOutInfo, m_Data.m_Outputs[0]->Map());
97 std::unique_ptr<Encoder<float>> cellStateOutEncoder =
98 MakeEncoder<float>(cellStateOutInfo, m_Data.m_Outputs[1]->Map());
99 std::unique_ptr<Encoder<float>> outputEncoder =
100 MakeEncoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
101
102 // Weights decoders
103 std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>(
104 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<void>());
105 std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>(
106 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<void>());
107 std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>(
108 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<void>());
109
110 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>(
111 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<void>());
112 std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>(
113 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<void>());
114 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>(
115 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<void>());
116
117 // Optional CIFG params
118 std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder;
119 std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder;
120 std::unique_ptr<Decoder<float>> inputGateBiasDecoder;
121
122 // Optional Peephole params
123 std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder;
124 std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder;
125 std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder;
126
127 // Optional Projection params
128 std::unique_ptr<Decoder<float>> projectionWeightsDecoder;
129 std::unique_ptr<Decoder<float>> projectionBiasDecoder;
130
131 // Optional Layer Norm params
132 std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder;
133 std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder;
134 std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder;
135 std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder;
136
137 // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024)
138 std::unique_ptr<Decoder<float>> forgetGateBiasDecoder;
139 std::unique_ptr<Decoder<float>> cellGateBiasDecoder;
140 std::unique_ptr<Decoder<float>> outputGateBiasDecoder;
141
142 // Int16 vectors for internal state data (to be decoded/encoded)
143 const uint32_t stateTensorSize = numBatches * numUnits;
144 std::vector<int16_t> inputGateData(stateTensorSize);
145 std::vector<int16_t> cellGateData(stateTensorSize);
146 std::vector<int16_t> forgetGateData(stateTensorSize);
147 std::vector<int16_t> outputGateData(stateTensorSize);
148 std::vector<int32_t> hiddenStateData(stateTensorSize);
149
150 armnn::TensorInfo inputGateInfo(
151 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_InputIntermediateScale, 0);
152 armnn::TensorInfo cellGateInfo(
153 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0);
154 armnn::TensorInfo forgetGateInfo(
155 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0);
156 armnn::TensorInfo outputGateInfo(
157 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0);
158 armnn::TensorInfo hiddenStateInfo({numBatches, numUnits},
159 armnn::DataType::QAsymmS8,
160 m_Data.m_Parameters.m_HiddenStateScale,
161 m_Data.m_Parameters.m_HiddenStateZeroPoint);
162
163 // Decoders/Encoders for internal states
164 std::unique_ptr<Decoder<float>> inputGateDecoder =
165 MakeDecoder<float>(inputGateInfo, inputGateData.data());
166 std::unique_ptr<Decoder<float>> cellGateDecoder =
167 MakeDecoder<float>(cellGateInfo, cellGateData.data());
168 std::unique_ptr<Decoder<float>> forgetGateDecoder =
169 MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
170 std::unique_ptr<Decoder<float>> outputGateDecoder =
171 MakeDecoder<float>(outputGateInfo, outputGateData.data());
172 std::unique_ptr<Decoder<float>> hiddenStateDecoder =
173 MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data());
174
175 std::unique_ptr<Encoder<float>> inputGateEncoder =
176 MakeEncoder<float>(inputGateInfo, inputGateData.data());
177 std::unique_ptr<Encoder<float>> cellGateEncoder =
178 MakeEncoder<float>(cellGateInfo, cellGateData.data());
179 std::unique_ptr<Encoder<float>> forgetGateEncoder =
180 MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
181 std::unique_ptr<Encoder<float>> outputGateEncoder =
182 MakeEncoder<float>(outputGateInfo, outputGateData.data());
183 std::unique_ptr<Encoder<float>> hiddenStateEncoder =
184 MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data());
185
186 // Create decoders for optional params if they are enabled
187 if (!cifgEnabled)
188 {
189 inputToInputWeightsDecoder = MakeDecoder<float>(
190 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<void>());
191 recurrentToInputWeightsDecoder = MakeDecoder<float>(
192 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<void>());
193 }
194
195 if (peepholeEnabled)
196 {
197 if (!cifgEnabled)
198 {
199 cellToInputWeightsDecoder = MakeDecoder<float>(
200 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<void>());
201 }
202 cellToForgetWeightsDecoder = MakeDecoder<float>(
203 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<void>());
204 cellToOutputWeightsDecoder = MakeDecoder<float>(
205 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<void>());
206 }
207
208 if (projectionEnabled)
209 {
210 projectionWeightsDecoder = MakeDecoder<float>(
211 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<void>());
212 if (m_ProjectionBiasTensor)
213 {
214 projectionBiasDecoder = MakeDecoder<float>(
215 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<void>());
216 }
217 }
218
219 if (layerNormEnabled)
220 {
221 if (!cifgEnabled)
222 {
223 inputLayerNormWeightsDecoder = MakeDecoder<float>(
224 m_InputLayerNormWeightsTensor->GetTensorInfo(), m_InputLayerNormWeightsTensor->GetTensor<void>());
225
226 // Bias only used if layer norm enabled
227 armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
228 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
229 inputGateBiasDecoder = MakeDecoder<float>(
230 inputGateBiasTensorInfo, m_InputGateBiasTensor->GetTensor<void>());
231 }
232
233 forgetLayerNormWeightsDecoder = MakeDecoder<float>(
234 m_ForgetLayerNormWeightsTensor->GetTensorInfo(), m_ForgetLayerNormWeightsTensor->GetTensor<void>());
235 cellLayerNormWeightsDecoder = MakeDecoder<float>(
236 m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetTensor<void>());
237 outputLayerNormWeightsDecoder = MakeDecoder<float>(
238 m_OutputLayerNormWeightsTensor->GetTensorInfo(), m_OutputLayerNormWeightsTensor->GetTensor<void>());
239
240 // Bias only used if layer norm enabled
241 armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
242 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
243 forgetGateBiasDecoder = MakeDecoder<float>(
244 forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetTensor<void>());
245
246 armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
247 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
248 cellGateBiasDecoder = MakeDecoder<float>(
249 cellGateBiasTensorInfo, m_CellBiasTensor->GetTensor<void>());
250
251 armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
252 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
253 outputGateBiasDecoder = MakeDecoder<float>(
254 outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetTensor<void>());
255 }
256
257 // Initialize internal state tensors with zeroes.
258 if (!cifgEnabled)
259 {
260 ZeroVector(*inputGateEncoder, stateTensorSize);
261 }
262 ZeroVector(*forgetGateEncoder, stateTensorSize);
263 ZeroVector(*cellGateEncoder, stateTensorSize);
264 ZeroVector(*outputGateEncoder, stateTensorSize);
265 ZeroVector(*hiddenStateEncoder, stateTensorSize);
266
267 // Input weights * Input
268 if (!cifgEnabled)
269 {
270 MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder,
271 numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder);
272 }
273
274 MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder,
275 numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder);
276
277 MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder,
278 numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder);
279
280 MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder,
281 numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder);
282
283 // Recurrent weights * OutputStateIn
284 if (!cifgEnabled)
285 {
286 MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder,
287 numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder);
288 }
289
290 MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder,
291 numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder);
292
293 MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder,
294 numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder);
295
296 MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder,
297 numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder);
298
299 // Input gate.
300 if (!cifgEnabled)
301 {
302 if (peepholeEnabled)
303 {
304 VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder,
305 numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder);
306 }
307
308 if (layerNormEnabled)
309 {
310 inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
311 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
312 1024);
313 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
314
315 MeanStddevNormalization(*inputGateDecoder,
316 *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
317
318 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
319
320 VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder,
321 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
322
323 inputGateInfo.SetQuantizationScale(1.f / 4096);
324 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
325
326 VectorBatchVectorAdd(*inputGateBiasDecoder,
327 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
328
329 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
330 }
331
332 inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
333 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
334
335 // Input gate sigmoid
336 Activation(*inputGateDecoder, *inputGateEncoder,
337 TensorInfo({numUnits, numBatches}, internalType),
338 ActivationFunction::Sigmoid, 0, 0);
339
340 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
341 }
342
343 // Forget gate
344 if (peepholeEnabled)
345 {
346 VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits,
347 *cellStateInDecoder, numBatches, *forgetGateEncoder);
348 }
349
350 if (layerNormEnabled)
351 {
352 // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024
353 forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
354 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
355 1024);
356 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
357
358
359
360 MeanStddevNormalization(*forgetGateDecoder,
361 *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
362
363
364 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
365
366 VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder,
367 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
368
369
370 // Dequantize layer norm output to (1 / 4096)
371 forgetGateInfo.SetQuantizationScale(1.f / 4096);
372 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
373
374 VectorBatchVectorAdd(*forgetGateBiasDecoder,
375 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
376
377
378 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
379 }
380
381 forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
382 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
383
384 // Forget gate sigmoid
385 Activation(*forgetGateDecoder, *forgetGateEncoder,
386 TensorInfo({numUnits, numBatches}, internalType),
387 ActivationFunction::Sigmoid, 0, 0);
388
389 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
390
391 // Cell (Modulation) gate
392 if (layerNormEnabled)
393 {
394 cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
395 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
396 1024);
397 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
398
399 MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
400
401 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
402
403 VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder,
404 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
405
406 cellGateInfo.SetQuantizationScale(1.f / 4096);
407 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
408
409 VectorBatchVectorAdd(*cellGateBiasDecoder,
410 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
411
412 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
413 }
414
415 cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
416 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
417
418 // Cell (Modulation) gate tanH
419 Activation(*cellGateDecoder, *cellGateEncoder,
420 TensorInfo({numUnits, numBatches}, internalType),
421 ActivationFunction::TanH, 1.0f, 1.0f);
422
423 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
424
425 VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder);
426
427 if (cifgEnabled)
428 {
429 Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder);
430 VectorVectorCwiseProductAccumulate(
431 *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder);
432 }
433 else
434 {
435 VectorVectorCwiseProductAccumulate(
436 *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder);
437 }
438
439 // Final cell state out calculated here
440 if (m_Data.m_Parameters.m_CellClip > 0.0)
441 {
442 ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder);
443 }
444
445 // Output gate.
446 if (peepholeEnabled)
447 {
448 VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder,
449 numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder);
450 }
451
452 if (layerNormEnabled)
453 {
454 outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
455 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
456 1024);
457 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
458
459 MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
460
461 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
462
463 VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder,
464 numBatches, *outputGateEncoder);
465
466 outputGateInfo.SetQuantizationScale(1.f / 4096);
467 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
468
469 VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder);
470
471 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
472 }
473
474 outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
475 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
476
477 // Output gate sigmoid
478 Activation(*outputGateDecoder, *outputGateEncoder,
479 TensorInfo({numUnits, numBatches}, internalType),
480 ActivationFunction::Sigmoid, 0, 0);
481
482 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
483
484 // Hidden state tanH
485 Activation(*cellStateOutDecoder, *cellGateEncoder,
486 TensorInfo({numUnits, numBatches}, internalType),
487 ActivationFunction::TanH, 1.0f, 1.0f);
488
489 // Final hidden state output
490 VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder);
491
492 // Projection
493 if (m_Data.m_Parameters.m_ProjectionEnabled)
494 {
495 if (m_ProjectionBiasTensor)
496 {
497 VectorBatchVectorAssign(*projectionBiasDecoder,
498 outputSize, numBatches, *outputEncoder);
499 }
500
501 MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder,
502 outputSize, numUnits, *hiddenStateDecoder, numBatches, *outputEncoder);
503
504 if (m_Data.m_Parameters.m_ProjectionClip > 0.0)
505 {
506 ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder);
507 }
508 }
509 else
510 {
511 // Output has same quantization scale as hidden state if projection is disabled
512 CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder);
513 }
514
515 // output == outputStateOut
516 CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder);
517}
518
519} //namespace armnn