diff --git a/openhands-agent-server/openhands/agent_server/dependencies.py b/openhands-agent-server/openhands/agent_server/dependencies.py index c68d94e4ce..9b55e3e434 100644 --- a/openhands-agent-server/openhands/agent_server/dependencies.py +++ b/openhands-agent-server/openhands/agent_server/dependencies.py @@ -1,6 +1,6 @@ from uuid import UUID -from fastapi import Depends, HTTPException, Query, Request, status +from fastapi import Depends, HTTPException, Request, status from fastapi.security import APIKeyHeader from openhands.agent_server.config import Config @@ -26,23 +26,6 @@ def check_session_api_key( return check_session_api_key -def create_websocket_session_api_key_dependency(config: Config): - """Create a WebSocket session API key dependency with the given config. - - WebSocket connections cannot send custom headers directly from browsers, - so we use query parameters instead. - """ - - def check_websocket_session_api_key( - session_api_key: str | None = Query(None, alias="session_api_key"), - ): - """Check the session API key from query parameter for WebSocket connections.""" - if config.session_api_keys and session_api_key not in config.session_api_keys: - raise HTTPException(status.HTTP_401_UNAUTHORIZED) - - return check_websocket_session_api_key - - def get_conversation_service(request: Request): """Get the conversation service from app state. diff --git a/openhands-agent-server/openhands/agent_server/sockets.py b/openhands-agent-server/openhands/agent_server/sockets.py index c13e80716a..ef9586b08a 100644 --- a/openhands-agent-server/openhands/agent_server/sockets.py +++ b/openhands-agent-server/openhands/agent_server/sockets.py @@ -2,8 +2,9 @@ WebSocket endpoints for OpenHands SDK. These endpoints are separate from the main API routes to handle WebSocket-specific -authentication using query parameters instead of headers, since browsers cannot -send custom HTTP headers directly with WebSocket connections. +authentication. Browsers cannot send custom HTTP headers directly with WebSocket +connections, so we support the `session_api_key` query param. For non-browser +clients (e.g. Python/Node), we also support authenticating via headers. """ import logging @@ -35,6 +36,42 @@ logger = logging.getLogger(__name__) +def _resolve_websocket_session_api_key( + websocket: WebSocket, + session_api_key: str | None, +) -> str | None: + """Resolve the session API key from multiple sources. + + Precedence order (highest to lowest): + 1. Query parameter (session_api_key) - for browser compatibility + 2. X-Session-API-Key header - for non-browser clients + + Returns None if no valid key is found in any source. + """ + if session_api_key: + return session_api_key + + header_key = websocket.headers.get("x-session-api-key") + if header_key: + return header_key + + return None + + +async def _accept_authenticated_websocket( + websocket: WebSocket, + session_api_key: str | None, +) -> bool: + """Authenticate and accept the socket, or close with an auth error.""" + config = get_default_config() + resolved_key = _resolve_websocket_session_api_key(websocket, session_api_key) + if config.session_api_keys and resolved_key not in config.session_api_keys: + await websocket.close(code=4001, reason="Authentication failed") + return False + await websocket.accept() + return True + + @sockets_router.websocket("/events/{conversation_id}") async def events_socket( conversation_id: UUID, @@ -43,14 +80,9 @@ async def events_socket( resend_all: Annotated[bool, Query()] = False, ): """WebSocket endpoint for conversation events.""" - # Perform authentication check before accepting the WebSocket connection - config = get_default_config() - if config.session_api_keys and session_api_key not in config.session_api_keys: - # Close the WebSocket connection with an authentication error code - await websocket.close(code=4001, reason="Authentication failed") + if not await _accept_authenticated_websocket(websocket, session_api_key): return - await websocket.accept() logger.info(f"Event Websocket Connected: {conversation_id}") event_service = await conversation_service.get_event_service(conversation_id) if event_service is None: @@ -97,14 +129,9 @@ async def bash_events_socket( resend_all: Annotated[bool, Query()] = False, ): """WebSocket endpoint for bash events.""" - # Perform authentication check before accepting the WebSocket connection - config = get_default_config() - if config.session_api_keys and session_api_key not in config.session_api_keys: - # Close the WebSocket connection with an authentication error code - await websocket.close(code=4001, reason="Authentication failed") + if not await _accept_authenticated_websocket(websocket, session_api_key): return - await websocket.accept() logger.info("Bash Websocket Connected") subscriber_id = await bash_event_service.subscribe_to_events( _BashWebSocketSubscriber(websocket) diff --git a/tests/agent_server/test_agent_server_wsproto.py b/tests/agent_server/test_agent_server_wsproto.py index 3e0d8044f3..3afc178a2e 100644 --- a/tests/agent_server/test_agent_server_wsproto.py +++ b/tests/agent_server/test_agent_server_wsproto.py @@ -104,3 +104,46 @@ async def test_agent_server_websocket_with_wsproto(agent_server): await ws.send( json.dumps({"role": "user", "content": "Hello from wsproto test"}) ) + + +@pytest.mark.asyncio +async def test_agent_server_websocket_with_wsproto_header_auth(agent_server): + port = agent_server["port"] + api_key = agent_server["api_key"] + + response = requests.post( + f"http://127.0.0.1:{port}/api/conversations", + headers={"X-Session-API-Key": api_key}, + json={ + "agent": { + "llm": { + "usage_id": "test-llm", + "model": "test-provider/test-model", + "api_key": "test-key", + }, + "tools": [], + }, + "workspace": {"working_dir": "/tmp/test-workspace"}, + }, + ) + assert response.status_code in [200, 201] + conversation_id = response.json()["id"] + + ws_url = f"ws://127.0.0.1:{port}/sockets/events/{conversation_id}?resend_all=true" + + async with websockets.connect( + ws_url, + open_timeout=5, + additional_headers={"X-Session-API-Key": api_key}, + ) as ws: + try: + response = await asyncio.wait_for(ws.recv(), timeout=2) + assert response is not None + except TimeoutError: + pass + + await ws.send( + json.dumps( + {"role": "user", "content": "Hello from wsproto header auth test"} + ) + ) diff --git a/tests/agent_server/test_api_authentication.py b/tests/agent_server/test_api_authentication.py index 86865c82bd..2561d127a0 100644 --- a/tests/agent_server/test_api_authentication.py +++ b/tests/agent_server/test_api_authentication.py @@ -213,29 +213,46 @@ def test_api_websocket_authentication(): app = create_app(config) client = TestClient(app) - # Test WebSocket connection without authentication - should fail - try: + # Without authentication -> should fail + with pytest.raises(Exception): with client.websocket_connect("/sockets/bash-events"): - # If we get here, the connection was established without auth - # (should not happen) - assert False, ( - "WebSocket connection should have failed without authentication" - ) - except Exception: - # WebSocket connection should fail without proper authentication + assert False, "WebSocket connection should have failed without auth" + + # Query-param authentication -> should work (browser-compatible) + with client.websocket_connect("/sockets/bash-events?session_api_key=test-key"): + pass + + # Header authentication -> should work for non-browser clients + with client.websocket_connect( + "/sockets/bash-events", + headers={"X-Session-API-Key": "test-key"}, + ): pass - # Test WebSocket connection with authentication via query parameter - should work - try: - with client.websocket_connect("/sockets/bash-events?session_api_key=test-key"): - # If we get here, the connection was established with proper auth - pass - except Exception: - # Connection might fail for other reasons (like missing conversation ID for - # events endpoint) - # This test mainly ensures the auth mechanism works + # Query param should take precedence over headers (browser-compatible escape hatch). + with client.websocket_connect( + "/sockets/bash-events?session_api_key=test-key", + headers={"X-Session-API-Key": "wrong-key"}, + ): pass + # If query param is present and wrong, connection should fail even if the + # header is correct. + with pytest.raises(Exception): + with client.websocket_connect( + "/sockets/bash-events?session_api_key=wrong-key", + headers={"X-Session-API-Key": "test-key"}, + ): + assert False, "WebSocket connection should have failed with wrong query key" + + # Wrong header -> should fail + with pytest.raises(Exception): + with client.websocket_connect( + "/sockets/bash-events", + headers={"X-Session-API-Key": "wrong-key"}, + ): + assert False, "WebSocket connection should have failed with wrong key" + def test_api_options_requests(): """Test that OPTIONS requests work for CORS preflight."""