blob: 7d1a9faaeaedca599df98f3d722c7f39f3e49994 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include "Runtime.hpp"
6
7#include "armnn/Version.hpp"
8
surmeh013537c2c2018-05-18 16:31:43 +01009#include <iostream>
10
telsoa014fcda012018-03-09 14:13:49 +000011#ifdef ARMCOMPUTECL_ENABLED
12#include <arm_compute/core/CL/OpenCL.h>
13#include <arm_compute/core/CL/CLKernelLibrary.h>
surmeh01bceff2f2018-03-29 16:29:27 +010014#include <arm_compute/runtime/CL/CLScheduler.h>
telsoa014fcda012018-03-09 14:13:49 +000015#endif
16
17#include <boost/log/trivial.hpp>
18#include <boost/polymorphic_cast.hpp>
19
20using namespace armnn;
21using namespace std;
22
23namespace armnn
24{
25
26IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
27{
28 return new Runtime(options);
29}
30
31IRuntimePtr IRuntime::Create(const CreationOptions& options)
32{
33 return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
34}
35
36void IRuntime::Destroy(IRuntime* runtime)
37{
38 delete boost::polymorphic_downcast<Runtime*>(runtime);
39}
40
41int Runtime::GenerateNetworkId()
42{
43 return m_NetworkIdCounter++;
44}
45
46Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
47{
telsoa01c577f2c2018-08-31 09:22:23 +010048 std::string ignoredErrorMessage;
49 return LoadNetwork(networkIdOut, std::move(inNetwork), ignoredErrorMessage);
50}
51
52Status Runtime::LoadNetwork(NetworkId& networkIdOut,
53 IOptimizedNetworkPtr inNetwork,
54 std::string & errorMessage)
55{
telsoa014fcda012018-03-09 14:13:49 +000056 IOptimizedNetwork* rawNetwork = inNetwork.release();
57 unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
58 std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
telsoa01c577f2c2018-08-31 09:22:23 +010059 errorMessage);
telsoa014fcda012018-03-09 14:13:49 +000060
61 if (!loadedNetwork)
62 {
63 return Status::Failure;
64 }
65
66 networkIdOut = GenerateNetworkId();
67
telsoa01c577f2c2018-08-31 09:22:23 +010068 {
69 std::lock_guard<std::mutex> lockGuard(m_Mutex);
70
71 // Stores the network
72 m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
73 }
telsoa014fcda012018-03-09 14:13:49 +000074
75 return Status::Success;
telsoa014fcda012018-03-09 14:13:49 +000076}
77
78Status Runtime::UnloadNetwork(NetworkId networkId)
79{
surmeh01bceff2f2018-03-29 16:29:27 +010080#ifdef ARMCOMPUTECL_ENABLED
81 if (arm_compute::CLScheduler::get().context()() != NULL)
82 {
telsoa01c577f2c2018-08-31 09:22:23 +010083 // Waits for all queued CL requests to finish before unloading the network they may be using.
surmeh013537c2c2018-05-18 16:31:43 +010084 try
85 {
86 // Coverity fix: arm_compute::CLScheduler::sync() may throw an exception of type cl::Error.
87 arm_compute::CLScheduler::get().sync();
88 }
89 catch (const cl::Error&)
90 {
91 BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): an error occurred while waiting for "
92 "the queued CL requests to finish";
93 return Status::Failure;
94 }
surmeh01bceff2f2018-03-29 16:29:27 +010095 }
96#endif
surmeh013537c2c2018-05-18 16:31:43 +010097
telsoa014fcda012018-03-09 14:13:49 +000098 {
telsoa01c577f2c2018-08-31 09:22:23 +010099 std::lock_guard<std::mutex> lockGuard(m_Mutex);
100
101 if (m_LoadedNetworks.erase(networkId) == 0)
102 {
103 BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
104 return Status::Failure;
105 }
106
telsoa014fcda012018-03-09 14:13:49 +0000107#ifdef ARMCOMPUTECL_ENABLED
telsoa01c577f2c2018-08-31 09:22:23 +0100108 if (arm_compute::CLScheduler::get().context()() != NULL && m_LoadedNetworks.empty())
109 {
110 // There are no loaded networks left, so clear the CL cache to free up memory
111 m_ClContextControl.ClearClCache();
112 }
telsoa014fcda012018-03-09 14:13:49 +0000113#endif
telsoa01c577f2c2018-08-31 09:22:23 +0100114 }
115
telsoa014fcda012018-03-09 14:13:49 +0000116 BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
117 return Status::Success;
118}
119
telsoa01c577f2c2018-08-31 09:22:23 +0100120const std::shared_ptr<IProfiler> Runtime::GetProfiler(NetworkId networkId) const
121{
122 auto it = m_LoadedNetworks.find(networkId);
123 if (it != m_LoadedNetworks.end())
124 {
125 auto& loadedNetwork = it->second;
126 return loadedNetwork->GetProfiler();
127 }
128
129 return nullptr;
130}
131
telsoa014fcda012018-03-09 14:13:49 +0000132Runtime::Runtime(const CreationOptions& options)
telsoa01c577f2c2018-08-31 09:22:23 +0100133 : m_ClContextControl(options.m_GpuAccTunedParameters.get(),
134 options.m_EnableGpuProfiling)
surmeh013537c2c2018-05-18 16:31:43 +0100135 , m_NetworkIdCounter(0)
telsoa014fcda012018-03-09 14:13:49 +0000136{
137 BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
telsoa014fcda012018-03-09 14:13:49 +0000138
telsoa01c577f2c2018-08-31 09:22:23 +0100139 m_DeviceSpec.m_SupportedComputeDevices.insert(armnn::Compute::CpuRef);
140 #if ARMCOMPUTECL_ENABLED
141 m_DeviceSpec.m_SupportedComputeDevices.insert(armnn::Compute::GpuAcc);
142 #endif
143 #if ARMCOMPUTENEON_ENABLED
144 m_DeviceSpec.m_SupportedComputeDevices.insert(armnn::Compute::CpuAcc);
145 #endif
surmeh01bceff2f2018-03-29 16:29:27 +0100146}
147
148Runtime::~Runtime()
149{
150 std::vector<int> networkIDs;
surmeh013537c2c2018-05-18 16:31:43 +0100151 try
152 {
153 // Coverity fix: The following code may throw an exception of type std::length_error.
154 std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
155 std::back_inserter(networkIDs),
156 [](const auto &pair) { return pair.first; });
157 }
158 catch (const std::exception& e)
159 {
160 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
161 // exception of type std::length_error.
162 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
163 std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
164 << "\nSome of the loaded networks may not be unloaded" << std::endl;
165 }
166 // We then proceed to unload all the networks which IDs have been appended to the list
167 // up to the point the exception was thrown (if any).
surmeh01bceff2f2018-03-29 16:29:27 +0100168
169 for (auto networkID : networkIDs)
170 {
surmeh013537c2c2018-05-18 16:31:43 +0100171 try
172 {
173 // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
174 // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
175 UnloadNetwork(networkID);
176 }
177 catch (const std::exception& e)
178 {
179 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
180 // exception of type std::length_error.
181 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
182 std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
183 << std::endl;
184 }
telsoa014fcda012018-03-09 14:13:49 +0000185 }
186}
187
surmeh013537c2c2018-05-18 16:31:43 +0100188LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
189{
190 std::lock_guard<std::mutex> lockGuard(m_Mutex);
191 return m_LoadedNetworks.at(networkId).get();
192}
193
telsoa014fcda012018-03-09 14:13:49 +0000194TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
195{
surmeh013537c2c2018-05-18 16:31:43 +0100196 return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
telsoa014fcda012018-03-09 14:13:49 +0000197}
198
199TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
200{
surmeh013537c2c2018-05-18 16:31:43 +0100201 return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
telsoa014fcda012018-03-09 14:13:49 +0000202}
203
204Status Runtime::EnqueueWorkload(NetworkId networkId,
telsoa01c577f2c2018-08-31 09:22:23 +0100205 const InputTensors& inputTensors,
206 const OutputTensors& outputTensors)
telsoa014fcda012018-03-09 14:13:49 +0000207{
surmeh013537c2c2018-05-18 16:31:43 +0100208 LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
209 return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
telsoa014fcda012018-03-09 14:13:49 +0000210}
211
212}