|
16 | 16 |
|
17 | 17 | import functools |
18 | 18 | import inspect |
| 19 | +import datetime |
19 | 20 | from typing import Any, AsyncGenerator, Callable, Optional, List |
20 | 21 |
|
21 | 22 | from temporalio import workflow, activity |
@@ -153,3 +154,186 @@ def default_activities(cls) -> List[Callable]: |
153 | 154 | Useful for registering activities with the Temporal Worker. |
154 | 155 | """ |
155 | 156 | return [generate_content_activity] |
| 157 | + |
| 158 | + |
| 159 | +class SessionServiceActivities: |
| 160 | + """Wraps a BaseSessionService to expose its methods as Temporal Activities.""" |
| 161 | + |
| 162 | + def __init__(self, session_service: Any): |
| 163 | + # We type hint as Any to avoid circular imports or strict dependency on BaseSessionService |
| 164 | + # definition availability at this module level if not necessary, |
| 165 | + # but logically it is a BaseSessionService. |
| 166 | + self.session_service = session_service |
| 167 | + |
| 168 | + @activity.defn(name="create_session") |
| 169 | + async def create_session( |
| 170 | + self, |
| 171 | + app_name: str, |
| 172 | + user_id: str, |
| 173 | + state: Optional[dict[str, Any]] = None, |
| 174 | + session_id: Optional[str] = None, |
| 175 | + extra_kwargs: Optional[dict[str, Any]] = None, |
| 176 | + ) -> Any: |
| 177 | + # We return Any (dict/model) that Temporal can serialize. |
| 178 | + # BaseSessionService returns a Session object. |
| 179 | + return await self.session_service.create_session( |
| 180 | + app_name=app_name, |
| 181 | + user_id=user_id, |
| 182 | + state=state, |
| 183 | + session_id=session_id, |
| 184 | + **(extra_kwargs or {}) |
| 185 | + ) |
| 186 | + |
| 187 | + @activity.defn(name="get_session") |
| 188 | + async def get_session( |
| 189 | + self, |
| 190 | + app_name: str, |
| 191 | + user_id: str, |
| 192 | + session_id: str, |
| 193 | + extra_kwargs: Optional[dict[str, Any]] = None, |
| 194 | + ) -> Optional[Any]: |
| 195 | + # Note: 'config' argument in get_session might need Pydantic serialization support |
| 196 | + # We assume kwargs handles simple args or properly serialized objects. |
| 197 | + return await self.session_service.get_session( |
| 198 | + app_name=app_name, |
| 199 | + user_id=user_id, |
| 200 | + session_id=session_id, |
| 201 | + **(extra_kwargs or {}) |
| 202 | + ) |
| 203 | + |
| 204 | + @activity.defn(name="list_sessions") |
| 205 | + async def list_sessions( |
| 206 | + self, |
| 207 | + app_name: str, |
| 208 | + user_id: Optional[str] = None |
| 209 | + ) -> Any: |
| 210 | + return await self.session_service.list_sessions( |
| 211 | + app_name=app_name, |
| 212 | + user_id=user_id |
| 213 | + ) |
| 214 | + |
| 215 | + @activity.defn(name="delete_session") |
| 216 | + async def delete_session( |
| 217 | + self, |
| 218 | + app_name: str, |
| 219 | + user_id: str, |
| 220 | + session_id: str |
| 221 | + ) -> None: |
| 222 | + await self.session_service.delete_session( |
| 223 | + app_name=app_name, |
| 224 | + user_id=user_id, |
| 225 | + session_id=session_id |
| 226 | + ) |
| 227 | + |
| 228 | + @activity.defn(name="append_event") |
| 229 | + async def append_event(self, session: Any, event: Any) -> Any: |
| 230 | + return await self.session_service.append_event(session, event) |
| 231 | + |
| 232 | + def get_activities(self) -> List[Callable]: |
| 233 | + """Returns the list of activities to register.""" |
| 234 | + return [ |
| 235 | + self.create_session, |
| 236 | + self.get_session, |
| 237 | + self.list_sessions, |
| 238 | + self.delete_session, |
| 239 | + self.append_event |
| 240 | + ] |
| 241 | + |
| 242 | + |
| 243 | +from google.adk.sessions import BaseSessionService, Session |
| 244 | +from google.adk.sessions.base_session_service import ListSessionsResponse |
| 245 | +from google.adk.events import Event |
| 246 | + |
| 247 | + |
| 248 | +class ActivitySessionService(BaseSessionService): |
| 249 | + """A SessionService that delegates all calls to Temporal Activities. |
| 250 | + |
| 251 | + This ensures determinism within a Workflow by offloading the actual |
| 252 | + session I/O (which might be non-deterministic or remote) to the Worker. |
| 253 | + """ |
| 254 | + |
| 255 | + def __init__( |
| 256 | + self, |
| 257 | + activity_options: Optional[dict[str, Any]] = None |
| 258 | + ): |
| 259 | + """Initializes the ActivitySessionService. |
| 260 | +
|
| 261 | + Args: |
| 262 | + activity_options: Default options for activity execution (e.g. timeouts). |
| 263 | + Defaults to schedule_to_close_timeout=datetime.timedelta(seconds=30) if not provided. |
| 264 | + """ |
| 265 | + self.activity_options = activity_options or { |
| 266 | + "schedule_to_close_timeout": datetime.timedelta(seconds=30) |
| 267 | + } |
| 268 | + |
| 269 | + async def create_session( |
| 270 | + self, |
| 271 | + *, |
| 272 | + app_name: str, |
| 273 | + user_id: str, |
| 274 | + state: Optional[dict[str, Any]] = None, |
| 275 | + session_id: Optional[str] = None, |
| 276 | + **kwargs: Any, |
| 277 | + ) -> Session: |
| 278 | + result = await workflow.execute_activity( |
| 279 | + "create_session", |
| 280 | + args=[app_name, user_id, state, session_id, kwargs], |
| 281 | + **self.activity_options |
| 282 | + ) |
| 283 | + if isinstance(result, dict): |
| 284 | + return Session.model_validate(result) |
| 285 | + return result |
| 286 | + |
| 287 | + async def get_session( |
| 288 | + self, |
| 289 | + *, |
| 290 | + app_name: str, |
| 291 | + user_id: str, |
| 292 | + session_id: str, |
| 293 | + config: Optional[Any] = None, |
| 294 | + ) -> Optional[Session]: |
| 295 | + kwargs = {"config": config} if config else {} |
| 296 | + result = await workflow.execute_activity( |
| 297 | + "get_session", |
| 298 | + args=[app_name, user_id, session_id, kwargs], |
| 299 | + **self.activity_options |
| 300 | + ) |
| 301 | + if result and isinstance(result, dict): |
| 302 | + return Session.model_validate(result) |
| 303 | + return result |
| 304 | + |
| 305 | + async def list_sessions( |
| 306 | + self, |
| 307 | + *, |
| 308 | + app_name: str, |
| 309 | + user_id: Optional[str] = None |
| 310 | + ) -> ListSessionsResponse: |
| 311 | + result = await workflow.execute_activity( |
| 312 | + "list_sessions", |
| 313 | + args=[app_name, user_id], |
| 314 | + **self.activity_options |
| 315 | + ) |
| 316 | + if isinstance(result, dict): |
| 317 | + return ListSessionsResponse.model_validate(result) |
| 318 | + return result |
| 319 | + |
| 320 | + async def delete_session( |
| 321 | + self, *, app_name: str, user_id: str, session_id: str |
| 322 | + ) -> None: |
| 323 | + await workflow.execute_activity( |
| 324 | + "delete_session", |
| 325 | + args=[app_name, user_id, session_id], |
| 326 | + **self.activity_options |
| 327 | + ) |
| 328 | + |
| 329 | + async def append_event(self, session: Session, event: Event) -> Event: |
| 330 | + # Note: We might need to serialize session/event to dicts if passing them as args causes issues? |
| 331 | + # Usually PydanticConverter handles serialization of args fine. |
| 332 | + result = await workflow.execute_activity( |
| 333 | + "append_event", |
| 334 | + args=[session, event], |
| 335 | + **self.activity_options |
| 336 | + ) |
| 337 | + if isinstance(result, dict): |
| 338 | + return Event.model_validate(result) |
| 339 | + return result |
0 commit comments