COMPMID-1246: Fix bug in handling backends that can't be loaded in the Graph API
Change-Id: Iefd175af2f472179d86df5358a1527a79c5666ed
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/145182
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
diff --git a/src/graph/Utils.cpp b/src/graph/Utils.cpp
index 75644a8..0a85a7f 100644
--- a/src/graph/Utils.cpp
+++ b/src/graph/Utils.cpp
@@ -101,7 +101,10 @@
{
for(const auto &backend : backends::BackendRegistry::get().backends())
{
- backend.second->release_backend_context(ctx);
+ if(backend.second->is_backend_supported())
+ {
+ backend.second->release_backend_context(ctx);
+ }
}
}
@@ -109,7 +112,10 @@
{
for(const auto &backend : backends::BackendRegistry::get().backends())
{
- backend.second->setup_backend_context(ctx);
+ if(backend.second->is_backend_supported())
+ {
+ backend.second->setup_backend_context(ctx);
+ }
}
}
@@ -172,11 +178,10 @@
{
if(tensor != nullptr && tensor->handle() == nullptr)
{
- Target target = tensor->desc().target;
- auto backend = backends::BackendRegistry::get().find_backend(target);
- ARM_COMPUTE_ERROR_ON_MSG(!backend, "Requested backend doesn't exist!");
- auto handle = backend->create_tensor(*tensor);
- ARM_COMPUTE_ERROR_ON_MSG(!backend, "Couldn't create backend handle!");
+ Target target = tensor->desc().target;
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(target);
+ std::unique_ptr<ITensorHandle> handle = backend.create_tensor(*tensor);
+ ARM_COMPUTE_ERROR_ON_MSG(!handle, "Couldn't create backend handle!");
tensor->set_handle(std::move(handle));
}
}
diff --git a/src/graph/backends/BackendRegistry.cpp b/src/graph/backends/BackendRegistry.cpp
index 2803322..dccfefc 100644
--- a/src/graph/backends/BackendRegistry.cpp
+++ b/src/graph/backends/BackendRegistry.cpp
@@ -48,6 +48,14 @@
return _registered_backends[target].get();
}
+IDeviceBackend &BackendRegistry::get_backend(Target target)
+{
+ IDeviceBackend *backend = find_backend(target);
+ ARM_COMPUTE_ERROR_ON_MSG(!backend, "Requested backend doesn't exist!");
+ ARM_COMPUTE_ERROR_ON_MSG(!backend->is_backend_supported(), "Requested backend isn't supported");
+ return *backend;
+}
+
bool BackendRegistry::contains(Target target) const
{
auto it = _registered_backends.find(target);
diff --git a/src/graph/detail/ExecutionHelpers.cpp b/src/graph/detail/ExecutionHelpers.cpp
index 6157b7f..f479963 100644
--- a/src/graph/detail/ExecutionHelpers.cpp
+++ b/src/graph/detail/ExecutionHelpers.cpp
@@ -44,10 +44,9 @@
{
if(node != nullptr)
{
- Target assigned_target = node->assigned_target();
- auto backend = backends::BackendRegistry::get().find_backend(assigned_target);
- ARM_COMPUTE_ERROR_ON_MSG(!backend, "Requested backend doesn't exist!");
- Status status = backend->validate_node(*node);
+ Target assigned_target = node->assigned_target();
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(assigned_target);
+ Status status = backend.validate_node(*node);
ARM_COMPUTE_ERROR_ON_MSG(!bool(status), status.error_description().c_str());
}
}
@@ -61,11 +60,10 @@
{
if(tensor && tensor->handle() == nullptr)
{
- Target target = tensor->desc().target;
- auto backend = backends::BackendRegistry::get().find_backend(target);
- ARM_COMPUTE_ERROR_ON_MSG(!backend, "Requested backend doesn't exist!");
- auto handle = backend->create_tensor(*tensor);
- ARM_COMPUTE_ERROR_ON_MSG(!backend, "Couldn't create backend handle!");
+ Target target = tensor->desc().target;
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(target);
+ std::unique_ptr<ITensorHandle> handle = backend.create_tensor(*tensor);
+ ARM_COMPUTE_ERROR_ON_MSG(!handle, "Couldn't create backend handle!");
tensor->set_handle(std::move(handle));
}
}
@@ -143,10 +141,9 @@
auto node = g.node(node_id);
if(node != nullptr)
{
- Target assigned_target = node->assigned_target();
- auto backend = backends::BackendRegistry::get().find_backend(assigned_target);
- ARM_COMPUTE_ERROR_ON_MSG(!backend, "Requested backend doesn't exist!");
- auto func = backend->configure_node(*node, ctx);
+ Target assigned_target = node->assigned_target();
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(assigned_target);
+ std::unique_ptr<IFunction> func = backend.configure_node(*node, ctx);
if(func != nullptr)
{
ExecutionTask task;
@@ -264,4 +261,4 @@
}
} // namespace detail
} // namespace graph
-} // namespace arm_compute
\ No newline at end of file
+} // namespace arm_compute
diff --git a/src/graph/mutators/DepthConcatSubTensorMutator.cpp b/src/graph/mutators/DepthConcatSubTensorMutator.cpp
index 937528d..a170c4d 100644
--- a/src/graph/mutators/DepthConcatSubTensorMutator.cpp
+++ b/src/graph/mutators/DepthConcatSubTensorMutator.cpp
@@ -77,7 +77,7 @@
});
// Create subtensors
- if(is_valid && backends::BackendRegistry::get().find_backend(output_tensor->desc().target) != nullptr)
+ if(is_valid && is_target_supported(output_tensor->desc().target))
{
ARM_COMPUTE_LOG_GRAPH_VERBOSE("Using sub-tensors for the node with ID : "
<< node->id() << " and name : " << node->name() << std::endl);
@@ -88,8 +88,8 @@
auto input_tensor = node->input(i);
const auto input_shape = input_tensor->desc().shape;
- auto backend = backends::BackendRegistry::get().find_backend(input_tensor->desc().target);
- auto handle = backend->create_subtensor(output_tensor->handle(), input_shape, Coordinates(0, 0, depth), false);
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(input_tensor->desc().target);
+ std::unique_ptr<ITensorHandle> handle = backend.create_subtensor(output_tensor->handle(), input_shape, Coordinates(0, 0, depth), false);
input_tensor->set_handle(std::move(handle));
depth += input_shape.z();
diff --git a/src/graph/mutators/GroupedConvolutionMutator.cpp b/src/graph/mutators/GroupedConvolutionMutator.cpp
index d2643d5..1bcc11b 100644
--- a/src/graph/mutators/GroupedConvolutionMutator.cpp
+++ b/src/graph/mutators/GroupedConvolutionMutator.cpp
@@ -117,8 +117,8 @@
if(node != nullptr && node->type() == NodeType::ConvolutionLayer && arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(node)->num_groups() != 1)
{
// Validate node
- backends::IDeviceBackend *backend = backends::BackendRegistry::get().find_backend(node->assigned_target());
- Status status = backend->validate_node(*node);
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(node->assigned_target());
+ Status status = backend.validate_node(*node);
// If grouped convolution is not supported
if(!bool(status))
diff --git a/src/graph/mutators/NodeExecutionMethodMutator.cpp b/src/graph/mutators/NodeExecutionMethodMutator.cpp
index 896bf07..b420121 100644
--- a/src/graph/mutators/NodeExecutionMethodMutator.cpp
+++ b/src/graph/mutators/NodeExecutionMethodMutator.cpp
@@ -55,8 +55,8 @@
if(node != nullptr)
{
// Validate node
- backends::IDeviceBackend *backend = backends::BackendRegistry::get().find_backend(node->assigned_target());
- Status status = backend->validate_node(*node);
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(node->assigned_target());
+ Status status = backend.validate_node(*node);
// Set default execution method in case of failure
if(!bool(status))
diff --git a/src/graph/mutators/SplitLayerSubTensorMutator.cpp b/src/graph/mutators/SplitLayerSubTensorMutator.cpp
index 5f1c9c3..e21252a 100644
--- a/src/graph/mutators/SplitLayerSubTensorMutator.cpp
+++ b/src/graph/mutators/SplitLayerSubTensorMutator.cpp
@@ -25,6 +25,7 @@
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/Logger.h"
+#include "arm_compute/graph/Utils.h"
#include "arm_compute/graph/algorithms/TopologicalSort.h"
#include "arm_compute/graph/backends/BackendRegistry.h"
#include "arm_compute/graph/nodes/SplitLayerNode.h"
@@ -69,7 +70,7 @@
});
// Create subtensors
- if(is_valid && backends::BackendRegistry::get().find_backend(input_tensor->desc().target) != nullptr)
+ if(is_valid && is_target_supported(input_tensor->desc().target))
{
ARM_COMPUTE_LOG_GRAPH_VERBOSE("Using sub-tensors for the node with ID : "
<< node->id() << " and name : " << node->name() << std::endl);
@@ -88,8 +89,8 @@
Coordinates coords;
std::tie(std::ignore, coords) = SplitLayerNode::compute_output_descriptor(input_tensor->desc(), num_splits, axis, i);
- backends::IDeviceBackend *backend = backends::BackendRegistry::get().find_backend(output_tensor->desc().target);
- std::unique_ptr<ITensorHandle> handle = backend->create_subtensor(input_tensor->handle(), output_shape, coords, extend_parent);
+ backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(output_tensor->desc().target);
+ std::unique_ptr<ITensorHandle> handle = backend.create_subtensor(input_tensor->handle(), output_shape, coords, extend_parent);
output_tensor->set_handle(std::move(handle));
}
}