Skip to content

Commit 194942f

Browse files
committed
refactor logic lock + reconnecor + stream - > lock should be resosible only for lock logic (release and so on) , stream about session work -> making stream_stream canal and so on, reconnector -> make this stream stable
1 parent 31e4fb6 commit 194942f

File tree

3 files changed

+106
-52
lines changed

3 files changed

+106
-52
lines changed

ydb/aio/coordination/lock.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
AcquireSemaphore,
77
ReleaseSemaphore,
88
FromServer,
9-
SessionStart,
109
)
1110
from ydb.aio.coordination.stream import CoordinationStream
1211
from ydb.aio.coordination.reconnector import CoordinationReconnector
@@ -31,13 +30,15 @@ def __init__(
3130
self._timeout_millis: int = timeout_millis
3231
self._next_req_id: int = 1
3332

34-
self._closed: asyncio.Event = asyncio.Event()
3533
self._request_queue: asyncio.Queue = asyncio.Queue()
3634
self._stream: Optional[CoordinationStream] = None
37-
self._reader_task: Optional[asyncio.Task] = None
38-
self.session_id: Optional[int] = None
39-
self._session_ready: asyncio.Event = asyncio.Event()
40-
self._reconnector = CoordinationReconnector(self)
35+
36+
self._reconnector = CoordinationReconnector(
37+
driver=self._driver,
38+
request_queue=self._request_queue,
39+
node_path=self._node_path,
40+
timeout_millis=self._timeout_millis,
41+
)
4142

4243
self._wait_timeout: float = self._timeout_millis / 1000.0
4344

@@ -51,33 +52,17 @@ async def send(self, req):
5152
raise issues.Error("Stream is not started yet")
5253
await self._stream.send(req)
5354

54-
async def _start_session(self):
55-
if self.session_id is not None:
55+
async def _ensure_session(self):
56+
if self._stream is not None and self._stream.session_id is not None:
5657
return
5758

5859
if not self._node_path:
5960
raise issues.Error("node_path is not set for CoordinationLock")
6061

61-
await self._request_queue.put(
62-
SessionStart(
63-
path=self._node_path,
64-
session_id=0,
65-
timeout_millis=self._timeout_millis,
66-
).to_proto()
67-
)
68-
6962
self._reconnector.start()
70-
await self._session_ready.wait()
63+
await self._reconnector.wait_ready()
7164

72-
async def _stop_session(self):
73-
self._closed.set()
74-
if self._stream:
75-
await self._stream.close()
76-
self._stream = None
77-
78-
await self._reconnector.stop()
79-
self.session_id = None
80-
self._node_path = None
65+
self._stream = self._reconnector.get_stream()
8166

8267
async def _wait_for_acquire_response(self):
8368
try:
@@ -95,7 +80,7 @@ async def _wait_for_acquire_response(self):
9580
)
9681

9782
async def __aenter__(self):
98-
await self._start_session()
83+
await self._ensure_session()
9984

10085
self._req_id = self.next_req_id()
10186

@@ -123,7 +108,9 @@ async def __aexit__(self, exc_type, exc, tb):
123108
except issues.Error:
124109
pass
125110

126-
await self._stop_session()
111+
await self._reconnector.stop()
112+
self._stream = None
113+
self._node_path = None
127114

128115
async def acquire(self):
129116
return await self.__aenter__()

ydb/aio/coordination/reconnector.py

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,51 +6,99 @@
66

77

88
class CoordinationReconnector:
9-
def __init__(self, client):
10-
self._client = client
9+
def __init__(
10+
self,
11+
driver,
12+
request_queue: asyncio.Queue,
13+
node_path: str,
14+
timeout_millis: int,
15+
):
16+
self._driver = driver
17+
self._request_queue = request_queue
18+
self._node_path = node_path
19+
self._timeout_millis = timeout_millis
20+
1121
self._task: Optional[asyncio.Task] = None
22+
self._stream: Optional[CoordinationStream] = None
23+
24+
self._ready = asyncio.Event()
25+
self._stopped = False
26+
1227
self._first_error: asyncio.Future = asyncio.get_running_loop().create_future()
1328
self._state_changed = asyncio.Event()
1429

1530
def start(self):
31+
if self._stopped:
32+
return
1633
if self._task is None or self._task.done():
1734
self._task = asyncio.create_task(self._connection_loop())
1835

1936
async def stop(self):
37+
self._stopped = True
38+
2039
if self._task:
2140
self._task.cancel()
2241
with contextlib.suppress(asyncio.CancelledError):
2342
await self._task
2443
self._task = None
2544

26-
if self._client._stream:
27-
await self._client._stream.close()
28-
self._client._stream = None
45+
if self._stream:
46+
await self._stream.close()
47+
self._stream = None
48+
49+
self._ready.clear()
50+
51+
async def wait_ready(self):
52+
await self._ready.wait()
53+
54+
def get_stream(self) -> CoordinationStream:
55+
if self._stream is None or self._stream.session_id is None:
56+
raise RuntimeError("Coordination stream is not ready")
57+
return self._stream
2958

3059
async def _connection_loop(self):
3160
attempt = 0
3261
backoff = 0.1
33-
while not self._client._closed.is_set():
62+
63+
while not self._stopped:
3464
try:
35-
self._client._stream = CoordinationStream(
36-
self._client._driver, self._client._request_queue
65+
stream = CoordinationStream(
66+
self._driver,
67+
self._request_queue,
68+
)
69+
70+
await stream.start_session(
71+
self._node_path,
72+
self._timeout_millis,
73+
)
74+
75+
self._stream = stream
76+
self._ready.set()
77+
78+
await asyncio.wait(
79+
stream._background_tasks,
80+
return_when=asyncio.FIRST_EXCEPTION,
3781
)
38-
await self._client._stream.start()
39-
self._client.session_id = self._client._stream.session_id
40-
self._client._session_ready.set()
4182

42-
await asyncio.wait(self._client._stream._background_tasks, return_when=asyncio.FIRST_EXCEPTION)
4383
except asyncio.CancelledError:
44-
raise
84+
break
85+
4586
except Exception as exc:
46-
self._client.session_id = None
47-
self._client._session_ready.clear()
87+
self._ready.clear()
88+
self._stream = None
89+
4890
if not self._first_error.done():
4991
self._first_error.set_result(exc)
92+
self._state_changed.set()
93+
94+
if self._stopped:
95+
break
96+
5097
await asyncio.sleep(backoff)
5198
attempt += 1
5299
backoff = min(backoff * 2, 3.0)
100+
53101
finally:
54-
if self._client._stream:
55-
await self._client._stream.close()
56-
self._client._stream = None
102+
if self._stream:
103+
await self._stream.close()
104+
self._stream = None

ydb/aio/coordination/stream.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Set, Optional
44

55
from ydb import issues, _apis
6-
from ydb._grpc.grpcwrapper.ydb_coordination import FromServer, Ping
6+
from ydb._grpc.grpcwrapper.ydb_coordination import FromServer, Ping, SessionStart
77

88

99
class CoordinationStream:
@@ -19,7 +19,21 @@ def __init__(self, driver: "ydb.aio.Driver", request_queue: asyncio.Queue):
1919
self.session_id: Optional[int] = None
2020
self._started: bool = False
2121

22-
async def start(self):
22+
async def start_session(self, path: str, timeout_millis: int):
23+
if self._started:
24+
raise issues.Error("CoordinationStream already started")
25+
26+
await self.send(
27+
SessionStart(
28+
path=path,
29+
session_id=0,
30+
timeout_millis=timeout_millis,
31+
).to_proto()
32+
)
33+
34+
await self._start_internal()
35+
36+
async def _start_internal(self):
2337
if self._started:
2438
raise issues.Error("CoordinationStream already started")
2539
self._started = True
@@ -30,13 +44,16 @@ async def request_gen():
3044
yield req
3145

3246
self._stream = await self._driver(
33-
request_gen(), _apis.CoordinationService.Stub, _apis.CoordinationService.Session
47+
request_gen(),
48+
_apis.CoordinationService.Stub,
49+
_apis.CoordinationService.Session,
3450
)
3551

3652
try:
3753
async for resp in self._stream:
38-
if FromServer.from_proto(resp).session_started:
39-
self.session_id = FromServer.from_proto(resp).session_started
54+
fs = FromServer.from_proto(resp)
55+
if fs.session_started:
56+
self.session_id = fs.session_started
4057
break
4158
except Exception as exc:
4259
self._set_first_error(exc)
@@ -84,6 +101,8 @@ async def close(self):
84101
self._stream.close()
85102
except Exception:
86103
pass
104+
105+
self.session_id = None
87106
self._state_changed.set()
88107

89108
def _set_first_error(self, exc: Exception):
@@ -98,4 +117,4 @@ def _get_first_error(self):
98117
def _check_error(self):
99118
err = self._get_first_error()
100119
if err:
101-
raise err
120+
raise err

0 commit comments

Comments
 (0)