blob: 267574c1865599592125cec4d23768519ed24d46 [file] [log] [blame]
telsoa01ce3e84a2018-08-31 09:31:35 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa01ce3e84a2018-08-31 09:31:35 +01004//
5
Matteo Martincighe48bdff2018-09-03 13:50:50 +01006#define LOG_TAG "ArmnnDriver"
7
telsoa01ce3e84a2018-08-31 09:31:35 +01008#include "ArmnnDriverImpl.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +01009#include "ArmnnPreparedModel.hpp"
arovir01b0717b52018-09-05 17:03:25 +010010#include "ModelToINetworkConverter.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +010011#include "SystemPropertiesUtils.hpp"
12
13#if defined(ARMNN_ANDROID_P)
14// The headers of the ML framework have changed between Android O and Android P.
15// The validation functions have been moved into their own header, ValidateHal.h.
16#include <ValidateHal.h>
17#endif
18
19#include <log/log.h>
20
21using namespace std;
22using namespace android;
23using namespace android::nn;
24using namespace android::hardware;
25
26namespace
27{
28
telsoa01ce3e84a2018-08-31 09:31:35 +010029void NotifyCallbackAndCheck(const sp<IPreparedModelCallback>& callback,
30 ErrorStatus errorStatus,
31 const sp<IPreparedModel>& preparedModelPtr)
32{
33 Return<void> returned = callback->notify(errorStatus, preparedModelPtr);
34 // This check is required, if the callback fails and it isn't checked it will bring down the service
35 if (!returned.isOk())
36 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +010037 ALOGE("ArmnnDriverImpl::prepareModel: hidl callback failed to return properly: %s ",
telsoa01ce3e84a2018-08-31 09:31:35 +010038 returned.description().c_str());
39 }
40}
41
42Return<ErrorStatus> FailPrepareModel(ErrorStatus error,
43 const string& message,
44 const sp<IPreparedModelCallback>& callback)
45{
Matteo Martincighe48bdff2018-09-03 13:50:50 +010046 ALOGW("ArmnnDriverImpl::prepareModel: %s", message.c_str());
telsoa01ce3e84a2018-08-31 09:31:35 +010047 NotifyCallbackAndCheck(callback, error, nullptr);
48 return error;
49}
50
51} // namespace
52
53namespace armnn_driver
54{
telsoa01ce3e84a2018-08-31 09:31:35 +010055
arovir01b0717b52018-09-05 17:03:25 +010056template<typename HalPolicy>
57Return<void> ArmnnDriverImpl<HalPolicy>::getSupportedOperations(const armnn::IRuntimePtr& runtime,
58 const DriverOptions& options,
59 const HalModel& model,
60 HalGetSupportedOperations_cb cb)
telsoa01ce3e84a2018-08-31 09:31:35 +010061{
Matteo Martincighe48bdff2018-09-03 13:50:50 +010062 ALOGV("ArmnnDriverImpl::getSupportedOperations()");
telsoa01ce3e84a2018-08-31 09:31:35 +010063
64 vector<bool> result;
65
66 if (!runtime)
67 {
68 cb(ErrorStatus::DEVICE_UNAVAILABLE, result);
69 return Void();
70 }
71
Matteo Martincighe48bdff2018-09-03 13:50:50 +010072 // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
telsoa01ce3e84a2018-08-31 09:31:35 +010073 if (!android::nn::validateModel(model))
74 {
75 cb(ErrorStatus::INVALID_ARGUMENT, result);
76 return Void();
77 }
78
79 // Attempt to convert the model to an ArmNN input network (INetwork).
arovir01b0717b52018-09-05 17:03:25 +010080 ModelToINetworkConverter<HalPolicy> modelConverter(options.GetComputeDevice(),
Matteo Martincighe48bdff2018-09-03 13:50:50 +010081 model,
82 options.GetForcedUnsupportedOperations());
telsoa01ce3e84a2018-08-31 09:31:35 +010083
84 if (modelConverter.GetConversionResult() != ConversionResult::Success
Matteo Martincighe48bdff2018-09-03 13:50:50 +010085 && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
telsoa01ce3e84a2018-08-31 09:31:35 +010086 {
87 cb(ErrorStatus::GENERAL_FAILURE, result);
88 return Void();
89 }
90
91 // Check each operation if it was converted successfully and copy the flags
Matteo Martincighe48bdff2018-09-03 13:50:50 +010092 // into the result (vector<bool>) that we need to return to Android.
telsoa01ce3e84a2018-08-31 09:31:35 +010093 result.reserve(model.operations.size());
94 for (uint32_t operationIdx = 0; operationIdx < model.operations.size(); operationIdx++)
95 {
96 bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
97 result.push_back(operationSupported);
98 }
99
100 cb(ErrorStatus::NONE, result);
101 return Void();
102}
103
arovir01b0717b52018-09-05 17:03:25 +0100104template<typename HalPolicy>
105Return<ErrorStatus> ArmnnDriverImpl<HalPolicy>::prepareModel(
telsoa01ce3e84a2018-08-31 09:31:35 +0100106 const armnn::IRuntimePtr& runtime,
107 const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
108 const DriverOptions& options,
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100109 const HalModel& model,
telsoa01ce3e84a2018-08-31 09:31:35 +0100110 const sp<IPreparedModelCallback>& cb,
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100111 bool float32ToFloat16)
telsoa01ce3e84a2018-08-31 09:31:35 +0100112{
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100113 ALOGV("ArmnnDriverImpl::prepareModel()");
telsoa01ce3e84a2018-08-31 09:31:35 +0100114
115 if (cb.get() == nullptr)
116 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100117 ALOGW("ArmnnDriverImpl::prepareModel: Invalid callback passed to prepareModel");
telsoa01ce3e84a2018-08-31 09:31:35 +0100118 return ErrorStatus::INVALID_ARGUMENT;
119 }
120
121 if (!runtime)
122 {
123 return FailPrepareModel(ErrorStatus::DEVICE_UNAVAILABLE,
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100124 "ArmnnDriverImpl::prepareModel: Device unavailable", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100125 }
126
127 if (!android::nn::validateModel(model))
128 {
129 return FailPrepareModel(ErrorStatus::INVALID_ARGUMENT,
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100130 "ArmnnDriverImpl::prepareModel: Invalid model passed as input", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100131 }
132
133 // Deliberately ignore any unsupported operations requested by the options -
134 // at this point we're being asked to prepare a model that we've already declared support for
135 // and the operation indices may be different to those in getSupportedOperations anyway.
136 set<unsigned int> unsupportedOperations;
arovir01b0717b52018-09-05 17:03:25 +0100137 ModelToINetworkConverter<HalPolicy> modelConverter(options.GetComputeDevice(),
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100138 model,
139 unsupportedOperations);
telsoa01ce3e84a2018-08-31 09:31:35 +0100140
141 if (modelConverter.GetConversionResult() != ConversionResult::Success)
142 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100143 FailPrepareModel(ErrorStatus::GENERAL_FAILURE,
144 "ArmnnDriverImpl::prepareModel: ModelToINetworkConverter failed", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100145 return ErrorStatus::NONE;
146 }
147
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100148 // Optimize the network
telsoa01ce3e84a2018-08-31 09:31:35 +0100149 armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
150 armnn::OptimizerOptions OptOptions;
151 OptOptions.m_ReduceFp32ToFp16 = float32ToFloat16;
152
jimfly0107dedda2018-10-09 12:29:41 +0100153 std::vector<std::string> errMessages;
telsoa01ce3e84a2018-08-31 09:31:35 +0100154 try
155 {
156 optNet = armnn::Optimize(*modelConverter.GetINetwork(),
157 {options.GetComputeDevice()},
158 runtime->GetDeviceSpec(),
jimfly0107dedda2018-10-09 12:29:41 +0100159 OptOptions,
160 errMessages);
telsoa01ce3e84a2018-08-31 09:31:35 +0100161 }
162 catch (armnn::Exception &e)
163 {
164 stringstream message;
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100165 message << "ArmnnDriverImpl::prepareModel: armnn::Exception (" << e.what() << ") caught from optimize.";
telsoa01ce3e84a2018-08-31 09:31:35 +0100166 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, message.str(), cb);
167 return ErrorStatus::NONE;
168 }
169
170 // Check that the optimized network is valid.
171 if (!optNet)
172 {
jimfly0107dedda2018-10-09 12:29:41 +0100173 stringstream message;
174 message << "ArmnnDriverImpl::prepareModel: Invalid optimized network";
175 for (const string& msg : errMessages) {
176 message << "\n" << msg;
177 }
telsoa01ce3e84a2018-08-31 09:31:35 +0100178 FailPrepareModel(ErrorStatus::GENERAL_FAILURE,
jimfly0107dedda2018-10-09 12:29:41 +0100179 message.str(), cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100180 return ErrorStatus::NONE;
181 }
182
183 // Export the optimized network graph to a dot file if an output dump directory
184 // has been specified in the drivers' arguments.
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100185 ExportNetworkGraphToDotFile<HalModel>(*optNet, options.GetRequestInputsAndOutputsDumpDir(), model);
telsoa01ce3e84a2018-08-31 09:31:35 +0100186
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100187 // Load it into the runtime.
telsoa01ce3e84a2018-08-31 09:31:35 +0100188 armnn::NetworkId netId = 0;
189 try
190 {
191 if (runtime->LoadNetwork(netId, move(optNet)) != armnn::Status::Success)
192 {
193 return FailPrepareModel(ErrorStatus::GENERAL_FAILURE,
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100194 "ArmnnDriverImpl::prepareModel: Network could not be loaded", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100195 }
196 }
197 catch (armnn::Exception& e)
198 {
199 stringstream message;
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100200 message << "ArmnnDriverImpl::prepareModel: armnn::Exception (" << e.what()<< ") caught from LoadNetwork.";
telsoa01ce3e84a2018-08-31 09:31:35 +0100201 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, message.str(), cb);
202 return ErrorStatus::NONE;
203 }
204
arovir01b0717b52018-09-05 17:03:25 +0100205 unique_ptr<ArmnnPreparedModel<HalPolicy>> preparedModel(
206 new ArmnnPreparedModel<HalPolicy>(
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100207 netId,
208 runtime.get(),
209 model,
210 options.GetRequestInputsAndOutputsDumpDir(),
211 options.IsGpuProfilingEnabled()));
telsoa01ce3e84a2018-08-31 09:31:35 +0100212
213 // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
214 // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
215 preparedModel->ExecuteWithDummyInputs();
216
217 if (clTunedParameters &&
218 options.GetClTunedParametersMode() == armnn::IGpuAccTunedParameters::Mode::UpdateTunedParameters)
219 {
220 // Now that we've done one inference the CL kernel parameters will have been tuned, so save the updated file.
221 try
222 {
223 clTunedParameters->Save(options.GetClTunedParametersFile().c_str());
224 }
225 catch (const armnn::Exception& error)
226 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100227 ALOGE("ArmnnDriverImpl::prepareModel: Failed to save CL tuned parameters file '%s': %s",
telsoa01ce3e84a2018-08-31 09:31:35 +0100228 options.GetClTunedParametersFile().c_str(), error.what());
229 }
230 }
231
232 NotifyCallbackAndCheck(cb, ErrorStatus::NONE, preparedModel.release());
233
234 return ErrorStatus::NONE;
235}
236
arovir01b0717b52018-09-05 17:03:25 +0100237template<typename HalPolicy>
238Return<DeviceStatus> ArmnnDriverImpl<HalPolicy>::getStatus()
telsoa01ce3e84a2018-08-31 09:31:35 +0100239{
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100240 ALOGV("ArmnnDriver::getStatus()");
telsoa01ce3e84a2018-08-31 09:31:35 +0100241
242 return DeviceStatus::AVAILABLE;
243}
244
arovir01b0717b52018-09-05 17:03:25 +0100245///
246/// Class template specializations
247///
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100248
arovir01b0717b52018-09-05 17:03:25 +0100249template class ArmnnDriverImpl<hal_1_0::HalPolicy>;
250
Matteo Martincigh8b287c22018-09-07 09:25:10 +0100251#ifdef ARMNN_ANDROID_NN_V1_1
arovir01b0717b52018-09-05 17:03:25 +0100252template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100253#endif
254
arovir01b0717b52018-09-05 17:03:25 +0100255} // namespace armnn_driver