blob: c4345d497886f35dc6c1d92832745675c6ecae57 [file] [log] [blame]
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefUnidirectionalSequenceLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10#include "Lstm.hpp"
11#include "LstmUtils.hpp"
12#include "RefWorkloadUtils.hpp"
13
14#include <armnnUtils/Permute.hpp>
15
16namespace armnn
17{
18
19RefUnidirectionalSequenceLstmWorkload::RefUnidirectionalSequenceLstmWorkload(
20 const UnidirectionalSequenceLstmQueueDescriptor& descriptor,
21 const WorkloadInfo& info)
Finn Williams73c547d2022-02-15 20:47:34 +000022 : RefBaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>(descriptor, info)
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010023 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
24 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
25 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
26 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
27 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
28 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
29 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
30 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
31 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
32 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
33 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
34 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
35 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
36 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
37 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
38 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
39 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
40 , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
41 , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
42 , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
43 , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
44{}
45
46void RefUnidirectionalSequenceLstmWorkload::Execute() const
47{
48 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
49}
50
51void RefUnidirectionalSequenceLstmWorkload::ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor)
52{
53 Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
54}
55
56void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> inputs,
57 std::vector<ITensorHandle*> outputs) const
58{
59 TensorInfo inputInfo = GetTensorInfo(inputs[0]);
60 const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]);
61 const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]);
Mike Kelly12994962022-04-21 11:57:09 +010062 TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]);
63 TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]);
64 TensorInfo outputInfo = GetTensorInfo(outputs[2]);
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010065 TensorShape& inputShape = inputInfo.GetShape();
66 TensorShape& outputShape= outputInfo.GetShape();
67 auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map());
68
69 if (!m_Data.m_Parameters.m_TimeMajor)
70 {
71 // Permute to time major
72 const PermutationVector& mappings = {1U, 0U, 2U};
73 std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.GetNumElements());
74 inputShape = armnnUtils::Permuted(inputInfo.GetShape(), mappings);
75 inputInfo.SetShape(inputShape);
76 armnnUtils::Permute(inputShape, mappings, inputValue.data(), inputTensor, sizeof(float));
77
78 outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
79 outputInfo.SetShape(outputShape);
80 }
81 unsigned int maxTime = inputShape[0];
82 unsigned int batchSize = inputShape[1];
83 unsigned int outputSize = outputShape[2];
84 unsigned int inputSize = inputShape[2];
85
86 TensorInfo scratchInfo = outputInfo;
87 scratchInfo.SetShape({batchSize, cellStateInfo.GetShape()[1]});
88
89 std::vector<float> inputGateScratchBuffer;
90 std::vector<float> cellScratchBuffer(scratchInfo.GetNumElements(), 0.);
91 std::vector<float> forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
92 std::vector<float> outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
93
94 std::vector<float> outputStateOutBuffer(outputStateInfo.GetNumElements(), 0.);
95 std::vector<float> cellStateOutBuffer(cellStateInfo.GetNumElements(), 0.);
96
97 void* outputStateOutData = outputStateOutBuffer.data();
98 void* cellStateOutData = cellStateOutBuffer.data();
99
100 std::unique_ptr<Encoder<float>> inputGateScratch;
101 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
102 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
103 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
104
105 std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
106 std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
107 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
108 forgetGateScratchBuffer.data());
109 std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
110 outputGateScratchBuffer.data());
111
112 const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
113 const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
114 const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
115
116 if (!useCifg)
117 {
118 inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
119 inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
120 inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
121 }
122
123 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
124 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData);
125 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
126
127 TensorInfo lstmInputInfo = inputInfo;
128 TensorShape batchInputShape = TensorShape({batchSize, inputSize});
129 lstmInputInfo.SetShape(batchInputShape);
130
131 TensorInfo lstmOutputInfo = outputInfo;
132 lstmOutputInfo.SetShape({batchSize, outputSize});
133
134 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
135 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
136 unsigned int nOutput = recurrentToOutputWeightsShape[1];
137 auto outputStateInData = inputs[1]->Map();
138 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
139
140 auto cellStateInData = inputs[2]->Map();
141 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
142
143 auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map());
144 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
Mike Kelly12994962022-04-21 11:57:09 +0100145 auto currentOutputData = reinterpret_cast<float*>(outputs[2]->Map());
Narumol Prangnawarate5339e72021-07-28 17:33:28 +0100146 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
147 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
148
149 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
150 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
151 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
152 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
153 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
154 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
155 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
156
157 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
158 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
159 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
160 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
161 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
162 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
163 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
164
165 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
166 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
167 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
168 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
169 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
170 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
171 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
172
173 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
174 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
175 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
176
177 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
178 std::unique_ptr<Decoder<float>> projectionBiasTensor;
179
180 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
181 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
182 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
183 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
184
185 if (useLayerNorm)
186 {
187 if (!useCifg)
188 {
189 inputLayerNormWeights = MakeDecoder<float>(
190 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
191 }
192 forgetLayerNormWeights = MakeDecoder<float>(
193 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
194 cellLayerNormWeights = MakeDecoder<float>(
195 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
196 outputLayerNormWeights = MakeDecoder<float>(
197 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
198 }
199
200 if (!useCifg)
201 {
202 inputToInputWeightsTensor = MakeDecoder<float>(
203 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
204 inputGateBiasTensor = MakeDecoder<float>(
205 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
206 recurrentToInputWeightsTensor = MakeDecoder<float>(
207 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
208 }
209
210 if (usePeephole)
211 {
212 cellToForgetWeightsTensor = MakeDecoder<float>(
213 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
214 cellToOutputWeightsTensor = MakeDecoder<float>(
215 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
216 }
217
218 if (!useCifg && usePeephole)
219 {
220 cellToInputWeightsTensor = MakeDecoder<float>(
221 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
222 }
223
224 if (m_Data.m_Parameters.m_ProjectionEnabled)
225 {
226 projectionWeightsTensor = MakeDecoder<float>(
227 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
228 if (m_ProjectionBiasTensor)
229 {
230 projectionBiasTensor = MakeDecoder<float>(
231 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
232 }
233 }
234
235 unsigned int batchInputSize = batchSize * inputSize;
236 unsigned int batchOutputSize = batchSize * nOutput;
237
238 for (unsigned int t = 0; t < maxTime; ++t)
239 {
240 LstmImpl(m_Data.m_Parameters,
241 lstmInputInfo,
242 lstmOutputInfo,
243 inputToOutputWeightsShape,
244 recurrentToOutputWeightsShape,
245 inputData,
246 outputStateIn,
247 cellStateIn,
248 outputStateOut,
249 cellStateOut,
250 output,
251 cellStateOutDecoder,
252 outputDecoder,
253 inputToInputWeightsTensor,
254 inputToForgetWeightsTensor,
255 inputToCellWeightsTensor,
256 inputToOutputWeightsTensor,
257 recurrentToInputWeightsTensor,
258 recurrentToForgetWeightsTensor,
259 recurrentToCellWeightsTensor,
260 recurrentToOutputWeightsTensor,
261 cellToInputWeightsTensor,
262 cellToForgetWeightsTensor,
263 cellToOutputWeightsTensor,
264 inputGateBiasTensor,
265 forgetGateBiasTensor,
266 cellBiasTensor,
267 outputGateBiasTensor,
268 projectionWeightsTensor,
269 projectionBiasTensor,
270 inputLayerNormWeights,
271 forgetLayerNormWeights,
272 cellLayerNormWeights,
273 outputLayerNormWeights,
274 inputGateScratch,
275 cellScratch,
276 forgetGateScratch,
277 outputGateScratch,
278 inputGateScratchDecoder,
279 cellScratchDecoder,
280 forgetGateScratchDecoder,
281 outputGateScratchDecoder,
282 m_LayerNormEpsilon);
283
284 currentInputData += batchInputSize;
285 inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
286 currentOutputData += batchOutputSize;
287 output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
288 outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
289
290 // Assign output state out to the next output state in
291 outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
292
293 // Assign cell state out to the next cell state in
294 cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
295 }
296
297 if (!m_Data.m_Parameters.m_TimeMajor)
298 {
299 // Permute Output back to batch major
300 const PermutationVector& mappings = {1U, 0U, 2U};
Mike Kelly12994962022-04-21 11:57:09 +0100301 auto outputData = reinterpret_cast<float*>(outputs[2]->Map());
Narumol Prangnawarate5339e72021-07-28 17:33:28 +0100302 std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements());
303 outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
304 outputInfo.SetShape(outputShape);
305 armnnUtils::Permute(outputShape, mappings, outputValue.data(), outputData, sizeof(float));
306 }
307}
308
309} //namespace armnn