From 4429d82931ee1da80195fecbb2a56240ff5be8b1 Mon Sep 17 00:00:00 2001 From: nanhe Date: Thu, 9 Jan 2025 13:45:48 +0800 Subject: [PATCH] refactor credentials providers --- .github/workflows/testPython.yml | 8 +- alibabacloud_credentials/client.py | 204 ++++-- alibabacloud_credentials/credentials.py | 3 + alibabacloud_credentials/http/__init__.py | 5 + alibabacloud_credentials/http/_options.py | 9 + alibabacloud_credentials/models.py | 203 +++--- alibabacloud_credentials/provider/__init__.py | 25 + .../provider/cli_profile.py | 164 +++++ alibabacloud_credentials/provider/default.py | 89 +++ .../provider/ecs_ram_role.py | 251 ++++++++ alibabacloud_credentials/provider/env.py | 33 + alibabacloud_credentials/provider/oidc.py | 206 ++++++ alibabacloud_credentials/provider/profile.py | 144 +++++ .../provider/ram_role_arn.py | 234 +++++++ .../provider/refreshable.py | 304 +++++++++ .../provider/rsa_key_pair.py | 186 ++++++ .../provider/static_ak.py | 32 + .../provider/static_sts.py | 37 ++ alibabacloud_credentials/provider/uri.py | 135 ++++ alibabacloud_credentials/providers.py | 6 +- alibabacloud_credentials/utils/auth_util.py | 19 +- setup.py | 19 +- tests/provider/test_cli_profile.py | 385 ++++++++++++ tests/provider/test_default.py | 494 +++++++++++++++ tests/provider/test_ecs_ram_role.py | 489 +++++++++++++++ tests/provider/test_env.py | 175 ++++++ tests/provider/test_oidc.py | 592 ++++++++++++++++++ tests/provider/test_profile.py | 374 +++++++++++ tests/provider/test_ram_role_arn.py | 513 +++++++++++++++ tests/provider/test_refreshable.py | 465 ++++++++++++++ tests/provider/test_rsa_key_pair.py | 459 ++++++++++++++ tests/provider/test_static_ak.py | 183 ++++++ tests/provider/test_static_sts.py | 222 +++++++ tests/provider/test_uri.py | 374 +++++++++++ tests/test_client.py | 60 +- tests/test_credentials.py | 228 ++++++- tests/test_model.py | 67 -- tests/test_models.py | 269 ++++++++ tests/test_providers.py | 6 +- 39 files changed, 7454 insertions(+), 217 deletions(-) create mode 100644 alibabacloud_credentials/http/__init__.py create mode 100644 alibabacloud_credentials/http/_options.py create mode 100644 alibabacloud_credentials/provider/__init__.py create mode 100644 alibabacloud_credentials/provider/cli_profile.py create mode 100644 alibabacloud_credentials/provider/default.py create mode 100644 alibabacloud_credentials/provider/ecs_ram_role.py create mode 100644 alibabacloud_credentials/provider/env.py create mode 100644 alibabacloud_credentials/provider/oidc.py create mode 100644 alibabacloud_credentials/provider/profile.py create mode 100644 alibabacloud_credentials/provider/ram_role_arn.py create mode 100644 alibabacloud_credentials/provider/refreshable.py create mode 100644 alibabacloud_credentials/provider/rsa_key_pair.py create mode 100644 alibabacloud_credentials/provider/static_ak.py create mode 100644 alibabacloud_credentials/provider/static_sts.py create mode 100644 alibabacloud_credentials/provider/uri.py create mode 100644 tests/provider/test_cli_profile.py create mode 100644 tests/provider/test_default.py create mode 100644 tests/provider/test_ecs_ram_role.py create mode 100644 tests/provider/test_env.py create mode 100644 tests/provider/test_oidc.py create mode 100644 tests/provider/test_profile.py create mode 100644 tests/provider/test_ram_role_arn.py create mode 100644 tests/provider/test_refreshable.py create mode 100644 tests/provider/test_rsa_key_pair.py create mode 100644 tests/provider/test_static_ak.py create mode 100644 tests/provider/test_static_sts.py create mode 100644 tests/provider/test_uri.py delete mode 100644 tests/test_model.py create mode 100644 tests/test_models.py diff --git a/.github/workflows/testPython.yml b/.github/workflows/testPython.yml index b5a43b6..752c845 100644 --- a/.github/workflows/testPython.yml +++ b/.github/workflows/testPython.yml @@ -11,19 +11,19 @@ permissions: jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 strategy: matrix: - python-version: [ "3.7", "3.8", "3.9", "3.10", "3.11", "3.12" ] + python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ] fail-fast: false steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies - run: pip install alibabacloud-tea coverage pytest + run: pip install alibabacloud-tea coverage pytest alibabacloud_credentials_api APScheduler aiofiles - name: Setup OIDC run: npm install @actions/core@1.6.0 @actions/http-client - name: Get Id Token diff --git a/alibabacloud_credentials/client.py b/alibabacloud_credentials/client.py index 80bd909..302dea6 100644 --- a/alibabacloud_credentials/client.py +++ b/alibabacloud_credentials/client.py @@ -1,6 +1,18 @@ from functools import wraps -from alibabacloud_credentials import credentials, providers, models +from alibabacloud_credentials_api import ICredentialsProvider +from alibabacloud_credentials import credentials +from alibabacloud_credentials.exceptions import CredentialException +from alibabacloud_credentials.models import Config, CredentialModel +from alibabacloud_credentials.http import HttpOptions +from alibabacloud_credentials.provider import (StaticAKCredentialsProvider, + StaticSTSCredentialsProvider, + RamRoleArnCredentialsProvider, + OIDCRoleArnCredentialsProvider, + RsaKeyPairCredentialsProvider, + EcsRamRoleCredentialsProvider, + URLCredentialsProvider, + DefaultCredentialsProvider) from alibabacloud_credentials.utils import auth_constant as ac from Tea.decorators import deprecated @@ -16,24 +28,85 @@ def i(*args, **kwargs): return i +class _CredentialsProviderWrap: + + def __init__(self, + *, + type_name: str = None, + provider: ICredentialsProvider = None): + self.type_name = type_name + self.provider = provider + + def get_access_key_id(self) -> str: + credential = self.provider.get_credentials() + return credential.get_access_key_id() + + async def get_access_key_id_async(self) -> str: + credential = await self.provider.get_credentials_async() + return credential.get_access_key_id() + + def get_access_key_secret(self) -> str: + credential = self.provider.get_credentials() + return credential.get_access_key_secret() + + async def get_access_key_secret_async(self) -> str: + credential = await self.provider.get_credentials_async() + return credential.get_access_key_secret() + + def get_security_token(self): + credential = self.provider.get_credentials() + return credential.get_security_token() + + async def get_security_token_async(self): + credential = await self.provider.get_credentials_async() + return credential.get_security_token() + + def get_credential(self) -> CredentialModel: + credential = self.provider.get_credentials() + return CredentialModel( + access_key_id=credential.get_access_key_id(), + access_key_secret=credential.get_access_key_secret(), + security_token=credential.get_security_token(), + type=self.type_name, + provider_name=credential.get_provider_name(), + ) + + async def get_credential_async(self) -> CredentialModel: + credential = await self.provider.get_credentials_async() + return CredentialModel( + access_key_id=credential.get_access_key_id(), + access_key_secret=credential.get_access_key_secret(), + security_token=credential.get_security_token(), + type=self.type_name, + provider_name=credential.get_provider_name(), + ) + + def get_type(self) -> str: + return self.type_name + + class Client: cloud_credential = None - def __init__(self, config=None): - if config is None: - provider = providers.DefaultCredentialsProvider() - self.cloud_credential = provider.get_credentials() - return - self.cloud_credential = Client.get_credentials(config) - - def get_credential(self) -> models.CredentialModel: + def __init__(self, + config: Config = None, + provider: ICredentialsProvider = None): + if provider is not None: + self.cloud_credential = _CredentialsProviderWrap(provider=provider) + elif config is None: + provider = DefaultCredentialsProvider() + self.cloud_credential = _CredentialsProviderWrap(type_name='default', provider=provider) + else: + self.cloud_credential = Client.get_credentials(config) + + def get_credential(self) -> CredentialModel: """ Get credential @return: the whole credential """ return self.cloud_credential.get_credential() - async def get_credential_async(self) -> models.CredentialModel: + async def get_credential_async(self) -> CredentialModel: """ Get credential @return: the whole credential @@ -43,44 +116,99 @@ async def get_credential_async(self) -> models.CredentialModel: @staticmethod def get_credentials(config): if config.type == ac.ACCESS_KEY: - return credentials.AccessKeyCredential(config.access_key_id, config.access_key_secret) + provider = StaticAKCredentialsProvider( + access_key_id=config.access_key_id, + access_key_secret=config.access_key_secret, + ) + return _CredentialsProviderWrap(type_name='access_key', provider=provider) elif config.type == ac.STS: - return credentials.StsCredential(config.access_key_id, config.access_key_secret, config.security_token) + provider = StaticSTSCredentialsProvider( + access_key_id=config.access_key_id, + access_key_secret=config.access_key_secret, + security_token=config.security_token, + ) + return _CredentialsProviderWrap(type_name='sts', provider=provider) elif config.type == ac.BEARER: return credentials.BearerTokenCredential(config.bearer_token) elif config.type == ac.ECS_RAM_ROLE: - return credentials.EcsRamRoleCredential( - config.access_key_id, - config.access_key_secret, - config.security_token, - 0, - providers.EcsRamRoleCredentialProvider(config=config) + provider = EcsRamRoleCredentialsProvider( + role_name=config.role_name, + disable_imds_v1=config.disable_imds_v1, + http_options=HttpOptions( + read_timeout=config.timeout, + connect_timeout=config.connect_timeout, + proxy=config.proxy, + ), ) + return _CredentialsProviderWrap(type_name='ecs_ram_role', provider=provider) elif config.type == ac.CREDENTIALS_URI: - return credentials.CredentialsURICredential(config.credentials_uri) + provider = URLCredentialsProvider( + uri=config.credentials_uri, + http_options=HttpOptions( + read_timeout=config.timeout, + connect_timeout=config.connect_timeout, + proxy=config.proxy, + ), + ) + return _CredentialsProviderWrap(type_name='credentials_uri', provider=provider) elif config.type == ac.RAM_ROLE_ARN: - return credentials.RamRoleArnCredential( - config.access_key_id, - config.access_key_secret, - config.security_token, - 0, - providers.RamRoleArnCredentialProvider(config=config) + if config.security_token is not None and config.security_token != '': + previous_provider = StaticSTSCredentialsProvider( + access_key_id=config.access_key_id, + access_key_secret=config.access_key_secret, + security_token=config.security_token, + ) + else: + previous_provider = StaticAKCredentialsProvider( + access_key_id=config.access_key_id, + access_key_secret=config.access_key_secret, + ) + provider = RamRoleArnCredentialsProvider( + credentials_provider=previous_provider, + role_arn=config.role_arn, + role_session_name=config.role_session_name, + duration_seconds=config.role_session_expiration, + policy=config.policy, + external_id=config.external_id, + sts_endpoint=config.sts_endpoint, + http_options=HttpOptions( + read_timeout=config.timeout, + connect_timeout=config.connect_timeout, + proxy=config.proxy, + ), ) + return _CredentialsProviderWrap(type_name='ram_role_arn', provider=provider) elif config.type == ac.RSA_KEY_PAIR: - return credentials.RsaKeyPairCredential( - config.access_key_id, - config.access_key_secret, - 0, - providers.RsaKeyPairCredentialProvider(config=config) + provider = RsaKeyPairCredentialsProvider( + public_key_id=config.public_key_id, + private_key_file=config.private_key_file, + duration_seconds=config.role_session_expiration, + sts_endpoint=config.sts_endpoint, + http_options=HttpOptions( + read_timeout=config.timeout, + connect_timeout=config.connect_timeout, + proxy=config.proxy, + ), ) + return _CredentialsProviderWrap(type_name='rsa_key_pair', provider=provider) elif config.type == ac.OIDC_ROLE_ARN: - return credentials.OIDCRoleArnCredential( - config.access_key_id, - config.access_key_secret, - config.security_token, - 0, - providers.OIDCRoleArnCredentialProvider(config=config)) - return providers.DefaultCredentialsProvider().get_credentials() + provider = OIDCRoleArnCredentialsProvider( + role_arn=config.role_arn, + oidc_provider_arn=config.oidc_provider_arn, + oidc_token_file_path=config.oidc_token_file_path, + role_session_name=config.role_session_name, + duration_seconds=config.role_session_expiration, + policy=config.policy, + sts_endpoint=config.sts_endpoint, + http_options=HttpOptions( + read_timeout=config.timeout, + connect_timeout=config.connect_timeout, + proxy=config.proxy, + ), + ) + return _CredentialsProviderWrap(type_name='oidc_role_arn', provider=provider) + raise CredentialException( + 'invalid type option, support: access_key, sts, bearer, ecs_ram_role, ram_role_arn, rsa_key_pair, oidc_role_arn, credentials_uri') @deprecated("Use 'get_credential().access_key_id' instead") def get_access_key_id(self): @@ -109,7 +237,7 @@ async def get_security_token_async(self): @deprecated("Use 'get_credential().type' instead") @attribute_error_return_none def get_type(self): - return self.cloud_credential.credential_type + return self.cloud_credential.get_type() @deprecated("Use 'get_credential().bearer_token' instead") @attribute_error_return_none diff --git a/alibabacloud_credentials/credentials.py b/alibabacloud_credentials/credentials.py index 536321c..543f533 100644 --- a/alibabacloud_credentials/credentials.py +++ b/alibabacloud_credentials/credentials.py @@ -116,6 +116,9 @@ async def get_credential_async(self): type=ac.BEARER ) + def get_type(self) -> str: + return self.credential_type + class EcsRamRoleCredential(Credential, _AutomaticallyRefreshCredentials): """EcsRamRoleCredential""" diff --git a/alibabacloud_credentials/http/__init__.py b/alibabacloud_credentials/http/__init__.py new file mode 100644 index 0000000..7a37211 --- /dev/null +++ b/alibabacloud_credentials/http/__init__.py @@ -0,0 +1,5 @@ +from ._options import HttpOptions + +__all__ = [ + 'HttpOptions' +] diff --git a/alibabacloud_credentials/http/_options.py b/alibabacloud_credentials/http/_options.py new file mode 100644 index 0000000..d3b38f7 --- /dev/null +++ b/alibabacloud_credentials/http/_options.py @@ -0,0 +1,9 @@ +class HttpOptions: + def __init__(self, + *, + proxy: str = None, + connect_timeout: int = None, + read_timeout: int = None): + self.proxy = proxy + self.connect_timeout = connect_timeout + self.read_timeout = read_timeout diff --git a/alibabacloud_credentials/models.py b/alibabacloud_credentials/models.py index 98dcb37..28c3cbb 100644 --- a/alibabacloud_credentials/models.py +++ b/alibabacloud_credentials/models.py @@ -5,83 +5,128 @@ class Config(TeaModel): """ - Model for initing credential + Model for initializing credential """ def __init__( self, - access_key_id: str = '', - access_key_secret: str = '', - security_token: str = '', - bearer_token: str = '', - duration_seconds: int = '', - role_arn: str = '', - oidc_provider_arn: str = '', - oidc_token_file_path: str = '', - policy: str = '', - role_session_expiration: int = '', - role_session_name: str = '', - public_key_id: str = '', - private_key_file: str = '', - role_name: str = '', - type: str = '', - host: str = '', - timeout: int = 1000, - connect_timeout: int = 1000, - proxy: str = '', - credentials_uri: str = '', - disable_imds_v1: bool = False, - enable_imds_v2: bool = False, - metadata_token_duration: int = 21600, - sts_endpoint: str = None + *, + type: str = None, + access_key_id: str = None, + access_key_secret: str = None, + security_token: str = None, + bearer_token: str = None, + duration_seconds: int = None, + role_arn: str = None, + oidc_provider_arn: str = None, + oidc_token_file_path: str = None, + role_session_name: str = None, + role_session_expiration: int = None, + policy: str = None, + external_id: str = None, + sts_endpoint: str = None, + public_key_id: str = None, + private_key_file: str = None, + role_name: str = None, + enable_imds_v2: bool = None, + disable_imds_v1: bool = None, + metadata_token_duration: int = None, + credentials_uri: str = None, + host: str = None, + timeout: int = None, + connect_timeout: int = None, + proxy: str = None, ): - # accesskey id + """ + Initialize the credential object. + + ### Parameters + + #### General Parameters + - `type` (str): Credential type, including `access_key`, `sts`, `bearer`, `ecs_ram_role`, `ram_role_arn`, `rsa_key_pair`, `oidc_role_arn`, `credentials_uri`. + + #### Access Key Type + - `access_key_id` (str): Access Key ID. + - `access_key_secret` (str): Access Key Secret. + - `security_token` (str, optional): Security token. + + #### Bearer Token Type + - `bearer_token` (str): Bearer token. + + #### RAM Role ARN and OIDC Role ARN Types + - `role_arn` (str): Role ARN. + - `oidc_provider_arn` (str, for `oidc_role_arn` only): OIDC provider ARN. + - `oidc_token_file_path` (str, for `oidc_role_arn` only): Path to the OIDC token file. + - `role_session_name` (str): Role session name. + - `role_session_expiration` (int, optional): Role session expiration time in seconds. + - `policy` (str, optional): Policy. + - `external_id` (str, optional): External ID. + - `sts_endpoint` (str, optional): STS endpoint. + - `duration_seconds`: deprecated + + #### RSA Key Pair Type + - `public_key_id` (str): Public key ID. + - `private_key_file` (str): Path to the private key file. + + #### ECS RAM Role Type + - `role_name` (str): Role name. + - `disable_imds_v1` (bool, optional): Whether to disable IMDS v1. Default is `False`. + - `enable_imds_v2` (bool, optional): Whether to enable IMDS v2. Default is `None`. + - `metadata_token_duration` (int, optional): Metadata token expiration time in seconds. Default is `None`. + + #### Credentials URI Type + - `credentials_uri` (str): Credentials URI. + + #### HTTP Options + - `host` (str, optional): Host address. + - `timeout` (int, optional): Read timeout in milliseconds. Default values: + - `ecs_ram_role`: 1000ms + - `ram_role_arn`: 5000ms + - `oidc_role_arn`: 5000ms + - `connect_timeout` (int, optional): Connection timeout in milliseconds. Default values: + - `ecs_ram_role`: 1000ms + - `ram_role_arn`: 10000ms + - `oidc_role_arn`: 10000ms + - `proxy` (str, optional): HTTP or HTTPS proxy. + + #### Other Parameters + - `duration_seconds` (int, optional): Duration in seconds, mainly used for `sts` type credentials. + + Note: Some parameters are only valid for specific credential types. Please use them according to your actual needs. + """ + self.type = type self.access_key_id = access_key_id - # accesskey secret self.access_key_secret = access_key_secret - # security token self.security_token = security_token - # bearer token self.bearer_token = bearer_token - # duration seconds self.duration_seconds = duration_seconds - # role arn self.role_arn = role_arn - # oidc provider arn self.oidc_provider_arn = oidc_provider_arn - # oidc token file path self.oidc_token_file_path = oidc_token_file_path - # policy - self.policy = policy - # role session expiration - self.role_session_expiration = role_session_expiration - # role session name self.role_session_name = role_session_name - # publicKey id + self.role_session_expiration = role_session_expiration + self.policy = policy + self.external_id = external_id + self.sts_endpoint = sts_endpoint self.public_key_id = public_key_id - # privateKey file self.private_key_file = private_key_file - # role name self.role_name = role_name self.disable_imds_v1 = disable_imds_v1 self.enable_imds_v2 = enable_imds_v2 self.metadata_token_duration = metadata_token_duration - # credential type - self.type = type + self.credentials_uri = credentials_uri self.host = host self.timeout = timeout self.connect_timeout = connect_timeout self.proxy = proxy - # credentials uri - self.credentials_uri = credentials_uri - # STS Endpoint - self.sts_endpoint = sts_endpoint def validate(self): pass def to_map(self): result = dict() + if self.type is not None: + result['type'] = self.type if self.access_key_id is not None: result['accessKeyId'] = self.access_key_id if self.access_key_secret is not None: @@ -98,12 +143,16 @@ def to_map(self): result['oidcProviderArn'] = self.oidc_provider_arn if self.oidc_token_file_path is not None: result['oidcTokenFilePath'] = self.oidc_token_file_path - if self.policy is not None: - result['policy'] = self.policy - if self.role_session_expiration is not None: - result['roleSessionExpiration'] = self.role_session_expiration if self.role_session_name is not None: result['roleSessionName'] = self.role_session_name + if self.role_session_expiration is not None: + result['roleSessionExpiration'] = self.role_session_expiration + if self.policy is not None: + result['policy'] = self.policy + if self.external_id is not None: + result['externalId'] = self.external_id + if self.sts_endpoint is not None: + result['stsEndpoint'] = self.sts_endpoint if self.public_key_id is not None: result['publicKeyId'] = self.public_key_id if self.private_key_file is not None: @@ -116,8 +165,8 @@ def to_map(self): result['enableIMDSv2'] = self.enable_imds_v2 if self.metadata_token_duration is not None: result['metadataTokenDuration'] = self.metadata_token_duration - if self.type is not None: - result['type'] = self.type + if self.credentials_uri is not None: + result['credentialsUri'] = self.credentials_uri if self.host is not None: result['host'] = self.host if self.timeout is not None: @@ -126,14 +175,12 @@ def to_map(self): result['connectTimeout'] = self.connect_timeout if self.proxy is not None: result['proxy'] = self.proxy - if self.credentials_uri is not None: - result['credentialsUri'] = self.credentials_uri - if self.sts_endpoint is not None: - result['stsEndpoint'] = self.sts_endpoint return result def from_map(self, m: dict = None): m = m or dict() + if m.get('type') is not None: + self.type = m.get('type') if m.get('accessKeyId') is not None: self.access_key_id = m.get('accessKeyId') if m.get('accessKeySecret') is not None: @@ -150,12 +197,16 @@ def from_map(self, m: dict = None): self.oidc_provider_arn = m.get('oidcProviderArn') if m.get('oidcTokenFilePath') is not None: self.oidc_token_file_path = m.get('oidcTokenFilePath') - if m.get('policy') is not None: - self.policy = m.get('policy') - if m.get('roleSessionExpiration') is not None: - self.role_session_expiration = m.get('roleSessionExpiration') if m.get('roleSessionName') is not None: self.role_session_name = m.get('roleSessionName') + if m.get('roleSessionExpiration') is not None: + self.role_session_expiration = m.get('roleSessionExpiration') + if m.get('policy') is not None: + self.policy = m.get('policy') + if m.get('externalId') is not None: + self.external_id = m.get('externalId') + if m.get('stsEndpoint') is not None: + self.sts_endpoint = m.get('stsEndpoint') if m.get('publicKeyId') is not None: self.public_key_id = m.get('publicKeyId') if m.get('privateKeyFile') is not None: @@ -168,8 +219,8 @@ def from_map(self, m: dict = None): self.enable_imds_v2 = m.get('enableIMDSv2') if m.get('metadataTokenDuration') is not None: self.metadata_token_duration = m.get('metadataTokenDuration') - if m.get('type') is not None: - self.type = m.get('type') + if m.get('credentialsUri') is not None: + self.credentials_uri = m.get('credentialsUri') if m.get('host') is not None: self.host = m.get('host') if m.get('timeout') is not None: @@ -178,10 +229,6 @@ def from_map(self, m: dict = None): self.connect_timeout = m.get('connectTimeout') if m.get('proxy') is not None: self.proxy = m.get('proxy') - if m.get('credentialsUri') is not None: - self.credentials_uri = m.get('credentials_uri') - if m.get('stsEndpoint') is not None: - self.sts_endpoint = m.get('stsEndpoint') return self @@ -193,6 +240,7 @@ def __init__( security_token: str = None, bearer_token: str = None, type: str = None, + provider_name: str = None, ): # accesskey id self.access_key_id = access_key_id @@ -204,11 +252,13 @@ def __init__( self.bearer_token = bearer_token # type self.type = type + # provider name + self.provider_name = provider_name def validate(self): pass - def to_map(self): + def to_map(self) -> dict: _map = super().to_map() if _map is not None: return _map @@ -224,6 +274,8 @@ def to_map(self): result['bearerToken'] = self.bearer_token if self.type is not None: result['type'] = self.type + if self.provider_name is not None: + result['providerName'] = self.provider_name return result def from_map(self, m: dict = None): @@ -238,19 +290,24 @@ def from_map(self, m: dict = None): self.bearer_token = m.get('bearerToken') if m.get('type') is not None: self.type = m.get('type') + if m.get('providerName') is not None: + self.provider_name = m.get('providerName') return self - def get_access_key_id(self): + def get_access_key_id(self) -> str: return self.access_key_id - def get_access_key_secret(self): + def get_access_key_secret(self) -> str: return self.access_key_secret - def get_security_token(self): + def get_security_token(self) -> str: return self.security_token - def get_bearer_token(self): + def get_bearer_token(self) -> str: return self.bearer_token - def get_type(self): + def get_type(self) -> str: return self.type + + def get_provider_name(self) -> str: + return self.provider_name diff --git a/alibabacloud_credentials/provider/__init__.py b/alibabacloud_credentials/provider/__init__.py new file mode 100644 index 0000000..bc33809 --- /dev/null +++ b/alibabacloud_credentials/provider/__init__.py @@ -0,0 +1,25 @@ +from .static_ak import StaticAKCredentialsProvider +from .static_sts import StaticSTSCredentialsProvider +from .env import EnvironmentVariableCredentialsProvider +from .ecs_ram_role import EcsRamRoleCredentialsProvider +from .ram_role_arn import RamRoleArnCredentialsProvider +from .oidc import OIDCRoleArnCredentialsProvider +from .rsa_key_pair import RsaKeyPairCredentialsProvider +from .uri import URLCredentialsProvider +from .cli_profile import CLIProfileCredentialsProvider +from .profile import ProfileCredentialsProvider +from .default import DefaultCredentialsProvider + +__all__ = [ + 'StaticAKCredentialsProvider', + 'StaticSTSCredentialsProvider', + 'EnvironmentVariableCredentialsProvider', + 'EcsRamRoleCredentialsProvider', + 'RamRoleArnCredentialsProvider', + 'OIDCRoleArnCredentialsProvider', + 'RsaKeyPairCredentialsProvider', + 'URLCredentialsProvider', + 'CLIProfileCredentialsProvider', + 'ProfileCredentialsProvider', + 'DefaultCredentialsProvider' +] diff --git a/alibabacloud_credentials/provider/cli_profile.py b/alibabacloud_credentials/provider/cli_profile.py new file mode 100644 index 0000000..f996d3c --- /dev/null +++ b/alibabacloud_credentials/provider/cli_profile.py @@ -0,0 +1,164 @@ +import os +import json +from typing import Any, Dict + +import aiofiles + +from alibabacloud_credentials.provider import StaticAKCredentialsProvider, EcsRamRoleCredentialsProvider, \ + RamRoleArnCredentialsProvider, OIDCRoleArnCredentialsProvider, RsaKeyPairCredentialsProvider +from .refreshable import Credentials +from alibabacloud_credentials_api import ICredentialsProvider +from alibabacloud_credentials.utils import auth_constant as ac +from alibabacloud_credentials.utils import auth_util as au +from alibabacloud_credentials.exceptions import CredentialException + + +async def _load_config_async(file_path: str) -> Any: + async with aiofiles.open(file_path, mode='r') as f: + content = await f.read() + return json.loads(content) + + +def _load_config(file_path: str) -> Any: + with open(file_path, mode='r') as f: + content = f.read() + return json.loads(content) + + +class CLIProfileCredentialsProvider(ICredentialsProvider): + + def __init__(self, *, + profile_name: str = None): + self._profile_file = os.path.join(ac.HOME, "/.aliyun/config.json") + self._profile_name = profile_name or au.environment_profile_name + self.__innerProvider = None + + def _should_reload_credentials_provider(self) -> bool: + if self.__innerProvider is None: + return True + return False + + def get_credentials(self) -> Credentials: + if au.environment_cli_profile_disabled.lower() == "true": + raise CredentialException('cli credentials file is disabled') + + if self._should_reload_credentials_provider(): + if not os.path.exists(self._profile_file) or not os.path.isfile(self._profile_file): + raise CredentialException(f'unable to open credentials file: {self._profile_file}') + try: + config = _load_config(self._profile_file) + except Exception as e: + raise CredentialException( + f'failed to parse credential form cli credentials file: {self._profile_file}') + if config is None: + raise CredentialException( + f'failed to parse credential form cli credentials file: {self._profile_file}') + + profile_name = self._profile_name + if self._profile_name is None or self._profile_name == '': + profile_name = config.get('current') + self.__innerProvider = self._get_credentials_provider(config, profile_name) + + cre = self.__innerProvider.get_credentials() + credentials = Credentials( + access_key_id=cre.get_access_key_id(), + access_key_secret=cre.get_access_key_secret(), + security_token=cre.get_security_token(), + provider_name=f'{self.get_provider_name()}/{cre.get_provider_name()}' + ) + return credentials + + async def get_credentials_async(self) -> Credentials: + if au.environment_cli_profile_disabled.lower() == "true": + raise CredentialException('cli credentials file is disabled') + + if self._should_reload_credentials_provider(): + if not os.path.exists(self._profile_file) or not os.path.isfile(self._profile_file): + raise CredentialException(f'unable to open credentials file: {self._profile_file}') + try: + config = await _load_config_async(self._profile_file) + except Exception as e: + raise CredentialException( + f'failed to parse credential form cli credentials file: {self._profile_file}') + if config is None: + raise CredentialException( + f'failed to parse credential form cli credentials file: {self._profile_file}') + + profile_name = self._profile_name + if self._profile_name is None or self._profile_name == '': + profile_name = config.get('current') + self.__innerProvider = self._get_credentials_provider(config, profile_name) + + cre = await self.__innerProvider.get_credentials_async() + credentials = Credentials( + access_key_id=cre.get_access_key_id(), + access_key_secret=cre.get_access_key_secret(), + security_token=cre.get_security_token(), + provider_name=f'{self.get_provider_name()}/{cre.get_provider_name()}' + ) + return credentials + + def _get_credentials_provider(self, config: Dict, profile_name: str) -> ICredentialsProvider: + if profile_name is None or profile_name == '': + raise CredentialException('invalid profile name') + + profiles = config.get('profiles', []) + + if not profiles: + raise CredentialException(f"unable to get profile with '{profile_name}' form cli credentials file.") + + for profile in profiles: + if profile.get('name') is not None and profile['name'] == profile_name: + mode = profile.get('mode') + if mode == "AK": + return StaticAKCredentialsProvider( + access_key_id=profile.get('access_key_id'), + access_key_secret=profile.get('access_key_secret') + ) + elif mode == "RamRoleArn": + pre_provider = StaticAKCredentialsProvider( + access_key_id=profile.get('access_key_id'), + access_key_secret=profile.get('access_key_secret') + ) + return RamRoleArnCredentialsProvider( + credentials_provider=pre_provider, + role_arn=profile.get('ram_role_arn'), + role_session_name=profile.get('ram_session_name'), + duration_seconds=profile.get('expired_seconds'), + policy=profile.get('policy'), + external_id=profile.get('external_id'), + sts_region_id=profile.get('sts_region'), + enable_vpc=profile.get('enable_vpc'), + ) + elif mode == "EcsRamRole": + return EcsRamRoleCredentialsProvider( + role_name=profile.get('ram_role_name') + ) + elif mode == "OIDC": + return OIDCRoleArnCredentialsProvider( + role_arn=profile.get('ram_role_arn'), + oidc_provider_arn=profile.get('oidc_provider_arn'), + oidc_token_file_path=profile.get('oidc_token_file'), + role_session_name=profile.get('role_session_name'), + duration_seconds=profile.get('expired_seconds'), + policy=profile.get('policy'), + sts_region_id=profile.get('sts_region'), + enable_vpc=profile.get('enable_vpc'), + ) + elif mode == "ChainableRamRoleArn": + previous_provider = self._get_credentials_provider(config, profile.get('source_profile')) + return RamRoleArnCredentialsProvider( + credentials_provider=previous_provider, + role_arn=profile.get('ram_role_arn'), + role_session_name=profile.get('ram_session_name'), + duration_seconds=profile.get('expired_seconds'), + policy=profile.get('policy'), + external_id=profile.get('external_id'), + sts_region_id=profile.get('sts_region'), + enable_vpc=profile.get('enable_vpc'), + ) + else: + raise CredentialException(f"unsupported profile mode '{mode}' form cli credentials file.") + + def get_provider_name(self) -> str: + return 'cli_profile' diff --git a/alibabacloud_credentials/provider/default.py b/alibabacloud_credentials/provider/default.py new file mode 100644 index 0000000..a2aa73a --- /dev/null +++ b/alibabacloud_credentials/provider/default.py @@ -0,0 +1,89 @@ +from . import EnvironmentVariableCredentialsProvider, EcsRamRoleCredentialsProvider, \ + OIDCRoleArnCredentialsProvider, URLCredentialsProvider, CLIProfileCredentialsProvider, ProfileCredentialsProvider + +from alibabacloud_credentials.provider.refreshable import Credentials +from alibabacloud_credentials_api import ICredentialsProvider +from alibabacloud_credentials.utils import auth_util as au +from alibabacloud_credentials.exceptions import CredentialException + + +class DefaultCredentialsProvider(ICredentialsProvider): + + def __init__(self, *, + reuse_last_provider_enabled: bool = True): + + self.__reuse_last_provider_enabled = reuse_last_provider_enabled + self.__last_used_provider = None + + self.__providers_chain = [ + EnvironmentVariableCredentialsProvider() + ] + if au.enable_oidc_credential: + self.__providers_chain.append(OIDCRoleArnCredentialsProvider()) + + self.__providers_chain.append(CLIProfileCredentialsProvider()) + self.__providers_chain.append(ProfileCredentialsProvider()) + if au.environment_ecs_metadata_disabled.lower() != 'true': + self.__providers_chain.append(EcsRamRoleCredentialsProvider()) + + if au.environment_credentials_uri is not None and au.environment_credentials_uri != '': + self.__providers_chain.append(URLCredentialsProvider()) + + def get_credentials(self) -> Credentials: + if self.__reuse_last_provider_enabled and self.__last_used_provider is not None: + credentials = self.__last_used_provider.get_credentials() + return Credentials( + access_key_id=credentials.get_access_key_id(), + access_key_secret=credentials.get_access_key_secret(), + security_token=credentials.get_security_token(), + provider_name=f'{self.get_provider_name()}/{credentials.get_provider_name()}' + ) + + error_messages = [] + for provider in self.__providers_chain: + try: + credentials = provider.get_credentials() + if credentials is not None: + self.__last_used_provider = provider + return Credentials( + access_key_id=credentials.get_access_key_id(), + access_key_secret=credentials.get_access_key_secret(), + security_token=credentials.get_security_token(), + provider_name=f'{self.get_provider_name()}/{credentials.get_provider_name()}' + ) + except Exception as e: + error_messages.append(f'{type(provider).__name__}: {str(e)}') + + raise CredentialException( + f'unable to load credentials from any of the providers in the chain: {error_messages}') + + async def get_credentials_async(self) -> Credentials: + if self.__reuse_last_provider_enabled and self.__last_used_provider is not None: + credentials = await self.__last_used_provider.get_credentials_async() + return Credentials( + access_key_id=credentials.get_access_key_id(), + access_key_secret=credentials.get_access_key_secret(), + security_token=credentials.get_security_token(), + provider_name=f'{self.get_provider_name()}/{credentials.get_provider_name()}' + ) + + error_messages = [] + for provider in self.__providers_chain: + try: + credentials = await provider.get_credentials_async() + if credentials is not None: + self.__last_used_provider = provider + return Credentials( + access_key_id=credentials.get_access_key_id(), + access_key_secret=credentials.get_access_key_secret(), + security_token=credentials.get_security_token(), + provider_name=f'{self.get_provider_name()}/{credentials.get_provider_name()}' + ) + except Exception as e: + error_messages.append(f'{type(provider).__name__}: {str(e)}') + + raise CredentialException( + f'unable to load credentials from any of the providers in the chain: {error_messages}') + + def get_provider_name(self) -> str: + return 'default' diff --git a/alibabacloud_credentials/provider/ecs_ram_role.py b/alibabacloud_credentials/provider/ecs_ram_role.py new file mode 100644 index 0000000..89c5480 --- /dev/null +++ b/alibabacloud_credentials/provider/ecs_ram_role.py @@ -0,0 +1,251 @@ +import calendar +import json +import time +import signal +import logging + +from alibabacloud_credentials.provider.refreshable import Credentials, RefreshResult, StaleValueBehavior, \ + RefreshCachedSupplier, NonBlocking +from alibabacloud_credentials.http import HttpOptions +from Tea.core import TeaCore +from apscheduler.schedulers.background import BackgroundScheduler +from alibabacloud_credentials_api import ICredentialsProvider +from alibabacloud_credentials.utils import auth_util as au +from alibabacloud_credentials.utils import parameter_helper as ph +from alibabacloud_credentials.exceptions import CredentialException + +logging.basicConfig(level=logging.DEBUG) +log = logging.getLogger(__name__) + + +class EcsRamRoleCredentialsProvider(ICredentialsProvider): + DEFAULT_METADATA_TOKEN_DURATION = 21600 + DEFAULT_CONNECT_TIMEOUT = 1000 + DEFAULT_READ_TIMEOUT = 1000 + + def __init__(self, *, + role_name: str = None, + disable_imds_v1: bool = None, + http_options: HttpOptions = None, + async_update_enabled: bool = True): + + if au.environment_ecs_metadata_disabled.lower() == 'true': + raise ValueError('IMDS credentials is disabled') + + self.__url_in_ecs_metadata = '/latest/meta-data/ram/security-credentials/' + self.__url_in_ecs_metadata_token = '/latest/api/token' + self.__ecs_metadata_fetch_error_msg = 'Failed to get RAM session credentials from ECS metadata service.' + self.__ecs_metadata_token_fetch_error_msg = 'Failed to get token from ECS Metadata Service.' + self.__metadata_service_host = '100.100.100.200' + self._should_refresh = False + + self._role_name = role_name if role_name is not None else au.environment_ecs_metadata + self._disable_imds_v1 = disable_imds_v1 if disable_imds_v1 is not None else au.environment_imds_v1_disabled.lower() == 'true' + self._http_options = http_options if http_options is not None else HttpOptions() + self._runtime_options = { + 'connectTimeout': self._http_options.connect_timeout if self._http_options.connect_timeout is not None else EcsRamRoleCredentialsProvider.DEFAULT_CONNECT_TIMEOUT, + 'readTimeout': self._http_options.read_timeout if self._http_options.read_timeout is not None else EcsRamRoleCredentialsProvider.DEFAULT_READ_TIMEOUT, + 'httpProxy': self._http_options.proxy + } + + if async_update_enabled: + self._credentials_cache = RefreshCachedSupplier( + refresh_callable=self._refresh_credentials, + refresh_callable_async=self._refresh_credentials_async, + stale_value_behavior=StaleValueBehavior.ALLOW, + prefetch_strategy=NonBlocking() + ) + + scheduler = BackgroundScheduler() + + def refresh_task(): + if self._should_refresh: + log.debug(f'Begin checking or refreshing credentials asynchronously') + self.get_credentials() + + scheduler.add_job(refresh_task, 'interval', minutes=1) + scheduler.start() + + def shutdown_handler(signum, frame): + log.debug(f'Shutting down scheduler...') + scheduler.shutdown(wait=False) + + signal.signal(signal.SIGINT, shutdown_handler) + signal.signal(signal.SIGTERM, shutdown_handler) + + else: + self._credentials_cache = RefreshCachedSupplier( + refresh_callable=self._refresh_credentials, + refresh_callable_async=self._refresh_credentials_async, + stale_value_behavior=StaleValueBehavior.ALLOW + ) + + def get_credentials(self) -> Credentials: + return self._credentials_cache() + + async def get_credentials_async(self) -> Credentials: + return await self._credentials_cache() + + def _get_role_name(self, url: str = None) -> str: + tea_request = ph.get_new_request() + tea_request.headers['host'] = url if url else self.__metadata_service_host + metadata_token = self._get_metadata_token(url) + if metadata_token is not None: + tea_request.headers['X-aliyun-ecs-metadata-token'] = metadata_token + if not url: + tea_request.pathname = self.__url_in_ecs_metadata + response = TeaCore.do_action(tea_request, self._runtime_options) + if response.status_code != 200: + raise CredentialException(self.__ecs_metadata_fetch_error_msg + ' HttpCode=' + str(response.status_code)) + return response.body.decode('utf-8') + + async def _get_role_name_async(self, url: str = None) -> str: + tea_request = ph.get_new_request() + tea_request.headers['host'] = url if url else self.__metadata_service_host + metadata_token = await self._get_metadata_token_async(url) + if metadata_token is not None: + tea_request.headers['X-aliyun-ecs-metadata-token'] = metadata_token + if not url: + tea_request.pathname = self.__url_in_ecs_metadata + response = await TeaCore.async_do_action(tea_request, self._runtime_options) + if response.status_code != 200: + raise CredentialException(self.__ecs_metadata_fetch_error_msg + ' HttpCode=' + str(response.status_code)) + return response.body.decode('utf-8') + + def _get_metadata_token(self, url: str = None) -> str: + tea_request = ph.get_new_request() + tea_request.method = 'PUT' + tea_request.headers['host'] = url if url else self.__metadata_service_host + tea_request.headers['X-aliyun-ecs-metadata-token-ttl-seconds'] = str( + EcsRamRoleCredentialsProvider.DEFAULT_METADATA_TOKEN_DURATION) + if not url: + tea_request.pathname = self.__url_in_ecs_metadata_token + try: + response = TeaCore.do_action(tea_request, self._runtime_options) + if response.status_code != 200: + raise CredentialException( + self.__ecs_metadata_token_fetch_error_msg + ' HttpCode=' + str(response.status_code)) + return response.body.decode('utf-8') + except Exception as e: + if self._disable_imds_v1: + raise e + return None + + async def _get_metadata_token_async(self, url: str = None) -> str: + tea_request = ph.get_new_request() + tea_request.method = 'PUT' + tea_request.headers['host'] = url if url else self.__metadata_service_host + tea_request.headers['X-aliyun-ecs-metadata-token-ttl-seconds'] = str( + EcsRamRoleCredentialsProvider.DEFAULT_METADATA_TOKEN_DURATION) + if not url: + tea_request.pathname = self.__url_in_ecs_metadata_token + try: + response = await TeaCore.async_do_action(tea_request, self._runtime_options) + if response.status_code != 200: + raise CredentialException( + self.__ecs_metadata_token_fetch_error_msg + ' HttpCode=' + str(response.status_code)) + return response.body.decode('utf-8') + except Exception as e: + if self._disable_imds_v1: + raise e + return None + + def _refresh_credentials(self, url: str = None) -> RefreshResult[Credentials]: + role_name = self._role_name + if self._role_name is None or self._role_name == '': + role_name = self._get_role_name(url) + tea_request = ph.get_new_request() + tea_request.headers['host'] = url if url else self.__metadata_service_host + metadata_token = self._get_metadata_token(url) + if metadata_token is not None: + tea_request.headers['X-aliyun-ecs-metadata-token'] = metadata_token + if not url: + tea_request.pathname = self.__url_in_ecs_metadata + role_name + # request + response = TeaCore.do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException(self.__ecs_metadata_fetch_error_msg + ' HttpCode=' + str(response.status_code)) + + dic = json.loads(response.body.decode('utf-8')) + content_code = dic.get('Code') + content_access_key_id = dic.get('AccessKeyId') + content_access_key_secret = dic.get('AccessKeySecret') + content_security_token = dic.get('SecurityToken') + content_expiration = dic.get('Expiration') + + if content_code != 'Success': + raise CredentialException(self.__ecs_metadata_fetch_error_msg) + + # 先转换为时间数组 + time_array = time.strptime(content_expiration, '%Y-%m-%dT%H:%M:%SZ') + # 转换为时间戳 + expiration = calendar.timegm(time_array) + credentials = Credentials( + access_key_id=content_access_key_id, + access_key_secret=content_access_key_secret, + security_token=content_security_token, + expiration=expiration, + provider_name=self.get_provider_name() + ) + self._should_refresh = True + return RefreshResult(value=credentials, + stale_time=self._get_stale_time(expiration), + prefetch_time=self._get_prefetch_time(expiration)) + + async def _refresh_credentials_async(self, url: str = None) -> RefreshResult[Credentials]: + role_name = self._role_name + if self._role_name is None: + role_name = await self._get_role_name_async(url) + tea_request = ph.get_new_request() + tea_request.headers['host'] = url if url else self.__metadata_service_host + metadata_token = await self._get_metadata_token_async(url) + if metadata_token is not None: + tea_request.headers['X-aliyun-ecs-metadata-token'] = metadata_token + if not url: + tea_request.pathname = self.__url_in_ecs_metadata + role_name + + # request + response = await TeaCore.async_do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException(self.__ecs_metadata_fetch_error_msg + ' HttpCode=' + str(response.status_code)) + + dic = json.loads(response.body.decode('utf-8')) + content_code = dic.get('Code') + content_access_key_id = dic.get('AccessKeyId') + content_access_key_secret = dic.get('AccessKeySecret') + content_security_token = dic.get('SecurityToken') + content_expiration = dic.get('Expiration') + + if content_code != 'Success': + raise CredentialException(self.__ecs_metadata_fetch_error_msg) + + # 先转换为时间数组 + time_array = time.strptime(content_expiration, '%Y-%m-%dT%H:%M:%SZ') + # 转换为时间戳 + expiration = calendar.timegm(time_array) + credentials = Credentials( + access_key_id=content_access_key_id, + access_key_secret=content_access_key_secret, + security_token=content_security_token, + expiration=expiration, + provider_name=self.get_provider_name() + ) + self._should_refresh = True + return RefreshResult(value=credentials, + stale_time=self._get_stale_time(expiration), + prefetch_time=self._get_prefetch_time(expiration)) + + def _get_stale_time(self, expiration: int) -> int: + if expiration < 0: + return int(time.mktime(time.localtime())) + 60 * 60 + return expiration - 15 * 60 + + def _get_prefetch_time(self, expiration: int) -> int: + if expiration < 0: + return int(time.mktime(time.localtime())) + 5 * 60 + return int(time.mktime(time.localtime())) + 60 * 60 + + def get_provider_name(self) -> str: + return 'ecs_ram_role' diff --git a/alibabacloud_credentials/provider/env.py b/alibabacloud_credentials/provider/env.py new file mode 100644 index 0000000..a669480 --- /dev/null +++ b/alibabacloud_credentials/provider/env.py @@ -0,0 +1,33 @@ + +from alibabacloud_credentials.provider.refreshable import Credentials +from alibabacloud_credentials_api import ICredentialsProvider +from alibabacloud_credentials.utils import auth_util +from alibabacloud_credentials.exceptions import CredentialException + + +class EnvironmentVariableCredentialsProvider(ICredentialsProvider): + + def get_credentials(self) -> Credentials: + + access_key_id = auth_util.environment_access_key_id + access_key_secret = auth_util.environment_access_key_secret + security_token = auth_util.environment_security_token + + if access_key_id is None or len(access_key_id) == 0: + raise CredentialException("Environment variable accessKeyId cannot be empty") + + if access_key_secret is None or len(access_key_secret) == 0: + raise CredentialException("Environment variable accessKeySecret cannot be empty") + + return Credentials( + access_key_id=access_key_id, + access_key_secret=access_key_secret, + security_token=security_token, + provider_name=self.get_provider_name() + ) + + async def get_credentials_async(self) -> Credentials: + return self.get_credentials() + + def get_provider_name(self) -> str: + return 'env' diff --git a/alibabacloud_credentials/provider/oidc.py b/alibabacloud_credentials/provider/oidc.py new file mode 100644 index 0000000..804ba99 --- /dev/null +++ b/alibabacloud_credentials/provider/oidc.py @@ -0,0 +1,206 @@ +import calendar +import json +import time +import aiofiles + +from alibabacloud_credentials.provider.refreshable import Credentials, RefreshResult, RefreshCachedSupplier +from alibabacloud_credentials.http import HttpOptions +from Tea.core import TeaCore +from alibabacloud_credentials_api import ICredentialsProvider +from alibabacloud_credentials.utils import auth_util as au +from alibabacloud_credentials.utils import parameter_helper as ph +from alibabacloud_credentials.exceptions import CredentialException + + +async def _get_token_async(file_path: str) -> str: + async with aiofiles.open(file_path, mode='r') as file: + token = await file.read() + return token + + +def _get_token(file_path: str) -> str: + with open(file_path, mode='r') as file: + token = file.read() + return token + + +def _get_stale_time(expiration: int) -> int: + if expiration < 0: + return int(time.mktime(time.localtime())) + 60 * 60 + return expiration - 15 * 60 + + +class OIDCRoleArnCredentialsProvider(ICredentialsProvider): + DEFAULT_DURATION_SECONDS = 3600 + DEFAULT_CONNECT_TIMEOUT = 5000 + DEFAULT_READ_TIMEOUT = 10000 + + def __init__(self, *, + role_arn: str = None, + oidc_provider_arn: str = None, + oidc_token_file_path: str = None, + role_session_name: str = None, + duration_seconds: int = DEFAULT_DURATION_SECONDS, + policy: str = None, + sts_region_id: str = None, + sts_endpoint: str = None, + enable_vpc: bool = None, + http_options: HttpOptions = None): + + self._role_arn = role_arn or au.environment_role_arn + self._oidc_provider_arn = oidc_provider_arn or au.environment_oidc_provider_arn + self._oidc_token_file_path = oidc_token_file_path or au.environment_oidc_token_file + self._role_session_name = role_session_name or au.environment_role_session_name + self._duration_seconds = duration_seconds + self._policy = policy + + if self._role_session_name is None or self._role_session_name == '': + self._role_session_name = f'credentials-python-{str(int(time.mktime(time.localtime())))}' + if self._duration_seconds is None: + self._duration_seconds = self.DEFAULT_DURATION_SECONDS + if self._duration_seconds < 900: + raise ValueError('session duration should be in the range of 900s - max session duration') + if self._role_arn is None or self._role_arn == '': + raise ValueError('role_arn or environment variable ALIBABA_CLOUD_ROLE_ARN cannot be empty') + if self._oidc_provider_arn is None or self._oidc_provider_arn == '': + raise ValueError( + 'oidc_provider_arn or environment variable ALIBABA_CLOUD_OIDC_PROVIDER_ARN cannot be empty') + if self._oidc_token_file_path is None or self._oidc_token_file_path == '': + raise ValueError( + 'oidc_token_file_path or environment variable ALIBABA_CLOUD_OIDC_TOKEN_FILE cannot be empty') + + if sts_endpoint is not None and sts_endpoint != '': + self._sts_endpoint = sts_endpoint + else: + if enable_vpc is not None: + prefix = 'sts-vpc' if enable_vpc else 'sts' + else: + prefix = 'sts-vpc' if au.environment_enable_vpc.lower() == 'true' else 'sts' + if sts_region_id is not None and sts_region_id != '': + self._sts_endpoint = f'{prefix}.{sts_region_id}.aliyuncs.com' + elif au.environment_sts_region is not None and au.environment_sts_region != '': + self._sts_endpoint = f'{prefix}.{au.environment_sts_region}.aliyuncs.com' + else: + self._sts_endpoint = 'sts.aliyuncs.com' + + self._http_options = http_options if http_options is not None else HttpOptions() + self._runtime_options = { + 'connectTimeout': self._http_options.connect_timeout if self._http_options.connect_timeout is not None else OIDCRoleArnCredentialsProvider.DEFAULT_CONNECT_TIMEOUT, + 'readTimeout': self._http_options.read_timeout if self._http_options.read_timeout is not None else OIDCRoleArnCredentialsProvider.DEFAULT_READ_TIMEOUT, + 'httpsProxy': self._http_options.proxy + } + self._credentials_cache = RefreshCachedSupplier( + refresh_callable=self._refresh_credentials, + refresh_callable_async=self._refresh_credentials_async, + ) + + def get_credentials(self) -> Credentials: + return self._credentials_cache() + + async def get_credentials_async(self) -> Credentials: + return await self._credentials_cache() + + def _refresh_credentials(self) -> RefreshResult[Credentials]: + token = _get_token(self._oidc_token_file_path) + tea_request = ph.get_new_request() + tea_request.query = { + 'Action': 'AssumeRoleWithOIDC', + 'Format': 'JSON', + 'Version': '2015-04-01', + 'DurationSeconds': str(self._duration_seconds), + 'RoleArn': self._role_arn, + 'OIDCProviderArn': self._oidc_provider_arn, + 'OIDCToken': token, + 'RoleSessionName': self._role_session_name, + 'Timestamp': ph.get_iso_8061_date() + } + + if self._policy is not None and self._policy != '': + tea_request.query['Policy'] = self._policy + + tea_request.protocol = 'https' + tea_request.headers['host'] = self._sts_endpoint + + response = TeaCore.do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException( + f'error refreshing credentials from oidc_role_arn, http_code: {response.status_code}, result: {response.body.decode("utf-8")}') + + dic = json.loads(response.body.decode('utf-8')) + if 'Credentials' not in dic: + raise CredentialException( + f'error retrieving credentials from oidc_role_arn result: {response.body.decode("utf-8")}') + + cre = dic.get('Credentials') + if 'AccessKeyId' not in cre or 'AccessKeySecret' not in cre or 'SecurityToken' not in cre: + raise CredentialException( + f'error retrieving credentials from oidc_role_arn result: {response.body.decode("utf-8")}') + + # 先转换为时间数组 + time_array = time.strptime(cre.get('Expiration'), '%Y-%m-%dT%H:%M:%SZ') + # 转换为时间戳 + expiration = calendar.timegm(time_array) + credentials = Credentials( + access_key_id=cre.get('AccessKeyId'), + access_key_secret=cre.get('AccessKeySecret'), + security_token=cre.get('SecurityToken'), + expiration=expiration, + provider_name=self.get_provider_name() + ) + return RefreshResult(value=credentials, + stale_time=_get_stale_time(expiration)) + + async def _refresh_credentials_async(self) -> RefreshResult[Credentials]: + token = await _get_token_async(self._oidc_token_file_path) + tea_request = ph.get_new_request() + tea_request.query = { + 'Action': 'AssumeRoleWithOIDC', + 'Format': 'JSON', + 'Version': '2015-04-01', + 'DurationSeconds': str(self._duration_seconds), + 'RoleArn': self._role_arn, + 'OIDCProviderArn': self._oidc_provider_arn, + 'OIDCToken': token, + 'RoleSessionName': self._role_session_name, + 'Timestamp': ph.get_iso_8061_date() + } + + if self._policy is not None and self._policy != '': + tea_request.query['Policy'] = self._policy + + tea_request.protocol = 'https' + tea_request.headers['host'] = self._sts_endpoint + + response = await TeaCore.async_do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException( + f'error refreshing credentials from oidc_role_arn, http_code: {response.status_code}, result: {response.body.decode("utf-8")}') + + dic = json.loads(response.body.decode('utf-8')) + if 'Credentials' not in dic: + raise CredentialException( + f'error retrieving credentials from oidc_role_arn result: {response.body.decode("utf-8")}') + + cre = dic.get('Credentials') + if 'AccessKeyId' not in cre or 'AccessKeySecret' not in cre or 'SecurityToken' not in cre: + raise CredentialException( + f'error retrieving credentials from oidc_role_arn result: {response.body.decode("utf-8")}') + + # 先转换为时间数组 + time_array = time.strptime(cre.get('Expiration'), '%Y-%m-%dT%H:%M:%SZ') + # 转换为时间戳 + expiration = calendar.timegm(time_array) + credentials = Credentials( + access_key_id=cre.get('AccessKeyId'), + access_key_secret=cre.get('AccessKeySecret'), + security_token=cre.get('SecurityToken'), + expiration=expiration, + provider_name=self.get_provider_name() + ) + return RefreshResult(value=credentials, + stale_time=_get_stale_time(expiration)) + + def get_provider_name(self) -> str: + return 'oidc_role_arn' diff --git a/alibabacloud_credentials/provider/profile.py b/alibabacloud_credentials/provider/profile.py new file mode 100644 index 0000000..86f43e2 --- /dev/null +++ b/alibabacloud_credentials/provider/profile.py @@ -0,0 +1,144 @@ +import os +import configparser +from typing import Dict + +import aiofiles + +from alibabacloud_credentials.provider import StaticAKCredentialsProvider, EcsRamRoleCredentialsProvider, \ + RamRoleArnCredentialsProvider, OIDCRoleArnCredentialsProvider, RsaKeyPairCredentialsProvider +from alibabacloud_credentials.provider.refreshable import Credentials +from alibabacloud_credentials_api import ICredentialsProvider +from alibabacloud_credentials.utils import auth_constant as ac +from alibabacloud_credentials.utils import auth_util as au +from alibabacloud_credentials.exceptions import CredentialException + + +async def _load_ini_async(file_path: str) -> Dict[str, Dict[str, str]]: + config = configparser.ConfigParser() + async with aiofiles.open(file_path, mode='r') as f: + content = await f.read() + config.read_string(content) + ini_map = {} + for section in config.sections(): + option = {} + for key, value in config.items(section): + if '#' in value: + option[key] = value.split('#')[0].strip() + else: + option[key] = value.strip() + ini_map[section] = option + return ini_map + + +def _load_ini(file_path: str) -> Dict[str, Dict[str, str]]: + config = configparser.ConfigParser() + config.read(file_path, encoding='utf-8') + ini_map = {} + for section in config.sections(): + option = {} + for key, value in config.items(section): + if '#' in value: + option[key] = value.split('#')[0].strip() + else: + option[key] = value.strip() + ini_map[section] = option + return ini_map + + +def _get_default_file() -> str: + return os.path.join(ac.HOME, "/.alibabacloud/credentials.ini") + + +class ProfileCredentialsProvider(ICredentialsProvider): + + def __init__(self, *, + profile_file: str = None, + profile_name: str = None): + self._profile_file = profile_file or au.environment_credentials_file + self._profile_name = profile_name or au.client_type + self.__innerProvider = None + + if self._profile_file is None or self._profile_file == '': + self._profile_file = _get_default_file() + + def _should_reload_credentials_provider(self) -> bool: + if self.__innerProvider is None: + return True + return False + + def get_credentials(self) -> Credentials: + if self._should_reload_credentials_provider(): + ini_map = _load_ini(self._profile_file) + section = ini_map.get(self._profile_name) + if section is None: + raise CredentialException(f'failed to get credential from credentials file: ${self._profile_file}') + self.__innerProvider = self._get_credentials_provider(section) + + cre = self.__innerProvider.get_credentials() + credentials = Credentials( + access_key_id=cre.get_access_key_id(), + access_key_secret=cre.get_access_key_secret(), + security_token=cre.get_security_token(), + provider_name=f'{self.get_provider_name()}/{cre.get_provider_name()}' + ) + return credentials + + async def get_credentials_async(self) -> Credentials: + if self._should_reload_credentials_provider(): + ini_map = await _load_ini_async(self._profile_file) + section = ini_map.get(self._profile_name) + if section is None: + raise CredentialException(f'failed to get credential from credentials file: ${self._profile_file}') + self.__innerProvider = self._get_credentials_provider(section) + + cre = await self.__innerProvider.get_credentials_async() + credentials = Credentials( + access_key_id=cre.get_access_key_id(), + access_key_secret=cre.get_access_key_secret(), + security_token=cre.get_security_token(), + provider_name=f'{self.get_provider_name()}/{cre.get_provider_name()}' + ) + return credentials + + def _get_credentials_provider(self, section: Dict) -> ICredentialsProvider: + + config_type = section.get(ac.INI_TYPE) + if 'access_key' == config_type: + return StaticAKCredentialsProvider( + access_key_id=section.get('access_key_id'), + access_key_secret=section.get('access_key_secret') + ) + elif 'ram_role_arn' == config_type: + pre_provider = StaticAKCredentialsProvider( + access_key_id=section.get('access_key_id'), + access_key_secret=section.get('access_key_secret') + ) + return RamRoleArnCredentialsProvider( + credentials_provider=pre_provider, + role_arn=section.get('role_arn'), + role_session_name=section.get('role_session_name'), + policy=section.get('policy') + ) + elif 'oidc_role_arn' == config_type: + return OIDCRoleArnCredentialsProvider( + role_arn=section.get('role_arn'), + oidc_provider_arn=section.get('oidc_provider_arn'), + oidc_token_file_path=section.get('oidc_token_file_path'), + role_session_name=section.get('role_session_name'), + policy=section.get('policy') + ) + elif 'ecs_ram_role' == config_type: + return EcsRamRoleCredentialsProvider( + role_name=section.get('role_name') + ) + elif 'rsa_key_pair' == config_type: + return RsaKeyPairCredentialsProvider( + public_key_id=section.get('public_key_id'), + private_key_file=section.get('private_key_file') + ) + else: + raise CredentialException( + f'unsupported credential type {config_type} from credentials file {self._profile_file}') + + def get_provider_name(self) -> str: + return 'profile' diff --git a/alibabacloud_credentials/provider/ram_role_arn.py b/alibabacloud_credentials/provider/ram_role_arn.py new file mode 100644 index 0000000..b5d8d3c --- /dev/null +++ b/alibabacloud_credentials/provider/ram_role_arn.py @@ -0,0 +1,234 @@ +import calendar +import json +import time + +from alibabacloud_credentials.provider.refreshable import Credentials, RefreshResult, RefreshCachedSupplier +from alibabacloud_credentials.provider import StaticAKCredentialsProvider, StaticSTSCredentialsProvider +from alibabacloud_credentials.http import HttpOptions +from Tea.core import TeaCore +from alibabacloud_credentials_api import ICredentialsProvider +from alibabacloud_credentials.utils import auth_util as au +from alibabacloud_credentials.utils import parameter_helper as ph +from alibabacloud_credentials.exceptions import CredentialException + + +def _get_stale_time(expiration: int) -> int: + if expiration < 0: + return int(time.mktime(time.localtime())) + 60 * 60 + return expiration - 15 * 60 + + +class RamRoleArnCredentialsProvider(ICredentialsProvider): + DEFAULT_DURATION_SECONDS = 3600 + DEFAULT_CONNECT_TIMEOUT = 5000 + DEFAULT_READ_TIMEOUT = 10000 + + def __init__(self, *, + access_key_id: str = None, + access_key_secret: str = None, + security_token: str = None, + credentials_provider: ICredentialsProvider = None, + role_arn: str = None, + role_session_name: str = None, + duration_seconds: int = DEFAULT_DURATION_SECONDS, + policy: str = None, + external_id: str = None, + sts_region_id: str = None, + sts_endpoint: str = None, + enable_vpc: bool = None, + http_options: HttpOptions = None): + + if credentials_provider is not None: + self._credentials_provider = credentials_provider + elif security_token is not None and security_token != '': + self._credentials_provider = StaticSTSCredentialsProvider( + access_key_id=access_key_id, + access_key_secret=access_key_secret, + security_token=security_token + ) + else: + self._credentials_provider = StaticAKCredentialsProvider( + access_key_id=access_key_id, + access_key_secret=access_key_secret, + ) + + self._role_arn = role_arn or au.environment_role_arn + self._role_session_name = role_session_name or au.environment_role_session_name + self._duration_seconds = duration_seconds + self._policy = policy + self._external_id = external_id + + if self._role_session_name is None or self._role_session_name == '': + self._role_session_name = f'credentials-python-{str(int(time.mktime(time.localtime())))}' + if self._duration_seconds is None: + self._duration_seconds = self.DEFAULT_DURATION_SECONDS + if self._duration_seconds < 900: + raise ValueError('session duration should be in the range of 900s - max session duration') + if self._role_arn is None or self._role_arn == '': + raise ValueError('role_arn or environment variable ALIBABA_CLOUD_ROLE_ARN cannot be empty') + + if sts_endpoint is not None and sts_endpoint != '': + self._sts_endpoint = sts_endpoint + else: + if enable_vpc is not None: + prefix = 'sts-vpc' if enable_vpc else 'sts' + else: + prefix = 'sts-vpc' if au.environment_enable_vpc.lower() == 'true' else 'sts' + if sts_region_id is not None and sts_region_id != '': + self._sts_endpoint = f'{prefix}.{sts_region_id}.aliyuncs.com' + elif au.environment_sts_region is not None and au.environment_sts_region != '': + self._sts_endpoint = f'{prefix}.{au.environment_sts_region}.aliyuncs.com' + else: + self._sts_endpoint = 'sts.aliyuncs.com' + + self._http_options = http_options if http_options is not None else HttpOptions() + self._runtime_options = { + 'connectTimeout': self._http_options.connect_timeout if self._http_options.connect_timeout is not None else RamRoleArnCredentialsProvider.DEFAULT_CONNECT_TIMEOUT, + 'readTimeout': self._http_options.read_timeout if self._http_options.read_timeout is not None else RamRoleArnCredentialsProvider.DEFAULT_READ_TIMEOUT, + 'httpsProxy': self._http_options.proxy + } + self._credentials_cache = RefreshCachedSupplier( + refresh_callable=self._refresh_credentials, + refresh_callable_async=self._refresh_credentials_async, + ) + + def get_credentials(self) -> Credentials: + return self._credentials_cache() + + async def get_credentials_async(self) -> Credentials: + return await self._credentials_cache() + + def _refresh_credentials(self) -> RefreshResult[Credentials]: + tea_request = ph.get_new_request() + tea_request.query = { + 'Action': 'AssumeRole', + 'Format': 'JSON', + 'Version': '2015-04-01', + 'DurationSeconds': str(self._duration_seconds), + 'RoleArn': self._role_arn, + 'RoleSessionName': self._role_session_name, + 'SignatureMethod': 'HMAC-SHA1', + 'SignatureVersion': '1.0', + 'Timestamp': ph.get_iso_8061_date(), + 'SignatureNonce': ph.get_uuid() + } + + if self._policy is not None and self._policy != '': + tea_request.query['Policy'] = self._policy + + if self._external_id is not None and self._external_id != '': + tea_request.query['ExternalId'] = self._external_id + + pre_credentials = self._credentials_provider.get_credentials() + if pre_credentials is None: + raise CredentialException('unable to load original credentials from the provider in RAM role arn') + + tea_request.query['AccessKeyId'] = pre_credentials.get_access_key_id() + security_token = pre_credentials.get_security_token() + if security_token is not None and security_token != '': + tea_request.query['SecurityToken'] = security_token + + string_to_sign = ph.compose_string_to_sign('GET', tea_request.query) + signature = ph.sign_string(string_to_sign, pre_credentials.get_access_key_secret() + '&') + tea_request.query['Signature'] = signature + tea_request.protocol = 'https' + tea_request.headers['host'] = self._sts_endpoint + + response = TeaCore.do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException( + f'error refreshing credentials from ram_role_arn, http_code: {response.status_code}, result: {response.body.decode("utf-8")}') + + dic = json.loads(response.body.decode('utf-8')) + if 'Credentials' not in dic: + raise CredentialException( + f'error retrieving credentials from ram_role_arn result: {response.body.decode("utf-8")}') + + cre = dic.get('Credentials') + if 'AccessKeyId' not in cre or 'AccessKeySecret' not in cre or 'SecurityToken' not in cre: + raise CredentialException( + f'error retrieving credentials from ram_role_arn result: {response.body.decode("utf-8")}') + + # 先转换为时间数组 + time_array = time.strptime(cre.get('Expiration'), '%Y-%m-%dT%H:%M:%SZ') + # 转换为时间戳 + expiration = calendar.timegm(time_array) + credentials = Credentials( + access_key_id=cre.get('AccessKeyId'), + access_key_secret=cre.get('AccessKeySecret'), + security_token=cre.get('SecurityToken'), + expiration=expiration, + provider_name=f'{self.get_provider_name()}/{pre_credentials.get_provider_name()}' + ) + return RefreshResult(value=credentials, + stale_time=_get_stale_time(expiration)) + + async def _refresh_credentials_async(self) -> RefreshResult[Credentials]: + tea_request = ph.get_new_request() + tea_request.query = { + 'Action': 'AssumeRole', + 'Format': 'JSON', + 'Version': '2015-04-01', + 'DurationSeconds': str(self._duration_seconds), + 'RoleArn': self._role_arn, + 'RoleSessionName': self._role_session_name, + 'SignatureMethod': 'HMAC-SHA1', + 'SignatureVersion': '1.0', + 'Timestamp': ph.get_iso_8061_date(), + 'SignatureNonce': ph.get_uuid() + } + + if self._policy is not None and self._policy != '': + tea_request.query['Policy'] = self._policy + + if self._external_id is not None and self._external_id != '': + tea_request.query['ExternalId'] = self._external_id + + pre_credentials = await self._credentials_provider.get_credentials_async() + if pre_credentials is None: + raise CredentialException('unable to load original credentials from the provider in RAM role arn') + + tea_request.query['AccessKeyId'] = pre_credentials.get_access_key_id() + security_token = pre_credentials.get_security_token() + if security_token is not None and security_token != '': + tea_request.query['SecurityToken'] = security_token + + string_to_sign = ph.compose_string_to_sign('GET', tea_request.query) + signature = ph.sign_string(string_to_sign, pre_credentials.get_access_key_secret() + '&') + tea_request.query['Signature'] = signature + tea_request.protocol = 'https' + tea_request.headers['host'] = self._sts_endpoint + + response = await TeaCore.async_do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException( + f'error refreshing credentials from ram_role_arn, http_code: {response.status_code}, result: {response.body.decode("utf-8")}') + + dic = json.loads(response.body.decode('utf-8')) + if 'Credentials' not in dic: + raise CredentialException( + f'error retrieving credentials from ram_role_arn result: {response.body.decode("utf-8")}') + + cre = dic.get('Credentials') + if 'AccessKeyId' not in cre or 'AccessKeySecret' not in cre or 'SecurityToken' not in cre: + raise CredentialException( + f'error retrieving credentials from ram_role_arn result: {response.body.decode("utf-8")}') + + # 先转换为时间数组 + time_array = time.strptime(cre.get('Expiration'), '%Y-%m-%dT%H:%M:%SZ') + # 转换为时间戳 + expiration = calendar.timegm(time_array) + credentials = Credentials( + access_key_id=cre.get('AccessKeyId'), + access_key_secret=cre.get('AccessKeySecret'), + security_token=cre.get('SecurityToken'), + expiration=expiration, + provider_name=f'{self.get_provider_name()}/{pre_credentials.get_provider_name()}' + ) + return RefreshResult(value=credentials, + stale_time=_get_stale_time(expiration)) + + def get_provider_name(self) -> str: + return 'ram_role_arn' diff --git a/alibabacloud_credentials/provider/refreshable.py b/alibabacloud_credentials/provider/refreshable.py new file mode 100644 index 0000000..de42768 --- /dev/null +++ b/alibabacloud_credentials/provider/refreshable.py @@ -0,0 +1,304 @@ +import random +import asyncio +import threading +import weakref +import logging +import time +from datetime import datetime +from enum import Enum +from typing import Callable, Generic, TypeVar, Coroutine, Any +from threading import Semaphore +from concurrent.futures.thread import ThreadPoolExecutor, _worker, _base, _threads_queues + +from alibabacloud_credentials.exceptions import CredentialException +from alibabacloud_credentials_api import ICredentials + +logging.basicConfig(level=logging.DEBUG) +log = logging.getLogger(__name__) + +T = TypeVar('T') +INT64_MAX = 2 ** 63 - 1 +MAX_CONCURRENT_REFRESHES = 100 +CONCURRENT_REFRESH_LEASES = Semaphore(MAX_CONCURRENT_REFRESHES) + + +class _DaemonThreadPoolExecutor(ThreadPoolExecutor): + def _adjust_thread_count(self): + # if idle threads are available, don't spin new threads + if self._idle_semaphore.acquire(timeout=0): + return + + # When the executor gets lost, the weakref callback will wake up + # the worker threads. + def weakref_cb(_, q=self._work_queue): + q.put(None) + + num_threads = len(self._threads) + if num_threads < self._max_workers: + thread_name = '%s_%d' % (self._thread_name_prefix or self, + num_threads) + t = threading.Thread(target=_worker, + name=thread_name, + args=(weakref.ref(self, weakref_cb), + self._work_queue, + self._initializer, + self._initargs), + daemon=True) # Set thread as daemon + t.start() + self._threads.add(t) + _threads_queues[t] = self._work_queue + + +EXECUTOR = _DaemonThreadPoolExecutor(max_workers=INT64_MAX, thread_name_prefix='non-blocking-refresh') + + +def _jitter_time(now: int, jitter_start: int, jitter_end: int) -> int: + jitter_amount = random.randint(jitter_start, jitter_end) + return now + jitter_amount + + +def _max_stale_failure_jitter(num_failures: int) -> int: + backoff_millis = max(10 * 1000, (1 << num_failures - 1) * 100) + return backoff_millis + + +class Credentials(ICredentials): + def __init__(self, *, + access_key_id: str = None, + access_key_secret: str = None, + security_token: str = None, + expiration: int = None, + provider_name: str = None): + self._access_key_id = access_key_id + self._access_key_secret = access_key_secret + self._security_token = security_token + self._expiration = expiration + self._provider_name = provider_name + + def get_access_key_id(self) -> str: + return self._access_key_id + + def get_access_key_secret(self) -> str: + return self._access_key_secret + + def get_security_token(self) -> str: + return self._security_token + + def get_expiration(self) -> int: + return self._expiration + + def get_provider_name(self) -> str: + return self._provider_name + + +class StaleValueBehavior(Enum): + """ + Strictly treat the stale time. Never return a stale cached value (except when the supplier returns an expired + value, in which case the supplier will return the value but only for a very short period of time to prevent + overloading the underlying supplier). + """ + STRICT = 0 + """ + Allow stale values to be returned from the cache. Value retrieval will never fail, as long as the cache has + succeeded when calling the underlying supplier at least once. + """ + ALLOW = 1 + + +class RefreshResult(Generic[T]): + def __init__(self, *, + value: T, + stale_time: int = INT64_MAX, + prefetch_time: int = INT64_MAX): + self._value = value + self._stale_time = stale_time + self._prefetch_time = prefetch_time + + def value(self) -> T: + return self._value + + def stale_time(self) -> int: + return self._stale_time + + def prefetch_time(self) -> int: + return self._prefetch_time + + +class PrefetchStrategy: + def prefetch(self, action: Callable): + raise NotImplementedError + + async def prefetch_async(self, action: Callable): + raise NotImplementedError + + +class NonBlocking(PrefetchStrategy): + + def prefetch(self, action: Callable): + if not CONCURRENT_REFRESH_LEASES.acquire(False): + log.warning('Skipping a background refresh task because there are too many other tasks running.') + return + + try: + EXECUTOR.submit(action) + except Exception as t: + log.warning(f'Exception occurred when submitting background task.', exc_info=True) + finally: + CONCURRENT_REFRESH_LEASES.release() + + async def prefetch_async(self, action: Callable): + def run_asyncio_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(action()) + loop.close() + + self.prefetch(run_asyncio_loop) + + +class OneCallerBlocks(PrefetchStrategy): + def prefetch(self, action: Callable): + action() + + async def prefetch_async(self, action: Callable): + await action() + + +class RefreshCachedSupplier(Generic[T]): + STALE_TIME = 15 * 60 # seconds + REFRESH_BLOCKING_MAX_WAIT = 5 # seconds + + def __init__(self, refresh_callable: Callable[[], RefreshResult[T]], + refresh_callable_async: Callable[[], Coroutine[Any, Any, RefreshResult[T]]], + stale_value_behavior: StaleValueBehavior = StaleValueBehavior.STRICT, + prefetch_strategy: PrefetchStrategy = OneCallerBlocks()): + + self._refresh_callable = refresh_callable + self._refresh_callable_async = refresh_callable_async + self._stale_value_behavior = stale_value_behavior + self._prefetch_strategy = prefetch_strategy + self._consecutive_refresh_failures = 0 + self._cached_value = None + self._refresh_lock = threading.Lock() + self._loop = asyncio.get_event_loop() + + def __call__(self): + if self._loop.is_running(): + return self._async_call() + else: + return self._sync_call() + + def _sync_call(self) -> T: + if self._cache_is_stale(): + log.debug('Refreshing synchronously') + self._refresh_cache() + elif self._should_initiate_cache_prefetch(): + log.debug(f'Prefetching using strategy: {self._prefetch_strategy.__class__.__name__}') + self._prefetch_cache() + return self._cached_value.value() + + async def _async_call(self) -> T: + if self._cache_is_stale(): + log.debug('Refreshing synchronously') + await self._refresh_cache_async() + elif self._should_initiate_cache_prefetch(): + log.debug(f'Prefetching using strategy: {self._prefetch_strategy.__class__.__name__}') + await self._prefetch_cache_async() + return self._cached_value.value() + + def _cache_is_stale(self) -> bool: + if self._cached_value is None: + return True + return int(time.mktime(time.localtime())) >= self._cached_value.stale_time() + + def _should_initiate_cache_prefetch(self) -> bool: + if self._cached_value is None: + return True + return int(time.mktime(time.localtime())) >= self._cached_value.prefetch_time() + + def _prefetch_cache(self): + self._prefetch_strategy.prefetch(self._refresh_cache) + + def _refresh_cache(self): + acquired = self._refresh_lock.acquire(timeout=RefreshCachedSupplier.REFRESH_BLOCKING_MAX_WAIT) + try: + if self._cache_is_stale() or self._should_initiate_cache_prefetch(): + try: + self._cached_value = self._handle_fetched_success(self._refresh_callable()) + except Exception as ex: + self._cached_value = self._handle_fetched_failure(ex) + finally: + if acquired: + self._refresh_lock.release() + + async def _prefetch_cache_async(self): + await self._prefetch_strategy.prefetch_async(self._refresh_cache_async) + + async def _refresh_cache_async(self): + acquired = self._refresh_lock.acquire(timeout=RefreshCachedSupplier.REFRESH_BLOCKING_MAX_WAIT) + try: + if self._cache_is_stale() or self._should_initiate_cache_prefetch(): + try: + self._cached_value = self._handle_fetched_success(await self._refresh_callable_async()) + except Exception as ex: + self._cached_value = self._handle_fetched_failure(ex) + finally: + if acquired: + self._refresh_lock.release() + + def _handle_fetched_success(self, value: RefreshResult[T]) -> RefreshResult[T]: + log.debug(f'Refresh credentials successfully, retrieved value is {value}, cached value is {self._cached_value}') + self._consecutive_refresh_failures = 0 + now = int(time.mktime(time.localtime())) + # 过期时间大于15分钟,不用管 + if now < value.stale_time(): + log.debug( + f'Retrieved value stale time is {datetime.fromtimestamp(value.stale_time())}. Using staleTime of {datetime.fromtimestamp(value.stale_time())}') + return value + # 不足或等于15分钟,但未过期,下次会再次刷新 + if now < value.stale_time() + RefreshCachedSupplier.STALE_TIME: + log.warning( + f'Retrieved value stale time is in the past ({datetime.fromtimestamp(value.stale_time())}). Using staleTime of {datetime.fromtimestamp(now)}') + return RefreshResult(value=value.value(), stale_time=now, prefetch_time=value.prefetch_time()) + + log.warning( + f'Retrieved value expiration time of the credential is in the past ({datetime.fromtimestamp(value.stale_time() + RefreshCachedSupplier.STALE_TIME)}). Trying use the cached value.') + # 已过期,看缓存,缓存若大于15分钟,返回缓存,若小于15分钟,则根据策略判断是立刻重试还是稍后重试 + if self._cached_value is None: + raise CredentialException('No cached value was found.') + elif now < self._cached_value.stale_time(): + log.warning( + f'Cached value staleTime is {datetime.fromtimestamp(self._cached_value.stale_time())}. Using staleTime of {datetime.fromtimestamp(self._cached_value.stale_time())}') + return self._cached_value + elif self._stale_value_behavior == StaleValueBehavior.STRICT: + log.warning( + f'Cached value expiration is in the past ({datetime.fromtimestamp(self._cached_value.stale_time())}). Using expiration of {datetime.fromtimestamp(now + 1)}') + return RefreshResult(value=self._cached_value.value(), stale_time=now + 1, + prefetch_time=self._cached_value.prefetch_time()) + else: # ALLOW + extended_stale_time = now + int((50 * 1000 + random.randint(0, 20 * 1000 + 1)) / 1000) + log.warning( + f'Cached value expiration has been extended to {datetime.fromtimestamp(extended_stale_time)} because the downstream service returned a time in the past: {datetime.fromtimestamp(self._cached_value.stale_time())}') + return RefreshResult(value=self._cached_value.value(), stale_time=extended_stale_time, + prefetch_time=self._cached_value.prefetch_time()) + + def _handle_fetched_failure(self, exception: Exception) -> RefreshResult[T]: + log.warning(f'Refresh credentials failed, cached value is {self._cached_value}, error: {exception}') + if not self._cached_value: + log.exception(exception) + raise exception + now = int(time.mktime(time.localtime())) + if now < self._cached_value.stale_time(): + return self._cached_value + + self._consecutive_refresh_failures += 1 + if self._stale_value_behavior == StaleValueBehavior.STRICT: + log.exception(exception) + raise exception + else: # ALLOW + new_stale_time = int( + _jitter_time(now * 1000, 1000, _max_stale_failure_jitter(self._consecutive_refresh_failures)) / 1000) + log.warning( + f'Cached value expiration has been extended to {datetime.fromtimestamp(new_stale_time)} because calling the downstream service failed (consecutive failures: {self._consecutive_refresh_failures}).') + return RefreshResult(value=self._cached_value.value(), stale_time=new_stale_time, + prefetch_time=self._cached_value.prefetch_time()) diff --git a/alibabacloud_credentials/provider/rsa_key_pair.py b/alibabacloud_credentials/provider/rsa_key_pair.py new file mode 100644 index 0000000..b638ce5 --- /dev/null +++ b/alibabacloud_credentials/provider/rsa_key_pair.py @@ -0,0 +1,186 @@ +import calendar +import json +import time + +from alibabacloud_credentials.provider.refreshable import Credentials, RefreshResult, RefreshCachedSupplier +from alibabacloud_credentials.http import HttpOptions +from Tea.core import TeaCore +from alibabacloud_credentials_api import ICredentialsProvider +from alibabacloud_credentials.utils import auth_util as au +from alibabacloud_credentials.utils import parameter_helper as ph +from alibabacloud_credentials.exceptions import CredentialException + + +def _get_stale_time(expiration: int) -> int: + if expiration < 0: + return int(time.mktime(time.localtime())) + 60 * 60 + return expiration - 15 * 60 + + +def _get_content(file_path: str) -> str: + with open(file_path, mode='r') as file: + content = file.read() + return content + + +class RsaKeyPairCredentialsProvider(ICredentialsProvider): + DEFAULT_DURATION_SECONDS = 3600 + DEFAULT_CONNECT_TIMEOUT = 5000 + DEFAULT_READ_TIMEOUT = 10000 + + def __init__(self, *, + public_key_id: str = None, + private_key_file: str = None, + duration_seconds: int = DEFAULT_DURATION_SECONDS, + sts_region_id: str = None, + sts_endpoint: str = None, + enable_vpc: bool = None, + http_options: HttpOptions = None): + + self._public_key_id = public_key_id + self._private_key_file = private_key_file + self._duration_seconds = duration_seconds + + if self._duration_seconds is None: + self._duration_seconds = self.DEFAULT_DURATION_SECONDS + if self._duration_seconds < 900: + raise ValueError('session duration should be in the range of 900s - max session duration') + if self._public_key_id is None or self._public_key_id == '': + raise ValueError('public_key_id cannot be empty') + if self._private_key_file is None or self._private_key_file == '': + raise ValueError('private_key_file cannot be empty') + self._private_key = _get_content(self._private_key_file) + if self._private_key is None or self._private_key == '': + raise ValueError('private_key cannot be empty') + + if sts_endpoint is not None and sts_endpoint != '': + self._sts_endpoint = sts_endpoint + else: + if enable_vpc is not None: + prefix = 'sts-vpc' if enable_vpc else 'sts' + else: + prefix = 'sts-vpc' if au.environment_enable_vpc.lower() == 'true' else 'sts' + if sts_region_id is not None and sts_region_id != '': + self._sts_endpoint = f'{prefix}.{sts_region_id}.aliyuncs.com' + elif au.environment_sts_region is not None and au.environment_sts_region != '': + self._sts_endpoint = f'{prefix}.{au.environment_sts_region}.aliyuncs.com' + else: + self._sts_endpoint = 'sts.ap-northeast-1.aliyuncs.com' + + self._http_options = http_options if http_options is not None else HttpOptions() + self._runtime_options = { + 'connectTimeout': self._http_options.connect_timeout if self._http_options.connect_timeout is not None else RsaKeyPairCredentialsProvider.DEFAULT_CONNECT_TIMEOUT, + 'readTimeout': self._http_options.read_timeout if self._http_options.read_timeout is not None else RsaKeyPairCredentialsProvider.DEFAULT_READ_TIMEOUT, + 'httpsProxy': self._http_options.proxy + } + self._credentials_cache = RefreshCachedSupplier( + refresh_callable=self._refresh_credentials, + refresh_callable_async=self._refresh_credentials_async, + ) + + def get_credentials(self) -> Credentials: + return self._credentials_cache() + + async def get_credentials_async(self) -> Credentials: + return await self._credentials_cache() + + def _refresh_credentials(self) -> RefreshResult[Credentials]: + tea_request = ph.get_new_request() + tea_request.query = { + 'Action': 'GenerateSessionAccessKey', + 'Format': 'JSON', + 'Version': '2015-04-01', + 'DurationSeconds': str(self._duration_seconds), + 'SignatureMethod': 'HMAC-SHA1', + 'SignatureVersion': '1.0', + 'Timestamp': ph.get_iso_8061_date(), + 'SignatureNonce': ph.get_uuid(), + 'AccessKeyId': self._public_key_id, + } + + string_to_sign = ph.compose_string_to_sign('GET', tea_request.query) + signature = ph.sign_string(string_to_sign, self._private_key + '&') + tea_request.query['Signature'] = signature + tea_request.protocol = 'https' + tea_request.headers['host'] = self._sts_endpoint + + response = TeaCore.do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException( + f'error refreshing credentials from rsa_key_pair, http_code: {response.status_code}, result: {response.body.decode("utf-8")}') + + dic = json.loads(response.body.decode('utf-8')) + if 'SessionAccessKey' not in dic: + raise CredentialException( + f'error retrieving credentials from rsa_key_pair result: {response.body.decode("utf-8")}') + + cre = dic.get('SessionAccessKey') + if 'SessionAccessKeyId' not in cre or 'SessionAccessKeySecret' not in cre: + raise CredentialException( + f'error retrieving credentials from rsa_key_pair result: {response.body.decode("utf-8")}') + + # 先转换为时间数组 + time_array = time.strptime(cre.get('Expiration'), '%Y-%m-%dT%H:%M:%SZ') + # 转换为时间戳 + expiration = calendar.timegm(time_array) + credentials = Credentials( + access_key_id=cre.get('SessionAccessKeyId'), + access_key_secret=cre.get('SessionAccessKeySecret'), + expiration=expiration, + provider_name=self.get_provider_name() + ) + return RefreshResult(value=credentials, + stale_time=_get_stale_time(expiration)) + + async def _refresh_credentials_async(self) -> RefreshResult[Credentials]: + tea_request = ph.get_new_request() + tea_request.query = { + 'Action': 'GenerateSessionAccessKey', + 'Format': 'JSON', + 'Version': '2015-04-01', + 'DurationSeconds': str(self._duration_seconds), + 'SignatureMethod': 'HMAC-SHA1', + 'SignatureVersion': '1.0', + 'Timestamp': ph.get_iso_8061_date(), + 'SignatureNonce': ph.get_uuid(), + 'AccessKeyId': self._public_key_id, + } + + string_to_sign = ph.compose_string_to_sign('GET', tea_request.query) + signature = ph.sign_string(string_to_sign, self._private_key + '&') + tea_request.query['Signature'] = signature + tea_request.protocol = 'https' + tea_request.headers['host'] = self._sts_endpoint + + response = await TeaCore.async_do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException( + f'error refreshing credentials from rsa_key_pair, http_code: {response.status_code}, result: {response.body.decode("utf-8")}') + + dic = json.loads(response.body.decode('utf-8')) + if 'SessionAccessKey' not in dic: + raise CredentialException( + f'error retrieving credentials from rsa_key_pair result: {response.body.decode("utf-8")}') + + cre = dic.get('SessionAccessKey') + if 'SessionAccessKeyId' not in cre or 'SessionAccessKeySecret' not in cre: + raise CredentialException( + f'error retrieving credentials from rsa_key_pair result: {response.body.decode("utf-8")}') + + # 先转换为时间数组 + time_array = time.strptime(cre.get('Expiration'), '%Y-%m-%dT%H:%M:%SZ') + # 转换为时间戳 + expiration = calendar.timegm(time_array) + credentials = Credentials( + access_key_id=cre.get('SessionAccessKeyId'), + access_key_secret=cre.get('SessionAccessKeySecret'), + expiration=expiration, + provider_name=self.get_provider_name() + ) + return RefreshResult(value=credentials, + stale_time=_get_stale_time(expiration)) + + def get_provider_name(self) -> str: + return 'rsa_key_pair' diff --git a/alibabacloud_credentials/provider/static_ak.py b/alibabacloud_credentials/provider/static_ak.py new file mode 100644 index 0000000..6aa8c31 --- /dev/null +++ b/alibabacloud_credentials/provider/static_ak.py @@ -0,0 +1,32 @@ +from alibabacloud_credentials.provider.refreshable import Credentials +from alibabacloud_credentials_api import ICredentialsProvider +from alibabacloud_credentials.utils import auth_util + + +class StaticAKCredentialsProvider(ICredentialsProvider): + + def __init__(self, *, + access_key_id: str = None, + access_key_secret: str = None): + + self.access_key_id = access_key_id or auth_util.environment_access_key_id + self.access_key_secret = access_key_secret or auth_util.environment_access_key_secret + + if self.access_key_id is None or self.access_key_id == '': + raise ValueError('the access key id is empty') + if self.access_key_secret is None or self.access_key_secret == '': + raise ValueError('the access key secret is empty') + + def get_credentials(self) -> Credentials: + + return Credentials( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + provider_name=self.get_provider_name() + ) + + async def get_credentials_async(self) -> Credentials: + return self.get_credentials() + + def get_provider_name(self) -> str: + return 'static_ak' diff --git a/alibabacloud_credentials/provider/static_sts.py b/alibabacloud_credentials/provider/static_sts.py new file mode 100644 index 0000000..a4b3c50 --- /dev/null +++ b/alibabacloud_credentials/provider/static_sts.py @@ -0,0 +1,37 @@ +from alibabacloud_credentials.provider.refreshable import Credentials +from alibabacloud_credentials_api import ICredentialsProvider +from alibabacloud_credentials.utils import auth_util + + +class StaticSTSCredentialsProvider(ICredentialsProvider): + + def __init__(self, *, + access_key_id: str = None, + access_key_secret: str = None, + security_token: str = None): + + self.access_key_id = access_key_id or auth_util.environment_access_key_id + self.access_key_secret = access_key_secret or auth_util.environment_access_key_secret + self.security_token = security_token or auth_util.environment_security_token + + if self.access_key_id is None or self.access_key_id == '': + raise ValueError('the access key id is empty') + if self.access_key_secret is None or self.access_key_secret == '': + raise ValueError('the access key secret is empty') + if self.security_token is None or self.security_token == '': + raise ValueError('the security token is empty') + + def get_credentials(self) -> Credentials: + + return Credentials( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + provider_name=self.get_provider_name() + ) + + async def get_credentials_async(self) -> Credentials: + return self.get_credentials() + + def get_provider_name(self) -> str: + return 'static_sts' diff --git a/alibabacloud_credentials/provider/uri.py b/alibabacloud_credentials/provider/uri.py new file mode 100644 index 0000000..4b437fd --- /dev/null +++ b/alibabacloud_credentials/provider/uri.py @@ -0,0 +1,135 @@ +import calendar +import json +import time +from urllib.parse import urlparse, parse_qs + +from alibabacloud_credentials.provider.refreshable import Credentials, RefreshResult, RefreshCachedSupplier +from alibabacloud_credentials.http import HttpOptions +from Tea.core import TeaCore +from alibabacloud_credentials_api import ICredentialsProvider +from alibabacloud_credentials.utils import auth_util as au +from alibabacloud_credentials.utils import parameter_helper as ph +from alibabacloud_credentials.exceptions import CredentialException + + +def _get_stale_time(expiration: int) -> int: + if expiration < 0: + return int(time.mktime(time.localtime())) + 60 * 60 + return expiration - 15 * 60 + + +class URLCredentialsProvider(ICredentialsProvider): + DEFAULT_CONNECT_TIMEOUT = 5000 + DEFAULT_READ_TIMEOUT = 10000 + + def __init__(self, *, + uri: str = None, + protocol: str = 'http', + http_options: HttpOptions = None): + + self._uri = uri or au.environment_credentials_uri + if self._uri is None or self._uri == '': + raise ValueError('uri or environment variable ALIBABA_CLOUD_CREDENTIALS_URI cannot be empty') + self._protocol = protocol + + self._http_options = http_options if http_options is not None else HttpOptions() + self._runtime_options = { + 'connectTimeout': self._http_options.connect_timeout if self._http_options.connect_timeout is not None else URLCredentialsProvider.DEFAULT_CONNECT_TIMEOUT, + 'readTimeout': self._http_options.read_timeout if self._http_options.read_timeout is not None else URLCredentialsProvider.DEFAULT_READ_TIMEOUT, + 'httpsProxy': self._http_options.proxy + } + self._credentials_cache = RefreshCachedSupplier( + refresh_callable=self._refresh_credentials, + refresh_callable_async=self._refresh_credentials_async, + ) + + def get_credentials(self) -> Credentials: + return self._credentials_cache() + + async def get_credentials_async(self) -> Credentials: + return await self._credentials_cache() + + def _refresh_credentials(self) -> RefreshResult[Credentials]: + r = urlparse(self._uri) + tea_request = ph.get_new_request() + tea_request.headers['host'] = r.hostname + tea_request.port = r.port + tea_request.protocol = self._protocol + tea_request.method = 'GET' + tea_request.pathname = r.path + for key, values in parse_qs(r.query).items(): + for value in values: + tea_request.query[key] = value + + response = TeaCore.do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException( + f'error refreshing credentials from {self._uri}, http_code={str(response.status_code)}, result: {response.body.decode("utf-8")}') + + body = response.body.decode('utf-8') + + dic = json.loads(body) + content_code = dic.get('Code') + + if content_code != "Success" or 'AccessKeyId' not in dic or 'AccessKeySecret' not in dic or 'SecurityToken' not in dic or 'Expiration' not in dic: + raise CredentialException( + f'error retrieving credentials from {self._uri} result: {response.body.decode("utf-8")}') + + # 先转换为时间数组 + time_array = time.strptime(dic.get('Expiration'), '%Y-%m-%dT%H:%M:%SZ') + # 转换为时间戳 + expiration = calendar.timegm(time_array) + credentials = Credentials( + access_key_id=dic.get('AccessKeyId'), + access_key_secret=dic.get('AccessKeySecret'), + security_token=dic.get('SecurityToken'), + expiration=expiration, + provider_name=self.get_provider_name() + ) + return RefreshResult(value=credentials, + stale_time=_get_stale_time(expiration)) + + async def _refresh_credentials_async(self) -> RefreshResult[Credentials]: + r = urlparse(self._uri) + tea_request = ph.get_new_request() + tea_request.headers['host'] = r.hostname + tea_request.port = r.port + tea_request.protocol = r.scheme or self._protocol or 'http' + tea_request.method = 'GET' + tea_request.pathname = r.path + for key, values in parse_qs(r.query).items(): + for value in values: + tea_request.query[key] = value + + response = await TeaCore.async_do_action(tea_request, self._runtime_options) + + if response.status_code != 200: + raise CredentialException( + f'error refreshing credentials from {self._uri}, http_code={str(response.status_code)}, result: {response.body.decode("utf-8")}') + + body = response.body.decode('utf-8') + + dic = json.loads(body) + content_code = dic.get('Code') + + if content_code != "Success" or 'AccessKeyId' not in dic or 'AccessKeySecret' not in dic or 'SecurityToken' not in dic or 'Expiration' not in dic: + raise CredentialException( + f'error retrieving credentials from {self._uri} result: {response.body.decode("utf-8")}') + + # 先转换为时间数组 + time_array = time.strptime(dic.get('Expiration'), '%Y-%m-%dT%H:%M:%SZ') + # 转换为时间戳 + expiration = calendar.timegm(time_array) + credentials = Credentials( + access_key_id=dic.get('AccessKeyId'), + access_key_secret=dic.get('AccessKeySecret'), + security_token=dic.get('SecurityToken'), + expiration=expiration, + provider_name=self.get_provider_name() + ) + return RefreshResult(value=credentials, + stale_time=_get_stale_time(expiration)) + + def get_provider_name(self) -> str: + return 'credential_uri' diff --git a/alibabacloud_credentials/providers.py b/alibabacloud_credentials/providers.py index ca2073a..9bfa16e 100644 --- a/alibabacloud_credentials/providers.py +++ b/alibabacloud_credentials/providers.py @@ -18,7 +18,7 @@ class AlibabaCloudCredentialsProvider: """BaseProvider class""" duration_seconds = 3600 - timeout = 2000 + timeout = 3000 def __init__(self, config=None): if isinstance(config, Config): @@ -36,8 +36,8 @@ def __init__(self, config=None): self.bearer_token = config.bearer_token self.security_token = config.security_token self.host = config.host - self.timeout = config.timeout + config.connect_timeout - self.connect_timeout = config.connect_timeout + self.timeout = config.timeout or AlibabaCloudCredentialsProvider.timeout + self.connect_timeout = config.connect_timeout or AlibabaCloudCredentialsProvider.timeout self.proxy = config.proxy self.sts_endpoint = config.sts_endpoint diff --git a/alibabacloud_credentials/utils/auth_util.py b/alibabacloud_credentials/utils/auth_util.py index 713f058..2249a44 100644 --- a/alibabacloud_credentials/utils/auth_util.py +++ b/alibabacloud_credentials/utils/auth_util.py @@ -1,22 +1,33 @@ import os client_type = os.environ.get('ALIBABA_CLOUD_PROFILE', 'default') + environment_access_key_id = os.environ.get('ALIBABA_CLOUD_ACCESS_KEY_ID') environment_access_key_secret = os.environ.get('ALIBABA_CLOUD_ACCESS_KEY_SECRET') environment_security_token = os.environ.get('ALIBABA_CLOUD_SECURITY_TOKEN') + environment_ECSMeta_data = os.environ.get('ALIBABA_CLOUD_ECS_METADATA') -environment_imds_v1_disabled = os.environ.get('ALIBABA_CLOUD_IMDSV1_DISABLED') +environment_ecs_metadata = os.environ.get('ALIBABA_CLOUD_ECS_METADATA') +environment_imds_v1_disabled = os.environ.get('ALIBABA_CLOUD_IMDSV1_DISABLED', 'false') +environment_ecs_metadata_disabled = os.environ.get('ALIBABA_CLOUD_ECS_METADATA_DISABLED', 'false') + environment_credentials_file = os.environ.get('ALIBABA_CLOUD_CREDENTIALS_FILE') +environment_profile_name = os.environ.get('ALIBABA_CLOUD_PROFILE') environment_oidc_token_file = os.environ.get('ALIBABA_CLOUD_OIDC_TOKEN_FILE') environment_role_arn = os.environ.get('ALIBABA_CLOUD_ROLE_ARN') environment_oidc_provider_arn = os.environ.get('ALIBABA_CLOUD_OIDC_PROVIDER_ARN') environment_role_session_name = os.environ.get('ALIBABA_CLOUD_ROLE_SESSION_NAME') +environment_credentials_uri = os.environ.get('ALIBABA_CLOUD_CREDENTIALS_URI') + +environment_cli_profile_disabled = os.environ.get('ALIBABA_CLOUD_CLI_PROFILE_DISABLED', 'false') + environment_sts_region = os.environ.get('ALIBABA_CLOUD_STS_REGION') +environment_enable_vpc = os.environ.get('ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED', 'false') -enable_oidc_credential = environment_oidc_token_file is not None \ - and environment_role_arn is not None \ - and environment_oidc_provider_arn is not None +enable_oidc_credential = environment_oidc_token_file is not None and environment_oidc_token_file != '' \ + and environment_role_arn is not None and environment_role_arn != '' \ + and environment_oidc_provider_arn is not None and environment_oidc_provider_arn != '' private_key = None diff --git a/setup.py b/setup.py index 5076ad8..09a5014 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,7 @@ """ import os +import sys from setuptools import setup, find_packages """ @@ -35,6 +36,18 @@ with open("README.md", encoding="utf-8") as fp: LONG_DESCRIPTION = fp.read() +install_requires = [ + 'alibabacloud-tea>=0.4.0', + 'alibabacloud_credentials_api>=1.0.0, <2.0.0' +] + +if sys.version_info.minor <= 7: + install_requires.append('APScheduler>=3.10.0, <3.11.0') + install_requires.append('aiofiles>=22.1.0, <24.0.0') +else: + install_requires.append('APScheduler>=3.10.0, <4.0.0') + install_requires.append('aiofiles>=22.1.0, <25.0.0') + setup_args = { 'version': VERSION, 'description': DESCRIPTION, @@ -47,8 +60,8 @@ 'keywords': ["alibabacloud", "sdk", "tea"], 'packages': find_packages(exclude=["tests*"]), 'platforms': 'any', - 'python_requires': '>=3.6', - 'install_requires': ['alibabacloud-tea>=0.3.9'], + 'python_requires': '>=3.7', + 'install_requires': install_requires, 'classifiers': ( 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', @@ -64,4 +77,4 @@ ) } -setup(name='alibabacloud_credentials', **setup_args) +setup(name='alibabacloud-credentials', **setup_args) diff --git a/tests/provider/test_cli_profile.py b/tests/provider/test_cli_profile.py new file mode 100644 index 0000000..5e4091f --- /dev/null +++ b/tests/provider/test_cli_profile.py @@ -0,0 +1,385 @@ +import unittest +from unittest.mock import patch, MagicMock, AsyncMock +import asyncio +import os +import json +from alibabacloud_credentials.provider.cli_profile import ( + CLIProfileCredentialsProvider, + CredentialException, + _load_config_async, + _load_config +) +from alibabacloud_credentials.provider import ( + StaticAKCredentialsProvider, + RamRoleArnCredentialsProvider, + EcsRamRoleCredentialsProvider, + OIDCRoleArnCredentialsProvider +) +from alibabacloud_credentials.utils import auth_constant as ac + + +class TestCLIProfileCredentialsProvider(unittest.TestCase): + + def setUp(self): + self.profile_name = "test_profile" + self.profile_file = os.path.join(ac.HOME, "/.aliyun/config.json") + self.config = { + "current": "test_profile", + "profiles": [ + { + "name": "test_profile", + "mode": "AK", + "access_key_id": "test_access_key_id", + "access_key_secret": "test_access_key_secret" + }, + { + "name": "ram_role_profile", + "mode": "RamRoleArn", + "access_key_id": "test_access_key_id", + "access_key_secret": "test_access_key_secret", + "ram_role_arn": "test_ram_role_arn", + "ram_session_name": "test_ram_session_name", + "expired_seconds": 7200, + "policy": "test_policy", + "external_id": "test_external_id", + "sts_region": "test_sts_region", + "enable_vpc": True + }, + { + "name": "ecs_ram_role_profile", + "mode": "EcsRamRole", + "ram_role_name": "test_ram_role_name" + }, + { + "name": "oidc_profile", + "mode": "OIDC", + "ram_role_arn": "test_ram_role_arn", + "oidc_provider_arn": "test_oidc_provider_arn", + "oidc_token_file": "test_oidc_token_file", + "role_session_name": "test_role_session_name", + "expired_seconds": 7200, + "policy": "test_policy", + "sts_region": "test_sts_region", + "enable_vpc": True + }, + { + "name": "chainable_ram_role_profile", + "mode": "ChainableRamRoleArn", + "source_profile": "test_profile", + "ram_role_arn": "test_ram_role_arn", + "ram_session_name": "test_ram_session_name", + "expired_seconds": 7200, + "policy": "test_policy", + "external_id": "test_external_id", + "sts_region": "test_sts_region", + "enable_vpc": True + } + ] + } + self.access_key_id = "test_access_key_id" + self.access_key_secret = "test_access_key_secret" + self.security_token = "test_security_token" + self.expiration = "2023-12-31T23:59:59Z" + self.response_body = json.dumps({ + "AccessKeyId": self.access_key_id, + "AccessKeySecret": self.access_key_secret, + "SecurityToken": self.security_token, + "Expiration": self.expiration + }) + + def test_init_valid_input(self): + """ + Test case 1: Valid input, successfully initializes with provided parameters + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_profile_name', self.profile_name): + provider = CLIProfileCredentialsProvider() + + self.assertEqual(provider._profile_name, self.profile_name) + self.assertEqual(provider._profile_file, os.path.join(ac.HOME, "/.aliyun/config.json")) + + def test_get_credentials_valid_ak(self): + """ + Test case 2: Valid input, successfully retrieves credentials for AK mode + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.cli_profile._load_config', return_value=self.config): + provider = CLIProfileCredentialsProvider(profile_name=self.profile_name) + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertIsNone(credentials.get_security_token()) + self.assertEqual(credentials.get_provider_name(), "cli_profile/static_ak") + + def test_get_credentials_valid_ram_role_arn(self): + """ + Test case 3: Valid input, successfully retrieves credentials for RamRoleArn mode + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.cli_profile._load_config', return_value=self.config): + provider = CLIProfileCredentialsProvider(profile_name="ram_role_profile") + + credentials_provider = provider._get_credentials_provider(config=self.config, + profile_name="ram_role_profile") + + self.assertIsInstance(credentials_provider, RamRoleArnCredentialsProvider) + + self.assertEqual( + credentials_provider._credentials_provider.get_credentials().get_access_key_id(), + self.access_key_id) + self.assertEqual( + credentials_provider._credentials_provider.get_credentials().get_access_key_secret(), + self.access_key_secret) + self.assertIsNone( + credentials_provider._credentials_provider.get_credentials().get_security_token()) + self.assertEqual( + credentials_provider._credentials_provider.get_credentials().get_provider_name(), + "static_ak") + self.assertEqual(credentials_provider._role_arn, 'test_ram_role_arn') + self.assertEqual(credentials_provider._role_session_name, 'test_ram_session_name') + self.assertEqual(credentials_provider._duration_seconds, 7200) + self.assertEqual(credentials_provider._policy, 'test_policy') + self.assertEqual(credentials_provider._external_id, 'test_external_id') + self.assertEqual(credentials_provider._sts_endpoint, 'sts-vpc.test_sts_region.aliyuncs.com') + + def test_get_credentials_valid_ecs_ram_role(self): + """ + Test case 4: Valid input, successfully retrieves credentials for EcsRamRole mode + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', False): + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.cli_profile._load_config', return_value=self.config): + provider = CLIProfileCredentialsProvider(profile_name="ecs_ram_role_profile") + + credentials_provider = provider._get_credentials_provider(config=self.config, + profile_name="ecs_ram_role_profile") + + self.assertIsInstance(credentials_provider, EcsRamRoleCredentialsProvider) + + self.assertEqual(credentials_provider._role_name, 'test_ram_role_name') + + def test_get_credentials_valid_oidc(self): + """ + Test case 5: Valid input, successfully retrieves credentials for OIDC mode + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', False): + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.cli_profile._load_config', return_value=self.config): + provider = CLIProfileCredentialsProvider(profile_name="oidc_profile") + + credentials_provider = provider._get_credentials_provider(config=self.config, + profile_name="oidc_profile") + + self.assertIsInstance(credentials_provider, OIDCRoleArnCredentialsProvider) + + self.assertEqual(credentials_provider._role_arn, 'test_ram_role_arn') + self.assertEqual(credentials_provider._oidc_provider_arn, 'test_oidc_provider_arn') + self.assertEqual(credentials_provider._role_session_name, 'test_role_session_name') + self.assertEqual(credentials_provider._duration_seconds, 7200) + self.assertEqual(credentials_provider._policy, 'test_policy') + self.assertEqual(credentials_provider._sts_endpoint, 'sts-vpc.test_sts_region.aliyuncs.com') + + def test_get_credentials_valid_chainable_ram_role_arn(self): + """ + Test case 6: Valid input, successfully retrieves credentials for ChainableRamRoleArn mode + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', False): + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.cli_profile._load_config', return_value=self.config): + provider = CLIProfileCredentialsProvider(profile_name="chainable_ram_role_profile") + + credentials_provider = provider._get_credentials_provider(config=self.config, + profile_name="chainable_ram_role_profile") + + self.assertIsInstance(credentials_provider, RamRoleArnCredentialsProvider) + + self.assertEqual( + credentials_provider._credentials_provider.get_credentials().get_access_key_id(), + self.access_key_id) + self.assertEqual( + credentials_provider._credentials_provider.get_credentials().get_access_key_secret(), + self.access_key_secret) + self.assertIsNone( + credentials_provider._credentials_provider.get_credentials().get_security_token()) + self.assertEqual( + credentials_provider._credentials_provider.get_credentials().get_provider_name(), + "static_ak") + self.assertEqual(credentials_provider._role_arn, 'test_ram_role_arn') + self.assertEqual(credentials_provider._role_session_name, 'test_ram_session_name') + self.assertEqual(credentials_provider._duration_seconds, 7200) + self.assertEqual(credentials_provider._policy, 'test_policy') + self.assertEqual(credentials_provider._external_id, 'test_external_id') + self.assertEqual(credentials_provider._sts_endpoint, 'sts-vpc.test_sts_region.aliyuncs.com') + + def test_get_credentials_cli_profile_disabled(self): + """ + Test case 7: CLI profile disabled raises CredentialException + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'True'): + provider = CLIProfileCredentialsProvider(profile_name=self.profile_name) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("cli credentials file is disabled", str(context.exception)) + + def test_get_credentials_profile_file_not_exists(self): + """ + Test case 8: Profile file does not exist raises CredentialException + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): + with patch('os.path.exists', return_value=False): + provider = CLIProfileCredentialsProvider(profile_name=self.profile_name) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn(f'unable to open credentials file: {self.profile_file}', str(context.exception)) + + def test_get_credentials_profile_file_not_file(self): + """ + Test case 9: Profile file is not a file raises CredentialException + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=False): + provider = CLIProfileCredentialsProvider(profile_name=self.profile_name) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn(f'unable to open credentials file: {self.profile_file}', str(context.exception)) + + def test_get_credentials_invalid_json_format(self): + """ + Test case 10: Invalid JSON format in profile file raises CredentialException + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.cli_profile._load_config', + side_effect=json.JSONDecodeError('Invalid JSON', '', 0)): + provider = CLIProfileCredentialsProvider(profile_name=self.profile_name) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn(f'failed to parse credential form cli credentials file: {self.profile_file}', + str(context.exception)) + + def test_get_credentials_empty_json(self): + """ + Test case 11: Empty JSON in profile file raises CredentialException + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.cli_profile._load_config', return_value={}): + provider = CLIProfileCredentialsProvider(profile_name=self.profile_name) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("unable to get profile with 'test_profile' form cli credentials file.", + str(context.exception)) + + def test_get_credentials_missing_profiles(self): + """ + Test case 12: Missing profiles in JSON raises CredentialException + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.cli_profile._load_config', + return_value={"current": "test_profile"}): + provider = CLIProfileCredentialsProvider(profile_name=self.profile_name) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn(f"unable to get profile with 'test_profile' form cli credentials file.", + str(context.exception)) + + def test_get_credentials_invalid_profile_mode(self): + """ + Test case 13: Invalid profile mode raises CredentialException + """ + invalid_config = { + "current": "invalid_profile", + "profiles": [ + { + "name": "invalid_profile", + "mode": "InvalidMode", + "access_key_id": "test_access_key_id", + "access_key_secret": "test_access_key_secret" + } + ] + } + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.cli_profile._load_config', + return_value=invalid_config): + provider = CLIProfileCredentialsProvider(profile_name="invalid_profile") + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn(f"unsupported profile mode 'InvalidMode' form cli credentials file.", + str(context.exception)) + + def test_get_credentials_async_valid_ak(self): + """ + Test case 14: Valid input, successfully retrieves credentials for AK mode + """ + with patch('alibabacloud_credentials.provider.cli_profile.au.environment_cli_profile_disabled', 'False'): + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.cli_profile._load_config_async', + AsyncMock(return_value=self.config)): + provider = CLIProfileCredentialsProvider(profile_name=self.profile_name) + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertIsNone(credentials.get_security_token()) + self.assertEqual(credentials.get_provider_name(), "cli_profile/static_ak") + + @patch('builtins.open', new_callable=MagicMock) + def test_load_config_file_not_found(self, mock_open): + """ + Test case 15: File not found raises FileNotFoundError + """ + mock_open.side_effect = FileNotFoundError(f"No such file or directory: '{self.profile_file}'") + + with self.assertRaises(FileNotFoundError) as context: + _load_config(self.profile_file) + + self.assertIn(f"No such file or directory: '{self.profile_file}'", str(context.exception)) + + @patch('builtins.open', new_callable=MagicMock) + def test_load_config_invalid_json(self, mock_open): + """ + Test case 16: Invalid JSON format raises json.JSONDecodeError + """ + invalid_json = "invalid json content" + mock_open.return_value.__enter__.return_value.read.return_value = invalid_json + + with self.assertRaises(json.JSONDecodeError) as context: + _load_config(self.profile_file) + + self.assertIn("Expecting value: line 1 column 1", str(context.exception)) diff --git a/tests/provider/test_default.py b/tests/provider/test_default.py new file mode 100644 index 0000000..3fd2643 --- /dev/null +++ b/tests/provider/test_default.py @@ -0,0 +1,494 @@ +import unittest +from unittest.mock import patch, MagicMock, AsyncMock +import asyncio +from alibabacloud_credentials.provider.default import DefaultCredentialsProvider +from alibabacloud_credentials.provider import ( + EnvironmentVariableCredentialsProvider, + OIDCRoleArnCredentialsProvider, + CLIProfileCredentialsProvider, + ProfileCredentialsProvider, + EcsRamRoleCredentialsProvider, + URLCredentialsProvider +) +from alibabacloud_credentials.provider.refreshable import Credentials +from alibabacloud_credentials.exceptions import CredentialException + + +class TestDefaultCredentialsProvider(unittest.TestCase): + + def setUp(self): + self.access_key_id = "test_access_key_id" + self.access_key_secret = "test_access_key_secret" + self.security_token = "test_security_token" + self.credentials = Credentials( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + provider_name="test_provider" + ) + + @patch('alibabacloud_credentials.provider.default.au.enable_oidc_credential', False) + @patch('alibabacloud_credentials.provider.default.au.environment_ecs_metadata_disabled', 'false') + @patch('alibabacloud_credentials.provider.default.au.environment_credentials_uri', None) + def test_get_credentials_with_environment_variable_provider(self): + """ + Test case 1: Successfully retrieves credentials from EnvironmentVariableCredentialsProvider + """ + env_provider = EnvironmentVariableCredentialsProvider() + env_provider.get_credentials = MagicMock(return_value=self.credentials) + + with patch('alibabacloud_credentials.provider.default.EnvironmentVariableCredentialsProvider', + return_value=env_provider): + provider = DefaultCredentialsProvider() + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + @patch('alibabacloud_credentials.provider.default.au.enable_oidc_credential', True) + @patch('alibabacloud_credentials.provider.oidc.au.environment_role_arn', 'test_role_arn') + @patch('alibabacloud_credentials.provider.oidc.au.environment_oidc_provider_arn', 'test_oidc_provider_arn') + @patch('alibabacloud_credentials.provider.oidc.au.environment_oidc_token_file', 'test_token_file') + @patch('alibabacloud_credentials.provider.default.au.environment_ecs_metadata_disabled', 'false') + @patch('alibabacloud_credentials.provider.default.au.environment_credentials_uri', None) + def test_get_credentials_with_oidc_provider(self): + """ + Test case 2: Successfully retrieves credentials from OIDCRoleArnCredentialsProvider + """ + env_provider = EnvironmentVariableCredentialsProvider() + env_provider.get_credentials = MagicMock( + side_effect=CredentialException("EnvironmentVariableCredentialsProvider failed")) + + oidc_provider = OIDCRoleArnCredentialsProvider() + oidc_provider.get_credentials = MagicMock(return_value=self.credentials) + + with patch('alibabacloud_credentials.provider.default.EnvironmentVariableCredentialsProvider', + return_value=env_provider): + with patch('alibabacloud_credentials.provider.default.OIDCRoleArnCredentialsProvider', + return_value=oidc_provider): + provider = DefaultCredentialsProvider() + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + @patch('alibabacloud_credentials.provider.default.au.enable_oidc_credential', True) + @patch('alibabacloud_credentials.provider.oidc.au.environment_role_arn', 'test_role_arn') + @patch('alibabacloud_credentials.provider.oidc.au.environment_oidc_provider_arn', 'test_oidc_provider_arn') + @patch('alibabacloud_credentials.provider.oidc.au.environment_oidc_token_file', 'test_token_file') + @patch('alibabacloud_credentials.provider.default.au.environment_ecs_metadata_disabled', 'false') + @patch('alibabacloud_credentials.provider.default.au.environment_credentials_uri', None) + def test_get_credentials_with_cli_profile_provider(self): + """ + Test case 3: Successfully retrieves credentials from CLIProfileCredentialsProvider + """ + env_provider = EnvironmentVariableCredentialsProvider() + env_provider.get_credentials = MagicMock( + side_effect=CredentialException("EnvironmentVariableCredentialsProvider failed")) + + oidc_provider = OIDCRoleArnCredentialsProvider() + oidc_provider.get_credentials = MagicMock( + side_effect=CredentialException("OIDCRoleArnCredentialsProvider failed")) + + cli_provider = CLIProfileCredentialsProvider() + cli_provider.get_credentials = MagicMock(return_value=self.credentials) + + with patch('alibabacloud_credentials.provider.default.EnvironmentVariableCredentialsProvider', + return_value=env_provider): + with patch('alibabacloud_credentials.provider.default.OIDCRoleArnCredentialsProvider', + return_value=oidc_provider): + with patch('alibabacloud_credentials.provider.default.CLIProfileCredentialsProvider', + return_value=cli_provider): + provider = DefaultCredentialsProvider() + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + @patch('alibabacloud_credentials.provider.default.au.enable_oidc_credential', True) + @patch('alibabacloud_credentials.provider.oidc.au.environment_role_arn', 'test_role_arn') + @patch('alibabacloud_credentials.provider.oidc.au.environment_oidc_provider_arn', 'test_oidc_provider_arn') + @patch('alibabacloud_credentials.provider.oidc.au.environment_oidc_token_file', 'test_token_file') + @patch('alibabacloud_credentials.provider.default.au.environment_ecs_metadata_disabled', 'false') + @patch('alibabacloud_credentials.provider.default.au.environment_credentials_uri', None) + def test_get_credentials_with_profile_provider(self): + """ + Test case 4: Successfully retrieves credentials from ProfileCredentialsProvider + """ + env_provider = EnvironmentVariableCredentialsProvider() + env_provider.get_credentials = MagicMock( + side_effect=CredentialException("EnvironmentVariableCredentialsProvider failed")) + + oidc_provider = OIDCRoleArnCredentialsProvider() + oidc_provider.get_credentials = MagicMock( + side_effect=CredentialException("OIDCRoleArnCredentialsProvider failed")) + + cli_provider = CLIProfileCredentialsProvider() + cli_provider.get_credentials = MagicMock( + side_effect=CredentialException("CLIProfileCredentialsProvider failed")) + + profile_provider = ProfileCredentialsProvider() + profile_provider.get_credentials = MagicMock(return_value=self.credentials) + + with patch('alibabacloud_credentials.provider.default.EnvironmentVariableCredentialsProvider', + return_value=env_provider): + with patch('alibabacloud_credentials.provider.default.OIDCRoleArnCredentialsProvider', + return_value=oidc_provider): + with patch('alibabacloud_credentials.provider.default.CLIProfileCredentialsProvider', + return_value=cli_provider): + with patch('alibabacloud_credentials.provider.default.ProfileCredentialsProvider', + return_value=profile_provider): + provider = DefaultCredentialsProvider() + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + @patch('alibabacloud_credentials.provider.default.au.enable_oidc_credential', True) + @patch('alibabacloud_credentials.provider.oidc.au.environment_role_arn', 'test_role_arn') + @patch('alibabacloud_credentials.provider.oidc.au.environment_oidc_provider_arn', 'test_oidc_provider_arn') + @patch('alibabacloud_credentials.provider.oidc.au.environment_oidc_token_file', 'test_token_file') + @patch('alibabacloud_credentials.provider.default.au.environment_ecs_metadata_disabled', 'false') + @patch('alibabacloud_credentials.provider.default.au.environment_credentials_uri', None) + def test_get_credentials_with_ecs_ram_role_provider(self): + """ + Test case 5: Successfully retrieves credentials from EcsRamRoleCredentialsProvider + """ + env_provider = EnvironmentVariableCredentialsProvider() + env_provider.get_credentials = MagicMock( + side_effect=CredentialException("EnvironmentVariableCredentialsProvider failed")) + + oidc_provider = OIDCRoleArnCredentialsProvider() + oidc_provider.get_credentials = MagicMock( + side_effect=CredentialException("OIDCRoleArnCredentialsProvider failed")) + + cli_provider = CLIProfileCredentialsProvider() + cli_provider.get_credentials = MagicMock( + side_effect=CredentialException("CLIProfileCredentialsProvider failed")) + + profile_provider = ProfileCredentialsProvider() + profile_provider.get_credentials = MagicMock( + side_effect=CredentialException("ProfileCredentialsProvider failed")) + + ecs_provider = EcsRamRoleCredentialsProvider() + ecs_provider.get_credentials = MagicMock(return_value=self.credentials) + + with patch('alibabacloud_credentials.provider.default.EnvironmentVariableCredentialsProvider', + return_value=env_provider): + with patch('alibabacloud_credentials.provider.default.OIDCRoleArnCredentialsProvider', + return_value=oidc_provider): + with patch('alibabacloud_credentials.provider.default.CLIProfileCredentialsProvider', + return_value=cli_provider): + with patch('alibabacloud_credentials.provider.default.ProfileCredentialsProvider', + return_value=profile_provider): + with patch('alibabacloud_credentials.provider.default.EcsRamRoleCredentialsProvider', + return_value=ecs_provider): + provider = DefaultCredentialsProvider() + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + @patch('alibabacloud_credentials.provider.default.au.enable_oidc_credential', True) + @patch('alibabacloud_credentials.provider.oidc.au.environment_role_arn', 'test_role_arn') + @patch('alibabacloud_credentials.provider.oidc.au.environment_oidc_provider_arn', 'test_oidc_provider_arn') + @patch('alibabacloud_credentials.provider.oidc.au.environment_oidc_token_file', 'test_token_file') + @patch('alibabacloud_credentials.provider.default.au.environment_ecs_metadata_disabled', 'false') + @patch('alibabacloud_credentials.provider.default.au.environment_credentials_uri', "http://example.com/credentials") + def test_get_credentials_with_url_provider(self): + """ + Test case 6: Successfully retrieves credentials from URLCredentialsProvider + """ + env_provider = EnvironmentVariableCredentialsProvider() + env_provider.get_credentials = MagicMock( + side_effect=CredentialException("EnvironmentVariableCredentialsProvider failed")) + + oidc_provider = OIDCRoleArnCredentialsProvider() + oidc_provider.get_credentials = MagicMock( + side_effect=CredentialException("OIDCRoleArnCredentialsProvider failed")) + + cli_provider = CLIProfileCredentialsProvider() + cli_provider.get_credentials = MagicMock( + side_effect=CredentialException("CLIProfileCredentialsProvider failed")) + + profile_provider = ProfileCredentialsProvider() + profile_provider.get_credentials = MagicMock( + side_effect=CredentialException("ProfileCredentialsProvider failed")) + + ecs_provider = EcsRamRoleCredentialsProvider() + ecs_provider.get_credentials = MagicMock( + side_effect=CredentialException("EcsRamRoleCredentialsProvider failed")) + + url_provider = URLCredentialsProvider() + url_provider.get_credentials = MagicMock(return_value=self.credentials) + + with patch('alibabacloud_credentials.provider.default.EnvironmentVariableCredentialsProvider', + return_value=env_provider): + with patch('alibabacloud_credentials.provider.default.OIDCRoleArnCredentialsProvider', + return_value=oidc_provider): + with patch('alibabacloud_credentials.provider.default.CLIProfileCredentialsProvider', + return_value=cli_provider): + with patch('alibabacloud_credentials.provider.default.ProfileCredentialsProvider', + return_value=profile_provider): + with patch('alibabacloud_credentials.provider.default.EcsRamRoleCredentialsProvider', + return_value=ecs_provider): + with patch('alibabacloud_credentials.provider.default.URLCredentialsProvider', + return_value=url_provider): + provider = DefaultCredentialsProvider() + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + @patch('alibabacloud_credentials.provider.default.au.environment_ecs_metadata_disabled', 'false') + @patch('alibabacloud_credentials.provider.default.au.environment_credentials_uri', "http://example.com/credentials") + def test_get_credentials_no_valid_provider(self): + """ + Test case 7: No valid provider raises CredentialException + """ + env_provider = EnvironmentVariableCredentialsProvider() + env_provider.get_credentials = MagicMock( + side_effect=CredentialException("EnvironmentVariableCredentialsProvider failed")) + + cli_provider = CLIProfileCredentialsProvider() + cli_provider.get_credentials = MagicMock( + side_effect=CredentialException("CLIProfileCredentialsProvider failed")) + + profile_provider = ProfileCredentialsProvider() + profile_provider.get_credentials = MagicMock( + side_effect=CredentialException("ProfileCredentialsProvider failed")) + + ecs_provider = EcsRamRoleCredentialsProvider() + ecs_provider.get_credentials = MagicMock( + side_effect=CredentialException("EcsRamRoleCredentialsProvider failed")) + + url_provider = URLCredentialsProvider() + url_provider.get_credentials = MagicMock(side_effect=CredentialException("URLCredentialsProvider failed")) + + with patch('alibabacloud_credentials.provider.default.EnvironmentVariableCredentialsProvider', + return_value=env_provider): + with patch('alibabacloud_credentials.provider.default.CLIProfileCredentialsProvider', + return_value=cli_provider): + with patch('alibabacloud_credentials.provider.default.ProfileCredentialsProvider', + return_value=profile_provider): + with patch('alibabacloud_credentials.provider.default.EcsRamRoleCredentialsProvider', + return_value=ecs_provider): + with patch('alibabacloud_credentials.provider.default.URLCredentialsProvider', + return_value=url_provider): + provider = DefaultCredentialsProvider() + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("unable to load credentials from any of the providers in the chain", + str(context.exception)) + + @patch('alibabacloud_credentials.provider.default.au.enable_oidc_credential', False) + @patch('alibabacloud_credentials.provider.default.au.environment_ecs_metadata_disabled', 'false') + @patch('alibabacloud_credentials.provider.default.au.environment_credentials_uri', None) + def test_get_credentials_async_with_environment_variable_provider(self): + """ + Test case 8: Successfully retrieves credentials asynchronously from EnvironmentVariableCredentialsProvider + """ + env_provider = EnvironmentVariableCredentialsProvider() + env_provider.get_credentials_async = AsyncMock(return_value=self.credentials) + + with patch('alibabacloud_credentials.provider.default.EnvironmentVariableCredentialsProvider', + return_value=env_provider): + provider = DefaultCredentialsProvider() + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + @patch('alibabacloud_credentials.provider.default.au.enable_oidc_credential', False) + @patch('alibabacloud_credentials.provider.default.au.environment_ecs_metadata_disabled', 'false') + @patch('alibabacloud_credentials.provider.default.au.environment_credentials_uri', None) + def test_get_credentials_reuse_last_provider_enabled(self): + """ + Test case 8: Reuse last provider when reuse_last_provider_enabled is True + """ + env_provider = EnvironmentVariableCredentialsProvider() + env_provider.get_credentials = MagicMock( + side_effect=CredentialException("EnvironmentVariableCredentialsProvider failed")) + + cli_provider = CLIProfileCredentialsProvider() + cli_provider.get_credentials = MagicMock(return_value=self.credentials) + + with patch('alibabacloud_credentials.provider.default.EnvironmentVariableCredentialsProvider', + return_value=env_provider): + with patch('alibabacloud_credentials.provider.default.CLIProfileCredentialsProvider', + return_value=cli_provider): + provider = DefaultCredentialsProvider() + + # First call to get_credentials + credentials = provider.get_credentials() + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + # Second call to get_credentials should reuse the last provider + credentials = provider.get_credentials() + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + # Ensure get_credentials was only called once on the provider + env_provider.get_credentials.assert_called_once() + + @patch('alibabacloud_credentials.provider.default.au.enable_oidc_credential', False) + @patch('alibabacloud_credentials.provider.default.au.environment_ecs_metadata_disabled', 'false') + @patch('alibabacloud_credentials.provider.default.au.environment_credentials_uri', None) + def test_get_credentials_reuse_last_provider_disabled(self): + """ + Test case 9: Do not reuse last provider when reuse_last_provider_enabled is False + """ + env_provider = EnvironmentVariableCredentialsProvider() + env_provider.get_credentials = MagicMock( + side_effect=CredentialException("EnvironmentVariableCredentialsProvider failed")) + + cli_provider = CLIProfileCredentialsProvider() + cli_provider.get_credentials = MagicMock(return_value=self.credentials) + + with patch('alibabacloud_credentials.provider.default.EnvironmentVariableCredentialsProvider', + return_value=env_provider): + with patch('alibabacloud_credentials.provider.default.CLIProfileCredentialsProvider', + return_value=cli_provider): + provider = DefaultCredentialsProvider(reuse_last_provider_enabled=False) + + # First call to get_credentials + credentials = provider.get_credentials() + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + # Second call to get_credentials should not reuse the last provider + credentials = provider.get_credentials() + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + # Ensure get_credentials was called twice on the provider + self.assertEqual(env_provider.get_credentials.call_count, 2) + + @patch('alibabacloud_credentials.provider.default.au.enable_oidc_credential', False) + @patch('alibabacloud_credentials.provider.default.au.environment_ecs_metadata_disabled', 'false') + @patch('alibabacloud_credentials.provider.default.au.environment_credentials_uri', None) + def test_get_credentials_async_reuse_last_provider_enabled(self): + """ + Test case 8: Reuse last provider when reuse_last_provider_enabled is True + """ + env_provider = EnvironmentVariableCredentialsProvider() + env_provider.get_credentials_async = AsyncMock( + side_effect=CredentialException("EnvironmentVariableCredentialsProvider failed")) + + cli_provider = CLIProfileCredentialsProvider() + cli_provider.get_credentials_async = AsyncMock(return_value=self.credentials) + + with patch('alibabacloud_credentials.provider.default.EnvironmentVariableCredentialsProvider', + return_value=env_provider): + with patch('alibabacloud_credentials.provider.default.CLIProfileCredentialsProvider', + return_value=cli_provider): + provider = DefaultCredentialsProvider() + + # First call to get_credentials + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + # Second call to get_credentials should reuse the last provider + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + # Ensure get_credentials was only called once on the provider + env_provider.get_credentials_async.assert_called_once() + + @patch('alibabacloud_credentials.provider.default.au.enable_oidc_credential', False) + @patch('alibabacloud_credentials.provider.default.au.environment_ecs_metadata_disabled', 'false') + @patch('alibabacloud_credentials.provider.default.au.environment_credentials_uri', None) + def test_get_credentials_async_reuse_last_provider_disabled(self): + """ + Test case 9: Do not reuse last provider when reuse_last_provider_enabled is False + """ + env_provider = EnvironmentVariableCredentialsProvider() + env_provider.get_credentials_async = AsyncMock( + side_effect=CredentialException("EnvironmentVariableCredentialsProvider failed")) + + cli_provider = CLIProfileCredentialsProvider() + cli_provider.get_credentials_async = AsyncMock(return_value=self.credentials) + + with patch('alibabacloud_credentials.provider.default.EnvironmentVariableCredentialsProvider', + return_value=env_provider): + with patch('alibabacloud_credentials.provider.default.CLIProfileCredentialsProvider', + return_value=cli_provider): + provider = DefaultCredentialsProvider(reuse_last_provider_enabled=False) + + # First call to get_credentials + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + # Second call to get_credentials should not reuse the last provider + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.get_security_token(), self.security_token) + self.assertEqual(credentials.get_provider_name(), "default/test_provider") + + # Ensure get_credentials was called twice on the provider + self.assertEqual(env_provider.get_credentials_async.call_count, 2) diff --git a/tests/provider/test_ecs_ram_role.py b/tests/provider/test_ecs_ram_role.py new file mode 100644 index 0000000..27b0e93 --- /dev/null +++ b/tests/provider/test_ecs_ram_role.py @@ -0,0 +1,489 @@ +import unittest +from unittest.mock import patch, MagicMock, AsyncMock +import asyncio +import calendar +import time +import json +from alibabacloud_credentials.provider.ecs_ram_role import ( + EcsRamRoleCredentialsProvider, + CredentialException +) +from alibabacloud_credentials.http import HttpOptions +from Tea.core import TeaResponse + + +class TestEcsRamRoleCredentialsProvider(unittest.TestCase): + + def setUp(self): + self.role_name = "test_role_name" + self.disable_imds_v1 = False + self.http_options = HttpOptions(connect_timeout=5000, read_timeout=10000, proxy="test_proxy") + self.metadata_service_host = '100.100.100.200' + self.metadata_token_duration = 21600 + self.metadata_token = "test_metadata_token" + self.access_key_id = "test_access_key_id" + self.access_key_secret = "test_access_key_secret" + self.security_token = "test_security_token" + self.expiration = "2023-12-31T23:59:59Z" + self.response_body = json.dumps({ + "Code": "Success", + "AccessKeyId": self.access_key_id, + "AccessKeySecret": self.access_key_secret, + "SecurityToken": self.security_token, + "Expiration": self.expiration + }) + self.response = TeaResponse() + self.response.status_code = 200 + self.response.body = self.response_body.encode('utf-8') + + def test_init_valid_input(self): + """ + Test case 1: Valid input, successfully initializes with provided parameters + """ + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata_disabled', 'false'): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + disable_imds_v1=self.disable_imds_v1, + http_options=self.http_options, + async_update_enabled=False + ) + + self.assertEqual(provider._role_name, self.role_name) + self.assertEqual(provider._disable_imds_v1, self.disable_imds_v1) + self.assertEqual(provider._http_options, self.http_options) + self.assertEqual(provider._runtime_options['connectTimeout'], self.http_options.connect_timeout) + self.assertEqual(provider._runtime_options['readTimeout'], self.http_options.read_timeout) + self.assertEqual(provider._runtime_options['httpProxy'], self.http_options.proxy) + + def test_init_missing_role_name(self): + """ + Test case 2: Missing role_name raises ValueError + """ + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata_disabled', 'true'): + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata', ''): + with self.assertRaises(ValueError) as context: + EcsRamRoleCredentialsProvider( + disable_imds_v1=self.disable_imds_v1, + http_options=self.http_options + ) + + self.assertIn("IMDS credentials is disabled", str(context.exception)) + + def test_init_disable_metadata(self): + """ + Test case 4: Disable metadata + """ + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata_disabled', 'true'): + with self.assertRaises(ValueError) as context: + EcsRamRoleCredentialsProvider( + role_name=self.role_name, + disable_imds_v1=True, + http_options=self.http_options + ) + + self.assertIn("IMDS credentials is disabled", str(context.exception)) + + def test_get_credentials_valid_input(self): + """ + Test case 5: Valid input, successfully retrieves credentials + """ + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata_disabled', 'false'): + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata', self.role_name): + with patch( + 'alibabacloud_credentials.provider.ecs_ram_role.EcsRamRoleCredentialsProvider._get_metadata_token', + return_value=self.metadata_token): + with patch('Tea.core.TeaCore.do_action', return_value=self.response): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + disable_imds_v1=self.disable_imds_v1, + http_options=self.http_options + ) + + credentials = provider._refresh_credentials() + + self.assertEqual(credentials.value().get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.value().get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.value().get_security_token(), self.security_token) + self.assertEqual(credentials.value().get_expiration(), + calendar.timegm(time.strptime(self.expiration, '%Y-%m-%dT%H:%M:%SZ'))) + self.assertEqual(credentials.value().get_provider_name(), "ecs_ram_role") + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("No cached value was found.", str(context.exception)) + + def test_get_credentials_http_request_error(self): + """ + Test case 6: HTTP request error raises CredentialException + """ + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata_disabled', 'false'): + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata', self.role_name): + with patch( + 'alibabacloud_credentials.provider.ecs_ram_role.EcsRamRoleCredentialsProvider._get_metadata_token', + return_value=self.metadata_token): + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + disable_imds_v1=self.disable_imds_v1, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn( + "Failed to get RAM session credentials from ECS metadata service. HttpCode=400", + str(context.exception)) + + def test_get_credentials_response_format_error(self): + """ + Test case 7: Response format error raises CredentialException + """ + response_body = json.dumps({ + "Code": "Failure", + "Message": "Invalid request" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata_disabled', 'false'): + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata', self.role_name): + with patch( + 'alibabacloud_credentials.provider.ecs_ram_role.EcsRamRoleCredentialsProvider._get_metadata_token', + return_value=self.metadata_token): + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + disable_imds_v1=self.disable_imds_v1, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn('Failed to get RAM session credentials from ECS metadata service.', + str(context.exception)) + + def test_get_credentials_async_valid_input(self): + """ + Test case 8: Valid input, successfully retrieves credentials asynchronously + """ + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata_disabled', 'false'): + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata', self.role_name): + with patch( + 'alibabacloud_credentials.provider.ecs_ram_role.EcsRamRoleCredentialsProvider._get_metadata_token_async', + AsyncMock(return_value=self.metadata_token)): + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=self.response)): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + disable_imds_v1=self.disable_imds_v1, + http_options=self.http_options + ) + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider._refresh_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + + self.assertEqual(credentials.value().get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.value().get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.value().get_security_token(), self.security_token) + self.assertEqual(credentials.value().get_expiration(), + calendar.timegm(time.strptime(self.expiration, '%Y-%m-%dT%H:%M:%SZ'))) + self.assertEqual(credentials.value().get_provider_name(), "ecs_ram_role") + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + + self.assertIn("No cached value was found.", str(context.exception)) + + def test_get_credentials_async_http_request_error(self): + """ + Test case 9: HTTP request error raises CredentialException asynchronously + """ + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata_disabled', 'false'): + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata', self.role_name): + with patch( + 'alibabacloud_credentials.provider.ecs_ram_role.EcsRamRoleCredentialsProvider._get_metadata_token_async', + AsyncMock(return_value=self.metadata_token)): + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + disable_imds_v1=self.disable_imds_v1, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + + self.assertIn( + "Failed to get RAM session credentials from ECS metadata service. HttpCode=400", + str(context.exception)) + + def test_get_credentials_async_response_format_error(self): + """ + Test case 10: Response format error raises CredentialException asynchronously + """ + response_body = json.dumps({ + "Code": "Failure", + "Message": "Invalid request" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata_disabled', 'false'): + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata', self.role_name): + with patch( + 'alibabacloud_credentials.provider.ecs_ram_role.EcsRamRoleCredentialsProvider._get_metadata_token_async', + AsyncMock(return_value=self.metadata_token)): + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + disable_imds_v1=self.disable_imds_v1, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + + self.assertIn('Failed to get RAM session credentials from ECS metadata service.', + str(context.exception)) + + def test_get_metadata_token_valid_input(self): + """ + Test case 11: Valid input, successfully retrieves metadata token + """ + response_body = self.metadata_token + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_imds_v1_disabled', 'true'): + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + http_options=self.http_options + ) + + metadata_token = provider._get_metadata_token() + + self.assertEqual(metadata_token, self.metadata_token) + + def test_get_metadata_token_http_request_error(self): + """ + Test case 12: HTTP request error raises CredentialException + """ + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_imds_v1_disabled', 'true'): + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider._get_metadata_token() + + self.assertIn( + "Failed to get token from ECS Metadata Service. HttpCode=400", + str(context.exception)) + + def test_get_metadata_token_async_valid_input(self): + """ + Test case 13: Valid input, successfully retrieves metadata token asynchronously + """ + response_body = self.metadata_token + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_imds_v1_disabled', 'true'): + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + http_options=self.http_options + ) + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider._get_metadata_token_async() + ) + loop.run_until_complete(task) + metadata_token = task.result() + + self.assertEqual(metadata_token, self.metadata_token) + + def test_get_metadata_token_async_http_request_error(self): + """ + Test case 14: HTTP request error raises CredentialException asynchronously + """ + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_imds_v1_disabled', 'true'): + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider._get_metadata_token_async() + ) + loop.run_until_complete(task) + + self.assertIn( + "Failed to get token from ECS Metadata Service. HttpCode=400", + str(context.exception)) + + def test_get_role_name_valid_input(self): + """ + Test case 15: Valid input, successfully retrieves role name + """ + response_body = self.role_name + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata_disabled', 'false'): + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + disable_imds_v1=self.disable_imds_v1, + http_options=self.http_options + ) + + role_name = provider._get_role_name() + + self.assertEqual(role_name, self.role_name) + + def test_get_role_name_http_request_error(self): + """ + Test case 16: HTTP request error raises CredentialException + """ + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata_disabled', 'false'): + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + disable_imds_v1=self.disable_imds_v1, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider._get_role_name() + + self.assertIn( + "Failed to get RAM session credentials from ECS metadata service. HttpCode=400", + str(context.exception)) + + def test_get_role_name_async_valid_input(self): + """ + Test case 17: Valid input, successfully retrieves role name asynchronously + """ + response_body = self.role_name + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata_disabled', 'false'): + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + disable_imds_v1=self.disable_imds_v1, + http_options=self.http_options + ) + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider._get_role_name_async() + ) + loop.run_until_complete(task) + role_name = task.result() + + self.assertEqual(role_name, self.role_name) + + def test_get_role_name_async_http_request_error(self): + """ + Test case 18: HTTP request error raises CredentialException asynchronously + """ + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('alibabacloud_credentials.provider.ecs_ram_role.au.environment_ecs_metadata_disabled', 'false'): + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + provider = EcsRamRoleCredentialsProvider( + role_name=self.role_name, + disable_imds_v1=self.disable_imds_v1, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider._get_role_name_async() + ) + loop.run_until_complete(task) + + self.assertIn( + "Failed to get RAM session credentials from ECS metadata service. HttpCode=400", + str(context.exception)) + + def test_ecs_ram_role_provider_methods(self): + provider = EcsRamRoleCredentialsProvider() + + # Test _get_stale_time + stale_time = provider._get_stale_time(1000) + self.assertEqual(stale_time, 1000 - 15 * 60) + + stale_time = provider._get_stale_time(-1) + current_time = int(time.mktime(time.localtime())) + self.assertAlmostEqual(stale_time, current_time + 60 * 60, delta=1) + + # Test _get_prefetch_time + prefetch_time = provider._get_prefetch_time(1000) + current_time = int(time.mktime(time.localtime())) + self.assertAlmostEqual(prefetch_time, current_time + 60 * 60, delta=1) + + prefetch_time = provider._get_prefetch_time(-1) + current_time = int(time.mktime(time.localtime())) + self.assertAlmostEqual(prefetch_time, current_time + 5 * 60, delta=1) + + # Test get_provider_name + provider_name = provider.get_provider_name() + self.assertEqual(provider_name, 'ecs_ram_role') diff --git a/tests/provider/test_env.py b/tests/provider/test_env.py new file mode 100644 index 0000000..53761eb --- /dev/null +++ b/tests/provider/test_env.py @@ -0,0 +1,175 @@ +import unittest +from unittest.mock import patch, MagicMock +import asyncio +from alibabacloud_credentials.provider import EnvironmentVariableCredentialsProvider +from alibabacloud_credentials.exceptions import CredentialException + + +class TestEnvironmentVariableCredentialsProvider(unittest.TestCase): + + @patch('alibabacloud_credentials.provider.env.auth_util') + def test_get_credentials_valid_input(self, mock_auth_util): + """ + Test case 1: Valid input, successfully retrieves credentials + """ + # Set mock object return values + mock_auth_util.environment_access_key_id = "test_access_key_id" + mock_auth_util.environment_access_key_secret = "test_access_key_secret" + mock_auth_util.environment_security_token = "test_security_token" + + provider = EnvironmentVariableCredentialsProvider() + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.get_access_key_secret(), "test_access_key_secret") + self.assertEqual(credentials.get_security_token(), "test_security_token") + self.assertEqual(credentials.get_provider_name(), "env") + + @patch('alibabacloud_credentials.provider.env.auth_util') + def test_get_credentials_missing_access_key_id(self, mock_auth_util): + """ + Test case 2: Missing environment variable accessKeyId raises CredentialException + """ + mock_auth_util.environment_access_key_id = None + mock_auth_util.environment_access_key_secret = "test_access_key_secret" + mock_auth_util.environment_security_token = "test_security_token" + + provider = EnvironmentVariableCredentialsProvider() + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("Environment variable accessKeyId cannot be empty", str(context.exception)) + + @patch('alibabacloud_credentials.provider.env.auth_util') + def test_get_credentials_empty_access_key_id(self, mock_auth_util): + """ + Test case 3: Empty environment variable accessKeyId raises CredentialException + """ + mock_auth_util.environment_access_key_id = "" + mock_auth_util.environment_access_key_secret = "test_access_key_secret" + mock_auth_util.environment_security_token = "test_security_token" + + provider = EnvironmentVariableCredentialsProvider() + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("Environment variable accessKeyId cannot be empty", str(context.exception)) + + @patch('alibabacloud_credentials.provider.env.auth_util') + def test_get_credentials_missing_access_key_secret(self, mock_auth_util): + """ + Test case 4: Missing environment variable accessKeySecret raises CredentialException + """ + mock_auth_util.environment_access_key_id = "test_access_key_id" + mock_auth_util.environment_access_key_secret = None + mock_auth_util.environment_security_token = "test_security_token" + + provider = EnvironmentVariableCredentialsProvider() + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("Environment variable accessKeySecret cannot be empty", str(context.exception)) + + @patch('alibabacloud_credentials.provider.env.auth_util') + def test_get_credentials_empty_access_key_secret(self, mock_auth_util): + """ + Test case 5: Empty environment variable accessKeySecret raises CredentialException + """ + mock_auth_util.environment_access_key_id = "test_access_key_id" + mock_auth_util.environment_access_key_secret = "" + mock_auth_util.environment_security_token = "test_security_token" + + provider = EnvironmentVariableCredentialsProvider() + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("Environment variable accessKeySecret cannot be empty", str(context.exception)) + + @patch('alibabacloud_credentials.provider.env.auth_util') + def test_get_credentials_async_valid_input(self, mock_auth_util): + """ + Test case 6: Valid input, successfully retrieves credentials asynchronously + """ + # Set mock object return values + mock_auth_util.environment_access_key_id = "test_access_key_id" + mock_auth_util.environment_access_key_secret = "test_access_key_secret" + mock_auth_util.environment_security_token = "test_security_token" + + provider = EnvironmentVariableCredentialsProvider() + + # Use asyncio.run to execute the async function + credentials = asyncio.run(provider.get_credentials_async()) + + self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.get_access_key_secret(), "test_access_key_secret") + self.assertEqual(credentials.get_security_token(), "test_security_token") + self.assertEqual(credentials.get_provider_name(), "env") + + @patch('alibabacloud_credentials.provider.env.auth_util') + def test_get_credentials_async_missing_access_key_id(self, mock_auth_util): + """ + Test case 7: Missing environment variable accessKeyId raises CredentialException asynchronously + """ + mock_auth_util.environment_access_key_id = None + mock_auth_util.environment_access_key_secret = "test_access_key_secret" + mock_auth_util.environment_security_token = "test_security_token" + + provider = EnvironmentVariableCredentialsProvider() + + with self.assertRaises(CredentialException) as context: + asyncio.run(provider.get_credentials_async()) + + self.assertIn("Environment variable accessKeyId cannot be empty", str(context.exception)) + + @patch('alibabacloud_credentials.provider.env.auth_util') + def test_get_credentials_async_empty_access_key_id(self, mock_auth_util): + """ + Test case 8: Empty environment variable accessKeyId raises CredentialException asynchronously + """ + mock_auth_util.environment_access_key_id = "" + mock_auth_util.environment_access_key_secret = "test_access_key_secret" + mock_auth_util.environment_security_token = "test_security_token" + + provider = EnvironmentVariableCredentialsProvider() + + with self.assertRaises(CredentialException) as context: + asyncio.run(provider.get_credentials_async()) + + self.assertIn("Environment variable accessKeyId cannot be empty", str(context.exception)) + + @patch('alibabacloud_credentials.provider.env.auth_util') + def test_get_credentials_async_missing_access_key_secret(self, mock_auth_util): + """ + Test case 9: Missing environment variable accessKeySecret raises CredentialException asynchronously + """ + mock_auth_util.environment_access_key_id = "test_access_key_id" + mock_auth_util.environment_access_key_secret = None + mock_auth_util.environment_security_token = "test_security_token" + + provider = EnvironmentVariableCredentialsProvider() + + with self.assertRaises(CredentialException) as context: + asyncio.run(provider.get_credentials_async()) + + self.assertIn("Environment variable accessKeySecret cannot be empty", str(context.exception)) + + @patch('alibabacloud_credentials.provider.env.auth_util') + def test_get_credentials_async_empty_access_key_secret(self, mock_auth_util): + """ + Test case 10: Empty environment variable accessKeySecret raises CredentialException asynchronously + """ + mock_auth_util.environment_access_key_id = "test_access_key_id" + mock_auth_util.environment_access_key_secret = "" + mock_auth_util.environment_security_token = "test_security_token" + + provider = EnvironmentVariableCredentialsProvider() + + with self.assertRaises(CredentialException) as context: + asyncio.run(provider.get_credentials_async()) + + self.assertIn("Environment variable accessKeySecret cannot be empty", str(context.exception)) diff --git a/tests/provider/test_oidc.py b/tests/provider/test_oidc.py new file mode 100644 index 0000000..50cd4c5 --- /dev/null +++ b/tests/provider/test_oidc.py @@ -0,0 +1,592 @@ +import unittest +from unittest.mock import patch, AsyncMock +import asyncio +import calendar +import time +import json +from alibabacloud_credentials.provider.oidc import ( + OIDCRoleArnCredentialsProvider, + CredentialException +) +from alibabacloud_credentials.http import HttpOptions +from Tea.core import TeaResponse + + +class TestOIDCRoleArnCredentialsProvider(unittest.TestCase): + + def setUp(self): + self.role_arn = "test_role_arn" + self.oidc_provider_arn = "test_oidc_provider_arn" + self.oidc_token_file_path = "test_oidc_token_file_path" + self.role_session_name = "test_role_session_name" + self.duration_seconds = 3600 + self.policy = "test_policy" + self.sts_region_id = "test_sts_region_id" + self.sts_endpoint = "test_sts_endpoint" + self.enable_vpc = True + self.http_options = HttpOptions(connect_timeout=5000, read_timeout=10000, proxy="test_proxy") + + def test_init_valid_input(self): + """ + Test case 1: Valid input, successfully initializes with provided parameters + """ + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + sts_region_id=self.sts_region_id, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + self.assertEqual(provider._role_arn, self.role_arn) + self.assertEqual(provider._oidc_provider_arn, self.oidc_provider_arn) + self.assertEqual(provider._oidc_token_file_path, self.oidc_token_file_path) + self.assertEqual(provider._role_session_name, self.role_session_name) + self.assertEqual(provider._duration_seconds, self.duration_seconds) + self.assertEqual(provider._policy, self.policy) + self.assertEqual(provider._sts_endpoint, self.sts_endpoint) + self.assertEqual(provider._http_options, self.http_options) + self.assertEqual(provider._runtime_options['connectTimeout'], self.http_options.connect_timeout) + self.assertEqual(provider._runtime_options['readTimeout'], self.http_options.read_timeout) + self.assertEqual(provider._runtime_options['httpsProxy'], self.http_options.proxy) + + @patch('alibabacloud_credentials.provider.oidc.au') + def test_init_valid_environment_variables(self, mock_auth_util): + """ + Test case 2: Valid input, successfully initializes with environment variables + """ + mock_auth_util.environment_role_arn = self.role_arn + mock_auth_util.environment_oidc_provider_arn = self.oidc_provider_arn + mock_auth_util.environment_oidc_token_file = self.oidc_token_file_path + mock_auth_util.environment_role_session_name = self.role_session_name + mock_auth_util.environment_enable_vpc = str(self.enable_vpc) + mock_auth_util.environment_sts_region = self.sts_region_id + + provider = OIDCRoleArnCredentialsProvider() + + self.assertEqual(provider._role_arn, self.role_arn) + self.assertEqual(provider._oidc_provider_arn, self.oidc_provider_arn) + self.assertEqual(provider._oidc_token_file_path, self.oidc_token_file_path) + self.assertEqual(provider._role_session_name, self.role_session_name) + self.assertEqual(provider._duration_seconds, OIDCRoleArnCredentialsProvider.DEFAULT_DURATION_SECONDS) + self.assertIsNone(provider._policy) + self.assertEqual(provider._sts_endpoint, f'sts-vpc.{self.sts_region_id}.aliyuncs.com') + self.assertEqual(provider._runtime_options['connectTimeout'], + OIDCRoleArnCredentialsProvider.DEFAULT_CONNECT_TIMEOUT) + self.assertEqual(provider._runtime_options['readTimeout'], OIDCRoleArnCredentialsProvider.DEFAULT_READ_TIMEOUT) + self.assertIsNone(provider._runtime_options['httpsProxy']) + + def test_init_missing_role_arn(self): + """ + Test case 3: Missing role_arn raises ValueError + """ + with self.assertRaises(ValueError) as context: + OIDCRoleArnCredentialsProvider( + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path + ) + + self.assertIn("role_arn or environment variable ALIBABA_CLOUD_ROLE_ARN cannot be empty", str(context.exception)) + + def test_init_empty_role_arn(self): + """ + Test case 4: Empty role_arn raises ValueError + """ + with self.assertRaises(ValueError) as context: + OIDCRoleArnCredentialsProvider( + role_arn="", + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path + ) + + self.assertIn("role_arn or environment variable ALIBABA_CLOUD_ROLE_ARN cannot be empty", str(context.exception)) + + def test_init_missing_oidc_provider_arn(self): + """ + Test case 5: Missing oidc_provider_arn raises ValueError + """ + with self.assertRaises(ValueError) as context: + OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_token_file_path=self.oidc_token_file_path + ) + + self.assertIn("oidc_provider_arn or environment variable ALIBABA_CLOUD_OIDC_PROVIDER_ARN cannot be empty", + str(context.exception)) + + def test_init_empty_oidc_provider_arn(self): + """ + Test case 6: Empty oidc_provider_arn raises ValueError + """ + with self.assertRaises(ValueError) as context: + OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn="", + oidc_token_file_path=self.oidc_token_file_path + ) + + self.assertIn("oidc_provider_arn or environment variable ALIBABA_CLOUD_OIDC_PROVIDER_ARN cannot be empty", + str(context.exception)) + + def test_init_missing_oidc_token_file_path(self): + """ + Test case 7: Missing oidc_token_file_path raises ValueError + """ + with self.assertRaises(ValueError) as context: + OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn + ) + + self.assertIn("oidc_token_file_path or environment variable ALIBABA_CLOUD_OIDC_TOKEN_FILE cannot be empty", + str(context.exception)) + + def test_init_empty_oidc_token_file_path(self): + """ + Test case 8: Empty oidc_token_file_path raises ValueError + """ + with self.assertRaises(ValueError) as context: + OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path="" + ) + + self.assertIn("oidc_token_file_path or environment variable ALIBABA_CLOUD_OIDC_TOKEN_FILE cannot be empty", + str(context.exception)) + + def test_init_duration_seconds_too_short(self): + """ + Test case 9: Duration seconds less than 900 raises ValueError + """ + with self.assertRaises(ValueError) as context: + OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + duration_seconds=800 + ) + + self.assertIn("session duration should be in the range of 900s - max session duration", str(context.exception)) + + @patch('alibabacloud_credentials.provider.oidc.au') + def test_init_default_values(self, mock_auth_util): + """ + Test case 10: Initializes with default values + """ + mock_auth_util.environment_role_arn = self.role_arn + mock_auth_util.environment_oidc_provider_arn = self.oidc_provider_arn + mock_auth_util.environment_oidc_token_file = self.oidc_token_file_path + mock_auth_util.environment_role_session_name = None + mock_auth_util.environment_enable_vpc = 'false' + mock_auth_util.environment_sts_region = None + + provider = OIDCRoleArnCredentialsProvider() + + self.assertEqual(provider._role_arn, self.role_arn) + self.assertEqual(provider._oidc_provider_arn, self.oidc_provider_arn) + self.assertEqual(provider._oidc_token_file_path, self.oidc_token_file_path) + self.assertTrue(provider._role_session_name.startswith('credentials-python-')) + self.assertEqual(provider._duration_seconds, OIDCRoleArnCredentialsProvider.DEFAULT_DURATION_SECONDS) + self.assertIsNone(provider._policy) + self.assertEqual(provider._sts_endpoint, 'sts.aliyuncs.com') + self.assertEqual(provider._runtime_options['connectTimeout'], + OIDCRoleArnCredentialsProvider.DEFAULT_CONNECT_TIMEOUT) + self.assertEqual(provider._runtime_options['readTimeout'], OIDCRoleArnCredentialsProvider.DEFAULT_READ_TIMEOUT) + self.assertIsNone(provider._runtime_options['httpsProxy']) + + def test_get_credentials_valid_input(self): + """ + Test case 11: Valid input, successfully retrieves credentials + """ + token = "test_token" + response_body = json.dumps({ + "Credentials": { + "AccessKeyId": "test_access_key_id", + "AccessKeySecret": "test_access_key_secret", + "SecurityToken": "test_security_token", + "Expiration": "2023-12-31T23:59:59Z" + } + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('alibabacloud_credentials.provider.oidc._get_token', return_value=token): + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + credentials = provider._refresh_credentials() + + self.assertEqual(credentials.value().get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.value().get_access_key_secret(), "test_access_key_secret") + self.assertEqual(credentials.value().get_security_token(), "test_security_token") + self.assertEqual(credentials.value().get_expiration(), + calendar.timegm(time.strptime("2023-12-31T23:59:59Z", '%Y-%m-%dT%H:%M:%SZ'))) + self.assertEqual(credentials.value().get_provider_name(), "oidc_role_arn") + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("No cached value was found.", str(context.exception)) + + def test_get_credentials_file_read_error(self): + """ + Test case 12: File read error raises CredentialException + """ + with patch('alibabacloud_credentials.provider.oidc._get_token', side_effect=FileNotFoundError): + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + with self.assertRaises(FileNotFoundError) as context: + provider.get_credentials() + + def test_get_credentials_http_request_error(self): + """ + Test case 13: HTTP request error raises CredentialException + """ + token = "test_token" + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('alibabacloud_credentials.provider.oidc._get_token', return_value=token): + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn( + "error refreshing credentials from oidc_role_arn, http_code: 400, result: HTTP request failed", + str(context.exception)) + + def test_get_credentials_response_format_error(self): + """ + Test case 14: Response format error raises CredentialException + """ + token = "test_token" + response_body = json.dumps({ + "Error": "Invalid request" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('alibabacloud_credentials.provider.oidc._get_token', return_value=token): + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn('error retrieving credentials from oidc_role_arn result: {"Error": "Invalid request"}', + str(context.exception)) + + def test_get_credentials_async_valid_input(self): + """ + Test case 15: Valid input, successfully retrieves credentials asynchronously + """ + token = "test_token" + response_body = json.dumps({ + "Credentials": { + "AccessKeyId": "test_access_key_id", + "AccessKeySecret": "test_access_key_secret", + "SecurityToken": "test_security_token", + "Expiration": "2023-12-31T23:59:59Z" + } + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('alibabacloud_credentials.provider.oidc._get_token_async', AsyncMock(return_value=token)): + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider._refresh_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + + self.assertEqual(credentials.value().get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.value().get_access_key_secret(), "test_access_key_secret") + self.assertEqual(credentials.value().get_security_token(), "test_security_token") + self.assertEqual(credentials.value().get_expiration(), + calendar.timegm(time.strptime("2023-12-31T23:59:59Z", '%Y-%m-%dT%H:%M:%SZ'))) + self.assertEqual(credentials.value().get_provider_name(), "oidc_role_arn") + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + + self.assertIn("No cached value was found.", str(context.exception)) + + def test_get_credentials_async_file_read_error(self): + """ + Test case 16: File read error raises CredentialException asynchronously + """ + with patch('alibabacloud_credentials.provider.oidc._get_token_async', AsyncMock(side_effect=FileNotFoundError)): + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + with self.assertRaises(FileNotFoundError) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + + def test_get_credentials_async_http_request_error(self): + """ + Test case 17: HTTP request error raises CredentialException asynchronously + """ + token = "test_token" + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('alibabacloud_credentials.provider.oidc._get_token_async', AsyncMock(return_value=token)): + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + + self.assertIn( + "error refreshing credentials from oidc_role_arn, http_code: 400, result: HTTP request failed", + str(context.exception)) + + def test_get_credentials_async_response_format_error(self): + """ + Test case 18: Response format error raises CredentialException asynchronously + """ + token = "test_token" + response_body = json.dumps({ + "Error": "Invalid request" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('alibabacloud_credentials.provider.oidc._get_token_async', AsyncMock(return_value=token)): + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + + self.assertIn('error retrieving credentials from oidc_role_arn result: {"Error": "Invalid request"}', + str(context.exception)) + + @patch('alibabacloud_credentials.provider.oidc.au.environment_enable_vpc', 'true') + @patch('alibabacloud_credentials.provider.oidc.au.environment_sts_region', 'test_env_sts_region') + def test_sts_endpoint_with_sts_region_id_and_enable_vpc_true(self): + """ + Test case 19: sts_region_id is provided and enable_vpc is True + """ + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + sts_region_id=self.sts_region_id, + enable_vpc=True, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, f'sts-vpc.{self.sts_region_id}.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.oidc.au.environment_enable_vpc', 'false') + @patch('alibabacloud_credentials.provider.oidc.au.environment_sts_region', 'test_env_sts_region') + def test_sts_endpoint_with_sts_region_id_and_enable_vpc_false(self): + """ + Test case 20: sts_region_id is provided and enable_vpc is False + """ + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + sts_region_id=self.sts_region_id, + enable_vpc=False, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, f'sts.{self.sts_region_id}.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.oidc.au.environment_enable_vpc', 'true') + @patch('alibabacloud_credentials.provider.oidc.au.environment_sts_region', 'test_env_sts_region') + def test_sts_endpoint_with_environment_sts_region_and_enable_vpc_true(self): + """ + Test case 21: sts_region_id is not provided, environment_sts_region is provided, and enable_vpc is True + """ + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + enable_vpc=True, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, f'sts-vpc.test_env_sts_region.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.oidc.au.environment_enable_vpc', 'false') + @patch('alibabacloud_credentials.provider.oidc.au.environment_sts_region', 'test_env_sts_region') + def test_sts_endpoint_with_environment_sts_region_and_enable_vpc_false(self): + """ + Test case 22: sts_region_id is not provided, environment_sts_region is provided, and enable_vpc is False + """ + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + enable_vpc=False, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, f'sts.test_env_sts_region.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.oidc.au.environment_enable_vpc', 'true') + @patch('alibabacloud_credentials.provider.oidc.au.environment_sts_region', None) + def test_sts_endpoint_with_no_sts_region_id_or_environment_sts_region_and_enable_vpc_true(self): + """ + Test case 23: sts_region_id and environment_sts_region are not provided, and enable_vpc is True + """ + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + enable_vpc=True, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, 'sts.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.oidc.au.environment_enable_vpc', 'false') + @patch('alibabacloud_credentials.provider.oidc.au.environment_sts_region', None) + def test_sts_endpoint_with_no_sts_region_id_or_environment_sts_region_and_enable_vpc_false(self): + """ + Test case 24: sts_region_id and environment_sts_region are not provided, and enable_vpc is False + """ + provider = OIDCRoleArnCredentialsProvider( + role_arn=self.role_arn, + oidc_provider_arn=self.oidc_provider_arn, + oidc_token_file_path=self.oidc_token_file_path, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + enable_vpc=False, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, 'sts.aliyuncs.com') diff --git a/tests/provider/test_profile.py b/tests/provider/test_profile.py new file mode 100644 index 0000000..d9e9d33 --- /dev/null +++ b/tests/provider/test_profile.py @@ -0,0 +1,374 @@ +import unittest +from unittest.mock import patch, MagicMock, AsyncMock +import asyncio +import os +from alibabacloud_credentials.provider.profile import ( + ProfileCredentialsProvider, + CredentialException +) +from alibabacloud_credentials.utils import auth_util as au +from alibabacloud_credentials.utils import auth_constant as ac + + +class TestProfileCredentialsProvider(unittest.TestCase): + + def setUp(self): + self.profile_name = "default" + self.profile_file = os.path.join(ac.HOME, "/.alibabacloud/credentials.ini") + self.config = {'default': { + 'type': 'access_key', + 'access_key_id': 'test_access_key_id', + 'access_key_secret': 'test_access_key_secret' + }, 'ram_role_profile': { + 'type': 'ram_role_arn', + 'access_key_id': 'test_access_key_id', + 'access_key_secret': 'test_access_key_secret', + 'role_arn': 'test_ram_role_arn', + 'role_session_name': 'test_ram_session_name', + 'policy': 'test_policy' + }, 'oidc_profile': { + 'type': 'oidc_role_arn', + 'role_arn': 'test_ram_role_arn', + 'oidc_provider_arn': 'test_oidc_provider_arn', + 'oidc_token_file_path': 'test_oidc_token_file_path', + 'role_session_name': 'test_role_session_name', + 'policy': 'test_policy' + }, 'ecs_ram_role_profile': { + 'type': 'ecs_ram_role', + 'role_name': 'test_ram_role_name' + }, 'rsa_key_pair_profile': { + 'type': 'rsa_key_pair', + 'public_key_id': 'test_public_key_id', + 'private_key_file': 'test_private_key_file' + }} + self.access_key_id = "test_access_key_id" + self.access_key_secret = "test_access_key_secret" + self.security_token = "test_security_token" + self.expiration = "2023-12-31T23:59:59Z" + self.response_body = { + "AccessKeyId": self.access_key_id, + "AccessKeySecret": self.access_key_secret, + "SecurityToken": self.security_token, + "Expiration": self.expiration + } + + def test_init_valid_input(self): + """ + Test case 1: Valid input, successfully initializes with provided parameters + """ + with patch('alibabacloud_credentials.provider.profile.au.environment_credentials_file', self.profile_file): + provider = ProfileCredentialsProvider() + + self.assertEqual(provider._profile_file, self.profile_file) + self.assertEqual(provider._profile_name, au.client_type) + + def test_get_credentials_valid_access_key(self): + """ + Test case 2: Valid input, successfully retrieves credentials for access_key type + """ + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini', return_value=self.config): + provider = ProfileCredentialsProvider(profile_name=self.profile_name) + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertIsNone(credentials.get_security_token()) + self.assertEqual(credentials.get_provider_name(), "profile/static_ak") + + def test_get_credentials_valid_ram_role_arn(self): + """ + Test case 3: Valid input, successfully retrieves credentials for ram_role_arn type + """ + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini', return_value=self.config): + provider = ProfileCredentialsProvider(profile_name="ram_role_profile") + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("error refreshing credentials from ram_role_arn", str(context.exception)) + + def test_get_credentials_valid_oidc_role_arn(self): + """ + Test case 4: Valid input, successfully retrieves credentials for oidc_role_arn type + """ + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini', return_value=self.config): + with patch('alibabacloud_credentials.provider.oidc._get_token', return_value='test_token'): + provider = ProfileCredentialsProvider(profile_name="oidc_profile") + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("error refreshing credentials from oidc_role_arn", str(context.exception)) + + def test_get_credentials_valid_ecs_ram_role(self): + """ + Test case 5: Valid input, successfully retrieves credentials for ecs_ram_role type + """ + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini', return_value=self.config): + provider = ProfileCredentialsProvider(profile_name="ecs_ram_role_profile") + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("Failed to get RAM session credentials from ECS metadata service", + str(context.exception)) + + def test_get_credentials_valid_rsa_key_pair(self): + """ + Test case 6: Valid input, successfully retrieves credentials for rsa_key_pair type + """ + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini', return_value=self.config): + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value='test_content'): + provider = ProfileCredentialsProvider(profile_name="rsa_key_pair_profile") + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("error refreshing credentials from rsa_key_pair", str(context.exception)) + + def test_get_credentials_profile_file_not_exists(self): + """ + Test case 7: Profile file does not exist raises CredentialException + """ + with patch('os.path.exists', return_value=False): + provider = ProfileCredentialsProvider(profile_name=self.profile_name) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn(f'failed to get credential from credentials file: ${self.profile_file}', + str(context.exception)) + + def test_get_credentials_profile_file_not_file(self): + """ + Test case 8: Profile file is not a file raises CredentialException + """ + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=False): + provider = ProfileCredentialsProvider(profile_name=self.profile_name) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn(f'failed to get credential from credentials file: ${self.profile_file}', + str(context.exception)) + + def test_get_credentials_invalid_config_type(self): + """ + Test case 9: Invalid config type raises CredentialException + """ + invalid_config = {'default': { + 'type': 'invalid_type', + 'access_key_id': 'test_access_key_id', + 'access_key_secret': 'test_access_key_secret' + }} + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini', return_value=invalid_config): + provider = ProfileCredentialsProvider(profile_name=self.profile_name) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn(f'unsupported credential type invalid_type from credentials file {self.profile_file}', + str(context.exception)) + + def test_get_credentials_missing_access_key_id(self): + """ + Test case 10: Missing access_key_id raises CredentialException + """ + missing_access_key_id_config = {'default': { + 'type': 'access_key', + 'access_key_secret': 'test_access_key_secret' + }} + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini', + return_value=missing_access_key_id_config): + provider = ProfileCredentialsProvider(profile_name=self.profile_name) + + with self.assertRaises(ValueError) as context: + provider.get_credentials() + + self.assertIn('the access key id is empty', str(context.exception)) + + def test_get_credentials_missing_access_key_secret(self): + """ + Test case 11: Missing access_key_secret raises CredentialException + """ + missing_access_key_secret_config = {'default': { + 'type': 'access_key', + 'access_key_id': 'test_access_key_id' + }} + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini', + return_value=missing_access_key_secret_config): + provider = ProfileCredentialsProvider(profile_name=self.profile_name) + + with self.assertRaises(ValueError) as context: + provider.get_credentials() + + self.assertIn('the access key secret is empty', + str(context.exception)) + + def test_get_credentials_missing_role_arn(self): + """ + Test case 12: Missing role_arn raises CredentialException + """ + missing_role_arn_config = {'ram_role_profile': { + 'type': 'ram_role_arn', + 'access_key_id': 'test_access_key_id', + 'access_key_secret': 'test_access_key_secret', + 'role_session_name': 'test_ram_session_name', + 'policy': 'test_policy' + }} + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini', return_value=missing_role_arn_config): + provider = ProfileCredentialsProvider(profile_name="ram_role_profile") + + with self.assertRaises(ValueError) as context: + provider.get_credentials() + + self.assertIn('role_arn or environment variable ALIBABA_CLOUD_ROLE_ARN cannot be empty', + str(context.exception)) + + def test_get_credentials_missing_oidc_provider_arn(self): + """ + Test case 13: Missing oidc_provider_arn raises CredentialException + """ + missing_oidc_provider_arn_config = {'oidc_profile': { + 'type': 'oidc_role_arn', + 'role_arn': 'test_ram_role_arn', + 'oidc_token_file_path': 'test_oidc_token_file_path', + 'role_session_name': 'test_role_session_name', + 'policy': 'test_policy' + }} + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini', + return_value=missing_oidc_provider_arn_config): + provider = ProfileCredentialsProvider(profile_name="oidc_profile") + + with self.assertRaises(ValueError) as context: + provider.get_credentials() + + self.assertIn( + 'oidc_provider_arn or environment variable ALIBABA_CLOUD_OIDC_PROVIDER_ARN cannot be empty', + str(context.exception)) + + def test_get_credentials_missing_oidc_token_file_path(self): + """ + Test case 14: Missing oidc_token_file_path raises CredentialException + """ + missing_oidc_token_file_path_config = {'oidc_profile': { + 'type': 'oidc_role_arn', + 'role_arn': 'test_ram_role_arn', + 'oidc_provider_arn': 'test_oidc_provider_arn', + 'role_session_name': 'test_role_session_name', + 'policy': 'test_policy' + }} + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini', + return_value=missing_oidc_token_file_path_config): + provider = ProfileCredentialsProvider(profile_name="oidc_profile") + + with self.assertRaises(ValueError) as context: + provider.get_credentials() + + self.assertIn( + 'oidc_token_file_path or environment variable ALIBABA_CLOUD_OIDC_TOKEN_FILE cannot be empty', + str(context.exception)) + + def test_get_credentials_missing_role_name(self): + """ + Test case 15: Missing role_name raises CredentialException + """ + missing_role_name_config = {'ecs_ram_role_profile': { + 'type': 'ecs_ram_role' + }} + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini', + return_value=missing_role_name_config): + provider = ProfileCredentialsProvider(profile_name="ecs_ram_role_profile") + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn('Failed to get RAM session credentials from ECS metadata service', + str(context.exception)) + + def test_get_credentials_missing_public_key_id(self): + """ + Test case 16: Missing public_key_id raises CredentialException + """ + missing_public_key_id_config = {'rsa_key_pair_profile': { + 'type': 'rsa_key_pair', + 'private_key_file': 'test_private_key_file' + }} + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini', + return_value=missing_public_key_id_config): + provider = ProfileCredentialsProvider(profile_name="rsa_key_pair_profile") + + with self.assertRaises(ValueError) as context: + provider.get_credentials() + + self.assertIn('public_key_id cannot be empty', + str(context.exception)) + + def test_get_credentials_missing_private_key_file(self): + """ + Test case 17: Missing private_key_file raises CredentialException + """ + missing_private_key_file_config = {'rsa_key_pair_profile': { + 'type': 'rsa_key_pair', + 'public_key_id': 'test_public_key_id' + }} + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini', + return_value=missing_private_key_file_config): + provider = ProfileCredentialsProvider(profile_name="rsa_key_pair_profile") + + with self.assertRaises(ValueError) as context: + provider.get_credentials() + + self.assertIn('private_key_file cannot be empty', + str(context.exception)) + + def test_get_credentials_async_valid_access_key(self): + """ + Test case 18: Valid input, successfully retrieves credentials for access_key type asynchronously + """ + with patch('os.path.exists', return_value=True): + with patch('os.path.isfile', return_value=True): + with patch('alibabacloud_credentials.provider.profile._load_ini_async', + AsyncMock(return_value=self.config)): + provider = ProfileCredentialsProvider(profile_name=self.profile_name) + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + + self.assertEqual(credentials.get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.get_access_key_secret(), self.access_key_secret) + self.assertIsNone(credentials.get_security_token()) diff --git a/tests/provider/test_ram_role_arn.py b/tests/provider/test_ram_role_arn.py new file mode 100644 index 0000000..90117c1 --- /dev/null +++ b/tests/provider/test_ram_role_arn.py @@ -0,0 +1,513 @@ +import unittest +from unittest.mock import patch, MagicMock, AsyncMock +import asyncio +import calendar +import time +import json +from alibabacloud_credentials.provider.ram_role_arn import ( + RamRoleArnCredentialsProvider, + CredentialException +) +from alibabacloud_credentials.http import HttpOptions +from Tea.core import TeaResponse + + +class TestRamRoleArnCredentialsProvider(unittest.TestCase): + + def setUp(self): + self.access_key_id = "test_access_key_id" + self.access_key_secret = "test_access_key_secret" + self.security_token = "test_security_token" + self.role_arn = "test_role_arn" + self.role_session_name = "test_role_session_name" + self.duration_seconds = 3600 + self.policy = "test_policy" + self.external_id = "test_external_id" + self.sts_region_id = "test_sts_region_id" + self.sts_endpoint = "test_sts_endpoint" + self.enable_vpc = True + self.http_options = HttpOptions(connect_timeout=5000, read_timeout=10000, proxy="test_proxy") + + def test_init_valid_input(self): + """ + Test case 1: Valid input, successfully initializes with provided parameters + """ + provider = RamRoleArnCredentialsProvider( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + role_arn=self.role_arn, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + external_id=self.external_id, + sts_region_id=self.sts_region_id, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + self.assertEqual(provider._credentials_provider.get_credentials().get_access_key_id(), self.access_key_id) + self.assertEqual(provider._credentials_provider.get_credentials().get_access_key_secret(), + self.access_key_secret) + self.assertEqual(provider._credentials_provider.get_credentials().get_security_token(), self.security_token) + self.assertEqual(provider._role_arn, self.role_arn) + self.assertEqual(provider._role_session_name, self.role_session_name) + self.assertEqual(provider._duration_seconds, self.duration_seconds) + self.assertEqual(provider._policy, self.policy) + self.assertEqual(provider._external_id, self.external_id) + self.assertEqual(provider._sts_endpoint, self.sts_endpoint) + self.assertEqual(provider._http_options, self.http_options) + self.assertEqual(provider._runtime_options['connectTimeout'], self.http_options.connect_timeout) + self.assertEqual(provider._runtime_options['readTimeout'], self.http_options.read_timeout) + self.assertEqual(provider._runtime_options['httpsProxy'], self.http_options.proxy) + + @patch('alibabacloud_credentials.provider.static_ak.auth_util') + @patch('alibabacloud_credentials.provider.ram_role_arn.au') + def test_init_valid_environment_variables(self, mock_ram_util, mock_ak_util): + """ + Test case 2: Valid input, successfully initializes with environment variables + """ + mock_ak_util.environment_access_key_id = self.access_key_id + mock_ak_util.environment_access_key_secret = self.access_key_secret + mock_ram_util.environment_role_arn = self.role_arn + mock_ram_util.environment_role_session_name = self.role_session_name + mock_ram_util.environment_enable_vpc = str(self.enable_vpc) + mock_ram_util.environment_sts_region = self.sts_region_id + + provider = RamRoleArnCredentialsProvider() + + self.assertEqual(provider._credentials_provider.get_credentials().get_access_key_id(), self.access_key_id) + self.assertEqual(provider._credentials_provider.get_credentials().get_access_key_secret(), + self.access_key_secret) + self.assertEqual(provider._role_arn, self.role_arn) + self.assertEqual(provider._role_session_name, self.role_session_name) + self.assertEqual(provider._duration_seconds, RamRoleArnCredentialsProvider.DEFAULT_DURATION_SECONDS) + self.assertIsNone(provider._policy) + self.assertIsNone(provider._external_id) + self.assertEqual(provider._sts_endpoint, f'sts-vpc.{self.sts_region_id}.aliyuncs.com') + self.assertEqual(provider._runtime_options['connectTimeout'], + RamRoleArnCredentialsProvider.DEFAULT_CONNECT_TIMEOUT) + self.assertEqual(provider._runtime_options['readTimeout'], RamRoleArnCredentialsProvider.DEFAULT_READ_TIMEOUT) + self.assertIsNone(provider._runtime_options['httpsProxy']) + + def test_init_missing_role_arn(self): + """ + Test case 3: Missing role_arn raises ValueError + """ + with self.assertRaises(ValueError) as context: + RamRoleArnCredentialsProvider( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token + ) + + self.assertIn("role_arn or environment variable ALIBABA_CLOUD_ROLE_ARN cannot be empty", str(context.exception)) + + def test_init_empty_role_arn(self): + """ + Test case 4: Empty role_arn raises ValueError + """ + with self.assertRaises(ValueError) as context: + RamRoleArnCredentialsProvider( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + role_arn="" + ) + + self.assertIn("role_arn or environment variable ALIBABA_CLOUD_ROLE_ARN cannot be empty", str(context.exception)) + + def test_init_duration_seconds_too_short(self): + """ + Test case 5: Duration seconds less than 900 raises ValueError + """ + with self.assertRaises(ValueError) as context: + RamRoleArnCredentialsProvider( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + role_arn=self.role_arn, + duration_seconds=800 + ) + + self.assertIn("session duration should be in the range of 900s - max session duration", str(context.exception)) + + @patch('alibabacloud_credentials.provider.static_ak.auth_util') + @patch('alibabacloud_credentials.provider.ram_role_arn.au') + def test_init_default_values(self, mock_ram_util, mock_ak_util): + """ + Test case 6: Initializes with default values + """ + mock_ak_util.environment_access_key_id = self.access_key_id + mock_ak_util.environment_access_key_secret = self.access_key_secret + mock_ram_util.environment_role_arn = self.role_arn + mock_ram_util.environment_role_session_name = None + mock_ram_util.environment_enable_vpc = 'false' + mock_ram_util.environment_sts_region = None + + provider = RamRoleArnCredentialsProvider() + + self.assertEqual(provider._credentials_provider.get_credentials().get_access_key_id(), self.access_key_id) + self.assertEqual(provider._credentials_provider.get_credentials().get_access_key_secret(), + self.access_key_secret) + self.assertEqual(provider._role_arn, self.role_arn) + self.assertTrue(provider._role_session_name.startswith('credentials-python-')) + self.assertEqual(provider._duration_seconds, RamRoleArnCredentialsProvider.DEFAULT_DURATION_SECONDS) + self.assertIsNone(provider._policy) + self.assertIsNone(provider._external_id) + self.assertEqual(provider._sts_endpoint, 'sts.aliyuncs.com') + self.assertEqual(provider._runtime_options['connectTimeout'], + RamRoleArnCredentialsProvider.DEFAULT_CONNECT_TIMEOUT) + self.assertEqual(provider._runtime_options['readTimeout'], RamRoleArnCredentialsProvider.DEFAULT_READ_TIMEOUT) + self.assertIsNone(provider._runtime_options['httpsProxy']) + + def test_get_credentials_valid_input(self): + """ + Test case 7: Valid input, successfully retrieves credentials + """ + response_body = json.dumps({ + "Credentials": { + "AccessKeyId": "test_access_key_id", + "AccessKeySecret": "test_access_key_secret", + "SecurityToken": "test_security_token", + "Expiration": "2023-12-31T23:59:59Z" + } + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = RamRoleArnCredentialsProvider( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + role_arn=self.role_arn, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + external_id=self.external_id, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + credentials = provider._refresh_credentials() + + self.assertEqual(credentials.value().get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.value().get_access_key_secret(), "test_access_key_secret") + self.assertEqual(credentials.value().get_security_token(), "test_security_token") + self.assertEqual(credentials.value().get_expiration(), + calendar.timegm(time.strptime("2023-12-31T23:59:59Z", '%Y-%m-%dT%H:%M:%SZ'))) + self.assertEqual(credentials.value().get_provider_name(), "ram_role_arn/static_sts") + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn("No cached value was found.", str(context.exception)) + + def test_get_credentials_http_request_error(self): + """ + Test case 8: HTTP request error raises CredentialException + """ + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = RamRoleArnCredentialsProvider( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + role_arn=self.role_arn, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + external_id=self.external_id, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn( + "error refreshing credentials from ram_role_arn, http_code: 400, result: HTTP request failed", + str(context.exception)) + + def test_get_credentials_response_format_error(self): + """ + Test case 9: Response format error raises CredentialException + """ + response_body = json.dumps({ + "Error": "Invalid request" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = RamRoleArnCredentialsProvider( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + role_arn=self.role_arn, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + external_id=self.external_id, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn( + 'error retrieving credentials from ram_role_arn result: {"Error": "Invalid request"}', + str(context.exception)) + + def test_get_credentials_async_valid_input(self): + """ + Test case 10: Valid input, successfully retrieves credentials asynchronously + """ + response_body = json.dumps({ + "Credentials": { + "AccessKeyId": "test_access_key_id", + "AccessKeySecret": "test_access_key_secret", + "SecurityToken": "test_security_token", + "Expiration": "2023-12-31T23:59:59Z" + } + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + provider = RamRoleArnCredentialsProvider( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + role_arn=self.role_arn, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + external_id=self.external_id, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider._refresh_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + + self.assertEqual(credentials.value().get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.value().get_access_key_secret(), "test_access_key_secret") + self.assertEqual(credentials.value().get_security_token(), "test_security_token") + self.assertEqual(credentials.value().get_expiration(), + calendar.timegm( + time.strptime("2023-12-31T23:59:59Z", '%Y-%m-%dT%H:%M:%SZ'))) + self.assertEqual(credentials.value().get_provider_name(), "ram_role_arn/static_sts") + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + + self.assertIn("No cached value was found.", str(context.exception)) + + def test_get_credentials_async_http_request_error(self): + """ + Test case 11: HTTP request error raises CredentialException asynchronously + """ + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + provider = RamRoleArnCredentialsProvider( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + role_arn=self.role_arn, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + external_id=self.external_id, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + + self.assertIn( + "error refreshing credentials from ram_role_arn, http_code: 400, result: HTTP request failed", + str(context.exception)) + + def test_get_credentials_async_response_format_error(self): + """ + Test case 12: Response format error raises CredentialException asynchronously + """ + response_body = json.dumps({ + "Error": "Invalid request" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + provider = RamRoleArnCredentialsProvider( + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + security_token=self.security_token, + role_arn=self.role_arn, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + external_id=self.external_id, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + + self.assertIn( + 'error retrieving credentials from ram_role_arn result: {"Error": "Invalid request"}', + str(context.exception)) + + @patch('alibabacloud_credentials.provider.ram_role_arn.au.environment_enable_vpc', 'true') + @patch('alibabacloud_credentials.provider.ram_role_arn.au.environment_sts_region', 'test_env_sts_region') + def test_sts_endpoint_with_sts_region_id_and_enable_vpc_true(self): + """ + Test case 13: sts_region_id is provided and enable_vpc is True + """ + provider = RamRoleArnCredentialsProvider( + role_arn=self.role_arn, + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + sts_region_id=self.sts_region_id, + enable_vpc=True, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, f'sts-vpc.{self.sts_region_id}.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.ram_role_arn.au.environment_enable_vpc', 'false') + @patch('alibabacloud_credentials.provider.ram_role_arn.au.environment_sts_region', 'test_env_sts_region') + def test_sts_endpoint_with_sts_region_id_and_enable_vpc_false(self): + """ + Test case 14: sts_region_id is provided and enable_vpc is False + """ + provider = RamRoleArnCredentialsProvider( + role_arn=self.role_arn, + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + sts_region_id=self.sts_region_id, + enable_vpc=False, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, f'sts.{self.sts_region_id}.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.ram_role_arn.au.environment_enable_vpc', 'true') + @patch('alibabacloud_credentials.provider.ram_role_arn.au.environment_sts_region', 'test_env_sts_region') + def test_sts_endpoint_with_environment_sts_region_and_enable_vpc_true(self): + """ + Test case 15: sts_region_id is not provided, environment_sts_region is provided, and enable_vpc is True + """ + provider = RamRoleArnCredentialsProvider( + role_arn=self.role_arn, + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + enable_vpc=True, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, f'sts-vpc.test_env_sts_region.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.ram_role_arn.au.environment_enable_vpc', 'false') + @patch('alibabacloud_credentials.provider.ram_role_arn.au.environment_sts_region', 'test_env_sts_region') + def test_sts_endpoint_with_environment_sts_region_and_enable_vpc_false(self): + """ + Test case 16: sts_region_id is not provided, environment_sts_region is provided, and enable_vpc is False + """ + provider = RamRoleArnCredentialsProvider( + role_arn=self.role_arn, + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + enable_vpc=False, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, f'sts.test_env_sts_region.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.ram_role_arn.au.environment_enable_vpc', 'true') + @patch('alibabacloud_credentials.provider.ram_role_arn.au.environment_sts_region', None) + def test_sts_endpoint_with_no_sts_region_id_or_environment_sts_region_and_enable_vpc_true(self): + """ + Test case 17: sts_region_id and environment_sts_region are not provided, and enable_vpc is True + """ + provider = RamRoleArnCredentialsProvider( + role_arn=self.role_arn, + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + enable_vpc=True, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, 'sts.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.ram_role_arn.au.environment_enable_vpc', 'false') + @patch('alibabacloud_credentials.provider.ram_role_arn.au.environment_sts_region', None) + def test_sts_endpoint_with_no_sts_region_id_or_environment_sts_region_and_enable_vpc_false(self): + """ + Test case 18: sts_region_id and environment_sts_region are not provided, and enable_vpc is False + """ + provider = RamRoleArnCredentialsProvider( + role_arn=self.role_arn, + access_key_id=self.access_key_id, + access_key_secret=self.access_key_secret, + role_session_name=self.role_session_name, + duration_seconds=self.duration_seconds, + policy=self.policy, + enable_vpc=False, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, 'sts.aliyuncs.com') diff --git a/tests/provider/test_refreshable.py b/tests/provider/test_refreshable.py new file mode 100644 index 0000000..7370fed --- /dev/null +++ b/tests/provider/test_refreshable.py @@ -0,0 +1,465 @@ +import unittest +from unittest.mock import patch, MagicMock, AsyncMock +import asyncio +import threading +import time +from datetime import datetime +from alibabacloud_credentials.provider.refreshable import ( + Credentials, + RefreshResult, + NonBlocking, + OneCallerBlocks, + RefreshCachedSupplier, + StaleValueBehavior, + CredentialException +) + + +class TestCredentials(unittest.TestCase): + + def test_credentials_initialization(self): + """ + Test case 1: Test initialization of Credentials class + """ + cred = Credentials( + access_key_id="test_access_key_id", + access_key_secret="test_access_key_secret", + security_token="test_security_token", + expiration=1672531199, + provider_name="test_provider" + ) + + self.assertEqual(cred.get_access_key_id(), "test_access_key_id") + self.assertEqual(cred.get_access_key_secret(), "test_access_key_secret") + self.assertEqual(cred.get_security_token(), "test_security_token") + self.assertEqual(cred.get_expiration(), 1672531199) + self.assertEqual(cred.get_provider_name(), "test_provider") + + +class TestRefreshResult(unittest.TestCase): + + def test_refresh_result_initialization(self): + """ + Test case 2: Test initialization of RefreshResult class + """ + value = Credentials( + access_key_id="test_access_key_id", + access_key_secret="test_access_key_secret", + security_token="test_security_token", + expiration=1672531199, + provider_name="test_provider" + ) + refresh_result = RefreshResult( + value=value, + stale_time=1672531199 + 900, + prefetch_time=1672531199 + 1800 + ) + + self.assertEqual(refresh_result.value().get_access_key_id(), "test_access_key_id") + self.assertEqual(refresh_result.stale_time(), 1672531199 + 900) + self.assertEqual(refresh_result.prefetch_time(), 1672531199 + 1800) + + +class TestNonBlocking(unittest.TestCase): + + def setUp(self): + self.non_blocking = NonBlocking() + + @patch('alibabacloud_credentials.provider.refreshable.EXECUTOR.submit') + @patch('alibabacloud_credentials.provider.refreshable.CONCURRENT_REFRESH_LEASES.acquire') + def test_prefetch_success(self, mock_acquire, mock_submit): + """ + Test case 3: Test prefetch success in NonBlocking class + """ + mock_acquire.return_value = True + action = MagicMock() + + self.non_blocking.prefetch(action) + + mock_acquire.assert_called_once() + mock_submit.assert_called_once_with(action) + + @patch('alibabacloud_credentials.provider.refreshable.EXECUTOR.submit') + @patch('alibabacloud_credentials.provider.refreshable.CONCURRENT_REFRESH_LEASES.acquire') + def test_prefetch_failure(self, mock_acquire, mock_submit): + """ + Test case 4: Test prefetch failure in NonBlocking class + """ + mock_acquire.return_value = False + action = MagicMock() + + self.non_blocking.prefetch(action) + + mock_acquire.assert_called_once() + mock_submit.assert_not_called() + + @patch('alibabacloud_credentials.provider.refreshable.EXECUTOR.submit') + @patch('alibabacloud_credentials.provider.refreshable.CONCURRENT_REFRESH_LEASES.acquire') + def test_prefetch_exception(self, mock_acquire, mock_submit): + """ + Test case 5: Test prefetch exception in NonBlocking class + """ + mock_acquire.return_value = True + mock_submit.side_effect = Exception("Test exception") + action = MagicMock() + + self.non_blocking.prefetch(action) + + mock_acquire.assert_called_once() + mock_submit.assert_called_once_with(action) + + @patch('alibabacloud_credentials.provider.refreshable.NonBlocking.prefetch') + def test_prefetch_async(self, mock_prefetch): + """ + Test case 6: Test prefetch_async in NonBlocking class + """ + + action = AsyncMock() + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + self.non_blocking.prefetch_async(action) + ) + loop.run_until_complete(task) + + mock_prefetch.assert_called_once() + + +class TestOneCallerBlocks(unittest.TestCase): + + def setUp(self): + self.one_caller_blocks = OneCallerBlocks() + + def test_prefetch(self): + """ + Test case 7: Test prefetch in OneCallerBlocks class + """ + action = MagicMock() + + self.one_caller_blocks.prefetch(action) + + action.assert_called_once() + + @patch('alibabacloud_credentials.provider.refreshable.OneCallerBlocks.prefetch') + def test_prefetch_async(self, mock_prefetch): + """ + Test case 8: Test prefetch_async in OneCallerBlocks class + """ + action = AsyncMock() + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + self.one_caller_blocks.prefetch_async(action) + ) + loop.run_until_complete(task) + + action.assert_called_once() + + +class TestRefreshCachedSupplier(unittest.TestCase): + + def setUp(self): + self.refresh_callable = MagicMock() + self.refresh_callable_async = AsyncMock() + self.refresh_result = RefreshResult( + value=Credentials( + access_key_id="test_access_key_id", + access_key_secret="test_access_key_secret", + security_token="test_security_token", + expiration=int(time.mktime(time.localtime())) + 3600, + provider_name="test_provider" + ), + stale_time=int(time.mktime(time.localtime())) + 1800, + prefetch_time=int(time.mktime(time.localtime())) + 3600 + ) + self.refresh_cached_supplier = RefreshCachedSupplier( + refresh_callable=self.refresh_callable, + refresh_callable_async=self.refresh_callable_async, + stale_value_behavior=StaleValueBehavior.STRICT, + prefetch_strategy=OneCallerBlocks() + ) + + def test_sync_call_cache_not_stale(self): + """ + Test case 9: Test sync_call when cache is not stale + """ + self.refresh_cached_supplier._cached_value = self.refresh_result + + result = self.refresh_cached_supplier._sync_call() + + self.assertEqual(result.get_access_key_id(), "test_access_key_id") + self.refresh_callable.assert_not_called() + + def test_sync_call_cache_stale(self): + """ + Test case 10: Test sync_call when cache is stale + """ + self.refresh_cached_supplier._cached_value = self.refresh_result + self.refresh_cached_supplier._cached_value._stale_time = int(time.mktime(time.localtime())) - 1800 + self.refresh_callable.return_value = self.refresh_result + + result = self.refresh_cached_supplier._sync_call() + + self.assertEqual(result.get_access_key_id(), "test_access_key_id") + self.refresh_callable.assert_called_once() + + def test_async_call_cache_not_stale(self): + """ + Test case 11: Test async_call when cache is not stale + """ + self.refresh_cached_supplier._cached_value = self.refresh_result + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + self.refresh_cached_supplier._async_call() + ) + loop.run_until_complete(task) + result = task.result() + + self.assertEqual(result.get_access_key_id(), "test_access_key_id") + self.refresh_callable_async.assert_not_called() + + @patch('alibabacloud_credentials.provider.refreshable.RefreshCachedSupplier._refresh_cache_async') + def test_async_call_cache_stale(self, mock_refresh_cache_async): + """ + Test case 12: Test async_call when cache is stale + """ + self.refresh_cached_supplier._cached_value = self.refresh_result + self.refresh_cached_supplier._cached_value._stale_time = int(time.mktime(time.localtime())) - 1800 + mock_refresh_cache_async.return_value = self.refresh_result + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + self.refresh_cached_supplier._async_call() + ) + loop.run_until_complete(task) + result = task.result() + + self.assertEqual(result.get_access_key_id(), "test_access_key_id") + mock_refresh_cache_async.assert_called_once() + + def test_cache_is_stale(self): + """ + Test case 13: Test cache_is_stale method + """ + self.refresh_cached_supplier._cached_value = None + self.assertTrue(self.refresh_cached_supplier._cache_is_stale()) + + self.refresh_cached_supplier._cached_value = self.refresh_result + self.refresh_cached_supplier._cached_value._stale_time = int(time.mktime(time.localtime())) + 1800 + self.assertFalse(self.refresh_cached_supplier._cache_is_stale()) + + self.refresh_cached_supplier._cached_value._stale_time = int(time.mktime(time.localtime())) - 1800 + self.assertTrue(self.refresh_cached_supplier._cache_is_stale()) + + def test_should_initiate_cache_prefetch(self): + """ + Test case 14: Test should_initiate_cache_prefetch method + """ + self.refresh_cached_supplier._cached_value = None + self.assertTrue(self.refresh_cached_supplier._should_initiate_cache_prefetch()) + + self.refresh_cached_supplier._cached_value = self.refresh_result + self.refresh_cached_supplier._cached_value._prefetch_time = int(time.mktime(time.localtime())) + 3600 + self.assertFalse(self.refresh_cached_supplier._should_initiate_cache_prefetch()) + + self.refresh_cached_supplier._cached_value._prefetch_time = int(time.mktime(time.localtime())) - 3600 + self.assertTrue(self.refresh_cached_supplier._should_initiate_cache_prefetch()) + + def test_prefetch_cache(self): + """ + Test case 15: Test prefetch_cache method + """ + self.refresh_cached_supplier._prefetch_strategy.prefetch = MagicMock() + + self.refresh_cached_supplier._prefetch_cache() + + self.refresh_cached_supplier._prefetch_strategy.prefetch.assert_called_once_with( + self.refresh_cached_supplier._refresh_cache) + + @patch('alibabacloud_credentials.provider.refreshable.RefreshCachedSupplier._refresh_cache_async') + def test_prefetch_cache_async(self, mock_refresh_cache): + """ + Test case 16: Test prefetch_cache_async method + """ + self.refresh_cached_supplier._prefetch_strategy.prefetch_async = AsyncMock() + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + self.refresh_cached_supplier._prefetch_cache_async() + ) + loop.run_until_complete(task) + + self.refresh_cached_supplier._prefetch_strategy.prefetch_async.assert_called_once_with(mock_refresh_cache) + + @patch('alibabacloud_credentials.provider.refreshable.RefreshCachedSupplier._handle_fetched_success') + def test_refresh_cache_success(self, mock_handle_fetched_success): + """ + Test case 17: Test refresh_cache method on success + """ + self.refresh_callable.return_value = self.refresh_result + mock_handle_fetched_success.return_value = self.refresh_result + + self.refresh_cached_supplier._refresh_cache() + + self.refresh_callable.assert_called_once() + mock_handle_fetched_success.assert_called_once_with(self.refresh_result) + self.refresh_cached_supplier._refresh_lock.release.assert_called_once() + + @patch('alibabacloud_credentials.provider.refreshable.RefreshCachedSupplier._handle_fetched_failure') + def test_refresh_cache_failure(self, mock_handle_fetched_failure): + """ + Test case 18: Test refresh_cache method on failure + """ + self.refresh_callable.side_effect = Exception("Test exception") + mock_handle_fetched_failure.return_value = self.refresh_result + + self.refresh_cached_supplier._refresh_cache() + + self.refresh_callable.assert_called_once() + mock_handle_fetched_failure.assert_called_once() + + @patch('alibabacloud_credentials.provider.refreshable.RefreshCachedSupplier._handle_fetched_success') + async def test_refresh_cache_async_success(self, mock_handle_fetched_success): + """ + Test case 19: Test refresh_cache_async method on success + """ + self.refresh_callable_async.return_value = self.refresh_result + mock_handle_fetched_success.return_value = self.refresh_result + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + self.refresh_cached_supplier._prefetch_cache_async() + ) + loop.run_until_complete(task) + + self.refresh_callable_async.assert_called_once() + mock_handle_fetched_success.assert_called_once_with(self.refresh_result) + + @patch('alibabacloud_credentials.provider.refreshable.RefreshCachedSupplier._handle_fetched_failure') + async def test_refresh_cache_async_failure(self, mock_handle_fetched_failure): + """ + Test case 20: Test refresh_cache_async method on failure + """ + self.refresh_callable_async.side_effect = Exception("Test exception") + mock_handle_fetched_failure.return_value = self.refresh_result + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + self.refresh_cached_supplier._refresh_cache_async() + ) + loop.run_until_complete(task) + + self.refresh_callable_async.assert_called_once() + mock_handle_fetched_failure.assert_called_once_with(Exception("Test exception")) + self.refresh_cached_supplier._refresh_lock.release.assert_called_once() + + def test_handle_fetched_success(self): + """ + Test case 21: Test handle_fetched_success method + """ + now = int(time.mktime(time.localtime())) + self.refresh_result._stale_time = now + 1800 + + result = self.refresh_cached_supplier._handle_fetched_success(self.refresh_result) + + self.assertEqual(result.value().get_access_key_id(), "test_access_key_id") + self.assertEqual(result.stale_time(), now + 1800) + self.assertEqual(result.prefetch_time(), now + 3600) + + def test_handle_fetched_success_stale_time_in_past(self): + """ + Test case 22: Test handle_fetched_success method when stale time is in the past + """ + now = int(time.mktime(time.localtime())) + self.refresh_result._stale_time = now - 1800 + self.refresh_cached_supplier_stale_value_behavior = StaleValueBehavior.ALLOW + + with self.assertRaises(CredentialException) as context: + self.refresh_cached_supplier._handle_fetched_success(self.refresh_result) + + self.assertIn("No cached value was found.", str(context.exception)) + + def test_handle_fetched_success_expired(self): + """ + Test case 23: Test handle_fetched_success method when credential is expired + """ + now = int(time.mktime(time.localtime())) + self.refresh_result._stale_time = now - 1800 + self.refresh_cached_supplier._cached_value = self.refresh_result + + result = self.refresh_cached_supplier._handle_fetched_success(self.refresh_result) + + self.assertEqual(result.value().get_access_key_id(), "test_access_key_id") + self.assertGreaterEqual(result.stale_time(), now + 1) + self.assertGreaterEqual(result.prefetch_time(), now + 3600) + + def test_handle_fetched_success_expired_allow_stale(self): + """ + Test case 24: Test handle_fetched_success method when credential is expired and stale value behavior is ALLOW + """ + now = int(time.mktime(time.localtime())) + self.refresh_result._stale_time = now - 1800 + self.refresh_cached_supplier._cached_value = self.refresh_result + self.refresh_cached_supplier._stale_value_behavior = StaleValueBehavior.ALLOW + + result = self.refresh_cached_supplier._handle_fetched_success(self.refresh_result) + + self.assertEqual(result.value().get_access_key_id(), "test_access_key_id") + self.assertGreaterEqual(result.stale_time(), now) + self.assertGreaterEqual(result.prefetch_time(), now + 3600) + + def test_handle_fetched_failure(self): + """ + Test case 25: Test handle_fetched_failure method + """ + now = int(time.mktime(time.localtime())) + self.refresh_cached_supplier._cached_value = self.refresh_result + + result = self.refresh_cached_supplier._handle_fetched_failure(Exception("Test exception")) + + self.assertEqual(result.value().get_access_key_id(), "test_access_key_id") + self.assertGreaterEqual(result.stale_time(), now + 1) + self.assertGreaterEqual(result.prefetch_time(), now + 3600) + + def test_handle_fetched_failure_no_cached_value(self): + """ + Test case 26: Test handle_fetched_failure method when no cached value is available + """ + self.refresh_cached_supplier._cached_value = None + self.refresh_cached_supplier_stale_value_behavior = StaleValueBehavior.ALLOW + + with self.assertRaises(CredentialException) as context: + self.refresh_cached_supplier._handle_fetched_failure(CredentialException("Test exception")) + + self.assertIn("Test exception", str(context.exception)) + + def test_handle_fetched_failure_expired(self): + """ + Test case 27: Test handle_fetched_failure method when cached value is expired + """ + now = int(time.mktime(time.localtime())) + self.refresh_cached_supplier._cached_value = self.refresh_result + self.refresh_cached_supplier._cached_value._stale_time = now - 1800 + self.refresh_cached_supplier._stale_value_behavior = StaleValueBehavior.ALLOW + + result = self.refresh_cached_supplier._handle_fetched_failure(Exception("Test exception")) + + self.assertEqual(result.value().get_access_key_id(), "test_access_key_id") + self.assertGreaterEqual(result.stale_time(), now) + self.assertGreaterEqual(result.prefetch_time(), now + 3600) + + def test_handle_fetched_failure_expired_allow_stale(self): + """ + Test case 28: Test handle_fetched_failure method when cached value is expired and stale value behavior is ALLOW + """ + now = int(time.mktime(time.localtime())) + self.refresh_cached_supplier._cached_value = self.refresh_result + self.refresh_cached_supplier._cached_value._stale_time = now - 1800 + self.refresh_cached_supplier._stale_value_behavior = StaleValueBehavior.ALLOW + + result = self.refresh_cached_supplier._handle_fetched_failure(Exception("Test exception")) + + self.assertEqual(result.value().get_access_key_id(), "test_access_key_id") + self.assertGreaterEqual(result.stale_time(), now) + self.assertGreaterEqual(result.prefetch_time(), now + 3600) diff --git a/tests/provider/test_rsa_key_pair.py b/tests/provider/test_rsa_key_pair.py new file mode 100644 index 0000000..02fd988 --- /dev/null +++ b/tests/provider/test_rsa_key_pair.py @@ -0,0 +1,459 @@ +import unittest +from unittest.mock import patch, MagicMock, AsyncMock +import asyncio +import calendar +import time +import json +from alibabacloud_credentials.provider.rsa_key_pair import ( + RsaKeyPairCredentialsProvider, + CredentialException +) +from alibabacloud_credentials.http import HttpOptions +from Tea.core import TeaResponse + + +class TestRsaKeyPairCredentialsProvider(unittest.TestCase): + + def setUp(self): + self.public_key_id = "test_public_key_id" + self.private_key_file = "test_private_key_file" + self.duration_seconds = 3600 + self.sts_region_id = "test_sts_region_id" + self.sts_endpoint = "test_sts_endpoint" + self.enable_vpc = True + self.http_options = HttpOptions(connect_timeout=5000, read_timeout=10000, proxy="test_proxy") + self.private_key_content = "test_private_key_content" + + def test_init_valid_input(self): + """ + Test case 1: Valid input, successfully initializes with provided parameters + """ + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value=self.private_key_content): + provider = RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds, + sts_region_id=self.sts_region_id, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + self.assertEqual(provider._public_key_id, self.public_key_id) + self.assertEqual(provider._private_key_file, self.private_key_file) + self.assertEqual(provider._duration_seconds, self.duration_seconds) + self.assertEqual(provider._private_key, self.private_key_content) + self.assertEqual(provider._sts_endpoint, self.sts_endpoint) + self.assertEqual(provider._http_options, self.http_options) + self.assertEqual(provider._runtime_options['connectTimeout'], self.http_options.connect_timeout) + self.assertEqual(provider._runtime_options['readTimeout'], self.http_options.read_timeout) + self.assertEqual(provider._runtime_options['httpsProxy'], self.http_options.proxy) + + def test_init_missing_public_key_id(self): + """ + Test case 2: Missing public_key_id raises ValueError + """ + with self.assertRaises(ValueError) as context: + RsaKeyPairCredentialsProvider( + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds + ) + + self.assertIn("public_key_id cannot be empty", str(context.exception)) + + def test_init_empty_public_key_id(self): + """ + Test case 3: Empty public_key_id raises ValueError + """ + with self.assertRaises(ValueError) as context: + RsaKeyPairCredentialsProvider( + public_key_id="", + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds + ) + + self.assertIn("public_key_id cannot be empty", str(context.exception)) + + def test_init_missing_private_key_file(self): + """ + Test case 4: Missing private_key_file raises ValueError + """ + with self.assertRaises(ValueError) as context: + RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + duration_seconds=self.duration_seconds + ) + + self.assertIn("private_key_file cannot be empty", str(context.exception)) + + def test_init_empty_private_key_file(self): + """ + Test case 5: Empty private_key_file raises ValueError + """ + with self.assertRaises(ValueError) as context: + RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file="", + duration_seconds=self.duration_seconds + ) + + def test_init_private_key_file_read_error(self): + """ + Test case 6: Private key file read error raises ValueError + """ + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', side_effect=FileNotFoundError): + with self.assertRaises(FileNotFoundError) as context: + RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds + ) + + def test_init_duration_seconds_too_short(self): + """ + Test case 7: Duration seconds less than 900 raises ValueError + """ + with self.assertRaises(ValueError) as context: + RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=800 + ) + + self.assertIn("session duration should be in the range of 900s - max session duration", str(context.exception)) + + @patch('alibabacloud_credentials.provider.rsa_key_pair.au') + def test_init_default_values(self, mock_auth_util): + """ + Test case 8: Initializes with default values + """ + mock_auth_util.environment_enable_vpc = 'false' + mock_auth_util.environment_sts_region = None + + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value=self.private_key_content): + provider = RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file + ) + + self.assertEqual(provider._public_key_id, self.public_key_id) + self.assertEqual(provider._private_key_file, self.private_key_file) + self.assertEqual(provider._duration_seconds, RsaKeyPairCredentialsProvider.DEFAULT_DURATION_SECONDS) + self.assertEqual(provider._private_key, self.private_key_content) + self.assertEqual(provider._sts_endpoint, 'sts.ap-northeast-1.aliyuncs.com') + self.assertEqual(provider._runtime_options['connectTimeout'], + RsaKeyPairCredentialsProvider.DEFAULT_CONNECT_TIMEOUT) + self.assertEqual(provider._runtime_options['readTimeout'], + RsaKeyPairCredentialsProvider.DEFAULT_READ_TIMEOUT) + self.assertIsNone(provider._runtime_options['httpsProxy']) + + def test_get_credentials_valid_input(self): + """ + Test case 9: Valid input, successfully retrieves credentials + """ + response_body = json.dumps({ + "SessionAccessKey": { + "SessionAccessKeyId": "test_access_key_id", + "SessionAccessKeySecret": "test_access_key_secret", + "Expiration": "2023-12-31T23:59:59Z" + } + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.do_action', return_value=response): + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value=self.private_key_content): + provider = RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + credentials = provider._refresh_credentials() + + self.assertEqual(credentials.value().get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.value().get_access_key_secret(), "test_access_key_secret") + self.assertEqual(credentials.value().get_expiration(), + calendar.timegm(time.strptime("2023-12-31T23:59:59Z", '%Y-%m-%dT%H:%M:%SZ'))) + self.assertEqual(credentials.value().get_provider_name(), "rsa_key_pair") + + def test_get_credentials_http_request_error(self): + """ + Test case 10: HTTP request error raises CredentialException + """ + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('Tea.core.TeaCore.do_action', return_value=response): + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value=self.private_key_content): + provider = RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn( + "error refreshing credentials from rsa_key_pair, http_code: 400, result: HTTP request failed", + str(context.exception)) + + def test_get_credentials_response_format_error(self): + """ + Test case 11: Response format error raises CredentialException + """ + response_body = json.dumps({ + "Error": "Invalid request" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.do_action', return_value=response): + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value=self.private_key_content): + provider = RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn( + 'error retrieving credentials from rsa_key_pair result: {"Error": "Invalid request"}', + str(context.exception)) + + def test_get_credentials_async_valid_input(self): + """ + Test case 12: Valid input, successfully retrieves credentials asynchronously + """ + response_body = json.dumps({ + "SessionAccessKey": { + "SessionAccessKeyId": "test_access_key_id", + "SessionAccessKeySecret": "test_access_key_secret", + "Expiration": "2023-12-31T23:59:59Z" + } + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value=self.private_key_content): + provider = RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider._refresh_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + + self.assertEqual(credentials.value().get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.value().get_access_key_secret(), + "test_access_key_secret") + self.assertEqual(credentials.value().get_expiration(), + calendar.timegm( + time.strptime("2023-12-31T23:59:59Z", '%Y-%m-%dT%H:%M:%SZ'))) + self.assertEqual(credentials.value().get_provider_name(), "rsa_key_pair") + + def test_get_credentials_async_http_request_error(self): + """ + Test case 13: HTTP request error raises CredentialException asynchronously + """ + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value=self.private_key_content): + provider = RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + + self.assertIn( + "error refreshing credentials from rsa_key_pair, http_code: 400, result: HTTP request failed", + str(context.exception)) + + def test_get_credentials_async_response_format_error(self): + """ + Test case 14: Response format error raises CredentialException asynchronously + """ + response_body = json.dumps({ + "Error": "Invalid request" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value=self.private_key_content): + provider = RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds, + sts_endpoint=self.sts_endpoint, + enable_vpc=self.enable_vpc, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + + self.assertIn( + 'error retrieving credentials from rsa_key_pair result: {"Error": "Invalid request"}', + str(context.exception)) + + @patch('alibabacloud_credentials.provider.rsa_key_pair.au.environment_enable_vpc', 'true') + @patch('alibabacloud_credentials.provider.rsa_key_pair.au.environment_sts_region', 'test_env_sts_region') + def test_sts_endpoint_with_sts_region_id_and_enable_vpc_true(self): + """ + Test case 15: sts_region_id is provided and enable_vpc is True + """ + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value=self.private_key_content): + provider = RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds, + sts_region_id=self.sts_region_id, + enable_vpc=True, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, f'sts-vpc.{self.sts_region_id}.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.rsa_key_pair.au.environment_enable_vpc', 'false') + @patch('alibabacloud_credentials.provider.rsa_key_pair.au.environment_sts_region', 'test_env_sts_region') + def test_sts_endpoint_with_sts_region_id_and_enable_vpc_false(self): + """ + Test case 16: sts_region_id is provided and enable_vpc is False + """ + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value=self.private_key_content): + provider = RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds, + sts_region_id=self.sts_region_id, + enable_vpc=False, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, f'sts.{self.sts_region_id}.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.rsa_key_pair.au.environment_enable_vpc', 'true') + @patch('alibabacloud_credentials.provider.rsa_key_pair.au.environment_sts_region', 'test_env_sts_region') + def test_sts_endpoint_with_environment_sts_region_and_enable_vpc_true(self): + """ + Test case 17: sts_region_id is not provided, environment_sts_region is provided, and enable_vpc is True + """ + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value=self.private_key_content): + provider = RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds, + enable_vpc=True, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, f'sts-vpc.test_env_sts_region.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.rsa_key_pair.au.environment_enable_vpc', 'false') + @patch('alibabacloud_credentials.provider.rsa_key_pair.au.environment_sts_region', 'test_env_sts_region') + def test_sts_endpoint_with_environment_sts_region_and_enable_vpc_false(self): + """ + Test case 18: sts_region_id is not provided, environment_sts_region is provided, and enable_vpc is False + """ + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value=self.private_key_content): + provider = RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds, + enable_vpc=False, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, f'sts.test_env_sts_region.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.rsa_key_pair.au.environment_enable_vpc', 'true') + @patch('alibabacloud_credentials.provider.rsa_key_pair.au.environment_sts_region', None) + def test_sts_endpoint_with_no_sts_region_id_or_environment_sts_region_and_enable_vpc_true(self): + """ + Test case 19: sts_region_id and environment_sts_region are not provided, and enable_vpc is True + """ + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value=self.private_key_content): + provider = RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds, + enable_vpc=True, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, 'sts.ap-northeast-1.aliyuncs.com') + + @patch('alibabacloud_credentials.provider.rsa_key_pair.au.environment_enable_vpc', 'false') + @patch('alibabacloud_credentials.provider.rsa_key_pair.au.environment_sts_region', None) + def test_sts_endpoint_with_no_sts_region_id_or_environment_sts_region_and_enable_vpc_false(self): + """ + Test case 20: sts_region_id and environment_sts_region are not provided, and enable_vpc is False + """ + with patch('alibabacloud_credentials.provider.rsa_key_pair._get_content', + return_value=self.private_key_content): + provider = RsaKeyPairCredentialsProvider( + public_key_id=self.public_key_id, + private_key_file=self.private_key_file, + duration_seconds=self.duration_seconds, + enable_vpc=False, + http_options=self.http_options + ) + + self.assertEqual(provider._sts_endpoint, 'sts.ap-northeast-1.aliyuncs.com') diff --git a/tests/provider/test_static_ak.py b/tests/provider/test_static_ak.py new file mode 100644 index 0000000..306f5f0 --- /dev/null +++ b/tests/provider/test_static_ak.py @@ -0,0 +1,183 @@ +import unittest +from unittest.mock import patch, MagicMock +import asyncio +from alibabacloud_credentials.provider.static_ak import StaticAKCredentialsProvider +from alibabacloud_credentials.exceptions import CredentialException + + +class TestStaticAKCredentialsProvider(unittest.TestCase): + + def test_init_valid_input(self): + """ + Test case 1: Valid input, successfully initializes with provided access_key_id and access_key_secret + """ + provider = StaticAKCredentialsProvider( + access_key_id="test_access_key_id", + access_key_secret="test_access_key_secret" + ) + + self.assertEqual(provider.access_key_id, "test_access_key_id") + self.assertEqual(provider.access_key_secret, "test_access_key_secret") + + @patch('alibabacloud_credentials.provider.static_ak.auth_util') + def test_init_valid_environment_variables(self, mock_auth_util): + """ + Test case 2: Valid input, successfully initializes with environment variables + """ + mock_auth_util.environment_access_key_id = "test_access_key_id" + mock_auth_util.environment_access_key_secret = "test_access_key_secret" + + provider = StaticAKCredentialsProvider() + + self.assertEqual(provider.access_key_id, "test_access_key_id") + self.assertEqual(provider.access_key_secret, "test_access_key_secret") + + def test_init_missing_access_key_id(self): + """ + Test case 3: Missing access_key_id raises ValueError + """ + with self.assertRaises(ValueError) as context: + StaticAKCredentialsProvider( + access_key_secret="test_access_key_secret" + ) + + self.assertIn("the access key id is empty", str(context.exception)) + + def test_init_empty_access_key_id(self): + """ + Test case 4: Empty access_key_id raises ValueError + """ + with self.assertRaises(ValueError) as context: + StaticAKCredentialsProvider( + access_key_id="", + access_key_secret="test_access_key_secret" + ) + + self.assertIn("the access key id is empty", str(context.exception)) + + def test_init_missing_access_key_secret(self): + """ + Test case 5: Missing access_key_secret raises ValueError + """ + with self.assertRaises(ValueError) as context: + StaticAKCredentialsProvider( + access_key_id="test_access_key_id" + ) + + self.assertIn("the access key secret is empty", str(context.exception)) + + def test_init_empty_access_key_secret(self): + """ + Test case 6: Empty access_key_secret raises ValueError + """ + with self.assertRaises(ValueError) as context: + StaticAKCredentialsProvider( + access_key_id="test_access_key_id", + access_key_secret="" + ) + + self.assertIn("the access key secret is empty", str(context.exception)) + + @patch('alibabacloud_credentials.provider.static_ak.auth_util') + def test_init_missing_environment_variables(self, mock_auth_util): + """ + Test case 7: Missing environment variables raises ValueError + """ + mock_auth_util.environment_access_key_id = None + mock_auth_util.environment_access_key_secret = None + + with self.assertRaises(ValueError) as context: + StaticAKCredentialsProvider() + + self.assertIn("the access key id is empty", str(context.exception)) + + @patch('alibabacloud_credentials.provider.static_ak.auth_util') + def test_init_empty_environment_variables(self, mock_auth_util): + """ + Test case 8: Empty environment variables raises ValueError + """ + mock_auth_util.environment_access_key_id = "" + mock_auth_util.environment_access_key_secret = "" + + with self.assertRaises(ValueError) as context: + StaticAKCredentialsProvider() + + self.assertIn("the access key id is empty", str(context.exception)) + + def test_get_credentials_valid_input(self): + """ + Test case 9: Valid input, successfully retrieves credentials + """ + provider = StaticAKCredentialsProvider( + access_key_id="test_access_key_id", + access_key_secret="test_access_key_secret" + ) + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.get_access_key_secret(), "test_access_key_secret") + self.assertIsNone(credentials.get_security_token()) + self.assertEqual(credentials.get_provider_name(), "static_ak") + + @patch('alibabacloud_credentials.provider.static_ak.auth_util') + def test_get_credentials_valid_environment_variables(self, mock_auth_util): + """ + Test case 10: Valid input, successfully retrieves credentials from environment variables + """ + mock_auth_util.environment_access_key_id = "test_access_key_id" + mock_auth_util.environment_access_key_secret = "test_access_key_secret" + + provider = StaticAKCredentialsProvider() + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.get_access_key_secret(), "test_access_key_secret") + self.assertIsNone(credentials.get_security_token()) + self.assertEqual(credentials.get_provider_name(), "static_ak") + + def test_get_credentials_async_valid_input(self): + """ + Test case 11: Valid input, successfully retrieves credentials asynchronously + """ + provider = StaticAKCredentialsProvider( + access_key_id="test_access_key_id", + access_key_secret="test_access_key_secret" + ) + + # Use asyncio.run to execute the async function + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + + self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.get_access_key_secret(), "test_access_key_secret") + self.assertIsNone(credentials.get_security_token()) + self.assertEqual(credentials.get_provider_name(), "static_ak") + + @patch('alibabacloud_credentials.provider.static_ak.auth_util') + def test_get_credentials_async_valid_environment_variables(self, mock_auth_util): + """ + Test case 12: Valid input, successfully retrieves credentials asynchronously from environment variables + """ + mock_auth_util.environment_access_key_id = "test_access_key_id" + mock_auth_util.environment_access_key_secret = "test_access_key_secret" + + provider = StaticAKCredentialsProvider() + + # Use asyncio.run to execute the async function + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + + self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.get_access_key_secret(), "test_access_key_secret") + self.assertIsNone(credentials.get_security_token()) + self.assertEqual(credentials.get_provider_name(), "static_ak") diff --git a/tests/provider/test_static_sts.py b/tests/provider/test_static_sts.py new file mode 100644 index 0000000..78e6816 --- /dev/null +++ b/tests/provider/test_static_sts.py @@ -0,0 +1,222 @@ +import unittest +from unittest.mock import patch, MagicMock +import asyncio +from alibabacloud_credentials.provider.static_sts import StaticSTSCredentialsProvider +from alibabacloud_credentials.exceptions import CredentialException + + +class TestStaticSTSCredentialsProvider(unittest.TestCase): + + def test_init_valid_input(self): + """ + Test case 1: Valid input, successfully initializes with provided access_key_id, access_key_secret, and security_token + """ + provider = StaticSTSCredentialsProvider( + access_key_id="test_access_key_id", + access_key_secret="test_access_key_secret", + security_token="test_security_token" + ) + + self.assertEqual(provider.access_key_id, "test_access_key_id") + self.assertEqual(provider.access_key_secret, "test_access_key_secret") + self.assertEqual(provider.security_token, "test_security_token") + + @patch('alibabacloud_credentials.provider.static_sts.auth_util') + def test_init_valid_environment_variables(self, mock_auth_util): + """ + Test case 2: Valid input, successfully initializes with environment variables + """ + mock_auth_util.environment_access_key_id = "test_access_key_id" + mock_auth_util.environment_access_key_secret = "test_access_key_secret" + mock_auth_util.environment_security_token = "test_security_token" + + provider = StaticSTSCredentialsProvider() + + self.assertEqual(provider.access_key_id, "test_access_key_id") + self.assertEqual(provider.access_key_secret, "test_access_key_secret") + self.assertEqual(provider.security_token, "test_security_token") + + def test_init_missing_access_key_id(self): + """ + Test case 3: Missing access_key_id raises ValueError + """ + with self.assertRaises(ValueError) as context: + StaticSTSCredentialsProvider( + access_key_secret="test_access_key_secret", + security_token="test_security_token" + ) + + self.assertIn("the access key id is empty", str(context.exception)) + + def test_init_empty_access_key_id(self): + """ + Test case 4: Empty access_key_id raises ValueError + """ + with self.assertRaises(ValueError) as context: + StaticSTSCredentialsProvider( + access_key_id="", + access_key_secret="test_access_key_secret", + security_token="test_security_token" + ) + + self.assertIn("the access key id is empty", str(context.exception)) + + def test_init_missing_access_key_secret(self): + """ + Test case 5: Missing access_key_secret raises ValueError + """ + with self.assertRaises(ValueError) as context: + StaticSTSCredentialsProvider( + access_key_id="test_access_key_id", + security_token="test_security_token" + ) + + self.assertIn("the access key secret is empty", str(context.exception)) + + def test_init_empty_access_key_secret(self): + """ + Test case 6: Empty access_key_secret raises ValueError + """ + with self.assertRaises(ValueError) as context: + StaticSTSCredentialsProvider( + access_key_id="test_access_key_id", + access_key_secret="", + security_token="test_security_token" + ) + + self.assertIn("the access key secret is empty", str(context.exception)) + + def test_init_missing_security_token(self): + """ + Test case 7: Missing security_token raises ValueError + """ + with self.assertRaises(ValueError) as context: + StaticSTSCredentialsProvider( + access_key_id="test_access_key_id", + access_key_secret="test_access_key_secret" + ) + + self.assertIn("the security token is empty", str(context.exception)) + + def test_init_empty_security_token(self): + """ + Test case 8: Empty security_token raises ValueError + """ + with self.assertRaises(ValueError) as context: + StaticSTSCredentialsProvider( + access_key_id="test_access_key_id", + access_key_secret="test_access_key_secret", + security_token="" + ) + + self.assertIn("the security token is empty", str(context.exception)) + + @patch('alibabacloud_credentials.provider.static_sts.auth_util') + def test_init_missing_environment_variables(self, mock_auth_util): + """ + Test case 9: Missing environment variables raises ValueError + """ + mock_auth_util.environment_access_key_id = None + mock_auth_util.environment_access_key_secret = None + mock_auth_util.environment_security_token = None + + with self.assertRaises(ValueError) as context: + StaticSTSCredentialsProvider() + + self.assertIn("the access key id is empty", str(context.exception)) + + @patch('alibabacloud_credentials.provider.static_sts.auth_util') + def test_init_empty_environment_variables(self, mock_auth_util): + """ + Test case 10: Empty environment variables raises ValueError + """ + mock_auth_util.environment_access_key_id = "" + mock_auth_util.environment_access_key_secret = "" + mock_auth_util.environment_security_token = "" + + with self.assertRaises(ValueError) as context: + StaticSTSCredentialsProvider() + + self.assertIn("the access key id is empty", str(context.exception)) + + def test_get_credentials_valid_input(self): + """ + Test case 11: Valid input, successfully retrieves credentials + """ + provider = StaticSTSCredentialsProvider( + access_key_id="test_access_key_id", + access_key_secret="test_access_key_secret", + security_token="test_security_token" + ) + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.get_access_key_secret(), "test_access_key_secret") + self.assertEqual(credentials.get_security_token(), "test_security_token") + self.assertEqual(credentials.get_provider_name(), "static_sts") + + @patch('alibabacloud_credentials.provider.static_sts.auth_util') + def test_get_credentials_valid_environment_variables(self, mock_auth_util): + """ + Test case 12: Valid input, successfully retrieves credentials from environment variables + """ + mock_auth_util.environment_access_key_id = "test_access_key_id" + mock_auth_util.environment_access_key_secret = "test_access_key_secret" + mock_auth_util.environment_security_token = "test_security_token" + + provider = StaticSTSCredentialsProvider() + + credentials = provider.get_credentials() + + self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.get_access_key_secret(), "test_access_key_secret") + self.assertEqual(credentials.get_security_token(), "test_security_token") + self.assertEqual(credentials.get_provider_name(), "static_sts") + + def test_get_credentials_async_valid_input(self): + """ + Test case 13: Valid input, successfully retrieves credentials asynchronously + """ + provider = StaticSTSCredentialsProvider( + access_key_id="test_access_key_id", + access_key_secret="test_access_key_secret", + security_token="test_security_token" + ) + + # Use asyncio.run to execute the async function + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + + self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.get_access_key_secret(), "test_access_key_secret") + self.assertEqual(credentials.get_security_token(), "test_security_token") + self.assertEqual(credentials.get_provider_name(), "static_sts") + + @patch('alibabacloud_credentials.provider.static_sts.auth_util') + def test_get_credentials_async_valid_environment_variables(self, mock_auth_util): + """ + Test case 14: Valid input, successfully retrieves credentials asynchronously from environment variables + """ + mock_auth_util.environment_access_key_id = "test_access_key_id" + mock_auth_util.environment_access_key_secret = "test_access_key_secret" + mock_auth_util.environment_security_token = "test_security_token" + + provider = StaticSTSCredentialsProvider() + + # Use asyncio.run to execute the async function + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + + self.assertEqual(credentials.get_access_key_id(), "test_access_key_id") + self.assertEqual(credentials.get_access_key_secret(), "test_access_key_secret") + self.assertEqual(credentials.get_security_token(), "test_security_token") + self.assertEqual(credentials.get_provider_name(), "static_sts") diff --git a/tests/provider/test_uri.py b/tests/provider/test_uri.py new file mode 100644 index 0000000..f65f0a2 --- /dev/null +++ b/tests/provider/test_uri.py @@ -0,0 +1,374 @@ +import unittest +from unittest.mock import patch, MagicMock, AsyncMock +import asyncio +import calendar +import time +import json +from alibabacloud_credentials.provider.uri import ( + URLCredentialsProvider, + CredentialException +) +from alibabacloud_credentials.http import HttpOptions +from Tea.core import TeaResponse + + +class TestURLCredentialsProvider(unittest.TestCase): + + def setUp(self): + self.uri = "http://example.com/credentials" + self.protocol = "http" + self.http_options = HttpOptions(connect_timeout=5000, read_timeout=10000, proxy="test_proxy") + self.access_key_id = "test_access_key_id" + self.access_key_secret = "test_access_key_secret" + self.security_token = "test_security_token" + self.expiration = "2023-12-31T23:59:59Z" + self.response_body = json.dumps({ + "Code": "Success", + "AccessKeyId": self.access_key_id, + "AccessKeySecret": self.access_key_secret, + "SecurityToken": self.security_token, + "Expiration": self.expiration + }) + self.response = TeaResponse() + self.response.status_code = 200 + self.response.body = self.response_body.encode('utf-8') + + def test_init_valid_input(self): + """ + Test case 1: Valid input, successfully initializes with provided parameters + """ + provider = URLCredentialsProvider( + uri=self.uri, + protocol=self.protocol, + http_options=self.http_options + ) + + self.assertEqual(provider._uri, self.uri) + self.assertEqual(provider._protocol, self.protocol) + self.assertEqual(provider._http_options, self.http_options) + self.assertEqual(provider._runtime_options['connectTimeout'], self.http_options.connect_timeout) + self.assertEqual(provider._runtime_options['readTimeout'], self.http_options.read_timeout) + self.assertEqual(provider._runtime_options['httpsProxy'], self.http_options.proxy) + + def test_init_missing_uri(self): + """ + Test case 2: Missing uri raises ValueError + """ + with self.assertRaises(ValueError) as context: + URLCredentialsProvider( + protocol=self.protocol, + http_options=self.http_options + ) + + self.assertIn("uri or environment variable ALIBABA_CLOUD_CREDENTIALS_URI cannot be empty", + str(context.exception)) + + def test_init_empty_uri(self): + """ + Test case 3: Empty uri raises ValueError + """ + with self.assertRaises(ValueError) as context: + URLCredentialsProvider( + uri="", + protocol=self.protocol, + http_options=self.http_options + ) + + self.assertIn("uri or environment variable ALIBABA_CLOUD_CREDENTIALS_URI cannot be empty", + str(context.exception)) + + @patch('alibabacloud_credentials.provider.uri.au.environment_credentials_uri', "http://example.com/credentials") + def test_init_valid_environment_variables(self): + """ + Test case 4: Valid input, successfully initializes with environment variables + """ + provider = URLCredentialsProvider() + + self.assertEqual(provider._uri, "http://example.com/credentials") + self.assertEqual(provider._protocol, "http") + self.assertEqual(provider._runtime_options['connectTimeout'], URLCredentialsProvider.DEFAULT_CONNECT_TIMEOUT) + self.assertEqual(provider._runtime_options['readTimeout'], URLCredentialsProvider.DEFAULT_READ_TIMEOUT) + self.assertIsNone(provider._runtime_options['httpsProxy']) + + def test_get_credentials_valid_input(self): + """ + Test case 5: Valid input, successfully retrieves credentials + """ + with patch('Tea.core.TeaCore.do_action', return_value=self.response): + provider = URLCredentialsProvider( + uri=self.uri, + protocol=self.protocol, + http_options=self.http_options + ) + + credentials = provider._refresh_credentials() + + self.assertEqual(credentials.value().get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.value().get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.value().get_security_token(), self.security_token) + self.assertEqual(credentials.value().get_expiration(), + calendar.timegm(time.strptime(self.expiration, '%Y-%m-%dT%H:%M:%SZ'))) + self.assertEqual(credentials.value().get_provider_name(), "credential_uri") + + def test_get_credentials_http_request_error(self): + """ + Test case 6: HTTP request error raises CredentialException + """ + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = URLCredentialsProvider( + uri=self.uri, + protocol=self.protocol, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn( + f'error refreshing credentials from {self.uri}, http_code=400, result: HTTP request failed', + str(context.exception)) + + def test_get_credentials_response_format_error(self): + """ + Test case 7: Response format error raises CredentialException + """ + response_body = json.dumps({ + "Code": "Failure", + "Message": "Invalid request" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = URLCredentialsProvider( + uri=self.uri, + protocol=self.protocol, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn(f'error retrieving credentials from {self.uri} result: {response_body}', + str(context.exception)) + + def test_get_credentials_async_valid_input(self): + """ + Test case 8: Valid input, successfully retrieves credentials asynchronously + """ + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=self.response)): + provider = URLCredentialsProvider( + uri=self.uri, + protocol=self.protocol, + http_options=self.http_options + ) + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider._refresh_credentials_async() + ) + loop.run_until_complete(task) + credentials = task.result() + + self.assertEqual(credentials.value().get_access_key_id(), self.access_key_id) + self.assertEqual(credentials.value().get_access_key_secret(), self.access_key_secret) + self.assertEqual(credentials.value().get_security_token(), self.security_token) + self.assertEqual(credentials.value().get_expiration(), + calendar.timegm(time.strptime(self.expiration, '%Y-%m-%dT%H:%M:%SZ'))) + self.assertEqual(credentials.value().get_provider_name(), "credential_uri") + + def test_get_credentials_async_http_request_error(self): + """ + Test case 9: HTTP request error raises CredentialException asynchronously + """ + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + provider = URLCredentialsProvider( + uri=self.uri, + protocol=self.protocol, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + + self.assertIn( + f'error refreshing credentials from {self.uri}, http_code=400, result: HTTP request failed', + str(context.exception)) + + def test_get_credentials_async_response_format_error(self): + """ + Test case 10: Response format error raises CredentialException asynchronously + """ + response_body = json.dumps({ + "Code": "Failure", + "Message": "Invalid request" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + provider = URLCredentialsProvider( + uri=self.uri, + protocol=self.protocol, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + provider.get_credentials_async() + ) + loop.run_until_complete(task) + + self.assertIn(f'error retrieving credentials from {self.uri} result: {response_body}', + str(context.exception)) + + def test_get_credentials_missing_access_key_id(self): + """ + Test case 11: Missing AccessKeyId in response raises CredentialException + """ + response_body = json.dumps({ + "Code": "Success", + "AccessKeySecret": self.access_key_secret, + "SecurityToken": self.security_token, + "Expiration": self.expiration + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = URLCredentialsProvider( + uri=self.uri, + protocol=self.protocol, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn(f'error retrieving credentials from {self.uri} result: {response_body}', + str(context.exception)) + + def test_get_credentials_missing_access_key_secret(self): + """ + Test case 12: Missing AccessKeySecret in response raises CredentialException + """ + response_body = json.dumps({ + "Code": "Success", + "AccessKeyId": self.access_key_id, + "SecurityToken": self.security_token, + "Expiration": self.expiration + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = URLCredentialsProvider( + uri=self.uri, + protocol=self.protocol, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn(f'error retrieving credentials from {self.uri} result: {response_body}', + str(context.exception)) + + def test_get_credentials_missing_security_token(self): + """ + Test case 13: Missing SecurityToken in response raises CredentialException + """ + response_body = json.dumps({ + "Code": "Success", + "AccessKeyId": self.access_key_id, + "AccessKeySecret": self.access_key_secret, + "Expiration": self.expiration + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = URLCredentialsProvider( + uri=self.uri, + protocol=self.protocol, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn(f'error retrieving credentials from {self.uri} result: {response_body}', + str(context.exception)) + + def test_get_credentials_missing_expiration(self): + """ + Test case 14: Missing Expiration in response raises CredentialException + """ + response_body = json.dumps({ + "Code": "Success", + "AccessKeyId": self.access_key_id, + "AccessKeySecret": self.access_key_secret, + "SecurityToken": self.security_token + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = URLCredentialsProvider( + uri=self.uri, + protocol=self.protocol, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn(f'error retrieving credentials from {self.uri} result: {response_body}', + str(context.exception)) + + def test_get_credentials_invalid_code(self): + """ + Test case 15: Invalid Code in response raises CredentialException + """ + response_body = json.dumps({ + "Code": "Failure", + "AccessKeyId": self.access_key_id, + "AccessKeySecret": self.access_key_secret, + "SecurityToken": self.security_token, + "Expiration": self.expiration + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.do_action', return_value=response): + provider = URLCredentialsProvider( + uri=self.uri, + protocol=self.protocol, + http_options=self.http_options + ) + + with self.assertRaises(CredentialException) as context: + provider.get_credentials() + + self.assertIn(f'error retrieving credentials from {self.uri} result: {response_body}', + str(context.exception)) diff --git a/tests/test_client.py b/tests/test_client.py index 8aed346..edd8b07 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,9 +1,10 @@ import asyncio import unittest +from . import txt_file from alibabacloud_credentials.models import Config from alibabacloud_credentials.utils import auth_constant -from alibabacloud_credentials.client import Client +from alibabacloud_credentials.client import Client, _CredentialsProviderWrap from alibabacloud_credentials import credentials from alibabacloud_credentials.utils import auth_util @@ -19,19 +20,31 @@ def test_client_ak(self): self.assertEqual('654321', cred.get_access_key_secret()) self.assertEqual(auth_constant.ACCESS_KEY, cred.get_type()) self.assertIsNone(cred.get_security_token()) + + model = cred.get_credential() + self.assertEqual('123456', model.get_access_key_id()) + self.assertEqual('654321', model.get_access_key_secret()) + self.assertIsNone(model.get_security_token()) + self.assertEqual(auth_constant.ACCESS_KEY, model.get_type()) + self.assertEqual('static_ak', model.get_provider_name()) + enable_oidc_credential = auth_util.enable_oidc_credential auth_util.enable_oidc_credential = False try: cred = Client() - cred.get_access_key_id() + cred.get_credential() except Exception as ex: - self.assertEqual('not found credentials', str(ex)) + self.assertTrue(str(ex).startswith('unable to load credentials from any of the providers in the chain')) auth_util.enable_oidc_credential = enable_oidc_credential def test_client_sts(self): - conf = Config(type='sts') + conf = Config( + type='sts', + access_key_id='123456', + access_key_secret='654321', + security_token='token', ) cred = Client(conf) - self.assertIsInstance(cred.cloud_credential, credentials.StsCredential) + self.assertIsInstance(cred.cloud_credential, _CredentialsProviderWrap) def test_client_bearer(self): conf = Config(type='bearer') @@ -40,23 +53,38 @@ def test_client_bearer(self): def test_client_ecs_ram_role(self): conf = Config(type='ecs_ram_role') - self.assertIsInstance(Client.get_credentials(conf), credentials.EcsRamRoleCredential) + self.assertIsInstance(Client.get_credentials(conf), _CredentialsProviderWrap) def test_client_credentials_uri(self): - conf = Config(type='credentials_uri') - self.assertIsInstance(Client.get_credentials(conf), credentials.CredentialsURICredential) + conf = Config( + type='credentials_uri', + credentials_uri='http://localhost:8080') + self.assertIsInstance(Client.get_credentials(conf), _CredentialsProviderWrap) def test_client_ram_role_arn(self): - conf = Config(type='ram_role_arn') - self.assertIsInstance(Client.get_credentials(conf), credentials.RamRoleArnCredential) + conf = Config( + type='ram_role_arn', + access_key_id='123456', + access_key_secret='654321', + role_arn='arn:aws:iam::123456789012:role/role-name', + ) + self.assertIsInstance(Client.get_credentials(conf), _CredentialsProviderWrap) def test_client_oidc_role_arn(self): - conf = Config(type='oidc_role_arn', oidc_token_file_path='oidc_token_file_path') - self.assertIsInstance(Client.get_credentials(conf), credentials.OIDCRoleArnCredential) + conf = Config( + type='oidc_role_arn', + role_arn='arn:aws:iam::123456789012:role/role-name', + oidc_provider_arn='arn:aws:iam::123456789012:role/role-name', + oidc_token_file_path='oidc_token_file_path') + self.assertIsInstance(Client.get_credentials(conf), _CredentialsProviderWrap) def test_client_rsa_key_pair(self): - conf = Config(type='rsa_key_pair') - self.assertIsInstance(Client.get_credentials(conf), credentials.RsaKeyPairCredential) + conf = Config( + type='rsa_key_pair', + private_key_file=txt_file, + public_key_id='test', + ) + self.assertIsInstance(Client.get_credentials(conf), _CredentialsProviderWrap) def test_async_call(self): conf = Config( @@ -75,3 +103,7 @@ def test_async_call(self): task = asyncio.ensure_future(client.get_access_key_secret_async()) loop.run_until_complete(task) self.assertEqual('sk1', task.result()) + task = asyncio.ensure_future(client.get_credential_async()) + loop.run_until_complete(task) + credential = task.result() + self.assertEqual('ak1', credential.access_key_id) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 9c3c3d2..bc302e4 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -1,6 +1,10 @@ import unittest - +import json +from unittest.mock import patch, MagicMock, AsyncMock +import asyncio from alibabacloud_credentials import credentials, providers +from alibabacloud_credentials.exceptions import CredentialException +from Tea.core import TeaResponse class TestCredentials(unittest.TestCase): @@ -21,8 +25,7 @@ def get_credentials(self): class TestRsaKeyPairProvider: def get_credentials(self): - return credentials.RsaKeyPairCredential("accessKeyId", "accessKeySecret", 100000000000, - None) + return credentials.RsaKeyPairCredential("accessKeyId", "accessKeySecret", 100000000000, None) def test_EcsRamRoleCredential(self): provider = providers.EcsRamRoleCredentialProvider("roleName") @@ -128,8 +131,8 @@ def test_RamRoleArnCredential(self): self.assertEqual(100000000000, cred.expiration) self.assertEqual('accessKeyId', cred.get_access_key_id()) - self.assertEqual('accessKeySecret', cred.access_key_secret) - self.assertEqual('securityToken', cred.security_token) + self.assertEqual('accessKeySecret', cred.get_access_key_secret()) + self.assertEqual('securityToken', cred.get_security_token()) self.assertEqual(100000000000, cred.expiration) self.assertEqual('ram_role_arn', cred.credential_type) self.assertIsInstance(cred.provider, self.TestRamRoleArnProvider) @@ -173,8 +176,8 @@ def test_OIDCRoleArnCredential(self): self.assertEqual(100000000000, cred.expiration) self.assertEqual('accessKeyId', cred.get_access_key_id()) - self.assertEqual('accessKeySecret', cred.access_key_secret) - self.assertEqual('securityToken', cred.security_token) + self.assertEqual('accessKeySecret', cred.get_access_key_secret()) + self.assertEqual('securityToken', cred.get_security_token()) self.assertEqual(100000000000, cred.expiration) self.assertEqual('oidc_role_arn', cred.credential_type) self.assertIsInstance(cred.provider, self.TestOIDCRoleArnProvider) @@ -215,7 +218,7 @@ def test_RsaKeyPairCredential(self): self.assertEqual(100000000000, cred.expiration) self.assertEqual('accessKeyId', cred.get_access_key_id()) - self.assertEqual('accessKeySecret', cred.access_key_secret) + self.assertEqual('accessKeySecret', cred.get_access_key_secret()) self.assertEqual(100000000000, cred.expiration) def test_CredentialsURICredential(self): @@ -245,3 +248,212 @@ def test_StsCredential(self): self.assertEqual('access_key_secret', model.access_key_secret) self.assertEqual('security_token', model.security_token) self.assertEqual('sts', model.type) + + def test_CredentialsURICredential_normal(self): + """ + Test case 1: Successfully retrieves credentials from URI + """ + response_body = json.dumps({ + "Code": "Success", + "AccessKeyId": "test_access_key_id", + "AccessKeySecret": "test_access_key_secret", + "SecurityToken": "test_security_token", + "Expiration": "2023-12-31T23:59:59Z" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.do_action', return_value=response): + credentials_uri = 'http://localhost:6666/test' + cred = credentials.CredentialsURICredential(credentials_uri) + + model = cred.get_credential() + self.assertEqual('test_access_key_id', model.access_key_id) + self.assertEqual('test_access_key_secret', model.access_key_secret) + self.assertEqual('test_security_token', model.security_token) + self.assertEqual('credentials_uri', model.type) + + def test_CredentialsURICredential_refresh(self): + """ + Test case 2: Refreshes credentials when expired + """ + response_body = json.dumps({ + "Code": "Success", + "AccessKeyId": "test_access_key_id", + "AccessKeySecret": "test_access_key_secret", + "SecurityToken": "test_security_token", + "Expiration": "2023-12-31T23:59:59Z" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.do_action', return_value=response): + credentials_uri = 'http://localhost:6666/test' + cred = credentials.CredentialsURICredential(credentials_uri) + + # Set expiration to a past time to trigger refresh + cred.expiration = 1 + + model = cred.get_credential() + self.assertEqual('test_access_key_id', model.access_key_id) + self.assertEqual('test_access_key_secret', model.access_key_secret) + self.assertEqual('test_security_token', model.security_token) + self.assertEqual('credentials_uri', model.type) + + def test_CredentialsURICredential_http_request_error(self): + """ + Test case 3: HTTP request error raises CredentialException + """ + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('Tea.core.TeaCore.do_action', return_value=response): + credentials_uri = 'http://localhost:6666/test' + cred = credentials.CredentialsURICredential(credentials_uri) + + with self.assertRaises(CredentialException) as context: + cred.get_credential() + + self.assertIn( + "Get credentials from http://localhost:6666/test failed, HttpCode=400", + str(context.exception)) + + def test_CredentialsURICredential_response_format_error(self): + """ + Test case 4: Response format error raises CredentialException + """ + response_body = json.dumps({ + "Code": "Failure", + "Message": "Invalid request" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.do_action', return_value=response): + credentials_uri = 'http://localhost:6666/test' + cred = credentials.CredentialsURICredential(credentials_uri) + + with self.assertRaises(CredentialException) as context: + cred.get_credential() + + self.assertIn( + "Get credentials from http://localhost:6666/test failed, Code is Failure", + str(context.exception)) + + def test_CredentialsURICredential_async_normal(self): + """ + Test case 5: Successfully retrieves credentials from URI asynchronously + """ + response_body = json.dumps({ + "Code": "Success", + "AccessKeyId": "test_access_key_id", + "AccessKeySecret": "test_access_key_secret", + "SecurityToken": "test_security_token", + "Expiration": "2023-12-31T23:59:59Z" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + credentials_uri = 'http://localhost:6666/test' + cred = credentials.CredentialsURICredential(credentials_uri) + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + cred.get_credential_async() + ) + loop.run_until_complete(task) + model = task.result() + + self.assertEqual('test_access_key_id', model.access_key_id) + self.assertEqual('test_access_key_secret', model.access_key_secret) + self.assertEqual('test_security_token', model.security_token) + self.assertEqual('credentials_uri', model.type) + + def test_CredentialsURICredential_async_refresh(self): + """ + Test case 6: Refreshes credentials when expired + """ + response_body = json.dumps({ + "Code": "Success", + "AccessKeyId": "test_access_key_id", + "AccessKeySecret": "test_access_key_secret", + "SecurityToken": "test_security_token", + "Expiration": "2023-12-31T23:59:59Z" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + credentials_uri = 'http://localhost:6666/test' + cred = credentials.CredentialsURICredential(credentials_uri) + + # Set expiration to a past time to trigger refresh + cred.expiration = 1 + + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + cred.get_credential_async() + ) + loop.run_until_complete(task) + model = task.result() + self.assertEqual('test_access_key_id', model.access_key_id) + self.assertEqual('test_access_key_secret', model.access_key_secret) + self.assertEqual('test_security_token', model.security_token) + self.assertEqual('credentials_uri', model.type) + + def test_CredentialsURICredential_async_http_request_error(self): + """ + Test case 7: HTTP request error raises CredentialException + """ + response = TeaResponse() + response.status_code = 400 + response.body = b'HTTP request failed' + + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + credentials_uri = 'http://localhost:6666/test' + cred = credentials.CredentialsURICredential(credentials_uri) + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + cred.get_credential_async() + ) + loop.run_until_complete(task) + + self.assertIn( + "Get credentials from http://localhost:6666/test failed, HttpCode=400", + str(context.exception)) + + def test_CredentialsURICredential_async_response_format_error(self): + """ + Test case 8: Response format error raises CredentialException + """ + response_body = json.dumps({ + "Code": "Failure", + "Message": "Invalid request" + }) + response = TeaResponse() + response.status_code = 200 + response.body = response_body.encode('utf-8') + + with patch('Tea.core.TeaCore.async_do_action', AsyncMock(return_value=response)): + credentials_uri = 'http://localhost:6666/test' + cred = credentials.CredentialsURICredential(credentials_uri) + + with self.assertRaises(CredentialException) as context: + loop = asyncio.get_event_loop() + task = asyncio.ensure_future( + cred.get_credential_async() + ) + loop.run_until_complete(task) + + self.assertIn( + "Get credentials from http://localhost:6666/test failed, Code is Failure", + str(context.exception)) diff --git a/tests/test_model.py b/tests/test_model.py deleted file mode 100644 index bc20c16..0000000 --- a/tests/test_model.py +++ /dev/null @@ -1,67 +0,0 @@ -import unittest -from alibabacloud_credentials.models import Config, CredentialModel - - -class TestModel(unittest.TestCase): - def test_model_config(self): - conf1 = Config() - self.assertEqual('', conf1.access_key_id) - self.assertEqual('', conf1.access_key_secret) - self.assertEqual('', conf1.role_name) - self.assertEqual(1000, conf1.timeout) - self.assertEqual(1000, conf1.connect_timeout) - self.assertFalse(conf1.disable_imds_v1) - self.assertIsNone(conf1.sts_endpoint) - - conf1.timeout = 0 - conf1.access_key_id = 'access_key_id' - self.assertEqual('access_key_id', conf1.access_key_id) - self.assertEqual(0, conf1.timeout) - - conf2 = Config( - access_key_id='access_key_id', - access_key_secret='access_key_secret' - ) - self.assertEqual('access_key_id', conf2.access_key_id) - self.assertEqual('access_key_secret', conf2.access_key_secret) - - def test_model_credential(self): - cred = CredentialModel() - self.assertIsNone(cred.access_key_id) - self.assertIsNone(cred.access_key_secret) - self.assertIsNone(cred.security_token) - self.assertIsNone(cred.bearer_token) - self.assertIsNone(cred.type) - - cred = CredentialModel( - access_key_id='access_key_id', - access_key_secret='access_key_secret', - security_token='security_token', - bearer_token='bearer_token', - type='type', - ) - self.assertEqual('access_key_id', cred.access_key_id) - self.assertEqual('access_key_secret', cred.access_key_secret) - self.assertEqual('security_token', cred.security_token) - self.assertEqual('bearer_token', cred.bearer_token) - self.assertEqual('type', cred.type) - - cred_map = cred.to_map() - self.assertEqual('access_key_id', cred_map['accessKeyId']) - self.assertEqual('access_key_secret', cred_map['accessKeySecret']) - self.assertEqual('security_token', cred_map['securityToken']) - self.assertEqual('bearer_token', cred_map['bearerToken']) - self.assertEqual('type', cred_map['type']) - - cred = CredentialModel() - cred.from_map(cred_map) - self.assertEqual('access_key_id', cred.access_key_id) - self.assertEqual('access_key_secret', cred.access_key_secret) - self.assertEqual('security_token', cred.security_token) - self.assertEqual('bearer_token', cred.bearer_token) - self.assertEqual('type', cred.type) - self.assertEqual('access_key_id', cred.get_access_key_id()) - self.assertEqual('access_key_secret', cred.get_access_key_secret()) - self.assertEqual('security_token', cred.get_security_token()) - self.assertEqual('bearer_token', cred.get_bearer_token()) - self.assertEqual('type', cred.get_type()) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..f165890 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,269 @@ +import unittest +from alibabacloud_credentials.models import Config, CredentialModel + +class TestModel(unittest.TestCase): + def test_config_default_values(self): + conf = Config() + self.assertIsNone(conf.type) + self.assertIsNone(conf.access_key_id) + self.assertIsNone(conf.access_key_secret) + self.assertIsNone(conf.security_token) + self.assertIsNone(conf.bearer_token) + self.assertIsNone(conf.duration_seconds) + self.assertIsNone(conf.role_arn) + self.assertIsNone(conf.oidc_provider_arn) + self.assertIsNone(conf.oidc_token_file_path) + self.assertIsNone(conf.role_session_name) + self.assertIsNone(conf.role_session_expiration) + self.assertIsNone(conf.policy) + self.assertIsNone(conf.external_id) + self.assertIsNone(conf.sts_endpoint) + self.assertIsNone(conf.public_key_id) + self.assertIsNone(conf.private_key_file) + self.assertIsNone(conf.role_name) + self.assertIsNone(conf.enable_imds_v2) + self.assertFalse(conf.disable_imds_v1) + self.assertIsNone(conf.metadata_token_duration) + self.assertIsNone(conf.credentials_uri) + self.assertIsNone(conf.host) + self.assertEqual(None, conf.timeout) + self.assertEqual(None, conf.connect_timeout) + self.assertIsNone(conf.proxy) + + def test_config_custom_values(self): + conf = Config( + type='access_key', + access_key_id='access_key_id', + access_key_secret='access_key_secret', + security_token='security_token', + bearer_token='bearer_token', + duration_seconds=3600, + role_arn='role_arn', + oidc_provider_arn='oidc_provider_arn', + oidc_token_file_path='oidc_token_file_path', + role_session_name='role_session_name', + role_session_expiration=3600, + policy='policy', + external_id='external_id', + sts_endpoint='sts_endpoint', + public_key_id='public_key_id', + private_key_file='private_key_file', + role_name='role_name', + enable_imds_v2=True, + disable_imds_v1=False, + metadata_token_duration=3600, + credentials_uri='credentials_uri', + host='host', + timeout=5000, + connect_timeout=10000, + proxy='proxy' + ) + self.assertEqual('access_key', conf.type) + self.assertEqual('access_key_id', conf.access_key_id) + self.assertEqual('access_key_secret', conf.access_key_secret) + self.assertEqual('security_token', conf.security_token) + self.assertEqual('bearer_token', conf.bearer_token) + self.assertEqual(3600, conf.duration_seconds) + self.assertEqual('role_arn', conf.role_arn) + self.assertEqual('oidc_provider_arn', conf.oidc_provider_arn) + self.assertEqual('oidc_token_file_path', conf.oidc_token_file_path) + self.assertEqual('role_session_name', conf.role_session_name) + self.assertEqual(3600, conf.role_session_expiration) + self.assertEqual('policy', conf.policy) + self.assertEqual('external_id', conf.external_id) + self.assertEqual('sts_endpoint', conf.sts_endpoint) + self.assertEqual('public_key_id', conf.public_key_id) + self.assertEqual('private_key_file', conf.private_key_file) + self.assertEqual('role_name', conf.role_name) + self.assertTrue(conf.enable_imds_v2) + self.assertFalse(conf.disable_imds_v1) + self.assertEqual(3600, conf.metadata_token_duration) + self.assertEqual('credentials_uri', conf.credentials_uri) + self.assertEqual('host', conf.host) + self.assertEqual(5000, conf.timeout) + self.assertEqual(10000, conf.connect_timeout) + self.assertEqual('proxy', conf.proxy) + + def test_config_to_map(self): + conf = Config( + type='access_key', + access_key_id='access_key_id', + access_key_secret='access_key_secret', + security_token='security_token', + bearer_token='bearer_token', + duration_seconds=3600, + role_arn='role_arn', + oidc_provider_arn='oidc_provider_arn', + oidc_token_file_path='oidc_token_file_path', + role_session_name='role_session_name', + role_session_expiration=3600, + policy='policy', + external_id='external_id', + sts_endpoint='sts_endpoint', + public_key_id='public_key_id', + private_key_file='private_key_file', + role_name='role_name', + enable_imds_v2=True, + disable_imds_v1=False, + metadata_token_duration=3600, + credentials_uri='credentials_uri', + host='host', + timeout=5000, + connect_timeout=10000, + proxy='proxy' + ) + conf_map = conf.to_map() + self.assertEqual('access_key', conf_map['type']) + self.assertEqual('access_key_id', conf_map['accessKeyId']) + self.assertEqual('access_key_secret', conf_map['accessKeySecret']) + self.assertEqual('security_token', conf_map['securityToken']) + self.assertEqual('bearer_token', conf_map['bearerToken']) + self.assertEqual(3600, conf_map['durationSeconds']) + self.assertEqual('role_arn', conf_map['roleArn']) + self.assertEqual('oidc_provider_arn', conf_map['oidcProviderArn']) + self.assertEqual('oidc_token_file_path', conf_map['oidcTokenFilePath']) + self.assertEqual('role_session_name', conf_map['roleSessionName']) + self.assertEqual(3600, conf_map['roleSessionExpiration']) + self.assertEqual('policy', conf_map['policy']) + self.assertEqual('external_id', conf_map['externalId']) + self.assertEqual('sts_endpoint', conf_map['stsEndpoint']) + self.assertEqual('public_key_id', conf_map['publicKeyId']) + self.assertEqual('private_key_file', conf_map['privateKeyFile']) + self.assertEqual('role_name', conf_map['roleName']) + self.assertTrue(conf_map['enableIMDSv2']) + self.assertFalse(conf_map['disableIMDSv1']) + self.assertEqual(3600, conf_map['metadataTokenDuration']) + self.assertEqual('credentials_uri', conf_map['credentialsUri']) + self.assertEqual('host', conf_map['host']) + self.assertEqual(5000, conf_map['timeout']) + self.assertEqual(10000, conf_map['connectTimeout']) + self.assertEqual('proxy', conf_map['proxy']) + + def test_config_from_map(self): + conf_map = { + 'type': 'access_key', + 'accessKeyId': 'access_key_id', + 'accessKeySecret': 'access_key_secret', + 'securityToken': 'security_token', + 'bearerToken': 'bearer_token', + 'durationSeconds': 3600, + 'roleArn': 'role_arn', + 'oidcProviderArn': 'oidc_provider_arn', + 'oidcTokenFilePath': 'oidc_token_file_path', + 'roleSessionName': 'role_session_name', + 'roleSessionExpiration': 3600, + 'policy': 'policy', + 'externalId': 'external_id', + 'stsEndpoint': 'sts_endpoint', + 'publicKeyId': 'public_key_id', + 'privateKeyFile': 'private_key_file', + 'roleName': 'role_name', + 'enableIMDSv2': True, + 'disableIMDSv1': False, + 'metadataTokenDuration': 3600, + 'credentialsUri': 'credentials_uri', + 'host': 'host', + 'timeout': 5000, + 'connectTimeout': 10000, + 'proxy': 'proxy' + } + conf = Config().from_map(conf_map) + self.assertEqual('access_key', conf.type) + self.assertEqual('access_key_id', conf.access_key_id) + self.assertEqual('access_key_secret', conf.access_key_secret) + self.assertEqual('security_token', conf.security_token) + self.assertEqual('bearer_token', conf.bearer_token) + self.assertEqual(3600, conf.duration_seconds) + self.assertEqual('role_arn', conf.role_arn) + self.assertEqual('oidc_provider_arn', conf.oidc_provider_arn) + self.assertEqual('oidc_token_file_path', conf.oidc_token_file_path) + self.assertEqual('role_session_name', conf.role_session_name) + self.assertEqual(3600, conf.role_session_expiration) + self.assertEqual('policy', conf.policy) + self.assertEqual('external_id', conf.external_id) + self.assertEqual('sts_endpoint', conf.sts_endpoint) + self.assertEqual('public_key_id', conf.public_key_id) + self.assertEqual('private_key_file', conf.private_key_file) + self.assertEqual('role_name', conf.role_name) + self.assertTrue(conf.enable_imds_v2) + self.assertFalse(conf.disable_imds_v1) + self.assertEqual(3600, conf.metadata_token_duration) + self.assertEqual('credentials_uri', conf.credentials_uri) + self.assertEqual('host', conf.host) + self.assertEqual(5000, conf.timeout) + self.assertEqual(10000, conf.connect_timeout) + self.assertEqual('proxy', conf.proxy) + + def test_credential_model_default_values(self): + cred = CredentialModel() + self.assertIsNone(cred.access_key_id) + self.assertIsNone(cred.access_key_secret) + self.assertIsNone(cred.security_token) + self.assertIsNone(cred.bearer_token) + self.assertIsNone(cred.type) + + def test_credential_model_custom_values(self): + cred = CredentialModel( + access_key_id='access_key_id', + access_key_secret='access_key_secret', + security_token='security_token', + bearer_token='bearer_token', + type='type', + provider_name='provider_name', + ) + self.assertEqual('access_key_id', cred.access_key_id) + self.assertEqual('access_key_secret', cred.access_key_secret) + self.assertEqual('security_token', cred.security_token) + self.assertEqual('bearer_token', cred.bearer_token) + self.assertEqual('type', cred.type) + self.assertEqual('provider_name', cred.provider_name) + + def test_credential_model_to_map(self): + cred = CredentialModel( + access_key_id='access_key_id', + access_key_secret='access_key_secret', + security_token='security_token', + bearer_token='bearer_token', + type='type', + provider_name='provider_name', + ) + cred_map = cred.to_map() + self.assertEqual('access_key_id', cred_map['accessKeyId']) + self.assertEqual('access_key_secret', cred_map['accessKeySecret']) + self.assertEqual('security_token', cred_map['securityToken']) + self.assertEqual('bearer_token', cred_map['bearerToken']) + self.assertEqual('type', cred_map['type']) + self.assertEqual('provider_name', cred_map['providerName']) + + def test_credential_model_from_map(self): + cred_map = { + 'accessKeyId': 'access_key_id', + 'accessKeySecret': 'access_key_secret', + 'securityToken': 'security_token', + 'bearerToken': 'bearer_token', + 'type': 'type', + 'providerName': 'provider_name', + } + cred = CredentialModel().from_map(cred_map) + self.assertEqual('access_key_id', cred.access_key_id) + self.assertEqual('access_key_secret', cred.access_key_secret) + self.assertEqual('security_token', cred.security_token) + self.assertEqual('bearer_token', cred.bearer_token) + self.assertEqual('type', cred.type) + self.assertEqual('provider_name', cred.provider_name) + + def test_credential_model_getters(self): + cred = CredentialModel( + access_key_id='access_key_id', + access_key_secret='access_key_secret', + security_token='security_token', + bearer_token='bearer_token', + type='type', + provider_name='provider_name', + ) + self.assertEqual('access_key_id', cred.get_access_key_id()) + self.assertEqual('access_key_secret', cred.get_access_key_secret()) + self.assertEqual('security_token', cred.get_security_token()) + self.assertEqual('bearer_token', cred.get_bearer_token()) + self.assertEqual('type', cred.get_type()) + self.assertEqual('provider_name', cred.get_provider_name()) \ No newline at end of file diff --git a/tests/test_providers.py b/tests/test_providers.py index 4fbc6c2..1323ab7 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -101,7 +101,7 @@ def test_EcsRamRoleCredentialProvider(self): prov = providers.EcsRamRoleCredentialProvider(config=cfg) self.assertIsNotNone(prov) self.assertEqual("roleNameConfig", prov.role_name) - self.assertEqual(2300, prov.timeout) + self.assertEqual(1100, prov.timeout) token = prov._get_metadata_token(url='127.0.0.1:8888') self.assertEqual('token', token) @@ -131,7 +131,7 @@ def test_EcsRamRoleCredentialProvider(self): self.assertIsNotNone(prov) self.assertTrue(prov.disable_imds_v1) self.assertEqual("roleNameConfig", prov.role_name) - self.assertEqual(2300, prov.timeout) + self.assertEqual(1100, prov.timeout) prov._get_metadata_token(url='127.0.0.1:8888') cred = prov._create_credential(url='127.0.0.1:8888') self.assertEqual('ak', cred.access_key_id) @@ -170,7 +170,7 @@ async def main(): prov = providers.EcsRamRoleCredentialProvider(config=cfg) self.assertIsNotNone(prov) self.assertEqual("roleNameConfig", prov.role_name) - self.assertEqual(2300, prov.timeout) + self.assertEqual(1100, prov.timeout) cred = await prov._create_credential_async(url='127.0.0.1:8888') self.assertEqual('ak', cred.access_key_id)