diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index 554c036db..cde55e58d 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -101,6 +101,7 @@ stop_after_attempt, wait_exponential, ) +from tenacity.wait import wait_base from typing_extensions import Self, is_typeddict from langchain_google_genai._common import ( @@ -148,31 +149,51 @@ class ChatGoogleGenerativeAIError(GoogleGenerativeAIError): """ -def _create_retry_decorator( - max_retries: int = 6, - wait_exponential_multiplier: float = 2.0, - wait_exponential_min: float = 1.0, - wait_exponential_max: float = 60.0, -) -> Callable[[Any], Any]: +class wait_with_server_retry_delay(wait_base): + def __init__(self, fallback_wait): + self.fallback_wait = fallback_wait + + def __call__(self, retry_state): + exception = retry_state.outcome.exception() + # Check if it's a ResourceExhausted with retry_delay + if ( + isinstance(exception, google.api_core.exceptions.ResourceExhausted) + and hasattr(exception, "retry_delay") + and exception.retry_delay is not None + and hasattr(exception.retry_delay, "seconds") + ): + delay = exception.retry_delay.seconds + logger.warning(f"Respecting server-suggested retry_delay: {delay}s") + return delay + # Otherwise use fallback (exponential backoff) + return self.fallback_wait(retry_state) + + +def _create_retry_decorator() -> Callable[[Any], Any]: """ Creates and returns a preconfigured tenacity retry decorator. The retry decorator is configured to handle specific Google API exceptions - such as ResourceExhausted and ServiceUnavailable. It uses an exponential - backoff strategy for retries. + such as ResourceExhausted and ServiceUnavailable. It uses a custom strategy + that respects retry_delay if provided by the API response. Returns: Callable[[Any], Any]: A retry decorator configured for handling specific Google API exceptions. """ + multiplier = 2 + min_seconds = 1 + max_seconds = 60 + max_retries = 2 + + fallback_wait = wait_exponential( + multiplier=multiplier, min=min_seconds, max=max_seconds + ) + return retry( reraise=True, stop=stop_after_attempt(max_retries), - wait=wait_exponential( - multiplier=wait_exponential_multiplier, - min=wait_exponential_min, - max=wait_exponential_max, - ), + wait=wait_with_server_retry_delay(fallback_wait), retry=( retry_if_exception_type(google.api_core.exceptions.ResourceExhausted) | retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)