Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyrit/prompt_target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_http_target_regex_matching_callback_function,
)
from pyrit.prompt_target.http_target.httpx_api_target import HTTPXAPITarget
from pyrit.prompt_target.http_target.mcp_auth_bypass_target import MCPAuthBypassTarget
from pyrit.prompt_target.hugging_face.hugging_face_chat_target import HuggingFaceChatTarget
from pyrit.prompt_target.hugging_face.hugging_face_endpoint_target import HuggingFaceEndpointTarget
from pyrit.prompt_target.openai.openai_chat_audio_config import OpenAIChatAudioConfig
Expand Down Expand Up @@ -50,6 +51,7 @@
"get_http_target_regex_matching_callback_function",
"HTTPTarget",
"HTTPXAPITarget",
"MCPAuthBypassTarget",
"HuggingFaceChatTarget",
"HuggingFaceEndpointTarget",
"limit_requests_per_minute",
Expand Down
137 changes: 137 additions & 0 deletions pyrit/prompt_target/http_target/mcp_auth_bypass_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
import logging
from typing import Any, Optional

import httpx

from pyrit.models import (
MessagePiece,
construct_response_from_request,
)
from pyrit.prompt_target.common.prompt_target import PromptTarget
from pyrit.prompt_target.common.utils import limit_requests_per_minute

logger = logging.getLogger(__name__)


class MCPAuthBypassTarget(PromptTarget):
"""
MCPAuthBypassTarget tests MCP server endpoints for authentication bypass vulnerabilities.
Implements OWASP MCP-07 (Insufficient Authentication/Authorization) testing.

Args:
mcp_server_url (str): The base URL of the MCP server endpoint.
bypass_technique (str): One of: no_auth, empty_token, malformed_token, role_escalation
mcp_method (str): The MCP JSON-RPC method to test. Defaults to tools/list.
timeout (int): Request timeout in seconds. Defaults to 30.
max_requests_per_minute (int, Optional): Rate limit for requests.
"""

BYPASS_TECHNIQUES = ["no_auth", "empty_token", "malformed_token", "role_escalation"]

def __init__(
self,
mcp_server_url: str,
bypass_technique: str = "no_auth",
mcp_method: str = "tools/list",
timeout: int = 30,
max_requests_per_minute: Optional[int] = None,
) -> None:
"""
Initialize MCPAuthBypassTarget.

Args:
mcp_server_url: The base URL of the MCP server endpoint.
bypass_technique: The auth bypass technique to use.
mcp_method: The MCP JSON-RPC method to test.
timeout: Request timeout in seconds.
max_requests_per_minute: Rate limit for requests.

Raises:
ValueError: If bypass_technique is not a valid option.
"""
if bypass_technique not in self.BYPASS_TECHNIQUES:
raise ValueError(
f"Invalid bypass_technique '{bypass_technique}'. "
f"Must be one of: {self.BYPASS_TECHNIQUES}"
)
super().__init__(
max_requests_per_minute=max_requests_per_minute,
endpoint=mcp_server_url,
)
self.mcp_server_url = mcp_server_url
self.bypass_technique = bypass_technique
self.mcp_method = mcp_method
self.timeout = timeout

def _build_headers(self) -> dict[str, str]:
base_headers = {"Content-Type": "application/json", "Accept": "application/json"}
if self.bypass_technique == "no_auth":
return base_headers
if self.bypass_technique == "empty_token":
return {**base_headers, "Authorization": "Bearer "}
if self.bypass_technique == "malformed_token":
return {**base_headers, "Authorization": "Bearer eyJhbGciOiJIUzI1NiJ9.invalid"}
if self.bypass_technique == "role_escalation":
tampered = "eyJhbGciOiJub25lIn0.eyJzdWIiOiJ0ZXN0Iiwicm9sZSI6ImFkbWluIn0."
return {**base_headers, "Authorization": f"Bearer {tampered}"}
return base_headers

def _build_jsonrpc_payload(self, prompt: str) -> dict[str, Any]:
return {"jsonrpc": "2.0", "id": 1, "method": self.mcp_method, "params": {"prompt": prompt}}

def _evaluate_response(self, status_code: int, response_body: str) -> str:
if status_code == 200:
return (
f"[VULNERABILITY DETECTED] MCP-07 Auth Bypass succeeded using "
f"'{self.bypass_technique}'. Server returned HTTP 200.\nResponse: {response_body[:500]}"
)
if status_code in (401, 403):
return (
f"[SECURE] Server correctly rejected with HTTP {status_code} "
f"using '{self.bypass_technique}'.\nResponse: {response_body[:200]}"
)
return (
f"[INVESTIGATE] Unexpected HTTP {status_code} "
f"using '{self.bypass_technique}'.\nResponse: {response_body[:200]}"
)

def _validate_request(self, *, message) -> None:
"""
Validate the request message. MCP target accepts all text messages.

Raises:
ValueError: If the message is None or empty.
"""
if not message:
raise ValueError("Message cannot be None or empty.")

@limit_requests_per_minute
async def send_prompt_async(self, *, prompt_request: MessagePiece) -> MessagePiece:
"""
Send a prompt to the MCP server using the configured auth bypass technique.

Args:
prompt_request: The prompt request to send.

Returns:
MessagePiece: The response containing bypass test results.
"""
prompt_text = prompt_request.converted_value
headers = self._build_headers()
payload = self._build_jsonrpc_payload(prompt_text)
logger.info(f"MCPAuthBypassTarget: Testing '{self.bypass_technique}' against {self.mcp_server_url}")
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(self.mcp_server_url, headers=headers, content=json.dumps(payload))
result = self._evaluate_response(response.status_code, response.text)
except httpx.TimeoutException:
result = f"[ERROR] Request timed out after {self.timeout}s"
except httpx.ConnectError as e:
result = f"[ERROR] Connection failed to {self.mcp_server_url}: {e}"
except Exception as e:
result = f"[ERROR] Unexpected error: {type(e).__name__}: {e}"
return construct_response_from_request(request=prompt_request, response_text_pieces=[result])
66 changes: 66 additions & 0 deletions tests/unit/target/test_mcp_auth_bypass_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from unittest.mock import MagicMock

import pytest

from pyrit.prompt_target.http_target.mcp_auth_bypass_target import MCPAuthBypassTarget


def make_mock_request(text="test prompt"):
req = MagicMock()
req.converted_value = text
return req


class TestMCPAuthBypassTargetInit:
def test_valid_bypass_technique(self, sqlite_instance):
target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth")
assert target.bypass_technique == "no_auth"

def test_invalid_bypass_technique_raises(self, sqlite_instance):
with pytest.raises(ValueError, match="Invalid bypass_technique"):
MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="invalid")

def test_default_values(self, sqlite_instance):
target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080")
assert target.bypass_technique == "no_auth"
assert target.mcp_method == "tools/list"
assert target.timeout == 30


class TestMCPAuthBypassTargetHeaders:
def test_no_auth_has_no_authorization_header(self, sqlite_instance):
target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth")
assert "Authorization" not in target._build_headers()

def test_empty_token_has_empty_bearer(self, sqlite_instance):
target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="empty_token")
assert target._build_headers()["Authorization"] == "Bearer "

def test_malformed_token_has_invalid_jwt(self, sqlite_instance):
target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="malformed_token")
assert "invalid" in target._build_headers()["Authorization"]

def test_role_escalation_has_tampered_token(self, sqlite_instance):
target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="role_escalation")
assert "eyJhbGciOiJub25lIn0" in target._build_headers()["Authorization"]


class TestMCPAuthBypassTargetEvaluate:
def test_200_detected_as_vulnerability(self, sqlite_instance):
target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth")
assert "VULNERABILITY DETECTED" in target._evaluate_response(200, "ok")

def test_401_detected_as_secure(self, sqlite_instance):
target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth")
assert "SECURE" in target._evaluate_response(401, "Unauthorized")

def test_403_detected_as_secure(self, sqlite_instance):
target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth")
assert "SECURE" in target._evaluate_response(403, "Forbidden")

def test_500_flagged_for_investigation(self, sqlite_instance):
target = MCPAuthBypassTarget(mcp_server_url="http://localhost:8080", bypass_technique="no_auth")
assert "INVESTIGATE" in target._evaluate_response(500, "Server Error")