blob: 74f5f1ef4c4889d6dc507cc594e66082907cc177 [file] [log] [blame]
James Conroy4f1f8992020-04-29 20:01:10 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefQLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10#include "LstmUtils.hpp"
11#include "RefWorkloadUtils.hpp"
12
13namespace armnn
14{
15
16RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
Finn Williams73c547d2022-02-15 20:47:34 +000017 : RefBaseWorkload<QLstmQueueDescriptor>(descriptor, info)
James Conroy1f58f032021-04-27 17:13:27 +010018 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
19 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
20 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
21 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
James Conroy4f1f8992020-04-29 20:01:10 +010022
James Conroy1f58f032021-04-27 17:13:27 +010023 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
James Conroy4f1f8992020-04-29 20:01:10 +010027
James Conroy1f58f032021-04-27 17:13:27 +010028 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
29 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
30 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
James Conroy4f1f8992020-04-29 20:01:10 +010031
James Conroy1f58f032021-04-27 17:13:27 +010032 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
33 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
34 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
35 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
James Conroy4f1f8992020-04-29 20:01:10 +010036
James Conroy1f58f032021-04-27 17:13:27 +010037 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
38 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
James Conroy4f1f8992020-04-29 20:01:10 +010039
James Conroy1f58f032021-04-27 17:13:27 +010040 , m_InputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
41 , m_ForgetLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
42 , m_CellLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
43 , m_OutputLayerNormWeightsTensor (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
James Conroy4f1f8992020-04-29 20:01:10 +010044{}
45
46void RefQLstmWorkload::Execute() const
47{
Finn Williamsb8181f72021-04-07 10:23:21 +010048 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
49}
50
51void RefQLstmWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor)
52{
53 Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
54}
55
56void RefQLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
57{
58 // This is a porting of the QLSTM::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs)
59 // method in the Android code base
James Conroy4f1f8992020-04-29 20:01:10 +010060 // Note: this implementation wraps the arithmetic functions of the LSTM cell in Quantize/Dequantize ops, so all
61 // computation is done in the floating point domain. Arithmetic functions are found in LstmUtils.cpp.
62 // Refer to: android/frameworks/ml/nn/common/operations/QLSTM.cpp
63 const DataType& internalType = armnn::DataType::QSymmS16;
64
Finn Williamsb8181f72021-04-07 10:23:21 +010065 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
66 const TensorInfo& outputStateInInfo = GetTensorInfo(inputs[1]);
67 const TensorInfo& cellStateInInfo = GetTensorInfo(inputs[2]);
James Conroy4f1f8992020-04-29 20:01:10 +010068
Finn Williamsb8181f72021-04-07 10:23:21 +010069 const TensorInfo& outputStateOutInfo = GetTensorInfo(outputs[0]);
70 const TensorInfo& cellStateOutInfo = GetTensorInfo(outputs[1]);
71 const TensorInfo& outputInfo = GetTensorInfo(outputs[2]);
James Conroy4f1f8992020-04-29 20:01:10 +010072
73 const TensorShape& inputShape = inputInfo.GetShape();
74 const TensorShape& outputStateInShape = outputStateInInfo.GetShape();
75 const TensorShape& cellStateInShape = cellStateInInfo.GetShape();
76
77 // Infer numBatches, inputSize, outputSize and numUnits
78 const uint32_t numBatches = inputShape[0];
79 const uint32_t inputSize = inputShape[1];
80 const uint32_t outputSize = outputStateInShape[1];
81 const uint32_t numUnits = cellStateInShape[1];
82
83 // Optional param settings
84 const bool cifgEnabled = m_Data.m_Parameters.m_CifgEnabled;
85 const bool peepholeEnabled = m_Data.m_Parameters.m_PeepholeEnabled;
86 const bool projectionEnabled = m_Data.m_Parameters.m_ProjectionEnabled;
87 const bool layerNormEnabled = m_Data.m_Parameters.m_LayerNormEnabled;
88
89 // Input decoders
90 std::unique_ptr<Decoder<float>> inputDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010091 MakeDecoder<float>(inputInfo, inputs[0]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +010092 std::unique_ptr<Decoder<float>> outputStateInDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010093 MakeDecoder<float>(outputStateInInfo, inputs[1]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +010094 std::unique_ptr<Decoder<float>> cellStateInDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010095 MakeDecoder<float>(cellStateInInfo, inputs[2]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +010096
97 // Output decoders
98 std::unique_ptr<Decoder<float>> outputStateOutDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010099 MakeDecoder<float>(outputStateOutInfo, outputs[0]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100100 std::unique_ptr<Decoder<float>> cellStateOutDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100101 MakeDecoder<float>(cellStateOutInfo, outputs[1]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100102 std::unique_ptr<Decoder<float>> outputDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100103 MakeDecoder<float>(outputInfo, outputs[2]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100104
105 // Output encoders
106 std::unique_ptr<Encoder<float>> outputStateOutEncoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100107 MakeEncoder<float>(outputStateOutInfo, outputs[0]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100108 std::unique_ptr<Encoder<float>> cellStateOutEncoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100109 MakeEncoder<float>(cellStateOutInfo, outputs[1]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100110 std::unique_ptr<Encoder<float>> outputEncoder =
Finn Williamsb8181f72021-04-07 10:23:21 +0100111 MakeEncoder<float>(outputInfo, outputs[2]->Map());
James Conroy4f1f8992020-04-29 20:01:10 +0100112
113 // Weights decoders
114 std::unique_ptr<Decoder<float>> inputToForgetWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000115 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100116 std::unique_ptr<Decoder<float>> inputToCellWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000117 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100118 std::unique_ptr<Decoder<float>> inputToOutputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000119 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100120
121 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000122 m_RecurrentToForgetWeightsTensor->GetTensorInfo(),
123 m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100124 std::unique_ptr<Decoder<float>> recurrentToCellWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000125 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100126 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000127 m_RecurrentToOutputWeightsTensor->GetTensorInfo(),
128 m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100129
130 // Optional CIFG params
131 std::unique_ptr<Decoder<float>> inputToInputWeightsDecoder;
132 std::unique_ptr<Decoder<float>> recurrentToInputWeightsDecoder;
133 std::unique_ptr<Decoder<float>> inputGateBiasDecoder;
134
135 // Optional Peephole params
136 std::unique_ptr<Decoder<float>> cellToInputWeightsDecoder;
137 std::unique_ptr<Decoder<float>> cellToForgetWeightsDecoder;
138 std::unique_ptr<Decoder<float>> cellToOutputWeightsDecoder;
139
140 // Optional Projection params
141 std::unique_ptr<Decoder<float>> projectionWeightsDecoder;
142 std::unique_ptr<Decoder<float>> projectionBiasDecoder;
143
144 // Optional Layer Norm params
145 std::unique_ptr<Decoder<float>> inputLayerNormWeightsDecoder;
146 std::unique_ptr<Decoder<float>> forgetLayerNormWeightsDecoder;
147 std::unique_ptr<Decoder<float>> cellLayerNormWeightsDecoder;
148 std::unique_ptr<Decoder<float>> outputLayerNormWeightsDecoder;
149
150 // Biases are only used when Layer Norm is enabled. Scale is defined as (XLayerNormWeights Scale / 1024)
151 std::unique_ptr<Decoder<float>> forgetGateBiasDecoder;
152 std::unique_ptr<Decoder<float>> cellGateBiasDecoder;
153 std::unique_ptr<Decoder<float>> outputGateBiasDecoder;
154
155 // Int16 vectors for internal state data (to be decoded/encoded)
156 const uint32_t stateTensorSize = numBatches * numUnits;
157 std::vector<int16_t> inputGateData(stateTensorSize);
158 std::vector<int16_t> cellGateData(stateTensorSize);
159 std::vector<int16_t> forgetGateData(stateTensorSize);
160 std::vector<int16_t> outputGateData(stateTensorSize);
161 std::vector<int32_t> hiddenStateData(stateTensorSize);
James Conroyb22a75e2020-06-08 14:53:10 +0100162 std::vector<int16_t> outputInt16Data(numBatches * outputSize);
James Conroy4f1f8992020-04-29 20:01:10 +0100163
164 armnn::TensorInfo inputGateInfo(
165 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_InputIntermediateScale, 0);
166 armnn::TensorInfo cellGateInfo(
167 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_CellIntermediateScale, 0);
168 armnn::TensorInfo forgetGateInfo(
169 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_ForgetIntermediateScale, 0);
170 armnn::TensorInfo outputGateInfo(
171 {numBatches , numUnits}, armnn::DataType::QSymmS16, m_Data.m_Parameters.m_OutputIntermediateScale, 0);
172 armnn::TensorInfo hiddenStateInfo({numBatches, numUnits},
173 armnn::DataType::QAsymmS8,
174 m_Data.m_Parameters.m_HiddenStateScale,
175 m_Data.m_Parameters.m_HiddenStateZeroPoint);
James Conroyb22a75e2020-06-08 14:53:10 +0100176 armnn::TensorInfo outputInt16Info({numBatches , outputSize},
177 armnn::DataType::QSymmS16,
178 outputInfo.GetQuantizationScale(),
179 outputInfo.GetQuantizationOffset());
James Conroy4f1f8992020-04-29 20:01:10 +0100180
181 // Decoders/Encoders for internal states
182 std::unique_ptr<Decoder<float>> inputGateDecoder =
183 MakeDecoder<float>(inputGateInfo, inputGateData.data());
184 std::unique_ptr<Decoder<float>> cellGateDecoder =
185 MakeDecoder<float>(cellGateInfo, cellGateData.data());
186 std::unique_ptr<Decoder<float>> forgetGateDecoder =
187 MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
188 std::unique_ptr<Decoder<float>> outputGateDecoder =
189 MakeDecoder<float>(outputGateInfo, outputGateData.data());
190 std::unique_ptr<Decoder<float>> hiddenStateDecoder =
191 MakeDecoder<float>(hiddenStateInfo, hiddenStateData.data());
192
193 std::unique_ptr<Encoder<float>> inputGateEncoder =
194 MakeEncoder<float>(inputGateInfo, inputGateData.data());
195 std::unique_ptr<Encoder<float>> cellGateEncoder =
196 MakeEncoder<float>(cellGateInfo, cellGateData.data());
197 std::unique_ptr<Encoder<float>> forgetGateEncoder =
198 MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
199 std::unique_ptr<Encoder<float>> outputGateEncoder =
200 MakeEncoder<float>(outputGateInfo, outputGateData.data());
201 std::unique_ptr<Encoder<float>> hiddenStateEncoder =
202 MakeEncoder<float>(hiddenStateInfo, hiddenStateData.data());
203
James Conroyb22a75e2020-06-08 14:53:10 +0100204 // Int16 used to accumulate output to prevent overflowing (after Projection MatMul)
205 std::unique_ptr<Decoder<float>> outputInt16Decoder =
206 MakeDecoder<float>(outputInt16Info, outputInt16Data.data());
207 std::unique_ptr<Encoder<float>> outputInt16Encoder =
208 MakeEncoder<float>(outputInt16Info, outputInt16Data.data());
209
James Conroy4f1f8992020-04-29 20:01:10 +0100210 // Create decoders for optional params if they are enabled
211 if (!cifgEnabled)
212 {
213 inputToInputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000214 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
215 recurrentToInputWeightsDecoder = MakeDecoder<float>(m_RecurrentToInputWeightsTensor->GetTensorInfo(),
216 m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100217 }
218
219 if (peepholeEnabled)
220 {
221 if (!cifgEnabled)
222 {
223 cellToInputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000224 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100225 }
226 cellToForgetWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000227 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100228 cellToOutputWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000229 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100230 }
231
232 if (projectionEnabled)
233 {
234 projectionWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000235 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100236 if (m_ProjectionBiasTensor)
237 {
238 projectionBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000239 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100240 }
241 }
242
243 if (layerNormEnabled)
244 {
245 if (!cifgEnabled)
246 {
Finn Williams4422cec2021-03-22 17:51:06 +0000247 inputLayerNormWeightsDecoder = MakeDecoder<float>(m_InputLayerNormWeightsTensor->GetTensorInfo(),
248 m_InputLayerNormWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100249
250 // Bias only used if layer norm enabled
251 armnn::TensorInfo inputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
252 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
253 inputGateBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000254 inputGateBiasTensorInfo, m_InputGateBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100255 }
256
257 forgetLayerNormWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000258 m_ForgetLayerNormWeightsTensor->GetTensorInfo(),
259 m_ForgetLayerNormWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100260 cellLayerNormWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000261 m_CellLayerNormWeightsTensor->GetTensorInfo(), m_CellLayerNormWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100262 outputLayerNormWeightsDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000263 m_OutputLayerNormWeightsTensor->GetTensorInfo(),
264 m_OutputLayerNormWeightsTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100265
266 // Bias only used if layer norm enabled
267 armnn::TensorInfo forgetGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
268 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
269 forgetGateBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000270 forgetGateBiasTensorInfo, m_ForgetGateBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100271
272 armnn::TensorInfo cellGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
273 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
274 cellGateBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000275 cellGateBiasTensorInfo, m_CellBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100276
277 armnn::TensorInfo outputGateBiasTensorInfo({outputSize}, armnn::DataType::Signed32,
278 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() / 1024, 0);
279 outputGateBiasDecoder = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000280 outputGateBiasTensorInfo, m_OutputGateBiasTensor->GetConstTensor<void>());
James Conroy4f1f8992020-04-29 20:01:10 +0100281 }
282
283 // Initialize internal state tensors with zeroes.
284 if (!cifgEnabled)
285 {
286 ZeroVector(*inputGateEncoder, stateTensorSize);
287 }
288 ZeroVector(*forgetGateEncoder, stateTensorSize);
289 ZeroVector(*cellGateEncoder, stateTensorSize);
290 ZeroVector(*outputGateEncoder, stateTensorSize);
291 ZeroVector(*hiddenStateEncoder, stateTensorSize);
292
293 // Input weights * Input
294 if (!cifgEnabled)
295 {
296 MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsDecoder,
297 numUnits, inputSize, *inputDecoder, numBatches, *inputGateEncoder);
298 }
299
300 MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsDecoder,
301 numUnits, inputSize, *inputDecoder, numBatches, *forgetGateEncoder);
302
303 MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsDecoder,
304 numUnits, inputSize, *inputDecoder, numBatches, *cellGateEncoder);
305
306 MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsDecoder,
307 numUnits, inputSize, *inputDecoder, numBatches, *outputGateEncoder);
308
309 // Recurrent weights * OutputStateIn
310 if (!cifgEnabled)
311 {
312 MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsDecoder,
313 numUnits, outputSize, *outputStateInDecoder, numBatches, *inputGateEncoder);
314 }
315
316 MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsDecoder,
317 numUnits, outputSize, *outputStateInDecoder, numBatches, *forgetGateEncoder);
318
319 MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsDecoder,
320 numUnits, outputSize, *outputStateInDecoder, numBatches, *cellGateEncoder);
321
322 MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsDecoder,
323 numUnits, outputSize, *outputStateInDecoder, numBatches, *outputGateEncoder);
324
325 // Input gate.
326 if (!cifgEnabled)
327 {
328 if (peepholeEnabled)
329 {
330 VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsDecoder,
331 numUnits, *cellStateInDecoder, numBatches, *inputGateEncoder);
332 }
333
334 if (layerNormEnabled)
335 {
336 inputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
337 m_InputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
338 1024);
339 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
340
341 MeanStddevNormalization(*inputGateDecoder,
342 *inputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
343
344 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
345
346 VectorBatchVectorCwiseProduct(*inputLayerNormWeightsDecoder,
347 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
348
349 inputGateInfo.SetQuantizationScale(1.f / 4096);
350 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
351
352 VectorBatchVectorAdd(*inputGateBiasDecoder,
353 numUnits, *inputGateDecoder, numBatches, *inputGateEncoder);
354
355 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
356 }
357
358 inputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
359 inputGateEncoder = MakeEncoder<float>(inputGateInfo, inputGateData.data());
360
361 // Input gate sigmoid
362 Activation(*inputGateDecoder, *inputGateEncoder,
363 TensorInfo({numUnits, numBatches}, internalType),
364 ActivationFunction::Sigmoid, 0, 0);
365
366 inputGateDecoder = MakeDecoder<float>(inputGateInfo, inputGateData.data());
367 }
368
369 // Forget gate
370 if (peepholeEnabled)
371 {
372 VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsDecoder, numUnits,
373 *cellStateInDecoder, numBatches, *forgetGateEncoder);
374 }
375
376 if (layerNormEnabled)
377 {
378 // Quantize layer norm output to Input Scale * m_ForgetLayerNormWeightsTensor * 1024
379 forgetGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
380 m_ForgetLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
381 1024);
382 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
383
384
385
386 MeanStddevNormalization(*forgetGateDecoder,
387 *forgetGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
388
389
390 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
391
392 VectorBatchVectorCwiseProduct(*forgetLayerNormWeightsDecoder,
393 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
394
395
396 // Dequantize layer norm output to (1 / 4096)
397 forgetGateInfo.SetQuantizationScale(1.f / 4096);
398 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
399
400 VectorBatchVectorAdd(*forgetGateBiasDecoder,
401 numUnits, *forgetGateDecoder, numBatches, *forgetGateEncoder);
402
403
404 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
405 }
406
407 forgetGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
408 forgetGateEncoder = MakeEncoder<float>(forgetGateInfo, forgetGateData.data());
409
410 // Forget gate sigmoid
411 Activation(*forgetGateDecoder, *forgetGateEncoder,
412 TensorInfo({numUnits, numBatches}, internalType),
413 ActivationFunction::Sigmoid, 0, 0);
414
415 forgetGateDecoder = MakeDecoder<float>(forgetGateInfo, forgetGateData.data());
416
417 // Cell (Modulation) gate
418 if (layerNormEnabled)
419 {
420 cellGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
421 m_CellLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
422 1024);
423 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
424
425 MeanStddevNormalization(*cellGateDecoder, *cellGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
426
427 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
428
429 VectorBatchVectorCwiseProduct(*cellLayerNormWeightsDecoder,
430 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
431
432 cellGateInfo.SetQuantizationScale(1.f / 4096);
433 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
434
435 VectorBatchVectorAdd(*cellGateBiasDecoder,
436 numUnits, *cellGateDecoder, numBatches, *cellGateEncoder);
437
438 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
439 }
440
441 cellGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
442 cellGateEncoder = MakeEncoder<float>(cellGateInfo, cellGateData.data());
443
444 // Cell (Modulation) gate tanH
445 Activation(*cellGateDecoder, *cellGateEncoder,
446 TensorInfo({numUnits, numBatches}, internalType),
447 ActivationFunction::TanH, 1.0f, 1.0f);
448
449 cellGateDecoder = MakeDecoder<float>(cellGateInfo, cellGateData.data());
450
451 VectorVectorCwiseProduct(*forgetGateDecoder, *cellStateInDecoder, stateTensorSize, *cellStateOutEncoder);
452
453 if (cifgEnabled)
454 {
455 Sub1Vector(*forgetGateDecoder, stateTensorSize, *forgetGateEncoder);
456 VectorVectorCwiseProductAccumulate(
457 *cellGateDecoder, *forgetGateDecoder, stateTensorSize, *cellStateOutEncoder);
458 }
459 else
460 {
461 VectorVectorCwiseProductAccumulate(
462 *cellGateDecoder, *inputGateDecoder, stateTensorSize, *cellStateOutEncoder);
463 }
464
465 // Final cell state out calculated here
466 if (m_Data.m_Parameters.m_CellClip > 0.0)
467 {
468 ClipVector(*cellStateOutDecoder, stateTensorSize, m_Data.m_Parameters.m_CellClip, *cellStateOutEncoder);
469 }
470
471 // Output gate.
472 if (peepholeEnabled)
473 {
474 VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsDecoder,
475 numUnits, *cellStateOutDecoder, numBatches, *outputGateEncoder);
476 }
477
478 if (layerNormEnabled)
479 {
480 outputGateInfo.SetQuantizationScale(inputInfo.GetQuantizationScale() *
481 m_OutputLayerNormWeightsTensor->GetTensorInfo().GetQuantizationScale() *
482 1024);
483 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
484
485 MeanStddevNormalization(*outputGateDecoder, *outputGateEncoder, numUnits, numBatches, m_LayerNormEpsilon);
486
487 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
488
489 VectorBatchVectorCwiseProduct(*outputLayerNormWeightsDecoder, numUnits, *outputGateDecoder,
490 numBatches, *outputGateEncoder);
491
492 outputGateInfo.SetQuantizationScale(1.f / 4096);
493 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
494
495 VectorBatchVectorAdd(*outputGateBiasDecoder, numUnits, *outputGateDecoder, numBatches, *outputGateEncoder);
496
497 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
498 }
499
500 outputGateInfo.SetQuantizationScale(cellStateOutInfo.GetQuantizationScale());
501 outputGateEncoder = MakeEncoder<float>(outputGateInfo, outputGateData.data());
502
503 // Output gate sigmoid
504 Activation(*outputGateDecoder, *outputGateEncoder,
505 TensorInfo({numUnits, numBatches}, internalType),
506 ActivationFunction::Sigmoid, 0, 0);
507
508 outputGateDecoder = MakeDecoder<float>(outputGateInfo, outputGateData.data());
509
510 // Hidden state tanH
511 Activation(*cellStateOutDecoder, *cellGateEncoder,
512 TensorInfo({numUnits, numBatches}, internalType),
513 ActivationFunction::TanH, 1.0f, 1.0f);
514
515 // Final hidden state output
516 VectorVectorCwiseProduct(*outputGateDecoder, *cellGateDecoder, stateTensorSize, *hiddenStateEncoder);
517
518 // Projection
519 if (m_Data.m_Parameters.m_ProjectionEnabled)
520 {
521 if (m_ProjectionBiasTensor)
522 {
James Conroyb22a75e2020-06-08 14:53:10 +0100523 VectorBatchVectorAssign(*projectionBiasDecoder, outputSize, numBatches, *outputInt16Encoder);
James Conroy4f1f8992020-04-29 20:01:10 +0100524 }
525
James Conroyb22a75e2020-06-08 14:53:10 +0100526 MatrixBatchVectorMultiplyAccumulate(*projectionWeightsDecoder, outputSize, numUnits, *hiddenStateDecoder,
527 numBatches, *outputInt16Encoder);
528
529 CopyVector(*outputInt16Decoder, numBatches * outputSize, *outputEncoder);
James Conroy4f1f8992020-04-29 20:01:10 +0100530
531 if (m_Data.m_Parameters.m_ProjectionClip > 0.0)
532 {
533 ClipVector(*outputDecoder, numBatches * outputSize, m_Data.m_Parameters.m_ProjectionClip, *outputEncoder);
534 }
535 }
536 else
537 {
538 // Output has same quantization scale as hidden state if projection is disabled
539 CopyVector(*hiddenStateDecoder, numBatches * outputSize, *outputEncoder);
540 }
541
542 // output == outputStateOut
543 CopyVector(*outputDecoder, numBatches * outputSize, *outputStateOutEncoder);
544}
545
546} //namespace armnn