blob: c95b2c45e2d789f0be14f779f2b185bfed09bff9 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "Runtime.hpp"
6
David Beck056be3c2018-10-22 13:16:00 +01007#include <armnn/Version.hpp>
Matteo Martincighc601aa62019-10-29 15:03:22 +00008#include <armnn/BackendRegistry.hpp>
Matteo Martincighe54aa062019-08-05 14:12:11 +01009
David Beck1b61be52018-11-08 09:19:14 +000010#include <backendsCommon/IBackendContext.hpp>
Matteo Martincighe54aa062019-08-05 14:12:11 +010011#include <backendsCommon/DynamicBackendUtils.hpp>
telsoa014fcda012018-03-09 14:13:49 +000012
Jim Flynnc4728ad2019-10-07 15:15:12 +010013#include "../profiling/ProfilingService.hpp"
14
surmeh013537c2c2018-05-18 16:31:43 +010015#include <iostream>
16
telsoa014fcda012018-03-09 14:13:49 +000017#include <boost/log/trivial.hpp>
18#include <boost/polymorphic_cast.hpp>
19
20using namespace armnn;
21using namespace std;
22
23namespace armnn
24{
25
26IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
27{
28 return new Runtime(options);
29}
30
31IRuntimePtr IRuntime::Create(const CreationOptions& options)
32{
33 return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
34}
35
36void IRuntime::Destroy(IRuntime* runtime)
37{
38 delete boost::polymorphic_downcast<Runtime*>(runtime);
39}
40
41int Runtime::GenerateNetworkId()
42{
43 return m_NetworkIdCounter++;
44}
45
46Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
47{
telsoa01c577f2c2018-08-31 09:22:23 +010048 std::string ignoredErrorMessage;
49 return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage);
50}
51
52Status Runtime::LoadNetwork(NetworkId& networkIdOut,
53 IOptimizedNetworkPtr inNetwork,
David Monahan4f1e8e42019-09-04 09:22:10 +010054 std::string& errorMessage)
55{
56 INetworkProperties networkProperties;
57 return LoadNetwork(networkIdOut, std::move(inNetwork), errorMessage, networkProperties);
58}
59
60Status Runtime::LoadNetwork(NetworkId& networkIdOut,
61 IOptimizedNetworkPtr inNetwork,
62 std::string& errorMessage,
63 const INetworkProperties& networkProperties)
telsoa01c577f2c2018-08-31 09:22:23 +010064{
telsoa014fcda012018-03-09 14:13:49 +000065 IOptimizedNetwork* rawNetwork = inNetwork.release();
David Beck1b61be52018-11-08 09:19:14 +000066
67 networkIdOut = GenerateNetworkId();
68
69 for (auto&& context : m_BackendContexts)
70 {
71 context.second->BeforeLoadNetwork(networkIdOut);
72 }
73
telsoa014fcda012018-03-09 14:13:49 +000074 unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
75 std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
David Monahan4f1e8e42019-09-04 09:22:10 +010076 errorMessage,
77 networkProperties);
telsoa014fcda012018-03-09 14:13:49 +000078
79 if (!loadedNetwork)
80 {
81 return Status::Failure;
82 }
83
telsoa01c577f2c2018-08-31 09:22:23 +010084 {
85 std::lock_guard<std::mutex> lockGuard(m_Mutex);
86
87 // Stores the network
88 m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
89 }
telsoa014fcda012018-03-09 14:13:49 +000090
David Beck1b61be52018-11-08 09:19:14 +000091 for (auto&& context : m_BackendContexts)
92 {
93 context.second->AfterLoadNetwork(networkIdOut);
94 }
95
telsoa014fcda012018-03-09 14:13:49 +000096 return Status::Success;
telsoa014fcda012018-03-09 14:13:49 +000097}
98
99Status Runtime::UnloadNetwork(NetworkId networkId)
100{
David Beck1b61be52018-11-08 09:19:14 +0000101 bool unloadOk = true;
102 for (auto&& context : m_BackendContexts)
David Beck9efb57d2018-11-05 13:40:33 +0000103 {
David Beck1b61be52018-11-08 09:19:14 +0000104 unloadOk &= context.second->BeforeUnloadNetwork(networkId);
David Beck9efb57d2018-11-05 13:40:33 +0000105 }
David Beck1b61be52018-11-08 09:19:14 +0000106
107 if (!unloadOk)
108 {
109 BOOST_LOG_TRIVIAL(warning) << "Runtime::UnloadNetwork(): failed to unload "
110 "network with ID:" << networkId << " because BeforeUnloadNetwork failed";
111 return Status::Failure;
112 }
David Beck9efb57d2018-11-05 13:40:33 +0000113
telsoa014fcda012018-03-09 14:13:49 +0000114 {
telsoa01c577f2c2018-08-31 09:22:23 +0100115 std::lock_guard<std::mutex> lockGuard(m_Mutex);
116
117 if (m_LoadedNetworks.erase(networkId) == 0)
118 {
119 BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
120 return Status::Failure;
121 }
David Beck1b61be52018-11-08 09:19:14 +0000122 }
David Beck9efb57d2018-11-05 13:40:33 +0000123
David Beck1b61be52018-11-08 09:19:14 +0000124 for (auto&& context : m_BackendContexts)
125 {
126 context.second->AfterUnloadNetwork(networkId);
telsoa01c577f2c2018-08-31 09:22:23 +0100127 }
128
telsoa014fcda012018-03-09 14:13:49 +0000129 BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
130 return Status::Success;
131}
132
telsoa01c577f2c2018-08-31 09:22:23 +0100133const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
134{
135 auto it = m_LoadedNetworks.find(networkId);
136 if (it != m_LoadedNetworks.end())
137 {
138 auto& loadedNetwork = it->second;
139 return loadedNetwork->GetProfiler();
140 }
141
142 return nullptr;
143}
144
telsoa014fcda012018-03-09 14:13:49 +0000145Runtime::Runtime(const CreationOptions& options)
David Beck1b61be52018-11-08 09:19:14 +0000146 : m_NetworkIdCounter(0)
David Beck056be3c2018-10-22 13:16:00 +0100147 , m_DeviceSpec{BackendRegistryInstance().GetBackendIds()}
telsoa014fcda012018-03-09 14:13:49 +0000148{
149 BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
David Beck1b61be52018-11-08 09:19:14 +0000150
Jim Flynnc4728ad2019-10-07 15:15:12 +0100151 // pass configuration info to the profiling service
Jim Flynn672d06e2019-10-15 10:18:11 +0100152 armnn::profiling::ProfilingService::Instance().ConfigureProfilingService(options.m_ProfilingOptions);
Jim Flynnc4728ad2019-10-07 15:15:12 +0100153
Matteo Martincighe54aa062019-08-05 14:12:11 +0100154 // Load any available/compatible dynamic backend before the runtime
155 // goes through the backend registry
156 LoadDynamicBackends(options.m_DynamicBackendsPath);
157
David Beck1b61be52018-11-08 09:19:14 +0000158 for (const auto& id : BackendRegistryInstance().GetBackendIds())
159 {
160 // Store backend contexts for the supported ones
Matteo Martincigh3d8a9ed2019-08-08 10:49:03 +0100161 const BackendIdSet& supportedBackends = m_DeviceSpec.GetSupportedBackends();
Matteo Martincigh89533902019-08-15 12:08:06 +0100162 if (supportedBackends.find(id) != supportedBackends.end())
David Beck1b61be52018-11-08 09:19:14 +0000163 {
164 auto factoryFun = BackendRegistryInstance().GetFactory(id);
165 auto backend = factoryFun();
166 BOOST_ASSERT(backend.get() != nullptr);
167
168 auto context = backend->CreateBackendContext(options);
169
170 // backends are allowed to return nullptrs if they
171 // don't wish to create a backend specific context
172 if (context)
173 {
174 m_BackendContexts.emplace(std::make_pair(id, std::move(context)));
175 }
176 }
177 }
surmeh01bceff2f2018-03-29 16:29:27 +0100178}
179
180Runtime::~Runtime()
181{
182 std::vector<int> networkIDs;
surmeh013537c2c2018-05-18 16:31:43 +0100183 try
184 {
185 // Coverity fix: The following code may throw an exception of type std::length_error.
186 std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
187 std::back_inserter(networkIDs),
188 [](const auto &pair) { return pair.first; });
189 }
190 catch (const std::exception& e)
191 {
192 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
193 // exception of type std::length_error.
194 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
195 std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
196 << "\nSome of the loaded networks may not be unloaded" << std::endl;
197 }
198 // We then proceed to unload all the networks which IDs have been appended to the list
199 // up to the point the exception was thrown (if any).
surmeh01bceff2f2018-03-29 16:29:27 +0100200
201 for (auto networkID : networkIDs)
202 {
surmeh013537c2c2018-05-18 16:31:43 +0100203 try
204 {
205 // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
206 // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
207 UnloadNetwork(networkID);
208 }
209 catch (const std::exception& e)
210 {
211 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
212 // exception of type std::length_error.
213 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
214 std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
215 << std::endl;
216 }
telsoa014fcda012018-03-09 14:13:49 +0000217 }
218}
219
surmeh013537c2c2018-05-18 16:31:43 +0100220LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
221{
222 std::lock_guard<std::mutex> lockGuard(m_Mutex);
223 return m_LoadedNetworks.at(networkId).get();
224}
225
telsoa014fcda012018-03-09 14:13:49 +0000226TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
227{
surmeh013537c2c2018-05-18 16:31:43 +0100228 return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
telsoa014fcda012018-03-09 14:13:49 +0000229}
230
231TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
232{
surmeh013537c2c2018-05-18 16:31:43 +0100233 return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
telsoa014fcda012018-03-09 14:13:49 +0000234}
235
Derek Lamberti03614f62018-10-02 15:52:46 +0100236
telsoa014fcda012018-03-09 14:13:49 +0000237Status Runtime::EnqueueWorkload(NetworkId networkId,
telsoa01c577f2c2018-08-31 09:22:23 +0100238 const InputTensors& inputTensors,
239 const OutputTensors& outputTensors)
telsoa014fcda012018-03-09 14:13:49 +0000240{
surmeh013537c2c2018-05-18 16:31:43 +0100241 LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
Derek Lamberti03614f62018-10-02 15:52:46 +0100242
243 static thread_local NetworkId lastId = networkId;
244 if (lastId != networkId)
245 {
246 LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
247 {
248 network->FreeWorkingMemory();
249 });
250 }
251 lastId=networkId;
252
surmeh013537c2c2018-05-18 16:31:43 +0100253 return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
telsoa014fcda012018-03-09 14:13:49 +0000254}
255
Nattapat Chaimanowong6e948202019-03-22 14:01:46 +0000256void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
257{
258 LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
259 loadedNetwork->RegisterDebugCallback(func);
260}
261
Matteo Martincighe54aa062019-08-05 14:12:11 +0100262void Runtime::LoadDynamicBackends(const std::string& overrideBackendPath)
263{
264 // Get the paths where to load the dynamic backends from
265 std::vector<std::string> backendPaths = DynamicBackendUtils::GetBackendPaths(overrideBackendPath);
266
267 // Get the shared objects to try to load as dynamic backends
268 std::vector<std::string> sharedObjects = DynamicBackendUtils::GetSharedObjects(backendPaths);
269
270 // Create a list of dynamic backends
Matteo Martincigh0c2b2892019-08-05 14:12:11 +0100271 m_DynamicBackends = DynamicBackendUtils::CreateDynamicBackends(sharedObjects);
272
273 // Register the dynamic backends in the backend registry
Matteo Martincigh89533902019-08-15 12:08:06 +0100274 BackendIdSet registeredBackendIds = DynamicBackendUtils::RegisterDynamicBackends(m_DynamicBackends);
275
276 // Add the registered dynamic backend ids to the list of supported backends
277 m_DeviceSpec.AddSupportedBackends(registeredBackendIds);
telsoa014fcda012018-03-09 14:13:49 +0000278}
Matteo Martincighe54aa062019-08-05 14:12:11 +0100279
280} // namespace armnn