blob: c1416f94d7c9be4440d8322c95241bfb705e020b [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "Runtime.hpp"
6
David Beck056be3c2018-10-22 13:16:00 +01007#include <armnn/Version.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +00008#include <armnn/BackendRegistry.hpp>
Matthew Benthamf48afc62020-01-15 17:55:08 +00009#include <armnn/Logging.hpp>
Matteo Martincighe54aa062019-08-05 14:12:11 +010010
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000011#include <armnn/backends/IBackendContext.hpp>
Matteo Martincighe54aa062019-08-05 14:12:11 +010012#include <backendsCommon/DynamicBackendUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013
Narumol Prangnawarat85ad78c2019-11-18 15:34:23 +000014#include <ProfilingService.hpp>
Jim Flynnc4728ad2019-10-07 15:15:12 +010015
surmeh013537c2c2018-05-18 16:31:43 +010016#include <iostream>
17
telsoa014fcda012018-03-09 14:13:49 +000018#include <boost/polymorphic_cast.hpp>
Colm Donelan1aff3932020-02-05 17:48:59 +000019#include <backends/BackendProfiling.hpp>
telsoa014fcda012018-03-09 14:13:49 +000020
21using namespace armnn;
22using namespace std;
23
24namespace armnn
25{
26
27IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
28{
29 return new Runtime(options);
30}
31
32IRuntimePtr IRuntime::Create(const CreationOptions& options)
33{
34 return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
35}
36
37void IRuntime::Destroy(IRuntime* runtime)
38{
39 delete boost::polymorphic_downcast<Runtime*>(runtime);
40}
41
42int Runtime::GenerateNetworkId()
43{
44 return m_NetworkIdCounter++;
45}
46
47Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
48{
telsoa01c577f2c2018-08-31 09:22:23 +010049 std::string ignoredErrorMessage;
50 return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage);
51}
52
53Status Runtime::LoadNetwork(NetworkId& networkIdOut,
54 IOptimizedNetworkPtr inNetwork,
David Monahan4f1e8e42019-09-04 09:22:10 +010055 std::string& errorMessage)
56{
57 INetworkProperties networkProperties;
58 return LoadNetwork(networkIdOut, std::move(inNetwork), errorMessage, networkProperties);
59}
60
61Status Runtime::LoadNetwork(NetworkId& networkIdOut,
62 IOptimizedNetworkPtr inNetwork,
63 std::string& errorMessage,
64 const INetworkProperties& networkProperties)
telsoa01c577f2c2018-08-31 09:22:23 +010065{
telsoa014fcda012018-03-09 14:13:49 +000066 IOptimizedNetwork* rawNetwork = inNetwork.release();
David Beck1b61be52018-11-08 09:19:14 +000067
68 networkIdOut = GenerateNetworkId();
69
70 for (auto&& context : m_BackendContexts)
71 {
72 context.second->BeforeLoadNetwork(networkIdOut);
73 }
74
telsoa014fcda012018-03-09 14:13:49 +000075 unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
76 std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
David Monahan4f1e8e42019-09-04 09:22:10 +010077 errorMessage,
78 networkProperties);
telsoa014fcda012018-03-09 14:13:49 +000079
80 if (!loadedNetwork)
81 {
82 return Status::Failure;
83 }
84
telsoa01c577f2c2018-08-31 09:22:23 +010085 {
86 std::lock_guard<std::mutex> lockGuard(m_Mutex);
87
88 // Stores the network
89 m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
90 }
telsoa014fcda012018-03-09 14:13:49 +000091
David Beck1b61be52018-11-08 09:19:14 +000092 for (auto&& context : m_BackendContexts)
93 {
94 context.second->AfterLoadNetwork(networkIdOut);
95 }
96
Keith Davise394bd92019-12-02 15:12:19 +000097 if (profiling::ProfilingService::Instance().IsProfilingEnabled())
98 {
99 profiling::ProfilingService::Instance().IncrementCounterValue(armnn::profiling::NETWORK_LOADS);
100 }
101
telsoa014fcda012018-03-09 14:13:49 +0000102 return Status::Success;
telsoa014fcda012018-03-09 14:13:49 +0000103}
104
105Status Runtime::UnloadNetwork(NetworkId networkId)
106{
David Beck1b61be52018-11-08 09:19:14 +0000107 bool unloadOk = true;
108 for (auto&& context : m_BackendContexts)
David Beck9efb57d2018-11-05 13:40:33 +0000109 {
David Beck1b61be52018-11-08 09:19:14 +0000110 unloadOk &= context.second->BeforeUnloadNetwork(networkId);
David Beck9efb57d2018-11-05 13:40:33 +0000111 }
David Beck1b61be52018-11-08 09:19:14 +0000112
113 if (!unloadOk)
114 {
Derek Lamberti08446972019-11-26 16:38:31 +0000115 ARMNN_LOG(warning) << "Runtime::UnloadNetwork(): failed to unload "
116 "network with ID:" << networkId << " because BeforeUnloadNetwork failed";
David Beck1b61be52018-11-08 09:19:14 +0000117 return Status::Failure;
118 }
David Beck9efb57d2018-11-05 13:40:33 +0000119
telsoa014fcda012018-03-09 14:13:49 +0000120 {
telsoa01c577f2c2018-08-31 09:22:23 +0100121 std::lock_guard<std::mutex> lockGuard(m_Mutex);
122
123 if (m_LoadedNetworks.erase(networkId) == 0)
124 {
Derek Lamberti08446972019-11-26 16:38:31 +0000125 ARMNN_LOG(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
telsoa01c577f2c2018-08-31 09:22:23 +0100126 return Status::Failure;
127 }
Keith Davise394bd92019-12-02 15:12:19 +0000128 if (profiling::ProfilingService::Instance().IsProfilingEnabled())
129 {
130 profiling::ProfilingService::Instance().IncrementCounterValue(armnn::profiling::NETWORK_UNLOADS);
131 }
David Beck1b61be52018-11-08 09:19:14 +0000132 }
David Beck9efb57d2018-11-05 13:40:33 +0000133
David Beck1b61be52018-11-08 09:19:14 +0000134 for (auto&& context : m_BackendContexts)
135 {
136 context.second->AfterUnloadNetwork(networkId);
telsoa01c577f2c2018-08-31 09:22:23 +0100137 }
138
Derek Lamberti08446972019-11-26 16:38:31 +0000139 ARMNN_LOG(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
telsoa014fcda012018-03-09 14:13:49 +0000140 return Status::Success;
141}
142
telsoa01c577f2c2018-08-31 09:22:23 +0100143const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
144{
145 auto it = m_LoadedNetworks.find(networkId);
146 if (it != m_LoadedNetworks.end())
147 {
148 auto& loadedNetwork = it->second;
149 return loadedNetwork->GetProfiler();
150 }
151
152 return nullptr;
153}
154
telsoa014fcda012018-03-09 14:13:49 +0000155Runtime::Runtime(const CreationOptions& options)
David Beck1b61be52018-11-08 09:19:14 +0000156 : m_NetworkIdCounter(0)
telsoa014fcda012018-03-09 14:13:49 +0000157{
Derek Lamberti08446972019-11-26 16:38:31 +0000158 ARMNN_LOG(info) << "ArmNN v" << ARMNN_VERSION << "\n";
David Beck1b61be52018-11-08 09:19:14 +0000159
Jim Flynnc4728ad2019-10-07 15:15:12 +0100160 // pass configuration info to the profiling service
Jim Flynn672d06e2019-10-15 10:18:11 +0100161 armnn::profiling::ProfilingService::Instance().ConfigureProfilingService(options.m_ProfilingOptions);
Jim Flynnc4728ad2019-10-07 15:15:12 +0100162
Matteo Martincighe54aa062019-08-05 14:12:11 +0100163 // Load any available/compatible dynamic backend before the runtime
164 // goes through the backend registry
165 LoadDynamicBackends(options.m_DynamicBackendsPath);
166
Matthew Bentham9a61fa62020-02-04 10:03:55 +0000167 BackendIdSet supportedBackends;
David Beck1b61be52018-11-08 09:19:14 +0000168 for (const auto& id : BackendRegistryInstance().GetBackendIds())
169 {
170 // Store backend contexts for the supported ones
Matthew Bentham9a61fa62020-02-04 10:03:55 +0000171 try {
David Beck1b61be52018-11-08 09:19:14 +0000172 auto factoryFun = BackendRegistryInstance().GetFactory(id);
173 auto backend = factoryFun();
174 BOOST_ASSERT(backend.get() != nullptr);
175
176 auto context = backend->CreateBackendContext(options);
177
178 // backends are allowed to return nullptrs if they
179 // don't wish to create a backend specific context
180 if (context)
181 {
182 m_BackendContexts.emplace(std::make_pair(id, std::move(context)));
183 }
Matthew Bentham9a61fa62020-02-04 10:03:55 +0000184 supportedBackends.emplace(id);
Colm Donelan1aff3932020-02-05 17:48:59 +0000185
186 unique_ptr<armnn::profiling::IBackendProfiling> profilingIface =
187 std::make_unique<armnn::profiling::BackendProfiling>(armnn::profiling::BackendProfiling(
188 options, armnn::profiling::ProfilingService::Instance(), id));
189
190 // Backends may also provide a profiling context. Ask for it now.
191 auto profilingContext = backend->CreateBackendProfilingContext(options, profilingIface);
192 // Backends that don't support profiling will return a null profiling context.
193 if (profilingContext)
194 {
195 // Pass the context onto the profiling service.
196 armnn::profiling::ProfilingService::Instance().AddBackendProfilingContext(id, profilingContext);
197 }
David Beck1b61be52018-11-08 09:19:14 +0000198 }
Matthew Bentham9a61fa62020-02-04 10:03:55 +0000199 catch (const BackendUnavailableException&)
200 {
201 // Ignore backends which are unavailable
202 }
203
David Beck1b61be52018-11-08 09:19:14 +0000204 }
Matthew Bentham9a61fa62020-02-04 10:03:55 +0000205 m_DeviceSpec.AddSupportedBackends(supportedBackends);
surmeh01bceff2f2018-03-29 16:29:27 +0100206}
207
208Runtime::~Runtime()
209{
210 std::vector<int> networkIDs;
surmeh013537c2c2018-05-18 16:31:43 +0100211 try
212 {
213 // Coverity fix: The following code may throw an exception of type std::length_error.
214 std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
215 std::back_inserter(networkIDs),
216 [](const auto &pair) { return pair.first; });
217 }
218 catch (const std::exception& e)
219 {
220 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
221 // exception of type std::length_error.
222 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
223 std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
224 << "\nSome of the loaded networks may not be unloaded" << std::endl;
225 }
226 // We then proceed to unload all the networks which IDs have been appended to the list
227 // up to the point the exception was thrown (if any).
surmeh01bceff2f2018-03-29 16:29:27 +0100228
229 for (auto networkID : networkIDs)
230 {
surmeh013537c2c2018-05-18 16:31:43 +0100231 try
232 {
233 // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
234 // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
235 UnloadNetwork(networkID);
236 }
237 catch (const std::exception& e)
238 {
239 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
240 // exception of type std::length_error.
241 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
242 std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
243 << std::endl;
244 }
telsoa014fcda012018-03-09 14:13:49 +0000245 }
Narumol Prangnawarat60a20fb2019-12-09 17:24:41 +0000246
Colm Donelan1aff3932020-02-05 17:48:59 +0000247
Narumol Prangnawarat60a20fb2019-12-09 17:24:41 +0000248 // Clear all dynamic backends.
249 DynamicBackendUtils::DeregisterDynamicBackends(m_DeviceSpec.GetDynamicBackends());
250 m_DeviceSpec.ClearDynamicBackends();
Colm Donelan1aff3932020-02-05 17:48:59 +0000251 m_BackendContexts.clear();
telsoa014fcda012018-03-09 14:13:49 +0000252}
253
surmeh013537c2c2018-05-18 16:31:43 +0100254LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
255{
256 std::lock_guard<std::mutex> lockGuard(m_Mutex);
257 return m_LoadedNetworks.at(networkId).get();
258}
259
telsoa014fcda012018-03-09 14:13:49 +0000260TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
261{
surmeh013537c2c2018-05-18 16:31:43 +0100262 return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
telsoa014fcda012018-03-09 14:13:49 +0000263}
264
265TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
266{
surmeh013537c2c2018-05-18 16:31:43 +0100267 return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
telsoa014fcda012018-03-09 14:13:49 +0000268}
269
Derek Lamberti03614f62018-10-02 15:52:46 +0100270
telsoa014fcda012018-03-09 14:13:49 +0000271Status Runtime::EnqueueWorkload(NetworkId networkId,
telsoa01c577f2c2018-08-31 09:22:23 +0100272 const InputTensors& inputTensors,
273 const OutputTensors& outputTensors)
telsoa014fcda012018-03-09 14:13:49 +0000274{
surmeh013537c2c2018-05-18 16:31:43 +0100275 LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
Derek Lamberti03614f62018-10-02 15:52:46 +0100276
277 static thread_local NetworkId lastId = networkId;
278 if (lastId != networkId)
279 {
280 LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
281 {
282 network->FreeWorkingMemory();
283 });
284 }
285 lastId=networkId;
286
surmeh013537c2c2018-05-18 16:31:43 +0100287 return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
telsoa014fcda012018-03-09 14:13:49 +0000288}
289
Nattapat Chaimanowong6e948202019-03-22 14:01:46 +0000290void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
291{
292 LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
293 loadedNetwork->RegisterDebugCallback(func);
294}
295
Matteo Martincighe54aa062019-08-05 14:12:11 +0100296void Runtime::LoadDynamicBackends(const std::string& overrideBackendPath)
297{
298 // Get the paths where to load the dynamic backends from
299 std::vector<std::string> backendPaths = DynamicBackendUtils::GetBackendPaths(overrideBackendPath);
300
301 // Get the shared objects to try to load as dynamic backends
302 std::vector<std::string> sharedObjects = DynamicBackendUtils::GetSharedObjects(backendPaths);
303
304 // Create a list of dynamic backends
Matteo Martincigh0c2b2892019-08-05 14:12:11 +0100305 m_DynamicBackends = DynamicBackendUtils::CreateDynamicBackends(sharedObjects);
306
307 // Register the dynamic backends in the backend registry
Matteo Martincigh89533902019-08-15 12:08:06 +0100308 BackendIdSet registeredBackendIds = DynamicBackendUtils::RegisterDynamicBackends(m_DynamicBackends);
309
310 // Add the registered dynamic backend ids to the list of supported backends
Narumol Prangnawarat60a20fb2019-12-09 17:24:41 +0000311 m_DeviceSpec.AddSupportedBackends(registeredBackendIds, true);
telsoa014fcda012018-03-09 14:13:49 +0000312}
Matteo Martincighe54aa062019-08-05 14:12:11 +0100313
314} // namespace armnn