Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions tests/aio/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import tempfile
import os
import json
import asyncio
from unittest.mock import patch, AsyncMock, MagicMock

import tests.auth.test_credentials
import tests.oauth2_token_exchange
Expand Down Expand Up @@ -112,3 +114,152 @@ def serve(s):
except Exception:
os.remove(cfg_file_name)
raise


@pytest.mark.asyncio
async def test_token_lazy_refresh():
credentials = ServiceAccountCredentialsForTest(
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
tests.auth.test_credentials.ACCESS_KEY_ID,
tests.auth.test_credentials.PRIVATE_KEY,
"localhost:0",
)

credentials._tp.submit = MagicMock()

mock_response = {"access_token": "token_v1", "expires_in": 3600}
credentials._make_token_request = AsyncMock(return_value=mock_response)

with patch("time.time") as mock_time:
mock_time.return_value = 1000

token1 = await credentials.token()
assert token1 == "token_v1"
assert credentials._make_token_request.call_count == 1

token2 = await credentials.token()
assert token2 == "token_v1"
assert credentials._make_token_request.call_count == 1

mock_time.return_value = 1000 + 3600 - 30 + 1
credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600}

token3 = await credentials.token()
assert token3 == "token_v2"
assert credentials._make_token_request.call_count == 2


@pytest.mark.asyncio
async def test_token_double_check_locking():
credentials = ServiceAccountCredentialsForTest(
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
tests.auth.test_credentials.ACCESS_KEY_ID,
tests.auth.test_credentials.PRIVATE_KEY,
"localhost:0",
)

credentials._tp.submit = MagicMock()

call_count = 0

async def mock_make_request():
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return {"access_token": f"token_v{call_count}", "expires_in": 3600}

credentials._make_token_request = mock_make_request

with patch("time.time") as mock_time:
mock_time.return_value = 1000

tasks = [credentials.token() for _ in range(10)]
results = await asyncio.gather(*tasks)

assert len(set(results)) == 1
assert call_count == 1


@pytest.mark.asyncio
async def test_token_expiration_calculation():
credentials = ServiceAccountCredentialsForTest(
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
tests.auth.test_credentials.ACCESS_KEY_ID,
tests.auth.test_credentials.PRIVATE_KEY,
"localhost:0",
)

credentials._tp.submit = MagicMock()

with patch("time.time") as mock_time:
mock_time.return_value = 1000

credentials._make_token_request = AsyncMock(return_value={"access_token": "token", "expires_in": 3600})

await credentials.token()

expected_expires = 1000 + 3600 - 30
assert credentials._expires_in == expected_expires


@pytest.mark.asyncio
async def test_token_refresh_error_handling():
credentials = ServiceAccountCredentialsForTest(
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
tests.auth.test_credentials.ACCESS_KEY_ID,
tests.auth.test_credentials.PRIVATE_KEY,
"localhost:0",
)

credentials._tp.submit = MagicMock()

credentials._make_token_request = AsyncMock(side_effect=Exception("Network error"))

with pytest.raises(Exception) as exc_info:
await credentials.token()

assert "Network error" in str(exc_info.value)
assert credentials.last_error == "Network error"


@pytest.mark.asyncio
async def test_hybrid_background_and_sync_refresh():
credentials = ServiceAccountCredentialsForTest(
tests.auth.test_credentials.SERVICE_ACCOUNT_ID,
tests.auth.test_credentials.ACCESS_KEY_ID,
tests.auth.test_credentials.PRIVATE_KEY,
"localhost:0",
)

call_count = 0
background_calls = []

async def mock_make_request():
nonlocal call_count
call_count += 1
return {"access_token": f"token_v{call_count}", "expires_in": 3600}

def mock_submit(callback):
background_calls.append(callback)

credentials._make_token_request = mock_make_request
credentials._tp.submit = mock_submit

with patch("time.time") as mock_time:
mock_time.return_value = 1000

token1 = await credentials.token()
assert token1 == "token_v1"
assert call_count == 1
assert len(background_calls) == 0

mock_time.return_value = 1000 + min(1800, 3600 / 10) + 1
token2 = await credentials.token()
assert token2 == "token_v1"
assert call_count == 1
assert len(background_calls) == 1

mock_time.return_value = 1000 + 3600 - 30 + 1
token3 = await credentials.token()
assert token3 == "token_v2"
assert call_count == 2
129 changes: 129 additions & 0 deletions tests/auth/test_static_credentials.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import ydb
from unittest.mock import patch, MagicMock


USERNAME = "root"
Expand Down Expand Up @@ -45,3 +46,131 @@ def test_static_credentials_wrong_creds(endpoint, database):
with pytest.raises(ydb.ConnectionFailure):
with ydb.Driver(driver_config=driver_config) as driver:
driver.wait(5, fail_fast=True)


def test_token_lazy_refresh():
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)

credentials._tp.submit = MagicMock()

mock_response = {"access_token": "token_v1", "expires_in": 3600}
credentials._make_token_request = MagicMock(return_value=mock_response)

with patch("time.time") as mock_time:
mock_time.return_value = 1000

token1 = credentials.token
assert token1 == "token_v1"
assert credentials._make_token_request.call_count == 1

token2 = credentials.token
assert token2 == "token_v1"
assert credentials._make_token_request.call_count == 1

mock_time.return_value = 1000 + 3600 - 30 + 1
credentials._make_token_request.return_value = {"access_token": "token_v2", "expires_in": 3600}

token3 = credentials.token
assert token3 == "token_v2"
assert credentials._make_token_request.call_count == 2


def test_token_double_check_locking():
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
credentials._tp.submit = MagicMock()

call_count = 0

def mock_make_request():
nonlocal call_count
call_count += 1
return {"access_token": f"token_v{call_count}", "expires_in": 3600}

credentials._make_token_request = mock_make_request

with patch("time.time") as mock_time:
mock_time.return_value = 1000

import threading

results = []

def get_token():
results.append(credentials.token)

threads = [threading.Thread(target=get_token) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()

assert len(set(results)) == 1
assert call_count == 1


def test_token_expiration_calculation():
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)

credentials._tp.submit = MagicMock()

with patch("time.time") as mock_time:
mock_time.return_value = 1000

credentials._make_token_request = MagicMock(return_value={"access_token": "token", "expires_in": 3600})

credentials.token

expected_expires = 1000 + 3600 - 30
assert credentials._expires_in == expected_expires


def test_token_refresh_error_handling():
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)
credentials._tp.submit = MagicMock()
credentials._make_token_request = MagicMock(side_effect=Exception("Network error"))

with patch("time.time") as mock_time:
mock_time.return_value = 1000 + 3600

with pytest.raises(ydb.ConnectionError) as exc_info:
credentials.token

assert "Network error" in str(exc_info.value)
assert credentials.last_error == "Network error"


def test_hybrid_background_and_sync_refresh():
credentials = ydb.StaticCredentials.from_user_password(USERNAME, PASSWORD)

call_count = 0
background_calls = []

def mock_make_request():
nonlocal call_count
call_count += 1
return {"access_token": f"token_v{call_count}", "expires_in": 3600}

def mock_submit(callback):
background_calls.append(callback)

credentials._make_token_request = mock_make_request
credentials._tp.submit = mock_submit

with patch("time.time") as mock_time:
mock_time.return_value = 1000

token1 = credentials.token
assert token1 == "token_v1"
assert call_count == 1
assert len(background_calls) == 0

mock_time.return_value = 1000 + min(1800, 3600 / 10) + 1
token2 = credentials.token
assert token2 == "token_v1"
assert call_count == 1
assert len(background_calls) == 1

mock_time.return_value = 1000 + 3600 - 30 + 1
token3 = credentials.token
assert token3 == "token_v2"
assert call_count == 2
Loading
Loading