Matthew Bentham | 245d64c | 2019-12-02 12:59:43 +0000 | [diff] [blame^] | 1 | # Copyright © 2019 Arm Ltd. All rights reserved. |
| 2 | # SPDX-License-Identifier: MIT |
| 3 | import logging |
| 4 | import os |
| 5 | import sys |
| 6 | from functools import lru_cache |
| 7 | from pathlib import Path |
| 8 | from itertools import chain |
| 9 | |
| 10 | from setuptools import setup |
| 11 | from distutils.core import Extension |
| 12 | from setuptools.command.build_py import build_py |
| 13 | from setuptools.command.build_ext import build_ext |
| 14 | |
| 15 | logger = logging.Logger(__name__) |
| 16 | |
| 17 | __version__ = None |
| 18 | __arm_ml_version__ = None |
| 19 | |
| 20 | |
| 21 | def check_armnn_version(*args): |
| 22 | pass |
| 23 | |
| 24 | |
| 25 | exec(open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src', 'pyarmnn', '_version.py')).read()) |
| 26 | |
| 27 | |
| 28 | class ExtensionPriorityBuilder(build_py): |
| 29 | """ |
| 30 | Runs extension builder before other stages. Otherwise generated files are not included to the distribution. |
| 31 | """ |
| 32 | |
| 33 | def run(self): |
| 34 | self.run_command('build_ext') |
| 35 | return super().run() |
| 36 | |
| 37 | |
| 38 | class ArmnnVersionCheckerExtBuilder(build_ext): |
| 39 | |
| 40 | def __init__(self, dist): |
| 41 | super().__init__(dist) |
| 42 | self.failed_ext = [] |
| 43 | |
| 44 | def build_extension(self, ext): |
| 45 | try: |
| 46 | super().build_extension(ext) |
| 47 | except Exception as err: |
| 48 | self.failed_ext.append(ext) |
| 49 | logger.warning('Failed to build extension %s. \n %s', ext.name, str(err)) |
| 50 | |
| 51 | if ext.name == 'pyarmnn._generated._pyarmnn_version': |
| 52 | sys.path.append(os.path.abspath(os.path.join(self.build_lib, str(Path(ext._file_name).parent)))) |
| 53 | from _pyarmnn_version import GetVersion |
| 54 | check_armnn_version(GetVersion(), __arm_ml_version__) |
| 55 | |
| 56 | def copy_extensions_to_source(self): |
| 57 | |
| 58 | for ext in self.failed_ext: |
| 59 | self.extensions.remove(ext) |
| 60 | super().copy_extensions_to_source() |
| 61 | |
| 62 | |
| 63 | def linux_gcc_lib_search(): |
| 64 | """ |
| 65 | Calls the `gcc` to get linker default system paths. |
| 66 | Returns: |
| 67 | list of paths |
| 68 | """ |
| 69 | cmd = 'gcc --print-search-dirs | grep libraries' |
| 70 | cmd_res = os.popen(cmd).read() |
| 71 | cmd_res = cmd_res.split('=') |
| 72 | if len(cmd_res) > 1: |
| 73 | return tuple(cmd_res[1].split(':')) |
| 74 | return None |
| 75 | |
| 76 | |
| 77 | def find_includes(armnn_include_env: str = 'ARMNN_INCLUDE'): |
| 78 | armnn_include_path = os.getenv(armnn_include_env, '') |
| 79 | return [armnn_include_path] if armnn_include_path else ['/usr/local/include', '/usr/include'] |
| 80 | |
| 81 | |
| 82 | @lru_cache(maxsize=1) |
| 83 | def find_armnn(lib_name: str, |
| 84 | optional: bool = False, |
| 85 | armnn_libs_env: str = 'ARMNN_LIB', |
| 86 | default_lib_search: tuple = linux_gcc_lib_search()): |
| 87 | """ |
| 88 | Searches for ArmNN installation on the local machine. |
| 89 | |
| 90 | Args: |
| 91 | lib_name: lib name to find |
| 92 | optional: Do not fail if optional. Default is False - fail if library was not found. |
| 93 | armnn_include_env: custom environment variable pointing to ArmNN headers, default is 'ARMNN_INCLUDE' |
| 94 | armnn_libs_env: custom environment variable pointing to ArmNN libraries location, default is 'ARMNN_LIBS' |
| 95 | default_lib_search: list of paths to search for ArmNN if not found within path provided by 'ARMNN_LIBS' |
| 96 | env variable |
| 97 | |
| 98 | Returns: |
| 99 | tuple containing name of the armnn libs, paths to the libs |
| 100 | """ |
| 101 | |
| 102 | armnn_lib_path = os.getenv(armnn_libs_env, "") |
| 103 | |
| 104 | lib_search = [armnn_lib_path] if armnn_lib_path else default_lib_search |
| 105 | |
| 106 | armnn_libs = dict(map(lambda path: (':{}'.format(path.name), path), |
| 107 | chain.from_iterable(map(lambda lib_path: Path(lib_path).glob(lib_name), |
| 108 | lib_search)))) |
| 109 | if not optional and len(armnn_libs) == 0: |
| 110 | raise RuntimeError("""ArmNN library {} was not found in {}. Please install ArmNN to one of the standard |
| 111 | locations or set correct ARMNN_INCLUDE and ARMNN_LIB env variables.""".format(lib_name, |
| 112 | lib_search)) |
| 113 | |
| 114 | # gives back tuple of names of the libs, set of unique libs locations and includes. |
| 115 | return list(armnn_libs.keys()), list(set( |
| 116 | map(lambda path: str(path.absolute().parent), armnn_libs.values()))) |
| 117 | |
| 118 | |
| 119 | class LazyArmnnFinderExtension(Extension): |
| 120 | """ |
| 121 | Derived from `Extension` this class adds ArmNN libraries search on the user's machine. |
| 122 | SWIG options and compilation flags are updated with relevant ArmNN libraries files locations (-L) and headers (-I). |
| 123 | |
| 124 | Search for ArmNN is executed only when attributes include_dirs, library_dirs, runtime_library_dirs, libraries or |
| 125 | swig_opts are queried. |
| 126 | |
| 127 | """ |
| 128 | |
| 129 | def __init__(self, name, sources, armnn_libs, include_dirs=None, define_macros=None, undef_macros=None, |
| 130 | library_dirs=None, |
| 131 | libraries=None, runtime_library_dirs=None, extra_objects=None, extra_compile_args=None, |
| 132 | extra_link_args=None, export_symbols=None, language=None, optional=None, **kw): |
| 133 | self._include_dirs = None |
| 134 | self._library_dirs = None |
| 135 | self._runtime_library_dirs = None |
| 136 | self._armnn_libs = armnn_libs |
| 137 | # self.__swig_opts = None |
| 138 | super().__init__(name, sources, include_dirs, define_macros, undef_macros, library_dirs, libraries, |
| 139 | runtime_library_dirs, extra_objects, extra_compile_args, extra_link_args, export_symbols, |
| 140 | language, optional, **kw) |
| 141 | |
| 142 | @property |
| 143 | def include_dirs(self): |
| 144 | return self._include_dirs + find_includes() |
| 145 | |
| 146 | @include_dirs.setter |
| 147 | def include_dirs(self, include_dirs): |
| 148 | self._include_dirs = include_dirs |
| 149 | |
| 150 | @property |
| 151 | def library_dirs(self): |
| 152 | library_dirs = self._library_dirs |
| 153 | for lib in self._armnn_libs: |
| 154 | _, lib_path = find_armnn(lib) |
| 155 | library_dirs = library_dirs + lib_path |
| 156 | |
| 157 | return library_dirs |
| 158 | |
| 159 | @library_dirs.setter |
| 160 | def library_dirs(self, library_dirs): |
| 161 | self._library_dirs = library_dirs |
| 162 | |
| 163 | @property |
| 164 | def runtime_library_dirs(self): |
| 165 | library_dirs = self._runtime_library_dirs |
| 166 | for lib in self._armnn_libs: |
| 167 | _, lib_path = find_armnn(lib) |
| 168 | library_dirs = library_dirs + lib_path |
| 169 | |
| 170 | return library_dirs |
| 171 | |
| 172 | @runtime_library_dirs.setter |
| 173 | def runtime_library_dirs(self, runtime_library_dirs): |
| 174 | self._runtime_library_dirs = runtime_library_dirs |
| 175 | |
| 176 | @property |
| 177 | def libraries(self): |
| 178 | libraries = self._libraries |
| 179 | for lib in self._armnn_libs: |
| 180 | lib_names, _ = find_armnn(lib) |
| 181 | libraries = libraries + lib_names |
| 182 | |
| 183 | return libraries |
| 184 | |
| 185 | @libraries.setter |
| 186 | def libraries(self, libraries): |
| 187 | self._libraries = libraries |
| 188 | |
| 189 | def __eq__(self, other): |
| 190 | return self.__class__ == other.__class__ and self.name == other.name |
| 191 | |
| 192 | def __ne__(self, other): |
| 193 | return not self.__eq__(other) |
| 194 | |
| 195 | def __hash__(self): |
| 196 | return self.name.__hash__() |
| 197 | |
| 198 | if __name__ == '__main__': |
| 199 | # mandatory extensions |
| 200 | pyarmnn_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn', |
| 201 | sources=['src/pyarmnn/_generated/armnn_wrap.cpp'], |
| 202 | extra_compile_args=['-std=c++14'], |
| 203 | language='c++', |
| 204 | armnn_libs=['libarmnn.so'] |
| 205 | ) |
| 206 | pyarmnn_v_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn_version', |
| 207 | sources=['src/pyarmnn/_generated/armnn_version_wrap.cpp'], |
| 208 | extra_compile_args=['-std=c++14'], |
| 209 | language='c++', |
| 210 | armnn_libs=['libarmnn.so'] |
| 211 | ) |
| 212 | extensions_to_build = [pyarmnn_v_module, pyarmnn_module] |
| 213 | |
| 214 | |
| 215 | # optional extensions |
| 216 | def add_parsers_ext(name: str, ext_list: list): |
| 217 | pyarmnn_optional_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn_{}'.format(name.lower()), |
| 218 | sources=['src/pyarmnn/_generated/armnn_{}_wrap.cpp'.format( |
| 219 | name.lower())], |
| 220 | extra_compile_args=['-std=c++14'], |
| 221 | language='c++', |
| 222 | armnn_libs=['libarmnn.so', 'libarmnn{}.so'.format(name)] |
| 223 | ) |
| 224 | ext_list.append(pyarmnn_optional_module) |
| 225 | |
| 226 | |
| 227 | add_parsers_ext('CaffeParser', extensions_to_build) |
| 228 | add_parsers_ext('OnnxParser', extensions_to_build) |
| 229 | add_parsers_ext('TfParser', extensions_to_build) |
| 230 | add_parsers_ext('TfLiteParser', extensions_to_build) |
| 231 | |
| 232 | setup( |
| 233 | name='pyarmnn', |
| 234 | version=__version__, |
| 235 | author='Arm ltd', |
| 236 | description='Arm NN python wrapper', |
| 237 | url='https://www.arm.com', |
| 238 | license='MIT', |
| 239 | package_dir={'': 'src'}, |
| 240 | packages=[ |
| 241 | 'pyarmnn', |
| 242 | 'pyarmnn._generated', |
| 243 | 'pyarmnn._quantization', |
| 244 | 'pyarmnn._tensor', |
| 245 | 'pyarmnn._utilities' |
| 246 | ], |
| 247 | data_files=[('', ['licences.txt'])], |
| 248 | python_requires='>=3.5', |
| 249 | install_requires=['numpy'], |
| 250 | cmdclass={'build_py': ExtensionPriorityBuilder, 'build_ext': ArmnnVersionCheckerExtBuilder}, |
| 251 | ext_modules=extensions_to_build |
| 252 | ) |