blob: cb4dc865a74b1befbc39cefac14d35a4784e902e [file] [log] [blame]
Matthew Bentham245d64c2019-12-02 12:59:43 +00001# Copyright © 2019 Arm Ltd. All rights reserved.
2# SPDX-License-Identifier: MIT
3import logging
4import os
5import sys
6from functools import lru_cache
7from pathlib import Path
8from itertools import chain
9
10from setuptools import setup
11from distutils.core import Extension
12from setuptools.command.build_py import build_py
13from setuptools.command.build_ext import build_ext
14
15logger = logging.Logger(__name__)
16
17__version__ = None
18__arm_ml_version__ = None
19
20
21def check_armnn_version(*args):
22 pass
23
24
25exec(open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src', 'pyarmnn', '_version.py')).read())
26
27
28class ExtensionPriorityBuilder(build_py):
29 """
30 Runs extension builder before other stages. Otherwise generated files are not included to the distribution.
31 """
32
33 def run(self):
34 self.run_command('build_ext')
35 return super().run()
36
37
38class ArmnnVersionCheckerExtBuilder(build_ext):
39
40 def __init__(self, dist):
41 super().__init__(dist)
42 self.failed_ext = []
43
44 def build_extension(self, ext):
45 try:
46 super().build_extension(ext)
47 except Exception as err:
48 self.failed_ext.append(ext)
49 logger.warning('Failed to build extension %s. \n %s', ext.name, str(err))
50
51 if ext.name == 'pyarmnn._generated._pyarmnn_version':
52 sys.path.append(os.path.abspath(os.path.join(self.build_lib, str(Path(ext._file_name).parent))))
53 from _pyarmnn_version import GetVersion
54 check_armnn_version(GetVersion(), __arm_ml_version__)
55
56 def copy_extensions_to_source(self):
57
58 for ext in self.failed_ext:
59 self.extensions.remove(ext)
60 super().copy_extensions_to_source()
61
62
63def linux_gcc_lib_search():
64 """
65 Calls the `gcc` to get linker default system paths.
66 Returns:
67 list of paths
68 """
69 cmd = 'gcc --print-search-dirs | grep libraries'
70 cmd_res = os.popen(cmd).read()
71 cmd_res = cmd_res.split('=')
72 if len(cmd_res) > 1:
73 return tuple(cmd_res[1].split(':'))
74 return None
75
76
77def find_includes(armnn_include_env: str = 'ARMNN_INCLUDE'):
78 armnn_include_path = os.getenv(armnn_include_env, '')
79 return [armnn_include_path] if armnn_include_path else ['/usr/local/include', '/usr/include']
80
81
82@lru_cache(maxsize=1)
83def find_armnn(lib_name: str,
84 optional: bool = False,
85 armnn_libs_env: str = 'ARMNN_LIB',
86 default_lib_search: tuple = linux_gcc_lib_search()):
87 """
88 Searches for ArmNN installation on the local machine.
89
90 Args:
91 lib_name: lib name to find
92 optional: Do not fail if optional. Default is False - fail if library was not found.
93 armnn_include_env: custom environment variable pointing to ArmNN headers, default is 'ARMNN_INCLUDE'
94 armnn_libs_env: custom environment variable pointing to ArmNN libraries location, default is 'ARMNN_LIBS'
95 default_lib_search: list of paths to search for ArmNN if not found within path provided by 'ARMNN_LIBS'
96 env variable
97
98 Returns:
99 tuple containing name of the armnn libs, paths to the libs
100 """
101
102 armnn_lib_path = os.getenv(armnn_libs_env, "")
103
104 lib_search = [armnn_lib_path] if armnn_lib_path else default_lib_search
105
106 armnn_libs = dict(map(lambda path: (':{}'.format(path.name), path),
107 chain.from_iterable(map(lambda lib_path: Path(lib_path).glob(lib_name),
108 lib_search))))
109 if not optional and len(armnn_libs) == 0:
110 raise RuntimeError("""ArmNN library {} was not found in {}. Please install ArmNN to one of the standard
111 locations or set correct ARMNN_INCLUDE and ARMNN_LIB env variables.""".format(lib_name,
112 lib_search))
113
114 # gives back tuple of names of the libs, set of unique libs locations and includes.
115 return list(armnn_libs.keys()), list(set(
116 map(lambda path: str(path.absolute().parent), armnn_libs.values())))
117
118
119class LazyArmnnFinderExtension(Extension):
120 """
121 Derived from `Extension` this class adds ArmNN libraries search on the user's machine.
122 SWIG options and compilation flags are updated with relevant ArmNN libraries files locations (-L) and headers (-I).
123
124 Search for ArmNN is executed only when attributes include_dirs, library_dirs, runtime_library_dirs, libraries or
125 swig_opts are queried.
126
127 """
128
129 def __init__(self, name, sources, armnn_libs, include_dirs=None, define_macros=None, undef_macros=None,
130 library_dirs=None,
131 libraries=None, runtime_library_dirs=None, extra_objects=None, extra_compile_args=None,
132 extra_link_args=None, export_symbols=None, language=None, optional=None, **kw):
133 self._include_dirs = None
134 self._library_dirs = None
135 self._runtime_library_dirs = None
136 self._armnn_libs = armnn_libs
137 # self.__swig_opts = None
138 super().__init__(name, sources, include_dirs, define_macros, undef_macros, library_dirs, libraries,
139 runtime_library_dirs, extra_objects, extra_compile_args, extra_link_args, export_symbols,
140 language, optional, **kw)
141
142 @property
143 def include_dirs(self):
144 return self._include_dirs + find_includes()
145
146 @include_dirs.setter
147 def include_dirs(self, include_dirs):
148 self._include_dirs = include_dirs
149
150 @property
151 def library_dirs(self):
152 library_dirs = self._library_dirs
153 for lib in self._armnn_libs:
154 _, lib_path = find_armnn(lib)
155 library_dirs = library_dirs + lib_path
156
157 return library_dirs
158
159 @library_dirs.setter
160 def library_dirs(self, library_dirs):
161 self._library_dirs = library_dirs
162
163 @property
164 def runtime_library_dirs(self):
165 library_dirs = self._runtime_library_dirs
166 for lib in self._armnn_libs:
167 _, lib_path = find_armnn(lib)
168 library_dirs = library_dirs + lib_path
169
170 return library_dirs
171
172 @runtime_library_dirs.setter
173 def runtime_library_dirs(self, runtime_library_dirs):
174 self._runtime_library_dirs = runtime_library_dirs
175
176 @property
177 def libraries(self):
178 libraries = self._libraries
179 for lib in self._armnn_libs:
180 lib_names, _ = find_armnn(lib)
181 libraries = libraries + lib_names
182
183 return libraries
184
185 @libraries.setter
186 def libraries(self, libraries):
187 self._libraries = libraries
188
189 def __eq__(self, other):
190 return self.__class__ == other.__class__ and self.name == other.name
191
192 def __ne__(self, other):
193 return not self.__eq__(other)
194
195 def __hash__(self):
196 return self.name.__hash__()
197
198if __name__ == '__main__':
199 # mandatory extensions
200 pyarmnn_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn',
201 sources=['src/pyarmnn/_generated/armnn_wrap.cpp'],
202 extra_compile_args=['-std=c++14'],
203 language='c++',
204 armnn_libs=['libarmnn.so']
205 )
206 pyarmnn_v_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn_version',
207 sources=['src/pyarmnn/_generated/armnn_version_wrap.cpp'],
208 extra_compile_args=['-std=c++14'],
209 language='c++',
210 armnn_libs=['libarmnn.so']
211 )
212 extensions_to_build = [pyarmnn_v_module, pyarmnn_module]
213
214
215 # optional extensions
216 def add_parsers_ext(name: str, ext_list: list):
217 pyarmnn_optional_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn_{}'.format(name.lower()),
218 sources=['src/pyarmnn/_generated/armnn_{}_wrap.cpp'.format(
219 name.lower())],
220 extra_compile_args=['-std=c++14'],
221 language='c++',
222 armnn_libs=['libarmnn.so', 'libarmnn{}.so'.format(name)]
223 )
224 ext_list.append(pyarmnn_optional_module)
225
226
227 add_parsers_ext('CaffeParser', extensions_to_build)
228 add_parsers_ext('OnnxParser', extensions_to_build)
229 add_parsers_ext('TfParser', extensions_to_build)
230 add_parsers_ext('TfLiteParser', extensions_to_build)
231
232 setup(
233 name='pyarmnn',
234 version=__version__,
235 author='Arm ltd',
236 description='Arm NN python wrapper',
237 url='https://www.arm.com',
238 license='MIT',
239 package_dir={'': 'src'},
240 packages=[
241 'pyarmnn',
242 'pyarmnn._generated',
243 'pyarmnn._quantization',
244 'pyarmnn._tensor',
245 'pyarmnn._utilities'
246 ],
247 data_files=[('', ['licences.txt'])],
248 python_requires='>=3.5',
249 install_requires=['numpy'],
250 cmdclass={'build_py': ExtensionPriorityBuilder, 'build_ext': ArmnnVersionCheckerExtBuilder},
251 ext_modules=extensions_to_build
252 )