blob: a3c2e10f5cc8315bb065137c3830e3db590a7e1d [file] [log] [blame]
telsoa01ce3e84a2018-08-31 09:31:35 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa01ce3e84a2018-08-31 09:31:35 +01004//
5
Matteo Martincighe48bdff2018-09-03 13:50:50 +01006#define LOG_TAG "ArmnnDriver"
7
telsoa01ce3e84a2018-08-31 09:31:35 +01008#include "ArmnnDriverImpl.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +01009#include "ArmnnPreparedModel.hpp"
arovir01b0717b52018-09-05 17:03:25 +010010#include "ModelToINetworkConverter.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +010011#include "SystemPropertiesUtils.hpp"
12
13#if defined(ARMNN_ANDROID_P)
14// The headers of the ML framework have changed between Android O and Android P.
15// The validation functions have been moved into their own header, ValidateHal.h.
16#include <ValidateHal.h>
17#endif
18
19#include <log/log.h>
20
21using namespace std;
22using namespace android;
23using namespace android::nn;
24using namespace android::hardware;
25
26namespace
27{
28
telsoa01ce3e84a2018-08-31 09:31:35 +010029void NotifyCallbackAndCheck(const sp<IPreparedModelCallback>& callback,
30 ErrorStatus errorStatus,
31 const sp<IPreparedModel>& preparedModelPtr)
32{
33 Return<void> returned = callback->notify(errorStatus, preparedModelPtr);
34 // This check is required, if the callback fails and it isn't checked it will bring down the service
35 if (!returned.isOk())
36 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +010037 ALOGE("ArmnnDriverImpl::prepareModel: hidl callback failed to return properly: %s ",
telsoa01ce3e84a2018-08-31 09:31:35 +010038 returned.description().c_str());
39 }
40}
41
42Return<ErrorStatus> FailPrepareModel(ErrorStatus error,
43 const string& message,
44 const sp<IPreparedModelCallback>& callback)
45{
Matteo Martincighe48bdff2018-09-03 13:50:50 +010046 ALOGW("ArmnnDriverImpl::prepareModel: %s", message.c_str());
telsoa01ce3e84a2018-08-31 09:31:35 +010047 NotifyCallbackAndCheck(callback, error, nullptr);
48 return error;
49}
50
51} // namespace
52
53namespace armnn_driver
54{
telsoa01ce3e84a2018-08-31 09:31:35 +010055
arovir01b0717b52018-09-05 17:03:25 +010056template<typename HalPolicy>
57Return<void> ArmnnDriverImpl<HalPolicy>::getSupportedOperations(const armnn::IRuntimePtr& runtime,
58 const DriverOptions& options,
59 const HalModel& model,
60 HalGetSupportedOperations_cb cb)
telsoa01ce3e84a2018-08-31 09:31:35 +010061{
Matteo Martincighe48bdff2018-09-03 13:50:50 +010062 ALOGV("ArmnnDriverImpl::getSupportedOperations()");
telsoa01ce3e84a2018-08-31 09:31:35 +010063
64 vector<bool> result;
65
66 if (!runtime)
67 {
68 cb(ErrorStatus::DEVICE_UNAVAILABLE, result);
69 return Void();
70 }
71
Matteo Martincighe48bdff2018-09-03 13:50:50 +010072 // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
telsoa01ce3e84a2018-08-31 09:31:35 +010073 if (!android::nn::validateModel(model))
74 {
75 cb(ErrorStatus::INVALID_ARGUMENT, result);
76 return Void();
77 }
78
79 // Attempt to convert the model to an ArmNN input network (INetwork).
arovir01b0717b52018-09-05 17:03:25 +010080 ModelToINetworkConverter<HalPolicy> modelConverter(options.GetComputeDevice(),
Matteo Martincighe48bdff2018-09-03 13:50:50 +010081 model,
82 options.GetForcedUnsupportedOperations());
telsoa01ce3e84a2018-08-31 09:31:35 +010083
84 if (modelConverter.GetConversionResult() != ConversionResult::Success
Matteo Martincighe48bdff2018-09-03 13:50:50 +010085 && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
telsoa01ce3e84a2018-08-31 09:31:35 +010086 {
87 cb(ErrorStatus::GENERAL_FAILURE, result);
88 return Void();
89 }
90
91 // Check each operation if it was converted successfully and copy the flags
Matteo Martincighe48bdff2018-09-03 13:50:50 +010092 // into the result (vector<bool>) that we need to return to Android.
telsoa01ce3e84a2018-08-31 09:31:35 +010093 result.reserve(model.operations.size());
94 for (uint32_t operationIdx = 0; operationIdx < model.operations.size(); operationIdx++)
95 {
96 bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
97 result.push_back(operationSupported);
98 }
99
100 cb(ErrorStatus::NONE, result);
101 return Void();
102}
103
arovir01b0717b52018-09-05 17:03:25 +0100104template<typename HalPolicy>
105Return<ErrorStatus> ArmnnDriverImpl<HalPolicy>::prepareModel(
telsoa01ce3e84a2018-08-31 09:31:35 +0100106 const armnn::IRuntimePtr& runtime,
107 const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
108 const DriverOptions& options,
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100109 const HalModel& model,
telsoa01ce3e84a2018-08-31 09:31:35 +0100110 const sp<IPreparedModelCallback>& cb,
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100111 bool float32ToFloat16)
telsoa01ce3e84a2018-08-31 09:31:35 +0100112{
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100113 ALOGV("ArmnnDriverImpl::prepareModel()");
telsoa01ce3e84a2018-08-31 09:31:35 +0100114
115 if (cb.get() == nullptr)
116 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100117 ALOGW("ArmnnDriverImpl::prepareModel: Invalid callback passed to prepareModel");
telsoa01ce3e84a2018-08-31 09:31:35 +0100118 return ErrorStatus::INVALID_ARGUMENT;
119 }
120
121 if (!runtime)
122 {
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100123 return FailPrepareModel(ErrorStatus::DEVICE_UNAVAILABLE, "Device unavailable", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100124 }
125
126 if (!android::nn::validateModel(model))
127 {
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100128 return FailPrepareModel(ErrorStatus::INVALID_ARGUMENT, "Invalid model passed as input", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100129 }
130
131 // Deliberately ignore any unsupported operations requested by the options -
132 // at this point we're being asked to prepare a model that we've already declared support for
133 // and the operation indices may be different to those in getSupportedOperations anyway.
134 set<unsigned int> unsupportedOperations;
arovir01b0717b52018-09-05 17:03:25 +0100135 ModelToINetworkConverter<HalPolicy> modelConverter(options.GetComputeDevice(),
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100136 model,
137 unsupportedOperations);
telsoa01ce3e84a2018-08-31 09:31:35 +0100138
139 if (modelConverter.GetConversionResult() != ConversionResult::Success)
140 {
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100141 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, "ModelToINetworkConverter failed", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100142 return ErrorStatus::NONE;
143 }
144
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100145 // Optimize the network
telsoa01ce3e84a2018-08-31 09:31:35 +0100146 armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
147 armnn::OptimizerOptions OptOptions;
148 OptOptions.m_ReduceFp32ToFp16 = float32ToFloat16;
149
jimfly0107dedda2018-10-09 12:29:41 +0100150 std::vector<std::string> errMessages;
telsoa01ce3e84a2018-08-31 09:31:35 +0100151 try
152 {
153 optNet = armnn::Optimize(*modelConverter.GetINetwork(),
154 {options.GetComputeDevice()},
155 runtime->GetDeviceSpec(),
jimfly0107dedda2018-10-09 12:29:41 +0100156 OptOptions,
157 errMessages);
telsoa01ce3e84a2018-08-31 09:31:35 +0100158 }
159 catch (armnn::Exception &e)
160 {
161 stringstream message;
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100162 message << "armnn::Exception (" << e.what() << ") caught from optimize.";
telsoa01ce3e84a2018-08-31 09:31:35 +0100163 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, message.str(), cb);
164 return ErrorStatus::NONE;
165 }
166
167 // Check that the optimized network is valid.
168 if (!optNet)
169 {
jimfly0107dedda2018-10-09 12:29:41 +0100170 stringstream message;
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100171 message << "Invalid optimized network";
172 for (const string& msg : errMessages)
173 {
jimfly0107dedda2018-10-09 12:29:41 +0100174 message << "\n" << msg;
175 }
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100176 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, message.str(), cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100177 return ErrorStatus::NONE;
178 }
179
180 // Export the optimized network graph to a dot file if an output dump directory
181 // has been specified in the drivers' arguments.
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100182 ExportNetworkGraphToDotFile<HalModel>(*optNet, options.GetRequestInputsAndOutputsDumpDir(), model);
telsoa01ce3e84a2018-08-31 09:31:35 +0100183
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100184 // Load it into the runtime.
telsoa01ce3e84a2018-08-31 09:31:35 +0100185 armnn::NetworkId netId = 0;
186 try
187 {
188 if (runtime->LoadNetwork(netId, move(optNet)) != armnn::Status::Success)
189 {
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100190 return FailPrepareModel(ErrorStatus::GENERAL_FAILURE, "Network could not be loaded", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100191 }
192 }
193 catch (armnn::Exception& e)
194 {
195 stringstream message;
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100196 message << "armnn::Exception (" << e.what()<< ") caught from LoadNetwork.";
telsoa01ce3e84a2018-08-31 09:31:35 +0100197 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, message.str(), cb);
198 return ErrorStatus::NONE;
199 }
200
arovir01b0717b52018-09-05 17:03:25 +0100201 unique_ptr<ArmnnPreparedModel<HalPolicy>> preparedModel(
202 new ArmnnPreparedModel<HalPolicy>(
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100203 netId,
204 runtime.get(),
205 model,
206 options.GetRequestInputsAndOutputsDumpDir(),
207 options.IsGpuProfilingEnabled()));
telsoa01ce3e84a2018-08-31 09:31:35 +0100208
209 // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
210 // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
211 preparedModel->ExecuteWithDummyInputs();
212
213 if (clTunedParameters &&
214 options.GetClTunedParametersMode() == armnn::IGpuAccTunedParameters::Mode::UpdateTunedParameters)
215 {
216 // Now that we've done one inference the CL kernel parameters will have been tuned, so save the updated file.
217 try
218 {
219 clTunedParameters->Save(options.GetClTunedParametersFile().c_str());
220 }
221 catch (const armnn::Exception& error)
222 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100223 ALOGE("ArmnnDriverImpl::prepareModel: Failed to save CL tuned parameters file '%s': %s",
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100224 options.GetClTunedParametersFile().c_str(), error.what());
telsoa01ce3e84a2018-08-31 09:31:35 +0100225 }
226 }
227
228 NotifyCallbackAndCheck(cb, ErrorStatus::NONE, preparedModel.release());
229
230 return ErrorStatus::NONE;
231}
232
arovir01b0717b52018-09-05 17:03:25 +0100233template<typename HalPolicy>
234Return<DeviceStatus> ArmnnDriverImpl<HalPolicy>::getStatus()
telsoa01ce3e84a2018-08-31 09:31:35 +0100235{
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100236 ALOGV("ArmnnDriver::getStatus()");
telsoa01ce3e84a2018-08-31 09:31:35 +0100237
238 return DeviceStatus::AVAILABLE;
239}
240
arovir01b0717b52018-09-05 17:03:25 +0100241///
242/// Class template specializations
243///
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100244
arovir01b0717b52018-09-05 17:03:25 +0100245template class ArmnnDriverImpl<hal_1_0::HalPolicy>;
246
Matteo Martincigh8b287c22018-09-07 09:25:10 +0100247#ifdef ARMNN_ANDROID_NN_V1_1
arovir01b0717b52018-09-05 17:03:25 +0100248template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100249#endif
250
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100251} // namespace armnn_driver