blob: 2a7c6f36d9e1261b5198050d18e1f25b13006a68 [file] [log] [blame]
Derek Lamberti84da38b2019-06-13 11:40:08 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Matteo Martincighe5b8eb92019-11-28 15:45:42 +00008#include <armnn/backends/ITensorHandleFactory.hpp>
Derek Lamberti84da38b2019-06-13 11:40:08 +01009
10#include <memory>
11#include <vector>
12
13namespace armnn
14{
15
16//Forward
17class IMemoryManager;
18
Narumol Prangnawaratb275da52021-12-17 17:27:37 +000019using CopyAndImportFactoryPairs = std::map<ITensorHandleFactory::FactoryId, ITensorHandleFactory::FactoryId>;
20
Derek Lamberti84da38b2019-06-13 11:40:08 +010021///
22class TensorHandleFactoryRegistry
23{
24public:
25 TensorHandleFactoryRegistry() = default;
26
27 TensorHandleFactoryRegistry(const TensorHandleFactoryRegistry& other) = delete;
28 TensorHandleFactoryRegistry(TensorHandleFactoryRegistry&& other) = delete;
29
30 /// Register a TensorHandleFactory and transfer ownership
31 void RegisterFactory(std::unique_ptr<ITensorHandleFactory> allocator);
32
33 /// Register a memory manager with shared ownership
34 void RegisterMemoryManager(std::shared_ptr<IMemoryManager> memoryManger);
35
36 /// Find a TensorHandleFactory by Id
37 /// Returns nullptr if not found
38 ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id) const;
39
Francis Murtagh73d3e2e2021-04-29 14:23:04 +010040 /// Overload of above allowing specification of Memory Source
41 ITensorHandleFactory* GetFactory(ITensorHandleFactory::FactoryId id,
42 MemorySource memSource) const;
43
Narumol Prangnawaratb275da52021-12-17 17:27:37 +000044 /// Register a pair of TensorHandleFactory Id for Memory Copy and TensorHandleFactory Id for Memory Import
45 void RegisterCopyAndImportFactoryPair(ITensorHandleFactory::FactoryId copyFactoryId,
46 ITensorHandleFactory::FactoryId importFactoryId);
47
48 /// Get a matching TensorHandleFatory Id for Memory Import given TensorHandleFactory Id for Memory Copy
49 ITensorHandleFactory::FactoryId GetMatchingImportFactoryId(ITensorHandleFactory::FactoryId copyFactoryId);
50
Derek Lamberti84da38b2019-06-13 11:40:08 +010051 /// Aquire memory required for inference
52 void AquireMemory();
53
54 /// Release memory required for inference
55 void ReleaseMemory();
56
Finn Williams01097942021-04-26 12:06:34 +010057 std::vector<std::shared_ptr<IMemoryManager>>& GetMemoryManagers()
58 {
59 return m_MemoryManagers;
60 }
61
Derek Lamberti84da38b2019-06-13 11:40:08 +010062private:
63 std::vector<std::unique_ptr<ITensorHandleFactory>> m_Factories;
64 std::vector<std::shared_ptr<IMemoryManager>> m_MemoryManagers;
Narumol Prangnawaratb275da52021-12-17 17:27:37 +000065 CopyAndImportFactoryPairs m_FactoryMappings;
Derek Lamberti84da38b2019-06-13 11:40:08 +010066};
67
68} // namespace armnn