Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 158 additions & 49 deletions dataclass_type_validator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import dataclasses
import typing
import functools
from typing import Any
from typing import Optional
import logging
import typing
from typing import Any, Optional

from pydantic import BaseModel

logger = logging.getLogger(__name__)


class TypeValidationError(Exception):
"""Exception raised on type validation errors.
"""
"""Exception raised on type validation errors."""

def __init__(self, *args, target: dataclasses.dataclass, errors: dict):
super(TypeValidationError, self).__init__(*args)
Expand All @@ -16,59 +19,59 @@ def __init__(self, *args, target: dataclasses.dataclass, errors: dict):

def __repr__(self):
cls = self.class_
cls_name = (
f"{cls.__module__}.{cls.__name__}"
if cls.__module__ != "__main__"
else cls.__name__
)
cls_name = f"{cls.__module__}.{cls.__name__}" if cls.__module__ != "__main__" else cls.__name__
attrs = ", ".join([repr(v) for v in self.args])
return f"{cls_name}({attrs}, errors={repr(self.errors)})"

def __str__(self):
cls = self.class_
cls_name = (
f"{cls.__module__}.{cls.__name__}"
if cls.__module__ != "__main__"
else cls.__name__
)
cls_name = f"{cls.__module__}.{cls.__name__}" if cls.__module__ != "__main__" else cls.__name__
s = cls_name
return f"{s} (errors = {self.errors})"


class EnforceError(Exception):
"""Exception raised on enforcing validation errors."""

def __init__(self, *args):
super(EnforceError, self).__init__(*args)
pass


def _validate_type(expected_type: type, value: Any) -> Optional[str]:
if not isinstance(value, expected_type):
return f'must be an instance of {expected_type}, but received {type(value)}'
return f"must be an instance of {expected_type}, but received {type(value)}"


def _validate_iterable_items(expected_type: type, value: Any, strict: bool) -> Optional[str]:
expected_item_type = expected_type.__args__[0]
errors = [_validate_types(expected_type=expected_item_type, value=v, strict=strict) for v in value]
errors = [x for x in errors if x]
if len(errors) > 0:
return f'must be an instance of {expected_type}, but there are some errors: {errors}'
return f"must be an instance of {expected_type}, but there are some errors: {errors}"


def _validate_typing_list(expected_type: type, value: Any, strict: bool) -> Optional[str]:
if not isinstance(value, list):
return f'must be an instance of list, but received {type(value)}'
return f"must be an instance of list, but received {type(value)}"
return _validate_iterable_items(expected_type, value, strict)


def _validate_typing_tuple(expected_type: type, value: Any, strict: bool) -> Optional[str]:
if not isinstance(value, tuple):
return f'must be an instance of tuple, but received {type(value)}'
return f"must be an instance of tuple, but received {type(value)}"
return _validate_iterable_items(expected_type, value, strict)


def _validate_typing_frozenset(expected_type: type, value: Any, strict: bool) -> Optional[str]:
if not isinstance(value, frozenset):
return f'must be an instance of frozenset, but received {type(value)}'
return f"must be an instance of frozenset, but received {type(value)}"
return _validate_iterable_items(expected_type, value, strict)


def _validate_typing_dict(expected_type: type, value: Any, strict: bool) -> Optional[str]:
if not isinstance(value, dict):
return f'must be an instance of dict, but received {type(value)}'
return f"must be an instance of dict, but received {type(value)}"

expected_key_type = expected_type.__args__[0]
expected_value_type = expected_type.__args__[1]
Expand All @@ -80,18 +83,20 @@ def _validate_typing_dict(expected_type: type, value: Any, strict: bool) -> Opti
val_errors = [v for v in val_errors if v]

if len(key_errors) > 0 and len(val_errors) > 0:
return f'must be an instance of {expected_type}, but there are some errors in keys and values. '\
f'key errors: {key_errors}, value errors: {val_errors}'
return (
f"must be an instance of {expected_type}, but there are some errors in keys and values. "
f"key errors: {key_errors}, value errors: {val_errors}"
)
elif len(key_errors) > 0:
return f'must be an instance of {expected_type}, but there are some errors in keys: {key_errors}'
return f"must be an instance of {expected_type}, but there are some errors in keys: {key_errors}"
elif len(val_errors) > 0:
return f'must be an instance of {expected_type}, but there are some errors in values: {val_errors}'
return f"must be an instance of {expected_type}, but there are some errors in values: {val_errors}"


def _validate_typing_callable(expected_type: type, value: Any, strict: bool) -> Optional[str]:
_ = strict
if not isinstance(value, type(lambda a: a)):
return f'must be an instance of {expected_type._name}, but received {type(value)}'
return f"must be an instance of {expected_type._name}, but received {type(value)}"


def _validate_typing_literal(expected_type: type, value: Any, strict: bool) -> Optional[str]:
Expand All @@ -101,11 +106,11 @@ def _validate_typing_literal(expected_type: type, value: Any, strict: bool) -> O


_validate_typing_mappings = {
'List': _validate_typing_list,
'Tuple': _validate_typing_tuple,
'FrozenSet': _validate_typing_frozenset,
'Dict': _validate_typing_dict,
'Callable': _validate_typing_callable,
"List": _validate_typing_list,
"Tuple": _validate_typing_tuple,
"FrozenSet": _validate_typing_frozenset,
"Dict": _validate_typing_dict,
"Callable": _validate_typing_callable,
}


Expand All @@ -114,18 +119,17 @@ def _validate_sequential_types(expected_type: type, value: Any, strict: bool) ->
if validate_func is not None:
return validate_func(expected_type, value, strict)

if str(expected_type).startswith('typing.Literal'):
if str(expected_type).startswith("typing.Literal"):
return _validate_typing_literal(expected_type, value, strict)

if str(expected_type).startswith('typing.Union'):
is_valid = any(_validate_types(expected_type=t, value=value, strict=strict) is None
for t in expected_type.__args__)
if str(expected_type).startswith("typing.Union"):
is_valid = any(_validate_types(expected_type=t, value=value, strict=strict) is None for t in expected_type.__args__)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Flake8] line too long (124 > 120 characters) (view)

Details
Rule
E501

You can close this issue if no need to fix it. Learn more.

if not is_valid:
return f'must be an instance of {expected_type}, but received {value}'
return f"must be an instance of {expected_type}, but received {value}"
return

if strict:
raise RuntimeError(f'Unknown type of {expected_type} (_name = {expected_type._name})')
raise RuntimeError(f"Unknown type of {expected_type} (_name = {expected_type._name})")


def _validate_types(expected_type: type, value: Any, strict: bool) -> Optional[str]:
Expand All @@ -136,8 +140,11 @@ def _validate_types(expected_type: type, value: Any, strict: bool) -> Optional[s
return _validate_sequential_types(expected_type=expected_type, value=value, strict=strict)


def dataclass_type_validator(target, strict: bool = False):
fields = dataclasses.fields(target)
def dataclass_type_validator(target, strict: bool = False, enforce: bool = False):
if isinstance(target, BaseModel):
fields = target.fields.values()
else:
fields = dataclasses.fields(target)

errors = {}
for field in fields:
Expand All @@ -148,14 +155,61 @@ def dataclass_type_validator(target, strict: bool = False):
err = _validate_types(expected_type=expected_type, value=value, strict=strict)
if err is not None:
errors[field_name] = err
if enforce:
val = (
field.default
if not isinstance(field.default, (dataclasses._MISSING_TYPE, type(None)))
else field.default_factory()
)
if isinstance(val, (dataclasses._MISSING_TYPE, type(None))):
raise EnforceError("Can't enforce values as there is no default")
setattr(target, field_name, val)

if len(errors) > 0 and not enforce:
raise TypeValidationError("Dataclass Type Validation Error", target=target, errors=errors)

elif len(errors) > 0 and enforce:
cls = target.__class__
cls_name = f"{cls.__module__}.{cls.__name__}" if cls.__module__ != "__main__" else cls.__name__
logger.warning(f"Dataclass type validation failed, types are enforced. {cls_name} errors={repr(errors)})")


def pydantic_type_validator(cls, values: dict, strict: bool = False, enforce: bool = False):
fields = cls.__fields__.values()
errors = {}
for field in fields:
field_name = field.name
expected_type = field.type_
value = values[field_name] if field_name in values.keys() else None

if len(errors) > 0:
raise TypeValidationError(
"Dataclass Type Validation Error", target=target, errors=errors
)


def dataclass_validate(cls=None, *, strict: bool = False, before_post_init: bool = False):
err = _validate_types(expected_type=expected_type, value=value, strict=strict)
new_values = values
if err is not None:
errors[field_name] = err
if enforce:
val = field.default if not isinstance(field.default, type(None)) else None
if val is None:
val = field.default_factory() if not isinstance(field.default_factory, type(None)) else None
if val is None:
raise EnforceError("Can't enforce values as there is no default")
new_values[field_name] = val

if len(errors) > 0 and not enforce:
raise TypeValidationError("Pydantic Type Validation Error", target=cls, errors=errors)

elif len(errors) > 0 and enforce:
cls_name = cls.__name__
logger.warning(f"Pydantic type validation failed, types are enforced. {cls_name} errors={repr(errors)})")
return new_values


def dataclass_validate(
cls=None,
*,
strict: bool = False,
before_post_init: bool = False,
enforce: bool = False,
):
"""Dataclass decorator to automatically add validation to a dataclass.

So you don't have to add a __post_init__ method, or if you have one, you don't have
Expand All @@ -170,7 +224,12 @@ def dataclass_validate(cls=None, *, strict: bool = False, before_post_init: bool
validation. Default: False.
"""
if cls is None:
return functools.partial(dataclass_validate, strict=strict, before_post_init=before_post_init)
return functools.partial(
dataclass_validate,
strict=strict,
before_post_init=before_post_init,
enforce=enforce,
)

if not hasattr(cls, "__post_init__"):
# No post-init method, so no processing. Wrap the constructor instead.
Expand All @@ -186,15 +245,65 @@ def dataclass_validate(cls=None, *, strict: bool = False, before_post_init: bool
# before the wrapped function.
@functools.wraps(orig_method)
def method_wrapper(self, *args, **kwargs):
dataclass_type_validator(self, strict=strict)
dataclass_type_validator(self, strict=strict, enforce=enforce)
return orig_method(self, *args, **kwargs)

else:
# Normal case - call validator at the end of __init__ or __post_init__.
@functools.wraps(orig_method)
def method_wrapper(self, *args, **kwargs):
x = orig_method(self, *args, **kwargs)
dataclass_type_validator(self, strict=strict)
dataclass_type_validator(self, strict=strict, enforce=enforce)
return x

setattr(cls, wrapped_method_name, method_wrapper)

return cls


if __name__ == "__main__":
# @dataclasses.dataclass
# class TestClass:
# k: str = "key"
# v: float = 1.2

# test_class = TestClass(k=1.2, v="key")

# @dataclasses.dataclass
# class TestClass:
# k: str = "key"
# v: float = 1.2

# def __post_init__(self):
# dataclass_type_validator(self, enforce=True)

# test_class = TestClass(k=1.2, v="key")
from pydantic import root_validator

class TestClass(BaseModel):
k: str = "key"
v: float = 1.2

@root_validator(pre=True)
def enforce_validator(cls, values):
values = pydantic_type_validator(cls, values, enforce=True)
return values

def validate_class(self):
from pydantic import validate_model

object_setattr = object.__setattr__
values, fields_set, validation_error = validate_model(self.__class__, self.dict())
if validation_error:
raise validation_error
object_setattr(self, "__dict__", values)
object_setattr(self, "__fields_set__", fields_set)
self._init_private_attributes()

test_class = TestClass(k=1.2, v="key")
print(test_class)
setattr(test_class, "v", "key")
print(test_class)
test_class.validate_class()
print(test_class)
TestClass()
1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pytest
pydantic
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ typing-extensions==3.10.0.0
# via importlib-metadata
zipp==3.4.1
# via importlib-metadata
pydantic