Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,19 @@ async def append_event(self, session: Session, event: Event) -> Event:
schema = self._get_schema_classes()
is_sqlite = self.db_engine.dialect.name == _SQLITE_DIALECT
use_row_level_locking = self._supports_row_level_locking()

# Pre-analyze which state scopes have deltas so we only acquire
# FOR UPDATE locks on rows that will actually be written.
# The result is reused later to avoid calling extract_state_delta twice.
state_deltas = None
has_app_delta, has_user_delta = False, False
if event.actions and event.actions.state_delta:
state_deltas = _session_util.extract_state_delta(
event.actions.state_delta
)
has_app_delta = bool(state_deltas.get("app"))
has_user_delta = bool(state_deltas.get("user"))

async with self._with_session_lock(
app_name=session.app_name,
user_id=session.user_id,
Expand All @@ -554,7 +567,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
sql_session=sql_session,
state_model=schema.StorageAppState,
predicates=(schema.StorageAppState.app_name == session.app_name,),
use_row_level_locking=use_row_level_locking,
use_row_level_locking=use_row_level_locking and has_app_delta,
missing_message=(
"App state missing for app_name="
f"{session.app_name!r}. Session state tables should be "
Expand All @@ -568,7 +581,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
schema.StorageUserState.app_name == session.app_name,
schema.StorageUserState.user_id == session.user_id,
),
use_row_level_locking=use_row_level_locking,
use_row_level_locking=use_row_level_locking and has_user_delta,
missing_message=(
"User state missing for app_name="
f"{session.app_name!r}, user_id={session.user_id!r}. "
Expand Down Expand Up @@ -599,11 +612,8 @@ async def append_event(self, session: Session, event: Event) -> Event:
storage_events = [e async for e in result]
session.events = [e.to_event() for e in storage_events]

# Extract state delta
if event.actions and event.actions.state_delta:
state_deltas = _session_util.extract_state_delta(
event.actions.state_delta
)
# Apply state deltas (reusing pre-analyzed result from above)
if state_deltas:
app_state_delta = state_deltas["app"]
user_state_delta = state_deltas["user"]
session_state_delta = state_deltas["session"]
Expand Down
69 changes: 69 additions & 0 deletions tests/unittests/sessions/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,3 +1153,72 @@ async def test_prepare_tables_idempotent_after_creation():
assert session.id == 's1'
finally:
await service.close()


@pytest.fixture
async def lock_spy_harness():
"""Sets up a DatabaseSessionService with a spy on _select_required_state."""
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
try:
session = await service.create_session(
app_name='my_app', user_id='user', session_id='s1'
)

calls = []
original_fn = database_session_service._select_required_state

async def spy(**kwargs):
calls.append(kwargs.get('use_row_level_locking'))
return await original_fn(**kwargs)

with (
mock.patch.object(
service, '_supports_row_level_locking', return_value=True
),
mock.patch.object(
database_session_service,
'_select_required_state',
side_effect=spy,
),
):
yield service, session, calls
finally:
await service.close()


@pytest.mark.asyncio
@pytest.mark.parametrize(
'state_delta, expected_app_lock, expected_user_lock',
[
pytest.param(
{'session_key': 'v1'},
False,
False,
id='session_only_delta',
),
pytest.param(None, False, False, id='no_state_delta'),
pytest.param({'app:key1': 'v1'}, True, False, id='app_delta_only'),
pytest.param({'user:key1': 'v1'}, False, True, id='user_delta_only'),
pytest.param(
{'app:ak': 'av', 'user:uk': 'uv'},
True,
True,
id='both_app_and_user_delta',
),
],
)
async def test_append_event_conditional_for_update_locking(
lock_spy_harness, state_delta, expected_app_lock, expected_user_lock
):
"""FOR UPDATE locks should only be acquired for scopes that have deltas."""
service, session, calls = lock_spy_harness

kwargs = {'invocation_id': 'inv', 'author': 'user'}
if state_delta is not None:
kwargs['actions'] = EventActions(state_delta=state_delta)
event = Event(**kwargs)
await service.append_event(session, event)

assert len(calls) == 2
assert calls[0] is expected_app_lock
assert calls[1] is expected_user_lock