blob: 03a741fb758451c882e92160b95b45f1b7339c7f [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
David Beckb4540be2018-09-24 13:18:27 +01007#include <armnn/Tensor.hpp>
8#include <armnn/Types.hpp>
9
telsoa014fcda012018-03-09 14:13:49 +000010#include "Network.hpp"
11#include "LayerFwd.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010012#include "Profiling.hpp"
David Beckb4540be2018-09-24 13:18:27 +010013
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +000014#include <backendsCommon/IBackendInternal.hpp>
15#include <backendsCommon/Workload.hpp>
16#include <backendsCommon/WorkloadFactory.hpp>
telsoa014fcda012018-03-09 14:13:49 +000017
Derek Lamberti03614f62018-10-02 15:52:46 +010018#include <mutex>
David Beck29c75de2018-10-23 13:35:58 +010019#include <unordered_map>
Derek Lamberti03614f62018-10-02 15:52:46 +010020
telsoa014fcda012018-03-09 14:13:49 +000021namespace cl
22{
23 class Context;
24 class CommandQueue;
25 class Device;
26}
27
28namespace armnn
29{
30
telsoa014fcda012018-03-09 14:13:49 +000031class LoadedNetwork
32{
33public:
Derek Lamberti03614f62018-10-02 15:52:46 +010034 using WorkloadQueue = std::vector< std::unique_ptr<IWorkload> >;
35 ~LoadedNetwork(){ FreeWorkingMemory(); }
36
telsoa014fcda012018-03-09 14:13:49 +000037 TensorInfo GetInputTensorInfo(LayerBindingId layerId) const;
38 TensorInfo GetOutputTensorInfo(LayerBindingId layerId) const;
39
surmeh013537c2c2018-05-18 16:31:43 +010040 Status EnqueueWorkload(const InputTensors& inputTensors, const OutputTensors& outputTensors);
telsoa014fcda012018-03-09 14:13:49 +000041
42 static std::unique_ptr<LoadedNetwork> MakeLoadedNetwork(std::unique_ptr<OptimizedNetwork> net,
telsoa01c577f2c2018-08-31 09:22:23 +010043 std::string & errorMessage);
44
45 // NOTE we return by reference as the purpose of this method is only to provide
46 // access to the private m_Profiler and in theory we should not need to increment
47 // the shared_ptr's reference counter
48 const std::shared_ptr<Profiler>& GetProfiler() const { return m_Profiler; }
telsoa014fcda012018-03-09 14:13:49 +000049
Derek Lamberti03614f62018-10-02 15:52:46 +010050 void AllocateWorkingMemory();
51 void FreeWorkingMemory();
52
telsoa014fcda012018-03-09 14:13:49 +000053private:
David Beck9efb57d2018-11-05 13:40:33 +000054 LoadedNetwork(std::unique_ptr<OptimizedNetwork> net);
telsoa014fcda012018-03-09 14:13:49 +000055
surmeh013537c2c2018-05-18 16:31:43 +010056 void EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo);
telsoa014fcda012018-03-09 14:13:49 +000057
surmeh013537c2c2018-05-18 16:31:43 +010058 void EnqueueOutput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo);
telsoa014fcda012018-03-09 14:13:49 +000059
60 bool Execute();
61
surmeh013537c2c2018-05-18 16:31:43 +010062 const IWorkloadFactory& GetWorkloadFactory(const Layer& layer) const;
63
David Beck29c75de2018-10-23 13:35:58 +010064 using BackendPtrMap = std::unordered_map<BackendId, IBackendInternalUniquePtr>;
Aron Virginas-Tar5caf9072018-11-14 18:35:18 +000065
66 using WorkloadFactoryWithMemoryManager =
67 std::pair<IBackendInternal::IWorkloadFactoryPtr, IBackendInternal::IMemoryManagerSharedPtr>;
68
69 using WorkloadFactoryMap = std::unordered_map<BackendId, WorkloadFactoryWithMemoryManager>;
David Beck29c75de2018-10-23 13:35:58 +010070
71 BackendPtrMap m_Backends;
72 WorkloadFactoryMap m_WorkloadFactories;
telsoa014fcda012018-03-09 14:13:49 +000073
74 std::unique_ptr<OptimizedNetwork> m_OptimizedNetwork;
Derek Lamberti03614f62018-10-02 15:52:46 +010075 WorkloadQueue m_InputQueue;
76 WorkloadQueue m_WorkloadQueue;
77 WorkloadQueue m_OutputQueue;
telsoa01c577f2c2018-08-31 09:22:23 +010078 std::shared_ptr<Profiler> m_Profiler;
Derek Lamberti03614f62018-10-02 15:52:46 +010079
80 using UniqueMutexLock = std::unique_lock<std::mutex>;
81 mutable std::mutex m_WorkingMemMutex;
82 UniqueMutexLock m_WorkingMemLock;
83
84 bool m_IsWorkingMemAllocated=false;
telsoa014fcda012018-03-09 14:13:49 +000085};
86
87}