blob: e84cbe0a60039e0e276e22212595dead3b308ff5 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "Runtime.hpp"
6
David Beck056be3c2018-10-22 13:16:00 +01007#include <armnn/Version.hpp>
8#include <backends/BackendRegistry.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009
surmeh013537c2c2018-05-18 16:31:43 +010010#include <iostream>
11
telsoa014fcda012018-03-09 14:13:49 +000012#ifdef ARMCOMPUTECL_ENABLED
13#include <arm_compute/core/CL/OpenCL.h>
14#include <arm_compute/core/CL/CLKernelLibrary.h>
surmeh01bceff2f2018-03-29 16:29:27 +010015#include <arm_compute/runtime/CL/CLScheduler.h>
telsoa014fcda012018-03-09 14:13:49 +000016#endif
17
18#include <boost/log/trivial.hpp>
19#include <boost/polymorphic_cast.hpp>
20
21using namespace armnn;
22using namespace std;
23
24namespace armnn
25{
26
27IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
28{
29 return new Runtime(options);
30}
31
32IRuntimePtr IRuntime::Create(const CreationOptions& options)
33{
34 return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
35}
36
37void IRuntime::Destroy(IRuntime* runtime)
38{
39 delete boost::polymorphic_downcast<Runtime*>(runtime);
40}
41
42int Runtime::GenerateNetworkId()
43{
44 return m_NetworkIdCounter++;
45}
46
47Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
48{
telsoa01c577f2c2018-08-31 09:22:23 +010049 std::string ignoredErrorMessage;
50 return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage);
51}
52
53Status Runtime::LoadNetwork(NetworkId& networkIdOut,
54 IOptimizedNetworkPtr inNetwork,
55 std::string & errorMessage)
56{
telsoa014fcda012018-03-09 14:13:49 +000057 IOptimizedNetwork* rawNetwork = inNetwork.release();
58 unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
59 std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
telsoa01c577f2c2018-08-31 09:22:23 +010060 errorMessage);
telsoa014fcda012018-03-09 14:13:49 +000061
62 if (!loadedNetwork)
63 {
64 return Status::Failure;
65 }
66
67 networkIdOut = GenerateNetworkId();
68
telsoa01c577f2c2018-08-31 09:22:23 +010069 {
70 std::lock_guard<std::mutex> lockGuard(m_Mutex);
71
72 // Stores the network
73 m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
74 }
telsoa014fcda012018-03-09 14:13:49 +000075
76 return Status::Success;
telsoa014fcda012018-03-09 14:13:49 +000077}
78
79Status Runtime::UnloadNetwork(NetworkId networkId)
80{
surmeh01bceff2f2018-03-29 16:29:27 +010081#ifdef ARMCOMPUTECL_ENABLED
82 if (arm_compute::CLScheduler::get().context()() != NULL)
83 {
telsoa01c577f2c2018-08-31 09:22:23 +010084 // Waits for all queued CL requests to finish before unloading the network they may be using.
surmeh013537c2c2018-05-18 16:31:43 +010085 try
86 {
87 // Coverity fix: arm_compute::CLScheduler::sync() may throw an exception of type cl::Error.
88 arm_compute::CLScheduler::get().sync();
89 }
90 catch (const cl::Error&)
91 {
92 BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): an error occurred while waiting for "
93 "the queued CL requests to finish";
94 return Status::Failure;
95 }
surmeh01bceff2f2018-03-29 16:29:27 +010096 }
97#endif
surmeh013537c2c2018-05-18 16:31:43 +010098
telsoa014fcda012018-03-09 14:13:49 +000099 {
telsoa01c577f2c2018-08-31 09:22:23 +0100100 std::lock_guard<std::mutex> lockGuard(m_Mutex);
101
102 if (m_LoadedNetworks.erase(networkId) == 0)
103 {
104 BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
105 return Status::Failure;
106 }
107
telsoa014fcda012018-03-09 14:13:49 +0000108#ifdef ARMCOMPUTECL_ENABLED
telsoa01c577f2c2018-08-31 09:22:23 +0100109 if (arm_compute::CLScheduler::get().context()() != NULL && m_LoadedNetworks.empty())
110 {
111 // There are no loaded networks left, so clear the CL cache to free up memory
112 m_ClContextControl.ClearClCache();
113 }
telsoa014fcda012018-03-09 14:13:49 +0000114#endif
telsoa01c577f2c2018-08-31 09:22:23 +0100115 }
116
telsoa014fcda012018-03-09 14:13:49 +0000117 BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
118 return Status::Success;
119}
120
telsoa01c577f2c2018-08-31 09:22:23 +0100121const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
122{
123 auto it = m_LoadedNetworks.find(networkId);
124 if (it != m_LoadedNetworks.end())
125 {
126 auto& loadedNetwork = it->second;
127 return loadedNetwork->GetProfiler();
128 }
129
130 return nullptr;
131}
132
telsoa014fcda012018-03-09 14:13:49 +0000133Runtime::Runtime(const CreationOptions& options)
telsoa01c577f2c2018-08-31 09:22:23 +0100134 : m_ClContextControl(options.m_GpuAccTunedParameters.get(),
135 options.m_EnableGpuProfiling)
surmeh013537c2c2018-05-18 16:31:43 +0100136 , m_NetworkIdCounter(0)
David Beck056be3c2018-10-22 13:16:00 +0100137 , m_DeviceSpec{BackendRegistryInstance().GetBackendIds()}
telsoa014fcda012018-03-09 14:13:49 +0000138{
139 BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
surmeh01bceff2f2018-03-29 16:29:27 +0100140}
141
142Runtime::~Runtime()
143{
144 std::vector<int> networkIDs;
surmeh013537c2c2018-05-18 16:31:43 +0100145 try
146 {
147 // Coverity fix: The following code may throw an exception of type std::length_error.
148 std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
149 std::back_inserter(networkIDs),
150 [](const auto &pair) { return pair.first; });
151 }
152 catch (const std::exception& e)
153 {
154 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
155 // exception of type std::length_error.
156 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
157 std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
158 << "\nSome of the loaded networks may not be unloaded" << std::endl;
159 }
160 // We then proceed to unload all the networks which IDs have been appended to the list
161 // up to the point the exception was thrown (if any).
surmeh01bceff2f2018-03-29 16:29:27 +0100162
163 for (auto networkID : networkIDs)
164 {
surmeh013537c2c2018-05-18 16:31:43 +0100165 try
166 {
167 // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
168 // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
169 UnloadNetwork(networkID);
170 }
171 catch (const std::exception& e)
172 {
173 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
174 // exception of type std::length_error.
175 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
176 std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
177 << std::endl;
178 }
telsoa014fcda012018-03-09 14:13:49 +0000179 }
180}
181
surmeh013537c2c2018-05-18 16:31:43 +0100182LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
183{
184 std::lock_guard<std::mutex> lockGuard(m_Mutex);
185 return m_LoadedNetworks.at(networkId).get();
186}
187
telsoa014fcda012018-03-09 14:13:49 +0000188TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
189{
surmeh013537c2c2018-05-18 16:31:43 +0100190 return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
telsoa014fcda012018-03-09 14:13:49 +0000191}
192
193TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
194{
surmeh013537c2c2018-05-18 16:31:43 +0100195 return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
telsoa014fcda012018-03-09 14:13:49 +0000196}
197
Derek Lamberti03614f62018-10-02 15:52:46 +0100198
telsoa014fcda012018-03-09 14:13:49 +0000199Status Runtime::EnqueueWorkload(NetworkId networkId,
telsoa01c577f2c2018-08-31 09:22:23 +0100200 const InputTensors& inputTensors,
201 const OutputTensors& outputTensors)
telsoa014fcda012018-03-09 14:13:49 +0000202{
surmeh013537c2c2018-05-18 16:31:43 +0100203 LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
Derek Lamberti03614f62018-10-02 15:52:46 +0100204
205 static thread_local NetworkId lastId = networkId;
206 if (lastId != networkId)
207 {
208 LoadedNetworkFuncSafe(lastId, [](LoadedNetwork* network)
209 {
210 network->FreeWorkingMemory();
211 });
212 }
213 lastId=networkId;
214
surmeh013537c2c2018-05-18 16:31:43 +0100215 return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
telsoa014fcda012018-03-09 14:13:49 +0000216}
217
218}