GitHub 653: Segfault when parsing Unidirectional Sequence LSTM
* Fixed Segfault when parsing Unidirectional Sequence LSTM
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: Ic69a4190c60ef595be64bc2c356e540319381b7e
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 49f1f9f..479fc4f 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -3346,7 +3346,7 @@
|| params.m_OutputLayerNormWeights != nullptr);
desc.m_TimeMajor = nodeParams->time_major;
- if (desc.m_LayerNormEnabled)
+ if (operatorPtr->intermediates.size() > 3 && desc.m_LayerNormEnabled)
{
auto inputIntermediate = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[0]].get(),
inputTensorInfo).first;
@@ -3377,12 +3377,14 @@
desc.m_OutputIntermediateScale = defaultIntermediate;
}
- auto hiddentensor = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[4]].get(),
- inputTensorInfo).first;
+ if (operatorPtr->intermediates.size() > 4)
+ {
+ auto hiddentensor = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[4]].get(),
+ inputTensorInfo).first;
- desc.m_HiddenStateScale = hiddentensor->GetInfo().GetQuantizationScale();
- desc.m_HiddenStateZeroPoint = hiddentensor->GetInfo().GetQuantizationOffset();
-
+ desc.m_HiddenStateScale = hiddentensor->GetInfo().GetQuantizationScale();
+ desc.m_HiddenStateZeroPoint = hiddentensor->GetInfo().GetQuantizationOffset();
+ }
unsigned int batchSize = inputTensorInfo.GetShape()[0];
unsigned int outputSize = outputTensorInfo.GetShape()[2];
unsigned int numUnits = cellStateInInfo.GetShape()[1];