blob: e0d6a9add0fd4db85e37ece734bb426e57b2881b [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include "Runtime.hpp"
6
7#include "armnn/Version.hpp"
8
9#ifdef ARMCOMPUTECL_ENABLED
10#include <arm_compute/core/CL/OpenCL.h>
11#include <arm_compute/core/CL/CLKernelLibrary.h>
surmeh01bceff2f2018-03-29 16:29:27 +010012#include <arm_compute/runtime/CL/CLScheduler.h>
telsoa014fcda012018-03-09 14:13:49 +000013#endif
14
15#include <boost/log/trivial.hpp>
16#include <boost/polymorphic_cast.hpp>
17
18using namespace armnn;
19using namespace std;
20
21namespace armnn
22{
23
24IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
25{
26 return new Runtime(options);
27}
28
29IRuntimePtr IRuntime::Create(const CreationOptions& options)
30{
31 return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
32}
33
34void IRuntime::Destroy(IRuntime* runtime)
35{
36 delete boost::polymorphic_downcast<Runtime*>(runtime);
37}
38
39int Runtime::GenerateNetworkId()
40{
41 return m_NetworkIdCounter++;
42}
43
44Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
45{
46 IOptimizedNetwork* rawNetwork = inNetwork.release();
47 unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
48 std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
49 m_WorkloadFactories);
50
51 if (!loadedNetwork)
52 {
53 return Status::Failure;
54 }
55
56 networkIdOut = GenerateNetworkId();
57
58 // store the network
59 m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
60
61 return Status::Success;
telsoa014fcda012018-03-09 14:13:49 +000062}
63
64Status Runtime::UnloadNetwork(NetworkId networkId)
65{
surmeh01bceff2f2018-03-29 16:29:27 +010066#ifdef ARMCOMPUTECL_ENABLED
67 if (arm_compute::CLScheduler::get().context()() != NULL)
68 {
69 arm_compute::CLScheduler::get().sync();
70 }
71#endif
telsoa014fcda012018-03-09 14:13:49 +000072 if (m_LoadedNetworks.erase(networkId) == 0)
73 {
74 BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
75 return Status::Failure;
76 }
77#ifdef ARMCOMPUTECL_ENABLED
surmeh01bceff2f2018-03-29 16:29:27 +010078 if (arm_compute::CLScheduler::get().context()() != NULL && m_LoadedNetworks.empty())
79 {
80 m_WorkloadFactories.m_GpuAcc.get()->LoadOpenClRuntime();
81 }
telsoa014fcda012018-03-09 14:13:49 +000082#endif
83 BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
84 return Status::Success;
85}
86
87Runtime::Runtime(const CreationOptions& options)
88: m_NetworkIdCounter(0)
89{
90 BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
91 BOOST_LOG_TRIVIAL(info) << "Using compute device: " << options.m_DefaultComputeDevice << "\n";
92 m_DeviceSpec.DefaultComputeDevice = options.m_DefaultComputeDevice;
93
94 // If useCpuRefAsFallback is false, the reference workload factory will be prevented from creating
95 // operation workloads, unless the default compute device is precisely the reference backend.
96 m_WorkloadFactories.m_CpuRef = make_shared<RefWorkloadFactory>(
97 options.m_DefaultComputeDevice == Compute::CpuRef ? true : options.m_UseCpuRefAsFallback);
98 m_WorkloadFactories.m_CpuAcc = make_shared<NeonWorkloadFactory>();
surmeh01bceff2f2018-03-29 16:29:27 +010099 m_WorkloadFactories.m_GpuAcc = make_shared<ClWorkloadFactory>(options.m_ClTunedParameters);
telsoa014fcda012018-03-09 14:13:49 +0000100
101 if (options.m_DefaultComputeDevice == Compute::GpuAcc)
102 {
surmeh01bceff2f2018-03-29 16:29:27 +0100103 m_WorkloadFactories.m_GpuAcc.get()->LoadOpenClRuntime();
104 }
105}
106
107Runtime::~Runtime()
108{
109 std::vector<int> networkIDs;
110 std::transform(m_LoadedNetworks.begin(), m_LoadedNetworks.end(),
111 std::back_inserter(networkIDs),
112 [](const auto &pair) { return pair.first; });
113
114 for (auto networkID : networkIDs)
115 {
116 UnloadNetwork(networkID);
telsoa014fcda012018-03-09 14:13:49 +0000117 }
118}
119
120TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
121{
122 LoadedNetwork* net = m_LoadedNetworks.at(networkId).get();
123 return net->GetInputTensorInfo(layerId);
124}
125
126TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
127{
128 const LoadedNetwork* net = m_LoadedNetworks.at(networkId).get();
129 return net->GetOutputTensorInfo(layerId);
130}
131
132Status Runtime::EnqueueWorkload(NetworkId networkId,
133 const InputTensors& inputTensors,
134 const OutputTensors& outputTensors)
135{
136 LoadedNetwork* loadedNetwork = m_LoadedNetworks.at(networkId).get();
137 return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors, m_WorkloadFactories);
138}
139
140}