blob: 1ff6f50ed53bdedabe13fe6f043d45d41d4a2672 [file] [log] [blame]
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "RefLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010010#include "Lstm.hpp"
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010011#include "LstmUtils.hpp"
12#include "RefWorkloadUtils.hpp"
13
14namespace armnn
15{
16
17RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
18 : BaseWorkload<LstmQueueDescriptor>(descriptor, info)
James Conroy1f58f032021-04-27 17:13:27 +010019 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
20 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
21 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
22 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
23 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
27 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
28 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
29 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
30 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
31 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
32 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
33 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
34 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
35 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
36 , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
37 , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
38 , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
39 , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010040{}
41
42void RefLstmWorkload::Execute() const
43{
Finn Williamsb8181f72021-04-07 10:23:21 +010044 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
45}
46
47void RefLstmWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor)
48{
49 Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
50}
51
52void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
53{
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010054 // This is a porting of the LSTM::Eval() method in the Android code base
55 // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
56
Finn Williamsb8181f72021-04-07 10:23:21 +010057 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
58 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010059
60 const TensorShape& inputShape = inputInfo.GetShape();
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010061
Finn Williamsb8181f72021-04-07 10:23:21 +010062 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
63 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map());
64 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010065
Finn Williamsb8181f72021-04-07 10:23:21 +010066 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
67 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010068
Finn Williamsb8181f72021-04-07 10:23:21 +010069 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map());
70 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
71 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010072
73 const uint32_t nBatch = inputShape[0];
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010074 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010075
Jan Eilers38e05bd2019-06-26 13:10:09 +010076 const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
77 const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
78 const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010079
80 // Index the scratch buffers pointers to the global scratch buffer.
Finn Williamsb8181f72021-04-07 10:23:21 +010081 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
82 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
83 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
84 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010085
86 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010087 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010088 std::unique_ptr<Decoder<float>> cellScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010089 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010090 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010091 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010092 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010093 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010094
95 if (useCifg)
96 {
97 *cellScratch += (0 * nCell * nBatch);
98 *forgetGateScratch += (1 * nCell * nBatch);
99 *outputGateScratch += (2 * nCell * nBatch);
100
101 *cellScratchDecoder += (0 * nCell * nBatch);
102 *forgetGateScratchDecoder += (1 * nCell * nBatch);
103 *outputGateScratchDecoder += (2 * nCell * nBatch);
104 }
105 else
106 {
107 *inputGateScratch += (0 * nCell * nBatch);
108 *cellScratch += (1 * nCell * nBatch);
109 *forgetGateScratch += (2 * nCell * nBatch);
110 *outputGateScratch += (3 * nCell * nBatch);
111
112 *inputGateScratchDecoder += (0 * nCell * nBatch);
113 *cellScratchDecoder += (1 * nCell * nBatch);
114 *forgetGateScratchDecoder += (2 * nCell * nBatch);
115 *outputGateScratchDecoder += (3 * nCell * nBatch);
116 }
117
118 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
119 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000120 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100121 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000122 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100123 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000124 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100125
126 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
127 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000128 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100129 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000130 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100131 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000132 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100133
134 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
135 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000136 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100137 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000138 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100139 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000140 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100141
142 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
143 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
144 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
145
146 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
147 std::unique_ptr<Decoder<float>> projectionBiasTensor;
148
Jan Eilers38e05bd2019-06-26 13:10:09 +0100149 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
150 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
151 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
152 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
153
Narumol Prangnawarate5339e72021-07-28 17:33:28 +0100154 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
155 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
156
Jan Eilers38e05bd2019-06-26 13:10:09 +0100157 if (useLayerNorm)
158 {
159 if (!useCifg)
160 {
161 inputLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000162 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100163 }
164 forgetLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000165 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100166 cellLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000167 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100168 outputLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000169 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100170 }
171
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100172 if (!useCifg)
173 {
174 inputToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000175 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100176 inputGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000177 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100178 recurrentToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000179 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100180 }
181
182 if (usePeephole)
183 {
184 cellToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000185 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100186 cellToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000187 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100188 }
189
190 if (!useCifg && usePeephole)
191 {
192 cellToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000193 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100194 }
195
196 if (m_Data.m_Parameters.m_ProjectionEnabled)
197 {
198 projectionWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000199 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100200 if (m_ProjectionBiasTensor)
201 {
202 projectionBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000203 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100204 }
205 }
206
Narumol Prangnawarate5339e72021-07-28 17:33:28 +0100207 LstmImpl(m_Data.m_Parameters,
208 inputInfo,
209 outputInfo,
210 inputToOutputWeightsShape,
211 recurrentToOutputWeightsShape,
212 inputData,
213 outputStateIn,
214 cellStateIn,
215 outputStateOut,
216 cellStateOut,
217 output,
218 cellStateOutDecoder,
219 outputDecoder,
220 inputToInputWeightsTensor,
221 inputToForgetWeightsTensor,
222 inputToCellWeightsTensor,
223 inputToOutputWeightsTensor,
224 recurrentToInputWeightsTensor,
225 recurrentToForgetWeightsTensor,
226 recurrentToCellWeightsTensor,
227 recurrentToOutputWeightsTensor,
228 cellToInputWeightsTensor,
229 cellToForgetWeightsTensor,
230 cellToOutputWeightsTensor,
231 inputGateBiasTensor,
232 forgetGateBiasTensor,
233 cellBiasTensor,
234 outputGateBiasTensor,
235 projectionWeightsTensor,
236 projectionBiasTensor,
237 inputLayerNormWeights,
238 forgetLayerNormWeights,
239 cellLayerNormWeights,
240 outputLayerNormWeights,
241 inputGateScratch,
242 cellScratch,
243 forgetGateScratch,
244 outputGateScratch,
245 inputGateScratchDecoder,
246 cellScratchDecoder,
247 forgetGateScratchDecoder,
248 outputGateScratchDecoder,
249 m_LayerNormEpsilon);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100250}
251
252} //namespace armnn