Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions dataclass_type_validator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from typing import Optional
from typing import Dict
import types

GlobalNS_T = Dict[str, Any]

Expand Down Expand Up @@ -109,15 +110,23 @@ def _validate_typing_literal(expected_type: type, value: Any, strict: bool) -> O

_validate_typing_mappings = {
'List': _validate_typing_list,
'list': _validate_typing_list,
'Tuple': _validate_typing_tuple,
'tuple': _validate_typing_tuple,
'FrozenSet': _validate_typing_frozenset,
'frozenset': _validate_typing_frozenset,
'Dict': _validate_typing_dict,
'dict': _validate_typing_dict,
'Callable': _validate_typing_callable,
}


def _type_name(t: type) -> str:
return t._name if hasattr(t, '_name') else t.__name__


def _validate_sequential_types(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
validate_func = _validate_typing_mappings.get(expected_type._name)
validate_func = _validate_typing_mappings.get(_type_name(expected_type))
if validate_func is not None:
return validate_func(expected_type, value, strict, globalns)

Expand All @@ -132,16 +141,22 @@ def _validate_sequential_types(expected_type: type, value: Any, strict: bool, gl
return

if strict:
raise RuntimeError(f'Unknown type of {expected_type} (_name = {expected_type._name})')
raise RuntimeError(f'Unknown type of {expected_type} (_name = {_type_name(expected_type)}')

def _is_generic_alias(expected_type: type) -> bool:
if sys.version_info < (3, 9):
return isinstance(expected_type, typing._GenericAlias)
else:
return isinstance(expected_type, (typing._GenericAlias, types.GenericAlias))

def _validate_types(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
if isinstance(expected_type, type):
return _validate_type(expected_type=expected_type, value=value)

if isinstance(expected_type, typing._GenericAlias):
def _validate_types(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
if _is_generic_alias(expected_type):
return _validate_sequential_types(expected_type=expected_type, value=value,
strict=strict, globalns=globalns)

if isinstance(expected_type, type):
return _validate_type(expected_type=expected_type, value=value)

if isinstance(expected_type, typing.ForwardRef):
referenced_type = _evaluate_forward_reference(expected_type, globalns)
Expand Down
74 changes: 74 additions & 0 deletions tests/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,77 @@ def optional_type_name(arg_type_name):
return f"typing.Union\\[({arg_type_name}, NoneType|NoneType, {arg_type_name})\\]"

return f"typing.Optional\\[{arg_type_name}\\]"

# Tests for generic types, only in 3.9+
if sys.version_info >= (3, 9):

@dataclasses.dataclass(frozen=True)
class DataclassTestGenericList:
array_of_numbers: list[int]
array_of_strings: list[str]
array_of_optional_strings: list[typing.Optional[str]]

def __post_init__(self):
dataclass_type_validator(self)


class TestTypeValidationGenericList:
def test_build_success(self):
assert isinstance(DataclassTestGenericList(
array_of_numbers=[],
array_of_strings=[],
array_of_optional_strings=[],
), DataclassTestGenericList)
assert isinstance(DataclassTestGenericList(
array_of_numbers=[1, 2],
array_of_strings=['abc'],
array_of_optional_strings=['abc', None]
), DataclassTestGenericList)

def test_build_failure_on_array_numbers(self):
with pytest.raises(TypeValidationError, match='must be an instance of list\\[int\\]'):
assert isinstance(DataclassTestGenericList(
array_of_numbers=['abc'],
array_of_strings=['abc'],
array_of_optional_strings=['abc', None]
), DataclassTestGenericList)

def test_build_failure_on_array_strings(self):
with pytest.raises(TypeValidationError, match='must be an instance of list\\[str\\]'):
assert isinstance(DataclassTestGenericList(
array_of_numbers=[1, 2],
array_of_strings=[123],
array_of_optional_strings=['abc', None]
), DataclassTestGenericList)

def test_build_failure_on_array_optional_strings(self):
with pytest.raises(TypeValidationError,
match=f"must be an instance of list\\[{optional_type_name('str')}\\]"):
assert isinstance(DataclassTestGenericList(
array_of_numbers=[1, 2],
array_of_strings=['abc'],
array_of_optional_strings=[123, None]
), DataclassTestGenericList)

@dataclasses.dataclass(frozen=True)
class DataclassTestGenericDict:
str_to_str: dict[str, str]
str_to_any: dict[str, typing.Any]

def __post_init__(self):
dataclass_type_validator(self)


class TestTypeValidationGenericDict:
def test_build_success(self):
assert isinstance(DataclassTestGenericDict(
str_to_str={'str': 'str'},
str_to_any={'str': 'str', 'str2': 123}
), DataclassTestGenericDict)

def test_build_failure(self):
with pytest.raises(TypeValidationError, match='must be an instance of dict\\[str, str\\]'):
assert isinstance(DataclassTestGenericDict(
str_to_str={'str': 123},
str_to_any={'key': []}
), DataclassTestGenericDict)