| # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. |
| # SPDX-License-Identifier: Apache-2.0 |
| import argparse |
| import os |
| import pathlib |
| import setuptools |
| import setuptools.command.build_ext |
| import shutil |
| import sys |
| |
| |
| TOSA_CHECKER_VERSION = "0.1.0" |
| TENSORFLOW_VERSION = "2.9.0" |
| |
| # Get the TensorFlowâ„¢ source directory passed to the command line (if any). |
| # If none is given the sources are pulled from the official TF repository. |
| argparser = argparse.ArgumentParser() |
| argparser.add_argument( |
| "--tensorflow_src_dir", help="TensorFlow source directory path", required=False |
| ) |
| args, unknown = argparser.parse_known_args() |
| sys.argv = [sys.argv[0]] + unknown |
| |
| |
| class BazelExtensionModule(setuptools.Extension): |
| def __init__( |
| self, |
| py_module_name, |
| library_name, |
| bazel_target, |
| bazel_shared_lib_output, |
| tensorflow_version, |
| ): |
| super().__init__(py_module_name, sources=[]) |
| self.library_name = library_name |
| self.bazel_target = bazel_target |
| self.bazel_shared_lib_output = bazel_shared_lib_output |
| self.tensorflow_version = tensorflow_version |
| |
| |
| class BazelBuildExtension(setuptools.command.build_ext.build_ext): |
| """Override build_extension to build the library with bazel and copying it |
| beforehand.""" |
| |
| def build_extension(self, ext): |
| tensorflow_src_dir = args.tensorflow_src_dir |
| if not tensorflow_src_dir: |
| tensorflow_src_dir = os.path.join(self.build_temp, "tensorflow") |
| self._clone_tf_repository( |
| tensorflow_src_dir, |
| ext.tensorflow_version, |
| ) |
| |
| self.spawn( |
| [ |
| "bazel", |
| "build", |
| "-c", |
| "opt", |
| # FIXME Some of the Bazel targets dependencies we use have |
| # a 'friends' visibility, check if our Bazel target can be added |
| # to the 'friends' list. |
| "--check_visibility=false", |
| "--override_repository=org_tensorflow=" |
| + os.path.abspath(tensorflow_src_dir), |
| ext.bazel_target, |
| ] |
| ) |
| |
| shared_lib_dest_path = self.get_ext_fullpath(ext.name) |
| shared_lib_dest_dir = os.path.dirname(shared_lib_dest_path) |
| package_dir = os.path.join(shared_lib_dest_dir, ext.library_name) |
| |
| os.makedirs(shared_lib_dest_dir, exist_ok=True) |
| os.makedirs(package_dir, exist_ok=True) |
| |
| shutil.copyfile(ext.bazel_shared_lib_output, shared_lib_dest_path) |
| shutil.copy(os.path.join(ext.library_name, "__init__.py"), package_dir) |
| |
| super().build_extension(ext) |
| |
| def _clone_tf_repository(self, tensorflow_src_dir, tensorflow_version): |
| if os.path.exists(tensorflow_src_dir): |
| return |
| |
| tensorflow_repo = "https://github.com/tensorflow/tensorflow.git" |
| self.spawn( |
| [ |
| "git", |
| "clone", |
| "--depth=1", |
| "--branch", |
| "v" + tensorflow_version, |
| tensorflow_repo, |
| tensorflow_src_dir, |
| ] |
| ) |
| |
| |
| setuptools.setup( |
| name="tosa-checker", |
| version=TOSA_CHECKER_VERSION, |
| description="Tool to check if a ML model is compatible with the TOSA specification", |
| long_description=(pathlib.Path(__file__).parent / "README.md").read_text(), |
| long_description_content_type="text/markdown", |
| author="Arm Limited", |
| url="https://git.mlplatform.org/tosa/tosa_checker.git/", |
| license="Apache-2.0", |
| python_requires=">=3.7", |
| cmdclass={"build_ext": BazelBuildExtension}, |
| ext_modules=[ |
| BazelExtensionModule( |
| py_module_name="_tosa_checker_wrapper", |
| library_name="tosa_checker", |
| bazel_target="//tosa_checker:tosa_checker", |
| bazel_shared_lib_output="bazel-bin/tosa_checker/_tosa_checker_wrapper.so", |
| tensorflow_version=TENSORFLOW_VERSION, |
| ), |
| ], |
| classifiers=[ |
| "Development Status :: 3 - Alpha", |
| "Intended Audience :: Developers", |
| "License :: OSI Approved :: Apache Software License", |
| "Programming Language :: Python :: 3", |
| "Topic :: Utilities", |
| ], |
| ) |