diff --git a/py/selenium/webdriver/common/bidi/speculation.py b/py/selenium/webdriver/common/bidi/speculation.py new file mode 100644 index 0000000000000..7d617bcf112fc --- /dev/null +++ b/py/selenium/webdriver/common/bidi/speculation.py @@ -0,0 +1,198 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import threading +from collections.abc import Callable + +from selenium.webdriver.common.bidi.session import Session + + +class PreloadingStatus: + """Represents the different types of preloading statuses. + + This status is shared by prefetches and prerenders. + """ + + PENDING = "pending" + READY = "ready" + SUCCESS = "success" + FAILURE = "failure" + + VALID_STATUSES = {PENDING, READY, SUCCESS, FAILURE} + + +class PrefetchStatusUpdatedParams: + """Parameters for the speculation.prefetchStatusUpdated event.""" + + def __init__(self, context: str, url: str, status: str): + self.context = context + self.url = url + self.status = status + + @classmethod + def from_json(cls, json: dict) -> PrefetchStatusUpdatedParams: + """Creates a PrefetchStatusUpdatedParams instance from a dictionary. + + Args: + json: A dictionary containing the prefetch status updated parameters. + + Returns: + A new instance of PrefetchStatusUpdatedParams. + + Raises: + ValueError: If required fields are missing or have invalid types. + """ + context = json.get("context") + if context is None or not isinstance(context, str): + raise ValueError("context is required and must be a string") + + url = json.get("url") + if url is None or not isinstance(url, str): + raise ValueError("url is required and must be a string") + + status = json.get("status") + if status is None or not isinstance(status, str): + raise ValueError("status is required and must be a string") + + if status not in PreloadingStatus.VALID_STATUSES: + raise ValueError(f"Invalid status: {status}. Must be one of {PreloadingStatus.VALID_STATUSES}") + + return cls( + context=context, + url=url, + status=status, + ) + + +class PrefetchStatusUpdated: + """Event class for speculation.prefetchStatusUpdated event.""" + + event_class = "speculation.prefetchStatusUpdated" + + @classmethod + def from_json(cls, json: dict): + if isinstance(json, PrefetchStatusUpdatedParams): + return json + return PrefetchStatusUpdatedParams.from_json(json) + + +class Speculation: + """BiDi implementation of the speculation module. + + The speculation module contains commands for managing the remote end + behavior for prefetches, prerenders, and speculation rules. + """ + + # Maps Python event names to (bidi_event_name, event_class) + _EVENT_MAP: dict[str, tuple[str, type]] = { + "prefetch_status_updated": ("speculation.prefetchStatusUpdated", PrefetchStatusUpdated), + } + + # Reverse mapping: BiDi event name to event class + _BIDI_TO_CLASS: dict[str, type] = { + "speculation.prefetchStatusUpdated": PrefetchStatusUpdated, + } + + def __init__(self, conn): + self.conn = conn + self._session = Session(conn) + self.subscriptions: dict[str, list[int]] = {} + self._subscription_lock = threading.Lock() + + def _validate_event(self, event: str) -> tuple[str, type]: + """Validate and resolve an event name to its BiDi event name and class. + + Args: + event: The Python event name (e.g., "prefetch_status_updated"). + + Returns: + A tuple of (bidi_event_name, event_class). + + Raises: + ValueError: If the event name is not recognized. + """ + entry = self._EVENT_MAP.get(event) + if not entry: + available = ", ".join(sorted(self._EVENT_MAP.keys())) + raise ValueError(f"Event '{event}' not found. Available events: {available}") + return entry + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler for a speculation event. + + Args: + event: The event to subscribe to (e.g., "prefetch_status_updated"). + callback: The callback function to execute on event. + contexts: Optional browsing context IDs to subscribe to. + + Returns: + Callback id for later removal. + + Raises: + ValueError: If the event name is not recognized. + """ + bidi_event, event_class = self._validate_event(event) + + callback_id = self.conn.add_callback(event_class, callback) + + with self._subscription_lock: + if bidi_event not in self.subscriptions: + self.conn.execute(self._session.subscribe(bidi_event, browsing_contexts=contexts)) + self.subscriptions[bidi_event] = [] + self.subscriptions[bidi_event].append(callback_id) + + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler for a speculation event. + + Args: + event: The event to unsubscribe from. + callback_id: The callback id to remove. + + Raises: + ValueError: If the event name is not recognized. + """ + bidi_event, event_class = self._validate_event(event) + + self.conn.remove_callback(event_class, callback_id) + + with self._subscription_lock: + callback_list = self.subscriptions.get(bidi_event) + if callback_list and callback_id in callback_list: + callback_list.remove(callback_id) + + if callback_list is not None and not callback_list: + self.conn.execute(self._session.unsubscribe(bidi_event)) + del self.subscriptions[bidi_event] + + def clear_event_handlers(self) -> None: + """Clear all event handlers from the speculation module.""" + with self._subscription_lock: + if not self.subscriptions: + return + + for bidi_event, callback_ids in list(self.subscriptions.items()): + event_class = self._BIDI_TO_CLASS.get(bidi_event) + if event_class: + for callback_id in callback_ids: + self.conn.remove_callback(event_class, callback_id) + self.conn.execute(self._session.unsubscribe(bidi_event)) + + self.subscriptions.clear() diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 898eb8d547aa6..2ede2bf633f43 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -47,6 +47,7 @@ from selenium.webdriver.common.bidi.permissions import Permissions from selenium.webdriver.common.bidi.script import Script from selenium.webdriver.common.bidi.session import Session +from selenium.webdriver.common.bidi.speculation import Speculation from selenium.webdriver.common.bidi.storage import Storage from selenium.webdriver.common.bidi.webextension import WebExtension from selenium.webdriver.common.by import By @@ -277,6 +278,7 @@ def __init__( self._browser: Browser | None = None self._bidi_session: Session | None = None self._browsing_context: BrowsingContext | None = None + self._speculation: Speculation | None = None self._storage: Storage | None = None self._webextension: WebExtension | None = None self._permissions: Permissions | None = None @@ -1196,6 +1198,35 @@ def browsing_context(self) -> BrowsingContext: return self._browsing_context + @property + def speculation(self) -> Speculation: + """Returns a speculation module object for BiDi speculation commands. + + The speculation module contains commands for managing the remote end + behavior for prefetches, prerenders, and speculation rules. + + Returns: + An object containing access to BiDi speculation events. + + Examples: + ``` + from selenium.webdriver.common.bidi.speculation import PreloadingStatus + + events = [] + callback_id = driver.speculation.add_event_handler("prefetch_status_updated", events.append) + # ... trigger prefetch ... + driver.speculation.remove_event_handler("prefetch_status_updated", callback_id) + ``` + """ + if not self._websocket_connection: + self._start_bidi() + + assert self._websocket_connection is not None + if self._speculation is None: + self._speculation = Speculation(self._websocket_connection) + + return self._speculation + @property def storage(self) -> Storage: """Returns a storage module object for BiDi storage commands. diff --git a/py/test/selenium/webdriver/common/bidi_speculation_tests.py b/py/test/selenium/webdriver/common/bidi_speculation_tests.py new file mode 100644 index 0000000000000..741969f6b0eb6 --- /dev/null +++ b/py/test/selenium/webdriver/common/bidi_speculation_tests.py @@ -0,0 +1,194 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from selenium.webdriver.common.bidi.browsing_context import ReadinessState +from selenium.webdriver.common.bidi.speculation import PreloadingStatus +from selenium.webdriver.support.ui import WebDriverWait + + +def _add_speculation_rules_and_link(driver, prefetch_url): + driver.execute_script(f"addSpeculationRulesAndLink('{prefetch_url}')") + + +def test_speculation_module_initialized(driver): + assert driver.speculation is not None + + +@pytest.mark.xfail_firefox +def test_prefetch_status_updated_with_pending_and_ready_events(driver, pages): + """Test that prefetch status updated events are received with pending and ready statuses.""" + events_received = [] + + def on_prefetch_status_updated(event): + events_received.append(event) + + callback_id = driver.speculation.add_event_handler("prefetch_status_updated", on_prefetch_status_updated) + + try: + url = pages.url("bidi/speculationRules.html") + prefetch_url = pages.url("bidi/emptyPage.html") + driver.browsing_context.navigate( + context=driver.current_window_handle, + url=url, + wait=ReadinessState.COMPLETE, + ) + + _add_speculation_rules_and_link(driver, prefetch_url) + + # Wait for at least two events (pending + ready) + WebDriverWait(driver, 10).until(lambda _: len(events_received) >= 2) + + statuses = {event.status for event in events_received} + assert PreloadingStatus.PENDING in statuses + assert PreloadingStatus.READY in statuses + + # Verify event fields + for event in events_received: + assert event.context == driver.current_window_handle + assert prefetch_url in event.url + assert event.status in PreloadingStatus.VALID_STATUSES + finally: + driver.speculation.remove_event_handler("prefetch_status_updated", callback_id) + + +@pytest.mark.xfail_firefox +def test_prefetch_status_updated_with_navigation_and_success(driver, pages): + """Test that navigating to a prefetched page via link click generates a success status event.""" + events_received = [] + + def on_prefetch_status_updated(event): + events_received.append(event) + + callback_id = driver.speculation.add_event_handler("prefetch_status_updated", on_prefetch_status_updated) + + try: + url = pages.url("bidi/speculationRules.html") + prefetch_url = pages.url("bidi/emptyPage.html") + driver.browsing_context.navigate( + context=driver.current_window_handle, + url=url, + wait=ReadinessState.COMPLETE, + ) + + _add_speculation_rules_and_link(driver, prefetch_url) + + # Wait for prefetch to be ready + WebDriverWait(driver, 10).until( + lambda _: any(event.status == PreloadingStatus.READY for event in events_received) + ) + + # Click the prefetch link to activate the prefetched resource + driver.execute_script("document.getElementById('prefetch-link').click()") + + WebDriverWait(driver, 10).until( + lambda _: any(event.status == PreloadingStatus.SUCCESS for event in events_received) + ) + + statuses = {event.status for event in events_received} + assert PreloadingStatus.SUCCESS in statuses + + success_event = next(e for e in events_received if e.status == PreloadingStatus.SUCCESS) + assert prefetch_url in success_event.url + assert success_event.context == driver.current_window_handle + finally: + driver.speculation.remove_event_handler("prefetch_status_updated", callback_id) + + +@pytest.mark.xfail_firefox +def test_prefetch_status_updated_with_failure_events(driver, pages): + """Test that a failed prefetch generates failure status events.""" + events_received = [] + + def on_prefetch_status_updated(event): + events_received.append(event) + + callback_id = driver.speculation.add_event_handler("prefetch_status_updated", on_prefetch_status_updated) + + try: + url = pages.url("bidi/speculationRules.html") + # Target a non-existent page to trigger failure + prefetch_url = pages.url("nonexistent_page_404") + driver.browsing_context.navigate( + context=driver.current_window_handle, + url=url, + wait=ReadinessState.COMPLETE, + ) + + _add_speculation_rules_and_link(driver, prefetch_url) + + # Wait for failure or pending event + WebDriverWait(driver, 10).until(lambda _: len(events_received) >= 1) + + statuses = {event.status for event in events_received} + assert statuses.issubset({PreloadingStatus.PENDING, PreloadingStatus.FAILURE}) + finally: + driver.speculation.remove_event_handler("prefetch_status_updated", callback_id) + + +@pytest.mark.xfail_firefox +def test_can_unsubscribe_from_prefetch_status_updated(driver, pages): + """Test that events are no longer received after removing the handler.""" + events_received = [] + + def on_prefetch_status_updated(event): + events_received.append(event) + + callback_id = driver.speculation.add_event_handler("prefetch_status_updated", on_prefetch_status_updated) + + try: + url = pages.url("bidi/speculationRules.html") + prefetch_url = pages.url("bidi/emptyPage.html") + driver.browsing_context.navigate( + context=driver.current_window_handle, + url=url, + wait=ReadinessState.COMPLETE, + ) + + _add_speculation_rules_and_link(driver, prefetch_url) + + # Wait for initial events + WebDriverWait(driver, 10).until(lambda _: len(events_received) >= 1) + + initial_count = len(events_received) + + # Unsubscribe from events + driver.speculation.remove_event_handler("prefetch_status_updated", callback_id) + + # Reload and trigger new speculation rules with different target + driver.browsing_context.navigate( + context=driver.current_window_handle, + url=url, + wait=ReadinessState.COMPLETE, + ) + + second_prefetch_url = pages.url("blank.html") + _add_speculation_rules_and_link(driver, second_prefetch_url) + + assert len(events_received) == initial_count + except Exception: + # Only try to remove if we haven't already + driver.speculation.clear_event_handlers() + raise + + +@pytest.mark.xfail_firefox +def test_invalid_event_raises_error(driver): + """Test that subscribing to an invalid event name raises ValueError.""" + with pytest.raises(ValueError, match="Event 'invalid_event' not found"): + driver.speculation.add_event_handler("invalid_event", lambda e: None)