blob: f8b2462f966766c0cdc1dd604fa08267c0cebaac [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "Runtime.hpp"
6
David Beck056be3c2018-10-22 13:16:00 +01007#include <armnn/Version.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00008#include <backendsCommon/BackendRegistry.hpp>
David Beck1b61be52018-11-08 09:19:14 +00009#include <backendsCommon/IBackendContext.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010
surmeh013537c2c2018-05-18 16:31:43 +010011#include <iostream>
12
telsoa014fcda012018-03-09 14:13:49 +000013#include <boost/log/trivial.hpp>
14#include <boost/polymorphic_cast.hpp>
15
16using namespace armnn;
17using namespace std;
18
19namespace armnn
20{
21
22IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
23{
24 return new Runtime(options);
25}
26
27IRuntimePtr IRuntime::Create(const CreationOptions& options)
28{
29 return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
30}
31
32void IRuntime::Destroy(IRuntime* runtime)
33{
34 delete boost::polymorphic_downcast<Runtime*>(runtime);
35}
36
37int Runtime::GenerateNetworkId()
38{
39 return m_NetworkIdCounter++;
40}
41
42Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
43{
telsoa01c577f2c2018-08-31 09:22:23 +010044 std::string ignoredErrorMessage;
45 return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage);
46}
47
48Status Runtime::LoadNetwork(NetworkId& networkIdOut,
49 IOptimizedNetworkPtr inNetwork,
50 std::string & errorMessage)
51{
telsoa014fcda012018-03-09 14:13:49 +000052 IOptimizedNetwork* rawNetwork = inNetwork.release();
David Beck1b61be52018-11-08 09:19:14 +000053
54 networkIdOut = GenerateNetworkId();
55
56 for (auto&& context : m_BackendContexts)
57 {
58 context.second->BeforeLoadNetwork(networkIdOut);
59 }
60
telsoa014fcda012018-03-09 14:13:49 +000061 unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
62 std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
telsoa01c577f2c2018-08-31 09:22:23 +010063 errorMessage);
telsoa014fcda012018-03-09 14:13:49 +000064
65 if (!loadedNetwork)
66 {
67 return Status::Failure;
68 }
69
telsoa01c577f2c2018-08-31 09:22:23 +010070 {
71 std::lock_guard<std::mutex> lockGuard(m_Mutex);
72
73 // Stores the network
74 m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
75 }
telsoa014fcda012018-03-09 14:13:49 +000076
David Beck1b61be52018-11-08 09:19:14 +000077 for (auto&& context : m_BackendContexts)
78 {
79 context.second->AfterLoadNetwork(networkIdOut);
80 }
81
telsoa014fcda012018-03-09 14:13:49 +000082 return Status::Success;
telsoa014fcda012018-03-09 14:13:49 +000083}
84
85Status Runtime::UnloadNetwork(NetworkId networkId)
86{
David Beck1b61be52018-11-08 09:19:14 +000087 bool unloadOk = true;
88 for (auto&& context : m_BackendContexts)
David Beck9efb57d2018-11-05 13:40:33 +000089 {
David Beck1b61be52018-11-08 09:19:14 +000090 unloadOk &= context.second->BeforeUnloadNetwork(networkId);
David Beck9efb57d2018-11-05 13:40:33 +000091 }
David Beck1b61be52018-11-08 09:19:14 +000092
93 if (!unloadOk)
94 {
95 BOOST_LOG_TRIVIAL(warning) << "Runtime::UnloadNetwork(): failed to unload "
96 "network with ID:" << networkId << " because BeforeUnloadNetwork failed";
97 return Status::Failure;
98 }
David Beck9efb57d2018-11-05 13:40:33 +000099
telsoa014fcda012018-03-09 14:13:49 +0000100 {
telsoa01c577f2c2018-08-31 09:22:23 +0100101 std::lock_guard<std::mutex> lockGuard(m_Mutex);
102
103 if (m_LoadedNetworks.erase(networkId) == 0)
104 {
105 BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
106 return Status::Failure;
107 }
David Beck1b61be52018-11-08 09:19:14 +0000108 }
David Beck9efb57d2018-11-05 13:40:33 +0000109
David Beck1b61be52018-11-08 09:19:14 +0000110 for (auto&& context : m_BackendContexts)
111 {
112 context.second->AfterUnloadNetwork(networkId);
telsoa01c577f2c2018-08-31 09:22:23 +0100113 }
114
telsoa014fcda012018-03-09 14:13:49 +0000115 BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
116 return Status::Success;
117}
118
telsoa01c577f2c2018-08-31 09:22:23 +0100119const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
120{
121 auto it = m_LoadedNetworks.find(networkId);
122 if (it != m_LoadedNetworks.end())
123 {
124 auto& loadedNetwork = it->second;
125 return loadedNetwork->GetProfiler();
126 }
127
128 return nullptr;
129}
130
telsoa014fcda012018-03-09 14:13:49 +0000131Runtime::Runtime(const CreationOptions& options)
David Beck1b61be52018-11-08 09:19:14 +0000132 : m_NetworkIdCounter(0)
David Beck056be3c2018-10-22 13:16:00 +0100133 , m_DeviceSpec{BackendRegistryInstance().GetBackendIds()}
telsoa014fcda012018-03-09 14:13:49 +0000134{
135 BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
David Beck1b61be52018-11-08 09:19:14 +0000136
137 for (const auto& id : BackendRegistryInstance().GetBackendIds())
138 {
139 // Store backend contexts for the supported ones
140 if (m_DeviceSpec.GetSupportedBackends().count(id) > 0)
141 {
142 auto factoryFun = BackendRegistryInstance().GetFactory(id);
143 auto backend = factoryFun();
144 BOOST_ASSERT(backend.get() != nullptr);
145
146 auto context = backend->CreateBackendContext(options);
147
148 // backends are allowed to return nullptrs if they
149 // don't wish to create a backend specific context
150 if (context)
151 {
152 m_BackendContexts.emplace(std::make_pair(id, std::move(context)));
153 }
154 }
155 }
surmeh01bceff2f2018-03-29 16:29:27 +0100156}
157
158Runtime::~Runtime()
159{
160 std::vector<int> networkIDs;
surmeh013537c2c2018-05-18 16:31:43 +0100161 try
162 {
163 // Coverity fix: The following code may throw an exception of type std::length_error.
164 std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
165 std::back_inserter(networkIDs),
166 [](const auto &pair) { return pair.first; });
167 }
168 catch (const std::exception& e)
169 {
170 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
171 // exception of type std::length_error.
172 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
173 std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
174 << "\nSome of the loaded networks may not be unloaded" << std::endl;
175 }
176 // We then proceed to unload all the networks which IDs have been appended to the list
177 // up to the point the exception was thrown (if any).
surmeh01bceff2f2018-03-29 16:29:27 +0100178
179 for (auto networkID : networkIDs)
180 {
surmeh013537c2c2018-05-18 16:31:43 +0100181 try
182 {
183 // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
184 // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
185 UnloadNetwork(networkID);
186 }
187 catch (const std::exception& e)
188 {
189 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
190 // exception of type std::length_error.
191 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
192 std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
193 << std::endl;
194 }
telsoa014fcda012018-03-09 14:13:49 +0000195 }
196}
197
surmeh013537c2c2018-05-18 16:31:43 +0100198LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
199{
200 std::lock_guard<std::mutex> lockGuard(m_Mutex);
201 return m_LoadedNetworks.at(networkId).get();
202}
203
telsoa014fcda012018-03-09 14:13:49 +0000204TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
205{
surmeh013537c2c2018-05-18 16:31:43 +0100206 return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
telsoa014fcda012018-03-09 14:13:49 +0000207}
208
209TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
210{
surmeh013537c2c2018-05-18 16:31:43 +0100211 return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
telsoa014fcda012018-03-09 14:13:49 +0000212}
213
Derek Lamberti03614f62018-10-02 15:52:46 +0100214
telsoa014fcda012018-03-09 14:13:49 +0000215Status Runtime::EnqueueWorkload(NetworkId networkId,
telsoa01c577f2c2018-08-31 09:22:23 +0100216 const InputTensors& inputTensors,
217 const OutputTensors& outputTensors)
telsoa014fcda012018-03-09 14:13:49 +0000218{
surmeh013537c2c2018-05-18 16:31:43 +0100219 LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
Derek Lamberti03614f62018-10-02 15:52:46 +0100220
221 static thread_local NetworkId lastId = networkId;
222 if (lastId != networkId)
223 {
224 LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
225 {
226 network->FreeWorkingMemory();
227 });
228 }
229 lastId=networkId;
230
surmeh013537c2c2018-05-18 16:31:43 +0100231 return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
telsoa014fcda012018-03-09 14:13:49 +0000232}
233
Nattapat Chaimanowong6e948202019-03-22 14:01:46 +0000234void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
235{
236 LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
237 loadedNetwork->RegisterDebugCallback(func);
238}
239
telsoa014fcda012018-03-09 14:13:49 +0000240}