blob: d1c830d4877bb68f59cf5f87e35c20413017040d [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa015307bc12018-03-09 13:51:08 +00004//
5
6#pragma once
7
surmeh01deb3bdb2018-07-05 12:06:04 +01008#include "ArmnnDriver.hpp"
Matteo Martincighe48bdff2018-09-03 13:50:50 +01009#include "ArmnnDriverImpl.hpp"
arovir01b0717b52018-09-05 17:03:25 +010010#include "RequestThread.hpp"
surmeh01deb3bdb2018-07-05 12:06:04 +010011
telsoa01ce3e84a2018-08-31 09:31:35 +010012#include <NeuralNetworks.h>
13#include <armnn/ArmNN.hpp>
14
telsoa015307bc12018-03-09 13:51:08 +000015#include <string>
16#include <vector>
17
18namespace armnn_driver
19{
Mike Kelly65c42dc2019-07-22 14:06:00 +010020using armnnExecuteCallback_1_0 = std::function<void(V1_0::ErrorStatus status, std::string callingFunction)>;
21
22struct ArmnnCallback_1_0
23{
24 armnnExecuteCallback_1_0 callback;
25};
telsoa015307bc12018-03-09 13:51:08 +000026
Derek Lamberti4de83c52020-03-17 13:40:18 +000027struct ExecutionContext_1_0 {};
28
29using CallbackContext_1_0 = CallbackContext<armnnExecuteCallback_1_0, ExecutionContext_1_0>;
30
Matteo Martincighe48bdff2018-09-03 13:50:50 +010031template <typename HalVersion>
Matthew Bentham912b3622019-05-03 15:49:14 +010032class ArmnnPreparedModel : public V1_0::IPreparedModel
telsoa015307bc12018-03-09 13:51:08 +000033{
34public:
Matteo Martincighe48bdff2018-09-03 13:50:50 +010035 using HalModel = typename HalVersion::Model;
36
telsoa015307bc12018-03-09 13:51:08 +000037 ArmnnPreparedModel(armnn::NetworkId networkId,
38 armnn::IRuntime* runtime,
Matteo Martincighe48bdff2018-09-03 13:50:50 +010039 const HalModel& model,
telsoa01ce3e84a2018-08-31 09:31:35 +010040 const std::string& requestInputsAndOutputsDumpDir,
Finn Williamsd8fb5402021-05-19 20:52:00 +010041 const bool gpuProfilingEnabled,
42 const bool asyncModelExecutionEnabled = false);
telsoa015307bc12018-03-09 13:51:08 +000043
44 virtual ~ArmnnPreparedModel();
45
Kevin Mayec1e5b82020-02-26 17:00:39 +000046 virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
47 const ::android::sp<V1_0::IExecutionCallback>& callback) override;
telsoa015307bc12018-03-09 13:51:08 +000048
49 /// execute the graph prepared from the request
50 void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
Derek Lamberti4de83c52020-03-17 13:40:18 +000051 armnn::InputTensors& inputTensors,
52 armnn::OutputTensors& outputTensors,
53 CallbackContext_1_0 callback);
telsoa015307bc12018-03-09 13:51:08 +000054
55 /// Executes this model with dummy inputs (e.g. all zeroes).
Matthew Bentham16196e22019-04-01 17:17:58 +010056 /// \return false on failure, otherwise true
57 bool ExecuteWithDummyInputs();
telsoa015307bc12018-03-09 13:51:08 +000058
59private:
Finn Williamsd8fb5402021-05-19 20:52:00 +010060
61 template<typename CallbackContext>
62 class ArmnnThreadPoolCallback : public armnn::IAsyncExecutionCallback
63 {
64 public:
65 ArmnnThreadPoolCallback(ArmnnPreparedModel<HalVersion>* model,
66 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
67 std::shared_ptr<armnn::InputTensors>& inputTensors,
68 std::shared_ptr<armnn::OutputTensors>& outputTensors,
69 CallbackContext callbackContext) :
70 m_Model(model),
71 m_MemPools(pMemPools),
72 m_InputTensors(inputTensors),
73 m_OutputTensors(outputTensors),
74 m_CallbackContext(callbackContext)
75 {}
76
77 void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
78
79 // Retrieve the ArmNN Status from the AsyncExecutionCallback that has been notified
80 virtual armnn::Status GetStatus() const override
81 {
82 return armnn::Status::Success;
83 }
84
85 // Block the calling thread until the AsyncExecutionCallback object allows it to proceed
86 virtual void Wait() const override
87 {}
88
89 // Retrieve the start time before executing the inference
90 virtual armnn::HighResolutionClock GetStartTime() const override
91 {
92 return std::chrono::high_resolution_clock::now();
93 }
94
95 // Retrieve the time after executing the inference
96 virtual armnn::HighResolutionClock GetEndTime() const override
97 {
98 return std::chrono::high_resolution_clock::now();
99 }
100
101 ArmnnPreparedModel<HalVersion>* m_Model;
102 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
103 std::shared_ptr<armnn::InputTensors> m_InputTensors;
104 std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
105 CallbackContext m_CallbackContext;
106 };
107
telsoa015307bc12018-03-09 13:51:08 +0000108 template <typename TensorBindingCollection>
109 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
110
Finn Williamsd8fb5402021-05-19 20:52:00 +0100111 /// schedule the graph prepared from the request for execution
112 template<typename CallbackContext>
113 void ScheduleGraphForExecution(
114 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
115 std::shared_ptr<armnn::InputTensors>& inputTensors,
116 std::shared_ptr<armnn::OutputTensors>& outputTensors,
117 CallbackContext m_CallbackContext);
118
Mike Kelly65c42dc2019-07-22 14:06:00 +0100119 armnn::NetworkId m_NetworkId;
120 armnn::IRuntime* m_Runtime;
121 HalModel m_Model;
telsoa015307bc12018-03-09 13:51:08 +0000122 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
123 // It is specific to this class, so it is declared as static here
Derek Lamberti4de83c52020-03-17 13:40:18 +0000124 static RequestThread<ArmnnPreparedModel, HalVersion, CallbackContext_1_0> m_RequestThread;
Mike Kelly65c42dc2019-07-22 14:06:00 +0100125 uint32_t m_RequestCount;
126 const std::string& m_RequestInputsAndOutputsDumpDir;
127 const bool m_GpuProfilingEnabled;
Finn Williamsd8fb5402021-05-19 20:52:00 +0100128
129 std::unique_ptr<armnn::IWorkingMemHandle> m_WorkingMemHandle;
130 const bool m_AsyncModelExecutionEnabled;
telsoa015307bc12018-03-09 13:51:08 +0000131};
132
133}