diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 3ffe1c7b1d..e1ef36e616 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -235,31 +235,6 @@ def _get_sys_extension() -> str: raise RuntimeError(f"Unsupported operating system ({system})") -@functools.lru_cache(maxsize=None) -def _load_nvidia_cuda_library(lib_name: str): - """ - Attempts to load shared object file installed via pip. - - `lib_name`: Name of package as found in the `nvidia` dir in python environment. - """ - - so_paths = glob.glob( - os.path.join( - sysconfig.get_path("purelib"), - f"nvidia/{lib_name}/lib/lib*{_get_sys_extension()}.*[0-9]", - ) - ) - - path_found = len(so_paths) > 0 - ctypes_handles = [] - - if path_found: - for so_path in so_paths: - ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)) - - return path_found, ctypes_handles - - @functools.lru_cache(maxsize=None) def _nvidia_cudart_include_dir() -> str: """Returns the include directory for cuda_runtime.h if exists in python environment.""" @@ -279,101 +254,87 @@ def _nvidia_cudart_include_dir() -> str: @functools.lru_cache(maxsize=None) -def _load_cudnn(): - """Load CUDNN shared library.""" +def _load_cuda_library_from_python(lib_name: str): + """ + Attempts to load shared object file installed via python packages. - # Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set - cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH") - if cudnn_home: - libs = glob.glob(f"{cudnn_home}/**/libcudnn{_get_sys_extension()}*", recursive=True) - libs.sort(reverse=True, key=os.path.basename) - if libs: - return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + `lib_name`: Name of package as found in the `nvidia` dir in python environment. + """ - # Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda - cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" - libs = glob.glob(f"{cuda_home}/**/libcudnn{_get_sys_extension()}*", recursive=True) - libs.sort(reverse=True, key=os.path.basename) - if libs: - return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + ext = _get_sys_extension() + nvidia_dir = os.path.join(sysconfig.get_path("purelib"), "nvidia") - # Attempt to locate cuDNN in Python dist-packages - found, handle = _load_nvidia_cuda_library("cudnn") - if found: - return handle + # PyPI packages provided by nvidia libs exist + # in 3 possible direcories inside `nvidia`. + if os.path.isdir(os.path.join(nvidia_dir, "cu13")): + so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", f"lib/lib*{ext}.*[0-9]")) + elif os.path.isdir(os.path.join(nvidia_dir, lib_name)): + so_paths = glob.glob(os.path.join(nvidia_dir, lib_name, f"lib/lib*{ext}.*[0-9]")) + else: + so_paths = glob.glob(os.path.join(nvidia_dir, f"cuda_{lib_name}", f"lib/lib*{ext}.*[0-9]")) - # Attempt to locate libcudnn via ldconfig - libs = subprocess.check_output(["ldconfig", "-p"]) - libs = libs.decode("utf-8").split("\n") - sos = [] - for lib in libs: - if "libcudnn" in lib and "=>" in lib: - sos.append(lib.split(">")[1].strip()) - if sos: - return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) + path_found = len(so_paths) > 0 + ctypes_handles = [] - # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise - return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) + if path_found: + for so_path in so_paths: + ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)) + + return path_found, ctypes_handles @functools.lru_cache(maxsize=None) -def _load_nvrtc(): - """Load NVRTC shared library.""" - # Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda - cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" - libs = glob.glob(f"{cuda_home}/**/libnvrtc{_get_sys_extension()}*", recursive=True) - libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs)) - libs.sort(reverse=True, key=os.path.basename) - if libs: - return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) - - # Attempt to locate NVRTC in Python dist-packages - found, handle = _load_nvidia_cuda_library("cuda_nvrtc") - if found: - return handle +def _load_cuda_library_from_system(lib_name: str): + """ + Attempts to load shared object file installed via system/cuda-toolkit. + + `lib_name`: Name of library to load without extension or `lib` prefix. + """ - # Attempt to locate NVRTC via ldconfig - libs = subprocess.check_output(["ldconfig", "-p"]) - libs = libs.decode("utf-8").split("\n") - sos = [] - for lib in libs: - if "libnvrtc" in lib and "=>" in lib: - sos.append(lib.split(">")[1].strip()) - if sos: - return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) + # Where to look for the shared lib in decreasing order of preference. + paths = ( + os.environ.get(f"{lib_name.upper()}_HOME"), + os.environ.get(f"{lib_name.upper()}_PATH"), + os.environ.get("CUDA_HOME"), + os.environ.get("CUDA_PATH"), + "/usr/local/cuda", + ) - # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise - return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) + for path in paths: + if path is None: + continue + libs = glob.glob(f"{path}/**/lib{lib_name}{_get_sys_extension()}*", recursive=True) + libs.sort(reverse=True, key=os.path.basename) + if libs: + return True, ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + + # Search in LD_LIBRARY_PATH. + try: + _lib_handle = ctypes.CDLL(f"lib{lib_name}{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) + return True, _lib_handle + except OSError: + return False, None @functools.lru_cache(maxsize=None) -def _load_curand(): - """Load cuRAND shared library.""" - # Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda - cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" - libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True) - libs = list(filter(lambda x: not ("stub" in x), libs)) - libs.sort(reverse=True, key=os.path.basename) - if libs: - return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) - - # Attempt to locate cuRAND in Python dist-packages - found, handle = _load_nvidia_cuda_library("curand") +def _load_cuda_library(lib_name: str): + """ + Load given shared library. + Prioritize loading from system/toolkit + before checking python packages. + """ + + # Attempt to locate library in system. + found, handle = _load_cuda_library_from_system(lib_name) if found: return handle - # Attempt to locate cuRAND via ldconfig - libs = subprocess.check_output(["ldconfig", "-p"]) - libs = libs.decode("utf-8").split("\n") - sos = [] - for lib in libs: - if "libcurand" in lib and "=>" in lib: - sos.append(lib.split(">")[1].strip()) - if sos: - return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) + # Attempt to locate library in Python dist-packages. + found, handle = _load_cuda_library_from_python(lib_name) + if found: + return handle - # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise - return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) + raise RuntimeError(f"{lib_name} shared object not found.") @functools.lru_cache(maxsize=None) @@ -384,11 +345,16 @@ def _load_core_library(): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): sanity_checks_for_pypi_installation() - _CUDNN_LIB_CTYPES = _load_cudnn() - _NVRTC_LIB_CTYPES = _load_nvrtc() - _CURAND_LIB_CTYPES = _load_curand() - _CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") - _CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") + + # `_load_cuda_library` is used for packages that must be loaded + # during runtime. Both system and pypi packages are searched + # and an error is thrown if not found. + _CUDNN_LIB_CTYPES = _load_cuda_library("cudnn") + _NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc") + _CURAND_LIB_CTYPES = _load_cuda_library("curand") + _CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas") + _CUDART_LIB_CTYPES = _load_cuda_library_from_python("cuda_runtime") + _TE_LIB_CTYPES = _load_core_library() # Needed to find the correct headers for NVRTC kernels.