33import sys
44from decimal import Decimal
55from enum import Enum
6- from typing import Any
6+ from typing import Any , Callable , Dict
77from uuid import UUID
88
99import pytest
1010from helpers import assert_xml_equal
11- from pydantic import field_serializer
11+ from pydantic import SerializerFunctionWrapHandler , ValidatorFunctionWrapHandler , field_serializer , field_validator
12+ from pydantic import model_validator
1213from pydantic .functional_serializers import PlainSerializer , WrapSerializer
1314from pydantic .functional_validators import AfterValidator , BeforeValidator , WrapValidator
1415
@@ -131,7 +132,7 @@ def serialize_dt(self, value: dt.datetime) -> float:
131132 WrapSerializer (lambda val , nxt : val .timestamp (), return_type = float ),
132133 ],
133134)
134- def test_serializer (Serializer : Any ):
135+ def test_serializer_annotations (Serializer : Any ):
135136 from typing import Annotated
136137
137138 Timestamp = Annotated [dt .datetime , Serializer ]
@@ -163,31 +164,60 @@ class TestModel(BaseXmlModel, tag='model'):
163164 assert_xml_equal (actual_xml , xml .encode ())
164165
165166
167+ def test_serializer_methods ():
168+ class TestModel (BaseXmlModel , tag = 'model' ):
169+ field1 : dt .datetime = element ()
170+ field2 : dt .datetime = element ()
171+
172+ @field_serializer ('field1' , mode = 'plain' )
173+ def serialize_field1 (self , value : dt .datetime ) -> float :
174+ return value .timestamp ()
175+
176+ @field_serializer ('field2' , mode = 'wrap' )
177+ def serialize_field2 (self , value : dt .datetime , nxt : SerializerFunctionWrapHandler ) -> float :
178+ return nxt (value .timestamp ())
179+
180+ xml = '''
181+ <model>
182+ <field1>1675468800.0</field1>
183+ <field2>1675468800.0</field2>
184+ </model>
185+ '''
186+
187+ obj = TestModel (
188+ field1 = dt .datetime (2023 , 2 , 4 , tzinfo = dt .timezone .utc ),
189+ field2 = dt .datetime (2023 , 2 , 4 , tzinfo = dt .timezone .utc ),
190+ )
191+
192+ actual_xml = obj .to_xml ()
193+ assert_xml_equal (actual_xml , xml .encode ())
194+
195+
166196@pytest .mark .skipif (sys .version_info < (3 , 9 ), reason = "requires python 3.9 and above" )
167197@pytest .mark .parametrize (
168198 'Validator' , [
169- AfterValidator (lambda val : val ),
170- BeforeValidator (lambda val : dt .datetime .fromtimestamp ( float ( val ), tz = dt .timezone .utc )),
171- WrapValidator (lambda val , nxt : dt .datetime .fromtimestamp ( float ( val ), tz = dt .timezone .utc )),
199+ AfterValidator (lambda val : val . replace ( tzinfo = dt . timezone . utc ) ),
200+ BeforeValidator (lambda val : dt .datetime .fromisoformat ( val ). replace ( tzinfo = dt .timezone .utc )),
201+ WrapValidator (lambda val , hdr : hdr ( dt .datetime .fromisoformat ( val ). replace ( tzinfo = dt .timezone .utc ) )),
172202 ],
173203)
174- def test_validator (Validator : Any ):
204+ def test_validator_annotations (Validator : Any ):
175205 from typing import Annotated
176206
177- Timestamp = Annotated [dt .datetime , Validator ]
207+ DatetimeUTC = Annotated [dt .datetime , Validator ]
178208
179209 class TestSubModel (BaseXmlModel , tag = 'submodel' ):
180- field1 : Timestamp = element (tag = 'field1' )
210+ field1 : DatetimeUTC = element (tag = 'field1' )
181211
182212 class TestModel (BaseXmlModel , tag = 'model' ):
183- field1 : Timestamp = element (tag = 'field1' )
213+ field1 : DatetimeUTC = element (tag = 'field1' )
184214 field2 : TestSubModel
185215
186216 xml = '''
187217 <model>
188- <field1>1675468800.0 </field1>
218+ <field1>2023-02-04T00:00:00 </field1>
189219 <submodel>
190- <field1>1675468800.0 </field1>
220+ <field1>2023-02-04T00:00:00 </field1>
191221 </submodel>
192222 </model>
193223 '''
@@ -202,3 +232,74 @@ class TestModel(BaseXmlModel, tag='model'):
202232 )
203233
204234 assert actual_obj == expected_obj
235+
236+
237+ def test_validator_methods ():
238+ class TestModel (BaseXmlModel , tag = 'model' ):
239+ field1 : dt .datetime = element ()
240+ field2 : dt .datetime = element ()
241+ field3 : dt .datetime = element ()
242+
243+ @field_validator ('field1' , mode = 'wrap' )
244+ def validate_field1 (cls , value : str , handler : ValidatorFunctionWrapHandler ) -> dt .datetime :
245+ return handler (dt .datetime .fromisoformat (value ).replace (tzinfo = dt .timezone .utc ))
246+
247+ @field_validator ('field2' , mode = 'before' )
248+ def validate_field2 (cls , value : str ) -> dt .datetime :
249+ return dt .datetime .fromisoformat (value ).replace (tzinfo = dt .timezone .utc )
250+
251+ @field_validator ('field3' , mode = 'after' )
252+ def validate_field3 (cls , value : dt .datetime ) -> dt .datetime :
253+ return value .replace (tzinfo = dt .timezone .utc )
254+
255+ xml = '''
256+ <model>
257+ <field1>2023-02-04T00:00:00</field1>
258+ <field2>2023-02-04T00:00:00</field2>
259+ <field3>2023-02-04T00:00:00</field3>
260+ </model>
261+ '''
262+
263+ actual_obj = TestModel .from_xml (xml )
264+
265+ expected_obj = TestModel .model_construct (
266+ field1 = dt .datetime (2023 , 2 , 4 , tzinfo = dt .timezone .utc ),
267+ field2 = dt .datetime (2023 , 2 , 4 , tzinfo = dt .timezone .utc ),
268+ field3 = dt .datetime (2023 , 2 , 4 , tzinfo = dt .timezone .utc ),
269+ )
270+
271+ assert actual_obj == expected_obj
272+
273+
274+ def test_model_validator ():
275+ class TestModel (BaseXmlModel , tag = 'model' ):
276+ field1 : dt .datetime = element ()
277+
278+ @model_validator (mode = 'before' )
279+ def validate_model_before (cls , data : Dict [str , Any ]) -> 'TestModel' :
280+ return {
281+ 'field1' : dt .datetime .strptime (data ['field1' ], '%Y-%m-%d' ),
282+ }
283+
284+ @model_validator (mode = 'after' )
285+ def validate_model_after (cls , obj : 'TestModel' ) -> 'TestModel' :
286+ obj .field1 = obj .field1 .replace (tzinfo = dt .timezone .utc )
287+ return obj
288+
289+ @model_validator (mode = 'wrap' )
290+ def validate_model_wrap (cls , obj : 'TestModel' , handler : Callable ) -> 'TestModel' :
291+ return handler (obj )
292+
293+ xml = '''
294+ <model>
295+ <field1>2023-02-04</field1>
296+ </model>
297+ '''
298+
299+ actual_obj = TestModel .from_xml (xml )
300+
301+ expected_obj = TestModel .model_construct (
302+ field1 = dt .datetime (2023 , 2 , 4 , tzinfo = dt .timezone .utc ),
303+ )
304+
305+ assert actual_obj == expected_obj
0 commit comments