1313
1414
1515class CoordinationLock :
16- def __init__ (self , client , name : str , node_path : Optional [str ] = None ):
16+ def __init__ (
17+ self ,
18+ client ,
19+ name : str ,
20+ node_path : Optional [str ] = None ,
21+ count : int = 1 ,
22+ timeout_millis : int = 30000 ,
23+ ):
1724 self ._client = client
1825 self ._driver = client ._driver
1926 self ._name = name
2027 self ._node_path = node_path
2128
2229 self ._req_id : Optional [int ] = None
23- self ._count : int = 1
24- self ._timeout_millis : int = 30000
30+ self ._count : int = count
31+ self ._timeout_millis : int = timeout_millis
2532 self ._next_req_id : int = 1
2633
2734 self ._closed : asyncio .Event = asyncio .Event ()
@@ -32,6 +39,7 @@ def __init__(self, client, name: str, node_path: Optional[str] = None):
3239 self ._session_ready : asyncio .Event = asyncio .Event ()
3340 self ._reconnector = CoordinationReconnector (self )
3441
42+ self ._wait_timeout : float = self ._timeout_millis / 1000.0
3543
3644 def next_req_id (self ) -> int :
3745 r = self ._next_req_id
@@ -43,22 +51,51 @@ async def send(self, req):
4351 raise issues .Error ("Stream is not started yet" )
4452 await self ._stream .send (req )
4553
46- async def __aenter__ (self ):
47- if self .session_id is None :
48- if not self ._node_path :
49- raise issues .Error ("node_path is not set for CoordinationLock" )
50-
51- await self ._request_queue .put (
52- SessionStart (
53- path = self ._node_path ,
54- session_id = 0 ,
55- timeout_millis = 30000 ,
56- ).to_proto ()
57- )
54+ async def _start_session (self ):
55+ if self .session_id is not None :
56+ return
57+
58+ if not self ._node_path :
59+ raise issues .Error ("node_path is not set for CoordinationLock" )
5860
59- self ._reconnector .start ()
61+ await self ._request_queue .put (
62+ SessionStart (
63+ path = self ._node_path ,
64+ session_id = 0 ,
65+ timeout_millis = self ._timeout_millis ,
66+ ).to_proto ()
67+ )
68+
69+ self ._reconnector .start ()
70+ await self ._session_ready .wait ()
71+
72+ async def _stop_session (self ):
73+ self ._closed .set ()
74+ if self ._stream :
75+ await self ._stream .close ()
76+ self ._stream = None
6077
61- await self ._session_ready .wait ()
78+ await self ._reconnector .stop ()
79+ self .session_id = None
80+ self ._node_path = None
81+
82+ async def _wait_for_acquire_response (self ):
83+ try :
84+ while True :
85+ resp = await asyncio .wait_for (
86+ self ._stream ._incoming_queue .get (),
87+ timeout = self ._wait_timeout ,
88+ )
89+ acquire_resp = FromServer .from_proto (resp ).acquire_semaphore_result
90+ if acquire_resp and acquire_resp .req_id == self ._req_id :
91+ return acquire_resp
92+ except asyncio .TimeoutError :
93+ raise issues .Error (
94+ f"Timeout waiting for lock { self ._name } acquisition"
95+ )
96+
97+ async def __aenter__ (self ):
98+ await self ._start_session ()
6299
63100 self ._req_id = self .next_req_id ()
64101
@@ -78,29 +115,18 @@ async def __aenter__(self):
78115 else :
79116 raise issues .Error (f"Failed to acquire lock: { resp .issues } " )
80117
81- async def _wait_for_acquire_response (self ):
82- try :
83- while True :
84- resp = await asyncio .wait_for (
85- self ._stream ._incoming_queue .get (),
86- timeout = 30.0 ,
87- )
88- acquire_resp = FromServer .from_proto (resp ).acquire_semaphore_result
89- if acquire_resp and acquire_resp .req_id == self ._req_id :
90- return acquire_resp
91- except asyncio .TimeoutError :
92- raise issues .Error (f"Timeout waiting for lock { self ._name } acquisition" )
93-
94118 async def __aexit__ (self , exc_type , exc , tb ):
95119 if self ._req_id is not None :
96- req = ReleaseSemaphore (req_id = self ._req_id , name = self ._name )
97- await self .send (req )
120+ try :
121+ req = ReleaseSemaphore (req_id = self ._req_id , name = self ._name )
122+ await self .send (req )
123+ except issues .Error :
124+ pass
98125
99- self ._closed .set ()
100- if self ._stream :
101- await self ._stream .close ()
102- self ._stream = None
126+ await self ._stop_session ()
103127
104- await self ._reconnector .stop ()
105- self .session_id = None
106- self ._node_path = None
128+ async def acquire (self ):
129+ return await self .__aenter__ ()
130+
131+ async def release (self ):
132+ await self .__aexit__ (None , None , None )
0 commit comments