diff --git a/tests/aio/test_credentials.py b/tests/aio/test_credentials.py index 5000541c..e73f3522 100644 --- a/tests/aio/test_credentials.py +++ b/tests/aio/test_credentials.py @@ -5,6 +5,8 @@ import tempfile import os import json +import asyncio +from unittest.mock import patch, AsyncMock, MagicMock import tests.auth.test_credentials import tests.oauth2_token_exchange @@ -112,3 +114,152 @@ def serve(s): except Exception: os.remove(cfg_file_name) raise + + +@pytest.mark.asyncio +async def test_token_lazy_refresh(): + credentials = ServiceAccountCredentialsForTest( + tests.auth.test_credentials.SERVICE_ACCOUNT_ID, + tests.auth.test_credentials.ACCESS_KEY_ID, + tests.auth.test_credentials.PRIVATE_KEY, + "localhost:0", + ) + + credentials._tp.submit = MagicMock() + + mock_response = {"access_token": "token_v1", "expires_in": 3600} + credentials._make_token_request = AsyncMock(return_value=mock_response) + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + token1 = await credentials.token() + assert token1 == "token_v1" + assert credentials._make_token_request.call_count == 1 + + token2 = await credentials.token() + assert token2 == "token_v1" + assert credentials._make_token_request.call_count == 1 + + mock_time.return_value = 1000 + 3600 - 30 + 1 + credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600} + + token3 = await credentials.token() + assert token3 == "token_v2" + assert credentials._make_token_request.call_count == 2 + + +@pytest.mark.asyncio +async def test_token_double_check_locking(): + credentials = ServiceAccountCredentialsForTest( + tests.auth.test_credentials.SERVICE_ACCOUNT_ID, + tests.auth.test_credentials.ACCESS_KEY_ID, + tests.auth.test_credentials.PRIVATE_KEY, + "localhost:0", + ) + + credentials._tp.submit = MagicMock() + + call_count = 0 + + async def mock_make_request(): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.01) + return {"access_token": f"token_v{call_count}", "expires_in": 3600} + + credentials._make_token_request = mock_make_request + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + tasks = [credentials.token() for _ in range(10)] + results = await asyncio.gather(*tasks) + + assert len(set(results)) == 1 + assert call_count == 1 + + +@pytest.mark.asyncio +async def test_token_expiration_calculation(): + credentials = ServiceAccountCredentialsForTest( + tests.auth.test_credentials.SERVICE_ACCOUNT_ID, + tests.auth.test_credentials.ACCESS_KEY_ID, + tests.auth.test_credentials.PRIVATE_KEY, + "localhost:0", + ) + + credentials._tp.submit = MagicMock() + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + credentials._make_token_request = AsyncMock(return_value={"access_token": "token", "expires_in": 3600}) + + await credentials.token() + + expected_expires = 1000 + 3600 - 30 + assert credentials._expires_in == expected_expires + + +@pytest.mark.asyncio +async def test_token_refresh_error_handling(): + credentials = ServiceAccountCredentialsForTest( + tests.auth.test_credentials.SERVICE_ACCOUNT_ID, + tests.auth.test_credentials.ACCESS_KEY_ID, + tests.auth.test_credentials.PRIVATE_KEY, + "localhost:0", + ) + + credentials._tp.submit = MagicMock() + + credentials._make_token_request = AsyncMock(side_effect=Exception("Network error")) + + with pytest.raises(Exception) as exc_info: + await credentials.token() + + assert "Network error" in str(exc_info.value) + assert credentials.last_error == "Network error" + + +@pytest.mark.asyncio +async def test_hybrid_background_and_sync_refresh(): + credentials = ServiceAccountCredentialsForTest( + tests.auth.test_credentials.SERVICE_ACCOUNT_ID, + tests.auth.test_credentials.ACCESS_KEY_ID, + tests.auth.test_credentials.PRIVATE_KEY, + "localhost:0", + ) + + call_count = 0 + background_calls = [] + + async def mock_make_request(): + nonlocal call_count + call_count += 1 + return {"access_token": f"token_v{call_count}", "expires_in": 3600} + + def mock_submit(callback): + background_calls.append(callback) + + credentials._make_token_request = mock_make_request + credentials._tp.submit = mock_submit + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + token1 = await credentials.token() + assert token1 == "token_v1" + assert call_count == 1 + assert len(background_calls) == 0 + + mock_time.return_value = 1000 + min(1800, 3600 / 10) + 1 + token2 = await credentials.token() + assert token2 == "token_v1" + assert call_count == 1 + assert len(background_calls) == 1 + + mock_time.return_value = 1000 + 3600 - 30 + 1 + token3 = await credentials.token() + assert token3 == "token_v2" + assert call_count == 2 diff --git a/tests/auth/test_static_credentials.py b/tests/auth/test_static_credentials.py index a9239f2a..1e2938c9 100644 --- a/tests/auth/test_static_credentials.py +++ b/tests/auth/test_static_credentials.py @@ -1,5 +1,6 @@ import pytest import ydb +from unittest.mock import patch, MagicMock USERNAME = "root" @@ -45,3 +46,131 @@ def test_static_credentials_wrong_creds(endpoint, database): with pytest.raises(ydb.ConnectionFailure): with ydb.Driver(driver_config=driver_config) as driver: driver.wait(5, fail_fast=True) + + +def test_token_lazy_refresh(): + credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD) + + credentials._tp.submit = MagicMock() + + mock_response = {"access_token": "token_v1", "expires_in": 3600} + credentials._make_token_request = MagicMock(return_value=mock_response) + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + token1 = credentials.token + assert token1 == "token_v1" + assert credentials._make_token_request.call_count == 1 + + token2 = credentials.token + assert token2 == "token_v1" + assert credentials._make_token_request.call_count == 1 + + mock_time.return_value = 1000 + 3600 - 30 + 1 + credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600} + + token3 = credentials.token + assert token3 == "token_v2" + assert credentials._make_token_request.call_count == 2 + + +def test_token_double_check_locking(): + credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD) + credentials._tp.submit = MagicMock() + + call_count = 0 + + def mock_make_request(): + nonlocal call_count + call_count += 1 + return {"access_token": f"token_v{call_count}", "expires_in": 3600} + + credentials._make_token_request = mock_make_request + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + import threading + + results = [] + + def get_token(): + results.append(credentials.token) + + threads = [threading.Thread(target=get_token) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(set(results)) == 1 + assert call_count == 1 + + +def test_token_expiration_calculation(): + credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD) + + credentials._tp.submit = MagicMock() + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + credentials._make_token_request = MagicMock(return_value={"access_token": "token", "expires_in": 3600}) + + credentials.token + + expected_expires = 1000 + 3600 - 30 + assert credentials._expires_in == expected_expires + + +def test_token_refresh_error_handling(): + credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD) + credentials._tp.submit = MagicMock() + credentials._make_token_request = MagicMock(side_effect=Exception("Network error")) + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + 3600 + + with pytest.raises(ydb.ConnectionError) as exc_info: + credentials.token + + assert "Network error" in str(exc_info.value) + assert credentials.last_error == "Network error" + + +def test_hybrid_background_and_sync_refresh(): + credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD) + + call_count = 0 + background_calls = [] + + def mock_make_request(): + nonlocal call_count + call_count += 1 + return {"access_token": f"token_v{call_count}", "expires_in": 3600} + + def mock_submit(callback): + background_calls.append(callback) + + credentials._make_token_request = mock_make_request + credentials._tp.submit = mock_submit + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + token1 = credentials.token + assert token1 == "token_v1" + assert call_count == 1 + assert len(background_calls) == 0 + + mock_time.return_value = 1000 + min(1800, 3600 / 10) + 1 + token2 = credentials.token + assert token2 == "token_v1" + assert call_count == 1 + assert len(background_calls) == 1 + + mock_time.return_value = 1000 + 3600 - 30 + 1 + token3 = credentials.token + assert token3 == "token_v2" + assert call_count == 2 diff --git a/ydb/aio/credentials.py b/ydb/aio/credentials.py index 03c96a37..6a1d6333 100644 --- a/ydb/aio/credentials.py +++ b/ydb/aio/credentials.py @@ -10,57 +10,31 @@ YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket" -class _OneToManyValue(object): - def __init__(self): - self._value = None - self._condition = asyncio.Condition() - - async def consume(self, timeout=3): - async with self._condition: - if self._value is None: - try: - await asyncio.wait_for(self._condition.wait(), timeout=timeout) - except Exception: - return self._value - return self._value - - async def update(self, n_value): - async with self._condition: - prev_value = self._value - self._value = n_value - if prev_value is None: - self._condition.notify_all() - - -class _AtMostOneExecution(object): +class AtMostOneExecution(object): def __init__(self): self._can_schedule = True - self._lock = asyncio.Lock() # Lock to guarantee only one execution - - async def _wrapped_execution(self, callback): - await self._lock.acquire() - try: - res = callback() - if asyncio.iscoroutine(res): - await res - except Exception: - pass + self._lock = asyncio.Lock() - finally: - self._lock.release() - self._can_schedule = True + async def wrapped_execution(self, callback): + async with self._lock: + try: + await callback() + except Exception: + pass + finally: + self._can_schedule = True def submit(self, callback): if self._can_schedule: self._can_schedule = False - asyncio.ensure_future(self._wrapped_execution(callback)) + asyncio.create_task(self.wrapped_execution(callback)) class AbstractExpiringTokenCredentials(credentials.AbstractExpiringTokenCredentials): def __init__(self): super(AbstractExpiringTokenCredentials, self).__init__() - self._tp = _AtMostOneExecution() - self._cached_token = _OneToManyValue() + self._token_lock = asyncio.Lock() + self._tp = AtMostOneExecution() @abc.abstractmethod async def _make_token_request(self): @@ -72,51 +46,42 @@ async def get_auth_token(self) -> str: return token return "" - async def _refresh(self): + async def _refresh_token(self, should_raise=False): current_time = time.time() - self._log_refresh_start(current_time) try: - auth_metadata = await self._make_token_request() - await self._cached_token.update(auth_metadata["access_token"]) - self._update_expiration_info(auth_metadata) - self.logger.info( - "Token refresh successful. current_time %s, refresh_in %s", - current_time, - self._refresh_in, + self.logger.debug( + "Refreshing token async, current_time: %s, expires_in: %s", current_time, self._expires_in ) - except (KeyboardInterrupt, SystemExit): - return + token_response = await self._make_token_request() + self._update_token_info(token_response, current_time) - except Exception as e: - self.last_error = str(e) - await asyncio.sleep(1) - self._tp.submit(self._refresh) + self.logger.info("Token refreshed successfully async, expires_in: %s", self._expires_in) + self.last_error = None - except BaseException as e: + except Exception as e: self.last_error = str(e) - raise + self.logger.error("Failed to refresh token async: %s", e) + if should_raise: + raise issues.ConnectionError( + "%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message) + ) async def token(self): - current_time = time.time() - if current_time > self._refresh_in: - self._tp.submit(self._refresh) + if self._is_token_valid(): + if self._should_refresh(): + self._tp.submit(self._refresh_token) - cached_token = await self._cached_token.consume(timeout=3) - if cached_token is None: - if self.last_error is None: - raise issues.ConnectionError( - "%s: timeout occurred while waiting for token.\n%s" - % ( - self.__class__.__name__, - self.extra_error_message, - ) - ) - raise issues.ConnectionError( - "%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message) - ) - return cached_token + return self._cached_token + + async with self._token_lock: + if self._is_token_valid(): + return self._cached_token + + await self._refresh_token(should_raise=True) + + return self._cached_token async def auth_metadata(self): return [(credentials.YDB_AUTH_TICKET_HEADER, await self.token())] diff --git a/ydb/aio/iam.py b/ydb/aio/iam.py index 5a2a29f6..6c7f762c 100644 --- a/ydb/aio/iam.py +++ b/ydb/aio/iam.py @@ -102,8 +102,8 @@ def __init__(self, metadata_url=None): super(MetadataUrlCredentials, self).__init__() assert aiohttp is not None, "Install aiohttp library to use metadata credentials provider" self._metadata_url = auth.DEFAULT_METADATA_URL if metadata_url is None else metadata_url - self._tp.submit(self._refresh) self.extra_error_message = "Check that metadata service configured properly and application deployed in VM or function at Yandex.Cloud." + self._tp.submit(self._refresh_token) async def _make_token_request(self): timeout = aiohttp.ClientTimeout(total=2) diff --git a/ydb/credentials.py b/ydb/credentials.py index ab721d0b..c7e1cec2 100644 --- a/ydb/credentials.py +++ b/ydb/credentials.py @@ -4,8 +4,8 @@ from . import tracing, issues, connection from . import settings as settings_impl -import threading from concurrent import futures +import threading import logging import time @@ -22,6 +22,32 @@ logger = logging.getLogger(__name__) +class AtMostOneExecution(object): + def __init__(self): + self._can_schedule = True + self._lock = threading.Lock() + self._tp = futures.ThreadPoolExecutor(1) + + def wrapped_execution(self, callback): + try: + callback() + except Exception: + pass + + finally: + self.cleanup() + + def submit(self, callback): + with self._lock: + if self._can_schedule: + self._tp.submit(self.wrapped_execution, callback) + self._can_schedule = False + + def cleanup(self): + with self._lock: + self._can_schedule = True + + class AbstractCredentials(abc.ABC): """ An abstract class that provides auth metadata @@ -49,130 +75,76 @@ def _update_driver_config(self, driver_config): pass -class OneToManyValue(object): - def __init__(self): - self._value = None - self._condition = threading.Condition() - - def consume(self, timeout=3): - with self._condition: - if self._value is None: - self._condition.wait(timeout=timeout) - return self._value - - def update(self, n_value): - with self._condition: - prev_value = self._value - self._value = n_value - if prev_value is None: - self._condition.notify_all() - - -class AtMostOneExecution(object): - def __init__(self): - self._can_schedule = True - self._lock = threading.Lock() - self._tp = futures.ThreadPoolExecutor(1) - - def wrapped_execution(self, callback): - try: - callback() - except Exception: - pass - - finally: - self.cleanup() - - def submit(self, callback): - with self._lock: - if self._can_schedule: - self._tp.submit(self.wrapped_execution, callback) - self._can_schedule = False - - def cleanup(self): - with self._lock: - self._can_schedule = True - - class AbstractExpiringTokenCredentials(Credentials): def __init__(self, tracer=None): super(AbstractExpiringTokenCredentials, self).__init__(tracer) - self._expires_in = 0 self._refresh_in = 0 - self._hour = 60 * 60 - self._cached_token = OneToManyValue() - self._tp = AtMostOneExecution() + self._expires_in = 0 + self._cached_token = None + self._token_lock = threading.Lock() self.logger = logger.getChild(self.__class__.__name__) self.last_error = None self.extra_error_message = "" + self._hour = 60 * 60 + self._tp = AtMostOneExecution() + self._time_shift_protection_seconds = 30 @abc.abstractmethod def _make_token_request(self): pass - def _log_refresh_start(self, current_time): - self.logger.debug("Start refresh token from metadata") - if current_time > self._refresh_in: - self.logger.info( - "Cached token reached refresh_in deadline, current time %s, deadline %s", - current_time, - self._refresh_in, - ) + def _is_token_valid(self): + return self._cached_token is not None and time.time() <= self._expires_in - if current_time > self._expires_in and self._expires_in > 0: - self.logger.error( - "Cached token reached expires_in deadline, current time %s, deadline %s", - current_time, - self._expires_in, - ) + def _should_refresh(self): + return time.time() >= self._refresh_in - def _update_expiration_info(self, auth_metadata): - self._expires_in = time.time() + min(self._hour, auth_metadata["expires_in"] / 2) - self._refresh_in = time.time() + min(self._hour / 2, auth_metadata["expires_in"] / 4) + def _update_token_info(self, token_response, current_time): + self._refresh_in = current_time + min(self._hour / 2, token_response["expires_in"] / 10) + self._expires_in = current_time + token_response["expires_in"] - self._time_shift_protection_seconds + self._cached_token = token_response["access_token"] - def _refresh(self): + def _refresh_token(self, should_raise=False): current_time = time.time() - self._log_refresh_start(current_time) + try: + self.logger.debug("Refreshing token, current_time: %s, expires_in: %s", current_time, self._expires_in) + token_response = self._make_token_request() - self._cached_token.update(token_response["access_token"]) - self._update_expiration_info(token_response) - self.logger.info( - "Token refresh successful. current_time %s, refresh_in %s", - current_time, - self._refresh_in, - ) + self._update_token_info(token_response, current_time) - except (KeyboardInterrupt, SystemExit): - return + self.logger.info("Token refreshed successfully, expires_in: %s", self._expires_in) + self.last_error = None except Exception as e: self.last_error = str(e) - time.sleep(1) - self._tp.submit(self._refresh) + self.logger.error("Failed to refresh token: %s", e) + if should_raise: + raise issues.ConnectionError( + "%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message) + ) @property @tracing.with_trace() def token(self): - current_time = time.time() - if current_time > self._refresh_in: + if self._is_token_valid(): + if self._should_refresh(): + tracing.trace(self.tracer, {"refresh": True}) + self._tp.submit(self._refresh_token) + + tracing.trace(self.tracer, {"consumed": True}) + return self._cached_token + + with self._token_lock: + if self._is_token_valid(): + tracing.trace(self.tracer, {"consumed": True}) + return self._cached_token + tracing.trace(self.tracer, {"refresh": True}) - self._tp.submit(self._refresh) - cached_token = self._cached_token.consume(timeout=3) + self._refresh_token(should_raise=True) + tracing.trace(self.tracer, {"consumed": True}) - if cached_token is None: - if self.last_error is None: - raise issues.ConnectionError( - "%s: timeout occurred while waiting for token.\n%s" - % ( - self.__class__.__name__, - self.extra_error_message, - ) - ) - raise issues.ConnectionError( - "%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message) - ) - return cached_token + return self._cached_token def auth_metadata(self): return [(YDB_AUTH_TICKET_HEADER, self.token)] diff --git a/ydb/iam/auth.py b/ydb/iam/auth.py index 688deded..21ce9529 100644 --- a/ydb/iam/auth.py +++ b/ydb/iam/auth.py @@ -185,7 +185,6 @@ def __init__(self, metadata_url=None, tracer=None): "Check that metadata service configured properly since we failed to fetch it from metadata_url." ) self._metadata_url = DEFAULT_METADATA_URL if metadata_url is None else metadata_url - self._tp.submit(self._refresh) @tracing.with_trace() def _make_token_request(self):