blob: f1a9b8c5b9b549af72ad319e857bc857dfb29bbb [file] [log] [blame]
Jerry Ge5637a862023-10-30 10:18:45 -07001// Copyright (c) 2023, ARM Limited.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef CUSTOMREGISTRY_H
16#define CUSTOMREGISTRY_H
17
18#include "custom_op_interface.h"
19#include <dlfcn.h>
20#include <unordered_map>
21
22using namespace tosa;
23
24namespace TosaReference
25{
26
27typedef CustomOpInterface* (*op_creation_function_t)();
28typedef int (*registration_callback_t)(const std::string& domain_name,
29 const std::string& operator_name,
30 const op_creation_function_t& op_creation_function);
31
32class MasterRegistry
33{
34public:
35 static int register_function(const std::string& domain_name,
36 const std::string& operator_name,
37 const op_creation_function_t& op_creation_function)
38 {
39 std::string unique_id = domain_name + "::" + operator_name;
40 MasterRegistry& instance = get_instance();
41 if (instance.op_creation_map.find(unique_id) != instance.op_creation_map.end())
42 {
43 std::cout << std::endl;
44 printf("domain_name: %s and operator_name: %s pair has already been registered", domain_name.c_str(),
45 operator_name.c_str());
46 return 1;
47 }
48 instance.op_creation_map[unique_id] = op_creation_function;
49 return 0;
50 }
51
52 static MasterRegistry& get_instance()
53 {
54 static MasterRegistry instance;
55 return instance;
56 }
57
58 MasterRegistry(const MasterRegistry&) = delete;
59 void operator=(const MasterRegistry&) = delete;
60
61 std::unordered_map<std::string, op_creation_function_t> get_ops() const
62 {
63 return op_creation_map;
64 }
65
66 static op_creation_function_t get_op(const std::string& domain_name, const std::string& operator_name)
67 {
68 std::string unique_id = domain_name + "::" + operator_name;
69 MasterRegistry& instance = get_instance();
70 auto all_ops_map = instance.get_ops();
71 if (all_ops_map.find(unique_id) == all_ops_map.end())
72 {
73 return nullptr;
74 }
75 else
76 {
77 op_creation_function_t& op_creation_function = all_ops_map[unique_id];
78 return op_creation_function;
79 }
80 }
81
82private:
83 MasterRegistry() = default;
84 std::unordered_map<std::string, op_creation_function_t> op_creation_map;
85};
86} // namespace TosaReference
87
88#endif