diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index 2becfb4fb..1501e91a1 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -147,11 +147,14 @@ def open(self): else: self.__pool = pool_class(self.host, self.port, **_pool_kwargs) - def close(self): + def release_connection(self): self.__resp and self.__resp.drain_conn() self.__resp and self.__resp.release_conn() self.__resp = None + def close(self): + self.__pool.close() + def read(self, sz): return self.__resp.read(sz) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 75d2c665c..5f10f2df4 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -287,9 +287,9 @@ def close_session(self, session_id: SessionId) -> None: logger.debug("SeaDatabricksClient.close_session(session_id=%s)", session_id) - if session_id.backend_type != BackendType.SEA: - raise ValueError("Not a valid SEA session ID") sea_session_id = session_id.to_sea_session_id() + if sea_session_id is None: + raise ValueError("Not a valid SEA session ID") request_data = DeleteSessionRequest( warehouse_id=self.warehouse_id, @@ -302,6 +302,9 @@ def close_session(self, session_id: SessionId) -> None: data=request_data.to_dict(), ) + # close the HTTP client + self._http_client.close() + def _extract_description_from_manifest( self, manifest: ResultManifest ) -> List[Tuple]: diff --git a/src/databricks/sql/backend/sea/utils/http_client.py b/src/databricks/sql/backend/sea/utils/http_client.py index b47f2add2..069f61e52 100644 --- a/src/databricks/sql/backend/sea/utils/http_client.py +++ b/src/databricks/sql/backend/sea/utils/http_client.py @@ -183,7 +183,7 @@ def _open(self): def close(self): """Close the connection pool.""" if self._pool: - self._pool.clear() + self._pool.close() def using_proxy(self) -> bool: """Check if proxy is being used.""" diff --git a/src/databricks/sql/backend/thrift_backend.py b/src/databricks/sql/backend/thrift_backend.py index d2b10e718..6544c554c 100644 --- a/src/databricks/sql/backend/thrift_backend.py +++ b/src/databricks/sql/backend/thrift_backend.py @@ -483,8 +483,8 @@ def attempt_request(attempt): ) ) finally: - # Calling `close()` here releases the active HTTP connection back to the pool - self._transport.close() + # Calling `release_connection()` here releases the active HTTP connection back to the pool + self._transport.release_connection() return RequestErrorInfo( error=error, diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 71fcc40c6..c7c4289ec 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -2,8 +2,7 @@ import time import logging import json -from concurrent.futures import ThreadPoolExecutor -from concurrent.futures import Future +from concurrent.futures import ThreadPoolExecutor, wait from datetime import datetime, timezone from typing import List, Dict, Any, Optional, TYPE_CHECKING from databricks.sql.telemetry.models.event import ( @@ -182,6 +181,7 @@ def __init__( self._user_agent = None self._events_batch = [] self._lock = threading.RLock() + self._pending_futures = set() self._driver_connection_params = None self._host_url = host_url self._executor = executor @@ -245,6 +245,9 @@ def _send_telemetry(self, events): timeout=900, ) + with self._lock: + self._pending_futures.add(future) + future.add_done_callback( lambda fut: self._telemetry_request_callback(fut, sent_count=sent_count) ) @@ -303,6 +306,9 @@ def _telemetry_request_callback(self, future, sent_count: int): except Exception as e: logger.debug("Telemetry request failed with exception: %s", e) + finally: + with self._lock: + self._pending_futures.discard(future) def _export_telemetry_log(self, **telemetry_event_kwargs): """ @@ -356,10 +362,30 @@ def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id): ) def close(self): - """Flush remaining events before closing""" - logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) + """Schedule client closure.""" + logger.debug( + "Scheduling closure for TelemetryClient of connection %s", + self._session_id_hex, + ) + self._executor.submit(self._close_and_wait) + + def _close_and_wait(self): + """Flush remaining events and wait for them to complete before closing.""" self._flush() + with self._lock: + pending_events = list(self._pending_futures) + + if pending_events: + logger.debug( + "Waiting for %s pending telemetry requests to complete.", + len(pending_events), + ) + wait(pending_events) + + logger.debug("Closing TelemetryClient for connection %s", self._session_id_hex) + self._http_client.close() + class TelemetryClientFactory: """ diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 26a898cb8..639c311d2 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -222,6 +222,7 @@ def test_session_management(self, sea_client, mock_http_client, thrift_session_i path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"), data={"session_id": "test-session-789", "warehouse_id": "abc123"}, ) + mock_http_client.close.assert_called_once() # Test close_session with invalid ID type with pytest.raises(ValueError) as excinfo: diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 7254b66cb..d16cd4205 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -1436,8 +1436,12 @@ def test_op_handle_respected_in_close_command(self, tcli_service_class): ) @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) - def test_session_handle_respected_in_close_session(self, tcli_service_class): + @patch("databricks.sql.auth.thrift_http_client.THttpClient", autospec=True) + def test_session_handle_respected_in_close_session( + self, mock_http_client_class, tcli_service_class + ): tcli_service_instance = tcli_service_class.return_value + mock_http_client_instance = mock_http_client_class.return_value thrift_backend = ThriftDatabricksClient( "foobar", 443, @@ -1447,12 +1451,15 @@ def test_session_handle_respected_in_close_session(self, tcli_service_class): ssl_options=SSLOptions(), http_client=MagicMock(), ) + thrift_backend._transport = mock_http_client_instance + session_id = SessionId.from_thrift_handle(self.session_handle) thrift_backend.close_session(session_id) self.assertEqual( tcli_service_instance.CloseSession.call_args[0][0].sessionHandle, self.session_handle, ) + mock_http_client_instance.close.assert_called_once() @patch("databricks.sql.backend.thrift_backend.TCLIService.Client", autospec=True) def test_non_arrow_non_column_based_set_triggers_exception(