Skip to content

Commit c6bddb7

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Support fully override base_url and raw model name when none of the project, locations, api_key are configured
PiperOrigin-RevId: 789502670
1 parent 8a45746 commit c6bddb7

File tree

3 files changed

+189
-115
lines changed

3 files changed

+189
-115
lines changed

google/genai/_api_client.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ def __init__(
451451
http_options: Optional[HttpOptionsOrDict] = None,
452452
):
453453
self.vertexai = vertexai
454+
self.custom_base_url = None
454455
if self.vertexai is None:
455456
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '0').lower() in [
456457
'true',
@@ -536,29 +537,33 @@ def __init__(
536537
)
537538
self.api_key = None
538539

540+
self.custom_base_url = (
541+
validated_http_options.base_url
542+
if validated_http_options.base_url
543+
else None
544+
)
545+
539546
# Skip fetching project from ADC if base url is provided in http options.
540-
if (
541-
not self.project
542-
and not self.api_key
543-
and not validated_http_options.base_url
544-
):
547+
if not self.project and not self.api_key and not self.custom_base_url:
545548
credentials, self.project = _load_auth(project=None)
546549
if not self._credentials:
547550
self._credentials = credentials
548551

549552
has_sufficient_auth = (self.project and self.location) or self.api_key
550553

551-
if (not has_sufficient_auth and not validated_http_options.base_url):
554+
if not has_sufficient_auth and not self.custom_base_url:
552555
# Skip sufficient auth check if base url is provided in http options.
553556
raise ValueError(
554557
'Project and location or API key must be set when using the Vertex '
555558
'AI API.'
556559
)
557560
if self.api_key or self.location == 'global':
558561
self._http_options.base_url = f'https://aiplatform.googleapis.com/'
559-
elif validated_http_options.base_url and not has_sufficient_auth:
562+
elif self.custom_base_url and not has_sufficient_auth:
560563
# Avoid setting default base url and api version if base_url provided.
561-
self._http_options.base_url = validated_http_options.base_url
564+
# API gateway proxy can use the auth in custom headers, not url.
565+
# Enable custom url if auth is not sufficient.
566+
self._http_options.base_url = self.custom_base_url
562567
else:
563568
self._http_options.base_url = (
564569
f'https://{self.location}-aiplatform.googleapis.com/'
@@ -793,6 +798,11 @@ def _use_aiohttp(self) -> bool:
793798
)
794799

795800
def _websocket_base_url(self) -> str:
801+
has_sufficient_auth = (self.project and self.location) or self.api_key
802+
if self.custom_base_url and not has_sufficient_auth:
803+
# API gateway proxy can use the auth in custom headers, not url.
804+
# Enable custom url if auth is not sufficient.
805+
return self.custom_base_url
796806
url_parts = urlparse(self._http_options.base_url)
797807
return url_parts._replace(scheme='wss').geturl() # type: ignore[arg-type, return-value]
798808

google/genai/live.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,8 @@ async def connect(
971971
api_key = self._api_client.api_key
972972
version = self._api_client._http_options.api_version
973973
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
974-
headers = self._api_client._http_options.headers or {}
974+
original_headers = self._api_client._http_options.headers
975+
headers = original_headers.copy() if original_headers is not None else {}
975976

976977
request_dict = _common.convert_to_dict(
977978
live_converters._LiveConnectParameters_to_vertex(
@@ -1003,12 +1004,24 @@ async def connect(
10031004
bearer_token = creds.token
10041005
original_headers = self._api_client._http_options.headers
10051006
headers = original_headers.copy() if original_headers is not None else {}
1006-
headers['Authorization'] = f'Bearer {bearer_token}'
1007+
if not headers.get('Authorization'):
1008+
headers['Authorization'] = f'Bearer {bearer_token}'
10071009
version = self._api_client._http_options.api_version
1008-
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
1010+
1011+
has_sufficient_auth = (
1012+
self._api_client.project and self._api_client.location
1013+
)
1014+
if self._api_client.custom_base_url and not has_sufficient_auth:
1015+
# API gateway proxy can use the auth in custom headers, not url.
1016+
# Enable custom url if auth is not sufficient.
1017+
uri = self._api_client.custom_base_url
1018+
# Keep the model as is.
1019+
transformed_model = model
1020+
else:
1021+
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
10091022
location = self._api_client.location
10101023
project = self._api_client.project
1011-
if transformed_model.startswith('publishers/'):
1024+
if transformed_model.startswith('publishers/') and project and location:
10121025
transformed_model = (
10131026
f'projects/{project}/locations/{location}/' + transformed_model
10141027
)

0 commit comments

Comments
 (0)