blob: a62cd227b367544087105c5aa0b836643734588f [file] [log] [blame]
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
import pathlib
import setuptools
import setuptools.command.build_ext
import shutil
import sys
TOSA_CHECKER_VERSION = "0.1.0"
TENSORFLOW_VERSION = "2.9.0"
# Get the TensorFlowâ„¢ source directory passed to the command line (if any).
# If none is given the sources are pulled from the official TF repository.
argparser = argparse.ArgumentParser()
argparser.add_argument(
"--tensorflow_src_dir", help="TensorFlow source directory path", required=False
)
args, unknown = argparser.parse_known_args()
sys.argv = [sys.argv[0]] + unknown
class BazelExtensionModule(setuptools.Extension):
def __init__(
self,
py_module_name,
library_name,
bazel_target,
bazel_shared_lib_output,
tensorflow_version,
):
super().__init__(py_module_name, sources=[])
self.library_name = library_name
self.bazel_target = bazel_target
self.bazel_shared_lib_output = bazel_shared_lib_output
self.tensorflow_version = tensorflow_version
class BazelBuildExtension(setuptools.command.build_ext.build_ext):
"""Override build_extension to build the library with bazel and copying it
beforehand."""
def build_extension(self, ext):
tensorflow_src_dir = args.tensorflow_src_dir
if not tensorflow_src_dir:
tensorflow_src_dir = os.path.join(self.build_temp, "tensorflow")
self._clone_tf_repository(
tensorflow_src_dir,
ext.tensorflow_version,
)
self.spawn(
[
"bazel",
"build",
"-c",
"opt",
# FIXME Some of the Bazel targets dependencies we use have
# a 'friends' visibility, check if our Bazel target can be added
# to the 'friends' list.
"--check_visibility=false",
"--override_repository=org_tensorflow="
+ os.path.abspath(tensorflow_src_dir),
ext.bazel_target,
]
)
shared_lib_dest_path = self.get_ext_fullpath(ext.name)
shared_lib_dest_dir = os.path.dirname(shared_lib_dest_path)
package_dir = os.path.join(shared_lib_dest_dir, ext.library_name)
os.makedirs(shared_lib_dest_dir, exist_ok=True)
os.makedirs(package_dir, exist_ok=True)
shutil.copyfile(ext.bazel_shared_lib_output, shared_lib_dest_path)
shutil.copy(os.path.join(ext.library_name, "__init__.py"), package_dir)
super().build_extension(ext)
def _clone_tf_repository(self, tensorflow_src_dir, tensorflow_version):
if os.path.exists(tensorflow_src_dir):
return
tensorflow_repo = "https://github.com/tensorflow/tensorflow.git"
self.spawn(
[
"git",
"clone",
"--depth=1",
"--branch",
"v" + tensorflow_version,
tensorflow_repo,
tensorflow_src_dir,
]
)
setuptools.setup(
name="tosa-checker",
version=TOSA_CHECKER_VERSION,
description="Tool to check if a ML model is compatible with the TOSA specification",
long_description=(pathlib.Path(__file__).parent / "README.md").read_text(),
long_description_content_type="text/markdown",
author="Arm Limited",
url="https://git.mlplatform.org/tosa/tosa_checker.git/",
license="Apache-2.0",
python_requires=">=3.7",
cmdclass={"build_ext": BazelBuildExtension},
ext_modules=[
BazelExtensionModule(
py_module_name="_tosa_checker_wrapper",
library_name="tosa_checker",
bazel_target="//tosa_checker:tosa_checker",
bazel_shared_lib_output="bazel-bin/tosa_checker/_tosa_checker_wrapper.so",
tensorflow_version=TENSORFLOW_VERSION,
),
],
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Topic :: Utilities",
],
)