blob: b0a393d3f806b911ece9fdecae22dd2fd038390b [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
7#include "LoadedNetwork.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01008#include "DeviceSpec.hpp"
telsoa014fcda012018-03-09 14:13:49 +00009#include "armnn/INetwork.hpp"
10#include "armnn/IRuntime.hpp"
11#include "armnn/Tensor.hpp"
surmeh013537c2c2018-05-18 16:31:43 +010012#include "backends/ClContextControl.hpp"
telsoa014fcda012018-03-09 14:13:49 +000013
surmeh013537c2c2018-05-18 16:31:43 +010014#include <mutex>
telsoa014fcda012018-03-09 14:13:49 +000015#include <unordered_map>
16
17namespace armnn
18{
19
telsoa014fcda012018-03-09 14:13:49 +000020class Runtime final : public IRuntime
21{
22public:
telsoa01c577f2c2018-08-31 09:22:23 +010023 /// Loads a complete network into the Runtime.
24 /// @param [out] networkIdOut - Unique identifier for the network is returned in this reference.
25 /// @param [in] network - Complete network to load into the Runtime.
telsoa014fcda012018-03-09 14:13:49 +000026 /// The runtime takes ownership of the network once passed in.
27 /// @return armnn::Status
28 virtual Status LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr network) override;
29
telsoa01c577f2c2018-08-31 09:22:23 +010030 /// Load a complete network into the IRuntime.
31 /// @param [out] networkIdOut Unique identifier for the network is returned in this reference.
32 /// @param [in] network Complete network to load into the IRuntime.
33 /// @param [out] errorMessage Error message if there were any errors.
34 /// The runtime takes ownership of the network once passed in.
35 /// @return armnn::Status
36 virtual Status LoadNetwork(NetworkId& networkIdOut,
37 IOptimizedNetworkPtr network,
38 std::string & errorMessage) override;
39
telsoa014fcda012018-03-09 14:13:49 +000040 virtual TensorInfo GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
41 virtual TensorInfo GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const override;
42
telsoa01c577f2c2018-08-31 09:22:23 +010043 // Evaluates network using input in inputTensors, outputs filled into outputTensors.
telsoa014fcda012018-03-09 14:13:49 +000044 virtual Status EnqueueWorkload(NetworkId networkId,
45 const InputTensors& inputTensors,
46 const OutputTensors& outputTensors) override;
47
telsoa01c577f2c2018-08-31 09:22:23 +010048 /// Unloads a network from the Runtime.
telsoa014fcda012018-03-09 14:13:49 +000049 /// At the moment this only removes the network from the m_Impl->m_Network.
50 /// This might need more work in the future to be AndroidNN compliant.
51 /// @param [in] networkId Unique identifier for the network to be unloaded. Generated in LoadNetwork().
52 /// @return armnn::Status
53 virtual Status UnloadNetwork(NetworkId networkId) override;
54
telsoa01c577f2c2018-08-31 09:22:23 +010055 virtual const IDeviceSpec& GetDeviceSpec() const override { return m_DeviceSpec; }
56
57 /// Gets the profiler corresponding to the given network id.
58 /// @param networkId The id of the network for which to get the profile.
59 /// @return A pointer to the requested profiler, or nullptr if not found.
60 virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const override;
telsoa014fcda012018-03-09 14:13:49 +000061
62 /// Creates a runtime for workload execution.
63 /// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but
64 /// it cannot be setup for some reason.
65 Runtime(const CreationOptions& options);
66
surmeh01bceff2f2018-03-29 16:29:27 +010067 ~Runtime();
68
telsoa014fcda012018-03-09 14:13:49 +000069private:
telsoa01c577f2c2018-08-31 09:22:23 +010070 friend void RuntimeLoadedNetworksReserve(armnn::Runtime* runtime); // See RuntimeTests.cpp
telsoa014fcda012018-03-09 14:13:49 +000071
72 int GenerateNetworkId();
73
surmeh013537c2c2018-05-18 16:31:43 +010074 LoadedNetwork* GetLoadedNetworkPtr(NetworkId networkId) const;
75
76 mutable std::mutex m_Mutex;
77
telsoa014fcda012018-03-09 14:13:49 +000078 std::unordered_map<NetworkId, std::unique_ptr<LoadedNetwork>> m_LoadedNetworks;
79
surmeh013537c2c2018-05-18 16:31:43 +010080 ClContextControl m_ClContextControl;
telsoa014fcda012018-03-09 14:13:49 +000081
82 int m_NetworkIdCounter;
83
telsoa014fcda012018-03-09 14:13:49 +000084 DeviceSpec m_DeviceSpec;
85};
86
87}