diff --git a/tests/test_type_conversion.py b/tests/test_type_conversion.py index 904a686d2e..7934d8f508 100644 --- a/tests/test_type_conversion.py +++ b/tests/test_type_conversion.py @@ -1,6 +1,6 @@ from enum import Enum from pathlib import Path -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union import click import pytest @@ -51,6 +51,28 @@ def opt(user: str | None = None): assert "User: Camila" in result.output +@pytest.mark.parametrize( + ("value", "expected"), + [("0", "ROOTED!"), ("12", "ID: 12"), ("name", "USER: name")], +) +def test_union(value, expected): + app = typer.Typer() + + @app.command() + def opt(id_or_name: Union[int, str]): + if isinstance(id_or_name, int): + if id_or_name == 0: + print("ROOTED!") + else: + print(f"ID: {id_or_name}") + else: + print(f"USER: {id_or_name}") + + result = runner.invoke(app, [value]) + assert result.exit_code == 0 + assert expected in result.output + + def test_optional_tuple(): app = typer.Typer() diff --git a/typer/main.py b/typer/main.py index e546e042b7..1bec796f5a 100644 --- a/typer/main.py +++ b/typer/main.py @@ -688,6 +688,30 @@ def wrapper(**kwargs: Any) -> Any: return wrapper +class UnionParamType(click.ParamType): + @property + def name(self) -> str: # type: ignore + return " | ".join(_type.name for _type in self._types) + + def __init__(self, types: List[click.ParamType]): + super().__init__() + self._types = types + + def convert( + self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context] + ) -> Any: + # *types, last = self._types + error_messages = [] + for _type in self._types: + try: + return _type.convert(value, param, ctx) + except click.BadParameter as e: + print(type(e)) + error_messages.append(str(e)) + # return last.convert(value, param, ctx) + raise self.fail("\n" + "\nbut also\n".join(error_messages), param, ctx) + + def get_click_type( *, annotation: Any, parameter_info: ParameterInfo ) -> click.ParamType: @@ -783,6 +807,12 @@ def get_click_type( [item.value for item in annotation], case_sensitive=parameter_info.case_sensitive, ) + elif get_origin(annotation) is not None and is_union(get_origin(annotation)): + types = [ + get_click_type(annotation=arg, parameter_info=parameter_info) + for arg in get_args(annotation) + ] + return UnionParamType(types) elif is_literal_type(annotation): return click.Choice( literal_values(annotation), @@ -838,9 +868,14 @@ def get_click_param( if type_ is NoneType: continue types.append(type_) - assert len(types) == 1, "Typer Currently doesn't support Union types" - main_type = types[0] - origin = get_origin(main_type) + if len(types) == 1: + (main_type,) = types + origin = get_origin(main_type) + else: + for type_ in get_args(main_type): + assert not get_origin(type_), ( + "Union types with complex sub-types are not currently supported" + ) # Handle Tuples and Lists if lenient_issubclass(origin, List): main_type = get_args(main_type)[0]