blob: 150735e974e4179af4ccbf416c5868a3155a0736 [file] [log] [blame]
Sadik Armagan1153d1e2020-04-01 15:09:39 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ConversionUtils_1_2.hpp"
9
10using Half = half_float::half;
11
12namespace armnn_driver
13{
14
15using namespace armnn;
16using namespace android::nn;
17
18template<typename HalPolicy,
19 typename HalOperation = typename HalPolicy::Operation,
20 typename HalModel = typename HalPolicy::Model>
21bool ConvertElu(const HalOperation& operation, const HalModel& model, ConversionData& data)
22{
23 using HalOperandType = typename HalPolicy::OperandType;
24
25 LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
26 if (!input0.IsValid())
27 {
28 return Fail("%s: Operation has invalid inputs", __func__);
29 }
30
31 // Determine data type of input tensor
32 HalOperandType inputType;
33 if (!GetOperandType<HalPolicy>(operation, 0, model, inputType))
34 {
35 return Fail("%s: Operation has invalid inputs", __func__);
36 }
37
38 ActivationDescriptor desc;
39 desc.m_Function = ActivationFunction::Elu;
40
41 // Read alpha
42 if (inputType == HalOperandType::TENSOR_FLOAT16)
43 {
44 Half alpha;
45
46 if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, alpha, model, data))
47 {
48 return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__);
49 }
50
51 desc.m_A = static_cast<float>(alpha);
52 }
53 else if (inputType == HalOperandType::TENSOR_FLOAT32)
54 {
55 if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, desc.m_A, model, data))
56 {
57 return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__);
58 }
59 }
60 else
61 {
62 return Fail("%s: Unsupported input tensor type: %d", __func__, inputType);
63 }
64
65 return ::ConvertToActivation<HalPolicy>(operation, __func__, desc, model, data);
66}
67
Sadik Armagan813f2302020-05-19 14:10:30 +010068template<typename HalPolicy,
Sadik Armagan2e329612020-06-24 10:57:23 +010069 typename HalOperation = typename HalPolicy::Operation,
70 typename HalModel = typename HalPolicy::Model>
71bool ConvertFill(const HalOperation& operation, const HalModel& model, ConversionData& data)
72{
73 using HalOperand = typename HalPolicy::Operand;
74 using HalOperandType = typename HalPolicy::OperandType;
75
76 LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
77 if (!input.IsValid())
78 {
79 return Fail("%s: Operation has invalid inputs", __func__);
80 }
81
82 const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
83 if (!output)
84 {
85 return Fail("%s: Could not read output", __func__);
86 }
87
88 const TensorInfo& inputInfo = input.GetTensorInfo();
89 const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
90 if (IsDynamicTensor(outputInfo))
91 {
92 return Fail("%s: Dynamic output tensors are not supported", __func__);
93 }
94
95 // Determine data type of output tensor
96 HalOperandType outputType = output->type;
97 FillDescriptor descriptor;
98 // Read the scalar fill value
99 if (outputType == HalOperandType::TENSOR_FLOAT16)
100 {
101 Half value;
102
103 if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, value, model, data))
104 {
105 return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
106 }
107
108 descriptor.m_Value = static_cast<float>(value);
109 }
110 else if (outputType == HalOperandType::TENSOR_FLOAT32)
111 {
112 if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, descriptor.m_Value, model, data))
113 {
114 return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
115 }
116 }
117 else if (outputType == HalOperandType::TENSOR_INT32)
118 {
119 int32_t value;
120
121 if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::INT32, value, model, data))
122 {
123 return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
124 }
125
126 descriptor.m_Value = static_cast<float>(value);
127 }
128 else
129 {
130 return Fail("%s: Unsupported input tensor type: %d", __func__, outputType);
131 }
132
133 bool isSupported = false;
134 FORWARD_LAYER_SUPPORT_FUNC(__func__,
135 IsFillSupported,
136 data.m_Backends,
137 isSupported,
138 inputInfo,
139 outputInfo,
140 descriptor);
141 if (!isSupported)
142 {
143 return false;
144 }
145
146 IConnectableLayer* const layer = data.m_Network->AddFillLayer(descriptor);
147 assert(layer != nullptr);
148 input.Connect(layer->GetInputSlot(0));
Sadik Armagan2e329612020-06-24 10:57:23 +0100149
150 return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data);
151}
152
153template<typename HalPolicy,
Sadik Armagan813f2302020-05-19 14:10:30 +0100154 typename HalOperation = typename HalPolicy::Operation,
155 typename HalModel = typename HalPolicy::Model>
Narumol Prangnawarat0629eb82020-11-12 18:27:37 +0000156bool ConvertLogicalBinary(const HalOperation& operation,
157 const HalModel& model,
158 ConversionData& data,
159 LogicalBinaryOperation logicalOperation)
160{
161 using HalOperand = typename HalPolicy::Operand;
162
163 ALOGV("HalPolicy::ConvertLogicalBinary()");
164 ALOGV("logicalOperation = %s", GetLogicalBinaryOperationAsCString(logicalOperation));
165
166 LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
167 LayerInputHandle input1 = ConvertToLayerInputHandle<HalPolicy>(operation, 1, model, data);
168
169 if (!(input0.IsValid() && input1.IsValid()))
170 {
171 return Fail("%s: Operation has invalid inputs", __func__);
172 }
173
174 const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
175 if (!output)
176 {
177 return Fail("%s: Could not read output 0", __func__);
178 }
179
180 const TensorInfo& inputInfo0 = input0.GetTensorInfo();
181 const TensorInfo& inputInfo1 = input1.GetTensorInfo();
182 const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
183
184 LogicalBinaryDescriptor descriptor(logicalOperation);
185
186 bool isSupported = false;
187
188 auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
189 {
190 FORWARD_LAYER_SUPPORT_FUNC(__func__,
191 IsLogicalBinarySupported,
192 data.m_Backends,
193 isSupported,
194 inputInfo0,
195 inputInfo1,
196 outputInfo,
197 descriptor);
198 };
199
200 if(!IsDynamicTensor(outputInfo))
201 {
202 validateFunc(outputInfo, isSupported);
203 }
204 else
205 {
206 isSupported = AreDynamicTensorsSupported();
207 }
208
209 if (!isSupported)
210 {
211 return false;
212 }
213
214 IConnectableLayer* layer = data.m_Network->AddLogicalBinaryLayer(descriptor);
215 assert(layer != nullptr);
216
217 bool isReshapeSupported = BroadcastTensor(input0, input1, layer, data);
218 if (!isReshapeSupported)
219 {
220 return false;
221 }
222
223 return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
224}
225
226template<typename HalPolicy,
227 typename HalOperation = typename HalPolicy::Operation,
228 typename HalModel = typename HalPolicy::Model>
Sadik Armagan813f2302020-05-19 14:10:30 +0100229bool ConvertQuantizedLstm(const HalOperation& operation, const HalModel& model, ConversionData& data)
230{
231 using HalOperand = typename HalPolicy::Operand;
232 using HalOperandType = typename HalPolicy::OperandType;
233
234 ALOGV("HalPolicy::ConvertQuantizedLstm()");
235
236 //Inputs:
237 // 0: The input: A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape [numBatches, inputSize]
238 // specifying the input to the LSTM cell. Tensor is quantized with a fixed quantization range of -1, 127/128.
239 LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
240 if (!input.IsValid())
241 {
242 return Fail("%s: Could not read input 0: input", __func__);
243 }
244
245 // 18: The output state: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, of shape [batch_size, output_size].
246 LayerInputHandle outputStatePrevTimeStep = ConvertToLayerInputHandle<HalPolicy>(operation, 18, model, data);
247 if (!outputStatePrevTimeStep.IsValid())
248 {
249 return Fail("%s: Could not read input 18: outputStatePrevTimeStep", __func__);
250 }
251
252 // 19: The cell state: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape [batch_size, num_units].
253 LayerInputHandle cellStatePrevTimeStep = ConvertToLayerInputHandle<HalPolicy>(operation, 19, model, data);
254 if (!cellStatePrevTimeStep.IsValid())
255 {
256 return Fail("%s: Could not read input 19: cellStatePrevTimeStep", __func__);
257 }
258
259 // Get the mandatory input tensors:
260
261 // 02: The input-to-forget weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
262 // [num_units, input_size].
263 const ConstTensorPin inputToForgetWeightsPin =
264 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data);
265
266 // 03: The input-to-cell weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
267 // [num_units, input_size].
268 const ConstTensorPin inputToCellWeightsPin =
269 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 3, model, data);
270
271 // 04: The input-to-output weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
272 // [num_units, input_size].
273 const ConstTensorPin inputToOutputWeightsPin =
274 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 4, model, data);
275
276 // 06: The recurrent-to-forget weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
277 // [num_units, output_size].
278 const ConstTensorPin recurrentToForgetWeightsPin =
279 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 6, model, data);
280
281 // 07: The recurrent-to-cell weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
282 // [num_units, output_size].
283 const ConstTensorPin recurrentToCellWeightsPin =
284 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 7, model, data);
285
286 // 08: The recurrent-to-output weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
287 // [num_units, output_size].
288 const ConstTensorPin recurrentToOutputWeightsPin =
289 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 8, model, data);
290
291 // 13: The forget gate bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
292 const ConstTensorPin forgetGateBiasPin =
293 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 13, model, data);
294
295 // 14: The cell bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
296 const ConstTensorPin cellBiasPin =
297 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 14, model, data);
298
299 // 15: The output gate bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
300 const ConstTensorPin outputGateBiasPin =
301 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 15, model, data);
302
303 if (!inputToForgetWeightsPin.IsValid() ||
304 !inputToCellWeightsPin.IsValid() ||
305 !inputToOutputWeightsPin.IsValid() ||
306 !recurrentToForgetWeightsPin.IsValid() ||
307 !recurrentToCellWeightsPin.IsValid() ||
308 !recurrentToOutputWeightsPin.IsValid() ||
309 !forgetGateBiasPin.IsValid() ||
310 !cellBiasPin.IsValid() ||
311 !outputGateBiasPin.IsValid())
312 {
313 return Fail("%s: Operation has invalid tensor inputs", __func__);
314 }
315
316 // Get the optional input tensors:
317
318 // 01: The input-to-input weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
319 // [num_units, input_size], where “num_units” corresponds to the number of cell units.
320 const ConstTensorPin inputToInputWeightsPin =
321 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
322 1,
323 model,
324 data,
325 g_DontPermute,
326 nullptr,
327 true);
328
329 // 05: The recurrent-to-input weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
330 // [num_units, output_size], where “output_size” corresponds to either the number of cell units (i.e.,
331 // “num_units”), or the second dimension of the “projection_weights”, if defined.
332 const ConstTensorPin recurrentToInputWeightsPin =
333 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
334 5,
335 model,
336 data,
337 g_DontPermute,
338 nullptr,
339 true);
340
341 // 09: The cell-to-input weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
342 // [num_units].
343 const ConstTensorPin cellToInputWeightsPin =
344 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
345 9,
346 model,
347 data,
348 g_DontPermute,
349 nullptr,
350 true);
351
352 // 10: The cell-to-forget weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
353 // [num_units].
354 const ConstTensorPin cellToForgetWeightsPin =
355 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
356 10,
357 model,
358 data,
359 g_DontPermute,
360 nullptr,
361 true);
362
363 // 11: The cell-to-output weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
364 // [num_units].
365 const ConstTensorPin cellToOutputWeightsPin =
366 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
367 11,
368 model,
369 data,
370 g_DontPermute,
371 nullptr,
372 true);
373
374 // 12: The input gate bias: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
375 const ConstTensorPin inputGateBiasPin =
376 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
377 12,
378 model,
379 data,
380 g_DontPermute,
381 nullptr,
382 true);
383
384 // 16: The projection weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
385 // [output_size, num_units].
386 const ConstTensorPin projectionWeightsPin =
387 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
388 16,
389 model,
390 data,
391 g_DontPermute,
392 nullptr,
393 true);
394
395 // 17: The projection bias: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [output_size].
396 const ConstTensorPin projectionBiasPin =
397 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
398 17,
399 model,
400 data,
401 g_DontPermute,
402 nullptr,
403 true);
404
405 if ((!inputToInputWeightsPin.IsValid() && !inputToInputWeightsPin.IsOptional())
406 || (!recurrentToInputWeightsPin.IsValid() && !recurrentToInputWeightsPin.IsOptional())
407 || (!cellToInputWeightsPin.IsValid() && !cellToInputWeightsPin.IsOptional())
408 || (!cellToForgetWeightsPin.IsValid() && !cellToForgetWeightsPin.IsOptional())
409 || (!cellToOutputWeightsPin.IsValid() && !cellToOutputWeightsPin.IsOptional())
410 || (!inputGateBiasPin.IsValid() && !inputGateBiasPin.IsOptional())
411 || (!projectionWeightsPin.IsValid() && !projectionWeightsPin.IsOptional())
412 || (!projectionBiasPin.IsValid() && !projectionBiasPin.IsOptional()))
413 {
414 return Fail("%s: Operation has invalid tensor inputs", __func__);
415 }
416
417
418 // Get the optional normalization tensors
419
420 // 20: The input layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM.
421 // Used to rescale normalized inputs to activation at input gate.
422 const ConstTensorPin inputLayerNormWeightsPin =
423 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
424 20,
425 model,
426 data,
427 g_DontPermute,
428 nullptr,
429 true);
430
431 // 21: The forget layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM
432 // Used to rescale normalized inputs to activation at forget gate.
433 const ConstTensorPin forgetLayerNormWeightsPin =
434 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
435 21,
436 model,
437 data,
438 g_DontPermute,
439 nullptr,
440 true);
441
442 // 22: The cell layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM.
443 // Used to rescale normalized inputs to activation at cell gate.
444 const ConstTensorPin cellLayerNormWeightsPin =
445 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
446 22,
447 model,
448 data,
449 g_DontPermute,
450 nullptr,
451 true);
452
453 // 23: The output layer normalization weights. A 1-D tensor of shape [num_units].
454 // Used to rescale normalized inputs to activation at output gate.
455 const ConstTensorPin outputLayerNormWeightsPin =
456 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
457 23,
458 model,
459 data,
460 g_DontPermute,
461 nullptr,
462 true);
463
464 if ((!inputLayerNormWeightsPin.IsValid() && !inputLayerNormWeightsPin.IsOptional())
465 || (!forgetLayerNormWeightsPin.IsValid() && !forgetLayerNormWeightsPin.IsOptional())
466 || (!cellLayerNormWeightsPin.IsValid() && !cellLayerNormWeightsPin.IsOptional())
467 || (!outputLayerNormWeightsPin.IsValid() && !outputLayerNormWeightsPin.IsOptional()))
468 {
469 return Fail("%s: Operation has invalid tensor inputs", __func__);
470 }
471
472 // Get the optional input scalars:
473 // 24: The cell clip: If provided the cell state is clipped by this value prior to the cell output activation.
474 // 25: The projection clip: If provided and projection is enabled, this is used for clipping the projected values.
475
476 // Get the mandatory input scalars:
477 // 26: The scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
478 // 27: The scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
479 // 28: The scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
480 // 29: The scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
481 // 30: The zero point of the hidden state, i.e. input to projection.
482 // 31: The scale of the hidden state, i.e. input to projection.
483 float cellClip, projClip, matMulInputGate, matMulForgetGate, matMulCellGate, matMulOutputGate, projInputScale;
484 int projInputZeroPoint;
485
486 if (!GetInputScalar<HalPolicy>(operation, 24, HalOperandType::FLOAT32, cellClip, model, data, true) ||
487 !GetInputScalar<HalPolicy>(operation, 25, HalOperandType::FLOAT32, projClip, model, data, true) ||
488 !GetInputScalar<HalPolicy>(operation, 26, HalOperandType::FLOAT32, matMulInputGate, model, data) ||
489 !GetInputScalar<HalPolicy>(operation, 27, HalOperandType::FLOAT32, matMulForgetGate, model, data) ||
490 !GetInputScalar<HalPolicy>(operation, 28, HalOperandType::FLOAT32, matMulCellGate, model, data) ||
491 !GetInputScalar<HalPolicy>(operation, 29, HalOperandType::FLOAT32, matMulOutputGate, model, data) ||
Sadik Armagan24af8b22020-05-22 08:34:16 +0100492 !GetInputScalar<HalPolicy>(operation, 30, HalOperandType::INT32, projInputZeroPoint, model, data) ||
493 !GetInputScalar<HalPolicy>(operation, 31, HalOperandType::FLOAT32, projInputScale, model, data))
Sadik Armagan813f2302020-05-19 14:10:30 +0100494 {
495 return Fail("%s: Operation has invalid scalar inputs", __func__);
496 }
497
498 // Outputs:
499 // 0: The output state (out): A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED, of shape [batch_size,
500 // output_size].
501 const HalOperand* outputStateOut = GetOutputOperand<HalPolicy>(operation, 0, model);
502 if (!outputStateOut)
503 {
504 return Fail("%s: Could not read output 0: outputStateOut", __func__);
505 }
506
507 // 1: The cell state (out): A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape [batch_size, num_units].
508 const HalOperand* cellStateOut = GetOutputOperand<HalPolicy>(operation, 1, model);
509 if (!cellStateOut)
510 {
511 return Fail("%s: Could not read output 1: cellStateOut", __func__);
512 }
513
514 // 2: The output: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED, of shape [batch_size, output_size].
515 // This is effectively the same as the current “output state (out)” value.
516 const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 2, model);
517 if (!output)
518 {
519 return Fail("%s: Could not read output 2: output", __func__);
520 }
521
522 // set the params structure for the AddLstmLayer call
523 LstmInputParams params;
524 params.m_InputToInputWeights = inputToInputWeightsPin.GetConstTensorPtr();
525 params.m_InputToForgetWeights = inputToForgetWeightsPin.GetConstTensorPtr();
526 params.m_InputToCellWeights = inputToCellWeightsPin.GetConstTensorPtr();
527 params.m_InputToOutputWeights = inputToOutputWeightsPin.GetConstTensorPtr();
528 params.m_RecurrentToInputWeights = recurrentToInputWeightsPin.GetConstTensorPtr();
529 params.m_RecurrentToForgetWeights = recurrentToForgetWeightsPin.GetConstTensorPtr();
530 params.m_RecurrentToCellWeights = recurrentToCellWeightsPin.GetConstTensorPtr();
531 params.m_RecurrentToOutputWeights = recurrentToOutputWeightsPin.GetConstTensorPtr();
532 params.m_CellToInputWeights = cellToInputWeightsPin.GetConstTensorPtr();
533 params.m_CellToForgetWeights = cellToForgetWeightsPin.GetConstTensorPtr();
534 params.m_CellToOutputWeights = cellToOutputWeightsPin.GetConstTensorPtr();
535 params.m_InputGateBias = inputGateBiasPin.GetConstTensorPtr();
536 params.m_ForgetGateBias = forgetGateBiasPin.GetConstTensorPtr();
537 params.m_CellBias = cellBiasPin.GetConstTensorPtr();
538 params.m_OutputGateBias = outputGateBiasPin.GetConstTensorPtr();
539 params.m_ProjectionWeights = projectionWeightsPin.GetConstTensorPtr();
540 params.m_ProjectionBias = projectionBiasPin.GetConstTensorPtr();
541 params.m_InputLayerNormWeights = inputLayerNormWeightsPin.GetConstTensorPtr();
542 params.m_ForgetLayerNormWeights = forgetLayerNormWeightsPin.GetConstTensorPtr();
543 params.m_CellLayerNormWeights = cellLayerNormWeightsPin.GetConstTensorPtr();
544 params.m_OutputLayerNormWeights = outputLayerNormWeightsPin.GetConstTensorPtr();
545
546 // set the layer descriptor
547 QLstmDescriptor desc;
548 desc.m_CellClip = cellClip;
549 desc.m_ProjectionClip = projClip;
550 desc.m_CifgEnabled = (params.m_InputToInputWeights == nullptr ||
551 params.m_RecurrentToInputWeights == nullptr ||
552 params.m_InputGateBias == nullptr);
553 desc.m_PeepholeEnabled = (params.m_CellToForgetWeights != nullptr ||
554 params.m_CellToOutputWeights != nullptr);
555 desc.m_ProjectionEnabled = (params.m_ProjectionWeights != nullptr);
556 desc.m_LayerNormEnabled = (params.m_InputLayerNormWeights != nullptr ||
557 params.m_ForgetLayerNormWeights != nullptr ||
558 params.m_CellLayerNormWeights != nullptr ||
559 params.m_OutputLayerNormWeights != nullptr);
560 desc.m_InputIntermediateScale = matMulInputGate;
561 desc.m_ForgetIntermediateScale = matMulForgetGate;
562 desc.m_CellIntermediateScale = matMulCellGate;
563 desc.m_OutputIntermediateScale = matMulOutputGate;
564 desc.m_HiddenStateScale = projInputScale;
565 desc.m_HiddenStateZeroPoint = projInputZeroPoint;
566
567 // validate the optional input groups
568 if (desc.m_CifgEnabled &&
569 (params.m_InputToInputWeights != nullptr ||
570 params.m_RecurrentToInputWeights != nullptr ||
571 params.m_InputGateBias != nullptr))
572 {
573 return Fail("%s: All, or none, of input-to-input weights, recurrent-to-input weights,"
574 " and input gate bias must be provided", __func__);
575 }
576
577 if (!desc.m_ProjectionEnabled && params.m_ProjectionBias != nullptr)
578 {
579 return Fail("%s: projection bias should not be provided without projection weights", __func__);
580 }
581
582 if (desc.m_PeepholeEnabled &&
583 (params.m_CellToForgetWeights == nullptr ||
584 params.m_CellToOutputWeights == nullptr ||
585 (!desc.m_CifgEnabled && params.m_CellToInputWeights == nullptr)))
586 {
587 return Fail("%s: All, or none, of cell-to-forget weights and cell-to-output weights must be provided"
588 " and, if CIFG is not enabled, cell-to-input weights must also be provided", __func__);
589 }
590
591 if (desc.m_LayerNormEnabled &&
592 (params.m_ForgetLayerNormWeights == nullptr ||
593 params.m_CellLayerNormWeights == nullptr ||
594 params.m_OutputLayerNormWeights == nullptr ||
595 (!desc.m_CifgEnabled && params.m_InputLayerNormWeights == nullptr)))
596 {
597 return Fail("%s: All, or none, of forget-norm weights, cell-norm weights and output-norm weights must be"
598 " provided and, if CIFG is not enabled, input-norm weights must also be provided", __func__);
599 }
600
601
602 // Basic parameters
603 LstmInputParamsInfo paramsInfo;
604 paramsInfo.m_InputToForgetWeights = &(params.m_InputToForgetWeights->GetInfo());
605 paramsInfo.m_InputToCellWeights = &(params.m_InputToCellWeights->GetInfo());
606 paramsInfo.m_InputToOutputWeights = &(params.m_InputToOutputWeights->GetInfo());
607 paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
608 paramsInfo.m_RecurrentToCellWeights = &(params.m_RecurrentToCellWeights->GetInfo());
609 paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
610 paramsInfo.m_ForgetGateBias = &(params.m_ForgetGateBias->GetInfo());
611 paramsInfo.m_CellBias = &(params.m_CellBias->GetInfo());
612 paramsInfo.m_OutputGateBias = &(params.m_OutputGateBias->GetInfo());
613
614 // Inputs
615 const TensorInfo& inputInfo = input.GetTensorInfo();
616 const TensorInfo& outputStatePrevTimeStepInfo = outputStatePrevTimeStep.GetTensorInfo();
617 const TensorInfo& cellStatePrevTimeStepInfo = cellStatePrevTimeStep.GetTensorInfo();
618
619 // Outputs
620 TensorInfo outputStateOutInfo = GetTensorInfoForOperand(*outputStateOut);
621 TensorInfo outputInfo = GetTensorInfoForOperand(*output);
622 const TensorInfo& cellStateOutInfo = GetTensorInfoForOperand(*cellStateOut);
623
624 // Optional parameters
625 if (!desc.m_CifgEnabled)
626 {
627 paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
628 paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
629 if (desc.m_PeepholeEnabled)
630 {
631 paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
632 }
633 paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
634 }
635
636
637 if (desc.m_ProjectionEnabled)
638 {
639 paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
640 if (params.m_ProjectionBias != nullptr)
641 {
642 paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
643 }
644 }
645 else
646 {
647 // If Projection is disabled, override non-const outputs to change the quant info with hidden params, then
648 // create a new const TensorInfo based on this
649 outputStateOutInfo.SetQuantizationScale(projInputScale);
650 outputStateOutInfo.SetQuantizationOffset(projInputZeroPoint);
651 outputInfo.SetQuantizationScale(projInputScale);
652 outputInfo.SetQuantizationOffset(projInputZeroPoint);
653 }
654
655 const TensorInfo constOutputStateOutInfo(outputStateOutInfo);
656 const TensorInfo constOutputInfo(outputInfo);
657
658 if (desc.m_PeepholeEnabled)
659 {
660 paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
661 paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
662 }
663
664 if (desc.m_LayerNormEnabled)
665 {
666 if(!desc.m_CifgEnabled)
667 {
668 paramsInfo.m_InputLayerNormWeights = &(params.m_InputLayerNormWeights->GetInfo());
669 }
670 paramsInfo.m_ForgetLayerNormWeights = &(params.m_ForgetLayerNormWeights->GetInfo());
671 paramsInfo.m_CellLayerNormWeights = &(params.m_CellLayerNormWeights->GetInfo());
672 paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
673 }
674
675 // Check if the layer is supported
Sadik Armagan34db1872020-09-03 15:22:29 +0100676 bool isSupported = false;
677 auto validateFunc = [&](const armnn::TensorInfo& cellStateOutInfo, bool& isSupported)
Sadik Armagan813f2302020-05-19 14:10:30 +0100678 {
Sadik Armagan34db1872020-09-03 15:22:29 +0100679 FORWARD_LAYER_SUPPORT_FUNC(__func__,
680 IsQLstmSupported,
681 data.m_Backends,
682 isSupported,
683 inputInfo,
684 outputStatePrevTimeStepInfo,
685 cellStatePrevTimeStepInfo,
686 constOutputStateOutInfo,
687 cellStateOutInfo,
688 constOutputInfo,
689 desc,
690 paramsInfo);
691 };
692
693 bool isDynamic = false;
694 if (!IsDynamicTensor(constOutputStateOutInfo) &&
695 !IsDynamicTensor(cellStateOutInfo) &&
696 !IsDynamicTensor(constOutputInfo))
697 {
698 validateFunc(outputInfo, isSupported);
699 }
700 else
701 {
702 isDynamic = true;
703 isSupported = AreDynamicTensorsSupported();
Sadik Armagan813f2302020-05-19 14:10:30 +0100704 }
705
Sadik Armagan813f2302020-05-19 14:10:30 +0100706 if (!isSupported)
707 {
708 return false;
709 }
710
711 // Add the layer
712 IConnectableLayer* layer = data.m_Network->AddQLstmLayer(desc, params, "QLstm");
713
714 input.Connect(layer->GetInputSlot(0));
715 outputStatePrevTimeStep.Connect(layer->GetInputSlot(1));
716 cellStatePrevTimeStep.Connect(layer->GetInputSlot(2));
717
Sadik Armagan34db1872020-09-03 15:22:29 +0100718 if (!isDynamic)
719 {
720 return ( SetupAndTrackLayerOutputSlot<HalPolicy>(
721 operation, 0, *layer, 0, model, data, &constOutputStateOutInfo) &&
722 SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data) &&
723 SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo));
724 }
725 else
726 {
727 return ( SetupAndTrackLayerOutputSlot<HalPolicy>(
728 operation, 0, *layer, 0, model, data, &constOutputStateOutInfo) &&
729 SetupAndTrackLayerOutputSlot<HalPolicy>(
Kevin Mayfcf2a152020-09-08 16:06:32 +0100730 operation, 1, *layer, 1, model, data, nullptr, validateFunc,
731 ActivationFn::kActivationNone, true) &&
Sadik Armagan34db1872020-09-03 15:22:29 +0100732 SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo));
733 }
Sadik Armagan813f2302020-05-19 14:10:30 +0100734}
735
Finn Williamsfc884b42020-06-11 17:35:44 +0100736template<typename HalPolicy,
737 typename HalOperation = typename HalPolicy::Operation,
738 typename HalModel = typename HalPolicy::Model>
739bool ConvertRank(const HalOperation& operation, const HalModel& model, ConversionData& data)
740{
741 using HalOperand = typename HalPolicy::Operand;
742
743 const HalOperand* inputOperand = GetInputOperand<HalPolicy>(operation, 0, model);
744 const HalOperand* outputOperand = GetOutputOperand<HalPolicy>(operation, 0, model);
745
746 if (inputOperand == nullptr || outputOperand == nullptr)
747 {
748 return Fail("%s: Operation has invalid inputs", __func__);
749 }
750
751 const Shape inputOperandShape = GetOperandShape(*inputOperand);
752 const Shape outputOperandShape = GetOperandShape(*outputOperand);
753
754 LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
755 if (!input.IsValid())
756 {
757 return Fail("%s: Could not read input 0", __func__);
758 }
759
760 armnn::TensorInfo outInfo = GetTensorInfoForOperand(*outputOperand);
Teresa Charlin4bd9a742020-08-12 12:58:50 +0100761 if (IsDynamicTensor(outInfo))
762 {
763 return Fail("%s: Dynamic output tensors are not supported", __func__);
764 }
Finn Williamsfc884b42020-06-11 17:35:44 +0100765
766 bool isSupported = false;
767 FORWARD_LAYER_SUPPORT_FUNC(__func__,
768 IsRankSupported,
769 data.m_Backends,
770 isSupported,
771 input.GetTensorInfo(),
772 outInfo);
773 if (!isSupported)
774 {
775 return false;
776 }
777
778 armnn::IConnectableLayer* layer = data.m_Network->AddRankLayer();
779 assert(layer != nullptr);
780 input.Connect(layer->GetInputSlot(0));
781
782 return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, &outInfo);
783}
784
785} // armnn_driver namespace