blob: a5f939668bf370dfad3376aa414ac4c3a42c868d [file] [log] [blame]
James Conroy4f1f8992020-04-29 20:01:10 +01001//
Mike Kelly7cbe7812023-07-25 17:37:33 +01002// Copyright © 2020-2023 Arm Ltd and Contributors. All rights reserved.
James Conroy4f1f8992020-04-29 20:01:10 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "RefQLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10#include "LstmUtils.hpp"
11#include "RefWorkloadUtils.hpp"
12
13namespace armnn
14{
15
16RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
Finn Williams73c547d2022-02-15 20:47:34 +000017 : RefBaseWorkload<QLstmQueueDescriptor>(descriptor, info)
James Conroy1f58f032021-04-27 17:13:27 +010018 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
19 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
20 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
21 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
James Conroy4f1f8992020-04-29 20:01:10 +010022
James Conroy1f58f032021-04-27 17:13:27 +010023 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
James Conroy4f1f8992020-04-29 20:01:10 +010027
James Conroy1f58f032021-04-27 17:13:27 +010028 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
29 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
30 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
James Conroy4f1f8992020-04-29 20:01:10 +010031
James Conroy1f58f032021-04-27 17:13:27 +010032 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
33 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
34 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
35 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
James Conroy4f1f8992020-04-29 20:01:10 +010036
James Conroy1f58f032021-04-27 17:13:27 +010037 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
38 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
James Conroy4f1f8992020-04-29 20:01:10 +010039
James Conroy1f58f032021-04-27 17:13:27 +010040 , m_InputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
41 , m_ForgetLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
42 , m_CellLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
43 , m_OutputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
James Conroy4f1f8992020-04-29 20:01:10 +010044{}
45
46void RefQLstmWorkload::Execute() const
47{
Finn Williamsb8181f72021-04-07 10:23:21 +010048 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
49}
50
Matthew Sloyan2d213a72022-06-30 17:13:04 +010051void RefQLstmWorkload::ExecuteAsync(ExecutionData& executionData)
Finn Williamsb8181f72021-04-07 10:23:21 +010052{
Matthew Sloyan2d213a72022-06-30 17:13:04 +010053 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
54 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
Finn Williamsb8181f72021-04-07 10:23:21 +010055}
56
57void RefQLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
58{
Mike Kelly7cbe7812023-07-25 17:37:33 +010059 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefQLstmWorkload_Execute");
60
Finn Williamsb8181f72021-04-07 10:23:21 +010061 // This is a porting of the QLSTM::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs)
62 // method in the Android code base
James Conroy4f1f8992020-04-29 20:01:10 +010063 // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all
64 // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp.
65 // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp
66 const DataType& internalType = armnn::DataType::QSymmS16;
67
Finn Williamsb8181f72021-04-07 10:23:21 +010068 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
69 const TensorInfo& outputStateInInfo = GetTensorInfo(inputs[1]);
70 const TensorInfo& cellStateInInfo = GetTensorInfo(inputs[2]);
James Conroy4f1f8992020-04-29 20:01:10 +010071
Finn Williamsb8181f72021-04-07 10:23:21 +010072 const TensorInfo& outputStateOutInfo = GetTensorInfo(outputs[0]);
73 const TensorInfo& cellStateOutInfo = GetTensorInfo(outputs[1]);
74 const TensorInfo& outputInfo = GetTensorInfo(outputs[2]);
James Conroy4f1f8992020-04-29 20:01:10 +010075
76 const TensorShape& inputShape = inputInfo.GetShape();
77 const TensorShape& outputStateInShape = outputStateInInfo.GetShape();
78 const TensorShape& cellStateInShape = cellStateInInfo.GetShape();
79
80 // Infer numBatches, inputSize, outputSize and numUnits
81 const uint32_t numBatches = inputShape[0];
82 const uint32_t inputSize = inputShape[1];
83 const uint32_t outputSize = outputStateInShape[1];
84 const uint32_t numUnits = cellStateInShape[1];
85
86 // Optional param settings
87 const bool cifgEnabled = m_Data.m_Parameters.m_CifgEnabled;
88 const bool peepholeEnabled = m_Data.m_Parameters.m_PeepholeEnabled;
89 const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled;
90 const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled;
91
92 // Input decoders
93 std::unique_ptr<Decoder<float>> inputDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010094 MakeDecoder<float>(inputInfo, inputs[0]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +010095 std::unique_ptr<Decoder<float>> outputStateInDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010096 MakeDecoder<float>(outputStateInInfo, inputs[1]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +010097 std::unique_ptr<Decoder<float>> cellStateInDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010098 MakeDecoder<float>(cellStateInInfo, inputs[2]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +010099
100 // Output decoders
101 std::unique_ptr<Decoder<float>> outputStateOutDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100102 MakeDecoder<float>(outputStateOutInfo, outputs[0]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100103 std::unique_ptr<Decoder<float>> cellStateOutDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100104 MakeDecoder<float>(cellStateOutInfo, outputs[1]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100105 std::unique_ptr<Decoder<float>> outputDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100106 MakeDecoder<float>(outputInfo, outputs[2]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100107
108 // Output encoders
109 std::unique_ptr<Encoder<float>> outputStateOutEncoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100110 MakeEncoder<float>(outputStateOutInfo, outputs[0]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100111 std::unique_ptr<Encoder<float>> cellStateOutEncoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100112 MakeEncoder<float>(cellStateOutInfo, outputs[1]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100113 std::unique_ptr<Encoder<float>> outputEncoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100114 MakeEncoder<float>(outputInfo, outputs[2]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100115
116 // Weights decoders
117 std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000118 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100119 std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000120 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100121 std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000122 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100123
124 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000125 m_RecurrentToForgetWeightsTensor->GetTensorInfo(),
126 m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100127 std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000128 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100129 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000130 m_RecurrentToOutputWeightsTensor->GetTensorInfo(),
131 m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100132
133 // Optional CIFG params
134 std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder;
135 std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder;
136 std::unique_ptr<Decoder<float>> inputGateBiasDecoder;
137
138 // Optional Peephole params
139 std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder;
140 std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder;
141 std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder;
142
143 // Optional Projection params
144 std::unique_ptr<Decoder<float>> projectionWeightsDecoder;
145 std::unique_ptr<Decoder<float>> projectionBiasDecoder;
146
147 // Optional Layer Norm params
148 std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder;
149 std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder;
150 std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder;
151 std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder;
152
153 // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024)
154 std::unique_ptr<Decoder<float>> forgetGateBiasDecoder;
155 std::unique_ptr<Decoder<float>> cellGateBiasDecoder;
156 std::unique_ptr<Decoder<float>> outputGateBiasDecoder;
157
158 // Int16 vectors for internal state data (to be decoded/encoded)
159 const uint32_t stateTensorSize = numBatches * numUnits;
160 std::vector<int16_t> inputGateData(stateTensorSize);
161 std::vector<int16_t> cellGateData(stateTensorSize);
162 std::vector<int16_t> forgetGateData(stateTensorSize);
163 std::vector<int16_t> outputGateData(stateTensorSize);
164 std::vector<int32_t> hiddenStateData(stateTensorSize);
James Conroyb22a75e2020-06-08 14:53:10 +0100165 std::vector<int16_t> outputInt16Data(numBatches * outputSize);
James Conroy4f1f8992020-04-29 20:01:10 +0100166
167 armnn::TensorInfo inputGateInfo(
168 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_InputIntermediateScale, 0);
169 armnn::TensorInfo cellGateInfo(
170 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0);
171 armnn::TensorInfo forgetGateInfo(
172 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0);
173 armnn::TensorInfo outputGateInfo(
174 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0);
175 armnn::TensorInfo hiddenStateInfo({numBatches, numUnits},
176 armnn::DataType::QAsymmS8,
177 m_Data.m_Parameters.m_HiddenStateScale,
178 m_Data.m_Parameters.m_HiddenStateZeroPoint);
James Conroyb22a75e2020-06-08 14:53:10 +0100179 armnn::TensorInfo outputInt16Info({numBatches , outputSize},
180 armnn::DataType::QSymmS16,
181 outputInfo.GetQuantizationScale(),
182 outputInfo.GetQuantizationOffset());
James Conroy4f1f8992020-04-29 20:01:10 +0100183
184 // Decoders/Encoders for internal states
185 std::unique_ptr<Decoder<float>> inputGateDecoder =
186 MakeDecoder<float>(inputGateInfo, inputGateData.data());
187 std::unique_ptr<Decoder<float>> cellGateDecoder =
188 MakeDecoder<float>(cellGateInfo, cellGateData.data());
189 std::unique_ptr<Decoder<float>> forgetGateDecoder =
190 MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
191 std::unique_ptr<Decoder<float>> outputGateDecoder =
192 MakeDecoder<float>(outputGateInfo, outputGateData.data());
193 std::unique_ptr<Decoder<float>> hiddenStateDecoder =
194 MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data());
195
196 std::unique_ptr<Encoder<float>> inputGateEncoder =
197 MakeEncoder<float>(inputGateInfo, inputGateData.data());
198 std::unique_ptr<Encoder<float>> cellGateEncoder =
199 MakeEncoder<float>(cellGateInfo, cellGateData.data());
200 std::unique_ptr<Encoder<float>> forgetGateEncoder =
201 MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
202 std::unique_ptr<Encoder<float>> outputGateEncoder =
203 MakeEncoder<float>(outputGateInfo, outputGateData.data());
204 std::unique_ptr<Encoder<float>> hiddenStateEncoder =
205 MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data());
206
James Conroyb22a75e2020-06-08 14:53:10 +0100207 // Int16 used to accumulate output to prevent overflowing (after Projection MatMul)
208 std::unique_ptr<Decoder<float>> outputInt16Decoder =
209 MakeDecoder<float>(outputInt16Info, outputInt16Data.data());
210 std::unique_ptr<Encoder<float>> outputInt16Encoder =
211 MakeEncoder<float>(outputInt16Info, outputInt16Data.data());
212
James Conroy4f1f8992020-04-29 20:01:10 +0100213 // Create decoders for optional params if they are enabled
214 if (!cifgEnabled)
215 {
216 inputToInputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000217 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
218 recurrentToInputWeightsDecoder = MakeDecoder<float>(m_RecurrentToInputWeightsTensor->GetTensorInfo(),
219 m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100220 }
221
222 if (peepholeEnabled)
223 {
224 if (!cifgEnabled)
225 {
226 cellToInputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000227 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100228 }
229 cellToForgetWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000230 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100231 cellToOutputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000232 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100233 }
234
235 if (projectionEnabled)
236 {
237 projectionWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000238 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100239 if (m_ProjectionBiasTensor)
240 {
241 projectionBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000242 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100243 }
244 }
245
246 if (layerNormEnabled)
247 {
248 if (!cifgEnabled)
249 {
Finn Williams4422cec2021-03-22 17:51:06 +0000250 inputLayerNormWeightsDecoder = MakeDecoder<float>(m_InputLayerNormWeightsTensor->GetTensorInfo(),
251 m_InputLayerNormWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100252
253 // Bias only used if layer norm enabled
254 armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
255 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
256 inputGateBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000257 inputGateBiasTensorInfo, m_InputGateBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100258 }
259
260 forgetLayerNormWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000261 m_ForgetLayerNormWeightsTensor->GetTensorInfo(),
262 m_ForgetLayerNormWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100263 cellLayerNormWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000264 m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100265 outputLayerNormWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000266 m_OutputLayerNormWeightsTensor->GetTensorInfo(),
267 m_OutputLayerNormWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100268
269 // Bias only used if layer norm enabled
270 armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
271 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
272 forgetGateBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000273 forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100274
275 armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
276 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
277 cellGateBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000278 cellGateBiasTensorInfo, m_CellBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100279
280 armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
281 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
282 outputGateBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000283 outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100284 }
285
286 // Initialize internal state tensors with zeroes.
287 if (!cifgEnabled)
288 {
289 ZeroVector(*inputGateEncoder, stateTensorSize);
290 }
291 ZeroVector(*forgetGateEncoder, stateTensorSize);
292 ZeroVector(*cellGateEncoder, stateTensorSize);
293 ZeroVector(*outputGateEncoder, stateTensorSize);
294 ZeroVector(*hiddenStateEncoder, stateTensorSize);
295
296 // Input weights * Input
297 if (!cifgEnabled)
298 {
299 MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder,
300 numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder);
301 }
302
303 MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder,
304 numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder);
305
306 MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder,
307 numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder);
308
309 MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder,
310 numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder);
311
312 // Recurrent weights * OutputStateIn
313 if (!cifgEnabled)
314 {
315 MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder,
316 numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder);
317 }
318
319 MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder,
320 numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder);
321
322 MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder,
323 numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder);
324
325 MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder,
326 numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder);
327
328 // Input gate.
329 if (!cifgEnabled)
330 {
331 if (peepholeEnabled)
332 {
333 VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder,
334 numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder);
335 }
336
337 if (layerNormEnabled)
338 {
339 inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
340 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
341 1024);
342 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
343
344 MeanStddevNormalization(*inputGateDecoder,
345 *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
346
347 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
348
349 VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder,
350 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
351
352 inputGateInfo.SetQuantizationScale(1.f / 4096);
353 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
354
355 VectorBatchVectorAdd(*inputGateBiasDecoder,
356 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
357
358 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
359 }
360
361 inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
362 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
363
364 // Input gate sigmoid
365 Activation(*inputGateDecoder, *inputGateEncoder,
366 TensorInfo({numUnits, numBatches}, internalType),
367 ActivationFunction::Sigmoid, 0, 0);
368
369 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
370 }
371
372 // Forget gate
373 if (peepholeEnabled)
374 {
375 VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits,
376 *cellStateInDecoder, numBatches, *forgetGateEncoder);
377 }
378
379 if (layerNormEnabled)
380 {
381 // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024
382 forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
383 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
384 1024);
385 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
386
387
388
389 MeanStddevNormalization(*forgetGateDecoder,
390 *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
391
392
393 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
394
395 VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder,
396 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
397
398
399 // Dequantize layer norm output to (1 / 4096)
400 forgetGateInfo.SetQuantizationScale(1.f / 4096);
401 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
402
403 VectorBatchVectorAdd(*forgetGateBiasDecoder,
404 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
405
406
407 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
408 }
409
410 forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
411 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
412
413 // Forget gate sigmoid
414 Activation(*forgetGateDecoder, *forgetGateEncoder,
415 TensorInfo({numUnits, numBatches}, internalType),
416 ActivationFunction::Sigmoid, 0, 0);
417
418 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
419
420 // Cell (Modulation) gate
421 if (layerNormEnabled)
422 {
423 cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
424 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
425 1024);
426 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
427
428 MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
429
430 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
431
432 VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder,
433 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
434
435 cellGateInfo.SetQuantizationScale(1.f / 4096);
436 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
437
438 VectorBatchVectorAdd(*cellGateBiasDecoder,
439 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
440
441 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
442 }
443
444 cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
445 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
446
447 // Cell (Modulation) gate tanH
448 Activation(*cellGateDecoder, *cellGateEncoder,
449 TensorInfo({numUnits, numBatches}, internalType),
450 ActivationFunction::TanH, 1.0f, 1.0f);
451
452 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
453
454 VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder);
455
456 if (cifgEnabled)
457 {
458 Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder);
459 VectorVectorCwiseProductAccumulate(
460 *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder);
461 }
462 else
463 {
464 VectorVectorCwiseProductAccumulate(
465 *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder);
466 }
467
468 // Final cell state out calculated here
469 if (m_Data.m_Parameters.m_CellClip > 0.0)
470 {
471 ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder);
472 }
473
474 // Output gate.
475 if (peepholeEnabled)
476 {
477 VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder,
478 numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder);
479 }
480
481 if (layerNormEnabled)
482 {
483 outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
484 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
485 1024);
486 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
487
488 MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
489
490 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
491
492 VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder,
493 numBatches, *outputGateEncoder);
494
495 outputGateInfo.SetQuantizationScale(1.f / 4096);
496 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
497
498 VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder);
499
500 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
501 }
502
503 outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
504 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
505
506 // Output gate sigmoid
507 Activation(*outputGateDecoder, *outputGateEncoder,
508 TensorInfo({numUnits, numBatches}, internalType),
509 ActivationFunction::Sigmoid, 0, 0);
510
511 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
512
513 // Hidden state tanH
514 Activation(*cellStateOutDecoder, *cellGateEncoder,
515 TensorInfo({numUnits, numBatches}, internalType),
516 ActivationFunction::TanH, 1.0f, 1.0f);
517
518 // Final hidden state output
519 VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder);
520
521 // Projection
522 if (m_Data.m_Parameters.m_ProjectionEnabled)
523 {
524 if (m_ProjectionBiasTensor)
525 {
James Conroyb22a75e2020-06-08 14:53:10 +0100526 VectorBatchVectorAssign(*projectionBiasDecoder, outputSize, numBatches, *outputInt16Encoder);
James Conroy4f1f8992020-04-29 20:01:10 +0100527 }
528
James Conroyb22a75e2020-06-08 14:53:10 +0100529 MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder, outputSize, numUnits, *hiddenStateDecoder,
530 numBatches, *outputInt16Encoder);
531
532 CopyVector(*outputInt16Decoder, numBatches * outputSize, *outputEncoder);
James Conroy4f1f8992020-04-29 20:01:10 +0100533
534 if (m_Data.m_Parameters.m_ProjectionClip > 0.0)
535 {
536 ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder);
537 }
538 }
539 else
540 {
541 // Output has same quantization scale as hidden state if projection is disabled
542 CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder);
543 }
544
545 // output == outputStateOut
546 CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder);
547}
548
549} //namespace armnn