blob: 5d2136597f74091a765bef4189ede3b616f5c16b [file] [log] [blame]
telsoa01ce3e84a2018-08-31 09:31:35 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa01ce3e84a2018-08-31 09:31:35 +01004//
5
Matteo Martincighe48bdff2018-09-03 13:50:50 +01006#define LOG_TAG "ArmnnDriver"
7
telsoa01ce3e84a2018-08-31 09:31:35 +01008#include "ArmnnDriverImpl.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +01009#include "ArmnnPreparedModel.hpp"
arovir01b0717b52018-09-05 17:03:25 +010010#include "ModelToINetworkConverter.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +010011#include "SystemPropertiesUtils.hpp"
12
13#if defined(ARMNN_ANDROID_P)
14// The headers of the ML framework have changed between Android O and Android P.
15// The validation functions have been moved into their own header, ValidateHal.h.
16#include <ValidateHal.h>
17#endif
18
19#include <log/log.h>
20
21using namespace std;
22using namespace android;
23using namespace android::nn;
24using namespace android::hardware;
25
26namespace
27{
28
telsoa01ce3e84a2018-08-31 09:31:35 +010029void NotifyCallbackAndCheck(const sp<IPreparedModelCallback>& callback,
30 ErrorStatus errorStatus,
31 const sp<IPreparedModel>& preparedModelPtr)
32{
33 Return<void> returned = callback->notify(errorStatus, preparedModelPtr);
34 // This check is required, if the callback fails and it isn't checked it will bring down the service
35 if (!returned.isOk())
36 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +010037 ALOGE("ArmnnDriverImpl::prepareModel: hidl callback failed to return properly: %s ",
telsoa01ce3e84a2018-08-31 09:31:35 +010038 returned.description().c_str());
39 }
40}
41
42Return<ErrorStatus> FailPrepareModel(ErrorStatus error,
43 const string& message,
44 const sp<IPreparedModelCallback>& callback)
45{
Matteo Martincighe48bdff2018-09-03 13:50:50 +010046 ALOGW("ArmnnDriverImpl::prepareModel: %s", message.c_str());
telsoa01ce3e84a2018-08-31 09:31:35 +010047 NotifyCallbackAndCheck(callback, error, nullptr);
48 return error;
49}
50
51} // namespace
52
53namespace armnn_driver
54{
telsoa01ce3e84a2018-08-31 09:31:35 +010055
arovir01b0717b52018-09-05 17:03:25 +010056template<typename HalPolicy>
57Return<void> ArmnnDriverImpl<HalPolicy>::getSupportedOperations(const armnn::IRuntimePtr& runtime,
58 const DriverOptions& options,
59 const HalModel& model,
60 HalGetSupportedOperations_cb cb)
telsoa01ce3e84a2018-08-31 09:31:35 +010061{
Matteo Martincighe48bdff2018-09-03 13:50:50 +010062 ALOGV("ArmnnDriverImpl::getSupportedOperations()");
telsoa01ce3e84a2018-08-31 09:31:35 +010063
64 vector<bool> result;
65
66 if (!runtime)
67 {
68 cb(ErrorStatus::DEVICE_UNAVAILABLE, result);
69 return Void();
70 }
71
Matteo Martincighe48bdff2018-09-03 13:50:50 +010072 // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
telsoa01ce3e84a2018-08-31 09:31:35 +010073 if (!android::nn::validateModel(model))
74 {
75 cb(ErrorStatus::INVALID_ARGUMENT, result);
76 return Void();
77 }
78
79 // Attempt to convert the model to an ArmNN input network (INetwork).
arovir01b0717b52018-09-05 17:03:25 +010080 ModelToINetworkConverter<HalPolicy> modelConverter(options.GetComputeDevice(),
Matteo Martincighe48bdff2018-09-03 13:50:50 +010081 model,
82 options.GetForcedUnsupportedOperations());
telsoa01ce3e84a2018-08-31 09:31:35 +010083
84 if (modelConverter.GetConversionResult() != ConversionResult::Success
Matteo Martincighe48bdff2018-09-03 13:50:50 +010085 && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
telsoa01ce3e84a2018-08-31 09:31:35 +010086 {
87 cb(ErrorStatus::GENERAL_FAILURE, result);
88 return Void();
89 }
90
91 // Check each operation if it was converted successfully and copy the flags
Matteo Martincighe48bdff2018-09-03 13:50:50 +010092 // into the result (vector<bool>) that we need to return to Android.
telsoa01ce3e84a2018-08-31 09:31:35 +010093 result.reserve(model.operations.size());
94 for (uint32_t operationIdx = 0; operationIdx < model.operations.size(); operationIdx++)
95 {
96 bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
97 result.push_back(operationSupported);
98 }
99
100 cb(ErrorStatus::NONE, result);
101 return Void();
102}
103
arovir01b0717b52018-09-05 17:03:25 +0100104template<typename HalPolicy>
105Return<ErrorStatus> ArmnnDriverImpl<HalPolicy>::prepareModel(
telsoa01ce3e84a2018-08-31 09:31:35 +0100106 const armnn::IRuntimePtr& runtime,
107 const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
108 const DriverOptions& options,
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100109 const HalModel& model,
telsoa01ce3e84a2018-08-31 09:31:35 +0100110 const sp<IPreparedModelCallback>& cb,
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100111 bool float32ToFloat16)
telsoa01ce3e84a2018-08-31 09:31:35 +0100112{
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100113 ALOGV("ArmnnDriverImpl::prepareModel()");
telsoa01ce3e84a2018-08-31 09:31:35 +0100114
115 if (cb.get() == nullptr)
116 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100117 ALOGW("ArmnnDriverImpl::prepareModel: Invalid callback passed to prepareModel");
telsoa01ce3e84a2018-08-31 09:31:35 +0100118 return ErrorStatus::INVALID_ARGUMENT;
119 }
120
121 if (!runtime)
122 {
123 return FailPrepareModel(ErrorStatus::DEVICE_UNAVAILABLE,
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100124 "ArmnnDriverImpl::prepareModel: Device unavailable", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100125 }
126
127 if (!android::nn::validateModel(model))
128 {
129 return FailPrepareModel(ErrorStatus::INVALID_ARGUMENT,
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100130 "ArmnnDriverImpl::prepareModel: Invalid model passed as input", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100131 }
132
133 // Deliberately ignore any unsupported operations requested by the options -
134 // at this point we're being asked to prepare a model that we've already declared support for
135 // and the operation indices may be different to those in getSupportedOperations anyway.
136 set<unsigned int> unsupportedOperations;
arovir01b0717b52018-09-05 17:03:25 +0100137 ModelToINetworkConverter<HalPolicy> modelConverter(options.GetComputeDevice(),
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100138 model,
139 unsupportedOperations);
telsoa01ce3e84a2018-08-31 09:31:35 +0100140
141 if (modelConverter.GetConversionResult() != ConversionResult::Success)
142 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100143 FailPrepareModel(ErrorStatus::GENERAL_FAILURE,
144 "ArmnnDriverImpl::prepareModel: ModelToINetworkConverter failed", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100145 return ErrorStatus::NONE;
146 }
147
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100148 // Optimize the network
telsoa01ce3e84a2018-08-31 09:31:35 +0100149 armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
150 armnn::OptimizerOptions OptOptions;
151 OptOptions.m_ReduceFp32ToFp16 = float32ToFloat16;
152
153 try
154 {
155 optNet = armnn::Optimize(*modelConverter.GetINetwork(),
156 {options.GetComputeDevice()},
157 runtime->GetDeviceSpec(),
158 OptOptions);
159 }
160 catch (armnn::Exception &e)
161 {
162 stringstream message;
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100163 message << "ArmnnDriverImpl::prepareModel: armnn::Exception (" << e.what() << ") caught from optimize.";
telsoa01ce3e84a2018-08-31 09:31:35 +0100164 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, message.str(), cb);
165 return ErrorStatus::NONE;
166 }
167
168 // Check that the optimized network is valid.
169 if (!optNet)
170 {
171 FailPrepareModel(ErrorStatus::GENERAL_FAILURE,
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100172 "ArmnnDriverImpl::prepareModel: Invalid optimized network", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100173 return ErrorStatus::NONE;
174 }
175
176 // Export the optimized network graph to a dot file if an output dump directory
177 // has been specified in the drivers' arguments.
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100178 ExportNetworkGraphToDotFile<HalModel>(*optNet, options.GetRequestInputsAndOutputsDumpDir(), model);
telsoa01ce3e84a2018-08-31 09:31:35 +0100179
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100180 // Load it into the runtime.
telsoa01ce3e84a2018-08-31 09:31:35 +0100181 armnn::NetworkId netId = 0;
182 try
183 {
184 if (runtime->LoadNetwork(netId, move(optNet)) != armnn::Status::Success)
185 {
186 return FailPrepareModel(ErrorStatus::GENERAL_FAILURE,
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100187 "ArmnnDriverImpl::prepareModel: Network could not be loaded", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100188 }
189 }
190 catch (armnn::Exception& e)
191 {
192 stringstream message;
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100193 message << "ArmnnDriverImpl::prepareModel: armnn::Exception (" << e.what()<< ") caught from LoadNetwork.";
telsoa01ce3e84a2018-08-31 09:31:35 +0100194 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, message.str(), cb);
195 return ErrorStatus::NONE;
196 }
197
arovir01b0717b52018-09-05 17:03:25 +0100198 unique_ptr<ArmnnPreparedModel<HalPolicy>> preparedModel(
199 new ArmnnPreparedModel<HalPolicy>(
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100200 netId,
201 runtime.get(),
202 model,
203 options.GetRequestInputsAndOutputsDumpDir(),
204 options.IsGpuProfilingEnabled()));
telsoa01ce3e84a2018-08-31 09:31:35 +0100205
206 // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
207 // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
208 preparedModel->ExecuteWithDummyInputs();
209
210 if (clTunedParameters &&
211 options.GetClTunedParametersMode() == armnn::IGpuAccTunedParameters::Mode::UpdateTunedParameters)
212 {
213 // Now that we've done one inference the CL kernel parameters will have been tuned, so save the updated file.
214 try
215 {
216 clTunedParameters->Save(options.GetClTunedParametersFile().c_str());
217 }
218 catch (const armnn::Exception& error)
219 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100220 ALOGE("ArmnnDriverImpl::prepareModel: Failed to save CL tuned parameters file '%s': %s",
telsoa01ce3e84a2018-08-31 09:31:35 +0100221 options.GetClTunedParametersFile().c_str(), error.what());
222 }
223 }
224
225 NotifyCallbackAndCheck(cb, ErrorStatus::NONE, preparedModel.release());
226
227 return ErrorStatus::NONE;
228}
229
arovir01b0717b52018-09-05 17:03:25 +0100230template<typename HalPolicy>
231Return<DeviceStatus> ArmnnDriverImpl<HalPolicy>::getStatus()
telsoa01ce3e84a2018-08-31 09:31:35 +0100232{
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100233 ALOGV("ArmnnDriver::getStatus()");
telsoa01ce3e84a2018-08-31 09:31:35 +0100234
235 return DeviceStatus::AVAILABLE;
236}
237
arovir01b0717b52018-09-05 17:03:25 +0100238///
239/// Class template specializations
240///
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100241
arovir01b0717b52018-09-05 17:03:25 +0100242template class ArmnnDriverImpl<hal_1_0::HalPolicy>;
243
Matteo Martincigh8b287c22018-09-07 09:25:10 +0100244#ifdef ARMNN_ANDROID_NN_V1_1
arovir01b0717b52018-09-05 17:03:25 +0100245template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100246#endif
247
arovir01b0717b52018-09-05 17:03:25 +0100248} // namespace armnn_driver