IVGCVSW-6417: Catch AddFullyConnected API error when weights TensorInfo isn't set
* Updated code in Graph.cpp InferTensorInfos() to be more descriptive.
* Added method VerifyConstantLayerSetTensorInfo() in Graph.cpp/hpp
to error when ConstantLayer TensorInfo is not set.
* Updated Optimize() in Network.cpp to call VerifyConstantLayerSetTensorInfo().
* Added unit test with ConstantLayer TensorInfo not
set to catch error in VerifyConstantLayerSetTensorInfo().
* Added comments around method VerifyConstantLayerSetTensorInfo().
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I366596243f7c5823676222e2d0cce1335bc8c325
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp
index 7b6f56f..60bf328 100644
--- a/src/armnn/Graph.cpp
+++ b/src/armnn/Graph.cpp
@@ -526,6 +526,33 @@
subgraph.Clear();
}
+/// For each ConstantLayer in Graph, ensures TensorInfo is set on all output slots.
+/// LayerValidationException thrown if no TensorInfo is set.
+///
+/// @throws LayerValidationException
+void Graph::VerifyConstantLayerSetTensorInfo() const
+{
+ for (auto&& layer : TopologicalSort())
+ {
+ if(layer->GetType() == armnn::LayerType::Constant)
+ {
+ for (auto&& output: layer->GetOutputSlots())
+ {
+ if (!output.IsTensorInfoSet())
+ {
+ std::ostringstream message;
+ message << "Output slot TensorInfo not set on "
+ << GetLayerTypeAsCString(layer->GetType())
+ << " layer \""
+ << layer->GetName()
+ << "\"";
+ throw LayerValidationException(message.str());
+ }
+ }
+ }
+ }
+}
+
void Graph::InferTensorInfos()
{
for (auto&& layer : TopologicalSort())
@@ -536,7 +563,9 @@
if (source == NULL)
{
std::ostringstream message;
- message << "Input not connected on "
+ message << "Input slot "
+ << input.GetSlotIndex()
+ << " not connected to an output slot on "
<< GetLayerTypeAsCString(layer->GetType())
<< " layer \""
<< layer->GetName()
@@ -546,13 +575,19 @@
if (!source->IsTensorInfoSet())
{
- throw LayerValidationException("All inputs must have the TensorInfo set at this point.");
+ std::ostringstream message;
+ message << "Output slot TensorInfo not set on "
+ << GetLayerTypeAsCString(layer->GetType())
+ << " layer \""
+ << layer->GetName()
+ << "\"";
+ throw LayerValidationException(message.str());
}
+ }
- if (layer->m_ShapeInferenceMethod == ShapeInferenceMethod::ValidateOnly)
- {
- layer->ValidateTensorShapesFromInputs();
- }
+ if (layer->m_ShapeInferenceMethod == ShapeInferenceMethod::ValidateOnly)
+ {
+ layer->ValidateTensorShapesFromInputs();
}
}
}
diff --git a/src/armnn/Graph.hpp b/src/armnn/Graph.hpp
index 731ae1e..d5fbeaf 100644
--- a/src/armnn/Graph.hpp
+++ b/src/armnn/Graph.hpp
@@ -203,6 +203,10 @@
void SubstituteSubgraph(SubgraphView& subgraph, IConnectableLayer* substituteLayer);
void SubstituteSubgraph(SubgraphView& subgraph, const SubgraphView& substituteSubgraph);
+ /// For each ConstantLayer in Graph, ensures TensorInfo is set on all output slots.
+ /// LayerValidationException thrown if no TensorInfo is set.
+ void VerifyConstantLayerSetTensorInfo() const;
+
void InferTensorInfos();
void AttachObservable(IGraphObservable* const observable, GraphEvent notifyOnEvent) {
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index a39b6b1..39af10f 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -1576,6 +1576,9 @@
throw InvalidArgumentException("BFloat16 and Float16 optimization cannot be enabled at the same time.");
}
+ // Ensure TensorInfo is set on all output slots of ConstantLayers in the graph
+ inNetwork.pNetworkImpl->GetGraph().VerifyConstantLayerSetTensorInfo();
+
std::unique_ptr<Graph> graph = std::make_unique<Graph>(inNetwork.pNetworkImpl->GetGraph());
auto optNet = IOptimizedNetworkPtr(new IOptimizedNetwork(std::move(graph), options.m_ModelOptions),
diff --git a/src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp b/src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp
index af6b568..7345ff5 100644
--- a/src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp
+++ b/src/backends/backendsCommon/test/FullyConnectedEndToEndTestImpl.hpp
@@ -84,6 +84,25 @@
return network;
}
+armnn::INetworkPtr CreateFullyConnectedNetworkNoTensorInfoConstWeights(const armnn::TensorInfo& inputTensorInfo,
+ const armnn::TensorInfo& outputTensorInfo,
+ const armnn::ConstTensor& weightsConstantTensor,
+ armnn::FullyConnectedDescriptor descriptor)
+{
+ armnn::INetworkPtr network(armnn::INetwork::Create());
+
+ armnn::IConnectableLayer* inputLayer = network->AddInputLayer(0, "Input");
+ armnn::IConnectableLayer* weightsLayer = network->AddConstantLayer(weightsConstantTensor, "Weights");
+ armnn::IConnectableLayer* fullyConnectedLayer = network->AddFullyConnectedLayer(descriptor, "Fully_Connected");
+ armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
+
+ Connect(inputLayer, fullyConnectedLayer, inputTensorInfo, 0, 0);
+ weightsLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1));
+ Connect(fullyConnectedLayer, outputLayer, outputTensorInfo, 0, 0);
+
+ return network;
+}
+
template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
void FullyConnectedWithDynamicWeightsEndToEnd(const std::vector<armnn::BackendId>& backends)
{
@@ -141,7 +160,8 @@
template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
void FullyConnectedWithDynamicOrConstantInputsEndToEnd(const std::vector<armnn::BackendId>& backends,
const bool transposeWeights,
- const bool constantWeightsOrBias)
+ const bool constantWeightsOrBias,
+ const bool tensorInfoSet)
{
unsigned int inputWidth = 1;
unsigned int inputHeight = 1;
@@ -210,7 +230,24 @@
descriptor.m_TransposeWeightMatrix = transposeWeights;
descriptor.m_ConstantWeights = constantWeightsOrBias;
- if (!constantWeightsOrBias)
+ if(!tensorInfoSet)
+ {
+ // Tests constant weights and non constant bias.
+ ConstTensor weightsConstantTensor(weightsDesc, weights.data());
+
+ armnn::INetworkPtr network = CreateFullyConnectedNetworkNoTensorInfoConstWeights(inputTensorInfo,
+ outputTensorInfo,
+ weightsConstantTensor,
+ descriptor);
+ CHECK(network);
+
+ // Create runtime in which test will run
+ IRuntime::CreationOptions options;
+ IRuntimePtr runtime(IRuntime::Create(options));
+
+ CHECK_THROWS_AS( Optimize(*network, backends, runtime->GetDeviceSpec()), LayerValidationException );
+ }
+ else if (!constantWeightsOrBias)
{
// Tests non constant weights and constant bias.
ConstTensor biasConstantTensor(biasesDesc, biasValues.data());
diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp
index ed4b229..6c11a75 100644
--- a/src/backends/reference/test/RefEndToEndTests.cpp
+++ b/src/backends/reference/test/RefEndToEndTests.cpp
@@ -618,12 +618,17 @@
TEST_CASE("RefFullyConnectedEndToEndTestNonConstantWeightsConstantBiasesFloat32")
{
- FullyConnectedWithDynamicOrConstantInputsEndToEnd<armnn::DataType::Float32>(defaultBackends, true, true);
+ FullyConnectedWithDynamicOrConstantInputsEndToEnd<armnn::DataType::Float32>(defaultBackends, true, true, true);
}
TEST_CASE("RefFullyConnectedEndToEndTestConstantWeightsNonConstantBiasesFloat32")
{
- FullyConnectedWithDynamicOrConstantInputsEndToEnd<armnn::DataType::Float32>(defaultBackends, true, false);
+ FullyConnectedWithDynamicOrConstantInputsEndToEnd<armnn::DataType::Float32>(defaultBackends, true, false, true);
+}
+
+TEST_CASE("RefFullyConnectedEndToEndTestConstantWeightsTensorInfoNotSet")
+{
+ FullyConnectedWithDynamicOrConstantInputsEndToEnd<armnn::DataType::Float32>(defaultBackends, true, false, false);
}
TEST_CASE("RefGatherFloatTest")