Skip to content

Commit c5e11e5

Browse files
committed
Enhance serialization tests by adding new fields and improving error handling
1 parent c57612d commit c5e11e5

File tree

2 files changed

+204
-163
lines changed

2 files changed

+204
-163
lines changed

src/iop/_serialization.py

Lines changed: 182 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from __future__ import annotations
12
import base64
23
import codecs
34
import datetime
@@ -6,191 +7,213 @@
67
import json
78
import pickle
89
import uuid
9-
from typing import Any, Dict, Type
10+
from abc import ABC, abstractmethod
11+
from dataclasses import is_dataclass
12+
from typing import Any, Dict, Type, Optional
1013

1114
from dacite import Config, from_dict
1215
import iris
1316

1417
from iop._utils import _Utils
1518

19+
# Constants
20+
DATETIME_FORMAT_LENGTH = 23
21+
TIME_FORMAT_LENGTH = 12
22+
TYPE_SEPARATOR = ':'
23+
SUPPORTED_TYPES = {
24+
'datetime', 'date', 'time', 'dataframe',
25+
'decimal', 'uuid', 'bytes'
26+
}
27+
28+
class SerializationError(Exception):
29+
"""Base exception for serialization errors."""
30+
pass
31+
32+
class TypeConverter:
33+
"""Handles type conversion for special data types."""
34+
35+
@staticmethod
36+
def convert_to_string(typ: str, obj: Any) -> str:
37+
if typ == 'dataframe':
38+
return obj.to_json(orient="table")
39+
elif typ == 'datetime':
40+
return TypeConverter._format_datetime(obj)
41+
elif typ == 'date':
42+
return obj.isoformat()
43+
elif typ == 'time':
44+
return TypeConverter._format_time(obj)
45+
elif typ == 'bytes':
46+
return base64.b64encode(obj).decode("UTF-8")
47+
return str(obj)
48+
49+
@staticmethod
50+
def convert_from_string(typ: str, val: str) -> Any:
51+
try:
52+
if typ == 'datetime':
53+
return datetime.datetime.fromisoformat(val)
54+
elif typ == 'date':
55+
return datetime.date.fromisoformat(val)
56+
elif typ == 'time':
57+
return datetime.time.fromisoformat(val)
58+
elif typ == 'dataframe':
59+
try:
60+
import pandas as pd
61+
except ImportError:
62+
raise SerializationError("Failed to load pandas module")
63+
return pd.read_json(val, orient="table")
64+
elif typ == 'decimal':
65+
return decimal.Decimal(val)
66+
elif typ == 'uuid':
67+
return uuid.UUID(val)
68+
elif typ == 'bytes':
69+
return base64.b64decode(val.encode("UTF-8"))
70+
return val
71+
except Exception as e:
72+
raise SerializationError(f"Failed to convert type {typ}: {str(e)}")
73+
74+
@staticmethod
75+
def _format_datetime(dt: datetime.datetime) -> str:
76+
r = dt.isoformat()
77+
if dt.microsecond:
78+
r = r[:DATETIME_FORMAT_LENGTH] + r[26:]
79+
if r.endswith("+00:00"):
80+
r = r[:-6] + "Z"
81+
return r
82+
83+
@staticmethod
84+
def _format_time(t: datetime.time) -> str:
85+
r = t.isoformat()
86+
if t.microsecond:
87+
r = r[:TIME_FORMAT_LENGTH]
88+
return r
89+
1690
class IrisJSONEncoder(json.JSONEncoder):
1791
"""JSONEncoder that handles dates, decimals, UUIDs, etc."""
1892

19-
def default(self, o: Any) -> Any:
20-
if o.__class__.__name__ == 'DataFrame':
21-
return 'dataframe:' + o.to_json(orient="table")
22-
elif isinstance(o, datetime.datetime):
23-
r = o.isoformat()
24-
if o.microsecond:
25-
r = r[:23] + r[26:]
26-
if r.endswith("+00:00"):
27-
r = r[:-6] + "Z"
28-
return 'datetime:' + r
29-
elif isinstance(o, datetime.date):
30-
return 'date:' + o.isoformat()
31-
elif isinstance(o, datetime.time):
32-
r = o.isoformat()
33-
if o.microsecond:
34-
r = r[:12]
35-
return 'time:' + r
36-
elif isinstance(o, decimal.Decimal):
37-
return 'decimal:' + str(o)
38-
elif isinstance(o, uuid.UUID):
39-
return 'uuid:' + str(o)
40-
elif isinstance(o, bytes):
41-
return 'bytes:' + base64.b64encode(o).decode("UTF-8")
42-
elif hasattr(o, '__dict__'):
43-
return o.__dict__
44-
return super().default(o)
93+
def default(self, obj: Any) -> Any:
94+
if obj.__class__.__name__ == 'DataFrame':
95+
return f'dataframe:{TypeConverter.convert_to_string("dataframe", obj)}'
96+
elif isinstance(obj, datetime.datetime):
97+
return f'datetime:{TypeConverter.convert_to_string("datetime", obj)}'
98+
elif isinstance(obj, datetime.date):
99+
return f'date:{TypeConverter.convert_to_string("date", obj)}'
100+
elif isinstance(obj, datetime.time):
101+
return f'time:{TypeConverter.convert_to_string("time", obj)}'
102+
elif isinstance(obj, decimal.Decimal):
103+
return f'decimal:{obj}'
104+
elif isinstance(obj, uuid.UUID):
105+
return f'uuid:{obj}'
106+
elif isinstance(obj, bytes):
107+
return f'bytes:{TypeConverter.convert_to_string("bytes", obj)}'
108+
elif hasattr(obj, '__dict__'):
109+
return obj.__dict__
110+
return super().default(obj)
45111

46112
class IrisJSONDecoder(json.JSONDecoder):
47113
"""JSONDecoder that handles special type annotations."""
48114

49115
def __init__(self, *args: Any, **kwargs: Any) -> None:
50-
json.JSONDecoder.__init__(
51-
self, object_hook=self.object_hook, *args, **kwargs)
116+
super().__init__(object_hook=self.object_hook, *args, **kwargs)
52117

53118
def object_hook(self, obj: Dict) -> Dict:
54-
ret = {}
55-
for key, value in obj.items():
56-
i = 0
57-
if isinstance(value, str):
58-
i = value.find(":")
59-
if i > 0:
60-
typ = value[:i]
61-
val = value[i+1:]
62-
ret[key] = self._convert_typed_value(typ, val)
63-
else:
64-
ret[key] = value
65-
return ret
119+
return {
120+
key: self._process_value(value)
121+
for key, value in obj.items()
122+
}
66123

67-
def _convert_typed_value(self, typ: str, val: str) -> Any:
68-
if typ == 'datetime':
69-
return datetime.datetime.fromisoformat(val)
70-
elif typ == 'date':
71-
return datetime.date.fromisoformat(val)
72-
elif typ == 'time':
73-
return datetime.time.fromisoformat(val)
74-
elif typ == 'dataframe':
75-
module = importlib.import_module('pandas')
76-
return module.read_json(val, orient="table")
77-
elif typ == 'decimal':
78-
return decimal.Decimal(val)
79-
elif typ == 'uuid':
80-
return uuid.UUID(val)
81-
elif typ == 'bytes':
82-
return base64.b64decode(val.encode("UTF-8"))
83-
return val
84-
85-
86-
def serialize_pickle_message(message: Any) -> iris.cls:
87-
"""Converts a python dataclass message into an iris iop.message.
88-
89-
Args:
90-
message: The message to serialize, an instance of a class that is a subclass of Message.
91-
92-
Returns:
93-
The message in json format.
94-
"""
95-
pickle_string = codecs.encode(pickle.dumps(message), "base64").decode()
96-
module = message.__class__.__module__
97-
classname = message.__class__.__name__
98-
99-
msg = iris.cls('IOP.PickleMessage')._New()
100-
msg.classname = module + "." + classname
101-
102-
stream = _Utils.string_to_stream(pickle_string)
103-
msg.jstr = stream
104-
105-
return msg
106-
107-
def serialize_message(message: Any) -> iris.cls:
108-
"""Converts a python dataclass message into an iris iop.message.
109-
110-
Args:
111-
message: The message to serialize, an instance of a class that is a subclass of Message.
112-
113-
Returns:
114-
The message in json format.
115-
"""
116-
json_string = json.dumps(message, cls=IrisJSONEncoder, ensure_ascii=False)
117-
module = message.__class__.__module__
118-
classname = message.__class__.__name__
119-
120-
msg = iris.cls('IOP.Message')._New()
121-
msg.classname = module + "." + classname
122-
123-
if hasattr(msg, 'buffer') and len(json_string) > msg.buffer:
124-
msg.json = _Utils.string_to_stream(json_string, msg.buffer)
125-
else:
126-
msg.json = json_string
127-
128-
return msg
129-
130-
def deserialize_pickle_message(serial: iris.cls) -> Any:
131-
"""Converts an iris iop.message into a python dataclass message.
132-
133-
Args:
134-
serial: The serialized message
124+
def _process_value(self, value: Any) -> Any:
125+
if isinstance(value, str) and TYPE_SEPARATOR in value:
126+
typ, val = value.split(TYPE_SEPARATOR, 1)
127+
if typ in SUPPORTED_TYPES:
128+
return TypeConverter.convert_from_string(typ, val)
129+
return value
130+
131+
class MessageSerializer:
132+
"""Handles message serialization and deserialization."""
133+
134+
@staticmethod
135+
def serialize(message: Any, use_pickle: bool = False) -> iris.cls:
136+
"""Serializes a message to IRIS format."""
137+
if use_pickle:
138+
return MessageSerializer._serialize_pickle(message)
139+
return MessageSerializer._serialize_json(message)
140+
141+
@staticmethod
142+
def deserialize(serial: iris.cls, use_pickle: bool = False) -> Any:
143+
"""Deserializes a message from IRIS format."""
144+
if use_pickle:
145+
return MessageSerializer._deserialize_pickle(serial)
146+
return MessageSerializer._deserialize_json(serial)
147+
148+
@staticmethod
149+
def _serialize_pickle(message: Any) -> iris.cls:
150+
pickle_string = codecs.encode(pickle.dumps(message), "base64").decode()
151+
msg = iris.cls('IOP.PickleMessage')._New()
152+
msg.classname = f"{message.__class__.__module__}.{message.__class__.__name__}"
153+
msg.jstr = _Utils.string_to_stream(pickle_string)
154+
return msg
155+
156+
@staticmethod
157+
def _serialize_json(message: Any) -> iris.cls:
158+
json_string = json.dumps(message, cls=IrisJSONEncoder, ensure_ascii=False)
159+
msg = iris.cls('IOP.Message')._New()
160+
msg.classname = f"{message.__class__.__module__}.{message.__class__.__name__}"
135161

136-
Returns:
137-
The deserialized message
138-
"""
139-
string = _Utils.stream_to_string(serial.jstr)
140-
141-
msg = pickle.loads(codecs.decode(string.encode(), "base64"))
142-
return msg
143-
144-
def deserialize_message(serial: iris.cls) -> Any:
145-
"""Converts an iris iop.message into a python dataclass message.
146-
147-
Args:
148-
serial: The serialized message
162+
if hasattr(msg, 'buffer') and len(json_string) > msg.buffer:
163+
msg.json = _Utils.string_to_stream(json_string, msg.buffer)
164+
else:
165+
msg.json = json_string
166+
return msg
167+
168+
@staticmethod
169+
def _deserialize_pickle(serial: iris.cls) -> Any:
170+
string = _Utils.stream_to_string(serial.jstr)
171+
return pickle.loads(codecs.decode(string.encode(), "base64"))
172+
173+
@staticmethod
174+
def _deserialize_json(serial: iris.cls) -> Any:
175+
if not serial.classname:
176+
raise SerializationError("JSON message malformed, must include classname")
149177

150-
Returns:
151-
The deserialized message
152-
"""
153-
if (serial.classname is None):
154-
raise ValueError("JSON message malformed, must include classname")
155-
classname = serial.classname
156-
157-
j = classname.rindex(".")
158-
if (j <= 0):
159-
raise ValueError("Classname must include a module: " + classname)
160-
try:
161-
module = importlib.import_module(classname[:j])
162-
msg = getattr(module, classname[j+1:])
163-
except Exception:
164-
raise ImportError("Class not found: " + classname)
165-
166-
string = ""
167-
if (serial.type == 'Stream'):
168-
string = _Utils.stream_to_string(serial.json)
169-
else:
170-
string = serial.json
171-
172-
jdict = json.loads(string, cls=IrisJSONDecoder)
173-
msg = dataclass_from_dict(msg, jdict)
174-
return msg
178+
try:
179+
module_name, class_name = MessageSerializer._parse_classname(serial.classname)
180+
module = importlib.import_module(module_name)
181+
msg_class = getattr(module, class_name)
182+
except Exception as e:
183+
raise SerializationError(f"Failed to load class {serial.classname}: {str(e)}")
184+
185+
json_string = (_Utils.stream_to_string(serial.json)
186+
if serial.type == 'Stream' else serial.json)
187+
188+
try:
189+
json_dict = json.loads(json_string, cls=IrisJSONDecoder)
190+
return dataclass_from_dict(msg_class, json_dict)
191+
except Exception as e:
192+
raise SerializationError(f"Failed to deserialize JSON: {str(e)}")
193+
194+
@staticmethod
195+
def _parse_classname(classname: str) -> tuple[str, str]:
196+
j = classname.rindex(".")
197+
if j <= 0:
198+
raise SerializationError(f"Classname must include a module: {classname}")
199+
return classname[:j], classname[j+1:]
175200

176201
def dataclass_from_dict(klass: Type, dikt: Dict) -> Any:
177-
"""Converts a dictionary to a dataclass instance.
178-
179-
Args:
180-
klass: The dataclass to convert to
181-
dikt: The dictionary to convert to a dataclass
182-
183-
Returns:
184-
A dataclass object with the fields of the dataclass and the fields of the dictionary.
185-
"""
202+
"""Converts a dictionary to a dataclass instance."""
186203
ret = from_dict(klass, dikt, Config(check_types=False))
187204

188205
try:
189206
fieldtypes = klass.__annotations__
190-
except Exception as e:
191-
fieldtypes = []
207+
except Exception:
208+
fieldtypes = {}
192209

193210
for key, val in dikt.items():
194211
if key not in fieldtypes:
195212
setattr(ret, key, val)
196-
return ret
213+
return ret
214+
215+
# Maintain backwards compatibility
216+
serialize_pickle_message = lambda msg: MessageSerializer.serialize(msg, use_pickle=True)
217+
serialize_message = lambda msg: MessageSerializer.serialize(msg, use_pickle=False)
218+
deserialize_pickle_message = lambda serial: MessageSerializer.deserialize(serial, use_pickle=True)
219+
deserialize_message = lambda serial: MessageSerializer.deserialize(serial, use_pickle=False)

0 commit comments

Comments
 (0)