Skip to content

Commit c45a1d9

Browse files
committed
Refactor dataclass_from_dict to improve field processing and error handling; add tests for message serialization
1 parent 2f49734 commit c45a1d9

File tree

2 files changed

+74
-31
lines changed

2 files changed

+74
-31
lines changed

src/iop/_serialization.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Any, Dict, Type
99

1010
import iris
11-
from pydantic import BaseModel, TypeAdapter
11+
from pydantic import BaseModel, TypeAdapter, ValidationError
1212

1313
from iop._message import _PydanticPickleMessage
1414
from iop._utils import _Utils
@@ -108,38 +108,42 @@ def _parse_classname(classname: str) -> tuple[str, str]:
108108
return classname[:j], classname[j+1:]
109109

110110
def dataclass_from_dict(klass: Type, dikt: Dict) -> Any:
111-
field_types = {
112-
key: val.annotation
113-
for key, val in inspect.signature(klass).parameters.items()
114-
}
115-
processed_dict = {}
116-
for key, val in inspect.signature(klass).parameters.items():
117-
if key not in dikt and val.default != val.empty:
118-
processed_dict[key] = val.default
119-
continue
120-
121-
value = dikt.get(key)
111+
"""Converts a dictionary to a dataclass instance.
112+
Handles non attended fields and nested dataclasses."""
113+
114+
def process_field(value: Any, field_type: Type) -> Any:
122115
if value is None:
123-
processed_dict[key] = None
116+
return None
117+
if is_dataclass(field_type):
118+
return dataclass_from_dict(field_type, value)
119+
if field_type != inspect.Parameter.empty:
120+
try:
121+
return TypeAdapter(field_type).validate_python(value)
122+
except ValidationError:
123+
return value
124+
return value
125+
126+
# Get field definitions from class signature
127+
fields = inspect.signature(klass).parameters
128+
field_dict = {}
129+
130+
# Process each field
131+
for field_name, field_info in fields.items():
132+
if field_name not in dikt:
133+
if field_info.default != field_info.empty:
134+
field_dict[field_name] = field_info.default
124135
continue
125-
126-
try:
127-
field_type = field_types[key]
128-
if field_type != inspect.Parameter.empty:
129-
adapter = TypeAdapter(field_type)
130-
processed_dict[key] = adapter.validate_python(value)
131-
else:
132-
processed_dict[key] = value
133-
except Exception:
134-
processed_dict[key] = value
135-
136-
instance = klass(
137-
**processed_dict
138-
)
139-
# handle any extra fields
140-
for k, v in dikt.items():
141-
if k not in processed_dict:
142-
setattr(instance, k, v)
136+
137+
field_dict[field_name] = process_field(dikt[field_name], field_info.annotation)
138+
139+
# Create instance
140+
instance = klass(**field_dict)
141+
142+
# Add any extra fields not in the dataclass definition
143+
for key, value in dikt.items():
144+
if key not in field_dict:
145+
setattr(instance, key, value)
146+
143147
return instance
144148

145149
def dataclass_to_dict(instance: Any) -> Dict:

src/tests/test_serialization.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import decimal
3+
from typing import Optional
34
import uuid
45
from dataclasses import dataclass
56

@@ -33,6 +34,44 @@ class FullMessge:
3334
data: bytes
3435
items: list # Changed from df to a simple list
3536

37+
@dataclass
38+
class MyObject:
39+
value: str = None
40+
foo: int = None
41+
bar: float = 3.14
42+
43+
@dataclass
44+
class Msg:
45+
text: str
46+
number: int
47+
my_obj: MyObject
48+
49+
def test_message_serialization():
50+
51+
52+
msg = Msg(text="hello", number=42, my_obj=None)
53+
54+
my_obj = MyObject(value="test", foo=None)
55+
56+
# hack my_obj as a dict
57+
msg.my_obj = {}
58+
msg.my_obj['value'] = "test"
59+
msg.my_obj['foo'] = None
60+
msg.my_obj['other'] = 3.14
61+
62+
# Test serialization
63+
serial = serialize_message(msg)
64+
assert type(serial).__module__.startswith('iris') and serial._IsA("IOP.Message")
65+
assert serial.classname == f"{Msg.__module__}.{Msg.__name__}"
66+
67+
# Test deserialization
68+
result = deserialize_message(serial)
69+
assert isinstance(result, Msg)
70+
assert result.text == msg.text
71+
assert result.number == msg.number
72+
assert result.my_obj == my_obj
73+
74+
3675
def test_json_serialization():
3776
# Create test data
3877
test_items = [{'col1': 1, 'col2': 'a'}, {'col1': 2, 'col2': 'b'}] # Simple list of dicts instead of DataFrame

0 commit comments

Comments
 (0)