Skip to content

Commit 8db2c6c

Browse files
authored
Update test_remote_a2a_agent.py
1 parent b2242ae commit 8db2c6c

File tree

1 file changed

+154
-0
lines changed

1 file changed

+154
-0
lines changed

tests/unittests/agents/test_remote_a2a_agent.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,6 +1590,160 @@ async def test_run_async_impl_successful_request(self):
15901590
in mock_event.custom_metadata
15911591
)
15921592

1593+
@pytest.mark.asyncio
1594+
async def test_run_async_impl_uses_session_id_when_no_context_id(self):
1595+
"""Test that session ID is used as context_id when no existing context.
1596+
1597+
When _construct_message_parts_from_session returns None for context_id,
1598+
the agent should use ctx.session.id to maintain session identity across
1599+
local and remote agents.
1600+
"""
1601+
with patch.object(self.agent, "_ensure_resolved"):
1602+
with patch.object(
1603+
self.agent, "_create_a2a_request_for_user_function_response"
1604+
) as mock_create_func:
1605+
mock_create_func.return_value = None
1606+
1607+
with patch.object(
1608+
self.agent, "_construct_message_parts_from_session"
1609+
) as mock_construct:
1610+
# Create proper A2A part mocks
1611+
from a2a.client import Client as A2AClient
1612+
from a2a.types import TextPart
1613+
1614+
mock_a2a_part = Mock(spec=TextPart)
1615+
# Return None for context_id to trigger session ID fallback
1616+
mock_construct.return_value = (
1617+
[mock_a2a_part],
1618+
None,
1619+
) # Tuple with parts and NO context_id
1620+
1621+
# Mock A2A client
1622+
mock_a2a_client = create_autospec(spec=A2AClient, instance=True)
1623+
mock_response = Mock()
1624+
mock_send_message = AsyncMock()
1625+
mock_send_message.__aiter__.return_value = [mock_response]
1626+
mock_a2a_client.send_message.return_value = mock_send_message
1627+
self.agent._a2a_client = mock_a2a_client
1628+
1629+
mock_event = Event(
1630+
author=self.agent.name,
1631+
invocation_id=self.mock_context.invocation_id,
1632+
branch=self.mock_context.branch,
1633+
)
1634+
1635+
with patch.object(self.agent, "_handle_a2a_response") as mock_handle:
1636+
mock_handle.return_value = mock_event
1637+
1638+
# Mock the logging functions to avoid iteration issues
1639+
with patch(
1640+
"google.adk.agents.remote_a2a_agent.build_a2a_request_log"
1641+
) as mock_req_log:
1642+
with patch(
1643+
"google.adk.agents.remote_a2a_agent.build_a2a_response_log"
1644+
) as mock_resp_log:
1645+
mock_req_log.return_value = "Mock request log"
1646+
mock_resp_log.return_value = "Mock response log"
1647+
1648+
# Mock the A2AMessage constructor to capture the arguments
1649+
with patch(
1650+
"google.adk.agents.remote_a2a_agent.A2AMessage"
1651+
) as mock_message_class:
1652+
mock_message = Mock(spec=A2AMessage)
1653+
mock_message_class.return_value = mock_message
1654+
1655+
# Add model_dump to mock_response for metadata
1656+
mock_response.model_dump.return_value = {"test": "response"}
1657+
1658+
# Execute
1659+
events = []
1660+
async for event in self.agent._run_async_impl(
1661+
self.mock_context
1662+
):
1663+
events.append(event)
1664+
1665+
# Verify A2AMessage was called with session ID as context_id
1666+
mock_message_class.assert_called_once()
1667+
call_kwargs = mock_message_class.call_args[1]
1668+
assert call_kwargs["context_id"] == self.mock_session.id
1669+
1670+
@pytest.mark.asyncio
1671+
async def test_run_async_impl_preserves_existing_context_id(self):
1672+
"""Test that existing context_id is preserved when available.
1673+
1674+
When _construct_message_parts_from_session returns a context_id from
1675+
a previous remote agent response, that context_id should be used
1676+
for conversation continuity.
1677+
"""
1678+
with patch.object(self.agent, "_ensure_resolved"):
1679+
with patch.object(
1680+
self.agent, "_create_a2a_request_for_user_function_response"
1681+
) as mock_create_func:
1682+
mock_create_func.return_value = None
1683+
1684+
with patch.object(
1685+
self.agent, "_construct_message_parts_from_session"
1686+
) as mock_construct:
1687+
# Create proper A2A part mocks
1688+
from a2a.client import Client as A2AClient
1689+
from a2a.types import TextPart
1690+
1691+
mock_a2a_part = Mock(spec=TextPart)
1692+
existing_context_id = "existing-context-456"
1693+
mock_construct.return_value = (
1694+
[mock_a2a_part],
1695+
existing_context_id,
1696+
) # Tuple with parts and existing context_id
1697+
1698+
# Mock A2A client
1699+
mock_a2a_client = create_autospec(spec=A2AClient, instance=True)
1700+
mock_response = Mock()
1701+
mock_send_message = AsyncMock()
1702+
mock_send_message.__aiter__.return_value = [mock_response]
1703+
mock_a2a_client.send_message.return_value = mock_send_message
1704+
self.agent._a2a_client = mock_a2a_client
1705+
1706+
mock_event = Event(
1707+
author=self.agent.name,
1708+
invocation_id=self.mock_context.invocation_id,
1709+
branch=self.mock_context.branch,
1710+
)
1711+
1712+
with patch.object(self.agent, "_handle_a2a_response") as mock_handle:
1713+
mock_handle.return_value = mock_event
1714+
1715+
# Mock the logging functions to avoid iteration issues
1716+
with patch(
1717+
"google.adk.agents.remote_a2a_agent.build_a2a_request_log"
1718+
) as mock_req_log:
1719+
with patch(
1720+
"google.adk.agents.remote_a2a_agent.build_a2a_response_log"
1721+
) as mock_resp_log:
1722+
mock_req_log.return_value = "Mock request log"
1723+
mock_resp_log.return_value = "Mock response log"
1724+
1725+
# Mock the A2AMessage constructor to capture the arguments
1726+
with patch(
1727+
"google.adk.agents.remote_a2a_agent.A2AMessage"
1728+
) as mock_message_class:
1729+
mock_message = Mock(spec=A2AMessage)
1730+
mock_message_class.return_value = mock_message
1731+
1732+
# Add model_dump to mock_response for metadata
1733+
mock_response.model_dump.return_value = {"test": "response"}
1734+
1735+
# Execute
1736+
events = []
1737+
async for event in self.agent._run_async_impl(
1738+
self.mock_context
1739+
):
1740+
events.append(event)
1741+
1742+
# Verify A2AMessage was called with existing context_id
1743+
mock_message_class.assert_called_once()
1744+
call_kwargs = mock_message_class.call_args[1]
1745+
assert call_kwargs["context_id"] == existing_context_id
1746+
15931747
@pytest.mark.asyncio
15941748
async def test_run_async_impl_a2a_client_error(self):
15951749
"""Test _run_async_impl when A2A send_message fails."""

0 commit comments

Comments
 (0)