blob: 03bd338090aee5a709d889cb718d3372d29fa5b4 [file] [log] [blame]
David Beck3e9e1152018-10-17 14:17:50 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include <armnn/BackendId.hpp>
8#include <armnn/Exceptions.hpp>
Aron Virginas-Tarc9cc8042018-11-01 16:15:57 +00009
David Beck3e9e1152018-10-17 14:17:50 +010010#include <functional>
11#include <memory>
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010012#include <sstream>
13#include <string>
David Beck3e9e1152018-10-17 14:17:50 +010014#include <unordered_map>
15
16namespace armnn
17{
18
19template <typename RegisteredType>
20struct RegisteredTypeName
21{
22 static const char * Name() { return "UNKNOWN"; }
23};
24
David Beck9efb57d2018-11-05 13:40:33 +000025template <typename RegisteredType, typename PointerType>
David Beck3e9e1152018-10-17 14:17:50 +010026class RegistryCommon
27{
28public:
David Beck9efb57d2018-11-05 13:40:33 +000029 using FactoryFunction = std::function<PointerType()>;
David Beck3e9e1152018-10-17 14:17:50 +010030
31 void Register(const BackendId& id, FactoryFunction factory)
32 {
33 if (m_Factories.count(id) > 0)
34 {
35 throw InvalidArgumentException(
36 std::string(id) + " already registered as " + RegisteredTypeName<RegisteredType>::Name() + " factory",
37 CHECK_LOCATION());
38 }
39
40 m_Factories[id] = factory;
41 }
42
43 FactoryFunction GetFactory(const BackendId& id) const
44 {
45 auto it = m_Factories.find(id);
46 if (it == m_Factories.end())
47 {
48 throw InvalidArgumentException(
49 std::string(id) + " has no " + RegisteredTypeName<RegisteredType>::Name() + " factory registered",
50 CHECK_LOCATION());
51 }
52
53 return it->second;
54 }
55
56 size_t Size() const
57 {
58 return m_Factories.size();
59 }
60
61 BackendIdSet GetBackendIds() const
62 {
63 BackendIdSet result;
64 for (const auto& it : m_Factories)
65 {
66 result.insert(it.first);
67 }
68 return result;
69 }
70
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010071 std::string GetBackendIdsAsString() const
72 {
73 static const std::string delimitator = ", ";
74
75 std::stringstream output;
76 for (auto& backendId : GetBackendIds())
77 {
78 if (output.tellp() != std::streampos(0))
79 {
80 output << delimitator;
81 }
82 output << backendId;
83 }
84
85 return output.str();
86 }
87
David Beck3e9e1152018-10-17 14:17:50 +010088 RegistryCommon() {}
89 virtual ~RegistryCommon() {}
90
91protected:
92 using FactoryStorage = std::unordered_map<BackendId, FactoryFunction>;
93
94 // For testing only
95 static void Swap(RegistryCommon& instance, FactoryStorage& other)
96 {
97 std::swap(instance.m_Factories, other);
98 }
99
100private:
101 RegistryCommon(const RegistryCommon&) = delete;
102 RegistryCommon& operator=(const RegistryCommon&) = delete;
103
104 FactoryStorage m_Factories;
105};
106
107template <typename RegistryType>
108struct StaticRegistryInitializer
109{
110 using FactoryFunction = typename RegistryType::FactoryFunction;
111
112 StaticRegistryInitializer(RegistryType& instance,
113 const BackendId& id,
114 FactoryFunction factory)
115 {
116 instance.Register(id, factory);
117 }
118};
119
David Beck9efb57d2018-11-05 13:40:33 +0000120} // namespace armnn