@@ -1590,6 +1590,160 @@ async def test_run_async_impl_successful_request(self):
15901590 in mock_event .custom_metadata
15911591 )
15921592
1593+ @pytest .mark .asyncio
1594+ async def test_run_async_impl_uses_session_id_when_no_context_id (self ):
1595+ """Test that session ID is used as context_id when no existing context.
1596+
1597+ When _construct_message_parts_from_session returns None for context_id,
1598+ the agent should use ctx.session.id to maintain session identity across
1599+ local and remote agents.
1600+ """
1601+ with patch .object (self .agent , "_ensure_resolved" ):
1602+ with patch .object (
1603+ self .agent , "_create_a2a_request_for_user_function_response"
1604+ ) as mock_create_func :
1605+ mock_create_func .return_value = None
1606+
1607+ with patch .object (
1608+ self .agent , "_construct_message_parts_from_session"
1609+ ) as mock_construct :
1610+ # Create proper A2A part mocks
1611+ from a2a .client import Client as A2AClient
1612+ from a2a .types import TextPart
1613+
1614+ mock_a2a_part = Mock (spec = TextPart )
1615+ # Return None for context_id to trigger session ID fallback
1616+ mock_construct .return_value = (
1617+ [mock_a2a_part ],
1618+ None ,
1619+ ) # Tuple with parts and NO context_id
1620+
1621+ # Mock A2A client
1622+ mock_a2a_client = create_autospec (spec = A2AClient , instance = True )
1623+ mock_response = Mock ()
1624+ mock_send_message = AsyncMock ()
1625+ mock_send_message .__aiter__ .return_value = [mock_response ]
1626+ mock_a2a_client .send_message .return_value = mock_send_message
1627+ self .agent ._a2a_client = mock_a2a_client
1628+
1629+ mock_event = Event (
1630+ author = self .agent .name ,
1631+ invocation_id = self .mock_context .invocation_id ,
1632+ branch = self .mock_context .branch ,
1633+ )
1634+
1635+ with patch .object (self .agent , "_handle_a2a_response" ) as mock_handle :
1636+ mock_handle .return_value = mock_event
1637+
1638+ # Mock the logging functions to avoid iteration issues
1639+ with patch (
1640+ "google.adk.agents.remote_a2a_agent.build_a2a_request_log"
1641+ ) as mock_req_log :
1642+ with patch (
1643+ "google.adk.agents.remote_a2a_agent.build_a2a_response_log"
1644+ ) as mock_resp_log :
1645+ mock_req_log .return_value = "Mock request log"
1646+ mock_resp_log .return_value = "Mock response log"
1647+
1648+ # Mock the A2AMessage constructor to capture the arguments
1649+ with patch (
1650+ "google.adk.agents.remote_a2a_agent.A2AMessage"
1651+ ) as mock_message_class :
1652+ mock_message = Mock (spec = A2AMessage )
1653+ mock_message_class .return_value = mock_message
1654+
1655+ # Add model_dump to mock_response for metadata
1656+ mock_response .model_dump .return_value = {"test" : "response" }
1657+
1658+ # Execute
1659+ events = []
1660+ async for event in self .agent ._run_async_impl (
1661+ self .mock_context
1662+ ):
1663+ events .append (event )
1664+
1665+ # Verify A2AMessage was called with session ID as context_id
1666+ mock_message_class .assert_called_once ()
1667+ call_kwargs = mock_message_class .call_args [1 ]
1668+ assert call_kwargs ["context_id" ] == self .mock_session .id
1669+
1670+ @pytest .mark .asyncio
1671+ async def test_run_async_impl_preserves_existing_context_id (self ):
1672+ """Test that existing context_id is preserved when available.
1673+
1674+ When _construct_message_parts_from_session returns a context_id from
1675+ a previous remote agent response, that context_id should be used
1676+ for conversation continuity.
1677+ """
1678+ with patch .object (self .agent , "_ensure_resolved" ):
1679+ with patch .object (
1680+ self .agent , "_create_a2a_request_for_user_function_response"
1681+ ) as mock_create_func :
1682+ mock_create_func .return_value = None
1683+
1684+ with patch .object (
1685+ self .agent , "_construct_message_parts_from_session"
1686+ ) as mock_construct :
1687+ # Create proper A2A part mocks
1688+ from a2a .client import Client as A2AClient
1689+ from a2a .types import TextPart
1690+
1691+ mock_a2a_part = Mock (spec = TextPart )
1692+ existing_context_id = "existing-context-456"
1693+ mock_construct .return_value = (
1694+ [mock_a2a_part ],
1695+ existing_context_id ,
1696+ ) # Tuple with parts and existing context_id
1697+
1698+ # Mock A2A client
1699+ mock_a2a_client = create_autospec (spec = A2AClient , instance = True )
1700+ mock_response = Mock ()
1701+ mock_send_message = AsyncMock ()
1702+ mock_send_message .__aiter__ .return_value = [mock_response ]
1703+ mock_a2a_client .send_message .return_value = mock_send_message
1704+ self .agent ._a2a_client = mock_a2a_client
1705+
1706+ mock_event = Event (
1707+ author = self .agent .name ,
1708+ invocation_id = self .mock_context .invocation_id ,
1709+ branch = self .mock_context .branch ,
1710+ )
1711+
1712+ with patch .object (self .agent , "_handle_a2a_response" ) as mock_handle :
1713+ mock_handle .return_value = mock_event
1714+
1715+ # Mock the logging functions to avoid iteration issues
1716+ with patch (
1717+ "google.adk.agents.remote_a2a_agent.build_a2a_request_log"
1718+ ) as mock_req_log :
1719+ with patch (
1720+ "google.adk.agents.remote_a2a_agent.build_a2a_response_log"
1721+ ) as mock_resp_log :
1722+ mock_req_log .return_value = "Mock request log"
1723+ mock_resp_log .return_value = "Mock response log"
1724+
1725+ # Mock the A2AMessage constructor to capture the arguments
1726+ with patch (
1727+ "google.adk.agents.remote_a2a_agent.A2AMessage"
1728+ ) as mock_message_class :
1729+ mock_message = Mock (spec = A2AMessage )
1730+ mock_message_class .return_value = mock_message
1731+
1732+ # Add model_dump to mock_response for metadata
1733+ mock_response .model_dump .return_value = {"test" : "response" }
1734+
1735+ # Execute
1736+ events = []
1737+ async for event in self .agent ._run_async_impl (
1738+ self .mock_context
1739+ ):
1740+ events .append (event )
1741+
1742+ # Verify A2AMessage was called with existing context_id
1743+ mock_message_class .assert_called_once ()
1744+ call_kwargs = mock_message_class .call_args [1 ]
1745+ assert call_kwargs ["context_id" ] == existing_context_id
1746+
15931747 @pytest .mark .asyncio
15941748 async def test_run_async_impl_a2a_client_error (self ):
15951749 """Test _run_async_impl when A2A send_message fails."""
0 commit comments