Skip to content

Commit c7aa9a7

Browse files
authored
[CUDA] Preload dependent DLLs (#23674)
### Description Changes: (1) Pass --cuda_version in packaging pipeline to build wheel command line so that cuda_version can be saved. Note that cuda_version is also required for generating extra_require for #23659. (2) Update steup.py and onnxruntime_validation.py to save cuda version to capi/build_and_package_info.py. (3) Add a helper function to preload dependent DLLs (MSVC, CUDA, CUDNN) in `__init__.py`. First we will try to load DLLs from nvidia site packages, then try load remaining DLLs with default path settings. ``` import onnxruntime onnxruntime.preload_dlls() ``` To show loaded DLLs, set `verbose=True`. It is also possible to disable loading some types of DLLs like: ``` onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True) ``` #### PyTorch and onnxruntime in Windows When working with pytorch, onnxruntime will reuse the CUDA and cuDNN DLLs loaded by pytorch as long as CUDA and cuDNN major versions are compatible. Preload DLLs actually might cause issues (see example 2 and 3 below) in Windows. Example 1: onnxruntime and torch can work together easily. ``` >>> import torch >>> import onnxruntime >>> session = onnxruntime.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"]) >>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True) ----List of loaded DLLs---- D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\curand64_10.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cufft64_11.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_heuristic64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_precompiled64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_ops64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_adv64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cublasLt64_12.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cublas64_12.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\nvrtc64_120_0.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\nvrtc-builtins64_124.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_runtime_compiled64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_cnn64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_graph64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\numpy.libs\msvcp140-d64049c6e3865410a7dda6a7e9f0c575.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudart64_12.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn64_9.dll D:\anaconda3\envs\py310\msvcp140.dll D:\anaconda3\envs\py310\msvcp140_1.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cufftw64_11.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\caffe2_nvrtc.dll D:\anaconda3\envs\py310\vcruntime140_1.dll D:\anaconda3\envs\py310\vcruntime140.dll >>> session.get_providers() ['CUDAExecutionProvider', 'CPUExecutionProvider'] ``` Example 2: Use preload_dlls after `import torch` is not necessary. Unfortunately, it seems that multiple DLLs of same filename are loaded. They can be used in parallel but not ideal since more memory is used. ``` >>> import torch >>> import onnxruntime >>> onnxruntime.preload_dlls(verbose=True) ----List of loaded DLLs---- D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cufft\bin\cufft64_11.dll D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublas64_12.dll D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublasLt64_12.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\curand64_10.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cufft64_11.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_heuristic64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_precompiled64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_ops64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_adv64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cublasLt64_12.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cublas64_12.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\nvrtc64_120_0.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\nvrtc-builtins64_124.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_engines_runtime_compiled64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_cnn64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn_graph64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_graph64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cuda_runtime\bin\cudart64_12.dll D:\anaconda3\envs\py310\Lib\site-packages\numpy.libs\msvcp140-d64049c6e3865410a7dda6a7e9f0c575.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudart64_12.dll D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cudnn64_9.dll D:\anaconda3\envs\py310\msvcp140_1.dll D:\anaconda3\envs\py310\msvcp140.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\cufftw64_11.dll D:\anaconda3\envs\py310\Lib\site-packages\torch\lib\caffe2_nvrtc.dll D:\anaconda3\envs\py310\vcruntime140_1.dll D:\anaconda3\envs\py310\vcruntime140.dll ``` Example 3: Use preload_dlls before `import torch` might cause torch import error in Windows. Later we may provide an option to load DLLs from torch directory to avoid this issue. ``` >>> import onnxruntime >>> onnxruntime.preload_dlls(verbose=True) ----List of loaded DLLs---- D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cufft\bin\cufft64_11.dll D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublas64_12.dll D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cublas\bin\cublasLt64_12.dll D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn_graph64_9.dll D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cuda_runtime\bin\cudart64_12.dll D:\anaconda3\envs\py310\Lib\site-packages\numpy.libs\msvcp140-d64049c6e3865410a7dda6a7e9f0c575.dll D:\anaconda3\envs\py310\Lib\site-packages\nvidia\cudnn\bin\cudnn64_9.dll D:\anaconda3\envs\py310\msvcp140.dll D:\anaconda3\envs\py310\vcruntime140_1.dll D:\anaconda3\envs\py310\msvcp140_1.dll D:\anaconda3\envs\py310\vcruntime140.dll >>> import torch Traceback (most recent call last): File "<stdin>", line 1, in <module> File "D:\anaconda3\envs\py310\lib\site-packages\torch\__init__.py", line 137, in <module> raise err OSError: [WinError 127] The specified procedure could not be found. Error loading "D:\anaconda3\envs\py310\lib\site-packages\torch\lib\cudnn_adv64_9.dll" or one of its dependencies. ``` #### PyTorch and onnxruntime in Linux In Linux, since pytorch uses nvidia site packages for CUDA and cuDNN DLLs. Preload DLLs consistently loads same set of DLLs, and it could help maintaining. ``` >>> import onnxruntime >>> onnxruntime.preload_dlls(verbose=True) ----List of loaded DLLs---- /anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn.so.9 /anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn_graph.so.9 /anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 /anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cufft/lib/libcufft.so.11 /anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/curand/lib/libcurand.so.10 /anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cuda_nvrtc/lib/libnvrtc.so.12 /anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cublas/lib/libcublas.so.12 /anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cublas/lib/libcublasLt.so.12 >>> import torch >>> torch.rand(3, 3).cuda() tensor([[0.4619, 0.0279, 0.2092], [0.0416, 0.6782, 0.5889], [0.9988, 0.9092, 0.7982]], device='cuda:0') >>> session = onnxruntime.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"]) >>> session.get_providers() ['CUDAExecutionProvider', 'CPUExecutionProvider'] ``` ``` >>> import torch >>> import onnxruntime >>> session = onnxruntime.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"]) >>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True) ----List of loaded DLLs---- /cuda12.8/targets/x86_64-linux/lib/libnvrtc.so.12.8.61 /cudnn9.7/lib/libcudnn_graph.so.9.7.0 /anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cublas/lib/libcublasLt.so.12 /anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cublas/lib/libcublas.so.12 /anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/curand/lib/libcurand.so.10 /anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cufft/lib/libcufft.so.11 /anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cudnn/lib/libcudnn.so.9 /anaconda3/envs/py310/lib/python3.10/site-packages/nvidia/cuda_runtime/lib/libcudart.so.12 ``` Without preloading DLLs, onnxruntime will load CUDA and cuDNN DLLs based on `LD_LIBRARY_PATH`. Torch will reuse the same DLLs loaded by onnxruntime: ``` >>> import onnxruntime >>> session = onnxruntime.InferenceSession("model.onnx", providers=["CUDAExecutionProvider"]) >>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True) ----List of loaded DLLs---- /cuda12.8/targets/x86_64-linux/lib/libnvrtc.so.12.8.61 /cuda12.8/targets/x86_64-linux/lib/libcufft.so.11.3.3.41 /cuda12.8/targets/x86_64-linux/lib/libcurand.so.10.3.9.55 /cuda12.8/targets/x86_64-linux/lib/libcublas.so.12.8.3.14 /cuda12.8/targets/x86_64-linux/lib/libcublasLt.so.12.8.3.14 /cudnn9.7/lib/libcudnn_graph.so.9.7.0 /cudnn9.7/lib/libcudnn.so.9.7.0 /cuda12.8/targets/x86_64-linux/lib/libcudart.so.12.8.57 >>> import torch >>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True) ----List of loaded DLLs---- /cuda12.8/targets/x86_64-linux/lib/libnvrtc.so.12.8.61 /cuda12.8/targets/x86_64-linux/lib/libcufft.so.11.3.3.41 /cuda12.8/targets/x86_64-linux/lib/libcurand.so.10.3.9.55 /cuda12.8/targets/x86_64-linux/lib/libcublas.so.12.8.3.14 /cuda12.8/targets/x86_64-linux/lib/libcublasLt.so.12.8.3.14 /cudnn9.7/lib/libcudnn_graph.so.9.7.0 /cudnn9.7/lib/libcudnn.so.9.7.0 /cuda12.8/targets/x86_64-linux/lib/libcudart.so.12.8.57 >>> torch.rand(3, 3).cuda() tensor([[0.2233, 0.9194, 0.8078], [0.0906, 0.2884, 0.3655], [0.6249, 0.2904, 0.4568]], device='cuda:0') >>> onnxruntime.preload_dlls(cuda=False, cudnn=False, msvc=False, verbose=True) ----List of loaded DLLs---- /cuda12.8/targets/x86_64-linux/lib/libnvrtc.so.12.8.61 /cuda12.8/targets/x86_64-linux/lib/libcufft.so.11.3.3.41 /cuda12.8/targets/x86_64-linux/lib/libcurand.so.10.3.9.55 /cuda12.8/targets/x86_64-linux/lib/libcublas.so.12.8.3.14 /cuda12.8/targets/x86_64-linux/lib/libcublasLt.so.12.8.3.14 /cudnn9.7/lib/libcudnn_graph.so.9.7.0 /cudnn9.7/lib/libcudnn.so.9.7.0 /cuda12.8/targets/x86_64-linux/lib/libcudart.so.12.8.57 ``` ### Motivation and Context In many reported issues of import onnxruntime failure, the root cause is dependent DLLs missing or not in path. This change will make it easier to resolve those issues. This is based on Jian's PR #22506 with extra change to load msvc dlls. #23659 can be used to install CUDA/cuDNN dlls to site packages. Example command line after next official release 1.21: ``` pip install onnxruntime-gpu[cuda,cudnn] ``` If user installed pytorch in Linux, those DLLs are usually installed together with torch.
1 parent 4f66610 commit c7aa9a7

File tree

4 files changed

+161
-63
lines changed

4 files changed

+161
-63
lines changed

onnxruntime/__init__.py

+95-1
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,103 @@
7575
except ImportError:
7676
pass
7777

78-
from onnxruntime.capi.onnxruntime_validation import cuda_version, package_name, version # noqa: F401
78+
79+
package_name, version, cuda_version = onnxruntime_validation.get_package_name_and_version_info()
7980

8081
if version:
8182
__version__ = version
8283

8384
onnxruntime_validation.check_distro_info()
85+
86+
87+
def preload_dlls(cuda: bool = True, cudnn: bool = True, msvc: bool = True, verbose: bool = False):
88+
import ctypes
89+
import os
90+
import platform
91+
import site
92+
93+
if platform.system() not in ["Windows", "Linux"]:
94+
return
95+
96+
is_windows = platform.system() == "Windows"
97+
if is_windows and msvc:
98+
try:
99+
ctypes.CDLL("vcruntime140.dll")
100+
ctypes.CDLL("msvcp140.dll")
101+
if platform.machine() != "ARM64":
102+
ctypes.CDLL("vcruntime140_1.dll")
103+
except OSError:
104+
print("Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.")
105+
print("It can be downloaded at https://aka.ms/vs/17/release/vc_redist.x64.exe.")
106+
107+
if cuda_version and cuda_version.startswith("12.") and (cuda or cudnn):
108+
# Paths are relative to nvidia root in site packages.
109+
if is_windows:
110+
cuda_dll_paths = [
111+
("cublas", "bin", "cublasLt64_12.dll"),
112+
("cublas", "bin", "cublas64_12.dll"),
113+
("cufft", "bin", "cufft64_11.dll"),
114+
("cuda_runtime", "bin", "cudart64_12.dll"),
115+
]
116+
cudnn_dll_paths = [
117+
("cudnn", "bin", "cudnn_graph64_9.dll"),
118+
("cudnn", "bin", "cudnn64_9.dll"),
119+
]
120+
else: # Linux
121+
# cublas64 depends on cublasLt64, so cublasLt64 should be loaded first.
122+
cuda_dll_paths = [
123+
("cublas", "lib", "libcublasLt.so.12"),
124+
("cublas", "lib", "libcublas.so.12"),
125+
("cuda_nvrtc", "lib", "libnvrtc.so.12"),
126+
("curand", "lib", "libcurand.so.10"),
127+
("cufft", "lib", "libcufft.so.11"),
128+
("cuda_runtime", "lib", "libcudart.so.12"),
129+
]
130+
cudnn_dll_paths = [
131+
("cudnn", "lib", "libcudnn_graph.so.9"),
132+
("cudnn", "lib", "libcudnn.so.9"),
133+
]
134+
135+
# Try load DLLs from nvidia site packages.
136+
dll_paths = (cuda_dll_paths if cuda else []) + (cudnn_dll_paths if cudnn else [])
137+
loaded_dlls = []
138+
for site_packages_path in reversed(site.getsitepackages()):
139+
nvidia_path = os.path.join(site_packages_path, "nvidia")
140+
if os.path.isdir(nvidia_path):
141+
for relative_path in dll_paths:
142+
dll_path = os.path.join(nvidia_path, *relative_path)
143+
if os.path.isfile(dll_path):
144+
try:
145+
_ = ctypes.CDLL(dll_path)
146+
loaded_dlls.append(relative_path[-1])
147+
except Exception as e:
148+
print(f"Failed to load {dll_path}: {e}")
149+
break
150+
151+
# Try load DLLs with default path settings.
152+
has_failure = False
153+
for relative_path in dll_paths:
154+
dll_filename = relative_path[-1]
155+
if dll_filename not in loaded_dlls:
156+
try:
157+
_ = ctypes.CDLL(dll_filename)
158+
except Exception as e:
159+
has_failure = True
160+
print(f"Failed to load {dll_filename}: {e}")
161+
162+
if has_failure:
163+
print("Please follow https://onnxruntime.ai/docs/install/#cuda-and-cudnn to install CUDA and CuDNN.")
164+
165+
if verbose:
166+
167+
def is_target_dll(path: str):
168+
target_keywords = ["cufft", "cublas", "cudart", "nvrtc", "curand", "cudnn", "vcruntime140", "msvcp140"]
169+
return any(keyword in path for keyword in target_keywords)
170+
171+
import psutil
172+
173+
p = psutil.Process(os.getpid())
174+
print("----List of loaded DLLs----")
175+
for lib in p.memory_maps():
176+
if is_target_dll(lib.path.lower()):
177+
print(lib.path)

onnxruntime/python/onnxruntime_validation.py

+47-44
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,28 @@ def check_distro_info():
6868
)
6969

7070

71-
def validate_build_package_info():
71+
def get_package_name_and_version_info():
72+
package_name = ""
73+
version = ""
74+
cuda_version = ""
75+
76+
try:
77+
from .build_and_package_info import __version__ as version
78+
from .build_and_package_info import package_name
79+
80+
try: # noqa: SIM105
81+
from .build_and_package_info import cuda_version
82+
except ImportError:
83+
# cuda_version is optional. For example, cpu only package does not have the attribute.
84+
pass
85+
except Exception as e:
86+
warnings.warn("WARNING: failed to collect package name and version info")
87+
print(e)
88+
89+
return package_name, version, cuda_version
90+
91+
92+
def check_training_module():
7293
import_ortmodule_exception = None
7394

7495
has_ortmodule = False
@@ -96,48 +117,33 @@ def validate_build_package_info():
96117
if not has_ortmodule:
97118
import_ortmodule_exception = e
98119

99-
package_name = ""
100-
version = ""
101-
cuda_version = ""
120+
# collect onnxruntime package name, version, and cuda version
121+
package_name, version, cuda_version = get_package_name_and_version_info()
102122

103-
if has_ortmodule:
123+
if has_ortmodule and cuda_version:
104124
try:
105-
# collect onnxruntime package name, version, and cuda version
106-
from .build_and_package_info import __version__ as version
107-
from .build_and_package_info import package_name
108-
109-
try: # noqa: SIM105
110-
from .build_and_package_info import cuda_version
111-
except Exception:
112-
pass
113-
114-
if cuda_version:
115-
# collect cuda library build info. the library info may not be available
116-
# when the build environment has none or multiple libraries installed
117-
try:
118-
from .build_and_package_info import cudart_version
119-
except Exception:
120-
warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
121-
cudart_version = None
122-
123-
def print_build_package_info():
124-
warnings.warn(f"onnxruntime training package info: package_name: {package_name}")
125-
warnings.warn(f"onnxruntime training package info: __version__: {version}")
126-
warnings.warn(f"onnxruntime training package info: cuda_version: {cuda_version}")
127-
warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}")
128-
129-
# collection cuda library info from current environment.
130-
from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions
131-
132-
local_cudart_versions = find_cudart_versions(build_env=False, build_cuda_version=cuda_version)
133-
if cudart_version and local_cudart_versions and cudart_version not in local_cudart_versions:
134-
print_build_package_info()
135-
warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info")
136-
warnings.warn(f"WARNING: found cudart versions: {local_cudart_versions}")
137-
else:
138-
# TODO: rcom
139-
pass
140-
125+
# collect cuda library build info. the library info may not be available
126+
# when the build environment has none or multiple libraries installed
127+
try:
128+
from .build_and_package_info import cudart_version
129+
except ImportError:
130+
warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
131+
cudart_version = None
132+
133+
def print_build_package_info():
134+
warnings.warn(f"onnxruntime training package info: package_name: {package_name}")
135+
warnings.warn(f"onnxruntime training package info: __version__: {version}")
136+
warnings.warn(f"onnxruntime training package info: cuda_version: {cuda_version}")
137+
warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}")
138+
139+
# collection cuda library info from current environment.
140+
from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions
141+
142+
local_cudart_versions = find_cudart_versions(build_env=False, build_cuda_version=cuda_version)
143+
if cudart_version and local_cudart_versions and cudart_version not in local_cudart_versions:
144+
print_build_package_info()
145+
warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info")
146+
warnings.warn(f"WARNING: found cudart versions: {local_cudart_versions}")
141147
except Exception as e:
142148
warnings.warn("WARNING: failed to collect onnxruntime version and build info")
143149
print(e)
@@ -146,6 +152,3 @@ def print_build_package_info():
146152
raise import_ortmodule_exception
147153

148154
return has_ortmodule, package_name, version, cuda_version
149-
150-
151-
has_ortmodule, package_name, version, cuda_version = validate_build_package_info()

setup.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -724,21 +724,22 @@ def reformat_run_count(count_str):
724724
with open(requirements_path) as f:
725725
install_requires = f.read().splitlines()
726726

727-
if enable_training:
728727

729-
def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version):
730-
sys.path.append(path.join(path.dirname(__file__), "onnxruntime", "python"))
731-
from onnxruntime_collect_build_info import find_cudart_versions
728+
def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version):
729+
sys.path.append(path.join(path.dirname(__file__), "onnxruntime", "python"))
730+
from onnxruntime_collect_build_info import find_cudart_versions
732731

733-
version_path = path.join("onnxruntime", "capi", "build_and_package_info.py")
734-
with open(version_path, "w") as f:
735-
f.write(f"package_name = '{package_name}'\n")
736-
f.write(f"__version__ = '{version_number}'\n")
732+
version_path = path.join("onnxruntime", "capi", "build_and_package_info.py")
733+
with open(version_path, "w") as f:
734+
f.write(f"package_name = '{package_name}'\n")
735+
f.write(f"__version__ = '{version_number}'\n")
737736

738-
if cuda_version:
739-
f.write(f"cuda_version = '{cuda_version}'\n")
737+
if cuda_version:
738+
f.write(f"cuda_version = '{cuda_version}'\n")
740739

741-
# cudart_versions are integers
740+
# The cudart version used in building training packages in Linux.
741+
# It is possible to parse version.json at cuda_home in build.py, then pass in the parameter directly.
742+
if enable_training and platform.system().lower() == "linux":
742743
cudart_versions = find_cudart_versions(build_env=True)
743744
if cudart_versions and len(cudart_versions) == 1:
744745
f.write(f"cudart_version = {cudart_versions[0]}\n")
@@ -751,10 +752,11 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm
751752
else "found multiple cudart libraries"
752753
),
753754
)
754-
elif rocm_version:
755-
f.write(f"rocm_version = '{rocm_version}'\n")
755+
elif rocm_version:
756+
f.write(f"rocm_version = '{rocm_version}'\n")
756757

757-
save_build_and_package_info(package_name, version_number, cuda_version, rocm_version)
758+
759+
save_build_and_package_info(package_name, version_number, cuda_version, rocm_version)
758760

759761
extras_require = {}
760762
if package_name == "onnxruntime-gpu" and is_cuda_version_12:
@@ -770,7 +772,6 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm
770772
],
771773
}
772774

773-
# Setup
774775
setup(
775776
name=package_name,
776777
version=version_number,

tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ parameters:
3333
- Release
3434
- RelWithDebInfo
3535
- MinSizeRel
36-
36+
3737
- name: use_tensorrt
3838
type: boolean
3939
default: false
@@ -141,7 +141,7 @@ stages:
141141
displayName: 'Build wheel'
142142
inputs:
143143
scriptPath: '$(Build.SourcesDirectory)\setup.py'
144-
arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=${{ parameters.EP_NAME }}'
144+
arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=${{ parameters.EP_NAME }} --cuda_version=${{ parameters.CudaVersion }}'
145145
workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}'
146146

147147
- task: CopyFiles@2
@@ -195,7 +195,7 @@ stages:
195195
TMPDIR: "$(Agent.TempDirectory)"
196196

197197
- powershell: |
198-
198+
199199
python -m pip uninstall -y onnxruntime onnxruntime-gpu -qq
200200
Get-ChildItem -Path $(Build.ArtifactStagingDirectory)/*cp${{ replace(parameters.PYTHON_VERSION,'.','') }}*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname tabulate}
201201
mkdir -p $(Agent.TempDirectory)\ort_test_data

0 commit comments

Comments
 (0)