Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/apps/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies = [
"cryptography>=3.4.0",
"pyjwt[crypto]>=2.10.0",
"dependency-injector>=4.48.1",
"msal>=1.33.0",
]

[project.optional-dependencies]
Expand Down
1 change: 0 additions & 1 deletion packages/apps/src/microsoft/teams/apps/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(self, **options: Unpack[AppOptions]):
self.credentials = self._init_credentials()

self._token_manager = TokenManager(
http_client=self.http_client,
credentials=self.credentials,
logger=self.log,
default_connection_name=self.options.default_connection_name,
Expand Down
110 changes: 59 additions & 51 deletions packages/apps/src/microsoft/teams/apps/token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,31 @@
Licensed under the MIT License.
"""

import asyncio
import logging
from typing import Optional
from inspect import isawaitable
from typing import Any, Optional, reveal_type

from microsoft.teams.api import (
BotTokenClient,
ClientCredentials,
Credentials,
JsonWebToken,
TokenProtocol,
)
from microsoft.teams.common import Client, ConsoleLogger, LocalStorage, LocalStorageOptions
from microsoft.teams.api.auth.credentials import TokenCredentials
from microsoft.teams.common import ConsoleLogger
from msal import ConfidentialClientApplication # pyright: ignore[reportMissingTypeStubs]


class TokenManager:
"""Manages authentication tokens for the Teams application."""

def __init__(
self,
http_client: Client,
credentials: Optional[Credentials],
logger: Optional[logging.Logger] = None,
default_connection_name: Optional[str] = None,
):
self._bot_token_client = BotTokenClient(http_client.clone())
self._credentials = credentials
self._default_connection_name = default_connection_name

Expand All @@ -35,31 +36,13 @@ def __init__(
else:
self._logger = logger.getChild("TokenManager")

self._bot_token: Optional[TokenProtocol] = None
self._msal_clients_by_tenantId: dict[str, ConfidentialClientApplication] = {}

# Key: tenant_id (empty string "" for default app graph token)
self._graph_tokens: LocalStorage[TokenProtocol] = LocalStorage({}, LocalStorageOptions(max=20000))

async def get_bot_token(self, force: bool = False) -> Optional[TokenProtocol]:
async def get_bot_token(self) -> Optional[TokenProtocol]:
"""Refresh the bot authentication token."""
if not self._credentials:
self._logger.warning("No credentials provided, skipping bot token refresh")
return None

if not force and self._bot_token and not self._bot_token.is_expired():
return self._bot_token

if self._bot_token:
self._logger.debug("Refreshing bot token")
else:
self._logger.debug("Retrieving bot token")

token_response = await self._bot_token_client.get(self._credentials)
self._bot_token = JsonWebToken(token_response.access_token)
self._logger.debug("Bot token refreshed successfully")
return self._bot_token
return await self._get_token("https://api.botframework.com/.default")

async def get_graph_token(self, tenant_id: Optional[str] = None, force: bool = False) -> Optional[TokenProtocol]:
async def get_graph_token(self, tenant_id: Optional[str] = None) -> Optional[TokenProtocol]:
"""
Get or refresh a Graph API token.

Expand All @@ -70,29 +53,54 @@ async def get_graph_token(self, tenant_id: Optional[str] = None, force: bool = F
Returns:
The graph token or None if not available
"""
if not self._credentials:
self._logger.debug("No credentials provided for graph token refresh")
return await self._get_token("https://graph.microsoft.com/.default", tenant_id)

async def _get_token(
self, scope: str | list[str], tenant_id: str | None = None, *, caller_name: str | None = None
) -> Optional[TokenProtocol]:
credentials = self._credentials
if self._credentials is None:
if caller_name:
self._logger.debug(f"No credentials provided for {caller_name}")
return None

# Use empty string as key for default graph token
key = tenant_id or ""

cached = self._graph_tokens.get(key)
if not force and cached and not cached.is_expired():
return cached

creds = self._credentials
if tenant_id and isinstance(self._credentials, ClientCredentials):
creds = ClientCredentials(
client_id=self._credentials.client_id,
client_secret=self._credentials.client_secret,
tenant_id=tenant_id,
if isinstance(credentials, ClientCredentials):
tenant_id_param = tenant_id or credentials.tenant_id or "botframework.com"
msal_client = self._get_msal_client_for_tenant(tenant_id_param)
token_res: dict[str, Any] | None = await asyncio.to_thread(
lambda: msal_client.acquire_token_for_client(scope if isinstance(scope, list) else [scope])
)

response = await self._bot_token_client.get_graph(creds)
token = JsonWebToken(response.access_token)
self._graph_tokens.set(key, token)

self._logger.debug(f"Refreshed graph token tenant_id={tenant_id}")

return token
if token_res.get("access_token", None):
access_token = token_res["access_token"]
return JsonWebToken(access_token)
else:
self._logger.debug(f"TokenRes: {token_res}")
error = token_res.get("error", ValueError("Error retrieving token"))
error_description = token_res.get("error_description", "Error retrieving token from MSAL")
self._logger.error(error_description)
raise error
elif isinstance(credentials, TokenCredentials):
tenant_id_param = tenant_id or credentials.tenant_id or "botframework.com"
tenant_id = tenant_id or "botframework.com"
token = credentials.token(scope, tenant_id)
if isawaitable(token):
access_token = await token
else:
access_token = token

return JsonWebToken(access_token)

def _get_msal_client_for_tenant(self, tenant_id: str) -> ConfidentialClientApplication:
credentials = self._credentials
assert isinstance(credentials, ClientCredentials), (
"MSAL clients are only eligible for client credentials,"
f"but current credentials is {reveal_type(credentials)}"
)
cached_client = self._msal_clients_by_tenantId.setdefault(
tenant_id,
ConfidentialClientApplication(
credentials.client_id,
client_credential=credentials.client_secret if credentials else None,
authority=f"https://login.microsoftonline.com/{tenant_id}",
),
)
return cached_client
Loading