Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pybind11_stubgen/parser/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import abc
import types
from typing import Any
from typing import Any, Callable, TypeVar

from pybind11_stubgen.parser.errors import ParserError
from pybind11_stubgen.structs import (
Expand All @@ -23,6 +23,8 @@
Value,
)

T = TypeVar("T")


class IParser(abc.ABC):
@abc.abstractmethod
Expand Down Expand Up @@ -80,6 +82,13 @@ def handle_type(self, type_: type) -> QualifiedName: ...
@abc.abstractmethod
def handle_value(self, value: Any) -> Value: ...

def call_with_local_types(self, parameters: list[str], func: Callable[[], T]) -> T:
"""
PEP 695 added template syntax to classes and functions.
This will call the function with these additional local types.
"""
...

@abc.abstractmethod
def parse_args_str(self, args_str: str) -> list[Argument]: ...

Expand Down
25 changes: 23 additions & 2 deletions pybind11_stubgen/parser/mixins/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
import types
from logging import getLogger
from typing import Any, Sequence
from typing import Any, Callable, Sequence, TypeVar

from pybind11_stubgen.parser.errors import (
InvalidExpressionError,
Expand Down Expand Up @@ -38,6 +38,8 @@

logger = getLogger("pybind11_stubgen")

T = TypeVar("T")


class RemoveSelfAnnotation(IParser):

Expand Down Expand Up @@ -88,6 +90,7 @@ def __init__(self):
self.__extra_imports: set[Import] = set()
self.__current_module: types.ModuleType | None = None
self.__current_class: type | None = None
self.__local_types: set[str] = set()

def handle_alias(self, path: QualifiedName, origin: Any) -> Alias | None:
result = super().handle_alias(path, origin)
Expand Down Expand Up @@ -144,6 +147,13 @@ def handle_value(self, value: Any) -> Value:
self._add_import(QualifiedName.from_str(result.repr))
return result

def call_with_local_types(self, parameters: list[str], func: Callable[[], T]) -> T:
original_local_types = self.__local_types.copy()
self.__local_types.update(parameters)
result = super().call_with_local_types(parameters, func)
self.__local_types = original_local_types
return result

def parse_annotation_str(
self, annotation_str: str
) -> ResolvedType | InvalidExpression | Value:
Expand All @@ -155,7 +165,7 @@ def parse_annotation_str(
def _add_import(self, name: QualifiedName) -> None:
if len(name) == 0:
return
if len(name) == 1 and len(name[0]) == 0:
if len(name) == 1 and (len(name[0]) == 0 or name[0] in self.__local_types):
return
if hasattr(builtins, name[0]):
return
Expand Down Expand Up @@ -636,6 +646,7 @@ class FixNumpyArrayDimTypeVar(IParser):
numpy_primitive_types = FixNumpyArrayDimAnnotation.numpy_primitive_types

__DIM_VARS: set[str] = set()
__local_types: set[str] = set()

def handle_module(
self, path: QualifiedName, module: types.ModuleType
Expand All @@ -662,6 +673,13 @@ def handle_module(

return result

def call_with_local_types(self, parameters: list[str], func: Callable[[], T]) -> T:
original_local_types = self.__local_types.copy()
self.__local_types.update(parameters)
result = super().call_with_local_types(parameters, func)
self.__local_types = original_local_types
return result

def parse_annotation_str(
self, annotation_str: str
) -> ResolvedType | InvalidExpression | Value:
Expand All @@ -675,6 +693,9 @@ def parse_annotation_str(
if not isinstance(result, ResolvedType):
return result

if len(result.name) == 1 and result.name[0] in self.__local_types:
return result

# handle unqualified, single-letter annotation as a TypeVar
if len(result.name) == 1 and len(result.name[0]) == 1:
result.name = QualifiedName.from_str(result.name[0].upper())
Expand Down
63 changes: 55 additions & 8 deletions pybind11_stubgen/parser/mixins/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import ast
import inspect
import re
import sys
import types
from typing import Any
from typing import Any, Callable, TypeVar

from pybind11_stubgen.parser.errors import (
InvalidExpressionError,
Expand Down Expand Up @@ -40,6 +41,8 @@
Argument(name=Identifier("kwargs"), kw_variadic=True),
]

T = TypeVar("T")


class ParserDispatchMixin(IParser):
def handle_class(self, path: QualifiedName, class_: type) -> Class | None:
Expand Down Expand Up @@ -384,6 +387,9 @@ def handle_type(self, type_: type) -> QualifiedName:
)
)

def call_with_local_types(self, parameters: list[str], func: Callable[[], T]) -> T:
return func()

def parse_value_str(self, value: str) -> Value | InvalidExpression:
return self._parse_expression_str(value)

Expand Down Expand Up @@ -624,32 +630,53 @@ def parse_function_docstring(
return []

top_signature_regex = re.compile(
rf"^{func_name}\((?P<args>.*)\)\s*(->\s*(?P<returns>.+))?$"
rf"^{func_name}"
r"(\[(?P<type_vars>[\w\s,]*)])?"
r"\((?P<args>.*)\)\s*(->\s*(?P<returns>.+))?$"
)

match = top_signature_regex.match(doc_lines[0])
if match is None:
return []

if len(doc_lines) < 2 or doc_lines[1] != "Overloaded function.":
# TODO: Update to support more complex formats.
# This only supports bare type parameters.
type_vars_group = match.group("type_vars")
if sys.version_info < (3, 12) and type_vars_group:
# This syntax is not supported before Python 3.12.
return []
type_vars: list[str] = list(
filter(
bool, map(str.strip, (type_vars_group or "").split(","))
)
)
args = self.call_with_local_types(
type_vars, lambda: self.parse_args_str(match.group("args"))
)

returns_str = match.group("returns")
if returns_str is not None:
returns = self.parse_annotation_str(returns_str)
returns = self.call_with_local_types(
type_vars, lambda: self.parse_annotation_str(returns_str)
)
else:
returns = None

return [
Function(
name=func_name,
args=self.parse_args_str(match.group("args")),
args=args,
doc=self._strip_empty_lines(doc_lines[1:]),
returns=returns,
type_vars=type_vars,
)
]

overload_signature_regex = re.compile(
rf"^(\s*(?P<overload_number>\d+).\s*)"
rf"{func_name}\((?P<args>.*)\)\s*->\s*(?P<returns>.+)$"
rf"^(\s*(?P<overload_number>\d+)\.\s*){func_name}"
r"(\[(?P<type_vars>[\w\s,]*)])?"
r"\((?P<args>.*)\)\s*->\s*(?P<returns>.+)$"
)

doc_start = 0
Expand All @@ -661,18 +688,38 @@ def parse_function_docstring(
if match:
if match.group("overload_number") != f"{len(overloads)}":
continue
type_vars_group = match.group("type_vars")
if sys.version_info < (3, 12) and type_vars_group:
# This syntax is not supported before Python 3.12.
continue
overloads[-1].doc = self._strip_empty_lines(doc_lines[doc_start:i])
doc_start = i + 1
# TODO: Update to support more complex formats.
# This only supports bare type parameters.

type_vars: list[str] = list(
filter(
bool,
map(str.strip, (type_vars_group or "").split(",")),
)
)
args = self.call_with_local_types(
type_vars, lambda: self.parse_args_str(match.group("args"))
)
returns = self.call_with_local_types(
type_vars, lambda: self.parse_annotation_str(match.group("returns"))
)
overloads.append(
Function(
name=func_name,
args=self.parse_args_str(match.group("args")),
returns=self.parse_annotation_str(match.group("returns")),
args=args,
returns=returns,
doc=None,
decorators=[
# use `parse_annotation_str()` to trigger typing import
Decorator(str(self.parse_annotation_str("typing.overload")))
],
type_vars=type_vars,
)
)

Expand Down
17 changes: 12 additions & 5 deletions pybind11_stubgen/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,18 @@ def print_function(self, func: Function) -> list[str]:
args.append(self.print_argument(arg))
if len(args) > 0 and args[0] == "/":
args = args[1:]
signature = [
f"def {func.name}(",
", ".join(args),
")",
]
signature = [f"def {func.name}"]

if func.type_vars:
signature.extend(["[", ", ".join(func.type_vars), "]"])

signature.extend(
[
"(",
", ".join(args),
")",
]
)

if func.returns is not None:
signature.append(f" -> {self.print_annotation(func.returns)}")
Expand Down
1 change: 1 addition & 0 deletions pybind11_stubgen/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class Function:
returns: Annotation | None = field_(default=None)
doc: Docstring | None = field_(default=None)
decorators: list[Decorator] = field_(default_factory=list)
type_vars: list[str] = field_(default_factory=list)

def __str__(self):
return (
Expand Down
37 changes: 37 additions & 0 deletions tests/py-demo/bindings/src/modules/functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,41 @@ void bind_functions_module(py::module &&m) {
m.def("default_custom_arg", [](Foo &foo) {}, py::arg_v("foo", Foo(5), "Foo(5)"));
m.def("pass_callback", [](std::function<Foo(Foo &)> &callback) { return Foo(13); });
m.def("nested_types", [](std::variant<std::list<Foo>, Foo> arg){ return arg; });

py::options options;
options.disable_function_signatures();
m.def(
"passthrough1",
[](py::object obj) { return obj; },
py::doc("passthrough1[T](obj: T) -> T\n"));
m.def(
"passthrough2",
[](py::object obj) { return obj; },
py::doc(
"passthrough2(*args, **kwargs)\n"
"Overloaded function.\n"
"1. passthrough2() -> None\n"
"2. passthrough2[T](obj: T) -> T\n"),
py::arg("obj") = py::none());
m.def(
"passthrough3",
[](py::object obj1, py::object obj2) { return py::make_tuple(obj1, obj2); },
py::doc(
"passthrough3(*args, **kwargs)\n"
"Overloaded function.\n"
"1. passthrough3() -> tuple[None, None]\n"
"2. passthrough3[T](obj: T) -> tuple[T, None]\n"
"3. passthrough3[T1, T2](obj1: T1, obj2: T2) -> tuple[T1, T2]\n"),
py::arg("obj1") = py::none(),
py::arg("obj2") = py::none());
m.def(
"passthrough_backwards",
[](py::object obj) { return obj; },
#if PY_MAJOR_VERSION > 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 12)
py::doc("passthrough_backwards[T](obj: T) -> T\n"));
#else
py::doc("passthrough_backwards(obj: U) -> U\n"));
m.attr("U") = py::module::import("typing").attr("TypeVar")("U");
#endif
options.enable_function_signatures();
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import typing

__all__: list[str] = [
"Foo",
"U",
"accept_callable",
"accept_frozenset",
"accept_py_handle",
Expand All @@ -20,6 +21,10 @@ __all__: list[str] = [
"mul",
"nested_types",
"pass_callback",
"passthrough1",
"passthrough2",
"passthrough3",
"passthrough_backwards",
"pos_kw_only_mix",
"pos_kw_only_variadic_mix",
]
Expand Down Expand Up @@ -54,5 +59,26 @@ def mul(p: float, q: float) -> float:

def nested_types(arg0: list[Foo] | Foo) -> list[Foo] | Foo: ...
def pass_callback(arg0: typing.Callable[[Foo], Foo]) -> Foo: ...
def passthrough1(*args, **kwargs):
"""
passthrough1[T](obj: T) -> T
"""

@typing.overload
def passthrough2() -> None:
"""
2. passthrough2[T](obj: T) -> T
"""

@typing.overload
def passthrough3() -> tuple[None, None]:
"""
2. passthrough3[T](obj: T) -> tuple[T, None]
3. passthrough3[T1, T2](obj1: T1, obj2: T2) -> tuple[T1, T2]
"""

def passthrough_backwards(obj: U) -> U: ...
def pos_kw_only_mix(i: int, /, j: int, *, k: int) -> tuple: ...
def pos_kw_only_variadic_mix(i: int, /, j: int, *args, k: int, **kwargs) -> tuple: ...

U: typing.TypeVar # value = ~U
Loading