blob: 0ca3446e1bf05072cc65d0c4c1dcbd1fbbab9447 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include "Runtime.hpp"
6
7#include "armnn/Version.hpp"
8
surmeh013537c2c2018-05-18 16:31:43 +01009#include <iostream>
10
telsoa014fcda012018-03-09 14:13:49 +000011#ifdef ARMCOMPUTECL_ENABLED
12#include <arm_compute/core/CL/OpenCL.h>
13#include <arm_compute/core/CL/CLKernelLibrary.h>
surmeh01bceff2f2018-03-29 16:29:27 +010014#include <arm_compute/runtime/CL/CLScheduler.h>
telsoa014fcda012018-03-09 14:13:49 +000015#endif
16
17#include <boost/log/trivial.hpp>
18#include <boost/polymorphic_cast.hpp>
19
20using namespace armnn;
21using namespace std;
22
23namespace armnn
24{
25
26IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
27{
28 return new Runtime(options);
29}
30
31IRuntimePtr IRuntime::Create(const CreationOptions& options)
32{
33 return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
34}
35
36void IRuntime::Destroy(IRuntime* runtime)
37{
38 delete boost::polymorphic_downcast<Runtime*>(runtime);
39}
40
41int Runtime::GenerateNetworkId()
42{
43 return m_NetworkIdCounter++;
44}
45
46Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
47{
48 IOptimizedNetwork* rawNetwork = inNetwork.release();
49 unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
50 std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
surmeh013537c2c2018-05-18 16:31:43 +010051 m_UseCpuRefAsFallback);
telsoa014fcda012018-03-09 14:13:49 +000052
53 if (!loadedNetwork)
54 {
55 return Status::Failure;
56 }
57
surmeh013537c2c2018-05-18 16:31:43 +010058 std::lock_guard<std::mutex> lockGuard(m_Mutex);
59
telsoa014fcda012018-03-09 14:13:49 +000060 networkIdOut = GenerateNetworkId();
61
62 // store the network
63 m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
64
65 return Status::Success;
telsoa014fcda012018-03-09 14:13:49 +000066}
67
68Status Runtime::UnloadNetwork(NetworkId networkId)
69{
surmeh01bceff2f2018-03-29 16:29:27 +010070#ifdef ARMCOMPUTECL_ENABLED
71 if (arm_compute::CLScheduler::get().context()() != NULL)
72 {
surmeh013537c2c2018-05-18 16:31:43 +010073 // wait for all queued CL requests to finish before unloading the network they may be using
74 try
75 {
76 // Coverity fix: arm_compute::CLScheduler::sync() may throw an exception of type cl::Error.
77 arm_compute::CLScheduler::get().sync();
78 }
79 catch (const cl::Error&)
80 {
81 BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): an error occurred while waiting for "
82 "the queued CL requests to finish";
83 return Status::Failure;
84 }
surmeh01bceff2f2018-03-29 16:29:27 +010085 }
86#endif
surmeh013537c2c2018-05-18 16:31:43 +010087 std::lock_guard<std::mutex> lockGuard(m_Mutex);
88
telsoa014fcda012018-03-09 14:13:49 +000089 if (m_LoadedNetworks.erase(networkId) == 0)
90 {
91 BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
92 return Status::Failure;
93 }
94#ifdef ARMCOMPUTECL_ENABLED
surmeh01bceff2f2018-03-29 16:29:27 +010095 if (arm_compute::CLScheduler::get().context()() != NULL && m_LoadedNetworks.empty())
96 {
surmeh013537c2c2018-05-18 16:31:43 +010097 // There are no loaded networks left, so clear the CL cache to free up memory
98 m_ClContextControl.ClearClCache();
surmeh01bceff2f2018-03-29 16:29:27 +010099 }
telsoa014fcda012018-03-09 14:13:49 +0000100#endif
101 BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
102 return Status::Success;
103}
104
105Runtime::Runtime(const CreationOptions& options)
surmeh013537c2c2018-05-18 16:31:43 +0100106 : m_ClContextControl(options.m_ClTunedParameters)
107 , m_NetworkIdCounter(0)
telsoa014fcda012018-03-09 14:13:49 +0000108{
109 BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
110 BOOST_LOG_TRIVIAL(info) << "Using compute device: " << options.m_DefaultComputeDevice << "\n";
111 m_DeviceSpec.DefaultComputeDevice = options.m_DefaultComputeDevice;
112
surmeh013537c2c2018-05-18 16:31:43 +0100113 // If useCpuRefAsFallback is false, the reference workload factory will be prevented from creating
114 // operation workloads, unless the default compute device is precisely the reference backend.
115 // This option is passed to the LoadedNetwork, which owns the workload factories.
116 m_UseCpuRefAsFallback = options.m_DefaultComputeDevice == Compute::CpuRef || options.m_UseCpuRefAsFallback;
surmeh01bceff2f2018-03-29 16:29:27 +0100117}
118
119Runtime::~Runtime()
120{
121 std::vector<int> networkIDs;
surmeh013537c2c2018-05-18 16:31:43 +0100122 try
123 {
124 // Coverity fix: The following code may throw an exception of type std::length_error.
125 std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
126 std::back_inserter(networkIDs),
127 [](const auto &pair) { return pair.first; });
128 }
129 catch (const std::exception& e)
130 {
131 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
132 // exception of type std::length_error.
133 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
134 std::cerr << "WARNING: An error has occurred when getting the IDs of the networks to unload: " << e.what()
135 << "\nSome of the loaded networks may not be unloaded" << std::endl;
136 }
137 // We then proceed to unload all the networks which IDs have been appended to the list
138 // up to the point the exception was thrown (if any).
surmeh01bceff2f2018-03-29 16:29:27 +0100139
140 for (auto networkID : networkIDs)
141 {
surmeh013537c2c2018-05-18 16:31:43 +0100142 try
143 {
144 // Coverity fix: UnloadNetwork() may throw an exception of type std::length_error,
145 // boost::log::v2s_mt_posix::odr_violation or boost::log::v2s_mt_posix::system_error
146 UnloadNetwork(networkID);
147 }
148 catch (const std::exception& e)
149 {
150 // Coverity fix: BOOST_LOG_TRIVIAL (typically used to report errors) may throw an
151 // exception of type std::length_error.
152 // Using stderr instead in this context as there is no point in nesting try-catch blocks here.
153 std::cerr << "WARNING: An error has occurred when unloading network " << networkID << ": " << e.what()
154 << std::endl;
155 }
telsoa014fcda012018-03-09 14:13:49 +0000156 }
157}
158
surmeh013537c2c2018-05-18 16:31:43 +0100159LoadedNetwork* Runtime::GetLoadedNetworkPtr(NetworkId networkId) const
160{
161 std::lock_guard<std::mutex> lockGuard(m_Mutex);
162 return m_LoadedNetworks.at(networkId).get();
163}
164
telsoa014fcda012018-03-09 14:13:49 +0000165TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
166{
surmeh013537c2c2018-05-18 16:31:43 +0100167 return GetLoadedNetworkPtr(networkId)->GetInputTensorInfo(layerId);
telsoa014fcda012018-03-09 14:13:49 +0000168}
169
170TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
171{
surmeh013537c2c2018-05-18 16:31:43 +0100172 return GetLoadedNetworkPtr(networkId)->GetOutputTensorInfo(layerId);
telsoa014fcda012018-03-09 14:13:49 +0000173}
174
175Status Runtime::EnqueueWorkload(NetworkId networkId,
176 const InputTensors& inputTensors,
177 const OutputTensors& outputTensors)
178{
surmeh013537c2c2018-05-18 16:31:43 +0100179 LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
180 return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
telsoa014fcda012018-03-09 14:13:49 +0000181}
182
183}