blob: 311fa18f912eec1525c9579e1467f3a9fbf9c807 [file] [log] [blame]
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefUnidirectionalSequenceLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10#include "Lstm.hpp"
11#include "LstmUtils.hpp"
12#include "RefWorkloadUtils.hpp"
13
14#include <armnnUtils/Permute.hpp>
15
16namespace armnn
17{
18
19RefUnidirectionalSequenceLstmWorkload::RefUnidirectionalSequenceLstmWorkload(
20 const UnidirectionalSequenceLstmQueueDescriptor& descriptor,
21 const WorkloadInfo& info)
22 : BaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>(descriptor, info)
23 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
24 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
25 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
26 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
27 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
28 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
29 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
30 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
31 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
32 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
33 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
34 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
35 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
36 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
37 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
38 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
39 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
40 , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
41 , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
42 , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
43 , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
44{}
45
46void RefUnidirectionalSequenceLstmWorkload::Execute() const
47{
48 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
49}
50
51void RefUnidirectionalSequenceLstmWorkload::ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor)
52{
53 Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
54}
55
56void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> inputs,
57 std::vector<ITensorHandle*> outputs) const
58{
59 TensorInfo inputInfo = GetTensorInfo(inputs[0]);
60 const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]);
61 const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]);
62 TensorInfo outputInfo = GetTensorInfo(outputs[0]);
63 TensorShape& inputShape = inputInfo.GetShape();
64 TensorShape& outputShape= outputInfo.GetShape();
65 auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map());
66
67 if (!m_Data.m_Parameters.m_TimeMajor)
68 {
69 // Permute to time major
70 const PermutationVector& mappings = {1U, 0U, 2U};
71 std::vector<float> inputValue(inputTensor, inputTensor + inputInfo.GetNumElements());
72 inputShape = armnnUtils::Permuted(inputInfo.GetShape(), mappings);
73 inputInfo.SetShape(inputShape);
74 armnnUtils::Permute(inputShape, mappings, inputValue.data(), inputTensor, sizeof(float));
75
76 outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
77 outputInfo.SetShape(outputShape);
78 }
79 unsigned int maxTime = inputShape[0];
80 unsigned int batchSize = inputShape[1];
81 unsigned int outputSize = outputShape[2];
82 unsigned int inputSize = inputShape[2];
83
84 TensorInfo scratchInfo = outputInfo;
85 scratchInfo.SetShape({batchSize, cellStateInfo.GetShape()[1]});
86
87 std::vector<float> inputGateScratchBuffer;
88 std::vector<float> cellScratchBuffer(scratchInfo.GetNumElements(), 0.);
89 std::vector<float> forgetGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
90 std::vector<float> outputGateScratchBuffer(scratchInfo.GetNumElements(), 0.);
91
92 std::vector<float> outputStateOutBuffer(outputStateInfo.GetNumElements(), 0.);
93 std::vector<float> cellStateOutBuffer(cellStateInfo.GetNumElements(), 0.);
94
95 void* outputStateOutData = outputStateOutBuffer.data();
96 void* cellStateOutData = cellStateOutBuffer.data();
97
98 std::unique_ptr<Encoder<float>> inputGateScratch;
99 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(scratchInfo, cellScratchBuffer.data());
100 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(scratchInfo, forgetGateScratchBuffer.data());
101 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(scratchInfo, outputGateScratchBuffer.data());
102
103 std::unique_ptr<Decoder<float>> inputGateScratchDecoder;
104 std::unique_ptr<Decoder<float>> cellScratchDecoder = MakeDecoder<float>(scratchInfo, cellScratchBuffer.data());
105 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder = MakeDecoder<float>(scratchInfo,
106 forgetGateScratchBuffer.data());
107 std::unique_ptr<Decoder<float>> outputGateScratchDecoder = MakeDecoder<float>(scratchInfo,
108 outputGateScratchBuffer.data());
109
110 const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
111 const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
112 const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
113
114 if (!useCifg)
115 {
116 inputGateScratchBuffer.resize(scratchInfo.GetNumElements(), 0.);
117 inputGateScratch = MakeEncoder<float>(scratchInfo, inputGateScratchBuffer.data());
118 inputGateScratchDecoder = MakeDecoder<float>(scratchInfo, inputGateScratchBuffer.data());
119 }
120
121 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputStateInfo, outputStateOutData);
122 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(cellStateInfo, cellStateOutData);
123 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(cellStateInfo, cellStateOutData);
124
125 TensorInfo lstmInputInfo = inputInfo;
126 TensorShape batchInputShape = TensorShape({batchSize, inputSize});
127 lstmInputInfo.SetShape(batchInputShape);
128
129 TensorInfo lstmOutputInfo = outputInfo;
130 lstmOutputInfo.SetShape({batchSize, outputSize});
131
132 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
133 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
134 unsigned int nOutput = recurrentToOutputWeightsShape[1];
135 auto outputStateInData = inputs[1]->Map();
136 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateInData);
137
138 auto cellStateInData = inputs[2]->Map();
139 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateInData);
140
141 auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map());
142 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
143 auto currentOutputData = reinterpret_cast<float*>(outputs[0]->Map());
144 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
145 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
146
147 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
148 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
149 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
150 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
151 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
152 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
153 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
154
155 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
156 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
157 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
158 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
159 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
160 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
161 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
162
163 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
164 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
165 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
166 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
167 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
168 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
169 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
170
171 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
172 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
173 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
174
175 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
176 std::unique_ptr<Decoder<float>> projectionBiasTensor;
177
178 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
179 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
180 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
181 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
182
183 if (useLayerNorm)
184 {
185 if (!useCifg)
186 {
187 inputLayerNormWeights = MakeDecoder<float>(
188 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
189 }
190 forgetLayerNormWeights = MakeDecoder<float>(
191 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
192 cellLayerNormWeights = MakeDecoder<float>(
193 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
194 outputLayerNormWeights = MakeDecoder<float>(
195 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
196 }
197
198 if (!useCifg)
199 {
200 inputToInputWeightsTensor = MakeDecoder<float>(
201 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
202 inputGateBiasTensor = MakeDecoder<float>(
203 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
204 recurrentToInputWeightsTensor = MakeDecoder<float>(
205 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
206 }
207
208 if (usePeephole)
209 {
210 cellToForgetWeightsTensor = MakeDecoder<float>(
211 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
212 cellToOutputWeightsTensor = MakeDecoder<float>(
213 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
214 }
215
216 if (!useCifg && usePeephole)
217 {
218 cellToInputWeightsTensor = MakeDecoder<float>(
219 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
220 }
221
222 if (m_Data.m_Parameters.m_ProjectionEnabled)
223 {
224 projectionWeightsTensor = MakeDecoder<float>(
225 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
226 if (m_ProjectionBiasTensor)
227 {
228 projectionBiasTensor = MakeDecoder<float>(
229 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
230 }
231 }
232
233 unsigned int batchInputSize = batchSize * inputSize;
234 unsigned int batchOutputSize = batchSize * nOutput;
235
236 for (unsigned int t = 0; t < maxTime; ++t)
237 {
238 LstmImpl(m_Data.m_Parameters,
239 lstmInputInfo,
240 lstmOutputInfo,
241 inputToOutputWeightsShape,
242 recurrentToOutputWeightsShape,
243 inputData,
244 outputStateIn,
245 cellStateIn,
246 outputStateOut,
247 cellStateOut,
248 output,
249 cellStateOutDecoder,
250 outputDecoder,
251 inputToInputWeightsTensor,
252 inputToForgetWeightsTensor,
253 inputToCellWeightsTensor,
254 inputToOutputWeightsTensor,
255 recurrentToInputWeightsTensor,
256 recurrentToForgetWeightsTensor,
257 recurrentToCellWeightsTensor,
258 recurrentToOutputWeightsTensor,
259 cellToInputWeightsTensor,
260 cellToForgetWeightsTensor,
261 cellToOutputWeightsTensor,
262 inputGateBiasTensor,
263 forgetGateBiasTensor,
264 cellBiasTensor,
265 outputGateBiasTensor,
266 projectionWeightsTensor,
267 projectionBiasTensor,
268 inputLayerNormWeights,
269 forgetLayerNormWeights,
270 cellLayerNormWeights,
271 outputLayerNormWeights,
272 inputGateScratch,
273 cellScratch,
274 forgetGateScratch,
275 outputGateScratch,
276 inputGateScratchDecoder,
277 cellScratchDecoder,
278 forgetGateScratchDecoder,
279 outputGateScratchDecoder,
280 m_LayerNormEpsilon);
281
282 currentInputData += batchInputSize;
283 inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
284 currentOutputData += batchOutputSize;
285 output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
286 outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
287
288 // Assign output state out to the next output state in
289 outputStateIn = MakeDecoder<float>(outputStateInfo, outputStateOutData);
290
291 // Assign cell state out to the next cell state in
292 cellStateIn = MakeDecoder<float>(cellStateInfo, cellStateOutData);
293 }
294
295 if (!m_Data.m_Parameters.m_TimeMajor)
296 {
297 // Permute Output back to batch major
298 const PermutationVector& mappings = {1U, 0U, 2U};
299 auto outputData = reinterpret_cast<float*>(outputs[0]->Map());
300 std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements());
301 outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
302 outputInfo.SetShape(outputShape);
303 armnnUtils::Permute(outputShape, mappings, outputValue.data(), outputData, sizeof(float));
304 }
305}
306
307} //namespace armnn