blob: 0d19c077289366ce5db2f8ce884a60b8e6bc1f83 [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
telsoa015307bc12018-03-09 13:51:08 +00004//
5
6#pragma once
7
surmeh01deb3bdb2018-07-05 12:06:04 +01008#include "ArmnnDriver.hpp"
Matteo Martincighe48bdff2018-09-03 13:50:50 +01009#include "ArmnnDriverImpl.hpp"
arovir01b0717b52018-09-05 17:03:25 +010010#include "RequestThread.hpp"
surmeh01deb3bdb2018-07-05 12:06:04 +010011
telsoa01ce3e84a2018-08-31 09:31:35 +010012#include <NeuralNetworks.h>
13#include <armnn/ArmNN.hpp>
Finn Williamsca3a3e02021-06-11 15:04:02 +010014#include <armnn/Threadpool.hpp>
telsoa01ce3e84a2018-08-31 09:31:35 +010015
telsoa015307bc12018-03-09 13:51:08 +000016#include <string>
17#include <vector>
18
19namespace armnn_driver
20{
Mike Kelly65c42dc2019-07-22 14:06:00 +010021using armnnExecuteCallback_1_0 = std::function<void(V1_0::ErrorStatus status, std::string callingFunction)>;
22
23struct ArmnnCallback_1_0
24{
25 armnnExecuteCallback_1_0 callback;
26};
telsoa015307bc12018-03-09 13:51:08 +000027
Derek Lamberti4de83c52020-03-17 13:40:18 +000028struct ExecutionContext_1_0 {};
29
30using CallbackContext_1_0 = CallbackContext<armnnExecuteCallback_1_0, ExecutionContext_1_0>;
31
Matteo Martincighe48bdff2018-09-03 13:50:50 +010032template <typename HalVersion>
Matthew Bentham912b3622019-05-03 15:49:14 +010033class ArmnnPreparedModel : public V1_0::IPreparedModel
telsoa015307bc12018-03-09 13:51:08 +000034{
35public:
Matteo Martincighe48bdff2018-09-03 13:50:50 +010036 using HalModel = typename HalVersion::Model;
37
telsoa015307bc12018-03-09 13:51:08 +000038 ArmnnPreparedModel(armnn::NetworkId networkId,
39 armnn::IRuntime* runtime,
Matteo Martincighe48bdff2018-09-03 13:50:50 +010040 const HalModel& model,
telsoa01ce3e84a2018-08-31 09:31:35 +010041 const std::string& requestInputsAndOutputsDumpDir,
Finn Williamsd8fb5402021-05-19 20:52:00 +010042 const bool gpuProfilingEnabled,
Finn Williamsca3a3e02021-06-11 15:04:02 +010043 const bool asyncModelExecutionEnabled = false,
Narumol Prangnawarat558a1d42022-02-07 13:12:24 +000044 const unsigned int numberOfThreads = 1,
45 const bool importEnabled = false,
46 const bool exportEnabled = true);
telsoa015307bc12018-03-09 13:51:08 +000047
48 virtual ~ArmnnPreparedModel();
49
Kevin Mayec1e5b82020-02-26 17:00:39 +000050 virtual Return<V1_0::ErrorStatus> execute(const V1_0::Request& request,
51 const ::android::sp<V1_0::IExecutionCallback>& callback) override;
telsoa015307bc12018-03-09 13:51:08 +000052
53 /// execute the graph prepared from the request
54 void ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
Derek Lamberti4de83c52020-03-17 13:40:18 +000055 armnn::InputTensors& inputTensors,
56 armnn::OutputTensors& outputTensors,
57 CallbackContext_1_0 callback);
telsoa015307bc12018-03-09 13:51:08 +000058
59 /// Executes this model with dummy inputs (e.g. all zeroes).
Matthew Bentham16196e22019-04-01 17:17:58 +010060 /// \return false on failure, otherwise true
61 bool ExecuteWithDummyInputs();
telsoa015307bc12018-03-09 13:51:08 +000062
63private:
Finn Williamsd8fb5402021-05-19 20:52:00 +010064
65 template<typename CallbackContext>
66 class ArmnnThreadPoolCallback : public armnn::IAsyncExecutionCallback
67 {
68 public:
69 ArmnnThreadPoolCallback(ArmnnPreparedModel<HalVersion>* model,
70 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
71 std::shared_ptr<armnn::InputTensors>& inputTensors,
72 std::shared_ptr<armnn::OutputTensors>& outputTensors,
73 CallbackContext callbackContext) :
74 m_Model(model),
75 m_MemPools(pMemPools),
76 m_InputTensors(inputTensors),
77 m_OutputTensors(outputTensors),
78 m_CallbackContext(callbackContext)
79 {}
80
81 void Notify(armnn::Status status, armnn::InferenceTimingPair timeTaken) override;
82
Finn Williamsd8fb5402021-05-19 20:52:00 +010083 ArmnnPreparedModel<HalVersion>* m_Model;
84 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> m_MemPools;
85 std::shared_ptr<armnn::InputTensors> m_InputTensors;
86 std::shared_ptr<armnn::OutputTensors> m_OutputTensors;
87 CallbackContext m_CallbackContext;
88 };
89
telsoa015307bc12018-03-09 13:51:08 +000090 template <typename TensorBindingCollection>
91 void DumpTensorsIfRequired(char const* tensorNamePrefix, const TensorBindingCollection& tensorBindings);
92
Finn Williamsd8fb5402021-05-19 20:52:00 +010093 /// schedule the graph prepared from the request for execution
94 template<typename CallbackContext>
95 void ScheduleGraphForExecution(
96 std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
97 std::shared_ptr<armnn::InputTensors>& inputTensors,
98 std::shared_ptr<armnn::OutputTensors>& outputTensors,
99 CallbackContext m_CallbackContext);
100
Finn Williamsfdf2eae2021-07-08 13:07:19 +0100101 armnn::NetworkId m_NetworkId;
102 armnn::IRuntime* m_Runtime;
103 HalModel m_Model;
telsoa015307bc12018-03-09 13:51:08 +0000104 // There must be a single RequestThread for all ArmnnPreparedModel objects to ensure serial execution of workloads
105 // It is specific to this class, so it is declared as static here
Finn Williamsfdf2eae2021-07-08 13:07:19 +0100106 static RequestThread<ArmnnPreparedModel,
107 HalVersion,
108 CallbackContext_1_0> m_RequestThread;
109 uint32_t m_RequestCount;
110 const std::string& m_RequestInputsAndOutputsDumpDir;
111 const bool m_GpuProfilingEnabled;
112 // Static to allow sharing of threadpool between ArmnnPreparedModel instances
113 static std::unique_ptr<armnn::Threadpool> m_Threadpool;
Finn Williamsca3a3e02021-06-11 15:04:02 +0100114 std::shared_ptr<armnn::IWorkingMemHandle> m_WorkingMemHandle;
Finn Williamsd8fb5402021-05-19 20:52:00 +0100115 const bool m_AsyncModelExecutionEnabled;
Narumol Prangnawarat558a1d42022-02-07 13:12:24 +0000116 const bool m_EnableImport;
117 const bool m_EnableExport;
telsoa015307bc12018-03-09 13:51:08 +0000118};
119
120}