Skip to content
19 changes: 1 addition & 18 deletions openhands-agent-server/openhands/agent_server/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from uuid import UUID

from fastapi import Depends, HTTPException, Query, Request, status
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import APIKeyHeader

from openhands.agent_server.config import Config
Expand All @@ -26,23 +26,6 @@ def check_session_api_key(
return check_session_api_key


def create_websocket_session_api_key_dependency(config: Config):
"""Create a WebSocket session API key dependency with the given config.

WebSocket connections cannot send custom headers directly from browsers,
so we use query parameters instead.
"""

def check_websocket_session_api_key(
session_api_key: str | None = Query(None, alias="session_api_key"),
):
"""Check the session API key from query parameter for WebSocket connections."""
if config.session_api_keys and session_api_key not in config.session_api_keys:
raise HTTPException(status.HTTP_401_UNAUTHORIZED)

return check_websocket_session_api_key


def get_conversation_service(request: Request):
"""Get the conversation service from app state.

Expand Down
55 changes: 41 additions & 14 deletions openhands-agent-server/openhands/agent_server/sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
WebSocket endpoints for OpenHands SDK.

These endpoints are separate from the main API routes to handle WebSocket-specific
authentication using query parameters instead of headers, since browsers cannot
send custom HTTP headers directly with WebSocket connections.
authentication. Browsers cannot send custom HTTP headers directly with WebSocket
connections, so we support the `session_api_key` query param. For non-browser
clients (e.g. Python/Node), we also support authenticating via headers.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟑 Suggestion: Consider mentioning in the docstring or in documentation that header-based authentication is the recommended approach for security (avoids leaking secrets in URLs), while query param auth is maintained for backward compatibility with browser clients.

This could help guide users toward the more secure option.

"""

import logging
Expand Down Expand Up @@ -35,6 +36,42 @@
logger = logging.getLogger(__name__)


def _resolve_websocket_session_api_key(
websocket: WebSocket,
session_api_key: str | None,
) -> str | None:
"""Resolve the session API key from multiple sources.

Precedence order (highest to lowest):
1. Query parameter (session_api_key) - for browser compatibility
2. X-Session-API-Key header - for non-browser clients

Returns None if no valid key is found in any source.
"""
if session_api_key:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟒 Nit: Using truthiness check here means an empty string query param (?session_api_key=) would be ignored and fall through to header checking. This is probably the desired behavior, but it could be surprising. Consider adding a comment to clarify this is intentional.

return session_api_key

header_key = websocket.headers.get("x-session-api-key")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

πŸ”΄ Critical: The PR description claims to support both X-Session-API-Key and Authorization: Bearer headers, but only X-Session-API-Key is implemented here.

Either:

  1. Add support for Authorization: Bearer header, OR
  2. Update the PR description to remove the claim about Bearer token support

If adding Bearer support, you could add after this line:

# Support Authorization: Bearer <token> format
auth_header = websocket.headers.get("authorization")
if auth_header and auth_header.startswith("Bearer "):
    return auth_header[7:]  # Strip "Bearer " prefix

if header_key:
return header_key

return None


async def _accept_authenticated_websocket(
websocket: WebSocket,
session_api_key: str | None,
) -> bool:
"""Authenticate and accept the socket, or close with an auth error."""
config = get_default_config()
resolved_key = _resolve_websocket_session_api_key(websocket, session_api_key)
if config.session_api_keys and resolved_key not in config.session_api_keys:
await websocket.close(code=4001, reason="Authentication failed")
return False
await websocket.accept()
return True


@sockets_router.websocket("/events/{conversation_id}")
async def events_socket(
conversation_id: UUID,
Expand All @@ -43,14 +80,9 @@ async def events_socket(
resend_all: Annotated[bool, Query()] = False,
):
"""WebSocket endpoint for conversation events."""
# Perform authentication check before accepting the WebSocket connection
config = get_default_config()
if config.session_api_keys and session_api_key not in config.session_api_keys:
# Close the WebSocket connection with an authentication error code
await websocket.close(code=4001, reason="Authentication failed")
if not await _accept_authenticated_websocket(websocket, session_api_key):
return

await websocket.accept()
logger.info(f"Event Websocket Connected: {conversation_id}")
event_service = await conversation_service.get_event_service(conversation_id)
if event_service is None:
Expand Down Expand Up @@ -97,14 +129,9 @@ async def bash_events_socket(
resend_all: Annotated[bool, Query()] = False,
):
"""WebSocket endpoint for bash events."""
# Perform authentication check before accepting the WebSocket connection
config = get_default_config()
if config.session_api_keys and session_api_key not in config.session_api_keys:
# Close the WebSocket connection with an authentication error code
await websocket.close(code=4001, reason="Authentication failed")
if not await _accept_authenticated_websocket(websocket, session_api_key):
return

await websocket.accept()
logger.info("Bash Websocket Connected")
subscriber_id = await bash_event_service.subscribe_to_events(
_BashWebSocketSubscriber(websocket)
Expand Down
43 changes: 43 additions & 0 deletions tests/agent_server/test_agent_server_wsproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,46 @@ async def test_agent_server_websocket_with_wsproto(agent_server):
await ws.send(
json.dumps({"role": "user", "content": "Hello from wsproto test"})
)


@pytest.mark.asyncio
async def test_agent_server_websocket_with_wsproto_header_auth(agent_server):
port = agent_server["port"]
api_key = agent_server["api_key"]

response = requests.post(
f"http://127.0.0.1:{port}/api/conversations",
headers={"X-Session-API-Key": api_key},
json={
"agent": {
"llm": {
"usage_id": "test-llm",
"model": "test-provider/test-model",
"api_key": "test-key",
},
"tools": [],
},
"workspace": {"working_dir": "/tmp/test-workspace"},
},
)
assert response.status_code in [200, 201]
conversation_id = response.json()["id"]

ws_url = f"ws://127.0.0.1:{port}/sockets/events/{conversation_id}?resend_all=true"

async with websockets.connect(
ws_url,
open_timeout=5,
additional_headers={"X-Session-API-Key": api_key},
) as ws:
try:
response = await asyncio.wait_for(ws.recv(), timeout=2)
assert response is not None
except TimeoutError:
pass

await ws.send(
json.dumps(
{"role": "user", "content": "Hello from wsproto header auth test"}
)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟑 Suggestion: Consider adding negative test cases for header auth in the wsproto tests (e.g., connection fails without auth, connection fails with wrong header).

While test_api_authentication.py has comprehensive coverage, having at least one negative case here would ensure the feature works correctly with real WebSocket libraries like websockets.

53 changes: 35 additions & 18 deletions tests/agent_server/test_api_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,29 +213,46 @@ def test_api_websocket_authentication():
app = create_app(config)
client = TestClient(app)

# Test WebSocket connection without authentication - should fail
try:
# Without authentication -> should fail
with pytest.raises(Exception):
with client.websocket_connect("/sockets/bash-events"):
# If we get here, the connection was established without auth
# (should not happen)
assert False, (
"WebSocket connection should have failed without authentication"
)
except Exception:
# WebSocket connection should fail without proper authentication
assert False, "WebSocket connection should have failed without auth"

# Query-param authentication -> should work (browser-compatible)
with client.websocket_connect("/sockets/bash-events?session_api_key=test-key"):
pass

# Header authentication -> should work for non-browser clients
with client.websocket_connect(
"/sockets/bash-events",
headers={"X-Session-API-Key": "test-key"},
):
pass

# Test WebSocket connection with authentication via query parameter - should work
try:
with client.websocket_connect("/sockets/bash-events?session_api_key=test-key"):
# If we get here, the connection was established with proper auth
pass
except Exception:
# Connection might fail for other reasons (like missing conversation ID for
# events endpoint)
# This test mainly ensures the auth mechanism works
# Query param should take precedence over headers (browser-compatible escape hatch).
with client.websocket_connect(
"/sockets/bash-events?session_api_key=test-key",
headers={"X-Session-API-Key": "wrong-key"},
):
pass

# If query param is present and wrong, connection should fail even if the
# header is correct.
with pytest.raises(Exception):
with client.websocket_connect(
"/sockets/bash-events?session_api_key=wrong-key",
headers={"X-Session-API-Key": "test-key"},
):
assert False, "WebSocket connection should have failed with wrong query key"

# Wrong header -> should fail
with pytest.raises(Exception):
with client.websocket_connect(
"/sockets/bash-events",
headers={"X-Session-API-Key": "wrong-key"},
):
assert False, "WebSocket connection should have failed with wrong key"


def test_api_options_requests():
"""Test that OPTIONS requests work for CORS preflight."""
Expand Down
Loading