diff --git a/docs/guides/code_examples/request_throttling/throttling_example.py b/docs/guides/code_examples/request_throttling/throttling_example.py new file mode 100644 index 0000000000..e46ccb647f --- /dev/null +++ b/docs/guides/code_examples/request_throttling/throttling_example.py @@ -0,0 +1,41 @@ +import asyncio + +from crawlee.crawlers import BasicCrawler, BasicCrawlingContext +from crawlee.request_loaders import ThrottlingRequestManager +from crawlee.storages import RequestQueue + + +async def main() -> None: + # Open the default request queue. + queue = await RequestQueue.open() + + # Wrap it with ThrottlingRequestManager for specific domains. + # The throttler uses the same storage backend as the underlying queue. + throttler = ThrottlingRequestManager( + queue, + domains=['api.example.com', 'slow-site.org'], + ) + + # Pass the throttler as the crawler's request manager. + crawler = BasicCrawler(request_manager=throttler) + + @crawler.router.default_handler + async def handler(context: BasicCrawlingContext) -> None: + context.log.info(f'Processing {context.request.url}') + + # Add requests. Listed domains are routed directly to their + # throttled sub-queues. Others go to the main queue. + await throttler.add_requests( + [ + 'https://api.example.com/data', + 'https://api.example.com/users', + 'https://slow-site.org/page1', + 'https://fast-site.com/page1', # Not throttled + ] + ) + + await crawler.run() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/docs/guides/request_throttling.mdx b/docs/guides/request_throttling.mdx new file mode 100644 index 0000000000..2504966e28 --- /dev/null +++ b/docs/guides/request_throttling.mdx @@ -0,0 +1,47 @@ +--- +id: request-throttling +title: Request throttling +description: How to throttle requests per domain using the ThrottlingRequestManager. +--- + +import ApiLink from '@site/src/components/ApiLink'; +import RunnableCodeBlock from '@site/src/components/RunnableCodeBlock'; + +import ThrottlingExample from '!!raw-loader!roa-loader!./code_examples/request_throttling/throttling_example.py'; + +When crawling websites that enforce rate limits (HTTP 429) or specify `crawl-delay` in their `robots.txt`, you need a way to throttle requests per domain without blocking unrelated domains. The `ThrottlingRequestManager` provides exactly this. + +## Overview + +The `ThrottlingRequestManager` wraps a `RequestQueue` and manages per-domain throttling. You specify which domains to throttle at initialization, and the manager automatically: + +- **Routes requests** for listed domains into dedicated sub-queues at insertion time. +- **Enforces delays** from HTTP 429 responses (exponential backoff) and `robots.txt` crawl-delay directives. +- **Schedules fairly** by fetching from the domain that has been waiting the longest. +- **Sleeps intelligently** when all configured domains are throttled, instead of busy-waiting. + +Requests for domains **not** in the configured list pass through to the main queue without any throttling. + +## Basic usage + +To use request throttling, create a `ThrottlingRequestManager` with the domains you want to throttle and pass it as the `request_manager` to your crawler: + + + {ThrottlingExample} + + +## How it works + +1. **Insertion-time routing**: When you add requests via `add_request` or `add_requests`, each request is checked against the configured domain list. Matching requests go directly into a per-domain sub-queue; all others go to the main queue. This eliminates request duplication entirely. + +2. **429 backoff**: When the crawler detects an HTTP 429 response, the `ThrottlingRequestManager` records an exponential backoff delay for that domain (starting at 2s, doubling up to 60s). If the response includes a `Retry-After` header, that value takes priority. + +3. **Crawl-delay**: If `robots.txt` specifies a `crawl-delay`, the manager enforces a minimum interval between requests to that domain. + +4. **Fair scheduling**: `fetch_next_request` sorts available sub-queues by how long each domain has been waiting, ensuring no domain is starved. + +:::tip + +The `ThrottlingRequestManager` is an opt-in feature. If you don't pass it to your crawler, requests are processed normally without any per-domain throttling. + +::: diff --git a/src/crawlee/_utils/http.py b/src/crawlee/_utils/http.py new file mode 100644 index 0000000000..be7a1fa5e4 --- /dev/null +++ b/src/crawlee/_utils/http.py @@ -0,0 +1,41 @@ +"""HTTP utility functions for Crawlee.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + + +def parse_retry_after_header(value: str | None) -> timedelta | None: + """Parse the Retry-After HTTP header value. + + The header can contain either a number of seconds or an HTTP-date. + See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After + + Args: + value: The raw Retry-After header value. + + Returns: + A timedelta representing the delay, or None if the header is missing or unparsable. + """ + if not value: + return None + + # Try parsing as integer seconds first. + try: + seconds = int(value) + return timedelta(seconds=seconds) + except ValueError: + pass + + # Try parsing as HTTP-date (e.g., "Wed, 21 Oct 2015 07:28:00 GMT"). + from email.utils import parsedate_to_datetime # noqa: PLC0415 + + try: + retry_date = parsedate_to_datetime(value) + delay = retry_date - datetime.now(retry_date.tzinfo or timezone.utc) + if delay.total_seconds() > 0: + return delay + except (ValueError, TypeError): + pass + + return None diff --git a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py index 7aafa49e2e..13bf2bd1b1 100644 --- a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py +++ b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py @@ -279,7 +279,12 @@ async def _handle_status_code_response( """ status_code = context.http_response.status_code if self._retry_on_blocked: - self._raise_for_session_blocked_status_code(context.session, status_code) + self._raise_for_session_blocked_status_code( + context.session, + status_code, + request_url=context.request.url, + retry_after_header=context.http_response.headers.get('retry-after'), + ) self._raise_for_error_status_code(status_code) yield context diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 6451d59461..b858c31720 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -45,6 +45,7 @@ ) from crawlee._utils.docs import docs_group from crawlee._utils.file import atomic_write, export_csv_to_stream, export_json_to_stream +from crawlee._utils.http import parse_retry_after_header from crawlee._utils.recurring_task import RecurringTask from crawlee._utils.robots import RobotsTxtFile from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute @@ -63,6 +64,7 @@ ) from crawlee.events._types import Event, EventCrawlerStatusData from crawlee.http_clients import ImpitHttpClient +from crawlee.request_loaders import ThrottlingRequestManager from crawlee.router import Router from crawlee.sessions import SessionPool from crawlee.statistics import Statistics, StatisticsState @@ -707,12 +709,23 @@ async def run( await self._session_pool.reset_store() request_manager = await self.get_request_manager() - if purge_request_queue and isinstance(request_manager, RequestQueue): - await request_manager.drop() - self._request_manager = await RequestQueue.open( - storage_client=self._service_locator.get_storage_client(), - configuration=self._service_locator.get_configuration(), - ) + if purge_request_queue: + if isinstance(request_manager, RequestQueue): + await request_manager.drop() + self._request_manager = await RequestQueue.open( + storage_client=self._service_locator.get_storage_client(), + configuration=self._service_locator.get_configuration(), + ) + elif isinstance(request_manager, ThrottlingRequestManager): + domains = list(request_manager._domains) # noqa: SLF001 + await request_manager.drop() + inner = await RequestQueue.open( + storage_client=self._service_locator.get_storage_client(), + configuration=self._service_locator.get_configuration(), + ) + self._request_manager = ThrottlingRequestManager( + inner, domains=domains, service_locator=self._service_locator + ) if requests is not None: await self.add_requests(requests) @@ -1442,6 +1455,10 @@ async def __run_task_function(self) -> None: await self._mark_request_as_handled(request) + # Record successful request to reset rate limit backoff for this domain. + if isinstance(request_manager, ThrottlingRequestManager): + request_manager.record_success(request.url) + if session and session.is_usable: session.mark_good() @@ -1542,16 +1559,36 @@ def _raise_for_error_status_code(self, status_code: int) -> None: if is_status_code_server_error(status_code) and not is_ignored_status: raise HttpStatusCodeError('Error status code returned', status_code) - def _raise_for_session_blocked_status_code(self, session: Session | None, status_code: int) -> None: + def _raise_for_session_blocked_status_code( + self, + session: Session | None, + status_code: int, + *, + request_url: str = '', + retry_after_header: str | None = None, + ) -> None: """Raise an exception if the given status code indicates the session is blocked. + If the status code is 429 (Too Many Requests), the domain is recorded as + rate-limited in the `ThrottlingRequestManager` for per-domain backoff. + Args: session: The session used for the request. If None, no check is performed. status_code: The HTTP status code to check. + request_url: The request URL, used for per-domain rate limit tracking. + retry_after_header: The value of the Retry-After response header, if present. Raises: SessionError: If the status code indicates the session is blocked. """ + if status_code == 429 and request_url: # noqa: PLR2004 + retry_after = parse_retry_after_header(retry_after_header) + + # _request_manager might not be initialized yet if called directly or early, + # but usually it's set in get_request_manager(). + if isinstance(self._request_manager, ThrottlingRequestManager): + self._request_manager.record_domain_delay(request_url, retry_after=retry_after) + if session is not None and session.is_blocked_status_code( status_code=status_code, ignore_http_error_status_codes=self._ignore_http_error_status_codes, @@ -1582,7 +1619,16 @@ async def _is_allowed_based_on_robots_txt_file(self, url: str) -> bool: if not self._respect_robots_txt_file: return True robots_txt_file = await self._get_robots_txt_file_for_url(url) - return not robots_txt_file or robots_txt_file.is_allowed(url) + if not robots_txt_file: + return True + + # Wire robots.txt crawl-delay into ThrottlingRequestManager + if isinstance(self._request_manager, ThrottlingRequestManager): + crawl_delay = robots_txt_file.get_crawl_delay() + if crawl_delay is not None: + self._request_manager.set_crawl_delay(url, crawl_delay) + + return robots_txt_file.is_allowed(url) async def _get_robots_txt_file_for_url(self, url: str) -> RobotsTxtFile | None: """Get the RobotsTxtFile for a given URL. diff --git a/src/crawlee/crawlers/_playwright/_playwright_crawler.py b/src/crawlee/crawlers/_playwright/_playwright_crawler.py index 6f4b2b0e9d..9db340c976 100644 --- a/src/crawlee/crawlers/_playwright/_playwright_crawler.py +++ b/src/crawlee/crawlers/_playwright/_playwright_crawler.py @@ -459,7 +459,13 @@ async def _handle_status_code_response( """ status_code = context.response.status if self._retry_on_blocked: - self._raise_for_session_blocked_status_code(context.session, status_code) + retry_after_header = context.response.headers.get('retry-after') + self._raise_for_session_blocked_status_code( + context.session, + status_code, + request_url=context.request.url, + retry_after_header=retry_after_header, + ) self._raise_for_error_status_code(status_code) yield context diff --git a/src/crawlee/request_loaders/__init__.py b/src/crawlee/request_loaders/__init__.py index c04d9aa810..6dd8cccfab 100644 --- a/src/crawlee/request_loaders/__init__.py +++ b/src/crawlee/request_loaders/__init__.py @@ -3,5 +3,13 @@ from ._request_manager import RequestManager from ._request_manager_tandem import RequestManagerTandem from ._sitemap_request_loader import SitemapRequestLoader +from ._throttling_request_manager import ThrottlingRequestManager -__all__ = ['RequestList', 'RequestLoader', 'RequestManager', 'RequestManagerTandem', 'SitemapRequestLoader'] +__all__ = [ + 'RequestList', + 'RequestLoader', + 'RequestManager', + 'RequestManagerTandem', + 'SitemapRequestLoader', + 'ThrottlingRequestManager', +] diff --git a/src/crawlee/request_loaders/_throttling_request_manager.py b/src/crawlee/request_loaders/_throttling_request_manager.py new file mode 100644 index 0000000000..7d8029ede9 --- /dev/null +++ b/src/crawlee/request_loaders/_throttling_request_manager.py @@ -0,0 +1,407 @@ +"""A request manager wrapper that enforces per-domain delays. + +Handles both HTTP 429 backoff and robots.txt crawl-delay at the scheduling layer, +routing requests for explicitly configured domains into dedicated sub-queues and +applying intelligent delay-aware scheduling. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from logging import getLogger +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +from typing_extensions import override + +from crawlee._utils.docs import docs_group +from crawlee.request_loaders._request_manager import RequestManager +from crawlee.storages import RequestQueue + +if TYPE_CHECKING: + from collections.abc import Sequence + + from crawlee._request import Request + from crawlee._service_locator import ServiceLocator + from crawlee.storage_clients.models import ProcessedRequest + +logger = getLogger(__name__) + + +@dataclass +class _DomainState: + """Tracks delay state for a single domain.""" + + domain: str + """The domain being tracked.""" + + throttled_until: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + """Earliest time the next request to this domain is allowed.""" + + consecutive_429_count: int = 0 + """Number of consecutive 429 responses (for exponential backoff).""" + + crawl_delay: timedelta | None = None + """Minimum interval between requests, used to push `throttled_until` on dispatch.""" + + +@docs_group('Request loaders') +class ThrottlingRequestManager(RequestManager): + """A request manager that wraps another and enforces per-domain delays. + + Requests for explicitly configured domains are routed into dedicated sub-queues + at insertion time — each request lives in exactly one queue, eliminating + duplication and simplifying deduplication. + + When `fetch_next_request()` is called, it returns requests from the sub-queue + whose domain has been waiting the longest. If all configured domains are + throttled, it sleeps until the earliest cooldown expires. + + Delay sources: + - HTTP 429 responses (via `record_domain_delay`) + - robots.txt crawl-delay directives (via `set_crawl_delay`) + + Example: + ```python + from crawlee.storages import RequestQueue + from crawlee.request_loaders import ThrottlingRequestManager + + queue = await RequestQueue.open() + throttler = ThrottlingRequestManager( + queue, + domains=['api.example.com', 'slow-site.org'], + ) + crawler = BasicCrawler(request_manager=throttler) + ``` + """ + + _BASE_DELAY = timedelta(seconds=2) + """Initial delay after the first 429 response from a domain.""" + + _MAX_DELAY = timedelta(seconds=60) + """Maximum delay between requests to a rate-limited domain.""" + + def __init__( + self, + inner: RequestManager, + *, + domains: Sequence[str], + service_locator: ServiceLocator | None = None, + ) -> None: + """Initialize the throttling manager. + + Args: + inner: The underlying request manager to wrap (typically a RequestQueue). + Requests for non-throttled domains are stored here. + domains: Explicit list of domain hostnames to throttle. Only requests + matching these domains will be routed to per-domain sub-queues. + service_locator: Optional service locator for creating sub-queues. + If provided, sub-queues will use its storage client and configuration, + ensuring consistency with the crawler's storage backend. + """ + self._inner = inner + self._domains = set(domains) + self._service_locator = service_locator + self._domain_states: dict[str, _DomainState] = {} + self._sub_queues: dict[str, RequestQueue] = {} + + @staticmethod + def _extract_domain(url: str) -> str: + """Extract the domain (hostname) from a URL.""" + parsed = urlparse(url) + return parsed.hostname or '' + + def _get_url_from_request(self, request: str | Request) -> str: + """Extract URL string from a request that may be a string or Request object.""" + if isinstance(request, str): + return request + return request.url + + async def _get_or_create_sub_queue(self, domain: str) -> RequestQueue: + """Get or create a per-domain sub-queue.""" + if domain not in self._sub_queues: + if self._service_locator: + self._sub_queues[domain] = await RequestQueue.open( + alias=f'throttled-{domain}', + storage_client=self._service_locator.get_storage_client(), + configuration=self._service_locator.get_configuration(), + ) + else: + self._sub_queues[domain] = await RequestQueue.open(alias=f'throttled-{domain}') + return self._sub_queues[domain] + + def _is_domain_throttled(self, domain: str) -> bool: + """Check if a domain is currently throttled.""" + state = self._domain_states.get(domain) + if state is None: + return False + return datetime.now(timezone.utc) < state.throttled_until + + def _get_earliest_available_time(self) -> datetime: + """Get the earliest time any throttled domain becomes available.""" + now = datetime.now(timezone.utc) + earliest = now + self._MAX_DELAY # Fallback upper bound. + + for domain in self._domains: + state = self._domain_states.get(domain) + if state and state.throttled_until > now and state.throttled_until < earliest: + earliest = state.throttled_until + + return earliest + + def record_domain_delay(self, url: str, *, retry_after: timedelta | None = None) -> None: + """Record a 429 Too Many Requests response for the domain of the given URL. + + Increments the consecutive 429 count and calculates the next allowed + request time using exponential backoff or the Retry-After value. + + Args: + url: The URL that received a 429 response. + retry_after: Optional delay from the Retry-After header. If provided, + it takes priority over the calculated exponential backoff. + """ + domain = self._extract_domain(url) + if not domain: + return + + now = datetime.now(timezone.utc) + if domain not in self._domain_states: + self._domain_states[domain] = _DomainState(domain=domain) + state = self._domain_states[domain] + state.consecutive_429_count += 1 + + # Calculate delay: use Retry-After if provided, otherwise exponential backoff. + delay = retry_after if retry_after is not None else self._BASE_DELAY * (2 ** (state.consecutive_429_count - 1)) + + # Cap the delay. + delay = min(delay, self._MAX_DELAY) + + state.throttled_until = now + delay + + logger.info( + f'Rate limit (429) detected for domain "{domain}" ' + f'(consecutive: {state.consecutive_429_count}, delay: {delay.total_seconds():.1f}s)' + ) + + def record_success(self, url: str) -> None: + """Record a successful request, resetting the backoff state for that domain. + + Args: + url: The URL that received a successful response. + """ + domain = self._extract_domain(url) + state = self._domain_states.get(domain) + + if state is not None and state.consecutive_429_count > 0: + logger.debug(f'Resetting rate limit state for domain "{domain}" after successful request') + state.consecutive_429_count = 0 + + def set_crawl_delay(self, url: str, delay_seconds: int) -> None: + """Set the robots.txt crawl-delay for a domain. + + Args: + url: A URL from the domain to throttle. + delay_seconds: The crawl-delay value in seconds. + """ + domain = self._extract_domain(url) + if not domain: + return + + if domain not in self._domain_states: + self._domain_states[domain] = _DomainState(domain=domain) + state = self._domain_states[domain] + state.crawl_delay = timedelta(seconds=delay_seconds) + + logger.debug(f'Set crawl-delay for domain "{domain}" to {delay_seconds}s') + + def _mark_domain_dispatched(self, url: str) -> None: + """Record that a request to this domain was just dispatched. + + If a crawl-delay is configured, push throttled_until forward by that amount. + """ + domain = self._extract_domain(url) + if not domain: + return + + if domain not in self._domain_states: + self._domain_states[domain] = _DomainState(domain=domain) + state = self._domain_states[domain] + + # If crawl-delay is set, enforce minimum interval by pushing throttled_until. + if state.crawl_delay is not None: + state.throttled_until = datetime.now(timezone.utc) + state.crawl_delay + + # ────────────────────────────────────────────────────── + # RequestManager interface delegation + smart scheduling + # ────────────────────────────────────────────────────── + + @override + async def drop(self) -> None: + await self._inner.drop() + for sq in self._sub_queues.values(): + await sq.drop() + self._sub_queues.clear() + + @override + async def add_request(self, request: str | Request, *, forefront: bool = False) -> ProcessedRequest: + """Add a request, routing it to the appropriate queue. + + Requests for explicitly configured domains are routed directly to their + per-domain sub-queue. All other requests go to the inner queue. + """ + url = self._get_url_from_request(request) + domain = self._extract_domain(url) + + if domain in self._domains: + sq = await self._get_or_create_sub_queue(domain) + return await sq.add_request(request, forefront=forefront) + + return await self._inner.add_request(request, forefront=forefront) + + @override + async def add_requests( + self, + requests: Sequence[str | Request], + *, + forefront: bool = False, + batch_size: int = 1000, + wait_time_between_batches: timedelta = timedelta(seconds=1), + wait_for_all_requests_to_be_added: bool = False, + wait_for_all_requests_to_be_added_timeout: timedelta | None = None, + ) -> None: + """Add multiple requests, routing each to the appropriate queue.""" + inner_requests: list[str | Request] = [] + domain_requests: dict[str, list[str | Request]] = {} + + for request in requests: + url = self._get_url_from_request(request) + domain = self._extract_domain(url) + + if domain in self._domains: + domain_requests.setdefault(domain, []).append(request) + else: + inner_requests.append(request) + + # Add non-throttled requests to inner queue. + if inner_requests: + await self._inner.add_requests( + inner_requests, + forefront=forefront, + batch_size=batch_size, + wait_time_between_batches=wait_time_between_batches, + wait_for_all_requests_to_be_added=wait_for_all_requests_to_be_added, + wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout, + ) + + # Add throttled requests to their respective sub-queues. + for domain, reqs in domain_requests.items(): + sq = await self._get_or_create_sub_queue(domain) + await sq.add_requests( + reqs, + forefront=forefront, + batch_size=batch_size, + wait_time_between_batches=wait_time_between_batches, + wait_for_all_requests_to_be_added=wait_for_all_requests_to_be_added, + wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout, + ) + + @override + async def reclaim_request(self, request: Request, *, forefront: bool = False) -> ProcessedRequest | None: + domain = self._extract_domain(request.url) + if domain in self._domains and domain in self._sub_queues: + return await self._sub_queues[domain].reclaim_request(request, forefront=forefront) + return await self._inner.reclaim_request(request, forefront=forefront) + + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + domain = self._extract_domain(request.url) + if domain in self._domains and domain in self._sub_queues: + return await self._sub_queues[domain].mark_request_as_handled(request) + return await self._inner.mark_request_as_handled(request) + + @override + async def get_handled_count(self) -> int: + count = await self._inner.get_handled_count() + for sq in self._sub_queues.values(): + count += await sq.get_handled_count() + return count + + @override + async def get_total_count(self) -> int: + count = await self._inner.get_total_count() + for sq in self._sub_queues.values(): + count += await sq.get_total_count() + return count + + @override + async def is_empty(self) -> bool: + if not await self._inner.is_empty(): + return False + for sq in self._sub_queues.values(): + if not await sq.is_empty(): + return False + return True + + @override + async def is_finished(self) -> bool: + if not await self._inner.is_finished(): + return False + for sq in self._sub_queues.values(): + if not await sq.is_finished(): + return False + return True + + @override + async def fetch_next_request(self) -> Request | None: + """Fetch the next request, respecting per-domain delays. + + Sub-queues are checked in order of longest-overdue domain first + (sorted by `throttled_until` ascending). If all configured domains are + throttled, falls back to the inner queue. If everything is throttled, + sleeps until the earliest domain becomes available. + """ + # Collect unthrottled domains and sort by throttled_until (longest-overdue first). + available_domains = [ + domain for domain in self._domains if domain in self._sub_queues and not self._is_domain_throttled(domain) + ] + min_time = datetime.min.replace(tzinfo=timezone.utc) + available_domains.sort( + key=lambda d: self._domain_states[d].throttled_until if d in self._domain_states else min_time, + ) + + for domain in available_domains: + sq = self._sub_queues[domain] + req = await sq.fetch_next_request() + if req: + self._mark_domain_dispatched(req.url) + return req + + # Try fetching from the inner queue (non-throttled domains). + request = await self._inner.fetch_next_request() + if request is not None: + return request + + # No requests in inner queue. Check if any sub-queues still have requests. + have_sq_requests = False + for sq in self._sub_queues.values(): + if not await sq.is_empty(): + have_sq_requests = True + break + + if have_sq_requests: + # Requests exist but all domains are throttled. Sleep and retry. + earliest = self._get_earliest_available_time() + sleep_duration = max( + (earliest - datetime.now(timezone.utc)).total_seconds(), + 0.1, # Minimum sleep to avoid tight loops. + ) + logger.debug( + f'All configured domains are throttled. ' + f'Sleeping {sleep_duration:.1f}s until earliest domain is available.' + ) + await asyncio.sleep(sleep_duration) + return await self.fetch_next_request() + + return None diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index 23ca3c1eca..2b3ed72ccd 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -1238,7 +1238,8 @@ async def test_crawler_uses_default_storages(tmp_path: Path) -> None: assert dataset is await crawler.get_dataset() assert kvs is await crawler.get_key_value_store() - assert rq is await crawler.get_request_manager() + manager = await crawler.get_request_manager() + assert manager is rq async def test_crawler_can_use_other_storages(tmp_path: Path) -> None: @@ -1256,7 +1257,8 @@ async def test_crawler_can_use_other_storages(tmp_path: Path) -> None: assert dataset is not await crawler.get_dataset() assert kvs is not await crawler.get_key_value_store() - assert rq is not await crawler.get_request_manager() + manager = await crawler.get_request_manager() + assert manager is not rq async def test_crawler_can_use_other_storages_of_same_type(tmp_path: Path) -> None: @@ -1293,7 +1295,8 @@ async def test_crawler_can_use_other_storages_of_same_type(tmp_path: Path) -> No # Assert that the storages are different assert dataset is not await crawler.get_dataset() assert kvs is not await crawler.get_key_value_store() - assert rq is not await crawler.get_request_manager() + manager = await crawler.get_request_manager() + assert manager is not rq # Assert that all storages exists on the filesystem for path in expected_paths: diff --git a/tests/unit/test_throttling_request_manager.py b/tests/unit/test_throttling_request_manager.py new file mode 100644 index 0000000000..9a0b6a67ef --- /dev/null +++ b/tests/unit/test_throttling_request_manager.py @@ -0,0 +1,471 @@ +"""Tests for ThrottlingRequestManager - per-domain delay scheduling.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, cast +from unittest.mock import AsyncMock, patch + +if TYPE_CHECKING: + from collections.abc import Iterator + +import pytest + +from crawlee._request import Request +from crawlee._utils.http import parse_retry_after_header +from crawlee.request_loaders._throttling_request_manager import ThrottlingRequestManager + +THROTTLED_DOMAIN = 'throttled.com' +NON_THROTTLED_DOMAIN = 'free.com' +TEST_DOMAINS = [THROTTLED_DOMAIN] + + +@pytest.fixture +def mock_inner() -> AsyncMock: + """Create a mock RequestManager to wrap.""" + inner = AsyncMock() + inner.fetch_next_request = AsyncMock(return_value=None) + inner.add_request = AsyncMock() + inner.add_requests = AsyncMock() + inner.reclaim_request = AsyncMock() + inner.mark_request_as_handled = AsyncMock() + inner.get_handled_count = AsyncMock(return_value=0) + inner.get_total_count = AsyncMock(return_value=0) + inner.is_empty = AsyncMock(return_value=True) + inner.is_finished = AsyncMock(return_value=True) + inner.drop = AsyncMock() + return inner + + +@pytest.fixture +def manager(mock_inner: AsyncMock) -> ThrottlingRequestManager: + """Create a ThrottlingRequestManager wrapping the mock with test domains.""" + return ThrottlingRequestManager(mock_inner, domains=TEST_DOMAINS) + + +@pytest.fixture(autouse=True) +def mock_request_queue_open() -> Iterator[AsyncMock]: + """Mock RequestQueue.open to avoid hitting real storage during tests.""" + target = 'crawlee.request_loaders._throttling_request_manager.RequestQueue.open' + with patch(target, new_callable=AsyncMock) as mocked: + + async def mock_open(*_args: Any, **_kwargs: Any) -> AsyncMock: + sq = AsyncMock() + sq.fetch_next_request = AsyncMock(return_value=None) + sq.add_request = AsyncMock() + sq.add_requests = AsyncMock() + sq.reclaim_request = AsyncMock() + sq.mark_request_as_handled = AsyncMock() + sq.get_handled_count = AsyncMock(return_value=0) + sq.get_total_count = AsyncMock(return_value=0) + sq.is_empty = AsyncMock(return_value=True) + sq.is_finished = AsyncMock(return_value=True) + sq.drop = AsyncMock() + return sq + + mocked.side_effect = mock_open + yield mocked + + +def _make_request(url: str) -> Request: + """Helper to create a Request object.""" + return Request.from_url(url) + + +# ── Request Routing Tests ───────────────────────────────── + + +@pytest.mark.asyncio +async def test_add_request_routes_listed_domain_to_sub_queue( + manager: ThrottlingRequestManager, + mock_inner: AsyncMock, + mock_request_queue_open: AsyncMock, +) -> None: + """Requests for listed domains should be routed to their sub-queue, not inner.""" + request = _make_request(f'https://{THROTTLED_DOMAIN}/page1') + await manager.add_request(request) + + mock_request_queue_open.assert_called_once() + assert THROTTLED_DOMAIN in manager._sub_queues + sq = manager._sub_queues[THROTTLED_DOMAIN] + cast('AsyncMock', sq.add_request).assert_called_once_with(request, forefront=False) + mock_inner.add_request.assert_not_called() + + +@pytest.mark.asyncio +async def test_add_request_routes_non_listed_domain_to_inner( + manager: ThrottlingRequestManager, + mock_inner: AsyncMock, +) -> None: + """Requests for non-listed domains should go to the inner queue.""" + request = _make_request(f'https://{NON_THROTTLED_DOMAIN}/page1') + await manager.add_request(request) + + mock_inner.add_request.assert_called_once_with(request, forefront=False) + assert NON_THROTTLED_DOMAIN not in manager._sub_queues + + +@pytest.mark.asyncio +async def test_add_request_with_string_url( + manager: ThrottlingRequestManager, + mock_request_queue_open: AsyncMock, +) -> None: + """add_request should also work when given a plain URL string.""" + url = f'https://{THROTTLED_DOMAIN}/page1' + await manager.add_request(url) + + mock_request_queue_open.assert_called_once() + sq = manager._sub_queues[THROTTLED_DOMAIN] + cast('AsyncMock', sq.add_request).assert_called_once_with(url, forefront=False) + + +@pytest.mark.asyncio +async def test_add_requests_routes_mixed_domains( + manager: ThrottlingRequestManager, + mock_inner: AsyncMock, +) -> None: + """add_requests should split requests by domain and route them correctly.""" + throttled_req = _make_request(f'https://{THROTTLED_DOMAIN}/page1') + free_req = _make_request(f'https://{NON_THROTTLED_DOMAIN}/page1') + + await manager.add_requests([throttled_req, free_req]) + + # Inner gets only the non-listed domain request + mock_inner.add_requests.assert_called_once() + inner_call_args = mock_inner.add_requests.call_args + assert free_req in inner_call_args[0][0] + + # Sub-queue gets the listed domain request + assert THROTTLED_DOMAIN in manager._sub_queues + + +# ── Core Throttling Tests ───────────────────────────────── + + +@pytest.mark.asyncio +async def test_429_triggers_domain_delay(manager: ThrottlingRequestManager) -> None: + """After record_domain_delay(), the domain should be throttled.""" + manager.record_domain_delay(f'https://{THROTTLED_DOMAIN}/page1') + assert manager._is_domain_throttled(THROTTLED_DOMAIN) + + +@pytest.mark.asyncio +async def test_different_domains_independent(manager: ThrottlingRequestManager) -> None: + """Throttling one domain should NOT affect other domains.""" + manager.record_domain_delay(f'https://{THROTTLED_DOMAIN}/page1') + assert manager._is_domain_throttled(THROTTLED_DOMAIN) + assert not manager._is_domain_throttled(NON_THROTTLED_DOMAIN) + + +@pytest.mark.asyncio +async def test_exponential_backoff(manager: ThrottlingRequestManager) -> None: + """Consecutive 429s should increase delay exponentially.""" + url = f'https://{THROTTLED_DOMAIN}/page1' + + manager.record_domain_delay(url) + state = manager._domain_states[THROTTLED_DOMAIN] + first_until = state.throttled_until + + manager.record_domain_delay(url) + second_until = state.throttled_until + + assert second_until > first_until + assert state.consecutive_429_count == 2 + + +@pytest.mark.asyncio +async def test_max_delay_cap(manager: ThrottlingRequestManager) -> None: + """Backoff should cap at _MAX_DELAY (60s).""" + url = f'https://{THROTTLED_DOMAIN}/page1' + + for _ in range(20): + manager.record_domain_delay(url) + + state = manager._domain_states[THROTTLED_DOMAIN] + now = datetime.now(timezone.utc) + actual_delay = state.throttled_until - now + + assert actual_delay <= manager._MAX_DELAY + timedelta(seconds=1) + + +@pytest.mark.asyncio +async def test_retry_after_header_priority(manager: ThrottlingRequestManager) -> None: + """Explicit Retry-After should override exponential backoff.""" + url = f'https://{THROTTLED_DOMAIN}/page1' + + manager.record_domain_delay(url, retry_after=timedelta(seconds=30)) + + state = manager._domain_states[THROTTLED_DOMAIN] + now = datetime.now(timezone.utc) + actual_delay = state.throttled_until - now + + assert actual_delay > timedelta(seconds=28) + assert actual_delay <= timedelta(seconds=31) + + +@pytest.mark.asyncio +async def test_success_resets_backoff(manager: ThrottlingRequestManager) -> None: + """Successful request should reset the consecutive 429 count.""" + url = f'https://{THROTTLED_DOMAIN}/page1' + + manager.record_domain_delay(url) + manager.record_domain_delay(url) + assert manager._domain_states[THROTTLED_DOMAIN].consecutive_429_count == 2 + + manager.record_success(url) + assert manager._domain_states[THROTTLED_DOMAIN].consecutive_429_count == 0 + + +# ── Crawl-Delay Integration Tests ───────────────────────── + + +@pytest.mark.asyncio +async def test_crawl_delay_integration(manager: ThrottlingRequestManager) -> None: + """set_crawl_delay() should record the delay for the domain.""" + url = f'https://{THROTTLED_DOMAIN}/page1' + manager.set_crawl_delay(url, 5) + + state = manager._domain_states[THROTTLED_DOMAIN] + assert state.crawl_delay == timedelta(seconds=5) + + +@pytest.mark.asyncio +async def test_crawl_delay_throttles_after_dispatch(manager: ThrottlingRequestManager) -> None: + """After dispatching a request, crawl-delay should throttle the next one.""" + url = f'https://{THROTTLED_DOMAIN}/page1' + manager.set_crawl_delay(url, 5) + + manager._mark_domain_dispatched(url) + + assert manager._is_domain_throttled(THROTTLED_DOMAIN) + + +# ── Fetch Scheduling Tests ──────────────────────────── + + +@pytest.mark.asyncio +async def test_fetch_from_unthrottled_sub_queue( + manager: ThrottlingRequestManager, + mock_inner: AsyncMock, +) -> None: + """fetch_next_request should return from an unthrottled sub-queue.""" + request = _make_request(f'https://{THROTTLED_DOMAIN}/page1') + + sq = AsyncMock() + sq.fetch_next_request = AsyncMock(return_value=request) + manager._sub_queues[THROTTLED_DOMAIN] = sq + + result = await manager.fetch_next_request() + + assert result is not None + assert result.url == f'https://{THROTTLED_DOMAIN}/page1' + mock_inner.fetch_next_request.assert_not_called() + + +@pytest.mark.asyncio +async def test_fetch_falls_back_to_inner( + manager: ThrottlingRequestManager, + mock_inner: AsyncMock, +) -> None: + """When sub-queues are empty, should return from inner queue.""" + request = _make_request(f'https://{NON_THROTTLED_DOMAIN}/page1') + mock_inner.fetch_next_request.return_value = request + + result = await manager.fetch_next_request() + + assert result is not None + assert result.url == f'https://{NON_THROTTLED_DOMAIN}/page1' + + +@pytest.mark.asyncio +async def test_fetch_skips_throttled_sub_queue( + manager: ThrottlingRequestManager, + mock_inner: AsyncMock, +) -> None: + """Should skip throttled sub-queues and fall through to inner.""" + manager.record_domain_delay(f'https://{THROTTLED_DOMAIN}/page1') + + sq = AsyncMock() + sq.fetch_next_request = AsyncMock(return_value=_make_request(f'https://{THROTTLED_DOMAIN}/page1')) + sq.is_empty = AsyncMock(return_value=False) + manager._sub_queues[THROTTLED_DOMAIN] = sq + + inner_req = _make_request(f'https://{NON_THROTTLED_DOMAIN}/page1') + mock_inner.fetch_next_request.return_value = inner_req + + result = await manager.fetch_next_request() + + assert result is not None + assert result.url == f'https://{NON_THROTTLED_DOMAIN}/page1' + + +@pytest.mark.asyncio +async def test_sleep_when_all_throttled(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """When all domains are throttled and inner is empty, should sleep and retry.""" + request = _make_request(f'https://{THROTTLED_DOMAIN}/page1') + manager.record_domain_delay(f'https://{THROTTLED_DOMAIN}/page1', retry_after=timedelta(seconds=0.2)) + + sq = AsyncMock() + sq.is_empty = AsyncMock(return_value=False) + sq.fetch_next_request = AsyncMock(return_value=request) + manager._sub_queues[THROTTLED_DOMAIN] = sq + + mock_inner.fetch_next_request.return_value = None + + target = 'crawlee.request_loaders._throttling_request_manager.asyncio.sleep' + with patch(target, new_callable=AsyncMock) as mock_sleep: + + async def sleep_side_effect(*_args: Any, **_kwargs: Any) -> None: + manager._domain_states[THROTTLED_DOMAIN].throttled_until = datetime.now(timezone.utc) + + mock_sleep.side_effect = sleep_side_effect + + result = await manager.fetch_next_request() + + mock_sleep.assert_called_once() + assert result is not None + assert result.url == f'https://{THROTTLED_DOMAIN}/page1' + + +# ── Delegation Tests ──────────────────────────────────── + + +@pytest.mark.asyncio +async def test_reclaim_request_routes_to_sub_queue( + manager: ThrottlingRequestManager, + mock_inner: AsyncMock, +) -> None: + """reclaim_request should route to sub-queue for listed domains.""" + request = _make_request(f'https://{THROTTLED_DOMAIN}/page1') + sq = AsyncMock() + manager._sub_queues[THROTTLED_DOMAIN] = sq + + await manager.reclaim_request(request) + + sq.reclaim_request.assert_called_once_with(request, forefront=False) + mock_inner.reclaim_request.assert_not_called() + + +@pytest.mark.asyncio +async def test_reclaim_request_routes_to_inner( + manager: ThrottlingRequestManager, + mock_inner: AsyncMock, +) -> None: + """reclaim_request should route to inner for non-listed domains.""" + request = _make_request(f'https://{NON_THROTTLED_DOMAIN}/page1') + + await manager.reclaim_request(request) + + mock_inner.reclaim_request.assert_called_once_with(request, forefront=False) + + +@pytest.mark.asyncio +async def test_mark_request_as_handled_routes_to_sub_queue( + manager: ThrottlingRequestManager, + mock_inner: AsyncMock, +) -> None: + """mark_request_as_handled should route to sub-queue for listed domains.""" + request = _make_request(f'https://{THROTTLED_DOMAIN}/page1') + sq = AsyncMock() + manager._sub_queues[THROTTLED_DOMAIN] = sq + + await manager.mark_request_as_handled(request) + + sq.mark_request_as_handled.assert_called_once_with(request) + mock_inner.mark_request_as_handled.assert_not_called() + + +@pytest.mark.asyncio +async def test_mark_request_as_handled_routes_to_inner( + manager: ThrottlingRequestManager, + mock_inner: AsyncMock, +) -> None: + """mark_request_as_handled should route to inner for non-listed domains.""" + request = _make_request(f'https://{NON_THROTTLED_DOMAIN}/page1') + + await manager.mark_request_as_handled(request) + + mock_inner.mark_request_as_handled.assert_called_once_with(request) + + +@pytest.mark.asyncio +async def test_get_handled_count_aggregates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """get_handled_count should sum inner and all sub-queues.""" + mock_inner.get_handled_count.return_value = 42 + + sq = AsyncMock() + sq.get_handled_count.return_value = 10 + manager._sub_queues[THROTTLED_DOMAIN] = sq + + assert await manager.get_handled_count() == 52 + + +@pytest.mark.asyncio +async def test_get_total_count_aggregates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """get_total_count should sum inner and all sub-queues.""" + mock_inner.get_total_count.return_value = 100 + + sq = AsyncMock() + sq.get_total_count.return_value = 20 + manager._sub_queues[THROTTLED_DOMAIN] = sq + + assert await manager.get_total_count() == 120 + + +@pytest.mark.asyncio +async def test_is_empty_aggregates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """is_empty should return False if any queue has requests.""" + mock_inner.is_empty.return_value = True + assert await manager.is_empty() is True + + sq = AsyncMock() + sq.is_empty.return_value = False + manager._sub_queues[THROTTLED_DOMAIN] = sq + + assert await manager.is_empty() is False + + +@pytest.mark.asyncio +async def test_is_finished_aggregates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """is_finished should return False if any queue is not finished.""" + mock_inner.is_finished.return_value = True + assert await manager.is_finished() is True + + sq = AsyncMock() + sq.is_finished.return_value = False + manager._sub_queues[THROTTLED_DOMAIN] = sq + + assert await manager.is_finished() is False + + +@pytest.mark.asyncio +async def test_drop_clears_all(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """drop should clear inner, all sub-queues, and internal state.""" + sq = AsyncMock() + manager._sub_queues[THROTTLED_DOMAIN] = sq + + await manager.drop() + + mock_inner.drop.assert_called_once() + sq.drop.assert_called_once() + assert len(manager._sub_queues) == 0 + + +# ── Utility Tests ────────────────────────────────────── + + +def test_parse_retry_after_none_value() -> None: + assert parse_retry_after_header(None) is None + + +def test_parse_retry_after_empty_string() -> None: + assert parse_retry_after_header('') is None + + +def test_parse_retry_after_integer_seconds() -> None: + result = parse_retry_after_header('120') + assert result == timedelta(seconds=120) + + +def test_parse_retry_after_invalid_value() -> None: + assert parse_retry_after_header('not-a-date-or-number') is None