33Licensed under the MIT License.
44"""
55
6- from typing import Optional
6+ from typing import Literal , cast
77from unittest .mock import MagicMock , create_autospec , patch
88
99import pytest
10- from microsoft .teams .api import ClientCredentials , JsonWebToken , ManagedIdentityCredentials
11- from microsoft .teams .api .auth .credentials import FederatedIdentityCredentials
10+ from microsoft .teams .api import (
11+ ClientCredentials ,
12+ FederatedIdentityCredentials ,
13+ JsonWebToken ,
14+ ManagedIdentityCredentials ,
15+ )
1216from microsoft .teams .apps .token_manager import TokenManager
1317from msal import ManagedIdentityClient # pyright: ignore[reportMissingTypeStubs]
1418
@@ -130,7 +134,9 @@ async def test_get_token_with_managed_identity(self, get_token_method: str, expe
130134
131135 manager = TokenManager (credentials = mock_credentials )
132136
137+ # Patch _get_managed_identity_client to return our mock
133138 with patch .object (manager , "_get_managed_identity_client" , return_value = mock_msal_client ):
139+ # Call the method dynamically
134140 token = await getattr (manager , get_token_method )()
135141
136142 assert token is not None
@@ -141,41 +147,6 @@ async def test_get_token_with_managed_identity(self, get_token_method: str, expe
141147 # and without /.default suffix
142148 mock_msal_client .acquire_token_for_client .assert_called_once_with (resource = expected_resource )
143149
144- @pytest .mark .asyncio
145- async def test_get_graph_token_with_managed_identity_and_tenant (self ):
146- """Test getting tenant-specific graph token with ManagedIdentityCredentials."""
147- mock_credentials = ManagedIdentityCredentials (
148- client_id = "test-managed-identity-client-id" ,
149- tenant_id = "original-tenant-id" ,
150- )
151-
152- # Create a mock that will pass isinstance checks
153- mock_msal_client = create_autospec (ManagedIdentityClient , instance = True )
154- mock_msal_client .acquire_token_for_client .return_value = {"access_token" : VALID_TEST_TOKEN }
155-
156- manager = TokenManager (credentials = mock_credentials )
157-
158- # Track calls to _get_managed_identity_client
159- get_managed_identity_client_calls : list [str ] = []
160-
161- def track_get_managed_identity_client (
162- credentials : ManagedIdentityCredentials | FederatedIdentityCredentials , tenant_id : Optional [str ] = None
163- ) -> ManagedIdentityClient :
164- if tenant_id :
165- get_managed_identity_client_calls .append (tenant_id )
166- return mock_msal_client
167-
168- # Patch _get_managed_identity_client to track calls
169- with patch .object (manager , "_get_managed_identity_client" , side_effect = track_get_managed_identity_client ):
170- # Request token for different tenant
171- token = await manager .get_graph_token ("different-tenant-id" )
172-
173- assert token is not None
174- assert isinstance (token , JsonWebToken )
175-
176- # Note: ManagedIdentityClient is tenant-agnostic and cached, so it won't be called again
177- assert len (get_managed_identity_client_calls ) >= 0
178-
179150 @pytest .mark .asyncio
180151 async def test_get_token_error_handling_with_managed_identity (self ):
181152 """Test error handling when token acquisition fails with ManagedIdentityCredentials."""
@@ -200,3 +171,92 @@ async def test_get_token_error_handling_with_managed_identity(self):
200171 await manager .get_bot_token ()
201172
202173 assert "invalid_client" in str (exc_info .value )
174+
175+ @pytest .mark .asyncio
176+ @pytest .mark .parametrize (
177+ "mi_type,mi_client_id,description" ,
178+ [
179+ ("system" , None , "system-assigned managed identity" ),
180+ ("user" , "test-user-mi-client-id" , "user-assigned managed identity" ),
181+ ],
182+ )
183+ async def test_get_token_with_federated_identity (self , mi_type : str , mi_client_id : str | None , description : str ):
184+ """Test token retrieval using FederatedIdentityCredentials (two-step flow)."""
185+ mock_credentials = FederatedIdentityCredentials (
186+ client_id = "test-app-client-id" ,
187+ managed_identity_type = cast (Literal ["system" , "user" ], mi_type ),
188+ managed_identity_client_id = mi_client_id ,
189+ tenant_id = "test-tenant-id" ,
190+ )
191+
192+ manager = TokenManager (credentials = mock_credentials )
193+
194+ # Mock the managed identity token acquisition (step 1)
195+ mi_token = "mi_token_from_step_1"
196+ with patch .object (manager , "_acquire_managed_identity_token" , return_value = mi_token ):
197+ # Mock ConfidentialClientApplication for step 2
198+ with patch ("microsoft.teams.apps.token_manager.ConfidentialClientApplication" ) as mock_confidential_app :
199+ mock_app_instance = MagicMock ()
200+ mock_app_instance .acquire_token_for_client .return_value = {"access_token" : VALID_TEST_TOKEN }
201+ mock_confidential_app .return_value = mock_app_instance
202+
203+ token = await manager .get_bot_token ()
204+
205+ assert token is not None , f"Failed for: { description } "
206+ assert isinstance (token , JsonWebToken ), f"Failed for: { description } "
207+ assert str (token ) == VALID_TEST_TOKEN , f"Failed for: { description } "
208+
209+ # Verify ConfidentialClientApplication was called with MI token as client_assertion
210+ mock_confidential_app .assert_called_once ()
211+ call_kwargs = mock_confidential_app .call_args [1 ]
212+ assert call_kwargs ["client_credential" ] == {"client_assertion" : mi_token }, f"Failed for: { description } "
213+
214+ @pytest .mark .asyncio
215+ async def test_get_token_with_federated_identity_step1_failure (self ):
216+ """Test error handling when step 1 (MI token acquisition) fails."""
217+ mock_credentials = FederatedIdentityCredentials (
218+ client_id = "test-app-client-id" ,
219+ managed_identity_type = "user" ,
220+ managed_identity_client_id = "test-mi-client-id" ,
221+ tenant_id = "test-tenant-id" ,
222+ )
223+
224+ manager = TokenManager (credentials = mock_credentials )
225+
226+ # Mock step 1 to fail
227+ with patch .object (
228+ manager , "_acquire_managed_identity_token" , side_effect = ValueError ("MI token acquisition failed" )
229+ ):
230+ with pytest .raises (ValueError ) as exc_info :
231+ await manager .get_bot_token ()
232+
233+ assert "MI token acquisition failed" in str (exc_info .value )
234+
235+ @pytest .mark .asyncio
236+ async def test_get_token_with_federated_identity_step2_failure (self ):
237+ """Test error handling when step 2 (final token acquisition) fails."""
238+ mock_credentials = FederatedIdentityCredentials (
239+ client_id = "test-app-client-id" ,
240+ managed_identity_type = "user" ,
241+ managed_identity_client_id = "test-mi-client-id" ,
242+ tenant_id = "test-tenant-id" ,
243+ )
244+
245+ manager = TokenManager (credentials = mock_credentials )
246+
247+ # Mock step 1 to succeed
248+ mi_token = "mi_token_from_step_1"
249+ with patch .object (manager , "_acquire_managed_identity_token" , return_value = mi_token ):
250+ # Mock step 2 to fail
251+ with patch ("microsoft.teams.apps.token_manager.ConfidentialClientApplication" ) as mock_confidential_app :
252+ mock_app_instance = MagicMock ()
253+ mock_app_instance .acquire_token_for_client .return_value = {
254+ "error" : "invalid_grant" ,
255+ "error_description" : "FIC Step 2 failed" ,
256+ }
257+ mock_confidential_app .return_value = mock_app_instance
258+
259+ with pytest .raises (ValueError ) as exc_info :
260+ await manager .get_bot_token ()
261+
262+ assert "invalid_grant" in str (exc_info .value )
0 commit comments