diff --git a/.travis.yml b/.travis.yml index b5c7633..7530e89 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,18 +1,14 @@ language: python python: - - "pypy" - "pypy3" matrix: include: - - { python: '2.7', env: } - - { arch: arm64, python: '2.7' } - - { python: '3.4', env: } - - { python: '3.5', env: } - { python: '3.6', env: } - { python: '3.7', env: } - { python: '3.8', env: } - - { arch: arm64, python: '3.8' } + - { python: '3.9', env: } + - { arch: arm64, python: '3.9' } install: - pip install --upgrade pip - pip install coverage @@ -20,11 +16,7 @@ install: script: - | - if [[ $(bc <<< "$TRAVIS_PYTHON_VERSION >= 3.3") -eq 1 ]]; then - py.test --doctest-modules multipledispatch - else - py.test --doctest-modules --ignore=multipledispatch/tests/test_dispatcher_3only.py multipledispatch - fi + py.test --doctest-modules multipledispatch after_success: - | diff --git a/multipledispatch/dispatcher.py b/multipledispatch/dispatcher.py index 7568595..8409205 100644 --- a/multipledispatch/dispatcher.py +++ b/multipledispatch/dispatcher.py @@ -1,3 +1,4 @@ +from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, get_type_hints from warnings import warn import inspect from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning @@ -95,7 +96,10 @@ def variadic_signature_matches(types, full_signature): return all(variadic_signature_matches_iter(types, full_signature)) -class Dispatcher(object): +DISPATCHED_RETURN = TypeVar("DISPATCHED_RETURN") + + +class Dispatcher(Generic[DISPATCHED_RETURN]): """ Dispatch methods based on type signature Use ``dispatch`` to add implementations @@ -119,14 +123,16 @@ class Dispatcher(object): """ __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc' - def __init__(self, name, doc=None): + def __init__(self, name: str, doc: Optional[None] = None) -> None: self.name = self.__name__ = name self.funcs = {} self.doc = doc self._cache = {} - def register(self, *types, **kwargs): + def register( + self, *types: type, **kwargs: Any + ) -> Callable[[Callable[..., DISPATCHED_RETURN]], Callable[..., DISPATCHED_RETURN]]: """ register dispatcher with new implementation >>> f = Dispatcher('f') @@ -171,19 +177,19 @@ def get_func_annotations(cls, func): if params: Parameter = inspect.Parameter - params = (param for param in params - if param.kind in - (Parameter.POSITIONAL_ONLY, - Parameter.POSITIONAL_OR_KEYWORD)) - + hints = get_type_hints(func) annotations = tuple( - param.annotation - for param in params) + hints.get(param.name, Parameter.empty) + for param in params + if param.kind in ( + Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD + ) + ) if all(ann is not Parameter.empty for ann in annotations): return annotations - def add(self, signature, func): + def add(self, signature: Tuple[type, ...], func: Callable[..., DISPATCHED_RETURN]) -> None: """ Add new types/method pair to dispatcher >>> D = Dispatcher('add') @@ -263,7 +269,7 @@ def reorder(self, on_ambiguity=ambiguity_warn): on_ambiguity(self, amb) return od - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> DISPATCHED_RETURN: types = tuple([type(arg) for arg in args]) try: func = self._cache[types] diff --git a/multipledispatch/tests/test_core.py b/multipledispatch/tests/test_core.py index d3f6eec..9483559 100644 --- a/multipledispatch/tests/test_core.py +++ b/multipledispatch/tests/test_core.py @@ -1,14 +1,15 @@ -from multipledispatch import dispatch +from multipledispatch import dispatch as orig_dispatch from multipledispatch.utils import raises from functools import partial +import pytest -test_namespace = dict() -orig_dispatch = dispatch -dispatch = partial(dispatch, namespace=test_namespace) +@pytest.fixture(name="dispatch") +def fixture_dispatch(): + return partial(orig_dispatch, namespace=dict()) -def test_singledispatch(): +def test_singledispatch(dispatch): @dispatch(int) def f(x): return x + 1 @@ -28,7 +29,7 @@ def f(x): assert raises(NotImplementedError, lambda: f('hello')) -def test_multipledispatch(benchmark): +def test_multipledispatch(dispatch): @dispatch(int, int) def f(x, y): return x + y @@ -48,7 +49,7 @@ class D(C): pass class E(C): pass -def test_inheritance(): +def test_inheritance(dispatch): @dispatch(A) def f(x): return 'a' @@ -62,7 +63,7 @@ def f(x): assert f(C()) == 'a' -def test_inheritance_and_multiple_dispatch(): +def test_inheritance_and_multiple_dispatch(dispatch): @dispatch(A, A) def f(x, y): return type(x), type(y) @@ -78,7 +79,7 @@ def f(x, y): assert raises(NotImplementedError, lambda: f(B(), B())) -def test_competing_solutions(): +def test_competing_solutions(dispatch): @dispatch(A) def h(x): return 1 @@ -90,7 +91,7 @@ def h(x): assert h(D()) == 2 -def test_competing_multiple(): +def test_competing_multiple(dispatch): @dispatch(A, B) def h(x, y): return 1 @@ -102,7 +103,7 @@ def h(x, y): assert h(D(), B()) == 2 -def test_competing_ambiguous(): +def test_competing_ambiguous(dispatch): @dispatch(A, C) def f(x, y): return 2 @@ -115,7 +116,7 @@ def f(x, y): # assert raises(Warning, lambda : f(C(), C())) -def test_caching_correct_behavior(): +def test_caching_correct_behavior(dispatch): @dispatch(A) def f(x): return 1 @@ -129,7 +130,7 @@ def f(x): assert f(C()) == 2 -def test_union_types(): +def test_union_types(dispatch): @dispatch((A, C)) def f(x): return 1 @@ -156,7 +157,7 @@ def foo(x): """ Fails -def test_dispatch_on_dispatch(): +def test_dispatch_on_dispatch(dispatch): @dispatch(A) @dispatch(C) def q(x): @@ -167,7 +168,7 @@ def q(x): """ -def test_methods(): +def test_methods(dispatch): class Foo(object): @dispatch(float) def f(self, x): @@ -188,7 +189,7 @@ def g(self, x): assert foo.g(1) == 4 -def test_methods_multiple_dispatch(): +def test_methods_multiple_dispatch(dispatch): class Foo(object): @dispatch(A, A) def f(x, y): diff --git a/multipledispatch/tests/test_dispatcher_3only.py b/multipledispatch/tests/test_dispatcher_3only.py index b041450..f23bc21 100644 --- a/multipledispatch/tests/test_dispatcher_3only.py +++ b/multipledispatch/tests/test_dispatcher_3only.py @@ -17,6 +17,10 @@ def inc(x: int): def inc(x: float): return x - 1 + @f.register() + def inc(x: 'float'): + return x - 1 + assert f(1) == 2 assert f(1.0) == 0.0