blob: 23022d076c794597f154b802ecc3ae1e30b535ac [file] [log] [blame]
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001//
Matthew Sloyan2d213a72022-06-30 17:13:04 +01002// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "RefUnidirectionalSequenceLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10#include "Lstm.hpp"
11#include "LstmUtils.hpp"
12#include "RefWorkloadUtils.hpp"
13
14#include <armnnUtils/Permute.hpp>
15
16namespace armnn
17{
18
19RefUnidirectionalSequenceLstmWorkload::RefUnidirectionalSequenceLstmWorkload(
20 const UnidirectionalSequenceLstmQueueDescriptor& descriptor,
21 const WorkloadInfo& info)
Finn Williams73c547d2022-02-15 20:47:34 +000022 : RefBaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>(descriptor, info)
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010023 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
24 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
25 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
26 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
27 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
28 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
29 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
30 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
31 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
32 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
33 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
34 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
35 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
36 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
37 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
38 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
39 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
40 , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
41 , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
42 , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
43 , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
44{}
45
46void RefUnidirectionalSequenceLstmWorkload::Execute() const
47{
48 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
49}
50
Matthew Sloyan2d213a72022-06-30 17:13:04 +010051void RefUnidirectionalSequenceLstmWorkload::ExecuteAsync(ExecutionData& executionData)
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010052{
Matthew Sloyan2d213a72022-06-30 17:13:04 +010053 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
54 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010055}
56
57void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> inputs,
58 std::vector<ITensorHandle*> outputs) const
59{
60 TensorInfo inputInfo = GetTensorInfo(inputs[0]);
61 const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]);
62 const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]);
Mike Kelly12994962022-04-21 11:57:09 +010063 TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]);
64 TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]);
65 TensorInfo outputInfo = GetTensorInfo(outputs[2]);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010066 TensorShape& inputShape = inputInfo.GetShape();
67 TensorShape& outputShape= outputInfo.GetShape();
68 auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map());
69
70 if (!m_Data.m_Parameters.m_TimeMajor)
71 {
72 // Permute to time major
73 const PermutationVector& mappings = {1U, 0U, 2U};
74 std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.GetNumElements());
75 inputShape = armnnUtils::Permuted(inputInfo.GetShape(), mappings);
76 inputInfo.SetShape(inputShape);
77 armnnUtils::Permute(inputShape, mappings, inputValue.data(), inputTensor, sizeof(float));
78
79 outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
80 outputInfo.SetShape(outputShape);
81 }
82 unsigned int maxTime = inputShape[0];
83 unsigned int batchSize = inputShape[1];
84 unsigned int outputSize = outputShape[2];
85 unsigned int inputSize = inputShape[2];
86
87 TensorInfo scratchInfo = outputInfo;
88 scratchInfo.SetShape({batchSize, cellStateInfo.GetShape()[1]});
89
90 std::vector<float> inputGateScratchBuffer;
91 std::vector<float> cellScratchBuffer(scratchInfo.GetNumElements(), 0.);
92 std::vector<float> forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
93 std::vector<float> outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
94
95 std::vector<float> outputStateOutBuffer(outputStateInfo.GetNumElements(), 0.);
96 std::vector<float> cellStateOutBuffer(cellStateInfo.GetNumElements(), 0.);
97
98 void* outputStateOutData = outputStateOutBuffer.data();
99 void* cellStateOutData = cellStateOutBuffer.data();
100
101 std::unique_ptr<Encoder<float>> inputGateScratch;
102 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
103 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
104 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
105
106 std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
107 std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
108 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
109 forgetGateScratchBuffer.data());
110 std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
111 outputGateScratchBuffer.data());
112
113 const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
114 const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
115 const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
116
117 if (!useCifg)
118 {
119 inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
120 inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
121 inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
122 }
123
124 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
125 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData);
126 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
127
128 TensorInfo lstmInputInfo = inputInfo;
129 TensorShape batchInputShape = TensorShape({batchSize, inputSize});
130 lstmInputInfo.SetShape(batchInputShape);
131
132 TensorInfo lstmOutputInfo = outputInfo;
133 lstmOutputInfo.SetShape({batchSize, outputSize});
134
135 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
136 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
137 unsigned int nOutput = recurrentToOutputWeightsShape[1];
138 auto outputStateInData = inputs[1]->Map();
139 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
140
141 auto cellStateInData = inputs[2]->Map();
142 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
143
144 auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map());
145 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
Mike Kelly12994962022-04-21 11:57:09 +0100146 auto currentOutputData = reinterpret_cast<float*>(outputs[2]->Map());
Narumol Prangnawarate5339e72021-07-28 17:33:28 +0100147 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
148 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
149
150 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
151 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
152 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
153 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
154 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
155 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
156 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
157
158 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
159 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
160 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
161 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
162 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
163 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
164 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
165
166 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
167 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
168 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
169 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
170 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
171 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
172 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
173
174 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
175 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
176 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
177
178 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
179 std::unique_ptr<Decoder<float>> projectionBiasTensor;
180
181 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
182 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
183 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
184 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
185
186 if (useLayerNorm)
187 {
188 if (!useCifg)
189 {
190 inputLayerNormWeights = MakeDecoder<float>(
191 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
192 }
193 forgetLayerNormWeights = MakeDecoder<float>(
194 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
195 cellLayerNormWeights = MakeDecoder<float>(
196 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
197 outputLayerNormWeights = MakeDecoder<float>(
198 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
199 }
200
201 if (!useCifg)
202 {
203 inputToInputWeightsTensor = MakeDecoder<float>(
204 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
205 inputGateBiasTensor = MakeDecoder<float>(
206 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
207 recurrentToInputWeightsTensor = MakeDecoder<float>(
208 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
209 }
210
211 if (usePeephole)
212 {
213 cellToForgetWeightsTensor = MakeDecoder<float>(
214 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
215 cellToOutputWeightsTensor = MakeDecoder<float>(
216 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
217 }
218
219 if (!useCifg && usePeephole)
220 {
221 cellToInputWeightsTensor = MakeDecoder<float>(
222 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
223 }
224
225 if (m_Data.m_Parameters.m_ProjectionEnabled)
226 {
227 projectionWeightsTensor = MakeDecoder<float>(
228 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
229 if (m_ProjectionBiasTensor)
230 {
231 projectionBiasTensor = MakeDecoder<float>(
232 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
233 }
234 }
235
236 unsigned int batchInputSize = batchSize * inputSize;
237 unsigned int batchOutputSize = batchSize * nOutput;
238
239 for (unsigned int t = 0; t < maxTime; ++t)
240 {
241 LstmImpl(m_Data.m_Parameters,
242 lstmInputInfo,
243 lstmOutputInfo,
244 inputToOutputWeightsShape,
245 recurrentToOutputWeightsShape,
246 inputData,
247 outputStateIn,
248 cellStateIn,
249 outputStateOut,
250 cellStateOut,
251 output,
252 cellStateOutDecoder,
253 outputDecoder,
254 inputToInputWeightsTensor,
255 inputToForgetWeightsTensor,
256 inputToCellWeightsTensor,
257 inputToOutputWeightsTensor,
258 recurrentToInputWeightsTensor,
259 recurrentToForgetWeightsTensor,
260 recurrentToCellWeightsTensor,
261 recurrentToOutputWeightsTensor,
262 cellToInputWeightsTensor,
263 cellToForgetWeightsTensor,
264 cellToOutputWeightsTensor,
265 inputGateBiasTensor,
266 forgetGateBiasTensor,
267 cellBiasTensor,
268 outputGateBiasTensor,
269 projectionWeightsTensor,
270 projectionBiasTensor,
271 inputLayerNormWeights,
272 forgetLayerNormWeights,
273 cellLayerNormWeights,
274 outputLayerNormWeights,
275 inputGateScratch,
276 cellScratch,
277 forgetGateScratch,
278 outputGateScratch,
279 inputGateScratchDecoder,
280 cellScratchDecoder,
281 forgetGateScratchDecoder,
282 outputGateScratchDecoder,
283 m_LayerNormEpsilon);
284
285 currentInputData += batchInputSize;
286 inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
287 currentOutputData += batchOutputSize;
288 output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
289 outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
290
291 // Assign output state out to the next output state in
292 outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
293
294 // Assign cell state out to the next cell state in
295 cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
296 }
297
298 if (!m_Data.m_Parameters.m_TimeMajor)
299 {
300 // Permute Output back to batch major
301 const PermutationVector& mappings = {1U, 0U, 2U};
Mike Kelly12994962022-04-21 11:57:09 +0100302 auto outputData = reinterpret_cast<float*>(outputs[2]->Map());
Narumol Prangnawarate5339e72021-07-28 17:33:28 +0100303 std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements());
304 outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
305 outputInfo.SetShape(outputShape);
306 armnnUtils::Permute(outputShape, mappings, outputValue.data(), outputData, sizeof(float));
307 }
308}
309
310} //namespace armnn