blob: d5fa9784688aa27bd4272da6b944213dd138da51 [file] [log] [blame]
telsoa01ce3e84a2018-08-31 09:31:35 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa01ce3e84a2018-08-31 09:31:35 +01004//
5
Matteo Martincighe48bdff2018-09-03 13:50:50 +01006#define LOG_TAG "ArmnnDriver"
7
telsoa01ce3e84a2018-08-31 09:31:35 +01008#include "ArmnnDriverImpl.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +01009#include "ArmnnPreparedModel.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +010010
Mike Kellyb5fdf382019-06-11 16:35:25 +010011#ifdef ARMNN_ANDROID_NN_V1_2 // Using ::android::hardware::neuralnetworks::V1_2
12#include "ArmnnPreparedModel_1_2.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +010013#endif
14
Mike Kellyb5fdf382019-06-11 16:35:25 +010015#include "ModelToINetworkConverter.hpp"
16#include "SystemPropertiesUtils.hpp"
17#include <ValidateHal.h>
telsoa01ce3e84a2018-08-31 09:31:35 +010018#include <log/log.h>
19
20using namespace std;
21using namespace android;
22using namespace android::nn;
23using namespace android::hardware;
24
25namespace
26{
27
Matthew Bentham912b3622019-05-03 15:49:14 +010028void NotifyCallbackAndCheck(const sp<V1_0::IPreparedModelCallback>& callback,
telsoa01ce3e84a2018-08-31 09:31:35 +010029 ErrorStatus errorStatus,
Matthew Bentham912b3622019-05-03 15:49:14 +010030 const sp<V1_0::IPreparedModel>& preparedModelPtr)
telsoa01ce3e84a2018-08-31 09:31:35 +010031{
32 Return<void> returned = callback->notify(errorStatus, preparedModelPtr);
33 // This check is required, if the callback fails and it isn't checked it will bring down the service
34 if (!returned.isOk())
35 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +010036 ALOGE("ArmnnDriverImpl::prepareModel: hidl callback failed to return properly: %s ",
Mike Kellyb5fdf382019-06-11 16:35:25 +010037 returned.description().c_str());
telsoa01ce3e84a2018-08-31 09:31:35 +010038 }
39}
40
41Return<ErrorStatus> FailPrepareModel(ErrorStatus error,
42 const string& message,
Matthew Bentham912b3622019-05-03 15:49:14 +010043 const sp<V1_0::IPreparedModelCallback>& callback)
telsoa01ce3e84a2018-08-31 09:31:35 +010044{
Matteo Martincighe48bdff2018-09-03 13:50:50 +010045 ALOGW("ArmnnDriverImpl::prepareModel: %s", message.c_str());
telsoa01ce3e84a2018-08-31 09:31:35 +010046 NotifyCallbackAndCheck(callback, error, nullptr);
47 return error;
48}
49
Mike Kellyb5fdf382019-06-11 16:35:25 +010050
telsoa01ce3e84a2018-08-31 09:31:35 +010051} // namespace
52
53namespace armnn_driver
54{
telsoa01ce3e84a2018-08-31 09:31:35 +010055
arovir01b0717b52018-09-05 17:03:25 +010056template<typename HalPolicy>
arovir01b0717b52018-09-05 17:03:25 +010057Return<ErrorStatus> ArmnnDriverImpl<HalPolicy>::prepareModel(
telsoa01ce3e84a2018-08-31 09:31:35 +010058 const armnn::IRuntimePtr& runtime,
59 const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
60 const DriverOptions& options,
Matteo Martincighe48bdff2018-09-03 13:50:50 +010061 const HalModel& model,
Matthew Bentham912b3622019-05-03 15:49:14 +010062 const sp<V1_0::IPreparedModelCallback>& cb,
Matteo Martincighe48bdff2018-09-03 13:50:50 +010063 bool float32ToFloat16)
telsoa01ce3e84a2018-08-31 09:31:35 +010064{
Matteo Martincighe48bdff2018-09-03 13:50:50 +010065 ALOGV("ArmnnDriverImpl::prepareModel()");
telsoa01ce3e84a2018-08-31 09:31:35 +010066
67 if (cb.get() == nullptr)
68 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +010069 ALOGW("ArmnnDriverImpl::prepareModel: Invalid callback passed to prepareModel");
telsoa01ce3e84a2018-08-31 09:31:35 +010070 return ErrorStatus::INVALID_ARGUMENT;
71 }
72
73 if (!runtime)
74 {
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +010075 return FailPrepareModel(ErrorStatus::DEVICE_UNAVAILABLE, "Device unavailable", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +010076 }
77
78 if (!android::nn::validateModel(model))
79 {
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +010080 return FailPrepareModel(ErrorStatus::INVALID_ARGUMENT, "Invalid model passed as input", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +010081 }
82
83 // Deliberately ignore any unsupported operations requested by the options -
84 // at this point we're being asked to prepare a model that we've already declared support for
85 // and the operation indices may be different to those in getSupportedOperations anyway.
86 set<unsigned int> unsupportedOperations;
Nattapat Chaimanowongd5fd9762019-04-04 13:33:10 +010087 ModelToINetworkConverter<HalPolicy> modelConverter(options.GetBackends(),
88 model,
89 unsupportedOperations);
telsoa01ce3e84a2018-08-31 09:31:35 +010090
91 if (modelConverter.GetConversionResult() != ConversionResult::Success)
92 {
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +010093 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, "ModelToINetworkConverter failed", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +010094 return ErrorStatus::NONE;
95 }
96
Matteo Martincighe48bdff2018-09-03 13:50:50 +010097 // Optimize the network
telsoa01ce3e84a2018-08-31 09:31:35 +010098 armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
99 armnn::OptimizerOptions OptOptions;
100 OptOptions.m_ReduceFp32ToFp16 = float32ToFloat16;
101
jimfly0107dedda2018-10-09 12:29:41 +0100102 std::vector<std::string> errMessages;
telsoa01ce3e84a2018-08-31 09:31:35 +0100103 try
104 {
105 optNet = armnn::Optimize(*modelConverter.GetINetwork(),
Nattapat Chaimanowongd5fd9762019-04-04 13:33:10 +0100106 options.GetBackends(),
telsoa01ce3e84a2018-08-31 09:31:35 +0100107 runtime->GetDeviceSpec(),
jimfly0107dedda2018-10-09 12:29:41 +0100108 OptOptions,
109 errMessages);
telsoa01ce3e84a2018-08-31 09:31:35 +0100110 }
Mike Kellyc7d0d442019-12-11 19:27:11 +0000111 catch (std::exception& e)
telsoa01ce3e84a2018-08-31 09:31:35 +0100112 {
113 stringstream message;
Mike Kellyc7d0d442019-12-11 19:27:11 +0000114 message << "Exception (" << e.what() << ") caught from optimize.";
telsoa01ce3e84a2018-08-31 09:31:35 +0100115 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, message.str(), cb);
116 return ErrorStatus::NONE;
117 }
118
119 // Check that the optimized network is valid.
120 if (!optNet)
121 {
jimfly0107dedda2018-10-09 12:29:41 +0100122 stringstream message;
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100123 message << "Invalid optimized network";
124 for (const string& msg : errMessages)
125 {
jimfly0107dedda2018-10-09 12:29:41 +0100126 message << "\n" << msg;
127 }
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100128 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, message.str(), cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100129 return ErrorStatus::NONE;
130 }
131
132 // Export the optimized network graph to a dot file if an output dump directory
133 // has been specified in the drivers' arguments.
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100134 ExportNetworkGraphToDotFile<HalModel>(*optNet, options.GetRequestInputsAndOutputsDumpDir(), model);
telsoa01ce3e84a2018-08-31 09:31:35 +0100135
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100136 // Load it into the runtime.
telsoa01ce3e84a2018-08-31 09:31:35 +0100137 armnn::NetworkId netId = 0;
138 try
139 {
140 if (runtime->LoadNetwork(netId, move(optNet)) != armnn::Status::Success)
141 {
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100142 return FailPrepareModel(ErrorStatus::GENERAL_FAILURE, "Network could not be loaded", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100143 }
144 }
Mike Kellyc7d0d442019-12-11 19:27:11 +0000145 catch (std::exception& e)
telsoa01ce3e84a2018-08-31 09:31:35 +0100146 {
147 stringstream message;
Mike Kellyc7d0d442019-12-11 19:27:11 +0000148 message << "Exception (" << e.what()<< ") caught from LoadNetwork.";
telsoa01ce3e84a2018-08-31 09:31:35 +0100149 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, message.str(), cb);
150 return ErrorStatus::NONE;
151 }
152
arovir01b0717b52018-09-05 17:03:25 +0100153 unique_ptr<ArmnnPreparedModel<HalPolicy>> preparedModel(
Mike Kellyb5fdf382019-06-11 16:35:25 +0100154 new ArmnnPreparedModel<HalPolicy>(
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100155 netId,
156 runtime.get(),
157 model,
158 options.GetRequestInputsAndOutputsDumpDir(),
159 options.IsGpuProfilingEnabled()));
telsoa01ce3e84a2018-08-31 09:31:35 +0100160
161 // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
162 // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
Matthew Bentham16196e22019-04-01 17:17:58 +0100163 if (!preparedModel->ExecuteWithDummyInputs())
164 {
165 return FailPrepareModel(ErrorStatus::GENERAL_FAILURE, "Network could not be executed", cb);
166 }
telsoa01ce3e84a2018-08-31 09:31:35 +0100167
168 if (clTunedParameters &&
169 options.GetClTunedParametersMode() == armnn::IGpuAccTunedParameters::Mode::UpdateTunedParameters)
170 {
171 // Now that we've done one inference the CL kernel parameters will have been tuned, so save the updated file.
172 try
173 {
174 clTunedParameters->Save(options.GetClTunedParametersFile().c_str());
175 }
Mike Kellyc7d0d442019-12-11 19:27:11 +0000176 catch (std::exception& error)
telsoa01ce3e84a2018-08-31 09:31:35 +0100177 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100178 ALOGE("ArmnnDriverImpl::prepareModel: Failed to save CL tuned parameters file '%s': %s",
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100179 options.GetClTunedParametersFile().c_str(), error.what());
telsoa01ce3e84a2018-08-31 09:31:35 +0100180 }
181 }
182
183 NotifyCallbackAndCheck(cb, ErrorStatus::NONE, preparedModel.release());
184
185 return ErrorStatus::NONE;
186}
187
arovir01b0717b52018-09-05 17:03:25 +0100188template<typename HalPolicy>
Mike Kellyb5fdf382019-06-11 16:35:25 +0100189Return<void> ArmnnDriverImpl<HalPolicy>::getSupportedOperations(const armnn::IRuntimePtr& runtime,
190 const DriverOptions& options,
191 const HalModel& model,
192 HalGetSupportedOperations_cb cb)
193{
194 ALOGV("ArmnnDriverImpl::getSupportedOperations()");
195
196 vector<bool> result;
197
198 if (!runtime)
199 {
200 cb(ErrorStatus::DEVICE_UNAVAILABLE, result);
201 return Void();
202 }
203
204 // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
205 if (!android::nn::validateModel(model))
206 {
207 cb(ErrorStatus::INVALID_ARGUMENT, result);
208 return Void();
209 }
210
211 // Attempt to convert the model to an ArmNN input network (INetwork).
212 ModelToINetworkConverter<HalPolicy> modelConverter(options.GetBackends(),
213 model,
214 options.GetForcedUnsupportedOperations());
215
216 if (modelConverter.GetConversionResult() != ConversionResult::Success
217 && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
218 {
219 cb(ErrorStatus::GENERAL_FAILURE, result);
220 return Void();
221 }
222
223 // Check each operation if it was converted successfully and copy the flags
224 // into the result (vector<bool>) that we need to return to Android.
225 result.reserve(model.operations.size());
226 for (uint32_t operationIdx = 0; operationIdx < model.operations.size(); operationIdx++)
227 {
228 bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
229 result.push_back(operationSupported);
230 }
231
232 cb(ErrorStatus::NONE, result);
233 return Void();
234}
235
236template<typename HalPolicy>
arovir01b0717b52018-09-05 17:03:25 +0100237Return<DeviceStatus> ArmnnDriverImpl<HalPolicy>::getStatus()
telsoa01ce3e84a2018-08-31 09:31:35 +0100238{
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100239 ALOGV("ArmnnDriver::getStatus()");
telsoa01ce3e84a2018-08-31 09:31:35 +0100240
241 return DeviceStatus::AVAILABLE;
242}
243
arovir01b0717b52018-09-05 17:03:25 +0100244///
245/// Class template specializations
246///
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100247
arovir01b0717b52018-09-05 17:03:25 +0100248template class ArmnnDriverImpl<hal_1_0::HalPolicy>;
249
Matteo Martincigh8b287c22018-09-07 09:25:10 +0100250#ifdef ARMNN_ANDROID_NN_V1_1
arovir01b0717b52018-09-05 17:03:25 +0100251template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100252#endif
253
Mike Kellyb5fdf382019-06-11 16:35:25 +0100254#ifdef ARMNN_ANDROID_NN_V1_2
255template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
256template class ArmnnDriverImpl<hal_1_2::HalPolicy>;
257#endif
258
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100259} // namespace armnn_driver