11from __future__ import annotations
22
3+ import functools
34import itertools
45from collections .abc import AsyncIterator
56from collections .abc import Generator
67from collections .abc import Iterator
78from collections .abc import Sequence
9+ from inspect import iscoroutinefunction
810from typing import TYPE_CHECKING
911from typing import Any
12+ from typing import Callable
1013from typing import Union
1114
1215import ydb
2023from .utils import maybe_get_current_trace_id
2124
2225if TYPE_CHECKING :
26+ from .connections import AsyncConnection
27+ from .connections import Connection
28+
2329 ParametersType = dict [
2430 str ,
2531 Union [
@@ -34,6 +40,34 @@ def _get_column_type(type_obj: Any) -> str:
3440 return str (ydb .convert .type_to_native (type_obj ))
3541
3642
43+ def invalidate_cursor_on_ydb_error (func : Callable ) -> Callable :
44+ if iscoroutinefunction (func ):
45+
46+ @functools .wraps (func )
47+ async def awrapper (
48+ self : AsyncCursor , * args : tuple , ** kwargs : dict
49+ ) -> Any :
50+ try :
51+ return await func (self , * args , ** kwargs )
52+ except ydb .Error :
53+ self ._state = CursorStatus .finished
54+ await self ._connection ._invalidate_session ()
55+ raise
56+
57+ return awrapper
58+
59+ @functools .wraps (func )
60+ def wrapper (self : Cursor , * args : tuple , ** kwargs : dict ) -> Any :
61+ try :
62+ return func (self , * args , ** kwargs )
63+ except ydb .Error :
64+ self ._state = CursorStatus .finished
65+ self ._connection ._invalidate_session ()
66+ raise
67+
68+ return wrapper
69+
70+
3771class BufferedCursor :
3872 def __init__ (self ) -> None :
3973 self .arraysize : int = 1
@@ -154,13 +188,15 @@ def _append_table_path_prefix(self, query: str) -> str:
154188class Cursor (BufferedCursor ):
155189 def __init__ (
156190 self ,
191+ connection : Connection ,
157192 session_pool : ydb .QuerySessionPool ,
158193 tx_mode : ydb .BaseQueryTxMode ,
159194 request_settings : ydb .BaseRequestSettings ,
160195 tx_context : ydb .QueryTxContext | None = None ,
161196 table_path_prefix : str = "" ,
162197 ) -> None :
163198 super ().__init__ ()
199+ self ._connection = connection
164200 self ._session_pool = session_pool
165201 self ._tx_mode = tx_mode
166202 self ._request_settings = request_settings
@@ -188,6 +224,7 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings:
188224 return settings
189225
190226 @handle_ydb_errors
227+ @invalidate_cursor_on_ydb_error
191228 def _execute_generic_query (
192229 self , query : str , parameters : ParametersType | None = None
193230 ) -> Iterator [ydb .convert .ResultSet ]:
@@ -205,6 +242,7 @@ def callee(
205242 return self ._session_pool .retry_operation_sync (callee )
206243
207244 @handle_ydb_errors
245+ @invalidate_cursor_on_ydb_error
208246 def _execute_session_query (
209247 self ,
210248 query : str ,
@@ -225,6 +263,7 @@ def callee(
225263 return self ._session_pool .retry_operation_sync (callee )
226264
227265 @handle_ydb_errors
266+ @invalidate_cursor_on_ydb_error
228267 def _execute_transactional_query (
229268 self ,
230269 tx_context : ydb .QueryTxContext ,
@@ -283,6 +322,7 @@ def executemany(
283322 self .execute (query , parameters )
284323
285324 @handle_ydb_errors
325+ @invalidate_cursor_on_ydb_error
286326 def nextset (self , replace_current : bool = True ) -> bool :
287327 if self ._stream is None :
288328 return False
@@ -328,13 +368,15 @@ def __exit__(
328368class AsyncCursor (BufferedCursor ):
329369 def __init__ (
330370 self ,
371+ connection : AsyncConnection ,
331372 session_pool : ydb .aio .QuerySessionPool ,
332373 tx_mode : ydb .BaseQueryTxMode ,
333374 request_settings : ydb .BaseRequestSettings ,
334375 tx_context : ydb .aio .QueryTxContext | None = None ,
335376 table_path_prefix : str = "" ,
336377 ) -> None :
337378 super ().__init__ ()
379+ self ._connection = connection
338380 self ._session_pool = session_pool
339381 self ._tx_mode = tx_mode
340382 self ._request_settings = request_settings
@@ -362,6 +404,7 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings:
362404 return settings
363405
364406 @handle_ydb_errors
407+ @invalidate_cursor_on_ydb_error
365408 async def _execute_generic_query (
366409 self , query : str , parameters : ParametersType | None = None
367410 ) -> AsyncIterator [ydb .convert .ResultSet ]:
@@ -379,6 +422,7 @@ async def callee(
379422 return await self ._session_pool .retry_operation_async (callee )
380423
381424 @handle_ydb_errors
425+ @invalidate_cursor_on_ydb_error
382426 async def _execute_session_query (
383427 self ,
384428 query : str ,
@@ -399,6 +443,7 @@ async def callee(
399443 return await self ._session_pool .retry_operation_async (callee )
400444
401445 @handle_ydb_errors
446+ @invalidate_cursor_on_ydb_error
402447 async def _execute_transactional_query (
403448 self ,
404449 tx_context : ydb .aio .QueryTxContext ,
@@ -457,6 +502,7 @@ async def executemany(
457502 await self .execute (query , parameters )
458503
459504 @handle_ydb_errors
505+ @invalidate_cursor_on_ydb_error
460506 async def nextset (self , replace_current : bool = True ) -> bool :
461507 if self ._stream is None :
462508 return False
0 commit comments