blob: 434b64ffafbe40334c72a651e9c9fb00f68f44a3 [file] [log] [blame]
Derek Lamberti84da38b2019-06-13 11:40:08 +01001//
Colm Donelana98e79a2022-12-06 21:32:29 +00002// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
Derek Lamberti84da38b2019-06-13 11:40:08 +01003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Colm Donelana98e79a2022-12-06 21:32:29 +00008#include <armnn/Types.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00009#include <armnn/backends/ITensorHandleFactory.hpp>
Colm Donelana98e79a2022-12-06 21:32:29 +000010#include <map>
Derek Lamberti84da38b2019-06-13 11:40:08 +010011#include <memory>
12#include <vector>
13
14namespace armnn
15{
16
17//Forward
18class IMemoryManager;
19
Narumol Prangnawaratb275da52021-12-17 17:27:37 +000020using CopyAndImportFactoryPairs = std::map<ITensorHandleFactory::FactoryId, ITensorHandleFactory::FactoryId>;
21
Derek Lamberti84da38b2019-06-13 11:40:08 +010022///
23class TensorHandleFactoryRegistry
24{
25public:
26 TensorHandleFactoryRegistry() = default;
27
28 TensorHandleFactoryRegistry(const TensorHandleFactoryRegistry& other) = delete;
29 TensorHandleFactoryRegistry(TensorHandleFactoryRegistry&& other) = delete;
30
31 /// Register a TensorHandleFactory and transfer ownership
32 void RegisterFactory(std::unique_ptr<ITensorHandleFactory> allocator);
33
34 /// Register a memory manager with shared ownership
35 void RegisterMemoryManager(std::shared_ptr<IMemoryManager> memoryManger);
36
37 /// Find a TensorHandleFactory by Id
38 /// Returns nullptr if not found
39 ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id) const;
40
Francis Murtagh73d3e2e2021-04-29 14:23:04 +010041 /// Overload of above allowing specification of Memory Source
42 ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id,
43 MemorySource memSource) const;
44
Narumol Prangnawaratb275da52021-12-17 17:27:37 +000045 /// Register a pair of TensorHandleFactory Id for Memory Copy and TensorHandleFactory Id for Memory Import
46 void RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId,
47 ITensorHandleFactory::FactoryId importFactoryId);
48
49 /// Get a matching TensorHandleFatory Id for Memory Import given TensorHandleFactory Id for Memory Copy
50 ITensorHandleFactory::FactoryId GetMatchingImportFactoryId(ITensorHandleFactory::FactoryId copyFactoryId);
51
Derek Lamberti84da38b2019-06-13 11:40:08 +010052 /// Aquire memory required for inference
53 void AquireMemory();
54
55 /// Release memory required for inference
56 void ReleaseMemory();
57
Finn Williams01097942021-04-26 12:06:34 +010058 std::vector<std::shared_ptr<IMemoryManager>>& GetMemoryManagers()
59 {
60 return m_MemoryManagers;
61 }
62
Derek Lamberti84da38b2019-06-13 11:40:08 +010063private:
64 std::vector<std::unique_ptr<ITensorHandleFactory>> m_Factories;
65 std::vector<std::shared_ptr<IMemoryManager>> m_MemoryManagers;
Narumol Prangnawaratb275da52021-12-17 17:27:37 +000066 CopyAndImportFactoryPairs m_FactoryMappings;
Derek Lamberti84da38b2019-06-13 11:40:08 +010067};
68
69} // namespace armnn