From 1a40613946ade4845e311f382a6e531e1c11481e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Fri, 25 Apr 2025 16:31:01 +0200 Subject: [PATCH] Change polling for progress logging to exponential backoff --- changelog.md | 1 + .../query_runner/neo4j_query_runner.py | 9 ++- .../progress/query_progress_logger.py | 62 ++++++++++++------- .../query_runner/session_query_runner.py | 10 ++- graphdatascience/retry_utils/retry_utils.py | 14 ++++- .../progress/test_query_progress_logger.py | 30 +++++++++ .../gds_plugin_enterprise/compose.yml | 2 +- 7 files changed, 99 insertions(+), 29 deletions(-) diff --git a/changelog.md b/changelog.md index 1eb1094ac..7b7f123b8 100644 --- a/changelog.md +++ b/changelog.md @@ -21,6 +21,7 @@ * Improve error message if session is expired. * Improve robustness of Arrow client against connection errors such as `FlightUnavailableError` and `FlightTimedOutError`. * Return dedicated error class `SessionStatusError` if a session failed or expired. +* Reduce calls which check for progress updates. Previously every 0.5 seconds, now with exponential backoff capped at 10s. ## Other changes diff --git a/graphdatascience/query_runner/neo4j_query_runner.py b/graphdatascience/query_runner/neo4j_query_runner.py index 4a8f552d5..85de21729 100644 --- a/graphdatascience/query_runner/neo4j_query_runner.py +++ b/graphdatascience/query_runner/neo4j_query_runner.py @@ -8,6 +8,7 @@ import neo4j from pandas import DataFrame +from tenacity import wait_exponential from ..call_parameters import CallParameters from ..error.endpoint_suggester import generate_suggestive_error_message @@ -114,7 +115,13 @@ def __init__( self._server_version: Optional[ServerVersion] = None self._show_progress = show_progress self._progress_logger = QueryProgressLogger( - self.__run_cypher_simplified_for_query_progress_logger, self.server_version + run_cypher_func=self.__run_cypher_simplified_for_query_progress_logger, + server_version_func=self.server_version, + log_interval=wait_exponential( + max=10, + exp_base=1.5, + min=0.5, + ), ) self._instance_description = instance_description diff --git a/graphdatascience/query_runner/progress/query_progress_logger.py b/graphdatascience/query_runner/progress/query_progress_logger.py index 1e2ffd1a2..e8421b3dc 100644 --- a/graphdatascience/query_runner/progress/query_progress_logger.py +++ b/graphdatascience/query_runner/progress/query_progress_logger.py @@ -1,10 +1,13 @@ import warnings -from concurrent.futures import Future, ThreadPoolExecutor, wait -from typing import Any, Callable, NoReturn, Optional +from concurrent import futures +from typing import Any, Callable, NoReturn, Optional, Union from pandas import DataFrame +from tenacity import Retrying, wait from tqdm.auto import tqdm +from graphdatascience.retry_utils.retry_utils import retry_until_future + from ...server_version.server_version import ServerVersion from .progress_provider import ProgressProvider, TaskWithProgress from .query_progress_provider import CypherQueryFunction, QueryProgressProvider, ServerVersionFunction @@ -18,16 +21,24 @@ def __init__( self, run_cypher_func: CypherQueryFunction, server_version_func: ServerVersionFunction, - polling_interval: float = 0.5, + log_interval: Union[float, wait.wait_base] = 0.5, + initial_wait_time: float = 0.5, progress_bar_options: dict[str, Any] = {}, ): self._run_cypher_func = run_cypher_func self._server_version_func = server_version_func self._static_progress_provider = StaticProgressProvider() self._query_progress_provider = QueryProgressProvider(run_cypher_func, server_version_func) - self._polling_interval = polling_interval self._progress_bar_options = progress_bar_options + self._initial_wait_time = initial_wait_time + if isinstance(log_interval, float): + self._wait_base: wait.wait_base = wait.wait_fixed(log_interval) + elif isinstance(log_interval, wait.wait_base): + self._wait_base = log_interval + else: + raise ValueError("polling interval must be a float or an instance of wait_base") + def run_with_progress_logging( self, runnable: DataFrameProducer, job_id: str, database: Optional[str] = None ) -> DataFrame: @@ -38,9 +49,10 @@ def run_with_progress_logging( # Entries in the static progress store are already visible at this point. progress_provider = self._select_progress_provider(job_id) - with ThreadPoolExecutor() as executor: + with futures.ThreadPoolExecutor() as executor: future = executor.submit(runnable) + futures.wait([future], timeout=self._initial_wait_time) # wait for progress task to be available self._log(future, job_id, progress_provider, database) if future.exception(): @@ -56,29 +68,33 @@ def _select_progress_provider(self, job_id: str) -> ProgressProvider: ) def _log( - self, future: Future[Any], job_id: str, progress_provider: ProgressProvider, database: Optional[str] = None + self, + future: futures.Future[Any], + job_id: str, + progress_provider: ProgressProvider, + database: Optional[str] = None, ) -> None: pbar: Optional[tqdm[NoReturn]] = None warn_if_failure = True - while wait([future], timeout=self._polling_interval).not_done: - try: - task_with_progress = progress_provider.root_task_with_progress(job_id, database) - if pbar is None: - pbar = self._init_pbar(task_with_progress) - - self._update_pbar(pbar, task_with_progress) - except Exception as e: - # Do nothing if the procedure either: - # * has not started yet, - # * has already completed. - if f"No task with job id `{job_id}` was found" in str(e): - continue - else: + for attempt in Retrying(wait=self._wait_base, retry=retry_until_future(future)): + with attempt: + try: + task_with_progress = progress_provider.root_task_with_progress(job_id, database) + if pbar is None: + pbar = self._init_pbar(task_with_progress) + + self._update_pbar(pbar, task_with_progress) + except Exception as e: + # Do nothing if the procedure either: + # * has not started yet, + # * has already completed. + if f"No task with job id `{job_id}` was found" in str(e): + continue + if warn_if_failure: warnings.warn(f"Unable to get progress: {str(e)}", RuntimeWarning) warn_if_failure = False - continue if pbar is not None: self._finish_pbar(future, pbar) @@ -91,7 +107,6 @@ def _init_pbar(self, task_with_progress: TaskWithProgress) -> tqdm: # type: ign total=None, unit="", desc=root_task_name, - maxinterval=self._polling_interval, bar_format="{desc} [elapsed: {elapsed} {postfix}]", **self._progress_bar_options, ) @@ -100,7 +115,6 @@ def _init_pbar(self, task_with_progress: TaskWithProgress) -> tqdm: # type: ign total=100, unit="%", desc=root_task_name, - maxinterval=self._polling_interval, **self._progress_bar_options, ) @@ -118,7 +132,7 @@ def _update_pbar(self, pbar: tqdm, task_with_progress: TaskWithProgress) -> None else: pbar.refresh() - def _finish_pbar(self, future: Future[Any], pbar: tqdm) -> None: # type: ignore + def _finish_pbar(self, future: futures.Future[Any], pbar: tqdm) -> None: # type: ignore if future.exception(): pbar.set_postfix_str("status: FAILED", refresh=True) return diff --git a/graphdatascience/query_runner/session_query_runner.py b/graphdatascience/query_runner/session_query_runner.py index a28c9bdc5..7e9f32a59 100644 --- a/graphdatascience/query_runner/session_query_runner.py +++ b/graphdatascience/query_runner/session_query_runner.py @@ -5,6 +5,7 @@ from uuid import uuid4 from pandas import DataFrame +from tenacity import wait_exponential from graphdatascience.query_runner.graph_constructor import GraphConstructor from graphdatascience.query_runner.progress.query_progress_logger import QueryProgressLogger @@ -41,8 +42,13 @@ def __init__( self._resolved_protocol_version = ProtocolVersionResolver(db_query_runner).resolve() self._show_progress = show_progress self._progress_logger = QueryProgressLogger( - lambda query, database: self._gds_query_runner.run_cypher(query=query, database=database), - self._gds_query_runner.server_version, + run_cypher_func=lambda query, database: self._gds_query_runner.run_cypher(query=query, database=database), + server_version_func=self._gds_query_runner.server_version, + log_interval=wait_exponential( + max=10, + exp_base=1.5, + min=0.5, + ), ) def run_cypher( diff --git a/graphdatascience/retry_utils/retry_utils.py b/graphdatascience/retry_utils/retry_utils.py index 6d7415062..acd0f552c 100644 --- a/graphdatascience/retry_utils/retry_utils.py +++ b/graphdatascience/retry_utils/retry_utils.py @@ -1,7 +1,8 @@ import logging import typing +from concurrent.futures import Future -from tenacity import RetryCallState +from tenacity import RetryCallState, retry_base def before_log( @@ -18,3 +19,14 @@ def log_it(retry_state: RetryCallState) -> None: ) return log_it + + +class retry_until_future(retry_base): + def __init__( + self, + future: Future[typing.Any], + ): + self._future = future + + def __call__(self, retry_state: "RetryCallState") -> bool: + return not self._future.done() diff --git a/graphdatascience/tests/unit/query_runner/progress/test_query_progress_logger.py b/graphdatascience/tests/unit/query_runner/progress/test_query_progress_logger.py index 1fe9cd9a7..3d7d9e6e1 100644 --- a/graphdatascience/tests/unit/query_runner/progress/test_query_progress_logger.py +++ b/graphdatascience/tests/unit/query_runner/progress/test_query_progress_logger.py @@ -5,6 +5,7 @@ from typing import Optional from pandas import DataFrame +from tenacity import wait from graphdatascience import ServerVersion from graphdatascience.query_runner.progress.progress_provider import TaskWithProgress @@ -31,6 +32,35 @@ def fake_query() -> DataFrame: assert df["result"][0] == 42 +def test_log_interval() -> None: + def fake_run_cypher(query: str, database: Optional[str] = None) -> DataFrame: + assert "CALL gds.listProgress('foo')" in query + assert database == "database" + + return DataFrame([{"progress": "n/a", "taskName": "Test task", "status": "RUNNING"}]) + + def fake_query() -> DataFrame: + time.sleep(0.5) + return DataFrame([{"result": 42}]) + + with StringIO() as pbarOutputStream: + qpl = QueryProgressLogger( + fake_run_cypher, + lambda: ServerVersion(3, 0, 0), + log_interval=wait.wait_fixed(0.1), + initial_wait_time=0, + progress_bar_options={"file": pbarOutputStream, "mininterval": 0}, + ) + df = qpl.run_with_progress_logging(fake_query, "foo", "database") + + running_output = pbarOutputStream.getvalue().split("\r")[:-1] + + assert len(running_output) > 4 + assert len(running_output) < 15 + + assert df["result"][0] == 42 + + def test_skips_progress_logging_for_old_server_version() -> None: def fake_run_cypher(query: str, database: Optional[str] = None) -> DataFrame: print("Should not be called!") diff --git a/scripts/test_envs/gds_plugin_enterprise/compose.yml b/scripts/test_envs/gds_plugin_enterprise/compose.yml index 65893b663..fab0e0c05 100755 --- a/scripts/test_envs/gds_plugin_enterprise/compose.yml +++ b/scripts/test_envs/gds_plugin_enterprise/compose.yml @@ -1,6 +1,6 @@ services: neo4j: - image: neo4j:enterprise + image: neo4j:5-enterprise volumes: - ${HOME}/.gds_license:/licenses/.gds_license environment: