1
1
from __future__ import annotations
2
2
3
+ import functools
3
4
import itertools
4
5
from collections .abc import AsyncIterator
5
6
from collections .abc import Generator
6
7
from collections .abc import Iterator
7
8
from collections .abc import Sequence
9
+ from inspect import iscoroutinefunction
8
10
from typing import TYPE_CHECKING
9
11
from typing import Any
12
+ from typing import Callable
10
13
from typing import Union
11
14
12
15
import ydb
20
23
from .utils import maybe_get_current_trace_id
21
24
22
25
if TYPE_CHECKING :
26
+ from .connections import AsyncConnection
27
+ from .connections import Connection
28
+
23
29
ParametersType = dict [
24
30
str ,
25
31
Union [
@@ -34,6 +40,34 @@ def _get_column_type(type_obj: Any) -> str:
34
40
return str (ydb .convert .type_to_native (type_obj ))
35
41
36
42
43
+ def invalidate_cursor_on_ydb_error (func : Callable ) -> Callable :
44
+ if iscoroutinefunction (func ):
45
+
46
+ @functools .wraps (func )
47
+ async def awrapper (
48
+ self : AsyncCursor , * args : tuple , ** kwargs : dict
49
+ ) -> Any :
50
+ try :
51
+ return await func (self , * args , ** kwargs )
52
+ except ydb .Error :
53
+ self ._state = CursorStatus .finished
54
+ await self ._connection ._invalidate_session ()
55
+ raise
56
+
57
+ return awrapper
58
+
59
+ @functools .wraps (func )
60
+ def wrapper (self : Cursor , * args : tuple , ** kwargs : dict ) -> Any :
61
+ try :
62
+ return func (self , * args , ** kwargs )
63
+ except ydb .Error :
64
+ self ._state = CursorStatus .closed
65
+ self ._connection ._invalidate_session ()
66
+ raise
67
+
68
+ return wrapper
69
+
70
+
37
71
class BufferedCursor :
38
72
def __init__ (self ) -> None :
39
73
self .arraysize : int = 1
@@ -154,13 +188,15 @@ def _append_table_path_prefix(self, query: str) -> str:
154
188
class Cursor (BufferedCursor ):
155
189
def __init__ (
156
190
self ,
191
+ connection : Connection ,
157
192
session_pool : ydb .QuerySessionPool ,
158
193
tx_mode : ydb .BaseQueryTxMode ,
159
194
request_settings : ydb .BaseRequestSettings ,
160
195
tx_context : ydb .QueryTxContext | None = None ,
161
196
table_path_prefix : str = "" ,
162
197
) -> None :
163
198
super ().__init__ ()
199
+ self ._connection = connection
164
200
self ._session_pool = session_pool
165
201
self ._tx_mode = tx_mode
166
202
self ._request_settings = request_settings
@@ -188,6 +224,7 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings:
188
224
return settings
189
225
190
226
@handle_ydb_errors
227
+ @invalidate_cursor_on_ydb_error
191
228
def _execute_generic_query (
192
229
self , query : str , parameters : ParametersType | None = None
193
230
) -> Iterator [ydb .convert .ResultSet ]:
@@ -205,6 +242,7 @@ def callee(
205
242
return self ._session_pool .retry_operation_sync (callee )
206
243
207
244
@handle_ydb_errors
245
+ @invalidate_cursor_on_ydb_error
208
246
def _execute_session_query (
209
247
self ,
210
248
query : str ,
@@ -225,6 +263,7 @@ def callee(
225
263
return self ._session_pool .retry_operation_sync (callee )
226
264
227
265
@handle_ydb_errors
266
+ @invalidate_cursor_on_ydb_error
228
267
def _execute_transactional_query (
229
268
self ,
230
269
tx_context : ydb .QueryTxContext ,
@@ -283,6 +322,7 @@ def executemany(
283
322
self .execute (query , parameters )
284
323
285
324
@handle_ydb_errors
325
+ @invalidate_cursor_on_ydb_error
286
326
def nextset (self , replace_current : bool = True ) -> bool :
287
327
if self ._stream is None :
288
328
return False
@@ -328,13 +368,15 @@ def __exit__(
328
368
class AsyncCursor (BufferedCursor ):
329
369
def __init__ (
330
370
self ,
371
+ connection : AsyncConnection ,
331
372
session_pool : ydb .aio .QuerySessionPool ,
332
373
tx_mode : ydb .BaseQueryTxMode ,
333
374
request_settings : ydb .BaseRequestSettings ,
334
375
tx_context : ydb .aio .QueryTxContext | None = None ,
335
376
table_path_prefix : str = "" ,
336
377
) -> None :
337
378
super ().__init__ ()
379
+ self ._connection = connection
338
380
self ._session_pool = session_pool
339
381
self ._tx_mode = tx_mode
340
382
self ._request_settings = request_settings
@@ -362,6 +404,7 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings:
362
404
return settings
363
405
364
406
@handle_ydb_errors
407
+ @invalidate_cursor_on_ydb_error
365
408
async def _execute_generic_query (
366
409
self , query : str , parameters : ParametersType | None = None
367
410
) -> AsyncIterator [ydb .convert .ResultSet ]:
@@ -379,6 +422,7 @@ async def callee(
379
422
return await self ._session_pool .retry_operation_async (callee )
380
423
381
424
@handle_ydb_errors
425
+ @invalidate_cursor_on_ydb_error
382
426
async def _execute_session_query (
383
427
self ,
384
428
query : str ,
@@ -399,6 +443,7 @@ async def callee(
399
443
return await self ._session_pool .retry_operation_async (callee )
400
444
401
445
@handle_ydb_errors
446
+ @invalidate_cursor_on_ydb_error
402
447
async def _execute_transactional_query (
403
448
self ,
404
449
tx_context : ydb .aio .QueryTxContext ,
@@ -457,6 +502,7 @@ async def executemany(
457
502
await self .execute (query , parameters )
458
503
459
504
@handle_ydb_errors
505
+ @invalidate_cursor_on_ydb_error
460
506
async def nextset (self , replace_current : bool = True ) -> bool :
461
507
if self ._stream is None :
462
508
return False
0 commit comments