blob: 76ef4265919a0cf916db35d33511237a20d342d7 [file] [log] [blame]
Mike Kellyb5fdf382019-06-11 16:35:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#define LOG_TAG "ArmnnDriver"
7
8#include "ArmnnPreparedModel_1_2.hpp"
9#include "Utils.hpp"
10
11#include <boost/format.hpp>
12#include <log/log.h>
13#include <OperationsUtils.h>
14#include <ExecutionBurstServer.h>
15#include <ValidateHal.h>
16
17#include <cassert>
18#include <cinttypes>
19
20using namespace android;
21using namespace android::hardware;
22
Mike Kellyb5fdf382019-06-11 16:35:25 +010023namespace {
24
Mike Kelly44381512019-07-08 17:37:35 +010025static const Timing g_NoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
Mike Kellyb5fdf382019-06-11 16:35:25 +010026using namespace armnn_driver;
Mike Kelly44381512019-07-08 17:37:35 +010027using TimePoint = std::chrono::steady_clock::time_point;
28
29TimePoint Now()
30{
31 return std::chrono::steady_clock::now();
32}
33
34unsigned long MicrosecondsDuration(TimePoint endPoint, TimePoint startPoint)
35{
36 return static_cast<unsigned long>(std::chrono::duration_cast<std::chrono::microseconds>(
37 endPoint - startPoint).count());
38}
Mike Kellyb5fdf382019-06-11 16:35:25 +010039
Mike Kelly65c42dc2019-07-22 14:06:00 +010040void NotifyCallbackAndCheck(const ::android::sp<V1_0::IExecutionCallback>& callback,
Kevin Mayec1e5b82020-02-26 17:00:39 +000041 V1_0::ErrorStatus errorStatus,
Mike Kelly65c42dc2019-07-22 14:06:00 +010042 std::vector<OutputShape>,
43 const Timing,
Mike Kellyb5fdf382019-06-11 16:35:25 +010044 std::string callingFunction)
45{
46 Return<void> returned = callback->notify(errorStatus);
47 // This check is required, if the callback fails and it isn't checked it will bring down the service
48 if (!returned.isOk())
49 {
50 ALOGE("ArmnnDriver::%s: hidl callback failed to return properly: %s",
51 callingFunction.c_str(), returned.description().c_str());
52 }
53}
54
Mike Kelly65c42dc2019-07-22 14:06:00 +010055void NotifyCallbackAndCheck(const ::android::sp<V1_2::IExecutionCallback>& callback,
Kevin Mayec1e5b82020-02-26 17:00:39 +000056 V1_0::ErrorStatus errorStatus,
Mike Kelly65c42dc2019-07-22 14:06:00 +010057 std::vector<OutputShape> outputShapes,
58 const Timing timing,
Mike Kellyb5fdf382019-06-11 16:35:25 +010059 std::string callingFunction)
60{
Mike Kelly65c42dc2019-07-22 14:06:00 +010061 Return<void> returned = callback->notify_1_2(errorStatus, outputShapes, timing);
Mike Kellyb5fdf382019-06-11 16:35:25 +010062 // This check is required, if the callback fails and it isn't checked it will bring down the service
63 if (!returned.isOk())
64 {
65 ALOGE("ArmnnDriver::%s: hidl callback failed to return properly: %s",
66 callingFunction.c_str(), returned.description().c_str());
67 }
68}
69
70bool ValidateRequestArgument(const RequestArgument& requestArg, const armnn::TensorInfo& tensorInfo)
71{
72 if (requestArg.dimensions.size() != 0)
73 {
74 if (requestArg.dimensions.size() != tensorInfo.GetNumDimensions())
75 {
76 ALOGE("Mismatched dimensions (request argument: %zu, expected: %u)",
77 requestArg.dimensions.size(), tensorInfo.GetNumDimensions());
78 return false;
79 }
80
81 for (unsigned int d = 0; d < tensorInfo.GetNumDimensions(); ++d)
82 {
83 if (requestArg.dimensions[d] != tensorInfo.GetShape()[d])
84 {
85 ALOGE("Mismatched size for dimension %d (request argument: %u, expected %u)",
86 d, requestArg.dimensions[d], tensorInfo.GetShape()[d]);
87 return false;
88 }
89 }
90 }
91
92 return true;
93}
94
95armnn::Tensor GetTensorForRequestArgument(const RequestArgument& requestArg,
96 const armnn::TensorInfo& tensorInfo,
97 const std::vector<::android::nn::RunTimePoolInfo>& requestPools)
98{
99 if (!ValidateRequestArgument(requestArg, tensorInfo))
100 {
101 return armnn::Tensor();
102 }
103
104 return armnn::Tensor(tensorInfo, GetMemoryFromPool(requestArg.location, requestPools));
105}
106
107inline std::string BuildTensorName(const char* tensorNamePrefix, std::size_t index)
108{
109 return tensorNamePrefix + std::to_string(index);
110}
111
112} // anonymous namespace
113
114using namespace android::hardware;
115
116namespace armnn_driver
117{
118
119template<typename HalVersion>
Derek Lamberti4de83c52020-03-17 13:40:18 +0000120RequestThread<ArmnnPreparedModel_1_2, HalVersion, CallbackContext_1_2>
Mike Kelly65c42dc2019-07-22 14:06:00 +0100121 ArmnnPreparedModel_1_2<HalVersion>::m_RequestThread;
Mike Kellyb5fdf382019-06-11 16:35:25 +0100122
123template<typename HalVersion>
124template<typename TensorBindingCollection>
125void ArmnnPreparedModel_1_2<HalVersion>::DumpTensorsIfRequired(char const* tensorNamePrefix,
126 const TensorBindingCollection& tensorBindings)
127{
128 if (!m_RequestInputsAndOutputsDumpDir.empty())
129 {
130 const std::string requestName = boost::str(boost::format("%1%_%2%.dump") % m_NetworkId % m_RequestCount);
131 for (std::size_t i = 0u; i < tensorBindings.size(); ++i)
132 {
133 DumpTensor(m_RequestInputsAndOutputsDumpDir,
134 requestName,
135 BuildTensorName(tensorNamePrefix, i),
136 tensorBindings[i].second);
137 }
138 }
139}
140
141template<typename HalVersion>
142ArmnnPreparedModel_1_2<HalVersion>::ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
143 armnn::IRuntime* runtime,
144 const V1_2::Model& model,
145 const std::string& requestInputsAndOutputsDumpDir,
146 const bool gpuProfilingEnabled)
147 : m_NetworkId(networkId)
148 , m_Runtime(runtime)
149 , m_Model(model)
150 , m_RequestCount(0)
151 , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
152 , m_GpuProfilingEnabled(gpuProfilingEnabled)
153{
154 // Enable profiling if required.
155 m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled);
156}
157
158template<typename HalVersion>
159ArmnnPreparedModel_1_2<HalVersion>::~ArmnnPreparedModel_1_2()
160{
161 // Get a hold of the profiler used by this model.
162 std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkId);
163
164 // Unload the network associated with this model.
165 m_Runtime->UnloadNetwork(m_NetworkId);
166
167 // Dump the profiling info to a file if required.
168 DumpJsonProfilingIfRequired(m_GpuProfilingEnabled, m_RequestInputsAndOutputsDumpDir, m_NetworkId, profiler.get());
169}
170
171template<typename HalVersion>
Kevin Mayec1e5b82020-02-26 17:00:39 +0000172Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::execute(const V1_0::Request& request,
Mike Kellyb5fdf382019-06-11 16:35:25 +0100173 const ::android::sp<V1_0::IExecutionCallback>& callback)
174{
Mike Kelly65c42dc2019-07-22 14:06:00 +0100175 if (callback.get() == nullptr)
176 {
177 ALOGE("ArmnnPreparedModel_1_2::execute invalid callback passed");
Kevin Mayec1e5b82020-02-26 17:00:39 +0000178 return V1_0::ErrorStatus::INVALID_ARGUMENT;
Mike Kelly65c42dc2019-07-22 14:06:00 +0100179 }
180
Kevin Mayec1e5b82020-02-26 17:00:39 +0000181 auto cb = [callback](V1_0::ErrorStatus errorStatus,
Mike Kelly65c42dc2019-07-22 14:06:00 +0100182 std::vector<OutputShape> outputShapes,
183 const Timing& timing,
184 std::string callingFunction)
185 {
186 NotifyCallbackAndCheck(callback, errorStatus, outputShapes, timing, callingFunction);
187 };
188
189 return Execute(request, MeasureTiming::NO, cb);
Mike Kellyb5fdf382019-06-11 16:35:25 +0100190}
191
192template<typename HalVersion>
Kevin Mayec1e5b82020-02-26 17:00:39 +0000193Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::execute_1_2(
194 const V1_0::Request& request,
195 MeasureTiming measureTiming,
196 const sp<V1_2::IExecutionCallback>& callback)
Mike Kellyb5fdf382019-06-11 16:35:25 +0100197{
Mike Kelly65c42dc2019-07-22 14:06:00 +0100198 if (callback.get() == nullptr)
199 {
200 ALOGE("ArmnnPreparedModel_1_2::execute_1_2 invalid callback passed");
Kevin Mayec1e5b82020-02-26 17:00:39 +0000201 return V1_0::ErrorStatus::INVALID_ARGUMENT;
Mike Kelly65c42dc2019-07-22 14:06:00 +0100202 }
203
Kevin Mayec1e5b82020-02-26 17:00:39 +0000204 auto cb = [callback](V1_0::ErrorStatus errorStatus,
Mike Kelly65c42dc2019-07-22 14:06:00 +0100205 std::vector<OutputShape> outputShapes,
206 const Timing& timing,
207 std::string callingFunction)
208 {
209 NotifyCallbackAndCheck(callback, errorStatus, outputShapes, timing, callingFunction);
210 };
211
212 return Execute(request, measureTiming, cb);
Mike Kellyb5fdf382019-06-11 16:35:25 +0100213}
214
Derek Lamberti4de83c52020-03-17 13:40:18 +0000215template<typename HalVersion>
216Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForInputs(
217 armnn::InputTensors& inputs,
218 const V1_0::Request& request,
219 const std::vector<android::nn::RunTimePoolInfo>& memPools)
220{
221 inputs.reserve(request.inputs.size());
222 for (unsigned int i = 0; i < request.inputs.size(); i++)
223 {
224 const auto& inputArg = request.inputs[i];
225
226 const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
227 const armnn::Tensor inputTensor = GetTensorForRequestArgument(inputArg, inputTensorInfo, memPools);
228
229 if (inputTensor.GetMemoryArea() == nullptr)
230 {
231 ALOGE("Cannot execute request. Error converting request input %u to tensor", i);
232 return V1_0::ErrorStatus::GENERAL_FAILURE;
233 }
234
235 inputs.emplace_back(i, inputTensor);
236 }
237
238 return V1_0::ErrorStatus::NONE;
239}
240
241template<typename HalVersion>
242Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForOutputs(
243 armnn::OutputTensors& outputs,
244 std::vector<OutputShape> &outputShapes,
245 const V1_0::Request& request,
246 const std::vector<android::nn::RunTimePoolInfo>& memPools)
247{
248 outputs.reserve(request.outputs.size());
249 for (unsigned int i = 0; i < request.outputs.size(); i++)
250 {
251 const auto& outputArg = request.outputs[i];
252
253 const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
254 const armnn::Tensor outputTensor = GetTensorForRequestArgument(outputArg, outputTensorInfo, memPools);
255 if (outputTensor.GetMemoryArea() == nullptr)
256 {
257 ALOGE("Cannot execute request. Error converting request output %u to tensor", i);
258 return V1_0::ErrorStatus::GENERAL_FAILURE;
259 }
260
261 const size_t outputSize = outputTensorInfo.GetNumBytes();
262 const size_t bufferSize = memPools.at(outputArg.location.poolIndex).getHidlMemory().size();
263 if (bufferSize < outputSize)
264 {
265 ALOGW("ArmnnPreparedModel_1_2::Execute failed");
266 return V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
267 }
268
269 outputs.emplace_back(i, outputTensor);
270 outputShapes[i] = ComputeShape(outputTensorInfo);
271 }
272
273 return V1_0::ErrorStatus::NONE;
274}
275
276template<typename HalVersion>
277Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForIO(
278 armnn::InputTensors& inputs,
279 armnn::OutputTensors& outputs,
280 std::vector<android::nn::RunTimePoolInfo>& memPools,
281 const V1_0::Request& request,
282 CallbackAsync_1_2 callback)
283{
284 if (!setRunTimePoolInfosFromHidlMemories(&memPools, request.pools))
285 {
286 callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
287 return V1_0::ErrorStatus::GENERAL_FAILURE;
288 }
289
290 // add the inputs and outputs with their data
291 try
292 {
293 if (PrepareMemoryForInputs(inputs, request, memPools) != V1_0::ErrorStatus::NONE)
294 {
295 callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
296 return V1_0::ErrorStatus::GENERAL_FAILURE;
297 }
298
299 std::vector<OutputShape> outputShapes(request.outputs.size());
300
301 auto errorStatus = PrepareMemoryForOutputs(outputs, outputShapes, request, memPools);
302 if (errorStatus != V1_0::ErrorStatus::NONE)
303 {
304 callback(errorStatus,
305 outputShapes,
306 g_NoTiming,
307 "ArmnnPreparedModel_1_2::Execute");
308 return errorStatus;
309 }
310 }
311 catch (armnn::Exception& e)
312 {
313 ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what());
314 callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
315 return V1_0::ErrorStatus::GENERAL_FAILURE;
316 }
317 catch (std::exception& e)
318 {
319 ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what());
320 callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
321 return V1_0::ErrorStatus::GENERAL_FAILURE;
322 }
323
324 return V1_0::ErrorStatus::NONE;
325}
326
Mike Kellyb5fdf382019-06-11 16:35:25 +0100327template<typename HalVersion>
Kevin Mayec1e5b82020-02-26 17:00:39 +0000328Return<void> ArmnnPreparedModel_1_2<HalVersion>::executeSynchronously(const V1_0::Request& request,
Mike Kelly44381512019-07-08 17:37:35 +0100329 MeasureTiming measureTiming,
330 executeSynchronously_cb cb)
Mike Kellyb5fdf382019-06-11 16:35:25 +0100331{
332 ALOGV("ArmnnPreparedModel_1_2::executeSynchronously(): %s", GetModelSummary(m_Model).c_str());
333 m_RequestCount++;
334
335 if (cb == nullptr)
336 {
337 ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid callback passed");
338 return Void();
339 }
340
Derek Lamberti4de83c52020-03-17 13:40:18 +0000341 TimePoint driverStart;
Mike Kelly44381512019-07-08 17:37:35 +0100342
343 if (measureTiming == MeasureTiming::YES)
344 {
345 driverStart = Now();
346 }
347
Mike Kellyb5fdf382019-06-11 16:35:25 +0100348 if (!android::nn::validateRequest(request, m_Model))
349 {
Mike Kelly44381512019-07-08 17:37:35 +0100350 ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid request model");
Kevin Mayec1e5b82020-02-26 17:00:39 +0000351 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming);
Mike Kellyb5fdf382019-06-11 16:35:25 +0100352 return Void();
353 }
354
Derek Lamberti4de83c52020-03-17 13:40:18 +0000355 auto cbWrapper = [cb](V1_0::ErrorStatus errorStatus,
356 std::vector<OutputShape> outputShapes,
357 const Timing& timing,
358 std::string)
359 {
360 cb(errorStatus, outputShapes, timing);
361 };
Mike Kellyb5fdf382019-06-11 16:35:25 +0100362
363 // map the memory pool into shared pointers
364 // use a shared memory pools vector on the heap, as it is passed to the request thread
Derek Lamberti4de83c52020-03-17 13:40:18 +0000365 auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
Mike Kellyb5fdf382019-06-11 16:35:25 +0100366
Derek Lamberti4de83c52020-03-17 13:40:18 +0000367 // allocate the tensors on the heap, as they are passed to the request thread
368 auto inputs = std::make_shared<armnn::InputTensors>();
369 auto outputs = std::make_shared<armnn::OutputTensors>();
370
371 auto prepareStatus = PrepareMemoryForIO(*inputs, *outputs, *memPools, request, cbWrapper);
372 if (prepareStatus != V1_0::ErrorStatus::NONE)
Mike Kellyb5fdf382019-06-11 16:35:25 +0100373 {
Kevin May7bdaac52020-02-10 12:10:07 +0000374 return Void();
375 }
376
Mike Kellyb5fdf382019-06-11 16:35:25 +0100377 ALOGV("ArmnnPreparedModel_1_2::executeSynchronously() before Execution");
378
Derek Lamberti4de83c52020-03-17 13:40:18 +0000379 CallbackContext_1_2 cbCtx;
380 cbCtx.callback = cbWrapper;
381 cbCtx.ctx.measureTimings = measureTiming;
382 cbCtx.ctx.driverStart = driverStart;
383 ExecuteGraph(memPools, *inputs, *outputs, cbCtx);
384
385 return Void();
386}
387
388template<typename HalVersion>
389template<typename CallbackContext>
390bool ArmnnPreparedModel_1_2<HalVersion>::ExecuteGraph(
391 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
392 armnn::InputTensors& inputTensors,
393 armnn::OutputTensors& outputTensors,
394 CallbackContext cb)
395{
396 ALOGV("ArmnnPreparedModel_1_2::ExecuteGraph(...)");
397
398 TimePoint driverEnd, deviceStart, deviceEnd;
399
400 DumpTensorsIfRequired("Input", inputTensors);
401
402 std::vector<OutputShape> outputShapes(outputTensors.size());
403 for (unsigned int i = 0; i < outputTensors.size(); i++)
404 {
405 std::pair<int, armnn::Tensor> outputTensorPair = outputTensors[i];
406 const armnn::Tensor outputTensor = outputTensorPair.second;
407 const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo();
408
409 outputShapes[i] = ComputeShape(outputTensorInfo);
410 }
411
Mike Kellyb5fdf382019-06-11 16:35:25 +0100412 // run it
413 try
414 {
Derek Lamberti4de83c52020-03-17 13:40:18 +0000415 if (cb.ctx.measureTimings == MeasureTiming::YES)
Mike Kelly44381512019-07-08 17:37:35 +0100416 {
417 deviceStart = Now();
418 }
419
Derek Lamberti4de83c52020-03-17 13:40:18 +0000420 armnn::Status status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors);
Mike Kellyb5fdf382019-06-11 16:35:25 +0100421
Derek Lamberti4de83c52020-03-17 13:40:18 +0000422 if (cb.ctx.measureTimings == MeasureTiming::YES)
Mike Kelly44381512019-07-08 17:37:35 +0100423 {
424 deviceEnd = Now();
425 }
Mike Kellyb5fdf382019-06-11 16:35:25 +0100426 if (status != armnn::Status::Success)
427 {
428 ALOGW("EnqueueWorkload failed");
Derek Lamberti4de83c52020-03-17 13:40:18 +0000429 cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming,
430 "ArmnnPreparedModel_1_2::ExecuteGraph");
431 return false;
Mike Kellyb5fdf382019-06-11 16:35:25 +0100432 }
433 }
Kevin May7bdaac52020-02-10 12:10:07 +0000434 catch (armnn::Exception& e)
435 {
Derek Lamberti4de83c52020-03-17 13:40:18 +0000436 ALOGW("armnn:Exception caught from EnqueueWorkload: %s", e.what());
437 cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
438 return false;
Kevin May7bdaac52020-02-10 12:10:07 +0000439 }
Derek Lambertib9cb8442019-11-28 13:34:48 +0000440 catch (std::exception& e)
Mike Kellyb5fdf382019-06-11 16:35:25 +0100441 {
Kevin May7bdaac52020-02-10 12:10:07 +0000442 ALOGE("std::exception caught from EnqueueWorkload: %s", e.what());
Derek Lamberti4de83c52020-03-17 13:40:18 +0000443 cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
444 return false;
Mike Kellyb5fdf382019-06-11 16:35:25 +0100445 }
446
Derek Lamberti4de83c52020-03-17 13:40:18 +0000447 CommitPools(*pMemPools);
Mike Kellyb5fdf382019-06-11 16:35:25 +0100448
Derek Lamberti4de83c52020-03-17 13:40:18 +0000449 DumpTensorsIfRequired("Output", outputTensors);
Kevin Mayec1e5b82020-02-26 17:00:39 +0000450
Derek Lamberti4de83c52020-03-17 13:40:18 +0000451 if (cb.ctx.measureTimings == MeasureTiming::YES)
Mike Kelly44381512019-07-08 17:37:35 +0100452 {
453 driverEnd = Now();
454 Timing timing;
455 timing.timeOnDevice = MicrosecondsDuration(deviceEnd, deviceStart);
Derek Lamberti4de83c52020-03-17 13:40:18 +0000456 timing.timeInDriver = MicrosecondsDuration(driverEnd, cb.ctx.driverStart);
457 ALOGV("ArmnnPreparedModel_1_2::execute timing - Device = %lu Driver = %lu", timing.timeOnDevice,
458 timing.timeInDriver);
459 cb.callback(V1_0::ErrorStatus::NONE, outputShapes, timing, "ArmnnPreparedModel_1_2::ExecuteGraph");
460 } else {
461 cb.callback(V1_0::ErrorStatus::NONE, outputShapes, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
Mike Kelly44381512019-07-08 17:37:35 +0100462 }
Derek Lamberti4de83c52020-03-17 13:40:18 +0000463
464 return true;
Mike Kellyb5fdf382019-06-11 16:35:25 +0100465}
466
Derek Lamberti4de83c52020-03-17 13:40:18 +0000467template<typename HalVersion>
468bool ArmnnPreparedModel_1_2<HalVersion>::ExecuteWithDummyInputs()
469{
470 std::vector<std::vector<char>> storage;
471 armnn::InputTensors inputTensors;
Kevin May42477c12020-03-26 13:34:14 +0000472 for (unsigned int i = 0; i < getMainModel(m_Model).inputIndexes.size(); i++)
Derek Lamberti4de83c52020-03-17 13:40:18 +0000473 {
474 const armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
475 storage.emplace_back(inputTensorInfo.GetNumBytes());
476 const armnn::ConstTensor inputTensor(inputTensorInfo, storage.back().data());
477
478 inputTensors.emplace_back(i, inputTensor);
479 }
480
481 armnn::OutputTensors outputTensors;
Kevin May42477c12020-03-26 13:34:14 +0000482 for (unsigned int i = 0; i < getMainModel(m_Model).outputIndexes.size(); i++)
Derek Lamberti4de83c52020-03-17 13:40:18 +0000483 {
484 const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
485 storage.emplace_back(outputTensorInfo.GetNumBytes());
486 const armnn::Tensor outputTensor(outputTensorInfo, storage.back().data());
487
488 outputTensors.emplace_back(i, outputTensor);
489 }
490
491 auto nullCallback = [](V1_0::ErrorStatus, std::vector<OutputShape>, const Timing&, std::string) {};
492 CallbackContext_1_2 callbackContext;
493 callbackContext.callback = nullCallback;
494 callbackContext.ctx.measureTimings = MeasureTiming::NO;
495 auto memPools = std::make_shared<std::vector<::android::nn::RunTimePoolInfo>>();
496 return ExecuteGraph(memPools,
497 inputTensors,
498 outputTensors,
499 callbackContext);
500}
501
502template<typename HalVersion>
503Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::Execute(const V1_0::Request& request,
504 MeasureTiming measureTiming,
505 CallbackAsync_1_2 callback)
506{
507 ExecutionContext_1_2 ctx;
508 if (measureTiming == MeasureTiming::YES)
509 {
510 ctx.measureTimings = measureTiming;
511 ctx.driverStart = Now();
512 }
513
514 ALOGV("ArmnnPreparedModel_1_2::execute(): %s", GetModelSummary(m_Model).c_str());
515 m_RequestCount++;
516
517 if (!android::nn::validateRequest(request, m_Model))
518 {
519 callback(V1_0::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
520 return V1_0::ErrorStatus::INVALID_ARGUMENT;
521 }
522
523 if (!m_RequestInputsAndOutputsDumpDir.empty())
524 {
525 ALOGD("Dumping inputs and outputs for request %" PRIuPTR, reinterpret_cast<std::uintptr_t>(&callback));
526 }
527
528 // map the memory pool into shared pointers
529 // use a shared memory pools vector on the heap, as it is passed to the request thread
530 auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
531
532 // allocate the tensors on the heap, as they are passed to the request thread
533 auto inputTensors = std::make_shared<armnn::InputTensors>();
534 auto outputTensors = std::make_shared<armnn::OutputTensors>();
535
536 auto prepareStatus = PrepareMemoryForIO(*inputTensors, *outputTensors, *memPools, request, callback);
537 switch(prepareStatus)
538 {
539 case V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
540 return V1_0::ErrorStatus::NONE;
541 case V1_0::ErrorStatus::GENERAL_FAILURE:
542 return V1_0::ErrorStatus::GENERAL_FAILURE;
543 default:
544 {}
545 }
546
547 ALOGV("ArmnnPreparedModel_1_2::execute(...) before PostMsg");
548
549 // post the request for asynchronous execution
550 CallbackContext_1_2 cb;
551 cb.callback = callback;
552 cb.ctx = ctx;
553 m_RequestThread.PostMsg(this, memPools, inputTensors, outputTensors, cb);
554 ALOGV("ArmnnPreparedModel_1_2::execute(...) after PostMsg");
555 return V1_0::ErrorStatus::NONE;
556}
557
Mike Kellyb5fdf382019-06-11 16:35:25 +0100558template<typename HalVersion>
559Return<void> ArmnnPreparedModel_1_2<HalVersion>::configureExecutionBurst(
Derek Lamberti4de83c52020-03-17 13:40:18 +0000560 const sp<V1_2::IBurstCallback>& callback,
561 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
562 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
563 V1_2::IPreparedModel::configureExecutionBurst_cb cb)
Mike Kellyb5fdf382019-06-11 16:35:25 +0100564{
565 ALOGV("ArmnnPreparedModel_1_2::configureExecutionBurst");
Mike Kelly65c42dc2019-07-22 14:06:00 +0100566 const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(callback,
567 requestChannel,
568 resultChannel,
Kevin May42477c12020-03-26 13:34:14 +0000569 this);
Mike Kellyb5fdf382019-06-11 16:35:25 +0100570
Mike Kelly44381512019-07-08 17:37:35 +0100571 if (burst == nullptr)
572 {
Kevin Mayec1e5b82020-02-26 17:00:39 +0000573 cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
Mike Kelly44381512019-07-08 17:37:35 +0100574 }
575 else
576 {
Kevin Mayec1e5b82020-02-26 17:00:39 +0000577 cb(V1_0::ErrorStatus::NONE, burst);
Mike Kellyb5fdf382019-06-11 16:35:25 +0100578 }
579 return Void();
580}
581
Kevin May42477c12020-03-26 13:34:14 +0000582#if defined(ARMNN_ANDROID_NN_V1_2) || defined(ARMNN_ANDROID_NN_V1_3)
Mike Kellyb5fdf382019-06-11 16:35:25 +0100583template class ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>;
Derek Lamberti4de83c52020-03-17 13:40:18 +0000584template bool ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>::ExecuteGraph<CallbackContext_1_2>(
585 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
586 armnn::InputTensors& pInputTensors,
587 armnn::OutputTensors& pOutputTensors,
588 CallbackContext_1_2 cb);
Mike Kellyb5fdf382019-06-11 16:35:25 +0100589#endif
590
591} // namespace armnn_driver