blob: e11ea55addeb01f1f0c157d1aa268e0fb934c742 [file] [log] [blame]
James Conroy4f1f8992020-04-29 20:01:10 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefQLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10#include "LstmUtils.hpp"
11#include "RefWorkloadUtils.hpp"
12
13namespace armnn
14{
15
16RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
17 : BaseWorkload<QLstmQueueDescriptor>(descriptor, info)
18 , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights))
19 , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights))
20 , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights))
21 , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights))
22
23 , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights))
24 , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights))
25 , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights))
26 , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights))
27
28 , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights))
29 , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights))
30 , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights))
31
32 , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias))
33 , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias))
34 , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias))
35 , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias))
36
37 , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights))
38 , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias))
39
40 , m_InputLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputLayerNormWeights))
41 , m_ForgetLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetLayerNormWeights))
42 , m_CellLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellLayerNormWeights))
43 , m_OutputLayerNormWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputLayerNormWeights))
44{}
45
46void RefQLstmWorkload::Execute() const
47{
48 // This is a porting of the QLSTM::Execute() method in the Android code base
49 // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all
50 // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp.
51 // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp
52 const DataType& internalType = armnn::DataType::QSymmS16;
53
54 const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
55 const TensorInfo& outputStateInInfo = GetTensorInfo(m_Data.m_Inputs[1]);
56 const TensorInfo& cellStateInInfo = GetTensorInfo(m_Data.m_Inputs[2]);
57
58 const TensorInfo& outputStateOutInfo = GetTensorInfo(m_Data.m_Outputs[0]);
59 const TensorInfo& cellStateOutInfo = GetTensorInfo(m_Data.m_Outputs[1]);
60 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[2]);
61
62 const TensorShape& inputShape = inputInfo.GetShape();
63 const TensorShape& outputStateInShape = outputStateInInfo.GetShape();
64 const TensorShape& cellStateInShape = cellStateInInfo.GetShape();
65
66 // Infer numBatches, inputSize, outputSize and numUnits
67 const uint32_t numBatches = inputShape[0];
68 const uint32_t inputSize = inputShape[1];
69 const uint32_t outputSize = outputStateInShape[1];
70 const uint32_t numUnits = cellStateInShape[1];
71
72 // Optional param settings
73 const bool cifgEnabled = m_Data.m_Parameters.m_CifgEnabled;
74 const bool peepholeEnabled = m_Data.m_Parameters.m_PeepholeEnabled;
75 const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled;
76 const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled;
77
78 // Input decoders
79 std::unique_ptr<Decoder<float>> inputDecoder =
80 MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map());
81 std::unique_ptr<Decoder<float>> outputStateInDecoder =
82 MakeDecoder<float>(outputStateInInfo, m_Data.m_Inputs[1]->Map());
83 std::unique_ptr<Decoder<float>> cellStateInDecoder =
84 MakeDecoder<float>(cellStateInInfo, m_Data.m_Inputs[2]->Map());
85
86 // Output decoders
87 std::unique_ptr<Decoder<float>> outputStateOutDecoder =
88 MakeDecoder<float>(outputStateOutInfo, m_Data.m_Outputs[0]->Map());
89 std::unique_ptr<Decoder<float>> cellStateOutDecoder =
90 MakeDecoder<float>(cellStateOutInfo, m_Data.m_Outputs[1]->Map());
91 std::unique_ptr<Decoder<float>> outputDecoder =
92 MakeDecoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
93
94 // Output encoders
95 std::unique_ptr<Encoder<float>> outputStateOutEncoder =
96 MakeEncoder<float>(outputStateOutInfo, m_Data.m_Outputs[0]->Map());
97 std::unique_ptr<Encoder<float>> cellStateOutEncoder =
98 MakeEncoder<float>(cellStateOutInfo, m_Data.m_Outputs[1]->Map());
99 std::unique_ptr<Encoder<float>> outputEncoder =
100 MakeEncoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
101
102 // Weights decoders
103 std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>(
104 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<void>());
105 std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>(
106 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<void>());
107 std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>(
108 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<void>());
109
110 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>(
111 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<void>());
112 std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>(
113 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<void>());
114 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>(
115 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<void>());
116
117 // Optional CIFG params
118 std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder;
119 std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder;
120 std::unique_ptr<Decoder<float>> inputGateBiasDecoder;
121
122 // Optional Peephole params
123 std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder;
124 std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder;
125 std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder;
126
127 // Optional Projection params
128 std::unique_ptr<Decoder<float>> projectionWeightsDecoder;
129 std::unique_ptr<Decoder<float>> projectionBiasDecoder;
130
131 // Optional Layer Norm params
132 std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder;
133 std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder;
134 std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder;
135 std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder;
136
137 // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024)
138 std::unique_ptr<Decoder<float>> forgetGateBiasDecoder;
139 std::unique_ptr<Decoder<float>> cellGateBiasDecoder;
140 std::unique_ptr<Decoder<float>> outputGateBiasDecoder;
141
142 // Int16 vectors for internal state data (to be decoded/encoded)
143 const uint32_t stateTensorSize = numBatches * numUnits;
144 std::vector<int16_t> inputGateData(stateTensorSize);
145 std::vector<int16_t> cellGateData(stateTensorSize);
146 std::vector<int16_t> forgetGateData(stateTensorSize);
147 std::vector<int16_t> outputGateData(stateTensorSize);
148 std::vector<int32_t> hiddenStateData(stateTensorSize);
James Conroyb22a75e2020-06-08 14:53:10 +0100149 std::vector<int16_t> outputInt16Data(numBatches * outputSize);
James Conroy4f1f8992020-04-29 20:01:10 +0100150
151 armnn::TensorInfo inputGateInfo(
152 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_InputIntermediateScale, 0);
153 armnn::TensorInfo cellGateInfo(
154 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0);
155 armnn::TensorInfo forgetGateInfo(
156 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0);
157 armnn::TensorInfo outputGateInfo(
158 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0);
159 armnn::TensorInfo hiddenStateInfo({numBatches, numUnits},
160 armnn::DataType::QAsymmS8,
161 m_Data.m_Parameters.m_HiddenStateScale,
162 m_Data.m_Parameters.m_HiddenStateZeroPoint);
James Conroyb22a75e2020-06-08 14:53:10 +0100163 armnn::TensorInfo outputInt16Info({numBatches , outputSize},
164 armnn::DataType::QSymmS16,
165 outputInfo.GetQuantizationScale(),
166 outputInfo.GetQuantizationOffset());
James Conroy4f1f8992020-04-29 20:01:10 +0100167
168 // Decoders/Encoders for internal states
169 std::unique_ptr<Decoder<float>> inputGateDecoder =
170 MakeDecoder<float>(inputGateInfo, inputGateData.data());
171 std::unique_ptr<Decoder<float>> cellGateDecoder =
172 MakeDecoder<float>(cellGateInfo, cellGateData.data());
173 std::unique_ptr<Decoder<float>> forgetGateDecoder =
174 MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
175 std::unique_ptr<Decoder<float>> outputGateDecoder =
176 MakeDecoder<float>(outputGateInfo, outputGateData.data());
177 std::unique_ptr<Decoder<float>> hiddenStateDecoder =
178 MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data());
179
180 std::unique_ptr<Encoder<float>> inputGateEncoder =
181 MakeEncoder<float>(inputGateInfo, inputGateData.data());
182 std::unique_ptr<Encoder<float>> cellGateEncoder =
183 MakeEncoder<float>(cellGateInfo, cellGateData.data());
184 std::unique_ptr<Encoder<float>> forgetGateEncoder =
185 MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
186 std::unique_ptr<Encoder<float>> outputGateEncoder =
187 MakeEncoder<float>(outputGateInfo, outputGateData.data());
188 std::unique_ptr<Encoder<float>> hiddenStateEncoder =
189 MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data());
190
James Conroyb22a75e2020-06-08 14:53:10 +0100191 // Int16 used to accumulate output to prevent overflowing (after Projection MatMul)
192 std::unique_ptr<Decoder<float>> outputInt16Decoder =
193 MakeDecoder<float>(outputInt16Info, outputInt16Data.data());
194 std::unique_ptr<Encoder<float>> outputInt16Encoder =
195 MakeEncoder<float>(outputInt16Info, outputInt16Data.data());
196
James Conroy4f1f8992020-04-29 20:01:10 +0100197 // Create decoders for optional params if they are enabled
198 if (!cifgEnabled)
199 {
200 inputToInputWeightsDecoder = MakeDecoder<float>(
201 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<void>());
202 recurrentToInputWeightsDecoder = MakeDecoder<float>(
203 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<void>());
204 }
205
206 if (peepholeEnabled)
207 {
208 if (!cifgEnabled)
209 {
210 cellToInputWeightsDecoder = MakeDecoder<float>(
211 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<void>());
212 }
213 cellToForgetWeightsDecoder = MakeDecoder<float>(
214 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<void>());
215 cellToOutputWeightsDecoder = MakeDecoder<float>(
216 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<void>());
217 }
218
219 if (projectionEnabled)
220 {
221 projectionWeightsDecoder = MakeDecoder<float>(
222 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<void>());
223 if (m_ProjectionBiasTensor)
224 {
225 projectionBiasDecoder = MakeDecoder<float>(
226 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<void>());
227 }
228 }
229
230 if (layerNormEnabled)
231 {
232 if (!cifgEnabled)
233 {
234 inputLayerNormWeightsDecoder = MakeDecoder<float>(
235 m_InputLayerNormWeightsTensor->GetTensorInfo(), m_InputLayerNormWeightsTensor->GetTensor<void>());
236
237 // Bias only used if layer norm enabled
238 armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
239 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
240 inputGateBiasDecoder = MakeDecoder<float>(
241 inputGateBiasTensorInfo, m_InputGateBiasTensor->GetTensor<void>());
242 }
243
244 forgetLayerNormWeightsDecoder = MakeDecoder<float>(
245 m_ForgetLayerNormWeightsTensor->GetTensorInfo(), m_ForgetLayerNormWeightsTensor->GetTensor<void>());
246 cellLayerNormWeightsDecoder = MakeDecoder<float>(
247 m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetTensor<void>());
248 outputLayerNormWeightsDecoder = MakeDecoder<float>(
249 m_OutputLayerNormWeightsTensor->GetTensorInfo(), m_OutputLayerNormWeightsTensor->GetTensor<void>());
250
251 // Bias only used if layer norm enabled
252 armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
253 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
254 forgetGateBiasDecoder = MakeDecoder<float>(
255 forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetTensor<void>());
256
257 armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
258 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
259 cellGateBiasDecoder = MakeDecoder<float>(
260 cellGateBiasTensorInfo, m_CellBiasTensor->GetTensor<void>());
261
262 armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
263 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
264 outputGateBiasDecoder = MakeDecoder<float>(
265 outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetTensor<void>());
266 }
267
268 // Initialize internal state tensors with zeroes.
269 if (!cifgEnabled)
270 {
271 ZeroVector(*inputGateEncoder, stateTensorSize);
272 }
273 ZeroVector(*forgetGateEncoder, stateTensorSize);
274 ZeroVector(*cellGateEncoder, stateTensorSize);
275 ZeroVector(*outputGateEncoder, stateTensorSize);
276 ZeroVector(*hiddenStateEncoder, stateTensorSize);
277
278 // Input weights * Input
279 if (!cifgEnabled)
280 {
281 MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder,
282 numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder);
283 }
284
285 MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder,
286 numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder);
287
288 MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder,
289 numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder);
290
291 MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder,
292 numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder);
293
294 // Recurrent weights * OutputStateIn
295 if (!cifgEnabled)
296 {
297 MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder,
298 numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder);
299 }
300
301 MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder,
302 numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder);
303
304 MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder,
305 numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder);
306
307 MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder,
308 numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder);
309
310 // Input gate.
311 if (!cifgEnabled)
312 {
313 if (peepholeEnabled)
314 {
315 VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder,
316 numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder);
317 }
318
319 if (layerNormEnabled)
320 {
321 inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
322 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
323 1024);
324 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
325
326 MeanStddevNormalization(*inputGateDecoder,
327 *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
328
329 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
330
331 VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder,
332 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
333
334 inputGateInfo.SetQuantizationScale(1.f / 4096);
335 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
336
337 VectorBatchVectorAdd(*inputGateBiasDecoder,
338 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
339
340 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
341 }
342
343 inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
344 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
345
346 // Input gate sigmoid
347 Activation(*inputGateDecoder, *inputGateEncoder,
348 TensorInfo({numUnits, numBatches}, internalType),
349 ActivationFunction::Sigmoid, 0, 0);
350
351 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
352 }
353
354 // Forget gate
355 if (peepholeEnabled)
356 {
357 VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits,
358 *cellStateInDecoder, numBatches, *forgetGateEncoder);
359 }
360
361 if (layerNormEnabled)
362 {
363 // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024
364 forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
365 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
366 1024);
367 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
368
369
370
371 MeanStddevNormalization(*forgetGateDecoder,
372 *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
373
374
375 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
376
377 VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder,
378 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
379
380
381 // Dequantize layer norm output to (1 / 4096)
382 forgetGateInfo.SetQuantizationScale(1.f / 4096);
383 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
384
385 VectorBatchVectorAdd(*forgetGateBiasDecoder,
386 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
387
388
389 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
390 }
391
392 forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
393 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
394
395 // Forget gate sigmoid
396 Activation(*forgetGateDecoder, *forgetGateEncoder,
397 TensorInfo({numUnits, numBatches}, internalType),
398 ActivationFunction::Sigmoid, 0, 0);
399
400 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
401
402 // Cell (Modulation) gate
403 if (layerNormEnabled)
404 {
405 cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
406 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
407 1024);
408 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
409
410 MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
411
412 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
413
414 VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder,
415 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
416
417 cellGateInfo.SetQuantizationScale(1.f / 4096);
418 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
419
420 VectorBatchVectorAdd(*cellGateBiasDecoder,
421 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
422
423 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
424 }
425
426 cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
427 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
428
429 // Cell (Modulation) gate tanH
430 Activation(*cellGateDecoder, *cellGateEncoder,
431 TensorInfo({numUnits, numBatches}, internalType),
432 ActivationFunction::TanH, 1.0f, 1.0f);
433
434 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
435
436 VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder);
437
438 if (cifgEnabled)
439 {
440 Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder);
441 VectorVectorCwiseProductAccumulate(
442 *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder);
443 }
444 else
445 {
446 VectorVectorCwiseProductAccumulate(
447 *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder);
448 }
449
450 // Final cell state out calculated here
451 if (m_Data.m_Parameters.m_CellClip > 0.0)
452 {
453 ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder);
454 }
455
456 // Output gate.
457 if (peepholeEnabled)
458 {
459 VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder,
460 numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder);
461 }
462
463 if (layerNormEnabled)
464 {
465 outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
466 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
467 1024);
468 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
469
470 MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
471
472 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
473
474 VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder,
475 numBatches, *outputGateEncoder);
476
477 outputGateInfo.SetQuantizationScale(1.f / 4096);
478 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
479
480 VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder);
481
482 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
483 }
484
485 outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
486 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
487
488 // Output gate sigmoid
489 Activation(*outputGateDecoder, *outputGateEncoder,
490 TensorInfo({numUnits, numBatches}, internalType),
491 ActivationFunction::Sigmoid, 0, 0);
492
493 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
494
495 // Hidden state tanH
496 Activation(*cellStateOutDecoder, *cellGateEncoder,
497 TensorInfo({numUnits, numBatches}, internalType),
498 ActivationFunction::TanH, 1.0f, 1.0f);
499
500 // Final hidden state output
501 VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder);
502
503 // Projection
504 if (m_Data.m_Parameters.m_ProjectionEnabled)
505 {
506 if (m_ProjectionBiasTensor)
507 {
James Conroyb22a75e2020-06-08 14:53:10 +0100508 VectorBatchVectorAssign(*projectionBiasDecoder, outputSize, numBatches, *outputInt16Encoder);
James Conroy4f1f8992020-04-29 20:01:10 +0100509 }
510
James Conroyb22a75e2020-06-08 14:53:10 +0100511 MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder, outputSize, numUnits, *hiddenStateDecoder,
512 numBatches, *outputInt16Encoder);
513
514 CopyVector(*outputInt16Decoder, numBatches * outputSize, *outputEncoder);
James Conroy4f1f8992020-04-29 20:01:10 +0100515
516 if (m_Data.m_Parameters.m_ProjectionClip > 0.0)
517 {
518 ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder);
519 }
520 }
521 else
522 {
523 // Output has same quantization scale as hidden state if projection is disabled
524 CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder);
525 }
526
527 // output == outputStateOut
528 CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder);
529}
530
531} //namespace armnn