Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions libs/genai/langchain_google_genai/_client_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import asyncio
from functools import lru_cache
from typing import Optional
from weakref import WeakKeyDictionary, WeakSet

from google.ai.generativelanguage_v1beta import (
GenerativeServiceAsyncClient as v1betaGenerativeServiceAsyncClient,
)
from google.ai.generativelanguage_v1beta import (
GenerativeServiceClient as v1betaGenerativeServiceClient,
)

from langchain_google_genai._common import get_client_info

from . import _genai_extension as genaix


# Cache sync client
@lru_cache
def _get_sync_client(
*,
api_key: Optional[str],
model: str,
transport: str,
) -> v1betaGenerativeServiceClient:
"""Return a shared sync client."""
client_info = get_client_info(f"ChatGoogleGenerativeAI:{model}")
return genaix.build_generative_service(
api_key=api_key,
client_info=client_info,
client_options=None,
transport=transport,
)


# Cache async client - must store caches per event loop
_client_caches: WeakKeyDictionary = WeakKeyDictionary()
_clients_to_close: WeakSet = WeakSet()


def _create_async_client(
*,
api_key: Optional[str],
model: str,
transport: Optional[str],
) -> v1betaGenerativeServiceAsyncClient:
"""Create a new async client."""
# async clients don't support "rest" transport
# https://github.com/googleapis/gapic-generator-python/issues/1962
if transport == "rest":
transport = "grpc_asyncio"
client = genaix.build_generative_async_service(
credentials=None,
api_key=api_key,
client_info=get_client_info(f"ChatGoogleGenerativeAI:{model}"),
client_options=None,
transport=transport,
)
_clients_to_close.add(client)
return client


def _get_async_client(
*,
api_key: Optional[str],
model: str,
transport: Optional[str],
) -> v1betaGenerativeServiceAsyncClient:
"""Return a shared async client per event loop."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# If no event loop is running, don't cache
return _create_async_client(
api_key=api_key,
model=model,
transport=transport,
)

# Get or create cache for this event loop
if loop not in _client_caches:
_client_caches[loop] = {}

cache_key = (api_key, model, transport)

if cache_key not in _client_caches[loop]:
_client_caches[loop][cache_key] = _create_async_client(
api_key=api_key,
model=model,
transport=transport,
)

return _client_caches[loop][cache_key]


# import atexit

# import asyncio, atexit

# _cleanup_loop = asyncio.new_event_loop()

# async def _close_everything():
# await asyncio.gather(
# *(c.transport.close() for c in _clients_to_close),
# return_exceptions=True,
# )
# # let grpc.aio finish any pending callbacks
# await asyncio.sleep(0)

# def _shutdown_all_clients():
# try:
# _cleanup_loop.run_until_complete(_close_everything())
# _cleanup_loop.run_until_complete(_cleanup_loop.shutdown_asyncgens())
# finally:
# _cleanup_loop.close()

# atexit.register(_shutdown_all_clients)
66 changes: 46 additions & 20 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import atexit
import base64
import io
import json
Expand Down Expand Up @@ -99,6 +100,7 @@
)
from typing_extensions import Self, is_typeddict

from langchain_google_genai._client_utils import _get_async_client, _get_sync_client
from langchain_google_genai._common import (
GoogleGenerativeAIError,
SafetySettingDict,
Expand Down Expand Up @@ -1156,24 +1158,42 @@ def validate_environment(self) -> Self:

additional_headers = self.additional_headers or {}
self.default_metadata = tuple(additional_headers.items())
client_info = get_client_info(f"ChatGoogleGenerativeAI:{self.model}")
google_api_key = None
if not self.credentials:
if isinstance(self.google_api_key, SecretStr):
google_api_key = self.google_api_key.get_secret_value()
else:
google_api_key = self.google_api_key
transport: Optional[str] = self.transport
self.client = genaix.build_generative_service(
credentials=self.credentials,
api_key=google_api_key,
client_info=client_info,
client_options=self.client_options,
transport=transport,
)
if self.client_options is None and self.credentials is None:
self.client = _get_sync_client(
api_key=google_api_key,
model=self.model,
transport=transport,
)
else:
self.client = genaix.build_generative_service(
credentials=self.credentials,
api_key=google_api_key,
client_info=get_client_info(f"ChatGoogleGenerativeAI:{self.model}"),
client_options=self.client_options,
transport=transport,
)
self.async_client_running = None
atexit.register(self.client.transport.close)
return self

# @atexit.register
# def close_async_client(self) -> None:
# """Close the async client."""
# loop = asyncio.get_event_loop()
# loop.run_until_complete(self.await_close_async_client())

# async def await_close_async_client(self) -> None:
# """Close the async client."""
# if self.async_client_running:
# await self.async_client_running.transport.close()

@property
def async_client(self) -> v1betaGenerativeServiceAsyncClient:
google_api_key = None
Expand All @@ -1188,20 +1208,26 @@ def async_client(self) -> v1betaGenerativeServiceAsyncClient:
# this check ensures that async client is only initialized
# within an asyncio event loop to avoid the error
if not self.async_client_running and _is_event_loop_running():
# async clients don't support "rest" transport
# https://github.com/googleapis/gapic-generator-python/issues/1962
transport = self.transport
if transport == "rest":
transport = "grpc_asyncio"
self.async_client_running = genaix.build_generative_async_service(
credentials=self.credentials,
api_key=google_api_key,
client_info=get_client_info(f"ChatGoogleGenerativeAI:{self.model}"),
client_options=self.client_options,
transport=transport,
)
if self.credentials is None and self.client_options is None:
self.async_client_running = _get_async_client(
api_key=google_api_key,
model=self.model,
transport=self.transport,
)
else:
self.async_client_running = genaix.build_generative_async_service(
credentials=self.credentials,
api_key=google_api_key,
client_info=get_client_info(f"ChatGoogleGenerativeAI:{self.model}"),
client_options=self.client_options,
transport=self.transport,
)
return self.async_client_running

# async def __aexit__(self) -> None:
# if self.async_client_running:
# await self.async_client_running.transport.close()

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
Expand Down
Loading
Loading