blob: f8ebc58f6e7a686b5c035318967ed824c745c356 [file] [log] [blame]
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
10#include "LstmUtils.hpp"
11#include "RefWorkloadUtils.hpp"
12
13namespace armnn
14{
15
16RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
17 : BaseWorkload<LstmQueueDescriptor>(descriptor, info)
18 , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights))
19 , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights))
20 , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights))
21 , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights))
22 , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights))
23 , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights))
24 , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights))
25 , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights))
26 , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights))
27 , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights))
28 , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights))
29 , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias))
30 , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias))
31 , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias))
32 , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias))
33 , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights))
34 , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias))
35{}
36
37void RefLstmWorkload::Execute() const
38{
39 // This is a porting of the LSTM::Eval() method in the Android code base
40 // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
41
42 const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
43 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
44
45 const TensorShape& inputShape = inputInfo.GetShape();
46 const DataType& outputType = outputInfo.GetDataType();
47
48 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[1]->Map());
49 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
50 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[3]->Map());
51
52 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[2]->Map());
53 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, m_Data.m_Outputs[3]->Map());
54
55 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map());
56 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[1]->Map());
57 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, m_Data.m_Inputs[2]->Map());
58
59 const uint32_t nBatch = inputShape[0];
60 const uint32_t nInput = inputShape[1];
61
62 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
63 const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1];
64
65 const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
66 const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
67
68 // Index the scratch buffers pointers to the global scratch buffer.
69 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
70 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
71 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
72 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
73
74 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
75 MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
76 std::unique_ptr<Decoder<float>> cellScratchDecoder =
77 MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
78 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
79 MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
80 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
81 MakeDecoder<float>(outputInfo, m_Data.m_Outputs[0]->Map());
82
83 if (useCifg)
84 {
85 *cellScratch += (0 * nCell * nBatch);
86 *forgetGateScratch += (1 * nCell * nBatch);
87 *outputGateScratch += (2 * nCell * nBatch);
88
89 *cellScratchDecoder += (0 * nCell * nBatch);
90 *forgetGateScratchDecoder += (1 * nCell * nBatch);
91 *outputGateScratchDecoder += (2 * nCell * nBatch);
92 }
93 else
94 {
95 *inputGateScratch += (0 * nCell * nBatch);
96 *cellScratch += (1 * nCell * nBatch);
97 *forgetGateScratch += (2 * nCell * nBatch);
98 *outputGateScratch += (3 * nCell * nBatch);
99
100 *inputGateScratchDecoder += (0 * nCell * nBatch);
101 *cellScratchDecoder += (1 * nCell * nBatch);
102 *forgetGateScratchDecoder += (2 * nCell * nBatch);
103 *outputGateScratchDecoder += (3 * nCell * nBatch);
104 }
105
106 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
107 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
108 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetTensor<void>());
109 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
110 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetTensor<void>());
111 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
112 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetTensor<void>());
113
114 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
115 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
116 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetTensor<void>());
117 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
118 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetTensor<void>());
119 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
120 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetTensor<void>());
121
122 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
123 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
124 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetTensor<void>());
125 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
126 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetTensor<void>());
127 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
128 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetTensor<void>());
129
130 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
131 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
132 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
133
134 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
135 std::unique_ptr<Decoder<float>> projectionBiasTensor;
136
137 if (!useCifg)
138 {
139 inputToInputWeightsTensor = MakeDecoder<float>(
140 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetTensor<void>());
141 inputGateBiasTensor = MakeDecoder<float>(
142 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetTensor<void>());
143 recurrentToInputWeightsTensor = MakeDecoder<float>(
144 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetTensor<void>());
145 }
146
147 if (usePeephole)
148 {
149 cellToForgetWeightsTensor = MakeDecoder<float>(
150 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetTensor<void>());
151 cellToOutputWeightsTensor = MakeDecoder<float>(
152 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetTensor<void>());
153 }
154
155 if (!useCifg && usePeephole)
156 {
157 cellToInputWeightsTensor = MakeDecoder<float>(
158 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetTensor<void>());
159 }
160
161 if (m_Data.m_Parameters.m_ProjectionEnabled)
162 {
163 projectionWeightsTensor = MakeDecoder<float>(
164 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetTensor<void>());
165 if (m_ProjectionBiasTensor)
166 {
167 projectionBiasTensor = MakeDecoder<float>(
168 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetTensor<void>());
169 }
170 }
171
172 // Initialize scratch buffers with bias.
173 if (!useCifg)
174 {
175 VectorBatchVectorAssign(*inputGateBiasTensor,
176 nCell, nBatch, *inputGateScratch);
177 }
178 VectorBatchVectorAssign(*forgetGateBiasTensor,
179 nCell, nBatch, *forgetGateScratch);
180 VectorBatchVectorAssign(*cellBiasTensor,
181 nCell, nBatch, *cellScratch);
182 VectorBatchVectorAssign(*outputGateBiasTensor,
183 nCell, nBatch, *outputGateScratch);
184
185 // For each batch and cell: compute input_weight * input.
186 if (!useCifg)
187 {
188 MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor,
189 nCell, nInput, *inputData, nBatch, *inputGateScratch);
190 }
191 MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor,
192 nCell, nInput, *inputData, nBatch, *forgetGateScratch);
193 MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor,
194 nCell, nInput, *inputData, nBatch, *cellScratch);
195 MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor,
196 nCell, nInput, *inputData, nBatch, *outputGateScratch);
197
198 // For each batch and cell: compute recurrent_weight * output_state.
199 if (!useCifg)
200 {
201 MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor,
202 nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
203 }
204 MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor,
205 nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
206 MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor,
207 nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
208 MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor,
209 nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
210
211 // For each batch and cell: update input gate.
212 if (!useCifg)
213 {
214 if (usePeephole)
215 {
216 VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor,
217 nCell, *cellStateIn, nBatch, *inputGateScratch);
218 }
219 Activation(*inputGateScratchDecoder, *inputGateScratch,
220 TensorInfo({nCell, nBatch}, outputType),
221 ActivationFunction::Sigmoid, 0, 0);
222 }
223
224 // For each batch and cell: update forget gate.
225 if (usePeephole)
226 {
227 VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell,
228 *cellStateIn, nBatch, *forgetGateScratch);
229 }
230 Activation(*forgetGateScratchDecoder, *forgetGateScratch,
231 TensorInfo({nCell, nBatch}, outputType),
232 ActivationFunction::Sigmoid, 0, 0);
233
234 // For each batch and cell: update the cell.
235 VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut);
236
237 ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
238 float a = 0;
239 float b = 0;
240 SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b);
241
242 if (m_Data.m_Parameters.m_ActivationFunc > 0)
243 {
244 Activation(*cellScratchDecoder, *cellScratch,
245 TensorInfo({nCell, nBatch}, outputType),
246 armnnActivationFunc, a, b);
247 }
248 if (useCifg)
249 {
250 Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
251 VectorVectorCwiseProductAccumulate(
252 *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
253 }
254 else
255 {
256 VectorVectorCwiseProductAccumulate(
257 *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
258 }
259 if (m_Data.m_Parameters.m_ClippingThresCell > 0.0)
260 {
261 ClipVector(*cellStateOutDecoder, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, *cellStateOut);
262 }
263
264 // For each batch and cell: update the output gate.
265 if (usePeephole)
266 {
267 VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor,
268 nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
269 }
270 Activation(*outputGateScratchDecoder, *outputGateScratch,
271 TensorInfo({nCell, nBatch}, outputType),
272 ActivationFunction::Sigmoid, 0, 0);
273
274 if (m_Data.m_Parameters.m_ActivationFunc > 0)
275 {
276 Activation(*cellStateOutDecoder, *cellScratch,
277 TensorInfo({nCell, nBatch}, outputType),
278 armnnActivationFunc, a, b);
279 }
280
281 VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch);
282
283 // For each batch: update the projection and output_state.
284 if (m_Data.m_Parameters.m_ProjectionEnabled)
285 {
286 if (m_ProjectionBiasTensor)
287 {
288 VectorBatchVectorAssign(*projectionBiasTensor,
289 nOutput, nBatch, *output);
290 }
291 MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor,
292 nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
293
294 if (m_Data.m_Parameters.m_ClippingThresProj > 0.0)
295 {
296 ClipVector(*outputDecoder, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, *output);
297 }
298 }
299 else
300 {
301 CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
302 }
303
304 CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);
305}
306
307} //namespace armnn