blob: a3f4a3930b99246c9e29920fee934e762726e7af [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "LoadedNetwork.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01008#include "DeviceSpec.hpp"
David Beckac42efd2018-09-26 17:41:13 +01009#include <armnn/INetwork.hpp>
10#include <armnn/IRuntime.hpp>
11#include <armnn/Tensor.hpp>
David Beckd4dfa682018-10-24 17:09:46 +010012#include <armnn/BackendId.hpp>
David Beck9efb57d2018-11-05 13:40:33 +000013
surmeh013537c2c2018-05-18 16:31:43 +010014#include <mutex>
telsoa014fcda012018-03-09 14:13:49 +000015#include <unordered_map>
16
17namespace armnn
18{
19
telsoa014fcda012018-03-09 14:13:49 +000020class Runtime final : public IRuntime
21{
22public:
telsoa01c577f2c2018-08-31 09:22:23 +010023 /// Loads a complete network into the Runtime.
24 /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference.
25 /// @param [in] network - Complete network to load into the Runtime.
telsoa014fcda012018-03-09 14:13:49 +000026 /// The runtime takes ownership of the network once passed in.
27 /// @return armnn::Status
28 virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override;
29
telsoa01c577f2c2018-08-31 09:22:23 +010030 /// Load a complete network into the IRuntime.
31 /// @param [out] networkIdOut Unique identifier for the network is returned in this reference.
32 /// @param [in] network Complete network to load into the IRuntime.
33 /// @param [out] errorMessage Error message if there were any errors.
34 /// The runtime takes ownership of the network once passed in.
35 /// @return armnn::Status
36 virtual Status LoadNetwork(NetworkId& networkIdOut,
37 IOptimizedNetworkPtr network,
38 std::string & errorMessage) override;
39
telsoa014fcda012018-03-09 14:13:49 +000040 virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
41 virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
42
telsoa01c577f2c2018-08-31 09:22:23 +010043 // Evaluates network using input in inputTensors, outputs filled into outputTensors.
telsoa014fcda012018-03-09 14:13:49 +000044 virtual Status EnqueueWorkload(NetworkId networkId,
45 const InputTensors& inputTensors,
46 const OutputTensors& outputTensors) override;
47
telsoa01c577f2c2018-08-31 09:22:23 +010048 /// Unloads a network from the Runtime.
telsoa014fcda012018-03-09 14:13:49 +000049 /// At the moment this only removes the network from the m_Impl->m_Network.
50 /// This might need more work in the future to be AndroidNN compliant.
51 /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork().
52 /// @return armnn::Status
53 virtual Status UnloadNetwork(NetworkId networkId) override;
54
telsoa01c577f2c2018-08-31 09:22:23 +010055 virtual const IDeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; }
56
57 /// Gets the profiler corresponding to the given network id.
58 /// @param networkId The id of the network for which to get the profile.
59 /// @return A pointer to the requested profiler, or nullptr if not found.
60 virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const override;
telsoa014fcda012018-03-09 14:13:49 +000061
62 /// Creates a runtime for workload execution.
63 /// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but
64 /// it cannot be setup for some reason.
65 Runtime(const CreationOptions& options);
66
surmeh01bceff2f2018-03-29 16:29:27 +010067 ~Runtime();
68
telsoa014fcda012018-03-09 14:13:49 +000069private:
telsoa01c577f2c2018-08-31 09:22:23 +010070 friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp
telsoa014fcda012018-03-09 14:13:49 +000071
72 int GenerateNetworkId();
73
surmeh013537c2c2018-05-18 16:31:43 +010074 LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const;
75
Derek Lamberti03614f62018-10-02 15:52:46 +010076 template<typename Func>
77 void LoadedNetworkFuncSafe(NetworkId networkId, Func f)
78 {
79 std::lock_guard<std::mutex> lockGuard(m_Mutex);
80 auto iter = m_LoadedNetworks.find(networkId);
81 if (iter != m_LoadedNetworks.end())
82 {
83 f(iter->second.get());
84 }
85 }
86
surmeh013537c2c2018-05-18 16:31:43 +010087 mutable std::mutex m_Mutex;
David Beckd4dfa682018-10-24 17:09:46 +010088
David Beck9efb57d2018-11-05 13:40:33 +000089 std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks;
David Beck1b61be52018-11-08 09:19:14 +000090 std::unordered_map<BackendId, IBackendInternal::IBackendContextPtr> m_BackendContexts;
David Beck9efb57d2018-11-05 13:40:33 +000091
92 int m_NetworkIdCounter;
93
94 DeviceSpec m_DeviceSpec;
telsoa014fcda012018-03-09 14:13:49 +000095};
96
97}