blob: e6961253b147eda3782f591fb0a75c5c9d27c2c2 [file] [log] [blame]
Sadik Armagan1153d1e2020-04-01 15:09:39 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ConversionUtils_1_2.hpp"
9
10using Half = half_float::half;
11
12namespace armnn_driver
13{
14
15using namespace armnn;
16using namespace android::nn;
17
18template<typename HalPolicy,
19 typename HalOperation = typename HalPolicy::Operation,
20 typename HalModel = typename HalPolicy::Model>
21bool ConvertElu(const HalOperation& operation, const HalModel& model, ConversionData& data)
22{
23 using HalOperandType = typename HalPolicy::OperandType;
24
25 LayerInputHandle input0 = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
26 if (!input0.IsValid())
27 {
28 return Fail("%s: Operation has invalid inputs", __func__);
29 }
30
31 // Determine data type of input tensor
32 HalOperandType inputType;
33 if (!GetOperandType<HalPolicy>(operation, 0, model, inputType))
34 {
35 return Fail("%s: Operation has invalid inputs", __func__);
36 }
37
38 ActivationDescriptor desc;
39 desc.m_Function = ActivationFunction::Elu;
40
41 // Read alpha
42 if (inputType == HalOperandType::TENSOR_FLOAT16)
43 {
44 Half alpha;
45
46 if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, alpha, model, data))
47 {
48 return Fail("%s: Operation has invalid inputs (FLOAT16)", __func__);
49 }
50
51 desc.m_A = static_cast<float>(alpha);
52 }
53 else if (inputType == HalOperandType::TENSOR_FLOAT32)
54 {
55 if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, desc.m_A, model, data))
56 {
57 return Fail("%s: Operation has invalid inputs (FLOAT32)", __func__);
58 }
59 }
60 else
61 {
62 return Fail("%s: Unsupported input tensor type: %d", __func__, inputType);
63 }
64
65 return ::ConvertToActivation<HalPolicy>(operation, __func__, desc, model, data);
66}
67
Sadik Armagan813f2302020-05-19 14:10:30 +010068template<typename HalPolicy,
Sadik Armagan2e329612020-06-24 10:57:23 +010069 typename HalOperation = typename HalPolicy::Operation,
70 typename HalModel = typename HalPolicy::Model>
71bool ConvertFill(const HalOperation& operation, const HalModel& model, ConversionData& data)
72{
73 using HalOperand = typename HalPolicy::Operand;
74 using HalOperandType = typename HalPolicy::OperandType;
75
76 LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
77 if (!input.IsValid())
78 {
79 return Fail("%s: Operation has invalid inputs", __func__);
80 }
81
82 const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 0, model);
83 if (!output)
84 {
85 return Fail("%s: Could not read output", __func__);
86 }
87
88 const TensorInfo& inputInfo = input.GetTensorInfo();
89 const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
90 if (IsDynamicTensor(outputInfo))
91 {
92 return Fail("%s: Dynamic output tensors are not supported", __func__);
93 }
94
95 // Determine data type of output tensor
96 HalOperandType outputType = output->type;
97 FillDescriptor descriptor;
98 // Read the scalar fill value
99 if (outputType == HalOperandType::TENSOR_FLOAT16)
100 {
101 Half value;
102
103 if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT16, value, model, data))
104 {
105 return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
106 }
107
108 descriptor.m_Value = static_cast<float>(value);
109 }
110 else if (outputType == HalOperandType::TENSOR_FLOAT32)
111 {
112 if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::FLOAT32, descriptor.m_Value, model, data))
113 {
114 return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
115 }
116 }
117 else if (outputType == HalOperandType::TENSOR_INT32)
118 {
119 int32_t value;
120
121 if (!GetInputScalar<HalPolicy>(operation, 1, HalOperandType::INT32, value, model, data))
122 {
123 return Fail("%s: Operation has invalid inputs %d", __func__, outputType);
124 }
125
126 descriptor.m_Value = static_cast<float>(value);
127 }
128 else
129 {
130 return Fail("%s: Unsupported input tensor type: %d", __func__, outputType);
131 }
132
133 bool isSupported = false;
134 FORWARD_LAYER_SUPPORT_FUNC(__func__,
135 IsFillSupported,
136 data.m_Backends,
137 isSupported,
138 inputInfo,
139 outputInfo,
140 descriptor);
141 if (!isSupported)
142 {
143 return false;
144 }
145
146 IConnectableLayer* const layer = data.m_Network->AddFillLayer(descriptor);
147 assert(layer != nullptr);
148 input.Connect(layer->GetInputSlot(0));
Sadik Armagan2e329612020-06-24 10:57:23 +0100149
150 return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data);
151}
152
153template<typename HalPolicy,
Sadik Armagan813f2302020-05-19 14:10:30 +0100154 typename HalOperation = typename HalPolicy::Operation,
155 typename HalModel = typename HalPolicy::Model>
156bool ConvertQuantizedLstm(const HalOperation& operation, const HalModel& model, ConversionData& data)
157{
158 using HalOperand = typename HalPolicy::Operand;
159 using HalOperandType = typename HalPolicy::OperandType;
160
161 ALOGV("HalPolicy::ConvertQuantizedLstm()");
162
163 //Inputs:
164 // 0: The input: A 2-D tensor of type ANEURALNETWORKS_TENSOR_QUANT8_ASYMM and shape [numBatches, inputSize]
165 // specifying the input to the LSTM cell. Tensor is quantized with a fixed quantization range of -1, 127/128.
166 LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
167 if (!input.IsValid())
168 {
169 return Fail("%s: Could not read input 0: input", __func__);
170 }
171
172 // 18: The output state: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, of shape [batch_size, output_size].
173 LayerInputHandle outputStatePrevTimeStep = ConvertToLayerInputHandle<HalPolicy>(operation, 18, model, data);
174 if (!outputStatePrevTimeStep.IsValid())
175 {
176 return Fail("%s: Could not read input 18: outputStatePrevTimeStep", __func__);
177 }
178
179 // 19: The cell state: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape [batch_size, num_units].
180 LayerInputHandle cellStatePrevTimeStep = ConvertToLayerInputHandle<HalPolicy>(operation, 19, model, data);
181 if (!cellStatePrevTimeStep.IsValid())
182 {
183 return Fail("%s: Could not read input 19: cellStatePrevTimeStep", __func__);
184 }
185
186 // Get the mandatory input tensors:
187
188 // 02: The input-to-forget weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
189 // [num_units, input_size].
190 const ConstTensorPin inputToForgetWeightsPin =
191 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 2, model, data);
192
193 // 03: The input-to-cell weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
194 // [num_units, input_size].
195 const ConstTensorPin inputToCellWeightsPin =
196 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 3, model, data);
197
198 // 04: The input-to-output weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
199 // [num_units, input_size].
200 const ConstTensorPin inputToOutputWeightsPin =
201 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 4, model, data);
202
203 // 06: The recurrent-to-forget weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
204 // [num_units, output_size].
205 const ConstTensorPin recurrentToForgetWeightsPin =
206 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 6, model, data);
207
208 // 07: The recurrent-to-cell weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
209 // [num_units, output_size].
210 const ConstTensorPin recurrentToCellWeightsPin =
211 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 7, model, data);
212
213 // 08: The recurrent-to-output weights: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
214 // [num_units, output_size].
215 const ConstTensorPin recurrentToOutputWeightsPin =
216 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 8, model, data);
217
218 // 13: The forget gate bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
219 const ConstTensorPin forgetGateBiasPin =
220 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 13, model, data);
221
222 // 14: The cell bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
223 const ConstTensorPin cellBiasPin =
224 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 14, model, data);
225
226 // 15: The output gate bias: A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
227 const ConstTensorPin outputGateBiasPin =
228 ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 15, model, data);
229
230 if (!inputToForgetWeightsPin.IsValid() ||
231 !inputToCellWeightsPin.IsValid() ||
232 !inputToOutputWeightsPin.IsValid() ||
233 !recurrentToForgetWeightsPin.IsValid() ||
234 !recurrentToCellWeightsPin.IsValid() ||
235 !recurrentToOutputWeightsPin.IsValid() ||
236 !forgetGateBiasPin.IsValid() ||
237 !cellBiasPin.IsValid() ||
238 !outputGateBiasPin.IsValid())
239 {
240 return Fail("%s: Operation has invalid tensor inputs", __func__);
241 }
242
243 // Get the optional input tensors:
244
245 // 01: The input-to-input weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
246 // [num_units, input_size], where “num_units” corresponds to the number of cell units.
247 const ConstTensorPin inputToInputWeightsPin =
248 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
249 1,
250 model,
251 data,
252 g_DontPermute,
253 nullptr,
254 true);
255
256 // 05: The recurrent-to-input weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
257 // [num_units, output_size], where “output_size” corresponds to either the number of cell units (i.e.,
258 // “num_units”), or the second dimension of the “projection_weights”, if defined.
259 const ConstTensorPin recurrentToInputWeightsPin =
260 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
261 5,
262 model,
263 data,
264 g_DontPermute,
265 nullptr,
266 true);
267
268 // 09: The cell-to-input weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
269 // [num_units].
270 const ConstTensorPin cellToInputWeightsPin =
271 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
272 9,
273 model,
274 data,
275 g_DontPermute,
276 nullptr,
277 true);
278
279 // 10: The cell-to-forget weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
280 // [num_units].
281 const ConstTensorPin cellToForgetWeightsPin =
282 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
283 10,
284 model,
285 data,
286 g_DontPermute,
287 nullptr,
288 true);
289
290 // 11: The cell-to-output weights: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape
291 // [num_units].
292 const ConstTensorPin cellToOutputWeightsPin =
293 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
294 11,
295 model,
296 data,
297 g_DontPermute,
298 nullptr,
299 true);
300
301 // 12: The input gate bias: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [num_units].
302 const ConstTensorPin inputGateBiasPin =
303 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
304 12,
305 model,
306 data,
307 g_DontPermute,
308 nullptr,
309 true);
310
311 // 16: The projection weights: Optional. A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_SYMM, of shape
312 // [output_size, num_units].
313 const ConstTensorPin projectionWeightsPin =
314 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
315 16,
316 model,
317 data,
318 g_DontPermute,
319 nullptr,
320 true);
321
322 // 17: The projection bias: Optional. A 1-D tensor of ANEURALNETWORKS_TENSOR_INT32, of shape [output_size].
323 const ConstTensorPin projectionBiasPin =
324 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
325 17,
326 model,
327 data,
328 g_DontPermute,
329 nullptr,
330 true);
331
332 if ((!inputToInputWeightsPin.IsValid() && !inputToInputWeightsPin.IsOptional())
333 || (!recurrentToInputWeightsPin.IsValid() && !recurrentToInputWeightsPin.IsOptional())
334 || (!cellToInputWeightsPin.IsValid() && !cellToInputWeightsPin.IsOptional())
335 || (!cellToForgetWeightsPin.IsValid() && !cellToForgetWeightsPin.IsOptional())
336 || (!cellToOutputWeightsPin.IsValid() && !cellToOutputWeightsPin.IsOptional())
337 || (!inputGateBiasPin.IsValid() && !inputGateBiasPin.IsOptional())
338 || (!projectionWeightsPin.IsValid() && !projectionWeightsPin.IsOptional())
339 || (!projectionBiasPin.IsValid() && !projectionBiasPin.IsOptional()))
340 {
341 return Fail("%s: Operation has invalid tensor inputs", __func__);
342 }
343
344
345 // Get the optional normalization tensors
346
347 // 20: The input layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM.
348 // Used to rescale normalized inputs to activation at input gate.
349 const ConstTensorPin inputLayerNormWeightsPin =
350 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
351 20,
352 model,
353 data,
354 g_DontPermute,
355 nullptr,
356 true);
357
358 // 21: The forget layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM
359 // Used to rescale normalized inputs to activation at forget gate.
360 const ConstTensorPin forgetLayerNormWeightsPin =
361 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
362 21,
363 model,
364 data,
365 g_DontPermute,
366 nullptr,
367 true);
368
369 // 22: The cell layer normalization weights. A 1-D tensor of shape [num_units] ANEURALNETWORKS_TENSOR_QUANT16_SYMM.
370 // Used to rescale normalized inputs to activation at cell gate.
371 const ConstTensorPin cellLayerNormWeightsPin =
372 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
373 22,
374 model,
375 data,
376 g_DontPermute,
377 nullptr,
378 true);
379
380 // 23: The output layer normalization weights. A 1-D tensor of shape [num_units].
381 // Used to rescale normalized inputs to activation at output gate.
382 const ConstTensorPin outputLayerNormWeightsPin =
383 ConvertOperationInputToConstTensorPin<HalPolicy>(operation,
384 23,
385 model,
386 data,
387 g_DontPermute,
388 nullptr,
389 true);
390
391 if ((!inputLayerNormWeightsPin.IsValid() && !inputLayerNormWeightsPin.IsOptional())
392 || (!forgetLayerNormWeightsPin.IsValid() && !forgetLayerNormWeightsPin.IsOptional())
393 || (!cellLayerNormWeightsPin.IsValid() && !cellLayerNormWeightsPin.IsOptional())
394 || (!outputLayerNormWeightsPin.IsValid() && !outputLayerNormWeightsPin.IsOptional()))
395 {
396 return Fail("%s: Operation has invalid tensor inputs", __func__);
397 }
398
399 // Get the optional input scalars:
400 // 24: The cell clip: If provided the cell state is clipped by this value prior to the cell output activation.
401 // 25: The projection clip: If provided and projection is enabled, this is used for clipping the projected values.
402
403 // Get the mandatory input scalars:
404 // 26: The scale of the intermediate result of matmul, i.e. input to layer normalization, at input gate.
405 // 27: The scale of the intermediate result of matmul, i.e. input to layer normalization, at forget gate.
406 // 28: The scale of the intermediate result of matmul, i.e. input to layer normalization, at cell gate.
407 // 29: The scale of the intermediate result of matmul, i.e. input to layer normalization, at output gate.
408 // 30: The zero point of the hidden state, i.e. input to projection.
409 // 31: The scale of the hidden state, i.e. input to projection.
410 float cellClip, projClip, matMulInputGate, matMulForgetGate, matMulCellGate, matMulOutputGate, projInputScale;
411 int projInputZeroPoint;
412
413 if (!GetInputScalar<HalPolicy>(operation, 24, HalOperandType::FLOAT32, cellClip, model, data, true) ||
414 !GetInputScalar<HalPolicy>(operation, 25, HalOperandType::FLOAT32, projClip, model, data, true) ||
415 !GetInputScalar<HalPolicy>(operation, 26, HalOperandType::FLOAT32, matMulInputGate, model, data) ||
416 !GetInputScalar<HalPolicy>(operation, 27, HalOperandType::FLOAT32, matMulForgetGate, model, data) ||
417 !GetInputScalar<HalPolicy>(operation, 28, HalOperandType::FLOAT32, matMulCellGate, model, data) ||
418 !GetInputScalar<HalPolicy>(operation, 29, HalOperandType::FLOAT32, matMulOutputGate, model, data) ||
Sadik Armagan24af8b22020-05-22 08:34:16 +0100419 !GetInputScalar<HalPolicy>(operation, 30, HalOperandType::INT32, projInputZeroPoint, model, data) ||
420 !GetInputScalar<HalPolicy>(operation, 31, HalOperandType::FLOAT32, projInputScale, model, data))
Sadik Armagan813f2302020-05-19 14:10:30 +0100421 {
422 return Fail("%s: Operation has invalid scalar inputs", __func__);
423 }
424
425 // Outputs:
426 // 0: The output state (out): A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED, of shape [batch_size,
427 // output_size].
428 const HalOperand* outputStateOut = GetOutputOperand<HalPolicy>(operation, 0, model);
429 if (!outputStateOut)
430 {
431 return Fail("%s: Could not read output 0: outputStateOut", __func__);
432 }
433
434 // 1: The cell state (out): A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT16_SYMM, of shape [batch_size, num_units].
435 const HalOperand* cellStateOut = GetOutputOperand<HalPolicy>(operation, 1, model);
436 if (!cellStateOut)
437 {
438 return Fail("%s: Could not read output 1: cellStateOut", __func__);
439 }
440
441 // 2: The output: A 2-D tensor of ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED, of shape [batch_size, output_size].
442 // This is effectively the same as the current “output state (out)” value.
443 const HalOperand* output = GetOutputOperand<HalPolicy>(operation, 2, model);
444 if (!output)
445 {
446 return Fail("%s: Could not read output 2: output", __func__);
447 }
448
449 // set the params structure for the AddLstmLayer call
450 LstmInputParams params;
451 params.m_InputToInputWeights = inputToInputWeightsPin.GetConstTensorPtr();
452 params.m_InputToForgetWeights = inputToForgetWeightsPin.GetConstTensorPtr();
453 params.m_InputToCellWeights = inputToCellWeightsPin.GetConstTensorPtr();
454 params.m_InputToOutputWeights = inputToOutputWeightsPin.GetConstTensorPtr();
455 params.m_RecurrentToInputWeights = recurrentToInputWeightsPin.GetConstTensorPtr();
456 params.m_RecurrentToForgetWeights = recurrentToForgetWeightsPin.GetConstTensorPtr();
457 params.m_RecurrentToCellWeights = recurrentToCellWeightsPin.GetConstTensorPtr();
458 params.m_RecurrentToOutputWeights = recurrentToOutputWeightsPin.GetConstTensorPtr();
459 params.m_CellToInputWeights = cellToInputWeightsPin.GetConstTensorPtr();
460 params.m_CellToForgetWeights = cellToForgetWeightsPin.GetConstTensorPtr();
461 params.m_CellToOutputWeights = cellToOutputWeightsPin.GetConstTensorPtr();
462 params.m_InputGateBias = inputGateBiasPin.GetConstTensorPtr();
463 params.m_ForgetGateBias = forgetGateBiasPin.GetConstTensorPtr();
464 params.m_CellBias = cellBiasPin.GetConstTensorPtr();
465 params.m_OutputGateBias = outputGateBiasPin.GetConstTensorPtr();
466 params.m_ProjectionWeights = projectionWeightsPin.GetConstTensorPtr();
467 params.m_ProjectionBias = projectionBiasPin.GetConstTensorPtr();
468 params.m_InputLayerNormWeights = inputLayerNormWeightsPin.GetConstTensorPtr();
469 params.m_ForgetLayerNormWeights = forgetLayerNormWeightsPin.GetConstTensorPtr();
470 params.m_CellLayerNormWeights = cellLayerNormWeightsPin.GetConstTensorPtr();
471 params.m_OutputLayerNormWeights = outputLayerNormWeightsPin.GetConstTensorPtr();
472
473 // set the layer descriptor
474 QLstmDescriptor desc;
475 desc.m_CellClip = cellClip;
476 desc.m_ProjectionClip = projClip;
477 desc.m_CifgEnabled = (params.m_InputToInputWeights == nullptr ||
478 params.m_RecurrentToInputWeights == nullptr ||
479 params.m_InputGateBias == nullptr);
480 desc.m_PeepholeEnabled = (params.m_CellToForgetWeights != nullptr ||
481 params.m_CellToOutputWeights != nullptr);
482 desc.m_ProjectionEnabled = (params.m_ProjectionWeights != nullptr);
483 desc.m_LayerNormEnabled = (params.m_InputLayerNormWeights != nullptr ||
484 params.m_ForgetLayerNormWeights != nullptr ||
485 params.m_CellLayerNormWeights != nullptr ||
486 params.m_OutputLayerNormWeights != nullptr);
487 desc.m_InputIntermediateScale = matMulInputGate;
488 desc.m_ForgetIntermediateScale = matMulForgetGate;
489 desc.m_CellIntermediateScale = matMulCellGate;
490 desc.m_OutputIntermediateScale = matMulOutputGate;
491 desc.m_HiddenStateScale = projInputScale;
492 desc.m_HiddenStateZeroPoint = projInputZeroPoint;
493
494 // validate the optional input groups
495 if (desc.m_CifgEnabled &&
496 (params.m_InputToInputWeights != nullptr ||
497 params.m_RecurrentToInputWeights != nullptr ||
498 params.m_InputGateBias != nullptr))
499 {
500 return Fail("%s: All, or none, of input-to-input weights, recurrent-to-input weights,"
501 " and input gate bias must be provided", __func__);
502 }
503
504 if (!desc.m_ProjectionEnabled && params.m_ProjectionBias != nullptr)
505 {
506 return Fail("%s: projection bias should not be provided without projection weights", __func__);
507 }
508
509 if (desc.m_PeepholeEnabled &&
510 (params.m_CellToForgetWeights == nullptr ||
511 params.m_CellToOutputWeights == nullptr ||
512 (!desc.m_CifgEnabled && params.m_CellToInputWeights == nullptr)))
513 {
514 return Fail("%s: All, or none, of cell-to-forget weights and cell-to-output weights must be provided"
515 " and, if CIFG is not enabled, cell-to-input weights must also be provided", __func__);
516 }
517
518 if (desc.m_LayerNormEnabled &&
519 (params.m_ForgetLayerNormWeights == nullptr ||
520 params.m_CellLayerNormWeights == nullptr ||
521 params.m_OutputLayerNormWeights == nullptr ||
522 (!desc.m_CifgEnabled && params.m_InputLayerNormWeights == nullptr)))
523 {
524 return Fail("%s: All, or none, of forget-norm weights, cell-norm weights and output-norm weights must be"
525 " provided and, if CIFG is not enabled, input-norm weights must also be provided", __func__);
526 }
527
528
529 // Basic parameters
530 LstmInputParamsInfo paramsInfo;
531 paramsInfo.m_InputToForgetWeights = &(params.m_InputToForgetWeights->GetInfo());
532 paramsInfo.m_InputToCellWeights = &(params.m_InputToCellWeights->GetInfo());
533 paramsInfo.m_InputToOutputWeights = &(params.m_InputToOutputWeights->GetInfo());
534 paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
535 paramsInfo.m_RecurrentToCellWeights = &(params.m_RecurrentToCellWeights->GetInfo());
536 paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
537 paramsInfo.m_ForgetGateBias = &(params.m_ForgetGateBias->GetInfo());
538 paramsInfo.m_CellBias = &(params.m_CellBias->GetInfo());
539 paramsInfo.m_OutputGateBias = &(params.m_OutputGateBias->GetInfo());
540
541 // Inputs
542 const TensorInfo& inputInfo = input.GetTensorInfo();
543 const TensorInfo& outputStatePrevTimeStepInfo = outputStatePrevTimeStep.GetTensorInfo();
544 const TensorInfo& cellStatePrevTimeStepInfo = cellStatePrevTimeStep.GetTensorInfo();
545
546 // Outputs
547 TensorInfo outputStateOutInfo = GetTensorInfoForOperand(*outputStateOut);
548 TensorInfo outputInfo = GetTensorInfoForOperand(*output);
549 const TensorInfo& cellStateOutInfo = GetTensorInfoForOperand(*cellStateOut);
550
551 // Optional parameters
552 if (!desc.m_CifgEnabled)
553 {
554 paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
555 paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
556 if (desc.m_PeepholeEnabled)
557 {
558 paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
559 }
560 paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
561 }
562
563
564 if (desc.m_ProjectionEnabled)
565 {
566 paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
567 if (params.m_ProjectionBias != nullptr)
568 {
569 paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
570 }
571 }
572 else
573 {
574 // If Projection is disabled, override non-const outputs to change the quant info with hidden params, then
575 // create a new const TensorInfo based on this
576 outputStateOutInfo.SetQuantizationScale(projInputScale);
577 outputStateOutInfo.SetQuantizationOffset(projInputZeroPoint);
578 outputInfo.SetQuantizationScale(projInputScale);
579 outputInfo.SetQuantizationOffset(projInputZeroPoint);
580 }
581
582 const TensorInfo constOutputStateOutInfo(outputStateOutInfo);
583 const TensorInfo constOutputInfo(outputInfo);
584
585 if (desc.m_PeepholeEnabled)
586 {
587 paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
588 paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
589 }
590
591 if (desc.m_LayerNormEnabled)
592 {
593 if(!desc.m_CifgEnabled)
594 {
595 paramsInfo.m_InputLayerNormWeights = &(params.m_InputLayerNormWeights->GetInfo());
596 }
597 paramsInfo.m_ForgetLayerNormWeights = &(params.m_ForgetLayerNormWeights->GetInfo());
598 paramsInfo.m_CellLayerNormWeights = &(params.m_CellLayerNormWeights->GetInfo());
599 paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
600 }
601
602 // Check if the layer is supported
603
604 if (IsDynamicTensor(constOutputStateOutInfo) ||
605 IsDynamicTensor(cellStateOutInfo) ||
606 IsDynamicTensor(constOutputInfo))
607 {
608 return Fail("%s: Dynamic output tensors are not supported %d %d %d %d", __func__,
609 IsDynamicTensor(constOutputStateOutInfo), IsDynamicTensor(cellStateOutInfo),
610 IsDynamicTensor(constOutputInfo));
611 }
612
613 bool isSupported = false;
614 FORWARD_LAYER_SUPPORT_FUNC(__func__,
615 IsQLstmSupported,
616 data.m_Backends,
617 isSupported,
618 inputInfo,
619 outputStatePrevTimeStepInfo,
620 cellStatePrevTimeStepInfo,
621 constOutputStateOutInfo,
622 cellStateOutInfo,
623 constOutputInfo,
624 desc,
625 paramsInfo);
626 if (!isSupported)
627 {
628 return false;
629 }
630
631 // Add the layer
632 IConnectableLayer* layer = data.m_Network->AddQLstmLayer(desc, params, "QLstm");
633
634 input.Connect(layer->GetInputSlot(0));
635 outputStatePrevTimeStep.Connect(layer->GetInputSlot(1));
636 cellStatePrevTimeStep.Connect(layer->GetInputSlot(2));
637
638 return ( SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data,
639 &constOutputStateOutInfo) &&
640 SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data) &&
641 SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo));
642}
643
Finn Williamsfc884b42020-06-11 17:35:44 +0100644template<typename HalPolicy,
645 typename HalOperation = typename HalPolicy::Operation,
646 typename HalModel = typename HalPolicy::Model>
647bool ConvertRank(const HalOperation& operation, const HalModel& model, ConversionData& data)
648{
649 using HalOperand = typename HalPolicy::Operand;
650
651 const HalOperand* inputOperand = GetInputOperand<HalPolicy>(operation, 0, model);
652 const HalOperand* outputOperand = GetOutputOperand<HalPolicy>(operation, 0, model);
653
654 if (inputOperand == nullptr || outputOperand == nullptr)
655 {
656 return Fail("%s: Operation has invalid inputs", __func__);
657 }
658
659 const Shape inputOperandShape = GetOperandShape(*inputOperand);
660 const Shape outputOperandShape = GetOperandShape(*outputOperand);
661
662 LayerInputHandle input = ConvertToLayerInputHandle<HalPolicy>(operation, 0, model, data);
663 if (!input.IsValid())
664 {
665 return Fail("%s: Could not read input 0", __func__);
666 }
667
668 armnn::TensorInfo outInfo = GetTensorInfoForOperand(*outputOperand);
Teresa Charlin4bd9a742020-08-12 12:58:50 +0100669 if (IsDynamicTensor(outInfo))
670 {
671 return Fail("%s: Dynamic output tensors are not supported", __func__);
672 }
Finn Williamsfc884b42020-06-11 17:35:44 +0100673
674 bool isSupported = false;
675 FORWARD_LAYER_SUPPORT_FUNC(__func__,
676 IsRankSupported,
677 data.m_Backends,
678 isSupported,
679 input.GetTensorInfo(),
680 outInfo);
681 if (!isSupported)
682 {
683 return false;
684 }
685
686 armnn::IConnectableLayer* layer = data.m_Network->AddRankLayer();
687 assert(layer != nullptr);
688 input.Connect(layer->GetInputSlot(0));
689
690 return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, &outInfo);
691}
692
693} // armnn_driver namespace