diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 39fc19e2..dd4adedb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -272,3 +272,22 @@ jobs: - name: Check no test always skipped run: | python continuous_integration/check_no_test_skipped.py test_results + + typecheck: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + + - uses: astral-sh/setup-uv@v6 + + - name: mypy + run: uvx mypy --no-incremental --cache-dir=/dev/null --python-version=3.9 threadpoolctl + + - name: pyright + run: uvx pyright --pythonversion=3.9 threadpoolctl + + - name: basedpyright + run: uvx pyright --pythonversion=3.9 threadpoolctl + + - name: ty check + run: uvx ty check --python-version=3.9 threadpoolctl diff --git a/pyproject.toml b/pyproject.toml index c0ea1696..92163f50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,9 +22,26 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", ] [tool.black] line-length = 88 target_version = ['py39', 'py310', 'py311', 'py312', 'py313'] preview = true + +[tool.mypy] +exclude = ["benchmarks", "continuous_integration", "tests"] +enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] +strict = true +warn_unreachable = true +local_partial_types = true +allow_redefinition_new = true + +[tool.pyright] +exclude = ["benchmarks", "continuous_integration", "tests"] +ignore = [".venv"] +stubPath = "." +typeCheckingMode = "strict" +reportPrivateUsage = false +reportConstantRedefinition = false diff --git a/threadpoolctl.py b/threadpoolctl/__init__.py similarity index 79% rename from threadpoolctl.py rename to threadpoolctl/__init__.py index e6ac58d8..17ec9313 100644 --- a/threadpoolctl.py +++ b/threadpoolctl/__init__.py @@ -11,20 +11,42 @@ # adapted from code by Intel developer @anton-malakhov available at # https://github.com/IntelPython/smp (Copyright (c) 2017, Intel Corporation) # and also published under the BSD 3-Clause license +import ctypes +import itertools import os import re import sys -import ctypes -import itertools import textwrap -from typing import final import warnings -from ctypes.util import find_library from abc import ABC, abstractmethod -from functools import lru_cache from contextlib import ContextDecorator +from ctypes.util import find_library +from functools import lru_cache +from typing import TYPE_CHECKING, Any, ClassVar, Final, Literal, Union, cast, final + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + from types import TracebackType + + from _typeshed import IdentityFunction + from typing_extensions import Never, Self, TypeAlias, TypedDict + + _ThreadingBackend: TypeAlias = Literal["openmp", "pthreads", "disabled", "unknown"] + _ToLimits: TypeAlias = Union[ + int, + dict[str, int], + list[dict[str, Any]], + "ThreadpoolController", + Literal["sequential_blas_under_openmp"], + None, + ] + + class _BLASParamsDict(TypedDict): + limits: Union[int, None] + user_api: Union[str, None] -__version__ = "3.7.0.dev0" + +__version__: Final = "3.7.0.dev0" __all__ = [ "threadpool_limits", "threadpool_info", @@ -63,10 +85,10 @@ class _dl_phdr_info(ctypes.Structure): # The RTLD_NOLOAD flag for loading shared libraries is not defined on Windows. -try: - _RTLD_NOLOAD = os.RTLD_NOLOAD -except AttributeError: +if sys.platform == "win32": _RTLD_NOLOAD = ctypes.DEFAULT_MODE +else: + _RTLD_NOLOAD = os.RTLD_NOLOAD class LibController(ABC): @@ -105,8 +127,27 @@ class LibController(ABC): must be set as attributes in the `set_additional_attributes` method. """ + user_api: ClassVar[str] # abstract + internal_api: ClassVar[str] # abstract + filename_prefixes: ClassVar[tuple[str, ...]] # abstract + check_symbols: ClassVar[tuple[str, ...]] # abstract + + parent: Final["LibController | ThreadpoolController | None"] + prefix: Final[Union[str, None]] + filepath: Final[Union[str, None]] + dynlib: Final[ctypes.CDLL] + _symbol_prefix: Final[str] + _symbol_suffix: Final[str] + version: Final[Union[str, None]] + @final - def __init__(self, *, filepath=None, prefix=None, parent=None): + def __init__( + self, + *, + filepath: Union[str, None] = None, + prefix: Union[str, None] = None, + parent: "LibController | ThreadpoolController | None" = None, + ) -> None: """This is not meant to be overriden by subclasses.""" self.parent = parent self.prefix = prefix @@ -116,7 +157,7 @@ def __init__(self, *, filepath=None, prefix=None, parent=None): self.version = self.get_version() self.set_additional_attributes() - def info(self): + def info(self) -> dict[str, Any]: """Return relevant info wrapped in a dict""" hidden_attrs = ("dynlib", "parent", "_symbol_prefix", "_symbol_suffix") return { @@ -126,11 +167,11 @@ def info(self): **{k: v for k, v in vars(self).items() if k not in hidden_attrs}, } - def set_additional_attributes(self): + def set_additional_attributes(self) -> None: """Set additional attributes meant to be exposed in the info dict""" @property - def num_threads(self): + def num_threads(self) -> Union[int, None]: """Exposes the current thread limit as a dynamic property This is not meant to be used or overriden by subclasses. @@ -138,22 +179,22 @@ def num_threads(self): return self.get_num_threads() @abstractmethod - def get_num_threads(self): + def get_num_threads(self) -> Union[int, None]: """Return the maximum number of threads available to use""" @abstractmethod - def set_num_threads(self, num_threads): + def set_num_threads(self, num_threads: int) -> None: """Set the maximum number of threads to use""" @abstractmethod - def get_version(self): + def get_version(self) -> Union[str, None]: """Return the version of the shared library""" - def _find_affixes(self): + def _find_affixes(self) -> tuple[str, str]: """Return the affixes for the symbols of the shared library""" return "", "" - def _get_symbol(self, name): + def _get_symbol(self, name: str) -> Union[Any, None]: """Return the symbol of the shared library accounding for the affixes""" return getattr( self.dynlib, f"{self._symbol_prefix}{name}{self._symbol_suffix}", None @@ -163,55 +204,70 @@ def _get_symbol(self, name): class OpenBLASController(LibController): """Controller class for OpenBLAS""" - user_api = "blas" - internal_api = "openblas" - filename_prefixes = ("libopenblas", "libblas", "libscipy_openblas") + user_api: ClassVar[str] = "blas" + internal_api: ClassVar[str] = "openblas" + filename_prefixes: ClassVar[tuple[str, ...]] = ( + "libopenblas", + "libblas", + "libscipy_openblas", + ) - _symbol_prefixes = ("", "scipy_") - _symbol_suffixes = ("", "64_", "_64") + _symbol_prefixes: ClassVar[tuple[str, ...]] = ("", "scipy_") + _symbol_suffixes: ClassVar[tuple[str, ...]] = ("", "64_", "_64") # All variations of "openblas_get_num_threads", accounting for the affixes - check_symbols = tuple( + check_symbols: ClassVar[tuple[str, ...]] = tuple( f"{prefix}openblas_get_num_threads{suffix}" for prefix, suffix in itertools.product(_symbol_prefixes, _symbol_suffixes) ) - def _find_affixes(self): + threading_layer: "_ThreadingBackend" + architecture: Union[str, None] + + def _find_affixes(self) -> tuple[str, str]: for prefix, suffix in itertools.product( self._symbol_prefixes, self._symbol_suffixes ): if hasattr(self.dynlib, f"{prefix}openblas_get_num_threads{suffix}"): return prefix, suffix - def set_additional_attributes(self): + return "", "" # should never happen + + def set_additional_attributes(self) -> None: self.threading_layer = self._get_threading_layer() self.architecture = self._get_architecture() - def get_num_threads(self): - get_num_threads_func = self._get_symbol("openblas_get_num_threads") + def get_num_threads(self) -> Union[int, None]: + get_num_threads_func: "Callable[[], int] | None" = self._get_symbol( + "openblas_get_num_threads" + ) if get_num_threads_func is not None: return get_num_threads_func() return None - def set_num_threads(self, num_threads): - set_num_threads_func = self._get_symbol("openblas_set_num_threads") + def set_num_threads(self, num_threads: int) -> None: + set_num_threads_func: "Callable[[int], None] | None" = self._get_symbol( + "openblas_set_num_threads" + ) if set_num_threads_func is not None: return set_num_threads_func(num_threads) return None - def get_version(self): + def get_version(self) -> Union[str, None]: # None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS # did not expose its version before that. - get_version_func = self._get_symbol("openblas_get_config") + get_version_func: "Callable[[], bytes] | None" = self._get_symbol( + "openblas_get_config" + ) if get_version_func is not None: - get_version_func.restype = ctypes.c_char_p + get_version_func.restype = ctypes.c_char_p # type: ignore[attr-defined] config = get_version_func().split() if config[0] == b"OpenBLAS": return config[1].decode("utf-8") return None return None - def _get_threading_layer(self): + def _get_threading_layer(self) -> "_ThreadingBackend": """Return the threading layer of OpenBLAS""" get_threading_layer_func = self._get_symbol("openblas_get_parallel") if get_threading_layer_func is not None: @@ -223,11 +279,13 @@ def _get_threading_layer(self): return "disabled" return "unknown" - def _get_architecture(self): + def _get_architecture(self) -> Union[str, None]: """Return the architecture detected by OpenBLAS""" - get_architecture_func = self._get_symbol("openblas_get_corename") + get_architecture_func: "Callable[[], bytes] | None" = self._get_symbol( + "openblas_get_corename" + ) if get_architecture_func is not None: - get_architecture_func.restype = ctypes.c_char_p + get_architecture_func.restype = ctypes.c_char_p # type: ignore[attr-defined] return get_architecture_func().decode("utf-8") return None @@ -235,10 +293,10 @@ def _get_architecture(self): class BLISController(LibController): """Controller class for BLIS""" - user_api = "blas" - internal_api = "blis" - filename_prefixes = ("libblis", "libblas") - check_symbols = ( + user_api: ClassVar[str] = "blas" + internal_api: ClassVar[str] = "blis" + filename_prefixes: ClassVar[tuple[str, ...]] = ("libblis", "libblas") + check_symbols: ClassVar[tuple[str, ...]] = ( "bli_thread_get_num_threads", "bli_thread_set_num_threads", "bli_info_get_version_str", @@ -248,32 +306,34 @@ class BLISController(LibController): "bli_arch_string", ) - def set_additional_attributes(self): + def set_additional_attributes(self) -> None: self.threading_layer = self._get_threading_layer() self.architecture = self._get_architecture() - def get_num_threads(self): + def get_num_threads(self) -> Union[int, None]: get_func = getattr(self.dynlib, "bli_thread_get_num_threads", lambda: None) num_threads = get_func() # by default BLIS is single-threaded and get_num_threads # returns -1. We map it to 1 for consistency with other libraries. return 1 if num_threads == -1 else num_threads - def set_num_threads(self, num_threads): - set_func = getattr( + def set_num_threads(self, num_threads: int) -> None: + set_func: "Callable[[int], None]" = getattr( self.dynlib, "bli_thread_set_num_threads", lambda num_threads: None ) return set_func(num_threads) - def get_version(self): - get_version_ = getattr(self.dynlib, "bli_info_get_version_str", None) + def get_version(self) -> Union[str, None]: + get_version_: "Callable[[], bytes] | None" = getattr( + self.dynlib, "bli_info_get_version_str", None + ) if get_version_ is None: return None - get_version_.restype = ctypes.c_char_p + get_version_.restype = ctypes.c_char_p # type: ignore[attr-defined] return get_version_().decode("utf-8") - def _get_threading_layer(self): + def _get_threading_layer(self) -> Literal["openmp", "pthreads", "disabled"]: """Return the threading layer of BLIS""" if getattr(self.dynlib, "bli_info_get_enable_openmp", lambda: False)(): return "openmp" @@ -281,17 +341,21 @@ def _get_threading_layer(self): return "pthreads" return "disabled" - def _get_architecture(self): + def _get_architecture(self) -> Union[str, None]: """Return the architecture detected by BLIS""" - bli_arch_query_id = getattr(self.dynlib, "bli_arch_query_id", None) - bli_arch_string = getattr(self.dynlib, "bli_arch_string", None) + bli_arch_query_id: "Callable[[], bytes] | None" = getattr( + self.dynlib, "bli_arch_query_id", None + ) + bli_arch_string: "Callable[[bytes], bytes] | None" = getattr( + self.dynlib, "bli_arch_string", None + ) if bli_arch_query_id is None or bli_arch_string is None: return None # the true restype should be BLIS' arch_t (enum) but int should work # for us: - bli_arch_query_id.restype = ctypes.c_int - bli_arch_string.restype = ctypes.c_char_p + bli_arch_query_id.restype = ctypes.c_int # type: ignore[attr-defined] + bli_arch_string.restype = ctypes.c_char_p # type: ignore[attr-defined] return bli_arch_string(bli_arch_query_id()).decode("utf-8") @@ -310,15 +374,17 @@ class FlexiBLASController(LibController): "flexiblas_current_backend", ) + available_backends: list[str] + @property - def loaded_backends(self): + def loaded_backends(self) -> Union[list[str], Any]: return self._get_backend_list(loaded=True) @property - def current_backend(self): + def current_backend(self) -> Union[str, Any]: return self._get_current_backend() - def info(self): + def info(self) -> dict[str, Any]: """Return relevant info wrapped in a dict""" # We override the info method because the loaded and current backends # are dynamic properties @@ -328,23 +394,23 @@ def info(self): return exposed_attrs - def set_additional_attributes(self): + def set_additional_attributes(self) -> None: self.available_backends = self._get_backend_list(loaded=False) - def get_num_threads(self): + def get_num_threads(self) -> Union[int, None]: get_func = getattr(self.dynlib, "flexiblas_get_num_threads", lambda: None) num_threads = get_func() # by default BLIS is single-threaded and get_num_threads # returns -1. We map it to 1 for consistency with other libraries. return 1 if num_threads == -1 else num_threads - def set_num_threads(self, num_threads): - set_func = getattr( + def set_num_threads(self, num_threads: int) -> None: + set_func: "Callable[[int], None]" = getattr( self.dynlib, "flexiblas_set_num_threads", lambda num_threads: None ) return set_func(num_threads) - def get_version(self): + def get_version(self) -> Union[str, None]: get_version_ = getattr(self.dynlib, "flexiblas_get_version", None) if get_version_ is None: return None @@ -355,7 +421,7 @@ def get_version(self): get_version_(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch)) return f"{major.value}.{minor.value}.{patch.value}" - def _get_backend_list(self, loaded=False): + def _get_backend_list(self, loaded: bool = False) -> list[str]: """Return the list of available backends for FlexiBLAS. If loaded is False, return the list of available backends from the FlexiBLAS @@ -364,11 +430,11 @@ def _get_backend_list(self, loaded=False): func_name = f"flexiblas_list{'_loaded' if loaded else ''}" get_backend_list_ = getattr(self.dynlib, func_name, None) if get_backend_list_ is None: - return None + return [] n_backends = get_backend_list_(None, 0, 0) - backends = [] + backends: list[str] = [] for i in range(n_backends): backend_name = ctypes.create_string_buffer(1024) get_backend_list_(backend_name, 1024, i) @@ -378,7 +444,7 @@ def _get_backend_list(self, loaded=False): backends.append(backend_name.value.decode("utf-8")) return backends - def _get_current_backend(self): + def _get_current_backend(self) -> Union[str, None]: """Return the backend of FlexiBLAS""" get_backend_ = getattr(self.dynlib, "flexiblas_current_backend", None) if get_backend_ is None: @@ -388,7 +454,7 @@ def _get_current_backend(self): get_backend_(backend, ctypes.sizeof(backend)) return backend.value.decode("utf-8") - def switch_backend(self, backend): + def switch_backend(self, backend: str) -> None: """Switch the backend of FlexiBLAS Parameters @@ -398,6 +464,7 @@ def switch_backend(self, backend): the backend is not already loaded, it will be loaded first. """ if backend not in self.loaded_backends: + load_func: "Callable[[bytes], int]" if backend in self.available_backends: load_func = getattr(self.dynlib, "flexiblas_load_backend", lambda _: -1) else: # assume backend is a path to a shared library @@ -414,9 +481,11 @@ def switch_backend(self, backend): # Trigger a new search of loaded shared libraries since loading a new # backend caused a dlopen. - self.parent._load_libraries() + cast("ThreadpoolController", self.parent)._load_libraries() - switch_func = getattr(self.dynlib, "flexiblas_switch", lambda _: -1) + switch_func: "Callable[[int], int]" = getattr( + self.dynlib, "flexiblas_switch", lambda _: -1 + ) idx = self.loaded_backends.index(backend) res = switch_func(idx) if res == -1: @@ -436,18 +505,20 @@ class MKLController(LibController): "MKL_Set_Threading_Layer", ) - def set_additional_attributes(self): + def set_additional_attributes(self) -> None: self.threading_layer = self._get_threading_layer() - def get_num_threads(self): + def get_num_threads(self) -> Union[int, None]: get_func = getattr(self.dynlib, "MKL_Get_Max_Threads", lambda: None) return get_func() - def set_num_threads(self, num_threads): - set_func = getattr(self.dynlib, "MKL_Set_Num_Threads", lambda num_threads: None) + def set_num_threads(self, num_threads: int) -> None: + set_func: "Callable[[int], None]" = getattr( + self.dynlib, "MKL_Set_Num_Threads", lambda num_threads: None + ) return set_func(num_threads) - def get_version(self): + def get_version(self) -> Union[str, None]: if not hasattr(self.dynlib, "MKL_Get_Version_String"): return None @@ -460,12 +531,12 @@ def get_version(self): version = group.groups()[0] return version.strip() - def _get_threading_layer(self): + def _get_threading_layer(self) -> str: """Return the threading layer of MKL""" # The function mkl_set_threading_layer returns the current threading # layer. Calling it with an invalid threading layer allows us to safely # get the threading layer - set_threading_layer = getattr( + set_threading_layer: "Callable[[int], int]" = getattr( self.dynlib, "MKL_Set_Threading_Layer", lambda layer: -1 ) layer_map = { @@ -490,22 +561,24 @@ class OpenMPController(LibController): "omp_get_num_threads", ) - def get_num_threads(self): + def get_num_threads(self) -> Union[int, None]: get_func = getattr(self.dynlib, "omp_get_max_threads", lambda: None) return get_func() - def set_num_threads(self, num_threads): - set_func = getattr(self.dynlib, "omp_set_num_threads", lambda num_threads: None) + def set_num_threads(self, num_threads: int) -> None: + set_func: "Callable[[int], None]" = getattr( + self.dynlib, "omp_set_num_threads", lambda num_threads: None + ) return set_func(num_threads) - def get_version(self): + def get_version(self) -> None: # There is no way to get the version number programmatically in OpenMP. return None # Controllers for the libraries that we'll look for in the loaded libraries. # Third party libraries can register their own controllers. -_ALL_CONTROLLERS = [ +_ALL_CONTROLLERS: list[type[LibController]] = [ OpenBLASController, BLISController, MKLController, @@ -525,7 +598,7 @@ def get_version(self): _ALL_OPENMP_LIBRARIES = OpenMPController.filename_prefixes -def register(controller): +def register(controller: type[LibController]) -> None: """Register a new controller""" _ALL_CONTROLLERS.append(controller) _ALL_USER_APIS.append(controller.user_api) @@ -533,8 +606,8 @@ def register(controller): _ALL_PREFIXES.extend(controller.filename_prefixes) -def _format_docstring(*args, **kwargs): - def decorator(o): +def _format_docstring(*args: object, **kwargs: object) -> "IdentityFunction": + def decorator(o: Any) -> Any: if o.__doc__ is not None: o.__doc__ = o.__doc__.format(*args, **kwargs) return o @@ -543,13 +616,13 @@ def decorator(o): @lru_cache(maxsize=10000) -def _realpath(filepath): +def _realpath(filepath: str) -> str: """Small caching wrapper around os.path.realpath to limit system calls""" return os.path.realpath(filepath) @_format_docstring(USER_APIS=list(_ALL_USER_APIS), INTERNAL_APIS=_ALL_INTERNAL_APIS) -def threadpool_info(): +def threadpool_info() -> list[dict[str, Any]]: """Return the maximal number of threads for each detected library. Return a list with all the supported libraries that have been found. Each @@ -578,7 +651,15 @@ class _ThreadpoolLimiter: that it can be used as a decorator. """ - def __init__(self, controller, *, limits=None, user_api=None): + _controller: "ThreadpoolController" + + def __init__( + self, + controller: "ThreadpoolController", + *, + limits: "_ToLimits | None" = None, + user_api: Union[str, None] = None, + ) -> None: self._controller = controller self._limits, self._user_api, self._prefixes = self._check_params( limits, user_api @@ -586,20 +667,31 @@ def __init__(self, controller, *, limits=None, user_api=None): self._original_info = self._controller.info() self._set_threadpool_limits() - def __enter__(self): + def __enter__(self) -> "Self": return self - def __exit__(self, type, value, traceback): + def __exit__( + self, + type: Union[type[BaseException], None], + value: Union[BaseException, None], + traceback: "TracebackType | None", + ) -> None: self.restore_original_limits() @classmethod - def wrap(cls, controller, *, limits=None, user_api=None): + def wrap( + cls, + controller: "ThreadpoolController", + *, + limits: "_ToLimits | None" = None, + user_api: Union[str, None] = None, + ) -> "_ThreadpoolLimiterDecorator": """Return an instance of this class that can be used as a decorator""" return _ThreadpoolLimiterDecorator( controller=controller, limits=limits, user_api=user_api ) - def restore_original_limits(self): + def restore_original_limits(self) -> None: """Set the limits back to their original values""" for lib_controller, original_info in zip( self._controller.lib_controllers, self._original_info @@ -609,13 +701,13 @@ def restore_original_limits(self): # Alias of `restore_original_limits` for backward compatibility unregister = restore_original_limits - def get_original_num_threads(self): + def get_original_num_threads(self) -> dict[str, Any]: """Original num_threads from before calling threadpool_limits Return a dict `{user_api: num_threads}`. """ - num_threads = {} - warning_apis = [] + num_threads: dict[str, Union[int, None]] = {} + warning_apis: list[str] = [] for user_api in self._user_api: limits = [ @@ -623,15 +715,15 @@ def get_original_num_threads(self): for lib_info in self._original_info if lib_info["user_api"] == user_api ] - limits = set(limits) - n_limits = len(limits) + limits_set = set(limits) + n_limits = len(limits_set) if n_limits == 1: - limit = limits.pop() + limit = limits_set.pop() elif n_limits == 0: limit = None else: - limit = min(limits) + limit = min(limits_set) warning_apis.append(user_api) num_threads[user_api] = limit @@ -645,58 +737,70 @@ def get_original_num_threads(self): return num_threads - def _check_params(self, limits, user_api): + def _check_params( + self, + limits: "_ToLimits | None", + user_api: Union[list[str], str, None], + ) -> tuple[Union[dict[str, Any], None], list[str], list[str]]: """Suitable values for the _limits, _user_api and _prefixes attributes""" if isinstance(limits, str) and limits == "sequential_blas_under_openmp": - ( - limits, - user_api, - ) = self._controller._get_params_for_sequential_blas_under_openmp().values() + params = self._controller._get_params_for_sequential_blas_under_openmp() + limits, user_api = params["limits"], params["user_api"] if limits is None or isinstance(limits, int): if user_api is None: - user_api = _ALL_USER_APIS + user_api_clean = _ALL_USER_APIS elif user_api in _ALL_USER_APIS: - user_api = [user_api] + assert isinstance(user_api, str) + user_api_clean = [user_api] else: raise ValueError( f"user_api must be either in {_ALL_USER_APIS} or None. Got " f"{user_api} instead." ) - if limits is not None: - limits = {api: limits for api in user_api} + if limits is None: + limits_clean = None + else: + limits_clean = {api: limits for api in user_api_clean} + prefixes = [] else: if isinstance(limits, list): # This should be a list of dicts of library info, for # compatibility with the result from threadpool_info. - limits = { + limits_clean = { lib_info["prefix"]: lib_info["num_threads"] for lib_info in limits } elif isinstance(limits, ThreadpoolController): # To set the limits from the library controllers of a # ThreadpoolController object. - limits = { - lib_controller.prefix: lib_controller.num_threads + limits_clean = { + cast("str", lib_controller.prefix): lib_controller.num_threads for lib_controller in limits.lib_controllers } + else: + limits_clean = limits - if not isinstance(limits, dict): + if not isinstance( + limits_clean, dict + ): # pyright: ignore[reportUnnecessaryIsInstance] raise TypeError( "limits must either be an int, a list, a dict, or " f"'sequential_blas_under_openmp'. Got {type(limits)} instead" ) + user_api_clean = cast("list[str]", user_api) + # With a dictionary, can set both specific limit for given # libraries and global limit for user_api. Fetch each separately. - prefixes = [prefix for prefix in limits if prefix in _ALL_PREFIXES] - user_api = [api for api in limits if api in _ALL_USER_APIS] + prefixes = [prefix for prefix in limits_clean if prefix in _ALL_PREFIXES] + user_api = [api for api in limits_clean if api in _ALL_USER_APIS] - return limits, user_api, prefixes + return limits_clean, user_api_clean, prefixes - def _set_threadpool_limits(self): + def _set_threadpool_limits(self) -> None: """Change the maximal number of threads in selected thread pools. Return a list with all the supported libraries that have been found @@ -723,13 +827,19 @@ def _set_threadpool_limits(self): class _ThreadpoolLimiterDecorator(_ThreadpoolLimiter, ContextDecorator): """Same as _ThreadpoolLimiter but to be used as a decorator""" - def __init__(self, controller, *, limits=None, user_api=None): + def __init__( + self, + controller: "ThreadpoolController", + *, + limits: "_ToLimits | None" = None, + user_api: Union[list[str], str, None] = None, + ) -> None: self._limits, self._user_api, self._prefixes = self._check_params( limits, user_api ) self._controller = controller - def __enter__(self): + def __enter__(self) -> "Self": # we need to set the limits here and not in the __init__ because we want the # limits to be set when calling the decorated function, not when creating the # decorator. @@ -790,11 +900,15 @@ class threadpool_limits(_ThreadpoolLimiter): - If None, this function will apply to all supported libraries. """ - def __init__(self, limits=None, user_api=None): + def __init__( + self, limits: "_ToLimits | None" = None, user_api: Union[str, None] = None + ) -> None: super().__init__(ThreadpoolController(), limits=limits, user_api=user_api) @classmethod - def wrap(cls, limits=None, user_api=None): + def wrap( # type: ignore[override] + cls, limits: "_ToLimits | None" = None, user_api: Union[str, None] = None + ) -> _ThreadpoolLimiterDecorator: return super().wrap(ThreadpoolController(), limits=limits, user_api=user_api) @@ -811,24 +925,26 @@ class ThreadpoolController: # We use a class level cache instead of an instance level cache because # it's very unlikely that a shared library will be unloaded and reloaded # during the lifetime of a program. - _system_libraries = dict() + _system_libraries: ClassVar[dict[str, ctypes.CDLL]] = {} + + lib_controllers: list[LibController] - def __init__(self): + def __init__(self) -> None: self.lib_controllers = [] self._load_libraries() self._warn_if_incompatible_openmp() @classmethod - def _from_controllers(cls, lib_controllers): + def _from_controllers(cls, lib_controllers: list[LibController]) -> "Self": new_controller = cls.__new__(cls) new_controller.lib_controllers = lib_controllers return new_controller - def info(self): + def info(self) -> list[dict[str, Any]]: """Return lib_controllers info as a list of dicts""" return [lib_controller.info() for lib_controller in self.lib_controllers] - def select(self, **kwargs): + def select(self, **kwargs: Any) -> "ThreadpoolController": """Return a ThreadpoolController containing a subset of its current library controllers @@ -855,7 +971,7 @@ def select(self, **kwargs): return ThreadpoolController._from_controllers(lib_controllers) - def _get_params_for_sequential_blas_under_openmp(self): + def _get_params_for_sequential_blas_under_openmp(self) -> "_BLASParamsDict": """Return appropriate params to use for a sequential BLAS call in an OpenMP loop This function takes into account the unexpected behavior of OpenBLAS with the @@ -872,7 +988,9 @@ def _get_params_for_sequential_blas_under_openmp(self): BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), ) - def limit(self, *, limits=None, user_api=None): + def limit( + self, *, limits: "_ToLimits" = None, user_api: Union[str, None] = None + ) -> _ThreadpoolLimiter: """Change the maximal number of threads that can be used in thread pools. This function returns an object that can be used either as a callable (the @@ -925,7 +1043,9 @@ def limit(self, *, limits=None, user_api=None): BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), ) - def wrap(self, *, limits=None, user_api=None): + def wrap( + self, *, limits: "_ToLimits" = None, user_api: Union[str, None] = None + ) -> _ThreadpoolLimiterDecorator: """Change the maximal number of threads that can be used in thread pools. This function returns an object that can be used as a decorator. @@ -961,10 +1081,10 @@ def wrap(self, *, limits=None, user_api=None): """ return _ThreadpoolLimiter.wrap(self, limits=limits, user_api=user_api) - def __len__(self): + def __len__(self) -> int: return len(self.lib_controllers) - def _load_libraries(self): + def _load_libraries(self) -> None: """Loop through loaded shared libraries and store the supported ones""" if sys.platform == "darwin": self._find_libraries_with_dyld() @@ -975,7 +1095,7 @@ def _load_libraries(self): else: self._find_libraries_with_dl_iterate_phdr() - def _find_libraries_with_dl_iterate_phdr(self): + def _find_libraries_with_dl_iterate_phdr(self) -> Union[list["Never"], int, None]: """Loop through loaded libraries and return binders on supported ones This function is expected to work on POSIX system only. @@ -995,7 +1115,9 @@ def _find_libraries_with_dl_iterate_phdr(self): # Callback function for `dl_iterate_phdr` which is called for every # library loaded in the current process until it returns 1. - def match_library_callback(info, size, data): + def match_library_callback( + info: "ctypes._Pointer[Any]", size: int, data: ctypes.c_char_p + ) -> int: # Get the path of the current library filepath = info.contents.dlpi_name if filepath: @@ -1016,7 +1138,9 @@ def match_library_callback(info, size, data): data = ctypes.c_char_p(b"") libc.dl_iterate_phdr(c_match_library_callback, data) - def _find_libraries_with_dyld(self): + return None + + def _find_libraries_with_dyld(self) -> Union[list["Never"], None]: """Loop through loaded libraries and return binders on supported ones This function is expected to work on OSX system only @@ -1033,13 +1157,15 @@ def _find_libraries_with_dyld(self): libc._dyld_get_image_name.restype = ctypes.c_char_p for i in range(n_dyld): - filepath = ctypes.string_at(libc._dyld_get_image_name(i)) - filepath = filepath.decode("utf-8") + filepath_raw = ctypes.string_at(libc._dyld_get_image_name(i)) + filepath = filepath_raw.decode("utf-8") # Store the library controller if it is supported and selected self._make_controller_from_path(filepath) - def _find_libraries_with_enum_process_module_ex(self): + return None + + def _find_libraries_with_enum_process_module_ex(self) -> None: """Loop through loaded libraries and return binders on supported ones This function is expected to work on windows system only. @@ -1089,15 +1215,15 @@ def _find_libraries_with_enum_process_module_ex(self): # Allocate a buffer for the path 10 times the size of MAX_PATH to take # into account long path names. max_path = 10 * MAX_PATH - buf = ctypes.create_unicode_buffer(max_path) + buf_wchar = ctypes.create_unicode_buffer(max_path) n_size = DWORD() for h_module in h_modules: # Get the path of the current module if not ps_api.GetModuleFileNameExW( - h_process, h_module, ctypes.byref(buf), ctypes.byref(n_size) + h_process, h_module, ctypes.byref(buf_wchar), ctypes.byref(n_size) ): raise OSError("GetModuleFileNameEx failed") - filepath = buf.value + filepath = buf_wchar.value if len(filepath) == max_path: # pragma: no cover warnings.warn( @@ -1113,7 +1239,7 @@ def _find_libraries_with_enum_process_module_ex(self): finally: kernel_32.CloseHandle(h_process) - def _find_libraries_pyodide(self): + def _find_libraries_pyodide(self) -> None: """Pyodide specific implementation for finding loaded libraries. Adapted from suggestion in https://github.com/joblib/threadpoolctl/pull/169#issuecomment-1946696449. @@ -1124,7 +1250,9 @@ def _find_libraries_pyodide(self): details. """ try: - from pyodide_js._module import LDSO + from pyodide_js._module import LDSO # type: ignore[import-not-found] + + LDSO = cast(Any, LDSO) # pyright: ignore[reportConstantRedefinition] except ImportError: warnings.warn( "Unable to import LDSO from pyodide_js._module. This should never " @@ -1140,7 +1268,7 @@ def _find_libraries_pyodide(self): if os.path.exists(filepath): self._make_controller_from_path(filepath) - def _make_controller_from_path(self, filepath): + def _make_controller_from_path(self, filepath: str) -> None: """Store a library controller if it is supported and selected""" # Required to resolve symlinks filepath = _realpath(filepath) @@ -1198,7 +1326,9 @@ def _make_controller_from_path(self, filepath): ): self.lib_controllers.append(lib_controller) - def _check_prefix(self, library_basename, filename_prefixes): + def _check_prefix( + self, library_basename: str, filename_prefixes: "Iterable[str]" + ) -> Union[str, None]: """Return the prefix library_basename starts with Return None if none matches. @@ -1208,7 +1338,7 @@ def _check_prefix(self, library_basename, filename_prefixes): return prefix return None - def _warn_if_incompatible_openmp(self): + def _warn_if_incompatible_openmp(self) -> None: """Raise a warning if llvm-OpenMP and intel-OpenMP are both loaded""" prefixes = [lib_controller.prefix for lib_controller in self.lib_controllers] msg = textwrap.dedent( @@ -1226,7 +1356,7 @@ def _warn_if_incompatible_openmp(self): warnings.warn(msg, RuntimeWarning) @classmethod - def _get_libc(cls): + def _get_libc(cls) -> ctypes.CDLL: """Load the lib-C for unix systems.""" libc = cls._system_libraries.get("libc") if libc is None: @@ -1241,52 +1371,13 @@ def _get_libc(cls): return libc @classmethod - def _get_windll(cls, dll_name): + def _get_windll(cls, dll_name: str) -> ctypes.CDLL: """Load a windows DLL""" dll = cls._system_libraries.get(dll_name) - if dll is None: + # the `sys.platform` check is required to avoid typing errors on non-windows + if dll is None and sys.platform == "win32": dll = ctypes.WinDLL(f"{dll_name}.dll") cls._system_libraries[dll_name] = dll + else: + assert dll is not None return dll - - -def _main(): - """Commandline interface to display thread-pool information and exit.""" - import argparse - import importlib - import json - import sys - - parser = argparse.ArgumentParser( - usage="python -m threadpoolctl -i numpy scipy.linalg xgboost", - description="Display thread-pool information and exit.", - ) - parser.add_argument( - "-i", - "--import", - dest="modules", - nargs="*", - default=(), - help="Python modules to import before introspecting thread-pools.", - ) - parser.add_argument( - "-c", - "--command", - help="a Python statement to execute before introspecting thread-pools.", - ) - - options = parser.parse_args(sys.argv[1:]) - for module in options.modules: - try: - importlib.import_module(module, package=None) - except ImportError: - print("WARNING: could not import", module, file=sys.stderr) - - if options.command: - exec(options.command) - - print(json.dumps(threadpool_info(), indent=2)) - - -if __name__ == "__main__": - _main() diff --git a/threadpoolctl/__main__.py b/threadpoolctl/__main__.py new file mode 100644 index 00000000..09abef99 --- /dev/null +++ b/threadpoolctl/__main__.py @@ -0,0 +1,46 @@ +"""Commandline interface to display thread-pool information and exit.""" + +__all__ = () + + +def _main() -> None: + import argparse + import importlib + import json + import sys + + from threadpoolctl import threadpool_info + + parser = argparse.ArgumentParser( + usage="python -m threadpoolctl -i numpy scipy.linalg xgboost", + description="Display thread-pool information and exit.", + ) + parser.add_argument( + "-i", + "--import", + dest="modules", + nargs="*", + default=(), + help="Python modules to import before introspecting thread-pools.", + ) + parser.add_argument( + "-c", + "--command", + help="a Python statement to execute before introspecting thread-pools.", + ) + options = parser.parse_args(sys.argv[1:]) + + for module in options.modules: + try: + importlib.import_module(module, package=None) + except ImportError: + print("WARNING: could not import", module, file=sys.stderr) + + if options.command: + exec(options.command) + + print(json.dumps(threadpool_info(), indent=2)) + + +if __name__ == "__main__": + _main() diff --git a/threadpoolctl/py.typed b/threadpoolctl/py.typed new file mode 100644 index 00000000..e69de29b