blob: 2f1abef79cc44c8aabaac7a55ec490e026b43e9d [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa015307bc12018-03-09 13:51:08 +00004//
5
6#define LOG_TAG "ArmnnDriver"
7
8#include "ArmnnPreparedModel.hpp"
9#include "Utils.hpp"
10
11#include <boost/format.hpp>
12#include <log/log.h>
13#include <OperationsUtils.h>
14
Mike Kellyb5fdf382019-06-11 16:35:25 +010015#if defined(ARMNN_ANDROID_P) || defined(ARMNN_ANDROID_Q)
surmeh01deb3bdb2018-07-05 12:06:04 +010016// The headers of the ML framework have changed between Android O and Android P.
17// The validation functions have been moved into their own header, ValidateHal.h.
18#include <ValidateHal.h>
19#endif
20
telsoa015307bc12018-03-09 13:51:08 +000021#include <cassert>
22#include <cinttypes>
23
24using namespace android;
25
26namespace
27{
28using namespace armnn_driver;
29
Matthew Bentham9e80cd22019-05-03 22:54:36 +010030void NotifyCallbackAndCheck(const ::android::sp<V1_0::IExecutionCallback>& callback, ErrorStatus errorStatus,
telsoa015307bc12018-03-09 13:51:08 +000031 std::string callingFunction)
32{
33 Return<void> returned = callback->notify(errorStatus);
34 // This check is required, if the callback fails and it isn't checked it will bring down the service
35 if (!returned.isOk())
36 {
37 ALOGE("ArmnnDriver::%s: hidl callback failed to return properly: %s",
38 callingFunction.c_str(), returned.description().c_str());
39 }
40}
41
42bool ValidateRequestArgument(const RequestArgument& requestArg, const armnn::TensorInfo& tensorInfo)
43{
44 if (requestArg.dimensions.size() != 0)
45 {
46 if (requestArg.dimensions.size() != tensorInfo.GetNumDimensions())
47 {
48 ALOGE("Mismatched dimensions (request argument: %zu, expected: %u)",
49 requestArg.dimensions.size(), tensorInfo.GetNumDimensions());
50 return false;
51 }
52
53 for (unsigned int d = 0; d < tensorInfo.GetNumDimensions(); ++d)
54 {
55 if (requestArg.dimensions[d] != tensorInfo.GetShape()[d])
56 {
57 ALOGE("Mismatched size for dimension %d (request argument: %u, expected %u)",
58 d, requestArg.dimensions[d], tensorInfo.GetShape()[d]);
59 return false;
60 }
61 }
62 }
63
64 return true;
65}
66
67armnn::Tensor GetTensorForRequestArgument(const RequestArgument& requestArg,
68 const armnn::TensorInfo& tensorInfo,
69 const std::vector<::android::nn::RunTimePoolInfo>& requestPools)
70{
71 if (!ValidateRequestArgument(requestArg, tensorInfo))
72 {
73 return armnn::Tensor();
74 }
75
76 return armnn::Tensor(tensorInfo, GetMemoryFromPool(requestArg.location, requestPools));
77}
78
79inline std::string BuildTensorName(const char* tensorNamePrefix, std::size_t index)
80{
81 return tensorNamePrefix + std::to_string(index);
82}
83
Matteo Martincighe48bdff2018-09-03 13:50:50 +010084} // anonymous namespace
telsoa015307bc12018-03-09 13:51:08 +000085
telsoa01ce3e84a2018-08-31 09:31:35 +010086using namespace android::hardware;
87
telsoa015307bc12018-03-09 13:51:08 +000088namespace armnn_driver
89{
Matteo Martincighe48bdff2018-09-03 13:50:50 +010090template<typename HalVersion>
Mike Kelly65c42dc2019-07-22 14:06:00 +010091RequestThread<ArmnnPreparedModel, HalVersion, ArmnnCallback_1_0> ArmnnPreparedModel<HalVersion>::m_RequestThread;
telsoa015307bc12018-03-09 13:51:08 +000092
Matteo Martincighe48bdff2018-09-03 13:50:50 +010093template<typename HalVersion>
telsoa015307bc12018-03-09 13:51:08 +000094template <typename TensorBindingCollection>
Matteo Martincighe48bdff2018-09-03 13:50:50 +010095void ArmnnPreparedModel<HalVersion>::DumpTensorsIfRequired(char const* tensorNamePrefix,
96 const TensorBindingCollection& tensorBindings)
telsoa015307bc12018-03-09 13:51:08 +000097{
98 if (!m_RequestInputsAndOutputsDumpDir.empty())
99 {
100 const std::string requestName = boost::str(boost::format("%1%_%2%.dump") % m_NetworkId % m_RequestCount);
101 for (std::size_t i = 0u; i < tensorBindings.size(); ++i)
102 {
103 DumpTensor(m_RequestInputsAndOutputsDumpDir,
104 requestName,
105 BuildTensorName(tensorNamePrefix, i),
106 tensorBindings[i].second);
107 }
108 }
109}
110
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100111template<typename HalVersion>
112ArmnnPreparedModel<HalVersion>::ArmnnPreparedModel(armnn::NetworkId networkId,
113 armnn::IRuntime* runtime,
114 const HalModel& model,
115 const std::string& requestInputsAndOutputsDumpDir,
116 const bool gpuProfilingEnabled)
telsoa01ce3e84a2018-08-31 09:31:35 +0100117 : m_NetworkId(networkId)
118 , m_Runtime(runtime)
119 , m_Model(model)
120 , m_RequestCount(0)
121 , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
122 , m_GpuProfilingEnabled(gpuProfilingEnabled)
telsoa015307bc12018-03-09 13:51:08 +0000123{
telsoa01ce3e84a2018-08-31 09:31:35 +0100124 // Enable profiling if required.
125 m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled);
telsoa015307bc12018-03-09 13:51:08 +0000126}
127
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100128template<typename HalVersion>
129ArmnnPreparedModel<HalVersion>::~ArmnnPreparedModel()
telsoa015307bc12018-03-09 13:51:08 +0000130{
telsoa01ce3e84a2018-08-31 09:31:35 +0100131 // Get a hold of the profiler used by this model.
132 std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkId);
133
134 // Unload the network associated with this model.
telsoa015307bc12018-03-09 13:51:08 +0000135 m_Runtime->UnloadNetwork(m_NetworkId);
telsoa01ce3e84a2018-08-31 09:31:35 +0100136
137 // Dump the profiling info to a file if required.
138 DumpJsonProfilingIfRequired(m_GpuProfilingEnabled, m_RequestInputsAndOutputsDumpDir, m_NetworkId, profiler.get());
telsoa015307bc12018-03-09 13:51:08 +0000139}
140
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100141template<typename HalVersion>
142Return<ErrorStatus> ArmnnPreparedModel<HalVersion>::execute(const Request& request,
Matthew Bentham9e80cd22019-05-03 22:54:36 +0100143 const ::android::sp<V1_0::IExecutionCallback>& callback)
telsoa015307bc12018-03-09 13:51:08 +0000144{
145 ALOGV("ArmnnPreparedModel::execute(): %s", GetModelSummary(m_Model).c_str());
146 m_RequestCount++;
147
148 if (callback.get() == nullptr) {
149 ALOGE("ArmnnPreparedModel::execute invalid callback passed");
150 return ErrorStatus::INVALID_ARGUMENT;
151 }
152
153 if (!android::nn::validateRequest(request, m_Model))
154 {
155 NotifyCallbackAndCheck(callback, ErrorStatus::INVALID_ARGUMENT, "ArmnnPreparedModel::execute");
156 return ErrorStatus::INVALID_ARGUMENT;
157 }
158
159 if (!m_RequestInputsAndOutputsDumpDir.empty())
160 {
161 ALOGD("Dumping inputs and outputs for request %" PRIuPTR, reinterpret_cast<std::uintptr_t>(callback.get()));
162 }
163
164 // allocate the tensors on the heap, as they are passed to the request thread
165 auto pInputTensors = std::make_shared<armnn::InputTensors>();
166 auto pOutputTensors = std::make_shared<armnn::OutputTensors>();
167
168 // map the memory pool into shared pointers
169 // use a shared memory pools vector on the heap, as it is passed to the request thread
170 auto pMemPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
171 if (!setRunTimePoolInfosFromHidlMemories(pMemPools.get(), request.pools))
172 {
173 NotifyCallbackAndCheck(callback, ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::execute");
174 return ErrorStatus::GENERAL_FAILURE;
175 }
176
177 // add the inputs and outputs with their data
178 try
179 {
180 pInputTensors->reserve(request.inputs.size());
181 for (unsigned int i = 0; i < request.inputs.size(); i++)
182 {
183 const auto& inputArg = request.inputs[i];
184
185 const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
186 const armnn::Tensor inputTensor = GetTensorForRequestArgument(inputArg, inputTensorInfo, *pMemPools);
187 if (inputTensor.GetMemoryArea() == nullptr)
188 {
189 ALOGE("Cannot execute request. Error converting request input %u to tensor", i);
190 return ErrorStatus::GENERAL_FAILURE;
191 }
192
193 pInputTensors->emplace_back(i, inputTensor);
194 }
195
196 pOutputTensors->reserve(request.outputs.size());
197 for (unsigned int i = 0; i < request.outputs.size(); i++)
198 {
199 const auto& outputArg = request.outputs[i];
200
201 const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
202 const armnn::Tensor outputTensor = GetTensorForRequestArgument(outputArg, outputTensorInfo, *pMemPools);
203 if (outputTensor.GetMemoryArea() == nullptr)
204 {
205 ALOGE("Cannot execute request. Error converting request output %u to tensor", i);
206 return ErrorStatus::GENERAL_FAILURE;
207 }
208
209 pOutputTensors->emplace_back(i, outputTensor);
210 }
211 }
Mike Kellyc7d0d442019-12-11 19:27:11 +0000212 catch (std::exception& e)
telsoa015307bc12018-03-09 13:51:08 +0000213 {
Mike Kellyc7d0d442019-12-11 19:27:11 +0000214 ALOGW("Exception caught while preparing for EnqueueWorkload: %s", e.what());
telsoa015307bc12018-03-09 13:51:08 +0000215 NotifyCallbackAndCheck(callback, ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::execute");
216 return ErrorStatus::GENERAL_FAILURE;
217 }
218
219 ALOGV("ArmnnPreparedModel::execute(...) before PostMsg");
telsoa015307bc12018-03-09 13:51:08 +0000220
Mike Kelly65c42dc2019-07-22 14:06:00 +0100221 auto cb = [callback](ErrorStatus errorStatus, std::string callingFunction)
222 {
223 NotifyCallbackAndCheck(callback, errorStatus, callingFunction);
224 };
225
226 ArmnnCallback_1_0 armnnCb;
227 armnnCb.callback = cb;
228 // post the request for asynchronous execution
229 m_RequestThread.PostMsg(this, pMemPools, pInputTensors, pOutputTensors, armnnCb);
230 ALOGV("ArmnnPreparedModel::execute(...) after PostMsg");
telsoa015307bc12018-03-09 13:51:08 +0000231 return ErrorStatus::NONE; // successfully queued
232}
233
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100234template<typename HalVersion>
235void ArmnnPreparedModel<HalVersion>::ExecuteGraph(
236 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
237 std::shared_ptr<armnn::InputTensors>& pInputTensors,
238 std::shared_ptr<armnn::OutputTensors>& pOutputTensors,
Mike Kelly65c42dc2019-07-22 14:06:00 +0100239 ArmnnCallback_1_0 cb)
telsoa015307bc12018-03-09 13:51:08 +0000240{
241 ALOGV("ArmnnPreparedModel::ExecuteGraph(...)");
242
243 DumpTensorsIfRequired("Input", *pInputTensors);
244
245 // run it
246 try
247 {
Matthew Bentham16196e22019-04-01 17:17:58 +0100248 armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, *pInputTensors, *pOutputTensors);
249 if (status != armnn::Status::Success)
250 {
251 ALOGW("EnqueueWorkload failed");
Mike Kelly65c42dc2019-07-22 14:06:00 +0100252 cb.callback(ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph");
Matthew Bentham16196e22019-04-01 17:17:58 +0100253 return;
254 }
telsoa015307bc12018-03-09 13:51:08 +0000255 }
Mike Kellyc7d0d442019-12-11 19:27:11 +0000256 catch (std::exception& e)
telsoa015307bc12018-03-09 13:51:08 +0000257 {
Mike Kellyc7d0d442019-12-11 19:27:11 +0000258 ALOGW("Exception caught from EnqueueWorkload: %s", e.what());
Mike Kelly65c42dc2019-07-22 14:06:00 +0100259 cb.callback(ErrorStatus::GENERAL_FAILURE, "ArmnnPreparedModel::ExecuteGraph");
telsoa015307bc12018-03-09 13:51:08 +0000260 return;
261 }
262
263 DumpTensorsIfRequired("Output", *pOutputTensors);
264
265 // Commit output buffers.
266 // Note that we update *all* pools, even if they aren't actually used as outputs -
267 // this is simpler and is what the CpuExecutor does.
268 for (android::nn::RunTimePoolInfo& pool : *pMemPools)
269 {
270 pool.update();
271 }
272
Mike Kelly65c42dc2019-07-22 14:06:00 +0100273 cb.callback(ErrorStatus::NONE, "ExecuteGraph");
telsoa015307bc12018-03-09 13:51:08 +0000274}
275
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100276template<typename HalVersion>
Matthew Bentham16196e22019-04-01 17:17:58 +0100277bool ArmnnPreparedModel<HalVersion>::ExecuteWithDummyInputs()
telsoa015307bc12018-03-09 13:51:08 +0000278{
279 std::vector<std::vector<char>> storage;
280 armnn::InputTensors inputTensors;
281 for (unsigned int i = 0; i < m_Model.inputIndexes.size(); i++)
282 {
283 const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
284 storage.emplace_back(inputTensorInfo.GetNumBytes());
285 const armnn::ConstTensor inputTensor(inputTensorInfo, storage.back().data());
286
287 inputTensors.emplace_back(i, inputTensor);
288 }
289
290 armnn::OutputTensors outputTensors;
291 for (unsigned int i = 0; i < m_Model.outputIndexes.size(); i++)
292 {
293 const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
294 storage.emplace_back(outputTensorInfo.GetNumBytes());
295 const armnn::Tensor outputTensor(outputTensorInfo, storage.back().data());
296
297 outputTensors.emplace_back(i, outputTensor);
298 }
299
300 try
301 {
Matthew Bentham16196e22019-04-01 17:17:58 +0100302 armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
303 if (status != armnn::Status::Success)
304 {
305 ALOGW("ExecuteWithDummyInputs: EnqueueWorkload failed");
306 return false;
307 }
telsoa015307bc12018-03-09 13:51:08 +0000308 }
Mike Kellyc7d0d442019-12-11 19:27:11 +0000309 catch (std::exception& e)
telsoa015307bc12018-03-09 13:51:08 +0000310 {
Mike Kellyc7d0d442019-12-11 19:27:11 +0000311 ALOGW("ExecuteWithDummyInputs: Exception caught from EnqueueWorkload: %s", e.what());
Matthew Bentham16196e22019-04-01 17:17:58 +0100312 return false;
telsoa015307bc12018-03-09 13:51:08 +0000313 }
Matthew Bentham16196e22019-04-01 17:17:58 +0100314 return true;
telsoa015307bc12018-03-09 13:51:08 +0000315}
316
arovir01b0717b52018-09-05 17:03:25 +0100317///
318/// Class template specializations
319///
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100320
arovir01b0717b52018-09-05 17:03:25 +0100321template class ArmnnPreparedModel<hal_1_0::HalPolicy>;
322
Matteo Martincigh8b287c22018-09-07 09:25:10 +0100323#ifdef ARMNN_ANDROID_NN_V1_1
arovir01b0717b52018-09-05 17:03:25 +0100324template class ArmnnPreparedModel<hal_1_1::HalPolicy>;
Matteo Martincighe48bdff2018-09-03 13:50:50 +0100325#endif
326
Mike Kellyb5fdf382019-06-11 16:35:25 +0100327#ifdef ARMNN_ANDROID_NN_V1_2
328template class ArmnnPreparedModel<hal_1_1::HalPolicy>;
329template class ArmnnPreparedModel<hal_1_2::HalPolicy>;
330#endif
Nikhil Raj77605822018-09-03 11:25:56 +0100331} // namespace armnn_driver