diff --git a/dataclass_type_validator/__init__.py b/dataclass_type_validator/__init__.py index 200111e..40eac09 100644 --- a/dataclass_type_validator/__init__.py +++ b/dataclass_type_validator/__init__.py @@ -1,13 +1,16 @@ import dataclasses -import typing import functools -from typing import Any -from typing import Optional +import logging +import typing +from typing import Any, Optional + +from pydantic import BaseModel + +logger = logging.getLogger(__name__) class TypeValidationError(Exception): - """Exception raised on type validation errors. - """ + """Exception raised on type validation errors.""" def __init__(self, *args, target: dataclasses.dataclass, errors: dict): super(TypeValidationError, self).__init__(*args) @@ -16,28 +19,28 @@ def __init__(self, *args, target: dataclasses.dataclass, errors: dict): def __repr__(self): cls = self.class_ - cls_name = ( - f"{cls.__module__}.{cls.__name__}" - if cls.__module__ != "__main__" - else cls.__name__ - ) + cls_name = f"{cls.__module__}.{cls.__name__}" if cls.__module__ != "__main__" else cls.__name__ attrs = ", ".join([repr(v) for v in self.args]) return f"{cls_name}({attrs}, errors={repr(self.errors)})" def __str__(self): cls = self.class_ - cls_name = ( - f"{cls.__module__}.{cls.__name__}" - if cls.__module__ != "__main__" - else cls.__name__ - ) + cls_name = f"{cls.__module__}.{cls.__name__}" if cls.__module__ != "__main__" else cls.__name__ s = cls_name return f"{s} (errors = {self.errors})" +class EnforceError(Exception): + """Exception raised on enforcing validation errors.""" + + def __init__(self, *args): + super(EnforceError, self).__init__(*args) + pass + + def _validate_type(expected_type: type, value: Any) -> Optional[str]: if not isinstance(value, expected_type): - return f'must be an instance of {expected_type}, but received {type(value)}' + return f"must be an instance of {expected_type}, but received {type(value)}" def _validate_iterable_items(expected_type: type, value: Any, strict: bool) -> Optional[str]: @@ -45,30 +48,30 @@ def _validate_iterable_items(expected_type: type, value: Any, strict: bool) -> O errors = [_validate_types(expected_type=expected_item_type, value=v, strict=strict) for v in value] errors = [x for x in errors if x] if len(errors) > 0: - return f'must be an instance of {expected_type}, but there are some errors: {errors}' + return f"must be an instance of {expected_type}, but there are some errors: {errors}" def _validate_typing_list(expected_type: type, value: Any, strict: bool) -> Optional[str]: if not isinstance(value, list): - return f'must be an instance of list, but received {type(value)}' + return f"must be an instance of list, but received {type(value)}" return _validate_iterable_items(expected_type, value, strict) def _validate_typing_tuple(expected_type: type, value: Any, strict: bool) -> Optional[str]: if not isinstance(value, tuple): - return f'must be an instance of tuple, but received {type(value)}' + return f"must be an instance of tuple, but received {type(value)}" return _validate_iterable_items(expected_type, value, strict) def _validate_typing_frozenset(expected_type: type, value: Any, strict: bool) -> Optional[str]: if not isinstance(value, frozenset): - return f'must be an instance of frozenset, but received {type(value)}' + return f"must be an instance of frozenset, but received {type(value)}" return _validate_iterable_items(expected_type, value, strict) def _validate_typing_dict(expected_type: type, value: Any, strict: bool) -> Optional[str]: if not isinstance(value, dict): - return f'must be an instance of dict, but received {type(value)}' + return f"must be an instance of dict, but received {type(value)}" expected_key_type = expected_type.__args__[0] expected_value_type = expected_type.__args__[1] @@ -80,18 +83,20 @@ def _validate_typing_dict(expected_type: type, value: Any, strict: bool) -> Opti val_errors = [v for v in val_errors if v] if len(key_errors) > 0 and len(val_errors) > 0: - return f'must be an instance of {expected_type}, but there are some errors in keys and values. '\ - f'key errors: {key_errors}, value errors: {val_errors}' + return ( + f"must be an instance of {expected_type}, but there are some errors in keys and values. " + f"key errors: {key_errors}, value errors: {val_errors}" + ) elif len(key_errors) > 0: - return f'must be an instance of {expected_type}, but there are some errors in keys: {key_errors}' + return f"must be an instance of {expected_type}, but there are some errors in keys: {key_errors}" elif len(val_errors) > 0: - return f'must be an instance of {expected_type}, but there are some errors in values: {val_errors}' + return f"must be an instance of {expected_type}, but there are some errors in values: {val_errors}" def _validate_typing_callable(expected_type: type, value: Any, strict: bool) -> Optional[str]: _ = strict if not isinstance(value, type(lambda a: a)): - return f'must be an instance of {expected_type._name}, but received {type(value)}' + return f"must be an instance of {expected_type._name}, but received {type(value)}" def _validate_typing_literal(expected_type: type, value: Any, strict: bool) -> Optional[str]: @@ -101,11 +106,11 @@ def _validate_typing_literal(expected_type: type, value: Any, strict: bool) -> O _validate_typing_mappings = { - 'List': _validate_typing_list, - 'Tuple': _validate_typing_tuple, - 'FrozenSet': _validate_typing_frozenset, - 'Dict': _validate_typing_dict, - 'Callable': _validate_typing_callable, + "List": _validate_typing_list, + "Tuple": _validate_typing_tuple, + "FrozenSet": _validate_typing_frozenset, + "Dict": _validate_typing_dict, + "Callable": _validate_typing_callable, } @@ -114,18 +119,17 @@ def _validate_sequential_types(expected_type: type, value: Any, strict: bool) -> if validate_func is not None: return validate_func(expected_type, value, strict) - if str(expected_type).startswith('typing.Literal'): + if str(expected_type).startswith("typing.Literal"): return _validate_typing_literal(expected_type, value, strict) - if str(expected_type).startswith('typing.Union'): - is_valid = any(_validate_types(expected_type=t, value=value, strict=strict) is None - for t in expected_type.__args__) + if str(expected_type).startswith("typing.Union"): + is_valid = any(_validate_types(expected_type=t, value=value, strict=strict) is None for t in expected_type.__args__) if not is_valid: - return f'must be an instance of {expected_type}, but received {value}' + return f"must be an instance of {expected_type}, but received {value}" return if strict: - raise RuntimeError(f'Unknown type of {expected_type} (_name = {expected_type._name})') + raise RuntimeError(f"Unknown type of {expected_type} (_name = {expected_type._name})") def _validate_types(expected_type: type, value: Any, strict: bool) -> Optional[str]: @@ -136,8 +140,11 @@ def _validate_types(expected_type: type, value: Any, strict: bool) -> Optional[s return _validate_sequential_types(expected_type=expected_type, value=value, strict=strict) -def dataclass_type_validator(target, strict: bool = False): - fields = dataclasses.fields(target) +def dataclass_type_validator(target, strict: bool = False, enforce: bool = False): + if isinstance(target, BaseModel): + fields = target.fields.values() + else: + fields = dataclasses.fields(target) errors = {} for field in fields: @@ -148,14 +155,61 @@ def dataclass_type_validator(target, strict: bool = False): err = _validate_types(expected_type=expected_type, value=value, strict=strict) if err is not None: errors[field_name] = err + if enforce: + val = ( + field.default + if not isinstance(field.default, (dataclasses._MISSING_TYPE, type(None))) + else field.default_factory() + ) + if isinstance(val, (dataclasses._MISSING_TYPE, type(None))): + raise EnforceError("Can't enforce values as there is no default") + setattr(target, field_name, val) + + if len(errors) > 0 and not enforce: + raise TypeValidationError("Dataclass Type Validation Error", target=target, errors=errors) + + elif len(errors) > 0 and enforce: + cls = target.__class__ + cls_name = f"{cls.__module__}.{cls.__name__}" if cls.__module__ != "__main__" else cls.__name__ + logger.warning(f"Dataclass type validation failed, types are enforced. {cls_name} errors={repr(errors)})") + + +def pydantic_type_validator(cls, values: dict, strict: bool = False, enforce: bool = False): + fields = cls.__fields__.values() + errors = {} + for field in fields: + field_name = field.name + expected_type = field.type_ + value = values[field_name] if field_name in values.keys() else None - if len(errors) > 0: - raise TypeValidationError( - "Dataclass Type Validation Error", target=target, errors=errors - ) - - -def dataclass_validate(cls=None, *, strict: bool = False, before_post_init: bool = False): + err = _validate_types(expected_type=expected_type, value=value, strict=strict) + new_values = values + if err is not None: + errors[field_name] = err + if enforce: + val = field.default if not isinstance(field.default, type(None)) else None + if val is None: + val = field.default_factory() if not isinstance(field.default_factory, type(None)) else None + if val is None: + raise EnforceError("Can't enforce values as there is no default") + new_values[field_name] = val + + if len(errors) > 0 and not enforce: + raise TypeValidationError("Pydantic Type Validation Error", target=cls, errors=errors) + + elif len(errors) > 0 and enforce: + cls_name = cls.__name__ + logger.warning(f"Pydantic type validation failed, types are enforced. {cls_name} errors={repr(errors)})") + return new_values + + +def dataclass_validate( + cls=None, + *, + strict: bool = False, + before_post_init: bool = False, + enforce: bool = False, +): """Dataclass decorator to automatically add validation to a dataclass. So you don't have to add a __post_init__ method, or if you have one, you don't have @@ -170,7 +224,12 @@ def dataclass_validate(cls=None, *, strict: bool = False, before_post_init: bool validation. Default: False. """ if cls is None: - return functools.partial(dataclass_validate, strict=strict, before_post_init=before_post_init) + return functools.partial( + dataclass_validate, + strict=strict, + before_post_init=before_post_init, + enforce=enforce, + ) if not hasattr(cls, "__post_init__"): # No post-init method, so no processing. Wrap the constructor instead. @@ -186,15 +245,65 @@ def dataclass_validate(cls=None, *, strict: bool = False, before_post_init: bool # before the wrapped function. @functools.wraps(orig_method) def method_wrapper(self, *args, **kwargs): - dataclass_type_validator(self, strict=strict) + dataclass_type_validator(self, strict=strict, enforce=enforce) return orig_method(self, *args, **kwargs) + else: # Normal case - call validator at the end of __init__ or __post_init__. @functools.wraps(orig_method) def method_wrapper(self, *args, **kwargs): x = orig_method(self, *args, **kwargs) - dataclass_type_validator(self, strict=strict) + dataclass_type_validator(self, strict=strict, enforce=enforce) return x + setattr(cls, wrapped_method_name, method_wrapper) return cls + + +if __name__ == "__main__": + # @dataclasses.dataclass + # class TestClass: + # k: str = "key" + # v: float = 1.2 + + # test_class = TestClass(k=1.2, v="key") + + # @dataclasses.dataclass + # class TestClass: + # k: str = "key" + # v: float = 1.2 + + # def __post_init__(self): + # dataclass_type_validator(self, enforce=True) + + # test_class = TestClass(k=1.2, v="key") + from pydantic import root_validator + + class TestClass(BaseModel): + k: str = "key" + v: float = 1.2 + + @root_validator(pre=True) + def enforce_validator(cls, values): + values = pydantic_type_validator(cls, values, enforce=True) + return values + + def validate_class(self): + from pydantic import validate_model + + object_setattr = object.__setattr__ + values, fields_set, validation_error = validate_model(self.__class__, self.dict()) + if validation_error: + raise validation_error + object_setattr(self, "__dict__", values) + object_setattr(self, "__fields_set__", fields_set) + self._init_private_attributes() + + test_class = TestClass(k=1.2, v="key") + print(test_class) + setattr(test_class, "v", "key") + print(test_class) + test_class.validate_class() + print(test_class) + TestClass() diff --git a/requirements.in b/requirements.in index e079f8a..1cc0826 100644 --- a/requirements.in +++ b/requirements.in @@ -1 +1,2 @@ pytest +pydantic \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d76d9c4..0314205 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ typing-extensions==3.10.0.0 # via importlib-metadata zipp==3.4.1 # via importlib-metadata +pydantic \ No newline at end of file