@@ -18,41 +18,56 @@ class Protocol(Enum):
1818 HTTP = "HTTP"
1919 Unknown = "Unknown"
2020 DETECT = "Detect"
21+
2122 @staticmethod
2223 def get (data : bytes ):
23- if b' HTTP/1.1' in data :
24+ if b" HTTP/1.1" in data :
2425 return Protocol .HTTP
2526 if check_port_key == data :
2627 return Protocol .DETECT
2728 return Protocol .Unknown
29+
30+
2831@dataclass
2932class ProxyClient :
30- proxy : ' Proxy'
33+ proxy : " Proxy"
3134 origin : Client
3235 target : Client
33- before : bytes = b''
36+ before : bytes = b""
3437 closed : bool = False
38+
3539 def start (self ):
3640 self ._task_origin = Timer .delay (self .process_origin , (), 0 )
37- self ._task_target = Timer .delay (self .process_target , (), 0 )
41+ self ._task_target = Timer .delay (self .process_target , (), 0 )
42+
3843 async def process_origin (self ):
3944 try :
4045 self .target .write (self .before )
41- while (buffer := await self .origin .read (IO_BUFFER , timeout = TIMEOUT )) and not self .origin .is_closed () and not self .origin .is_closed ():
46+ while (
47+ (buffer := await self .origin .read (IO_BUFFER , timeout = TIMEOUT ))
48+ and not self .origin .is_closed ()
49+ and not self .origin .is_closed ()
50+ ):
4251 self .target .write (buffer )
43- self .before = b''
52+ self .before = b""
4453 await self .target .writer .drain ()
4554 except :
4655 ...
4756 self .close ()
57+
4858 async def process_target (self ):
4959 try :
50- while (buffer := await self .target .read (IO_BUFFER , timeout = TIMEOUT )) and not self .target .is_closed () and not self .target .is_closed ():
60+ while (
61+ (buffer := await self .target .read (IO_BUFFER , timeout = TIMEOUT ))
62+ and not self .target .is_closed ()
63+ and not self .target .is_closed ()
64+ ):
5165 self .origin .write (buffer )
5266 await self .origin .writer .drain ()
5367 except :
5468 ...
5569 self .close ()
70+
5671 def close (self ):
5772 if not self .closed :
5873 if not self .origin .is_closed ():
@@ -61,47 +76,63 @@ def close(self):
6176 self .target .close ()
6277 self .closed = True
6378 self .proxy .disconnect (self )
79+
80+
6481class Proxy :
6582 def __init__ (self ) -> None :
6683 self ._tables : list [ProxyClient ] = []
84+
6785 async def connect (self , origin : Client , target : Client , before : bytes ):
6886 client = ProxyClient (self , origin , target , before )
6987 self ._tables .append (client )
7088 client .start ()
89+
7190 def disconnect (self , client : ProxyClient ):
7291 if client not in self ._tables :
7392 return
7493 self ._tables .remove (client )
7594
95+
7696ssl_server : Optional [asyncio .Server ] = None
7797server : Optional [asyncio .Server ] = None
7898proxy : Proxy = Proxy ()
7999restart = False
80100check_port_key = os .urandom (8 )
81- PORT : int = Config .get ("web.port" )
82- TIMEOUT : int = Config .get ("advanced.timeout" )
101+ PORT : int = Config .get ("web.port" )
102+ TIMEOUT : int = Config .get ("advanced.timeout" )
83103SSL_PORT : int = Config .get ("web.ssl_port" )
84104PROTOCOL_HEADER_BYTES = Config .get ("advanced.header_bytes" )
85105IO_BUFFER : int = Config .get ("advanced.io_buffer" )
86106
107+
87108async def _handle_ssl (reader : asyncio .StreamReader , writer : asyncio .StreamWriter ):
88109 return await _handle_process (Client (reader , writer ), True )
89110
111+
90112async def _handle (reader : asyncio .StreamReader , writer : asyncio .StreamWriter ):
91113 return await _handle_process (Client (reader , writer ))
92114
115+
93116async def _handle_process (client : Client , ssl : bool = False ):
94117 global ssl_server
95118 proxying = False
96119 try :
97- while (header := await client .read (PROTOCOL_HEADER_BYTES , timeout = 30 )) and not client .is_closed ():
120+ while (
121+ header := await client .read (PROTOCOL_HEADER_BYTES , timeout = 30 )
122+ ) and not client .is_closed ():
98123 protocol = Protocol .get (header )
99124 if protocol == Protocol .DETECT :
100125 client .write (check_port_key )
101126 await client .writer .drain ()
102127 break
103128 if protocol == Protocol .Unknown and not ssl and ssl_server :
104- target = Client (* (await asyncio .open_connection ("127.0.0.1" , ssl_server .sockets [0 ].getsockname ()[1 ])))
129+ target = Client (
130+ * (
131+ await asyncio .open_connection (
132+ "127.0.0.1" , ssl_server .sockets [0 ].getsockname ()[1 ]
133+ )
134+ )
135+ )
105136 proxying = True
106137 await proxy .connect (client , target , header )
107138 break
@@ -117,40 +148,64 @@ async def _handle_process(client: Client, ssl: bool = False):
117148 logger .debug (traceback .format_exc ())
118149 if not proxying and not client .is_closed ():
119150 client .close ()
151+
152+
120153async def check_ports ():
121154 global ssl_server , server , client_side_ssl , restart , check_port_key
122155 while 1 :
123156 ports : list [tuple [asyncio .Server , ssl .SSLContext | None ]] = []
124- for service in ((server , None ), (ssl_server , client_side_ssl if get_loads () != 0 else None )):
157+ for service in (
158+ (server , None ),
159+ (ssl_server , client_side_ssl if get_loads () != 0 else None ),
160+ ):
125161 if not service [0 ]:
126162 continue
127163 ports .append ((service [0 ], service [1 ]))
128164 closed = False
129165 for port in ports :
130166 try :
131- client = Client (* (await asyncio .open_connection ('127.0.0.1' , port [0 ].sockets [0 ].getsockname ()[1 ], ssl = port [1 ])))
167+ client = Client (
168+ * (
169+ await asyncio .open_connection (
170+ "127.0.0.1" ,
171+ port [0 ].sockets [0 ].getsockname ()[1 ],
172+ ssl = port [1 ],
173+ )
174+ )
175+ )
132176 client .write (check_port_key )
133177 await client .writer .drain ()
134178 key = await client .read (len (check_port_key ), 5 )
135179 except :
136- logger .warn (f"Port { port [0 ].sockets [0 ].getsockname ()[1 ]} is shutdown now! Now restarting the port!" )
180+ logger .warn (
181+ f"Port { port [0 ].sockets [0 ].getsockname ()[1 ]} has been closed! Reopening..."
182+ )
137183 logger .error (traceback .format_exc ())
138184 closed = True
139185 if closed :
140186 restart = True
141187 for port in ports :
142188 port [0 ].close ()
143189 await asyncio .sleep (5 )
190+
191+
144192async def main ():
145193 global ssl_server , server , server_side_ssl , restart
146194 await web .init ()
147195 Timer .delay (check_ports , (), 5 )
148196 while 1 :
149197 try :
150198 server = await asyncio .start_server (_handle , port = PORT )
151- ssl_server = await asyncio .start_server (_handle_ssl , port = 0 if SSL_PORT == PORT else SSL_PORT , ssl = server_side_ssl if get_loads () != 0 else None )
152- logger .info (f"Listening server on { PORT } " )
153- logger .info (f"Listening server on { ssl_server .sockets [0 ].getsockname ()[1 ]} Loaded certificates: { get_loads ()} " )
199+ ssl_server = await asyncio .start_server (
200+ _handle_ssl ,
201+ port = 0 if SSL_PORT == PORT else SSL_PORT ,
202+ ssl = server_side_ssl if get_loads () != 0 else None ,
203+ )
204+ logger .info (f"Listening server on port { PORT } ." )
205+ logger .info (
206+ f"Listening server on { ssl_server .sockets [0 ].getsockname ()[1 ]} ."
207+ )
208+ logger .info (f"Loaded { get_loads ()} certificates!" )
154209 async with server , ssl_server :
155210 await asyncio .gather (server .serve_forever (), ssl_server .serve_forever ())
156211 except asyncio .CancelledError :
@@ -159,8 +214,8 @@ async def main():
159214 server .close ()
160215 restart = False
161216 else :
162- logger .info ("Shutdown web service" )
163- await web .close ()
217+ logger .info ("Shutting down web service... " )
218+ web .close ()
164219 break
165220 except :
166221 if server :
@@ -170,4 +225,4 @@ async def main():
170225
171226
172227def init ():
173- asyncio .run (main ())
228+ asyncio .run (main ())
0 commit comments