diff --git a/databricks/sdk/mixins/open_ai_client.py b/databricks/sdk/mixins/open_ai_client.py index 4ab08ee5a..e910fd47b 100644 --- a/databricks/sdk/mixins/open_ai_client.py +++ b/databricks/sdk/mixins/open_ai_client.py @@ -28,9 +28,72 @@ def auth_flow(self, request: httpx.Request) -> httpx.Request: databricks_token_auth = BearerAuth(self._api._cfg.authenticate) # Create an HTTP client with Bearer Token authentication - http_client = httpx.Client(auth=databricks_token_auth) + http_client = httpx.AsyncClient(auth=databricks_token_auth) return http_client + def get_async_open_ai_client(self, **kwargs): + """Create an OpenAI client configured for Databricks Model Serving. + + Returns an OpenAI client instance that is pre-configured to send requests to + Databricks Model Serving endpoints. The client uses Databricks authentication + to query endpoints within the workspace associated with the current WorkspaceClient + instance. + + Args: + **kwargs: Additional parameters to pass to the OpenAI client constructor. + Common parameters include: + - timeout (float): Request timeout in seconds (e.g., 30.0) + - max_retries (int): Maximum number of retries for failed requests (e.g., 3) + - default_headers (dict): Additional headers to include with requests + - default_query (dict): Additional query parameters to include with requests + + Any parameter accepted by the OpenAI client constructor can be passed here, + except for the following parameters which are reserved for Databricks integration: + base_url, api_key, http_client + + Returns: + OpenAI: An OpenAI client instance configured for Databricks Model Serving. + + Raises: + ImportError: If the OpenAI library is not installed. + ValueError: If any reserved Databricks parameters are provided in kwargs. + + Example: + >>> client = workspace_client.serving_endpoints.get_open_ai_client() + >>> # With custom timeout and retries + >>> client = workspace_client.serving_endpoints.get_open_ai_client( + ... timeout=30.0, + ... max_retries=5 + ... ) + """ + try: + from openai import AsyncOpenAI + except Exception: + raise ImportError( + "Open AI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]`" + ) + + # Check for reserved parameters that should not be overridden + reserved_params = {"base_url", "api_key", "http_client"} + conflicting_params = reserved_params.intersection(kwargs.keys()) + if conflicting_params: + raise ValueError( + f"Cannot override reserved Databricks parameters: {', '.join(sorted(conflicting_params))}. " + f"These parameters are automatically configured for Databricks Model Serving." + ) + + # Default parameters that are required for Databricks integration + client_params = { + "base_url": self._api._cfg.host + "/serving-endpoints", + "api_key": "no-token", # Passing in a placeholder to pass validations, this will not be used + "http_client": self._get_authorized_http_client(), + } + + # Update with any additional parameters passed by the user + client_params.update(kwargs) + + return AsyncOpenAI(**client_params) + def get_open_ai_client(self, **kwargs): """Create an OpenAI client configured for Databricks Model Serving. diff --git a/databricks/sdk/version.py b/databricks/sdk/version.py index e082e13ed..37a100052 100644 --- a/databricks/sdk/version.py +++ b/databricks/sdk/version.py @@ -1 +1 @@ -__version__ = "0.65.0" +__version__ = "0.65.0.dev0"