diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a621c47..0c619ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,17 +28,17 @@ repos: - --in-place - --remove-all-unused-imports - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 + rev: 7.3.0 hooks: - id: flake8 additional_dependencies: - - pycodestyle==2.10.0 - - pyflakes==3.0.1 + - pycodestyle + - pyflakes # - flake8-docstrings==1.6.0 # - pydocstyle==6.2.3 - - flake8-comprehensions==3.10.1 - - flake8-noqa==1.3.0 - - mccabe==0.7.0 + - flake8-comprehensions + - flake8-noqa + - mccabe files: ^(custom_components)/.+\.py$ - repo: https://github.com/PyCQA/bandit rev: 1.7.7 diff --git a/CHANGELOG.md b/CHANGELOG.md index 50820d0..b394f1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ # Changelog +## v1.0.24 + +- Implement exponential backoff for reconnection attempts to prevent aggressive retries + - Initial retry after 1 minute, increasing to 2, 4, 8, and maximum 15 minutes + - Reduces server load during extended connection issues +- Add AWS IoT credential caching with 2-hour validity + - Credentials are cached and reused for 1h50m after successful fetch + - Eliminates redundant token API calls when credentials are still valid +- Implement rate limiting for token endpoint calls + - Minimum 5-minute interval between token API requests + - Protects against API rate limiting during recurring disconnection issues +- Reduce token API calls from ~20-30 per hour during outages to ~1 per hour maximum +- Fix excessive API calls to Maytronics token endpoint during recurring disconnections +- Fix integration-version header value + ## v1.0.23 - Remove navigate service (`mydolphin_plus.navigate`) diff --git a/custom_components/mydolphin_plus/common/consts.py b/custom_components/mydolphin_plus/common/consts.py index a7019f8..743ad26 100644 --- a/custom_components/mydolphin_plus/common/consts.py +++ b/custom_components/mydolphin_plus/common/consts.py @@ -126,6 +126,23 @@ API_RECONNECT_INTERVAL = timedelta(minutes=1) WS_RECONNECT_INTERVAL = timedelta(minutes=1) +# Reconnection backoff settings +RECONNECT_BACKOFF_BASE = timedelta(minutes=1) # Initial retry interval +RECONNECT_BACKOFF_MAX = timedelta(minutes=15) # Maximum backoff time +RECONNECT_MAX_ATTEMPTS_BEFORE_MAX = ( + 4 # Attempts before reaching max (1, 2, 4, 8, 15 min pattern) +) + +# AWS credential caching +AWS_CREDENTIALS_TTL = timedelta( + hours=1, minutes=50 +) # AWS IoT credentials valid for 2h, use 1h50m for safety +AWS_CREDENTIALS_EXPIRY = "aws_credentials_expiry" + +# Rate limiting for token fetches +MIN_TOKEN_FETCH_INTERVAL = timedelta(minutes=5) # Minimum time between token API calls +STORAGE_DATA_LAST_TOKEN_FETCH = "last-token-fetch" + WS_LAST_UPDATE = "last-update" BASE_API = "https://mbapp18.maytronics.com/api" @@ -309,6 +326,8 @@ STORAGE_DATA_API_TOKEN, STORAGE_DATA_SERIAL_NUMBER, STORAGE_DATA_MOTOR_UNIT_SERIAL, + STORAGE_DATA_LAST_TOKEN_FETCH, + AWS_CREDENTIALS_EXPIRY, ] TO_REDACT = [ diff --git a/custom_components/mydolphin_plus/common/integration_info.py b/custom_components/mydolphin_plus/common/integration_info.py new file mode 100644 index 0000000..39f9d90 --- /dev/null +++ b/custom_components/mydolphin_plus/common/integration_info.py @@ -0,0 +1,67 @@ +"""Integration information handler.""" +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from homeassistant.core import HomeAssistant + +from homeassistant.loader import async_get_integration + +_LOGGER = logging.getLogger(__name__) + + +class IntegrationInfo: + """Handles integration name and version information.""" + + _instance: IntegrationInfo | None = None + _name: str | None = None + _version: str | None = None + _user_agent: str | None = None + _initialized: bool = False + + def __new__(cls): + """Singleton pattern - return the same instance.""" + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + async def initialize(self, hass: HomeAssistant | None) -> None: + """Initialize integration info from manifest, cached at instance level.""" + if self._initialized: + return # Already initialized + + name = "unknown" + version = "unknown" + user_agent = "unknown" + is_initialized = False + + if hass is not None: + try: + integration = await async_get_integration(hass, "mydolphin_plus") + name = integration.name + version = integration.version + clean_name = name.replace(" ", "-") + user_agent = f"HA-{clean_name}/{version}" + is_initialized = True + except Exception as ex: + _LOGGER.warning(f"Failed to get integration info: {ex}") + + self._name = name + self._version = version + self._user_agent = user_agent + self._initialized = is_initialized + + def get_version(self) -> str | None: + """Get integration version.""" + return self._version + + def set_user_agent(self, headers: dict) -> None: + """Set User-Agent header if name and version are available. + + Args: + headers: Headers dict to modify. Will set User-Agent if name/version are valid. + """ + if self._initialized: + headers["User-Agent"] = self._user_agent diff --git a/custom_components/mydolphin_plus/managers/config_manager.py b/custom_components/mydolphin_plus/managers/config_manager.py index 94c1bc8..8b9f297 100644 --- a/custom_components/mydolphin_plus/managers/config_manager.py +++ b/custom_components/mydolphin_plus/managers/config_manager.py @@ -19,12 +19,14 @@ get_clean_mode_cycle_time_key, ) from ..common.consts import ( + AWS_CREDENTIALS_EXPIRY, CONFIGURATION_FILE, DEFAULT_NAME, DOMAIN, INVALID_TOKEN_SECTION, STORAGE_DATA_API_TOKEN, STORAGE_DATA_AWS_TOKEN, + STORAGE_DATA_LAST_TOKEN_FETCH, STORAGE_DATA_LOCATING, STORAGE_DATA_MOTOR_UNIT_SERIAL, STORAGE_DATA_SERIAL_NUMBER, @@ -124,6 +126,16 @@ def motor_unit_serial(self) -> str | None: return motor_unit_serial + @property + def last_token_fetch(self) -> float: + timestamp = self._data.get(STORAGE_DATA_LAST_TOKEN_FETCH, 0) + return timestamp + + @property + def aws_credentials_expiry(self) -> float: + expiry = self._data.get(AWS_CREDENTIALS_EXPIRY, 0) + return expiry + @property def _token_details(self): token_details = { @@ -264,6 +276,14 @@ async def update_is_locating(self, state: bool): await self._save() + async def update_last_token_fetch(self, timestamp: float): + self._data[STORAGE_DATA_LAST_TOKEN_FETCH] = timestamp + await self._save() + + async def update_aws_credentials_expiry(self, expiry: float): + self._data[AWS_CREDENTIALS_EXPIRY] = expiry + await self._save() + def get_debug_data(self) -> dict: data = self._config_data.to_dict() diff --git a/custom_components/mydolphin_plus/managers/coordinator.py b/custom_components/mydolphin_plus/managers/coordinator.py index d67836c..9961b2f 100644 --- a/custom_components/mydolphin_plus/managers/coordinator.py +++ b/custom_components/mydolphin_plus/managers/coordinator.py @@ -33,7 +33,6 @@ from ..common.clean_modes import CleanModes, get_clean_mode_cycle_time_key from ..common.connectivity_status import ConnectivityStatus from ..common.consts import ( - API_RECONNECT_INTERVAL, ATTR_ACTIONS, ATTR_ATTRIBUTES, ATTR_EXPECTED_END_TIME, @@ -102,6 +101,7 @@ LED_MODE_ICON_DEFAULT, MANUFACTURER, PLATFORMS, + RECONNECT_BACKOFF_MAX, SIGNAL_API_STATUS, SIGNAL_AWS_CLIENT_STATUS, UPDATE_API_INTERVAL, @@ -149,6 +149,7 @@ def __init__(self, hass, config_manager: ConfigManager): self._last_update_api = 0 self._last_update_ws = 0 + self._reconnection_attempts = 0 self._load_signal_handlers() @@ -261,6 +262,8 @@ async def _on_api_status_changed(self, entry_id: str, status: ConnectivityStatus return if status == ConnectivityStatus.CONNECTED: + self._reconnection_attempts = 0 # Reset backoff counter on success + await self._api.update() await self._aws_client.update_api_data(self.api_data) @@ -281,6 +284,7 @@ async def _on_aws_client_status_changed( return if status == ConnectivityStatus.CONNECTED: + self._reconnection_attempts = 0 # Reset backoff counter on success await self._aws_client.update() if status in [ConnectivityStatus.FAILED, ConnectivityStatus.NOT_CONNECTED]: @@ -289,8 +293,20 @@ async def _on_aws_client_status_changed( async def _handle_connection_failure(self): await self._aws_client.terminate() - await sleep(API_RECONNECT_INTERVAL.total_seconds()) + # Calculate exponential backoff: 1min, 2min, 4min, 8min, 15min (max) + backoff_minutes = min( + 2**self._reconnection_attempts, RECONNECT_BACKOFF_MAX.total_seconds() / 60 + ) + backoff_interval = timedelta(minutes=backoff_minutes) + + self._reconnection_attempts += 1 + + _LOGGER.warning( + f"Connection failure - reconnection attempt #{self._reconnection_attempts}, " + f"waiting {backoff_minutes} minute(s) before retry" + ) + await sleep(backoff_interval.total_seconds()) await self._api.initialize() async def _async_update_data(self): diff --git a/custom_components/mydolphin_plus/managers/rest_api.py b/custom_components/mydolphin_plus/managers/rest_api.py index 3b73412..0a5e2a0 100644 --- a/custom_components/mydolphin_plus/managers/rest_api.py +++ b/custom_components/mydolphin_plus/managers/rest_api.py @@ -1,8 +1,9 @@ from __future__ import annotations -from asyncio import sleep -from base64 import b64encode +from base64 import urlsafe_b64encode +from datetime import datetime import hashlib +import json import logging import secrets import sys @@ -31,6 +32,7 @@ API_RESPONSE_STATUS_SUCCESS, API_RESPONSE_UNIT_SERIAL_NUMBER, API_TOKEN_FIELDS, + AWS_CREDENTIALS_TTL, BLOCK_SIZE, DATA_ROBOT_DETAILS, DEFAULT_NAME, @@ -38,12 +40,14 @@ FORGOT_PASSWORD_URL, LOGIN_HEADERS, LOGIN_URL, + MIN_TOKEN_FETCH_INTERVAL, ROBOT_DETAILS_BY_SN_URL, ROBOT_DETAILS_URL, SIGNAL_API_STATUS, SIGNAL_DEVICE_NEW, TOKEN_URL, ) +from ..common.integration_info import IntegrationInfo from ..models.config_data import ConfigData from .config_manager import ConfigManager @@ -58,6 +62,7 @@ class RestAPI: _status: ConnectivityStatus | None _session: ClientSession | None _config_manager: ConfigManager + _integration_info: IntegrationInfo _device_loaded: bool @@ -69,6 +74,8 @@ def __init__(self, hass: HomeAssistant | None, config_manager: ConfigManager): self._config_manager = config_manager + self._integration_info = IntegrationInfo() + self._status = None self._session = None @@ -109,6 +116,8 @@ def _is_home_assistant(self): async def initialize(self): _LOGGER.info("Initializing MyDolphin API") + await self._integration_info.initialize(self._hass) + await self._initialize_session() await self._login() @@ -145,10 +154,14 @@ async def _async_post(self, url, headers: dict, request_data: str | dict | None) result = None try: + # Copy headers and set User-Agent if available + headers = headers.copy() if headers else {} + self._integration_info.set_user_agent(headers) + async with self._session.post( url, headers=headers, data=request_data, ssl=False ) as response: - _LOGGER.debug(f"Status of {url}: {response.status}") + _LOGGER.debug(f"Status of POST request to {url}: {response.status}") response.raise_for_status() @@ -173,8 +186,11 @@ async def _async_get(self, url, headers: dict): result = None try: + headers = headers.copy() if headers else {} + self._integration_info.set_user_agent(headers) + async with self._session.get(url, headers=headers, ssl=False) as response: - _LOGGER.debug(f"Status of {url}: {response.status}") + _LOGGER.debug(f"Status of GET request to {url}: {response.status}") response.raise_for_status() @@ -409,8 +425,43 @@ async def _generate_aws_token(self): await self._config_manager.update_aws_token(aws_token) + # Check if cached AWS IoT credentials are still valid + if await self._are_cached_credentials_valid(): + _LOGGER.info("Using cached AWS IoT credentials (still valid)") + self._set_status(ConnectivityStatus.CONNECTED) + return + + # Check rate limiting + now = datetime.now().timestamp() + last_fetch = self._config_manager.last_token_fetch + time_since_last = now - last_fetch + + if time_since_last < MIN_TOKEN_FETCH_INTERVAL.total_seconds(): + wait_time = MIN_TOKEN_FETCH_INTERVAL.total_seconds() - time_since_last + _LOGGER.warning( + f"Token fetch rate limited. Last fetch was {time_since_last:.0f}s ago. " + f"Need to wait {wait_time:.0f}s more. Using cached credentials if available." + ) + + # Try to use cached credentials even if expired, better than nothing + if self._has_cached_credentials(): + _LOGGER.info( + "Using potentially expired cached credentials due to rate limit" + ) + self._set_status(ConnectivityStatus.CONNECTED) + return + else: + _LOGGER.warning( + "No cached credentials available, will attempt fetch despite rate limit" + ) + + # Make API call to get fresh credentials request_data = f"{API_REQUEST_SERIAL_NUMBER}={aws_token}" + _LOGGER.info("Fetching fresh AWS IoT credentials from token endpoint") + _LOGGER.debug(f"Request data: {json.dumps(request_data)}") + _LOGGER.debug(f"Headers: {json.dumps(headers)}") + payload = await self._async_post(TOKEN_URL, headers, request_data) if self._status == ConnectivityStatus.TEMPORARY_CONNECTED: @@ -422,6 +473,18 @@ async def _generate_aws_token(self): for field in API_TOKEN_FIELDS: self.data[field] = data.get(field) + # Update timestamps + now = datetime.now().timestamp() + expiry = now + AWS_CREDENTIALS_TTL.total_seconds() + + await self._config_manager.update_last_token_fetch(now) + await self._config_manager.update_aws_credentials_expiry(expiry) + + _LOGGER.info( + f"Successfully fetched AWS IoT credentials. " + f"Valid until {datetime.fromtimestamp(expiry).isoformat()}" + ) + self._set_status(ConnectivityStatus.CONNECTED) else: @@ -439,6 +502,30 @@ async def _generate_aws_token(self): self._set_status(ConnectivityStatus.FAILED, message) + async def _are_cached_credentials_valid(self) -> bool: + """Check if cached AWS IoT credentials are still valid.""" + if not self._has_cached_credentials(): + return False + + expiry = self._config_manager.aws_credentials_expiry + now = datetime.now().timestamp() + + is_valid = expiry > now + + if is_valid: + remaining_hours = (expiry - now) / 3600 + _LOGGER.debug( + f"Cached credentials valid for {remaining_hours:.1f} more hours" + ) + else: + _LOGGER.debug("Cached credentials have expired") + + return is_valid + + def _has_cached_credentials(self) -> bool: + """Check if AWS IoT credentials exist in cache.""" + return all(self.data.get(field) is not None for field in API_TOKEN_FIELDS) + async def _load_details(self): if self._status != ConnectivityStatus.CONNECTED: return @@ -485,30 +572,24 @@ async def _get_aws_token(self) -> str | None: f"ENCRYPT: Motor Unit Serial: {self._config_manager.motor_unit_serial}" ) - for i in range(0, 10): - backend = default_backend() - iv = secrets.token_bytes(BLOCK_SIZE) - mode = modes.CBC(iv) - aes_key = self._get_aes_key() + backend = default_backend() + iv = secrets.token_bytes(BLOCK_SIZE) + mode = modes.CBC(iv) + aes_key = self._get_aes_key() - aes = algorithms.AES(aes_key) - cipher = Cipher(aes, mode, backend=backend) + aes = algorithms.AES(aes_key) + cipher = Cipher(aes, mode, backend=backend) - encryptor = cipher.encryptor() + encryptor = cipher.encryptor() - data = self._pad(self._config_manager.motor_unit_serial).encode() - ct = encryptor.update(data) + encryptor.finalize() + data = self._pad(self._config_manager.motor_unit_serial).encode() + ct = encryptor.update(data) + encryptor.finalize() - result_b64 = iv + ct + result_b64 = iv + ct - result = b64encode(result_b64).decode() + result = urlsafe_b64encode(result_b64).decode() - if "+" not in result: - return result - - await sleep(0.5) - - raise ValueError("Invalid AWS Token generated") + return result @staticmethod def _pad(text) -> str: diff --git a/custom_components/mydolphin_plus/manifest.json b/custom_components/mydolphin_plus/manifest.json index 273872e..241bb23 100644 --- a/custom_components/mydolphin_plus/manifest.json +++ b/custom_components/mydolphin_plus/manifest.json @@ -10,5 +10,5 @@ "iot_class": "cloud_push", "issue_tracker": "https://github.com/sh00t2kill/dolphin-robot/issues", "requirements": ["awsiotsdk"], - "version": "1.0.23" + "version": "1.0.24" } diff --git a/tests/genereate_aws_token_test.py b/tests/genereate_aws_token_test.py index c5971fb..3178e4c 100644 --- a/tests/genereate_aws_token_test.py +++ b/tests/genereate_aws_token_test.py @@ -1,7 +1,7 @@ """Generate AWS token test file.""" from __future__ import annotations -from base64 import b64decode, b64encode +from base64 import urlsafe_b64decode, urlsafe_b64encode import hashlib import os import secrets @@ -33,7 +33,7 @@ def __init__(self, key: str): f"Mode: {self._mode.name}, initialization_vector: {self._mode.initialization_vector}" ) - self._aes_key = self._get_key(email) + self._aes_key = self._get_key(key) print("") @staticmethod @@ -99,7 +99,7 @@ def encrypt(self, sn: str) -> str | None: result_b64 = self._iv + ct print(f"result_b64: {result_b64}, Length: {len(result_b64)}") - result = b64encode(result_b64).decode() + result = urlsafe_b64encode(result_b64).decode() print(f"result: {result}, Length: {len(result)}") print("") @@ -110,7 +110,7 @@ def decrypt(self, encrypted_data: str) -> str | None: """Do decryption of input string, Returns non encrypted data.""" print(f"DECRYPT: {encrypted_data}") - encrypted_value = b64decode(encrypted_data.encode())[BLOCK_SIZE:] + encrypted_value = urlsafe_b64decode(encrypted_data.encode())[BLOCK_SIZE:] print(f"encrypted_value: {encrypted_value}") cipher = self._get_cipher()