blob: c1fb2bf4aa50bffe11d20d657080b4a826785e8f [file] [log] [blame]
Narumol Prangnawarate5339e72021-07-28 17:33:28 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "Activation.hpp"
7#include "Lstm.hpp"
8#include "LstmUtils.hpp"
9
10namespace armnn
11{
12
13void LstmImpl(const LstmDescriptor& descriptor,
14 const TensorInfo& inputInfo,
15 const TensorInfo& outputInfo,
16 const TensorShape& inputToOutputWeightsShape,
17 const TensorShape& recurrentToOutputWeightsShape,
18 std::unique_ptr<Decoder<float>>& inputData,
19 std::unique_ptr<Decoder<float>>& outputStateIn,
20 std::unique_ptr<Decoder<float>>& cellStateIn,
21 std::unique_ptr<Encoder<float>>& outputStateOut,
22 std::unique_ptr<Encoder<float>>& cellStateOut,
23 std::unique_ptr<Encoder<float>>& output,
24 std::unique_ptr<Decoder<float>>& cellStateOutDecoder,
25 std::unique_ptr<Decoder<float>>& outputDecoder,
26 std::unique_ptr<Decoder<float>>& inputToInputWeightsTensor,
27 std::unique_ptr<Decoder<float>>& inputToForgetWeightsTensor,
28 std::unique_ptr<Decoder<float>>& inputToCellWeightsTensor,
29 std::unique_ptr<Decoder<float>>& inputToOutputWeightsTensor,
30 std::unique_ptr<Decoder<float>>& recurrentToInputWeightsTensor,
31 std::unique_ptr<Decoder<float>>& recurrentToForgetWeightsTensor,
32 std::unique_ptr<Decoder<float>>& recurrentToCellWeightsTensor,
33 std::unique_ptr<Decoder<float>>& recurrentToOutputWeightsTensor,
34 std::unique_ptr<Decoder<float>>& cellToInputWeightsTensor,
35 std::unique_ptr<Decoder<float>>& cellToForgetWeightsTensor,
36 std::unique_ptr<Decoder<float>>& cellToOutputWeightsTensor,
37 std::unique_ptr<Decoder<float>>& inputGateBiasTensor,
38 std::unique_ptr<Decoder<float>>& forgetGateBiasTensor,
39 std::unique_ptr<Decoder<float>>& cellBiasTensor,
40 std::unique_ptr<Decoder<float>>& outputGateBiasTensor,
41 std::unique_ptr<Decoder<float>>& projectionWeightsTensor,
42 std::unique_ptr<Decoder<float>>& projectionBiasTensor,
43 std::unique_ptr<Decoder<float>>& inputLayerNormWeights,
44 std::unique_ptr<Decoder<float>>& forgetLayerNormWeights,
45 std::unique_ptr<Decoder<float>>& cellLayerNormWeights,
46 std::unique_ptr<Decoder<float>>& outputLayerNormWeights,
47 std::unique_ptr<Encoder<float>>& inputGateScratch,
48 std::unique_ptr<Encoder<float>>& cellScratch,
49 std::unique_ptr<Encoder<float>>& forgetGateScratch,
50 std::unique_ptr<Encoder<float>>& outputGateScratch,
51 std::unique_ptr<Decoder<float>>& inputGateScratchDecoder,
52 std::unique_ptr<Decoder<float>>& cellScratchDecoder,
53 std::unique_ptr<Decoder<float>>& forgetGateScratchDecoder,
54 std::unique_ptr<Decoder<float>>& outputGateScratchDecoder,
55 float layerNormEpsilon)
56{
57 // This is a porting of the LSTM::Eval() method in the Android code base
58 // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
59
60 const TensorShape& inputShape = inputInfo.GetShape();
61 const DataType& outputType = outputInfo.GetDataType();
62
63 const uint32_t nBatch = inputShape[0];
64 const uint32_t nInput = inputShape[1];
65
66 const uint32_t nCell = inputToOutputWeightsShape[0];
67 const uint32_t nOutput = recurrentToOutputWeightsShape[1];
68
69 const bool useCifg = descriptor.m_CifgEnabled;
70 const bool usePeephole = descriptor.m_PeepholeEnabled;
71 const bool useLayerNorm = descriptor.m_LayerNormEnabled;
72
73 if (!useLayerNorm)
74 {
75 // Initialize scratch buffers with bias.
76 if (!useCifg)
77 {
78 VectorBatchVectorAssign(*inputGateBiasTensor,
79 nCell, nBatch, *inputGateScratch);
80 }
81 VectorBatchVectorAssign(*forgetGateBiasTensor,
82 nCell, nBatch, *forgetGateScratch);
83 VectorBatchVectorAssign(*cellBiasTensor,
84 nCell, nBatch, *cellScratch);
85 VectorBatchVectorAssign(*outputGateBiasTensor,
86 nCell, nBatch, *outputGateScratch);
87 }
88 else
89 {
90 // Initialize scratch buffers with zeroes.
91 if (!useCifg)
92 {
93 ZeroVector(*inputGateScratch, nCell * nBatch);
94 }
95 ZeroVector(*forgetGateScratch, nCell * nBatch);
96 ZeroVector(*cellScratch , nCell * nBatch);
97 ZeroVector(*outputGateScratch, nCell * nBatch);
98 }
99
100 // For each batch and cell: compute input_weight * input.
101 if (!useCifg)
102 {
103 MatrixBatchVectorMultiplyAccumulate(*inputToInputWeightsTensor,
104 nCell, nInput, *inputData, nBatch, *inputGateScratch);
105 }
106 MatrixBatchVectorMultiplyAccumulate(*inputToForgetWeightsTensor,
107 nCell, nInput, *inputData, nBatch, *forgetGateScratch);
108 MatrixBatchVectorMultiplyAccumulate(*inputToCellWeightsTensor,
109 nCell, nInput, *inputData, nBatch, *cellScratch);
110 MatrixBatchVectorMultiplyAccumulate(*inputToOutputWeightsTensor,
111 nCell, nInput, *inputData, nBatch, *outputGateScratch);
112
113 // For each batch and cell: compute recurrent_weight * output_state.
114 if (!useCifg)
115 {
116 MatrixBatchVectorMultiplyAccumulate(*recurrentToInputWeightsTensor,
117 nCell, nOutput, *outputStateIn, nBatch, *inputGateScratch);
118 }
119 MatrixBatchVectorMultiplyAccumulate(*recurrentToForgetWeightsTensor,
120 nCell, nOutput, *outputStateIn, nBatch, *forgetGateScratch);
121 MatrixBatchVectorMultiplyAccumulate(*recurrentToCellWeightsTensor,
122 nCell, nOutput, *outputStateIn, nBatch, *cellScratch);
123 MatrixBatchVectorMultiplyAccumulate(*recurrentToOutputWeightsTensor,
124 nCell, nOutput, *outputStateIn, nBatch, *outputGateScratch);
125
126 // For each batch and cell: update input gate.
127 if (!useCifg)
128 {
129 if (usePeephole)
130 {
131 VectorBatchVectorCwiseProductAccumulate(*cellToInputWeightsTensor,
132 nCell, *cellStateIn, nBatch, *inputGateScratch);
133 }
134 if (useLayerNorm)
135 {
136 MeanStddevNormalization(*inputGateScratchDecoder,
137 *inputGateScratch, nCell, nBatch, layerNormEpsilon);
138 VectorBatchVectorCwiseProduct(*inputLayerNormWeights,
139 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
140 VectorBatchVectorAdd(*inputGateBiasTensor,
141 nCell, *inputGateScratchDecoder, nBatch, *inputGateScratch);
142 }
143 Activation(*inputGateScratchDecoder, *inputGateScratch,
144 TensorInfo({nCell, nBatch}, outputType),
145 ActivationFunction::Sigmoid, 0, 0);
146 }
147
148 // For each batch and cell: update forget gate.
149 if (usePeephole)
150 {
151 VectorBatchVectorCwiseProductAccumulate(*cellToForgetWeightsTensor, nCell,
152 *cellStateIn, nBatch, *forgetGateScratch);
153 }
154 if (useLayerNorm)
155 {
156 MeanStddevNormalization(*forgetGateScratchDecoder,
157 *forgetGateScratch, nCell, nBatch, layerNormEpsilon);
158 VectorBatchVectorCwiseProduct(*forgetLayerNormWeights,
159 nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
160 VectorBatchVectorAdd(*forgetGateBiasTensor,
161 nCell, *forgetGateScratchDecoder, nBatch, *forgetGateScratch);
162 }
163 Activation(*forgetGateScratchDecoder, *forgetGateScratch,
164 TensorInfo({nCell, nBatch}, outputType),
165 ActivationFunction::Sigmoid, 0, 0);
166
167 // For each batch and cell: update the cell.
168 if (useLayerNorm)
169 {
170 MeanStddevNormalization(*cellScratchDecoder,
171 *cellScratch, nCell, nBatch, layerNormEpsilon);
172 VectorBatchVectorCwiseProduct(*cellLayerNormWeights,
173 nCell, *cellScratchDecoder, nBatch, *cellScratch);
174 VectorBatchVectorAdd(*cellBiasTensor,
175 nCell, *cellScratchDecoder, nBatch, *cellScratch);
176 }
177
178 VectorVectorCwiseProduct(*forgetGateScratchDecoder, *cellStateIn, nBatch * nCell, *cellStateOut);
179
180 ActivationFunction armnnActivationFunc = ActivationFunction::Sigmoid;
181 float a = 0;
182 float b = 0;
183 SetActivationParameters(descriptor.m_ActivationFunc, armnnActivationFunc, a, b);
184
185 if (descriptor.m_ActivationFunc > 0)
186 {
187 Activation(*cellScratchDecoder, *cellScratch,
188 TensorInfo({nCell, nBatch}, outputType),
189 armnnActivationFunc, a, b);
190 }
191 if (useCifg)
192 {
193 Sub1Vector(*forgetGateScratchDecoder, nBatch * nCell, *forgetGateScratch);
194 VectorVectorCwiseProductAccumulate(
195 *cellScratchDecoder, *forgetGateScratchDecoder, nBatch * nCell, *cellStateOut);
196 }
197 else
198 {
199 VectorVectorCwiseProductAccumulate(
200 *cellScratchDecoder, *inputGateScratchDecoder, nBatch * nCell, *cellStateOut);
201 }
202 if (descriptor.m_ClippingThresCell > 0.0)
203 {
204 ClipVector(*cellStateOutDecoder, nBatch * nCell, descriptor.m_ClippingThresCell, *cellStateOut);
205 }
206
207 // For each batch and cell: update the output gate.
208 if (usePeephole)
209 {
210 VectorBatchVectorCwiseProductAccumulate(*cellToOutputWeightsTensor,
211 nCell, *cellStateOutDecoder, nBatch, *outputGateScratch);
212 }
213 if (useLayerNorm)
214 {
215 MeanStddevNormalization(*outputGateScratchDecoder,
216 *outputGateScratch, nCell, nBatch, layerNormEpsilon);
217 VectorBatchVectorCwiseProduct(*outputLayerNormWeights,
218 nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
219 VectorBatchVectorAdd(*outputGateBiasTensor,
220 nCell, *outputGateScratchDecoder, nBatch, *outputGateScratch);
221 }
222 Activation(*outputGateScratchDecoder, *outputGateScratch,
223 TensorInfo({nCell, nBatch}, outputType),
224 ActivationFunction::Sigmoid, 0, 0);
225
226 if (descriptor.m_ActivationFunc > 0)
227 {
228 Activation(*cellStateOutDecoder, *cellScratch,
229 TensorInfo({nCell, nBatch}, outputType),
230 armnnActivationFunc, a, b);
231 }
232
233 VectorVectorCwiseProduct(*outputGateScratchDecoder, *cellScratchDecoder, nBatch * nCell, *outputGateScratch);
234
235 // For each batch: update the projection and output_state.
236 if (descriptor.m_ProjectionEnabled)
237 {
238 if (projectionBiasTensor)
239 {
240 VectorBatchVectorAssign(*projectionBiasTensor,
241 nOutput, nBatch, *output);
242 }
243 MatrixBatchVectorMultiplyAccumulate(*projectionWeightsTensor,
244 nOutput, nCell, *outputGateScratchDecoder, nBatch, *output);
245
246 if (descriptor.m_ClippingThresProj > 0.0)
247 {
248 ClipVector(*outputDecoder, nBatch * nOutput, descriptor.m_ClippingThresProj, *output);
249 }
250 }
251 else
252 {
253 CopyVector(*outputGateScratchDecoder, nBatch * nOutput, *output);
254 }
255
256 CopyVector(*outputDecoder, nBatch * nOutput, *outputStateOut);
257}
258
259} //namespace armnn