From 3b33d1a8785257c69b58ed483fded795eae00a2c Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 29 Oct 2025 13:44:57 +0300 Subject: [PATCH 1/2] Refactor auth token refresh logic --- tests/aio/test_credentials.py | 100 ++++++++++++++++++++ tests/auth/test_static_credentials.py | 84 +++++++++++++++++ ydb/aio/credentials.py | 104 +++++---------------- ydb/aio/iam.py | 1 - ydb/credentials.py | 130 +++++++------------------- ydb/iam/auth.py | 1 - 6 files changed, 239 insertions(+), 181 deletions(-) diff --git a/tests/aio/test_credentials.py b/tests/aio/test_credentials.py index 5000541c..796d895b 100644 --- a/tests/aio/test_credentials.py +++ b/tests/aio/test_credentials.py @@ -5,6 +5,8 @@ import tempfile import os import json +import asyncio +from unittest.mock import patch, AsyncMock import tests.auth.test_credentials import tests.oauth2_token_exchange @@ -112,3 +114,101 @@ def serve(s): except Exception: os.remove(cfg_file_name) raise + + +@pytest.mark.asyncio +async def test_token_lazy_refresh(): + credentials = ServiceAccountCredentialsForTest( + tests.auth.test_credentials.SERVICE_ACCOUNT_ID, + tests.auth.test_credentials.ACCESS_KEY_ID, + tests.auth.test_credentials.PRIVATE_KEY, + "localhost:0", + ) + + mock_response = {"access_token": "token_v1", "expires_in": 3600} + credentials._make_token_request = AsyncMock(return_value=mock_response) + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + token1 = await credentials.token() + assert token1 == "token_v1" + assert credentials._make_token_request.call_count == 1 + + token2 = await credentials.token() + assert token2 == "token_v1" + assert credentials._make_token_request.call_count == 1 + + mock_time.return_value = 2000 + credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600} + + token3 = await credentials.token() + assert token3 == "token_v2" + assert credentials._make_token_request.call_count == 2 + + +@pytest.mark.asyncio +async def test_token_double_check_locking(): + credentials = ServiceAccountCredentialsForTest( + tests.auth.test_credentials.SERVICE_ACCOUNT_ID, + tests.auth.test_credentials.ACCESS_KEY_ID, + tests.auth.test_credentials.PRIVATE_KEY, + "localhost:0", + ) + + call_count = 0 + + async def mock_make_request(): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.01) + return {"access_token": f"token_v{call_count}", "expires_in": 3600} + + credentials._make_token_request = mock_make_request + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + tasks = [credentials.token() for _ in range(10)] + results = await asyncio.gather(*tasks) + + assert len(set(results)) == 1 + assert call_count == 1 + + +@pytest.mark.asyncio +async def test_token_expiration_calculation(): + credentials = ServiceAccountCredentialsForTest( + tests.auth.test_credentials.SERVICE_ACCOUNT_ID, + tests.auth.test_credentials.ACCESS_KEY_ID, + tests.auth.test_credentials.PRIVATE_KEY, + "localhost:0", + ) + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + credentials._make_token_request = AsyncMock(return_value={"access_token": "token", "expires_in": 3600}) + + await credentials.token() + + expected_expires = 1000 + min(1800, 3600 / 4) + assert credentials._expires_in == expected_expires + + +@pytest.mark.asyncio +async def test_token_refresh_error_handling(): + credentials = ServiceAccountCredentialsForTest( + tests.auth.test_credentials.SERVICE_ACCOUNT_ID, + tests.auth.test_credentials.ACCESS_KEY_ID, + tests.auth.test_credentials.PRIVATE_KEY, + "localhost:0", + ) + + credentials._make_token_request = AsyncMock(side_effect=Exception("Network error")) + + with pytest.raises(Exception) as exc_info: + await credentials.token() + + assert "Network error" in str(exc_info.value) + assert credentials.last_error == "Network error" diff --git a/tests/auth/test_static_credentials.py b/tests/auth/test_static_credentials.py index a9239f2a..19e2f9a4 100644 --- a/tests/auth/test_static_credentials.py +++ b/tests/auth/test_static_credentials.py @@ -1,5 +1,6 @@ import pytest import ydb +from unittest.mock import patch, MagicMock USERNAME = "root" @@ -45,3 +46,86 @@ def test_static_credentials_wrong_creds(endpoint, database): with pytest.raises(ydb.ConnectionFailure): with ydb.Driver(driver_config=driver_config) as driver: driver.wait(5, fail_fast=True) + + +def test_token_lazy_refresh(): + credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD) + + mock_response = {"access_token": "token_v1", "expires_in": 3600} + credentials._make_token_request = MagicMock(return_value=mock_response) + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + token1 = credentials.token + assert token1 == "token_v1" + assert credentials._make_token_request.call_count == 1 + + token2 = credentials.token + assert token2 == "token_v1" + assert credentials._make_token_request.call_count == 1 + + mock_time.return_value = 2000 + credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600} + + token3 = credentials.token + assert token3 == "token_v2" + assert credentials._make_token_request.call_count == 2 + + +def test_token_double_check_locking(): + credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD) + + call_count = 0 + + def mock_make_request(): + nonlocal call_count + call_count += 1 + return {"access_token": f"token_v{call_count}", "expires_in": 3600} + + credentials._make_token_request = mock_make_request + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + import threading + + results = [] + + def get_token(): + results.append(credentials.token) + + threads = [threading.Thread(target=get_token) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(set(results)) == 1 + assert call_count == 1 + + +def test_token_expiration_calculation(): + credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD) + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + credentials._make_token_request = MagicMock(return_value={"access_token": "token", "expires_in": 3600}) + + credentials.token + + expected_expires = 1000 + min(1800, 3600 / 4) + assert credentials._expires_in == expected_expires + + +def test_token_refresh_error_handling(): + credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD) + + credentials._make_token_request = MagicMock(side_effect=Exception("Network error")) + + with pytest.raises(ydb.ConnectionError) as exc_info: + credentials.token + + assert "Network error" in str(exc_info.value) + assert credentials.last_error == "Network error" diff --git a/ydb/aio/credentials.py b/ydb/aio/credentials.py index 03c96a37..2665efed 100644 --- a/ydb/aio/credentials.py +++ b/ydb/aio/credentials.py @@ -10,57 +10,10 @@ YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket" -class _OneToManyValue(object): - def __init__(self): - self._value = None - self._condition = asyncio.Condition() - - async def consume(self, timeout=3): - async with self._condition: - if self._value is None: - try: - await asyncio.wait_for(self._condition.wait(), timeout=timeout) - except Exception: - return self._value - return self._value - - async def update(self, n_value): - async with self._condition: - prev_value = self._value - self._value = n_value - if prev_value is None: - self._condition.notify_all() - - -class _AtMostOneExecution(object): - def __init__(self): - self._can_schedule = True - self._lock = asyncio.Lock() # Lock to guarantee only one execution - - async def _wrapped_execution(self, callback): - await self._lock.acquire() - try: - res = callback() - if asyncio.iscoroutine(res): - await res - except Exception: - pass - - finally: - self._lock.release() - self._can_schedule = True - - def submit(self, callback): - if self._can_schedule: - self._can_schedule = False - asyncio.ensure_future(self._wrapped_execution(callback)) - - class AbstractExpiringTokenCredentials(credentials.AbstractExpiringTokenCredentials): def __init__(self): super(AbstractExpiringTokenCredentials, self).__init__() - self._tp = _AtMostOneExecution() - self._cached_token = _OneToManyValue() + self._token_lock = asyncio.Lock() @abc.abstractmethod async def _make_token_request(self): @@ -72,51 +25,40 @@ async def get_auth_token(self) -> str: return token return "" - async def _refresh(self): + async def _refresh_token(self): current_time = time.time() - self._log_refresh_start(current_time) try: - auth_metadata = await self._make_token_request() - await self._cached_token.update(auth_metadata["access_token"]) - self._update_expiration_info(auth_metadata) - self.logger.info( - "Token refresh successful. current_time %s, refresh_in %s", + self.logger.debug( + "Refreshing token async, current_time: %s, expires_in: %s", current_time, - self._refresh_in, + self._expires_in, ) - except (KeyboardInterrupt, SystemExit): - return + token_response = await self._make_token_request() + self._update_token_info(token_response, current_time) - except Exception as e: - self.last_error = str(e) - await asyncio.sleep(1) - self._tp.submit(self._refresh) + self.logger.info("Token refreshed successfully async, expires_in: %s", self._expires_in) + self.last_error = None - except BaseException as e: + except Exception as e: self.last_error = str(e) - raise - - async def token(self): - current_time = time.time() - if current_time > self._refresh_in: - self._tp.submit(self._refresh) - - cached_token = await self._cached_token.consume(timeout=3) - if cached_token is None: - if self.last_error is None: - raise issues.ConnectionError( - "%s: timeout occurred while waiting for token.\n%s" - % ( - self.__class__.__name__, - self.extra_error_message, - ) - ) + self.logger.error("Failed to refresh token async: %s", e) raise issues.ConnectionError( "%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message) ) - return cached_token + + async def token(self): + if self._is_token_valid(): + return self._cached_token + + async with self._token_lock: + if self._is_token_valid(): + return self._cached_token + + await self._refresh_token() + + return self._cached_token async def auth_metadata(self): return [(credentials.YDB_AUTH_TICKET_HEADER, await self.token())] diff --git a/ydb/aio/iam.py b/ydb/aio/iam.py index 5a2a29f6..27417743 100644 --- a/ydb/aio/iam.py +++ b/ydb/aio/iam.py @@ -102,7 +102,6 @@ def __init__(self, metadata_url=None): super(MetadataUrlCredentials, self).__init__() assert aiohttp is not None, "Install aiohttp library to use metadata credentials provider" self._metadata_url = auth.DEFAULT_METADATA_URL if metadata_url is None else metadata_url - self._tp.submit(self._refresh) self.extra_error_message = "Check that metadata service configured properly and application deployed in VM or function at Yandex.Cloud." async def _make_token_request(self): diff --git a/ydb/credentials.py b/ydb/credentials.py index ab721d0b..1adca547 100644 --- a/ydb/credentials.py +++ b/ydb/credentials.py @@ -5,7 +5,6 @@ from . import tracing, issues, connection from . import settings as settings_impl import threading -from concurrent import futures import logging import time @@ -49,130 +48,65 @@ def _update_driver_config(self, driver_config): pass -class OneToManyValue(object): - def __init__(self): - self._value = None - self._condition = threading.Condition() - - def consume(self, timeout=3): - with self._condition: - if self._value is None: - self._condition.wait(timeout=timeout) - return self._value - - def update(self, n_value): - with self._condition: - prev_value = self._value - self._value = n_value - if prev_value is None: - self._condition.notify_all() - - -class AtMostOneExecution(object): - def __init__(self): - self._can_schedule = True - self._lock = threading.Lock() - self._tp = futures.ThreadPoolExecutor(1) - - def wrapped_execution(self, callback): - try: - callback() - except Exception: - pass - - finally: - self.cleanup() - - def submit(self, callback): - with self._lock: - if self._can_schedule: - self._tp.submit(self.wrapped_execution, callback) - self._can_schedule = False - - def cleanup(self): - with self._lock: - self._can_schedule = True - - class AbstractExpiringTokenCredentials(Credentials): def __init__(self, tracer=None): super(AbstractExpiringTokenCredentials, self).__init__(tracer) self._expires_in = 0 - self._refresh_in = 0 - self._hour = 60 * 60 - self._cached_token = OneToManyValue() - self._tp = AtMostOneExecution() + self._cached_token = None + self._token_lock = threading.Lock() self.logger = logger.getChild(self.__class__.__name__) self.last_error = None self.extra_error_message = "" + self._hour = 60 * 60 @abc.abstractmethod def _make_token_request(self): pass - def _log_refresh_start(self, current_time): - self.logger.debug("Start refresh token from metadata") - if current_time > self._refresh_in: - self.logger.info( - "Cached token reached refresh_in deadline, current time %s, deadline %s", - current_time, - self._refresh_in, - ) - - if current_time > self._expires_in and self._expires_in > 0: - self.logger.error( - "Cached token reached expires_in deadline, current time %s, deadline %s", - current_time, - self._expires_in, - ) + def _is_token_valid(self): + current_time = time.time() + return self._cached_token is not None and current_time <= self._expires_in - def _update_expiration_info(self, auth_metadata): - self._expires_in = time.time() + min(self._hour, auth_metadata["expires_in"] / 2) - self._refresh_in = time.time() + min(self._hour / 2, auth_metadata["expires_in"] / 4) + def _update_token_info(self, token_response, current_time): + self._expires_in = current_time + min(self._hour / 2, token_response["expires_in"] / 4) + self._cached_token = token_response["access_token"] - def _refresh(self): + def _refresh_token(self): current_time = time.time() - self._log_refresh_start(current_time) + try: + self.logger.debug("Refreshing token, current_time: %s, expires_in: %s", current_time, self._expires_in) + token_response = self._make_token_request() - self._cached_token.update(token_response["access_token"]) - self._update_expiration_info(token_response) - self.logger.info( - "Token refresh successful. current_time %s, refresh_in %s", - current_time, - self._refresh_in, - ) + self._update_token_info(token_response, current_time) - except (KeyboardInterrupt, SystemExit): - return + self.logger.info("Token refreshed successfully, expires_in: %s", self._expires_in) + self.last_error = None except Exception as e: self.last_error = str(e) - time.sleep(1) - self._tp.submit(self._refresh) + self.logger.error("Failed to refresh token: %s", e) + raise issues.ConnectionError( + "%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message) + ) @property @tracing.with_trace() def token(self): - current_time = time.time() - if current_time > self._refresh_in: + if self._is_token_valid(): + tracing.trace(self.tracer, {"consumed": True}) + return self._cached_token + + with self._token_lock: + if self._is_token_valid(): + tracing.trace(self.tracer, {"consumed": True}) + return self._cached_token + tracing.trace(self.tracer, {"refresh": True}) - self._tp.submit(self._refresh) - cached_token = self._cached_token.consume(timeout=3) + self._refresh_token() + tracing.trace(self.tracer, {"consumed": True}) - if cached_token is None: - if self.last_error is None: - raise issues.ConnectionError( - "%s: timeout occurred while waiting for token.\n%s" - % ( - self.__class__.__name__, - self.extra_error_message, - ) - ) - raise issues.ConnectionError( - "%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message) - ) - return cached_token + return self._cached_token def auth_metadata(self): return [(YDB_AUTH_TICKET_HEADER, self.token)] diff --git a/ydb/iam/auth.py b/ydb/iam/auth.py index 688deded..21ce9529 100644 --- a/ydb/iam/auth.py +++ b/ydb/iam/auth.py @@ -185,7 +185,6 @@ def __init__(self, metadata_url=None, tracer=None): "Check that metadata service configured properly since we failed to fetch it from metadata_url." ) self._metadata_url = DEFAULT_METADATA_URL if metadata_url is None else metadata_url - self._tp.submit(self._refresh) @tracing.with_trace() def _make_token_request(self): From f1712bc0971b19cc273fb92b6a63a25b69bcd55f Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Thu, 30 Oct 2025 19:35:50 +0300 Subject: [PATCH 2/2] use hybrid approach --- tests/aio/test_credentials.py | 57 ++++++++++++++++++++++++-- tests/auth/test_static_credentials.py | 59 +++++++++++++++++++++++---- ydb/aio/credentials.py | 39 ++++++++++++++---- ydb/aio/iam.py | 1 + ydb/credentials.py | 54 ++++++++++++++++++++---- 5 files changed, 184 insertions(+), 26 deletions(-) diff --git a/tests/aio/test_credentials.py b/tests/aio/test_credentials.py index 796d895b..e73f3522 100644 --- a/tests/aio/test_credentials.py +++ b/tests/aio/test_credentials.py @@ -6,7 +6,7 @@ import os import json import asyncio -from unittest.mock import patch, AsyncMock +from unittest.mock import patch, AsyncMock, MagicMock import tests.auth.test_credentials import tests.oauth2_token_exchange @@ -125,6 +125,8 @@ async def test_token_lazy_refresh(): "localhost:0", ) + credentials._tp.submit = MagicMock() + mock_response = {"access_token": "token_v1", "expires_in": 3600} credentials._make_token_request = AsyncMock(return_value=mock_response) @@ -139,7 +141,7 @@ async def test_token_lazy_refresh(): assert token2 == "token_v1" assert credentials._make_token_request.call_count == 1 - mock_time.return_value = 2000 + mock_time.return_value = 1000 + 3600 - 30 + 1 credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600} token3 = await credentials.token() @@ -156,6 +158,8 @@ async def test_token_double_check_locking(): "localhost:0", ) + credentials._tp.submit = MagicMock() + call_count = 0 async def mock_make_request(): @@ -185,6 +189,8 @@ async def test_token_expiration_calculation(): "localhost:0", ) + credentials._tp.submit = MagicMock() + with patch("time.time") as mock_time: mock_time.return_value = 1000 @@ -192,7 +198,7 @@ async def test_token_expiration_calculation(): await credentials.token() - expected_expires = 1000 + min(1800, 3600 / 4) + expected_expires = 1000 + 3600 - 30 assert credentials._expires_in == expected_expires @@ -205,6 +211,8 @@ async def test_token_refresh_error_handling(): "localhost:0", ) + credentials._tp.submit = MagicMock() + credentials._make_token_request = AsyncMock(side_effect=Exception("Network error")) with pytest.raises(Exception) as exc_info: @@ -212,3 +220,46 @@ async def test_token_refresh_error_handling(): assert "Network error" in str(exc_info.value) assert credentials.last_error == "Network error" + + +@pytest.mark.asyncio +async def test_hybrid_background_and_sync_refresh(): + credentials = ServiceAccountCredentialsForTest( + tests.auth.test_credentials.SERVICE_ACCOUNT_ID, + tests.auth.test_credentials.ACCESS_KEY_ID, + tests.auth.test_credentials.PRIVATE_KEY, + "localhost:0", + ) + + call_count = 0 + background_calls = [] + + async def mock_make_request(): + nonlocal call_count + call_count += 1 + return {"access_token": f"token_v{call_count}", "expires_in": 3600} + + def mock_submit(callback): + background_calls.append(callback) + + credentials._make_token_request = mock_make_request + credentials._tp.submit = mock_submit + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + token1 = await credentials.token() + assert token1 == "token_v1" + assert call_count == 1 + assert len(background_calls) == 0 + + mock_time.return_value = 1000 + min(1800, 3600 / 10) + 1 + token2 = await credentials.token() + assert token2 == "token_v1" + assert call_count == 1 + assert len(background_calls) == 1 + + mock_time.return_value = 1000 + 3600 - 30 + 1 + token3 = await credentials.token() + assert token3 == "token_v2" + assert call_count == 2 diff --git a/tests/auth/test_static_credentials.py b/tests/auth/test_static_credentials.py index 19e2f9a4..1e2938c9 100644 --- a/tests/auth/test_static_credentials.py +++ b/tests/auth/test_static_credentials.py @@ -51,6 +51,8 @@ def test_static_credentials_wrong_creds(endpoint, database): def test_token_lazy_refresh(): credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD) + credentials._tp.submit = MagicMock() + mock_response = {"access_token": "token_v1", "expires_in": 3600} credentials._make_token_request = MagicMock(return_value=mock_response) @@ -65,7 +67,7 @@ def test_token_lazy_refresh(): assert token2 == "token_v1" assert credentials._make_token_request.call_count == 1 - mock_time.return_value = 2000 + mock_time.return_value = 1000 + 3600 - 30 + 1 credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600} token3 = credentials.token @@ -75,6 +77,7 @@ def test_token_lazy_refresh(): def test_token_double_check_locking(): credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD) + credentials._tp.submit = MagicMock() call_count = 0 @@ -108,6 +111,8 @@ def get_token(): def test_token_expiration_calculation(): credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD) + credentials._tp.submit = MagicMock() + with patch("time.time") as mock_time: mock_time.return_value = 1000 @@ -115,17 +120,57 @@ def test_token_expiration_calculation(): credentials.token - expected_expires = 1000 + min(1800, 3600 / 4) + expected_expires = 1000 + 3600 - 30 assert credentials._expires_in == expected_expires def test_token_refresh_error_handling(): credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD) - + credentials._tp.submit = MagicMock() credentials._make_token_request = MagicMock(side_effect=Exception("Network error")) - with pytest.raises(ydb.ConnectionError) as exc_info: - credentials.token + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + 3600 + + with pytest.raises(ydb.ConnectionError) as exc_info: + credentials.token + + assert "Network error" in str(exc_info.value) + assert credentials.last_error == "Network error" + + +def test_hybrid_background_and_sync_refresh(): + credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD) + + call_count = 0 + background_calls = [] + + def mock_make_request(): + nonlocal call_count + call_count += 1 + return {"access_token": f"token_v{call_count}", "expires_in": 3600} - assert "Network error" in str(exc_info.value) - assert credentials.last_error == "Network error" + def mock_submit(callback): + background_calls.append(callback) + + credentials._make_token_request = mock_make_request + credentials._tp.submit = mock_submit + + with patch("time.time") as mock_time: + mock_time.return_value = 1000 + + token1 = credentials.token + assert token1 == "token_v1" + assert call_count == 1 + assert len(background_calls) == 0 + + mock_time.return_value = 1000 + min(1800, 3600 / 10) + 1 + token2 = credentials.token + assert token2 == "token_v1" + assert call_count == 1 + assert len(background_calls) == 1 + + mock_time.return_value = 1000 + 3600 - 30 + 1 + token3 = credentials.token + assert token3 == "token_v2" + assert call_count == 2 diff --git a/ydb/aio/credentials.py b/ydb/aio/credentials.py index 2665efed..6a1d6333 100644 --- a/ydb/aio/credentials.py +++ b/ydb/aio/credentials.py @@ -10,10 +10,31 @@ YDB_AUTH_TICKET_HEADER = "x-ydb-auth-ticket" +class AtMostOneExecution(object): + def __init__(self): + self._can_schedule = True + self._lock = asyncio.Lock() + + async def wrapped_execution(self, callback): + async with self._lock: + try: + await callback() + except Exception: + pass + finally: + self._can_schedule = True + + def submit(self, callback): + if self._can_schedule: + self._can_schedule = False + asyncio.create_task(self.wrapped_execution(callback)) + + class AbstractExpiringTokenCredentials(credentials.AbstractExpiringTokenCredentials): def __init__(self): super(AbstractExpiringTokenCredentials, self).__init__() self._token_lock = asyncio.Lock() + self._tp = AtMostOneExecution() @abc.abstractmethod async def _make_token_request(self): @@ -25,14 +46,12 @@ async def get_auth_token(self) -> str: return token return "" - async def _refresh_token(self): + async def _refresh_token(self, should_raise=False): current_time = time.time() try: self.logger.debug( - "Refreshing token async, current_time: %s, expires_in: %s", - current_time, - self._expires_in, + "Refreshing token async, current_time: %s, expires_in: %s", current_time, self._expires_in ) token_response = await self._make_token_request() @@ -44,19 +63,23 @@ async def _refresh_token(self): except Exception as e: self.last_error = str(e) self.logger.error("Failed to refresh token async: %s", e) - raise issues.ConnectionError( - "%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message) - ) + if should_raise: + raise issues.ConnectionError( + "%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message) + ) async def token(self): if self._is_token_valid(): + if self._should_refresh(): + self._tp.submit(self._refresh_token) + return self._cached_token async with self._token_lock: if self._is_token_valid(): return self._cached_token - await self._refresh_token() + await self._refresh_token(should_raise=True) return self._cached_token diff --git a/ydb/aio/iam.py b/ydb/aio/iam.py index 27417743..6c7f762c 100644 --- a/ydb/aio/iam.py +++ b/ydb/aio/iam.py @@ -103,6 +103,7 @@ def __init__(self, metadata_url=None): assert aiohttp is not None, "Install aiohttp library to use metadata credentials provider" self._metadata_url = auth.DEFAULT_METADATA_URL if metadata_url is None else metadata_url self.extra_error_message = "Check that metadata service configured properly and application deployed in VM or function at Yandex.Cloud." + self._tp.submit(self._refresh_token) async def _make_token_request(self): timeout = aiohttp.ClientTimeout(total=2) diff --git a/ydb/credentials.py b/ydb/credentials.py index 1adca547..c7e1cec2 100644 --- a/ydb/credentials.py +++ b/ydb/credentials.py @@ -4,6 +4,7 @@ from . import tracing, issues, connection from . import settings as settings_impl +from concurrent import futures import threading import logging import time @@ -21,6 +22,32 @@ logger = logging.getLogger(__name__) +class AtMostOneExecution(object): + def __init__(self): + self._can_schedule = True + self._lock = threading.Lock() + self._tp = futures.ThreadPoolExecutor(1) + + def wrapped_execution(self, callback): + try: + callback() + except Exception: + pass + + finally: + self.cleanup() + + def submit(self, callback): + with self._lock: + if self._can_schedule: + self._tp.submit(self.wrapped_execution, callback) + self._can_schedule = False + + def cleanup(self): + with self._lock: + self._can_schedule = True + + class AbstractCredentials(abc.ABC): """ An abstract class that provides auth metadata @@ -51,6 +78,7 @@ def _update_driver_config(self, driver_config): class AbstractExpiringTokenCredentials(Credentials): def __init__(self, tracer=None): super(AbstractExpiringTokenCredentials, self).__init__(tracer) + self._refresh_in = 0 self._expires_in = 0 self._cached_token = None self._token_lock = threading.Lock() @@ -58,20 +86,25 @@ def __init__(self, tracer=None): self.last_error = None self.extra_error_message = "" self._hour = 60 * 60 + self._tp = AtMostOneExecution() + self._time_shift_protection_seconds = 30 @abc.abstractmethod def _make_token_request(self): pass def _is_token_valid(self): - current_time = time.time() - return self._cached_token is not None and current_time <= self._expires_in + return self._cached_token is not None and time.time() <= self._expires_in + + def _should_refresh(self): + return time.time() >= self._refresh_in def _update_token_info(self, token_response, current_time): - self._expires_in = current_time + min(self._hour / 2, token_response["expires_in"] / 4) + self._refresh_in = current_time + min(self._hour / 2, token_response["expires_in"] / 10) + self._expires_in = current_time + token_response["expires_in"] - self._time_shift_protection_seconds self._cached_token = token_response["access_token"] - def _refresh_token(self): + def _refresh_token(self, should_raise=False): current_time = time.time() try: @@ -86,14 +119,19 @@ def _refresh_token(self): except Exception as e: self.last_error = str(e) self.logger.error("Failed to refresh token: %s", e) - raise issues.ConnectionError( - "%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message) - ) + if should_raise: + raise issues.ConnectionError( + "%s: %s.\n%s" % (self.__class__.__name__, self.last_error, self.extra_error_message) + ) @property @tracing.with_trace() def token(self): if self._is_token_valid(): + if self._should_refresh(): + tracing.trace(self.tracer, {"refresh": True}) + self._tp.submit(self._refresh_token) + tracing.trace(self.tracer, {"consumed": True}) return self._cached_token @@ -103,7 +141,7 @@ def token(self): return self._cached_token tracing.trace(self.tracer, {"refresh": True}) - self._refresh_token() + self._refresh_token(should_raise=True) tracing.trace(self.tracer, {"consumed": True}) return self._cached_token