Skip to content

Commit eb18e77

Browse files
author
Rares Polenciuc
committed
feat: add envelope serializer for extended python types
Implement EnvelopeSerDes to handle datetime, Decimal, bytes, UUID, tuple types using wrapper envelope format. Maintains backward compatibility with existing JSON serializer while providing comprehensive type support. - Add TypeHandler chain architecture for extensible serialization - Support datetime/date with ISO format encoding - Handle Decimal with string representation - Support bytes/bytearray/memoryview with base64 encoding - Add UUID serialization with string format - Implement tuple/list/dict container handling - Provide clear error messages for unsupported types - Add comprehensive test coverage for all supported types
1 parent 6beb550 commit eb18e77

File tree

2 files changed

+723
-17
lines changed

2 files changed

+723
-17
lines changed

src/aws_durable_execution_sdk_python/serdes.py

Lines changed: 211 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
"""Serialization and deserialization"""
22

3+
from __future__ import annotations
4+
5+
import base64
36
import json
47
import logging
8+
import uuid
59
from abc import ABC, abstractmethod
610
from dataclasses import dataclass
7-
from typing import Generic, TypeVar
11+
from datetime import date, datetime
12+
from decimal import Decimal
13+
from typing import TYPE_CHECKING, Any, Generic, TypeVar
14+
15+
if TYPE_CHECKING:
16+
from collections.abc import Callable
817

918
from aws_durable_execution_sdk_python.exceptions import FatalError
1019

@@ -37,38 +46,223 @@ def deserialize(self, data: str, _: SerDesContext) -> T: # noqa: PLR6301
3746
return json.loads(data)
3847

3948

49+
class TypeHandler(ABC):
50+
def __init__(self, next_handler: TypeHandler | None = None) -> None:
51+
self._next: TypeHandler | None = next_handler
52+
53+
@abstractmethod
54+
def encode(self, obj: Any) -> dict[str, Any]:
55+
pass
56+
57+
@abstractmethod
58+
def decode(self, tag: str, value: Any) -> Any:
59+
pass
60+
61+
62+
class UnsupportedHandler(TypeHandler):
63+
def encode(self, obj: Any) -> dict[str, Any]: # noqa: PLR6301
64+
msg = f"Unsupported type: {type(obj)!r}"
65+
raise TypeError(msg)
66+
67+
def decode(self, tag: str, value: Any) -> Any: # noqa: PLR6301, ARG002
68+
msg = f"Unknown type tag: {tag!r}"
69+
raise ValueError(msg)
70+
71+
72+
class BytesHandler(TypeHandler):
73+
def encode(self, obj: Any) -> dict[str, Any]:
74+
if isinstance(obj, bytes | bytearray | memoryview):
75+
encoded: str = base64.b64encode(bytes(obj)).decode()
76+
return {"_": {"t": "bytes", "v": encoded}}
77+
return self._next.encode(obj)
78+
79+
def decode(self, tag: str, value: Any) -> Any:
80+
if tag == "bytes":
81+
return base64.b64decode(value)
82+
return self._next.decode(tag, value)
83+
84+
85+
class UuidHandler(TypeHandler):
86+
def encode(self, obj: Any) -> dict[str, Any]:
87+
if isinstance(obj, uuid.UUID):
88+
return {"_": {"t": "uuid", "v": str(obj)}}
89+
return self._next.encode(obj)
90+
91+
def decode(self, tag: str, value: Any) -> Any:
92+
if tag == "uuid":
93+
return uuid.UUID(value)
94+
return self._next.decode(tag, value)
95+
96+
97+
class DecimalHandler(TypeHandler):
98+
def encode(self, obj: Any) -> dict[str, Any]:
99+
if isinstance(obj, Decimal):
100+
return {"_": {"t": "decimal", "v": str(obj)}}
101+
return self._next.encode(obj)
102+
103+
def decode(self, tag: str, value: Any) -> Any:
104+
if tag == "decimal":
105+
return Decimal(value)
106+
return self._next.decode(tag, value)
107+
108+
109+
class DateTimeHandler(TypeHandler):
110+
def encode(self, obj: Any) -> dict[str, Any]:
111+
if isinstance(obj, datetime):
112+
return {"_": {"t": "datetime", "v": obj.isoformat()}}
113+
if isinstance(obj, date):
114+
return {"_": {"t": "date", "v": obj.isoformat()}}
115+
return self._next.encode(obj)
116+
117+
def decode(self, tag: str, value: Any) -> Any:
118+
if tag == "datetime":
119+
return datetime.fromisoformat(value)
120+
if tag == "date":
121+
return date.fromisoformat(value)
122+
return self._next.decode(tag, value)
123+
124+
125+
class ContainerHandler(TypeHandler):
126+
def __init__(self, next_handler: TypeHandler | None = None) -> None:
127+
super().__init__(next_handler)
128+
self._dispatch_encode: Callable[[Any], dict[str, Any]] | None = None
129+
self._dispatch_decode: Callable[[str, Any], Any] | None = None
130+
131+
def _enc(self, obj: Any) -> dict[str, Any]:
132+
if self._dispatch_encode is None:
133+
msg = "ContainerHandler not initialized with encode dispatcher."
134+
raise RuntimeError(msg)
135+
return self._dispatch_encode(obj)
136+
137+
def _dec(self, tag: str, value: Any) -> Any:
138+
if self._dispatch_decode is None:
139+
msg = "ContainerHandler not initialized with decode dispatcher."
140+
raise RuntimeError(msg)
141+
return self._dispatch_decode(tag, value)
142+
143+
def encode(self, obj: Any) -> dict[str, Any]:
144+
if isinstance(obj, tuple):
145+
items: list[dict[str, Any]] = [self._enc(x) for x in obj]
146+
return {"_": {"t": "tuple", "v": items}}
147+
if isinstance(obj, list):
148+
items_list: list[dict[str, Any]] = [self._enc(x) for x in obj]
149+
return {"_": {"t": "list", "v": items_list}}
150+
if isinstance(obj, dict):
151+
self._validate_dict_keys(obj)
152+
wrapped: dict[str, dict[str, Any]] = {
153+
k: self._enc(v) for k, v in obj.items()
154+
}
155+
return {"_": {"t": "dict", "v": wrapped}}
156+
return self._next.encode(obj)
157+
158+
def decode(self, tag: str, value: Any) -> Any:
159+
if tag == "tuple":
160+
if not isinstance(value, list):
161+
msg = 'Malformed envelope: "tuple" expects array value.'
162+
raise TypeError(msg)
163+
return tuple(self._dec(v["_"]["t"], v["_"]["v"]) for v in value)
164+
if tag == "list":
165+
if not isinstance(value, list):
166+
msg = 'Malformed envelope: "list" expects array value.'
167+
raise TypeError(msg)
168+
return [self._dec(v["_"]["t"], v["_"]["v"]) for v in value]
169+
if tag == "dict":
170+
if not isinstance(value, dict):
171+
msg = 'Malformed envelope: "dict" expects object value.'
172+
raise TypeError(msg)
173+
return {k: self._dec(v["_"]["t"], v["_"]["v"]) for k, v in value.items()}
174+
return self._next.decode(tag, value)
175+
176+
@staticmethod
177+
def _validate_dict_keys(mapping: dict[Any, Any]) -> None:
178+
bad: list[Any] = [k for k in mapping if not isinstance(k, str)]
179+
if bad:
180+
ex: Any = bad[0]
181+
msg = f"Unsupported mapping key type: {type(ex)!r}. JSON object keys must be strings."
182+
raise TypeError(msg)
183+
184+
185+
class PrimitiveHandler(TypeHandler):
186+
def encode(self, obj: Any) -> dict[str, Any]:
187+
if obj is None or isinstance(obj, str | int | float | bool):
188+
tag: str = type(obj).__name__
189+
return {"_": {"t": tag, "v": obj}}
190+
return self._next.encode(obj)
191+
192+
def decode(self, tag: str, value: Any) -> Any:
193+
if tag == "NoneType":
194+
return None
195+
if tag in {"str", "int", "float", "bool"}:
196+
return value
197+
return self._next.decode(tag, value)
198+
199+
200+
@dataclass(frozen=True)
201+
class HandlerChain:
202+
root: TypeHandler
203+
container: ContainerHandler
204+
205+
@classmethod
206+
def create(cls) -> HandlerChain:
207+
unsupported: UnsupportedHandler = UnsupportedHandler()
208+
bytes_h: BytesHandler = BytesHandler(unsupported)
209+
uuid_h: UuidHandler = UuidHandler(bytes_h)
210+
decimal_h: DecimalHandler = DecimalHandler(uuid_h)
211+
dt_h: DateTimeHandler = DateTimeHandler(decimal_h)
212+
container_h: ContainerHandler = ContainerHandler(dt_h)
213+
primitive_h: PrimitiveHandler = PrimitiveHandler(container_h)
214+
215+
# Wire dispatchers to always go through the root
216+
container_h._dispatch_encode = primitive_h.encode # noqa: SLF001
217+
container_h._dispatch_decode = primitive_h.decode # noqa: SLF001
218+
219+
return cls(root=primitive_h, container=container_h)
220+
221+
222+
class EnvelopeSerDes(SerDes[T]):
223+
def __init__(self) -> None:
224+
self._chain: HandlerChain = HandlerChain.create()
225+
226+
def serialize(self, value: T, _: SerDesContext) -> str:
227+
wrapped: dict[str, Any] = self._chain.root.encode(value)
228+
return json.dumps(wrapped, separators=(",", ":"))
229+
230+
def deserialize(self, data: str, _: SerDesContext) -> T:
231+
obj: Any = json.loads(data)
232+
if not (isinstance(obj, dict) and "_" in obj and isinstance(obj["_"], dict)):
233+
msg = 'Malformed envelope: root must be {"_": {"t": ..., "v": ...}}.'
234+
raise TypeError(msg)
235+
inner: dict[str, Any] = obj["_"]
236+
if not (isinstance(inner, dict) and "t" in inner and "v" in inner):
237+
msg = 'Malformed envelope: missing "t" or "v" at root.'
238+
raise TypeError(msg)
239+
return self._chain.root.decode(inner["t"], inner["v"])
240+
241+
40242
_DEFAULT_JSON_SERDES: SerDes = JsonSerDes()
41243

42244

43245
def serialize(
44246
serdes: SerDes[T] | None, value: T, operation_id: str, durable_execution_arn: str
45247
) -> str:
46248
serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn)
47-
if serdes is None:
48-
serdes = _DEFAULT_JSON_SERDES
249+
active_serdes: SerDes[T] = serdes or _DEFAULT_JSON_SERDES
49250
try:
50-
return serdes.serialize(value, serdes_context)
251+
return active_serdes.serialize(value, serdes_context)
51252
except Exception as e:
52-
logger.exception(
53-
"⚠️ Serialization failed for id: %s",
54-
operation_id,
55-
)
56-
msg = f"Serialization failed for id: {operation_id}, error: {e}."
253+
logger.exception("⚠️ Serialization failed for id: %s", operation_id)
254+
msg: str = f"Serialization failed for id: {operation_id}, error: {e}."
57255
raise FatalError(msg) from e
58256

59257

60258
def deserialize(
61259
serdes: SerDes[T] | None, data: str, operation_id: str, durable_execution_arn: str
62260
) -> T:
63261
serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn)
64-
if serdes is None:
65-
serdes = _DEFAULT_JSON_SERDES
262+
active_serdes: SerDes[T] = serdes or _DEFAULT_JSON_SERDES
66263
try:
67-
return serdes.deserialize(data, serdes_context)
264+
return active_serdes.deserialize(data, serdes_context)
68265
except Exception as e:
69-
logger.exception(
70-
"⚠️ Deserialization failed for id: %s",
71-
operation_id,
72-
)
73-
msg = f"Deserialization failed for id: {operation_id}"
266+
logger.exception("⚠️ Deserialization failed for id: %s", operation_id)
267+
msg: str = f"Deserialization failed for id: {operation_id}"
74268
raise FatalError(msg) from e

0 commit comments

Comments
 (0)