Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion databricks/sdk/mixins/open_ai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,72 @@ def auth_flow(self, request: httpx.Request) -> httpx.Request:
databricks_token_auth = BearerAuth(self._api._cfg.authenticate)

# Create an HTTP client with Bearer Token authentication
http_client = httpx.Client(auth=databricks_token_auth)
http_client = httpx.AsyncClient(auth=databricks_token_auth)
return http_client

def get_async_open_ai_client(self, **kwargs):
"""Create an OpenAI client configured for Databricks Model Serving.

Returns an OpenAI client instance that is pre-configured to send requests to
Databricks Model Serving endpoints. The client uses Databricks authentication
to query endpoints within the workspace associated with the current WorkspaceClient
instance.

Args:
**kwargs: Additional parameters to pass to the OpenAI client constructor.
Common parameters include:
- timeout (float): Request timeout in seconds (e.g., 30.0)
- max_retries (int): Maximum number of retries for failed requests (e.g., 3)
- default_headers (dict): Additional headers to include with requests
- default_query (dict): Additional query parameters to include with requests

Any parameter accepted by the OpenAI client constructor can be passed here,
except for the following parameters which are reserved for Databricks integration:
base_url, api_key, http_client

Returns:
OpenAI: An OpenAI client instance configured for Databricks Model Serving.

Raises:
ImportError: If the OpenAI library is not installed.
ValueError: If any reserved Databricks parameters are provided in kwargs.

Example:
>>> client = workspace_client.serving_endpoints.get_open_ai_client()
>>> # With custom timeout and retries
>>> client = workspace_client.serving_endpoints.get_open_ai_client(
... timeout=30.0,
... max_retries=5
... )
"""
try:
from openai import AsyncOpenAI
except Exception:
raise ImportError(
"Open AI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]`"
)

# Check for reserved parameters that should not be overridden
reserved_params = {"base_url", "api_key", "http_client"}
conflicting_params = reserved_params.intersection(kwargs.keys())
if conflicting_params:
raise ValueError(
f"Cannot override reserved Databricks parameters: {', '.join(sorted(conflicting_params))}. "
f"These parameters are automatically configured for Databricks Model Serving."
)

# Default parameters that are required for Databricks integration
client_params = {
"base_url": self._api._cfg.host + "/serving-endpoints",
"api_key": "no-token", # Passing in a placeholder to pass validations, this will not be used
"http_client": self._get_authorized_http_client(),
}

# Update with any additional parameters passed by the user
client_params.update(kwargs)

return AsyncOpenAI(**client_params)

def get_open_ai_client(self, **kwargs):
"""Create an OpenAI client configured for Databricks Model Serving.

Expand Down
2 changes: 1 addition & 1 deletion databricks/sdk/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.65.0"
__version__ = "0.65.0.dev0"
Loading