Skip to content

Commit 0957f91

Browse files
committed
fix: derive CUDA_CORE_BUILD_MAJOR from headers instead of bindings version
Fixes build failures when cuda-bindings reports major version 13 but CUDA headers are version 12, causing missing enum errors for CU_MEM_LOCATION_TYPE_NONE and CU_MEM_ALLOCATION_TYPE_MANAGED. The new _get_cuda_core_build_major_version() function prioritizes: 1. Explicit CUDA_CORE_BUILD_MAJOR env var (CI override) 2. CUDA_VERSION from cuda.h headers (matches compile target) 3. nvidia-smi driver-reported version (fallback) 4. cuda-bindings major version (last resort) Adds tests for the version detection logic in test_build_hooks.py.
1 parent f83eff2 commit 0957f91

File tree

2 files changed

+152
-1
lines changed

2 files changed

+152
-1
lines changed

cuda_core/build_hooks.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,42 @@ def _get_proper_cuda_bindings_major_version() -> str:
5454
return "13"
5555

5656

57+
@functools.cache
58+
def _get_cuda_core_build_major_version() -> str:
59+
# Explicit overwrite, e.g. in CI.
60+
cuda_major = os.environ.get("CUDA_CORE_BUILD_MAJOR")
61+
if cuda_major is not None:
62+
return cuda_major
63+
64+
# Try to derive from the CUDA headers (preferred; matches what we compile against).
65+
CUDA_PATH = os.environ.get("CUDA_PATH", os.environ.get("CUDA_HOME", None))
66+
if CUDA_PATH:
67+
for root in CUDA_PATH.split(os.pathsep):
68+
cuda_h = os.path.join(root, "include", "cuda.h")
69+
try:
70+
with open(cuda_h, encoding="utf-8") as f:
71+
for line in f:
72+
m = re.match(r"^#\s*define\s+CUDA_VERSION\s+(\d+)\s*$", line)
73+
if m:
74+
v = int(m.group(1))
75+
# CUDA_VERSION is e.g. 12020 for 12.2.
76+
return str(v // 1000)
77+
except OSError:
78+
continue
79+
80+
# Fall back to driver-reported CUDA version if available.
81+
try:
82+
out = subprocess.run("nvidia-smi", env=os.environ, capture_output=True, check=True) # noqa: S603, S607
83+
m = re.search(r"CUDA Version:\s*([\d\.]+)", out.stdout.decode())
84+
if m:
85+
return m.group(1).split(".")[0]
86+
except (FileNotFoundError, subprocess.CalledProcessError):
87+
pass
88+
89+
# Last resort: align to cuda-bindings major.
90+
return _get_proper_cuda_bindings_major_version()
91+
92+
5793
# used later by setup()
5894
_extensions = None
5995

@@ -104,7 +140,7 @@ def get_cuda_paths():
104140
)
105141

106142
nthreads = int(os.environ.get("CUDA_PYTHON_PARALLEL_LEVEL", os.cpu_count() // 2))
107-
compile_time_env = {"CUDA_CORE_BUILD_MAJOR": int(_get_proper_cuda_bindings_major_version())}
143+
compile_time_env = {"CUDA_CORE_BUILD_MAJOR": int(_get_cuda_core_build_major_version())}
108144
compiler_directives = {"embedsignature": True, "warn.deprecated.IF": False, "freethreading_compatible": True}
109145
if COMPILE_FOR_COVERAGE:
110146
compiler_directives["linetrace"] = True
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Tests for build_hooks.py build infrastructure.
5+
6+
These tests verify the CUDA version detection logic used during builds,
7+
particularly the _get_cuda_core_build_major_version() function which
8+
derives the CUDA major version from headers (preferred) rather than
9+
cuda-bindings version to avoid mismatches.
10+
11+
Note: These tests do NOT require cuda.core to be built/installed since they
12+
test build-time infrastructure. Run with --noconftest to avoid loading
13+
conftest.py which imports cuda.core modules:
14+
15+
pytest tests/test_build_hooks.py -v --noconftest
16+
"""
17+
18+
import importlib.util
19+
import os
20+
import tempfile
21+
from pathlib import Path
22+
from unittest import mock
23+
24+
import pytest
25+
26+
27+
def _load_build_hooks():
28+
"""Load build_hooks module from source without permanently modifying sys.path.
29+
30+
build_hooks.py is a PEP 517 build backend, not an installed module.
31+
We use importlib to load it directly from source to avoid polluting
32+
sys.path with the cuda_core/ directory (which contains cuda/core/ source
33+
that could shadow the installed package).
34+
"""
35+
build_hooks_path = Path(__file__).parent.parent / "build_hooks.py"
36+
spec = importlib.util.spec_from_file_location("build_hooks", build_hooks_path)
37+
module = importlib.util.module_from_spec(spec)
38+
spec.loader.exec_module(module)
39+
return module
40+
41+
42+
# Load the module once at import time
43+
build_hooks = _load_build_hooks()
44+
45+
46+
def _check_version_detection(
47+
cuda_version, expected_major, *, use_cuda_path=True, use_cuda_home=False, cuda_core_build_major=None
48+
):
49+
"""Test version detection with a mock cuda.h.
50+
51+
Args:
52+
cuda_version: CUDA_VERSION to write in mock cuda.h (e.g., 12080)
53+
expected_major: Expected return value (e.g., "12")
54+
use_cuda_path: If True, set CUDA_PATH to the mock headers directory
55+
use_cuda_home: If True, set CUDA_HOME to the mock headers directory
56+
cuda_core_build_major: If set, override with this CUDA_CORE_BUILD_MAJOR env var
57+
"""
58+
with tempfile.TemporaryDirectory() as tmpdir:
59+
include_dir = Path(tmpdir) / "include"
60+
include_dir.mkdir()
61+
cuda_h = include_dir / "cuda.h"
62+
cuda_h.write_text(f"#define CUDA_VERSION {cuda_version}\n")
63+
64+
build_hooks._get_cuda_core_build_major_version.cache_clear()
65+
66+
mock_env = {
67+
k: v
68+
for k, v in {
69+
"CUDA_CORE_BUILD_MAJOR": cuda_core_build_major,
70+
"CUDA_PATH": tmpdir if use_cuda_path else None,
71+
"CUDA_HOME": tmpdir if use_cuda_home else None,
72+
}.items()
73+
if v is not None
74+
}
75+
76+
with mock.patch.dict(os.environ, mock_env, clear=True):
77+
result = build_hooks._get_cuda_core_build_major_version()
78+
assert result == expected_major
79+
80+
81+
class TestGetCudaCoreBuildMajorVersion:
82+
"""Tests for _get_cuda_core_build_major_version()."""
83+
84+
@pytest.mark.parametrize("version", ["11", "12", "13", "14"])
85+
def test_env_var_override(self, version):
86+
"""CUDA_CORE_BUILD_MAJOR env var override works with various versions."""
87+
build_hooks._get_cuda_core_build_major_version.cache_clear()
88+
with mock.patch.dict(os.environ, {"CUDA_CORE_BUILD_MAJOR": version}, clear=False):
89+
result = build_hooks._get_cuda_core_build_major_version()
90+
assert result == version
91+
92+
@pytest.mark.parametrize(
93+
("cuda_version", "expected_major"),
94+
[
95+
(11000, "11"), # CUDA 11.0
96+
(11080, "11"), # CUDA 11.8
97+
(12000, "12"), # CUDA 12.0
98+
(12020, "12"), # CUDA 12.2
99+
(12080, "12"), # CUDA 12.8
100+
(13000, "13"), # CUDA 13.0
101+
(13010, "13"), # CUDA 13.1
102+
],
103+
ids=["11.0", "11.8", "12.0", "12.2", "12.8", "13.0", "13.1"],
104+
)
105+
def test_cuda_headers_parsing(self, cuda_version, expected_major):
106+
"""CUDA_VERSION is correctly parsed from cuda.h headers."""
107+
_check_version_detection(cuda_version, expected_major)
108+
109+
def test_cuda_home_fallback(self):
110+
"""CUDA_HOME is used if CUDA_PATH is not set."""
111+
_check_version_detection(12050, "12", use_cuda_path=False, use_cuda_home=True)
112+
113+
def test_env_var_takes_priority_over_headers(self):
114+
"""Env var override takes priority even when headers exist."""
115+
_check_version_detection(12080, "11", cuda_core_build_major="11")

0 commit comments

Comments
 (0)