Skip to content

Commit aebc91e

Browse files
authored
Merge pull request #729 from messense/import-hook
Add a Python import hook
2 parents 7819bce + ac16e94 commit aebc91e

File tree

2 files changed

+150
-0
lines changed

2 files changed

+150
-0
lines changed

Changelog.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
* Don't package non-path-dep crates in sdist for workspaces in [#720](https://github.com/PyO3/maturin/pull/720)
1212
* Build release packages with `password-storage` feature in [#725](https://github.com/PyO3/maturin/pull/725)
1313
* Add support for x86_64 DargonFly BSD in [#727](https://github.com/PyO3/maturin/pull/727)
14+
* Add a Python import hook in [#729](https://github.com/PyO3/maturin/pull/729)
1415

1516
## [0.12.3] - 2021-11-29
1617

maturin/import_hook.py

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import contextlib
2+
import importlib
3+
import importlib.util
4+
from importlib import abc
5+
from importlib.machinery import ModuleSpec
6+
import os
7+
import pathlib
8+
import shutil
9+
import sys
10+
import subprocess
11+
from typing import Optional
12+
13+
import toml
14+
15+
16+
class Importer(abc.MetaPathFinder):
17+
"""A meta-path importer for the maturin based packages"""
18+
19+
def __init__(self, bindings: Optional[str] = None, release: bool = False):
20+
self.bindings = bindings
21+
self.release = release
22+
23+
def find_spec(self, fullname, path, target=None):
24+
if fullname in sys.modules:
25+
return
26+
mod_parts = fullname.split(".")
27+
module_name = mod_parts[-1]
28+
29+
cwd = pathlib.Path(os.getcwd())
30+
# Full Cargo project in cwd
31+
cargo_toml = cwd / "Cargo.toml"
32+
if _is_cargo_project(cargo_toml, module_name):
33+
return self._build_and_load(fullname, cargo_toml)
34+
35+
# Full Cargo project in subdirectory of cwd
36+
cargo_toml = cwd / module_name / "Cargo.toml"
37+
if _is_cargo_project(cargo_toml, module_name):
38+
return self._build_and_load(fullname, cargo_toml)
39+
# module name with '-' instead of '_'
40+
cargo_toml = cwd / module_name.replace("_", "-") / "Cargo.toml"
41+
if _is_cargo_project(cargo_toml, module_name):
42+
return self._build_and_load(fullname, cargo_toml)
43+
44+
# Single .rs file
45+
rust_file = cwd / (module_name + ".rs")
46+
if rust_file.exists():
47+
project_dir = generate_project(rust_file, bindings=self.bindings or "pyo3")
48+
cargo_toml = project_dir / "Cargo.toml"
49+
return self._build_and_load(fullname, cargo_toml)
50+
51+
def _build_and_load(self, fullname: str, cargo_toml: pathlib.Path) -> ModuleSpec:
52+
build_module(cargo_toml, bindings=self.bindings)
53+
loader = Loader(fullname)
54+
return importlib.util.spec_from_loader(fullname, loader)
55+
56+
57+
class Loader(abc.Loader):
58+
def __init__(self, fullname):
59+
self.fullname = fullname
60+
61+
def load_module(self, fullname):
62+
return importlib.import_module(self.fullname)
63+
64+
65+
def _is_cargo_project(cargo_toml: pathlib.Path, module_name: str) -> bool:
66+
with contextlib.suppress(FileNotFoundError):
67+
with open(cargo_toml) as f:
68+
cargo = toml.load(f)
69+
package_name = cargo.get("package", {}).get("name")
70+
if (
71+
package_name == module_name
72+
or package_name.replace("-", "_") == module_name
73+
):
74+
return True
75+
return False
76+
77+
78+
def generate_project(rust_file: pathlib.Path, bindings: str = "pyo3") -> pathlib.Path:
79+
build_dir = pathlib.Path(os.getcwd()) / "build"
80+
project_dir = build_dir / rust_file.stem
81+
if project_dir.exists():
82+
shutil.rmtree(project_dir)
83+
84+
command = ["maturin", "new", "-b", bindings, project_dir]
85+
result = subprocess.run(command, stdout=subprocess.PIPE)
86+
if result.returncode != 0:
87+
sys.stderr.write(
88+
f"Error: command {command} returned non-zero exit status {result.returncode}\n"
89+
)
90+
raise ImportError("Failed to generate cargo project")
91+
92+
with open(rust_file) as f:
93+
lib_rs_content = f.read()
94+
lib_rs = project_dir / "src" / "lib.rs"
95+
with open(lib_rs, "w") as f:
96+
f.write(lib_rs_content)
97+
return project_dir
98+
99+
100+
def build_module(
101+
manifest_path: pathlib.Path, bindings: Optional[str] = None, release: bool = False
102+
):
103+
command = ["maturin", "develop", "-m", manifest_path]
104+
if bindings:
105+
command.append("-b")
106+
command.append(bindings)
107+
if release:
108+
command.append("--release")
109+
result = subprocess.run(command, stdout=subprocess.PIPE)
110+
sys.stdout.buffer.write(result.stdout)
111+
sys.stdout.flush()
112+
if result.returncode != 0:
113+
sys.stderr.write(
114+
f"Error: command {command} returned non-zero exit status {result.returncode}\n"
115+
)
116+
raise ImportError("Failed to build module with maturin")
117+
118+
119+
def _have_importer() -> bool:
120+
for importer in sys.meta_path:
121+
if isinstance(importer, Importer):
122+
return True
123+
return False
124+
125+
126+
def install(bindings: Optional[str] = None, release: bool = False):
127+
"""
128+
Install the import hook.
129+
130+
:param bindings: Which kind of bindings to use.
131+
Possible values are pyo3, rust-cpython and cffi
132+
133+
:param release: Build in release mode, otherwise debug mode by default
134+
"""
135+
if _have_importer():
136+
return
137+
importer = Importer(bindings=bindings, release=release)
138+
sys.meta_path.append(importer)
139+
return importer
140+
141+
142+
def uninstall(importer: Importer):
143+
"""
144+
Uninstall the import hook.
145+
"""
146+
try:
147+
sys.meta_path.remove(importer)
148+
except ValueError:
149+
pass

0 commit comments

Comments
 (0)