@@ -31,7 +31,7 @@ def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:
3131
3232
3333class Socket :
34- def __init__ (self , url : str , params : Dict [str , Any ] = {}, hb_interval : int = 5 ) -> None :
34+ def __init__ (self , url : str , auto_reconnect : bool = False , params : Dict [str , Any ] = {}, hb_interval : int = 5 ) -> None :
3535 """
3636 `Socket` is the abstraction for an actual socket connection that receives and 'reroutes' `Message` according to its `topic` and `event`.
3737 Socket-Channel has a 1-many relationship.
@@ -47,6 +47,7 @@ def __init__(self, url: str, params: Dict[str, Any] = {}, hb_interval: int = 5)
4747 self .hb_interval = hb_interval
4848 self .ws_connection : websockets .client .WebSocketClientProtocol
4949 self .kept_alive = False
50+ self .auto_reconnect = auto_reconnect
5051
5152 self .channels = cast (defaultdict [str , List [Channel ]], self .channels )
5253
@@ -79,8 +80,15 @@ async def _listen(self) -> None:
7980 if cl .event == msg .event :
8081 cl .callback (msg .payload )
8182 except websockets .exceptions .ConnectionClosed :
82- logging .exception ("Connection closed" )
83- break
83+ if self .auto_reconnect :
84+ logging .info ("Connection with server closed, trying to reconnect..." )
85+ await self ._connect ()
86+ for topic , channels in self .channels .items ():
87+ for channel in channels :
88+ await channel ._join ()
89+ else :
90+ logging .exception ("Connection with the server closed." )
91+ break
8492
8593 def connect (self ) -> None :
8694 """
@@ -116,8 +124,12 @@ async def _keep_alive(self) -> None:
116124 await self .ws_connection .send (json .dumps (data ))
117125 await asyncio .sleep (self .hb_interval )
118126 except websockets .exceptions .ConnectionClosed :
119- logging .exception ("Connection with server closed" )
120- break
127+ if self .auto_reconnect :
128+ logging .info ("Connection with server closed, trying to reconnect..." )
129+ await self ._connect ()
130+ else :
131+ logging .exception ("Connection with the server closed." )
132+ break
121133
122134 @ensure_connection
123135 def set_channel (self , topic : str ) -> Channel :
0 commit comments