Skip to content

Commit ab357dc

Browse files
committed
Add Auth service to RetryingClientSession class
1 parent f769555 commit ab357dc

File tree

4 files changed

+180
-25
lines changed

4 files changed

+180
-25
lines changed

reportportal_client/_internal/aio/http.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,13 @@
2929
from aenum import Enum
3030
from aiohttp import ClientResponse, ClientResponseError, ClientSession, ServerConnectionError
3131

32+
from reportportal_client._internal.services.auth import AuthAsync
33+
3234
DEFAULT_RETRY_NUMBER: int = 5
3335
DEFAULT_RETRY_DELAY: float = 0.005
3436
THROTTLING_STATUSES: set = {425, 429}
3537
RETRY_STATUSES: set = {408, 500, 502, 503, 507}.union(THROTTLING_STATUSES)
38+
AUTH_PROBLEM_STATUSES: set = {401, 403}
3639

3740

3841
class RetryClass(int, Enum):
@@ -49,28 +52,32 @@ class RetryingClientSession:
4952
_client: ClientSession
5053
__retry_number: int
5154
__retry_delay: float
55+
__auth: Optional[AuthAsync]
5256

5357
def __init__(
5458
self,
5559
*args,
5660
max_retry_number: int = DEFAULT_RETRY_NUMBER,
5761
base_retry_delay: float = DEFAULT_RETRY_DELAY,
62+
auth: Optional[AuthAsync] = None,
5863
**kwargs,
5964
):
6065
"""Initialize an instance of the session with arguments.
6166
6267
To obtain the full list of arguments please see aiohttp.ClientSession.__init__() method. This class
63-
just bypass everything to the base method, except two local arguments 'max_retry_number' and
64-
'base_retry_delay'.
68+
just bypass everything to the base method, except three local arguments 'max_retry_number',
69+
'base_retry_delay', and 'auth'.
6570
6671
:param max_retry_number: the maximum number of the request retries if it was unsuccessful
6772
:param base_retry_delay: base value for retry delay, determine how much time the class will wait after
6873
an error. Real value highly depends on Retry Class and Retry attempt number,
6974
since retries are performed in exponential delay manner
75+
:param auth: authentication instance to use for requests
7076
"""
7177
self._client = ClientSession(*args, **kwargs)
7278
self.__retry_number = max_retry_number
7379
self.__retry_delay = base_retry_delay
80+
self.__auth = auth
7481

7582
async def __nothing(self):
7683
pass
@@ -89,12 +96,24 @@ async def __request(self, method: Callable, url, **kwargs: Any) -> ClientRespons
8996
400 Bad Request it just returns result, for cases where it's reasonable to retry it does it in
9097
exponential manner.
9198
"""
99+
# Clone kwargs and add Authorization header if auth is configured
100+
request_kwargs = kwargs.copy()
101+
if self.__auth:
102+
auth_header = await self.__auth.get()
103+
if auth_header:
104+
if "headers" not in request_kwargs:
105+
request_kwargs["headers"] = {}
106+
else:
107+
request_kwargs["headers"] = request_kwargs["headers"].copy()
108+
request_kwargs["headers"]["Authorization"] = auth_header
109+
92110
result = None
93111
exceptions = []
112+
94113
for i in range(self.__retry_number + 1): # add one for the first attempt, which is not a retry
95114
retry_factor = None
96115
try:
97-
result = await method(url, **kwargs)
116+
result = await method(url, **request_kwargs)
98117
except Exception as exc:
99118
exceptions.append(exc)
100119
if isinstance(exc, ServerConnectionError) or isinstance(exc, ClientResponseError):
@@ -104,6 +123,15 @@ async def __request(self, method: Callable, url, **kwargs: Any) -> ClientRespons
104123
raise exc
105124

106125
if result:
126+
# Check for authentication errors first
127+
if result.status in AUTH_PROBLEM_STATUSES and self.__auth:
128+
refreshed_header = await self.__auth.refresh()
129+
if refreshed_header:
130+
# Retry with new auth header
131+
request_kwargs["headers"] = request_kwargs.get("headers", {}).copy()
132+
request_kwargs["headers"]["Authorization"] = refreshed_header
133+
result = await method(url, **request_kwargs)
134+
107135
if result.ok or result.status not in RETRY_STATUSES:
108136
return result
109137

reportportal_client/_internal/services/auth.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929

3030
# noinspection PyAbstractClass
3131
class Auth(metaclass=AbstractBaseClass):
32-
"""Abstract base class for authentication.
32+
"""Abstract base class for synchronous authentication.
3333
34-
This class defines the interface for all authentication methods.
34+
This class defines the interface for all synchronous authentication methods.
3535
"""
3636

3737
__metaclass__ = AbstractBaseClass
@@ -53,6 +53,32 @@ def refresh(self) -> Optional[str]:
5353
raise NotImplementedError('"refresh" method is not implemented!')
5454

5555

56+
# noinspection PyAbstractClass
57+
class AuthAsync(metaclass=AbstractBaseClass):
58+
"""Abstract base class for asynchronous authentication.
59+
60+
This class defines the interface for all asynchronous authentication methods.
61+
"""
62+
63+
__metaclass__ = AbstractBaseClass
64+
65+
@abstractmethod
66+
async def get(self) -> Optional[str]:
67+
"""Get valid Authorization header value.
68+
69+
:return: Authorization header value or None if authentication failed.
70+
"""
71+
raise NotImplementedError('"get" method is not implemented!')
72+
73+
@abstractmethod
74+
async def refresh(self) -> Optional[str]:
75+
"""Refresh the access token and return Authorization header value.
76+
77+
:return: Authorization header value or None if refresh failed.
78+
"""
79+
raise NotImplementedError('"refresh" method is not implemented!')
80+
81+
5682
class ApiKeyAuthSync(Auth):
5783
"""Synchronous API key authentication.
5884
@@ -86,7 +112,7 @@ def refresh(self) -> None:
86112
return None
87113

88114

89-
class ApiKeyAuthAsync(Auth):
115+
class ApiKeyAuthAsync(AuthAsync):
90116
"""Asynchronous API key authentication.
91117
92118
This class provides simple key-based authentication that always returns
@@ -120,15 +146,14 @@ async def refresh(self) -> None:
120146

121147

122148
# noinspection PyAbstractClass
123-
class OAuthPasswordGrant(Auth):
124-
"""Abstract base class for OAuth 2.0 password grant authentication.
149+
class OAuthPasswordGrant:
150+
"""Base class for OAuth 2.0 password grant authentication.
125151
126152
This class provides common logic for obtaining and refreshing access tokens using
127-
the OAuth 2.0 password grant flow.
153+
the OAuth 2.0 password grant flow. This class should not be used directly, use
154+
OAuthPasswordGrantSync or OAuthPasswordGrantAsync instead.
128155
"""
129156

130-
__metaclass__ = AbstractBaseClass
131-
132157
oauth_uri: str
133158
username: str
134159
password: str
@@ -244,7 +269,7 @@ def _build_token_request_data(self, grant_type: str, **extra_params) -> dict:
244269
return data
245270

246271

247-
class OAuthPasswordGrantSync(OAuthPasswordGrant):
272+
class OAuthPasswordGrantSync(OAuthPasswordGrant, Auth):
248273
"""Synchronous implementation of OAuth 2.0 password grant authentication."""
249274

250275
_session: Optional[requests.Session]
@@ -370,7 +395,7 @@ def close(self) -> None:
370395
self._session.close()
371396

372397

373-
class OAuthPasswordGrantAsync(OAuthPasswordGrant):
398+
class OAuthPasswordGrantAsync(OAuthPasswordGrant, AuthAsync):
374399
"""Asynchronous implementation of OAuth 2.0 password grant authentication."""
375400

376401
_session: Optional[aiohttp.ClientSession]

tests/_internal/aio/test_http.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
# noinspection PyProtectedMember
3333
from reportportal_client._internal.aio.http import RetryingClientSession
34+
from reportportal_client._internal.services.auth import ApiKeyAuthAsync
3435

3536
HTTP_TIMEOUT_TIME = 1.2
3637

@@ -75,6 +76,21 @@ def do_GET(self):
7576
self.wfile.flush()
7677

7778

79+
class UnauthorizedHttpHandler(http.server.BaseHTTPRequestHandler):
80+
def do_GET(self):
81+
auth_header = self.headers.get("Authorization")
82+
if auth_header == "Bearer test_api_key":
83+
self.send_response(200)
84+
self.send_header("Content-Type", "application/json")
85+
self.end_headers()
86+
self.wfile.write("{}\n\n".encode("utf-8"))
87+
else:
88+
self.send_response(401, "Unauthorized")
89+
self.end_headers()
90+
self.wfile.write("Unauthorized\n\n".encode("utf-8"))
91+
self.wfile.flush()
92+
93+
7894
SERVER_PORT = 8000
7995
SERVER_ADDRESS = ("", SERVER_PORT)
8096
SERVER_CLASS = socketserver.TCPServer
@@ -163,3 +179,93 @@ async def test_no_retry_on_not_retryable_error():
163179
assert result is None
164180
assert async_mock.call_count == 1
165181
assert total_time < 1
182+
183+
184+
@pytest.mark.asyncio
185+
async def test_auth_header_added_to_request():
186+
"""Test that auth header is added to requests when auth is configured."""
187+
port = 8006
188+
retry_number = 5
189+
auth = ApiKeyAuthAsync("test_api_key")
190+
timeout = aiohttp.ClientTimeout(connect=1.0, sock_read=1.0)
191+
connector = aiohttp.TCPConnector(force_close=True)
192+
session = RetryingClientSession(
193+
f"http://localhost:{port}",
194+
timeout=timeout,
195+
max_retry_number=retry_number,
196+
base_retry_delay=0.01,
197+
auth=auth,
198+
connector=connector,
199+
)
200+
201+
with get_http_server(server_handler=UnauthorizedHttpHandler, server_address=("", port)):
202+
async with session:
203+
result = await session.get("/")
204+
assert result.ok
205+
assert result.status == 200
206+
207+
208+
@pytest.mark.asyncio
209+
async def test_auth_refresh_on_401():
210+
"""Test that 401 response triggers auth refresh."""
211+
port = 8007
212+
retry_number = 5
213+
214+
# Create a mock auth that fails first, then succeeds
215+
auth = mock.AsyncMock()
216+
auth.get = mock.AsyncMock(side_effect=["Bearer invalid_token", "Bearer test_api_key"])
217+
auth.refresh = mock.AsyncMock(return_value="Bearer test_api_key")
218+
219+
timeout = aiohttp.ClientTimeout(connect=1.0, sock_read=1.0)
220+
connector = aiohttp.TCPConnector(force_close=True)
221+
session = RetryingClientSession(
222+
f"http://localhost:{port}",
223+
timeout=timeout,
224+
max_retry_number=retry_number,
225+
base_retry_delay=0.01,
226+
auth=auth,
227+
connector=connector,
228+
)
229+
230+
with get_http_server(server_handler=UnauthorizedHttpHandler, server_address=("", port)):
231+
async with session:
232+
result = await session.get("/")
233+
# First call to get() returns invalid token, which causes 401
234+
# Then refresh() is called and returns valid token
235+
# Request is retried with valid token and succeeds
236+
assert result.ok
237+
assert result.status == 200
238+
assert auth.get.call_count == 1
239+
assert auth.refresh.call_count == 1
240+
241+
242+
@pytest.mark.asyncio
243+
async def test_auth_refresh_only_once():
244+
"""Test that auth refresh is only performed once per request."""
245+
port = 8008
246+
retry_number = 5
247+
248+
# Create a mock auth that always fails
249+
auth = mock.AsyncMock()
250+
auth.get = mock.AsyncMock(return_value="Bearer invalid_token")
251+
auth.refresh = mock.AsyncMock(return_value="Bearer still_invalid_token")
252+
253+
timeout = aiohttp.ClientTimeout(connect=1.0, sock_read=1.0)
254+
connector = aiohttp.TCPConnector(force_close=True)
255+
session = RetryingClientSession(
256+
f"http://localhost:{port}",
257+
timeout=timeout,
258+
max_retry_number=retry_number,
259+
base_retry_delay=0.01,
260+
auth=auth,
261+
connector=connector,
262+
)
263+
264+
with get_http_server(server_handler=UnauthorizedHttpHandler, server_address=("", port)):
265+
async with session:
266+
result = await session.get("/")
267+
# Auth refresh should only be attempted once
268+
assert not result.ok
269+
assert result.status == 401
270+
assert auth.get.call_count == 1
271+
assert auth.refresh.call_count == 1

tests/_internal/services/test_auth.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -542,26 +542,24 @@ def test_get_returns_token(self):
542542

543543
assert result == f"Bearer {api_token}"
544544

545-
def test_refresh_returns_token(self):
546-
"""Test that refresh() returns the API token."""
545+
def test_refresh_returns_none(self):
546+
"""Test that refresh() returns None (API keys don't have refresh mechanism)."""
547547
api_token = "test_api_token_67890"
548548
auth = ApiKeyAuthSync(api_token)
549549
result = auth.refresh()
550550

551-
assert result == f"Bearer {api_token}"
551+
assert result is None
552552

553553
def test_multiple_calls_return_same_token(self):
554-
"""Test that multiple calls return the same token."""
554+
"""Test that multiple get() calls return the same token."""
555555
api_token = "test_api_token_stable"
556556
auth = ApiKeyAuthSync(api_token)
557557

558558
result1 = auth.get()
559559
result2 = auth.get()
560-
result3 = auth.refresh()
561560

562561
assert result1 == f"Bearer {api_token}"
563562
assert result2 == f"Bearer {api_token}"
564-
assert result3 == f"Bearer {api_token}"
565563

566564

567565
class TestApiTokenAuthAsync:
@@ -577,24 +575,22 @@ async def test_get_returns_token(self):
577575
assert result == f"Bearer {api_token}"
578576

579577
@pytest.mark.asyncio
580-
async def test_refresh_returns_token(self):
581-
"""Test that refresh() returns the API token."""
578+
async def test_refresh_returns_none(self):
579+
"""Test that refresh() returns None (API keys don't have refresh mechanism)."""
582580
api_token = "test_api_token_async_67890"
583581
auth = ApiKeyAuthAsync(api_token)
584582
result = await auth.refresh()
585583

586-
assert result == f"Bearer {api_token}"
584+
assert result is None
587585

588586
@pytest.mark.asyncio
589587
async def test_multiple_calls_return_same_token(self):
590-
"""Test that multiple calls return the same token."""
588+
"""Test that multiple get() calls return the same token."""
591589
api_token = "test_api_token_async_stable"
592590
auth = ApiKeyAuthAsync(api_token)
593591

594592
result1 = await auth.get()
595593
result2 = await auth.get()
596-
result3 = await auth.refresh()
597594

598595
assert result1 == f"Bearer {api_token}"
599596
assert result2 == f"Bearer {api_token}"
600-
assert result3 == f"Bearer {api_token}"

0 commit comments

Comments
 (0)