diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 24f525bae0..f193d40d75 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -531,6 +531,19 @@ async def append_event(self, session: Session, event: Event) -> Event: schema = self._get_schema_classes() is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT use_row_level_locking = self._supports_row_level_locking() + + # Pre-analyze which state scopes have deltas so we only acquire + # FOR UPDATE locks on rows that will actually be written. + # The result is reused later to avoid calling extract_state_delta twice. + state_deltas = None + has_app_delta, has_user_delta = False, False + if event.actions and event.actions.state_delta: + state_deltas = _session_util.extract_state_delta( + event.actions.state_delta + ) + has_app_delta = bool(state_deltas.get("app")) + has_user_delta = bool(state_deltas.get("user")) + async with self._with_session_lock( app_name=session.app_name, user_id=session.user_id, @@ -554,7 +567,7 @@ async def append_event(self, session: Session, event: Event) -> Event: sql_session=sql_session, state_model=schema.StorageAppState, predicates=(schema.StorageAppState.app_name == session.app_name,), - use_row_level_locking=use_row_level_locking, + use_row_level_locking=use_row_level_locking and has_app_delta, missing_message=( "App state missing for app_name=" f"{session.app_name!r}. Session state tables should be " @@ -568,7 +581,7 @@ async def append_event(self, session: Session, event: Event) -> Event: schema.StorageUserState.app_name == session.app_name, schema.StorageUserState.user_id == session.user_id, ), - use_row_level_locking=use_row_level_locking, + use_row_level_locking=use_row_level_locking and has_user_delta, missing_message=( "User state missing for app_name=" f"{session.app_name!r}, user_id={session.user_id!r}. " @@ -599,11 +612,8 @@ async def append_event(self, session: Session, event: Event) -> Event: storage_events = [e async for e in result] session.events = [e.to_event() for e in storage_events] - # Extract state delta - if event.actions and event.actions.state_delta: - state_deltas = _session_util.extract_state_delta( - event.actions.state_delta - ) + # Apply state deltas (reusing pre-analyzed result from above) + if state_deltas: app_state_delta = state_deltas["app"] user_state_delta = state_deltas["user"] session_state_delta = state_deltas["session"] diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 25530bed89..9d326f85d2 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -1153,3 +1153,72 @@ async def test_prepare_tables_idempotent_after_creation(): assert session.id == 's1' finally: await service.close() + + +@pytest.fixture +async def lock_spy_harness(): + """Sets up a DatabaseSessionService with a spy on _select_required_state.""" + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + session = await service.create_session( + app_name='my_app', user_id='user', session_id='s1' + ) + + calls = [] + original_fn = database_session_service._select_required_state + + async def spy(**kwargs): + calls.append(kwargs.get('use_row_level_locking')) + return await original_fn(**kwargs) + + with ( + mock.patch.object( + service, '_supports_row_level_locking', return_value=True + ), + mock.patch.object( + database_session_service, + '_select_required_state', + side_effect=spy, + ), + ): + yield service, session, calls + finally: + await service.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'state_delta, expected_app_lock, expected_user_lock', + [ + pytest.param( + {'session_key': 'v1'}, + False, + False, + id='session_only_delta', + ), + pytest.param(None, False, False, id='no_state_delta'), + pytest.param({'app:key1': 'v1'}, True, False, id='app_delta_only'), + pytest.param({'user:key1': 'v1'}, False, True, id='user_delta_only'), + pytest.param( + {'app:ak': 'av', 'user:uk': 'uv'}, + True, + True, + id='both_app_and_user_delta', + ), + ], +) +async def test_append_event_conditional_for_update_locking( + lock_spy_harness, state_delta, expected_app_lock, expected_user_lock +): + """FOR UPDATE locks should only be acquired for scopes that have deltas.""" + service, session, calls = lock_spy_harness + + kwargs = {'invocation_id': 'inv', 'author': 'user'} + if state_delta is not None: + kwargs['actions'] = EventActions(state_delta=state_delta) + event = Event(**kwargs) + await service.append_event(session, event) + + assert len(calls) == 2 + assert calls[0] is expected_app_lock + assert calls[1] is expected_user_lock