diff --git a/dataclass_type_validator/__init__.py b/dataclass_type_validator/__init__.py index ebf4e1a..5a09dca 100644 --- a/dataclass_type_validator/__init__.py +++ b/dataclass_type_validator/__init__.py @@ -5,6 +5,7 @@ from typing import Any from typing import Optional from typing import Dict +import types GlobalNS_T = Dict[str, Any] @@ -109,15 +110,23 @@ def _validate_typing_literal(expected_type: type, value: Any, strict: bool) -> O _validate_typing_mappings = { 'List': _validate_typing_list, + 'list': _validate_typing_list, 'Tuple': _validate_typing_tuple, + 'tuple': _validate_typing_tuple, 'FrozenSet': _validate_typing_frozenset, + 'frozenset': _validate_typing_frozenset, 'Dict': _validate_typing_dict, + 'dict': _validate_typing_dict, 'Callable': _validate_typing_callable, } +def _type_name(t: type) -> str: + return t._name if hasattr(t, '_name') else t.__name__ + + def _validate_sequential_types(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]: - validate_func = _validate_typing_mappings.get(expected_type._name) + validate_func = _validate_typing_mappings.get(_type_name(expected_type)) if validate_func is not None: return validate_func(expected_type, value, strict, globalns) @@ -132,16 +141,22 @@ def _validate_sequential_types(expected_type: type, value: Any, strict: bool, gl return if strict: - raise RuntimeError(f'Unknown type of {expected_type} (_name = {expected_type._name})') + raise RuntimeError(f'Unknown type of {expected_type} (_name = {_type_name(expected_type)}') +def _is_generic_alias(expected_type: type) -> bool: + if sys.version_info < (3, 9): + return isinstance(expected_type, typing._GenericAlias) + else: + return isinstance(expected_type, (typing._GenericAlias, types.GenericAlias)) -def _validate_types(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]: - if isinstance(expected_type, type): - return _validate_type(expected_type=expected_type, value=value) - if isinstance(expected_type, typing._GenericAlias): +def _validate_types(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]: + if _is_generic_alias(expected_type): return _validate_sequential_types(expected_type=expected_type, value=value, strict=strict, globalns=globalns) + + if isinstance(expected_type, type): + return _validate_type(expected_type=expected_type, value=value) if isinstance(expected_type, typing.ForwardRef): referenced_type = _evaluate_forward_reference(expected_type, globalns) diff --git a/tests/test_validator.py b/tests/test_validator.py index b54d493..5614a3b 100644 --- a/tests/test_validator.py +++ b/tests/test_validator.py @@ -361,3 +361,77 @@ def optional_type_name(arg_type_name): return f"typing.Union\\[({arg_type_name}, NoneType|NoneType, {arg_type_name})\\]" return f"typing.Optional\\[{arg_type_name}\\]" + +# Tests for generic types, only in 3.9+ +if sys.version_info >= (3, 9): + + @dataclasses.dataclass(frozen=True) + class DataclassTestGenericList: + array_of_numbers: list[int] + array_of_strings: list[str] + array_of_optional_strings: list[typing.Optional[str]] + + def __post_init__(self): + dataclass_type_validator(self) + + + class TestTypeValidationGenericList: + def test_build_success(self): + assert isinstance(DataclassTestGenericList( + array_of_numbers=[], + array_of_strings=[], + array_of_optional_strings=[], + ), DataclassTestGenericList) + assert isinstance(DataclassTestGenericList( + array_of_numbers=[1, 2], + array_of_strings=['abc'], + array_of_optional_strings=['abc', None] + ), DataclassTestGenericList) + + def test_build_failure_on_array_numbers(self): + with pytest.raises(TypeValidationError, match='must be an instance of list\\[int\\]'): + assert isinstance(DataclassTestGenericList( + array_of_numbers=['abc'], + array_of_strings=['abc'], + array_of_optional_strings=['abc', None] + ), DataclassTestGenericList) + + def test_build_failure_on_array_strings(self): + with pytest.raises(TypeValidationError, match='must be an instance of list\\[str\\]'): + assert isinstance(DataclassTestGenericList( + array_of_numbers=[1, 2], + array_of_strings=[123], + array_of_optional_strings=['abc', None] + ), DataclassTestGenericList) + + def test_build_failure_on_array_optional_strings(self): + with pytest.raises(TypeValidationError, + match=f"must be an instance of list\\[{optional_type_name('str')}\\]"): + assert isinstance(DataclassTestGenericList( + array_of_numbers=[1, 2], + array_of_strings=['abc'], + array_of_optional_strings=[123, None] + ), DataclassTestGenericList) + + @dataclasses.dataclass(frozen=True) + class DataclassTestGenericDict: + str_to_str: dict[str, str] + str_to_any: dict[str, typing.Any] + + def __post_init__(self): + dataclass_type_validator(self) + + + class TestTypeValidationGenericDict: + def test_build_success(self): + assert isinstance(DataclassTestGenericDict( + str_to_str={'str': 'str'}, + str_to_any={'str': 'str', 'str2': 123} + ), DataclassTestGenericDict) + + def test_build_failure(self): + with pytest.raises(TypeValidationError, match='must be an instance of dict\\[str, str\\]'): + assert isinstance(DataclassTestGenericDict( + str_to_str={'str': 123}, + str_to_any={'key': []} + ), DataclassTestGenericDict)