blob: 075aa80419ec141abef72b77abbe67d39a5a9e5c [file] [log] [blame]
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001//
Mike Kelly7cbe7812023-07-25 17:37:33 +01002// Copyright © 2019,2021-2023 Arm Ltd and Contributors. All rights reserved.
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "RefLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010010#include "Lstm.hpp"
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010011#include "LstmUtils.hpp"
12#include "RefWorkloadUtils.hpp"
13
14namespace armnn
15{
16
17RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
Finn Williams73c547d2022-02-15 20:47:34 +000018 : RefBaseWorkload<LstmQueueDescriptor>(descriptor, info)
James Conroy1f58f032021-04-27 17:13:27 +010019 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
20 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
21 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
22 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
23 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
27 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
28 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
29 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
30 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
31 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
32 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
33 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
34 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
35 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
36 , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
37 , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
38 , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
39 , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010040{}
41
42void RefLstmWorkload::Execute() const
43{
Finn Williamsb8181f72021-04-07 10:23:21 +010044 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
45}
46
Matthew Sloyan2d213a72022-06-30 17:13:04 +010047void RefLstmWorkload::ExecuteAsync(ExecutionData& executionData)
Finn Williamsb8181f72021-04-07 10:23:21 +010048{
Matthew Sloyan2d213a72022-06-30 17:13:04 +010049 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
50 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
Finn Williamsb8181f72021-04-07 10:23:21 +010051}
52
53void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
54{
Mike Kelly7cbe7812023-07-25 17:37:33 +010055 ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefLstmWorkload_Execute");
56
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010057 // This is a porting of the LSTM::Eval() method in the Android code base
58 // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
59
Finn Williamsb8181f72021-04-07 10:23:21 +010060 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
61 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010062
63 const TensorShape& inputShape = inputInfo.GetShape();
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010064
Finn Williamsb8181f72021-04-07 10:23:21 +010065 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
66 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map());
67 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010068
Finn Williamsb8181f72021-04-07 10:23:21 +010069 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
70 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010071
Finn Williamsb8181f72021-04-07 10:23:21 +010072 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map());
73 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
74 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010075
76 const uint32_t nBatch = inputShape[0];
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010077 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010078
Jan Eilers38e05bd2019-06-26 13:10:09 +010079 const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
80 const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
81 const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010082
83 // Index the scratch buffers pointers to the global scratch buffer.
Finn Williamsb8181f72021-04-07 10:23:21 +010084 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
85 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
86 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
87 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010088
89 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010090 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010091 std::unique_ptr<Decoder<float>> cellScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010092 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010093 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010094 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010095 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010096 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010097
98 if (useCifg)
99 {
100 *cellScratch += (0 * nCell * nBatch);
101 *forgetGateScratch += (1 * nCell * nBatch);
102 *outputGateScratch += (2 * nCell * nBatch);
103
104 *cellScratchDecoder += (0 * nCell * nBatch);
105 *forgetGateScratchDecoder += (1 * nCell * nBatch);
106 *outputGateScratchDecoder += (2 * nCell * nBatch);
107 }
108 else
109 {
110 *inputGateScratch += (0 * nCell * nBatch);
111 *cellScratch += (1 * nCell * nBatch);
112 *forgetGateScratch += (2 * nCell * nBatch);
113 *outputGateScratch += (3 * nCell * nBatch);
114
115 *inputGateScratchDecoder += (0 * nCell * nBatch);
116 *cellScratchDecoder += (1 * nCell * nBatch);
117 *forgetGateScratchDecoder += (2 * nCell * nBatch);
118 *outputGateScratchDecoder += (3 * nCell * nBatch);
119 }
120
121 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
122 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000123 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100124 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000125 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100126 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000127 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100128
129 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
130 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000131 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100132 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000133 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100134 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000135 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100136
137 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
138 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000139 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100140 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000141 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100142 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000143 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100144
145 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
146 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
147 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
148
149 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
150 std::unique_ptr<Decoder<float>> projectionBiasTensor;
151
Jan Eilers38e05bd2019-06-26 13:10:09 +0100152 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
153 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
154 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
155 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
156
Narumol Prangnawarate5339e72021-07-28 17:33:28 +0100157 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
158 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
159
Jan Eilers38e05bd2019-06-26 13:10:09 +0100160 if (useLayerNorm)
161 {
162 if (!useCifg)
163 {
164 inputLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000165 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100166 }
167 forgetLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000168 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100169 cellLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000170 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100171 outputLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000172 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100173 }
174
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100175 if (!useCifg)
176 {
177 inputToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000178 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100179 inputGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000180 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100181 recurrentToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000182 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100183 }
184
185 if (usePeephole)
186 {
187 cellToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000188 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100189 cellToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000190 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100191 }
192
193 if (!useCifg && usePeephole)
194 {
195 cellToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000196 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100197 }
198
199 if (m_Data.m_Parameters.m_ProjectionEnabled)
200 {
201 projectionWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000202 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100203 if (m_ProjectionBiasTensor)
204 {
205 projectionBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000206 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100207 }
208 }
209
Narumol Prangnawarate5339e72021-07-28 17:33:28 +0100210 LstmImpl(m_Data.m_Parameters,
211 inputInfo,
212 outputInfo,
213 inputToOutputWeightsShape,
214 recurrentToOutputWeightsShape,
215 inputData,
216 outputStateIn,
217 cellStateIn,
218 outputStateOut,
219 cellStateOut,
220 output,
221 cellStateOutDecoder,
222 outputDecoder,
223 inputToInputWeightsTensor,
224 inputToForgetWeightsTensor,
225 inputToCellWeightsTensor,
226 inputToOutputWeightsTensor,
227 recurrentToInputWeightsTensor,
228 recurrentToForgetWeightsTensor,
229 recurrentToCellWeightsTensor,
230 recurrentToOutputWeightsTensor,
231 cellToInputWeightsTensor,
232 cellToForgetWeightsTensor,
233 cellToOutputWeightsTensor,
234 inputGateBiasTensor,
235 forgetGateBiasTensor,
236 cellBiasTensor,
237 outputGateBiasTensor,
238 projectionWeightsTensor,
239 projectionBiasTensor,
240 inputLayerNormWeights,
241 forgetLayerNormWeights,
242 cellLayerNormWeights,
243 outputLayerNormWeights,
244 inputGateScratch,
245 cellScratch,
246 forgetGateScratch,
247 outputGateScratch,
248 inputGateScratchDecoder,
249 cellScratchDecoder,
250 forgetGateScratchDecoder,
251 outputGateScratchDecoder,
252 m_LayerNormEpsilon);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100253}
254
255} //namespace armnn