Skip to content

Commit a4ec900

Browse files
authored
Merge pull request #262 from reportportal/EPMRPP-109228-oauth_password_grant
[EPMRPP-109228] OAuth password grant
2 parents 2200b19 + e93f503 commit a4ec900

File tree

23 files changed

+2460
-152
lines changed

23 files changed

+2460
-152
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,5 @@ dmypy.json
114114
# Pyre type checker
115115
.pyre/
116116

117-
# End of https://www.gitignore.io/api/python
117+
AGENTS.md
118+
PROMPTS.md

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,9 @@ profile = "black"
1515
[tool.black]
1616
line-length = 119
1717
target-version = ["py310"]
18+
19+
[tool.pytest.ini_options]
20+
minversion = "6.0"
21+
required_plugins = "pytest-cov"
22+
testpaths = ["tests"]
23+
asyncio_default_fixture_loop_scope = "session"

reportportal_client/__init__.py

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,15 @@
1212
# limitations under the License
1313

1414
"""This package is the base package for ReportPortal client."""
15-
import typing
16-
import warnings
15+
16+
import sys
17+
from typing import Optional, Tuple, TypedDict, Union
18+
19+
# noinspection PyUnreachableCode
20+
if sys.version_info >= (3, 11):
21+
from typing import Unpack
22+
else:
23+
from typing_extensions import Unpack
1724

1825
import aenum
1926

@@ -34,74 +41,94 @@ class ClientType(aenum.Enum):
3441
ASYNC_BATCHED = aenum.auto()
3542

3643

44+
class _ClientOptions(TypedDict, total=False):
45+
client_type: ClientType
46+
endpoint: str
47+
project: str
48+
api_key: Optional[str]
49+
# OAuth 2.0 parameters
50+
oauth_uri: Optional[str]
51+
oauth_username: Optional[str]
52+
oauth_password: Optional[str]
53+
oauth_client_id: Optional[str]
54+
oauth_client_secret: Optional[str]
55+
oauth_scope: Optional[str]
56+
# Common client parameters
57+
launch_uuid: Optional[str]
58+
is_skipped_an_issue: bool
59+
verify_ssl: Union[bool, str]
60+
retries: int
61+
max_pool_size: int
62+
http_timeout: Union[float, Tuple[float, float]]
63+
mode: str
64+
launch_uuid_print: bool
65+
print_output: OutputType
66+
truncate_attributes: bool
67+
log_batch_size: int
68+
log_batch_payload_limit: int
69+
# Async client specific parameters
70+
keepalive_timeout: float
71+
# Async threaded/batched client specific parameters
72+
task_timeout: float
73+
shutdown_timeout: float
74+
# Async batched client specific parameters
75+
trigger_num: int
76+
trigger_interval: float
77+
78+
3779
# noinspection PyIncorrectDocstring
3880
def create_client(
39-
client_type: ClientType, endpoint: str, project: str, *, api_key: str = None, **kwargs: typing.Any
40-
) -> typing.Optional[RP]:
81+
client_type: ClientType, endpoint: str, project: str, **kwargs: Unpack[_ClientOptions]
82+
) -> Optional[RP]:
4183
"""Create and ReportPortal Client based on the type and arguments provided.
4284
4385
:param client_type: Type of the Client to create.
44-
:type client_type: ClientType
4586
:param endpoint: Endpoint of the ReportPortal service.
46-
:type endpoint: str
4787
:param project: Project name to report to.
48-
:type project: str
4988
:param api_key: Authorization API key.
50-
:type api_key: str
89+
:param oauth_uri: OAuth 2.0 token endpoint URI (for OAuth authentication).
90+
:param oauth_username: Username for OAuth 2.0 authentication.
91+
:param oauth_password: Password for OAuth 2.0 authentication.
92+
:param oauth_client_id: OAuth 2.0 client ID.
93+
:param oauth_client_secret: OAuth 2.0 client secret (optional).
94+
:param oauth_scope: OAuth 2.0 scope (optional).
5195
:param launch_uuid: A launch UUID to use instead of starting own one.
52-
:type launch_uuid: str
5396
:param is_skipped_an_issue: Option to mark skipped tests as not 'To Investigate' items on the server
5497
side.
55-
:type is_skipped_an_issue: bool
5698
:param verify_ssl: Option to skip ssl verification.
57-
:type verify_ssl: typing.Union[bool, str]
5899
:param retries: Number of retry attempts to make in case of connection / server
59100
errors.
60-
:type retries: int
61101
:param max_pool_size: Option to set the maximum number of connections to save the pool.
62-
:type max_pool_size: int
63102
:param http_timeout : A float in seconds for connect and read timeout. Use a Tuple to
64103
specific connect and read separately.
65-
:type http_timeout: Tuple[float, float]
66104
:param mode: Launch mode, all Launches started by the client will be in that mode.
67-
:type mode: str
68105
:param launch_uuid_print: Print Launch UUID into passed TextIO or by default to stdout.
69-
:type launch_uuid_print: bool
70106
:param print_output: Set output stream for Launch UUID printing.
71-
:type print_output: OutputType
72107
:param truncate_attributes: Truncate test item attributes to default maximum length.
73-
:type truncate_attributes: bool
74108
:param log_batch_size: Option to set the maximum number of logs that can be processed in one
75109
batch.
76-
:type log_batch_size: int
77110
:param log_batch_payload_limit: Maximum size in bytes of logs that can be processed in one batch.
78-
:type log_batch_payload_limit: int
79111
:param keepalive_timeout: For Async Clients only. Maximum amount of idle time in seconds before
80112
force connection closing.
81-
:type keepalive_timeout: int
82113
:param task_timeout: For Async Threaded and Batched Clients only. Time limit in seconds for a
83114
Task processing.
84-
:type task_timeout: float
85115
:param shutdown_timeout: For Async Threaded and Batched Clients only. Time limit in seconds for
86116
shutting down internal Tasks.
87-
:type shutdown_timeout: float
88117
:param trigger_num: For Async Batched Client only. Number of tasks which triggers Task batch
89118
execution.
90-
:type trigger_num: int
91119
:param trigger_interval: For Async Batched Client only. Time limit which triggers Task batch
92120
execution.
93-
:type trigger_interval: float
94121
:return: ReportPortal Client instance.
95122
"""
96123
if client_type is ClientType.SYNC:
97-
return RPClient(endpoint, project, api_key=api_key, **kwargs)
124+
return RPClient(endpoint, project, **kwargs)
98125
if client_type is ClientType.ASYNC:
99-
return AsyncRPClient(endpoint, project, api_key=api_key, **kwargs)
126+
return AsyncRPClient(endpoint, project, **kwargs)
100127
if client_type is ClientType.ASYNC_THREAD:
101-
return ThreadedRPClient(endpoint, project, api_key=api_key, **kwargs)
128+
return ThreadedRPClient(endpoint, project, **kwargs)
102129
if client_type is ClientType.ASYNC_BATCHED:
103-
return BatchedRPClient(endpoint, project, api_key=api_key, **kwargs)
104-
warnings.warn(f"Unknown ReportPortal Client type requested: {client_type}", RuntimeWarning, stacklevel=2)
130+
return BatchedRPClient(endpoint, project, **kwargs)
131+
raise ValueError(f"Unknown ReportPortal Client type requested: {client_type}")
105132

106133

107134
__all__ = [

reportportal_client/_internal/aio/http.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,20 @@
2424
import asyncio
2525
import sys
2626
from types import TracebackType
27-
from typing import Any, Callable, Coroutine, Optional, Type
27+
from typing import Any, Callable, Coroutine, Optional, Type, Union
2828

2929
from aenum import Enum
30-
from aiohttp import ClientResponse, ClientResponseError, ClientSession, ServerConnectionError
30+
from aiohttp import ClientResponse, ClientResponseError
31+
from aiohttp import ClientSession as AioHttpClientSession
32+
from aiohttp import ServerConnectionError
33+
34+
from reportportal_client._internal.services.auth import AuthAsync
3135

3236
DEFAULT_RETRY_NUMBER: int = 5
3337
DEFAULT_RETRY_DELAY: float = 0.005
3438
THROTTLING_STATUSES: set = {425, 429}
3539
RETRY_STATUSES: set = {408, 500, 502, 503, 507}.union(THROTTLING_STATUSES)
40+
AUTH_PROBLEM_STATUSES: set = {401, 403}
3641

3742

3843
class RetryClass(int, Enum):
@@ -46,7 +51,7 @@ class RetryClass(int, Enum):
4651
class RetryingClientSession:
4752
"""Class uses aiohttp.ClientSession.request method and adds request retry logic."""
4853

49-
_client: ClientSession
54+
_client: AioHttpClientSession
5055
__retry_number: int
5156
__retry_delay: float
5257

@@ -68,7 +73,7 @@ def __init__(
6873
an error. Real value highly depends on Retry Class and Retry attempt number,
6974
since retries are performed in exponential delay manner
7075
"""
71-
self._client = ClientSession(*args, **kwargs)
76+
self._client = AioHttpClientSession(*args, **kwargs)
7277
self.__retry_number = max_retry_number
7378
self.__retry_delay = base_retry_delay
7479

@@ -91,8 +96,12 @@ async def __request(self, method: Callable, url, **kwargs: Any) -> ClientRespons
9196
"""
9297
result = None
9398
exceptions = []
99+
94100
for i in range(self.__retry_number + 1): # add one for the first attempt, which is not a retry
95101
retry_factor = None
102+
if result is not None:
103+
# Release previous result to return connection to pool
104+
await result.release()
96105
try:
97106
result = await method(url, **kwargs)
98107
except Exception as exc:
@@ -157,3 +166,84 @@ async def __aexit__(
157166
) -> None:
158167
"""Auxiliary method which controls what `async with` construction does on block exit."""
159168
await self.close()
169+
170+
171+
class ClientSession:
172+
"""Class wraps aiohttp.ClientSession or RetryingClientSession and adds authentication support."""
173+
174+
_client: Union[AioHttpClientSession, RetryingClientSession]
175+
__auth: Optional[AuthAsync]
176+
177+
def __init__(
178+
self,
179+
wrapped: Union[AioHttpClientSession, RetryingClientSession],
180+
auth: Optional[AuthAsync] = None,
181+
):
182+
"""Initialize an instance of the session with arguments.
183+
184+
:param wrapped: aiohttp.ClientSession or RetryingClientSession instance to wrap
185+
:param auth: authentication instance to use for requests
186+
"""
187+
self._client = wrapped
188+
self.__auth = auth
189+
190+
async def __request(self, method: Callable, url: str, **kwargs: Any) -> ClientResponse:
191+
"""Make a request with authentication support.
192+
193+
The method adds Authorization header if auth is configured and handles auth refresh
194+
on 401/403 responses.
195+
"""
196+
# Clone kwargs and add Authorization header if auth is configured
197+
request_kwargs = kwargs.copy()
198+
if self.__auth:
199+
auth_header = await self.__auth.get()
200+
if auth_header:
201+
if "headers" not in request_kwargs:
202+
request_kwargs["headers"] = {}
203+
else:
204+
request_kwargs["headers"] = request_kwargs["headers"].copy()
205+
request_kwargs["headers"]["Authorization"] = auth_header
206+
207+
result = await method(url, **request_kwargs)
208+
209+
# Check for authentication errors
210+
if result.status in AUTH_PROBLEM_STATUSES and self.__auth:
211+
refreshed_header = await self.__auth.refresh()
212+
if refreshed_header:
213+
# Release previous result to return connection to pool
214+
await result.release()
215+
# Retry with new auth header
216+
request_kwargs["headers"] = request_kwargs.get("headers", {}).copy()
217+
request_kwargs["headers"]["Authorization"] = refreshed_header
218+
result = await method(url, **request_kwargs)
219+
220+
return result
221+
222+
def get(self, url: str, *, allow_redirects: bool = True, **kwargs: Any) -> Coroutine[Any, Any, ClientResponse]:
223+
"""Perform HTTP GET request."""
224+
return self.__request(self._client.get, url, allow_redirects=allow_redirects, **kwargs)
225+
226+
def post(self, url: str, *, data: Any = None, **kwargs: Any) -> Coroutine[Any, Any, ClientResponse]:
227+
"""Perform HTTP POST request."""
228+
return self.__request(self._client.post, url, data=data, **kwargs)
229+
230+
def put(self, url: str, *, data: Any = None, **kwargs: Any) -> Coroutine[Any, Any, ClientResponse]:
231+
"""Perform HTTP PUT request."""
232+
return self.__request(self._client.put, url, data=data, **kwargs)
233+
234+
def close(self) -> Coroutine:
235+
"""Gracefully close internal session instance."""
236+
return self._client.close()
237+
238+
async def __aenter__(self) -> "ClientSession":
239+
"""Auxiliary method which controls what `async with` construction does on block enter."""
240+
return self
241+
242+
async def __aexit__(
243+
self,
244+
exc_type: Optional[Type[BaseException]],
245+
exc_val: Optional[BaseException],
246+
exc_tb: Optional[TracebackType],
247+
) -> None:
248+
"""Auxiliary method which controls what `async with` construction does on block exit."""
249+
await self.close()

reportportal_client/_internal/aio/tasks.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -177,20 +177,24 @@ def append(self, value: _T) -> Optional[List[_T]]:
177177
:return: a batch or None
178178
"""
179179
self.__task_list.append(value)
180-
if self.__ready_to_run():
181-
tasks = self.__task_list
182-
self.__task_list = []
183-
return tasks
180+
if not self.__ready_to_run():
181+
return None
182+
183+
tasks = self.__task_list
184+
self.__task_list = []
185+
return tasks
184186

185187
def flush(self) -> Optional[List[_T]]:
186188
"""Immediately return everything what's left in the internal batch.
187189
188190
:return: a batch or None
189191
"""
190-
if len(self.__task_list) > 0:
191-
tasks = self.__task_list
192-
self.__task_list = []
193-
return tasks
192+
if len(self.__task_list) <= 0:
193+
return None
194+
195+
tasks = self.__task_list
196+
self.__task_list = []
197+
return tasks
194198

195199

196200
class BackgroundTaskList(Generic[_T]):
@@ -224,7 +228,9 @@ def flush(self) -> Optional[List[_T]]:
224228
:return: a batch or None
225229
"""
226230
self.__remove_finished()
227-
if len(self.__task_list) > 0:
228-
tasks = self.__task_list
229-
self.__task_list = []
230-
return tasks
231+
if len(self.__task_list) <= 0:
232+
return None
233+
234+
tasks = self.__task_list
235+
self.__task_list = []
236+
return tasks

0 commit comments

Comments
 (0)