Skip to content

Commit d3b7c60

Browse files
committed
Switch to cmake-based build
1 parent a29fa3d commit d3b7c60

File tree

6 files changed

+227
-56
lines changed

6 files changed

+227
-56
lines changed

.github/workflows/build-wheels.yml

+2
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ jobs:
145145
run: python -m pip install build
146146

147147
- name: build sdist
148+
env:
149+
PIP_EXTRA_INDEX_URL: "https://download.pytorch.org/whl/cpu"
148150
run: python -m build . --outdir=dist/
149151

150152
- uses: actions/upload-artifact@v4

CMakeLists.txt

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
cmake_minimum_required(VERSION 3.27)
2+
3+
if (POLICY CMP0076)
4+
# target_sources() converts relative paths to absolute
5+
cmake_policy(SET CMP0076 NEW)
6+
endif()
7+
8+
project(neighbors_convert CXX)
9+
10+
# Set a default build type if none was specified
11+
if (${CMAKE_CURRENT_SOURCE_DIR} STREQUAL ${CMAKE_SOURCE_DIR})
12+
if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "")
13+
message(STATUS "Setting build type to 'relwithdebinfo' as none was specified.")
14+
set(
15+
CMAKE_BUILD_TYPE "relwithdebinfo"
16+
CACHE STRING
17+
"Choose the type of build, options are: none(CMAKE_CXX_FLAGS or CMAKE_C_FLAGS used) debug release relwithdebinfo minsizerel."
18+
FORCE
19+
)
20+
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug relwithdebinfo minsizerel none)
21+
endif()
22+
endif()
23+
24+
# add path to the cmake configuration of the version of libtorch used
25+
# by the Python torch module. PYTHON_EXECUTABLE is provided by skbuild
26+
execute_process(
27+
COMMAND ${PYTHON_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
28+
RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT
29+
OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT
30+
ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR
31+
)
32+
33+
if (NOT ${TORCH_CMAKE_PATH_RESULT} EQUAL 0)
34+
message(FATAL_ERROR "failed to find your pytorch installation\n${TORCH_CMAKE_PATH_ERROR}")
35+
endif()
36+
37+
string(STRIP ${TORCH_CMAKE_PATH_OUTPUT} TORCH_CMAKE_PATH_OUTPUT)
38+
set(CMAKE_PREFIX_PATH "${CMAKE_PREFIX_PATH};${TORCH_CMAKE_PATH_OUTPUT}")
39+
40+
find_package(Torch 2.3 REQUIRED)
41+
42+
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/_build_torch_version.py "BUILD_TORCH_VERSION = '${Torch_VERSION}'")
43+
44+
add_library(neighbors_convert SHARED
45+
"src/pet_neighbors_convert/neighbors_convert.cpp"
46+
)
47+
48+
# only link to `torch_cpu_library` instead of `torch`, which could also include
49+
# `libtorch_cuda`.
50+
target_link_libraries(neighbors_convert PUBLIC torch_cpu_library)
51+
target_include_directories(neighbors_convert PUBLIC "${TORCH_INCLUDE_DIRS}")
52+
target_compile_definitions(neighbors_convert PUBLIC "${TORCH_CXX_FLAGS}")
53+
54+
target_compile_features(neighbors_convert PUBLIC cxx_std_17)
55+
56+
install(TARGETS neighbors_convert
57+
LIBRARY DESTINATION "lib"
58+
)

build-backend/backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
if FORCED_TORCH_VERSION is not None:
88
TORCH_DEP = f"torch =={FORCED_TORCH_VERSION}"
99
else:
10-
TORCH_DEP = "torch >=1.12"
10+
TORCH_DEP = "torch >=2.3"
1111

1212
# ==================================================================================== #
1313
# Build backend functions definition #

pyproject.toml

+5-5
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,18 @@ classifiers = [
1616
"Topic :: Scientific/Engineering",
1717
"Topic :: Software Development :: Libraries :: Python Modules"
1818
]
19-
description = "Extension for PET model for reordering the neighbors during message passing."
19+
description = "PET model extension for processing neighbor lists"
2020
dynamic = ["version", "dependencies"]
2121
license = {text = "BSD-3-Clause"}
2222
name = "pet-neighbors-convert"
2323
readme = "README.rst"
2424
requires-python = ">=3.9"
2525

26-
[tool.check-manifest]
27-
ignore = ["src/pet_neighbors_convert/_version.py"]
26+
# [tool.check-manifest]
27+
# ignore = ["src/pet_neighbors_convert/_version.py"]
2828

29-
[tool.setuptools_scm]
30-
version_file = "src/pet_neighbors_convert/_version.py"
29+
# [tool.setuptools_scm]
30+
# version_file = "src/pet_neighbors_convert/_version.py"
3131

3232
[tool.setuptools.packages.find]
3333
where = ["src"]

setup.py

+94-39
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,113 @@
1+
import os
2+
import subprocess
13
import sys
2-
from setuptools import setup, Extension
4+
5+
from setuptools import Extension, setup
6+
from setuptools.command.bdist_egg import bdist_egg
37
from setuptools.command.build_ext import build_ext
8+
from wheel.bdist_wheel import bdist_wheel
49

510

6-
class CustomBuildExt(build_ext):
7-
def build_extensions(self):
8-
# Import torch here, when we actually need it
9-
from torch.utils.cpp_extension import include_paths, library_paths
11+
ROOT = os.path.realpath(os.path.dirname(__file__))
1012

11-
# Update extension settings with torch-specific information
12-
for ext in self.extensions:
13-
ext.include_dirs = include_paths()
14-
ext.library_dirs = library_paths()
15-
ext.libraries = ["c10", "torch", "torch_cpu"]
1613

17-
super().build_extensions()
14+
class universal_wheel(bdist_wheel):
15+
# When building the wheel, the `wheel` package assumes that if we have a
16+
# binary extension then we are linking to `libpython.so`; and thus the wheel
17+
# is only usable with a single python version. This is not the case for
18+
# here, and the wheel will be compatible with any Python >=3.7. This is
19+
# tracked in https://github.com/pypa/wheel/issues/185, but until then we
20+
# manually override the wheel tag.
21+
def get_tag(self):
22+
tag = bdist_wheel.get_tag(self)
23+
# tag[2:] contains the os/arch tags, we want to keep them
24+
return ("py3", "none") + tag[2:]
1825

1926

20-
if __name__ == "__main__":
21-
# Basic compilation settings
22-
extra_compile_args = ["-std=c++17"]
23-
extra_link_args = []
24-
if sys.platform == "darwin":
25-
shared_lib_ext = ".dylib"
26-
extra_compile_args.append("-stdlib=libc++")
27-
extra_link_args.extend(["-stdlib=libc++", "-mmacosx-version-min=10.9"])
28-
elif sys.platform == "linux":
29-
extra_compile_args.append("-fPIC")
30-
shared_lib_ext = ".so"
31-
else:
32-
raise RuntimeError(f"Unsupported platform {sys.platform}")
33-
34-
neighbors_convert_extension = Extension(
35-
name="pet_neighbors_convert.neighbors_convert",
36-
sources=["src/pet_neighbors_convert/neighbors_convert.cpp"],
37-
language="c++",
38-
extra_compile_args=extra_compile_args,
39-
extra_link_args=extra_link_args,
40-
)
27+
class cmake_ext(build_ext):
28+
"""Build the native library using cmake"""
29+
30+
def run(self):
31+
import torch
32+
33+
torch_major, torch_minor, *_ = torch.__version__.split(".")
34+
35+
source_dir = ROOT
36+
build_dir = os.path.join(ROOT, "build", "cmake-build")
37+
install_dir = os.path.join(
38+
os.path.realpath(self.build_lib),
39+
f"pet_neighbors_convert/torch-{torch_major}.{torch_minor}",
40+
)
41+
42+
os.makedirs(build_dir, exist_ok=True)
43+
44+
cmake_options = [
45+
"-DCMAKE_BUILD_TYPE=Release",
46+
f"-DCMAKE_INSTALL_PREFIX={install_dir}",
47+
f"-DPYTHON_EXECUTABLE={sys.executable}",
48+
]
49+
50+
if sys.platform.startswith("darwin"):
51+
cmake_options.append("-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=11.0")
4152

53+
# ARCHFLAGS is used by cibuildwheel to pass the requested arch to the
54+
# compilers
55+
ARCHFLAGS = os.environ.get("ARCHFLAGS")
56+
if ARCHFLAGS is not None:
57+
cmake_options.append(f"-DCMAKE_C_FLAGS={ARCHFLAGS}")
58+
cmake_options.append(f"-DCMAKE_CXX_FLAGS={ARCHFLAGS}")
59+
60+
subprocess.run(
61+
["cmake", source_dir, *cmake_options],
62+
cwd=build_dir,
63+
check=True,
64+
)
65+
build_command = [
66+
"cmake",
67+
"--build",
68+
build_dir,
69+
"--target",
70+
"install",
71+
]
72+
73+
subprocess.run(build_command, check=True)
74+
75+
76+
class bdist_egg_disabled(bdist_egg):
77+
"""Disabled version of bdist_egg
78+
79+
Prevents setup.py install performing setuptools' default easy_install,
80+
which it should never ever do.
81+
"""
82+
83+
def run(self):
84+
sys.exit(
85+
"Aborting implicit building of eggs. "
86+
"Use `pip install .` to install from source."
87+
)
88+
89+
90+
if __name__ == "__main__":
4291
try:
4392
import torch
4493

45-
# if we have torch, we are building a wheel, which will only be compatible with
46-
# a single torch version
94+
# if we have torch, we are building a wheel - requires specific torch version
4795
torch_v_major, torch_v_minor, *_ = torch.__version__.split(".")
4896
torch_version = f"== {torch_v_major}.{torch_v_minor}.*"
4997
except ImportError:
5098
# otherwise we are building a sdist
51-
torch_version = ">= 1.12"
99+
torch_version = ">= 2.1"
100+
101+
install_requires = [f"torch {torch_version}"]
52102

53103
setup(
54-
install_requires=[f"torch {torch_version}"],
55-
ext_modules=[neighbors_convert_extension],
56-
cmdclass={"build_ext": CustomBuildExt},
57-
package_data={"pet_neighbors_convert": [f"neighbors_convert{shared_lib_ext}"]},
104+
install_requires=install_requires,
105+
ext_modules=[
106+
Extension(name="neighbors_convert", sources=[]),
107+
],
108+
cmdclass={
109+
"build_ext": cmake_ext,
110+
"bdist_egg": bdist_egg if "bdist_egg" in sys.argv else bdist_egg_disabled,
111+
"bdist_wheel": universal_wheel,
112+
},
58113
)

src/pet_neighbors_convert/__init__.py

+67-11
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,72 @@
1+
import os
2+
import sys
3+
from collections import namedtuple
4+
15
import torch
2-
import importlib.resources as pkg_resources
3-
from ._version import __version__ # noqa
46

7+
import re
8+
import glob
9+
10+
11+
Version = namedtuple("Version", ["major", "minor", "patch"])
12+
13+
14+
def parse_version(version):
15+
match = re.match(r"(\d+)\.(\d+)\.(\d+).*", version)
16+
if match:
17+
return Version(*map(int, match.groups()))
18+
else:
19+
raise ValueError("Invalid version string format")
20+
21+
22+
_HERE = os.path.realpath(os.path.dirname(__file__))
23+
24+
25+
def _lib_path():
26+
torch_version = parse_version(torch.__version__)
27+
expected_prefix = os.path.join(
28+
_HERE, f"torch-{torch_version.major}.{torch_version.minor}"
29+
)
30+
if os.path.exists(expected_prefix):
31+
if sys.platform.startswith("darwin"):
32+
path = os.path.join(expected_prefix, "lib", "libneighbors_convert.dylib")
33+
elif sys.platform.startswith("linux"):
34+
path = os.path.join(expected_prefix, "lib", "libneighbors_convert.so")
35+
elif sys.platform.startswith("win"):
36+
path = os.path.join(expected_prefix, "bin", "neighbors_convert.dll")
37+
else:
38+
raise ImportError("Unknown platform. Please edit this file")
39+
40+
if os.path.isfile(path):
41+
return path
42+
else:
43+
raise ImportError(
44+
"Could not find neighbors_convert shared library at " + path
45+
)
46+
47+
# gather which torch version(s) the current install was built
48+
# with to create the error message
49+
existing_versions = []
50+
for prefix in glob.glob(os.path.join(_HERE, "torch-*")):
51+
existing_versions.append(os.path.basename(prefix)[11:])
52+
53+
print(existing_versions)
554

6-
def load_neighbors_convert():
7-
try:
8-
# Locate the shared object file in the package
9-
with pkg_resources.files(__name__).joinpath("neighbors_convert.so") as lib_path:
10-
# Load the shared object file
11-
torch.ops.load_library(str(lib_path))
12-
except Exception as e:
13-
print(f"Failed to load neighbors_convert.so: {e}")
55+
if len(existing_versions) == 1:
56+
raise ImportError(
57+
f"Trying to load neighbors-convert with torch v{torch.__version__}, "
58+
f"but it was compiled against torch v{existing_versions[0]}, which "
59+
"is not ABI compatible"
60+
)
61+
else:
62+
all_versions = ", ".join(map(lambda version: f"v{version}", existing_versions))
63+
raise ImportError(
64+
f"Trying to load neighbors-convert with torch v{torch.__version__}, "
65+
f"we found builds for torch {all_versions}; which are not ABI compatible.\n"
66+
"You can try to re-install from source with "
67+
"`pip install neighbors-convert --no-binary=neighbors-convert`"
68+
)
1469

1570

16-
load_neighbors_convert()
71+
# load the C++ operators and custom classes
72+
torch.classes.load_library(_lib_path())

0 commit comments

Comments
 (0)