Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/apps/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies = [
"cryptography>=3.4.0",
"pyjwt[crypto]>=2.10.0",
"dependency-injector>=4.48.1",
"msal>=1.33.0",
]

[project.optional-dependencies]
Expand Down
2 changes: 0 additions & 2 deletions packages/apps/src/microsoft/teams/apps/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,8 @@ def __init__(self, **options: Unpack[AppOptions]):
self.credentials = self._init_credentials()

self._token_manager = TokenManager(
http_client=self.http_client,
credentials=self.credentials,
logger=self.log,
default_connection_name=self.options.default_connection_name,
)

self.container = Container()
Expand Down
121 changes: 68 additions & 53 deletions packages/apps/src/microsoft/teams/apps/token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,63 +3,52 @@
Licensed under the MIT License.
"""

import asyncio
import logging
from typing import Optional
from inspect import isawaitable
from typing import Any, Optional

from microsoft.teams.api import (
BotTokenClient,
ClientCredentials,
Credentials,
JsonWebToken,
TokenProtocol,
)
from microsoft.teams.common import Client, ConsoleLogger, LocalStorage, LocalStorageOptions
from microsoft.teams.api.auth.credentials import TokenCredentials
from microsoft.teams.common import ConsoleLogger
from msal import ConfidentialClientApplication # pyright: ignore[reportMissingTypeStubs]

BOT_TOKEN_SCOPE = "https://api.botframework.com/.default"
GRAPH_TOKEN_SCOPE = "https://graph.microsoft.com/.default"
DEFAULT_TENANT_FOR_BOT_TOKEN = "botframework.com"
DEFAULT_TENANT_FOR_GRAPH_TOKEN = "common"
DEFAULT_TOKEN_AUTHORITY = "https://login.microsoftonline.com/{tenant_id}"


class TokenManager:
"""Manages authentication tokens for the Teams application."""

def __init__(
self,
http_client: Client,
credentials: Optional[Credentials],
logger: Optional[logging.Logger] = None,
default_connection_name: Optional[str] = None,
):
self._bot_token_client = BotTokenClient(http_client.clone())
self._credentials = credentials
self._default_connection_name = default_connection_name

if not logger:
self._logger = ConsoleLogger().create_logger("TokenManager")
else:
self._logger = logger.getChild("TokenManager")

self._bot_token: Optional[TokenProtocol] = None

# Key: tenant_id (empty string "" for default app graph token)
self._graph_tokens: LocalStorage[TokenProtocol] = LocalStorage({}, LocalStorageOptions(max=20000))
self._msal_clients_by_tenantId: dict[str, ConfidentialClientApplication] = {}

async def get_bot_token(self, force: bool = False) -> Optional[TokenProtocol]:
async def get_bot_token(self) -> Optional[TokenProtocol]:
"""Refresh the bot authentication token."""
if not self._credentials:
self._logger.warning("No credentials provided, skipping bot token refresh")
return None

if not force and self._bot_token and not self._bot_token.is_expired():
return self._bot_token

if self._bot_token:
self._logger.debug("Refreshing bot token")
else:
self._logger.debug("Retrieving bot token")

token_response = await self._bot_token_client.get(self._credentials)
self._bot_token = JsonWebToken(token_response.access_token)
self._logger.debug("Bot token refreshed successfully")
return self._bot_token
return await self._get_token(
BOT_TOKEN_SCOPE, tenant_id=self._resolve_tenant_id(None, DEFAULT_TENANT_FOR_BOT_TOKEN)
)

async def get_graph_token(self, tenant_id: Optional[str] = None, force: bool = False) -> Optional[TokenProtocol]:
async def get_graph_token(self, tenant_id: Optional[str] = None) -> Optional[TokenProtocol]:
"""
Get or refresh a Graph API token.

Expand All @@ -70,29 +59,55 @@ async def get_graph_token(self, tenant_id: Optional[str] = None, force: bool = F
Returns:
The graph token or None if not available
"""
if not self._credentials:
self._logger.debug("No credentials provided for graph token refresh")
return await self._get_token(
GRAPH_TOKEN_SCOPE, tenant_id=self._resolve_tenant_id(tenant_id, DEFAULT_TENANT_FOR_GRAPH_TOKEN)
)

async def _get_token(
self, scope: str | list[str], tenant_id: str, *, caller_name: str | None = None
) -> Optional[TokenProtocol]:
credentials = self._credentials
if self._credentials is None:
if caller_name:
self._logger.debug(f"No credentials provided for {caller_name}")
return None

# Use empty string as key for default graph token
key = tenant_id or ""

cached = self._graph_tokens.get(key)
if not force and cached and not cached.is_expired():
return cached

creds = self._credentials
if tenant_id and isinstance(self._credentials, ClientCredentials):
creds = ClientCredentials(
client_id=self._credentials.client_id,
client_secret=self._credentials.client_secret,
tenant_id=tenant_id,
if isinstance(credentials, ClientCredentials):
msal_client = self._get_msal_client_for_tenant(tenant_id)
token_res: dict[str, Any] | None = await asyncio.to_thread(
lambda: msal_client.acquire_token_for_client(scope if isinstance(scope, list) else [scope])
)

response = await self._bot_token_client.get_graph(creds)
token = JsonWebToken(response.access_token)
self._graph_tokens.set(key, token)

self._logger.debug(f"Refreshed graph token tenant_id={tenant_id}")

return token
if token_res.get("access_token", None):
access_token = token_res["access_token"]
return JsonWebToken(access_token)
else:
self._logger.debug(f"TokenRes: {token_res}")
error = token_res.get("error", ValueError("Error retrieving token"))
error_description = token_res.get("error_description", "Error retrieving token from MSAL")
self._logger.error(error_description)
raise error
elif isinstance(credentials, TokenCredentials):
token = credentials.token(scope, tenant_id)
if isawaitable(token):
access_token = await token
else:
access_token = token

return JsonWebToken(access_token)

def _get_msal_client_for_tenant(self, tenant_id: str) -> ConfidentialClientApplication:
credentials = self._credentials
assert isinstance(credentials, ClientCredentials), (
f"MSAL clients are only eligible for client credentials,but current credentials is {type(credentials)}"
)
cached_client = self._msal_clients_by_tenantId.setdefault(
tenant_id,
ConfidentialClientApplication(
credentials.client_id,
client_credential=credentials.client_secret if credentials else None,
authority=DEFAULT_TOKEN_AUTHORITY.format(tenant_id=tenant_id),
),
)
return cached_client

def _resolve_tenant_id(self, tenant_id: str | None, default_tenant_id: str):
return tenant_id or (self._credentials.tenant_id if self._credentials else False) or default_tenant_id
Loading