IVGCVSW-7441 Checking for constant input tensors before populating.
* When the tfLiteExecutor attempts to populate the input tensors it did
not check whether the tensor was constant. This was causing
segmentation faults.
Signed-off-by: Colm Donelan <colm.donelan@arm.com>
Change-Id: I80a4cc788de4ffe08afb2df9185d04fcb8b27c3a
diff --git a/tests/ExecuteNetwork/TfliteExecutor.cpp b/tests/ExecuteNetwork/TfliteExecutor.cpp
index 810495f..3c8313b 100644
--- a/tests/ExecuteNetwork/TfliteExecutor.cpp
+++ b/tests/ExecuteNetwork/TfliteExecutor.cpp
@@ -4,6 +4,7 @@
//
#include "TfliteExecutor.hpp"
+#include "tensorflow/lite/kernels/kernel_util.h"
TfLiteExecutor::TfLiteExecutor(const ExecuteNetworkParams& params) : m_Params(params)
{
@@ -51,54 +52,62 @@
: armnn::MakeOptional<std::string>(m_Params.m_InputTensorDataFilePaths[inputIndex]);
int input = m_TfLiteInterpreter->inputs()[inputIndex];
-
- TfLiteIntArray* inputDims = m_TfLiteInterpreter->tensor(input)->dims;
-
- unsigned int inputSize = 1;
- for (unsigned int dim = 0; dim < static_cast<unsigned int>(inputDims->size); ++dim)
- {
- inputSize *= inputDims->data[dim];
- }
-
const auto& inputName = m_TfLiteInterpreter->tensor(input)->name;
- const auto& dataType = m_TfLiteInterpreter->tensor(input)->type;
- switch (dataType)
+ // Before we start, check if the tensor is constant.
+ if (!tflite::IsConstantTensor(m_TfLiteInterpreter->tensor(input)))
{
- case kTfLiteFloat32:
+ TfLiteIntArray* inputDims = m_TfLiteInterpreter->tensor(input)->dims;
+
+ unsigned int inputSize = 1;
+ for (unsigned int dim = 0; dim < static_cast<unsigned int>(inputDims->size); ++dim)
{
- auto inputData = m_TfLiteInterpreter->typed_tensor<float>(input);
- PopulateTensorWithData<float>(inputData, inputSize, dataFile, inputName);
- break;
+ inputSize *= inputDims->data[dim];
}
- case kTfLiteInt32:
+
+ const auto& dataType = m_TfLiteInterpreter->tensor(input)->type;
+
+ switch (dataType)
{
- auto inputData = m_TfLiteInterpreter->typed_tensor<int32_t>(input);
- PopulateTensorWithData<int32_t>(inputData, inputSize, dataFile, inputName);
- break;
+ case kTfLiteFloat32:
+ {
+ auto inputData = m_TfLiteInterpreter->typed_tensor<float>(input);
+ PopulateTensorWithData<float>(inputData, inputSize, dataFile, inputName);
+ break;
+ }
+ case kTfLiteInt32:
+ {
+ auto inputData = m_TfLiteInterpreter->typed_tensor<int32_t>(input);
+ PopulateTensorWithData<int32_t>(inputData, inputSize, dataFile, inputName);
+ break;
+ }
+ case kTfLiteUInt8:
+ {
+ auto inputData = m_TfLiteInterpreter->typed_tensor<uint8_t>(input);
+ PopulateTensorWithData<uint8_t>(inputData, inputSize, dataFile, inputName);
+ break;
+ }
+ case kTfLiteInt16:
+ {
+ auto inputData = m_TfLiteInterpreter->typed_tensor<int16_t>(input);
+ PopulateTensorWithData<int16_t>(inputData, inputSize, dataFile, inputName);
+ break;
+ }
+ case kTfLiteInt8:
+ {
+ auto inputData = m_TfLiteInterpreter->typed_tensor<int8_t>(input);
+ PopulateTensorWithData<int8_t>(inputData, inputSize, dataFile, inputName);
+ break;
+ }
+ default:
+ {
+ LogAndThrow("Unsupported input tensor data type");
+ }
}
- case kTfLiteUInt8:
- {
- auto inputData = m_TfLiteInterpreter->typed_tensor<uint8_t>(input);
- PopulateTensorWithData<uint8_t>(inputData, inputSize, dataFile, inputName);
- break;
- }
- case kTfLiteInt16:
- {
- auto inputData = m_TfLiteInterpreter->typed_tensor<int16_t>(input);
- PopulateTensorWithData<int16_t>(inputData, inputSize, dataFile, inputName);
- break;
- }
- case kTfLiteInt8:
- {
- auto inputData = m_TfLiteInterpreter->typed_tensor<int8_t>(input);
- PopulateTensorWithData<int8_t>(inputData, inputSize, dataFile, inputName);
- break;
- }
- default:
- {
- LogAndThrow("Unsupported input tensor data type");
- }
+ }
+ else
+ {
+ ARMNN_LOG(info) << "Input tensor \"" << inputName << "\" is constant and will not be populated with data.";
}
}
}