Skip to content

Remove default timeout from topic stream #693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ydb/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
DEFAULT_INITIAL_RESPONSE_TIMEOUT = 600
DEFAULT_LONG_STREAM_TIMEOUT = 31536000 # year
33 changes: 29 additions & 4 deletions ydb/_grpc/grpcwrapper/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from ..common.protos import ydb_topic_pb2, ydb_issue_message_pb2

from ... import issues, connection
from ...settings import BaseRequestSettings
from ..._constants import DEFAULT_LONG_STREAM_TIMEOUT


class IFromProto(abc.ABC):
Expand Down Expand Up @@ -131,7 +133,7 @@ async def __anext__(self):

class IGrpcWrapperAsyncIO(abc.ABC):
@abc.abstractmethod
async def receive(self) -> Any:
async def receive(self, timeout: Optional[int] = None) -> Any:
...

@abc.abstractmethod
Expand Down Expand Up @@ -161,6 +163,13 @@ def __init__(self, convert_server_grpc_to_wrapper):
self._stream_call = None
self._wait_executor = None

self._stream_settings: BaseRequestSettings = (
BaseRequestSettings()
.with_operation_timeout(DEFAULT_LONG_STREAM_TIMEOUT)
.with_cancel_after(DEFAULT_LONG_STREAM_TIMEOUT)
.with_timeout(DEFAULT_LONG_STREAM_TIMEOUT)
)

def __del__(self):
self._clean_executor(wait=False)

Expand Down Expand Up @@ -188,6 +197,7 @@ async def _start_asyncio_driver(self, driver: DriverIO, stub, method):
requests_iterator,
stub,
method,
settings=self._stream_settings,
)
self._stream_call = stream_call
self.from_server_grpc = stream_call.__aiter__()
Expand All @@ -196,14 +206,29 @@ async def _start_sync_driver(self, driver: Driver, stub, method):
requests_iterator = AsyncQueueToSyncIteratorAsyncIO(self.from_client_grpc)
self._wait_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)

stream_call = await to_thread(driver, requests_iterator, stub, method, executor=self._wait_executor)
stream_call = await to_thread(
driver,
requests_iterator,
stub,
method,
executor=self._wait_executor,
settings=self._stream_settings,
)
self._stream_call = stream_call
self.from_server_grpc = SyncToAsyncIterator(stream_call.__iter__(), self._wait_executor)

async def receive(self) -> Any:
async def receive(self, timeout: Optional[int] = None) -> Any:
# todo handle grpc exceptions and convert it to internal exceptions
try:
grpc_message = await self.from_server_grpc.__anext__()
if timeout is None:
grpc_message = await self.from_server_grpc.__anext__()
else:

async def get_response():
return await self.from_server_grpc.__anext__()

grpc_message = await asyncio.wait_for(get_response(), timeout)

except (grpc.RpcError, grpc.aio.AioRpcError) as e:
raise connection._rpc_error_handler(self._connection_state, e)

Expand Down
2 changes: 1 addition & 1 deletion ydb/_topic_common/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self):
self.from_client = asyncio.Queue()
self._closed = False

async def receive(self) -> typing.Any:
async def receive(self, timeout: typing.Optional[int] = None) -> typing.Any:
if self._closed:
raise Exception("read from closed StreamMock")

Expand Down
10 changes: 9 additions & 1 deletion ydb/_topic_reader/topic_reader_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
if typing.TYPE_CHECKING:
from ..query.transaction import BaseQueryTxContext

from .._constants import DEFAULT_INITIAL_RESPONSE_TIMEOUT

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -490,7 +492,13 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMess
logger.debug("reader stream %s send init request", self._id)

stream.write(StreamReadMessage.FromClient(client_message=init_message))
init_response = await stream.receive() # type: StreamReadMessage.FromServer
try:
init_response = await stream.receive(
timeout=DEFAULT_INITIAL_RESPONSE_TIMEOUT
) # type: StreamReadMessage.FromServer
except asyncio.TimeoutError:
raise TopicReaderError("Timeout waiting for init response")

if isinstance(init_response.server_message, StreamReadMessage.InitResponse):
self._session_id = init_response.server_message.session_id
logger.debug("reader stream %s initialized session=%s", self._id, self._session_id)
Expand Down
44 changes: 43 additions & 1 deletion ydb/_topic_reader/topic_reader_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from . import datatypes, topic_reader_asyncio
from .datatypes import PublicBatch, PublicMessage
from .topic_reader import PublicReaderSettings
from .topic_reader_asyncio import ReaderStream, ReaderReconnector
from .topic_reader_asyncio import ReaderStream, ReaderReconnector, TopicReaderError
from .._grpc.grpcwrapper.common_utils import SupportedDriverType, ServerStatus
from .._grpc.grpcwrapper.ydb_topic import (
StreamReadMessage,
Expand All @@ -36,6 +36,8 @@
else:
from .._grpc.common.protos import ydb_status_codes_pb2

from .._constants import DEFAULT_INITIAL_RESPONSE_TIMEOUT


@pytest.fixture(autouse=True)
def handle_exceptions(event_loop):
Expand Down Expand Up @@ -1475,6 +1477,46 @@ def logged():

await wait_condition(logged)

async def test_init_timeout_parameter(self, stream, default_reader_settings):
"""Test that ReaderStream._start calls stream.receive with timeout=10"""
reader = ReaderStream(self.default_reader_reconnector_id, default_reader_settings)
init_message = default_reader_settings._init_message()

# Mock stream.receive to check if timeout is passed
with mock.patch.object(stream, "receive") as mock_receive:
mock_receive.return_value = StreamReadMessage.FromServer(
server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []),
server_message=StreamReadMessage.InitResponse(session_id="test_session"),
)

await reader._start(stream, init_message)

# Verify that receive was called with timeout
mock_receive.assert_called_with(timeout=DEFAULT_INITIAL_RESPONSE_TIMEOUT)

await reader.close(False)

async def test_init_timeout_behavior(self, stream, default_reader_settings):
"""Test that ReaderStream._start raises TopicReaderError when receive times out"""
reader = ReaderStream(self.default_reader_reconnector_id, default_reader_settings)
init_message = default_reader_settings._init_message()

# Mock stream.receive to directly raise TimeoutError when called with timeout
async def timeout_receive(timeout=None):
if timeout == DEFAULT_INITIAL_RESPONSE_TIMEOUT:
raise asyncio.TimeoutError("Simulated timeout")
return StreamReadMessage.FromServer(
server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []),
server_message=StreamReadMessage.InitResponse(session_id="test_session"),
)

with mock.patch.object(stream, "receive", side_effect=timeout_receive):
# Should raise TopicReaderError with timeout message
with pytest.raises(TopicReaderError, match="Timeout waiting for init response"):
await reader._start(stream, init_message)

await reader.close(False)


@pytest.mark.asyncio
class TestReaderReconnector:
Expand Down
8 changes: 7 additions & 1 deletion ydb/_topic_writer/topic_writer_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
if typing.TYPE_CHECKING:
from ..query.transaction import BaseQueryTxContext

from .._constants import DEFAULT_INITIAL_RESPONSE_TIMEOUT

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -799,7 +801,11 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMes
logger.debug("writer stream %s send init request", self._id)
stream.write(StreamWriteMessage.FromClient(init_message))

resp = await stream.receive()
try:
resp = await stream.receive(timeout=DEFAULT_INITIAL_RESPONSE_TIMEOUT)
except asyncio.TimeoutError:
raise TopicWriterError("Timeout waiting for init response")

self._ensure_ok(resp)
if not isinstance(resp, StreamWriteMessage.InitResponse):
raise TopicWriterError("Unexpected answer for init request: %s" % resp)
Expand Down
51 changes: 51 additions & 0 deletions ydb/_topic_writer/topic_writer_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

from ..credentials import AnonymousCredentials

from .._constants import DEFAULT_INITIAL_RESPONSE_TIMEOUT


FAKE_TRANSACTION_IDENTITY = TransactionIdentity(
tx_id="transaction_id",
Expand Down Expand Up @@ -231,6 +233,55 @@ async def test_update_token(self, stream: StreamMock):

await writer.close()

async def test_init_timeout_parameter(self, stream):
"""Test that WriterAsyncIOStream._start calls stream.receive with timeout=10"""
writer_id = 1
settings = WriterSettings(PublicWriterSettings("test-topic", "test-producer"))

# Mock stream.receive to check if timeout is passed
with mock.patch.object(stream, "receive") as mock_receive:
mock_receive.return_value = StreamWriteMessage.InitResponse(
last_seq_no=0,
session_id="test_session",
partition_id=1,
supported_codecs=[Codec.CODEC_RAW],
status=ServerStatus(StatusCode.SUCCESS, []),
)

writer = WriterAsyncIOStream(writer_id, settings)
await writer._start(stream, settings.create_init_request())

# Verify that receive was called with timeout
mock_receive.assert_called_with(timeout=DEFAULT_INITIAL_RESPONSE_TIMEOUT)

await writer.close()

async def test_init_timeout_behavior(self, stream):
"""Test that WriterAsyncIOStream._start raises TopicWriterError when receive times out"""
writer_id = 1
settings = WriterSettings(PublicWriterSettings("test-topic", "test-producer"))

# Mock stream.receive to directly raise TimeoutError when called with timeout
async def timeout_receive(timeout=None):
if timeout == DEFAULT_INITIAL_RESPONSE_TIMEOUT:
raise asyncio.TimeoutError("Simulated timeout")
return StreamWriteMessage.InitResponse(
last_seq_no=0,
session_id="test_session",
partition_id=1,
supported_codecs=[Codec.CODEC_RAW],
status=ServerStatus(StatusCode.SUCCESS, []),
)

with mock.patch.object(stream, "receive", side_effect=timeout_receive):
writer = WriterAsyncIOStream(writer_id, settings)

# Should raise TopicWriterError with timeout message
with pytest.raises(TopicWriterError, match="Timeout waiting for init response"):
await writer._start(stream, settings.create_init_request())

# Don't close writer since _start failed and _stream was never set


@pytest.mark.asyncio
class TestWriterAsyncIOReconnector:
Expand Down
5 changes: 3 additions & 2 deletions ydb/aio/query/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from ...query import base
from ...query.session import (
BaseQuerySession,
DEFAULT_ATTACH_FIRST_RESP_TIMEOUT,
QuerySessionStateEnum,
)

from ..._constants import DEFAULT_INITIAL_RESPONSE_TIMEOUT


class QuerySession(BaseQuerySession):
"""Session object for Query Service. It is not recommended to control
Expand Down Expand Up @@ -47,7 +48,7 @@ async def _attach(self) -> None:
try:
first_response = await _utilities.get_first_message_with_timeout(
self._status_stream,
DEFAULT_ATTACH_FIRST_RESP_TIMEOUT,
DEFAULT_INITIAL_RESPONSE_TIMEOUT,
)
if first_response.status != issues.StatusCode.SUCCESS:
raise RuntimeError("Failed to attach session")
Expand Down
14 changes: 6 additions & 8 deletions ydb/query/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@

from .transaction import QueryTxContext


logger = logging.getLogger(__name__)
from .._constants import DEFAULT_INITIAL_RESPONSE_TIMEOUT, DEFAULT_LONG_STREAM_TIMEOUT


DEFAULT_ATTACH_FIRST_RESP_TIMEOUT = 600
DEFAULT_ATTACH_LONG_TIMEOUT = 31536000 # year
logger = logging.getLogger(__name__)


class QuerySessionStateEnum(enum.Enum):
Expand Down Expand Up @@ -142,9 +140,9 @@ def __init__(self, driver: common_utils.SupportedDriverType, settings: Optional[
self._state = QuerySessionState(settings)
self._attach_settings: BaseRequestSettings = (
BaseRequestSettings()
.with_operation_timeout(DEFAULT_ATTACH_LONG_TIMEOUT)
.with_cancel_after(DEFAULT_ATTACH_LONG_TIMEOUT)
.with_timeout(DEFAULT_ATTACH_LONG_TIMEOUT)
.with_operation_timeout(DEFAULT_LONG_STREAM_TIMEOUT)
.with_cancel_after(DEFAULT_LONG_STREAM_TIMEOUT)
.with_timeout(DEFAULT_LONG_STREAM_TIMEOUT)
)

self._last_query_stats = None
Expand Down Expand Up @@ -233,7 +231,7 @@ class QuerySession(BaseQuerySession):

_stream = None

def _attach(self, first_resp_timeout: int = DEFAULT_ATTACH_FIRST_RESP_TIMEOUT) -> None:
def _attach(self, first_resp_timeout: int = DEFAULT_INITIAL_RESPONSE_TIMEOUT) -> None:
self._stream = self._attach_call()
status_stream = _utilities.SyncResponseIterator(
self._stream,
Expand Down
Loading