forked from OpenBMB/ChatDev
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmiddleware.py
More file actions
executable file
·95 lines (77 loc) · 3.4 KB
/
middleware.py
File metadata and controls
executable file
·95 lines (77 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Custom middleware for the DevAll workflow system."""
import uuid
from typing import Callable, Awaitable
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
import time
import re
from utils.structured_logger import get_server_logger, LogType
from utils.exceptions import SecurityError
async def correlation_id_middleware(request: Request, call_next: Callable):
"""Add correlation ID to requests for tracing."""
correlation_id = request.headers.get("X-Correlation-ID") or str(uuid.uuid4())
request.state.correlation_id = correlation_id
start_time = time.time()
response = await call_next(request)
duration = time.time() - start_time
# Log the request and response
logger = get_server_logger()
logger.log_request(
request.method,
str(request.url),
correlation_id=correlation_id,
path=request.url.path,
query_params=dict(request.query_params),
client_host=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent")
)
logger.log_response(
response.status_code,
duration,
correlation_id=correlation_id,
content_length=response.headers.get("content-length")
)
# Add correlation ID to response headers
response.headers["X-Correlation-ID"] = correlation_id
return response
async def security_middleware(request: Request, call_next: Callable):
"""Security middleware to validate requests."""
# Validate content type for JSON endpoints
if request.url.path.startswith("/api/") and request.method in ["POST", "PUT", "PATCH"]:
content_type = request.headers.get("content-type", "").lower()
if not content_type.startswith("application/json") and request.method != "GET":
# Skip validation for file uploads
if not content_type.startswith("multipart/form-data"):
raise HTTPException(
status_code=400,
detail="Content-Type must be application/json for API endpoints"
)
# Validate file paths to prevent path traversal
# Check URL path for suspicious patterns
path = request.url.path
if ".." in path or "./" in path:
# Use a more thorough check
if re.search(r"(\.{2}[/\\])|([/\\]\.{2})", path):
logger = get_server_logger()
logger.log_security_event(
"PATH_TRAVERSAL_ATTEMPT",
f"Suspicious path detected: {path}",
correlation_id=getattr(request.state, 'correlation_id', str(uuid.uuid4()))
)
raise HTTPException(status_code=400, detail="Invalid path")
response = await call_next(request)
return response
async def rate_limit_middleware(request: Request, call_next: Callable):
"""Rate limiting middleware (basic implementation)."""
# This is a simple rate limiting implementation
# In production, you would use Redis or other storage for tracking
# This is just a placeholder for now
response = await call_next(request)
return response
def add_middleware(app):
"""Add all middleware to the FastAPI application."""
# Add middleware in the appropriate order
app.middleware("http")(correlation_id_middleware)
app.middleware("http")(security_middleware)
# app.middleware("http")(rate_limit_middleware) # Enable if needed
return app