|
| 1 | +import socket |
| 2 | +import threading |
| 3 | +import socketserver |
| 4 | +import time |
| 5 | +import sys |
| 6 | +import ssl |
| 7 | +import json |
| 8 | +import datetime |
| 9 | +import email.utils |
| 10 | +from urllib.parse import urlparse |
| 11 | + |
| 12 | +# Rate limiting settings |
| 13 | +TOKEN_BUCKET_CAPACITY = 100 # requests |
| 14 | +REFILL_RATE = 10 # requests per second |
| 15 | + |
| 16 | +# ============================================================================ |
| 17 | +# SERVICE-SPECIFIC CONFIGURATION: Customize this section for your integration |
| 18 | +# ============================================================================ |
| 19 | +# This configuration mimics Trello's rate limiting response format. |
| 20 | +# When adapting this proxy for a different third-party service, modify these |
| 21 | +# settings to match that service's 429 response behavior. |
| 22 | +# ============================================================================ |
| 23 | + |
| 24 | +RATE_LIMIT_DELAY = 3 # seconds - Time to wait before retrying |
| 25 | + |
| 26 | +class RateLimiterState: |
| 27 | + """A thread-safe class to manage the global rate limiting state.""" |
| 28 | + def __init__(self): |
| 29 | + self.lock = threading.Lock() |
| 30 | + self.rate_limiting_active = False |
| 31 | + self.test_name = None |
| 32 | + |
| 33 | + def start_rate_limiting(self, test_name): |
| 34 | + with self.lock: |
| 35 | + self.rate_limiting_active = True |
| 36 | + self.test_name = test_name |
| 37 | + |
| 38 | + def end_rate_limiting(self): |
| 39 | + with self.lock: |
| 40 | + self.rate_limiting_active = False |
| 41 | + self.test_name = None |
| 42 | + |
| 43 | + def is_rate_limiting_active(self): |
| 44 | + with self.lock: |
| 45 | + return self.rate_limiting_active, self.test_name |
| 46 | + |
| 47 | +rate_limiter_state = RateLimiterState() |
| 48 | + |
| 49 | +class TokenBucket: |
| 50 | + """A thread-safe token bucket for rate limiting.""" |
| 51 | + def __init__(self, capacity, refill_rate): |
| 52 | + self.capacity = float(capacity) |
| 53 | + self.refill_rate = float(refill_rate) |
| 54 | + self.tokens = float(capacity) |
| 55 | + self.last_refill = time.time() |
| 56 | + self.lock = threading.Lock() |
| 57 | + |
| 58 | + def consume(self, tokens): |
| 59 | + """Consumes tokens from the bucket. Returns True if successful, False otherwise.""" |
| 60 | + with self.lock: |
| 61 | + now = time.time() |
| 62 | + time_since_refill = now - self.last_refill |
| 63 | + new_tokens = time_since_refill * self.refill_rate |
| 64 | + self.tokens = min(self.capacity, self.tokens + new_tokens) |
| 65 | + self.last_refill = now |
| 66 | + |
| 67 | + if self.tokens >= tokens: |
| 68 | + self.tokens -= tokens |
| 69 | + return True |
| 70 | + return False |
| 71 | + |
| 72 | +rate_limiter = TokenBucket(TOKEN_BUCKET_CAPACITY, REFILL_RATE) |
| 73 | + |
| 74 | +def create_rate_limit_response(): |
| 75 | + """ |
| 76 | + TODO: Adopt this based on the 3rd party service's rate limiting response format. |
| 77 | +
|
| 78 | + ======================================================================== |
| 79 | + SERVICE-SPECIFIC: Customize this function for your third-party service |
| 80 | + ======================================================================== |
| 81 | + |
| 82 | + Generates the 429 Rate Limit response matching the third-party service's |
| 83 | + format. Different services may use different: |
| 84 | + - Response body structures (e.g., {"detail": "..."} vs {"error": "..."}) |
| 85 | + - Retry-After header formats (HTTP date vs seconds) |
| 86 | + - Error messages and field names |
| 87 | + |
| 88 | + This implementation matches Trello's rate limiting response format. |
| 89 | + |
| 90 | + Returns: |
| 91 | + tuple: (status_code, status_message, response_body_dict, headers_dict) |
| 92 | + """ |
| 93 | + retry_after_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=RATE_LIMIT_DELAY) |
| 94 | + retry_after_str = email.utils.formatdate( |
| 95 | + timeval=retry_after_time.timestamp(), |
| 96 | + localtime=False, |
| 97 | + usegmt=True |
| 98 | + ) |
| 99 | + |
| 100 | + response_body = {"detail": "Rate limit exceeded"} |
| 101 | + headers = {"Retry-After": retry_after_str} |
| 102 | + |
| 103 | + return 429, "Too Many Requests", response_body, headers |
| 104 | + |
| 105 | +class ProxyHandler(socketserver.BaseRequestHandler): |
| 106 | + """Handles incoming proxy requests.""" |
| 107 | + def handle(self): |
| 108 | + if not rate_limiter.consume(1): |
| 109 | + print("Rate limit exceeded. Dropping connection.") |
| 110 | + try: |
| 111 | + self.request.sendall(b'HTTP/1.1 429 Too Many Requests\r\n\r\n') |
| 112 | + except OSError: |
| 113 | + pass # Client might have already closed the connection. |
| 114 | + finally: |
| 115 | + self.request.close() |
| 116 | + return |
| 117 | + |
| 118 | + try: |
| 119 | + data = self.request.recv(4096) |
| 120 | + except ConnectionResetError: |
| 121 | + return # Client closed connection. |
| 122 | + |
| 123 | + if not data: |
| 124 | + return |
| 125 | + |
| 126 | + first_line = data.split(b'\r\n')[0] |
| 127 | + try: |
| 128 | + method, target, _ = first_line.split() |
| 129 | + except ValueError: |
| 130 | + print(f"Could not parse request: {first_line}") |
| 131 | + self.request.close() |
| 132 | + return |
| 133 | + |
| 134 | + print(f"Received request: {method.decode('utf-8')} {target.decode('utf-8')}") |
| 135 | + |
| 136 | + path = target.decode('utf-8') |
| 137 | + # Check for control plane endpoints on the proxy itself |
| 138 | + if path.startswith(('/start_rate_limiting', '/end_rate_limiting')): |
| 139 | + self.handle_control_request(method, path, data) |
| 140 | + return |
| 141 | + |
| 142 | + # Check if global rate limiting is active |
| 143 | + is_active, test_name = rate_limiter_state.is_rate_limiting_active() |
| 144 | + if is_active: |
| 145 | + print(f"Rate limiting is active for test: '{test_name}'. Blocking request.") |
| 146 | + |
| 147 | + # Generate service-specific rate limit response |
| 148 | + status_code, status_message, response_body, headers = create_rate_limit_response() |
| 149 | + self.send_json_response(status_code, status_message, response_body, headers=headers) |
| 150 | + return |
| 151 | + |
| 152 | + if method == b'CONNECT': |
| 153 | + self.handle_connect(target) |
| 154 | + else: |
| 155 | + self.handle_http_request(target, data) |
| 156 | + |
| 157 | + def get_request_body(self, data): |
| 158 | + header_end = data.find(b'\r\n\r\n') |
| 159 | + if header_end != -1: |
| 160 | + return data[header_end + 4:].decode('utf-8') |
| 161 | + return "" |
| 162 | + |
| 163 | + def send_json_response(self, status_code, status_message, body_json, headers=None): |
| 164 | + body_bytes = json.dumps(body_json).encode('utf-8') |
| 165 | + |
| 166 | + response_headers = [ |
| 167 | + f"HTTP/1.1 {status_code} {status_message}", |
| 168 | + "Content-Type: application/json", |
| 169 | + f"Content-Length: {len(body_bytes)}", |
| 170 | + "Connection: close", |
| 171 | + ] |
| 172 | + |
| 173 | + if headers: |
| 174 | + for key, value in headers.items(): |
| 175 | + response_headers.append(f"{key}: {value}") |
| 176 | + |
| 177 | + response_headers.append("") |
| 178 | + response_headers.append("") |
| 179 | + |
| 180 | + response = '\r\n'.join(response_headers).encode('utf-8') + body_bytes |
| 181 | + try: |
| 182 | + self.request.sendall(response) |
| 183 | + except OSError: |
| 184 | + pass # Client might have closed the connection. |
| 185 | + finally: |
| 186 | + self.request.close() |
| 187 | + |
| 188 | + def handle_control_request(self, method, path, data): |
| 189 | + if method != b'POST': |
| 190 | + self.send_json_response(405, "Method Not Allowed", {"error": "Only POST method is allowed"}) |
| 191 | + return |
| 192 | + |
| 193 | + if path == '/start_rate_limiting': |
| 194 | + body_str = self.get_request_body(data) |
| 195 | + if not body_str: |
| 196 | + self.send_json_response(400, "Bad Request", {"error": "Request body is missing or empty"}) |
| 197 | + return |
| 198 | + try: |
| 199 | + body_json = json.loads(body_str) |
| 200 | + test_name = body_json.get('test_name') |
| 201 | + if not test_name or not isinstance(test_name, str): |
| 202 | + self.send_json_response(400, "Bad Request", {"error": "'test_name' is missing or not a string"}) |
| 203 | + return |
| 204 | + except json.JSONDecodeError: |
| 205 | + self.send_json_response(400, "Bad Request", {"error": "Invalid JSON in request body"}) |
| 206 | + return |
| 207 | + |
| 208 | + rate_limiter_state.start_rate_limiting(test_name) |
| 209 | + response_body = {"status": f"rate limiting started for test: {test_name}"} |
| 210 | + self.send_json_response(200, "OK", response_body) |
| 211 | + |
| 212 | + elif path == '/end_rate_limiting': |
| 213 | + rate_limiter_state.end_rate_limiting() |
| 214 | + response_body = {"status": "rate limiting ended"} |
| 215 | + self.send_json_response(200, "OK", response_body) |
| 216 | + else: |
| 217 | + self.send_json_response(404, "Not Found", {"error": "Endpoint not found"}) |
| 218 | + |
| 219 | + def handle_http_request(self, target, data): |
| 220 | + """Handles HTTP requests like GET, POST, etc.""" |
| 221 | + try: |
| 222 | + parsed_url = urlparse(target.decode('utf-8')) |
| 223 | + host = parsed_url.hostname |
| 224 | + port = parsed_url.port |
| 225 | + if port is None: |
| 226 | + port = 443 if parsed_url.scheme == 'https' else 80 |
| 227 | + except Exception as e: |
| 228 | + print(f"Could not parse URL for HTTP request: {target}. Error: {e}") |
| 229 | + self.request.close() |
| 230 | + return |
| 231 | + |
| 232 | + if not host: |
| 233 | + print(f"Invalid host in URL: {target}") |
| 234 | + self.request.close() |
| 235 | + return |
| 236 | + |
| 237 | + try: |
| 238 | + remote_socket = socket.create_connection((host, port), timeout=10) |
| 239 | + if parsed_url.scheme == 'https': |
| 240 | + context = ssl.create_default_context() |
| 241 | + remote_socket = context.wrap_socket(remote_socket, server_hostname=host) |
| 242 | + except (socket.error, ssl.SSLError) as e: |
| 243 | + print(f"Failed to connect or SSL wrap to {host}:{port}: {e}") |
| 244 | + self.request.close() |
| 245 | + return |
| 246 | + |
| 247 | + # Modify the request to use a relative path and force connection closing |
| 248 | + # This ensures each request gets its own connection and is logged. |
| 249 | + header_end = data.find(b'\r\n\r\n') |
| 250 | + if header_end == -1: |
| 251 | + # If no header-body separator is found, assume it's a simple request with no body. |
| 252 | + header_end = len(data) |
| 253 | + |
| 254 | + header_data = data[:header_end] |
| 255 | + body = data[header_end:] |
| 256 | + |
| 257 | + lines = header_data.split(b'\r\n') |
| 258 | + first_line = lines[0] |
| 259 | + headers = lines[1:] |
| 260 | + |
| 261 | + method, _, http_version = first_line.split(b' ', 2) |
| 262 | + |
| 263 | + path = parsed_url.path or '/' |
| 264 | + if parsed_url.query: |
| 265 | + path += '?' + parsed_url.query |
| 266 | + |
| 267 | + new_first_line = b' '.join([method, path.encode('utf-8'), http_version]) |
| 268 | + |
| 269 | + new_headers = [] |
| 270 | + for header in headers: |
| 271 | + # Remove existing connection-related headers, as we're forcing it to close. |
| 272 | + if not header.lower().startswith(b'connection:') and \ |
| 273 | + not header.lower().startswith(b'proxy-connection:'): |
| 274 | + new_headers.append(header) |
| 275 | + new_headers.append(b'Connection: close') |
| 276 | + |
| 277 | + modified_header_part = new_first_line + b'\r\n' + b'\r\n'.join(new_headers) |
| 278 | + modified_request = modified_header_part + body |
| 279 | + |
| 280 | + try: |
| 281 | + remote_socket.sendall(modified_request) |
| 282 | + except OSError: |
| 283 | + remote_socket.close() |
| 284 | + return |
| 285 | + |
| 286 | + self.tunnel(self.request, remote_socket) |
| 287 | + |
| 288 | + def handle_connect(self, target): |
| 289 | + """Handles CONNECT requests for HTTPS traffic.""" |
| 290 | + try: |
| 291 | + host, port_str = target.split(b':') |
| 292 | + port = int(port_str) |
| 293 | + except ValueError: |
| 294 | + print(f"Invalid target for CONNECT: {target}") |
| 295 | + self.request.close() |
| 296 | + return |
| 297 | + |
| 298 | + try: |
| 299 | + remote_socket = socket.create_connection((host.decode('utf-8'), port), timeout=10) |
| 300 | + except socket.error as e: |
| 301 | + print(f"Failed to connect to {host.decode('utf-8')}:{port}: {e}") |
| 302 | + self.request.close() |
| 303 | + return |
| 304 | + |
| 305 | + try: |
| 306 | + self.request.sendall(b'HTTP/1.1 200 Connection Established\r\n\r\n') |
| 307 | + except OSError: |
| 308 | + remote_socket.close() |
| 309 | + return |
| 310 | + |
| 311 | + self.tunnel(self.request, remote_socket) |
| 312 | + |
| 313 | + def tunnel(self, client_socket, remote_socket): |
| 314 | + """Tunnels data between the client and the remote server.""" |
| 315 | + stop_event = threading.Event() |
| 316 | + |
| 317 | + def forward(src, dst): |
| 318 | + try: |
| 319 | + while not stop_event.is_set(): |
| 320 | + data = src.recv(4096) |
| 321 | + if not data: |
| 322 | + break |
| 323 | + dst.sendall(data) |
| 324 | + except OSError: |
| 325 | + pass |
| 326 | + finally: |
| 327 | + stop_event.set() |
| 328 | + |
| 329 | + client_thread = threading.Thread(target=forward, args=(client_socket, remote_socket)) |
| 330 | + remote_thread = threading.Thread(target=forward, args=(remote_socket, client_socket)) |
| 331 | + |
| 332 | + client_thread.start() |
| 333 | + remote_thread.start() |
| 334 | + |
| 335 | + client_thread.join() |
| 336 | + remote_thread.join() |
| 337 | + |
| 338 | + client_socket.close() |
| 339 | + remote_socket.close() |
| 340 | + |
| 341 | +class ThreadingTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer): |
| 342 | + daemon_threads = True |
| 343 | + allow_reuse_address = True |
| 344 | + |
| 345 | +def main(): |
| 346 | + HOST, PORT = "localhost", 8004 |
| 347 | + |
| 348 | + try: |
| 349 | + server = ThreadingTCPServer((HOST, PORT), ProxyHandler) |
| 350 | + print(f"Starting proxy server on {HOST}:{PORT}") |
| 351 | + server.serve_forever() |
| 352 | + except Exception as e: |
| 353 | + print(f"Could not start proxy server: {e}", file=sys.stderr) |
| 354 | + # The script `run_devrev_snapin_conformance_tests.sh` checks for exit code 69. |
| 355 | + sys.exit(69) |
| 356 | + |
| 357 | +if __name__ == "__main__": |
| 358 | + main() |
0 commit comments