telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2017 Arm Ltd. All rights reserved. |
David Beck | ecb56cd | 2018-09-05 12:52:57 +0100 | [diff] [blame^] | 3 | // SPDX-License-Identifier: MIT |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 4 | // |
| 5 | #pragma once |
| 6 | |
| 7 | #include "LoadedNetwork.hpp" |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 8 | #include "DeviceSpec.hpp" |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 9 | #include "armnn/INetwork.hpp" |
| 10 | #include "armnn/IRuntime.hpp" |
| 11 | #include "armnn/Tensor.hpp" |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 12 | #include "backends/ClContextControl.hpp" |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 13 | |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 14 | #include <mutex> |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 15 | #include <unordered_map> |
| 16 | |
| 17 | namespace armnn |
| 18 | { |
| 19 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 20 | class Runtime final : public IRuntime |
| 21 | { |
| 22 | public: |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 23 | /// Loads a complete network into the Runtime. |
| 24 | /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference. |
| 25 | /// @param [in] network - Complete network to load into the Runtime. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 26 | /// The runtime takes ownership of the network once passed in. |
| 27 | /// @return armnn::Status |
| 28 | virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override; |
| 29 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 30 | /// Load a complete network into the IRuntime. |
| 31 | /// @param [out] networkIdOut Unique identifier for the network is returned in this reference. |
| 32 | /// @param [in] network Complete network to load into the IRuntime. |
| 33 | /// @param [out] errorMessage Error message if there were any errors. |
| 34 | /// The runtime takes ownership of the network once passed in. |
| 35 | /// @return armnn::Status |
| 36 | virtual Status LoadNetwork(NetworkId& networkIdOut, |
| 37 | IOptimizedNetworkPtr network, |
| 38 | std::string & errorMessage) override; |
| 39 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 40 | virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override; |
| 41 | virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override; |
| 42 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 43 | // Evaluates network using input in inputTensors, outputs filled into outputTensors. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 44 | virtual Status EnqueueWorkload(NetworkId networkId, |
| 45 | const InputTensors& inputTensors, |
| 46 | const OutputTensors& outputTensors) override; |
| 47 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 48 | /// Unloads a network from the Runtime. |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 49 | /// At the moment this only removes the network from the m_Impl->m_Network. |
| 50 | /// This might need more work in the future to be AndroidNN compliant. |
| 51 | /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork(). |
| 52 | /// @return armnn::Status |
| 53 | virtual Status UnloadNetwork(NetworkId networkId) override; |
| 54 | |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 55 | virtual const IDeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; } |
| 56 | |
| 57 | /// Gets the profiler corresponding to the given network id. |
| 58 | /// @param networkId The id of the network for which to get the profile. |
| 59 | /// @return A pointer to the requested profiler, or nullptr if not found. |
| 60 | virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const override; |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 61 | |
| 62 | /// Creates a runtime for workload execution. |
| 63 | /// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but |
| 64 | /// it cannot be setup for some reason. |
| 65 | Runtime(const CreationOptions& options); |
| 66 | |
surmeh01 | bceff2f | 2018-03-29 16:29:27 +0100 | [diff] [blame] | 67 | ~Runtime(); |
| 68 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 69 | private: |
telsoa01 | c577f2c | 2018-08-31 09:22:23 +0100 | [diff] [blame] | 70 | friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 71 | |
| 72 | int GenerateNetworkId(); |
| 73 | |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 74 | LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const; |
| 75 | |
| 76 | mutable std::mutex m_Mutex; |
| 77 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 78 | std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks; |
| 79 | |
surmeh01 | 3537c2c | 2018-05-18 16:31:43 +0100 | [diff] [blame] | 80 | ClContextControl m_ClContextControl; |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 81 | |
| 82 | int m_NetworkIdCounter; |
| 83 | |
telsoa01 | 4fcda01 | 2018-03-09 14:13:49 +0000 | [diff] [blame] | 84 | DeviceSpec m_DeviceSpec; |
| 85 | }; |
| 86 | |
| 87 | } |