blob: 57deb98ca12c80873a83ef5c718668d7cbb2c00f [file] [log] [blame]
Mike Kellyb5fdf382019-06-11 16:35:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ArmnnDriver.hpp"
9#include "ArmnnDriverImpl.hpp"
10#include "RequestThread.hpp"
11#include "ModelToINetworkConverter.hpp"
12
13#include <NeuralNetworks.h>
14#include <armnn/ArmNN.hpp>
Finn Williamsca3a3e02021-06-11 15:04:02 +010015#include <armnn/Threadpool.hpp>
Mike Kellyb5fdf382019-06-11 16:35:25 +010016
17#include <string>
18#include <vector>
19
20namespace armnn_driver
21{
22
Derek Lamberti4de83c52020-03-17 13:40:18 +000023using CallbackAsync_1_2 = std::function<
24 void(V1_0::ErrorStatus errorStatus,
25 std::vector<::android::hardware::neuralnetworks::V1_2::OutputShape> outputShapes,
26 const ::android::hardware::neuralnetworks::V1_2::Timing& timing,
27 std::string callingFunction)>;
Mike Kelly65c42dc2019-07-22 14:06:00 +010028
Derek Lamberti4de83c52020-03-17 13:40:18 +000029struct ExecutionContext_1_2
Mike Kelly65c42dc2019-07-22 14:06:00 +010030{
Derek Lamberti4de83c52020-03-17 13:40:18 +000031 ::android::hardware::neuralnetworks::V1_2::MeasureTiming measureTimings =
32 ::android::hardware::neuralnetworks::V1_2::MeasureTiming::NO;
Mike Kelly65c42dc2019-07-22 14:06:00 +010033 TimePoint driverStart;
Mike Kelly65c42dc2019-07-22 14:06:00 +010034};
35
Derek Lamberti4de83c52020-03-17 13:40:18 +000036using CallbackContext_1_2 = CallbackContext<CallbackAsync_1_2, ExecutionContext_1_2>;
37
Mike Kellyb5fdf382019-06-11 16:35:25 +010038template <typename HalVersion>
39class ArmnnPreparedModel_1_2 : public V1_2::IPreparedModel
40{
41public:
42 using HalModel = typename V1_2::Model;
43
44 ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
45 armnn::IRuntime* runtime,
46 const HalModel& model,
47 const std::string& requestInputsAndOutputsDumpDir,
Finn Williamsd8fb5402021-05-19 20:52:00 +010048 const bool gpuProfilingEnabled,
Finn Williamsca3a3e02021-06-11 15:04:02 +010049 const bool asyncModelExecutionEnabled = false,
Narumol Prangnawaratd1a947f2022-02-07 13:12:24 +000050 const unsigned int numberOfThreads = 1,
David Monahanbe9d99e2022-04-29 16:25:24 +010051 const bool importEnabled = false,
52 const bool exportEnabled = false);
Mike Kellyb5fdf382019-06-11 16:35:25 +010053
Sadik Armagan0a2dfab2021-10-06 16:41:44 +010054 ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
55 armnn::IRuntime* runtime,
56 const std::string& requestInputsAndOutputsDumpDir,
57 const bool gpuProfilingEnabled,
58 const bool asyncModelExecutionEnabled = false,
59 const unsigned int numberOfThreads = 1,
David Monahanbe9d99e2022-04-29 16:25:24 +010060 const bool importEnabled = false,
61 const bool exportEnabled = false,
Sadik Armagan0a2dfab2021-10-06 16:41:44 +010062 const bool preparedFromCache = false);
63
Mike Kellyb5fdf382019-06-11 16:35:25 +010064 virtual ~ArmnnPreparedModel_1_2();
65
Kevin Mayec1e5b82020-02-26 17:00:39 +000066 virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
Sadik Armagan188675f2021-02-12 17:16:42 +000067 const ::android::sp<V1_0::IExecutionCallback>& callback) override;
Mike Kellyb5fdf382019-06-11 16:35:25 +010068
Sadik Armagan188675f2021-02-12 17:16:42 +000069 virtual Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request& request, V1_2::MeasureTiming measure,
70 const ::android::sp<V1_2::IExecutionCallback>& callback) override;
Mike Kellyb5fdf382019-06-11 16:35:25 +010071
Kevin Mayec1e5b82020-02-26 17:00:39 +000072 virtual Return<void> executeSynchronously(const V1_0::Request &request,
Sadik Armagan188675f2021-02-12 17:16:42 +000073 V1_2::MeasureTiming measure,
Mike Kellyb5fdf382019-06-11 16:35:25 +010074 V1_2::IPreparedModel::executeSynchronously_cb cb) override;
75
76 virtual Return<void> configureExecutionBurst(
Sadik Armagan188675f2021-02-12 17:16:42 +000077 const ::android::sp<V1_2::IBurstCallback>& callback,
Mike Kellyb5fdf382019-06-11 16:35:25 +010078 const android::hardware::MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
79 const android::hardware::MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
80 configureExecutionBurst_cb cb) override;
81
82 /// execute the graph prepared from the request
Derek Lamberti4de83c52020-03-17 13:40:18 +000083 template<typename CallbackContext>
84 bool ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
85 armnn::InputTensors& inputTensors,
86 armnn::OutputTensors& outputTensors,
87 CallbackContext callback);
Mike Kellyb5fdf382019-06-11 16:35:25 +010088
89 /// Executes this model with dummy inputs (e.g. all zeroes).
90 /// \return false on failure, otherwise true
Sadik Armagan0a2dfab2021-10-06 16:41:44 +010091 bool ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs);
Mike Kellyb5fdf382019-06-11 16:35:25 +010092
93private:
Finn Williamsd8fb5402021-05-19 20:52:00 +010094
95 template<typename CallbackContext>
96 class ArmnnThreadPoolCallback_1_2 : public armnn::IAsyncExecutionCallback
97 {
98 public:
99 ArmnnThreadPoolCallback_1_2(ArmnnPreparedModel_1_2<HalVersion>* model,
100 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
101 std::vector<V1_2::OutputShape> outputShapes,
102 std::shared_ptr<armnn::InputTensors>& inputTensors,
103 std::shared_ptr<armnn::OutputTensors>& outputTensors,
104 CallbackContext callbackContext) :
105 m_Model(model),
106 m_MemPools(pMemPools),
107 m_OutputShapes(outputShapes),
108 m_InputTensors(inputTensors),
109 m_OutputTensors(outputTensors),
110 m_CallbackContext(callbackContext)
111 {}
112
113 void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
114
Finn Williamsd8fb5402021-05-19 20:52:00 +0100115 ArmnnPreparedModel_1_2<HalVersion>* m_Model;
116 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
117 std::vector<V1_2::OutputShape> m_OutputShapes;
118 std::shared_ptr<armnn::InputTensors> m_InputTensors;
119 std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
120 CallbackContext m_CallbackContext;
121 };
122
Derek Lamberti4de83c52020-03-17 13:40:18 +0000123 Return<V1_0::ErrorStatus> Execute(const V1_0::Request& request,
Sadik Armagan188675f2021-02-12 17:16:42 +0000124 V1_2::MeasureTiming measureTiming,
Derek Lamberti4de83c52020-03-17 13:40:18 +0000125 CallbackAsync_1_2 callback);
126
127 Return<V1_0::ErrorStatus> PrepareMemoryForInputs(
128 armnn::InputTensors& inputs,
129 const V1_0::Request& request,
130 const std::vector<android::nn::RunTimePoolInfo>& memPools);
131
132 Return<V1_0::ErrorStatus> PrepareMemoryForOutputs(
133 armnn::OutputTensors& outputs,
Sadik Armagan188675f2021-02-12 17:16:42 +0000134 std::vector<V1_2::OutputShape> &outputShapes,
Derek Lamberti4de83c52020-03-17 13:40:18 +0000135 const V1_0::Request& request,
136 const std::vector<android::nn::RunTimePoolInfo>& memPools);
137
138 Return <V1_0::ErrorStatus> PrepareMemoryForIO(
139 armnn::InputTensors& inputs,
140 armnn::OutputTensors& outputs,
141 std::vector<android::nn::RunTimePoolInfo>& memPools,
142 const V1_0::Request& request,
143 CallbackAsync_1_2 callback);
Mike Kellyb5fdf382019-06-11 16:35:25 +0100144
145 template <typename TensorBindingCollection>
146 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
147
Finn Williamsd8fb5402021-05-19 20:52:00 +0100148 /// schedule the graph prepared from the request for execution
149 template<typename CallbackContext>
150 void ScheduleGraphForExecution(
151 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
152 std::shared_ptr<armnn::InputTensors>& inputTensors,
153 std::shared_ptr<armnn::OutputTensors>& outputTensors,
154 CallbackContext m_CallbackContext);
155
Finn Williamsca3a3e02021-06-11 15:04:02 +0100156 armnn::NetworkId m_NetworkId;
157 armnn::IRuntime* m_Runtime;
Finn Williamsca3a3e02021-06-11 15:04:02 +0100158 V1_2::Model m_Model;
Mike Kellyb5fdf382019-06-11 16:35:25 +0100159 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
160 // It is specific to this class, so it is declared as static here
Derek Lamberti4de83c52020-03-17 13:40:18 +0000161 static RequestThread<ArmnnPreparedModel_1_2,
162 HalVersion,
Finn Williamsca3a3e02021-06-11 15:04:02 +0100163 CallbackContext_1_2> m_RequestThread;
164 uint32_t m_RequestCount;
165 const std::string& m_RequestInputsAndOutputsDumpDir;
166 const bool m_GpuProfilingEnabled;
Finn Williamsfdf2eae2021-07-08 13:07:19 +0100167 // Static to allow sharing of threadpool between ArmnnPreparedModel instances
168 static std::unique_ptr<armnn::Threadpool> m_Threadpool;
Finn Williamsca3a3e02021-06-11 15:04:02 +0100169 std::shared_ptr<IWorkingMemHandle> m_WorkingMemHandle;
170 const bool m_AsyncModelExecutionEnabled;
Narumol Prangnawaratd1a947f2022-02-07 13:12:24 +0000171 const bool m_EnableImport;
172 const bool m_EnableExport;
Sadik Armagan0a2dfab2021-10-06 16:41:44 +0100173 const bool m_PreparedFromCache;
Mike Kellyb5fdf382019-06-11 16:35:25 +0100174};
175
176}