|
1 | 1 | """Serialization and deserialization""" |
2 | 2 |
|
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import base64 |
3 | 6 | import json |
4 | 7 | import logging |
| 8 | +import uuid |
5 | 9 | from abc import ABC, abstractmethod |
6 | 10 | from dataclasses import dataclass |
7 | | -from typing import Generic, TypeVar |
| 11 | +from datetime import date, datetime |
| 12 | +from decimal import Decimal |
| 13 | +from typing import TYPE_CHECKING, Any, Generic, TypeVar |
| 14 | + |
| 15 | +if TYPE_CHECKING: |
| 16 | + from collections.abc import Callable |
8 | 17 |
|
9 | 18 | from aws_durable_execution_sdk_python.exceptions import FatalError |
10 | 19 |
|
@@ -37,38 +46,223 @@ def deserialize(self, data: str, _: SerDesContext) -> T: # noqa: PLR6301 |
37 | 46 | return json.loads(data) |
38 | 47 |
|
39 | 48 |
|
| 49 | +class TypeHandler(ABC): |
| 50 | + def __init__(self, next_handler: TypeHandler | None = None) -> None: |
| 51 | + self._next: TypeHandler | None = next_handler |
| 52 | + |
| 53 | + @abstractmethod |
| 54 | + def encode(self, obj: Any) -> dict[str, Any]: |
| 55 | + pass |
| 56 | + |
| 57 | + @abstractmethod |
| 58 | + def decode(self, tag: str, value: Any) -> Any: |
| 59 | + pass |
| 60 | + |
| 61 | + |
| 62 | +class UnsupportedHandler(TypeHandler): |
| 63 | + def encode(self, obj: Any) -> dict[str, Any]: # noqa: PLR6301 |
| 64 | + msg = f"Unsupported type: {type(obj)!r}" |
| 65 | + raise TypeError(msg) |
| 66 | + |
| 67 | + def decode(self, tag: str, value: Any) -> Any: # noqa: PLR6301, ARG002 |
| 68 | + msg = f"Unknown type tag: {tag!r}" |
| 69 | + raise ValueError(msg) |
| 70 | + |
| 71 | + |
| 72 | +class BytesHandler(TypeHandler): |
| 73 | + def encode(self, obj: Any) -> dict[str, Any]: |
| 74 | + if isinstance(obj, bytes | bytearray | memoryview): |
| 75 | + encoded: str = base64.b64encode(bytes(obj)).decode() |
| 76 | + return {"_": {"t": "bytes", "v": encoded}} |
| 77 | + return self._next.encode(obj) |
| 78 | + |
| 79 | + def decode(self, tag: str, value: Any) -> Any: |
| 80 | + if tag == "bytes": |
| 81 | + return base64.b64decode(value) |
| 82 | + return self._next.decode(tag, value) |
| 83 | + |
| 84 | + |
| 85 | +class UuidHandler(TypeHandler): |
| 86 | + def encode(self, obj: Any) -> dict[str, Any]: |
| 87 | + if isinstance(obj, uuid.UUID): |
| 88 | + return {"_": {"t": "uuid", "v": str(obj)}} |
| 89 | + return self._next.encode(obj) |
| 90 | + |
| 91 | + def decode(self, tag: str, value: Any) -> Any: |
| 92 | + if tag == "uuid": |
| 93 | + return uuid.UUID(value) |
| 94 | + return self._next.decode(tag, value) |
| 95 | + |
| 96 | + |
| 97 | +class DecimalHandler(TypeHandler): |
| 98 | + def encode(self, obj: Any) -> dict[str, Any]: |
| 99 | + if isinstance(obj, Decimal): |
| 100 | + return {"_": {"t": "decimal", "v": str(obj)}} |
| 101 | + return self._next.encode(obj) |
| 102 | + |
| 103 | + def decode(self, tag: str, value: Any) -> Any: |
| 104 | + if tag == "decimal": |
| 105 | + return Decimal(value) |
| 106 | + return self._next.decode(tag, value) |
| 107 | + |
| 108 | + |
| 109 | +class DateTimeHandler(TypeHandler): |
| 110 | + def encode(self, obj: Any) -> dict[str, Any]: |
| 111 | + if isinstance(obj, datetime): |
| 112 | + return {"_": {"t": "datetime", "v": obj.isoformat()}} |
| 113 | + if isinstance(obj, date): |
| 114 | + return {"_": {"t": "date", "v": obj.isoformat()}} |
| 115 | + return self._next.encode(obj) |
| 116 | + |
| 117 | + def decode(self, tag: str, value: Any) -> Any: |
| 118 | + if tag == "datetime": |
| 119 | + return datetime.fromisoformat(value) |
| 120 | + if tag == "date": |
| 121 | + return date.fromisoformat(value) |
| 122 | + return self._next.decode(tag, value) |
| 123 | + |
| 124 | + |
| 125 | +class ContainerHandler(TypeHandler): |
| 126 | + def __init__(self, next_handler: TypeHandler | None = None) -> None: |
| 127 | + super().__init__(next_handler) |
| 128 | + self._dispatch_encode: Callable[[Any], dict[str, Any]] | None = None |
| 129 | + self._dispatch_decode: Callable[[str, Any], Any] | None = None |
| 130 | + |
| 131 | + def _enc(self, obj: Any) -> dict[str, Any]: |
| 132 | + if self._dispatch_encode is None: |
| 133 | + msg = "ContainerHandler not initialized with encode dispatcher." |
| 134 | + raise RuntimeError(msg) |
| 135 | + return self._dispatch_encode(obj) |
| 136 | + |
| 137 | + def _dec(self, tag: str, value: Any) -> Any: |
| 138 | + if self._dispatch_decode is None: |
| 139 | + msg = "ContainerHandler not initialized with decode dispatcher." |
| 140 | + raise RuntimeError(msg) |
| 141 | + return self._dispatch_decode(tag, value) |
| 142 | + |
| 143 | + def encode(self, obj: Any) -> dict[str, Any]: |
| 144 | + if isinstance(obj, tuple): |
| 145 | + items: list[dict[str, Any]] = [self._enc(x) for x in obj] |
| 146 | + return {"_": {"t": "tuple", "v": items}} |
| 147 | + if isinstance(obj, list): |
| 148 | + items_list: list[dict[str, Any]] = [self._enc(x) for x in obj] |
| 149 | + return {"_": {"t": "list", "v": items_list}} |
| 150 | + if isinstance(obj, dict): |
| 151 | + self._validate_dict_keys(obj) |
| 152 | + wrapped: dict[str, dict[str, Any]] = { |
| 153 | + k: self._enc(v) for k, v in obj.items() |
| 154 | + } |
| 155 | + return {"_": {"t": "dict", "v": wrapped}} |
| 156 | + return self._next.encode(obj) |
| 157 | + |
| 158 | + def decode(self, tag: str, value: Any) -> Any: |
| 159 | + if tag == "tuple": |
| 160 | + if not isinstance(value, list): |
| 161 | + msg = 'Malformed envelope: "tuple" expects array value.' |
| 162 | + raise TypeError(msg) |
| 163 | + return tuple(self._dec(v["_"]["t"], v["_"]["v"]) for v in value) |
| 164 | + if tag == "list": |
| 165 | + if not isinstance(value, list): |
| 166 | + msg = 'Malformed envelope: "list" expects array value.' |
| 167 | + raise TypeError(msg) |
| 168 | + return [self._dec(v["_"]["t"], v["_"]["v"]) for v in value] |
| 169 | + if tag == "dict": |
| 170 | + if not isinstance(value, dict): |
| 171 | + msg = 'Malformed envelope: "dict" expects object value.' |
| 172 | + raise TypeError(msg) |
| 173 | + return {k: self._dec(v["_"]["t"], v["_"]["v"]) for k, v in value.items()} |
| 174 | + return self._next.decode(tag, value) |
| 175 | + |
| 176 | + @staticmethod |
| 177 | + def _validate_dict_keys(mapping: dict[Any, Any]) -> None: |
| 178 | + bad: list[Any] = [k for k in mapping if not isinstance(k, str)] |
| 179 | + if bad: |
| 180 | + ex: Any = bad[0] |
| 181 | + msg = f"Unsupported mapping key type: {type(ex)!r}. JSON object keys must be strings." |
| 182 | + raise TypeError(msg) |
| 183 | + |
| 184 | + |
| 185 | +class PrimitiveHandler(TypeHandler): |
| 186 | + def encode(self, obj: Any) -> dict[str, Any]: |
| 187 | + if obj is None or isinstance(obj, str | int | float | bool): |
| 188 | + tag: str = type(obj).__name__ |
| 189 | + return {"_": {"t": tag, "v": obj}} |
| 190 | + return self._next.encode(obj) |
| 191 | + |
| 192 | + def decode(self, tag: str, value: Any) -> Any: |
| 193 | + if tag == "NoneType": |
| 194 | + return None |
| 195 | + if tag in {"str", "int", "float", "bool"}: |
| 196 | + return value |
| 197 | + return self._next.decode(tag, value) |
| 198 | + |
| 199 | + |
| 200 | +@dataclass(frozen=True) |
| 201 | +class HandlerChain: |
| 202 | + root: TypeHandler |
| 203 | + container: ContainerHandler |
| 204 | + |
| 205 | + @classmethod |
| 206 | + def create(cls) -> HandlerChain: |
| 207 | + unsupported: UnsupportedHandler = UnsupportedHandler() |
| 208 | + bytes_h: BytesHandler = BytesHandler(unsupported) |
| 209 | + uuid_h: UuidHandler = UuidHandler(bytes_h) |
| 210 | + decimal_h: DecimalHandler = DecimalHandler(uuid_h) |
| 211 | + dt_h: DateTimeHandler = DateTimeHandler(decimal_h) |
| 212 | + container_h: ContainerHandler = ContainerHandler(dt_h) |
| 213 | + primitive_h: PrimitiveHandler = PrimitiveHandler(container_h) |
| 214 | + |
| 215 | + # Wire dispatchers to always go through the root |
| 216 | + container_h._dispatch_encode = primitive_h.encode # noqa: SLF001 |
| 217 | + container_h._dispatch_decode = primitive_h.decode # noqa: SLF001 |
| 218 | + |
| 219 | + return cls(root=primitive_h, container=container_h) |
| 220 | + |
| 221 | + |
| 222 | +class EnvelopeSerDes(SerDes[T]): |
| 223 | + def __init__(self) -> None: |
| 224 | + self._chain: HandlerChain = HandlerChain.create() |
| 225 | + |
| 226 | + def serialize(self, value: T, _: SerDesContext) -> str: |
| 227 | + wrapped: dict[str, Any] = self._chain.root.encode(value) |
| 228 | + return json.dumps(wrapped, separators=(",", ":")) |
| 229 | + |
| 230 | + def deserialize(self, data: str, _: SerDesContext) -> T: |
| 231 | + obj: Any = json.loads(data) |
| 232 | + if not (isinstance(obj, dict) and "_" in obj and isinstance(obj["_"], dict)): |
| 233 | + msg = 'Malformed envelope: root must be {"_": {"t": ..., "v": ...}}.' |
| 234 | + raise TypeError(msg) |
| 235 | + inner: dict[str, Any] = obj["_"] |
| 236 | + if not (isinstance(inner, dict) and "t" in inner and "v" in inner): |
| 237 | + msg = 'Malformed envelope: missing "t" or "v" at root.' |
| 238 | + raise TypeError(msg) |
| 239 | + return self._chain.root.decode(inner["t"], inner["v"]) |
| 240 | + |
| 241 | + |
40 | 242 | _DEFAULT_JSON_SERDES: SerDes = JsonSerDes() |
41 | 243 |
|
42 | 244 |
|
43 | 245 | def serialize( |
44 | 246 | serdes: SerDes[T] | None, value: T, operation_id: str, durable_execution_arn: str |
45 | 247 | ) -> str: |
46 | 248 | serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn) |
47 | | - if serdes is None: |
48 | | - serdes = _DEFAULT_JSON_SERDES |
| 249 | + active_serdes: SerDes[T] = serdes or _DEFAULT_JSON_SERDES |
49 | 250 | try: |
50 | | - return serdes.serialize(value, serdes_context) |
| 251 | + return active_serdes.serialize(value, serdes_context) |
51 | 252 | except Exception as e: |
52 | | - logger.exception( |
53 | | - "⚠️ Serialization failed for id: %s", |
54 | | - operation_id, |
55 | | - ) |
56 | | - msg = f"Serialization failed for id: {operation_id}, error: {e}." |
| 253 | + logger.exception("⚠️ Serialization failed for id: %s", operation_id) |
| 254 | + msg: str = f"Serialization failed for id: {operation_id}, error: {e}." |
57 | 255 | raise FatalError(msg) from e |
58 | 256 |
|
59 | 257 |
|
60 | 258 | def deserialize( |
61 | 259 | serdes: SerDes[T] | None, data: str, operation_id: str, durable_execution_arn: str |
62 | 260 | ) -> T: |
63 | 261 | serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn) |
64 | | - if serdes is None: |
65 | | - serdes = _DEFAULT_JSON_SERDES |
| 262 | + active_serdes: SerDes[T] = serdes or _DEFAULT_JSON_SERDES |
66 | 263 | try: |
67 | | - return serdes.deserialize(data, serdes_context) |
| 264 | + return active_serdes.deserialize(data, serdes_context) |
68 | 265 | except Exception as e: |
69 | | - logger.exception( |
70 | | - "⚠️ Deserialization failed for id: %s", |
71 | | - operation_id, |
72 | | - ) |
73 | | - msg = f"Deserialization failed for id: {operation_id}" |
| 266 | + logger.exception("⚠️ Deserialization failed for id: %s", operation_id) |
| 267 | + msg: str = f"Deserialization failed for id: {operation_id}" |
74 | 268 | raise FatalError(msg) from e |
0 commit comments