diff --git a/automated_api.py b/automated_api.py index 329c6c858..a1d0da9e9 100644 --- a/automated_api.py +++ b/automated_api.py @@ -22,11 +22,20 @@ # Fake modules to avoid import errors requests = type(sys)("requests") +requests_adapters = type(sys)("requests.adapters") requests.__dict__["Response"] = type( "Response", (), {"__module__": "requests"} ) +requests.__dict__["adapters"] = requests_adapters +requests_adapters.__dict__["HTTPAdapter"] = type( + "HTTPAdapter", (), {"__module__": "requests.adapters"} +) +requests_adapters.__dict__["Retry"] = type( + "Retry", (), {"__module__": "requests.adapters"} +) sys.modules["requests"] = requests +sys.modules["requests.adapters"] = requests_adapters sys.modules["unidecode"] = type(sys)("unidecode") import ayon_api # noqa: E402 diff --git a/ayon_api/server_api.py b/ayon_api/server_api.py index 87f939e45..95c9e35b6 100644 --- a/ayon_api/server_api.py +++ b/ayon_api/server_api.py @@ -25,6 +25,7 @@ HTTPStatus = None import requests +from requests.adapters import HTTPAdapter, Retry try: # This should be used if 'requests' have it available from requests.exceptions import JSONDecodeError as RequestsJSONDecodeError @@ -476,6 +477,9 @@ def __init__( if not base_url: raise ValueError(f"Invalid server URL {str(base_url)}") + self._session = None + self._session_handlers = {} + base_url = base_url.rstrip("/") self._base_url: str = base_url self._rest_url: str = f"{base_url}/api" @@ -522,17 +526,6 @@ def __init__( self._graphql_allows_data_in_query = None - self._session = None - - self._base_functions_mapping = { - RequestTypes.get: requests.get, - RequestTypes.post: requests.post, - RequestTypes.put: requests.put, - RequestTypes.patch: requests.patch, - RequestTypes.delete: requests.delete - } - self._session_functions_mapping = {} - # Attributes cache self._attributes_schema = None self._entity_type_attributes_cache = {} @@ -674,7 +667,14 @@ def set_max_retries(self, max_retries: Optional[int]): """ if max_retries is None: max_retries = self.get_default_max_retries() - self._max_retries = int(max_retries) + max_retries = int(max_retries) + if max_retries < 0: + max_retries = 0 + if max_retries == self._max_retries: + return + self._max_retries = max_retries + for handler in self._session_handlers.values(): + handler.max_retries = Retry.from_int(max_retries) timeout = property(get_timeout, set_timeout) max_retries = property(get_max_retries, set_max_retries) @@ -996,19 +996,10 @@ def create_session( # Validate token before session creation self.validate_token() - session = requests.Session() - session.cert = self._cert - session.verify = self._ssl_verify - session.headers.update(self.get_headers()) - - self._session_functions_mapping = { - RequestTypes.get: session.get, - RequestTypes.post: session.post, - RequestTypes.put: session.put, - RequestTypes.patch: session.patch, - RequestTypes.delete: session.delete - } + session, handlers = self._create_new_session() + self._session = session + self._session_handlers = handlers def close_session(self): if self._session is None: @@ -1016,7 +1007,7 @@ def close_session(self): session = self._session self._session = None - self._session_functions_mapping = {} + self._session_handlers = {} session.close() def _update_session_headers(self): @@ -1340,12 +1331,35 @@ def logout(self, soft: bool = False): def _logout(self): logout_from_server(self._base_url, self._access_token) - def _do_rest_request(self, function, url, **kwargs): + def _do_rest_request(self, request_type, url, **kwargs): + """ + + Args: + request_type (RequestType): Request type. + url (str): Request url. + max_retries (int): Does affect only connection issues or + when session is not created. + **kwargs: + + Returns: + RestApiResponse: Response. + + Raises: + ConnectionRefusedError: When connection is refused. + requests.exceptions.Timeout: When connection timed out. + requests.exceptions.ConnectionError: When connection error + happens. + + """ kwargs.setdefault("timeout", self.timeout) - max_retries = kwargs.get("max_retries", self.max_retries) - if max_retries < 1: - max_retries = 1 - if self._session is None: + + close_session = False + session = self._session + max_retries = kwargs.get("max_retries") + if max_retries is None: + max_retries = self.max_retries + + if session is None: # Validate token if was not yet validated # - ignore validation if we're in middle of # validation @@ -1355,65 +1369,49 @@ def _do_rest_request(self, function, url, **kwargs): ): self.validate_token() - if "headers" not in kwargs: - kwargs["headers"] = self.get_headers() - - if isinstance(function, RequestType): - function = self._base_functions_mapping[function] - - elif isinstance(function, RequestType): - function = self._session_functions_mapping[function] + headers = kwargs.get("headers") + close_session = True + session, _ = self._create_new_session( + max_retries=max_retries, headers=headers + ) response = None new_response = None - for retry_idx in reversed(range(max_retries)): - try: - response = function(url, **kwargs) - break - - except ConnectionRefusedError: - if retry_idx == 0: - self.log.warning( - "Connection error happened.", exc_info=True - ) - - # Server may be restarting - new_response = RestApiResponse( - None, - { - "detail": ( - "Unable to connect the server. Connection refused" - ) - } - ) - - except requests.exceptions.Timeout: - # Connection timed out - new_response = RestApiResponse( - None, - {"detail": "Connection timed out."} - ) + if max_retries < 1: + max_retries = 1 - except requests.exceptions.ConnectionError: - # Log warning only on last attempt - if retry_idx == 0: - self.log.warning( - "Connection error happened.", exc_info=True + try: + for retry_idx in reversed(range(max_retries)): + try: + response = session.request( + request_type.name, url, **kwargs ) - - new_response = RestApiResponse( - None, - { - "detail": ( - "Unable to connect the server. Connection error" + break + + except ( + # These are 'ConnectionError' but it doesn't make sense + # to retry + requests.exceptions.ProxyError, + requests.exceptions.SSLError, + ): + raise + + except ( + ConnectionRefusedError, + requests.exceptions.ConnectionError + ): + # Log warning only on last attempt + if retry_idx == 0: + self.log.warning( + "Connection error happened.", exc_info=True ) - } - ) + raise - time.sleep(0.1) + time.sleep(0.1) - if new_response is not None: - return new_response + finally: + if close_session: + session.close() content_type = response.headers.get("Content-Type") if content_type == "application/json": @@ -1434,6 +1432,28 @@ def _do_rest_request(self, function, url, **kwargs): self.log.debug(f"Response {str(new_response)}") return new_response + def _create_new_session(self, max_retries=None, headers=None): + if max_retries is None: + max_retries = self.max_retries + if max_retries < 0: + max_retries = 0 + + if headers is None: + headers = self.get_headers() + + session = requests.Session() + session.cert = self._cert + session.verify = self._ssl_verify + session.headers.update(headers) + handlers = { + "http://": HTTPAdapter(max_retries=max_retries), + "https://": HTTPAdapter(max_retries=max_retries), + } + for prefix, adapter in handlers.items(): + session.mount(prefix, adapter) + + return session, handlers + def raw_post(self, entrypoint: str, **kwargs): url = self._endpoint_to_url(entrypoint) self.log.debug(f"Executing [POST] {url}") @@ -2142,18 +2162,22 @@ def _download_file_to_stream( self, url: str, stream, chunk_size, progress ): kwargs = {"stream": True} - if self._session is None: - kwargs["headers"] = self.get_headers() - get_func = self._base_functions_mapping[RequestTypes.get] - else: - get_func = self._session_functions_mapping[RequestTypes.get] + session = self._session + close_session = False + if session is None: + close_session = True + session, _ = self._create_new_session() - with get_func(url, **kwargs) as response: - response.raise_for_status() - progress.set_content_size(response.headers["Content-length"]) - for chunk in response.iter_content(chunk_size=chunk_size): - stream.write(chunk) - progress.add_transferred_chunk(len(chunk)) + try: + with session.request("GET", url, **kwargs) as response: + response.raise_for_status() + progress.set_content_size(response.headers["Content-length"]) + for chunk in response.iter_content(chunk_size=chunk_size): + stream.write(chunk) + progress.add_transferred_chunk(len(chunk)) + finally: + if close_session: + session.close() def download_file_to_stream( self, @@ -2317,25 +2341,29 @@ def _upload_file( """ if request_type is None: - request_type = RequestTypes.put + request_type = "PUT" + elif isinstance(request_type, RequestType): + request_type = request_type.name + session = self._session + close_session = False if self._session is None: - headers = kwargs.setdefault("headers", {}) - for key, value in self.get_headers().items(): - if key not in headers: - headers[key] = value - post_func = self._base_functions_mapping[request_type] - else: - post_func = self._session_functions_mapping[request_type] + close_session = True + session, _ = self._create_new_session() if not chunk_size: chunk_size = self.default_upload_chunk_size - response = post_func( - url, - data=self._upload_chunks_iter(stream, progress, chunk_size), - **kwargs - ) + try: + response = session.request( + request_type, + url, + data=self._upload_chunks_iter(stream, progress, chunk_size), + **kwargs + ) + finally: + if close_session: + session.close() response.raise_for_status() return response