blob: 685d950eec154afa55e4ac25d183d373008d0846 [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa015307bc12018-03-09 13:51:08 +00004//
5
6#pragma once
7
surmeh01deb3bdb2018-07-05 12:06:04 +01008#include "ArmnnDriver.hpp"
Matteo Martincighe48bdff2018-09-03 13:50:50 +01009#include "ArmnnDriverImpl.hpp"
arovir01b0717b52018-09-05 17:03:25 +010010#include "RequestThread.hpp"
surmeh01deb3bdb2018-07-05 12:06:04 +010011
telsoa01ce3e84a2018-08-31 09:31:35 +010012#include <NeuralNetworks.h>
13#include <armnn/ArmNN.hpp>
Finn Williamsca3a3e02021-06-11 15:04:02 +010014#include <armnn/Threadpool.hpp>
telsoa01ce3e84a2018-08-31 09:31:35 +010015
telsoa015307bc12018-03-09 13:51:08 +000016#include <string>
17#include <vector>
18
19namespace armnn_driver
20{
Mike Kelly65c42dc2019-07-22 14:06:00 +010021using armnnExecuteCallback_1_0 = std::function<void(V1_0::ErrorStatus status, std::string callingFunction)>;
22
23struct ArmnnCallback_1_0
24{
25 armnnExecuteCallback_1_0 callback;
26};
telsoa015307bc12018-03-09 13:51:08 +000027
Derek Lamberti4de83c52020-03-17 13:40:18 +000028struct ExecutionContext_1_0 {};
29
30using CallbackContext_1_0 = CallbackContext<armnnExecuteCallback_1_0, ExecutionContext_1_0>;
31
Matteo Martincighe48bdff2018-09-03 13:50:50 +010032template <typename HalVersion>
Matthew Bentham912b3622019-05-03 15:49:14 +010033class ArmnnPreparedModel : public V1_0::IPreparedModel
telsoa015307bc12018-03-09 13:51:08 +000034{
35public:
Matteo Martincighe48bdff2018-09-03 13:50:50 +010036 using HalModel = typename HalVersion::Model;
37
telsoa015307bc12018-03-09 13:51:08 +000038 ArmnnPreparedModel(armnn::NetworkId networkId,
39 armnn::IRuntime* runtime,
Matteo Martincighe48bdff2018-09-03 13:50:50 +010040 const HalModel& model,
telsoa01ce3e84a2018-08-31 09:31:35 +010041 const std::string& requestInputsAndOutputsDumpDir,
Finn Williamsd8fb5402021-05-19 20:52:00 +010042 const bool gpuProfilingEnabled,
Finn Williamsca3a3e02021-06-11 15:04:02 +010043 const bool asyncModelExecutionEnabled = false,
44 const unsigned int numberOfThreads = 1);
telsoa015307bc12018-03-09 13:51:08 +000045
46 virtual ~ArmnnPreparedModel();
47
Kevin Mayec1e5b82020-02-26 17:00:39 +000048 virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
49 const ::android::sp<V1_0::IExecutionCallback>& callback) override;
telsoa015307bc12018-03-09 13:51:08 +000050
51 /// execute the graph prepared from the request
52 void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
Derek Lamberti4de83c52020-03-17 13:40:18 +000053 armnn::InputTensors& inputTensors,
54 armnn::OutputTensors& outputTensors,
55 CallbackContext_1_0 callback);
telsoa015307bc12018-03-09 13:51:08 +000056
57 /// Executes this model with dummy inputs (e.g. all zeroes).
Matthew Bentham16196e22019-04-01 17:17:58 +010058 /// \return false on failure, otherwise true
59 bool ExecuteWithDummyInputs();
telsoa015307bc12018-03-09 13:51:08 +000060
61private:
Finn Williamsd8fb5402021-05-19 20:52:00 +010062
63 template<typename CallbackContext>
64 class ArmnnThreadPoolCallback : public armnn::IAsyncExecutionCallback
65 {
66 public:
67 ArmnnThreadPoolCallback(ArmnnPreparedModel<HalVersion>* model,
68 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
69 std::shared_ptr<armnn::InputTensors>& inputTensors,
70 std::shared_ptr<armnn::OutputTensors>& outputTensors,
71 CallbackContext callbackContext) :
72 m_Model(model),
73 m_MemPools(pMemPools),
74 m_InputTensors(inputTensors),
75 m_OutputTensors(outputTensors),
76 m_CallbackContext(callbackContext)
77 {}
78
79 void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
80
Finn Williamsd8fb5402021-05-19 20:52:00 +010081 ArmnnPreparedModel<HalVersion>* m_Model;
82 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
83 std::shared_ptr<armnn::InputTensors> m_InputTensors;
84 std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
85 CallbackContext m_CallbackContext;
86 };
87
telsoa015307bc12018-03-09 13:51:08 +000088 template <typename TensorBindingCollection>
89 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
90
Finn Williamsd8fb5402021-05-19 20:52:00 +010091 /// schedule the graph prepared from the request for execution
92 template<typename CallbackContext>
93 void ScheduleGraphForExecution(
94 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
95 std::shared_ptr<armnn::InputTensors>& inputTensors,
96 std::shared_ptr<armnn::OutputTensors>& outputTensors,
97 CallbackContext m_CallbackContext);
98
Finn Williamsfdf2eae2021-07-08 13:07:19 +010099 armnn::NetworkId m_NetworkId;
100 armnn::IRuntime* m_Runtime;
101 HalModel m_Model;
telsoa015307bc12018-03-09 13:51:08 +0000102 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
103 // It is specific to this class, so it is declared as static here
Finn Williamsfdf2eae2021-07-08 13:07:19 +0100104 static RequestThread<ArmnnPreparedModel,
105 HalVersion,
106 CallbackContext_1_0> m_RequestThread;
107 uint32_t m_RequestCount;
108 const std::string& m_RequestInputsAndOutputsDumpDir;
109 const bool m_GpuProfilingEnabled;
110 // Static to allow sharing of threadpool between ArmnnPreparedModel instances
111 static std::unique_ptr<armnn::Threadpool> m_Threadpool;
Finn Williamsca3a3e02021-06-11 15:04:02 +0100112 std::shared_ptr<armnn::IWorkingMemHandle> m_WorkingMemHandle;
Finn Williamsd8fb5402021-05-19 20:52:00 +0100113 const bool m_AsyncModelExecutionEnabled;
telsoa015307bc12018-03-09 13:51:08 +0000114};
115
116}