android-nn-driver getType returns the right device
* ArmnnDriver queries the options and returns CPU or GPU
depending on which is the first backend listed in the options
* Resolves MLCE-401
Change-Id: If4e63e144507e817449f37926711fa325861b57d
Signed-off-by: Pablo Tello <pablo.tello@arm.com>
diff --git a/1.2/ArmnnDriver.hpp b/1.2/ArmnnDriver.hpp
index 5227272..a350d3f 100644
--- a/1.2/ArmnnDriver.hpp
+++ b/1.2/ArmnnDriver.hpp
@@ -129,8 +129,8 @@
Return<void> getType(getType_cb cb)
{
ALOGV("hal_1_2::ArmnnDriver::getType()");
-
- cb(V1_0::ErrorStatus::NONE, V1_2::DeviceType::CPU);
+ const auto device_type = hal_1_2::HalPolicy::GetDeviceTypeFromOptions(this->m_Options);
+ cb(V1_0::ErrorStatus::NONE, device_type);
return Void();
}
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index fb6c31c..79d117a 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -4,6 +4,8 @@
//
#include "HalPolicy.hpp"
+#include "DriverOptions.hpp"
+
namespace armnn_driver
{
@@ -17,6 +19,33 @@
} // anonymous namespace
+HalPolicy::DeviceType HalPolicy::GetDeviceTypeFromOptions(const DriverOptions& options)
+{
+ // Query backends list from the options
+ auto backends = options.GetBackends();
+ // Return first backend
+ if(backends.size()>0)
+ {
+ const auto &first_backend = backends[0];
+ if(first_backend.IsCpuAcc()||first_backend.IsCpuRef())
+ {
+ return V1_2::DeviceType::CPU;
+ }
+ else if(first_backend.IsGpuAcc())
+ {
+ return V1_2::DeviceType::GPU;
+ }
+ else
+ {
+ return V1_2::DeviceType::ACCELERATOR;
+ }
+ }
+ else
+ {
+ return V1_2::DeviceType::CPU;
+ }
+}
+
bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, ConversionData& data)
{
switch (operation.type)
diff --git a/1.2/HalPolicy.hpp b/1.2/HalPolicy.hpp
index a348abe..0662e1b 100644
--- a/1.2/HalPolicy.hpp
+++ b/1.2/HalPolicy.hpp
@@ -16,6 +16,7 @@
namespace armnn_driver
{
+class DriverOptions;
namespace hal_1_2
{
@@ -31,6 +32,9 @@
using ExecutionCallback = V1_2::IExecutionCallback;
using getSupportedOperations_cb = V1_2::IDevice::getSupportedOperations_1_2_cb;
using ErrorStatus = V1_0::ErrorStatus;
+ using DeviceType = V1_2::DeviceType;
+
+ static DeviceType GetDeviceTypeFromOptions(const DriverOptions& options);
static bool ConvertOperation(const Operation& operation, const Model& model, ConversionData& data);
diff --git a/1.3/ArmnnDriver.hpp b/1.3/ArmnnDriver.hpp
index 451b5ab..fd4aa74 100644
--- a/1.3/ArmnnDriver.hpp
+++ b/1.3/ArmnnDriver.hpp
@@ -244,8 +244,8 @@
Return<void> getType(getType_cb cb)
{
ALOGV("hal_1_3::ArmnnDriver::getType()");
-
- cb(V1_0::ErrorStatus::NONE, V1_2::DeviceType::CPU);
+ const auto device_type = hal_1_2::HalPolicy::GetDeviceTypeFromOptions(this->m_Options);
+ cb(V1_0::ErrorStatus::NONE, device_type);
return Void();
}