|
| 1 | +import os |
| 2 | +import subprocess |
1 | 3 | import sys
|
2 |
| -from setuptools import setup, Extension |
| 4 | + |
| 5 | +from setuptools import Extension, setup |
| 6 | +from setuptools.command.bdist_egg import bdist_egg |
3 | 7 | from setuptools.command.build_ext import build_ext
|
| 8 | +from wheel.bdist_wheel import bdist_wheel |
4 | 9 |
|
5 | 10 |
|
6 |
| -class CustomBuildExt(build_ext): |
7 |
| - def build_extensions(self): |
8 |
| - # Import torch here, when we actually need it |
9 |
| - from torch.utils.cpp_extension import include_paths, library_paths |
| 11 | +ROOT = os.path.realpath(os.path.dirname(__file__)) |
10 | 12 |
|
11 |
| - # Update extension settings with torch-specific information |
12 |
| - for ext in self.extensions: |
13 |
| - ext.include_dirs = include_paths() |
14 |
| - ext.library_dirs = library_paths() |
15 |
| - ext.libraries = ["c10", "torch", "torch_cpu"] |
16 | 13 |
|
17 |
| - super().build_extensions() |
| 14 | +class universal_wheel(bdist_wheel): |
| 15 | + # When building the wheel, the `wheel` package assumes that if we have a |
| 16 | + # binary extension then we are linking to `libpython.so`; and thus the wheel |
| 17 | + # is only usable with a single python version. This is not the case for |
| 18 | + # here, and the wheel will be compatible with any Python >=3.7. This is |
| 19 | + # tracked in https://github.com/pypa/wheel/issues/185, but until then we |
| 20 | + # manually override the wheel tag. |
| 21 | + def get_tag(self): |
| 22 | + tag = bdist_wheel.get_tag(self) |
| 23 | + # tag[2:] contains the os/arch tags, we want to keep them |
| 24 | + return ("py3", "none") + tag[2:] |
18 | 25 |
|
19 | 26 |
|
20 |
| -if __name__ == "__main__": |
21 |
| - # Basic compilation settings |
22 |
| - extra_compile_args = ["-std=c++17"] |
23 |
| - extra_link_args = [] |
24 |
| - if sys.platform == "darwin": |
25 |
| - shared_lib_ext = ".dylib" |
26 |
| - extra_compile_args.append("-stdlib=libc++") |
27 |
| - extra_link_args.extend(["-stdlib=libc++", "-mmacosx-version-min=10.9"]) |
28 |
| - elif sys.platform == "linux": |
29 |
| - extra_compile_args.append("-fPIC") |
30 |
| - shared_lib_ext = ".so" |
31 |
| - else: |
32 |
| - raise RuntimeError(f"Unsupported platform {sys.platform}") |
33 |
| - |
34 |
| - neighbors_convert_extension = Extension( |
35 |
| - name="pet_neighbors_convert.neighbors_convert", |
36 |
| - sources=["src/pet_neighbors_convert/neighbors_convert.cpp"], |
37 |
| - language="c++", |
38 |
| - extra_compile_args=extra_compile_args, |
39 |
| - extra_link_args=extra_link_args, |
40 |
| - ) |
| 27 | +class cmake_ext(build_ext): |
| 28 | + """Build the native library using cmake""" |
| 29 | + |
| 30 | + def run(self): |
| 31 | + import torch |
| 32 | + |
| 33 | + torch_major, torch_minor, *_ = torch.__version__.split(".") |
| 34 | + |
| 35 | + source_dir = ROOT |
| 36 | + build_dir = os.path.join(ROOT, "build", "cmake-build") |
| 37 | + install_dir = os.path.join( |
| 38 | + os.path.realpath(self.build_lib), |
| 39 | + f"pet_neighbors_convert/torch-{torch_major}.{torch_minor}", |
| 40 | + ) |
| 41 | + |
| 42 | + os.makedirs(build_dir, exist_ok=True) |
| 43 | + |
| 44 | + cmake_options = [ |
| 45 | + "-DCMAKE_BUILD_TYPE=Release", |
| 46 | + f"-DCMAKE_INSTALL_PREFIX={install_dir}", |
| 47 | + f"-DPYTHON_EXECUTABLE={sys.executable}", |
| 48 | + ] |
| 49 | + |
| 50 | + if sys.platform.startswith("darwin"): |
| 51 | + cmake_options.append("-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=11.0") |
41 | 52 |
|
| 53 | + # ARCHFLAGS is used by cibuildwheel to pass the requested arch to the |
| 54 | + # compilers |
| 55 | + ARCHFLAGS = os.environ.get("ARCHFLAGS") |
| 56 | + if ARCHFLAGS is not None: |
| 57 | + cmake_options.append(f"-DCMAKE_C_FLAGS={ARCHFLAGS}") |
| 58 | + cmake_options.append(f"-DCMAKE_CXX_FLAGS={ARCHFLAGS}") |
| 59 | + |
| 60 | + subprocess.run( |
| 61 | + ["cmake", source_dir, *cmake_options], |
| 62 | + cwd=build_dir, |
| 63 | + check=True, |
| 64 | + ) |
| 65 | + build_command = [ |
| 66 | + "cmake", |
| 67 | + "--build", |
| 68 | + build_dir, |
| 69 | + "--target", |
| 70 | + "install", |
| 71 | + ] |
| 72 | + |
| 73 | + subprocess.run(build_command, check=True) |
| 74 | + |
| 75 | + |
| 76 | +class bdist_egg_disabled(bdist_egg): |
| 77 | + """Disabled version of bdist_egg |
| 78 | +
|
| 79 | + Prevents setup.py install performing setuptools' default easy_install, |
| 80 | + which it should never ever do. |
| 81 | + """ |
| 82 | + |
| 83 | + def run(self): |
| 84 | + sys.exit( |
| 85 | + "Aborting implicit building of eggs. " |
| 86 | + "Use `pip install .` to install from source." |
| 87 | + ) |
| 88 | + |
| 89 | + |
| 90 | +if __name__ == "__main__": |
42 | 91 | try:
|
43 | 92 | import torch
|
44 | 93 |
|
45 |
| - # if we have torch, we are building a wheel, which will only be compatible with |
46 |
| - # a single torch version |
| 94 | + # if we have torch, we are building a wheel - requires specific torch version |
47 | 95 | torch_v_major, torch_v_minor, *_ = torch.__version__.split(".")
|
48 | 96 | torch_version = f"== {torch_v_major}.{torch_v_minor}.*"
|
49 | 97 | except ImportError:
|
50 | 98 | # otherwise we are building a sdist
|
51 |
| - torch_version = ">= 1.12" |
| 99 | + torch_version = ">= 2.1" |
| 100 | + |
| 101 | + install_requires = [f"torch {torch_version}"] |
52 | 102 |
|
53 | 103 | setup(
|
54 |
| - install_requires=[f"torch {torch_version}"], |
55 |
| - ext_modules=[neighbors_convert_extension], |
56 |
| - cmdclass={"build_ext": CustomBuildExt}, |
57 |
| - package_data={"pet_neighbors_convert": [f"neighbors_convert{shared_lib_ext}"]}, |
| 104 | + install_requires=install_requires, |
| 105 | + ext_modules=[ |
| 106 | + Extension(name="neighbors_convert", sources=[]), |
| 107 | + ], |
| 108 | + cmdclass={ |
| 109 | + "build_ext": cmake_ext, |
| 110 | + "bdist_egg": bdist_egg if "bdist_egg" in sys.argv else bdist_egg_disabled, |
| 111 | + "bdist_wheel": universal_wheel, |
| 112 | + }, |
58 | 113 | )
|
0 commit comments