Skip to content

Commit 1962e35

Browse files
committed
Add CUDA version compatibility check
Warn when cuda-bindings was compiled against a newer CUDA major version than the installed driver supports. This helps users understand why certain features may not work correctly. The check runs once after cuInit and can be suppressed via CUDA_PYTHON_DISABLE_VERSION_CHECK=1.
1 parent 62a8cb3 commit 1962e35

File tree

4 files changed

+176
-0
lines changed

4 files changed

+176
-0
lines changed

cuda_bindings/docs/source/environment_variables.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ Runtime Environment Variables
99

1010
- ``CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM`` : When set to 1, the default stream is the per-thread default stream. When set to 0, the default stream is the legacy default stream. This defaults to 0, for the legacy default stream. See `Stream Synchronization Behavior <https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html>`_ for an explanation of the legacy and per-thread default streams.
1111

12+
- ``CUDA_PYTHON_DISABLE_VERSION_CHECK`` : When set to 1, suppresses the warning that is issued when ``cuda.core`` detects that ``cuda-bindings`` was compiled against a newer CUDA major version than the installed driver supports. This warning helps identify version mismatches that may cause features to not work correctly.
13+
1214

1315
Build-Time Environment Variables
1416
--------------------------------

cuda_core/cuda/core/_device.pyx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ from cuda.core._graph import GraphBuilder
1717
from cuda.core._stream import IsStreamT, Stream, StreamOptions
1818
from cuda.core._utils.clear_error_support import assert_type
1919
from cuda.core._utils.cuda_utils import (
20+
check_cuda_version_compatibility,
2021
ComputeCapability,
2122
CUDAError,
2223
driver,
@@ -963,6 +964,8 @@ class Device:
963964
with _lock, nogil:
964965
HANDLE_RETURN(cydriver.cuInit(0))
965966
_is_cuInit = True
967+
# Check version compatibility after CUDA is initialized
968+
check_cuda_version_compatibility()
966969

967970
# important: creating a Device instance does not initialize the GPU!
968971
cdef cydriver.CUdevice dev

cuda_core/cuda/core/_utils/cuda_utils.pyx

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import functools
66
from functools import partial
77
import importlib.metadata
88
import multiprocessing
9+
import os
910
import platform
1011
import warnings
1112
from collections import namedtuple
@@ -288,6 +289,67 @@ class Transaction:
288289
self._stack.pop_all()
289290

290291

292+
# Track whether we've already checked version compatibility
293+
_version_compatibility_checked = False
294+
295+
296+
def reset_version_compatibility_check():
297+
"""Reset the version compatibility check flag for testing purposes.
298+
299+
This function is intended for use in tests to allow multiple test runs
300+
to check the warning behavior.
301+
"""
302+
global _version_compatibility_checked
303+
_version_compatibility_checked = False
304+
305+
306+
def check_cuda_version_compatibility():
307+
"""Check if the CUDA driver version is compatible with cuda-bindings compile-time version.
308+
309+
This function compares the CUDA version that cuda-bindings was compiled against
310+
with the CUDA version supported by the installed driver. If the compile-time
311+
major version is greater than the driver's major version, a warning is issued.
312+
313+
The warning can be suppressed by setting the environment variable
314+
CUDA_PYTHON_DISABLE_VERSION_CHECK=1.
315+
"""
316+
global _version_compatibility_checked
317+
if _version_compatibility_checked:
318+
return
319+
_version_compatibility_checked = True
320+
321+
# Allow users to suppress the warning
322+
if os.environ.get("CUDA_PYTHON_DISABLE_VERSION_CHECK"):
323+
return
324+
325+
# Get compile-time CUDA version from cuda-bindings
326+
try:
327+
compile_version = driver.CUDA_VERSION # e.g., 13010
328+
except AttributeError:
329+
# Older cuda-bindings may not expose CUDA_VERSION
330+
return
331+
332+
# Get runtime driver version
333+
err, runtime_version = driver.cuDriverGetVersion()
334+
if err != driver.CUresult.CUDA_SUCCESS:
335+
return # Can't check, skip silently
336+
337+
compile_major = compile_version // 1000
338+
runtime_major = runtime_version // 1000
339+
340+
if compile_major > runtime_major:
341+
compile_minor = (compile_version % 1000) // 10
342+
runtime_minor = (runtime_version % 1000) // 10
343+
warnings.warn(
344+
f"cuda-python was built against CUDA {compile_major}.{compile_minor}, "
345+
f"but the installed driver only supports CUDA {runtime_major}.{runtime_minor}. "
346+
f"Some features may not work correctly. Consider updating your NVIDIA driver. "
347+
f"Set CUDA_PYTHON_DISABLE_VERSION_CHECK=1 to suppress this warning.",
348+
UserWarning,
349+
stacklevel=3,
350+
)
351+
352+
291353
# Track whether we've already warned about fork method
292354
_fork_warning_checked = False
293355

cuda_core/tests/test_cuda_utils.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
#
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

5+
import os
6+
import warnings
7+
from unittest import mock
8+
59
import pytest
610
from cuda.bindings import driver, runtime
711
from cuda.core._utils import cuda_utils
@@ -75,3 +79,108 @@ def test_check_runtime_error():
7579
assert enum_name in msg
7680
# Smoke test: We don't want most to be unexpected.
7781
assert num_unexpected < len(driver.CUresult) * 0.5
82+
83+
84+
class TestVersionCompatibilityCheck:
85+
"""Tests for check_cuda_version_compatibility function."""
86+
87+
def setup_method(self):
88+
"""Reset the version compatibility check flag before each test."""
89+
cuda_utils.reset_version_compatibility_check()
90+
91+
def teardown_method(self):
92+
"""Reset the version compatibility check flag after each test."""
93+
cuda_utils.reset_version_compatibility_check()
94+
95+
def test_no_warning_when_driver_newer(self):
96+
"""No warning should be issued when driver version >= compile version."""
97+
# Mock compile version 12.9 and driver version 13.0
98+
with (
99+
mock.patch.object(driver, "CUDA_VERSION", 12090),
100+
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 13000)),
101+
warnings.catch_warnings(record=True) as w,
102+
):
103+
warnings.simplefilter("always")
104+
cuda_utils.check_cuda_version_compatibility()
105+
assert len(w) == 0
106+
107+
def test_no_warning_when_same_major_version(self):
108+
"""No warning should be issued when major versions match."""
109+
# Mock compile version 12.9 and driver version 12.8
110+
with (
111+
mock.patch.object(driver, "CUDA_VERSION", 12090),
112+
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)),
113+
warnings.catch_warnings(record=True) as w,
114+
):
115+
warnings.simplefilter("always")
116+
cuda_utils.check_cuda_version_compatibility()
117+
assert len(w) == 0
118+
119+
def test_warning_when_compile_major_newer(self):
120+
"""Warning should be issued when compile major version > driver major version."""
121+
# Mock compile version 13.0 and driver version 12.8
122+
with (
123+
mock.patch.object(driver, "CUDA_VERSION", 13000),
124+
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)),
125+
warnings.catch_warnings(record=True) as w,
126+
):
127+
warnings.simplefilter("always")
128+
cuda_utils.check_cuda_version_compatibility()
129+
assert len(w) == 1
130+
assert issubclass(w[0].category, UserWarning)
131+
assert "cuda-python was built against CUDA 13.0" in str(w[0].message)
132+
assert "driver only supports CUDA 12.8" in str(w[0].message)
133+
134+
def test_warning_only_issued_once(self):
135+
"""Warning should only be issued once per process."""
136+
with (
137+
mock.patch.object(driver, "CUDA_VERSION", 13000),
138+
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)),
139+
warnings.catch_warnings(record=True) as w,
140+
):
141+
warnings.simplefilter("always")
142+
cuda_utils.check_cuda_version_compatibility()
143+
cuda_utils.check_cuda_version_compatibility()
144+
cuda_utils.check_cuda_version_compatibility()
145+
# Only one warning despite multiple calls
146+
assert len(w) == 1
147+
148+
def test_warning_suppressed_by_env_var(self):
149+
"""Warning should be suppressed when CUDA_PYTHON_DISABLE_VERSION_CHECK is set."""
150+
with (
151+
mock.patch.object(driver, "CUDA_VERSION", 13000),
152+
mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)),
153+
mock.patch.dict(os.environ, {"CUDA_PYTHON_DISABLE_VERSION_CHECK": "1"}),
154+
warnings.catch_warnings(record=True) as w,
155+
):
156+
warnings.simplefilter("always")
157+
cuda_utils.check_cuda_version_compatibility()
158+
assert len(w) == 0
159+
160+
def test_silent_when_driver_version_fails(self):
161+
"""Should silently skip if cuDriverGetVersion fails."""
162+
with (
163+
mock.patch.object(driver, "CUDA_VERSION", 13000),
164+
mock.patch.object(
165+
driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_ERROR_NOT_INITIALIZED, 0)
166+
),
167+
warnings.catch_warnings(record=True) as w,
168+
):
169+
warnings.simplefilter("always")
170+
cuda_utils.check_cuda_version_compatibility()
171+
assert len(w) == 0
172+
173+
def test_silent_when_cuda_version_not_available(self):
174+
"""Should silently skip if CUDA_VERSION attribute is not available."""
175+
# Simulate older cuda-bindings without CUDA_VERSION
176+
with mock.patch.object(driver, "CUDA_VERSION", None):
177+
# Make accessing CUDA_VERSION raise AttributeError
178+
original = driver.CUDA_VERSION
179+
del driver.CUDA_VERSION
180+
try:
181+
with warnings.catch_warnings(record=True) as w:
182+
warnings.simplefilter("always")
183+
cuda_utils.check_cuda_version_compatibility()
184+
assert len(w) == 0
185+
finally:
186+
driver.CUDA_VERSION = original

0 commit comments

Comments
 (0)