blob: 4714b84ba4586520505f2069fa1383722d56079c [file] [log] [blame]
Sadik Armagan1153d1e2020-04-01 15:09:39 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ConversionUtils_1_2.hpp"
9
10using Half = half_float::half;
11
12namespace armnn_driver
13{
14
15using namespace armnn;
16using namespace android::nn;
17
18template<typename HalPolicy,
19 typename HalOperation = typename HalPolicy::Operation,
20 typename HalModel = typename HalPolicy::Model>
21bool ConvertElu(const HalOperation& operation, const HalModel& model, ConversionData& data)
22{
23 using HalOperandType = typename HalPolicy::OperandType;
24
25 LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
26 if (!input0.IsValid())
27 {
28 return Fail("%s: Operation has invalid inputs", __func__);
29 }
30
31 // Determine data type of input tensor
32 HalOperandType inputType;
33 if (!GetOperandType<HalPolicy>(operation, 0, model, inputType))
34 {
35 return Fail("%s: Operation has invalid inputs", __func__);
36 }
37
38 ActivationDescriptor desc;
39 desc.m_Function = ActivationFunction::Elu;
40
41 // Read alpha
42 if (inputType == HalOperandType::TENSOR_FLOAT16)
43 {
44 Half alpha;
45
46 if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, alpha, model, data))
47 {
48 return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__);
49 }
50
51 desc.m_A = static_cast<float>(alpha);
52 }
53 else if (inputType == HalOperandType::TENSOR_FLOAT32)
54 {
55 if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, desc.m_A, model, data))
56 {
57 return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__);
58 }
59 }
60 else
61 {
62 return Fail("%s: Unsupported input tensor type: %d", __func__, inputType);
63 }
64
65 return ::ConvertToActivation<HalPolicy>(operation, __func__, desc, model, data);
66}
67
Sadik Armagan813f2302020-05-19 14:10:30 +010068template<typename HalPolicy,
69 typename HalOperation = typename HalPolicy::Operation,
70 typename HalModel = typename HalPolicy::Model>
71bool ConvertQuantizedLstm(const HalOperation& operation, const HalModel& model, ConversionData& data)
72{
73 using HalOperand = typename HalPolicy::Operand;
74 using HalOperandType = typename HalPolicy::OperandType;
75
76 ALOGV("HalPolicy::ConvertQuantizedLstm()");
77
78 //Inputs:
79 // 0: The input: A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape [numBatches, inputSize]
80 // specifying the input to the LSTM cell. Tensor is quantized with a fixed quantization range of -1, 127/128.
81 LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
82 if (!input.IsValid())
83 {
84 return Fail("%s: Could not read input 0: input", __func__);
85 }
86
87 // 18: The output state: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, of shape [batch_size, output_size].
88 LayerInputHandle outputStatePrevTimeStep = ConvertToLayerInputHandle<HalPolicy>(operation, 18, model, data);
89 if (!outputStatePrevTimeStep.IsValid())
90 {
91 return Fail("%s: Could not read input 18: outputStatePrevTimeStep", __func__);
92 }
93
94 // 19: The cell state: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape [batch_size, num_units].
95 LayerInputHandle cellStatePrevTimeStep = ConvertToLayerInputHandle<HalPolicy>(operation, 19, model, data);
96 if (!cellStatePrevTimeStep.IsValid())
97 {
98 return Fail("%s: Could not read input 19: cellStatePrevTimeStep", __func__);
99 }
100
101 // Get the mandatory input tensors:
102
103 // 02: The input-to-forget weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
104 // [num_units, input_size].
105 const ConstTensorPin inputToForgetWeightsPin =
106 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data);
107
108 // 03: The input-to-cell weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
109 // [num_units, input_size].
110 const ConstTensorPin inputToCellWeightsPin =
111 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 3, model, data);
112
113 // 04: The input-to-output weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
114 // [num_units, input_size].
115 const ConstTensorPin inputToOutputWeightsPin =
116 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 4, model, data);
117
118 // 06: The recurrent-to-forget weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
119 // [num_units, output_size].
120 const ConstTensorPin recurrentToForgetWeightsPin =
121 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 6, model, data);
122
123 // 07: The recurrent-to-cell weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
124 // [num_units, output_size].
125 const ConstTensorPin recurrentToCellWeightsPin =
126 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 7, model, data);
127
128 // 08: The recurrent-to-output weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
129 // [num_units, output_size].
130 const ConstTensorPin recurrentToOutputWeightsPin =
131 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 8, model, data);
132
133 // 13: The forget gate bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
134 const ConstTensorPin forgetGateBiasPin =
135 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 13, model, data);
136
137 // 14: The cell bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
138 const ConstTensorPin cellBiasPin =
139 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 14, model, data);
140
141 // 15: The output gate bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
142 const ConstTensorPin outputGateBiasPin =
143 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 15, model, data);
144
145 if (!inputToForgetWeightsPin.IsValid() ||
146 !inputToCellWeightsPin.IsValid() ||
147 !inputToOutputWeightsPin.IsValid() ||
148 !recurrentToForgetWeightsPin.IsValid() ||
149 !recurrentToCellWeightsPin.IsValid() ||
150 !recurrentToOutputWeightsPin.IsValid() ||
151 !forgetGateBiasPin.IsValid() ||
152 !cellBiasPin.IsValid() ||
153 !outputGateBiasPin.IsValid())
154 {
155 return Fail("%s: Operation has invalid tensor inputs", __func__);
156 }
157
158 // Get the optional input tensors:
159
160 // 01: The input-to-input weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
161 // [num_units, input_size], where “num_units” corresponds to the number of cell units.
162 const ConstTensorPin inputToInputWeightsPin =
163 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
164 1,
165 model,
166 data,
167 g_DontPermute,
168 nullptr,
169 true);
170
171 // 05: The recurrent-to-input weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
172 // [num_units, output_size], where “output_size” corresponds to either the number of cell units (i.e.,
173 // “num_units”), or the second dimension of the “projection_weights”, if defined.
174 const ConstTensorPin recurrentToInputWeightsPin =
175 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
176 5,
177 model,
178 data,
179 g_DontPermute,
180 nullptr,
181 true);
182
183 // 09: The cell-to-input weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
184 // [num_units].
185 const ConstTensorPin cellToInputWeightsPin =
186 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
187 9,
188 model,
189 data,
190 g_DontPermute,
191 nullptr,
192 true);
193
194 // 10: The cell-to-forget weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
195 // [num_units].
196 const ConstTensorPin cellToForgetWeightsPin =
197 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
198 10,
199 model,
200 data,
201 g_DontPermute,
202 nullptr,
203 true);
204
205 // 11: The cell-to-output weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
206 // [num_units].
207 const ConstTensorPin cellToOutputWeightsPin =
208 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
209 11,
210 model,
211 data,
212 g_DontPermute,
213 nullptr,
214 true);
215
216 // 12: The input gate bias: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
217 const ConstTensorPin inputGateBiasPin =
218 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
219 12,
220 model,
221 data,
222 g_DontPermute,
223 nullptr,
224 true);
225
226 // 16: The projection weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
227 // [output_size, num_units].
228 const ConstTensorPin projectionWeightsPin =
229 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
230 16,
231 model,
232 data,
233 g_DontPermute,
234 nullptr,
235 true);
236
237 // 17: The projection bias: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [output_size].
238 const ConstTensorPin projectionBiasPin =
239 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
240 17,
241 model,
242 data,
243 g_DontPermute,
244 nullptr,
245 true);
246
247 if ((!inputToInputWeightsPin.IsValid() && !inputToInputWeightsPin.IsOptional())
248 || (!recurrentToInputWeightsPin.IsValid() && !recurrentToInputWeightsPin.IsOptional())
249 || (!cellToInputWeightsPin.IsValid() && !cellToInputWeightsPin.IsOptional())
250 || (!cellToForgetWeightsPin.IsValid() && !cellToForgetWeightsPin.IsOptional())
251 || (!cellToOutputWeightsPin.IsValid() && !cellToOutputWeightsPin.IsOptional())
252 || (!inputGateBiasPin.IsValid() && !inputGateBiasPin.IsOptional())
253 || (!projectionWeightsPin.IsValid() && !projectionWeightsPin.IsOptional())
254 || (!projectionBiasPin.IsValid() && !projectionBiasPin.IsOptional()))
255 {
256 return Fail("%s: Operation has invalid tensor inputs", __func__);
257 }
258
259
260 // Get the optional normalization tensors
261
262 // 20: The input layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM.
263 // Used to rescale normalized inputs to activation at input gate.
264 const ConstTensorPin inputLayerNormWeightsPin =
265 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
266 20,
267 model,
268 data,
269 g_DontPermute,
270 nullptr,
271 true);
272
273 // 21: The forget layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM
274 // Used to rescale normalized inputs to activation at forget gate.
275 const ConstTensorPin forgetLayerNormWeightsPin =
276 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
277 21,
278 model,
279 data,
280 g_DontPermute,
281 nullptr,
282 true);
283
284 // 22: The cell layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM.
285 // Used to rescale normalized inputs to activation at cell gate.
286 const ConstTensorPin cellLayerNormWeightsPin =
287 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
288 22,
289 model,
290 data,
291 g_DontPermute,
292 nullptr,
293 true);
294
295 // 23: The output layer normalization weights. A 1-D tensor of shape [num_units].
296 // Used to rescale normalized inputs to activation at output gate.
297 const ConstTensorPin outputLayerNormWeightsPin =
298 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
299 23,
300 model,
301 data,
302 g_DontPermute,
303 nullptr,
304 true);
305
306 if ((!inputLayerNormWeightsPin.IsValid() && !inputLayerNormWeightsPin.IsOptional())
307 || (!forgetLayerNormWeightsPin.IsValid() && !forgetLayerNormWeightsPin.IsOptional())
308 || (!cellLayerNormWeightsPin.IsValid() && !cellLayerNormWeightsPin.IsOptional())
309 || (!outputLayerNormWeightsPin.IsValid() && !outputLayerNormWeightsPin.IsOptional()))
310 {
311 return Fail("%s: Operation has invalid tensor inputs", __func__);
312 }
313
314 // Get the optional input scalars:
315 // 24: The cell clip: If provided the cell state is clipped by this value prior to the cell output activation.
316 // 25: The projection clip: If provided and projection is enabled, this is used for clipping the projected values.
317
318 // Get the mandatory input scalars:
319 // 26: The scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
320 // 27: The scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
321 // 28: The scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
322 // 29: The scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
323 // 30: The zero point of the hidden state, i.e. input to projection.
324 // 31: The scale of the hidden state, i.e. input to projection.
325 float cellClip, projClip, matMulInputGate, matMulForgetGate, matMulCellGate, matMulOutputGate, projInputScale;
326 int projInputZeroPoint;
327
328 if (!GetInputScalar<HalPolicy>(operation, 24, HalOperandType::FLOAT32, cellClip, model, data, true) ||
329 !GetInputScalar<HalPolicy>(operation, 25, HalOperandType::FLOAT32, projClip, model, data, true) ||
330 !GetInputScalar<HalPolicy>(operation, 26, HalOperandType::FLOAT32, matMulInputGate, model, data) ||
331 !GetInputScalar<HalPolicy>(operation, 27, HalOperandType::FLOAT32, matMulForgetGate, model, data) ||
332 !GetInputScalar<HalPolicy>(operation, 28, HalOperandType::FLOAT32, matMulCellGate, model, data) ||
333 !GetInputScalar<HalPolicy>(operation, 29, HalOperandType::FLOAT32, matMulOutputGate, model, data) ||
334 !GetInputScalar<HalPolicy>(operation, 30, HalOperandType::FLOAT32, projInputScale, model, data) ||
335 !GetInputScalar<HalPolicy>(operation, 31, HalOperandType::FLOAT32, projInputZeroPoint, model, data))
336 {
337 return Fail("%s: Operation has invalid scalar inputs", __func__);
338 }
339
340 // Outputs:
341 // 0: The output state (out): A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED, of shape [batch_size,
342 // output_size].
343 const HalOperand* outputStateOut = GetOutputOperand<HalPolicy>(operation, 0, model);
344 if (!outputStateOut)
345 {
346 return Fail("%s: Could not read output 0: outputStateOut", __func__);
347 }
348
349 // 1: The cell state (out): A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape [batch_size, num_units].
350 const HalOperand* cellStateOut = GetOutputOperand<HalPolicy>(operation, 1, model);
351 if (!cellStateOut)
352 {
353 return Fail("%s: Could not read output 1: cellStateOut", __func__);
354 }
355
356 // 2: The output: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED, of shape [batch_size, output_size].
357 // This is effectively the same as the current “output state (out)” value.
358 const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 2, model);
359 if (!output)
360 {
361 return Fail("%s: Could not read output 2: output", __func__);
362 }
363
364 // set the params structure for the AddLstmLayer call
365 LstmInputParams params;
366 params.m_InputToInputWeights = inputToInputWeightsPin.GetConstTensorPtr();
367 params.m_InputToForgetWeights = inputToForgetWeightsPin.GetConstTensorPtr();
368 params.m_InputToCellWeights = inputToCellWeightsPin.GetConstTensorPtr();
369 params.m_InputToOutputWeights = inputToOutputWeightsPin.GetConstTensorPtr();
370 params.m_RecurrentToInputWeights = recurrentToInputWeightsPin.GetConstTensorPtr();
371 params.m_RecurrentToForgetWeights = recurrentToForgetWeightsPin.GetConstTensorPtr();
372 params.m_RecurrentToCellWeights = recurrentToCellWeightsPin.GetConstTensorPtr();
373 params.m_RecurrentToOutputWeights = recurrentToOutputWeightsPin.GetConstTensorPtr();
374 params.m_CellToInputWeights = cellToInputWeightsPin.GetConstTensorPtr();
375 params.m_CellToForgetWeights = cellToForgetWeightsPin.GetConstTensorPtr();
376 params.m_CellToOutputWeights = cellToOutputWeightsPin.GetConstTensorPtr();
377 params.m_InputGateBias = inputGateBiasPin.GetConstTensorPtr();
378 params.m_ForgetGateBias = forgetGateBiasPin.GetConstTensorPtr();
379 params.m_CellBias = cellBiasPin.GetConstTensorPtr();
380 params.m_OutputGateBias = outputGateBiasPin.GetConstTensorPtr();
381 params.m_ProjectionWeights = projectionWeightsPin.GetConstTensorPtr();
382 params.m_ProjectionBias = projectionBiasPin.GetConstTensorPtr();
383 params.m_InputLayerNormWeights = inputLayerNormWeightsPin.GetConstTensorPtr();
384 params.m_ForgetLayerNormWeights = forgetLayerNormWeightsPin.GetConstTensorPtr();
385 params.m_CellLayerNormWeights = cellLayerNormWeightsPin.GetConstTensorPtr();
386 params.m_OutputLayerNormWeights = outputLayerNormWeightsPin.GetConstTensorPtr();
387
388 // set the layer descriptor
389 QLstmDescriptor desc;
390 desc.m_CellClip = cellClip;
391 desc.m_ProjectionClip = projClip;
392 desc.m_CifgEnabled = (params.m_InputToInputWeights == nullptr ||
393 params.m_RecurrentToInputWeights == nullptr ||
394 params.m_InputGateBias == nullptr);
395 desc.m_PeepholeEnabled = (params.m_CellToForgetWeights != nullptr ||
396 params.m_CellToOutputWeights != nullptr);
397 desc.m_ProjectionEnabled = (params.m_ProjectionWeights != nullptr);
398 desc.m_LayerNormEnabled = (params.m_InputLayerNormWeights != nullptr ||
399 params.m_ForgetLayerNormWeights != nullptr ||
400 params.m_CellLayerNormWeights != nullptr ||
401 params.m_OutputLayerNormWeights != nullptr);
402 desc.m_InputIntermediateScale = matMulInputGate;
403 desc.m_ForgetIntermediateScale = matMulForgetGate;
404 desc.m_CellIntermediateScale = matMulCellGate;
405 desc.m_OutputIntermediateScale = matMulOutputGate;
406 desc.m_HiddenStateScale = projInputScale;
407 desc.m_HiddenStateZeroPoint = projInputZeroPoint;
408
409 // validate the optional input groups
410 if (desc.m_CifgEnabled &&
411 (params.m_InputToInputWeights != nullptr ||
412 params.m_RecurrentToInputWeights != nullptr ||
413 params.m_InputGateBias != nullptr))
414 {
415 return Fail("%s: All, or none, of input-to-input weights, recurrent-to-input weights,"
416 " and input gate bias must be provided", __func__);
417 }
418
419 if (!desc.m_ProjectionEnabled && params.m_ProjectionBias != nullptr)
420 {
421 return Fail("%s: projection bias should not be provided without projection weights", __func__);
422 }
423
424 if (desc.m_PeepholeEnabled &&
425 (params.m_CellToForgetWeights == nullptr ||
426 params.m_CellToOutputWeights == nullptr ||
427 (!desc.m_CifgEnabled && params.m_CellToInputWeights == nullptr)))
428 {
429 return Fail("%s: All, or none, of cell-to-forget weights and cell-to-output weights must be provided"
430 " and, if CIFG is not enabled, cell-to-input weights must also be provided", __func__);
431 }
432
433 if (desc.m_LayerNormEnabled &&
434 (params.m_ForgetLayerNormWeights == nullptr ||
435 params.m_CellLayerNormWeights == nullptr ||
436 params.m_OutputLayerNormWeights == nullptr ||
437 (!desc.m_CifgEnabled && params.m_InputLayerNormWeights == nullptr)))
438 {
439 return Fail("%s: All, or none, of forget-norm weights, cell-norm weights and output-norm weights must be"
440 " provided and, if CIFG is not enabled, input-norm weights must also be provided", __func__);
441 }
442
443
444 // Basic parameters
445 LstmInputParamsInfo paramsInfo;
446 paramsInfo.m_InputToForgetWeights = &(params.m_InputToForgetWeights->GetInfo());
447 paramsInfo.m_InputToCellWeights = &(params.m_InputToCellWeights->GetInfo());
448 paramsInfo.m_InputToOutputWeights = &(params.m_InputToOutputWeights->GetInfo());
449 paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
450 paramsInfo.m_RecurrentToCellWeights = &(params.m_RecurrentToCellWeights->GetInfo());
451 paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
452 paramsInfo.m_ForgetGateBias = &(params.m_ForgetGateBias->GetInfo());
453 paramsInfo.m_CellBias = &(params.m_CellBias->GetInfo());
454 paramsInfo.m_OutputGateBias = &(params.m_OutputGateBias->GetInfo());
455
456 // Inputs
457 const TensorInfo& inputInfo = input.GetTensorInfo();
458 const TensorInfo& outputStatePrevTimeStepInfo = outputStatePrevTimeStep.GetTensorInfo();
459 const TensorInfo& cellStatePrevTimeStepInfo = cellStatePrevTimeStep.GetTensorInfo();
460
461 // Outputs
462 TensorInfo outputStateOutInfo = GetTensorInfoForOperand(*outputStateOut);
463 TensorInfo outputInfo = GetTensorInfoForOperand(*output);
464 const TensorInfo& cellStateOutInfo = GetTensorInfoForOperand(*cellStateOut);
465
466 // Optional parameters
467 if (!desc.m_CifgEnabled)
468 {
469 paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
470 paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
471 if (desc.m_PeepholeEnabled)
472 {
473 paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
474 }
475 paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
476 }
477
478
479 if (desc.m_ProjectionEnabled)
480 {
481 paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
482 if (params.m_ProjectionBias != nullptr)
483 {
484 paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
485 }
486 }
487 else
488 {
489 // If Projection is disabled, override non-const outputs to change the quant info with hidden params, then
490 // create a new const TensorInfo based on this
491 outputStateOutInfo.SetQuantizationScale(projInputScale);
492 outputStateOutInfo.SetQuantizationOffset(projInputZeroPoint);
493 outputInfo.SetQuantizationScale(projInputScale);
494 outputInfo.SetQuantizationOffset(projInputZeroPoint);
495 }
496
497 const TensorInfo constOutputStateOutInfo(outputStateOutInfo);
498 const TensorInfo constOutputInfo(outputInfo);
499
500 if (desc.m_PeepholeEnabled)
501 {
502 paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
503 paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
504 }
505
506 if (desc.m_LayerNormEnabled)
507 {
508 if(!desc.m_CifgEnabled)
509 {
510 paramsInfo.m_InputLayerNormWeights = &(params.m_InputLayerNormWeights->GetInfo());
511 }
512 paramsInfo.m_ForgetLayerNormWeights = &(params.m_ForgetLayerNormWeights->GetInfo());
513 paramsInfo.m_CellLayerNormWeights = &(params.m_CellLayerNormWeights->GetInfo());
514 paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
515 }
516
517 // Check if the layer is supported
518
519 if (IsDynamicTensor(constOutputStateOutInfo) ||
520 IsDynamicTensor(cellStateOutInfo) ||
521 IsDynamicTensor(constOutputInfo))
522 {
523 return Fail("%s: Dynamic output tensors are not supported %d %d %d %d", __func__,
524 IsDynamicTensor(constOutputStateOutInfo), IsDynamicTensor(cellStateOutInfo),
525 IsDynamicTensor(constOutputInfo));
526 }
527
528 bool isSupported = false;
529 FORWARD_LAYER_SUPPORT_FUNC(__func__,
530 IsQLstmSupported,
531 data.m_Backends,
532 isSupported,
533 inputInfo,
534 outputStatePrevTimeStepInfo,
535 cellStatePrevTimeStepInfo,
536 constOutputStateOutInfo,
537 cellStateOutInfo,
538 constOutputInfo,
539 desc,
540 paramsInfo);
541 if (!isSupported)
542 {
543 return false;
544 }
545
546 // Add the layer
547 IConnectableLayer* layer = data.m_Network->AddQLstmLayer(desc, params, "QLstm");
548
549 input.Connect(layer->GetInputSlot(0));
550 outputStatePrevTimeStep.Connect(layer->GetInputSlot(1));
551 cellStatePrevTimeStep.Connect(layer->GetInputSlot(2));
552
553 return ( SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data,
554 &constOutputStateOutInfo) &&
555 SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data) &&
556 SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo));
557}
558
Sadik Armagan1153d1e2020-04-01 15:09:39 +0100559} // armnn_driver namespace