Skip to content

Commit 8543681

Browse files
authored
Merge pull request #115 from dapper91/dev
- model validator 'before' mode support added.
2 parents c0b64a8 + e6e7d8d commit 8543681

File tree

5 files changed

+125
-14
lines changed

5 files changed

+125
-14
lines changed

CHANGELOG.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
Changelog
22
=========
33

4+
2.2.2 (2023-09-15)
5+
------------------
6+
7+
- model validator 'before' mode support added.
8+
9+
410
2.2.1 (2023-09-12)
511
------------------
612

pydantic_xml/serializers/factories/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def from_core_schema(cls, schema: pcs.ModelSchema, ctx: Serializer.Context) -> '
5050
model_cls = schema['cls']
5151
fields_schema = schema['schema']
5252

53+
if fields_schema['type'] == 'function-before':
54+
fields_schema = fields_schema['schema']
55+
5356
assert issubclass(model_cls, pxml.BaseXmlModel), "model class must be a BaseXmlModel subclass"
5457
assert fields_schema['type'] == 'model-fields', f"unexpected schema type: {fields_schema['type']}"
5558
fields_schema = typing.cast(pcs.ModelFieldsSchema, fields_schema)

pydantic_xml/serializers/factories/primitive.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def serialize(
9797
def deserialize(
9898
self,
9999
element: Optional[XmlElementReader],
100-
*, context: Optional[Dict[str, Any]],
100+
*,
101+
context: Optional[Dict[str, Any]],
101102
) -> Optional[str]:
102103
if self._computed:
103104
return None

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pydantic-xml"
3-
version = "2.2.1"
3+
version = "2.2.2"
44
description = "pydantic xml extension"
55
authors = ["Dmitry Pershin <dapper1291@gmail.com>"]
66
license = "Unlicense"

tests/test_encoder.py

Lines changed: 113 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import sys
44
from decimal import Decimal
55
from enum import Enum
6-
from typing import Any
6+
from typing import Any, Callable, Dict
77
from uuid import UUID
88

99
import pytest
1010
from helpers import assert_xml_equal
11-
from pydantic import field_serializer
11+
from pydantic import SerializerFunctionWrapHandler, ValidatorFunctionWrapHandler, field_serializer, field_validator
12+
from pydantic import model_validator
1213
from pydantic.functional_serializers import PlainSerializer, WrapSerializer
1314
from pydantic.functional_validators import AfterValidator, BeforeValidator, WrapValidator
1415

@@ -131,7 +132,7 @@ def serialize_dt(self, value: dt.datetime) -> float:
131132
WrapSerializer(lambda val, nxt: val.timestamp(), return_type=float),
132133
],
133134
)
134-
def test_serializer(Serializer: Any):
135+
def test_serializer_annotations(Serializer: Any):
135136
from typing import Annotated
136137

137138
Timestamp = Annotated[dt.datetime, Serializer]
@@ -163,31 +164,60 @@ class TestModel(BaseXmlModel, tag='model'):
163164
assert_xml_equal(actual_xml, xml.encode())
164165

165166

167+
def test_serializer_methods():
168+
class TestModel(BaseXmlModel, tag='model'):
169+
field1: dt.datetime = element()
170+
field2: dt.datetime = element()
171+
172+
@field_serializer('field1', mode='plain')
173+
def serialize_field1(self, value: dt.datetime) -> float:
174+
return value.timestamp()
175+
176+
@field_serializer('field2', mode='wrap')
177+
def serialize_field2(self, value: dt.datetime, nxt: SerializerFunctionWrapHandler) -> float:
178+
return nxt(value.timestamp())
179+
180+
xml = '''
181+
<model>
182+
<field1>1675468800.0</field1>
183+
<field2>1675468800.0</field2>
184+
</model>
185+
'''
186+
187+
obj = TestModel(
188+
field1=dt.datetime(2023, 2, 4, tzinfo=dt.timezone.utc),
189+
field2=dt.datetime(2023, 2, 4, tzinfo=dt.timezone.utc),
190+
)
191+
192+
actual_xml = obj.to_xml()
193+
assert_xml_equal(actual_xml, xml.encode())
194+
195+
166196
@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python 3.9 and above")
167197
@pytest.mark.parametrize(
168198
'Validator', [
169-
AfterValidator(lambda val: val),
170-
BeforeValidator(lambda val: dt.datetime.fromtimestamp(float(val), tz=dt.timezone.utc)),
171-
WrapValidator(lambda val, nxt: dt.datetime.fromtimestamp(float(val), tz=dt.timezone.utc)),
199+
AfterValidator(lambda val: val.replace(tzinfo=dt.timezone.utc)),
200+
BeforeValidator(lambda val: dt.datetime.fromisoformat(val).replace(tzinfo=dt.timezone.utc)),
201+
WrapValidator(lambda val, hdr: hdr(dt.datetime.fromisoformat(val).replace(tzinfo=dt.timezone.utc))),
172202
],
173203
)
174-
def test_validator(Validator: Any):
204+
def test_validator_annotations(Validator: Any):
175205
from typing import Annotated
176206

177-
Timestamp = Annotated[dt.datetime, Validator]
207+
DatetimeUTC = Annotated[dt.datetime, Validator]
178208

179209
class TestSubModel(BaseXmlModel, tag='submodel'):
180-
field1: Timestamp = element(tag='field1')
210+
field1: DatetimeUTC = element(tag='field1')
181211

182212
class TestModel(BaseXmlModel, tag='model'):
183-
field1: Timestamp = element(tag='field1')
213+
field1: DatetimeUTC = element(tag='field1')
184214
field2: TestSubModel
185215

186216
xml = '''
187217
<model>
188-
<field1>1675468800.0</field1>
218+
<field1>2023-02-04T00:00:00</field1>
189219
<submodel>
190-
<field1>1675468800.0</field1>
220+
<field1>2023-02-04T00:00:00</field1>
191221
</submodel>
192222
</model>
193223
'''
@@ -202,3 +232,74 @@ class TestModel(BaseXmlModel, tag='model'):
202232
)
203233

204234
assert actual_obj == expected_obj
235+
236+
237+
def test_validator_methods():
238+
class TestModel(BaseXmlModel, tag='model'):
239+
field1: dt.datetime = element()
240+
field2: dt.datetime = element()
241+
field3: dt.datetime = element()
242+
243+
@field_validator('field1', mode='wrap')
244+
def validate_field1(cls, value: str, handler: ValidatorFunctionWrapHandler) -> dt.datetime:
245+
return handler(dt.datetime.fromisoformat(value).replace(tzinfo=dt.timezone.utc))
246+
247+
@field_validator('field2', mode='before')
248+
def validate_field2(cls, value: str) -> dt.datetime:
249+
return dt.datetime.fromisoformat(value).replace(tzinfo=dt.timezone.utc)
250+
251+
@field_validator('field3', mode='after')
252+
def validate_field3(cls, value: dt.datetime) -> dt.datetime:
253+
return value.replace(tzinfo=dt.timezone.utc)
254+
255+
xml = '''
256+
<model>
257+
<field1>2023-02-04T00:00:00</field1>
258+
<field2>2023-02-04T00:00:00</field2>
259+
<field3>2023-02-04T00:00:00</field3>
260+
</model>
261+
'''
262+
263+
actual_obj = TestModel.from_xml(xml)
264+
265+
expected_obj = TestModel.model_construct(
266+
field1=dt.datetime(2023, 2, 4, tzinfo=dt.timezone.utc),
267+
field2=dt.datetime(2023, 2, 4, tzinfo=dt.timezone.utc),
268+
field3=dt.datetime(2023, 2, 4, tzinfo=dt.timezone.utc),
269+
)
270+
271+
assert actual_obj == expected_obj
272+
273+
274+
def test_model_validator():
275+
class TestModel(BaseXmlModel, tag='model'):
276+
field1: dt.datetime = element()
277+
278+
@model_validator(mode='before')
279+
def validate_model_before(cls, data: Dict[str, Any]) -> 'TestModel':
280+
return {
281+
'field1': dt.datetime.strptime(data['field1'], '%Y-%m-%d'),
282+
}
283+
284+
@model_validator(mode='after')
285+
def validate_model_after(cls, obj: 'TestModel') -> 'TestModel':
286+
obj.field1 = obj.field1.replace(tzinfo=dt.timezone.utc)
287+
return obj
288+
289+
@model_validator(mode='wrap')
290+
def validate_model_wrap(cls, obj: 'TestModel', handler: Callable) -> 'TestModel':
291+
return handler(obj)
292+
293+
xml = '''
294+
<model>
295+
<field1>2023-02-04</field1>
296+
</model>
297+
'''
298+
299+
actual_obj = TestModel.from_xml(xml)
300+
301+
expected_obj = TestModel.model_construct(
302+
field1=dt.datetime(2023, 2, 4, tzinfo=dt.timezone.utc),
303+
)
304+
305+
assert actual_obj == expected_obj

0 commit comments

Comments
 (0)