diff --git a/alibabacloud_credentials/provider/__init__.py b/alibabacloud_credentials/provider/__init__.py index bc33809..60eaf60 100644 --- a/alibabacloud_credentials/provider/__init__.py +++ b/alibabacloud_credentials/provider/__init__.py @@ -9,6 +9,8 @@ from .cli_profile import CLIProfileCredentialsProvider from .profile import ProfileCredentialsProvider from .default import DefaultCredentialsProvider +from .cloud_sso import CloudSSOCredentialsProvider +from .oauth import OAuthCredentialsProvider __all__ = [ 'StaticAKCredentialsProvider', @@ -21,5 +23,7 @@ 'URLCredentialsProvider', 'CLIProfileCredentialsProvider', 'ProfileCredentialsProvider', - 'DefaultCredentialsProvider' + 'DefaultCredentialsProvider', + 'CloudSSOCredentialsProvider', + 'OAuthCredentialsProvider' ] diff --git a/alibabacloud_credentials/provider/cli_profile.py b/alibabacloud_credentials/provider/cli_profile.py index c43260f..1f1450b 100644 --- a/alibabacloud_credentials/provider/cli_profile.py +++ b/alibabacloud_credentials/provider/cli_profile.py @@ -1,11 +1,35 @@ import os import json +import threading +import platform from typing import Any, Dict import aiofiles -from alibabacloud_credentials.provider import StaticAKCredentialsProvider, EcsRamRoleCredentialsProvider, \ - RamRoleArnCredentialsProvider, OIDCRoleArnCredentialsProvider, StaticSTSCredentialsProvider +# 跨平台文件锁支持 +if platform.system() == 'Windows': + # Windows平台使用msvcrt + import msvcrt + + HAS_MSVCRT = True + HAS_FCNTL = False +else: + # 其他平台尝试使用fcntl,如果不可用则不设文件锁 + HAS_MSVCRT = False + try: + import fcntl + + HAS_FCNTL = True + except ImportError: + HAS_FCNTL = False + +from .static_ak import StaticAKCredentialsProvider +from .ecs_ram_role import EcsRamRoleCredentialsProvider +from .ram_role_arn import RamRoleArnCredentialsProvider +from .oidc import OIDCRoleArnCredentialsProvider +from .static_sts import StaticSTSCredentialsProvider +from .cloud_sso import CloudSSOCredentialsProvider +from .oauth import OAuthCredentialsProvider, OAuthTokenUpdateCallback, OAuthTokenUpdateCallbackAsync from .refreshable import Credentials from alibabacloud_credentials_api import ICredentialsProvider from alibabacloud_credentials.utils import auth_constant as ac @@ -28,10 +52,15 @@ def _load_config(file_path: str) -> Any: class CLIProfileCredentialsProvider(ICredentialsProvider): def __init__(self, *, - profile_name: str = None): - self._profile_file = os.path.join(ac.HOME, ".aliyun/config.json") + profile_name: str = None, + profile_file: str = None, + allow_config_force_rewrite: bool = False): + self._profile_file = profile_file or os.path.join(ac.HOME, ".aliyun/config.json") self._profile_name = profile_name or au.environment_profile_name + self._allow_config_force_rewrite = allow_config_force_rewrite self.__innerProvider = None + # 文件锁,用于并发安全 + self._file_lock = threading.RLock() def _should_reload_credentials_provider(self) -> bool: if self.__innerProvider is None: @@ -163,6 +192,42 @@ def _get_credentials_provider(self, config: Dict, profile_name: str) -> ICredent sts_region_id=profile.get('sts_region'), enable_vpc=profile.get('enable_vpc'), ) + elif mode == "CloudSSO": + return CloudSSOCredentialsProvider( + sign_in_url=profile.get('cloud_sso_sign_in_url'), + account_id=profile.get('cloud_sso_account_id'), + access_config=profile.get('cloud_sso_access_config'), + access_token=profile.get('access_token'), + access_token_expire=profile.get('cloud_sso_access_token_expire'), + ) + elif mode == "OAuth": + # 获取 OAuth 配置 + site_type = profile.get('oauth_site_type', 'CN') + oauth_base_url_map = { + 'CN': 'https://oauth.aliyun.com', + 'INTL': 'https://oauth.alibabacloud.com' + } + sign_in_url = oauth_base_url_map.get(site_type.upper()) + if not sign_in_url: + raise CredentialException('Invalid OAuth site type, support CN or INTL') + + oauth_client_map = { + 'CN': '4038181954557748008', + 'INTL': '4103531455503354461' + } + client_id = oauth_client_map.get(site_type.upper()) + if not client_id: + raise CredentialException('Invalid OAuth site type, support CN or INTL') + + return OAuthCredentialsProvider( + client_id=client_id, + sign_in_url=sign_in_url, + access_token=profile.get('oauth_access_token'), + access_token_expire=profile.get('oauth_access_token_expire'), + refresh_token=profile.get('oauth_refresh_token'), + token_update_callback=self._get_oauth_token_update_callback(), + token_update_callback_async=self._get_oauth_token_update_callback_async(), + ) else: raise CredentialException(f"unsupported profile mode '{mode}' form cli credentials file.") @@ -170,3 +235,355 @@ def _get_credentials_provider(self, config: Dict, profile_name: str) -> ICredent def get_provider_name(self) -> str: return 'cli_profile' + + def _update_oauth_tokens(self, refresh_token: str, access_token: str, access_key: str, secret: str, + security_token: str, access_token_expire: int, sts_expire: int) -> None: + """更新 OAuth 令牌并写回配置文件""" + + with self._file_lock: + try: + # 读取现有配置 + config = _load_config(self._profile_file) + + # 找到当前 profile 并更新 OAuth 令牌 + profile_name = self._profile_name + if not profile_name: + profile_name = config.get('current') + profiles = config.get('profiles', []) + profile_tag = False + for profile in profiles: + if profile.get('name') == profile_name: + profile_tag = True + # 更新 OAuth 令牌 + profile['oauth_refresh_token'] = refresh_token + profile['oauth_access_token'] = access_token + profile['oauth_access_token_expire'] = access_token_expire + # 更新 STS 凭据 + profile['access_key_id'] = access_key + profile['access_key_secret'] = secret + profile['sts_token'] = security_token + profile['sts_expiration'] = sts_expire + break + + # 写回配置文件 + if not profile_tag: + raise CredentialException(f"unable to get profile with '{profile_name}' form cli credentials file.") + + self._write_configuration_to_file_with_lock(self._profile_file, config) + + except Exception as e: + raise CredentialException(f"failed to update OAuth tokens in config file: {e}") + + def _write_configuration_to_file(self, config_path: str, config: Dict) -> None: + """将配置写入文件,使用原子写入确保数据完整性""" + # 获取原文件权限(如果存在) + file_mode = 0o644 + if os.path.exists(config_path): + file_mode = os.stat(config_path).st_mode + + # 创建唯一临时文件 + import time + temp_file = config_path + '.tmp-' + str(int(time.time() * 1000000)) # 微秒级时间戳 + backup_file = None + + try: + # 写入临时文件 + self._write_config_file(temp_file, file_mode, config) + + # 原子性重命名,Windows下需要特殊处理 + if platform.system() == 'Windows' and self._allow_config_force_rewrite: + # Windows下需要先删除目标文件,使用备份机制确保数据安全 + if os.path.exists(config_path): + backup_file = config_path + '.backup' + # 创建备份 + import shutil + shutil.copy2(config_path, backup_file) + # 删除原文件 + os.remove(config_path) + + os.rename(temp_file, config_path) + + # 成功后删除备份 + if backup_file and os.path.exists(backup_file): + os.remove(backup_file) + + except Exception as e: + # 恢复原文件(如果存在备份) + if backup_file and os.path.exists(backup_file): + try: + if not os.path.exists(config_path): + os.rename(backup_file, config_path) + except Exception as restore_error: + raise CredentialException( + f"Failed to restore original file after write error: {restore_error}. Original error: {e}") + + raise e + + def _write_config_file(self, filename: str, file_mode: int, config: Dict) -> None: + try: + with open(filename, 'w', encoding='utf-8') as f: + json.dump(config, f, indent=4, ensure_ascii=False) + + # 设置文件权限 + os.chmod(filename, file_mode) + + except Exception as e: + raise CredentialException(f"Failed to write config file: {e}") + + def _write_configuration_to_file_with_lock(self, config_path: str, config: Dict) -> None: + """使用操作系统级别的文件锁写入配置文件""" + # 获取原文件权限(如果存在) + file_mode = 0o644 + if os.path.exists(config_path): + file_mode = os.stat(config_path).st_mode + + backup_file = None + + try: + # 确保文件存在 + if not os.path.exists(config_path): + # 创建空文件 + with open(config_path, 'w') as f: + json.dump({}, f) + + # 在获取文件锁之前创建备份(Windows下需要) + if platform.system() == 'Windows' and self._allow_config_force_rewrite and os.path.exists(config_path): + backup_file = config_path + '.backup' + import shutil + shutil.copy2(config_path, backup_file) + + # 打开文件用于锁定 + with open(config_path, 'r+') as f: + # 获取独占锁(阻塞其他进程) + if HAS_MSVCRT: + # Windows使用msvcrt + msvcrt.locking(f.fileno(), msvcrt.LK_NBLCK, 1) + elif HAS_FCNTL: + # Unix/Linux使用fcntl + fcntl.flock(f.fileno(), fcntl.LOCK_EX) + # 如果都不支持,则跳过文件锁(仅进程内保护) + + try: + if platform.system() == 'Windows' and self._allow_config_force_rewrite: + # Windows下直接在锁定的文件中写入 + f.seek(0) + f.truncate() # 清空文件内容 + json.dump(config, f, indent=4, ensure_ascii=False) + f.flush() + else: + # 其他环境使用临时文件+rename(在文件锁内部进行原子操作) + import time + temp_file = config_path + '.tmp-' + str(int(time.time() * 1000000)) + self._write_config_file(temp_file, file_mode, config) + # 在文件锁内部进行原子重命名 + os.rename(temp_file, config_path) + + finally: + # 释放锁 + try: + if HAS_MSVCRT: + msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, 1) + elif HAS_FCNTL: + fcntl.flock(f.fileno(), fcntl.LOCK_UN) + except (OSError, PermissionError): + # 在Windows下,如果文件被重命名,文件句柄可能已经无效 + # 这种情况下锁会自动释放,所以忽略错误 + pass + + # 成功后删除备份 + if backup_file and os.path.exists(backup_file): + os.remove(backup_file) + + except Exception as e: + # 恢复原文件(如果存在备份) + if backup_file and os.path.exists(backup_file): + try: + if not os.path.exists(config_path): + os.rename(backup_file, config_path) + except Exception as restore_error: + raise CredentialException( + f"Failed to restore original file after write error: {restore_error}. Original error: {e}") + + raise e + + def _get_oauth_token_update_callback(self) -> OAuthTokenUpdateCallback: + """获取 OAuth 令牌更新回调函数""" + return lambda refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire: self._update_oauth_tokens( + refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire + ) + + async def _write_configuration_to_file_async(self, config_path: str, config: Dict) -> None: + """异步将配置写入文件,使用原子写入确保数据完整性""" + # 获取原文件权限(如果存在) + file_mode = 0o644 + if os.path.exists(config_path): + file_mode = os.stat(config_path).st_mode + + # 创建唯一临时文件 + import time + temp_file = config_path + '.tmp-' + str(int(time.time() * 1000000)) # 微秒级时间戳 + backup_file = None + + try: + # 异步写入临时文件 + await self._write_config_file_async(temp_file, file_mode, config) + + # 原子性重命名,Windows下需要特殊处理 + if platform.system() == 'Windows' and self._allow_config_force_rewrite: + # Windows下需要先删除目标文件,使用备份机制确保数据安全 + if os.path.exists(config_path): + backup_file = config_path + '.backup' + # 创建备份 + import shutil + shutil.copy2(config_path, backup_file) + # 删除原文件 + os.remove(config_path) + + os.rename(temp_file, config_path) + + # 成功后删除备份 + if backup_file and os.path.exists(backup_file): + os.remove(backup_file) + + except Exception as e: + # 恢复原文件(如果存在备份) + if backup_file and os.path.exists(backup_file): + try: + if not os.path.exists(config_path): + os.rename(backup_file, config_path) + except Exception as restore_error: + raise CredentialException( + f"Failed to restore original file after write error: {restore_error}. Original error: {e}") + + raise e + + async def _write_config_file_async(self, filename: str, file_mode: int, config: Dict) -> None: + try: + async with aiofiles.open(filename, 'w', encoding='utf-8') as f: + await f.write(json.dumps(config, indent=4, ensure_ascii=False)) + + # 设置文件权限 + os.chmod(filename, file_mode) + + except Exception as e: + raise CredentialException(f"Failed to write config file: {e}") + + async def _write_configuration_to_file_with_lock_async(self, config_path: str, config: Dict) -> None: + """异步使用操作系统级别的文件锁写入配置文件""" + # 获取原文件权限(如果存在) + file_mode = 0o644 + if os.path.exists(config_path): + file_mode = os.stat(config_path).st_mode + + backup_file = None + + try: + # 确保文件存在 + if not os.path.exists(config_path): + # 创建空文件 + with open(config_path, 'w') as f: + json.dump({}, f) + + # 在获取文件锁之前创建备份(Windows下需要) + if platform.system() == 'Windows' and self._allow_config_force_rewrite and os.path.exists(config_path): + backup_file = config_path + '.backup' + import shutil + shutil.copy2(config_path, backup_file) + + # 打开文件用于锁定 + with open(config_path, 'r+') as f: + # 获取独占锁(阻塞其他进程) + if HAS_MSVCRT: + # Windows使用msvcrt + msvcrt.locking(f.fileno(), msvcrt.LK_NBLCK, 1) + elif HAS_FCNTL: + # Unix/Linux使用fcntl + fcntl.flock(f.fileno(), fcntl.LOCK_EX) + # 如果都不支持,则跳过文件锁(仅进程内保护) + + try: + if platform.system() == 'Windows' and self._allow_config_force_rewrite: + # Windows下直接在锁定的文件中写入 + f.seek(0) + f.truncate() # 清空文件内容 + json.dump(config, f, indent=4, ensure_ascii=False) + f.flush() + else: + # 其他环境使用临时文件+rename(在文件锁内部进行原子操作) + import time + temp_file = config_path + '.tmp-' + str(int(time.time() * 1000000)) + await self._write_config_file_async(temp_file, file_mode, config) + # 在文件锁内部进行原子重命名 + os.rename(temp_file, config_path) + + finally: + # 释放锁 + try: + if HAS_MSVCRT: + msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, 1) + elif HAS_FCNTL: + fcntl.flock(f.fileno(), fcntl.LOCK_UN) + except (OSError, PermissionError): + # 在Windows下,如果文件被重命名,文件句柄可能已经无效 + # 这种情况下锁会自动释放,所以忽略错误 + pass + + # 成功后删除备份 + if backup_file and os.path.exists(backup_file): + os.remove(backup_file) + + except Exception as e: + # 恢复原文件(如果存在备份) + if backup_file and os.path.exists(backup_file): + try: + if not os.path.exists(config_path): + os.rename(backup_file, config_path) + except Exception as restore_error: + raise CredentialException( + f"Failed to restore original file after write error: {restore_error}. Original error: {e}") + + raise e + + async def _update_oauth_tokens_async(self, refresh_token: str, access_token: str, access_key: str, secret: str, + security_token: str, access_token_expire: int, sts_expire: int) -> None: + """异步更新 OAuth 令牌并写回配置文件""" + + try: + with self._file_lock: + cfg_path = self._profile_file + conf = await _load_config_async(cfg_path) + + # 找到当前 profile 并更新 OAuth 令牌 + profile_name = self._profile_name + if not profile_name: + profile_name = conf.get('current') + profiles = conf.get('profiles', []) + profile_tag = False + for profile in profiles: + if profile.get('name') == profile_name: + profile_tag = True + # 更新 OAuth 相关字段 + profile['oauth_refresh_token'] = refresh_token + profile['oauth_access_token'] = access_token + profile['oauth_access_token_expire'] = access_token_expire + # 更新 STS 凭据 + profile['access_key_id'] = access_key + profile['access_key_secret'] = secret + profile['sts_token'] = security_token + profile['sts_expiration'] = sts_expire + break + + if not profile_tag: + raise CredentialException(f"Profile '{profile_name}' not found in config file") + + # 异步写回配置文件 + await self._write_configuration_to_file_with_lock_async(cfg_path, conf) + + except Exception as e: + raise CredentialException(f"failed to update OAuth tokens in config file: {e}") + + def _get_oauth_token_update_callback_async(self) -> OAuthTokenUpdateCallbackAsync: + """获取异步 OAuth 令牌更新回调函数""" + return lambda refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire: self._update_oauth_tokens_async( + refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire + ) diff --git a/alibabacloud_credentials/provider/cloud_sso.py b/alibabacloud_credentials/provider/cloud_sso.py new file mode 100644 index 0000000..77ff6f0 --- /dev/null +++ b/alibabacloud_credentials/provider/cloud_sso.py @@ -0,0 +1,160 @@ +import calendar +import json +import time +from urllib.parse import urlparse + +from alibabacloud_credentials.provider.refreshable import Credentials, RefreshResult, RefreshCachedSupplier +from alibabacloud_credentials.http import HttpOptions +from Tea.core import TeaCore +from alibabacloud_credentials_api import ICredentialsProvider +from alibabacloud_credentials.utils import parameter_helper as ph +from alibabacloud_credentials.exceptions import CredentialException + + +def _get_stale_time(expiration: int) -> int: + if expiration < 0: + return int(time.mktime(time.localtime())) + 60 * 60 + return expiration - 15 * 60 + + +class CloudSSOCredentialsProvider(ICredentialsProvider): + DEFAULT_CONNECT_TIMEOUT = 5000 + DEFAULT_READ_TIMEOUT = 10000 + + def __init__(self, *, + sign_in_url: str = None, + account_id: str = None, + access_config: str = None, + access_token: str = None, + access_token_expire: int = 0, + http_options: HttpOptions = None): + + self._sign_in_url = sign_in_url + self._account_id = account_id + self._access_config = access_config + self._access_token = access_token + self._access_token_expire = access_token_expire + + if self._access_token is None or self._access_token_expire == 0 or self._access_token_expire - int( + time.mktime(time.localtime())) <= 0: + raise ValueError( + 'CloudSSO access token is empty or expired, please re-login with cli') + if self._sign_in_url is None or self._account_id is None or self._access_config is None: + raise ValueError( + 'CloudSSO sign in url or account id or access config is empty') + + self._http_options = http_options if http_options is not None else HttpOptions() + self._runtime_options = { + 'connectTimeout': self._http_options.connect_timeout if self._http_options.connect_timeout is not None else CloudSSOCredentialsProvider.DEFAULT_CONNECT_TIMEOUT, + 'readTimeout': self._http_options.read_timeout if self._http_options.read_timeout is not None else CloudSSOCredentialsProvider.DEFAULT_READ_TIMEOUT, + 'httpsProxy': self._http_options.proxy + } + self._credentials_cache = RefreshCachedSupplier( + refresh_callable=self._refresh_credentials, + refresh_callable_async=self._refresh_credentials_async, + ) + + def get_credentials(self) -> Credentials: + return self._credentials_cache._sync_call() + + async def get_credentials_async(self) -> Credentials: + return await self._credentials_cache._async_call() + + def _refresh_credentials(self) -> RefreshResult[Credentials]: + r = urlparse(self._sign_in_url) + tea_request = ph.get_new_request() + tea_request.headers['host'] = r.hostname + tea_request.port = r.port + tea_request.protocol = r.scheme + tea_request.method = 'POST' + tea_request.pathname = '/cloud-credentials' + + tea_request.body = json.dumps({ + 'AccountId': self._account_id, + 'AccessConfigurationId': self._access_config, + }) + + tea_request.headers['Accept'] = 'application/json' + tea_request.headers['Content-Type'] = 'application/json' + tea_request.headers['Authorization'] = f'Bearer {self._access_token}' + + response = TeaCore.do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException( + f'error refreshing credentials from sso, http_code: {response.status_code}, result: {response.body.decode("utf-8")}') + + dic = json.loads(response.body.decode('utf-8')) + if 'CloudCredential' not in dic: + raise CredentialException( + f'error retrieving credentials from sso result: {response.body.decode("utf-8")}') + + cre = dic.get('CloudCredential') + if 'AccessKeyId' not in cre or 'AccessKeySecret' not in cre or 'SecurityToken' not in cre: + raise CredentialException( + f'error retrieving credentials from sso result: {response.body.decode("utf-8")}') + + # 先转换为时间数组 + time_array = time.strptime(cre.get('Expiration'), '%Y-%m-%dT%H:%M:%SZ') + # 转换为时间戳 + expiration = calendar.timegm(time_array) + credentials = Credentials( + access_key_id=cre.get('AccessKeyId'), + access_key_secret=cre.get('AccessKeySecret'), + security_token=cre.get('SecurityToken'), + expiration=expiration, + provider_name=self.get_provider_name() + ) + return RefreshResult(value=credentials, + stale_time=_get_stale_time(expiration)) + + async def _refresh_credentials_async(self) -> RefreshResult[Credentials]: + r = urlparse(self._sign_in_url) + tea_request = ph.get_new_request() + tea_request.headers['host'] = r.hostname + tea_request.port = r.port + tea_request.protocol = r.scheme + tea_request.method = 'POST' + tea_request.pathname = '/cloud-credentials' + + tea_request.body = json.dumps({ + 'AccountId': self._account_id, + 'AccessConfigurationId': self._access_config, + }) + + tea_request.headers['Accept'] = 'application/json' + tea_request.headers['Content-Type'] = 'application/json' + tea_request.headers['Authorization'] = f'Bearer {self._access_token}' + + response = await TeaCore.async_do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException( + f'error refreshing credentials from sso, http_code: {response.status_code}, result: {response.body.decode("utf-8")}') + + dic = json.loads(response.body.decode('utf-8')) + if 'CloudCredential' not in dic: + raise CredentialException( + f'error retrieving credentials from sso result: {response.body.decode("utf-8")}') + + cre = dic.get('CloudCredential') + if 'AccessKeyId' not in cre or 'AccessKeySecret' not in cre or 'SecurityToken' not in cre: + raise CredentialException( + f'error retrieving credentials from sso result: {response.body.decode("utf-8")}') + + # 先转换为时间数组 + time_array = time.strptime(cre.get('Expiration'), '%Y-%m-%dT%H:%M:%SZ') + # 转换为时间戳 + expiration = calendar.timegm(time_array) + credentials = Credentials( + access_key_id=cre.get('AccessKeyId'), + access_key_secret=cre.get('AccessKeySecret'), + security_token=cre.get('SecurityToken'), + expiration=expiration, + provider_name=self.get_provider_name() + ) + return RefreshResult(value=credentials, + stale_time=_get_stale_time(expiration)) + + def get_provider_name(self) -> str: + return 'cloud_sso' diff --git a/alibabacloud_credentials/provider/oauth.py b/alibabacloud_credentials/provider/oauth.py new file mode 100644 index 0000000..81eb03e --- /dev/null +++ b/alibabacloud_credentials/provider/oauth.py @@ -0,0 +1,287 @@ +import calendar +import json +import logging +import time +from urllib.parse import urlparse, urlencode +from typing import Callable, Optional + +from alibabacloud_credentials.provider.refreshable import Credentials, RefreshResult, RefreshCachedSupplier +from alibabacloud_credentials.http import HttpOptions +from Tea.core import TeaCore +from alibabacloud_credentials_api import ICredentialsProvider +from alibabacloud_credentials.utils import parameter_helper as ph +from alibabacloud_credentials.exceptions import CredentialException + +log = logging.getLogger('credentials') +log.setLevel(logging.INFO) +ch = logging.StreamHandler() +log.addHandler(ch) + +# OAuth 令牌更新回调函数类型 +OAuthTokenUpdateCallback = Callable[[str, str, str, str, str, int, int], None] +OAuthTokenUpdateCallbackAsync = Callable[[str, str, str, str, str, int, int], None] + + +def _get_stale_time(expiration: int) -> int: + if expiration < 0: + return int(time.mktime(time.localtime())) + 60 * 60 + return expiration - 15 * 60 + + +class OAuthCredentialsProvider(ICredentialsProvider): + DEFAULT_CONNECT_TIMEOUT = 5000 + DEFAULT_READ_TIMEOUT = 10000 + + def __init__(self, *, + client_id: str = None, + sign_in_url: str = None, + access_token: str = None, + access_token_expire: int = 0, + refresh_token: str = None, + http_options: HttpOptions = None, + token_update_callback: Optional[OAuthTokenUpdateCallback] = None, + token_update_callback_async: Optional[OAuthTokenUpdateCallbackAsync] = None): + + if not client_id: + raise ValueError('the ClientId is empty') + + if not sign_in_url: + raise ValueError('the url for sign-in is empty') + + if not refresh_token: + raise ValueError('OAuth access token is empty or expired, please re-login with cli') + + self._client_id = client_id + self._sign_in_url = sign_in_url + self._access_token = access_token + self._access_token_expire = access_token_expire + self._refresh_token = refresh_token + self._token_update_callback = token_update_callback + self._token_update_callback_async = token_update_callback_async + + self._http_options = http_options if http_options is not None else HttpOptions() + self._runtime_options = { + 'connectTimeout': self._http_options.connect_timeout if self._http_options.connect_timeout is not None else OAuthCredentialsProvider.DEFAULT_CONNECT_TIMEOUT, + 'readTimeout': self._http_options.read_timeout if self._http_options.read_timeout is not None else OAuthCredentialsProvider.DEFAULT_READ_TIMEOUT, + 'httpsProxy': self._http_options.proxy + } + self._credentials_cache = RefreshCachedSupplier( + refresh_callable=self._refresh_credentials, + refresh_callable_async=self._refresh_credentials_async, + ) + + def get_credentials(self) -> Credentials: + return self._credentials_cache._sync_call() + + async def get_credentials_async(self) -> Credentials: + return await self._credentials_cache._async_call() + + def _try_refresh_oauth_token(self) -> None: + current_time = int(time.mktime(time.localtime())) + # 构建刷新令牌请求 + r = urlparse(self._sign_in_url) + tea_request = ph.get_new_request() + tea_request.headers['host'] = r.hostname + tea_request.port = r.port + tea_request.protocol = r.scheme + tea_request.method = 'POST' + tea_request.pathname = '/v1/token' + + # 设置请求体 + body_data = { + 'grant_type': 'refresh_token', + 'refresh_token': self._refresh_token, + 'client_id': self._client_id, + 'Timestamp': ph.get_iso_8061_date() + } + tea_request.body = urlencode(body_data) + tea_request.headers['Content-Type'] = 'application/x-www-form-urlencoded' + + response = TeaCore.do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException(f"failed to refresh OAuth token, status code: {response.status_code}, response: {response.body.decode('utf-8')}") + + # 解析响应 + dic = json.loads(response.body.decode('utf-8')) + if 'access_token' not in dic or 'refresh_token' not in dic: + raise CredentialException(f"failed to refresh OAuth token: {response.body.decode('utf-8')}") + + # 更新令牌 + new_access_token = dic.get('access_token') + new_refresh_token = dic.get('refresh_token') + expires_in = dic.get('expires_in', 3600) + new_access_token_expire = current_time + expires_in + + self._access_token = new_access_token + self._refresh_token = new_refresh_token + self._access_token_expire = new_access_token_expire + + async def _try_refresh_oauth_token_async(self) -> None: + current_time = int(time.mktime(time.localtime())) + # 构建刷新令牌请求 + r = urlparse(self._sign_in_url) + tea_request = ph.get_new_request() + tea_request.headers['host'] = r.hostname + tea_request.port = r.port + tea_request.protocol = r.scheme + tea_request.method = 'POST' + tea_request.pathname = '/v1/token' + + # 设置请求体 + body_data = { + 'grant_type': 'refresh_token', + 'refresh_token': self._refresh_token, + 'client_id': self._client_id, + 'Timestamp': ph.get_iso_8061_date() + } + tea_request.body = urlencode(body_data) + tea_request.headers['Content-Type'] = 'application/x-www-form-urlencoded' + + response = await TeaCore.async_do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException(f"failed to refresh OAuth token, status code: {response.status_code}, response: {response.body.decode('utf-8')}") + + # 解析响应 + dic = json.loads(response.body.decode('utf-8')) + if 'access_token' not in dic or 'refresh_token' not in dic: + raise CredentialException(f"failed to refresh OAuth token: {response.body.decode('utf-8')}") + + # 更新令牌 + new_access_token = dic.get('access_token') + new_refresh_token = dic.get('refresh_token') + expires_in = dic.get('expires_in', 3600) + new_access_token_expire = current_time + expires_in + + self._access_token = new_access_token + self._refresh_token = new_refresh_token + self._access_token_expire = new_access_token_expire + + def _refresh_credentials(self) -> RefreshResult[Credentials]: + if self._access_token is None or self._access_token_expire <= 0 or self._access_token_expire - int( + time.mktime(time.localtime())) <= 180: + self._try_refresh_oauth_token() + + r = urlparse(self._sign_in_url) + tea_request = ph.get_new_request() + tea_request.headers['host'] = r.hostname + tea_request.port = r.port + tea_request.protocol = r.scheme + tea_request.method = 'POST' + tea_request.pathname = '/v1/exchange' + + tea_request.headers['Content-Type'] = 'application/json' + tea_request.headers['Authorization'] = f'Bearer {self._access_token}' + + response = TeaCore.do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException( + f"error refreshing credentials from OAuth, http_code: {response.status_code}, result: {response.body.decode('utf-8')}") + + dic = json.loads(response.body.decode('utf-8')) + if 'error' in dic: + raise CredentialException( + f"error retrieving credentials from OAuth result: {response.body.decode('utf-8')}") + + if 'AccessKeyId' not in dic or 'AccessKeySecret' not in dic or 'SecurityToken' not in dic: + raise CredentialException( + f"error retrieving credentials from OAuth result: {response.body.decode('utf-8')}") + + # 先转换为时间数组 + time_array = time.strptime(dic.get('Expiration'), '%Y-%m-%dT%H:%M:%SZ') + # 转换为时间戳 + expiration = calendar.timegm(time_array) + credentials = Credentials( + access_key_id=dic.get('AccessKeyId'), + access_key_secret=dic.get('AccessKeySecret'), + security_token=dic.get('SecurityToken'), + expiration=expiration, + provider_name=self.get_provider_name() + ) + + # 调用令牌更新回调函数 + if self._token_update_callback: + try: + self._token_update_callback( + self._refresh_token, + self._access_token, + credentials.get_access_key_id(), + credentials.get_access_key_secret(), + credentials.get_security_token(), + self._access_token_expire, + expiration + ) + except Exception as e: + log.warning(f'failed to update OAuth tokens in config file: {e}') + + return RefreshResult(value=credentials, + stale_time=_get_stale_time(expiration)) + + async def _refresh_credentials_async(self) -> RefreshResult[Credentials]: + if self._access_token is None or self._access_token_expire <= 0 or self._access_token_expire - int( + time.mktime(time.localtime())) <= 180: + await self._try_refresh_oauth_token_async() + + r = urlparse(self._sign_in_url) + tea_request = ph.get_new_request() + tea_request.headers['host'] = r.hostname + tea_request.port = r.port + tea_request.protocol = r.scheme + tea_request.method = 'POST' + tea_request.pathname = '/v1/exchange' + + tea_request.headers['Content-Type'] = 'application/json' + tea_request.headers['Authorization'] = f'Bearer {self._access_token}' + + response = await TeaCore.async_do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException( + f"error refreshing credentials from OAuth, http_code: {response.status_code}, result: {response.body.decode('utf-8')}") + + dic = json.loads(response.body.decode('utf-8')) + if 'error' in dic: + raise CredentialException( + f"error retrieving credentials from OAuth result: {response.body.decode('utf-8')}") + + if 'AccessKeyId' not in dic or 'AccessKeySecret' not in dic or 'SecurityToken' not in dic: + raise CredentialException( + f"error retrieving credentials from OAuth result: {response.body.decode('utf-8')}") + + # 先转换为时间数组 + time_array = time.strptime(dic.get('Expiration'), '%Y-%m-%dT%H:%M:%SZ') + # 转换为时间戳 + expiration = calendar.timegm(time_array) + credentials = Credentials( + access_key_id=dic.get('AccessKeyId'), + access_key_secret=dic.get('AccessKeySecret'), + security_token=dic.get('SecurityToken'), + expiration=expiration, + provider_name=self.get_provider_name() + ) + + if self._token_update_callback_async: + try: + await self._token_update_callback_async( + self._refresh_token, + self._access_token, + credentials.get_access_key_id(), + credentials.get_access_key_secret(), + credentials.get_security_token(), + self._access_token_expire, + expiration + ) + except Exception as e: + log.warning(f'failed to update OAuth tokens in config file: {e}') + + return RefreshResult(value=credentials, + stale_time=_get_stale_time(expiration)) + + def _get_client_id(self) -> str: + """获取客户端ID""" + return self._client_id + + def get_provider_name(self) -> str: + return 'oauth' diff --git a/tests/provider/__init__.py b/tests/provider/__init__.py new file mode 100644 index 0000000..987c148 --- /dev/null +++ b/tests/provider/__init__.py @@ -0,0 +1 @@ +# Provider tests package diff --git a/tests/provider/test_cli_profile.py b/tests/provider/test_cli_profile.py index 0732489..ffbd802 100644 --- a/tests/provider/test_cli_profile.py +++ b/tests/provider/test_cli_profile.py @@ -3,6 +3,7 @@ import asyncio import os import json +import time from alibabacloud_credentials.provider.cli_profile import ( CLIProfileCredentialsProvider, CredentialException, @@ -13,7 +14,9 @@ StaticAKCredentialsProvider, RamRoleArnCredentialsProvider, EcsRamRoleCredentialsProvider, - OIDCRoleArnCredentialsProvider + OIDCRoleArnCredentialsProvider, + CloudSSOCredentialsProvider, + OAuthCredentialsProvider ) from alibabacloud_credentials.utils import auth_constant as ac @@ -21,6 +24,8 @@ class TestCLIProfileCredentialsProvider(unittest.TestCase): def setUp(self): + # 设置时区环境变量以避免调度器初始化问题 + os.environ['TZ'] = 'UTC' self.profile_name = "test_profile" self.profile_file = os.path.join(ac.HOME, ".aliyun/config.json") self.config = { @@ -80,6 +85,23 @@ def setUp(self): "external_id": "test_external_id", "sts_region": "test_sts_region", "enable_vpc": True + }, + { + "name": "cloud_sso_profile", + "mode": "CloudSSO", + "cloud_sso_sign_in_url": "https://sso.example.com", + "cloud_sso_account_id": "test_account_id", + "cloud_sso_access_config": "test_access_config", + "access_token": "test_access_token", + "cloud_sso_access_token_expire": int(time.mktime(time.localtime())) + 1000 + }, + { + "name": "oauth_profile", + "mode": "OAuth", + "oauth_site_type": "CN", + "oauth_refresh_token": "test_refresh_token", + "oauth_access_token": "test_oauth_access_token", + "oauth_access_token_expire": int(time.mktime(time.localtime())) + 1000 } ] } @@ -243,9 +265,50 @@ def test_get_credentials_valid_chainable_ram_role_arn(self): self.assertEqual(credentials_provider._external_id, 'test_external_id') self.assertEqual(credentials_provider._sts_endpoint, 'sts-vpc.test_sts_region.aliyuncs.com') + def test_get_credentials_valid_cloud_sso(self): + """ + Test case 7: Valid input, successfully retrieves credentials for CloudSSO mode + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', False): + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.cli_profile._load_config', return_value=self.config): + provider = CLIProfileCredentialsProvider(profile_name="cloud_sso_profile") + + credentials_provider = provider._get_credentials_provider(config=self.config, + profile_name="cloud_sso_profile") + + self.assertIsInstance(credentials_provider, CloudSSOCredentialsProvider) + + self.assertEqual(credentials_provider._sign_in_url, 'https://sso.example.com') + self.assertEqual(credentials_provider._account_id, 'test_account_id') + self.assertEqual(credentials_provider._access_config, 'test_access_config') + self.assertEqual(credentials_provider._access_token, 'test_access_token') + self.assertTrue(credentials_provider._access_token_expire > int(time.mktime(time.localtime()))) + + def test_get_credentials_valid_oauth(self): + """ + Test case 8: Valid input, successfully retrieves credentials for OAuth mode + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', False): + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.cli_profile._load_config', return_value=self.config): + provider = CLIProfileCredentialsProvider(profile_name="oauth_profile") + + credentials_provider = provider._get_credentials_provider(config=self.config, + profile_name="oauth_profile") + + self.assertIsInstance(credentials_provider, OAuthCredentialsProvider) + + self.assertEqual(credentials_provider._client_id, '4038181954557748008') + self.assertEqual(credentials_provider._sign_in_url, 'https://oauth.aliyun.com') + self.assertEqual(credentials_provider._access_token, 'test_oauth_access_token') + self.assertTrue(credentials_provider._access_token_expire > int(time.mktime(time.localtime()))) + def test_get_credentials_cli_profile_disabled(self): """ - Test case 7: CLI profile disabled raises CredentialException + Test case 9: CLI profile disabled raises CredentialException """ with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'True'): provider = CLIProfileCredentialsProvider(profile_name=self.profile_name) @@ -257,19 +320,21 @@ def test_get_credentials_cli_profile_disabled(self): def test_get_credentials_profile_name_not_exists(self): """ - Test case 8: Profile file does not exist raises CredentialException + Test case 10: Profile file does not exist raises CredentialException """ with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): - provider = CLIProfileCredentialsProvider(profile_name='not_exists') - - with self.assertRaises(CredentialException) as context: - provider.get_credentials() - - self.assertIn(f"unable to get profile with 'not_exists' form cli credentials file.", str(context.exception)) + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.cli_profile._load_config', return_value=self.config): + provider = CLIProfileCredentialsProvider(profile_name='not_exists') + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + self.assertIn(f"unable to get profile with 'not_exists' form cli credentials file.", + str(context.exception)) def test_get_credentials_profile_file_not_exists(self): """ - Test case 8: Profile file does not exist raises CredentialException + Test case 11: Profile file does not exist raises CredentialException """ with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): with patch('os.path.exists', return_value=False): @@ -282,7 +347,7 @@ def test_get_credentials_profile_file_not_exists(self): def test_get_credentials_profile_file_not_file(self): """ - Test case 9: Profile file is not a file raises CredentialException + Test case 12: Profile file is not a file raises CredentialException """ with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): with patch('os.path.exists', return_value=True): @@ -296,7 +361,7 @@ def test_get_credentials_profile_file_not_file(self): def test_get_credentials_invalid_json_format(self): """ - Test case 10: Invalid JSON format in profile file raises CredentialException + Test case 13: Invalid JSON format in profile file raises CredentialException """ with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): with patch('os.path.exists', return_value=True): @@ -313,7 +378,7 @@ def test_get_credentials_invalid_json_format(self): def test_get_credentials_empty_json(self): """ - Test case 11: Empty JSON in profile file raises CredentialException + Test case 14: Empty JSON in profile file raises CredentialException """ with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): with patch('os.path.exists', return_value=True): @@ -329,7 +394,7 @@ def test_get_credentials_empty_json(self): def test_get_credentials_missing_profiles(self): """ - Test case 12: Missing profiles in JSON raises CredentialException + Test case 15: Missing profiles in JSON raises CredentialException """ with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): with patch('os.path.exists', return_value=True): @@ -346,7 +411,7 @@ def test_get_credentials_missing_profiles(self): def test_get_credentials_invalid_profile_mode(self): """ - Test case 13: Invalid profile mode raises CredentialException + Test case 16: Invalid profile mode raises CredentialException """ invalid_config = { "current": "invalid_profile", @@ -374,7 +439,7 @@ def test_get_credentials_invalid_profile_mode(self): def test_get_credentials_async_valid_ak(self): """ - Test case 14: Valid input, successfully retrieves credentials for AK mode + Test case 17: Valid input, successfully retrieves credentials for AK mode """ with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): with patch('os.path.exists', return_value=True): @@ -383,12 +448,11 @@ def test_get_credentials_async_valid_ak(self): AsyncMock(return_value=self.config)): provider = CLIProfileCredentialsProvider(profile_name=self.profile_name) - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + credentials = asyncio.run(run_test()) self.assertEqual(credentials.get_access_key_id(), self.access_key_id) self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) @@ -398,7 +462,7 @@ def test_get_credentials_async_valid_ak(self): @patch('builtins.open', new_callable=MagicMock) def test_load_config_file_not_found(self, mock_open): """ - Test case 15: File not found raises FileNotFoundError + Test case 18: File not found raises FileNotFoundError """ mock_open.side_effect = FileNotFoundError(f"No such file or directory: '{self.profile_file}'") @@ -410,7 +474,7 @@ def test_load_config_file_not_found(self, mock_open): @patch('builtins.open', new_callable=MagicMock) def test_load_config_invalid_json(self, mock_open): """ - Test case 16: Invalid JSON format raises json.JSONDecodeError + Test case 19: Invalid JSON format raises json.JSONDecodeError """ invalid_json = "invalid json content" mock_open.return_value.__enter__.return_value.read.return_value = invalid_json @@ -419,3 +483,966 @@ def test_load_config_invalid_json(self, mock_open): _load_config(self.profile_file) self.assertIn("Expecting value: line 1 column 1", str(context.exception)) + + def test_oauth_token_update_callback(self): + """测试 OAuth 令牌更新回调功能""" + import tempfile + import json + import time + + # 创建临时配置文件 + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + test_config = { + "current": "oauth_test", + "profiles": [ + { + "name": "oauth_test", + "mode": "OAuth", + "oauth_site_type": "CN", + "oauth_refresh_token": "initial_refresh_token", + "oauth_access_token": "initial_access_token", + "oauth_access_token_expire": int(time.time()) + 3600 + } + ] + } + + with open(config_path, 'w') as f: + json.dump(test_config, f, indent=4) + + try: + provider = CLIProfileCredentialsProvider( + profile_name="oauth_test", + profile_file=config_path, + allow_config_force_rewrite=True + ) + + # 更新令牌 + new_refresh_token = "new_refresh_token" + new_access_token = "new_access_token" + new_access_key = "new_access_key" + new_secret = "new_secret" + new_security_token = "new_security_token" + new_expire_time = int(time.time()) + 7200 + new_sts_expire = int(time.time()) + 10800 + + provider._update_oauth_tokens(new_refresh_token, new_access_token, new_access_key, new_secret, + new_security_token, new_expire_time, new_sts_expire) + + # 验证配置文件已更新 + with open(config_path, 'r') as f: + updated_config = json.load(f) + + profile = updated_config['profiles'][0] + self.assertEqual(profile['oauth_refresh_token'], new_refresh_token) + self.assertEqual(profile['oauth_access_token'], new_access_token) + self.assertEqual(profile['access_key_id'], new_access_key) + self.assertEqual(profile['access_key_secret'], new_secret) + self.assertEqual(profile['sts_token'], new_security_token) + self.assertEqual(profile['oauth_access_token_expire'], new_expire_time) + self.assertEqual(profile['sts_expiration'], new_sts_expire) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_oauth_callback_integration(self): + """测试 OAuth 回调集成""" + import tempfile + import json + import time + + # 创建临时配置文件 + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + test_config = { + "current": "oauth_test", + "profiles": [ + { + "name": "oauth_test", + "mode": "OAuth", + "oauth_site_type": "CN", + "oauth_refresh_token": "initial_refresh_token", + "oauth_access_token": "initial_access_token", + "oauth_access_token_expire": int(time.time()) + 3600 + } + ] + } + + with open(config_path, 'w') as f: + json.dump(test_config, f, indent=4) + + try: + provider = CLIProfileCredentialsProvider( + profile_name="oauth_test", + profile_file=config_path, + allow_config_force_rewrite=True + ) + + # 获取回调函数 + callback = provider._get_oauth_token_update_callback() + + # 调用回调函数 + new_refresh_token = "callback_refresh_token" + new_access_token = "callback_access_token" + new_access_key = "callback_access_key" + new_secret = "callback_secret" + new_security_token = "callback_security_token" + new_expire_time = int(time.time()) + 3600 + new_sts_expire = int(time.time()) + 7200 + + callback(new_refresh_token, new_access_token, new_access_key, new_secret, new_security_token, + new_expire_time, new_sts_expire) + + # 验证配置文件已更新 + with open(config_path, 'r') as f: + updated_config = json.load(f) + + profile = updated_config['profiles'][0] + self.assertEqual(profile['oauth_refresh_token'], new_refresh_token) + self.assertEqual(profile['oauth_access_token'], new_access_token) + self.assertEqual(profile['access_key_id'], new_access_key) + self.assertEqual(profile['access_key_secret'], new_secret) + self.assertEqual(profile['sts_token'], new_security_token) + self.assertEqual(profile['oauth_access_token_expire'], new_expire_time) + self.assertEqual(profile['sts_expiration'], new_sts_expire) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_concurrent_token_update(self): + """测试并发令牌更新""" + import tempfile + import json + import time + import threading + + # 创建临时配置文件 + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + test_config = { + "current": "oauth_test", + "profiles": [ + { + "name": "oauth_test", + "mode": "OAuth", + "oauth_site_type": "CN", + "oauth_refresh_token": "initial_refresh_token", + "oauth_access_token": "initial_access_token", + "oauth_access_token_expire": int(time.time()) + 3600 + } + ] + } + + with open(config_path, 'w') as f: + json.dump(test_config, f, indent=4) + + try: + provider = CLIProfileCredentialsProvider( + profile_name="oauth_test", + profile_file=config_path + ) + + results = [] + errors = [] + + def update_tokens(index): + try: + refresh_token = f"refresh_token_{index}" + access_token = f"access_token_{index}" + access_key = f"access_key_{index}" + secret = f"secret_{index}" + security_token = f"security_token_{index}" + expire_time = int(time.time()) + 3600 + index + sts_expire = int(time.time()) + 7200 + index + + provider._update_oauth_tokens(refresh_token, access_token, access_key, secret, security_token, + expire_time, sts_expire) + results.append(index) + except Exception as e: + errors.append(e) + + # 并发更新 + threads = [] + for i in range(10): + thread = threading.Thread(target=update_tokens, args=(i,)) + threads.append(thread) + thread.start() + + # 等待所有线程完成 + for thread in threads: + thread.join() + + # 验证最终配置文件仍然有效 + with open(config_path, 'r') as f: + final_config = json.load(f) + + self.assertIsNotNone(final_config) + self.assertIn('profiles', final_config) + self.assertEqual(len(final_config['profiles']), 1) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_file_lock_safety(self): + """测试文件锁安全性""" + import tempfile + import json + import time + + # 创建临时配置文件 + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + test_config = { + "current": "oauth_test", + "profiles": [ + { + "name": "oauth_test", + "mode": "OAuth", + "oauth_site_type": "CN", + "oauth_refresh_token": "initial_refresh_token", + "oauth_access_token": "initial_access_token", + "oauth_access_token_expire": int(time.time()) + 3600 + } + ] + } + + with open(config_path, 'w') as f: + json.dump(test_config, f, indent=4) + + try: + provider = CLIProfileCredentialsProvider( + profile_name="oauth_test", + profile_file=config_path, + allow_config_force_rewrite=True + ) + + # 测试文件锁是否正常工作 + with provider._file_lock: + # 在锁内执行操作 + provider._update_oauth_tokens("locked_token", "locked_access", "locked_key", "locked_secret", + "locked_sts", int(time.time()) + 3600, int(time.time()) + 7200) + + # 验证操作成功 + with open(config_path, 'r') as f: + config = json.load(f) + + profile = config['profiles'][0] + self.assertEqual(profile['oauth_refresh_token'], "locked_token") + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_write_configuration_to_file(self): + """测试基本文件写入功能""" + import tempfile + import json + + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + test_config = { + "current": "test", + "profiles": [{"name": "test", "mode": "AK"}] + } + + try: + provider = CLIProfileCredentialsProvider() + provider._write_configuration_to_file(config_path, test_config) + + # 验证文件已写入 + self.assertTrue(os.path.exists(config_path)) + + with open(config_path, 'r') as f: + loaded_config = json.load(f) + + self.assertEqual(loaded_config, test_config) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_write_configuration_to_file_error(self): + """测试写入只读目录时的错误处理""" + import tempfile + import stat + import platform + + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "readonly", "config.json") + + # 创建只读目录 + readonly_dir = os.path.join(temp_dir, "readonly") + os.makedirs(readonly_dir) + + # Windows和Unix的权限处理不同 + if platform.system() == 'Windows': + # Windows下使用更严格的权限设置 + try: + import win32security + import win32api + import win32con + + # 获取目录的安全描述符 + sd = win32security.GetFileSecurity(readonly_dir, win32security.DACL_SECURITY_INFORMATION) + dacl = win32security.ACL() + + # 创建拒绝所有访问的ACE + everyone, domain, type = win32security.LookupAccountName("", "Everyone") + dacl.AddAccessDeniedAce(win32security.ACL_REVISION, win32con.FILE_ALL_ACCESS, everyone) + + # 应用安全描述符 + sd.SetSecurityDescriptorDacl(1, dacl, 0) + win32security.SetFileSecurity(readonly_dir, win32security.DACL_SECURITY_INFORMATION, sd) + + test_config = {"current": "test"} + provider = CLIProfileCredentialsProvider( + allow_config_force_rewrite=True, + ) + + with self.assertRaises(Exception): + provider._write_configuration_to_file(config_path, test_config) + + except ImportError: + # 如果没有pywin32,跳过这个测试 + self.skipTest("pywin32 not available for Windows permission test") + else: + # Unix-like系统使用chmod + os.chmod(readonly_dir, stat.S_IRUSR | stat.S_IXUSR) # 只读 + test_config = {"current": "test"} + provider = CLIProfileCredentialsProvider() + + with self.assertRaises(Exception): + provider._write_configuration_to_file(config_path, test_config) + + try: + pass # 测试逻辑在上面 + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_write_configuration_to_file_with_lock(self): + """测试带文件锁的写入功能""" + import tempfile + import json + + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + test_config = { + "current": "test", + "profiles": [{"name": "test", "mode": "AK"}] + } + + try: + # 先创建文件,因为_write_configuration_to_file_with_lock需要文件存在 + with open(config_path, 'w') as f: + json.dump({}, f) + + provider = CLIProfileCredentialsProvider( + allow_config_force_rewrite=True, + ) + provider._write_configuration_to_file_with_lock(config_path, test_config) + + # 验证文件已写入 + self.assertTrue(os.path.exists(config_path)) + + with open(config_path, 'r') as f: + loaded_config = json.load(f) + + self.assertEqual(loaded_config, test_config) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_write_configuration_to_file_with_lock_error(self): + """测试带文件锁写入时的错误处理""" + import tempfile + import stat + import platform + + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "readonly", "config.json") + + # 创建只读目录 + readonly_dir = os.path.join(temp_dir, "readonly") + os.makedirs(readonly_dir) + + # Windows和Unix的权限处理不同 + if platform.system() == 'Windows': + # Windows下使用更严格的权限设置 + try: + import win32security + import win32api + import win32con + + # 获取目录的安全描述符 + sd = win32security.GetFileSecurity(readonly_dir, win32security.DACL_SECURITY_INFORMATION) + dacl = win32security.ACL() + + # 创建拒绝所有访问的ACE + everyone, domain, type = win32security.LookupAccountName("", "Everyone") + dacl.AddAccessDeniedAce(win32security.ACL_REVISION, win32con.FILE_ALL_ACCESS, everyone) + + # 应用安全描述符 + sd.SetSecurityDescriptorDacl(1, dacl, 0) + win32security.SetFileSecurity(readonly_dir, win32security.DACL_SECURITY_INFORMATION, sd) + + test_config = {"current": "test"} + provider = CLIProfileCredentialsProvider() + + with self.assertRaises(Exception): + provider._write_configuration_to_file_with_lock(config_path, test_config) + + except ImportError: + # 如果没有pywin32,跳过这个测试 + self.skipTest("pywin32 not available for Windows permission test") + else: + # Unix-like系统使用chmod + os.chmod(readonly_dir, stat.S_IRUSR | stat.S_IXUSR) # 只读 + test_config = {"current": "test"} + provider = CLIProfileCredentialsProvider() + + with self.assertRaises(Exception): + provider._write_configuration_to_file_with_lock(config_path, test_config) + + try: + pass # 测试逻辑在上面 + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_get_oauth_token_update_callback(self): + """测试获取OAuth令牌更新回调函数""" + provider = CLIProfileCredentialsProvider() + callback = provider._get_oauth_token_update_callback() + + self.assertIsNotNone(callback) + self.assertTrue(callable(callback)) + + def test_update_oauth_tokens_error(self): + """测试更新OAuth令牌时的错误处理""" + import tempfile + import json + + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + # 创建无效的配置文件 + with open(config_path, 'w') as f: + f.write("invalid json") + + try: + provider = CLIProfileCredentialsProvider( + profile_name="test", + profile_file=config_path + ) + + # 应该抛出CredentialException异常 + with self.assertRaises(CredentialException) as context: + provider._update_oauth_tokens("token", "access", "key", "secret", "sts", 123, 456) + + self.assertIn("failed to update OAuth tokens in config file", str(context.exception)) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_update_oauth_tokens_profile_not_found(self): + """测试更新不存在的profile""" + import tempfile + import json + + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + test_config = { + "current": "test", + "profiles": [{"name": "test", "mode": "AK"}] + } + + with open(config_path, 'w') as f: + json.dump(test_config, f, indent=4) + + try: + provider = CLIProfileCredentialsProvider( + profile_name="nonexistent", + profile_file=config_path + ) + + # 应该抛出CredentialException异常 + with self.assertRaises(CredentialException) as context: + provider._update_oauth_tokens("token", "access", "key", "secret", "sts", 123, 456) + + self.assertIn("failed to update OAuth tokens in config file", str(context.exception)) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_get_credentials_with_oauth_profile(self): + """测试使用OAuth profile获取凭据""" + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.cli_profile._load_config', return_value=self.config): + provider = CLIProfileCredentialsProvider(profile_name="oauth_profile") + + # 模拟OAuth provider的get_credentials方法 + with patch( + 'alibabacloud_credentials.provider.oauth.OAuthCredentialsProvider.get_credentials') as mock_get_creds: + from alibabacloud_credentials.provider.refreshable import Credentials + mock_creds = Credentials( + access_key_id="test_ak", + access_key_secret="test_sk", + security_token="test_token", + provider_name="oauth" + ) + mock_get_creds.return_value = mock_creds + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), "test_ak") + self.assertEqual(credentials.get_access_key_secret(), "test_sk") + self.assertEqual(credentials.get_security_token(), "test_token") + self.assertEqual(credentials.get_provider_name(), "cli_profile/oauth") + + def test_file_lock_concurrent_access(self): + """测试文件锁的并发访问""" + import tempfile + import json + import threading + import time + + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + test_config = { + "current": "oauth_test", + "profiles": [ + { + "name": "oauth_test", + "mode": "OAuth", + "oauth_site_type": "CN", + "oauth_refresh_token": "initial_token", + "oauth_access_token": "initial_access", + "oauth_access_token_expire": int(time.time()) + 3600 + } + ] + } + + with open(config_path, 'w') as f: + json.dump(test_config, f, indent=4) + + try: + provider = CLIProfileCredentialsProvider( + profile_name="oauth_test", + profile_file=config_path + ) + + results = [] + errors = [] + + def update_tokens(index): + try: + provider._update_oauth_tokens( + f"token_{index}", f"access_{index}", f"key_{index}", + f"secret_{index}", f"sts_{index}", + int(time.time()) + 3600 + index, int(time.time()) + 7200 + index + ) + results.append(index) + except Exception as e: + errors.append(e) + + # 并发更新 + threads = [] + for i in range(5): + thread = threading.Thread(target=update_tokens, args=(i,)) + threads.append(thread) + thread.start() + + # 等待所有线程完成 + for thread in threads: + thread.join() + + # 验证最终配置文件仍然有效 + with open(config_path, 'r') as f: + final_config = json.load(f) + + self.assertIsNotNone(final_config) + self.assertIn('profiles', final_config) + self.assertEqual(len(final_config['profiles']), 1) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_edge_cases(self): + """测试边界情况""" + import tempfile + import json + + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + # 测试空配置 + empty_config = {"current": "test", "profiles": []} + + with open(config_path, 'w') as f: + json.dump(empty_config, f, indent=4) + + try: + provider = CLIProfileCredentialsProvider( + profile_name="test", + profile_file=config_path + ) + + # 应该抛出CredentialException异常 + with self.assertRaises(CredentialException) as context: + provider._update_oauth_tokens("token", "access", "key", "secret", "sts", 123, 456) + + self.assertIn("failed to update OAuth tokens in config file", str(context.exception)) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_profile_name_empty(self): + """测试空profile名称的情况""" + import tempfile + import json + + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + test_config = { + "current": "test", + "profiles": [ + { + "name": "test", + "mode": "OAuth", + "oauth_site_type": "CN", + "oauth_refresh_token": "initial_token", + "oauth_access_token": "initial_access", + "oauth_access_token_expire": int(time.time()) + 3600 + } + ] + } + + with open(config_path, 'w') as f: + json.dump(test_config, f, indent=4) + + try: + provider = CLIProfileCredentialsProvider( + profile_name="", # 空名称 + profile_file=config_path, + allow_config_force_rewrite=True + ) + + # 应该使用current profile + provider._update_oauth_tokens("new_token", "new_access", "new_key", "new_secret", "new_sts", 123, 456) + + # 验证配置文件已更新 + with open(config_path, 'r') as f: + updated_config = json.load(f) + + profile = updated_config['profiles'][0] + self.assertEqual(profile['oauth_refresh_token'], "new_token") + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_oauth_token_update_callback_async(self): + """测试异步OAuth令牌更新回调功能""" + import tempfile + import json + import time + + # 创建临时配置文件 + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + test_config = { + "current": "oauth_test", + "profiles": [ + { + "name": "oauth_test", + "mode": "OAuth", + "oauth_site_type": "CN", + "oauth_refresh_token": "initial_refresh_token", + "oauth_access_token": "initial_access_token", + "oauth_access_token_expire": int(time.time()) + 3600 + } + ] + } + + with open(config_path, 'w') as f: + json.dump(test_config, f, indent=4) + + try: + provider = CLIProfileCredentialsProvider( + profile_name="oauth_test", + profile_file=config_path, + allow_config_force_rewrite=True + ) + + # 更新令牌 + new_refresh_token = "new_refresh_token" + new_access_token = "new_access_token" + new_access_key = "new_access_key" + new_secret = "new_secret" + new_security_token = "new_security_token" + new_expire_time = int(time.time()) + 7200 + new_sts_expire = int(time.time()) + 10800 + + async def run_test(): + await provider._update_oauth_tokens_async(new_refresh_token, new_access_token, new_access_key, + new_secret, new_security_token, new_expire_time, + new_sts_expire) + + # 使用 asyncio.run() 替代 get_event_loop() + asyncio.run(run_test()) + + # 验证配置文件已更新 + with open(config_path, 'r') as f: + updated_config = json.load(f) + + profile = updated_config['profiles'][0] + self.assertEqual(profile['oauth_refresh_token'], new_refresh_token) + self.assertEqual(profile['oauth_access_token'], new_access_token) + self.assertEqual(profile['access_key_id'], new_access_key) + self.assertEqual(profile['access_key_secret'], new_secret) + self.assertEqual(profile['sts_token'], new_security_token) + self.assertEqual(profile['oauth_access_token_expire'], new_expire_time) + self.assertEqual(profile['sts_expiration'], new_sts_expire) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_oauth_callback_async_integration(self): + """测试异步OAuth回调集成""" + import tempfile + import json + import time + + # 创建临时配置文件 + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + test_config = { + "current": "oauth_test", + "profiles": [ + { + "name": "oauth_test", + "mode": "OAuth", + "oauth_site_type": "CN", + "oauth_refresh_token": "initial_refresh_token", + "oauth_access_token": "initial_access_token", + "oauth_access_token_expire": int(time.time()) + 3600 + } + ] + } + + with open(config_path, 'w') as f: + json.dump(test_config, f, indent=4) + + try: + provider = CLIProfileCredentialsProvider( + profile_name="oauth_test", + profile_file=config_path, + allow_config_force_rewrite=True + ) + + # 获取异步回调函数 + callback = provider._get_oauth_token_update_callback_async() + + # 调用异步回调函数 + new_refresh_token = "callback_refresh_token" + new_access_token = "callback_access_token" + new_access_key = "callback_access_key" + new_secret = "callback_secret" + new_security_token = "callback_security_token" + new_expire_time = int(time.time()) + 3600 + new_sts_expire = int(time.time()) + 7200 + + async def run_test(): + await callback(new_refresh_token, new_access_token, new_access_key, new_secret, new_security_token, + new_expire_time, new_sts_expire) + + # 使用 asyncio.run() 替代 get_event_loop() + asyncio.run(run_test()) + + # 验证配置文件已更新 + with open(config_path, 'r') as f: + updated_config = json.load(f) + + profile = updated_config['profiles'][0] + self.assertEqual(profile['oauth_refresh_token'], new_refresh_token) + self.assertEqual(profile['oauth_access_token'], new_access_token) + self.assertEqual(profile['access_key_id'], new_access_key) + self.assertEqual(profile['access_key_secret'], new_secret) + self.assertEqual(profile['sts_token'], new_security_token) + self.assertEqual(profile['oauth_access_token_expire'], new_expire_time) + self.assertEqual(profile['sts_expiration'], new_sts_expire) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_write_configuration_to_file_async(self): + """测试异步文件写入功能""" + import tempfile + import json + + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + test_config = { + "current": "test", + "profiles": [{"name": "test", "mode": "AK"}] + } + + try: + provider = CLIProfileCredentialsProvider() + + async def run_test(): + await provider._write_configuration_to_file_async(config_path, test_config) + + # 使用 asyncio.run() 替代 get_event_loop() + asyncio.run(run_test()) + + # 验证文件已写入 + self.assertTrue(os.path.exists(config_path)) + + with open(config_path, 'r') as f: + loaded_config = json.load(f) + + self.assertEqual(loaded_config, test_config) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_write_configuration_to_file_with_lock_async(self): + """测试异步带文件锁的写入功能""" + import tempfile + import json + + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + test_config = { + "current": "test", + "profiles": [{"name": "test", "mode": "AK"}] + } + + try: + # 先创建文件,因为_write_configuration_to_file_with_lock_async需要文件存在 + with open(config_path, 'w') as f: + json.dump({}, f) + + provider = CLIProfileCredentialsProvider( + allow_config_force_rewrite=True, + ) + + async def run_test(): + await provider._write_configuration_to_file_with_lock_async(config_path, test_config) + + # 使用 asyncio.run() 替代 get_event_loop() + asyncio.run(run_test()) + + # 验证文件已写入 + self.assertTrue(os.path.exists(config_path)) + + with open(config_path, 'r') as f: + loaded_config = json.load(f) + + self.assertEqual(loaded_config, test_config) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_get_oauth_token_update_callback_async(self): + """测试获取异步OAuth令牌更新回调函数""" + provider = CLIProfileCredentialsProvider() + callback = provider._get_oauth_token_update_callback_async() + + self.assertIsNotNone(callback) + self.assertTrue(callable(callback)) + + def test_update_oauth_tokens_async_error(self): + """测试异步更新OAuth令牌时的错误处理""" + import tempfile + import json + + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + # 创建无效的配置文件 + with open(config_path, 'w') as f: + f.write("invalid json") + + try: + provider = CLIProfileCredentialsProvider( + profile_name="test", + profile_file=config_path + ) + + async def run_test(): + with self.assertRaises(CredentialException) as context: + await provider._update_oauth_tokens_async("token", "access", "key", "secret", "sts", 123, 456) + + self.assertIn("failed to update OAuth tokens in config file", str(context.exception)) + + # 使用 asyncio.run() 替代 get_event_loop() + asyncio.run(run_test()) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + def test_update_oauth_tokens_async_profile_not_found(self): + """测试异步更新不存在的profile""" + import tempfile + import json + + temp_dir = tempfile.mkdtemp() + config_path = os.path.join(temp_dir, "config.json") + + test_config = { + "current": "test", + "profiles": [{"name": "test", "mode": "AK"}] + } + + with open(config_path, 'w') as f: + json.dump(test_config, f, indent=4) + + try: + provider = CLIProfileCredentialsProvider( + profile_name="nonexistent", + profile_file=config_path + ) + + async def run_test(): + with self.assertRaises(CredentialException) as context: + await provider._update_oauth_tokens_async("token", "access", "key", "secret", "sts", 123, 456) + + self.assertIn("failed to update OAuth tokens in config file", str(context.exception)) + + # 使用 asyncio.run() 替代 get_event_loop() + asyncio.run(run_test()) + + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) diff --git a/tests/provider/test_cloud_sso.py b/tests/provider/test_cloud_sso.py new file mode 100644 index 0000000..a2bf9a2 --- /dev/null +++ b/tests/provider/test_cloud_sso.py @@ -0,0 +1,449 @@ +import unittest +from unittest.mock import patch, MagicMock, AsyncMock +import asyncio +import json +import time +import calendar +from alibabacloud_credentials.provider.cloud_sso import CloudSSOCredentialsProvider, _get_stale_time +from alibabacloud_credentials.exceptions import CredentialException +from alibabacloud_credentials.http import HttpOptions + + +class TestCloudSSOCredentialsProvider(unittest.TestCase): + + def setUp(self): + self.sign_in_url = "https://sso.example.com" + self.account_id = "test_account_id" + self.access_config = "test_access_config" + self.access_token = "test_access_token" + self.access_token_expire = int(time.mktime(time.localtime())) + 3600 # 1 hour from now + self.http_options = HttpOptions(connect_timeout=5000, read_timeout=10000) + + self.access_key_id = "test_access_key_id" + self.access_key_secret = "test_access_key_secret" + self.security_token = "test_security_token" + self.expiration = "2030-12-31T23:59:59Z" + + # Mock response data + self.response_data = { + "CloudCredential": { + "AccessKeyId": self.access_key_id, + "AccessKeySecret": self.access_key_secret, + "SecurityToken": self.security_token, + "Expiration": self.expiration + } + } + + # Mock Tea response + self.mock_response = MagicMock() + self.mock_response.status_code = 200 + self.mock_response.body = json.dumps(self.response_data).encode('utf-8') + + def test_init_valid_input(self): + """ + Test case 1: Valid input, successfully initializes with provided parameters + """ + provider = CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=self.access_token_expire, + http_options=self.http_options + ) + + self.assertEqual(provider._sign_in_url, self.sign_in_url) + self.assertEqual(provider._account_id, self.account_id) + self.assertEqual(provider._access_config, self.access_config) + self.assertEqual(provider._access_token, self.access_token) + self.assertEqual(provider._access_token_expire, self.access_token_expire) + self.assertEqual(provider._http_options, self.http_options) + + def test_init_missing_sign_in_url(self): + """ + Test case 2: Missing sign_in_url raises ValueError + """ + with self.assertRaises(ValueError) as context: + CloudSSOCredentialsProvider( + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + self.assertIn("CloudSSO sign in url or account id or access config is empty", str(context.exception)) + + def test_init_missing_account_id(self): + """ + Test case 3: Missing account_id raises ValueError + """ + with self.assertRaises(ValueError) as context: + CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + self.assertIn("CloudSSO sign in url or account id or access config is empty", str(context.exception)) + + def test_init_missing_access_config(self): + """ + Test case 4: Missing access_config raises ValueError + """ + with self.assertRaises(ValueError) as context: + CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + self.assertIn("CloudSSO sign in url or account id or access config is empty", str(context.exception)) + + def test_init_missing_access_token(self): + """ + Test case 5: Missing access_token raises ValueError + """ + with self.assertRaises(ValueError) as context: + CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_config=self.access_config, + access_token_expire=self.access_token_expire + ) + + self.assertIn("CloudSSO access token is empty or expired, please re-login with cli", str(context.exception)) + + def test_init_expired_access_token(self): + """ + Test case 6: Expired access_token raises ValueError + """ + expired_time = int(time.mktime(time.localtime())) - 3600 # 1 hour ago + + with self.assertRaises(ValueError) as context: + CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=expired_time + ) + + self.assertIn("CloudSSO access token is empty or expired, please re-login with cli", str(context.exception)) + + def test_init_zero_access_token_expire(self): + """ + Test case 7: Zero access_token_expire raises ValueError + """ + with self.assertRaises(ValueError) as context: + CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=0 + ) + + self.assertIn("CloudSSO access token is empty or expired, please re-login with cli", str(context.exception)) + + def test_init_default_http_options(self): + """ + Test case 8: Initializes with default http_options when not provided + """ + provider = CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + self.assertIsInstance(provider._http_options, HttpOptions) + self.assertEqual(provider._runtime_options['connectTimeout'], CloudSSOCredentialsProvider.DEFAULT_CONNECT_TIMEOUT) + self.assertEqual(provider._runtime_options['readTimeout'], CloudSSOCredentialsProvider.DEFAULT_READ_TIMEOUT) + + @patch('alibabacloud_credentials.provider.cloud_sso.TeaCore.do_action') + def test_get_credentials_success(self, mock_do_action): + """ + Test case 9: Valid input, successfully retrieves credentials + """ + mock_do_action.return_value = self.mock_response + + provider = CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "cloud_sso") + + # Verify the request was made correctly + mock_do_action.assert_called_once() + call_args = mock_do_action.call_args + tea_request = call_args[0][0] + + self.assertEqual(tea_request.method, 'POST') + self.assertEqual(tea_request.pathname, '/cloud-credentials') + self.assertEqual(tea_request.headers['Authorization'], f'Bearer {self.access_token}') + self.assertEqual(tea_request.headers['Content-Type'], 'application/json') + + request_body = json.loads(tea_request.body) + self.assertEqual(request_body['AccountId'], self.account_id) + self.assertEqual(request_body['AccessConfigurationId'], self.access_config) + + @patch('alibabacloud_credentials.provider.cloud_sso.TeaCore.async_do_action') + def test_get_credentials_async_success(self, mock_async_do_action): + """ + Test case 10: Valid input, successfully retrieves credentials asynchronously + """ + mock_async_do_action.return_value = self.mock_response + + provider = CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + async def run_test(): + credentials = await provider.get_credentials_async() + return credentials + + # 使用 asyncio.run() 替代 get_event_loop() + credentials = asyncio.run(run_test()) + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "cloud_sso") + + @patch('alibabacloud_credentials.provider.cloud_sso.TeaCore.do_action') + def test_get_credentials_http_error(self, mock_do_action): + """ + Test case 11: HTTP error response raises CredentialException + """ + error_response = MagicMock() + error_response.status_code = 400 + error_response.body = b'{"error": "Bad Request"}' + mock_do_action.return_value = error_response + + provider = CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("error refreshing credentials from sso, http_code: 400", str(context.exception)) + + @patch('alibabacloud_credentials.provider.cloud_sso.TeaCore.do_action') + def test_get_credentials_missing_cloud_credential(self, mock_do_action): + """ + Test case 12: Missing CloudCredential in response raises CredentialException + """ + invalid_response = MagicMock() + invalid_response.status_code = 200 + invalid_response.body = json.dumps({"error": "No credentials"}).encode('utf-8') + mock_do_action.return_value = invalid_response + + provider = CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("error retrieving credentials from sso result", str(context.exception)) + + @patch('alibabacloud_credentials.provider.cloud_sso.TeaCore.do_action') + def test_get_credentials_missing_required_fields(self, mock_do_action): + """ + Test case 13: Missing required fields in CloudCredential raises CredentialException + """ + incomplete_response = MagicMock() + incomplete_response.status_code = 200 + incomplete_response.body = json.dumps({ + "CloudCredential": { + "AccessKeyId": self.access_key_id, + # Missing AccessKeySecret and SecurityToken + } + }).encode('utf-8') + mock_do_action.return_value = incomplete_response + + provider = CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("error retrieving credentials from sso result", str(context.exception)) + + @patch('alibabacloud_credentials.provider.cloud_sso.TeaCore.do_action') + def test_get_credentials_invalid_json(self, mock_do_action): + """ + Test case 14: Invalid JSON response raises JSONDecodeError + """ + invalid_json_response = MagicMock() + invalid_json_response.status_code = 200 + invalid_json_response.body = b'invalid json' + mock_do_action.return_value = invalid_json_response + + provider = CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + with self.assertRaises(json.JSONDecodeError): + provider.get_credentials() + + def test_get_provider_name(self): + """ + Test case 15: Returns correct provider name + """ + provider = CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + self.assertEqual(provider.get_provider_name(), "cloud_sso") + + def test_get_stale_time_positive_expiration(self): + """ + Test case 16: _get_stale_time with positive expiration returns expiration - 15 minutes + """ + expiration = 1672531199 # 2023-01-01 00:00:00 UTC + expected_stale_time = expiration - 15 * 60 # 15 minutes before + + stale_time = _get_stale_time(expiration) + + self.assertEqual(stale_time, expected_stale_time) + + def test_get_stale_time_negative_expiration(self): + """ + Test case 17: _get_stale_time with negative expiration returns current time + 1 hour + """ + with patch('time.mktime') as mock_mktime: + mock_mktime.return_value = 1672531199 # Mock current time + + stale_time = _get_stale_time(-1) + + expected_stale_time = 1672531199 + 60 * 60 # current time + 1 hour + self.assertEqual(stale_time, expected_stale_time) + + @patch('alibabacloud_credentials.provider.cloud_sso.TeaCore.do_action') + def test_credentials_caching(self, mock_do_action): + """ + Test case 18: Credentials are cached and not refreshed on subsequent calls + """ + mock_do_action.return_value = self.mock_response + + provider = CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + # First call + credentials1 = provider.get_credentials() + + # Second call should use cached credentials + credentials2 = provider.get_credentials() + + # Both should return the same credentials + self.assertEqual(credentials1.get_access_key_id(), credentials2.get_access_key_id()) + self.assertEqual(credentials1.get_access_key_secret(), credentials2.get_access_key_secret()) + self.assertEqual(credentials1.get_security_token(), credentials2.get_security_token()) + + # But TeaCore.do_action should only be called once due to caching + self.assertEqual(mock_do_action.call_count, 1) + + @patch('alibabacloud_credentials.provider.cloud_sso.TeaCore.do_action') + def test_url_parsing(self, mock_do_action): + """ + Test case 19: URL parsing works correctly for different URL formats + """ + mock_do_action.return_value = self.mock_response + + # Test with different URL formats + test_urls = [ + "https://sso.example.com", + "https://sso.example.com:8080", + "http://sso.example.com" + ] + + for url in test_urls: + provider = CloudSSOCredentialsProvider( + sign_in_url=url, + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + credentials = provider.get_credentials() + self.assertIsNotNone(credentials) + + # Verify the request was made with correct host + call_args = mock_do_action.call_args + tea_request = call_args[0][0] + + if ":8080" in url: + self.assertEqual(tea_request.port, 8080) + else: + self.assertEqual(tea_request.port, 80) + + if url.startswith("https"): + self.assertEqual(tea_request.protocol, "https") + else: + self.assertEqual(tea_request.protocol, "http") + + @patch('alibabacloud_credentials.provider.cloud_sso.TeaCore.do_action') + def test_expiration_time_parsing(self, mock_do_action): + """ + Test case 20: Expiration time is correctly parsed from ISO format + """ + mock_do_action.return_value = self.mock_response + + provider = CloudSSOCredentialsProvider( + sign_in_url=self.sign_in_url, + account_id=self.account_id, + access_config=self.access_config, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + credentials = provider.get_credentials() + + # The expiration should be parsed correctly from "2023-12-31T23:59:59Z" + expected_expiration = calendar.timegm(time.strptime(self.expiration, '%Y-%m-%dT%H:%M:%SZ')) + self.assertEqual(credentials.get_expiration(), expected_expiration) diff --git a/tests/provider/test_default.py b/tests/provider/test_default.py index 3fd2643..5d3027e 100644 --- a/tests/provider/test_default.py +++ b/tests/provider/test_default.py @@ -268,6 +268,14 @@ def test_get_credentials_no_valid_provider(self): env_provider.get_credentials = MagicMock( side_effect=CredentialException("EnvironmentVariableCredentialsProvider failed")) + oidc_provider = OIDCRoleArnCredentialsProvider( + role_arn='role_arn', + oidc_provider_arn='oidc_provider_arn', + oidc_token_file_path='oidc_token_file_path', + ) + oidc_provider.get_credentials = MagicMock( + side_effect=CredentialException("OIDCRoleArnCredentialsProvider failed")) + cli_provider = CLIProfileCredentialsProvider() cli_provider.get_credentials = MagicMock( side_effect=CredentialException("CLIProfileCredentialsProvider failed")) @@ -285,21 +293,23 @@ def test_get_credentials_no_valid_provider(self): with patch('alibabacloud_credentials.provider.default.EnvironmentVariableCredentialsProvider', return_value=env_provider): - with patch('alibabacloud_credentials.provider.default.CLIProfileCredentialsProvider', - return_value=cli_provider): - with patch('alibabacloud_credentials.provider.default.ProfileCredentialsProvider', - return_value=profile_provider): - with patch('alibabacloud_credentials.provider.default.EcsRamRoleCredentialsProvider', - return_value=ecs_provider): - with patch('alibabacloud_credentials.provider.default.URLCredentialsProvider', - return_value=url_provider): - provider = DefaultCredentialsProvider() + with patch('alibabacloud_credentials.provider.default.OIDCRoleArnCredentialsProvider', + return_value=oidc_provider): + with patch('alibabacloud_credentials.provider.default.CLIProfileCredentialsProvider', + return_value=cli_provider): + with patch('alibabacloud_credentials.provider.default.ProfileCredentialsProvider', + return_value=profile_provider): + with patch('alibabacloud_credentials.provider.default.EcsRamRoleCredentialsProvider', + return_value=ecs_provider): + with patch('alibabacloud_credentials.provider.default.URLCredentialsProvider', + return_value=url_provider): + provider = DefaultCredentialsProvider() - with self.assertRaises(CredentialException) as context: - provider.get_credentials() + with self.assertRaises(CredentialException) as context: + provider.get_credentials() - self.assertIn("unable to load credentials from any of the providers in the chain", - str(context.exception)) + self.assertIn("unable to load credentials from any of the providers in the chain", + str(context.exception)) @patch('alibabacloud_credentials.provider.default.au.enable_oidc_credential', False) @patch('alibabacloud_credentials.provider.default.au.environment_ecs_metadata_disabled', 'false') @@ -315,12 +325,11 @@ def test_get_credentials_async_with_environment_variable_provider(self): return_value=env_provider): provider = DefaultCredentialsProvider() - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + credentials = asyncio.run(run_test()) self.assertEqual(credentials.get_access_key_id(), self.access_key_id) self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) @@ -422,23 +431,21 @@ def test_get_credentials_async_reuse_last_provider_enabled(self): provider = DefaultCredentialsProvider() # First call to get_credentials - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + credentials = asyncio.run(run_test()) self.assertEqual(credentials.get_access_key_id(), self.access_key_id) self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) self.assertEqual(credentials.get_security_token(), self.security_token) self.assertEqual(credentials.get_provider_name(), "default/test_provider") # Second call to get_credentials should reuse the last provider - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + async def run_test1(): + return await provider.get_credentials_async() + + credentials = asyncio.run(run_test1()) self.assertEqual(credentials.get_access_key_id(), self.access_key_id) self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) self.assertEqual(credentials.get_security_token(), self.security_token) @@ -468,23 +475,21 @@ def test_get_credentials_async_reuse_last_provider_disabled(self): provider = DefaultCredentialsProvider(reuse_last_provider_enabled=False) # First call to get_credentials - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + credentials = asyncio.run(run_test()) self.assertEqual(credentials.get_access_key_id(), self.access_key_id) self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) self.assertEqual(credentials.get_security_token(), self.security_token) self.assertEqual(credentials.get_provider_name(), "default/test_provider") # Second call to get_credentials should not reuse the last provider - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + async def run_test1(): + return await provider.get_credentials_async() + + credentials = asyncio.run(run_test1()) self.assertEqual(credentials.get_access_key_id(), self.access_key_id) self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) self.assertEqual(credentials.get_security_token(), self.security_token) diff --git a/tests/provider/test_ecs_ram_role.py b/tests/provider/test_ecs_ram_role.py index 27b0e93..bbadbc5 100644 --- a/tests/provider/test_ecs_ram_role.py +++ b/tests/provider/test_ecs_ram_role.py @@ -186,12 +186,11 @@ def test_get_credentials_async_valid_input(self): http_options=self.http_options ) - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider._refresh_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider._refresh_credentials_async() + + credentials = asyncio.run(run_test()) self.assertEqual(credentials.value().get_access_key_id(), self.access_key_id) self.assertEqual(credentials.value().get_access_key_secret(), self.access_key_secret) @@ -201,11 +200,11 @@ def test_get_credentials_async_valid_input(self): self.assertEqual(credentials.value().get_provider_name(), "ecs_ram_role") with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + asyncio.run(run_test()) self.assertIn("No cached value was found.", str(context.exception)) @@ -230,11 +229,11 @@ def test_get_credentials_async_http_request_error(self): ) with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + asyncio.run(run_test()) self.assertIn( "Failed to get RAM session credentials from ECS metadata service. HttpCode=400", @@ -265,11 +264,11 @@ def test_get_credentials_async_response_format_error(self): ) with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + asyncio.run(run_test()) self.assertIn('Failed to get RAM session credentials from ECS metadata service.', str(context.exception)) @@ -332,12 +331,11 @@ def test_get_metadata_token_async_valid_input(self): http_options=self.http_options ) - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider._get_metadata_token_async() - ) - loop.run_until_complete(task) - metadata_token = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider._get_metadata_token_async() + + metadata_token = asyncio.run(run_test()) self.assertEqual(metadata_token, self.metadata_token) @@ -357,11 +355,11 @@ def test_get_metadata_token_async_http_request_error(self): ) with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider._get_metadata_token_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider._get_metadata_token_async() + + asyncio.run(run_test()) self.assertIn( "Failed to get token from ECS Metadata Service. HttpCode=400", @@ -428,12 +426,11 @@ def test_get_role_name_async_valid_input(self): http_options=self.http_options ) - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider._get_role_name_async() - ) - loop.run_until_complete(task) - role_name = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider._get_role_name_async() + + role_name = asyncio.run(run_test()) self.assertEqual(role_name, self.role_name) @@ -454,11 +451,11 @@ def test_get_role_name_async_http_request_error(self): ) with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider._get_role_name_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider._get_role_name_async() + + asyncio.run(run_test()) self.assertIn( "Failed to get RAM session credentials from ECS metadata service. HttpCode=400", diff --git a/tests/provider/test_oauth.py b/tests/provider/test_oauth.py new file mode 100644 index 0000000..aa8907e --- /dev/null +++ b/tests/provider/test_oauth.py @@ -0,0 +1,1035 @@ +import unittest +from unittest.mock import patch, MagicMock, AsyncMock +import asyncio +import json +import time +import calendar +from alibabacloud_credentials.provider.oauth import OAuthCredentialsProvider, _get_stale_time +from alibabacloud_credentials.exceptions import CredentialException +from alibabacloud_credentials.http import HttpOptions + + +class TestOAuthCredentialsProvider(unittest.TestCase): + + def setUp(self): + self.site_type = "CN" + self.refresh_token = "test_refresh_token" + self.access_token = "test_access_token" + self.access_token_expire = int(time.mktime(time.localtime())) + 3600 # 1 hour from now + self.http_options = HttpOptions(connect_timeout=5000, read_timeout=10000) + + self.access_key_id = "test_access_key_id" + self.access_key_secret = "test_access_key_secret" + self.security_token = "test_security_token" + self.expiration = "2030-12-31T23:59:59Z" + + # Mock response data + self.response_data = { + "AccessKeyId": self.access_key_id, + "AccessKeySecret": self.access_key_secret, + "SecurityToken": self.security_token, + "Expiration": self.expiration + } + + # Mock Tea response + self.mock_response = MagicMock() + self.mock_response.status_code = 200 + self.mock_response.body = json.dumps(self.response_data).encode('utf-8') + + def test_init_valid_input_cn(self): + """ + Test case 1: Valid input with CN client ID and sign-in URL, successfully initializes with provided parameters + """ + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire, + http_options=self.http_options + ) + + self.assertEqual(provider._client_id, "123") + self.assertEqual(provider._sign_in_url, "https://oauth.aliyun.com") + self.assertEqual(provider._access_token, self.access_token) + self.assertEqual(provider._access_token_expire, self.access_token_expire) + self.assertEqual(provider._http_options, self.http_options) + + def test_init_valid_input_intl(self): + """ + Test case 2: Valid input with INTL client ID and sign-in URL, successfully initializes with provided parameters + """ + provider = OAuthCredentialsProvider( + client_id="456", + sign_in_url="https://oauth.alibabacloud.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire, + http_options=self.http_options + ) + + self.assertEqual(provider._client_id, "456") + self.assertEqual(provider._sign_in_url, "https://oauth.alibabacloud.com") + self.assertEqual(provider._access_token, self.access_token) + self.assertEqual(provider._access_token_expire, self.access_token_expire) + self.assertEqual(provider._http_options, self.http_options) + + def test_init_missing_client_id(self): + """ + Test case 3: Missing client_id raises ValueError + """ + with self.assertRaises(ValueError) as context: + OAuthCredentialsProvider( + sign_in_url="https://oauth.aliyun.com", + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + self.assertIn("the ClientId is empty", str(context.exception)) + + def test_init_missing_sign_in_url(self): + """ + Test case 4: Missing sign_in_url raises ValueError + """ + with self.assertRaises(ValueError) as context: + OAuthCredentialsProvider( + client_id="123", + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + self.assertIn("the url for sign-in is empty", str(context.exception)) + + def test_init_missing_access_token(self): + """ + Test case 5: Missing access_token raises ValueError + """ + with self.assertRaises(ValueError) as context: + OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token_expire=self.access_token_expire + ) + + self.assertIn("OAuth access token is empty or expired, please re-login with cli", str(context.exception)) + + def test_init_missing_refresh_token(self): + """ + Test case 6: Missing refresh_token raises ValueError + """ + expired_time = int(time.mktime(time.localtime())) - 3600 # 1 hour ago + + with self.assertRaises(ValueError) as context: + OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token=self.access_token, + access_token_expire=expired_time + ) + + self.assertIn("OAuth access token is empty or expired, please re-login with cli", str(context.exception)) + + def test_init_default_http_options(self): + """ + Test case 8: Initializes with default http_options when not provided + """ + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + self.assertIsInstance(provider._http_options, HttpOptions) + self.assertEqual(provider._runtime_options['connectTimeout'], OAuthCredentialsProvider.DEFAULT_CONNECT_TIMEOUT) + self.assertEqual(provider._runtime_options['readTimeout'], OAuthCredentialsProvider.DEFAULT_READ_TIMEOUT) + + @patch('alibabacloud_credentials.provider.oauth.TeaCore.do_action') + def test_get_credentials_success(self, mock_do_action): + """ + Test case 9: Valid input, successfully retrieves credentials + """ + mock_do_action.return_value = self.mock_response + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "oauth") + + # Verify the request was made correctly + mock_do_action.assert_called_once() + call_args = mock_do_action.call_args + tea_request = call_args[0][0] + + self.assertEqual(tea_request.method, 'POST') + self.assertEqual(tea_request.pathname, '/v1/exchange') + self.assertEqual(tea_request.headers['Authorization'], f'Bearer {self.access_token}') + self.assertEqual(tea_request.headers['Content-Type'], 'application/json') + + @patch('alibabacloud_credentials.provider.oauth.TeaCore.async_do_action') + def test_get_credentials_async_success(self, mock_async_do_action): + """ + Test case 10: Valid input, successfully retrieves credentials asynchronously + """ + mock_async_do_action.return_value = self.mock_response + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + async def run_test(): + credentials = await provider.get_credentials_async() + return credentials + + # 使用 asyncio.run() 替代 get_event_loop() + credentials = asyncio.run(run_test()) + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "oauth") + + @patch('alibabacloud_credentials.provider.oauth.TeaCore.do_action') + def test_get_credentials_http_error(self, mock_do_action): + """ + Test case 11: HTTP error response raises CredentialException + """ + error_response = MagicMock() + error_response.status_code = 400 + error_response.body = b'{"error": "Bad Request"}' + mock_do_action.return_value = error_response + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("error refreshing credentials from OAuth, http_code: 400", str(context.exception)) + + @patch('alibabacloud_credentials.provider.oauth.TeaCore.do_action') + def test_get_credentials_error_in_response(self, mock_do_action): + """ + Test case 12: Error field in response raises CredentialException + """ + error_response = MagicMock() + error_response.status_code = 200 + error_response.body = json.dumps({"error": "Invalid token"}).encode('utf-8') + mock_do_action.return_value = error_response + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("error retrieving credentials from OAuth result", str(context.exception)) + + @patch('alibabacloud_credentials.provider.oauth.TeaCore.do_action') + def test_get_credentials_missing_required_fields(self, mock_do_action): + """ + Test case 13: Missing required fields in response raises CredentialException + """ + incomplete_response = MagicMock() + incomplete_response.status_code = 200 + incomplete_response.body = json.dumps({ + "AccessKeyId": self.access_key_id, + # Missing AccessKeySecret and SecurityToken + }).encode('utf-8') + mock_do_action.return_value = incomplete_response + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("error retrieving credentials from OAuth result", str(context.exception)) + + @patch('alibabacloud_credentials.provider.oauth.TeaCore.do_action') + def test_get_credentials_invalid_json(self, mock_do_action): + """ + Test case 14: Invalid JSON response raises JSONDecodeError + """ + invalid_json_response = MagicMock() + invalid_json_response.status_code = 200 + invalid_json_response.body = b'invalid json' + mock_do_action.return_value = invalid_json_response + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + with self.assertRaises(json.JSONDecodeError): + provider.get_credentials() + + def test_get_provider_name(self): + """ + Test case 15: Returns correct provider name + """ + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + self.assertEqual(provider.get_provider_name(), "oauth") + + def test_get_stale_time_positive_expiration(self): + """ + Test case 16: _get_stale_time with positive expiration returns expiration - 15 minutes + """ + expiration = 1672531199 # 2023-01-01 00:00:00 UTC + expected_stale_time = expiration - 15 * 60 # 15 minutes before + + stale_time = _get_stale_time(expiration) + + self.assertEqual(stale_time, expected_stale_time) + + def test_get_stale_time_negative_expiration(self): + """ + Test case 17: _get_stale_time with negative expiration returns current time + 1 hour + """ + with patch('time.mktime') as mock_mktime: + mock_mktime.return_value = 1672531199 # Mock current time + + stale_time = _get_stale_time(-1) + + expected_stale_time = 1672531199 + 60 * 60 # current time + 1 hour + self.assertEqual(stale_time, expected_stale_time) + + @patch('alibabacloud_credentials.provider.oauth.TeaCore.do_action') + def test_credentials_caching(self, mock_do_action): + """ + Test case 18: Credentials are cached and not refreshed on subsequent calls + """ + mock_do_action.return_value = self.mock_response + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + # First call + credentials1 = provider.get_credentials() + + # Second call should use cached credentials + credentials2 = provider.get_credentials() + + # Both should return the same credentials + self.assertEqual(credentials1.get_access_key_id(), credentials2.get_access_key_id()) + self.assertEqual(credentials1.get_access_key_secret(), credentials2.get_access_key_secret()) + self.assertEqual(credentials1.get_security_token(), credentials2.get_security_token()) + + # But TeaCore.do_action should only be called once due to caching + self.assertEqual(mock_do_action.call_count, 1) + + @patch('alibabacloud_credentials.provider.oauth.TeaCore.do_action') + def test_url_parsing_cn(self, mock_do_action): + """ + Test case 19: URL parsing works correctly for CN site type + """ + mock_do_action.return_value = self.mock_response + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + credentials = provider.get_credentials() + self.assertIsNotNone(credentials) + + # Verify the request was made with correct host + call_args = mock_do_action.call_args + tea_request = call_args[0][0] + + self.assertEqual(tea_request.protocol, "https") + self.assertEqual(tea_request.headers['host'], "oauth.aliyun.com") + + @patch('alibabacloud_credentials.provider.oauth.TeaCore.do_action') + def test_url_parsing_intl(self, mock_do_action): + """ + Test case 20: URL parsing works correctly for INTL site type + """ + mock_do_action.return_value = self.mock_response + + provider = OAuthCredentialsProvider( + client_id="456", + sign_in_url="https://oauth.alibabacloud.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + credentials = provider.get_credentials() + self.assertIsNotNone(credentials) + + # Verify the request was made with correct host + call_args = mock_do_action.call_args + tea_request = call_args[0][0] + + self.assertEqual(tea_request.protocol, "https") + self.assertEqual(tea_request.headers['host'], "oauth.alibabacloud.com") + + @patch('alibabacloud_credentials.provider.oauth.TeaCore.do_action') + def test_expiration_time_parsing(self, mock_do_action): + """ + Test case 21: Expiration time is correctly parsed from ISO format + """ + mock_do_action.return_value = self.mock_response + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + credentials = provider.get_credentials() + + # The expiration should be parsed correctly from "2030-12-31T23:59:59Z" + expected_expiration = calendar.timegm(time.strptime(self.expiration, '%Y-%m-%dT%H:%M:%SZ')) + self.assertEqual(credentials.get_expiration(), expected_expiration) + + def test_client_id_and_sign_in_url(self): + """ + Test case 22: Client ID and sign-in URL are correctly set + """ + # Test CN configuration + provider_cn = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + self.assertEqual(provider_cn._client_id, "123") + self.assertEqual(provider_cn._sign_in_url, "https://oauth.aliyun.com") + + # Test INTL configuration + provider_intl = OAuthCredentialsProvider( + client_id="456", + sign_in_url="https://oauth.alibabacloud.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + self.assertEqual(provider_intl._client_id, "456") + self.assertEqual(provider_intl._sign_in_url, "https://oauth.alibabacloud.com") + + @patch('alibabacloud_credentials.provider.oauth.TeaCore.do_action') + def test_request_body_empty(self, mock_do_action): + """ + Test case 23: Request body should be empty for OAuth exchange + """ + mock_do_action.return_value = self.mock_response + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + refresh_token=self.refresh_token, + access_token=self.access_token, + access_token_expire=self.access_token_expire + ) + + provider.get_credentials() + + # Verify the request body is empty + call_args = mock_do_action.call_args + tea_request = call_args[0][0] + + self.assertIsNone(tea_request.body) + + @patch('Tea.core.TeaCore.do_action') + def test_oauth_token_refresh_success(self, mock_do_action): + """测试 OAuth 令牌刷新成功""" + # 模拟成功的令牌刷新响应 + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.body = json.dumps({ + "AccessKeyId": "test", + "AccessKeySecret": "test", + "SecurityToken": "test", + "Expiration": "2021-10-20T04:27:09Z", + }).encode('utf-8') + mock_do_action.return_value = mock_response + + callback_called = False + + def test_callback(refresh_token, access_token, access_key, secret, security_token, access_token_expire, + sts_expire): + nonlocal callback_called + callback_called = True + self.assertEqual(refresh_token, "old_refresh_token") + self.assertEqual(access_token, "old_access_token") + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="old_access_token", + access_token_expire=int(time.time()) + 3600, # 未过期 + refresh_token="old_refresh_token", + token_update_callback=test_callback + ) + + # 执行令牌刷新 + provider._refresh_credentials() + + # 验证回调被调用 + self.assertTrue(callback_called) + + @patch('Tea.core.TeaCore.do_action') + def test_oauth_callback_in_credentials_refresh(self, mock_do_action): + """测试在凭据刷新时调用回调函数""" + # 模拟成功的凭据交换响应 + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.body = json.dumps({ + "AccessKeyId": "test_access_key_id", + "AccessKeySecret": "test_access_key_secret", + "SecurityToken": "test_security_token", + "Expiration": "2030-12-31T23:59:59Z" + }).encode('utf-8') + mock_do_action.return_value = mock_response + + callback_called = False + callback_data = None + + def test_callback(refresh_token, access_token, access_key, secret, security_token, access_token_expire, + sts_expire): + nonlocal callback_called, callback_data + callback_called = True + callback_data = (refresh_token, access_token, access_key, secret, security_token, access_token_expire, + sts_expire) + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="test_refresh_token", + token_update_callback=test_callback + ) + + # 获取凭据,这会触发回调 + credentials = provider.get_credentials() + + # 验证回调被调用 + self.assertTrue(callback_called) + self.assertIsNotNone(callback_data) + self.assertEqual(callback_data[0], "test_refresh_token") # refresh_token + self.assertEqual(callback_data[1], "test_access_token") # access_token + self.assertEqual(callback_data[2], "test_access_key_id") # access_key + self.assertEqual(callback_data[3], "test_access_key_secret") # secret + self.assertEqual(callback_data[4], "test_security_token") # security_token + + @patch('Tea.core.TeaCore.do_action') + def test_oauth_callback_error_handling(self, mock_do_action): + """测试回调函数错误处理""" + # 模拟成功的凭据交换响应 + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.body = json.dumps({ + "AccessKeyId": "test_access_key_id", + "AccessKeySecret": "test_access_key_secret", + "SecurityToken": "test_security_token", + "Expiration": "2030-12-31T23:59:59Z" + }).encode('utf-8') + mock_do_action.return_value = mock_response + + def error_callback(refresh_token, access_token, access_key, secret, security_token, access_token_expire, + sts_expire): + raise Exception("Callback error") + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="test_refresh_token", + token_update_callback=error_callback + ) + + # 获取凭据,即使回调出错也应该成功 + credentials = provider.get_credentials() + + # 验证凭据仍然成功获取 + self.assertIsNotNone(credentials) + self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") + + def test_oauth_provider_without_refresh_token(self): + """测试没有refresh_token的OAuth提供者""" + with self.assertRaises(ValueError) as context: + OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600 + ) + + self.assertIn("OAuth access token is empty or expired, please re-login with cli", str(context.exception)) + + def test_oauth_provider_with_empty_refresh_token(self): + """测试空refresh_token的OAuth提供者""" + with self.assertRaises(ValueError) as context: + OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="" + ) + + self.assertIn("OAuth access token is empty or expired, please re-login with cli", str(context.exception)) + + @patch('Tea.core.TeaCore.do_action') + def test_oauth_token_refresh_failure(self, mock_do_action): + """测试OAuth令牌刷新失败""" + # 模拟失败的令牌刷新响应 + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.body = b'{"error": "Invalid refresh token"}' + mock_do_action.return_value = mock_response + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="old_access_token", + access_token_expire=int(time.time()) - 100, # 已过期 + refresh_token="invalid_refresh_token" + ) + + # 执行令牌刷新,应该静默失败 + with self.assertRaises(CredentialException): + provider._try_refresh_oauth_token() + + # 验证令牌没有被更新 + self.assertEqual(provider._access_token, "old_access_token") + self.assertEqual(provider._refresh_token, "invalid_refresh_token") + + + @patch('Tea.core.TeaCore.do_action') + def test_oauth_token_refresh_network_error(self, mock_do_action): + """测试OAuth令牌刷新时网络错误""" + # 模拟网络错误 + mock_do_action.side_effect = Exception("Network error") + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="old_access_token", + access_token_expire=int(time.time()) - 100, # 已过期 + refresh_token="test_refresh_token" + ) + + # 执行令牌刷新 + with self.assertRaises(Exception): + provider._try_refresh_oauth_token() + + # 验证令牌没有被更新 + self.assertEqual(provider._access_token, "old_access_token") + + def test_oauth_provider_http_options(self): + """测试OAuth提供者的HTTP选项""" + custom_http_options = HttpOptions(connect_timeout=10000, read_timeout=20000) + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="test_refresh_token", + http_options=custom_http_options + ) + + self.assertEqual(provider._http_options, custom_http_options) + self.assertEqual(provider._runtime_options['connectTimeout'], 10000) + self.assertEqual(provider._runtime_options['readTimeout'], 20000) + + def test_oauth_provider_runtime_options_with_proxy(self): + """测试OAuth提供者的运行时选项包含代理""" + custom_http_options = HttpOptions(proxy="http://proxy.example.com:8080") + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="test_refresh_token", + http_options=custom_http_options + ) + + self.assertEqual(provider._runtime_options['httpsProxy'], "http://proxy.example.com:8080") + + @patch('Tea.core.TeaCore.do_action') + def test_oauth_credentials_refresh_with_callback(self, mock_do_action): + """测试OAuth凭据刷新时调用回调""" + # 模拟成功的凭据交换响应 + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.body = json.dumps({ + "AccessKeyId": "test_access_key_id", + "AccessKeySecret": "test_access_key_secret", + "SecurityToken": "test_security_token", + "Expiration": "2030-12-31T23:59:59Z" + }).encode('utf-8') + mock_do_action.return_value = mock_response + + callback_called = False + callback_data = None + + def test_callback(refresh_token, access_token, access_key, secret, security_token, access_token_expire, + sts_expire): + nonlocal callback_called, callback_data + callback_called = True + callback_data = (refresh_token, access_token, access_key, secret, security_token, access_token_expire, + sts_expire) + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="test_refresh_token", + token_update_callback=test_callback + ) + + # 获取凭据,这会触发回调 + credentials = provider.get_credentials() + + # 验证回调被调用 + self.assertTrue(callback_called) + self.assertIsNotNone(callback_data) + self.assertEqual(callback_data[0], "test_refresh_token") # refresh_token + self.assertEqual(callback_data[1], "test_access_token") # access_token + self.assertEqual(callback_data[2], "test_access_key_id") # access_key + self.assertEqual(callback_data[3], "test_access_key_secret") # secret + self.assertEqual(callback_data[4], "test_security_token") # security_token + + @patch('Tea.core.TeaCore.async_do_action') + def test_oauth_credentials_refresh_async_with_callback(self, mock_async_do_action): + """测试OAuth异步凭据刷新时调用回调""" + # 模拟成功的凭据交换响应 + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.body = json.dumps({ + "AccessKeyId": "test_access_key_id", + "AccessKeySecret": "test_access_key_secret", + "SecurityToken": "test_security_token", + "Expiration": "2030-12-31T23:59:59Z" + }).encode('utf-8') + mock_async_do_action.return_value = mock_response + + callback_called = False + callback_data = None + + async def test_callback(refresh_token, access_token, access_key, secret, security_token, access_token_expire, + sts_expire): + nonlocal callback_called, callback_data + callback_called = True + callback_data = (refresh_token, access_token, access_key, secret, security_token, access_token_expire, + sts_expire) + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="test_refresh_token", + token_update_callback_async=test_callback + ) + + async def run_test(): + return await provider.get_credentials_async() + + # 使用 asyncio.run() 替代 get_event_loop() + credentials = asyncio.run(run_test()) + + # 验证回调被调用 + self.assertTrue(callback_called) + self.assertIsNotNone(callback_data) + self.assertEqual(callback_data[0], "test_refresh_token") # refresh_token + self.assertEqual(callback_data[1], "test_access_token") # access_token + self.assertEqual(callback_data[2], "test_access_key_id") # access_key + self.assertEqual(callback_data[3], "test_access_key_secret") # secret + self.assertEqual(callback_data[4], "test_security_token") # security_token + + def test_oauth_provider_validation_edge_cases(self): + """测试OAuth提供者验证的边界情况""" + # 测试空字符串client_id + with self.assertRaises(ValueError) as context: + OAuthCredentialsProvider( + client_id="", + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="test_refresh_token" + ) + self.assertIn("the ClientId is empty", str(context.exception)) + + # 测试None client_id + with self.assertRaises(ValueError) as context: + OAuthCredentialsProvider( + client_id=None, + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="test_refresh_token" + ) + self.assertIn("the ClientId is empty", str(context.exception)) + + # 测试空字符串sign_in_url + with self.assertRaises(ValueError) as context: + OAuthCredentialsProvider( + client_id="123", + sign_in_url="", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="test_refresh_token" + ) + self.assertIn("the url for sign-in is empty", str(context.exception)) + + # 测试None sign_in_url + with self.assertRaises(ValueError) as context: + OAuthCredentialsProvider( + client_id="123", + sign_in_url=None, + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="test_refresh_token" + ) + self.assertIn("the url for sign-in is empty", str(context.exception)) + + def test_oauth_provider_refresh_token_validation(self): + """测试OAuth提供者refresh_token验证""" + # 测试None refresh_token + with self.assertRaises(ValueError) as context: + OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token=None + ) + self.assertIn("OAuth access token is empty or expired, please re-login with cli", str(context.exception)) + + # 测试空字符串refresh_token + with self.assertRaises(ValueError) as context: + OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="" + ) + self.assertIn("OAuth access token is empty or expired, please re-login with cli", str(context.exception)) + + def test_oauth_provider_with_async_callback(self): + """测试带有异步回调的OAuth提供者""" + callback_called = False + callback_data = None + + async def test_async_callback(refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire): + nonlocal callback_called, callback_data + callback_called = True + callback_data = (refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire) + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="test_refresh_token", + token_update_callback_async=test_async_callback + ) + + self.assertIsNotNone(provider._token_update_callback_async) + self.assertEqual(provider._client_id, "123") + self.assertEqual(provider._sign_in_url, "https://oauth.aliyun.com") + + @patch('Tea.core.TeaCore.async_do_action') + def test_oauth_async_token_refresh_success(self, mock_async_do_action): + """测试异步OAuth令牌刷新成功""" + # 模拟成功的令牌刷新响应 + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.body = json.dumps({ + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "expires_in": 3600 + }).encode('utf-8') + mock_async_do_action.return_value = mock_response + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="old_access_token", + access_token_expire=int(time.time()) - 100, # 已过期 + refresh_token="old_refresh_token" + ) + + async def run_test(): + await provider._try_refresh_oauth_token_async() + return provider._access_token, provider._refresh_token + + # 使用 asyncio.run() 替代 get_event_loop() + new_access_token, new_refresh_token = asyncio.run(run_test()) + + # 验证令牌被更新 + self.assertEqual(new_access_token, "new_access_token") + self.assertEqual(new_refresh_token, "new_refresh_token") + + @patch('Tea.core.TeaCore.async_do_action') + def test_oauth_async_token_refresh_failure(self, mock_async_do_action): + """测试异步OAuth令牌刷新失败""" + # 模拟失败的令牌刷新响应 + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.body = b'{"error": "Invalid refresh token"}' + mock_async_do_action.return_value = mock_response + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="old_access_token", + access_token_expire=int(time.time()) - 100, # 已过期 + refresh_token="invalid_refresh_token" + ) + + async def run_test(): + await provider._try_refresh_oauth_token_async() + return provider._access_token, provider._refresh_token + + # 使用 asyncio.run() 替代 get_event_loop() + with self.assertRaises(CredentialException): + new_access_token, new_refresh_token = asyncio.run(run_test()) + + + # 验证令牌没有被更新 + self.assertEqual(provider._access_token, "old_access_token") + self.assertEqual(provider._refresh_token, "invalid_refresh_token") + + @patch('Tea.core.TeaCore.async_do_action') + def test_oauth_async_credentials_refresh_with_async_callback(self, mock_async_do_action): + """测试异步凭据刷新时调用异步回调函数""" + # 模拟成功的凭据交换响应 + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.body = json.dumps({ + "AccessKeyId": "test_access_key_id", + "AccessKeySecret": "test_access_key_secret", + "SecurityToken": "test_security_token", + "Expiration": "2030-12-31T23:59:59Z" + }).encode('utf-8') + mock_async_do_action.return_value = mock_response + + callback_called = False + callback_data = None + + async def test_async_callback(refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire): + nonlocal callback_called, callback_data + callback_called = True + callback_data = (refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire) + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="test_refresh_token", + token_update_callback_async=test_async_callback + ) + + async def run_test(): + return await provider.get_credentials_async() + + # 使用 asyncio.run() 替代 get_event_loop() + credentials = asyncio.run(run_test()) + + # 验证回调被调用 + self.assertTrue(callback_called) + self.assertIsNotNone(callback_data) + self.assertEqual(callback_data[0], "test_refresh_token") # refresh_token + self.assertEqual(callback_data[1], "test_access_token") # access_token + self.assertEqual(callback_data[2], "test_access_key_id") # access_key + self.assertEqual(callback_data[3], "test_access_key_secret") # secret + self.assertEqual(callback_data[4], "test_security_token") # security_token + + @patch('Tea.core.TeaCore.async_do_action') + def test_oauth_async_credentials_refresh_with_async_callback_error(self, mock_async_do_action): + """测试异步回调函数错误处理""" + # 模拟成功的凭据交换响应 + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.body = json.dumps({ + "AccessKeyId": "test_access_key_id", + "AccessKeySecret": "test_access_key_secret", + "SecurityToken": "test_security_token", + "Expiration": "2030-12-31T23:59:59Z" + }).encode('utf-8') + mock_async_do_action.return_value = mock_response + + async def error_async_callback(refresh_token, access_token, access_key, secret, security_token, access_token_expire, sts_expire): + raise Exception("Async callback error") + + provider = OAuthCredentialsProvider( + client_id="123", + sign_in_url="https://oauth.aliyun.com", + access_token="test_access_token", + access_token_expire=int(time.time()) + 3600, + refresh_token="test_refresh_token", + token_update_callback_async=error_async_callback + ) + + async def run_test(): + return await provider.get_credentials_async() + + # 使用 asyncio.run() 替代 get_event_loop() + credentials = asyncio.run(run_test()) + + # 验证凭据仍然成功获取 + self.assertIsNotNone(credentials) + self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") diff --git a/tests/provider/test_oidc.py b/tests/provider/test_oidc.py index 50cd4c5..ed1af75 100644 --- a/tests/provider/test_oidc.py +++ b/tests/provider/test_oidc.py @@ -81,10 +81,12 @@ def test_init_valid_environment_variables(self, mock_auth_util): self.assertEqual(provider._runtime_options['readTimeout'], OIDCRoleArnCredentialsProvider.DEFAULT_READ_TIMEOUT) self.assertIsNone(provider._runtime_options['httpsProxy']) - def test_init_missing_role_arn(self): + @patch('alibabacloud_credentials.provider.oidc.au') + def test_init_missing_role_arn(self, mock_auth_util): """ Test case 3: Missing role_arn raises ValueError """ + mock_auth_util.environment_role_arn = None with self.assertRaises(ValueError) as context: OIDCRoleArnCredentialsProvider( oidc_provider_arn=self.oidc_provider_arn, @@ -93,10 +95,12 @@ def test_init_missing_role_arn(self): self.assertIn("role_arn or environment variable ALIBABA_CLOUD_ROLE_ARN cannot be empty", str(context.exception)) - def test_init_empty_role_arn(self): + @patch('alibabacloud_credentials.provider.oidc.au') + def test_init_empty_role_arn(self, mock_auth_util): """ Test case 4: Empty role_arn raises ValueError """ + mock_auth_util.environment_role_arn = None with self.assertRaises(ValueError) as context: OIDCRoleArnCredentialsProvider( role_arn="", @@ -106,10 +110,12 @@ def test_init_empty_role_arn(self): self.assertIn("role_arn or environment variable ALIBABA_CLOUD_ROLE_ARN cannot be empty", str(context.exception)) - def test_init_missing_oidc_provider_arn(self): + @patch('alibabacloud_credentials.provider.oidc.au') + def test_init_missing_oidc_provider_arn(self, mock_auth_util): """ Test case 5: Missing oidc_provider_arn raises ValueError """ + mock_auth_util.environment_oidc_provider_arn = None with self.assertRaises(ValueError) as context: OIDCRoleArnCredentialsProvider( role_arn=self.role_arn, @@ -119,10 +125,12 @@ def test_init_missing_oidc_provider_arn(self): self.assertIn("oidc_provider_arn or environment variable ALIBABA_CLOUD_OIDC_PROVIDER_ARN cannot be empty", str(context.exception)) - def test_init_empty_oidc_provider_arn(self): + @patch('alibabacloud_credentials.provider.oidc.au') + def test_init_empty_oidc_provider_arn(self, mock_auth_util): """ Test case 6: Empty oidc_provider_arn raises ValueError """ + mock_auth_util.environment_oidc_provider_arn = None with self.assertRaises(ValueError) as context: OIDCRoleArnCredentialsProvider( role_arn=self.role_arn, @@ -133,10 +141,12 @@ def test_init_empty_oidc_provider_arn(self): self.assertIn("oidc_provider_arn or environment variable ALIBABA_CLOUD_OIDC_PROVIDER_ARN cannot be empty", str(context.exception)) - def test_init_missing_oidc_token_file_path(self): + @patch('alibabacloud_credentials.provider.oidc.au') + def test_init_missing_oidc_token_file_path(self, mock_auth_util): """ Test case 7: Missing oidc_token_file_path raises ValueError """ + mock_auth_util.environment_oidc_token_file = None with self.assertRaises(ValueError) as context: OIDCRoleArnCredentialsProvider( role_arn=self.role_arn, @@ -146,10 +156,12 @@ def test_init_missing_oidc_token_file_path(self): self.assertIn("oidc_token_file_path or environment variable ALIBABA_CLOUD_OIDC_TOKEN_FILE cannot be empty", str(context.exception)) - def test_init_empty_oidc_token_file_path(self): + @patch('alibabacloud_credentials.provider.oidc.au') + def test_init_empty_oidc_token_file_path(self, mock_auth_util): """ Test case 8: Empty oidc_token_file_path raises ValueError """ + mock_auth_util.environment_oidc_token_file = None with self.assertRaises(ValueError) as context: OIDCRoleArnCredentialsProvider( role_arn=self.role_arn, @@ -358,12 +370,11 @@ def test_get_credentials_async_valid_input(self): http_options=self.http_options ) - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider._refresh_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider._refresh_credentials_async() + + credentials = asyncio.run(run_test()) self.assertEqual(credentials.value().get_access_key_id(), "test_access_key_id") self.assertEqual(credentials.value().get_access_key_secret(), "test_access_key_secret") @@ -373,11 +384,11 @@ def test_get_credentials_async_valid_input(self): self.assertEqual(credentials.value().get_provider_name(), "oidc_role_arn") with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + asyncio.run(run_test()) self.assertIn("No cached value was found.", str(context.exception)) @@ -399,11 +410,11 @@ def test_get_credentials_async_file_read_error(self): ) with self.assertRaises(FileNotFoundError) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + asyncio.run(run_test()) def test_get_credentials_async_http_request_error(self): """ @@ -429,11 +440,11 @@ def test_get_credentials_async_http_request_error(self): ) with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + asyncio.run(run_test()) self.assertIn( "error refreshing credentials from oidc_role_arn, http_code: 400, result: HTTP request failed", @@ -466,11 +477,11 @@ def test_get_credentials_async_response_format_error(self): ) with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + asyncio.run(run_test()) self.assertIn('error retrieving credentials from oidc_role_arn result: {"Error": "Invalid request"}', str(context.exception)) diff --git a/tests/provider/test_profile.py b/tests/provider/test_profile.py index d509385..2a96cdb 100644 --- a/tests/provider/test_profile.py +++ b/tests/provider/test_profile.py @@ -107,13 +107,18 @@ def test_get_credentials_valid_oidc_role_arn(self): self.assertIn("error refreshing credentials from oidc_role_arn", str(context.exception)) - def test_get_credentials_valid_ecs_ram_role(self): + @patch('Tea.core.TeaCore.do_action') + def test_get_credentials_valid_ecs_ram_role(self, mock_do_action): """ Test case 5: Valid input, successfully retrieves credentials for ecs_ram_role type """ with patch('os.path.exists', return_value=True): with patch('os.path.isfile', return_value=True): with patch('alibabacloud_credentials.provider.profile._load_ini', return_value=self.config): + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.body = b'{"error": "Invalid"}' + mock_do_action.return_value = mock_response provider = ProfileCredentialsProvider(profile_name="ecs_ram_role_profile") with self.assertRaises(CredentialException) as context: @@ -223,7 +228,8 @@ def test_get_credentials_missing_access_key_secret(self): self.assertIn('the access key secret is empty', str(context.exception)) - def test_get_credentials_missing_role_arn(self): + @patch('alibabacloud_credentials.provider.ram_role_arn.au') + def test_get_credentials_missing_role_arn(self, mock_au): """ Test case 12: Missing role_arn raises CredentialException """ @@ -234,6 +240,7 @@ def test_get_credentials_missing_role_arn(self): 'role_session_name': 'test_ram_session_name', 'policy': 'test_policy' }} + mock_au.environment_role_arn = None with patch('os.path.exists', return_value=True): with patch('os.path.isfile', return_value=True): with patch('alibabacloud_credentials.provider.profile._load_ini', return_value=missing_role_arn_config): @@ -245,7 +252,8 @@ def test_get_credentials_missing_role_arn(self): self.assertIn('role_arn or environment variable ALIBABA_CLOUD_ROLE_ARN cannot be empty', str(context.exception)) - def test_get_credentials_missing_oidc_provider_arn(self): + @patch('alibabacloud_credentials.provider.oidc.au') + def test_get_credentials_missing_oidc_provider_arn(self, mock_auth_util): """ Test case 13: Missing oidc_provider_arn raises CredentialException """ @@ -256,6 +264,7 @@ def test_get_credentials_missing_oidc_provider_arn(self): 'role_session_name': 'test_role_session_name', 'policy': 'test_policy' }} + mock_auth_util.environment_oidc_provider_arn = None with patch('os.path.exists', return_value=True): with patch('os.path.isfile', return_value=True): with patch('alibabacloud_credentials.provider.profile._load_ini', @@ -269,7 +278,8 @@ def test_get_credentials_missing_oidc_provider_arn(self): 'oidc_provider_arn or environment variable ALIBABA_CLOUD_OIDC_PROVIDER_ARN cannot be empty', str(context.exception)) - def test_get_credentials_missing_oidc_token_file_path(self): + @patch('alibabacloud_credentials.provider.oidc.au') + def test_get_credentials_missing_oidc_token_file_path(self, mock_auth_util): """ Test case 14: Missing oidc_token_file_path raises CredentialException """ @@ -280,6 +290,7 @@ def test_get_credentials_missing_oidc_token_file_path(self): 'role_session_name': 'test_role_session_name', 'policy': 'test_policy' }} + mock_auth_util.environment_oidc_token_file = None with patch('os.path.exists', return_value=True): with patch('os.path.isfile', return_value=True): with patch('alibabacloud_credentials.provider.profile._load_ini', @@ -293,7 +304,8 @@ def test_get_credentials_missing_oidc_token_file_path(self): 'oidc_token_file_path or environment variable ALIBABA_CLOUD_OIDC_TOKEN_FILE cannot be empty', str(context.exception)) - def test_get_credentials_missing_role_name(self): + @patch('Tea.core.TeaCore.do_action') + def test_get_credentials_missing_role_name(self, mock_do_action): """ Test case 15: Missing role_name raises CredentialException """ @@ -304,6 +316,10 @@ def test_get_credentials_missing_role_name(self): with patch('os.path.isfile', return_value=True): with patch('alibabacloud_credentials.provider.profile._load_ini', return_value=missing_role_name_config): + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.body = b'{"error": "Invalid"}' + mock_do_action.return_value = mock_response provider = ProfileCredentialsProvider(profile_name="ecs_ram_role_profile") with self.assertRaises(CredentialException) as context: @@ -362,12 +378,11 @@ def test_get_credentials_async_valid_access_key(self): AsyncMock(return_value=self.config)): provider = ProfileCredentialsProvider(profile_name=self.profile_name) - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + credentials = asyncio.run(run_test()) self.assertEqual(credentials.get_access_key_id(), self.access_key_id) self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) diff --git a/tests/provider/test_ram_role_arn.py b/tests/provider/test_ram_role_arn.py index 90117c1..417df5f 100644 --- a/tests/provider/test_ram_role_arn.py +++ b/tests/provider/test_ram_role_arn.py @@ -91,10 +91,12 @@ def test_init_valid_environment_variables(self, mock_ram_util, mock_ak_util): self.assertEqual(provider._runtime_options['readTimeout'], RamRoleArnCredentialsProvider.DEFAULT_READ_TIMEOUT) self.assertIsNone(provider._runtime_options['httpsProxy']) - def test_init_missing_role_arn(self): + @patch('alibabacloud_credentials.provider.ram_role_arn.au') + def test_init_missing_role_arn(self, mock_au): """ Test case 3: Missing role_arn raises ValueError """ + mock_au.environment_role_arn = None with self.assertRaises(ValueError) as context: RamRoleArnCredentialsProvider( access_key_id=self.access_key_id, @@ -104,10 +106,12 @@ def test_init_missing_role_arn(self): self.assertIn("role_arn or environment variable ALIBABA_CLOUD_ROLE_ARN cannot be empty", str(context.exception)) - def test_init_empty_role_arn(self): + @patch('alibabacloud_credentials.provider.ram_role_arn.au') + def test_init_empty_role_arn(self, mock_au): """ Test case 4: Empty role_arn raises ValueError """ + mock_au.environment_role_arn = None with self.assertRaises(ValueError) as context: RamRoleArnCredentialsProvider( access_key_id=self.access_key_id, @@ -301,12 +305,11 @@ def test_get_credentials_async_valid_input(self): http_options=self.http_options ) - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider._refresh_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider._refresh_credentials_async() + + credentials = asyncio.run(run_test()) self.assertEqual(credentials.value().get_access_key_id(), "test_access_key_id") self.assertEqual(credentials.value().get_access_key_secret(), "test_access_key_secret") @@ -317,11 +320,11 @@ def test_get_credentials_async_valid_input(self): self.assertEqual(credentials.value().get_provider_name(), "ram_role_arn/static_sts") with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + asyncio.run(run_test()) self.assertIn("No cached value was found.", str(context.exception)) @@ -349,11 +352,11 @@ def test_get_credentials_async_http_request_error(self): ) with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + asyncio.run(run_test()) self.assertIn( "error refreshing credentials from ram_role_arn, http_code: 400, result: HTTP request failed", @@ -386,11 +389,11 @@ def test_get_credentials_async_response_format_error(self): ) with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + asyncio.run(run_test()) self.assertIn( 'error retrieving credentials from ram_role_arn result: {"Error": "Invalid request"}', diff --git a/tests/provider/test_refreshable.py b/tests/provider/test_refreshable.py index 80e062a..1985106 100644 --- a/tests/provider/test_refreshable.py +++ b/tests/provider/test_refreshable.py @@ -114,11 +114,11 @@ def test_prefetch_async(self, mock_prefetch): action = AsyncMock() - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - self.non_blocking.prefetch_async(action) - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await self.non_blocking.prefetch_async(action) + + asyncio.run(run_test()) mock_prefetch.assert_called_once() @@ -145,11 +145,11 @@ def test_prefetch_async(self, mock_prefetch): """ action = AsyncMock() - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - self.one_caller_blocks.prefetch_async(action) - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await self.one_caller_blocks.prefetch_async(action) + + asyncio.run(run_test()) action.assert_called_once() @@ -207,12 +207,11 @@ def test_async_call_cache_not_stale(self): """ self.refresh_cached_supplier._cached_value = self.refresh_result - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - self.refresh_cached_supplier._async_call() - ) - loop.run_until_complete(task) - result = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await self.refresh_cached_supplier._async_call() + + result = asyncio.run(run_test()) self.assertEqual(result.get_access_key_id(), "test_access_key_id") self.refresh_callable_async.assert_not_called() @@ -226,12 +225,11 @@ def test_async_call_cache_stale(self, mock_refresh_cache_async): self.refresh_cached_supplier._cached_value._stale_time = int(time.mktime(time.localtime())) - 1800 mock_refresh_cache_async.return_value = self.refresh_result - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - self.refresh_cached_supplier._async_call() - ) - loop.run_until_complete(task) - result = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await self.refresh_cached_supplier._async_call() + + result = asyncio.run(run_test()) self.assertEqual(result.get_access_key_id(), "test_access_key_id") mock_refresh_cache_async.assert_called_once() @@ -282,11 +280,11 @@ def test_prefetch_cache_async(self, mock_refresh_cache): """ self.refresh_cached_supplier._prefetch_strategy.prefetch_async = AsyncMock() - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - self.refresh_cached_supplier._prefetch_cache_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await self.refresh_cached_supplier._prefetch_cache_async() + + asyncio.run(run_test()) self.refresh_cached_supplier._prefetch_strategy.prefetch_async.assert_called_once_with(mock_refresh_cache) @@ -297,12 +295,16 @@ def test_refresh_cache_success(self, mock_handle_fetched_success): """ self.refresh_callable.return_value = self.refresh_result mock_handle_fetched_success.return_value = self.refresh_result + + # Mock the lock object + mock_lock = MagicMock() + self.refresh_cached_supplier._refresh_lock = mock_lock self.refresh_cached_supplier._refresh_cache() self.refresh_callable.assert_called_once() mock_handle_fetched_success.assert_called_once_with(self.refresh_result) - self.refresh_cached_supplier._refresh_lock.release.assert_called_once() + mock_lock.release.assert_called_once() @patch('alibabacloud_credentials.provider.refreshable.RefreshCachedSupplier._handle_fetched_failure') def test_refresh_cache_failure(self, mock_handle_fetched_failure): @@ -325,32 +327,37 @@ async def test_refresh_cache_async_success(self, mock_handle_fetched_success): self.refresh_callable_async.return_value = self.refresh_result mock_handle_fetched_success.return_value = self.refresh_result - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - self.refresh_cached_supplier._prefetch_cache_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await self.refresh_cached_supplier._prefetch_cache_async() + + asyncio.run(run_test()) self.refresh_callable_async.assert_called_once() mock_handle_fetched_success.assert_called_once_with(self.refresh_result) @patch('alibabacloud_credentials.provider.refreshable.RefreshCachedSupplier._handle_fetched_failure') - async def test_refresh_cache_async_failure(self, mock_handle_fetched_failure): + def test_refresh_cache_async_failure(self, mock_handle_fetched_failure): """ Test case 20: Test refresh_cache_async method on failure """ - self.refresh_callable_async.side_effect = Exception("Test exception") + test_exception = Exception("Test exception") + self.refresh_callable_async.side_effect = test_exception mock_handle_fetched_failure.return_value = self.refresh_result + + # Mock the lock object + mock_lock = MagicMock() + self.refresh_cached_supplier._refresh_lock = mock_lock - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - self.refresh_cached_supplier._refresh_cache_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await self.refresh_cached_supplier._refresh_cache_async() + + asyncio.run(run_test()) self.refresh_callable_async.assert_called_once() - mock_handle_fetched_failure.assert_called_once_with(Exception("Test exception")) - self.refresh_cached_supplier._refresh_lock.release.assert_called_once() + mock_handle_fetched_failure.assert_called_once_with(test_exception) + mock_lock.release.assert_called_once() def test_handle_fetched_success(self): """ diff --git a/tests/provider/test_rsa_key_pair.py b/tests/provider/test_rsa_key_pair.py index 02fd988..35410e6 100644 --- a/tests/provider/test_rsa_key_pair.py +++ b/tests/provider/test_rsa_key_pair.py @@ -268,12 +268,11 @@ def test_get_credentials_async_valid_input(self): http_options=self.http_options ) - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider._refresh_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider._refresh_credentials_async() + + credentials = asyncio.run(run_test()) self.assertEqual(credentials.value().get_access_key_id(), "test_access_key_id") self.assertEqual(credentials.value().get_access_key_secret(), @@ -304,11 +303,11 @@ def test_get_credentials_async_http_request_error(self): ) with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + asyncio.run(run_test()) self.assertIn( "error refreshing credentials from rsa_key_pair, http_code: 400, result: HTTP request failed", @@ -338,11 +337,11 @@ def test_get_credentials_async_response_format_error(self): ) with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + asyncio.run(run_test()) self.assertIn( 'error retrieving credentials from rsa_key_pair result: {"Error": "Invalid request"}', diff --git a/tests/provider/test_static_ak.py b/tests/provider/test_static_ak.py index 306f5f0..2f2ee2a 100644 --- a/tests/provider/test_static_ak.py +++ b/tests/provider/test_static_ak.py @@ -146,13 +146,11 @@ def test_get_credentials_async_valid_input(self): access_key_secret="test_access_key_secret" ) - # Use asyncio.run to execute the async function - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + credentials = asyncio.run(run_test()) self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") self.assertEqual(credentials.get_access_key_secret(), "test_access_key_secret") @@ -169,13 +167,11 @@ def test_get_credentials_async_valid_environment_variables(self, mock_auth_util) provider = StaticAKCredentialsProvider() - # Use asyncio.run to execute the async function - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + credentials = asyncio.run(run_test()) self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") self.assertEqual(credentials.get_access_key_secret(), "test_access_key_secret") diff --git a/tests/provider/test_static_sts.py b/tests/provider/test_static_sts.py index 78e6816..16fab2f 100644 --- a/tests/provider/test_static_sts.py +++ b/tests/provider/test_static_sts.py @@ -184,13 +184,11 @@ def test_get_credentials_async_valid_input(self): security_token="test_security_token" ) - # Use asyncio.run to execute the async function - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + credentials = asyncio.run(run_test()) self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") self.assertEqual(credentials.get_access_key_secret(), "test_access_key_secret") @@ -208,13 +206,11 @@ def test_get_credentials_async_valid_environment_variables(self, mock_auth_util) provider = StaticSTSCredentialsProvider() - # Use asyncio.run to execute the async function - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + credentials = asyncio.run(run_test()) self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") self.assertEqual(credentials.get_access_key_secret(), "test_access_key_secret") diff --git a/tests/provider/test_uri.py b/tests/provider/test_uri.py index f65f0a2..d4051b8 100644 --- a/tests/provider/test_uri.py +++ b/tests/provider/test_uri.py @@ -168,12 +168,11 @@ def test_get_credentials_async_valid_input(self): http_options=self.http_options ) - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider._refresh_credentials_async() - ) - loop.run_until_complete(task) - credentials = task.result() + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider._refresh_credentials_async() + + credentials = asyncio.run(run_test()) self.assertEqual(credentials.value().get_access_key_id(), self.access_key_id) self.assertEqual(credentials.value().get_access_key_secret(), self.access_key_secret) @@ -198,11 +197,11 @@ def test_get_credentials_async_http_request_error(self): ) with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + asyncio.run(run_test()) self.assertIn( f'error refreshing credentials from {self.uri}, http_code=400, result: HTTP request failed', @@ -228,11 +227,11 @@ def test_get_credentials_async_response_format_error(self): ) with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - provider.get_credentials_async() - ) - loop.run_until_complete(task) + # 使用 asyncio.run() 替代 get_event_loop() + async def run_test(): + return await provider.get_credentials_async() + + asyncio.run(run_test()) self.assertIn(f'error retrieving credentials from {self.uri} result: {response_body}', str(context.exception)) diff --git a/tests/test_client.py b/tests/test_client.py index edd8b07..e5b382f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -93,17 +93,27 @@ def test_async_call(self): type='access_key' ) client = Client(conf) - loop = asyncio.get_event_loop() - task = asyncio.ensure_future(client.get_security_token_async()) - loop.run_until_complete(task) - self.assertIsNone(task.result()) - task = asyncio.ensure_future(client.get_access_key_id_async()) - loop.run_until_complete(task) - self.assertEqual('ak1', task.result()) - task = asyncio.ensure_future(client.get_access_key_secret_async()) - loop.run_until_complete(task) - self.assertEqual('sk1', task.result()) - task = asyncio.ensure_future(client.get_credential_async()) - loop.run_until_complete(task) - credential = task.result() + + async def get_security_token_async(): + return await client.get_security_token_async() + + result = asyncio.run(get_security_token_async()) + self.assertIsNone(result) + + async def get_access_key_id_async(): + return await client.get_access_key_id_async() + + result = asyncio.run(get_access_key_id_async()) + self.assertEqual('ak1', result) + + async def get_access_key_secret_async(): + return await client.get_access_key_secret_async() + + result = asyncio.run(get_access_key_secret_async()) + self.assertEqual('sk1', result) + + async def get_credential_async(): + return await client.get_credential_async() + + credential = asyncio.run(get_credential_async()) self.assertEqual('ak1', credential.access_key_id) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index bc302e4..79ea874 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -363,12 +363,10 @@ def test_CredentialsURICredential_async_normal(self): credentials_uri = 'http://localhost:6666/test' cred = credentials.CredentialsURICredential(credentials_uri) - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - cred.get_credential_async() - ) - loop.run_until_complete(task) - model = task.result() + async def run_test(): + return await cred.get_credential_async() + + model = asyncio.run(run_test()) self.assertEqual('test_access_key_id', model.access_key_id) self.assertEqual('test_access_key_secret', model.access_key_secret) @@ -397,12 +395,10 @@ def test_CredentialsURICredential_async_refresh(self): # Set expiration to a past time to trigger refresh cred.expiration = 1 - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - cred.get_credential_async() - ) - loop.run_until_complete(task) - model = task.result() + async def run_test(): + return await cred.get_credential_async() + + model = asyncio.run(run_test()) self.assertEqual('test_access_key_id', model.access_key_id) self.assertEqual('test_access_key_secret', model.access_key_secret) self.assertEqual('test_security_token', model.security_token) @@ -420,12 +416,11 @@ def test_CredentialsURICredential_async_http_request_error(self): credentials_uri = 'http://localhost:6666/test' cred = credentials.CredentialsURICredential(credentials_uri) + async def run_test(): + return await cred.get_credential_async() + with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - cred.get_credential_async() - ) - loop.run_until_complete(task) + asyncio.run(run_test()) self.assertIn( "Get credentials from http://localhost:6666/test failed, HttpCode=400", @@ -447,12 +442,11 @@ def test_CredentialsURICredential_async_response_format_error(self): credentials_uri = 'http://localhost:6666/test' cred = credentials.CredentialsURICredential(credentials_uri) + async def run_test(): + return await cred.get_credential_async() + with self.assertRaises(CredentialException) as context: - loop = asyncio.get_event_loop() - task = asyncio.ensure_future( - cred.get_credential_async() - ) - loop.run_until_complete(task) + asyncio.run(run_test()) self.assertIn( "Get credentials from http://localhost:6666/test failed, Code is Failure",