Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 117 additions & 49 deletions dataclass_type_validator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from typing import Any
from typing import Optional
from typing import Dict
import types

GlobalNS_T = Dict[str, Any]


class TypeValidationError(Exception):
"""Exception raised on type validation errors.
"""
"""Exception raised on type validation errors."""

def __init__(self, *args, target: dataclasses.dataclass, errors: dict):
super(TypeValidationError, self).__init__(*args)
Expand Down Expand Up @@ -41,115 +41,176 @@ def __str__(self):

def _validate_type(expected_type: type, value: Any) -> Optional[str]:
if not isinstance(value, expected_type):
return f'must be an instance of {expected_type}, but received {type(value)}'
return f"must be an instance of {expected_type}, but received {type(value)}"


def _validate_iterable_items(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
def _validate_iterable_items(
expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T
) -> Optional[str]:
expected_item_type = expected_type.__args__[0]
errors = [_validate_types(expected_type=expected_item_type, value=v, strict=strict, globalns=globalns)
for v in value]
errors = [
_validate_types(
expected_type=expected_item_type, value=v, strict=strict, globalns=globalns
)
for v in value
]
errors = [x for x in errors if x]
if len(errors) > 0:
return f'must be an instance of {expected_type}, but there are some errors: {errors}'
return f"must be an instance of {expected_type}, but there are some errors: {errors}"


def _validate_typing_list(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
def _validate_typing_list(
expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T
) -> Optional[str]:
if not isinstance(value, list):
return f'must be an instance of list, but received {type(value)}'
return f"must be an instance of list, but received {type(value)}"
return _validate_iterable_items(expected_type, value, strict, globalns)


def _validate_typing_tuple(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
def _validate_typing_tuple(
expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T
) -> Optional[str]:
if not isinstance(value, tuple):
return f'must be an instance of tuple, but received {type(value)}'
return f"must be an instance of tuple, but received {type(value)}"
return _validate_iterable_items(expected_type, value, strict, globalns)


def _validate_typing_frozenset(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
def _validate_typing_frozenset(
expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T
) -> Optional[str]:
if not isinstance(value, frozenset):
return f'must be an instance of frozenset, but received {type(value)}'
return f"must be an instance of frozenset, but received {type(value)}"
return _validate_iterable_items(expected_type, value, strict, globalns)


def _validate_typing_dict(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
def _validate_typing_dict(
expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T
) -> Optional[str]:
if not isinstance(value, dict):
return f'must be an instance of dict, but received {type(value)}'
return f"must be an instance of dict, but received {type(value)}"

expected_key_type = expected_type.__args__[0]
expected_value_type = expected_type.__args__[1]

key_errors = [_validate_types(expected_type=expected_key_type, value=k, strict=strict, globalns=globalns)
for k in value.keys()]
key_errors = [
_validate_types(
expected_type=expected_key_type, value=k, strict=strict, globalns=globalns
)
for k in value.keys()
]
key_errors = [k for k in key_errors if k]

val_errors = [_validate_types(expected_type=expected_value_type, value=v, strict=strict, globalns=globalns)
for v in value.values()]
val_errors = [
_validate_types(
expected_type=expected_value_type, value=v, strict=strict, globalns=globalns
)
for v in value.values()
]
val_errors = [v for v in val_errors if v]

if len(key_errors) > 0 and len(val_errors) > 0:
return f'must be an instance of {expected_type}, but there are some errors in keys and values. '\
f'key errors: {key_errors}, value errors: {val_errors}'
return (
f"must be an instance of {expected_type}, but there are some errors in keys and values. "
f"key errors: {key_errors}, value errors: {val_errors}"
)
elif len(key_errors) > 0:
return f'must be an instance of {expected_type}, but there are some errors in keys: {key_errors}'
return f"must be an instance of {expected_type}, but there are some errors in keys: {key_errors}"
elif len(val_errors) > 0:
return f'must be an instance of {expected_type}, but there are some errors in values: {val_errors}'
return f"must be an instance of {expected_type}, but there are some errors in values: {val_errors}"


def _validate_typing_callable(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
def _validate_typing_callable(
expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T
) -> Optional[str]:
_ = strict
if not isinstance(value, type(lambda a: a)):
return f'must be an instance of {expected_type._name}, but received {type(value)}'
return (
f"must be an instance of {expected_type._name}, but received {type(value)}"
)


def _validate_typing_literal(expected_type: type, value: Any, strict: bool) -> Optional[str]:
def _validate_typing_literal(
expected_type: type, value: Any, strict: bool
) -> Optional[str]:
_ = strict
if value not in expected_type.__args__:
return f'must be one of [{", ".join(expected_type.__args__)}] but received {value}'
return (
f"must be one of [{', '.join(expected_type.__args__)}] but received {value}"
)


_validate_typing_mappings = {
'List': _validate_typing_list,
'Tuple': _validate_typing_tuple,
'FrozenSet': _validate_typing_frozenset,
'Dict': _validate_typing_dict,
'Callable': _validate_typing_callable,
"List": _validate_typing_list,
"list": _validate_typing_list,
"Tuple": _validate_typing_tuple,
"tuple": _validate_typing_tuple,
"FrozenSet": _validate_typing_frozenset,
"frozenset": _validate_typing_frozenset,
"Dict": _validate_typing_dict,
"dict": _validate_typing_dict,
"Callable": _validate_typing_callable,
}


def _validate_sequential_types(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
validate_func = _validate_typing_mappings.get(expected_type._name)
def _type_name(t: type) -> str:
return t._name if hasattr(t, "_name") else t.__name__


def _validate_sequential_types(
expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T
) -> Optional[str]:
validate_func = _validate_typing_mappings.get(_type_name(expected_type))
if validate_func is not None:
return validate_func(expected_type, value, strict, globalns)

if str(expected_type).startswith('typing.Literal'):
if str(expected_type).startswith("typing.Literal"):
return _validate_typing_literal(expected_type, value, strict)

if str(expected_type).startswith('typing.Union') or str(expected_type).startswith('typing.Optional'):
is_valid = any(_validate_types(expected_type=t, value=value, strict=strict, globalns=globalns) is None
for t in expected_type.__args__)
if str(expected_type).startswith("typing.Union") or str(expected_type).startswith(
"typing.Optional"
):
is_valid = any(
_validate_types(
expected_type=t, value=value, strict=strict, globalns=globalns
)
is None
for t in expected_type.__args__
)
if not is_valid:
return f'must be an instance of {expected_type}, but received {value}'
return f"must be an instance of {expected_type}, but received {value}"
return

if strict:
raise RuntimeError(f'Unknown type of {expected_type} (_name = {expected_type._name})')
raise RuntimeError(
f"Unknown type of {expected_type} (_name = {_type_name(expected_type)}"
)


def _validate_types(expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T) -> Optional[str]:
def _is_generic_alias(expected_type: type) -> bool:
if sys.version_info < (3, 9):
return isinstance(expected_type, typing._GenericAlias)
else:
return isinstance(expected_type, (typing._GenericAlias, types.GenericAlias))


def _validate_types(
expected_type: type, value: Any, strict: bool, globalns: GlobalNS_T
) -> Optional[str]:
if _is_generic_alias(expected_type):
return _validate_sequential_types(
expected_type=expected_type, value=value, strict=strict, globalns=globalns
)

if isinstance(expected_type, type):
return _validate_type(expected_type=expected_type, value=value)

if isinstance(expected_type, typing._GenericAlias):
return _validate_sequential_types(expected_type=expected_type, value=value,
strict=strict, globalns=globalns)

if isinstance(expected_type, typing.ForwardRef):
referenced_type = _evaluate_forward_reference(expected_type, globalns)
return _validate_type(expected_type=referenced_type, value=value)


def _evaluate_forward_reference(ref_type: typing.ForwardRef, globalns: GlobalNS_T):
""" Support evaluating ForwardRef types on both Python 3.8 and 3.9. """
"""Support evaluating ForwardRef types on both Python 3.8 and 3.9."""
if sys.version_info < (3, 9):
return ref_type._evaluate(globalns, None)
return ref_type._evaluate(globalns, None, set())
Expand All @@ -165,7 +226,9 @@ def dataclass_type_validator(target, strict: bool = False):
expected_type = field.type
value = getattr(target, field_name)

err = _validate_types(expected_type=expected_type, value=value, strict=strict, globalns=globalns)
err = _validate_types(
expected_type=expected_type, value=value, strict=strict, globalns=globalns
)
if err is not None:
errors[field_name] = err

Expand All @@ -175,7 +238,9 @@ def dataclass_type_validator(target, strict: bool = False):
)


def dataclass_validate(cls=None, *, strict: bool = False, before_post_init: bool = False):
def dataclass_validate(
cls=None, *, strict: bool = False, before_post_init: bool = False
):
"""Dataclass decorator to automatically add validation to a dataclass.

So you don't have to add a __post_init__ method, or if you have one, you don't have
Expand All @@ -190,7 +255,9 @@ def dataclass_validate(cls=None, *, strict: bool = False, before_post_init: bool
validation. Default: False.
"""
if cls is None:
return functools.partial(dataclass_validate, strict=strict, before_post_init=before_post_init)
return functools.partial(
dataclass_validate, strict=strict, before_post_init=before_post_init
)

if not hasattr(cls, "__post_init__"):
# No post-init method, so no processing. Wrap the constructor instead.
Expand All @@ -215,6 +282,7 @@ def method_wrapper(self, *args, **kwargs):
x = orig_method(self, *args, **kwargs)
dataclass_type_validator(self, strict=strict)
return x

setattr(cls, wrapped_method_name, method_wrapper)

return cls
2 changes: 2 additions & 0 deletions src/dataclass_type_validator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def main() -> None:
print("Hello from dataclass-type-validator!")
Loading