blob: cbb2e642940ea617a3360f8625c71afd47727cc8 [file] [log] [blame]
telsoa01ce3e84a2018-08-31 09:31:35 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa01ce3e84a2018-08-31 09:31:35 +01004//
5
Matteo Martincighe48bdff2018-09-03 13:50:50 +01006#define LOG_TAG "ArmnnDriver"
7
telsoa01ce3e84a2018-08-31 09:31:35 +01008#include "ArmnnDriverImpl.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +01009#include "ArmnnPreparedModel.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +010010
Mike Kellyb5fdf382019-06-11 16:35:25 +010011#ifdef ARMNN_ANDROID_NN_V1_2 // Using ::android::hardware::neuralnetworks::V1_2
12#include "ArmnnPreparedModel_1_2.hpp"
telsoa01ce3e84a2018-08-31 09:31:35 +010013#endif
14
Mike Kellyb5fdf382019-06-11 16:35:25 +010015#include "ModelToINetworkConverter.hpp"
16#include "SystemPropertiesUtils.hpp"
17#include <ValidateHal.h>
telsoa01ce3e84a2018-08-31 09:31:35 +010018#include <log/log.h>
19
20using namespace std;
21using namespace android;
22using namespace android::nn;
23using namespace android::hardware;
24
25namespace
26{
27
Matthew Bentham912b3622019-05-03 15:49:14 +010028void NotifyCallbackAndCheck(const sp<V1_0::IPreparedModelCallback>& callback,
telsoa01ce3e84a2018-08-31 09:31:35 +010029 ErrorStatus errorStatus,
Matthew Bentham912b3622019-05-03 15:49:14 +010030 const sp<V1_0::IPreparedModel>& preparedModelPtr)
telsoa01ce3e84a2018-08-31 09:31:35 +010031{
32 Return<void> returned = callback->notify(errorStatus, preparedModelPtr);
33 // This check is required, if the callback fails and it isn't checked it will bring down the service
34 if (!returned.isOk())
35 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +010036 ALOGE("ArmnnDriverImpl::prepareModel: hidl callback failed to return properly: %s ",
Mike Kellyb5fdf382019-06-11 16:35:25 +010037 returned.description().c_str());
telsoa01ce3e84a2018-08-31 09:31:35 +010038 }
39}
40
41Return<ErrorStatus> FailPrepareModel(ErrorStatus error,
42 const string& message,
Matthew Bentham912b3622019-05-03 15:49:14 +010043 const sp<V1_0::IPreparedModelCallback>& callback)
telsoa01ce3e84a2018-08-31 09:31:35 +010044{
Matteo Martincighe48bdff2018-09-03 13:50:50 +010045 ALOGW("ArmnnDriverImpl::prepareModel: %s", message.c_str());
telsoa01ce3e84a2018-08-31 09:31:35 +010046 NotifyCallbackAndCheck(callback, error, nullptr);
47 return error;
48}
49
Mike Kellyb5fdf382019-06-11 16:35:25 +010050
telsoa01ce3e84a2018-08-31 09:31:35 +010051} // namespace
52
53namespace armnn_driver
54{
telsoa01ce3e84a2018-08-31 09:31:35 +010055
arovir01b0717b52018-09-05 17:03:25 +010056template<typename HalPolicy>
arovir01b0717b52018-09-05 17:03:25 +010057Return<ErrorStatus> ArmnnDriverImpl<HalPolicy>::prepareModel(
telsoa01ce3e84a2018-08-31 09:31:35 +010058 const armnn::IRuntimePtr& runtime,
59 const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
60 const DriverOptions& options,
Matteo Martincighe48bdff2018-09-03 13:50:50 +010061 const HalModel& model,
Matthew Bentham912b3622019-05-03 15:49:14 +010062 const sp<V1_0::IPreparedModelCallback>& cb,
Matteo Martincighe48bdff2018-09-03 13:50:50 +010063 bool float32ToFloat16)
telsoa01ce3e84a2018-08-31 09:31:35 +010064{
Matteo Martincighe48bdff2018-09-03 13:50:50 +010065 ALOGV("ArmnnDriverImpl::prepareModel()");
telsoa01ce3e84a2018-08-31 09:31:35 +010066
67 if (cb.get() == nullptr)
68 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +010069 ALOGW("ArmnnDriverImpl::prepareModel: Invalid callback passed to prepareModel");
telsoa01ce3e84a2018-08-31 09:31:35 +010070 return ErrorStatus::INVALID_ARGUMENT;
71 }
72
73 if (!runtime)
74 {
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +010075 return FailPrepareModel(ErrorStatus::DEVICE_UNAVAILABLE, "Device unavailable", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +010076 }
77
78 if (!android::nn::validateModel(model))
79 {
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +010080 return FailPrepareModel(ErrorStatus::INVALID_ARGUMENT, "Invalid model passed as input", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +010081 }
82
83 // Deliberately ignore any unsupported operations requested by the options -
84 // at this point we're being asked to prepare a model that we've already declared support for
85 // and the operation indices may be different to those in getSupportedOperations anyway.
86 set<unsigned int> unsupportedOperations;
Nattapat Chaimanowongd5fd9762019-04-04 13:33:10 +010087 ModelToINetworkConverter<HalPolicy> modelConverter(options.GetBackends(),
88 model,
89 unsupportedOperations);
telsoa01ce3e84a2018-08-31 09:31:35 +010090
91 if (modelConverter.GetConversionResult() != ConversionResult::Success)
92 {
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +010093 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, "ModelToINetworkConverter failed", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +010094 return ErrorStatus::NONE;
95 }
96
Matteo Martincighe48bdff2018-09-03 13:50:50 +010097 // Optimize the network
telsoa01ce3e84a2018-08-31 09:31:35 +010098 armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
99 armnn::OptimizerOptions OptOptions;
100 OptOptions.m_ReduceFp32ToFp16 = float32ToFloat16;
101
jimfly0107dedda2018-10-09 12:29:41 +0100102 std::vector<std::string> errMessages;
telsoa01ce3e84a2018-08-31 09:31:35 +0100103 try
104 {
105 optNet = armnn::Optimize(*modelConverter.GetINetwork(),
Nattapat Chaimanowongd5fd9762019-04-04 13:33:10 +0100106 options.GetBackends(),
telsoa01ce3e84a2018-08-31 09:31:35 +0100107 runtime->GetDeviceSpec(),
jimfly0107dedda2018-10-09 12:29:41 +0100108 OptOptions,
109 errMessages);
telsoa01ce3e84a2018-08-31 09:31:35 +0100110 }
Mike Kellyc7d0d442019-12-11 19:27:11 +0000111 catch (std::exception& e)
telsoa01ce3e84a2018-08-31 09:31:35 +0100112 {
113 stringstream message;
Mike Kellyc7d0d442019-12-11 19:27:11 +0000114 message << "Exception (" << e.what() << ") caught from optimize.";
telsoa01ce3e84a2018-08-31 09:31:35 +0100115 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, message.str(), cb);
116 return ErrorStatus::NONE;
117 }
118
119 // Check that the optimized network is valid.
120 if (!optNet)
121 {
jimfly0107dedda2018-10-09 12:29:41 +0100122 stringstream message;
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100123 message << "Invalid optimized network";
124 for (const string& msg : errMessages)
125 {
jimfly0107dedda2018-10-09 12:29:41 +0100126 message << "\n" << msg;
127 }
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100128 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, message.str(), cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100129 return ErrorStatus::NONE;
130 }
131
Jim Flynn14557e72019-12-16 11:50:29 +0000132 // Export the optimized network graph to a dot file if an output dump directory
133 // has been specified in the drivers' arguments.
134 std::string dotGraphFileName = ExportNetworkGraphToDotFile(*optNet, options.GetRequestInputsAndOutputsDumpDir());
135
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100136 // Load it into the runtime.
telsoa01ce3e84a2018-08-31 09:31:35 +0100137 armnn::NetworkId netId = 0;
138 try
139 {
140 if (runtime->LoadNetwork(netId, move(optNet)) != armnn::Status::Success)
141 {
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100142 return FailPrepareModel(ErrorStatus::GENERAL_FAILURE, "Network could not be loaded", cb);
telsoa01ce3e84a2018-08-31 09:31:35 +0100143 }
144 }
Mike Kellyc7d0d442019-12-11 19:27:11 +0000145 catch (std::exception& e)
telsoa01ce3e84a2018-08-31 09:31:35 +0100146 {
147 stringstream message;
Mike Kellyc7d0d442019-12-11 19:27:11 +0000148 message << "Exception (" << e.what()<< ") caught from LoadNetwork.";
telsoa01ce3e84a2018-08-31 09:31:35 +0100149 FailPrepareModel(ErrorStatus::GENERAL_FAILURE, message.str(), cb);
150 return ErrorStatus::NONE;
151 }
152
Jim Flynn14557e72019-12-16 11:50:29 +0000153 // Now that we have a networkId for the graph rename the dump file to use it
154 // so that we can associate the graph file and the input/output tensor dump files
155 RenameGraphDotFile(dotGraphFileName,
156 options.GetRequestInputsAndOutputsDumpDir(),
157 netId);
Jim Flynn4d3a24b2019-12-13 14:43:24 +0000158
arovir01b0717b52018-09-05 17:03:25 +0100159 unique_ptr<ArmnnPreparedModel<HalPolicy>> preparedModel(
Mike Kellyb5fdf382019-06-11 16:35:25 +0100160 new ArmnnPreparedModel<HalPolicy>(
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100161 netId,
162 runtime.get(),
163 model,
164 options.GetRequestInputsAndOutputsDumpDir(),
165 options.IsGpuProfilingEnabled()));
telsoa01ce3e84a2018-08-31 09:31:35 +0100166
167 // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
168 // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
Matthew Bentham16196e22019-04-01 17:17:58 +0100169 if (!preparedModel->ExecuteWithDummyInputs())
170 {
171 return FailPrepareModel(ErrorStatus::GENERAL_FAILURE, "Network could not be executed", cb);
172 }
telsoa01ce3e84a2018-08-31 09:31:35 +0100173
174 if (clTunedParameters &&
175 options.GetClTunedParametersMode() == armnn::IGpuAccTunedParameters::Mode::UpdateTunedParameters)
176 {
177 // Now that we've done one inference the CL kernel parameters will have been tuned, so save the updated file.
178 try
179 {
180 clTunedParameters->Save(options.GetClTunedParametersFile().c_str());
181 }
Mike Kellyc7d0d442019-12-11 19:27:11 +0000182 catch (std::exception& error)
telsoa01ce3e84a2018-08-31 09:31:35 +0100183 {
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100184 ALOGE("ArmnnDriverImpl::prepareModel: Failed to save CL tuned parameters file '%s': %s",
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100185 options.GetClTunedParametersFile().c_str(), error.what());
telsoa01ce3e84a2018-08-31 09:31:35 +0100186 }
187 }
188
189 NotifyCallbackAndCheck(cb, ErrorStatus::NONE, preparedModel.release());
190
191 return ErrorStatus::NONE;
192}
193
arovir01b0717b52018-09-05 17:03:25 +0100194template<typename HalPolicy>
Mike Kellyb5fdf382019-06-11 16:35:25 +0100195Return<void> ArmnnDriverImpl<HalPolicy>::getSupportedOperations(const armnn::IRuntimePtr& runtime,
196 const DriverOptions& options,
197 const HalModel& model,
198 HalGetSupportedOperations_cb cb)
199{
Jim Flynn14557e72019-12-16 11:50:29 +0000200 std::stringstream ss;
201 ss << "ArmnnDriverImpl::getSupportedOperations()";
202 std::string fileName;
203 std::string timestamp;
204 if (!options.GetRequestInputsAndOutputsDumpDir().empty())
205 {
206 timestamp = GetFileTimestamp();
207 fileName = boost::str(boost::format("%1%/%2%_getSupportedOperations.txt")
208 % options.GetRequestInputsAndOutputsDumpDir()
209 % timestamp);
210 ss << " : " << fileName;
211 }
212 ALOGV(ss.str().c_str());
213
214 if (!options.GetRequestInputsAndOutputsDumpDir().empty())
215 {
216 //dump the marker file
217 std::ofstream fileStream;
218 fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc);
219 if (fileStream.good())
220 {
221 fileStream << timestamp << std::endl;
222 }
223 fileStream.close();
224 }
Mike Kellyb5fdf382019-06-11 16:35:25 +0100225
226 vector<bool> result;
227
228 if (!runtime)
229 {
230 cb(ErrorStatus::DEVICE_UNAVAILABLE, result);
231 return Void();
232 }
233
234 // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
235 if (!android::nn::validateModel(model))
236 {
237 cb(ErrorStatus::INVALID_ARGUMENT, result);
238 return Void();
239 }
240
241 // Attempt to convert the model to an ArmNN input network (INetwork).
242 ModelToINetworkConverter<HalPolicy> modelConverter(options.GetBackends(),
243 model,
244 options.GetForcedUnsupportedOperations());
245
246 if (modelConverter.GetConversionResult() != ConversionResult::Success
247 && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
248 {
249 cb(ErrorStatus::GENERAL_FAILURE, result);
250 return Void();
251 }
252
253 // Check each operation if it was converted successfully and copy the flags
254 // into the result (vector<bool>) that we need to return to Android.
255 result.reserve(model.operations.size());
256 for (uint32_t operationIdx = 0; operationIdx < model.operations.size(); operationIdx++)
257 {
258 bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
259 result.push_back(operationSupported);
260 }
261
262 cb(ErrorStatus::NONE, result);
263 return Void();
264}
265
266template<typename HalPolicy>
arovir01b0717b52018-09-05 17:03:25 +0100267Return<DeviceStatus> ArmnnDriverImpl<HalPolicy>::getStatus()
telsoa01ce3e84a2018-08-31 09:31:35 +0100268{
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100269 ALOGV("ArmnnDriver::getStatus()");
telsoa01ce3e84a2018-08-31 09:31:35 +0100270
271 return DeviceStatus::AVAILABLE;
272}
273
arovir01b0717b52018-09-05 17:03:25 +0100274///
275/// Class template specializations
276///
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100277
arovir01b0717b52018-09-05 17:03:25 +0100278template class ArmnnDriverImpl<hal_1_0::HalPolicy>;
279
Matteo Martincigh8b287c22018-09-07 09:25:10 +0100280#ifdef ARMNN_ANDROID_NN_V1_1
arovir01b0717b52018-09-05 17:03:25 +0100281template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100282#endif
283
Mike Kellyb5fdf382019-06-11 16:35:25 +0100284#ifdef ARMNN_ANDROID_NN_V1_2
285template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
286template class ArmnnDriverImpl<hal_1_2::HalPolicy>;
287#endif
288
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100289} // namespace armnn_driver