diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index 05af2d67d8..0ca3d40c9b 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -22,6 +22,7 @@ get_http_target_regex_matching_callback_function, ) from pyrit.prompt_target.http_target.httpx_api_target import HTTPXAPITarget +from pyrit.prompt_target.http_target.mcp_auth_bypass_target import MCPAuthBypassTarget from pyrit.prompt_target.hugging_face.hugging_face_chat_target import HuggingFaceChatTarget from pyrit.prompt_target.hugging_face.hugging_face_endpoint_target import HuggingFaceEndpointTarget from pyrit.prompt_target.openai.openai_chat_audio_config import OpenAIChatAudioConfig @@ -50,6 +51,7 @@ "get_http_target_regex_matching_callback_function", "HTTPTarget", "HTTPXAPITarget", + "MCPAuthBypassTarget", "HuggingFaceChatTarget", "HuggingFaceEndpointTarget", "limit_requests_per_minute", diff --git a/pyrit/prompt_target/http_target/mcp_auth_bypass_target.py b/pyrit/prompt_target/http_target/mcp_auth_bypass_target.py new file mode 100644 index 0000000000..0db3cf69f9 --- /dev/null +++ b/pyrit/prompt_target/http_target/mcp_auth_bypass_target.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import json +import logging +from typing import Any, Optional + +import httpx + +from pyrit.models import ( + MessagePiece, + construct_response_from_request, +) +from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.utils import limit_requests_per_minute + +logger = logging.getLogger(__name__) + + +class MCPAuthBypassTarget(PromptTarget): + """ + MCPAuthBypassTarget tests MCP server endpoints for authentication bypass vulnerabilities. + Implements OWASP MCP-07 (Insufficient Authentication/Authorization) testing. + + Args: + mcp_server_url (str): The base URL of the MCP server endpoint. + bypass_technique (str): One of: no_auth, empty_token, malformed_token, role_escalation + mcp_method (str): The MCP JSON-RPC method to test. Defaults to tools/list. + timeout (int): Request timeout in seconds. Defaults to 30. + max_requests_per_minute (int, Optional): Rate limit for requests. + """ + + BYPASS_TECHNIQUES = ["no_auth", "empty_token", "malformed_token", "role_escalation"] + + def __init__( + self, + mcp_server_url: str, + bypass_technique: str = "no_auth", + mcp_method: str = "tools/list", + timeout: int = 30, + max_requests_per_minute: Optional[int] = None, + ) -> None: + """ + Initialize MCPAuthBypassTarget. + + Args: + mcp_server_url: The base URL of the MCP server endpoint. + bypass_technique: The auth bypass technique to use. + mcp_method: The MCP JSON-RPC method to test. + timeout: Request timeout in seconds. + max_requests_per_minute: Rate limit for requests. + + Raises: + ValueError: If bypass_technique is not a valid option. + """ + if bypass_technique not in self.BYPASS_TECHNIQUES: + raise ValueError( + f"Invalid bypass_technique '{bypass_technique}'. " + f"Must be one of: {self.BYPASS_TECHNIQUES}" + ) + super().__init__( + max_requests_per_minute=max_requests_per_minute, + endpoint=mcp_server_url, + ) + self.mcp_server_url = mcp_server_url + self.bypass_technique = bypass_technique + self.mcp_method = mcp_method + self.timeout = timeout + + def _build_headers(self) -> dict[str, str]: + base_headers = {"Content-Type": "application/json", "Accept": "application/json"} + if self.bypass_technique == "no_auth": + return base_headers + if self.bypass_technique == "empty_token": + return {**base_headers, "Authorization": "Bearer "} + if self.bypass_technique == "malformed_token": + return {**base_headers, "Authorization": "Bearer eyJhbGciOiJIUzI1NiJ9.invalid"} + if self.bypass_technique == "role_escalation": + tampered = "eyJhbGciOiJub25lIn0.eyJzdWIiOiJ0ZXN0Iiwicm9sZSI6ImFkbWluIn0." + return {**base_headers, "Authorization": f"Bearer {tampered}"} + return base_headers + + def _build_jsonrpc_payload(self, prompt: str) -> dict[str, Any]: + return {"jsonrpc": "2.0", "id": 1, "method": self.mcp_method, "params": {"prompt": prompt}} + + def _evaluate_response(self, status_code: int, response_body: str) -> str: + if status_code == 200: + return ( + f"[VULNERABILITY DETECTED] MCP-07 Auth Bypass succeeded using " + f"'{self.bypass_technique}'. Server returned HTTP 200.\nResponse: {response_body[:500]}" + ) + if status_code in (401, 403): + return ( + f"[SECURE] Server correctly rejected with HTTP {status_code} " + f"using '{self.bypass_technique}'.\nResponse: {response_body[:200]}" + ) + return ( + f"[INVESTIGATE] Unexpected HTTP {status_code} " + f"using '{self.bypass_technique}'.\nResponse: {response_body[:200]}" + ) + + def _validate_request(self, *, message) -> None: + """ + Validate the request message. MCP target accepts all text messages. + + Raises: + ValueError: If the message is None or empty. + """ + if not message: + raise ValueError("Message cannot be None or empty.") + + @limit_requests_per_minute + async def send_prompt_async(self, *, prompt_request: MessagePiece) -> MessagePiece: + """ + Send a prompt to the MCP server using the configured auth bypass technique. + + Args: + prompt_request: The prompt request to send. + + Returns: + MessagePiece: The response containing bypass test results. + """ + prompt_text = prompt_request.converted_value + headers = self._build_headers() + payload = self._build_jsonrpc_payload(prompt_text) + logger.info(f"MCPAuthBypassTarget: Testing '{self.bypass_technique}' against {self.mcp_server_url}") + try: + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.post(self.mcp_server_url, headers=headers, content=json.dumps(payload)) + result = self._evaluate_response(response.status_code, response.text) + except httpx.TimeoutException: + result = f"[ERROR] Request timed out after {self.timeout}s" + except httpx.ConnectError as e: + result = f"[ERROR] Connection failed to {self.mcp_server_url}: {e}" + except Exception as e: + result = f"[ERROR] Unexpected error: {type(e).__name__}: {e}" + return construct_response_from_request(request=prompt_request, response_text_pieces=[result]) diff --git a/tests/unit/target/test_mcp_auth_bypass_target.py b/tests/unit/target/test_mcp_auth_bypass_target.py new file mode 100644 index 0000000000..d34fbea911 --- /dev/null +++ b/tests/unit/target/test_mcp_auth_bypass_target.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from unittest.mock import MagicMock + +import pytest + +from pyrit.prompt_target.http_target.mcp_auth_bypass_target import MCPAuthBypassTarget + + +def make_mock_request(text="test prompt"): + req = MagicMock() + req.converted_value = text + return req + + +class TestMCPAuthBypassTargetInit: + def test_valid_bypass_technique(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth") + assert target.bypass_technique == "no_auth" + + def test_invalid_bypass_technique_raises(self, sqlite_instance): + with pytest.raises(ValueError, match="Invalid bypass_technique"): + MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="invalid") + + def test_default_values(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080") + assert target.bypass_technique == "no_auth" + assert target.mcp_method == "tools/list" + assert target.timeout == 30 + + +class TestMCPAuthBypassTargetHeaders: + def test_no_auth_has_no_authorization_header(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth") + assert "Authorization" not in target._build_headers() + + def test_empty_token_has_empty_bearer(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="empty_token") + assert target._build_headers()["Authorization"] == "Bearer " + + def test_malformed_token_has_invalid_jwt(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="malformed_token") + assert "invalid" in target._build_headers()["Authorization"] + + def test_role_escalation_has_tampered_token(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="role_escalation") + assert "eyJhbGciOiJub25lIn0" in target._build_headers()["Authorization"] + + +class TestMCPAuthBypassTargetEvaluate: + def test_200_detected_as_vulnerability(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth") + assert "VULNERABILITY DETECTED" in target._evaluate_response(200, "ok") + + def test_401_detected_as_secure(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth") + assert "SECURE" in target._evaluate_response(401, "Unauthorized") + + def test_403_detected_as_secure(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth") + assert "SECURE" in target._evaluate_response(403, "Forbidden") + + def test_500_flagged_for_investigation(self, sqlite_instance): + target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth") + assert "INVESTIGATE" in target._evaluate_response(500, "Server Error")