blob: ea6d19bd31ce70b50129a8a91d62a25c9f2500b5 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include "Runtime.hpp"
6
7#include "armnn/Version.hpp"
8
9#ifdef ARMCOMPUTECL_ENABLED
10#include <arm_compute/core/CL/OpenCL.h>
11#include <arm_compute/core/CL/CLKernelLibrary.h>
12#endif
13
14#include <boost/log/trivial.hpp>
15#include <boost/polymorphic_cast.hpp>
16
17using namespace armnn;
18using namespace std;
19
20namespace armnn
21{
22
23IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
24{
25 return new Runtime(options);
26}
27
28IRuntimePtr IRuntime::Create(const CreationOptions& options)
29{
30 return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
31}
32
33void IRuntime::Destroy(IRuntime* runtime)
34{
35 delete boost::polymorphic_downcast<Runtime*>(runtime);
36}
37
38int Runtime::GenerateNetworkId()
39{
40 return m_NetworkIdCounter++;
41}
42
43Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
44{
45 IOptimizedNetwork* rawNetwork = inNetwork.release();
46 unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
47 std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
48 m_WorkloadFactories);
49
50 if (!loadedNetwork)
51 {
52 return Status::Failure;
53 }
54
55 networkIdOut = GenerateNetworkId();
56
57 // store the network
58 m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
59
60 return Status::Success;
61
62}
63
64Status Runtime::UnloadNetwork(NetworkId networkId)
65{
66 if (m_LoadedNetworks.erase(networkId) == 0)
67 {
68 BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
69 return Status::Failure;
70 }
71#ifdef ARMCOMPUTECL_ENABLED
72 arm_compute::CLKernelLibrary::get().clear_programs_cache();
73#endif
74 BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
75 return Status::Success;
76}
77
78Runtime::Runtime(const CreationOptions& options)
79: m_NetworkIdCounter(0)
80{
81 BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
82 BOOST_LOG_TRIVIAL(info) << "Using compute device: " << options.m_DefaultComputeDevice << "\n";
83 m_DeviceSpec.DefaultComputeDevice = options.m_DefaultComputeDevice;
84
85 // If useCpuRefAsFallback is false, the reference workload factory will be prevented from creating
86 // operation workloads, unless the default compute device is precisely the reference backend.
87 m_WorkloadFactories.m_CpuRef = make_shared<RefWorkloadFactory>(
88 options.m_DefaultComputeDevice == Compute::CpuRef ? true : options.m_UseCpuRefAsFallback);
89 m_WorkloadFactories.m_CpuAcc = make_shared<NeonWorkloadFactory>();
90 m_WorkloadFactories.m_GpuAcc = make_shared<ClWorkloadFactory>();
91
92 if (options.m_DefaultComputeDevice == Compute::GpuAcc)
93 {
94 m_WorkloadFactories.m_GpuAcc.get()->LoadOpenClRuntime(options.m_ClTunedParameters);
95 }
96}
97
98TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
99{
100 LoadedNetwork* net = m_LoadedNetworks.at(networkId).get();
101 return net->GetInputTensorInfo(layerId);
102}
103
104TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
105{
106 const LoadedNetwork* net = m_LoadedNetworks.at(networkId).get();
107 return net->GetOutputTensorInfo(layerId);
108}
109
110Status Runtime::EnqueueWorkload(NetworkId networkId,
111 const InputTensors& inputTensors,
112 const OutputTensors& outputTensors)
113{
114 LoadedNetwork* loadedNetwork = m_LoadedNetworks.at(networkId).get();
115 return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors, m_WorkloadFactories);
116}
117
118}