blob: 4ca3e03c89ade9856b9ad01e9d771ced99b7bd52 [file] [log] [blame]
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001//
Mike Kelly7cbe7812023-07-25 17:37:33 +01002// Copyright © 2021-2023 Arm Ltd and Contributors. All rights reserved.
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "RefUnidirectionalSequenceLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10#include "Lstm.hpp"
11#include "LstmUtils.hpp"
12#include "RefWorkloadUtils.hpp"
13
14#include <armnnUtils/Permute.hpp>
15
16namespace armnn
17{
18
19RefUnidirectionalSequenceLstmWorkload::RefUnidirectionalSequenceLstmWorkload(
20 const UnidirectionalSequenceLstmQueueDescriptor& descriptor,
21 const WorkloadInfo& info)
Finn Williams73c547d2022-02-15 20:47:34 +000022 : RefBaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>(descriptor, info)
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010023 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
24 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
25 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
26 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
27 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
28 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
29 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
30 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
31 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
32 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
33 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
34 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
35 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
36 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
37 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
38 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
39 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
40 , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
41 , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
42 , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
43 , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
44{}
45
46void RefUnidirectionalSequenceLstmWorkload::Execute() const
47{
48 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
49}
50
Matthew Sloyan2d213a72022-06-30 17:13:04 +010051void RefUnidirectionalSequenceLstmWorkload::ExecuteAsync(ExecutionData& executionData)
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010052{
Matthew Sloyan2d213a72022-06-30 17:13:04 +010053 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
54 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010055}
56
57void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> inputs,
58 std::vector<ITensorHandle*> outputs) const
59{
Mike Kelly7cbe7812023-07-25 17:37:33 +010060 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefUnidirectionalSequenceLstmWorkload_Execute");
61
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010062 TensorInfo inputInfo = GetTensorInfo(inputs[0]);
63 const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]);
64 const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]);
Mike Kelly12994962022-04-21 11:57:09 +010065 TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]);
66 TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]);
67 TensorInfo outputInfo = GetTensorInfo(outputs[2]);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010068 TensorShape& inputShape = inputInfo.GetShape();
69 TensorShape& outputShape= outputInfo.GetShape();
70 auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map());
71
72 if (!m_Data.m_Parameters.m_TimeMajor)
73 {
74 // Permute to time major
75 const PermutationVector& mappings = {1U, 0U, 2U};
76 std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.GetNumElements());
77 inputShape = armnnUtils::Permuted(inputInfo.GetShape(), mappings);
78 inputInfo.SetShape(inputShape);
79 armnnUtils::Permute(inputShape, mappings, inputValue.data(), inputTensor, sizeof(float));
80
81 outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
82 outputInfo.SetShape(outputShape);
83 }
Narumol Prangnawarat5f941242023-08-11 16:09:26 +010084 // As it is permuted to time major, maxTime is inputShape[0].
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010085 unsigned int maxTime = inputShape[0];
86 unsigned int batchSize = inputShape[1];
87 unsigned int outputSize = outputShape[2];
88 unsigned int inputSize = inputShape[2];
89
90 TensorInfo scratchInfo = outputInfo;
91 scratchInfo.SetShape({batchSize, cellStateInfo.GetShape()[1]});
92
93 std::vector<float> inputGateScratchBuffer;
94 std::vector<float> cellScratchBuffer(scratchInfo.GetNumElements(), 0.);
95 std::vector<float> forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
96 std::vector<float> outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
97
98 std::vector<float> outputStateOutBuffer(outputStateInfo.GetNumElements(), 0.);
99 std::vector<float> cellStateOutBuffer(cellStateInfo.GetNumElements(), 0.);
100
101 void* outputStateOutData = outputStateOutBuffer.data();
102 void* cellStateOutData = cellStateOutBuffer.data();
103
104 std::unique_ptr<Encoder<float>> inputGateScratch;
105 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
106 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
107 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
108
109 std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
110 std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
111 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
112 forgetGateScratchBuffer.data());
113 std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
114 outputGateScratchBuffer.data());
115
116 const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
117 const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
118 const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
119
120 if (!useCifg)
121 {
122 inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
123 inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
124 inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
125 }
126
127 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
128 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData);
129 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
130
131 TensorInfo lstmInputInfo = inputInfo;
132 TensorShape batchInputShape = TensorShape({batchSize, inputSize});
133 lstmInputInfo.SetShape(batchInputShape);
134
135 TensorInfo lstmOutputInfo = outputInfo;
136 lstmOutputInfo.SetShape({batchSize, outputSize});
137
138 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
139 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
140 unsigned int nOutput = recurrentToOutputWeightsShape[1];
141 auto outputStateInData = inputs[1]->Map();
142 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
143
144 auto cellStateInData = inputs[2]->Map();
145 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
146
147 auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map());
148 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
Mike Kelly12994962022-04-21 11:57:09 +0100149 auto currentOutputData = reinterpret_cast<float*>(outputs[2]->Map());
Narumol Prangnawarate5339e72021-07-28 17:33:28 +0100150 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
151 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
152
153 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
154 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
155 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
156 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
157 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
158 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
159 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
160
161 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
162 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
163 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
164 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
165 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
166 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
167 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
168
169 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
170 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
171 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
172 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
173 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
174 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
175 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
176
177 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
178 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
179 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
180
181 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
182 std::unique_ptr<Decoder<float>> projectionBiasTensor;
183
184 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
185 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
186 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
187 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
188
189 if (useLayerNorm)
190 {
191 if (!useCifg)
192 {
193 inputLayerNormWeights = MakeDecoder<float>(
194 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
195 }
196 forgetLayerNormWeights = MakeDecoder<float>(
197 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
198 cellLayerNormWeights = MakeDecoder<float>(
199 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
200 outputLayerNormWeights = MakeDecoder<float>(
201 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
202 }
203
204 if (!useCifg)
205 {
206 inputToInputWeightsTensor = MakeDecoder<float>(
207 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
208 inputGateBiasTensor = MakeDecoder<float>(
209 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
210 recurrentToInputWeightsTensor = MakeDecoder<float>(
211 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
212 }
213
214 if (usePeephole)
215 {
216 cellToForgetWeightsTensor = MakeDecoder<float>(
217 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
218 cellToOutputWeightsTensor = MakeDecoder<float>(
219 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
220 }
221
222 if (!useCifg && usePeephole)
223 {
224 cellToInputWeightsTensor = MakeDecoder<float>(
225 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
226 }
227
228 if (m_Data.m_Parameters.m_ProjectionEnabled)
229 {
230 projectionWeightsTensor = MakeDecoder<float>(
231 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
232 if (m_ProjectionBiasTensor)
233 {
234 projectionBiasTensor = MakeDecoder<float>(
235 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
236 }
237 }
238
239 unsigned int batchInputSize = batchSize * inputSize;
240 unsigned int batchOutputSize = batchSize * nOutput;
241
242 for (unsigned int t = 0; t < maxTime; ++t)
243 {
244 LstmImpl(m_Data.m_Parameters,
245 lstmInputInfo,
246 lstmOutputInfo,
247 inputToOutputWeightsShape,
248 recurrentToOutputWeightsShape,
249 inputData,
250 outputStateIn,
251 cellStateIn,
252 outputStateOut,
253 cellStateOut,
254 output,
255 cellStateOutDecoder,
256 outputDecoder,
257 inputToInputWeightsTensor,
258 inputToForgetWeightsTensor,
259 inputToCellWeightsTensor,
260 inputToOutputWeightsTensor,
261 recurrentToInputWeightsTensor,
262 recurrentToForgetWeightsTensor,
263 recurrentToCellWeightsTensor,
264 recurrentToOutputWeightsTensor,
265 cellToInputWeightsTensor,
266 cellToForgetWeightsTensor,
267 cellToOutputWeightsTensor,
268 inputGateBiasTensor,
269 forgetGateBiasTensor,
270 cellBiasTensor,
271 outputGateBiasTensor,
272 projectionWeightsTensor,
273 projectionBiasTensor,
274 inputLayerNormWeights,
275 forgetLayerNormWeights,
276 cellLayerNormWeights,
277 outputLayerNormWeights,
278 inputGateScratch,
279 cellScratch,
280 forgetGateScratch,
281 outputGateScratch,
282 inputGateScratchDecoder,
283 cellScratchDecoder,
284 forgetGateScratchDecoder,
285 outputGateScratchDecoder,
286 m_LayerNormEpsilon);
287
288 currentInputData += batchInputSize;
289 inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
290 currentOutputData += batchOutputSize;
291 output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
292 outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
293
294 // Assign output state out to the next output state in
295 outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
296
297 // Assign cell state out to the next cell state in
298 cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
299 }
300
301 if (!m_Data.m_Parameters.m_TimeMajor)
302 {
303 // Permute Output back to batch major
304 const PermutationVector& mappings = {1U, 0U, 2U};
Mike Kelly12994962022-04-21 11:57:09 +0100305 auto outputData = reinterpret_cast<float*>(outputs[2]->Map());
Narumol Prangnawarate5339e72021-07-28 17:33:28 +0100306 std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements());
307 outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
308 outputInfo.SetShape(outputShape);
309 armnnUtils::Permute(outputShape, mappings, outputValue.data(), outputData, sizeof(float));
310 }
311}
312
313} //namespace armnn