diff --git a/typemap/type_eval/_apply_generic.py b/typemap/type_eval/_apply_generic.py index 02f493a..44a033e 100644 --- a/typemap/type_eval/_apply_generic.py +++ b/typemap/type_eval/_apply_generic.py @@ -15,6 +15,7 @@ if typing.TYPE_CHECKING: from typing import Any, Mapping + from typemap.typing import GenericCallable, Overloaded @dataclasses.dataclass(frozen=True) @@ -285,8 +286,10 @@ def get_annotations( return rr -def _resolved_function_signature(func, args): - """Get the signature of a function with type hints resolved to arg values""" +def _resolved_function_signature( + func, args, definition_cls: type | None = None +): + """Get the signature of a function with hints resolved to arg values.""" import typemap.typing as nt @@ -305,7 +308,7 @@ def _resolved_function_signature(func, args): finally: nt.special_form_evaluator.reset(token) - if hints := get_annotations(func, args): + if hints := get_annotations(func, args, cls=definition_cls): params = [] for name, param in sig.parameters.items(): annotation = hints.get(name, param.annotation) @@ -315,8 +318,10 @@ def _resolved_function_signature(func, args): sig = sig.replace( parameters=params, return_annotation=return_annotation ) + return sig - return sig + else: + return None def get_local_defns( @@ -324,10 +329,17 @@ def get_local_defns( ) -> tuple[ dict[str, Any], dict[ - str, types.FunctionType | classmethod | staticmethod | WrappedOverloads + str, + type[ + typing.Callable + | classmethod + | staticmethod + | GenericCallable + | Overloaded + ], ], ]: - from typemap.typing import GenericCallable + from typemap.typing import GenericCallable, Overloaded annos: dict[str, Any] = {} dct: dict[str, Any] = {} @@ -339,43 +351,21 @@ def get_local_defns( if name in EXCLUDED_ATTRIBUTES: continue + if orig is typing._no_init_or_replace_init: # type: ignore[attr-defined] + continue + stuff = inspect.unwrap(orig) if isinstance(stuff, types.FunctionType): - local_fn: Any = None - - # TODO: This annos_ok thing is a hack because processing - # __annotations__ on methods broke stuff and I didn't want - # to chase it down yet. - stuck = False - try: - rr = get_annotations( - stuff, boxed.str_args, cls=boxed.cls, annos_ok=False - ) - except _eval_typing.StuckException: - stuck = True - rr = None - - if rr is not None: - local_fn = make_func(orig, rr) - elif not stuck and getattr(stuff, "__annotations__", None): - # XXX: This is totally wrong; we still need to do - # substitute in class vars - local_fn = stuff - elif overloads := typing.get_overloads(stuff): - local_fn = WrappedOverloads(tuple(overloads)) - - # If we got stuck, we build a GenericCallable that - # computes the type once it has been given type - # variables! - if stuck and stuff.__type_params__: + # If the method has type params, we build a GenericCallable + # (in annos only) so that [Z] etc. are preserved in output. + if stuff.__type_params__: type_params = stuff.__type_params__ str_args = boxed.str_args - canonical_cls = boxed.canonical_cls - - def _make_lambda(fn, o, sa, tp, cls): - from ._eval_operators import _function_type_from_sig + receiver_cls = boxed.alias_type() + definition_cls = boxed.canonical_cls + def _make_lambda(fn, o, sa, tp, recv_cls, def_cls): def lam(*vs): args = dict(sa) args.update( @@ -385,9 +375,11 @@ def lam(*vs): strict=True, ) ) - sig = _resolved_function_signature(fn, args) + sig = _resolved_function_signature( + fn, args, definition_cls=def_cls + ) return _function_type_from_sig( - sig, o, receiver_type=cls + sig, type(o), receiver_type=recv_cls ) return lam @@ -395,24 +387,163 @@ def lam(*vs): gc = GenericCallable[ # type: ignore[valid-type,misc] tuple[*type_params], # type: ignore[valid-type] _make_lambda( - stuff, orig, str_args, type_params, canonical_cls + stuff, + orig, + str_args, + type_params, + receiver_cls, + definition_cls, ), ] - annos[name] = typing.ClassVar[gc] - elif local_fn is not None: - if orig.__class__ is classmethod: - local_fn = classmethod(local_fn) - elif orig.__class__ is staticmethod: - local_fn = staticmethod(local_fn) + dct[name] = gc - dct[name] = local_fn + elif overloads := typing.get_overloads(stuff): + # If the method is overloaded, build an Overloaded type. + overload_types: typing.Sequence[ + type[ + typing.Callable + | classmethod + | staticmethod + | GenericCallable + ] + ] = [ + _function_type( + _eval_typing.eval_typing(of), + receiver_type=boxed.alias_type(), + ) + for of in overloads + ] + + dct[name] = Overloaded[*overload_types] # type: ignore[valid-type] + continue + + else: + # Try to resolve the signature as a normal function. + resolved_sig = None + try: + resolved_sig = _resolved_function_signature( + stuff, + boxed.str_args, + definition_cls=boxed.cls, + ) + except _eval_typing.StuckException: + # We can get stuck if the signature has external type vars. + # Just fallback to the original signature for now. + resolved_sig = inspect.signature(stuff) + + if resolved_sig is not None: + dct[name] = _function_type_from_sig( + resolved_sig, + type(orig), + receiver_type=boxed.alias_type(), + ) + continue return annos, dct -@dataclasses.dataclass(frozen=True) -class WrappedOverloads: - functions: tuple[typing.Callable[..., Any], ...] +def _function_type_from_sig(sig, func_type, *, receiver_type): + from typemap.typing import Param + + empty = inspect.Parameter.empty + + def _ann(x): + return typing.Any if x is empty else None if x is type(None) else x + + specified_receiver = receiver_type + + params = [] + for i, p in enumerate(sig.parameters.values()): + ann = p.annotation + # Special handling for first argument on methods. + if i == 0 and receiver_type and func_type is not staticmethod: + if ann is empty: + ann = receiver_type + else: + if ( + func_type is classmethod + and typing.get_origin(ann) is type + and (receiver_args := typing.get_args(ann)) + ): + # The annotation for cls in a classmethod should be type[C] + specified_receiver = receiver_args[0] + else: + specified_receiver = ann + + quals = [] + if p.kind == inspect.Parameter.VAR_POSITIONAL: + quals.append("*") + if p.kind == inspect.Parameter.VAR_KEYWORD: + quals.append("**") + if p.kind == inspect.Parameter.KEYWORD_ONLY: + quals.append("keyword") + if p.kind == inspect.Parameter.POSITIONAL_ONLY: + quals.append("positional") + if p.default is not empty: + quals.append("default") + params.append( + Param[ + typing.Literal[p.name], + _ann(ann), + typing.Literal[*quals] if quals else typing.Never, + ] + ) + + ret = _ann(sig.return_annotation) + + # TODO: Is doing the tuple for staticmethod/classmethod legit? + # Putting a list in makes it unhashable... + f: typing.Any # type: ignore[annotation-unchecked] + if func_type is staticmethod: + f = staticmethod[tuple[*params], ret] + elif func_type is classmethod: + f = classmethod[specified_receiver, tuple[*params[1:]], ret] + else: + f = typing.Callable[params, ret] + + return f + + +def _function_type( + func, *, receiver_type +) -> type[typing.Callable | classmethod | staticmethod | GenericCallable]: + from typemap.typing import GenericCallable + + root = inspect.unwrap(func) + sig = inspect.signature(root) + f = _function_type_from_sig(sig, type(func), receiver_type=receiver_type) + + if root.__type_params__: + # Must store a lambda that performs type variable substitution + type_params = root.__type_params__ + callable_lambda = _create_generic_callable_lambda(f, type_params) + f = GenericCallable[tuple[*type_params], callable_lambda] # type: ignore[misc,valid-type] + return f + + +def _create_generic_callable_lambda( + f: typing.Callable | classmethod | staticmethod, + type_params: tuple[typing.TypeVar, ...], +): + if typing.get_origin(f) in (staticmethod, classmethod): + return lambda *vs: substitute( + f, dict(zip(type_params, vs, strict=True)) + ) + + else: + # Callable params are stored as a list + params, ret = typing.get_args(f) + + return lambda *vs: typing.Callable[ + [ + substitute( + p, + dict(zip(type_params, vs, strict=True)), + ) + for p in params + ], + substitute(ret, dict(zip(type_params, vs, strict=True))), + ] def flatten_class_new_proto(cls: type) -> type: diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py index bf513d6..7b45a50 100644 --- a/typemap/type_eval/_eval_operators.py +++ b/typemap/type_eval/_eval_operators.py @@ -41,7 +41,6 @@ Member, Members, NewProtocol, - Overloaded, Param, RaiseError, Slice, @@ -154,35 +153,12 @@ def get_annotated_method_hints(cls, *, ctx): _, dct = _apply_generic.get_local_defns(abox) for name, attr in dct.items(): - if isinstance( + hints[name] = ( attr, - ( - types.FunctionType, - types.MethodType, - staticmethod, - classmethod, - ), - ): - if attr is typing._no_init_or_replace_init: - continue - - hints[name] = ( - _function_type(attr, receiver_type=acls), - ("ClassVar",), - object, - acls, - ) - elif isinstance(attr, _apply_generic.WrappedOverloads): - overloads = [ - _function_type(_eval_types(of, ctx), receiver_type=acls) - for of in attr.functions - ] - hints[name] = ( - Overloaded[*overloads], - ("ClassVar",), - object, - acls, - ) + ("ClassVar",), + object, + acls, + ) return hints @@ -703,106 +679,6 @@ def _callable_type_to_method(name, typ, ctx): return head(func) -def _function_type_from_sig(sig, func, *, receiver_type): - empty = inspect.Parameter.empty - - def _ann(x): - return typing.Any if x is empty else None if x is type(None) else x - - specified_receiver = receiver_type - - params = [] - for i, p in enumerate(sig.parameters.values()): - ann = p.annotation - # Special handling for first argument on methods. - if i == 0 and receiver_type and not isinstance(func, staticmethod): - if ann is empty: - ann = receiver_type - else: - if ( - isinstance(func, classmethod) - and typing.get_origin(ann) is type - and (receiver_args := typing.get_args(ann)) - ): - # The annotation for cls in a classmethod should be type[C] - specified_receiver = receiver_args[0] - else: - specified_receiver = ann - - quals = [] - if p.kind == inspect.Parameter.VAR_POSITIONAL: - quals.append("*") - if p.kind == inspect.Parameter.VAR_KEYWORD: - quals.append("**") - if p.kind == inspect.Parameter.KEYWORD_ONLY: - quals.append("keyword") - if p.kind == inspect.Parameter.POSITIONAL_ONLY: - quals.append("positional") - if p.default is not empty: - quals.append("default") - params.append( - Param[ - typing.Literal[p.name], - _ann(ann), - typing.Literal[*quals] if quals else typing.Never, - ] - ) - - ret = _ann(sig.return_annotation) - - # TODO: Is doing the tuple for staticmethod/classmethod legit? - # Putting a list in makes it unhashable... - f: typing.Any # type: ignore[annotation-unchecked] - if isinstance(func, staticmethod): - f = staticmethod[tuple[*params], ret] - elif isinstance(func, classmethod): - f = classmethod[specified_receiver, tuple[*params[1:]], ret] - else: - f = typing.Callable[params, ret] - - return f - - -def _function_type(func, *, receiver_type): - root = inspect.unwrap(func) - sig = inspect.signature(root) - f = _function_type_from_sig(sig, func, receiver_type=receiver_type) - - if root.__type_params__: - # Must store a lambda that performs type variable substitution - type_params = root.__type_params__ - callable_lambda = _create_generic_callable_lambda(f, type_params) - f = GenericCallable[tuple[*type_params], callable_lambda] - return f - - -def _create_generic_callable_lambda( - f: typing.Callable | classmethod | staticmethod, - type_params: tuple[typing.TypeVar, ...], -): - if typing.get_origin(f) in (staticmethod, classmethod): - return lambda *vs: _apply_generic.substitute( - f, dict(zip(type_params, vs, strict=True)) - ) - - else: - # Callable params are stored as a list - params, ret = typing.get_args(f) - - return lambda *vs: typing.Callable[ - [ - _apply_generic.substitute( - p, - dict(zip(type_params, vs, strict=True)), - ) - for p in params - ], - _apply_generic.substitute( - ret, dict(zip(type_params, vs, strict=True)) - ), - ] - - def _hint_to_member(n, t, qs, init, d, *, ctx): return Member[ typing.Literal[n],