telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
David Beck | ecb56cd | 2018-09-05 12:52:57 +0100 | [diff] [blame] | 3 | // SPDX-License-Identifier: MIT |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 4 | // |
| 5 | #pragma once |
| 6 | |
| 7 | #include "LoadedNetwork.hpp" |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 8 | #include "DeviceSpec.hpp" |
David Beck | ac42efd | 2018-09-26 17:41:13 +0100 | [diff] [blame] | 9 | #include <armnn/INetwork.hpp> |
| 10 | #include <armnn/IRuntime.hpp> |
| 11 | #include <armnn/Tensor.hpp> |
David Beck | d4dfa68 | 2018-10-24 17:09:46 +0100 | [diff] [blame] | 12 | #include <armnn/BackendId.hpp> |
David Beck | 9efb57d | 2018-11-05 13:40:33 +0000 | [diff] [blame] | 13 | |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 14 | #include <mutex> |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 15 | #include <unordered_map> |
| 16 | |
| 17 | namespace armnn |
| 18 | { |
| 19 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 20 | class Runtime final : public IRuntime |
| 21 | { |
| 22 | public: |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 23 | /// Loads a complete network into the Runtime. |
| 24 | /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference. |
| 25 | /// @param [in] network - Complete network to load into the Runtime. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 26 | /// The runtime takes ownership of the network once passed in. |
| 27 | /// @return armnn::Status |
| 28 | virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override; |
| 29 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 30 | /// Load a complete network into the IRuntime. |
| 31 | /// @param [out] networkIdOut Unique identifier for the network is returned in this reference. |
| 32 | /// @param [in] network Complete network to load into the IRuntime. |
| 33 | /// @param [out] errorMessage Error message if there were any errors. |
| 34 | /// The runtime takes ownership of the network once passed in. |
| 35 | /// @return armnn::Status |
| 36 | virtual Status LoadNetwork(NetworkId& networkIdOut, |
| 37 | IOptimizedNetworkPtr network, |
| 38 | std::string & errorMessage) override; |
| 39 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 40 | virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override; |
| 41 | virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override; |
| 42 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 43 | // Evaluates network using input in inputTensors, outputs filled into outputTensors. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 44 | virtual Status EnqueueWorkload(NetworkId networkId, |
| 45 | const InputTensors& inputTensors, |
| 46 | const OutputTensors& outputTensors) override; |
| 47 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 48 | /// Unloads a network from the Runtime. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 49 | /// At the moment this only removes the network from the m_Impl->m_Network. |
| 50 | /// This might need more work in the future to be AndroidNN compliant. |
| 51 | /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork(). |
| 52 | /// @return armnn::Status |
| 53 | virtual Status UnloadNetwork(NetworkId networkId) override; |
| 54 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 55 | virtual const IDeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; } |
| 56 | |
| 57 | /// Gets the profiler corresponding to the given network id. |
| 58 | /// @param networkId The id of the network for which to get the profile. |
| 59 | /// @return A pointer to the requested profiler, or nullptr if not found. |
| 60 | virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const override; |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 61 | |
| 62 | /// Creates a runtime for workload execution. |
| 63 | /// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but |
| 64 | /// it cannot be setup for some reason. |
| 65 | Runtime(const CreationOptions& options); |
| 66 | |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 67 | ~Runtime(); |
| 68 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 69 | private: |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 70 | friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 71 | |
| 72 | int GenerateNetworkId(); |
| 73 | |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 74 | LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const; |
| 75 | |
Derek Lamberti | 03614f6 | 2018-10-02 15:52:46 +0100 | [diff] [blame] | 76 | template<typename Func> |
| 77 | void LoadedNetworkFuncSafe(NetworkId networkId, Func f) |
| 78 | { |
| 79 | std::lock_guard<std::mutex> lockGuard(m_Mutex); |
| 80 | auto iter = m_LoadedNetworks.find(networkId); |
| 81 | if (iter != m_LoadedNetworks.end()) |
| 82 | { |
| 83 | f(iter->second.get()); |
| 84 | } |
| 85 | } |
| 86 | |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 87 | mutable std::mutex m_Mutex; |
David Beck | d4dfa68 | 2018-10-24 17:09:46 +0100 | [diff] [blame] | 88 | |
David Beck | 9efb57d | 2018-11-05 13:40:33 +0000 | [diff] [blame] | 89 | std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks; |
David Beck | 1b61be5 | 2018-11-08 09:19:14 +0000 | [diff] [blame] | 90 | std::unordered_map<BackendId, IBackendInternal::IBackendContextPtr> m_BackendContexts; |
David Beck | 9efb57d | 2018-11-05 13:40:33 +0000 | [diff] [blame] | 91 | |
| 92 | int m_NetworkIdCounter; |
| 93 | |
| 94 | DeviceSpec m_DeviceSpec; |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 95 | }; |
| 96 | |
| 97 | } |