Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def __init__(
http_options: Optional[HttpOptionsOrDict] = None,
):
self.vertexai = vertexai
self.custom_base_url = None
if self.vertexai is None:
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in [
'true',
Expand Down Expand Up @@ -536,29 +537,33 @@ def __init__(
)
self.api_key = None

self.custom_base_url = (
validated_http_options.base_url
if validated_http_options.base_url
else None
)

# Skip fetching project from ADC if base url is provided in http options.
if (
not self.project
and not self.api_key
and not validated_http_options.base_url
):
if not self.project and not self.api_key and not self.custom_base_url:
credentials, self.project = _load_auth(project=None)
if not self._credentials:
self._credentials = credentials

has_sufficient_auth = (self.project and self.location) or self.api_key

if (not has_sufficient_auth and not validated_http_options.base_url):
if not has_sufficient_auth and not self.custom_base_url:
# Skip sufficient auth check if base url is provided in http options.
raise ValueError(
'Project and location or API key must be set when using the Vertex '
'AI API.'
)
if self.api_key or self.location == 'global':
self._http_options.base_url = f'https://aiplatform.googleapis.com/'
elif validated_http_options.base_url and not has_sufficient_auth:
elif self.custom_base_url and not has_sufficient_auth:
# Avoid setting default base url and api version if base_url provided.
self._http_options.base_url = validated_http_options.base_url
# API gateway proxy can use the auth in custom headers, not url.
# Enable custom url if auth is not sufficient.
self._http_options.base_url = self.custom_base_url
else:
self._http_options.base_url = (
f'https://{self.location}-aiplatform.googleapis.com/'
Expand Down Expand Up @@ -793,6 +798,11 @@ def _use_aiohttp(self) -> bool:
)

def _websocket_base_url(self) -> str:
has_sufficient_auth = (self.project and self.location) or self.api_key
if self.custom_base_url and not has_sufficient_auth:
# API gateway proxy can use the auth in custom headers, not url.
# Enable custom url if auth is not sufficient.
return self.custom_base_url
url_parts = urlparse(self._http_options.base_url)
return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]

Expand Down
21 changes: 17 additions & 4 deletions google/genai/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,8 @@ async def connect(
api_key = self._api_client.api_key
version = self._api_client._http_options.api_version
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
headers = self._api_client._http_options.headers or {}
original_headers = self._api_client._http_options.headers
headers = original_headers.copy() if original_headers is not None else {}

request_dict = _common.convert_to_dict(
live_converters._LiveConnectParameters_to_vertex(
Expand Down Expand Up @@ -1003,12 +1004,24 @@ async def connect(
bearer_token = creds.token
original_headers = self._api_client._http_options.headers
headers = original_headers.copy() if original_headers is not None else {}
headers['Authorization'] = f'Bearer {bearer_token}'
if not headers.get('Authorization'):
headers['Authorization'] = f'Bearer {bearer_token}'
version = self._api_client._http_options.api_version
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'

has_sufficient_auth = (
self._api_client.project and self._api_client.location
)
if self._api_client.custom_base_url and not has_sufficient_auth:
# API gateway proxy can use the auth in custom headers, not url.
# Enable custom url if auth is not sufficient.
uri = self._api_client.custom_base_url
# Keep the model as is.
transformed_model = model
else:
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
location = self._api_client.location
project = self._api_client.project
if transformed_model.startswith('publishers/'):
if transformed_model.startswith('publishers/') and project and location:
transformed_model = (
f'projects/{project}/locations/{location}/' + transformed_model
)
Expand Down
Loading