blob: e47835687d6bada243815e03d606553f18264d5c [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "Runtime.hpp"
6
David Beck056be3c2018-10-22 13:16:00 +01007#include <armnn/Version.hpp>
Matteo Martincighe54aa062019-08-05 14:12:11 +01008
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009#include <backendsCommon/BackendRegistry.hpp>
David Beck1b61be52018-11-08 09:19:14 +000010#include <backendsCommon/IBackendContext.hpp>
Matteo Martincighe54aa062019-08-05 14:12:11 +010011#include <backendsCommon/DynamicBackendUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
surmeh013537c2c2018-05-18 16:31:43 +010013#include <iostream>
14
telsoa014fcda012018-03-09 14:13:49 +000015#include <boost/log/trivial.hpp>
16#include <boost/polymorphic_cast.hpp>
17
18using namespace armnn;
19using namespace std;
20
21namespace armnn
22{
23
24IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
25{
26 return new Runtime(options);
27}
28
29IRuntimePtr IRuntime::Create(const CreationOptions& options)
30{
31 return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
32}
33
34void IRuntime::Destroy(IRuntime* runtime)
35{
36 delete boost::polymorphic_downcast<Runtime*>(runtime);
37}
38
39int Runtime::GenerateNetworkId()
40{
41 return m_NetworkIdCounter++;
42}
43
44Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
45{
telsoa01c577f2c2018-08-31 09:22:23 +010046 std::string ignoredErrorMessage;
47 return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage);
48}
49
50Status Runtime::LoadNetwork(NetworkId& networkIdOut,
51 IOptimizedNetworkPtr inNetwork,
David Monahan4f1e8e42019-09-04 09:22:10 +010052 std::string& errorMessage)
53{
54 INetworkProperties networkProperties;
55 return LoadNetwork(networkIdOut, std::move(inNetwork), errorMessage, networkProperties);
56}
57
58Status Runtime::LoadNetwork(NetworkId& networkIdOut,
59 IOptimizedNetworkPtr inNetwork,
60 std::string& errorMessage,
61 const INetworkProperties& networkProperties)
telsoa01c577f2c2018-08-31 09:22:23 +010062{
telsoa014fcda012018-03-09 14:13:49 +000063 IOptimizedNetwork* rawNetwork = inNetwork.release();
David Beck1b61be52018-11-08 09:19:14 +000064
65 networkIdOut = GenerateNetworkId();
66
67 for (auto&& context : m_BackendContexts)
68 {
69 context.second->BeforeLoadNetwork(networkIdOut);
70 }
71
telsoa014fcda012018-03-09 14:13:49 +000072 unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
73 std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
David Monahan4f1e8e42019-09-04 09:22:10 +010074 errorMessage,
75 networkProperties);
telsoa014fcda012018-03-09 14:13:49 +000076
77 if (!loadedNetwork)
78 {
79 return Status::Failure;
80 }
81
telsoa01c577f2c2018-08-31 09:22:23 +010082 {
83 std::lock_guard<std::mutex> lockGuard(m_Mutex);
84
85 // Stores the network
86 m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
87 }
telsoa014fcda012018-03-09 14:13:49 +000088
David Beck1b61be52018-11-08 09:19:14 +000089 for (auto&& context : m_BackendContexts)
90 {
91 context.second->AfterLoadNetwork(networkIdOut);
92 }
93
telsoa014fcda012018-03-09 14:13:49 +000094 return Status::Success;
telsoa014fcda012018-03-09 14:13:49 +000095}
96
97Status Runtime::UnloadNetwork(NetworkId networkId)
98{
David Beck1b61be52018-11-08 09:19:14 +000099 bool unloadOk = true;
100 for (auto&& context : m_BackendContexts)
David Beck9efb57d2018-11-05 13:40:33 +0000101 {
David Beck1b61be52018-11-08 09:19:14 +0000102 unloadOk &= context.second->BeforeUnloadNetwork(networkId);
David Beck9efb57d2018-11-05 13:40:33 +0000103 }
David Beck1b61be52018-11-08 09:19:14 +0000104
105 if (!unloadOk)
106 {
107 BOOST_LOG_TRIVIAL(warning) << "Runtime::UnloadNetwork(): failed to unload "
108 "network with ID:" << networkId << " because BeforeUnloadNetwork failed";
109 return Status::Failure;
110 }
David Beck9efb57d2018-11-05 13:40:33 +0000111
telsoa014fcda012018-03-09 14:13:49 +0000112 {
telsoa01c577f2c2018-08-31 09:22:23 +0100113 std::lock_guard<std::mutex> lockGuard(m_Mutex);
114
115 if (m_LoadedNetworks.erase(networkId) == 0)
116 {
117 BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
118 return Status::Failure;
119 }
David Beck1b61be52018-11-08 09:19:14 +0000120 }
David Beck9efb57d2018-11-05 13:40:33 +0000121
David Beck1b61be52018-11-08 09:19:14 +0000122 for (auto&& context : m_BackendContexts)
123 {
124 context.second->AfterUnloadNetwork(networkId);
telsoa01c577f2c2018-08-31 09:22:23 +0100125 }
126
telsoa014fcda012018-03-09 14:13:49 +0000127 BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
128 return Status::Success;
129}
130
telsoa01c577f2c2018-08-31 09:22:23 +0100131const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
132{
133 auto it = m_LoadedNetworks.find(networkId);
134 if (it != m_LoadedNetworks.end())
135 {
136 auto& loadedNetwork = it->second;
137 return loadedNetwork->GetProfiler();
138 }
139
140 return nullptr;
141}
142
telsoa014fcda012018-03-09 14:13:49 +0000143Runtime::Runtime(const CreationOptions& options)
David Beck1b61be52018-11-08 09:19:14 +0000144 : m_NetworkIdCounter(0)
David Beck056be3c2018-10-22 13:16:00 +0100145 , m_DeviceSpec{BackendRegistryInstance().GetBackendIds()}
telsoa014fcda012018-03-09 14:13:49 +0000146{
147 BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
David Beck1b61be52018-11-08 09:19:14 +0000148
Matteo Martincighe54aa062019-08-05 14:12:11 +0100149 // Load any available/compatible dynamic backend before the runtime
150 // goes through the backend registry
151 LoadDynamicBackends(options.m_DynamicBackendsPath);
152
David Beck1b61be52018-11-08 09:19:14 +0000153 for (const auto& id : BackendRegistryInstance().GetBackendIds())
154 {
155 // Store backend contexts for the supported ones
Matteo Martincigh3d8a9ed2019-08-08 10:49:03 +0100156 const BackendIdSet& supportedBackends = m_DeviceSpec.GetSupportedBackends();
Matteo Martincigh89533902019-08-15 12:08:06 +0100157 if (supportedBackends.find(id) != supportedBackends.end())
David Beck1b61be52018-11-08 09:19:14 +0000158 {
159 auto factoryFun = BackendRegistryInstance().GetFactory(id);
160 auto backend = factoryFun();
161 BOOST_ASSERT(backend.get() != nullptr);
162
163 auto context = backend->CreateBackendContext(options);
164
165 // backends are allowed to return nullptrs if they
166 // don't wish to create a backend specific context
167 if (context)
168 {
169 m_BackendContexts.emplace(std::make_pair(id, std::move(context)));
170 }
171 }
172 }
surmeh01bceff2f2018-03-29 16:29:27 +0100173}
174
175Runtime::~Runtime()
176{
177 std::vector<int> networkIDs;
surmeh013537c2c2018-05-18 16:31:43 +0100178 try
179 {
180 // Coverity fix: The following code may throw an exception of type std::length_error.
181 std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
182 std::back_inserter(networkIDs),
183 [](const auto &pair) { return pair.first; });
184 }
185 catch (const std::exception& e)
186 {
187 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
188 // exception of type std::length_error.
189 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
190 std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
191 << "\nSome of the loaded networks may not be unloaded" << std::endl;
192 }
193 // We then proceed to unload all the networks which IDs have been appended to the list
194 // up to the point the exception was thrown (if any).
surmeh01bceff2f2018-03-29 16:29:27 +0100195
196 for (auto networkID : networkIDs)
197 {
surmeh013537c2c2018-05-18 16:31:43 +0100198 try
199 {
200 // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
201 // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
202 UnloadNetwork(networkID);
203 }
204 catch (const std::exception& e)
205 {
206 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
207 // exception of type std::length_error.
208 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
209 std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
210 << std::endl;
211 }
telsoa014fcda012018-03-09 14:13:49 +0000212 }
213}
214
surmeh013537c2c2018-05-18 16:31:43 +0100215LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
216{
217 std::lock_guard<std::mutex> lockGuard(m_Mutex);
218 return m_LoadedNetworks.at(networkId).get();
219}
220
telsoa014fcda012018-03-09 14:13:49 +0000221TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
222{
surmeh013537c2c2018-05-18 16:31:43 +0100223 return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
telsoa014fcda012018-03-09 14:13:49 +0000224}
225
226TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
227{
surmeh013537c2c2018-05-18 16:31:43 +0100228 return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
telsoa014fcda012018-03-09 14:13:49 +0000229}
230
Derek Lamberti03614f62018-10-02 15:52:46 +0100231
telsoa014fcda012018-03-09 14:13:49 +0000232Status Runtime::EnqueueWorkload(NetworkId networkId,
telsoa01c577f2c2018-08-31 09:22:23 +0100233 const InputTensors& inputTensors,
234 const OutputTensors& outputTensors)
telsoa014fcda012018-03-09 14:13:49 +0000235{
surmeh013537c2c2018-05-18 16:31:43 +0100236 LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
Derek Lamberti03614f62018-10-02 15:52:46 +0100237
238 static thread_local NetworkId lastId = networkId;
239 if (lastId != networkId)
240 {
241 LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
242 {
243 network->FreeWorkingMemory();
244 });
245 }
246 lastId=networkId;
247
surmeh013537c2c2018-05-18 16:31:43 +0100248 return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
telsoa014fcda012018-03-09 14:13:49 +0000249}
250
Nattapat Chaimanowong6e948202019-03-22 14:01:46 +0000251void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
252{
253 LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
254 loadedNetwork->RegisterDebugCallback(func);
255}
256
Matteo Martincighe54aa062019-08-05 14:12:11 +0100257void Runtime::LoadDynamicBackends(const std::string& overrideBackendPath)
258{
259 // Get the paths where to load the dynamic backends from
260 std::vector<std::string> backendPaths = DynamicBackendUtils::GetBackendPaths(overrideBackendPath);
261
262 // Get the shared objects to try to load as dynamic backends
263 std::vector<std::string> sharedObjects = DynamicBackendUtils::GetSharedObjects(backendPaths);
264
265 // Create a list of dynamic backends
Matteo Martincigh0c2b2892019-08-05 14:12:11 +0100266 m_DynamicBackends = DynamicBackendUtils::CreateDynamicBackends(sharedObjects);
267
268 // Register the dynamic backends in the backend registry
Matteo Martincigh89533902019-08-15 12:08:06 +0100269 BackendIdSet registeredBackendIds = DynamicBackendUtils::RegisterDynamicBackends(m_DynamicBackends);
270
271 // Add the registered dynamic backend ids to the list of supported backends
272 m_DeviceSpec.AddSupportedBackends(registeredBackendIds);
telsoa014fcda012018-03-09 14:13:49 +0000273}
Matteo Martincighe54aa062019-08-05 14:12:11 +0100274
275} // namespace armnn