blob: 3879051a5bebf9817889e1a6faa180aa48ce1b6a [file] [log] [blame]
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01001//
Matthew Sloyan21a6a1a2022-06-30 17:13:04 +01002// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +01003// SPDX-License-Identifier: MIT
4//
5
6#include "RefLstmWorkload.hpp"
7#include "Activation.hpp"
8#include "Encoders.hpp"
9#include "Decoders.hpp"
Narumol Prangnawarate5339e72021-07-28 17:33:28 +010010#include "Lstm.hpp"
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010011#include "LstmUtils.hpp"
12#include "RefWorkloadUtils.hpp"
13
14namespace armnn
15{
16
17RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
Finn Williams73c547d2022-02-15 20:47:34 +000018 : RefBaseWorkload<LstmQueueDescriptor>(descriptor, info)
James Conroy1f58f032021-04-27 17:13:27 +010019 , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
20 , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
21 , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
22 , m_InputToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToOutputWeights))
23 , m_RecurrentToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToInputWeights))
24 , m_RecurrentToForgetWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToForgetWeights))
25 , m_RecurrentToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_RecurrentToCellWeights))
26 , m_RecurrentToOutputWeightsTensor(AssignScopedTensorHandle(descriptor.m_RecurrentToOutputWeights))
27 , m_CellToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToInputWeights))
28 , m_CellToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToForgetWeights))
29 , m_CellToOutputWeightsTensor (AssignScopedTensorHandle(descriptor.m_CellToOutputWeights))
30 , m_InputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_InputGateBias))
31 , m_ForgetGateBiasTensor (AssignScopedTensorHandle(descriptor.m_ForgetGateBias))
32 , m_CellBiasTensor (AssignScopedTensorHandle(descriptor.m_CellBias))
33 , m_OutputGateBiasTensor (AssignScopedTensorHandle(descriptor.m_OutputGateBias))
34 , m_ProjectionWeightsTensor (AssignScopedTensorHandle(descriptor.m_ProjectionWeights))
35 , m_ProjectionBiasTensor (AssignScopedTensorHandle(descriptor.m_ProjectionBias))
36 , m_InputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_InputLayerNormWeights))
37 , m_ForgetLayerNormWeights (AssignScopedTensorHandle(descriptor.m_ForgetLayerNormWeights))
38 , m_CellLayerNormWeights (AssignScopedTensorHandle(descriptor.m_CellLayerNormWeights))
39 , m_OutputLayerNormWeights (AssignScopedTensorHandle(descriptor.m_OutputLayerNormWeights))
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010040{}
41
42void RefLstmWorkload::Execute() const
43{
Finn Williamsb8181f72021-04-07 10:23:21 +010044 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
45}
46
Matthew Sloyan21a6a1a2022-06-30 17:13:04 +010047void RefLstmWorkload::ExecuteAsync(ExecutionData& executionData)
Finn Williamsb8181f72021-04-07 10:23:21 +010048{
Matthew Sloyan21a6a1a2022-06-30 17:13:04 +010049 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
50 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
Finn Williamsb8181f72021-04-07 10:23:21 +010051}
52
53void RefLstmWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
54{
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010055 // This is a porting of the LSTM::Eval() method in the Android code base
56 // Refer to: android/frameworks/ml/nn/common/operations/LSTM.cpp
57
Finn Williamsb8181f72021-04-07 10:23:21 +010058 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
59 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010060
61 const TensorShape& inputShape = inputInfo.GetShape();
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010062
Finn Williamsb8181f72021-04-07 10:23:21 +010063 std::unique_ptr<Encoder<float>> outputStateOut = MakeEncoder<float>(outputInfo, outputs[1]->Map());
64 std::unique_ptr<Encoder<float>> cellStateOut = MakeEncoder<float>(outputInfo, outputs[2]->Map());
65 std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(outputInfo, outputs[3]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010066
Finn Williamsb8181f72021-04-07 10:23:21 +010067 std::unique_ptr<Decoder<float>> cellStateOutDecoder = MakeDecoder<float>(outputInfo, outputs[2]->Map());
68 std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(outputInfo, outputs[3]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010069
Finn Williamsb8181f72021-04-07 10:23:21 +010070 std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(inputInfo, inputs[0]->Map());
71 std::unique_ptr<Decoder<float>> outputStateIn = MakeDecoder<float>(inputInfo, inputs[1]->Map());
72 std::unique_ptr<Decoder<float>> cellStateIn = MakeDecoder<float>(inputInfo, inputs[2]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010073
74 const uint32_t nBatch = inputShape[0];
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010075 const uint32_t nCell = m_InputToOutputWeightsTensor->GetShape()[0];
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010076
Jan Eilers38e05bd2019-06-26 13:10:09 +010077 const bool useCifg = m_Data.m_Parameters.m_CifgEnabled;
78 const bool usePeephole = m_Data.m_Parameters.m_PeepholeEnabled;
79 const bool useLayerNorm = m_Data.m_Parameters.m_LayerNormEnabled;
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010080
81 // Index the scratch buffers pointers to the global scratch buffer.
Finn Williamsb8181f72021-04-07 10:23:21 +010082 std::unique_ptr<Encoder<float>> inputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
83 std::unique_ptr<Encoder<float>> cellScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
84 std::unique_ptr<Encoder<float>> forgetGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
85 std::unique_ptr<Encoder<float>> outputGateScratch = MakeEncoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010086
87 std::unique_ptr<Decoder<float>> inputGateScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010088 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010089 std::unique_ptr<Decoder<float>> cellScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010090 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010091 std::unique_ptr<Decoder<float>> forgetGateScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010092 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010093 std::unique_ptr<Decoder<float>> outputGateScratchDecoder =
Finn Williamsb8181f72021-04-07 10:23:21 +010094 MakeDecoder<float>(outputInfo, outputs[0]->Map());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +010095
96 if (useCifg)
97 {
98 *cellScratch += (0 * nCell * nBatch);
99 *forgetGateScratch += (1 * nCell * nBatch);
100 *outputGateScratch += (2 * nCell * nBatch);
101
102 *cellScratchDecoder += (0 * nCell * nBatch);
103 *forgetGateScratchDecoder += (1 * nCell * nBatch);
104 *outputGateScratchDecoder += (2 * nCell * nBatch);
105 }
106 else
107 {
108 *inputGateScratch += (0 * nCell * nBatch);
109 *cellScratch += (1 * nCell * nBatch);
110 *forgetGateScratch += (2 * nCell * nBatch);
111 *outputGateScratch += (3 * nCell * nBatch);
112
113 *inputGateScratchDecoder += (0 * nCell * nBatch);
114 *cellScratchDecoder += (1 * nCell * nBatch);
115 *forgetGateScratchDecoder += (2 * nCell * nBatch);
116 *outputGateScratchDecoder += (3 * nCell * nBatch);
117 }
118
119 std::unique_ptr<Decoder<float>> inputToInputWeightsTensor;
120 std::unique_ptr<Decoder<float>> inputToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000121 m_InputToForgetWeightsTensor->GetTensorInfo(), m_InputToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100122 std::unique_ptr<Decoder<float>> inputToCellWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000123 m_InputToCellWeightsTensor->GetTensorInfo(), m_InputToCellWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100124 std::unique_ptr<Decoder<float>> inputToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000125 m_InputToOutputWeightsTensor->GetTensorInfo(), m_InputToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100126
127 std::unique_ptr<Decoder<float>> recurrentToInputWeightsTensor;
128 std::unique_ptr<Decoder<float>> recurrentToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000129 m_RecurrentToForgetWeightsTensor->GetTensorInfo(), m_RecurrentToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100130 std::unique_ptr<Decoder<float>> recurrentToCellWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000131 m_RecurrentToCellWeightsTensor->GetTensorInfo(), m_RecurrentToCellWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100132 std::unique_ptr<Decoder<float>> recurrentToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000133 m_RecurrentToOutputWeightsTensor->GetTensorInfo(), m_RecurrentToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100134
135 std::unique_ptr<Decoder<float>> inputGateBiasTensor;
136 std::unique_ptr<Decoder<float>> forgetGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000137 m_ForgetGateBiasTensor->GetTensorInfo(), m_ForgetGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100138 std::unique_ptr<Decoder<float>> cellBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000139 m_CellBiasTensor->GetTensorInfo(), m_CellBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100140 std::unique_ptr<Decoder<float>> outputGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000141 m_OutputGateBiasTensor->GetTensorInfo(), m_OutputGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100142
143 std::unique_ptr<Decoder<float>> cellToInputWeightsTensor;
144 std::unique_ptr<Decoder<float>> cellToForgetWeightsTensor;
145 std::unique_ptr<Decoder<float>> cellToOutputWeightsTensor;
146
147 std::unique_ptr<Decoder<float>> projectionWeightsTensor;
148 std::unique_ptr<Decoder<float>> projectionBiasTensor;
149
Jan Eilers38e05bd2019-06-26 13:10:09 +0100150 std::unique_ptr<Decoder<float>> inputLayerNormWeights;
151 std::unique_ptr<Decoder<float>> forgetLayerNormWeights;
152 std::unique_ptr<Decoder<float>> cellLayerNormWeights;
153 std::unique_ptr<Decoder<float>> outputLayerNormWeights;
154
Narumol Prangnawarate5339e72021-07-28 17:33:28 +0100155 const TensorShape& inputToOutputWeightsShape = m_InputToOutputWeightsTensor->GetShape();
156 const TensorShape& recurrentToOutputWeightsShape = m_RecurrentToOutputWeightsTensor->GetShape();
157
Jan Eilers38e05bd2019-06-26 13:10:09 +0100158 if (useLayerNorm)
159 {
160 if (!useCifg)
161 {
162 inputLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000163 m_InputLayerNormWeights->GetTensorInfo(), m_InputLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100164 }
165 forgetLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000166 m_ForgetLayerNormWeights->GetTensorInfo(), m_ForgetLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100167 cellLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000168 m_CellLayerNormWeights->GetTensorInfo(), m_CellLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100169 outputLayerNormWeights = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000170 m_OutputLayerNormWeights->GetTensorInfo(), m_OutputLayerNormWeights->GetConstTensor<void>());
Jan Eilers38e05bd2019-06-26 13:10:09 +0100171 }
172
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100173 if (!useCifg)
174 {
175 inputToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000176 m_InputToInputWeightsTensor->GetTensorInfo(), m_InputToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100177 inputGateBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000178 m_InputGateBiasTensor->GetTensorInfo(), m_InputGateBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100179 recurrentToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000180 m_RecurrentToInputWeightsTensor->GetTensorInfo(), m_RecurrentToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100181 }
182
183 if (usePeephole)
184 {
185 cellToForgetWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000186 m_CellToForgetWeightsTensor->GetTensorInfo(), m_CellToForgetWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100187 cellToOutputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000188 m_CellToOutputWeightsTensor->GetTensorInfo(), m_CellToOutputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100189 }
190
191 if (!useCifg && usePeephole)
192 {
193 cellToInputWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000194 m_CellToInputWeightsTensor->GetTensorInfo(), m_CellToInputWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100195 }
196
197 if (m_Data.m_Parameters.m_ProjectionEnabled)
198 {
199 projectionWeightsTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000200 m_ProjectionWeightsTensor->GetTensorInfo(), m_ProjectionWeightsTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100201 if (m_ProjectionBiasTensor)
202 {
203 projectionBiasTensor = MakeDecoder<float>(
Finn Williams4422cec2021-03-22 17:51:06 +0000204 m_ProjectionBiasTensor->GetTensorInfo(), m_ProjectionBiasTensor->GetConstTensor<void>());
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100205 }
206 }
207
Narumol Prangnawarate5339e72021-07-28 17:33:28 +0100208 LstmImpl(m_Data.m_Parameters,
209 inputInfo,
210 outputInfo,
211 inputToOutputWeightsShape,
212 recurrentToOutputWeightsShape,
213 inputData,
214 outputStateIn,
215 cellStateIn,
216 outputStateOut,
217 cellStateOut,
218 output,
219 cellStateOutDecoder,
220 outputDecoder,
221 inputToInputWeightsTensor,
222 inputToForgetWeightsTensor,
223 inputToCellWeightsTensor,
224 inputToOutputWeightsTensor,
225 recurrentToInputWeightsTensor,
226 recurrentToForgetWeightsTensor,
227 recurrentToCellWeightsTensor,
228 recurrentToOutputWeightsTensor,
229 cellToInputWeightsTensor,
230 cellToForgetWeightsTensor,
231 cellToOutputWeightsTensor,
232 inputGateBiasTensor,
233 forgetGateBiasTensor,
234 cellBiasTensor,
235 outputGateBiasTensor,
236 projectionWeightsTensor,
237 projectionBiasTensor,
238 inputLayerNormWeights,
239 forgetLayerNormWeights,
240 cellLayerNormWeights,
241 outputLayerNormWeights,
242 inputGateScratch,
243 cellScratch,
244 forgetGateScratch,
245 outputGateScratch,
246 inputGateScratchDecoder,
247 cellScratchDecoder,
248 forgetGateScratchDecoder,
249 outputGateScratchDecoder,
250 m_LayerNormEpsilon);
Nattapat Chaimanowongeb2b3292019-05-07 12:02:30 +0100251}
252
253} //namespace armnn