Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,30 +1,22 @@

language: python
python:
- "pypy"
- "pypy3"
matrix:
include:
- { python: '2.7', env: }
- { arch: arm64, python: '2.7' }
- { python: '3.4', env: }
- { python: '3.5', env: }
- { python: '3.6', env: }
- { python: '3.7', env: }
- { python: '3.8', env: }
- { arch: arm64, python: '3.8' }
- { python: '3.9', env: }
- { arch: arm64, python: '3.9' }
install:
- pip install --upgrade pip
- pip install coverage
- pip install --upgrade pytest pytest-benchmark

script:
- |
if [[ $(bc <<< "$TRAVIS_PYTHON_VERSION >= 3.3") -eq 1 ]]; then
py.test --doctest-modules multipledispatch
else
py.test --doctest-modules --ignore=multipledispatch/tests/test_dispatcher_3only.py multipledispatch
fi
py.test --doctest-modules multipledispatch

after_success:
- |
Expand Down
30 changes: 18 additions & 12 deletions multipledispatch/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, get_type_hints
from warnings import warn
import inspect
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
Expand Down Expand Up @@ -95,7 +96,10 @@ def variadic_signature_matches(types, full_signature):
return all(variadic_signature_matches_iter(types, full_signature))


class Dispatcher(object):
DISPATCHED_RETURN = TypeVar("DISPATCHED_RETURN")


class Dispatcher(Generic[DISPATCHED_RETURN]):
""" Dispatch methods based on type signature

Use ``dispatch`` to add implementations
Expand All @@ -119,14 +123,16 @@ class Dispatcher(object):
"""
__slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc'

def __init__(self, name, doc=None):
def __init__(self, name: str, doc: Optional[None] = None) -> None:
self.name = self.__name__ = name
self.funcs = {}
self.doc = doc

self._cache = {}

def register(self, *types, **kwargs):
def register(
self, *types: type, **kwargs: Any
) -> Callable[[Callable[..., DISPATCHED_RETURN]], Callable[..., DISPATCHED_RETURN]]:
""" register dispatcher with new implementation

>>> f = Dispatcher('f')
Expand Down Expand Up @@ -171,19 +177,19 @@ def get_func_annotations(cls, func):
if params:
Parameter = inspect.Parameter

params = (param for param in params
if param.kind in
(Parameter.POSITIONAL_ONLY,
Parameter.POSITIONAL_OR_KEYWORD))

hints = get_type_hints(func)
annotations = tuple(
param.annotation
for param in params)
hints.get(param.name, Parameter.empty)
for param in params
if param.kind in (
Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD
)
)

if all(ann is not Parameter.empty for ann in annotations):
return annotations

def add(self, signature, func):
def add(self, signature: Tuple[type, ...], func: Callable[..., DISPATCHED_RETURN]) -> None:
""" Add new types/method pair to dispatcher

>>> D = Dispatcher('add')
Expand Down Expand Up @@ -263,7 +269,7 @@ def reorder(self, on_ambiguity=ambiguity_warn):
on_ambiguity(self, amb)
return od

def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> DISPATCHED_RETURN:
types = tuple([type(arg) for arg in args])
try:
func = self._cache[types]
Expand Down
33 changes: 17 additions & 16 deletions multipledispatch/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from multipledispatch import dispatch
from multipledispatch import dispatch as orig_dispatch
from multipledispatch.utils import raises
from functools import partial
import pytest

test_namespace = dict()

orig_dispatch = dispatch
dispatch = partial(dispatch, namespace=test_namespace)
@pytest.fixture(name="dispatch")
def fixture_dispatch():
return partial(orig_dispatch, namespace=dict())


def test_singledispatch():
def test_singledispatch(dispatch):
@dispatch(int)
def f(x):
return x + 1
Expand All @@ -28,7 +29,7 @@ def f(x):
assert raises(NotImplementedError, lambda: f('hello'))


def test_multipledispatch(benchmark):
def test_multipledispatch(dispatch):
@dispatch(int, int)
def f(x, y):
return x + y
Expand All @@ -48,7 +49,7 @@ class D(C): pass
class E(C): pass


def test_inheritance():
def test_inheritance(dispatch):
@dispatch(A)
def f(x):
return 'a'
Expand All @@ -62,7 +63,7 @@ def f(x):
assert f(C()) == 'a'


def test_inheritance_and_multiple_dispatch():
def test_inheritance_and_multiple_dispatch(dispatch):
@dispatch(A, A)
def f(x, y):
return type(x), type(y)
Expand All @@ -78,7 +79,7 @@ def f(x, y):
assert raises(NotImplementedError, lambda: f(B(), B()))


def test_competing_solutions():
def test_competing_solutions(dispatch):
@dispatch(A)
def h(x):
return 1
Expand All @@ -90,7 +91,7 @@ def h(x):
assert h(D()) == 2


def test_competing_multiple():
def test_competing_multiple(dispatch):
@dispatch(A, B)
def h(x, y):
return 1
Expand All @@ -102,7 +103,7 @@ def h(x, y):
assert h(D(), B()) == 2


def test_competing_ambiguous():
def test_competing_ambiguous(dispatch):
@dispatch(A, C)
def f(x, y):
return 2
Expand All @@ -115,7 +116,7 @@ def f(x, y):
# assert raises(Warning, lambda : f(C(), C()))


def test_caching_correct_behavior():
def test_caching_correct_behavior(dispatch):
@dispatch(A)
def f(x):
return 1
Expand All @@ -129,7 +130,7 @@ def f(x):
assert f(C()) == 2


def test_union_types():
def test_union_types(dispatch):
@dispatch((A, C))
def f(x):
return 1
Expand All @@ -156,7 +157,7 @@ def foo(x):

"""
Fails
def test_dispatch_on_dispatch():
def test_dispatch_on_dispatch(dispatch):
@dispatch(A)
@dispatch(C)
def q(x):
Expand All @@ -167,7 +168,7 @@ def q(x):
"""


def test_methods():
def test_methods(dispatch):
class Foo(object):
@dispatch(float)
def f(self, x):
Expand All @@ -188,7 +189,7 @@ def g(self, x):
assert foo.g(1) == 4


def test_methods_multiple_dispatch():
def test_methods_multiple_dispatch(dispatch):
class Foo(object):
@dispatch(A, A)
def f(x, y):
Expand Down
4 changes: 4 additions & 0 deletions multipledispatch/tests/test_dispatcher_3only.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def inc(x: int):
def inc(x: float):
return x - 1

@f.register()
def inc(x: 'float'):
return x - 1

assert f(1) == 2
assert f(1.0) == 0.0

Expand Down