| # SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. |
| # SPDX-License-Identifier: Apache-2.0 |
| import argparse |
| import git |
| import os |
| import re |
| import setuptools |
| import setuptools.command.build_ext |
| import shutil |
| import sys |
| import time |
| |
| |
| TOSA_CHECKER_VERSION = "0.2.0" |
| DEFAULT_TENSORFLOW_VERSION = "2.13.0" |
| |
| # Get the TensorFlowâ„¢ source directory passed to the command line (if any). |
| # If none is given the sources are pulled from the official TF repository. |
| argparser = argparse.ArgumentParser() |
| argparser.add_argument( |
| "--tensorflow_src_dir", help="TensorFlow source directory path", required=False |
| ) |
| argparser.add_argument( |
| "--sanitizer", |
| help="Build using a sanitizer (choose from asan or ubsan)", |
| choices=["asan", "ubsan"], |
| required=False |
| ) |
| argparser.add_argument( |
| "--tosa_checker_copt", |
| help="Build tosa_checker with additional copt (comma separated string)", |
| default="", |
| required=False |
| ) |
| argparser.add_argument( |
| "--nightly", |
| help="Build tosa_checker as a nightly wheel", |
| action="store_true" |
| ) |
| args, unknown = argparser.parse_known_args() |
| sys.argv = [sys.argv[0]] + unknown |
| |
| def increment_version(version_string): |
| version_numbers = version_string.split(".") |
| increment = int(version_numbers[1]) + 1 |
| |
| version_numbers[1] = str(increment) |
| |
| return ".".join(version_numbers) |
| |
| def get_package_version(nightly=False): |
| if nightly: |
| return "{}.dev{}".format( |
| increment_version(TOSA_CHECKER_VERSION), |
| time.strftime("%Y%m%d") |
| ) |
| else: |
| return TOSA_CHECKER_VERSION |
| |
| def get_repo_version(repo_directory): |
| r = git.repo.Repo(repo_directory) |
| tag = r.git.tag('--points-at') |
| |
| if tag: |
| return tag |
| else: |
| return r.head.commit.hexsha |
| |
| class BazelExtensionModule(setuptools.Extension): |
| def __init__( |
| self, |
| py_module_name, |
| library_name, |
| bazel_target, |
| bazel_shared_lib_output, |
| tensorflow_version, |
| ): |
| super().__init__(py_module_name, sources=[]) |
| self.library_name = library_name |
| self.bazel_target = bazel_target |
| self.bazel_shared_lib_output = bazel_shared_lib_output |
| self.tensorflow_version = tensorflow_version |
| |
| |
| class BazelBuildExtension(setuptools.command.build_ext.build_ext): |
| """Override build_extension to build the library with bazel and copying it |
| beforehand.""" |
| |
| def build_extension(self, ext): |
| tensorflow_src_dir = args.tensorflow_src_dir |
| |
| if not tensorflow_src_dir: |
| tensorflow_src_dir = os.path.join(self.build_temp, "tensorflow") |
| self._clone_tf_repository( |
| tensorflow_src_dir, |
| ext.tensorflow_version, |
| ) |
| |
| commands = [ |
| "bazel", |
| "build" |
| ] |
| |
| if args.sanitizer: |
| commands += [ |
| "--config={}".format(args.sanitizer) |
| ] |
| if args.tosa_checker_copt: |
| commands += [ |
| "--per_file_copt=tosa_checker/tosa_checker.*@{}".format(args.tosa_checker_copt) |
| ] |
| commands += [ |
| # FIXME Some of the Bazel targets dependencies we use have |
| # a 'friends' visibility, check if our Bazel target can be added |
| # to the 'friends' list. |
| "-c", |
| "opt", |
| "--check_visibility=false", |
| "--override_repository=org_tensorflow={}".format( |
| os.path.abspath(tensorflow_src_dir) |
| ), |
| ext.bazel_target |
| ] |
| |
| self.spawn(commands) |
| |
| shared_lib_dest_path = self.get_ext_fullpath(ext.name) |
| shared_lib_dest_dir = os.path.dirname(shared_lib_dest_path) |
| package_dir = os.path.join(shared_lib_dest_dir, ext.library_name) |
| |
| os.makedirs(shared_lib_dest_dir, exist_ok=True) |
| os.makedirs(package_dir, exist_ok=True) |
| |
| shutil.copyfile(ext.bazel_shared_lib_output, shared_lib_dest_path) |
| |
| # Get the TensorFlow version this is built with |
| tf_version = get_repo_version(tensorflow_src_dir) |
| |
| with open(os.path.join(ext.library_name, "__init__.py"), "r") as f: |
| module_init_file = f.read() |
| |
| module_init_file += "__tensorflow_version__ = \"{}\"\n".format(tf_version) |
| |
| with open(os.path.join(package_dir, "__init__.py"), "w") as f: |
| f.write(module_init_file) |
| |
| super().build_extension(ext) |
| |
| def _clone_tf_repository(self, tensorflow_src_dir, tensorflow_version): |
| if os.path.exists(tensorflow_src_dir): |
| return |
| |
| tensorflow_repo = "https://github.com/tensorflow/tensorflow.git" |
| self.spawn( |
| [ |
| "git", |
| "clone", |
| "--depth=1", |
| "--branch", |
| "v" + tensorflow_version, |
| tensorflow_repo, |
| tensorflow_src_dir, |
| ] |
| ) |
| |
| |
| def get_long_description(nightly): |
| # Read the contents of README.md file |
| this_directory = os.path.abspath(os.path.dirname(__file__)) |
| with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f: |
| long_description = f.read() |
| |
| if not nightly: |
| # Replace relative links to existing files with absolute links to https://review.mlplatform.org |
| url = f"https://review.mlplatform.org/plugins/gitiles/tosa/tosa_checker/+/refs/tags/{TOSA_CHECKER_VERSION}/" |
| # Find all markdown links that match the format: [text](link) |
| for match, link in re.findall(r"(\[.+?\]\((.+?)\))", long_description): |
| # If the link is a file that exists, replace it with the web link to the file instead |
| if os.path.exists(os.path.join(this_directory, link)): |
| url_link = re.sub(r"\((.+?)\)", rf"({url}{link})", match) |
| long_description = long_description.replace(match, url_link) |
| |
| return long_description |
| |
| |
| setuptools.setup( |
| name="tosa-checker", |
| version=get_package_version(args.nightly), |
| description="Tool to check if a ML model is compatible with the TOSA specification", |
| long_description=get_long_description(args.nightly), |
| long_description_content_type="text/markdown", |
| author="Arm Limited", |
| url="https://git.mlplatform.org/tosa/tosa_checker.git/", |
| license="Apache-2.0", |
| python_requires=">=3.8", |
| cmdclass={"build_ext": BazelBuildExtension}, |
| ext_modules=[ |
| BazelExtensionModule( |
| py_module_name="_tosa_checker_wrapper", |
| library_name="tosa_checker", |
| bazel_target="//tosa_checker:tosa_checker", |
| bazel_shared_lib_output="bazel-bin/tosa_checker/_tosa_checker_wrapper.so", |
| tensorflow_version=DEFAULT_TENSORFLOW_VERSION, |
| ), |
| ], |
| classifiers=[ |
| "Development Status :: 3 - Alpha", |
| "Intended Audience :: Developers", |
| "License :: OSI Approved :: Apache Software License", |
| "Programming Language :: Python :: 3", |
| "Topic :: Utilities", |
| ], |
| ) |