33Licensed under the MIT License.
44"""
55
6- from typing import Optional
6+ from typing import Literal , cast
77from unittest .mock import MagicMock , create_autospec , patch
88
99import pytest
10- from microsoft .teams .api import ClientCredentials , JsonWebToken , ManagedIdentityCredentials
11- from microsoft .teams .api .auth .credentials import FederatedIdentityCredentials
10+ from microsoft .teams .api import (
11+ ClientCredentials ,
12+ FederatedIdentityCredentials ,
13+ JsonWebToken ,
14+ ManagedIdentityCredentials ,
15+ )
1216from microsoft .teams .apps .token_manager import TokenManager
1317from msal import ManagedIdentityClient # pyright: ignore[reportMissingTypeStubs]
1418
@@ -208,7 +212,9 @@ async def test_get_token_with_managed_identity(self, get_token_method: str, expe
208212
209213 manager = TokenManager (credentials = mock_credentials )
210214
215+ # Patch _get_managed_identity_client to return our mock
211216 with patch .object (manager , "_get_managed_identity_client" , return_value = mock_msal_client ):
217+ # Call the method dynamically
212218 token = await getattr (manager , get_token_method )()
213219
214220 assert token is not None
@@ -219,41 +225,6 @@ async def test_get_token_with_managed_identity(self, get_token_method: str, expe
219225 # and without /.default suffix
220226 mock_msal_client .acquire_token_for_client .assert_called_once_with (resource = expected_resource )
221227
222- @pytest .mark .asyncio
223- async def test_get_graph_token_with_managed_identity_and_tenant (self ):
224- """Test getting tenant-specific graph token with ManagedIdentityCredentials."""
225- mock_credentials = ManagedIdentityCredentials (
226- client_id = "test-managed-identity-client-id" ,
227- tenant_id = "original-tenant-id" ,
228- )
229-
230- # Create a mock that will pass isinstance checks
231- mock_msal_client = create_autospec (ManagedIdentityClient , instance = True )
232- mock_msal_client .acquire_token_for_client .return_value = {"access_token" : VALID_TEST_TOKEN }
233-
234- manager = TokenManager (credentials = mock_credentials )
235-
236- # Track calls to _get_managed_identity_client
237- get_managed_identity_client_calls : list [str ] = []
238-
239- def track_get_managed_identity_client (
240- credentials : ManagedIdentityCredentials | FederatedIdentityCredentials , tenant_id : Optional [str ] = None
241- ) -> ManagedIdentityClient :
242- if tenant_id :
243- get_managed_identity_client_calls .append (tenant_id )
244- return mock_msal_client
245-
246- # Patch _get_managed_identity_client to track calls
247- with patch .object (manager , "_get_managed_identity_client" , side_effect = track_get_managed_identity_client ):
248- # Request token for different tenant
249- token = await manager .get_graph_token ("different-tenant-id" )
250-
251- assert token is not None
252- assert isinstance (token , JsonWebToken )
253-
254- # Note: ManagedIdentityClient is tenant-agnostic and cached, so it won't be called again
255- assert len (get_managed_identity_client_calls ) >= 0
256-
257228 @pytest .mark .asyncio
258229 async def test_get_token_error_handling_with_managed_identity (self ):
259230 """Test error handling when token acquisition fails with ManagedIdentityCredentials."""
@@ -278,3 +249,92 @@ async def test_get_token_error_handling_with_managed_identity(self):
278249 await manager .get_bot_token ()
279250
280251 assert "invalid_client" in str (exc_info .value )
252+
253+ @pytest .mark .asyncio
254+ @pytest .mark .parametrize (
255+ "mi_type,mi_client_id,description" ,
256+ [
257+ ("system" , None , "system-assigned managed identity" ),
258+ ("user" , "test-user-mi-client-id" , "user-assigned managed identity" ),
259+ ],
260+ )
261+ async def test_get_token_with_federated_identity (self , mi_type : str , mi_client_id : str | None , description : str ):
262+ """Test token retrieval using FederatedIdentityCredentials (two-step flow)."""
263+ mock_credentials = FederatedIdentityCredentials (
264+ client_id = "test-app-client-id" ,
265+ managed_identity_type = cast (Literal ["system" , "user" ], mi_type ),
266+ managed_identity_client_id = mi_client_id ,
267+ tenant_id = "test-tenant-id" ,
268+ )
269+
270+ manager = TokenManager (credentials = mock_credentials )
271+
272+ # Mock the managed identity token acquisition (step 1)
273+ mi_token = "mi_token_from_step_1"
274+ with patch .object (manager , "_acquire_managed_identity_token" , return_value = mi_token ):
275+ # Mock ConfidentialClientApplication for step 2
276+ with patch ("microsoft.teams.apps.token_manager.ConfidentialClientApplication" ) as mock_confidential_app :
277+ mock_app_instance = MagicMock ()
278+ mock_app_instance .acquire_token_for_client .return_value = {"access_token" : VALID_TEST_TOKEN }
279+ mock_confidential_app .return_value = mock_app_instance
280+
281+ token = await manager .get_bot_token ()
282+
283+ assert token is not None , f"Failed for: { description } "
284+ assert isinstance (token , JsonWebToken ), f"Failed for: { description } "
285+ assert str (token ) == VALID_TEST_TOKEN , f"Failed for: { description } "
286+
287+ # Verify ConfidentialClientApplication was called with MI token as client_assertion
288+ mock_confidential_app .assert_called_once ()
289+ call_kwargs = mock_confidential_app .call_args [1 ]
290+ assert call_kwargs ["client_credential" ] == {"client_assertion" : mi_token }, f"Failed for: { description } "
291+
292+ @pytest .mark .asyncio
293+ async def test_get_token_with_federated_identity_step1_failure (self ):
294+ """Test error handling when step 1 (MI token acquisition) fails."""
295+ mock_credentials = FederatedIdentityCredentials (
296+ client_id = "test-app-client-id" ,
297+ managed_identity_type = "user" ,
298+ managed_identity_client_id = "test-mi-client-id" ,
299+ tenant_id = "test-tenant-id" ,
300+ )
301+
302+ manager = TokenManager (credentials = mock_credentials )
303+
304+ # Mock step 1 to fail
305+ with patch .object (
306+ manager , "_acquire_managed_identity_token" , side_effect = ValueError ("MI token acquisition failed" )
307+ ):
308+ with pytest .raises (ValueError ) as exc_info :
309+ await manager .get_bot_token ()
310+
311+ assert "MI token acquisition failed" in str (exc_info .value )
312+
313+ @pytest .mark .asyncio
314+ async def test_get_token_with_federated_identity_step2_failure (self ):
315+ """Test error handling when step 2 (final token acquisition) fails."""
316+ mock_credentials = FederatedIdentityCredentials (
317+ client_id = "test-app-client-id" ,
318+ managed_identity_type = "user" ,
319+ managed_identity_client_id = "test-mi-client-id" ,
320+ tenant_id = "test-tenant-id" ,
321+ )
322+
323+ manager = TokenManager (credentials = mock_credentials )
324+
325+ # Mock step 1 to succeed
326+ mi_token = "mi_token_from_step_1"
327+ with patch .object (manager , "_acquire_managed_identity_token" , return_value = mi_token ):
328+ # Mock step 2 to fail
329+ with patch ("microsoft.teams.apps.token_manager.ConfidentialClientApplication" ) as mock_confidential_app :
330+ mock_app_instance = MagicMock ()
331+ mock_app_instance .acquire_token_for_client .return_value = {
332+ "error" : "invalid_grant" ,
333+ "error_description" : "FIC Step 2 failed" ,
334+ }
335+ mock_confidential_app .return_value = mock_app_instance
336+
337+ with pytest .raises (ValueError ) as exc_info :
338+ await manager .get_bot_token ()
339+
340+ assert "invalid_grant" in str (exc_info .value )
0 commit comments