blob: fa6740580901fde1294d513aabde6e38a5fdc05f [file] [log] [blame]
Kevin May42477c12020-03-26 13:34:14 +00001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ArmnnDriver.hpp"
9#include "ArmnnDriverImpl.hpp"
10#include "RequestThread.hpp"
11#include "ModelToINetworkConverter.hpp"
12
13#include <NeuralNetworks.h>
14#include <armnn/ArmNN.hpp>
15
16#include <string>
17#include <vector>
18
19namespace armnn_driver
20{
21using CallbackAsync_1_3 = std::function<
22 void(V1_3::ErrorStatus errorStatus,
23 std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
24 const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
25 std::string callingFunction)>;
26
27struct ExecutionContext_1_3
28{
29 ::android::hardware::neuralnetworks::V1_2::MeasureTiming measureTimings =
30 ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO;
31 TimePoint driverStart;
Sadik Armagand7be72e2020-04-23 12:56:05 +010032 TimePoint deviceStart;
33 TimePoint deviceEnd;
Kevin May42477c12020-03-26 13:34:14 +000034};
35
36using CallbackContext_1_3 = CallbackContext<CallbackAsync_1_3, ExecutionContext_1_3>;
37
38using executeFenced_cb = std::function<void(::android::hardware::neuralnetworks::V1_3::ErrorStatus status,
39 const ::android::hardware::hidl_handle& syncFence,
40 const ::android::sp<::android::hardware::neuralnetworks::V1_3::IFencedExecutionCallback>& callback)>;
41
42template <typename HalVersion>
43class ArmnnPreparedModel_1_3 : public V1_3::IPreparedModel
44{
45public:
46 using HalModel = typename V1_3::Model;
47
48 ArmnnPreparedModel_1_3(armnn::NetworkId networkId,
49 armnn::IRuntime* runtime,
50 const HalModel& model,
51 const std::string& requestInputsAndOutputsDumpDir,
52 const bool gpuProfilingEnabled);
53
54 virtual ~ArmnnPreparedModel_1_3();
55
56 Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
57 const sp<V1_0::IExecutionCallback>& callback) override;
58
59 Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request& request, MeasureTiming measure,
60 const sp<V1_2::IExecutionCallback>& callback) override;
61
62 Return<V1_3::ErrorStatus> execute_1_3(const V1_3::Request& request,
63 V1_2::MeasureTiming measure,
64 const V1_3::OptionalTimePoint&,
Kevin May352d8382020-03-31 15:03:42 +010065 const V1_3::OptionalTimeoutDuration&,
Kevin May42477c12020-03-26 13:34:14 +000066 const sp<V1_3::IExecutionCallback>& callback) override;
67
68 Return<void> executeSynchronously(const V1_0::Request &request,
69 MeasureTiming measure,
70 V1_3::IPreparedModel::executeSynchronously_cb cb) override;
71
72 Return<void> executeSynchronously_1_3(const V1_3::Request &request,
73 MeasureTiming measure,
74 const V1_3::OptionalTimePoint& deadline,
Kevin May352d8382020-03-31 15:03:42 +010075 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
Kevin May42477c12020-03-26 13:34:14 +000076 V1_3::IPreparedModel::executeSynchronously_1_3_cb cb) override;
77
78 Return<void> executeFenced(const V1_3::Request& request,
Sadik Armagand7be72e2020-04-23 12:56:05 +010079 const android::hardware::hidl_vec<android::hardware::hidl_handle>& fenceWaitFor,
Kevin May42477c12020-03-26 13:34:14 +000080 MeasureTiming measure,
81 const V1_3::OptionalTimePoint& deadline,
Kevin May352d8382020-03-31 15:03:42 +010082 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
Kevin May42477c12020-03-26 13:34:14 +000083 const V1_3::OptionalTimeoutDuration& duration,
84 executeFenced_cb callback) override;
85
86 Return<void> configureExecutionBurst(
87 const sp<V1_2::IBurstCallback>& callback,
88 const android::hardware::MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
89 const android::hardware::MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
90 configureExecutionBurst_cb cb) override;
91
92 template<typename CallbackContext>
93 Return<void> ExecuteSynchronously(const V1_3::Request& request, CallbackContext cbCtx);
94
95 /// execute the graph prepared from the request
96 template<typename CallbackContext>
Sadik Armagand7be72e2020-04-23 12:56:05 +010097 Return <V1_3::ErrorStatus> ExecuteGraph(
98 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
99 armnn::InputTensors& inputTensors,
100 armnn::OutputTensors& outputTensors,
101 CallbackContext callback);
Kevin May42477c12020-03-26 13:34:14 +0000102
103 /// Executes this model with dummy inputs (e.g. all zeroes).
104 /// \return false on failure, otherwise true
105 bool ExecuteWithDummyInputs();
106
107private:
108 Return <V1_3::ErrorStatus> Execute(const V1_3::Request& request,
109 MeasureTiming measureTiming,
110 CallbackAsync_1_3 callback);
111
112 Return<V1_3::ErrorStatus> PrepareMemoryForInputs(
113 armnn::InputTensors& inputs,
114 const V1_3::Request& request,
115 const std::vector<android::nn::RunTimePoolInfo>& memPools);
116
117 Return<V1_3::ErrorStatus> PrepareMemoryForOutputs(
118 armnn::OutputTensors& outputs,
119 std::vector<OutputShape> &outputShapes,
120 const V1_3::Request& request,
121 const std::vector<android::nn::RunTimePoolInfo>& memPools);
122
123 std::tuple<V1_3::ErrorStatus, hidl_vec<OutputShape>, Timing, std::string> PrepareMemoryForIO(
124 armnn::InputTensors& inputs,
125 armnn::OutputTensors& outputs,
126 std::vector<android::nn::RunTimePoolInfo>& memPools,
127 const V1_3::Request& request);
128
129 template <typename TensorBindingCollection>
130 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
131
132 armnn::NetworkId m_NetworkId;
133 armnn::IRuntime* m_Runtime;
134 V1_3::Model m_Model;
135 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
136 // It is specific to this class, so it is declared as static here
137 static RequestThread<ArmnnPreparedModel_1_3, HalVersion, CallbackContext_1_3> m_RequestThread;
138 uint32_t m_RequestCount;
139 const std::string& m_RequestInputsAndOutputsDumpDir;
140 const bool m_GpuProfilingEnabled;
141};
142
143}