IVGCVSW-6806 Add Unidirectional Sequence Lstm support to Neon
* Corrected TensorInfo order for IsUnidirectionalSequenceLstmSupported
* outputStateOut TensorInfo is not optional.
* cellStateOut TensorInfo is not optional.
* TensorInfo Order matches other QLSTM/LSTM layers.
* Added missing parameters to UnidirectionalSequenceLstmOperator for
delegate.
* Added quantized UnidirectionalSequenceLstm support to Neon
!android-nn-driver:7457
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I26dde1bb96793dd25eb9081ca5ae5f63752288c4
diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
index d447a46..c4345d4 100644
--- a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
+++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
@@ -59,7 +59,9 @@
TensorInfo inputInfo = GetTensorInfo(inputs[0]);
const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]);
const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]);
- TensorInfo outputInfo = GetTensorInfo(outputs[0]);
+ TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]);
+ TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]);
+ TensorInfo outputInfo = GetTensorInfo(outputs[2]);
TensorShape& inputShape = inputInfo.GetShape();
TensorShape& outputShape= outputInfo.GetShape();
auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map());
@@ -140,7 +142,7 @@
auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map());
std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
- auto currentOutputData = reinterpret_cast<float*>(outputs[0]->Map());
+ auto currentOutputData = reinterpret_cast<float*>(outputs[2]->Map());
std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
@@ -296,7 +298,7 @@
{
// Permute Output back to batch major
const PermutationVector& mappings = {1U, 0U, 2U};
- auto outputData = reinterpret_cast<float*>(outputs[0]->Map());
+ auto outputData = reinterpret_cast<float*>(outputs[2]->Map());
std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements());
outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
outputInfo.SetShape(outputShape);