Skip to content

Commit 8264853

Browse files
authored
Rate limiting proxy (#19)
Closed #18 and opened this one for clear diff. ## Summary Related PR: devrev/airdrop-shared#31 After discussion with Gasper and Patricija, we're moving this away from shared and into 3rd party service-specific snap-in repos. Why: In order for the "rate limiting reponse" to match the actual response, rate limiting proxy has to be very slightly adopted for each 3rd party service provider (I indicated and documented what has to be changed). - [#ISS-217157](https://app.devrev.ai/devrev/works/ISS-217157)
1 parent ab95792 commit 8264853

File tree

1 file changed

+358
-0
lines changed

1 file changed

+358
-0
lines changed

rate_limiting_proxy.py

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
import socket
2+
import threading
3+
import socketserver
4+
import time
5+
import sys
6+
import ssl
7+
import json
8+
import datetime
9+
import email.utils
10+
from urllib.parse import urlparse
11+
12+
# Rate limiting settings
13+
TOKEN_BUCKET_CAPACITY = 100 # requests
14+
REFILL_RATE = 10 # requests per second
15+
16+
# ============================================================================
17+
# SERVICE-SPECIFIC CONFIGURATION: Customize this section for your integration
18+
# ============================================================================
19+
# This configuration mimics Trello's rate limiting response format.
20+
# When adapting this proxy for a different third-party service, modify these
21+
# settings to match that service's 429 response behavior.
22+
# ============================================================================
23+
24+
RATE_LIMIT_DELAY = 3 # seconds - Time to wait before retrying
25+
26+
class RateLimiterState:
27+
"""A thread-safe class to manage the global rate limiting state."""
28+
def __init__(self):
29+
self.lock = threading.Lock()
30+
self.rate_limiting_active = False
31+
self.test_name = None
32+
33+
def start_rate_limiting(self, test_name):
34+
with self.lock:
35+
self.rate_limiting_active = True
36+
self.test_name = test_name
37+
38+
def end_rate_limiting(self):
39+
with self.lock:
40+
self.rate_limiting_active = False
41+
self.test_name = None
42+
43+
def is_rate_limiting_active(self):
44+
with self.lock:
45+
return self.rate_limiting_active, self.test_name
46+
47+
rate_limiter_state = RateLimiterState()
48+
49+
class TokenBucket:
50+
"""A thread-safe token bucket for rate limiting."""
51+
def __init__(self, capacity, refill_rate):
52+
self.capacity = float(capacity)
53+
self.refill_rate = float(refill_rate)
54+
self.tokens = float(capacity)
55+
self.last_refill = time.time()
56+
self.lock = threading.Lock()
57+
58+
def consume(self, tokens):
59+
"""Consumes tokens from the bucket. Returns True if successful, False otherwise."""
60+
with self.lock:
61+
now = time.time()
62+
time_since_refill = now - self.last_refill
63+
new_tokens = time_since_refill * self.refill_rate
64+
self.tokens = min(self.capacity, self.tokens + new_tokens)
65+
self.last_refill = now
66+
67+
if self.tokens >= tokens:
68+
self.tokens -= tokens
69+
return True
70+
return False
71+
72+
rate_limiter = TokenBucket(TOKEN_BUCKET_CAPACITY, REFILL_RATE)
73+
74+
def create_rate_limit_response():
75+
"""
76+
TODO: Adopt this based on the 3rd party service's rate limiting response format.
77+
78+
========================================================================
79+
SERVICE-SPECIFIC: Customize this function for your third-party service
80+
========================================================================
81+
82+
Generates the 429 Rate Limit response matching the third-party service's
83+
format. Different services may use different:
84+
- Response body structures (e.g., {"detail": "..."} vs {"error": "..."})
85+
- Retry-After header formats (HTTP date vs seconds)
86+
- Error messages and field names
87+
88+
This implementation matches Trello's rate limiting response format.
89+
90+
Returns:
91+
tuple: (status_code, status_message, response_body_dict, headers_dict)
92+
"""
93+
retry_after_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=RATE_LIMIT_DELAY)
94+
retry_after_str = email.utils.formatdate(
95+
timeval=retry_after_time.timestamp(),
96+
localtime=False,
97+
usegmt=True
98+
)
99+
100+
response_body = {"detail": "Rate limit exceeded"}
101+
headers = {"Retry-After": retry_after_str}
102+
103+
return 429, "Too Many Requests", response_body, headers
104+
105+
class ProxyHandler(socketserver.BaseRequestHandler):
106+
"""Handles incoming proxy requests."""
107+
def handle(self):
108+
if not rate_limiter.consume(1):
109+
print("Rate limit exceeded. Dropping connection.")
110+
try:
111+
self.request.sendall(b'HTTP/1.1 429 Too Many Requests\r\n\r\n')
112+
except OSError:
113+
pass # Client might have already closed the connection.
114+
finally:
115+
self.request.close()
116+
return
117+
118+
try:
119+
data = self.request.recv(4096)
120+
except ConnectionResetError:
121+
return # Client closed connection.
122+
123+
if not data:
124+
return
125+
126+
first_line = data.split(b'\r\n')[0]
127+
try:
128+
method, target, _ = first_line.split()
129+
except ValueError:
130+
print(f"Could not parse request: {first_line}")
131+
self.request.close()
132+
return
133+
134+
print(f"Received request: {method.decode('utf-8')} {target.decode('utf-8')}")
135+
136+
path = target.decode('utf-8')
137+
# Check for control plane endpoints on the proxy itself
138+
if path.startswith(('/start_rate_limiting', '/end_rate_limiting')):
139+
self.handle_control_request(method, path, data)
140+
return
141+
142+
# Check if global rate limiting is active
143+
is_active, test_name = rate_limiter_state.is_rate_limiting_active()
144+
if is_active:
145+
print(f"Rate limiting is active for test: '{test_name}'. Blocking request.")
146+
147+
# Generate service-specific rate limit response
148+
status_code, status_message, response_body, headers = create_rate_limit_response()
149+
self.send_json_response(status_code, status_message, response_body, headers=headers)
150+
return
151+
152+
if method == b'CONNECT':
153+
self.handle_connect(target)
154+
else:
155+
self.handle_http_request(target, data)
156+
157+
def get_request_body(self, data):
158+
header_end = data.find(b'\r\n\r\n')
159+
if header_end != -1:
160+
return data[header_end + 4:].decode('utf-8')
161+
return ""
162+
163+
def send_json_response(self, status_code, status_message, body_json, headers=None):
164+
body_bytes = json.dumps(body_json).encode('utf-8')
165+
166+
response_headers = [
167+
f"HTTP/1.1 {status_code} {status_message}",
168+
"Content-Type: application/json",
169+
f"Content-Length: {len(body_bytes)}",
170+
"Connection: close",
171+
]
172+
173+
if headers:
174+
for key, value in headers.items():
175+
response_headers.append(f"{key}: {value}")
176+
177+
response_headers.append("")
178+
response_headers.append("")
179+
180+
response = '\r\n'.join(response_headers).encode('utf-8') + body_bytes
181+
try:
182+
self.request.sendall(response)
183+
except OSError:
184+
pass # Client might have closed the connection.
185+
finally:
186+
self.request.close()
187+
188+
def handle_control_request(self, method, path, data):
189+
if method != b'POST':
190+
self.send_json_response(405, "Method Not Allowed", {"error": "Only POST method is allowed"})
191+
return
192+
193+
if path == '/start_rate_limiting':
194+
body_str = self.get_request_body(data)
195+
if not body_str:
196+
self.send_json_response(400, "Bad Request", {"error": "Request body is missing or empty"})
197+
return
198+
try:
199+
body_json = json.loads(body_str)
200+
test_name = body_json.get('test_name')
201+
if not test_name or not isinstance(test_name, str):
202+
self.send_json_response(400, "Bad Request", {"error": "'test_name' is missing or not a string"})
203+
return
204+
except json.JSONDecodeError:
205+
self.send_json_response(400, "Bad Request", {"error": "Invalid JSON in request body"})
206+
return
207+
208+
rate_limiter_state.start_rate_limiting(test_name)
209+
response_body = {"status": f"rate limiting started for test: {test_name}"}
210+
self.send_json_response(200, "OK", response_body)
211+
212+
elif path == '/end_rate_limiting':
213+
rate_limiter_state.end_rate_limiting()
214+
response_body = {"status": "rate limiting ended"}
215+
self.send_json_response(200, "OK", response_body)
216+
else:
217+
self.send_json_response(404, "Not Found", {"error": "Endpoint not found"})
218+
219+
def handle_http_request(self, target, data):
220+
"""Handles HTTP requests like GET, POST, etc."""
221+
try:
222+
parsed_url = urlparse(target.decode('utf-8'))
223+
host = parsed_url.hostname
224+
port = parsed_url.port
225+
if port is None:
226+
port = 443 if parsed_url.scheme == 'https' else 80
227+
except Exception as e:
228+
print(f"Could not parse URL for HTTP request: {target}. Error: {e}")
229+
self.request.close()
230+
return
231+
232+
if not host:
233+
print(f"Invalid host in URL: {target}")
234+
self.request.close()
235+
return
236+
237+
try:
238+
remote_socket = socket.create_connection((host, port), timeout=10)
239+
if parsed_url.scheme == 'https':
240+
context = ssl.create_default_context()
241+
remote_socket = context.wrap_socket(remote_socket, server_hostname=host)
242+
except (socket.error, ssl.SSLError) as e:
243+
print(f"Failed to connect or SSL wrap to {host}:{port}: {e}")
244+
self.request.close()
245+
return
246+
247+
# Modify the request to use a relative path and force connection closing
248+
# This ensures each request gets its own connection and is logged.
249+
header_end = data.find(b'\r\n\r\n')
250+
if header_end == -1:
251+
# If no header-body separator is found, assume it's a simple request with no body.
252+
header_end = len(data)
253+
254+
header_data = data[:header_end]
255+
body = data[header_end:]
256+
257+
lines = header_data.split(b'\r\n')
258+
first_line = lines[0]
259+
headers = lines[1:]
260+
261+
method, _, http_version = first_line.split(b' ', 2)
262+
263+
path = parsed_url.path or '/'
264+
if parsed_url.query:
265+
path += '?' + parsed_url.query
266+
267+
new_first_line = b' '.join([method, path.encode('utf-8'), http_version])
268+
269+
new_headers = []
270+
for header in headers:
271+
# Remove existing connection-related headers, as we're forcing it to close.
272+
if not header.lower().startswith(b'connection:') and \
273+
not header.lower().startswith(b'proxy-connection:'):
274+
new_headers.append(header)
275+
new_headers.append(b'Connection: close')
276+
277+
modified_header_part = new_first_line + b'\r\n' + b'\r\n'.join(new_headers)
278+
modified_request = modified_header_part + body
279+
280+
try:
281+
remote_socket.sendall(modified_request)
282+
except OSError:
283+
remote_socket.close()
284+
return
285+
286+
self.tunnel(self.request, remote_socket)
287+
288+
def handle_connect(self, target):
289+
"""Handles CONNECT requests for HTTPS traffic."""
290+
try:
291+
host, port_str = target.split(b':')
292+
port = int(port_str)
293+
except ValueError:
294+
print(f"Invalid target for CONNECT: {target}")
295+
self.request.close()
296+
return
297+
298+
try:
299+
remote_socket = socket.create_connection((host.decode('utf-8'), port), timeout=10)
300+
except socket.error as e:
301+
print(f"Failed to connect to {host.decode('utf-8')}:{port}: {e}")
302+
self.request.close()
303+
return
304+
305+
try:
306+
self.request.sendall(b'HTTP/1.1 200 Connection Established\r\n\r\n')
307+
except OSError:
308+
remote_socket.close()
309+
return
310+
311+
self.tunnel(self.request, remote_socket)
312+
313+
def tunnel(self, client_socket, remote_socket):
314+
"""Tunnels data between the client and the remote server."""
315+
stop_event = threading.Event()
316+
317+
def forward(src, dst):
318+
try:
319+
while not stop_event.is_set():
320+
data = src.recv(4096)
321+
if not data:
322+
break
323+
dst.sendall(data)
324+
except OSError:
325+
pass
326+
finally:
327+
stop_event.set()
328+
329+
client_thread = threading.Thread(target=forward, args=(client_socket, remote_socket))
330+
remote_thread = threading.Thread(target=forward, args=(remote_socket, client_socket))
331+
332+
client_thread.start()
333+
remote_thread.start()
334+
335+
client_thread.join()
336+
remote_thread.join()
337+
338+
client_socket.close()
339+
remote_socket.close()
340+
341+
class ThreadingTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
342+
daemon_threads = True
343+
allow_reuse_address = True
344+
345+
def main():
346+
HOST, PORT = "localhost", 8004
347+
348+
try:
349+
server = ThreadingTCPServer((HOST, PORT), ProxyHandler)
350+
print(f"Starting proxy server on {HOST}:{PORT}")
351+
server.serve_forever()
352+
except Exception as e:
353+
print(f"Could not start proxy server: {e}", file=sys.stderr)
354+
# The script `run_devrev_snapin_conformance_tests.sh` checks for exit code 69.
355+
sys.exit(69)
356+
357+
if __name__ == "__main__":
358+
main()

0 commit comments

Comments
 (0)