Skip to content

Commit 31e4fb6

Browse files
committed
fix lock logic
1 parent 89fe18d commit 31e4fb6

File tree

1 file changed

+65
-39
lines changed

1 file changed

+65
-39
lines changed

ydb/aio/coordination/lock.py

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,22 @@
1313

1414

1515
class CoordinationLock:
16-
def __init__(self, client, name: str, node_path: Optional[str] = None):
16+
def __init__(
17+
self,
18+
client,
19+
name: str,
20+
node_path: Optional[str] = None,
21+
count: int = 1,
22+
timeout_millis: int = 30000,
23+
):
1724
self._client = client
1825
self._driver = client._driver
1926
self._name = name
2027
self._node_path = node_path
2128

2229
self._req_id: Optional[int] = None
23-
self._count: int = 1
24-
self._timeout_millis: int = 30000
30+
self._count: int = count
31+
self._timeout_millis: int = timeout_millis
2532
self._next_req_id: int = 1
2633

2734
self._closed: asyncio.Event = asyncio.Event()
@@ -32,6 +39,7 @@ def __init__(self, client, name: str, node_path: Optional[str] = None):
3239
self._session_ready: asyncio.Event = asyncio.Event()
3340
self._reconnector = CoordinationReconnector(self)
3441

42+
self._wait_timeout: float = self._timeout_millis / 1000.0
3543

3644
def next_req_id(self) -> int:
3745
r = self._next_req_id
@@ -43,22 +51,51 @@ async def send(self, req):
4351
raise issues.Error("Stream is not started yet")
4452
await self._stream.send(req)
4553

46-
async def __aenter__(self):
47-
if self.session_id is None:
48-
if not self._node_path:
49-
raise issues.Error("node_path is not set for CoordinationLock")
50-
51-
await self._request_queue.put(
52-
SessionStart(
53-
path=self._node_path,
54-
session_id=0,
55-
timeout_millis=30000,
56-
).to_proto()
57-
)
54+
async def _start_session(self):
55+
if self.session_id is not None:
56+
return
57+
58+
if not self._node_path:
59+
raise issues.Error("node_path is not set for CoordinationLock")
5860

59-
self._reconnector.start()
61+
await self._request_queue.put(
62+
SessionStart(
63+
path=self._node_path,
64+
session_id=0,
65+
timeout_millis=self._timeout_millis,
66+
).to_proto()
67+
)
68+
69+
self._reconnector.start()
70+
await self._session_ready.wait()
71+
72+
async def _stop_session(self):
73+
self._closed.set()
74+
if self._stream:
75+
await self._stream.close()
76+
self._stream = None
6077

61-
await self._session_ready.wait()
78+
await self._reconnector.stop()
79+
self.session_id = None
80+
self._node_path = None
81+
82+
async def _wait_for_acquire_response(self):
83+
try:
84+
while True:
85+
resp = await asyncio.wait_for(
86+
self._stream._incoming_queue.get(),
87+
timeout=self._wait_timeout,
88+
)
89+
acquire_resp = FromServer.from_proto(resp).acquire_semaphore_result
90+
if acquire_resp and acquire_resp.req_id == self._req_id:
91+
return acquire_resp
92+
except asyncio.TimeoutError:
93+
raise issues.Error(
94+
f"Timeout waiting for lock {self._name} acquisition"
95+
)
96+
97+
async def __aenter__(self):
98+
await self._start_session()
6299

63100
self._req_id = self.next_req_id()
64101

@@ -78,29 +115,18 @@ async def __aenter__(self):
78115
else:
79116
raise issues.Error(f"Failed to acquire lock: {resp.issues}")
80117

81-
async def _wait_for_acquire_response(self):
82-
try:
83-
while True:
84-
resp = await asyncio.wait_for(
85-
self._stream._incoming_queue.get(),
86-
timeout=30.0,
87-
)
88-
acquire_resp = FromServer.from_proto(resp).acquire_semaphore_result
89-
if acquire_resp and acquire_resp.req_id == self._req_id:
90-
return acquire_resp
91-
except asyncio.TimeoutError:
92-
raise issues.Error(f"Timeout waiting for lock {self._name} acquisition")
93-
94118
async def __aexit__(self, exc_type, exc, tb):
95119
if self._req_id is not None:
96-
req = ReleaseSemaphore(req_id=self._req_id, name=self._name)
97-
await self.send(req)
120+
try:
121+
req = ReleaseSemaphore(req_id=self._req_id, name=self._name)
122+
await self.send(req)
123+
except issues.Error:
124+
pass
98125

99-
self._closed.set()
100-
if self._stream:
101-
await self._stream.close()
102-
self._stream = None
126+
await self._stop_session()
103127

104-
await self._reconnector.stop()
105-
self.session_id = None
106-
self._node_path = None
128+
async def acquire(self):
129+
return await self.__aenter__()
130+
131+
async def release(self):
132+
await self.__aexit__(None, None, None)

0 commit comments

Comments
 (0)