blob: 398faa907417e017d339fae85961baf9fbafa89d [file] [log] [blame]
James Conroy4f1f8992020-04-29 20:01:10 +01001//
Matthew Sloyan2d213a72022-06-30 17:13:04 +01002// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
James Conroy4f1f8992020-04-29 20:01:10 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "RefQLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10#include "LstmUtils.hpp"
11#include "RefWorkloadUtils.hpp"
12
13namespace armnn
14{
15
16RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
Finn Williams73c547d2022-02-15 20:47:34 +000017 : RefBaseWorkload<QLstmQueueDescriptor>(descriptor, info)
James Conroy1f58f032021-04-27 17:13:27 +010018 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
19 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
20 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
21 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
James Conroy4f1f8992020-04-29 20:01:10 +010022
James Conroy1f58f032021-04-27 17:13:27 +010023 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
James Conroy4f1f8992020-04-29 20:01:10 +010027
James Conroy1f58f032021-04-27 17:13:27 +010028 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
29 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
30 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
James Conroy4f1f8992020-04-29 20:01:10 +010031
James Conroy1f58f032021-04-27 17:13:27 +010032 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
33 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
34 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
35 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
James Conroy4f1f8992020-04-29 20:01:10 +010036
James Conroy1f58f032021-04-27 17:13:27 +010037 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
38 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
James Conroy4f1f8992020-04-29 20:01:10 +010039
James Conroy1f58f032021-04-27 17:13:27 +010040 , m_InputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
41 , m_ForgetLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
42 , m_CellLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
43 , m_OutputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
James Conroy4f1f8992020-04-29 20:01:10 +010044{}
45
46void RefQLstmWorkload::Execute() const
47{
Finn Williamsb8181f72021-04-07 10:23:21 +010048 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
49}
50
Matthew Sloyan2d213a72022-06-30 17:13:04 +010051void RefQLstmWorkload::ExecuteAsync(ExecutionData& executionData)
Finn Williamsb8181f72021-04-07 10:23:21 +010052{
Matthew Sloyan2d213a72022-06-30 17:13:04 +010053 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
54 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
Finn Williamsb8181f72021-04-07 10:23:21 +010055}
56
57void RefQLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
58{
59 // This is a porting of the QLSTM::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs)
60 // method in the Android code base
James Conroy4f1f8992020-04-29 20:01:10 +010061 // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all
62 // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp.
63 // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp
64 const DataType& internalType = armnn::DataType::QSymmS16;
65
Finn Williamsb8181f72021-04-07 10:23:21 +010066 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
67 const TensorInfo& outputStateInInfo = GetTensorInfo(inputs[1]);
68 const TensorInfo& cellStateInInfo = GetTensorInfo(inputs[2]);
James Conroy4f1f8992020-04-29 20:01:10 +010069
Finn Williamsb8181f72021-04-07 10:23:21 +010070 const TensorInfo& outputStateOutInfo = GetTensorInfo(outputs[0]);
71 const TensorInfo& cellStateOutInfo = GetTensorInfo(outputs[1]);
72 const TensorInfo& outputInfo = GetTensorInfo(outputs[2]);
James Conroy4f1f8992020-04-29 20:01:10 +010073
74 const TensorShape& inputShape = inputInfo.GetShape();
75 const TensorShape& outputStateInShape = outputStateInInfo.GetShape();
76 const TensorShape& cellStateInShape = cellStateInInfo.GetShape();
77
78 // Infer numBatches, inputSize, outputSize and numUnits
79 const uint32_t numBatches = inputShape[0];
80 const uint32_t inputSize = inputShape[1];
81 const uint32_t outputSize = outputStateInShape[1];
82 const uint32_t numUnits = cellStateInShape[1];
83
84 // Optional param settings
85 const bool cifgEnabled = m_Data.m_Parameters.m_CifgEnabled;
86 const bool peepholeEnabled = m_Data.m_Parameters.m_PeepholeEnabled;
87 const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled;
88 const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled;
89
90 // Input decoders
91 std::unique_ptr<Decoder<float>> inputDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010092 MakeDecoder<float>(inputInfo, inputs[0]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +010093 std::unique_ptr<Decoder<float>> outputStateInDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010094 MakeDecoder<float>(outputStateInInfo, inputs[1]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +010095 std::unique_ptr<Decoder<float>> cellStateInDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010096 MakeDecoder<float>(cellStateInInfo, inputs[2]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +010097
98 // Output decoders
99 std::unique_ptr<Decoder<float>> outputStateOutDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100100 MakeDecoder<float>(outputStateOutInfo, outputs[0]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100101 std::unique_ptr<Decoder<float>> cellStateOutDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100102 MakeDecoder<float>(cellStateOutInfo, outputs[1]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100103 std::unique_ptr<Decoder<float>> outputDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100104 MakeDecoder<float>(outputInfo, outputs[2]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100105
106 // Output encoders
107 std::unique_ptr<Encoder<float>> outputStateOutEncoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100108 MakeEncoder<float>(outputStateOutInfo, outputs[0]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100109 std::unique_ptr<Encoder<float>> cellStateOutEncoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100110 MakeEncoder<float>(cellStateOutInfo, outputs[1]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100111 std::unique_ptr<Encoder<float>> outputEncoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100112 MakeEncoder<float>(outputInfo, outputs[2]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100113
114 // Weights decoders
115 std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000116 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100117 std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000118 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100119 std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000120 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100121
122 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000123 m_RecurrentToForgetWeightsTensor->GetTensorInfo(),
124 m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100125 std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000126 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100127 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000128 m_RecurrentToOutputWeightsTensor->GetTensorInfo(),
129 m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100130
131 // Optional CIFG params
132 std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder;
133 std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder;
134 std::unique_ptr<Decoder<float>> inputGateBiasDecoder;
135
136 // Optional Peephole params
137 std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder;
138 std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder;
139 std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder;
140
141 // Optional Projection params
142 std::unique_ptr<Decoder<float>> projectionWeightsDecoder;
143 std::unique_ptr<Decoder<float>> projectionBiasDecoder;
144
145 // Optional Layer Norm params
146 std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder;
147 std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder;
148 std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder;
149 std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder;
150
151 // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024)
152 std::unique_ptr<Decoder<float>> forgetGateBiasDecoder;
153 std::unique_ptr<Decoder<float>> cellGateBiasDecoder;
154 std::unique_ptr<Decoder<float>> outputGateBiasDecoder;
155
156 // Int16 vectors for internal state data (to be decoded/encoded)
157 const uint32_t stateTensorSize = numBatches * numUnits;
158 std::vector<int16_t> inputGateData(stateTensorSize);
159 std::vector<int16_t> cellGateData(stateTensorSize);
160 std::vector<int16_t> forgetGateData(stateTensorSize);
161 std::vector<int16_t> outputGateData(stateTensorSize);
162 std::vector<int32_t> hiddenStateData(stateTensorSize);
James Conroyb22a75e2020-06-08 14:53:10 +0100163 std::vector<int16_t> outputInt16Data(numBatches * outputSize);
James Conroy4f1f8992020-04-29 20:01:10 +0100164
165 armnn::TensorInfo inputGateInfo(
166 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_InputIntermediateScale, 0);
167 armnn::TensorInfo cellGateInfo(
168 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0);
169 armnn::TensorInfo forgetGateInfo(
170 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0);
171 armnn::TensorInfo outputGateInfo(
172 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0);
173 armnn::TensorInfo hiddenStateInfo({numBatches, numUnits},
174 armnn::DataType::QAsymmS8,
175 m_Data.m_Parameters.m_HiddenStateScale,
176 m_Data.m_Parameters.m_HiddenStateZeroPoint);
James Conroyb22a75e2020-06-08 14:53:10 +0100177 armnn::TensorInfo outputInt16Info({numBatches , outputSize},
178 armnn::DataType::QSymmS16,
179 outputInfo.GetQuantizationScale(),
180 outputInfo.GetQuantizationOffset());
James Conroy4f1f8992020-04-29 20:01:10 +0100181
182 // Decoders/Encoders for internal states
183 std::unique_ptr<Decoder<float>> inputGateDecoder =
184 MakeDecoder<float>(inputGateInfo, inputGateData.data());
185 std::unique_ptr<Decoder<float>> cellGateDecoder =
186 MakeDecoder<float>(cellGateInfo, cellGateData.data());
187 std::unique_ptr<Decoder<float>> forgetGateDecoder =
188 MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
189 std::unique_ptr<Decoder<float>> outputGateDecoder =
190 MakeDecoder<float>(outputGateInfo, outputGateData.data());
191 std::unique_ptr<Decoder<float>> hiddenStateDecoder =
192 MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data());
193
194 std::unique_ptr<Encoder<float>> inputGateEncoder =
195 MakeEncoder<float>(inputGateInfo, inputGateData.data());
196 std::unique_ptr<Encoder<float>> cellGateEncoder =
197 MakeEncoder<float>(cellGateInfo, cellGateData.data());
198 std::unique_ptr<Encoder<float>> forgetGateEncoder =
199 MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
200 std::unique_ptr<Encoder<float>> outputGateEncoder =
201 MakeEncoder<float>(outputGateInfo, outputGateData.data());
202 std::unique_ptr<Encoder<float>> hiddenStateEncoder =
203 MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data());
204
James Conroyb22a75e2020-06-08 14:53:10 +0100205 // Int16 used to accumulate output to prevent overflowing (after Projection MatMul)
206 std::unique_ptr<Decoder<float>> outputInt16Decoder =
207 MakeDecoder<float>(outputInt16Info, outputInt16Data.data());
208 std::unique_ptr<Encoder<float>> outputInt16Encoder =
209 MakeEncoder<float>(outputInt16Info, outputInt16Data.data());
210
James Conroy4f1f8992020-04-29 20:01:10 +0100211 // Create decoders for optional params if they are enabled
212 if (!cifgEnabled)
213 {
214 inputToInputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000215 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
216 recurrentToInputWeightsDecoder = MakeDecoder<float>(m_RecurrentToInputWeightsTensor->GetTensorInfo(),
217 m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100218 }
219
220 if (peepholeEnabled)
221 {
222 if (!cifgEnabled)
223 {
224 cellToInputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000225 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100226 }
227 cellToForgetWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000228 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100229 cellToOutputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000230 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100231 }
232
233 if (projectionEnabled)
234 {
235 projectionWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000236 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100237 if (m_ProjectionBiasTensor)
238 {
239 projectionBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000240 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100241 }
242 }
243
244 if (layerNormEnabled)
245 {
246 if (!cifgEnabled)
247 {
Finn Williams4422cec2021-03-22 17:51:06 +0000248 inputLayerNormWeightsDecoder = MakeDecoder<float>(m_InputLayerNormWeightsTensor->GetTensorInfo(),
249 m_InputLayerNormWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100250
251 // Bias only used if layer norm enabled
252 armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
253 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
254 inputGateBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000255 inputGateBiasTensorInfo, m_InputGateBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100256 }
257
258 forgetLayerNormWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000259 m_ForgetLayerNormWeightsTensor->GetTensorInfo(),
260 m_ForgetLayerNormWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100261 cellLayerNormWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000262 m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100263 outputLayerNormWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000264 m_OutputLayerNormWeightsTensor->GetTensorInfo(),
265 m_OutputLayerNormWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100266
267 // Bias only used if layer norm enabled
268 armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
269 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
270 forgetGateBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000271 forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100272
273 armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
274 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
275 cellGateBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000276 cellGateBiasTensorInfo, m_CellBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100277
278 armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
279 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
280 outputGateBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000281 outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100282 }
283
284 // Initialize internal state tensors with zeroes.
285 if (!cifgEnabled)
286 {
287 ZeroVector(*inputGateEncoder, stateTensorSize);
288 }
289 ZeroVector(*forgetGateEncoder, stateTensorSize);
290 ZeroVector(*cellGateEncoder, stateTensorSize);
291 ZeroVector(*outputGateEncoder, stateTensorSize);
292 ZeroVector(*hiddenStateEncoder, stateTensorSize);
293
294 // Input weights * Input
295 if (!cifgEnabled)
296 {
297 MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder,
298 numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder);
299 }
300
301 MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder,
302 numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder);
303
304 MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder,
305 numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder);
306
307 MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder,
308 numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder);
309
310 // Recurrent weights * OutputStateIn
311 if (!cifgEnabled)
312 {
313 MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder,
314 numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder);
315 }
316
317 MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder,
318 numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder);
319
320 MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder,
321 numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder);
322
323 MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder,
324 numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder);
325
326 // Input gate.
327 if (!cifgEnabled)
328 {
329 if (peepholeEnabled)
330 {
331 VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder,
332 numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder);
333 }
334
335 if (layerNormEnabled)
336 {
337 inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
338 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
339 1024);
340 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
341
342 MeanStddevNormalization(*inputGateDecoder,
343 *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
344
345 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
346
347 VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder,
348 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
349
350 inputGateInfo.SetQuantizationScale(1.f / 4096);
351 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
352
353 VectorBatchVectorAdd(*inputGateBiasDecoder,
354 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
355
356 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
357 }
358
359 inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
360 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
361
362 // Input gate sigmoid
363 Activation(*inputGateDecoder, *inputGateEncoder,
364 TensorInfo({numUnits, numBatches}, internalType),
365 ActivationFunction::Sigmoid, 0, 0);
366
367 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
368 }
369
370 // Forget gate
371 if (peepholeEnabled)
372 {
373 VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits,
374 *cellStateInDecoder, numBatches, *forgetGateEncoder);
375 }
376
377 if (layerNormEnabled)
378 {
379 // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024
380 forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
381 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
382 1024);
383 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
384
385
386
387 MeanStddevNormalization(*forgetGateDecoder,
388 *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
389
390
391 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
392
393 VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder,
394 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
395
396
397 // Dequantize layer norm output to (1 / 4096)
398 forgetGateInfo.SetQuantizationScale(1.f / 4096);
399 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
400
401 VectorBatchVectorAdd(*forgetGateBiasDecoder,
402 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
403
404
405 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
406 }
407
408 forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
409 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
410
411 // Forget gate sigmoid
412 Activation(*forgetGateDecoder, *forgetGateEncoder,
413 TensorInfo({numUnits, numBatches}, internalType),
414 ActivationFunction::Sigmoid, 0, 0);
415
416 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
417
418 // Cell (Modulation) gate
419 if (layerNormEnabled)
420 {
421 cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
422 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
423 1024);
424 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
425
426 MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
427
428 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
429
430 VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder,
431 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
432
433 cellGateInfo.SetQuantizationScale(1.f / 4096);
434 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
435
436 VectorBatchVectorAdd(*cellGateBiasDecoder,
437 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
438
439 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
440 }
441
442 cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
443 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
444
445 // Cell (Modulation) gate tanH
446 Activation(*cellGateDecoder, *cellGateEncoder,
447 TensorInfo({numUnits, numBatches}, internalType),
448 ActivationFunction::TanH, 1.0f, 1.0f);
449
450 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
451
452 VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder);
453
454 if (cifgEnabled)
455 {
456 Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder);
457 VectorVectorCwiseProductAccumulate(
458 *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder);
459 }
460 else
461 {
462 VectorVectorCwiseProductAccumulate(
463 *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder);
464 }
465
466 // Final cell state out calculated here
467 if (m_Data.m_Parameters.m_CellClip > 0.0)
468 {
469 ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder);
470 }
471
472 // Output gate.
473 if (peepholeEnabled)
474 {
475 VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder,
476 numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder);
477 }
478
479 if (layerNormEnabled)
480 {
481 outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
482 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
483 1024);
484 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
485
486 MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
487
488 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
489
490 VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder,
491 numBatches, *outputGateEncoder);
492
493 outputGateInfo.SetQuantizationScale(1.f / 4096);
494 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
495
496 VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder);
497
498 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
499 }
500
501 outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
502 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
503
504 // Output gate sigmoid
505 Activation(*outputGateDecoder, *outputGateEncoder,
506 TensorInfo({numUnits, numBatches}, internalType),
507 ActivationFunction::Sigmoid, 0, 0);
508
509 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
510
511 // Hidden state tanH
512 Activation(*cellStateOutDecoder, *cellGateEncoder,
513 TensorInfo({numUnits, numBatches}, internalType),
514 ActivationFunction::TanH, 1.0f, 1.0f);
515
516 // Final hidden state output
517 VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder);
518
519 // Projection
520 if (m_Data.m_Parameters.m_ProjectionEnabled)
521 {
522 if (m_ProjectionBiasTensor)
523 {
James Conroyb22a75e2020-06-08 14:53:10 +0100524 VectorBatchVectorAssign(*projectionBiasDecoder, outputSize, numBatches, *outputInt16Encoder);
James Conroy4f1f8992020-04-29 20:01:10 +0100525 }
526
James Conroyb22a75e2020-06-08 14:53:10 +0100527 MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder, outputSize, numUnits, *hiddenStateDecoder,
528 numBatches, *outputInt16Encoder);
529
530 CopyVector(*outputInt16Decoder, numBatches * outputSize, *outputEncoder);
James Conroy4f1f8992020-04-29 20:01:10 +0100531
532 if (m_Data.m_Parameters.m_ProjectionClip > 0.0)
533 {
534 ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder);
535 }
536 }
537 else
538 {
539 // Output has same quantization scale as hidden state if projection is disabled
540 CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder);
541 }
542
543 // output == outputStateOut
544 CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder);
545}
546
547} //namespace armnn