1+ from dataclasses import dataclass
2+ from enum import Enum
3+ import os
4+ import traceback
5+ from .config import Config
6+ from .timer import Timer
7+ from .utils import Client
8+ from .certificate import *
9+ from . import web
10+ from .logger import logger
11+
12+ import asyncio
13+ import ssl
14+ from typing import Optional
15+
16+
17+ class Protocol (Enum ):
18+ HTTP = "HTTP"
19+ Unknown = "Unknown"
20+ DETECT = "Detect"
21+ @staticmethod
22+ def get (data : bytes ):
23+ if b'HTTP/1.1' in data :
24+ return Protocol .HTTP
25+ if check_port_key == data :
26+ return Protocol .DETECT
27+ return Protocol .Unknown
28+ @dataclass
29+ class ProxyClient :
30+ proxy : 'Proxy'
31+ origin : Client
32+ target : Client
33+ before : bytes = b''
34+ closed : bool = False
35+ def start (self ):
36+ self ._task_origin = Timer .delay (self .process_origin , (), 0 )
37+ self ._task_target = Timer .delay (self .process_target , (), 0 )
38+ async def process_origin (self ):
39+ try :
40+ self .target .write (self .before )
41+ while (buffer := await self .origin .read (IO_BUFFER , timeout = TIMEOUT )) and not self .origin .is_closed () and not self .origin .is_closed ():
42+ self .target .write (buffer )
43+ self .before = b''
44+ await self .target .writer .drain ()
45+ except :
46+ ...
47+ self .close ()
48+ async def process_target (self ):
49+ try :
50+ while (buffer := await self .target .read (IO_BUFFER , timeout = TIMEOUT )) and not self .target .is_closed () and not self .target .is_closed ():
51+ self .origin .write (buffer )
52+ await self .origin .writer .drain ()
53+ except :
54+ ...
55+ self .close ()
56+ def close (self ):
57+ if not self .closed :
58+ if not self .origin .is_closed ():
59+ self .origin .close ()
60+ if not self .target .is_closed ():
61+ self .target .close ()
62+ self .closed = True
63+ self .proxy .disconnect (self )
64+ class Proxy :
65+ def __init__ (self ) -> None :
66+ self ._tables : list [ProxyClient ] = []
67+ async def connect (self , origin : Client , target : Client , before : bytes ):
68+ client = ProxyClient (self , origin , target , before )
69+ self ._tables .append (client )
70+ client .start ()
71+ def disconnect (self , client : ProxyClient ):
72+ if client not in self ._tables :
73+ return
74+ self ._tables .remove (client )
75+
76+ ssl_server : Optional [asyncio .Server ] = None
77+ server : Optional [asyncio .Server ] = None
78+ proxy : Proxy = Proxy ()
79+ restart = False
80+ check_port_key = os .urandom (8 )
81+ PORT : int = Config .get ("web.port" )
82+ TIMEOUT : int = Config .get ("advanced.timeout" )
83+ SSL_PORT : int = Config .get ("web.ssl_port" )
84+ PROTOCOL_HEADER_BYTES = Config .get ("advanced.header_bytes" )
85+ IO_BUFFER : int = Config .get ("advanced.io_buffer" )
86+
87+ async def _handle_ssl (reader : asyncio .StreamReader , writer : asyncio .StreamWriter ):
88+ return await _handle_process (Client (reader , writer ), True )
89+
90+ async def _handle (reader : asyncio .StreamReader , writer : asyncio .StreamWriter ):
91+ return await _handle_process (Client (reader , writer ))
92+
93+ async def _handle_process (client : Client , ssl : bool = False ):
94+ global ssl_server
95+ proxying = False
96+ try :
97+ while (header := await client .read (PROTOCOL_HEADER_BYTES , timeout = 30 )) and not client .is_closed ():
98+ protocol = Protocol .get (header )
99+ if protocol == Protocol .DETECT :
100+ client .write (check_port_key )
101+ await client .writer .drain ()
102+ break
103+ if protocol == Protocol .Unknown and not ssl and ssl_server :
104+ target = Client (* (await asyncio .open_connection ("127.0.0.1" , ssl_server .sockets [0 ].getsockname ()[1 ])))
105+ proxying = True
106+ await proxy .connect (client , target , header )
107+ break
108+ elif protocol == Protocol .HTTP :
109+ await web .handle (header , client )
110+ except (
111+ TimeoutError ,
112+ asyncio .exceptions .IncompleteReadError ,
113+ ConnectionResetError ,
114+ ):
115+ ...
116+ except :
117+ logger .debug (traceback .format_exc ())
118+ if not proxying and not client .is_closed ():
119+ client .close ()
120+ async def check_ports ():
121+ global ssl_server , server , client_side_ssl , restart , check_port_key
122+ while 1 :
123+ ports : list [tuple [asyncio .Server , ssl .SSLContext | None ]] = []
124+ for service in ((server , None ), (ssl_server , client_side_ssl if get_loads () != 0 else None )):
125+ if not service [0 ]:
126+ continue
127+ ports .append ((service [0 ], service [1 ]))
128+ closed = False
129+ for port in ports :
130+ try :
131+ client = Client (* (await asyncio .open_connection ('127.0.0.1' , port [0 ].sockets [0 ].getsockname ()[1 ], ssl = port [1 ])))
132+ client .write (check_port_key )
133+ await client .writer .drain ()
134+ key = await client .read (len (check_port_key ), 5 )
135+ except :
136+ logger .warn (f"Port { port [0 ].sockets [0 ].getsockname ()[1 ]} is shutdown now! Now restarting the port!" )
137+ logger .error (traceback .format_exc ())
138+ closed = True
139+ if closed :
140+ restart = True
141+ for port in ports :
142+ port [0 ].close ()
143+ await asyncio .sleep (5 )
144+ async def main ():
145+ global ssl_server , server , server_side_ssl , restart
146+ await web .init ()
147+ Timer .delay (check_ports , (), 5 )
148+ while 1 :
149+ try :
150+ server = await asyncio .start_server (_handle , port = PORT )
151+ ssl_server = await asyncio .start_server (_handle_ssl , port = 0 if SSL_PORT == PORT else SSL_PORT , ssl = server_side_ssl if get_loads () != 0 else None )
152+ logger .info (f"Listening server on { PORT } " )
153+ logger .info (f"Listening server on { ssl_server .sockets [0 ].getsockname ()[1 ]} Loaded certificates: { get_loads ()} " )
154+ async with server , ssl_server :
155+ await asyncio .gather (server .serve_forever (), ssl_server .serve_forever ())
156+ except asyncio .CancelledError :
157+ if restart :
158+ if server :
159+ server .close ()
160+ restart = False
161+ else :
162+ logger .info ("Shutdown web service" )
163+ await web .close ()
164+ break
165+ except :
166+ if server :
167+ server .close ()
168+ logger .error (traceback .format_exc ())
169+ await asyncio .sleep (2 )
170+
171+
172+ def init ():
173+ asyncio .run (main ())
0 commit comments