IVGCVSW-5829 Segfault in TfLiteDelegate
* Updated Split function to read correct axis data.
* Improved validation in Split and SplitV function.
* Moved ComputeWrappedIndex function to DelegateUtils.hpp.
Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com>
Change-Id: I8c7d0c9b747d1ab548df98da930d838c2f57659e
diff --git a/delegate/src/Split.hpp b/delegate/src/Split.hpp
index 8248be9..ad55e53 100644
--- a/delegate/src/Split.hpp
+++ b/delegate/src/Split.hpp
@@ -47,7 +47,20 @@
ARMNN_ASSERT(GetTensorInfoForTfLiteTensor(tfLiteAxisTensor).GetNumElements() == 1);
auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
- const unsigned int splitDim = axisTensorData[0];
+ int32_t axis = axisTensorData[0];
+
+ auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
+ if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
+ {
+ // Square bracket denotes inclusive n while parenthesis denotes exclusive n
+ // E.g. Rank 4 tensor can have axis in range [-4, 3)
+ // -1 == 3, -2 == 2, -3 == 1, -4 == 0
+ TF_LITE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
+ axis, nodeIndex);
+ }
+ const unsigned int splitDim = ComputeWrappedIndex(axis, inputTensorInfo.GetNumDimensions());
std::vector<armnn::TensorInfo> outputs;
for (unsigned int i = 0; i < numSplits; ++i)
@@ -171,19 +184,17 @@
auto* axisTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteAxisTensor);
std::vector<int32_t> axisTensorData(axisTensorDataPtr, axisTensorDataPtr + 1);
+ int32_t axis = axisTensorData[0];
- auto ComputeWrappedIndex = [](int index, unsigned int numDimensions)
+ auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
+ if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
{
- int numDims = armnn::numeric_cast<int>(numDimensions);
- int wrappedIndex = index < 0 ? numDims + index : index;
- ARMNN_ASSERT(wrappedIndex >= 0);
- ARMNN_ASSERT(wrappedIndex < numDims);
-
- return static_cast<unsigned int>(wrappedIndex);
- };
-
- const unsigned int splitDim = ComputeWrappedIndex(axisTensorData[0],
- inputTensorInfo.GetNumDimensions());
+ TF_LITE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnDelegate: Operation has invalid axis: #%d. Axis must be in range [-n, n) in node #%d:",
+ axis, nodeIndex);
+ }
+ const unsigned int splitDim = ComputeWrappedIndex(axisTensorData[0], inputTensorInfo.GetNumDimensions());
auto* splitVParameters = reinterpret_cast<TfLiteSplitVParams*>(tfLiteNode->builtin_data);
unsigned int numSplits = 0;