blob: 7ece05f5ffeccb382eea6358404c98b76eed1d3a [file] [log] [blame]
Ryan OShea2323af42020-05-13 16:36:19 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ClQLstmWorkload.hpp"
7#include "ClWorkloadUtils.hpp"
8
9#include "aclCommon/ArmComputeTensorUtils.hpp"
10
11#include "cl/ClTensorHandle.hpp"
12
13namespace armnn
14{
15using namespace armcomputetensorutils;
16
Teresa Charlinbe727be2020-09-25 15:08:21 +010017ClQLstmWorkload::ClQLstmWorkload(const QLstmQueueDescriptor& descriptor, const WorkloadInfo& info)
Ryan OShea2323af42020-05-13 16:36:19 +010018 : BaseWorkload<QLstmQueueDescriptor>(descriptor, info)
19{
20 arm_compute::LSTMParams<arm_compute::ICLTensor> qLstmParams;
21
22 // Mandatory params
23 m_InputToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
24 BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
25
26 m_InputToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
27 BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo());
28
29 m_InputToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
30 BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo());
31
32 m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
33 BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
34
35 m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
36 BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo());
37
38 m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
39 BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
40
41 m_ForgetGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
42 BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo());
43
44 m_CellBiasTensor = std::make_unique<arm_compute::CLTensor>();
45 BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo());
46
47 m_OutputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
48 BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo());
49
50 // Create tensors for optional params if they are enabled
51 if (m_Data.m_Parameters.m_PeepholeEnabled)
52 {
53 m_CellToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
54
55 if (!m_Data.m_Parameters.m_CifgEnabled)
56 {
57 // In ACL this is categorised as a CIFG param and not a Peephole param
58 BuildArmComputeTensor(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights->GetTensorInfo());
59 }
60
61 m_CellToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
62 BuildArmComputeTensor(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights->GetTensorInfo());
63
64 m_CellToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
65 BuildArmComputeTensor(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights->GetTensorInfo());
66
67 // Set Peephole params
68 qLstmParams.set_peephole_params(m_CellToForgetWeightsTensor.get(),
69 m_CellToOutputWeightsTensor.get());
70 }
71
72 if (m_Data.m_Parameters.m_ProjectionEnabled)
73 {
74 m_ProjectionWeightsTensor = std::make_unique<arm_compute::CLTensor>();
75 BuildArmComputeTensor(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights->GetTensorInfo());
76
77 m_ProjectionBiasTensor = std::make_unique<arm_compute::CLTensor>();
78 if (m_Data.m_ProjectionBias != nullptr)
79 {
80 BuildArmComputeTensor(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias->GetTensorInfo());
81 }
82
83 // Set projection params
84 qLstmParams.set_projection_params(
85 m_ProjectionWeightsTensor.get(),
86 m_Data.m_ProjectionBias != nullptr ? m_ProjectionBiasTensor.get() : nullptr);
87 }
88
89 if (m_Data.m_Parameters.m_LayerNormEnabled)
90 {
91 m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
92
93 if (!m_Data.m_Parameters.m_CifgEnabled)
94 {
95 BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
96 }
97
98 m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
99 BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
100
101 m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
102 BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
103
104 m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
105 BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
106
Teresa Charlinbe727be2020-09-25 15:08:21 +0100107 // Set layer norm params
Ryan OShea2323af42020-05-13 16:36:19 +0100108 qLstmParams.set_layer_normalization_params(
109 m_Data.m_InputLayerNormWeights != nullptr ? m_InputLayerNormWeightsTensor.get() : nullptr,
110 m_ForgetLayerNormWeightsTensor.get(),
111 m_CellLayerNormWeightsTensor.get(),
112 m_OutputLayerNormWeightsTensor.get());
113 }
114
115 if (!m_Data.m_Parameters.m_CifgEnabled)
116 {
117 m_InputToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
118 BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo());
119
120 m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
121 BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo());
122
123 m_InputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
124 BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo());
125
Teresa Charlinbe727be2020-09-25 15:08:21 +0100126 // Set CIFG params
Ryan OShea2323af42020-05-13 16:36:19 +0100127 qLstmParams.set_cifg_params(
128 m_InputToInputWeightsTensor.get(),
129 m_RecurrentToInputWeightsTensor.get(),
Teresa Charlinbe727be2020-09-25 15:08:21 +0100130 m_Data.m_CellToInputWeights != nullptr ? m_CellToInputWeightsTensor.get() : nullptr,
Ryan OShea2323af42020-05-13 16:36:19 +0100131 m_InputGateBiasTensor.get());
132 }
133
134 // Input/Output tensors
135 const arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
Nikhil Rajf3c27182020-09-24 17:58:34 +0100136 arm_compute::ICLTensor& outputStateIn = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
137 arm_compute::ICLTensor& cellStateIn = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
Ryan OShea2323af42020-05-13 16:36:19 +0100138
139 arm_compute::ICLTensor& outputStateOut = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
140 arm_compute::ICLTensor& cellStateOut = static_cast<IClTensorHandle*>(m_Data.m_Outputs[1])->GetTensor();
141 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
142
143 // Set scalar descriptor params
144 qLstmParams.set_cell_clip_params(m_Data.m_Parameters.m_CellClip);
145 qLstmParams.set_projection_clip_params(m_Data.m_Parameters.m_ProjectionClip);
146 qLstmParams.set_hidden_state_params(m_Data.m_Parameters.m_HiddenStateZeroPoint,
147 m_Data.m_Parameters.m_HiddenStateScale);
148 qLstmParams.set_matmul_scale_params(m_Data.m_Parameters.m_InputIntermediateScale,
149 m_Data.m_Parameters.m_ForgetIntermediateScale,
150 m_Data.m_Parameters.m_CellIntermediateScale,
151 m_Data.m_Parameters.m_OutputIntermediateScale);
152
Teresa Charlinbe727be2020-09-25 15:08:21 +0100153 // QLSTM NEON configure
Ryan OShea2323af42020-05-13 16:36:19 +0100154 m_QLstmLayer.configure(&input,
155 m_InputToForgetWeightsTensor.get(),
156 m_InputToCellWeightsTensor.get(),
157 m_InputToOutputWeightsTensor.get(),
158 m_RecurrentToForgetWeightsTensor.get(),
159 m_RecurrentToCellWeightsTensor.get(),
160 m_RecurrentToOutputWeightsTensor.get(),
161 m_ForgetGateBiasTensor.get(),
162 m_CellBiasTensor.get(),
163 m_OutputGateBiasTensor.get(),
164 &cellStateIn,
165 &outputStateIn,
166 &cellStateOut,
167 &outputStateOut,
168 &output,
169 qLstmParams);
170
Teresa Charlinbe727be2020-09-25 15:08:21 +0100171 // Initialise ACL tensor data for mandatory params
Ryan OShea2323af42020-05-13 16:36:19 +0100172 InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
173 InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
174 InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
175
176 InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
177 InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
178 InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
179
180 InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
181 InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
182 InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
183
Teresa Charlinbe727be2020-09-25 15:08:21 +0100184 // Initialise ACL tensor data for optional params
Ryan OShea2323af42020-05-13 16:36:19 +0100185 if (!m_Data.m_Parameters.m_CifgEnabled)
186 {
187 InitializeArmComputeClTensorData(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights);
188 InitializeArmComputeClTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights);
189 InitializeArmComputeClTensorData(*m_InputGateBiasTensor, m_Data.m_InputGateBias);
190 }
191
192 if (m_Data.m_Parameters.m_ProjectionEnabled)
193 {
194 InitializeArmComputeClTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights);
195
196 if (m_Data.m_ProjectionBias != nullptr)
197 {
198 InitializeArmComputeClTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias);
199 }
200 }
201
202 if (m_Data.m_Parameters.m_PeepholeEnabled)
203 {
204 if (!m_Data.m_Parameters.m_CifgEnabled)
205 {
206 InitializeArmComputeClTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights);
207 }
208
209 InitializeArmComputeClTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights);
210 InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
211 }
212
213 if (m_Data.m_Parameters.m_LayerNormEnabled)
214 {
215 if (!m_Data.m_Parameters.m_CifgEnabled)
216 {
217 InitializeArmComputeClTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
218 }
219 InitializeArmComputeClTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
220 InitializeArmComputeClTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights);
221 InitializeArmComputeClTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
222 }
223
224 m_QLstmLayer.prepare();
225
226 FreeUnusedTensors();
227}
228
229void ClQLstmWorkload::Execute() const
230{
231 m_QLstmLayer.run();
232}
233
234arm_compute::Status ClQLstmWorkloadValidate(const TensorInfo& input,
235 const TensorInfo& cellStateIn,
236 const TensorInfo& outputStateIn,
237 const TensorInfo& cellStateOut,
238 const TensorInfo& outputStateOut,
239 const TensorInfo& output,
240 const QLstmDescriptor& descriptor,
241 const LstmInputParamsInfo& paramsInfo)
242{
243 arm_compute::LSTMParams<arm_compute::ITensorInfo> aclParamsInfo;
244
Teresa Charlinbe727be2020-09-25 15:08:21 +0100245 // Input/Output tensor info
Ryan OShea2323af42020-05-13 16:36:19 +0100246 const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
247 const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
248 const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn);
249
250 const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
251 const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut);
252 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
253
254 // Mandatory tensor info
255 const arm_compute::TensorInfo aclInputToForgetWeightsInfo
256 = BuildArmComputeTensorInfo(paramsInfo.GetInputToForgetWeights());
257 const arm_compute::TensorInfo aclInputToCellWeightsInfo
258 = BuildArmComputeTensorInfo(paramsInfo.GetInputToCellWeights());
259 const arm_compute::TensorInfo aclInputToOutputWeightsInfo
260 = BuildArmComputeTensorInfo(paramsInfo.GetInputToOutputWeights());
261 const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
262 = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToForgetWeights());
263 const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
264 = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToCellWeights());
265 const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
266 = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToOutputWeights());
267 const arm_compute::TensorInfo aclForgetGateBiasInfo
268 = BuildArmComputeTensorInfo(paramsInfo.GetForgetGateBias());
269 const arm_compute::TensorInfo aclCellBiasInfo
270 = BuildArmComputeTensorInfo(paramsInfo.GetCellBias());
271 const arm_compute::TensorInfo aclOutputGateBiasInfo
272 = BuildArmComputeTensorInfo(paramsInfo.GetOutputGateBias());
273
274 // Optional tensor info
275 arm_compute::TensorInfo aclInputToInputWeightsInfo;
276 arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
Teresa Charlinbe727be2020-09-25 15:08:21 +0100277
Ryan OShea2323af42020-05-13 16:36:19 +0100278 arm_compute::TensorInfo aclCellToInputWeightsInfo;
279 arm_compute::TensorInfo aclCellToForgetWeightsInfo;
280 arm_compute::TensorInfo aclCellToOutputWeightsInfo;
Teresa Charlinbe727be2020-09-25 15:08:21 +0100281
Ryan OShea2323af42020-05-13 16:36:19 +0100282 arm_compute::TensorInfo aclInputGateBiasInfo;
Teresa Charlinbe727be2020-09-25 15:08:21 +0100283
Ryan OShea2323af42020-05-13 16:36:19 +0100284 arm_compute::TensorInfo aclProjectionWeightsInfo;
285 arm_compute::TensorInfo aclProjectionBiasInfo;
Teresa Charlinbe727be2020-09-25 15:08:21 +0100286
Ryan OShea2323af42020-05-13 16:36:19 +0100287 arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
288 arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
289 arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
290 arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
291
Teresa Charlinbe727be2020-09-25 15:08:21 +0100292 // Create tensor info for optional params if they are enabled
Ryan OShea2323af42020-05-13 16:36:19 +0100293 if (descriptor.m_PeepholeEnabled)
294 {
295 if (!descriptor.m_CifgEnabled)
296 {
297 aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToInputWeights());
298 }
299
300 aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToForgetWeights());
301 aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToOutputWeights());
302
Teresa Charlinbe727be2020-09-25 15:08:21 +0100303 // Set peephole params info
Ryan OShea2323af42020-05-13 16:36:19 +0100304 aclParamsInfo.set_peephole_params(&aclCellToForgetWeightsInfo,
305 &aclCellToOutputWeightsInfo);
306 }
307
308 if (descriptor.m_ProjectionEnabled)
309 {
310 aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionWeights());
311
312 if (paramsInfo.m_ProjectionBias != nullptr)
313 {
314 aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionBias());
315 }
316
Teresa Charlinbe727be2020-09-25 15:08:21 +0100317 // Set projection params info
Ryan OShea2323af42020-05-13 16:36:19 +0100318 aclParamsInfo.set_projection_params(
319 &aclProjectionWeightsInfo,
320 paramsInfo.m_ProjectionBias != nullptr ? &aclProjectionBiasInfo : nullptr);
321 }
322
323 if (descriptor.m_LayerNormEnabled)
324 {
325 if (!descriptor.m_CifgEnabled)
326 {
327 aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputLayerNormWeights());
328 }
329
330 aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetLayerNormWeights());
331 aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellLayerNormWeights());
332 aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputLayerNormWeights());
333
Teresa Charlinbe727be2020-09-25 15:08:21 +0100334 // Set layer norm params info
Ryan OShea2323af42020-05-13 16:36:19 +0100335 aclParamsInfo.set_layer_normalization_params(
336 paramsInfo.m_InputLayerNormWeights != nullptr ? &aclInputLayerNormWeightsInfo : nullptr,
337 &aclForgetLayerNormWeightsInfo,
338 &aclCellLayerNormWeightsInfo,
339 &aclOutputLayerNormWeightsInfo);
340 }
341
342 if (!descriptor.m_CifgEnabled)
343 {
344 aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputToInputWeights());
345 aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToInputWeights());
346 aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias());
347
Teresa Charlinbe727be2020-09-25 15:08:21 +0100348 // Set CIFG params info
Ryan OShea2323af42020-05-13 16:36:19 +0100349 aclParamsInfo.set_cifg_params(
350 &aclInputToInputWeightsInfo,
351 &aclRecurrentToInputWeightsInfo,
Teresa Charlinbe727be2020-09-25 15:08:21 +0100352 paramsInfo.m_CellToInputWeights != nullptr ? &aclCellToInputWeightsInfo : nullptr,
Ryan OShea2323af42020-05-13 16:36:19 +0100353 &aclInputGateBiasInfo);
354 }
355
Teresa Charlinbe727be2020-09-25 15:08:21 +0100356 // Set scalar descriptor params
Ryan OShea2323af42020-05-13 16:36:19 +0100357 aclParamsInfo.set_cell_clip_params(descriptor.m_CellClip);
358 aclParamsInfo.set_projection_clip_params(descriptor.m_ProjectionClip);
359 aclParamsInfo.set_hidden_state_params(descriptor.m_HiddenStateZeroPoint, descriptor.m_HiddenStateScale);
360 aclParamsInfo.set_matmul_scale_params(descriptor.m_InputIntermediateScale,
361 descriptor.m_ForgetIntermediateScale,
362 descriptor.m_CellIntermediateScale,
363 descriptor.m_OutputIntermediateScale);
364
Teresa Charlinbe727be2020-09-25 15:08:21 +0100365 // QLSTM CL validate
Ryan OShea2323af42020-05-13 16:36:19 +0100366 return arm_compute::CLQLSTMLayer::validate(&aclInputInfo,
367 &aclInputToForgetWeightsInfo,
368 &aclInputToCellWeightsInfo,
369 &aclInputToOutputWeightsInfo,
370 &aclRecurrentToForgetWeightsInfo,
371 &aclRecurrentToCellWeightsInfo,
372 &aclRecurrentToOutputWeightsInfo,
373 &aclForgetGateBiasInfo,
374 &aclCellBiasInfo,
375 &aclOutputGateBiasInfo,
376 &aclCellStateInInfo,
377 &aclOutputStateInInfo,
378 &aclCellStateOutInfo,
379 &aclOutputStateOutInfo,
380 &aclOutputInfo,
381 aclParamsInfo);
382}
383
384void ClQLstmWorkload::FreeUnusedTensors()
385{
386 FreeTensorIfUnused(m_InputToInputWeightsTensor);
387 FreeTensorIfUnused(m_InputToForgetWeightsTensor);
388 FreeTensorIfUnused(m_InputToCellWeightsTensor);
389 FreeTensorIfUnused(m_InputToOutputWeightsTensor);
390
391 FreeTensorIfUnused(m_RecurrentToInputWeightsTensor);
392 FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor);
393 FreeTensorIfUnused(m_RecurrentToCellWeightsTensor);
394 FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor);
395
396 FreeTensorIfUnused(m_CellToInputWeightsTensor);
397 FreeTensorIfUnused(m_CellToForgetWeightsTensor);
398 FreeTensorIfUnused(m_CellToOutputWeightsTensor);
399
400 FreeTensorIfUnused(m_InputGateBiasTensor);
401 FreeTensorIfUnused(m_ForgetGateBiasTensor);
402 FreeTensorIfUnused(m_CellBiasTensor);
403 FreeTensorIfUnused(m_OutputGateBiasTensor);
404
405 FreeTensorIfUnused(m_ProjectionWeightsTensor);
406 FreeTensorIfUnused(m_ProjectionBiasTensor);
407
408 FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
409 FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
410 FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
411 FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
412}
413
414} //namespace armnn