IVGCVSW-7749 DTS: Fix reshape floating point exception
* Updated Opaque Delegate, TfliteParser, OnnxParser, and Deserializer to handle the Zero In Shape edge case
Signed-off-by: Tianle Cheng <tianle.cheng@arm.com>
Change-Id: I4a0d1e72a66de1fa56de99af9b6730a84e0ff596
diff --git a/delegate/classic/src/Redefine.hpp b/delegate/classic/src/Redefine.hpp
index 6b10e44..c3422a2 100644
--- a/delegate/classic/src/Redefine.hpp
+++ b/delegate/classic/src/Redefine.hpp
@@ -166,6 +166,18 @@
return kTfLiteError;
}
+ // Check the target shape to check if there is zero in the shape.
+ if (std::find(targetShape.begin(), targetShape.end(), 0) != targetShape.end() &&
+ inputTensorInfo0.GetNumElements() != 0)
+ {
+ TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
+ "TfLiteArmnnDelegate: Input to reshape is a tensor with elements, "
+ "but the requested shape has 0. "
+ "operator #%d node #%d: ",
+ operatorCode, nodeIndex);
+ return kTfLiteError;
+ }
+
// Use the data to create the required tensor shape.
if (CreateOutputTensorShape(inputTensorInfo0, targetShape, reshapeDesc) != kTfLiteOk)
{
diff --git a/delegate/common/src/DelegateUtils.hpp b/delegate/common/src/DelegateUtils.hpp
index a74ed8b..a2cdc83 100644
--- a/delegate/common/src/DelegateUtils.hpp
+++ b/delegate/common/src/DelegateUtils.hpp
@@ -186,7 +186,16 @@
std::accumulate(targetShape.begin(), targetShape.end(), -1, std::multiplies<int32_t>()));
auto stretchIndex = static_cast<size_t>(std::distance(targetShape.begin(), stretchDim));
- outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
+
+ if (targetNumElements == 0)
+ {
+ // To handle the edge case that input and output both have zero elements
+ outputDims[stretchIndex] = 0;
+ }
+ else
+ {
+ outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
+ }
}
armnn::TensorShape outputShape = armnn::TensorShape(static_cast<unsigned int>(outputDims.size()),
diff --git a/delegate/opaque/src/Redefine.hpp b/delegate/opaque/src/Redefine.hpp
index 5ce7a3d..6319ca7 100644
--- a/delegate/opaque/src/Redefine.hpp
+++ b/delegate/opaque/src/Redefine.hpp
@@ -201,6 +201,19 @@
return kTfLiteError;
}
+ // Check the target shape to check if there is zero in the shape.
+ if (std::find(targetShape.begin(), targetShape.end(), 0) != targetShape.end() &&
+ inputTensorInfo0.GetNumElements() != 0)
+ {
+ TF_LITE_OPAQUE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnOpaqueDelegate: Input to reshape is a tensor with elements, "
+ "but the requested shape has 0. "
+ "operator #%d node #%d: ",
+ operatorCode, nodeIndex);
+ return kTfLiteError;
+ }
+
// Use the data to create the required tensor shape.
if (CreateOutputTensorShape(inputTensorInfo0, targetShape, reshapeDesc) != kTfLiteOk)
{