blob: 0ae371575b8d0405e7173b256c8f3703fae18e7a [file] [log] [blame]
Ryan OShea2323af42020-05-13 16:36:19 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ClQLstmWorkload.hpp"
7#include "ClWorkloadUtils.hpp"
8
9#include "aclCommon/ArmComputeTensorUtils.hpp"
10
11#include "cl/ClTensorHandle.hpp"
12
13namespace armnn
14{
15using namespace armcomputetensorutils;
16
Sadik Armagane9444752020-12-02 11:28:58 +000017ClQLstmWorkload::ClQLstmWorkload(const QLstmQueueDescriptor& descriptor,
18 const WorkloadInfo& info,
19 const arm_compute::CLCompileContext& clCompileContext)
Ryan OShea2323af42020-05-13 16:36:19 +010020 : BaseWorkload<QLstmQueueDescriptor>(descriptor, info)
21{
22 arm_compute::LSTMParams<arm_compute::ICLTensor> qLstmParams;
23
24 // Mandatory params
25 m_InputToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
26 BuildArmComputeTensor(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights->GetTensorInfo());
27
28 m_InputToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
29 BuildArmComputeTensor(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights->GetTensorInfo());
30
31 m_InputToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
32 BuildArmComputeTensor(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights->GetTensorInfo());
33
34 m_RecurrentToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
35 BuildArmComputeTensor(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights->GetTensorInfo());
36
37 m_RecurrentToCellWeightsTensor = std::make_unique<arm_compute::CLTensor>();
38 BuildArmComputeTensor(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights->GetTensorInfo());
39
40 m_RecurrentToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
41 BuildArmComputeTensor(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights->GetTensorInfo());
42
43 m_ForgetGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
44 BuildArmComputeTensor(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias->GetTensorInfo());
45
46 m_CellBiasTensor = std::make_unique<arm_compute::CLTensor>();
47 BuildArmComputeTensor(*m_CellBiasTensor, m_Data.m_CellBias->GetTensorInfo());
48
49 m_OutputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
50 BuildArmComputeTensor(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias->GetTensorInfo());
51
52 // Create tensors for optional params if they are enabled
53 if (m_Data.m_Parameters.m_PeepholeEnabled)
54 {
55 m_CellToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
56
57 if (!m_Data.m_Parameters.m_CifgEnabled)
58 {
59 // In ACL this is categorised as a CIFG param and not a Peephole param
60 BuildArmComputeTensor(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights->GetTensorInfo());
61 }
62
63 m_CellToForgetWeightsTensor = std::make_unique<arm_compute::CLTensor>();
64 BuildArmComputeTensor(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights->GetTensorInfo());
65
66 m_CellToOutputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
67 BuildArmComputeTensor(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights->GetTensorInfo());
68
69 // Set Peephole params
70 qLstmParams.set_peephole_params(m_CellToForgetWeightsTensor.get(),
71 m_CellToOutputWeightsTensor.get());
72 }
73
74 if (m_Data.m_Parameters.m_ProjectionEnabled)
75 {
76 m_ProjectionWeightsTensor = std::make_unique<arm_compute::CLTensor>();
77 BuildArmComputeTensor(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights->GetTensorInfo());
78
79 m_ProjectionBiasTensor = std::make_unique<arm_compute::CLTensor>();
80 if (m_Data.m_ProjectionBias != nullptr)
81 {
82 BuildArmComputeTensor(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias->GetTensorInfo());
83 }
84
85 // Set projection params
86 qLstmParams.set_projection_params(
87 m_ProjectionWeightsTensor.get(),
88 m_Data.m_ProjectionBias != nullptr ? m_ProjectionBiasTensor.get() : nullptr);
89 }
90
91 if (m_Data.m_Parameters.m_LayerNormEnabled)
92 {
93 m_InputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
94
95 if (!m_Data.m_Parameters.m_CifgEnabled)
96 {
97 BuildArmComputeTensor(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights->GetTensorInfo());
98 }
99
100 m_ForgetLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
101 BuildArmComputeTensor(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights->GetTensorInfo());
102
103 m_CellLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
104 BuildArmComputeTensor(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights->GetTensorInfo());
105
106 m_OutputLayerNormWeightsTensor = std::make_unique<arm_compute::CLTensor>();
107 BuildArmComputeTensor(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights->GetTensorInfo());
108
Teresa Charlinbe727be2020-09-25 15:08:21 +0100109 // Set layer norm params
Ryan OShea2323af42020-05-13 16:36:19 +0100110 qLstmParams.set_layer_normalization_params(
111 m_Data.m_InputLayerNormWeights != nullptr ? m_InputLayerNormWeightsTensor.get() : nullptr,
112 m_ForgetLayerNormWeightsTensor.get(),
113 m_CellLayerNormWeightsTensor.get(),
114 m_OutputLayerNormWeightsTensor.get());
115 }
116
117 if (!m_Data.m_Parameters.m_CifgEnabled)
118 {
119 m_InputToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
120 BuildArmComputeTensor(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights->GetTensorInfo());
121
122 m_RecurrentToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>();
123 BuildArmComputeTensor(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights->GetTensorInfo());
124
125 m_InputGateBiasTensor = std::make_unique<arm_compute::CLTensor>();
126 BuildArmComputeTensor(*m_InputGateBiasTensor, m_Data.m_InputGateBias->GetTensorInfo());
127
Teresa Charlinbe727be2020-09-25 15:08:21 +0100128 // Set CIFG params
Ryan OShea2323af42020-05-13 16:36:19 +0100129 qLstmParams.set_cifg_params(
130 m_InputToInputWeightsTensor.get(),
131 m_RecurrentToInputWeightsTensor.get(),
Teresa Charlinbe727be2020-09-25 15:08:21 +0100132 m_Data.m_CellToInputWeights != nullptr ? m_CellToInputWeightsTensor.get() : nullptr,
Ryan OShea2323af42020-05-13 16:36:19 +0100133 m_InputGateBiasTensor.get());
134 }
135
136 // Input/Output tensors
137 const arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
Nikhil Rajf3c27182020-09-24 17:58:34 +0100138 arm_compute::ICLTensor& outputStateIn = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
139 arm_compute::ICLTensor& cellStateIn = static_cast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
Ryan OShea2323af42020-05-13 16:36:19 +0100140
141 arm_compute::ICLTensor& outputStateOut = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
142 arm_compute::ICLTensor& cellStateOut = static_cast<IClTensorHandle*>(m_Data.m_Outputs[1])->GetTensor();
143 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[2])->GetTensor();
144
145 // Set scalar descriptor params
146 qLstmParams.set_cell_clip_params(m_Data.m_Parameters.m_CellClip);
147 qLstmParams.set_projection_clip_params(m_Data.m_Parameters.m_ProjectionClip);
148 qLstmParams.set_hidden_state_params(m_Data.m_Parameters.m_HiddenStateZeroPoint,
149 m_Data.m_Parameters.m_HiddenStateScale);
150 qLstmParams.set_matmul_scale_params(m_Data.m_Parameters.m_InputIntermediateScale,
151 m_Data.m_Parameters.m_ForgetIntermediateScale,
152 m_Data.m_Parameters.m_CellIntermediateScale,
153 m_Data.m_Parameters.m_OutputIntermediateScale);
154
Sadik Armagane9444752020-12-02 11:28:58 +0000155 // QLSTM CL configure
156 m_QLstmLayer.configure(clCompileContext,
157 &input,
Ryan OShea2323af42020-05-13 16:36:19 +0100158 m_InputToForgetWeightsTensor.get(),
159 m_InputToCellWeightsTensor.get(),
160 m_InputToOutputWeightsTensor.get(),
161 m_RecurrentToForgetWeightsTensor.get(),
162 m_RecurrentToCellWeightsTensor.get(),
163 m_RecurrentToOutputWeightsTensor.get(),
164 m_ForgetGateBiasTensor.get(),
165 m_CellBiasTensor.get(),
166 m_OutputGateBiasTensor.get(),
167 &cellStateIn,
168 &outputStateIn,
169 &cellStateOut,
170 &outputStateOut,
171 &output,
172 qLstmParams);
173
Teresa Charlinbe727be2020-09-25 15:08:21 +0100174 // Initialise ACL tensor data for mandatory params
Ryan OShea2323af42020-05-13 16:36:19 +0100175 InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
176 InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
177 InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
178
179 InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
180 InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
181 InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
182
183 InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
184 InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
185 InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
186
Teresa Charlinbe727be2020-09-25 15:08:21 +0100187 // Initialise ACL tensor data for optional params
Ryan OShea2323af42020-05-13 16:36:19 +0100188 if (!m_Data.m_Parameters.m_CifgEnabled)
189 {
190 InitializeArmComputeClTensorData(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights);
191 InitializeArmComputeClTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights);
192 InitializeArmComputeClTensorData(*m_InputGateBiasTensor, m_Data.m_InputGateBias);
193 }
194
195 if (m_Data.m_Parameters.m_ProjectionEnabled)
196 {
197 InitializeArmComputeClTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights);
198
199 if (m_Data.m_ProjectionBias != nullptr)
200 {
201 InitializeArmComputeClTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias);
202 }
203 }
204
205 if (m_Data.m_Parameters.m_PeepholeEnabled)
206 {
207 if (!m_Data.m_Parameters.m_CifgEnabled)
208 {
209 InitializeArmComputeClTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights);
210 }
211
212 InitializeArmComputeClTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights);
213 InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
214 }
215
216 if (m_Data.m_Parameters.m_LayerNormEnabled)
217 {
218 if (!m_Data.m_Parameters.m_CifgEnabled)
219 {
220 InitializeArmComputeClTensorData(*m_InputLayerNormWeightsTensor, m_Data.m_InputLayerNormWeights);
221 }
222 InitializeArmComputeClTensorData(*m_ForgetLayerNormWeightsTensor, m_Data.m_ForgetLayerNormWeights);
223 InitializeArmComputeClTensorData(*m_CellLayerNormWeightsTensor, m_Data.m_CellLayerNormWeights);
224 InitializeArmComputeClTensorData(*m_OutputLayerNormWeightsTensor, m_Data.m_OutputLayerNormWeights);
225 }
226
227 m_QLstmLayer.prepare();
228
229 FreeUnusedTensors();
230}
231
232void ClQLstmWorkload::Execute() const
233{
234 m_QLstmLayer.run();
235}
236
237arm_compute::Status ClQLstmWorkloadValidate(const TensorInfo& input,
238 const TensorInfo& cellStateIn,
239 const TensorInfo& outputStateIn,
240 const TensorInfo& cellStateOut,
241 const TensorInfo& outputStateOut,
242 const TensorInfo& output,
243 const QLstmDescriptor& descriptor,
244 const LstmInputParamsInfo& paramsInfo)
245{
246 arm_compute::LSTMParams<arm_compute::ITensorInfo> aclParamsInfo;
247
Teresa Charlinbe727be2020-09-25 15:08:21 +0100248 // Input/Output tensor info
Ryan OShea2323af42020-05-13 16:36:19 +0100249 const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
250 const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
251 const arm_compute::TensorInfo aclCellStateInInfo = BuildArmComputeTensorInfo(cellStateIn);
252
253 const arm_compute::TensorInfo aclOutputStateOutInfo = BuildArmComputeTensorInfo(outputStateOut);
254 const arm_compute::TensorInfo aclCellStateOutInfo = BuildArmComputeTensorInfo(cellStateOut);
255 const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
256
257 // Mandatory tensor info
258 const arm_compute::TensorInfo aclInputToForgetWeightsInfo
259 = BuildArmComputeTensorInfo(paramsInfo.GetInputToForgetWeights());
260 const arm_compute::TensorInfo aclInputToCellWeightsInfo
261 = BuildArmComputeTensorInfo(paramsInfo.GetInputToCellWeights());
262 const arm_compute::TensorInfo aclInputToOutputWeightsInfo
263 = BuildArmComputeTensorInfo(paramsInfo.GetInputToOutputWeights());
264 const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo
265 = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToForgetWeights());
266 const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo
267 = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToCellWeights());
268 const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo
269 = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToOutputWeights());
270 const arm_compute::TensorInfo aclForgetGateBiasInfo
271 = BuildArmComputeTensorInfo(paramsInfo.GetForgetGateBias());
272 const arm_compute::TensorInfo aclCellBiasInfo
273 = BuildArmComputeTensorInfo(paramsInfo.GetCellBias());
274 const arm_compute::TensorInfo aclOutputGateBiasInfo
275 = BuildArmComputeTensorInfo(paramsInfo.GetOutputGateBias());
276
277 // Optional tensor info
278 arm_compute::TensorInfo aclInputToInputWeightsInfo;
279 arm_compute::TensorInfo aclRecurrentToInputWeightsInfo;
Teresa Charlinbe727be2020-09-25 15:08:21 +0100280
Ryan OShea2323af42020-05-13 16:36:19 +0100281 arm_compute::TensorInfo aclCellToInputWeightsInfo;
282 arm_compute::TensorInfo aclCellToForgetWeightsInfo;
283 arm_compute::TensorInfo aclCellToOutputWeightsInfo;
Teresa Charlinbe727be2020-09-25 15:08:21 +0100284
Ryan OShea2323af42020-05-13 16:36:19 +0100285 arm_compute::TensorInfo aclInputGateBiasInfo;
Teresa Charlinbe727be2020-09-25 15:08:21 +0100286
Ryan OShea2323af42020-05-13 16:36:19 +0100287 arm_compute::TensorInfo aclProjectionWeightsInfo;
288 arm_compute::TensorInfo aclProjectionBiasInfo;
Teresa Charlinbe727be2020-09-25 15:08:21 +0100289
Ryan OShea2323af42020-05-13 16:36:19 +0100290 arm_compute::TensorInfo aclInputLayerNormWeightsInfo;
291 arm_compute::TensorInfo aclForgetLayerNormWeightsInfo;
292 arm_compute::TensorInfo aclCellLayerNormWeightsInfo;
293 arm_compute::TensorInfo aclOutputLayerNormWeightsInfo;
294
Teresa Charlinbe727be2020-09-25 15:08:21 +0100295 // Create tensor info for optional params if they are enabled
Ryan OShea2323af42020-05-13 16:36:19 +0100296 if (descriptor.m_PeepholeEnabled)
297 {
298 if (!descriptor.m_CifgEnabled)
299 {
300 aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToInputWeights());
301 }
302
303 aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToForgetWeights());
304 aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellToOutputWeights());
305
Teresa Charlinbe727be2020-09-25 15:08:21 +0100306 // Set peephole params info
Ryan OShea2323af42020-05-13 16:36:19 +0100307 aclParamsInfo.set_peephole_params(&aclCellToForgetWeightsInfo,
308 &aclCellToOutputWeightsInfo);
309 }
310
311 if (descriptor.m_ProjectionEnabled)
312 {
313 aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionWeights());
314
315 if (paramsInfo.m_ProjectionBias != nullptr)
316 {
317 aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetProjectionBias());
318 }
319
Teresa Charlinbe727be2020-09-25 15:08:21 +0100320 // Set projection params info
Ryan OShea2323af42020-05-13 16:36:19 +0100321 aclParamsInfo.set_projection_params(
322 &aclProjectionWeightsInfo,
323 paramsInfo.m_ProjectionBias != nullptr ? &aclProjectionBiasInfo : nullptr);
324 }
325
326 if (descriptor.m_LayerNormEnabled)
327 {
328 if (!descriptor.m_CifgEnabled)
329 {
330 aclInputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputLayerNormWeights());
331 }
332
333 aclForgetLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetForgetLayerNormWeights());
334 aclCellLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetCellLayerNormWeights());
335 aclOutputLayerNormWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetOutputLayerNormWeights());
336
Teresa Charlinbe727be2020-09-25 15:08:21 +0100337 // Set layer norm params info
Ryan OShea2323af42020-05-13 16:36:19 +0100338 aclParamsInfo.set_layer_normalization_params(
339 paramsInfo.m_InputLayerNormWeights != nullptr ? &aclInputLayerNormWeightsInfo : nullptr,
340 &aclForgetLayerNormWeightsInfo,
341 &aclCellLayerNormWeightsInfo,
342 &aclOutputLayerNormWeightsInfo);
343 }
344
345 if (!descriptor.m_CifgEnabled)
346 {
347 aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputToInputWeights());
348 aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.GetRecurrentToInputWeights());
349 aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.GetInputGateBias());
350
Teresa Charlinbe727be2020-09-25 15:08:21 +0100351 // Set CIFG params info
Ryan OShea2323af42020-05-13 16:36:19 +0100352 aclParamsInfo.set_cifg_params(
353 &aclInputToInputWeightsInfo,
354 &aclRecurrentToInputWeightsInfo,
Teresa Charlinbe727be2020-09-25 15:08:21 +0100355 paramsInfo.m_CellToInputWeights != nullptr ? &aclCellToInputWeightsInfo : nullptr,
Ryan OShea2323af42020-05-13 16:36:19 +0100356 &aclInputGateBiasInfo);
357 }
358
Teresa Charlinbe727be2020-09-25 15:08:21 +0100359 // Set scalar descriptor params
Ryan OShea2323af42020-05-13 16:36:19 +0100360 aclParamsInfo.set_cell_clip_params(descriptor.m_CellClip);
361 aclParamsInfo.set_projection_clip_params(descriptor.m_ProjectionClip);
362 aclParamsInfo.set_hidden_state_params(descriptor.m_HiddenStateZeroPoint, descriptor.m_HiddenStateScale);
363 aclParamsInfo.set_matmul_scale_params(descriptor.m_InputIntermediateScale,
364 descriptor.m_ForgetIntermediateScale,
365 descriptor.m_CellIntermediateScale,
366 descriptor.m_OutputIntermediateScale);
367
Teresa Charlinbe727be2020-09-25 15:08:21 +0100368 // QLSTM CL validate
Ryan OShea2323af42020-05-13 16:36:19 +0100369 return arm_compute::CLQLSTMLayer::validate(&aclInputInfo,
370 &aclInputToForgetWeightsInfo,
371 &aclInputToCellWeightsInfo,
372 &aclInputToOutputWeightsInfo,
373 &aclRecurrentToForgetWeightsInfo,
374 &aclRecurrentToCellWeightsInfo,
375 &aclRecurrentToOutputWeightsInfo,
376 &aclForgetGateBiasInfo,
377 &aclCellBiasInfo,
378 &aclOutputGateBiasInfo,
379 &aclCellStateInInfo,
380 &aclOutputStateInInfo,
381 &aclCellStateOutInfo,
382 &aclOutputStateOutInfo,
383 &aclOutputInfo,
384 aclParamsInfo);
385}
386
387void ClQLstmWorkload::FreeUnusedTensors()
388{
389 FreeTensorIfUnused(m_InputToInputWeightsTensor);
390 FreeTensorIfUnused(m_InputToForgetWeightsTensor);
391 FreeTensorIfUnused(m_InputToCellWeightsTensor);
392 FreeTensorIfUnused(m_InputToOutputWeightsTensor);
393
394 FreeTensorIfUnused(m_RecurrentToInputWeightsTensor);
395 FreeTensorIfUnused(m_RecurrentToForgetWeightsTensor);
396 FreeTensorIfUnused(m_RecurrentToCellWeightsTensor);
397 FreeTensorIfUnused(m_RecurrentToOutputWeightsTensor);
398
399 FreeTensorIfUnused(m_CellToInputWeightsTensor);
400 FreeTensorIfUnused(m_CellToForgetWeightsTensor);
401 FreeTensorIfUnused(m_CellToOutputWeightsTensor);
402
403 FreeTensorIfUnused(m_InputGateBiasTensor);
404 FreeTensorIfUnused(m_ForgetGateBiasTensor);
405 FreeTensorIfUnused(m_CellBiasTensor);
406 FreeTensorIfUnused(m_OutputGateBiasTensor);
407
408 FreeTensorIfUnused(m_ProjectionWeightsTensor);
409 FreeTensorIfUnused(m_ProjectionBiasTensor);
410
411 FreeTensorIfUnused(m_InputLayerNormWeightsTensor);
412 FreeTensorIfUnused(m_ForgetLayerNormWeightsTensor);
413 FreeTensorIfUnused(m_CellLayerNormWeightsTensor);
414 FreeTensorIfUnused(m_OutputLayerNormWeightsTensor);
415}
416
417} //namespace armnn