blob: 6c630c56f1aa21b4403da0cee54cb9c333668447 [file] [log] [blame]
Mike Kellyb5fdf382019-06-11 16:35:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ArmnnDriver.hpp"
9#include "ArmnnDriverImpl.hpp"
10#include "RequestThread.hpp"
11#include "ModelToINetworkConverter.hpp"
12
13#include <NeuralNetworks.h>
14#include <armnn/ArmNN.hpp>
15
16#include <string>
17#include <vector>
18
19namespace armnn_driver
20{
21
Derek Lamberti4de83c52020-03-17 13:40:18 +000022using CallbackAsync_1_2 = std::function<
23 void(V1_0::ErrorStatus errorStatus,
24 std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
25 const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
26 std::string callingFunction)>;
Mike Kelly65c42dc2019-07-22 14:06:00 +010027
Derek Lamberti4de83c52020-03-17 13:40:18 +000028struct ExecutionContext_1_2
Mike Kelly65c42dc2019-07-22 14:06:00 +010029{
Derek Lamberti4de83c52020-03-17 13:40:18 +000030 ::android::hardware::neuralnetworks::V1_2::MeasureTiming measureTimings =
31 ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO;
Mike Kelly65c42dc2019-07-22 14:06:00 +010032 TimePoint driverStart;
Mike Kelly65c42dc2019-07-22 14:06:00 +010033};
34
Derek Lamberti4de83c52020-03-17 13:40:18 +000035using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>;
36
Mike Kellyb5fdf382019-06-11 16:35:25 +010037template <typename HalVersion>
38class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel
39{
40public:
41 using HalModel = typename V1_2::Model;
42
43 ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
44 armnn::IRuntime* runtime,
45 const HalModel& model,
46 const std::string& requestInputsAndOutputsDumpDir,
Finn Williamsd8fb5402021-05-19 20:52:00 +010047 const bool gpuProfilingEnabled,
48 const bool asyncModelExecutionEnabled = false);
Mike Kellyb5fdf382019-06-11 16:35:25 +010049
50 virtual ~ArmnnPreparedModel_1_2();
51
Kevin Mayec1e5b82020-02-26 17:00:39 +000052 virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
Sadik Armagan188675f2021-02-12 17:16:42 +000053 const ::android::sp<V1_0::IExecutionCallback>& callback) override;
Mike Kellyb5fdf382019-06-11 16:35:25 +010054
Sadik Armagan188675f2021-02-12 17:16:42 +000055 virtual Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request& request, V1_2::MeasureTiming measure,
56 const ::android::sp<V1_2::IExecutionCallback>& callback) override;
Mike Kellyb5fdf382019-06-11 16:35:25 +010057
Kevin Mayec1e5b82020-02-26 17:00:39 +000058 virtual Return<void> executeSynchronously(const V1_0::Request &request,
Sadik Armagan188675f2021-02-12 17:16:42 +000059 V1_2::MeasureTiming measure,
Mike Kellyb5fdf382019-06-11 16:35:25 +010060 V1_2::IPreparedModel::executeSynchronously_cb cb) override;
61
62 virtual Return<void> configureExecutionBurst(
Sadik Armagan188675f2021-02-12 17:16:42 +000063 const ::android::sp<V1_2::IBurstCallback>& callback,
Mike Kellyb5fdf382019-06-11 16:35:25 +010064 const android::hardware::MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
65 const android::hardware::MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
66 configureExecutionBurst_cb cb) override;
67
68 /// execute the graph prepared from the request
Derek Lamberti4de83c52020-03-17 13:40:18 +000069 template<typename CallbackContext>
70 bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
71 armnn::InputTensors& inputTensors,
72 armnn::OutputTensors& outputTensors,
73 CallbackContext callback);
Mike Kellyb5fdf382019-06-11 16:35:25 +010074
75 /// Executes this model with dummy inputs (e.g. all zeroes).
76 /// \return false on failure, otherwise true
77 bool ExecuteWithDummyInputs();
78
79private:
Finn Williamsd8fb5402021-05-19 20:52:00 +010080
81 template<typename CallbackContext>
82 class ArmnnThreadPoolCallback_1_2 : public armnn::IAsyncExecutionCallback
83 {
84 public:
85 ArmnnThreadPoolCallback_1_2(ArmnnPreparedModel_1_2<HalVersion>* model,
86 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
87 std::vector<V1_2::OutputShape> outputShapes,
88 std::shared_ptr<armnn::InputTensors>& inputTensors,
89 std::shared_ptr<armnn::OutputTensors>& outputTensors,
90 CallbackContext callbackContext) :
91 m_Model(model),
92 m_MemPools(pMemPools),
93 m_OutputShapes(outputShapes),
94 m_InputTensors(inputTensors),
95 m_OutputTensors(outputTensors),
96 m_CallbackContext(callbackContext)
97 {}
98
99 void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
100
101 // Retrieve the Arm NN Status from the AsyncExecutionCallback that has been notified
102 virtual armnn::Status GetStatus() const override
103 {
104 return armnn::Status::Success;
105 }
106
107 // Block the calling thread until the AsyncExecutionCallback object allows it to proceed
108 virtual void Wait() const override
109 {}
110
111 // Retrieve the start time before executing the inference
112 virtual armnn::HighResolutionClock GetStartTime() const override
113 {
114 return std::chrono::high_resolution_clock::now();
115 }
116
117 // Retrieve the time after executing the inference
118 virtual armnn::HighResolutionClock GetEndTime() const override
119 {
120 return std::chrono::high_resolution_clock::now();
121 }
122
123 ArmnnPreparedModel_1_2<HalVersion>* m_Model;
124 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
125 std::vector<V1_2::OutputShape> m_OutputShapes;
126 std::shared_ptr<armnn::InputTensors> m_InputTensors;
127 std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
128 CallbackContext m_CallbackContext;
129 };
130
Derek Lamberti4de83c52020-03-17 13:40:18 +0000131 Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request,
Sadik Armagan188675f2021-02-12 17:16:42 +0000132 V1_2::MeasureTiming measureTiming,
Derek Lamberti4de83c52020-03-17 13:40:18 +0000133 CallbackAsync_1_2 callback);
134
135 Return<V1_0::ErrorStatus> PrepareMemoryForInputs(
136 armnn::InputTensors& inputs,
137 const V1_0::Request& request,
138 const std::vector<android::nn::RunTimePoolInfo>& memPools);
139
140 Return<V1_0::ErrorStatus> PrepareMemoryForOutputs(
141 armnn::OutputTensors& outputs,
Sadik Armagan188675f2021-02-12 17:16:42 +0000142 std::vector<V1_2::OutputShape> &outputShapes,
Derek Lamberti4de83c52020-03-17 13:40:18 +0000143 const V1_0::Request& request,
144 const std::vector<android::nn::RunTimePoolInfo>& memPools);
145
146 Return <V1_0::ErrorStatus> PrepareMemoryForIO(
147 armnn::InputTensors& inputs,
148 armnn::OutputTensors& outputs,
149 std::vector<android::nn::RunTimePoolInfo>& memPools,
150 const V1_0::Request& request,
151 CallbackAsync_1_2 callback);
Mike Kellyb5fdf382019-06-11 16:35:25 +0100152
153 template <typename TensorBindingCollection>
154 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
155
Finn Williamsd8fb5402021-05-19 20:52:00 +0100156 /// schedule the graph prepared from the request for execution
157 template<typename CallbackContext>
158 void ScheduleGraphForExecution(
159 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
160 std::shared_ptr<armnn::InputTensors>& inputTensors,
161 std::shared_ptr<armnn::OutputTensors>& outputTensors,
162 CallbackContext m_CallbackContext);
163
Mike Kelly65c42dc2019-07-22 14:06:00 +0100164 armnn::NetworkId m_NetworkId;
165 armnn::IRuntime* m_Runtime;
166 V1_2::Model m_Model;
Mike Kellyb5fdf382019-06-11 16:35:25 +0100167 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
168 // It is specific to this class, so it is declared as static here
Derek Lamberti4de83c52020-03-17 13:40:18 +0000169 static RequestThread<ArmnnPreparedModel_1_2,
170 HalVersion,
171 CallbackContext_1_2> m_RequestThread;
Mike Kelly65c42dc2019-07-22 14:06:00 +0100172 uint32_t m_RequestCount;
173 const std::string& m_RequestInputsAndOutputsDumpDir;
174 const bool m_GpuProfilingEnabled;
Finn Williamsd8fb5402021-05-19 20:52:00 +0100175
176 std::unique_ptr<IWorkingMemHandle> m_WorkingMemHandle;
177 const bool m_AsyncModelExecutionEnabled;
Mike Kellyb5fdf382019-06-11 16:35:25 +0100178};
179
180}