blob: 8f6c4670ebab552a393d2495b651d36f93757af0 [file] [log] [blame]
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import argparse
import git
import os
import re
import setuptools
import setuptools.command.build_ext
import shutil
import sys
import time
TOSA_CHECKER_VERSION = "0.2.0"
DEFAULT_TENSORFLOW_VERSION = "2.13.0"
# Get the TensorFlowâ„¢ source directory passed to the command line (if any).
# If none is given the sources are pulled from the official TF repository.
argparser = argparse.ArgumentParser()
argparser.add_argument(
"--tensorflow_src_dir", help="TensorFlow source directory path", required=False
)
argparser.add_argument(
"--sanitizer",
help="Build using a sanitizer (choose from asan or ubsan)",
choices=["asan", "ubsan"],
required=False
)
argparser.add_argument(
"--tosa_checker_copt",
help="Build tosa_checker with additional copt (comma separated string)",
default="",
required=False
)
argparser.add_argument(
"--nightly",
help="Build tosa_checker as a nightly wheel",
action="store_true"
)
args, unknown = argparser.parse_known_args()
sys.argv = [sys.argv[0]] + unknown
def increment_version(version_string):
version_numbers = version_string.split(".")
increment = int(version_numbers[1]) + 1
version_numbers[1] = str(increment)
return ".".join(version_numbers)
def get_package_version(nightly=False):
if nightly:
return "{}.dev{}".format(
increment_version(TOSA_CHECKER_VERSION),
time.strftime("%Y%m%d")
)
else:
return TOSA_CHECKER_VERSION
def get_repo_version(repo_directory):
r = git.repo.Repo(repo_directory)
tag = r.git.tag('--points-at')
if tag:
return tag
else:
return r.head.commit.hexsha
class BazelExtensionModule(setuptools.Extension):
def __init__(
self,
py_module_name,
library_name,
bazel_target,
bazel_shared_lib_output,
tensorflow_version,
):
super().__init__(py_module_name, sources=[])
self.library_name = library_name
self.bazel_target = bazel_target
self.bazel_shared_lib_output = bazel_shared_lib_output
self.tensorflow_version = tensorflow_version
class BazelBuildExtension(setuptools.command.build_ext.build_ext):
"""Override build_extension to build the library with bazel and copying it
beforehand."""
def build_extension(self, ext):
tensorflow_src_dir = args.tensorflow_src_dir
if not tensorflow_src_dir:
tensorflow_src_dir = os.path.join(self.build_temp, "tensorflow")
self._clone_tf_repository(
tensorflow_src_dir,
ext.tensorflow_version,
)
commands = [
"bazel",
"build"
]
if args.sanitizer:
commands += [
"--config={}".format(args.sanitizer)
]
if args.tosa_checker_copt:
commands += [
"--per_file_copt=tosa_checker/tosa_checker.*@{}".format(args.tosa_checker_copt)
]
commands += [
# FIXME Some of the Bazel targets dependencies we use have
# a 'friends' visibility, check if our Bazel target can be added
# to the 'friends' list.
"-c",
"opt",
"--check_visibility=false",
"--override_repository=org_tensorflow={}".format(
os.path.abspath(tensorflow_src_dir)
),
ext.bazel_target
]
self.spawn(commands)
shared_lib_dest_path = self.get_ext_fullpath(ext.name)
shared_lib_dest_dir = os.path.dirname(shared_lib_dest_path)
package_dir = os.path.join(shared_lib_dest_dir, ext.library_name)
os.makedirs(shared_lib_dest_dir, exist_ok=True)
os.makedirs(package_dir, exist_ok=True)
shutil.copyfile(ext.bazel_shared_lib_output, shared_lib_dest_path)
# Get the TensorFlow version this is built with
tf_version = get_repo_version(tensorflow_src_dir)
with open(os.path.join(ext.library_name, "__init__.py"), "r") as f:
module_init_file = f.read()
module_init_file += "__tensorflow_version__ = \"{}\"\n".format(tf_version)
with open(os.path.join(package_dir, "__init__.py"), "w") as f:
f.write(module_init_file)
super().build_extension(ext)
def _clone_tf_repository(self, tensorflow_src_dir, tensorflow_version):
if os.path.exists(tensorflow_src_dir):
return
tensorflow_repo = "https://github.com/tensorflow/tensorflow.git"
self.spawn(
[
"git",
"clone",
"--depth=1",
"--branch",
"v" + tensorflow_version,
tensorflow_repo,
tensorflow_src_dir,
]
)
def get_long_description(nightly):
# Read the contents of README.md file
this_directory = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(this_directory, "README.md"), encoding="utf-8") as f:
long_description = f.read()
if not nightly:
# Replace relative links to existing files with absolute links to https://review.mlplatform.org
url = f"https://review.mlplatform.org/plugins/gitiles/tosa/tosa_checker/+/refs/tags/{TOSA_CHECKER_VERSION}/"
# Find all markdown links that match the format: [text](link)
for match, link in re.findall(r"(\[.+?\]\((.+?)\))", long_description):
# If the link is a file that exists, replace it with the web link to the file instead
if os.path.exists(os.path.join(this_directory, link)):
url_link = re.sub(r"\((.+?)\)", rf"({url}{link})", match)
long_description = long_description.replace(match, url_link)
return long_description
setuptools.setup(
name="tosa-checker",
version=get_package_version(args.nightly),
description="Tool to check if a ML model is compatible with the TOSA specification",
long_description=get_long_description(args.nightly),
long_description_content_type="text/markdown",
author="Arm Limited",
url="https://git.mlplatform.org/tosa/tosa_checker.git/",
license="Apache-2.0",
python_requires=">=3.8",
cmdclass={"build_ext": BazelBuildExtension},
ext_modules=[
BazelExtensionModule(
py_module_name="_tosa_checker_wrapper",
library_name="tosa_checker",
bazel_target="//tosa_checker:tosa_checker",
bazel_shared_lib_output="bazel-bin/tosa_checker/_tosa_checker_wrapper.so",
tensorflow_version=DEFAULT_TENSORFLOW_VERSION,
),
],
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Topic :: Utilities",
],
)