Skip to content

Commit 3c5fc3b

Browse files
committed
Chore: Add type hints to manager.py
1 parent c8b9ece commit 3c5fc3b

7 files changed

Lines changed: 97 additions & 72 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ clean-dist: clean
2121
rm -rf dist/
2222

2323
lint: venv
24-
$(VENV_ACTIVATE); python -m ruff check .
24+
$(VENV_ACTIVATE); python -m ruff check . && python -m mypy
2525

2626
format: venv
2727
$(VENV_ACTIVATE); python -m ruff format . && python -m ruff check . --fix

mypy.ini

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[mypy]
2+
explicit_package_bases = true
3+
files=plux/runtime/manager.py,tests/test_manager.py
4+
ignore_missing_imports = False
5+
follow_imports = silent
6+
ignore_errors = False
7+
disallow_untyped_defs = True
8+
disallow_untyped_calls = True
9+
disallow_any_generics = True
10+
disallow_subclassing_any = True
11+
warn_unused_ignores = True

plux/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,5 @@
3333
"PluginSpecResolver",
3434
"PluginType",
3535
"plugin",
36-
"__version__"
36+
"__version__",
3737
]

plux/core/plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class PluginDisabled(PluginException):
2727

2828
reason: str
2929

30-
def __init__(self, namespace: str, name: str, reason: str = None):
30+
def __init__(self, namespace: str, name: str, reason: str | None = None):
3131
message = f"plugin {namespace}:{name} is disabled"
3232
if reason:
3333
message = f"{message}, reason: {reason}"

plux/runtime/manager.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import logging
22
import threading
33
import typing as t
4+
from collections.abc import Iterable
5+
from importlib.metadata import EntryPoint
46

57
from plux.core.plugin import (
68
Plugin,
@@ -18,9 +20,11 @@
1820
LOG = logging.getLogger(__name__)
1921

2022
P = t.TypeVar("P", bound=Plugin)
23+
PS = t.ParamSpec("PS")
24+
T = t.TypeVar("T")
2125

2226

23-
def _call_safe(func: t.Callable, args: tuple, exception_message: str):
27+
def _call_safe(func: t.Callable[PS, T], args: t.Any, exception_message: str) -> None:
2428
"""
2529
Call the given function with the given arguments, and if it fails, log the given exception_message. If
2630
logging.DEBUG is set for the logger, then we also log the traceback. An exception is made for any
@@ -32,7 +36,7 @@ def _call_safe(func: t.Callable, args: tuple, exception_message: str):
3236
:return: whatever the func returns
3337
"""
3438
try:
35-
return func(*args)
39+
func(*args, **{})
3640
except PluginException:
3741
# re-raise PluginExceptions, since they should be handled by the caller
3842
raise
@@ -54,23 +58,25 @@ class PluginLifecycleNotifierMixin:
5458

5559
listeners: list[PluginLifecycleListener]
5660

57-
def _fire_on_resolve_after(self, plugin_spec):
61+
def _fire_on_resolve_after(self, plugin_spec: PluginSpec) -> None:
5862
for listener in self.listeners:
5963
_call_safe(
6064
listener.on_resolve_after,
6165
(plugin_spec,), #
6266
"error while calling on_resolve_after",
6367
)
6468

65-
def _fire_on_resolve_exception(self, namespace, entrypoint, exception):
69+
def _fire_on_resolve_exception(
70+
self, namespace: str, entrypoint: EntryPoint, exception: Exception
71+
) -> None:
6672
for listener in self.listeners:
6773
_call_safe(
6874
listener.on_resolve_exception,
6975
(namespace, entrypoint, exception),
7076
"error while calling on_resolve_exception",
7177
)
7278

73-
def _fire_on_init_after(self, plugin_spec, plugin):
79+
def _fire_on_init_after(self, plugin_spec: PluginSpec, plugin: P) -> None:
7480
for listener in self.listeners:
7581
_call_safe(
7682
listener.on_init_after,
@@ -81,31 +87,35 @@ def _fire_on_init_after(self, plugin_spec, plugin):
8187
"error while calling on_init_after",
8288
)
8389

84-
def _fire_on_init_exception(self, plugin_spec, exception):
90+
def _fire_on_init_exception(self, plugin_spec: PluginSpec, exception: Exception) -> None:
8591
for listener in self.listeners:
8692
_call_safe(
8793
listener.on_init_exception,
8894
(plugin_spec, exception),
8995
"error while calling on_init_exception",
9096
)
9197

92-
def _fire_on_load_before(self, plugin_spec, plugin, load_args, load_kwargs):
98+
def _fire_on_load_before(
99+
self, plugin_spec: PluginSpec, plugin: P, load_args: t.Any, load_kwargs: t.Any
100+
) -> None:
93101
for listener in self.listeners:
94102
_call_safe(
95103
listener.on_load_before,
96104
(plugin_spec, plugin, load_args, load_kwargs),
97105
"error while calling on_load_before",
98106
)
99107

100-
def _fire_on_load_after(self, plugin_spec, plugin, result):
108+
def _fire_on_load_after(self, plugin_spec: PluginSpec, plugin: P | None, result: t.Any) -> None:
101109
for listener in self.listeners:
102110
_call_safe(
103111
listener.on_load_after,
104112
(plugin_spec, plugin, result),
105113
"error while calling on_load_after",
106114
)
107115

108-
def _fire_on_load_exception(self, plugin_spec, plugin, exception):
116+
def _fire_on_load_exception(
117+
self, plugin_spec: PluginSpec, plugin: P | None, exception: Exception
118+
) -> None:
109119
for listener in self.listeners:
110120
_call_safe(
111121
listener.on_load_exception,
@@ -123,20 +133,20 @@ class PluginContainer(t.Generic[P]):
123133
lock: threading.RLock
124134

125135
plugin_spec: PluginSpec
126-
plugin: P = None
127-
load_value: t.Any | None = None
136+
plugin: P | None = None
137+
load_value: t.Any = None
128138

129139
is_init: bool = False
130140
is_loaded: bool = False
131141

132-
init_error: Exception = None
133-
load_error: Exception = None
142+
init_error: Exception | None = None
143+
load_error: Exception | None = None
134144

135145
is_disabled: bool = False
136-
disabled_reason = str = None
146+
disabled_reason: str | None = None
137147

138148
@property
139-
def distribution(self) -> Distribution:
149+
def distribution(self) -> Distribution | None:
140150
"""
141151
Uses metadata from importlib to resolve the distribution information for this plugin.
142152
@@ -160,19 +170,19 @@ class PluginManager(PluginLifecycleNotifierMixin, t.Generic[P]):
160170

161171
namespace: str
162172

163-
load_args: list | tuple
173+
load_args: list[t.Any] | tuple[t.Any, ...]
164174
load_kwargs: dict[str, t.Any]
165175
listeners: list[PluginLifecycleListener]
166176
filters: list[PluginFilter]
167177

168178
def __init__(
169179
self,
170180
namespace: str,
171-
load_args: list | tuple = None,
172-
load_kwargs: dict = None,
173-
listener: PluginLifecycleListener | t.Iterable[PluginLifecycleListener] = None,
174-
finder: PluginFinder = None,
175-
filters: list[PluginFilter] = None,
181+
load_args: list[t.Any] | tuple[t.Any, ...] | None = None,
182+
load_kwargs: dict[str, t.Any] | None = None,
183+
listener: PluginLifecycleListener | t.Iterable[PluginLifecycleListener] | None = None,
184+
finder: PluginFinder | None = None,
185+
filters: list[PluginFilter] | None = None,
176186
):
177187
"""
178188
Create a new ``PluginManager`` that can be used to load plugins. The simplest ``PluginManager`` only needs
@@ -231,7 +241,7 @@ def on_load_before(self, plugin_spec: PluginSpec, plugin: Plugin, load_result: t
231241
self.load_kwargs = load_kwargs or dict()
232242

233243
if listener:
234-
if isinstance(listener, (list, set, tuple)):
244+
if isinstance(listener, Iterable):
235245
self.listeners = list(listener)
236246
else:
237247
self.listeners = [listener]
@@ -243,10 +253,10 @@ def on_load_before(self, plugin_spec: PluginSpec, plugin: Plugin, load_result: t
243253

244254
self.finder = finder or MetadataPluginFinder(self.namespace, self._fire_on_resolve_exception)
245255

246-
self._plugin_index = None
256+
self._plugin_index: dict[str, PluginContainer[P]] | None = None
247257
self._init_mutex = threading.RLock()
248258

249-
def add_listener(self, listener: PluginLifecycleListener):
259+
def add_listener(self, listener: PluginLifecycleListener) -> None:
250260
"""
251261
Adds a lifecycle listener to the plugin manager. The listener will be notified of plugin lifecycle events.
252262
@@ -326,12 +336,12 @@ def load(self, name: str) -> P:
326336
if container.load_error:
327337
raise container.load_error
328338

329-
if not container.is_loaded:
339+
if container.plugin is None or not container.is_loaded:
330340
raise PluginException("plugin did not load correctly", namespace=self.namespace, name=name)
331341

332342
return container.plugin
333343

334-
def load_all(self, propagate_exceptions=False) -> list[P]:
344+
def load_all(self, propagate_exceptions: bool = False) -> list[P]:
335345
"""
336346
Attempts to load all plugins found in the namespace and returns those that were loaded successfully.
337347
@@ -364,10 +374,10 @@ def load_all(self, propagate_exceptions=False) -> list[P]:
364374
:param propagate_exceptions: If True, re-raises any exceptions encountered during loading
365375
:return: A list of successfully loaded plugin instances
366376
"""
367-
plugins = list()
377+
plugins: list[P] = list()
368378

369379
for name, container in self._plugins.items():
370-
if container.is_loaded:
380+
if container.plugin is not None and container.is_loaded:
371381
plugins.append(container.plugin)
372382
continue
373383

@@ -552,7 +562,7 @@ def _require_plugin(self, name: str) -> PluginContainer[P]:
552562

553563
return self._plugins[name]
554564

555-
def _load_plugin(self, container: PluginContainer) -> None:
565+
def _load_plugin(self, container: PluginContainer[P]) -> None:
556566
"""
557567
Implements the core algorithm to load a plugin from a ``PluginSpec`` (contained in the ``PluginContainer``),
558568
and stores all relevant results, such as the Plugin instance, load result, or any errors into the passed
@@ -602,6 +612,7 @@ def _load_plugin(self, container: PluginContainer) -> None:
602612
return
603613

604614
plugin = container.plugin
615+
assert plugin # Make MyPy happy - plugin should exist at this point
605616

606617
if not plugin.should_load():
607618
raise PluginDisabled(
@@ -643,9 +654,9 @@ def _plugin_from_spec(self, plugin_spec: PluginSpec) -> P:
643654
if spec:
644655
factory = spec.factory
645656

646-
return factory()
657+
return factory() # type: ignore[return-value]
647658

648-
def _init_plugin_index(self) -> dict[str, PluginContainer]:
659+
def _init_plugin_index(self) -> dict[str, PluginContainer[P]]:
649660
"""
650661
Initializes the plugin index, which maps plugin names to plugin containers. This method will *resolve* plugins,
651662
meaning it loads the entry point object reference, thereby importing all its code.
@@ -654,7 +665,7 @@ def _init_plugin_index(self) -> dict[str, PluginContainer]:
654665
"""
655666
return {plugin.name: plugin for plugin in self._import_plugins() if plugin}
656667

657-
def _import_plugins(self) -> t.Iterable[PluginContainer]:
668+
def _import_plugins(self) -> t.Iterable[PluginContainer[P]]:
658669
"""
659670
Finds all ``PluginSpace`` instances in the namespace, creates a container for each spec, and yields them one
660671
by one. The plugin finder will typically load the entry point which involves importing the module it lives in.
@@ -671,14 +682,14 @@ def _import_plugins(self) -> t.Iterable[PluginContainer]:
671682

672683
yield self._create_container(spec)
673684

674-
def _create_container(self, plugin_spec: PluginSpec) -> PluginContainer:
685+
def _create_container(self, plugin_spec: PluginSpec) -> PluginContainer[P]:
675686
"""
676687
Factory method to create a ``PluginContainer`` for the given ``PluginSpec``.
677688
678689
:param plugin_spec: The ``PluginSpec`` to create a container for.
679690
:return: A new ``PluginContainer`` with the basic information of the plugin spec.
680691
"""
681-
container = PluginContainer()
692+
container = PluginContainer[P]()
682693
container.lock = threading.RLock()
683694
container.name = plugin_spec.name
684695
container.plugin_spec = plugin_spec

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dev = [
3333
"setuptools",
3434
"pytest==8.4.1",
3535
"ruff==0.9.1",
36+
"mypy",
3637
]
3738

3839
[tool.hatch.build.hooks.vcs]

0 commit comments

Comments
 (0)