blob: c697b66658738f94d48d2be4675f2c0e6480bdf8 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5
6#include "RefLstmFloat32Workload.hpp"
Matteo Martincigha65b7ae2018-11-14 12:39:55 +00007#include "RefWorkloadUtils.hpp"
8#include "Activation.hpp"
9
10namespace
11{
12
13// Helper functions ported from the Android code base
14// Refer to: android/external/tensorflow/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
15
16void MatrixBatchVectorMultiplyAccumulate(const float* matrix,
17 uint32_t mRows,
18 uint32_t mCols,
19 const float* vector,
20 uint32_t nBatch,
21 float* outResult,
22 int resultStride = 1)
23{
24 float* resultInBatch = outResult;
25 for (uint32_t b = 0; b < nBatch; b++)
26 {
27 const float* matrixPtr = matrix;
28 for (uint32_t r = 0; r < mRows; r++)
29 {
30 const float* vectorInBatch = vector + b * mCols;
31 for (uint32_t c = 0; c < mCols; c++)
32 {
33 *resultInBatch += *matrixPtr++ * *vectorInBatch++;
34 }
35 resultInBatch += resultStride;
36 }
37 }
38}
39
40void VectorBatchVectorAssign(const float* vector,
41 uint32_t vSize,
42 uint32_t nBatch,
43 float* outBatchVector)
44{
45 for (uint32_t b = 0; b < nBatch; b++)
46 {
47 memcpy(outBatchVector + b * vSize, vector, vSize * sizeof(float));
48 }
49}
50
51void VectorBatchVectorCwiseProductAccumulate(const float* vector,
52 uint32_t vSize,
53 const float* batchVector,
54 uint32_t nBatch,
55 float* outResult)
56{
57 for (uint32_t b = 0; b < nBatch; b++)
58 {
59 for (uint32_t v = 0; v < vSize; v++)
60 {
61 *outResult++ += vector[v] * *batchVector++;
62 }
63 }
64}
65
66void Sub1Vector(const float* vector,
67 uint32_t vSize,
68 float* result)
69{
70 for (uint32_t v = 0; v < vSize; v++)
71 {
72 *result++ = 1.0f - *vector++;
73 }
74}
75
76void VectorVectorCwiseProduct(const float* vector1,
77 const float* vector2,
78 uint32_t vSize,
79 float* outResult)
80{
81 for (uint32_t v = 0; v < vSize; v++)
82 {
83 *outResult++ = *vector1++ * *vector2++;
84 }
85}
86
87void VectorVectorCwiseProductAccumulate(const float* vector1,
88 const float* vector2,
89 uint32_t vSize,
90 float* outResult)
91{
92 for (uint32_t v = 0; v < vSize; v++)
93 {
94 *outResult++ += *vector1++ * *vector2++;
95 }
96}
97
98float Clip(float f,
99 float absLimit)
100{
101 float result = (absLimit < f) ? absLimit : f;
102 result = (-absLimit > result) ? -absLimit : result;
103 return result;
104}
105
106void ClipVector(const float* vector,
107 uint32_t vSize,
108 float absLimit,
109 float* outResult)
110{
111 for (uint32_t v = 0; v < vSize; v++)
112 {
113 *outResult++ = Clip(*vector++, absLimit);
114 }
115}
116
117void CopyVector(const float* vector,
118 uint32_t vSize,
119 float* outResult)
120{
121 memcpy(outResult, vector, vSize * sizeof(float));
122}
123
124void SetActivationParameters(uint32_t activation,
125 armnn::ActivationFunction& outArmnnActivation,
126 float& outA,
127 float& outB)
128{
129 switch (activation)
130 {
131 case 0: // None
132 outA = 0;
133 outB = 0;
134 return;
135
136 case 1: // Relu
137 outArmnnActivation = armnn::ActivationFunction::ReLu;
138 outA = 0;
139 outB = 0;
140 return;
141
142 case 3: // Relu6
143 outArmnnActivation = armnn::ActivationFunction::BoundedReLu;
144 outA = 6;
145 outB = 0;
146 return;
147
148 case 4: // Tanh
149 outArmnnActivation = armnn::ActivationFunction::TanH;
150 outA = 1;
151 outB = 1;
152 return;
153
154 case 6: // Sigmoid
155 outArmnnActivation = armnn::ActivationFunction::Sigmoid;
156 outA = 0;
157 outB = 0;
158 return;
159
160 default:
161 throw armnn::Exception("Unsupported activation function: " + std::to_string(activation));
162 }
163}
164
165std::unique_ptr<armnn::ScopedCpuTensorHandle> AssignScopedCpuTensorHandle(const armnn::ConstCpuTensorHandle* ptr)
166{
167 if (!ptr)
168 {
169 return nullptr;
170 }
171
172 return std::make_unique<armnn::ScopedCpuTensorHandle>(*ptr);
173}
174
175} // anonymous namespace
telsoa01c577f2c2018-08-31 09:22:23 +0100176
177namespace armnn
178{
179
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000180RefLstmFloat32Workload::RefLstmFloat32Workload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
181 : Float32Workload<LstmQueueDescriptor>(descriptor, info)
182 , m_InputToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToInputWeights))
183 , m_InputToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToForgetWeights))
184 , m_InputToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToCellWeights))
185 , m_InputToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_InputToOutputWeights))
186 , m_RecurrentToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToInputWeights))
187 , m_RecurrentToForgetWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToForgetWeights))
188 , m_RecurrentToCellWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_RecurrentToCellWeights))
189 , m_RecurrentToOutputWeightsTensor(AssignScopedCpuTensorHandle(descriptor.m_RecurrentToOutputWeights))
190 , m_CellToInputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToInputWeights))
191 , m_CellToForgetWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToForgetWeights))
192 , m_CellToOutputWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_CellToOutputWeights))
193 , m_InputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_InputGateBias))
194 , m_ForgetGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ForgetGateBias))
195 , m_CellBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_CellBias))
196 , m_OutputGateBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_OutputGateBias))
197 , m_ProjectionWeightsTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionWeights))
198 , m_ProjectionBiasTensor (AssignScopedCpuTensorHandle(descriptor.m_ProjectionBias))
199{}
200
telsoa01c577f2c2018-08-31 09:22:23 +0100201void RefLstmFloat32Workload::Execute() const
202{
Matteo Martincigha65b7ae2018-11-14 12:39:55 +0000203 // This is a porting of the LSTM::Eval() method in the Android code base
204 // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
205
206 const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
207 const TensorShape& inputShape = inputInfo.GetShape();
208
209 float* scratchBuffer = GetOutputTensorDataFloat(0, m_Data);
210 float* outputStateOut = GetOutputTensorDataFloat(1, m_Data);
211 float* cellStateOut = GetOutputTensorDataFloat(2, m_Data);
212 float* output = GetOutputTensorDataFloat(3, m_Data);
213
214 const float* inputData = GetInputTensorDataFloat(0, m_Data);
215 const float* outputStateIn = GetInputTensorDataFloat(1, m_Data);
216 const float* cellStateIn = GetInputTensorDataFloat(2, m_Data);
217
218 const uint32_t nBatch = inputShape[0];
219 const uint32_t nInput = inputShape[1];
220
221 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
222 const uint32_t nOutput = m_RecurrentToOutputWeightsTensor->GetShape()[1];
223
224 const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
225 const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
226
227 // Index the scratch buffers pointers to the global scratch buffer.
228 float* inputGateScratch = nullptr;
229 float* cellScratch = nullptr;
230 float* forgetGateScratch = nullptr;
231 float* outputGateScratch = nullptr;
232
233 if (useCifg)
234 {
235 cellScratch = scratchBuffer + 0 * nCell * nBatch;
236 forgetGateScratch = scratchBuffer + 1 * nCell * nBatch;
237 outputGateScratch = scratchBuffer + 2 * nCell * nBatch;
238 }
239 else
240 {
241 inputGateScratch = scratchBuffer + 0 * nCell * nBatch;
242 cellScratch = scratchBuffer + 1 * nCell * nBatch;
243 forgetGateScratch = scratchBuffer + 2 * nCell * nBatch;
244 outputGateScratch = scratchBuffer + 3 * nCell * nBatch;
245 }
246
247 // Initialize scratch buffers with bias.
248 if (!useCifg)
249 {
250 VectorBatchVectorAssign(m_InputGateBiasTensor->GetTensor<float>(),
251 nCell, nBatch, inputGateScratch);
252 }
253 VectorBatchVectorAssign(m_ForgetGateBiasTensor->GetTensor<float>(),
254 nCell, nBatch, forgetGateScratch);
255 VectorBatchVectorAssign(m_CellBiasTensor->GetTensor<float>(),
256 nCell, nBatch, cellScratch);
257 VectorBatchVectorAssign(m_OutputGateBiasTensor->GetTensor<float>(),
258 nCell, nBatch, outputGateScratch);
259
260 // For each batch and cell: compute input_weight * input.
261 if (!useCifg)
262 {
263 MatrixBatchVectorMultiplyAccumulate(m_InputToInputWeightsTensor->GetTensor<float>(),
264 nCell, nInput, inputData, nBatch, inputGateScratch);
265 }
266 MatrixBatchVectorMultiplyAccumulate(m_InputToForgetWeightsTensor->GetTensor<float>(),
267 nCell, nInput, inputData, nBatch, forgetGateScratch);
268 MatrixBatchVectorMultiplyAccumulate(m_InputToCellWeightsTensor->GetTensor<float>(),
269 nCell, nInput, inputData, nBatch, cellScratch);
270 MatrixBatchVectorMultiplyAccumulate(m_InputToOutputWeightsTensor->GetTensor<float>(),
271 nCell, nInput, inputData, nBatch, outputGateScratch);
272
273 // For each batch and cell: compute recurrent_weight * output_state.
274 if (!useCifg)
275 {
276 MatrixBatchVectorMultiplyAccumulate(m_RecurrentToInputWeightsTensor->GetTensor<float>(),
277 nCell, nOutput, outputStateIn, nBatch, inputGateScratch);
278 }
279 MatrixBatchVectorMultiplyAccumulate(m_RecurrentToForgetWeightsTensor->GetTensor<float>(),
280 nCell, nOutput, outputStateIn, nBatch, forgetGateScratch);
281 MatrixBatchVectorMultiplyAccumulate(m_RecurrentToCellWeightsTensor->GetTensor<float>(),
282 nCell, nOutput, outputStateIn, nBatch, cellScratch);
283 MatrixBatchVectorMultiplyAccumulate(m_RecurrentToOutputWeightsTensor->GetTensor<float>(),
284 nCell, nOutput, outputStateIn, nBatch, outputGateScratch);
285
286 // For each batch and cell: update input gate.
287 if (!useCifg)
288 {
289 if (usePeephole)
290 {
291 VectorBatchVectorCwiseProductAccumulate(m_CellToInputWeightsTensor->GetTensor<float>(),
292 nCell, cellStateIn, nBatch, inputGateScratch);
293 }
294 Activation(inputGateScratch, inputGateScratch,
295 TensorInfo({nCell, nBatch}, DataType::Float32),
296 ActivationFunction::Sigmoid, 0, 0);
297 }
298
299 // For each batch and cell: update forget gate.
300 if (usePeephole)
301 {
302 VectorBatchVectorCwiseProductAccumulate(m_CellToForgetWeightsTensor->GetTensor<float>(), nCell,
303 cellStateIn, nBatch, forgetGateScratch);
304 }
305 Activation(forgetGateScratch, forgetGateScratch,
306 TensorInfo({nCell, nBatch}, DataType::Float32),
307 ActivationFunction::Sigmoid, 0, 0);
308
309 // For each batch and cell: update the cell.
310 VectorVectorCwiseProduct(forgetGateScratch, cellStateIn, nBatch * nCell, cellStateOut);
311
312 ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
313 float a = 0;
314 float b = 0;
315 SetActivationParameters(m_Data.m_Parameters.m_ActivationFunc, armnnActivationFunc, a, b);
316
317 if (m_Data.m_Parameters.m_ActivationFunc > 0)
318 {
319 Activation(cellScratch, cellScratch,
320 TensorInfo({nCell, nBatch}, DataType::Float32),
321 armnnActivationFunc, a, b);
322 }
323 if (useCifg)
324 {
325 Sub1Vector(forgetGateScratch, nBatch * nCell, forgetGateScratch);
326 VectorVectorCwiseProductAccumulate(cellScratch, forgetGateScratch, nBatch * nCell, cellStateOut);
327 }
328 else
329 {
330 VectorVectorCwiseProductAccumulate(cellScratch, inputGateScratch, nBatch * nCell, cellStateOut);
331 }
332 if (m_Data.m_Parameters.m_ClippingThresCell > 0.0)
333 {
334 ClipVector(cellStateOut, nBatch * nCell, m_Data.m_Parameters.m_ClippingThresCell, cellStateOut);
335 }
336
337 // For each batch and cell: update the output gate.
338 if (usePeephole)
339 {
340 VectorBatchVectorCwiseProductAccumulate(m_CellToOutputWeightsTensor->GetTensor<float>(),
341 nCell, cellStateOut, nBatch, outputGateScratch);
342 }
343 Activation(outputGateScratch, outputGateScratch,
344 TensorInfo({nCell, nBatch}, DataType::Float32),
345 ActivationFunction::Sigmoid, 0, 0);
346
347 if (m_Data.m_Parameters.m_ActivationFunc > 0)
348 {
349 Activation(cellStateOut, cellScratch,
350 TensorInfo({nCell, nBatch}, DataType::Float32),
351 armnnActivationFunc, a, b);
352 }
353 VectorVectorCwiseProduct(outputGateScratch, cellScratch, nBatch * nCell, outputGateScratch);
354
355 // For each batch: update the projection and output_state.
356 if (m_Data.m_Parameters.m_ProjectionEnabled)
357 {
358 if (m_ProjectionBiasTensor)
359 {
360 VectorBatchVectorAssign(m_ProjectionBiasTensor->GetTensor<float>(),
361 nOutput, nBatch, output);
362 }
363 MatrixBatchVectorMultiplyAccumulate(m_ProjectionWeightsTensor->GetTensor<float>(),
364 nOutput, nCell, outputGateScratch, nBatch, output);
365
366 if (m_Data.m_Parameters.m_ClippingThresProj > 0.0)
367 {
368 ClipVector(output, nBatch * nOutput, m_Data.m_Parameters.m_ClippingThresProj, output);
369 }
370 }
371 else
372 {
373 CopyVector(outputGateScratch, nBatch * nOutput, output);
374 }
375
376 CopyVector(output, nBatch * nOutput, outputStateOut);
telsoa01c577f2c2018-08-31 09:22:23 +0100377}
378
379} //namespace armnn