From c9b4b2de02b0b717f6d5ebb0f23fd22b541ce57c Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Fri, 3 Apr 2026 18:24:00 -0400 Subject: [PATCH 1/9] add capability policy and necessary subclasses plus tests --- pyrit/prompt_target/__init__.py | 10 +- .../common/target_capabilities.py | 111 +++++++++++++++++- tests/unit/target/test_target_capabilities.py | 95 ++++++++++++++- 3 files changed, 213 insertions(+), 3 deletions(-) diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index 05af2d67d8..045b217e8a 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -12,7 +12,12 @@ from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.prompt_target import PromptTarget -from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityName, + CapabilityHandlingPolicy, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) from pyrit.prompt_target.common.utils import limit_requests_per_minute from pyrit.prompt_target.crucible_target import CrucibleTarget from pyrit.prompt_target.gandalf_target import GandalfLevel, GandalfTarget @@ -67,7 +72,10 @@ "PromptShieldTarget", "PromptTarget", "RealtimeTarget", + "CapabilityName", + "CapabilityHandlingPolicy", "TargetCapabilities", + "UnsupportedCapabilityBehavior", "TextTarget", "WebSocketCopilotTarget", ] diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index e6ced6a1a2..d450f0c727 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -1,12 +1,109 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from dataclasses import dataclass +from dataclasses import dataclass, field +from enum import Enum from typing import Optional, cast from pyrit.models import PromptDataType +class CapabilityName(str, Enum): + """ + Canonical identifiers for target capabilities. + + This keeps capability identity in one place so policy, requirements, and + normalization code do not duplicate string field names. + """ + + MULTI_TURN = "supports_multi_turn" + MULTI_MESSAGE_PIECES = "supports_multi_message_pieces" + JSON_SCHEMA = "supports_json_schema" + JSON_OUTPUT = "supports_json_output" + EDITABLE_HISTORY = "supports_editable_history" + SYSTEM_PROMPT = "supports_system_prompt" + + +NORMALIZABLE_CAPABILITIES: frozenset[CapabilityName] = frozenset( + { + CapabilityName.SYSTEM_PROMPT, + CapabilityName.MULTI_TURN, + } +) + + +class UnsupportedCapabilityBehavior(str, Enum): + """ + Defines what happens when a caller requires a capability the target does not support. + + ADAPT: apply a normalization step to work around the missing capability. + RAISE: fail immediately with an error. + """ + + ADAPT = "adapt" + RAISE = "raise" + + + +@dataclass(frozen=True) +class CapabilityHandlingPolicy: + """ + Per-capability policy consulted only when a capability is **missing**. + + Design invariants + ----------------- + * The policy is never consulted if the capability is already supported. + * Non-adaptable capabilities (e.g. ``supports_editable_history``) are not + represented here; requesting them on a target that lacks them always + raises immediately. + """ + + behaviors: dict[CapabilityName, UnsupportedCapabilityBehavior] = field( + default_factory=lambda: { + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + + def get_behavior(self, *, capability: CapabilityName) -> UnsupportedCapabilityBehavior: + """ + Return the configured handling behavior for a capability. + + Args: + capability: The capability to look up. + + Returns: + UnsupportedCapabilityBehavior: The configured behavior. + + Raises: + AttributeError: If no policy exists for the capability. + """ + try: + return self.behaviors[capability] + except KeyError as exc: + raise AttributeError(capability.value) from exc + + def __getattr__(self, name: str) -> UnsupportedCapabilityBehavior: + """ + Guard against accessing policies for non-adaptable or unknown capabilities. + + Raises: + AttributeError: If the capability is not part of this policy. + """ + for capability in CapabilityName: + if capability.value == name: + supported_names = ", ".join(sorted(cap.value for cap in self.behaviors)) + raise AttributeError( + f"'{type(self).__name__}' has no policy for '{name}'. " + f"Only the following capabilities have handling policies: " + f"{supported_names}." + ) + + raise AttributeError(name) + + @dataclass(frozen=True) class TargetCapabilities: """ @@ -47,6 +144,18 @@ class attribute. Users can override individual capabilities per instance # The output modalities supported by the target (e.g., "text", "image"). output_modalities: frozenset[frozenset[PromptDataType]] = frozenset({frozenset(["text"])}) + def supports(self, *, capability: CapabilityName) -> bool: + """ + Return whether this target supports the given capability. + + Args: + capability: The capability to check. + + Returns: + bool: True if supported, otherwise False. + """ + return bool(getattr(self, capability.value)) + @staticmethod def get_known_capabilities(underlying_model: str) -> "Optional[TargetCapabilities]": """ diff --git a/tests/unit/target/test_target_capabilities.py b/tests/unit/target/test_target_capabilities.py index 48b62297c3..b77f84af94 100644 --- a/tests/unit/target/test_target_capabilities.py +++ b/tests/unit/target/test_target_capabilities.py @@ -5,7 +5,100 @@ import pytest -from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityName, + CapabilityHandlingPolicy, + NORMALIZABLE_CAPABILITIES, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) + + +class TestCapabilityHandlingPolicy: + """Test behavior and defaults of capability handling policy classes.""" + + def test_capability_name_values(self): + assert CapabilityName.MULTI_TURN.value == "supports_multi_turn" + assert CapabilityName.MULTI_MESSAGE_PIECES.value == "supports_multi_message_pieces" + assert CapabilityName.JSON_SCHEMA.value == "supports_json_schema" + assert CapabilityName.JSON_OUTPUT.value == "supports_json_output" + assert CapabilityName.EDITABLE_HISTORY.value == "supports_editable_history" + assert CapabilityName.SYSTEM_PROMPT.value == "supports_system_prompt" + + def test_unsupported_capability_behavior_values(self): + assert UnsupportedCapabilityBehavior.ADAPT.value == "adapt" + assert UnsupportedCapabilityBehavior.RAISE.value == "raise" + + def test_capability_handling_policy_defaults(self): + policy = CapabilityHandlingPolicy() + assert policy.behaviors == { + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + + def test_capability_handling_policy_custom_values(self): + policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + + assert policy.behaviors[CapabilityName.MULTI_TURN] is UnsupportedCapabilityBehavior.ADAPT + assert policy.behaviors[CapabilityName.SYSTEM_PROMPT] is UnsupportedCapabilityBehavior.RAISE + + def test_capability_handling_policy_get_behavior(self): + policy = CapabilityHandlingPolicy() + + assert policy.get_behavior(capability=CapabilityName.MULTI_TURN) is UnsupportedCapabilityBehavior.RAISE + assert ( + policy.get_behavior(capability=CapabilityName.SYSTEM_PROMPT) is UnsupportedCapabilityBehavior.RAISE + ) + + def test_capability_handling_policy_get_behavior_for_all_supported_policy_keys(self): + policy = CapabilityHandlingPolicy() + + assert policy.get_behavior(capability=CapabilityName.JSON_SCHEMA) is UnsupportedCapabilityBehavior.RAISE + assert policy.get_behavior(capability=CapabilityName.JSON_OUTPUT) is UnsupportedCapabilityBehavior.RAISE + + def test_capability_handling_policy_rejects_capability_without_policy(self): + policy = CapabilityHandlingPolicy() + + with pytest.raises(AttributeError, match="supports_editable_history"): + policy.get_behavior(capability=CapabilityName.EDITABLE_HISTORY) + + with pytest.raises(AttributeError, match="supports_editable_history"): + _ = policy.supports_editable_history + + def test_capability_handling_policy_rejects_unknown_attribute(self): + policy = CapabilityHandlingPolicy() + + with pytest.raises(AttributeError, match="totally_unknown_attribute"): + _ = policy.totally_unknown_attribute + + def test_normalizable_capabilities(self): + assert NORMALIZABLE_CAPABILITIES == frozenset( + { + CapabilityName.MULTI_TURN, + CapabilityName.SYSTEM_PROMPT, + } + ) + + def test_target_capabilities_supports_helper(self): + capabilities = TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + supports_json_output=True, + ) + + assert capabilities.supports(capability=CapabilityName.MULTI_TURN) is True + assert capabilities.supports(capability=CapabilityName.SYSTEM_PROMPT) is False + assert capabilities.supports(capability=CapabilityName.JSON_OUTPUT) is True + assert capabilities.supports(capability=CapabilityName.EDITABLE_HISTORY) is False # Env vars that may leak from .env files loaded by other tests in parallel workers. # Clear them so that targets use _DEFAULT_CAPABILITIES instead of _KNOWN_CAPABILITIES. From 40b012e5f9c98c28d30746a2a6f07fd7343d90b3 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Mon, 6 Apr 2026 13:01:50 -0400 Subject: [PATCH 2/9] add normalization pipeline --- pyrit/message_normalizer/__init__.py | 2 + .../history_squash_normalizer.py | 67 +++++ .../common/normalization_pipeline.py | 153 ++++++++++ .../test_history_squash_normalizer.py | 98 +++++++ .../target/test_normalization_pipeline.py | 273 ++++++++++++++++++ 5 files changed, 593 insertions(+) create mode 100644 pyrit/message_normalizer/history_squash_normalizer.py create mode 100644 pyrit/prompt_target/common/normalization_pipeline.py create mode 100644 tests/unit/message_normalizer/test_history_squash_normalizer.py create mode 100644 tests/unit/target/test_normalization_pipeline.py diff --git a/pyrit/message_normalizer/__init__.py b/pyrit/message_normalizer/__init__.py index 3e46288bae..a9c9515dd0 100644 --- a/pyrit/message_normalizer/__init__.py +++ b/pyrit/message_normalizer/__init__.py @@ -8,6 +8,7 @@ from pyrit.message_normalizer.chat_message_normalizer import ChatMessageNormalizer from pyrit.message_normalizer.conversation_context_normalizer import ConversationContextNormalizer from pyrit.message_normalizer.generic_system_squash import GenericSystemSquashNormalizer +from pyrit.message_normalizer.history_squash_normalizer import HistorySquashNormalizer from pyrit.message_normalizer.message_normalizer import ( MessageListNormalizer, MessageStringNormalizer, @@ -18,6 +19,7 @@ "MessageListNormalizer", "MessageStringNormalizer", "GenericSystemSquashNormalizer", + "HistorySquashNormalizer", "TokenizerTemplateNormalizer", "ConversationContextNormalizer", "ChatMessageNormalizer", diff --git a/pyrit/message_normalizer/history_squash_normalizer.py b/pyrit/message_normalizer/history_squash_normalizer.py new file mode 100644 index 0000000000..7cd79a727f --- /dev/null +++ b/pyrit/message_normalizer/history_squash_normalizer.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from pyrit.message_normalizer.message_normalizer import MessageListNormalizer +from pyrit.models import Message + + +class HistorySquashNormalizer(MessageListNormalizer[Message]): + """ + Squashes a multi-turn conversation into a single user message. + + Previous turns are formatted as labeled context and prepended to the + latest message. Used by the normalization pipeline to adapt prompts + for targets that do not support multi-turn conversations. + """ + + async def normalize_async(self, messages: list[Message]) -> list[Message]: + """ + Combine all messages into a single user message. + + When there is only one message it is returned unchanged. Otherwise + all prior turns are formatted as ``Role: content`` lines under a + ``[Conversation History]`` header and the last message's content + appears under a ``[Current Message]`` header. + + Args: + messages: The conversation messages to squash. + + Returns: + list[Message]: A single-element list containing the squashed message. + + Raises: + ValueError: If the messages list is empty. + """ + if not messages: + raise ValueError("Messages list cannot be empty") + + if len(messages) == 1: + return list(messages) + + history_lines = self._format_history(messages=messages[:-1]) + current_parts = [piece.converted_value for piece in messages[-1].message_pieces] + + combined = ( + "[Conversation History]\n" + + "\n".join(history_lines) + + "\n\n[Current Message]\n" + + "\n".join(current_parts) + ) + + return [Message.from_prompt(prompt=combined, role="user")] + + def _format_history(self, *, messages: list[Message]) -> list[str]: + """ + Format prior messages as ``Role: content`` lines. + + Args: + messages: The history messages to format. + + Returns: + list[str]: One line per message piece. + """ + lines: list[str] = [] + for msg in messages: + for piece in msg.message_pieces: + lines.append(f"{piece.api_role.capitalize()}: {piece.converted_value}") + return lines diff --git a/pyrit/prompt_target/common/normalization_pipeline.py b/pyrit/prompt_target/common/normalization_pipeline.py new file mode 100644 index 0000000000..b7dbb0f008 --- /dev/null +++ b/pyrit/prompt_target/common/normalization_pipeline.py @@ -0,0 +1,153 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +from typing import ClassVar + +from pyrit.message_normalizer import ( + GenericSystemSquashNormalizer, + HistorySquashNormalizer, + MessageListNormalizer, +) +from pyrit.models import Message +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + NORMALIZABLE_CAPABILITIES, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) + +logger = logging.getLogger(__name__) + + +def _default_normalizer_factory() -> dict[CapabilityName, MessageListNormalizer[Message]]: + """ + Build the default normalizer for every normalizable capability. + + Returns: + dict[CapabilityName, MessageListNormalizer[Message]]: Mapping from + capability to its default normalizer instance. + """ + return { + CapabilityName.SYSTEM_PROMPT: GenericSystemSquashNormalizer(), + CapabilityName.MULTI_TURN: HistorySquashNormalizer(), + } + + +class ConversationNormalizationPipeline: + """ + Ordered sequence of message normalizers that adapt conversations when + the target lacks certain capabilities. + + The pipeline is constructed via ``from_capabilities``, which resolves + capabilities and policy into a concrete, ordered tuple of normalizers. + ``normalize_async`` then simply executes that tuple in order. + """ + + PIPELINE_ORDER: ClassVar[list[CapabilityName]] = [ + CapabilityName.SYSTEM_PROMPT, + CapabilityName.MULTI_TURN, + ] + + def __init__(self, normalizers: tuple[MessageListNormalizer[Message], ...] = ()) -> None: + """ + Initialize the normalization pipeline with an ordered sequence of normalizers. + + Args: + normalizers (tuple[MessageListNormalizer[Message], ...]): + Ordered normalizers to apply during ``normalize_async``. + Defaults to an empty tuple (pass-through). + """ + self._normalizers = normalizers + + @classmethod + def from_capabilities( + cls, + *, + capabilities: TargetCapabilities, + policy: CapabilityHandlingPolicy, + normalizer_overrides: dict[CapabilityName, MessageListNormalizer[Message]] | None = None, + ) -> "ConversationNormalizationPipeline": + """ + Resolve capabilities and policy into a concrete pipeline of normalizers. + + For each capability in ``PIPELINE_ORDER``: + + * If the target already supports the capability, no normalizer is added. + * If the capability is missing and the policy is ``ADAPT``, the + corresponding normalizer (from overrides or defaults) is added. + * If the capability is missing and the policy is ``RAISE``, a + ``ValueError`` is raised immediately. + + Args: + capabilities (TargetCapabilities): The target's declared capabilities. + policy (CapabilityHandlingPolicy): How to handle each missing capability. + normalizer_overrides (dict[CapabilityName, MessageListNormalizer[Message]] | None): + Optional overrides for specific capability normalizers. + Falls back to the defaults from ``_default_normalizer_factory``. + + Returns: + ConversationNormalizationPipeline: A pipeline with the resolved + ordered tuple of normalizers. + + Raises: + ValueError: If a required capability is missing and the policy is RAISE, + or if a capability is not normalizable, or if no normalizer is + available for an ADAPT policy. + """ + defaults = _default_normalizer_factory() + overrides = normalizer_overrides or {} + normalizers: list[MessageListNormalizer[Message]] = [] + + for capability in cls.PIPELINE_ORDER: + if capabilities.supports(capability=capability): + continue + + if capability not in NORMALIZABLE_CAPABILITIES: + raise ValueError( + f"Target does not support '{capability.value}' and this capability cannot be adapted." + ) + + behavior = policy.get_behavior(capability=capability) + + if behavior == UnsupportedCapabilityBehavior.RAISE: + raise ValueError( + f"Target does not support '{capability.value}' and the handling policy is RAISE." + ) + + normalizer = overrides.get(capability) or defaults.get(capability) + if normalizer is None: + raise ValueError( + f"Target does not support '{capability.value}' and the policy is ADAPT, " + f"but no normalizer is available for this capability." + ) + + normalizers.append(normalizer) + + return cls(normalizers=tuple(normalizers)) + + async def normalize_async(self, *, messages: list[Message]) -> list[Message]: + """ + Run the pre-resolved normalizer sequence over the messages. + + Args: + messages (list[Message]): The full conversation to normalize. + + Returns: + list[Message]: The (possibly adapted) message list. + """ + result = list(messages) + for normalizer in self._normalizers: + result = await normalizer.normalize_async(result) + return result + + @property + def normalizers(self) -> tuple[MessageListNormalizer[Message], ...]: + """ + The ordered normalizers in this pipeline. + + Returns: + tuple[MessageListNormalizer[Message], ...]: The normalizer sequence. + """ + return self._normalizers diff --git a/tests/unit/message_normalizer/test_history_squash_normalizer.py b/tests/unit/message_normalizer/test_history_squash_normalizer.py new file mode 100644 index 0000000000..b077fa2aca --- /dev/null +++ b/tests/unit/message_normalizer/test_history_squash_normalizer.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.message_normalizer import HistorySquashNormalizer +from pyrit.models import Message, MessagePiece +from pyrit.models.literals import ChatMessageRole + + +def _make_message(role: ChatMessageRole, content: str) -> Message: + return Message(message_pieces=[MessagePiece(role=role, original_value=content)]) + + +@pytest.mark.asyncio +async def test_history_squash_empty_raises(): + with pytest.raises(ValueError, match="cannot be empty"): + await HistorySquashNormalizer().normalize_async(messages=[]) + + +@pytest.mark.asyncio +async def test_history_squash_single_message_returns_unchanged(): + messages = [_make_message("user", "hello")] + result = await HistorySquashNormalizer().normalize_async(messages) + assert len(result) == 1 + assert result[0].get_value() == "hello" + assert result[0].api_role == "user" + + +@pytest.mark.asyncio +async def test_history_squash_two_turns(): + messages = [ + _make_message("user", "hello"), + _make_message("assistant", "hi there"), + _make_message("user", "how are you?"), + ] + result = await HistorySquashNormalizer().normalize_async(messages) + + assert len(result) == 1 + assert result[0].api_role == "user" + + text = result[0].get_value() + assert "[Conversation History]" in text + assert "User: hello" in text + assert "Assistant: hi there" in text + assert "[Current Message]" in text + assert "how are you?" in text + + +@pytest.mark.asyncio +async def test_history_squash_includes_system_in_history(): + messages = [ + _make_message("system", "You are helpful"), + _make_message("user", "hello"), + _make_message("assistant", "hi"), + _make_message("user", "bye"), + ] + result = await HistorySquashNormalizer().normalize_async(messages) + + assert len(result) == 1 + text = result[0].get_value() + assert "System: You are helpful" in text + assert "User: hello" in text + assert "Assistant: hi" in text + assert "[Current Message]" in text + assert "bye" in text + + +@pytest.mark.asyncio +async def test_history_squash_multi_piece_message(): + """Multi-piece last message has all pieces joined in [Current Message].""" + conversation_id = "test-conv-id" + pieces = [ + MessagePiece(role="user", original_value="part1", conversation_id=conversation_id), + MessagePiece(role="user", original_value="part2", conversation_id=conversation_id), + ] + messages = [ + _make_message("assistant", "hi"), + Message(message_pieces=pieces), + ] + result = await HistorySquashNormalizer().normalize_async(messages) + + text = result[0].get_value() + assert "part1" in text + assert "part2" in text + + +@pytest.mark.asyncio +async def test_history_squash_preserves_original_list(): + """Normalize should not mutate the input list.""" + messages = [ + _make_message("user", "hello"), + _make_message("assistant", "hi"), + _make_message("user", "bye"), + ] + original_len = len(messages) + await HistorySquashNormalizer().normalize_async(messages) + assert len(messages) == original_len diff --git a/tests/unit/target/test_normalization_pipeline.py b/tests/unit/target/test_normalization_pipeline.py new file mode 100644 index 0000000000..001043f48d --- /dev/null +++ b/tests/unit/target/test_normalization_pipeline.py @@ -0,0 +1,273 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from pyrit.message_normalizer import GenericSystemSquashNormalizer, HistorySquashNormalizer, MessageListNormalizer +from pyrit.models import Message, MessagePiece +from pyrit.models.literals import ChatMessageRole +from pyrit.prompt_target.common.normalization_pipeline import ConversationNormalizationPipeline +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) + + +_ADAPT_ALL = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } +) + +_RAISE_ALL = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } +) + + +def _make_message(role: ChatMessageRole, content: str) -> Message: + return Message(message_pieces=[MessagePiece(role=role, original_value=content)]) + + +# --------------------------------------------------------------------------- +# Construction — from_capabilities +# --------------------------------------------------------------------------- + + +def test_from_capabilities_all_supported_empty_tuple(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + pipeline = ConversationNormalizationPipeline.from_capabilities( + capabilities=caps, policy=_ADAPT_ALL + ) + assert pipeline.normalizers == () + + +def test_from_capabilities_none_supported_has_two_normalizers(): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + pipeline = ConversationNormalizationPipeline.from_capabilities( + capabilities=caps, policy=_ADAPT_ALL + ) + assert len(pipeline.normalizers) == 2 + assert isinstance(pipeline.normalizers[0], GenericSystemSquashNormalizer) + assert isinstance(pipeline.normalizers[1], HistorySquashNormalizer) + + +def test_from_capabilities_missing_system_prompt_only(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=False) + policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + pipeline = ConversationNormalizationPipeline.from_capabilities( + capabilities=caps, policy=policy + ) + assert len(pipeline.normalizers) == 1 + assert isinstance(pipeline.normalizers[0], GenericSystemSquashNormalizer) + + +def test_from_capabilities_missing_multi_turn_only(): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True) + policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + pipeline = ConversationNormalizationPipeline.from_capabilities( + capabilities=caps, policy=policy + ) + assert len(pipeline.normalizers) == 1 + assert isinstance(pipeline.normalizers[0], HistorySquashNormalizer) + + +def test_from_capabilities_normalizers_is_tuple(): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + pipeline = ConversationNormalizationPipeline.from_capabilities( + capabilities=caps, policy=_ADAPT_ALL + ) + assert isinstance(pipeline.normalizers, tuple) + + +# --------------------------------------------------------------------------- +# from_capabilities — RAISE policy +# --------------------------------------------------------------------------- + + +def test_from_capabilities_raises_when_system_prompt_missing_and_policy_raise(): + caps = TargetCapabilities(supports_system_prompt=False, supports_multi_turn=True) + with pytest.raises(ValueError, match="RAISE"): + ConversationNormalizationPipeline.from_capabilities( + capabilities=caps, policy=_RAISE_ALL + ) + + +def test_from_capabilities_raises_when_multi_turn_missing_and_policy_raise(): + caps = TargetCapabilities(supports_system_prompt=True, supports_multi_turn=False) + with pytest.raises(ValueError, match="RAISE"): + ConversationNormalizationPipeline.from_capabilities( + capabilities=caps, policy=_RAISE_ALL + ) + + +# --------------------------------------------------------------------------- +# from_capabilities — custom overrides +# --------------------------------------------------------------------------- + + +def test_from_capabilities_uses_override_normalizer(): + mock_normalizer = MagicMock(spec=MessageListNormalizer) + caps = TargetCapabilities(supports_system_prompt=False, supports_multi_turn=True) + policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + pipeline = ConversationNormalizationPipeline.from_capabilities( + capabilities=caps, + policy=policy, + normalizer_overrides={CapabilityName.SYSTEM_PROMPT: mock_normalizer}, + ) + assert len(pipeline.normalizers) == 1 + assert pipeline.normalizers[0] is mock_normalizer + + +# --------------------------------------------------------------------------- +# normalize_async — pass-through +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normalize_passthrough_when_empty_pipeline(): + pipeline = ConversationNormalizationPipeline() + messages = [_make_message("system", "sys"), _make_message("user", "hi")] + result = await pipeline.normalize_async(messages=messages) + + assert len(result) == 2 + assert result[0].get_value() == "sys" + assert result[1].get_value() == "hi" + + +# --------------------------------------------------------------------------- +# normalize_async — ADAPT system prompt +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normalize_adapts_system_prompt(): + caps = TargetCapabilities(supports_system_prompt=False, supports_multi_turn=True) + policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=policy) + + messages = [_make_message("system", "be nice"), _make_message("user", "hello")] + result = await pipeline.normalize_async(messages=messages) + + assert len(result) == 1 + assert result[0].api_role == "user" + assert "be nice" in result[0].get_value() + assert "hello" in result[0].get_value() + + +# --------------------------------------------------------------------------- +# normalize_async — ADAPT multi-turn +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normalize_adapts_multi_turn(): + caps = TargetCapabilities(supports_system_prompt=True, supports_multi_turn=False) + policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=policy) + + messages = [ + _make_message("user", "hello"), + _make_message("assistant", "hi"), + _make_message("user", "how are you?"), + ] + result = await pipeline.normalize_async(messages=messages) + + assert len(result) == 1 + assert result[0].api_role == "user" + text = result[0].get_value() + assert "hello" in text + assert "hi" in text + assert "how are you?" in text + + +# --------------------------------------------------------------------------- +# normalize_async — both adapts in order +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normalize_adapts_system_then_multi_turn(): + """System squash runs first, then history squash.""" + caps = TargetCapabilities(supports_system_prompt=False, supports_multi_turn=False) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=_ADAPT_ALL) + + messages = [ + _make_message("system", "be nice"), + _make_message("user", "hello"), + _make_message("assistant", "hi"), + _make_message("user", "bye"), + ] + result = await pipeline.normalize_async(messages=messages) + + assert len(result) == 1 + assert result[0].api_role == "user" + text = result[0].get_value() + assert "be nice" in text + assert "bye" in text + + +# --------------------------------------------------------------------------- +# normalize_async — custom normalizer via mock +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normalize_uses_custom_normalizer(): + mock_normalizer = MagicMock(spec=MessageListNormalizer) + expected = [_make_message("user", "custom")] + mock_normalizer.normalize_async = AsyncMock(return_value=expected) + + pipeline = ConversationNormalizationPipeline(normalizers=(mock_normalizer,)) + + messages = [_make_message("system", "sys"), _make_message("user", "hi")] + result = await pipeline.normalize_async(messages=messages) + + assert result == expected + mock_normalizer.normalize_async.assert_called_once() From 761166dc86bf9648e1585aed837d33c43c0be0a9 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Mon, 6 Apr 2026 15:01:09 -0400 Subject: [PATCH 3/9] add target configuration and clean up --- pyrit/prompt_target/__init__.py | 4 + ...=> conversation_normalization_pipeline.py} | 61 +++--- .../common/target_capabilities.py | 19 +- .../common/target_configuration.py | 129 ++++++++++++ ...st_conversation_normalization_pipeline.py} | 2 +- tests/unit/target/test_target_capabilities.py | 2 +- .../unit/target/test_target_configuration.py | 184 ++++++++++++++++++ 7 files changed, 368 insertions(+), 33 deletions(-) rename pyrit/prompt_target/common/{normalization_pipeline.py => conversation_normalization_pipeline.py} (72%) create mode 100644 pyrit/prompt_target/common/target_configuration.py rename tests/unit/target/{test_normalization_pipeline.py => test_conversation_normalization_pipeline.py} (99%) create mode 100644 tests/unit/target/test_target_configuration.py diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index 045b217e8a..895b8fd794 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -18,6 +18,8 @@ TargetCapabilities, UnsupportedCapabilityBehavior, ) +from pyrit.prompt_target.common.conversation_normalization_pipeline import ConversationNormalizationPipeline +from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute from pyrit.prompt_target.crucible_target import CrucibleTarget from pyrit.prompt_target.gandalf_target import GandalfLevel, GandalfTarget @@ -48,6 +50,7 @@ "AzureBlobStorageTarget", "AzureMLChatTarget", "CopilotType", + "ConversationNormalizationPipeline", "CrucibleTarget", "GandalfLevel", "GandalfTarget", @@ -74,6 +77,7 @@ "RealtimeTarget", "CapabilityName", "CapabilityHandlingPolicy", + "TargetConfiguration", "TargetCapabilities", "UnsupportedCapabilityBehavior", "TextTarget", diff --git a/pyrit/prompt_target/common/normalization_pipeline.py b/pyrit/prompt_target/common/conversation_normalization_pipeline.py similarity index 72% rename from pyrit/prompt_target/common/normalization_pipeline.py rename to pyrit/prompt_target/common/conversation_normalization_pipeline.py index b7dbb0f008..3d9f10acb4 100644 --- a/pyrit/prompt_target/common/normalization_pipeline.py +++ b/pyrit/prompt_target/common/conversation_normalization_pipeline.py @@ -2,7 +2,7 @@ # Licensed under the MIT license. import logging -from typing import ClassVar +from dataclasses import dataclass from pyrit.message_normalizer import ( GenericSystemSquashNormalizer, @@ -13,7 +13,6 @@ from pyrit.prompt_target.common.target_capabilities import ( CapabilityHandlingPolicy, CapabilityName, - NORMALIZABLE_CAPABILITIES, TargetCapabilities, UnsupportedCapabilityBehavior, ) @@ -21,18 +20,40 @@ logger = logging.getLogger(__name__) -def _default_normalizer_factory() -> dict[CapabilityName, MessageListNormalizer[Message]]: +@dataclass(frozen=True) +class _NormalizerRegistryEntry: + """Single entry in the normalizer registry.""" + + order: int + normalizer_factory: type[MessageListNormalizer[Message]] + + +# --------------------------------------------------------------------------- +# Single registry: add new normalizable capabilities here and nowhere else. +# --------------------------------------------------------------------------- +_NORMALIZER_REGISTRY: dict[CapabilityName, _NormalizerRegistryEntry] = { + CapabilityName.SYSTEM_PROMPT: _NormalizerRegistryEntry(order=0, normalizer_factory=GenericSystemSquashNormalizer), + CapabilityName.MULTI_TURN: _NormalizerRegistryEntry(order=1, normalizer_factory=HistorySquashNormalizer), +} + +# Derived constants — no manual maintenance required. +NORMALIZABLE_CAPABILITIES: frozenset[CapabilityName] = frozenset(_NORMALIZER_REGISTRY) + +_PIPELINE_ORDER: list[CapabilityName] = sorted( + _NORMALIZER_REGISTRY, + key=lambda cap: _NORMALIZER_REGISTRY[cap].order, +) + + +def _default_normalizers() -> dict[CapabilityName, MessageListNormalizer[Message]]: """ - Build the default normalizer for every normalizable capability. + Build a fresh default normalizer instance for every registered capability. Returns: dict[CapabilityName, MessageListNormalizer[Message]]: Mapping from - capability to its default normalizer instance. + capability to a new default normalizer instance. """ - return { - CapabilityName.SYSTEM_PROMPT: GenericSystemSquashNormalizer(), - CapabilityName.MULTI_TURN: HistorySquashNormalizer(), - } + return {cap: entry.normalizer_factory() for cap, entry in _NORMALIZER_REGISTRY.items()} class ConversationNormalizationPipeline: @@ -43,12 +64,11 @@ class ConversationNormalizationPipeline: The pipeline is constructed via ``from_capabilities``, which resolves capabilities and policy into a concrete, ordered tuple of normalizers. ``normalize_async`` then simply executes that tuple in order. - """ - PIPELINE_ORDER: ClassVar[list[CapabilityName]] = [ - CapabilityName.SYSTEM_PROMPT, - CapabilityName.MULTI_TURN, - ] + To add a new normalizable capability, add a single entry to + ``_NORMALIZER_REGISTRY``. ``NORMALIZABLE_CAPABILITIES``, + pipeline ordering, and default normalizers are all derived from it. + """ def __init__(self, normalizers: tuple[MessageListNormalizer[Message], ...] = ()) -> None: """ @@ -96,19 +116,14 @@ def from_capabilities( or if a capability is not normalizable, or if no normalizer is available for an ADAPT policy. """ - defaults = _default_normalizer_factory() + defaults = _default_normalizers() overrides = normalizer_overrides or {} normalizers: list[MessageListNormalizer[Message]] = [] - for capability in cls.PIPELINE_ORDER: + for capability in _PIPELINE_ORDER: if capabilities.supports(capability=capability): continue - if capability not in NORMALIZABLE_CAPABILITIES: - raise ValueError( - f"Target does not support '{capability.value}' and this capability cannot be adapted." - ) - behavior = policy.get_behavior(capability=capability) if behavior == UnsupportedCapabilityBehavior.RAISE: @@ -116,7 +131,9 @@ def from_capabilities( f"Target does not support '{capability.value}' and the handling policy is RAISE." ) - normalizer = overrides.get(capability) or defaults.get(capability) + normalizer = overrides.get(capability) + if normalizer is None: + normalizer = defaults.get(capability) if normalizer is None: raise ValueError( f"Target does not support '{capability.value}' and the policy is ADAPT, " diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index d450f0c727..6ec17bc3ae 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -6,6 +6,9 @@ from typing import Optional, cast from pyrit.models import PromptDataType +from types import MappingProxyType +from collections.abc import Mapping + class CapabilityName(str, Enum): @@ -24,14 +27,6 @@ class CapabilityName(str, Enum): SYSTEM_PROMPT = "supports_system_prompt" -NORMALIZABLE_CAPABILITIES: frozenset[CapabilityName] = frozenset( - { - CapabilityName.SYSTEM_PROMPT, - CapabilityName.MULTI_TURN, - } -) - - class UnsupportedCapabilityBehavior(str, Enum): """ Defines what happens when a caller requires a capability the target does not support. @@ -58,7 +53,7 @@ class CapabilityHandlingPolicy: raises immediately. """ - behaviors: dict[CapabilityName, UnsupportedCapabilityBehavior] = field( + behaviors: Mapping[CapabilityName, UnsupportedCapabilityBehavior] = field( default_factory=lambda: { CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, @@ -103,6 +98,12 @@ def __getattr__(self, name: str) -> UnsupportedCapabilityBehavior: raise AttributeError(name) + def __post_init__(self) -> None: + # Defensive copy + read-only wrapper. object.__setattr__ is required + # because the dataclass is frozen. + object.__setattr__(self, "behaviors", MappingProxyType(dict(self.behaviors))) + + @dataclass(frozen=True) class TargetCapabilities: diff --git a/pyrit/prompt_target/common/target_configuration.py b/pyrit/prompt_target/common/target_configuration.py new file mode 100644 index 0000000000..f4be7870f0 --- /dev/null +++ b/pyrit/prompt_target/common/target_configuration.py @@ -0,0 +1,129 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging + +from pyrit.message_normalizer import MessageListNormalizer +from pyrit.models import Message +from pyrit.prompt_target.common.conversation_normalization_pipeline import ConversationNormalizationPipeline +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) + +logger = logging.getLogger(__name__) + +# Default policy: RAISE on all adaptable capabilities. +_DEFAULT_POLICY = CapabilityHandlingPolicy() + + +class TargetConfiguration: + """ + Unified configuration that describes what a target supports, what to do + when it doesn't, and how to adapt. + + Composes three concerns into a single object: + + * **TargetCapabilities** — declarative, immutable description of what the + target natively supports. + * **CapabilityHandlingPolicy** — per-capability behavior (ADAPT or RAISE) + when a capability is missing. + * **ConversationNormalizationPipeline** — ordered sequence of normalizers + built from the gap between capabilities and policy. + + Each target defines defaults; callers can override policy or individual + normalizers at creation time. + """ + + def __init__( + self, + *, + capabilities: TargetCapabilities, + policy: CapabilityHandlingPolicy | None = None, + normalizer_overrides: dict[CapabilityName, MessageListNormalizer[Message]] | None = None, + ) -> None: + """ + Build a target configuration and resolve the normalization pipeline. + + Args: + capabilities (TargetCapabilities): The target's declared capabilities. + policy (CapabilityHandlingPolicy | None): How to handle each missing + capability. Defaults to RAISE for all adaptable capabilities. + normalizer_overrides (dict[CapabilityName, MessageListNormalizer[Message]] | None): + Optional overrides for specific capability normalizers. + + Raises: + ValueError: If a required capability is missing and the policy is RAISE. + """ + self._capabilities = capabilities + self._policy = policy or _DEFAULT_POLICY + self._pipeline = ConversationNormalizationPipeline.from_capabilities( + capabilities=self._capabilities, + policy=self._policy, + normalizer_overrides=normalizer_overrides, + ) + + @property + def capabilities(self) -> TargetCapabilities: + """The target's declared capabilities.""" + return self._capabilities + + @property + def policy(self) -> CapabilityHandlingPolicy: + """The handling policy for missing capabilities.""" + return self._policy + + @property + def pipeline(self) -> ConversationNormalizationPipeline: + """The resolved normalization pipeline.""" + return self._pipeline + + def supports(self, *, capability: CapabilityName) -> bool: + """ + Check whether the target supports the given capability. + + Args: + capability (CapabilityName): The capability to check. + + Returns: + bool: True if the target supports it natively. + """ + return self._capabilities.supports(capability=capability) + + def requires(self, *, capability: CapabilityName) -> None: + """ + Validate that the target either supports the capability natively or + has an ADAPT policy for it. + + Intended for use by consumers (attacks, converters, scorers) at + construction time. + + Args: + capability (CapabilityName): The required capability. + + Raises: + ValueError: If the capability is missing and the policy is RAISE + or no normalizer is available. + """ + if self._capabilities.supports(capability=capability): + return + + behavior = self._policy.get_behavior(capability=capability) + if behavior == UnsupportedCapabilityBehavior.RAISE: + raise ValueError( + f"Target does not support '{capability.value}' and the handling policy is RAISE." + ) + + async def normalize_async(self, *, messages: list[Message]) -> list[Message]: + """ + Run the normalization pipeline over the given messages. + + Args: + messages (list[Message]): The full conversation to normalize. + + Returns: + list[Message]: The (possibly adapted) message list. + """ + return await self._pipeline.normalize_async(messages=messages) diff --git a/tests/unit/target/test_normalization_pipeline.py b/tests/unit/target/test_conversation_normalization_pipeline.py similarity index 99% rename from tests/unit/target/test_normalization_pipeline.py rename to tests/unit/target/test_conversation_normalization_pipeline.py index 001043f48d..1efc761234 100644 --- a/tests/unit/target/test_normalization_pipeline.py +++ b/tests/unit/target/test_conversation_normalization_pipeline.py @@ -8,7 +8,7 @@ from pyrit.message_normalizer import GenericSystemSquashNormalizer, HistorySquashNormalizer, MessageListNormalizer from pyrit.models import Message, MessagePiece from pyrit.models.literals import ChatMessageRole -from pyrit.prompt_target.common.normalization_pipeline import ConversationNormalizationPipeline +from pyrit.prompt_target.common.conversation_normalization_pipeline import ConversationNormalizationPipeline from pyrit.prompt_target.common.target_capabilities import ( CapabilityHandlingPolicy, CapabilityName, diff --git a/tests/unit/target/test_target_capabilities.py b/tests/unit/target/test_target_capabilities.py index b77f84af94..fd52039285 100644 --- a/tests/unit/target/test_target_capabilities.py +++ b/tests/unit/target/test_target_capabilities.py @@ -5,10 +5,10 @@ import pytest +from pyrit.prompt_target.common.conversation_normalization_pipeline import NORMALIZABLE_CAPABILITIES from pyrit.prompt_target.common.target_capabilities import ( CapabilityName, CapabilityHandlingPolicy, - NORMALIZABLE_CAPABILITIES, TargetCapabilities, UnsupportedCapabilityBehavior, ) diff --git a/tests/unit/target/test_target_configuration.py b/tests/unit/target/test_target_configuration.py new file mode 100644 index 0000000000..51c53a0ea3 --- /dev/null +++ b/tests/unit/target/test_target_configuration.py @@ -0,0 +1,184 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.message_normalizer import GenericSystemSquashNormalizer, HistorySquashNormalizer +from pyrit.models import Message, MessagePiece +from pyrit.models.literals import ChatMessageRole +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) +from pyrit.prompt_target.common.target_configuration import TargetConfiguration + + +_ADAPT_ALL = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } +) + + +def _make_message(role: ChatMessageRole, content: str) -> Message: + return Message(message_pieces=[MessagePiece(role=role, original_value=content)]) + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +def test_init_with_defaults_uses_raise_policy(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps) + # Default policy is RAISE for all adaptable capabilities + assert config.policy.get_behavior(capability=CapabilityName.MULTI_TURN) == UnsupportedCapabilityBehavior.RAISE + + +def test_init_with_explicit_policy(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + assert config.policy is _ADAPT_ALL + + +def test_init_all_supported_empty_pipeline(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + assert config.pipeline.normalizers == () + + +def test_init_missing_capability_adapt_builds_pipeline(): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + assert len(config.pipeline.normalizers) == 2 + assert isinstance(config.pipeline.normalizers[0], GenericSystemSquashNormalizer) + assert isinstance(config.pipeline.normalizers[1], HistorySquashNormalizer) + + +def test_init_missing_capability_raise_policy_raises(): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True) + with pytest.raises(ValueError, match="RAISE"): + TargetConfiguration(capabilities=caps) + + +# --------------------------------------------------------------------------- +# Properties +# --------------------------------------------------------------------------- + + +def test_capabilities_property(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps) + assert config.capabilities is caps + + +# --------------------------------------------------------------------------- +# supports +# --------------------------------------------------------------------------- + + +def test_supports_returns_true_when_supported(): + caps = TargetCapabilities(supports_multi_turn=True) + config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + assert config.supports(capability=CapabilityName.MULTI_TURN) is True + + +def test_supports_returns_false_when_unsupported(): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + assert config.supports(capability=CapabilityName.MULTI_TURN) is False + + +# --------------------------------------------------------------------------- +# requires +# --------------------------------------------------------------------------- + + +def test_requires_passes_when_supported(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps) + # Should not raise + config.requires(capability=CapabilityName.MULTI_TURN) + + +def test_requires_passes_when_adapt(): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + # ADAPT policy → should not raise + config.requires(capability=CapabilityName.MULTI_TURN) + + +def test_requires_raises_when_raise_policy(): + # Build with ADAPT so construction succeeds, then test requires() on a RAISE capability. + # JSON_SCHEMA is RAISE and unsupported — but it's not normalizable, so construction + # doesn't try to build a normalizer for it. Use a custom policy where system_prompt + # is ADAPT (so pipeline builds), but then call requires() on JSON_OUTPUT which is RAISE. + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=False) + policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + config = TargetConfiguration(capabilities=caps, policy=policy) + # system_prompt is missing + ADAPT → requires passes + config.requires(capability=CapabilityName.SYSTEM_PROMPT) + # json_output is missing + RAISE → requires raises + with pytest.raises(ValueError, match="RAISE"): + config.requires(capability=CapabilityName.JSON_OUTPUT) + + +# --------------------------------------------------------------------------- +# normalize_async +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normalize_async_passthrough_when_all_supported(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + msgs = [_make_message("user", "hello")] + result = await config.normalize_async(messages=msgs) + assert len(result) == 1 + assert result[0].message_pieces[0].converted_value == "hello" + + +@pytest.mark.asyncio +async def test_normalize_async_adapts_system_prompt(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=False) + config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + + msgs = [ + _make_message("system", "you are helpful"), + _make_message("user", "hello"), + ] + result = await config.normalize_async(messages=msgs) + # System squash merges system into user messages — no system role left + for msg in result: + for piece in msg.message_pieces: + assert piece.api_role != "system" + + +@pytest.mark.asyncio +async def test_normalize_async_adapts_multi_turn(): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + + msgs = [ + _make_message("user", "turn 1"), + _make_message("assistant", "reply 1"), + _make_message("user", "turn 2"), + ] + result = await config.normalize_async(messages=msgs) + # History squash collapses into a single message + assert len(result) == 1 + assert "[Conversation History]" in result[0].message_pieces[0].converted_value + assert "turn 2" in result[0].message_pieces[0].converted_value From 9518a069d338a3639de29f1889af91b8f2f075ad Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Mon, 6 Apr 2026 15:53:39 -0400 Subject: [PATCH 4/9] add tests and fix missing capabilities --- .../history_squash_normalizer.py | 5 +--- pyrit/prompt_target/__init__.py | 10 ++++---- .../conversation_normalization_pipeline.py | 4 +++ .../common/target_capabilities.py | 9 +++---- ...est_conversation_normalization_pipeline.py | 4 +++ tests/unit/target/test_target_capabilities.py | 25 ++++++++++++------- .../unit/target/test_target_configuration.py | 2 ++ 7 files changed, 36 insertions(+), 23 deletions(-) diff --git a/pyrit/message_normalizer/history_squash_normalizer.py b/pyrit/message_normalizer/history_squash_normalizer.py index 7cd79a727f..6b1f9966b7 100644 --- a/pyrit/message_normalizer/history_squash_normalizer.py +++ b/pyrit/message_normalizer/history_squash_normalizer.py @@ -42,10 +42,7 @@ async def normalize_async(self, messages: list[Message]) -> list[Message]: current_parts = [piece.converted_value for piece in messages[-1].message_pieces] combined = ( - "[Conversation History]\n" - + "\n".join(history_lines) - + "\n\n[Current Message]\n" - + "\n".join(current_parts) + "[Conversation History]\n" + "\n".join(history_lines) + "\n\n[Current Message]\n" + "\n".join(current_parts) ) return [Message.from_prompt(prompt=combined, role="user")] diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index 895b8fd794..1517f8a51b 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -10,15 +10,15 @@ from pyrit.prompt_target.azure_blob_storage_target import AzureBlobStorageTarget from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget +from pyrit.prompt_target.common.conversation_normalization_pipeline import ConversationNormalizationPipeline from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.target_capabilities import ( - CapabilityName, CapabilityHandlingPolicy, + CapabilityName, TargetCapabilities, UnsupportedCapabilityBehavior, ) -from pyrit.prompt_target.common.conversation_normalization_pipeline import ConversationNormalizationPipeline from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute from pyrit.prompt_target.crucible_target import CrucibleTarget @@ -49,6 +49,8 @@ __all__ = [ "AzureBlobStorageTarget", "AzureMLChatTarget", + "CapabilityName", + "CapabilityHandlingPolicy", "CopilotType", "ConversationNormalizationPipeline", "CrucibleTarget", @@ -75,10 +77,8 @@ "PromptShieldTarget", "PromptTarget", "RealtimeTarget", - "CapabilityName", - "CapabilityHandlingPolicy", - "TargetConfiguration", "TargetCapabilities", + "TargetConfiguration", "UnsupportedCapabilityBehavior", "TextTarget", "WebSocketCopilotTarget", diff --git a/pyrit/prompt_target/common/conversation_normalization_pipeline.py b/pyrit/prompt_target/common/conversation_normalization_pipeline.py index 3d9f10acb4..5430f0f2d6 100644 --- a/pyrit/prompt_target/common/conversation_normalization_pipeline.py +++ b/pyrit/prompt_target/common/conversation_normalization_pipeline.py @@ -100,6 +100,10 @@ def from_capabilities( * If the capability is missing and the policy is ``RAISE``, a ``ValueError`` is raised immediately. + NOTE: Normalizers are only valid when the capability can be overridden with a normalizer (which is indicated + by its presence in the registry), so we only iterate over valid capabilities in this function and add normalizers + only when the capability can support normalization. + Args: capabilities (TargetCapabilities): The target's declared capabilities. policy (CapabilityHandlingPolicy): How to handle each missing capability. diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index 6ec17bc3ae..66b298765a 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -1,14 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from collections.abc import Mapping from dataclasses import dataclass, field from enum import Enum +from types import MappingProxyType from typing import Optional, cast from pyrit.models import PromptDataType -from types import MappingProxyType -from collections.abc import Mapping - class CapabilityName(str, Enum): @@ -39,7 +38,6 @@ class UnsupportedCapabilityBehavior(str, Enum): RAISE = "raise" - @dataclass(frozen=True) class CapabilityHandlingPolicy: """ @@ -59,6 +57,8 @@ class CapabilityHandlingPolicy: CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, } ) @@ -104,7 +104,6 @@ def __post_init__(self) -> None: object.__setattr__(self, "behaviors", MappingProxyType(dict(self.behaviors))) - @dataclass(frozen=True) class TargetCapabilities: """ diff --git a/tests/unit/target/test_conversation_normalization_pipeline.py b/tests/unit/target/test_conversation_normalization_pipeline.py index 1efc761234..02a06ced79 100644 --- a/tests/unit/target/test_conversation_normalization_pipeline.py +++ b/tests/unit/target/test_conversation_normalization_pipeline.py @@ -23,6 +23,8 @@ CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, } ) @@ -32,6 +34,8 @@ CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, } ) diff --git a/tests/unit/target/test_target_capabilities.py b/tests/unit/target/test_target_capabilities.py index fd52039285..ca448abb98 100644 --- a/tests/unit/target/test_target_capabilities.py +++ b/tests/unit/target/test_target_capabilities.py @@ -7,8 +7,8 @@ from pyrit.prompt_target.common.conversation_normalization_pipeline import NORMALIZABLE_CAPABILITIES from pyrit.prompt_target.common.target_capabilities import ( - CapabilityName, CapabilityHandlingPolicy, + CapabilityName, TargetCapabilities, UnsupportedCapabilityBehavior, ) @@ -36,6 +36,8 @@ def test_capability_handling_policy_defaults(self): CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, } def test_capability_handling_policy_custom_values(self): @@ -55,9 +57,7 @@ def test_capability_handling_policy_get_behavior(self): policy = CapabilityHandlingPolicy() assert policy.get_behavior(capability=CapabilityName.MULTI_TURN) is UnsupportedCapabilityBehavior.RAISE - assert ( - policy.get_behavior(capability=CapabilityName.SYSTEM_PROMPT) is UnsupportedCapabilityBehavior.RAISE - ) + assert policy.get_behavior(capability=CapabilityName.SYSTEM_PROMPT) is UnsupportedCapabilityBehavior.RAISE def test_capability_handling_policy_get_behavior_for_all_supported_policy_keys(self): policy = CapabilityHandlingPolicy() @@ -66,13 +66,19 @@ def test_capability_handling_policy_get_behavior_for_all_supported_policy_keys(s assert policy.get_behavior(capability=CapabilityName.JSON_OUTPUT) is UnsupportedCapabilityBehavior.RAISE def test_capability_handling_policy_rejects_capability_without_policy(self): - policy = CapabilityHandlingPolicy() + # Use a custom partial policy that deliberately omits EDITABLE_HISTORY + partial_policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + } + ) with pytest.raises(AttributeError, match="supports_editable_history"): - policy.get_behavior(capability=CapabilityName.EDITABLE_HISTORY) + partial_policy.get_behavior(capability=CapabilityName.EDITABLE_HISTORY) with pytest.raises(AttributeError, match="supports_editable_history"): - _ = policy.supports_editable_history + _ = partial_policy.supports_editable_history def test_capability_handling_policy_rejects_unknown_attribute(self): policy = CapabilityHandlingPolicy() @@ -81,12 +87,12 @@ def test_capability_handling_policy_rejects_unknown_attribute(self): _ = policy.totally_unknown_attribute def test_normalizable_capabilities(self): - assert NORMALIZABLE_CAPABILITIES == frozenset( + assert frozenset( { CapabilityName.MULTI_TURN, CapabilityName.SYSTEM_PROMPT, } - ) + ) == NORMALIZABLE_CAPABILITIES def test_target_capabilities_supports_helper(self): capabilities = TargetCapabilities( @@ -100,6 +106,7 @@ def test_target_capabilities_supports_helper(self): assert capabilities.supports(capability=CapabilityName.JSON_OUTPUT) is True assert capabilities.supports(capability=CapabilityName.EDITABLE_HISTORY) is False + # Env vars that may leak from .env files loaded by other tests in parallel workers. # Clear them so that targets use _DEFAULT_CAPABILITIES instead of _KNOWN_CAPABILITIES. _CLEAN_UNDERLYING_MODEL_ENV = { diff --git a/tests/unit/target/test_target_configuration.py b/tests/unit/target/test_target_configuration.py index 51c53a0ea3..11203a622a 100644 --- a/tests/unit/target/test_target_configuration.py +++ b/tests/unit/target/test_target_configuration.py @@ -21,6 +21,8 @@ CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, } ) From c3a89a81df440a19ed1b71eaa79a5454cb6fc4c6 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Mon, 6 Apr 2026 15:57:24 -0400 Subject: [PATCH 5/9] precommit --- .../conversation_normalization_pipeline.py | 4 +-- .../common/target_configuration.py | 4 +-- ...est_conversation_normalization_pipeline.py | 29 +++++-------------- tests/unit/target/test_target_capabilities.py | 15 ++++++---- .../unit/target/test_target_configuration.py | 1 - 5 files changed, 18 insertions(+), 35 deletions(-) diff --git a/pyrit/prompt_target/common/conversation_normalization_pipeline.py b/pyrit/prompt_target/common/conversation_normalization_pipeline.py index 5430f0f2d6..920e00e127 100644 --- a/pyrit/prompt_target/common/conversation_normalization_pipeline.py +++ b/pyrit/prompt_target/common/conversation_normalization_pipeline.py @@ -131,9 +131,7 @@ def from_capabilities( behavior = policy.get_behavior(capability=capability) if behavior == UnsupportedCapabilityBehavior.RAISE: - raise ValueError( - f"Target does not support '{capability.value}' and the handling policy is RAISE." - ) + raise ValueError(f"Target does not support '{capability.value}' and the handling policy is RAISE.") normalizer = overrides.get(capability) if normalizer is None: diff --git a/pyrit/prompt_target/common/target_configuration.py b/pyrit/prompt_target/common/target_configuration.py index f4be7870f0..98c727ff72 100644 --- a/pyrit/prompt_target/common/target_configuration.py +++ b/pyrit/prompt_target/common/target_configuration.py @@ -112,9 +112,7 @@ def requires(self, *, capability: CapabilityName) -> None: behavior = self._policy.get_behavior(capability=capability) if behavior == UnsupportedCapabilityBehavior.RAISE: - raise ValueError( - f"Target does not support '{capability.value}' and the handling policy is RAISE." - ) + raise ValueError(f"Target does not support '{capability.value}' and the handling policy is RAISE.") async def normalize_async(self, *, messages: list[Message]) -> list[Message]: """ diff --git a/tests/unit/target/test_conversation_normalization_pipeline.py b/tests/unit/target/test_conversation_normalization_pipeline.py index 02a06ced79..57d96c0539 100644 --- a/tests/unit/target/test_conversation_normalization_pipeline.py +++ b/tests/unit/target/test_conversation_normalization_pipeline.py @@ -16,7 +16,6 @@ UnsupportedCapabilityBehavior, ) - _ADAPT_ALL = CapabilityHandlingPolicy( behaviors={ CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, @@ -51,17 +50,13 @@ def _make_message(role: ChatMessageRole, content: str) -> Message: def test_from_capabilities_all_supported_empty_tuple(): caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) - pipeline = ConversationNormalizationPipeline.from_capabilities( - capabilities=caps, policy=_ADAPT_ALL - ) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=_ADAPT_ALL) assert pipeline.normalizers == () def test_from_capabilities_none_supported_has_two_normalizers(): caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) - pipeline = ConversationNormalizationPipeline.from_capabilities( - capabilities=caps, policy=_ADAPT_ALL - ) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=_ADAPT_ALL) assert len(pipeline.normalizers) == 2 assert isinstance(pipeline.normalizers[0], GenericSystemSquashNormalizer) assert isinstance(pipeline.normalizers[1], HistorySquashNormalizer) @@ -77,9 +72,7 @@ def test_from_capabilities_missing_system_prompt_only(): CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, } ) - pipeline = ConversationNormalizationPipeline.from_capabilities( - capabilities=caps, policy=policy - ) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=policy) assert len(pipeline.normalizers) == 1 assert isinstance(pipeline.normalizers[0], GenericSystemSquashNormalizer) @@ -94,18 +87,14 @@ def test_from_capabilities_missing_multi_turn_only(): CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, } ) - pipeline = ConversationNormalizationPipeline.from_capabilities( - capabilities=caps, policy=policy - ) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=policy) assert len(pipeline.normalizers) == 1 assert isinstance(pipeline.normalizers[0], HistorySquashNormalizer) def test_from_capabilities_normalizers_is_tuple(): caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) - pipeline = ConversationNormalizationPipeline.from_capabilities( - capabilities=caps, policy=_ADAPT_ALL - ) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=_ADAPT_ALL) assert isinstance(pipeline.normalizers, tuple) @@ -117,17 +106,13 @@ def test_from_capabilities_normalizers_is_tuple(): def test_from_capabilities_raises_when_system_prompt_missing_and_policy_raise(): caps = TargetCapabilities(supports_system_prompt=False, supports_multi_turn=True) with pytest.raises(ValueError, match="RAISE"): - ConversationNormalizationPipeline.from_capabilities( - capabilities=caps, policy=_RAISE_ALL - ) + ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=_RAISE_ALL) def test_from_capabilities_raises_when_multi_turn_missing_and_policy_raise(): caps = TargetCapabilities(supports_system_prompt=True, supports_multi_turn=False) with pytest.raises(ValueError, match="RAISE"): - ConversationNormalizationPipeline.from_capabilities( - capabilities=caps, policy=_RAISE_ALL - ) + ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=_RAISE_ALL) # --------------------------------------------------------------------------- diff --git a/tests/unit/target/test_target_capabilities.py b/tests/unit/target/test_target_capabilities.py index ca448abb98..1ee906ab7a 100644 --- a/tests/unit/target/test_target_capabilities.py +++ b/tests/unit/target/test_target_capabilities.py @@ -87,12 +87,15 @@ def test_capability_handling_policy_rejects_unknown_attribute(self): _ = policy.totally_unknown_attribute def test_normalizable_capabilities(self): - assert frozenset( - { - CapabilityName.MULTI_TURN, - CapabilityName.SYSTEM_PROMPT, - } - ) == NORMALIZABLE_CAPABILITIES + assert ( + frozenset( + { + CapabilityName.MULTI_TURN, + CapabilityName.SYSTEM_PROMPT, + } + ) + == NORMALIZABLE_CAPABILITIES + ) def test_target_capabilities_supports_helper(self): capabilities = TargetCapabilities( diff --git a/tests/unit/target/test_target_configuration.py b/tests/unit/target/test_target_configuration.py index 11203a622a..43de2261d9 100644 --- a/tests/unit/target/test_target_configuration.py +++ b/tests/unit/target/test_target_configuration.py @@ -14,7 +14,6 @@ ) from pyrit.prompt_target.common.target_configuration import TargetConfiguration - _ADAPT_ALL = CapabilityHandlingPolicy( behaviors={ CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, From 3a3b430d3a1d77d62bbc988420552066308cb731 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Mon, 6 Apr 2026 17:26:17 -0400 Subject: [PATCH 6/9] remove unnormalizable capabilities from policy and update tests --- .../common/target_capabilities.py | 8 ++----- tests/unit/target/test_target_capabilities.py | 23 +++++-------------- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index 66b298765a..a4d990c43e 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -30,7 +30,7 @@ class UnsupportedCapabilityBehavior(str, Enum): """ Defines what happens when a caller requires a capability the target does not support. - ADAPT: apply a normalization step to work around the missing capability. + ADAPT: apply a normalization step to work around the unsupported capability. RAISE: fail immediately with an error. """ @@ -41,7 +41,7 @@ class UnsupportedCapabilityBehavior(str, Enum): @dataclass(frozen=True) class CapabilityHandlingPolicy: """ - Per-capability policy consulted only when a capability is **missing**. + Per-capability policy consulted only when a capability is unsupported. Design invariants ----------------- @@ -55,10 +55,6 @@ class CapabilityHandlingPolicy: default_factory=lambda: { CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, } ) diff --git a/tests/unit/target/test_target_capabilities.py b/tests/unit/target/test_target_capabilities.py index e2488faaa4..c8e2c91120 100644 --- a/tests/unit/target/test_target_capabilities.py +++ b/tests/unit/target/test_target_capabilities.py @@ -34,10 +34,6 @@ def test_capability_handling_policy_defaults(self): assert policy.behaviors == { CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, } def test_capability_handling_policy_custom_values(self): @@ -59,26 +55,19 @@ def test_capability_handling_policy_get_behavior(self): assert policy.get_behavior(capability=CapabilityName.MULTI_TURN) is UnsupportedCapabilityBehavior.RAISE assert policy.get_behavior(capability=CapabilityName.SYSTEM_PROMPT) is UnsupportedCapabilityBehavior.RAISE - def test_capability_handling_policy_get_behavior_for_all_supported_policy_keys(self): + def test_capability_handling_policy_get_behavior_for_all_default_keys(self): policy = CapabilityHandlingPolicy() - - assert policy.get_behavior(capability=CapabilityName.JSON_SCHEMA) is UnsupportedCapabilityBehavior.RAISE - assert policy.get_behavior(capability=CapabilityName.JSON_OUTPUT) is UnsupportedCapabilityBehavior.RAISE + for cap in policy.behaviors: + assert policy.get_behavior(capability=cap) is UnsupportedCapabilityBehavior.RAISE def test_capability_handling_policy_rejects_capability_without_policy(self): - # Use a custom partial policy that deliberately omits EDITABLE_HISTORY - partial_policy = CapabilityHandlingPolicy( - behaviors={ - CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, - } - ) + policy = CapabilityHandlingPolicy() with pytest.raises(AttributeError, match="supports_editable_history"): - partial_policy.get_behavior(capability=CapabilityName.EDITABLE_HISTORY) + policy.get_behavior(capability=CapabilityName.EDITABLE_HISTORY) with pytest.raises(AttributeError, match="supports_editable_history"): - _ = partial_policy.supports_editable_history + _ = policy.supports_editable_history def test_capability_handling_policy_rejects_unknown_attribute(self): policy = CapabilityHandlingPolicy() From a9575f0f0feeabc1d38b50eb5c92fde9529146c7 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Tue, 7 Apr 2026 16:15:14 -0400 Subject: [PATCH 7/9] update naming and pre-commit --- .../history_squash_normalizer.py | 3 +- .../conversation_normalization_pipeline.py | 6 ++-- .../common/target_capabilities.py | 6 ++-- .../common/target_configuration.py | 10 +++--- tests/unit/target/test_target_capabilities.py | 10 +++--- .../unit/target/test_target_configuration.py | 32 +++++++++---------- 6 files changed, 33 insertions(+), 34 deletions(-) diff --git a/pyrit/message_normalizer/history_squash_normalizer.py b/pyrit/message_normalizer/history_squash_normalizer.py index 6b1f9966b7..f143e3ca01 100644 --- a/pyrit/message_normalizer/history_squash_normalizer.py +++ b/pyrit/message_normalizer/history_squash_normalizer.py @@ -59,6 +59,5 @@ def _format_history(self, *, messages: list[Message]) -> list[str]: """ lines: list[str] = [] for msg in messages: - for piece in msg.message_pieces: - lines.append(f"{piece.api_role.capitalize()}: {piece.converted_value}") + lines.extend(f"{piece.api_role.capitalize()}: {piece.converted_value}" for piece in msg.message_pieces) return lines diff --git a/pyrit/prompt_target/common/conversation_normalization_pipeline.py b/pyrit/prompt_target/common/conversation_normalization_pipeline.py index 920e00e127..4fdc492e18 100644 --- a/pyrit/prompt_target/common/conversation_normalization_pipeline.py +++ b/pyrit/prompt_target/common/conversation_normalization_pipeline.py @@ -101,8 +101,8 @@ def from_capabilities( ``ValueError`` is raised immediately. NOTE: Normalizers are only valid when the capability can be overridden with a normalizer (which is indicated - by its presence in the registry), so we only iterate over valid capabilities in this function and add normalizers - only when the capability can support normalization. + by its presence in the registry), so we only iterate over valid capabilities in this function and add + normalizers only when the capability can support normalization. Args: capabilities (TargetCapabilities): The target's declared capabilities. @@ -125,7 +125,7 @@ def from_capabilities( normalizers: list[MessageListNormalizer[Message]] = [] for capability in _PIPELINE_ORDER: - if capabilities.supports(capability=capability): + if capabilities.includes(capability=capability): continue behavior = policy.get_behavior(capability=capability) diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index a4d990c43e..068f4ec32a 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -95,8 +95,8 @@ def __getattr__(self, name: str) -> UnsupportedCapabilityBehavior: raise AttributeError(name) def __post_init__(self) -> None: - # Defensive copy + read-only wrapper. object.__setattr__ is required - # because the dataclass is frozen. + """Create a defensive read-only copy of the behaviors mapping.""" + # object.__setattr__ is required because the dataclass is frozen. object.__setattr__(self, "behaviors", MappingProxyType(dict(self.behaviors))) @@ -140,7 +140,7 @@ class attribute. Users can override individual capabilities per instance # The output modalities supported by the target (e.g., "text", "image"). output_modalities: frozenset[frozenset[PromptDataType]] = frozenset({frozenset(["text"])}) - def supports(self, *, capability: CapabilityName) -> bool: + def includes(self, *, capability: CapabilityName) -> bool: """ Return whether this target supports the given capability. diff --git a/pyrit/prompt_target/common/target_configuration.py b/pyrit/prompt_target/common/target_configuration.py index 98c727ff72..ae9094d7d9 100644 --- a/pyrit/prompt_target/common/target_configuration.py +++ b/pyrit/prompt_target/common/target_configuration.py @@ -80,9 +80,9 @@ def pipeline(self) -> ConversationNormalizationPipeline: """The resolved normalization pipeline.""" return self._pipeline - def supports(self, *, capability: CapabilityName) -> bool: + def includes(self, *, capability: CapabilityName) -> bool: """ - Check whether the target supports the given capability. + Check whether the target includes support for the given capability. Args: capability (CapabilityName): The capability to check. @@ -90,9 +90,9 @@ def supports(self, *, capability: CapabilityName) -> bool: Returns: bool: True if the target supports it natively. """ - return self._capabilities.supports(capability=capability) + return self._capabilities.includes(capability=capability) - def requires(self, *, capability: CapabilityName) -> None: + def ensure_can_handle(self, *, capability: CapabilityName) -> None: """ Validate that the target either supports the capability natively or has an ADAPT policy for it. @@ -107,7 +107,7 @@ def requires(self, *, capability: CapabilityName) -> None: ValueError: If the capability is missing and the policy is RAISE or no normalizer is available. """ - if self._capabilities.supports(capability=capability): + if self._capabilities.includes(capability=capability): return behavior = self._policy.get_behavior(capability=capability) diff --git a/tests/unit/target/test_target_capabilities.py b/tests/unit/target/test_target_capabilities.py index c8e2c91120..eff36042c1 100644 --- a/tests/unit/target/test_target_capabilities.py +++ b/tests/unit/target/test_target_capabilities.py @@ -86,17 +86,17 @@ def test_normalizable_capabilities(self): == NORMALIZABLE_CAPABILITIES ) - def test_target_capabilities_supports_helper(self): + def test_target_capabilities_includes_helper(self): capabilities = TargetCapabilities( supports_multi_turn=True, supports_system_prompt=False, supports_json_output=True, ) - assert capabilities.supports(capability=CapabilityName.MULTI_TURN) is True - assert capabilities.supports(capability=CapabilityName.SYSTEM_PROMPT) is False - assert capabilities.supports(capability=CapabilityName.JSON_OUTPUT) is True - assert capabilities.supports(capability=CapabilityName.EDITABLE_HISTORY) is False + assert capabilities.includes(capability=CapabilityName.MULTI_TURN) is True + assert capabilities.includes(capability=CapabilityName.SYSTEM_PROMPT) is False + assert capabilities.includes(capability=CapabilityName.JSON_OUTPUT) is True + assert capabilities.includes(capability=CapabilityName.EDITABLE_HISTORY) is False # Env vars that may leak from .env files loaded by other tests in parallel workers. diff --git a/tests/unit/target/test_target_configuration.py b/tests/unit/target/test_target_configuration.py index 43de2261d9..e7408324d8 100644 --- a/tests/unit/target/test_target_configuration.py +++ b/tests/unit/target/test_target_configuration.py @@ -84,42 +84,42 @@ def test_capabilities_property(): # --------------------------------------------------------------------------- -def test_supports_returns_true_when_supported(): +def test_includes_returns_true_when_supported(): caps = TargetCapabilities(supports_multi_turn=True) config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) - assert config.supports(capability=CapabilityName.MULTI_TURN) is True + assert config.includes(capability=CapabilityName.MULTI_TURN) is True -def test_supports_returns_false_when_unsupported(): +def test_includes_returns_false_when_unsupported(): caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) - assert config.supports(capability=CapabilityName.MULTI_TURN) is False + assert config.includes(capability=CapabilityName.MULTI_TURN) is False # --------------------------------------------------------------------------- -# requires +# ensure_can_handle # --------------------------------------------------------------------------- -def test_requires_passes_when_supported(): +def test_ensure_can_handle_passes_when_supported(): caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) config = TargetConfiguration(capabilities=caps) # Should not raise - config.requires(capability=CapabilityName.MULTI_TURN) + config.ensure_can_handle(capability=CapabilityName.MULTI_TURN) -def test_requires_passes_when_adapt(): +def test_ensure_can_handle_passes_when_adapt(): caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) # ADAPT policy → should not raise - config.requires(capability=CapabilityName.MULTI_TURN) + config.ensure_can_handle(capability=CapabilityName.MULTI_TURN) -def test_requires_raises_when_raise_policy(): - # Build with ADAPT so construction succeeds, then test requires() on a RAISE capability. +def test_ensure_can_handle_raises_when_raise_policy(): + # Build with ADAPT so construction succeeds, then test ensure_can_handle() on a RAISE capability. # JSON_SCHEMA is RAISE and unsupported — but it's not normalizable, so construction # doesn't try to build a normalizer for it. Use a custom policy where system_prompt - # is ADAPT (so pipeline builds), but then call requires() on JSON_OUTPUT which is RAISE. + # is ADAPT (so pipeline builds), but then call ensure_can_handle() on JSON_OUTPUT which is RAISE. caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=False) policy = CapabilityHandlingPolicy( behaviors={ @@ -130,11 +130,11 @@ def test_requires_raises_when_raise_policy(): } ) config = TargetConfiguration(capabilities=caps, policy=policy) - # system_prompt is missing + ADAPT → requires passes - config.requires(capability=CapabilityName.SYSTEM_PROMPT) - # json_output is missing + RAISE → requires raises + # system_prompt is missing + ADAPT → ensure_can_handle passes + config.ensure_can_handle(capability=CapabilityName.SYSTEM_PROMPT) + # json_output is missing + RAISE → ensure_can_handle raises with pytest.raises(ValueError, match="RAISE"): - config.requires(capability=CapabilityName.JSON_OUTPUT) + config.ensure_can_handle(capability=CapabilityName.JSON_OUTPUT) # --------------------------------------------------------------------------- From 0ad8df0555883efe45ef197170e5029fc81e3923 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Wed, 8 Apr 2026 12:44:13 -0400 Subject: [PATCH 8/9] PR comments --- .../conversation_normalization_pipeline.py | 62 ++-------- .../common/target_capabilities.py | 12 +- ...est_conversation_normalization_pipeline.py | 108 ++++++++++-------- tests/unit/target/test_target_capabilities.py | 2 +- .../unit/target/test_target_configuration.py | 81 +++++++------ 5 files changed, 123 insertions(+), 142 deletions(-) diff --git a/pyrit/prompt_target/common/conversation_normalization_pipeline.py b/pyrit/prompt_target/common/conversation_normalization_pipeline.py index 4fdc492e18..d81d7c97ae 100644 --- a/pyrit/prompt_target/common/conversation_normalization_pipeline.py +++ b/pyrit/prompt_target/common/conversation_normalization_pipeline.py @@ -2,7 +2,6 @@ # Licensed under the MIT license. import logging -from dataclasses import dataclass from pyrit.message_normalizer import ( GenericSystemSquashNormalizer, @@ -20,40 +19,17 @@ logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class _NormalizerRegistryEntry: - """Single entry in the normalizer registry.""" - - order: int - normalizer_factory: type[MessageListNormalizer[Message]] - - # --------------------------------------------------------------------------- # Single registry: add new normalizable capabilities here and nowhere else. +# Order in the list determines pipeline execution order. # --------------------------------------------------------------------------- -_NORMALIZER_REGISTRY: dict[CapabilityName, _NormalizerRegistryEntry] = { - CapabilityName.SYSTEM_PROMPT: _NormalizerRegistryEntry(order=0, normalizer_factory=GenericSystemSquashNormalizer), - CapabilityName.MULTI_TURN: _NormalizerRegistryEntry(order=1, normalizer_factory=HistorySquashNormalizer), -} +_NORMALIZER_REGISTRY: list[tuple[CapabilityName, MessageListNormalizer[Message]]] = [ + (CapabilityName.SYSTEM_PROMPT, GenericSystemSquashNormalizer()), + (CapabilityName.MULTI_TURN, HistorySquashNormalizer()), +] -# Derived constants — no manual maintenance required. -NORMALIZABLE_CAPABILITIES: frozenset[CapabilityName] = frozenset(_NORMALIZER_REGISTRY) - -_PIPELINE_ORDER: list[CapabilityName] = sorted( - _NORMALIZER_REGISTRY, - key=lambda cap: _NORMALIZER_REGISTRY[cap].order, -) - - -def _default_normalizers() -> dict[CapabilityName, MessageListNormalizer[Message]]: - """ - Build a fresh default normalizer instance for every registered capability. - - Returns: - dict[CapabilityName, MessageListNormalizer[Message]]: Mapping from - capability to a new default normalizer instance. - """ - return {cap: entry.normalizer_factory() for cap, entry in _NORMALIZER_REGISTRY.items()} +# Derived constant — no manual maintenance required. +NORMALIZABLE_CAPABILITIES: frozenset[CapabilityName] = frozenset(cap for cap, _ in _NORMALIZER_REGISTRY) class ConversationNormalizationPipeline: @@ -92,7 +68,7 @@ def from_capabilities( """ Resolve capabilities and policy into a concrete pipeline of normalizers. - For each capability in ``PIPELINE_ORDER``: + For each capability in ``_NORMALIZER_REGISTRY`` (in order): * If the target already supports the capability, no normalizer is added. * If the capability is missing and the policy is ``ADAPT``, the @@ -100,31 +76,24 @@ def from_capabilities( * If the capability is missing and the policy is ``RAISE``, a ``ValueError`` is raised immediately. - NOTE: Normalizers are only valid when the capability can be overridden with a normalizer (which is indicated - by its presence in the registry), so we only iterate over valid capabilities in this function and add - normalizers only when the capability can support normalization. - Args: capabilities (TargetCapabilities): The target's declared capabilities. policy (CapabilityHandlingPolicy): How to handle each missing capability. normalizer_overrides (dict[CapabilityName, MessageListNormalizer[Message]] | None): Optional overrides for specific capability normalizers. - Falls back to the defaults from ``_default_normalizer_factory``. + Falls back to the defaults from ``_NORMALIZER_REGISTRY``. Returns: ConversationNormalizationPipeline: A pipeline with the resolved ordered tuple of normalizers. Raises: - ValueError: If a required capability is missing and the policy is RAISE, - or if a capability is not normalizable, or if no normalizer is - available for an ADAPT policy. + ValueError: If a required capability is missing and the policy is RAISE. """ - defaults = _default_normalizers() overrides = normalizer_overrides or {} normalizers: list[MessageListNormalizer[Message]] = [] - for capability in _PIPELINE_ORDER: + for capability, default_normalizer in _NORMALIZER_REGISTRY: if capabilities.includes(capability=capability): continue @@ -133,14 +102,7 @@ def from_capabilities( if behavior == UnsupportedCapabilityBehavior.RAISE: raise ValueError(f"Target does not support '{capability.value}' and the handling policy is RAISE.") - normalizer = overrides.get(capability) - if normalizer is None: - normalizer = defaults.get(capability) - if normalizer is None: - raise ValueError( - f"Target does not support '{capability.value}' and the policy is ADAPT, " - f"but no normalizer is available for this capability." - ) + normalizer = overrides.get(capability, default_normalizer) normalizers.append(normalizer) diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index 068f4ec32a..2e9a460c3d 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from enum import Enum from types import MappingProxyType -from typing import Optional, cast +from typing import NoReturn, Optional, cast from pyrit.models import PromptDataType @@ -69,14 +69,16 @@ def get_behavior(self, *, capability: CapabilityName) -> UnsupportedCapabilityBe UnsupportedCapabilityBehavior: The configured behavior. Raises: - AttributeError: If no policy exists for the capability. + KeyError: If no behavior exists for the capability. This occurs for + non-adaptable capabilities (e.g., supports_editable_history). """ try: return self.behaviors[capability] - except KeyError as exc: - raise AttributeError(capability.value) from exc + except KeyError: + supported = ", ".join(sorted(cap.value for cap in self.behaviors)) + raise KeyError(f"No policy for capability '{capability.value}'. Supported capabilities: {supported}.") - def __getattr__(self, name: str) -> UnsupportedCapabilityBehavior: + def __getattr__(self, name: str) -> NoReturn: """ Guard against accessing policies for non-adaptable or unknown capabilities. diff --git a/tests/unit/target/test_conversation_normalization_pipeline.py b/tests/unit/target/test_conversation_normalization_pipeline.py index 57d96c0539..77d69a7e4b 100644 --- a/tests/unit/target/test_conversation_normalization_pipeline.py +++ b/tests/unit/target/test_conversation_normalization_pipeline.py @@ -16,31 +16,41 @@ UnsupportedCapabilityBehavior, ) -_ADAPT_ALL = CapabilityHandlingPolicy( - behaviors={ - CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, - CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, - CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, - } -) -_RAISE_ALL = CapabilityHandlingPolicy( - behaviors={ - CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, - } -) +@pytest.fixture +def adapt_all_policy(): + return CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, + } + ) + + +@pytest.fixture +def raise_all_policy(): + return CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, + } + ) + +@pytest.fixture +def make_message(): + def _make(role: ChatMessageRole, content: str) -> Message: + return Message(message_pieces=[MessagePiece(role=role, original_value=content)]) -def _make_message(role: ChatMessageRole, content: str) -> Message: - return Message(message_pieces=[MessagePiece(role=role, original_value=content)]) + return _make # --------------------------------------------------------------------------- @@ -48,15 +58,15 @@ def _make_message(role: ChatMessageRole, content: str) -> Message: # --------------------------------------------------------------------------- -def test_from_capabilities_all_supported_empty_tuple(): +def test_from_capabilities_all_supported_empty_tuple(adapt_all_policy): caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) - pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=_ADAPT_ALL) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=adapt_all_policy) assert pipeline.normalizers == () -def test_from_capabilities_none_supported_has_two_normalizers(): +def test_from_capabilities_none_supported_has_two_normalizers(adapt_all_policy): caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) - pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=_ADAPT_ALL) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=adapt_all_policy) assert len(pipeline.normalizers) == 2 assert isinstance(pipeline.normalizers[0], GenericSystemSquashNormalizer) assert isinstance(pipeline.normalizers[1], HistorySquashNormalizer) @@ -92,9 +102,9 @@ def test_from_capabilities_missing_multi_turn_only(): assert isinstance(pipeline.normalizers[0], HistorySquashNormalizer) -def test_from_capabilities_normalizers_is_tuple(): +def test_from_capabilities_normalizers_is_tuple(adapt_all_policy): caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) - pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=_ADAPT_ALL) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=adapt_all_policy) assert isinstance(pipeline.normalizers, tuple) @@ -103,16 +113,16 @@ def test_from_capabilities_normalizers_is_tuple(): # --------------------------------------------------------------------------- -def test_from_capabilities_raises_when_system_prompt_missing_and_policy_raise(): +def test_from_capabilities_raises_when_system_prompt_missing_and_policy_raise(raise_all_policy): caps = TargetCapabilities(supports_system_prompt=False, supports_multi_turn=True) with pytest.raises(ValueError, match="RAISE"): - ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=_RAISE_ALL) + ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=raise_all_policy) -def test_from_capabilities_raises_when_multi_turn_missing_and_policy_raise(): +def test_from_capabilities_raises_when_multi_turn_missing_and_policy_raise(raise_all_policy): caps = TargetCapabilities(supports_system_prompt=True, supports_multi_turn=False) with pytest.raises(ValueError, match="RAISE"): - ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=_RAISE_ALL) + ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=raise_all_policy) # --------------------------------------------------------------------------- @@ -146,9 +156,9 @@ def test_from_capabilities_uses_override_normalizer(): @pytest.mark.asyncio -async def test_normalize_passthrough_when_empty_pipeline(): +async def test_normalize_passthrough_when_empty_pipeline(make_message): pipeline = ConversationNormalizationPipeline() - messages = [_make_message("system", "sys"), _make_message("user", "hi")] + messages = [make_message("system", "sys"), make_message("user", "hi")] result = await pipeline.normalize_async(messages=messages) assert len(result) == 2 @@ -162,7 +172,7 @@ async def test_normalize_passthrough_when_empty_pipeline(): @pytest.mark.asyncio -async def test_normalize_adapts_system_prompt(): +async def test_normalize_adapts_system_prompt(make_message): caps = TargetCapabilities(supports_system_prompt=False, supports_multi_turn=True) policy = CapabilityHandlingPolicy( behaviors={ @@ -174,7 +184,7 @@ async def test_normalize_adapts_system_prompt(): ) pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=policy) - messages = [_make_message("system", "be nice"), _make_message("user", "hello")] + messages = [make_message("system", "be nice"), make_message("user", "hello")] result = await pipeline.normalize_async(messages=messages) assert len(result) == 1 @@ -189,7 +199,7 @@ async def test_normalize_adapts_system_prompt(): @pytest.mark.asyncio -async def test_normalize_adapts_multi_turn(): +async def test_normalize_adapts_multi_turn(make_message): caps = TargetCapabilities(supports_system_prompt=True, supports_multi_turn=False) policy = CapabilityHandlingPolicy( behaviors={ @@ -202,9 +212,9 @@ async def test_normalize_adapts_multi_turn(): pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=policy) messages = [ - _make_message("user", "hello"), - _make_message("assistant", "hi"), - _make_message("user", "how are you?"), + make_message("user", "hello"), + make_message("assistant", "hi"), + make_message("user", "how are you?"), ] result = await pipeline.normalize_async(messages=messages) @@ -222,16 +232,16 @@ async def test_normalize_adapts_multi_turn(): @pytest.mark.asyncio -async def test_normalize_adapts_system_then_multi_turn(): +async def test_normalize_adapts_system_then_multi_turn(adapt_all_policy, make_message): """System squash runs first, then history squash.""" caps = TargetCapabilities(supports_system_prompt=False, supports_multi_turn=False) - pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=_ADAPT_ALL) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=adapt_all_policy) messages = [ - _make_message("system", "be nice"), - _make_message("user", "hello"), - _make_message("assistant", "hi"), - _make_message("user", "bye"), + make_message("system", "be nice"), + make_message("user", "hello"), + make_message("assistant", "hi"), + make_message("user", "bye"), ] result = await pipeline.normalize_async(messages=messages) @@ -248,14 +258,14 @@ async def test_normalize_adapts_system_then_multi_turn(): @pytest.mark.asyncio -async def test_normalize_uses_custom_normalizer(): +async def test_normalize_uses_custom_normalizer(make_message): mock_normalizer = MagicMock(spec=MessageListNormalizer) - expected = [_make_message("user", "custom")] + expected = [make_message("user", "custom")] mock_normalizer.normalize_async = AsyncMock(return_value=expected) pipeline = ConversationNormalizationPipeline(normalizers=(mock_normalizer,)) - messages = [_make_message("system", "sys"), _make_message("user", "hi")] + messages = [make_message("system", "sys"), make_message("user", "hi")] result = await pipeline.normalize_async(messages=messages) assert result == expected diff --git a/tests/unit/target/test_target_capabilities.py b/tests/unit/target/test_target_capabilities.py index eff36042c1..df33a4f073 100644 --- a/tests/unit/target/test_target_capabilities.py +++ b/tests/unit/target/test_target_capabilities.py @@ -63,7 +63,7 @@ def test_capability_handling_policy_get_behavior_for_all_default_keys(self): def test_capability_handling_policy_rejects_capability_without_policy(self): policy = CapabilityHandlingPolicy() - with pytest.raises(AttributeError, match="supports_editable_history"): + with pytest.raises(KeyError, match="No policy for capability 'supports_editable_history'"): policy.get_behavior(capability=CapabilityName.EDITABLE_HISTORY) with pytest.raises(AttributeError, match="supports_editable_history"): diff --git a/tests/unit/target/test_target_configuration.py b/tests/unit/target/test_target_configuration.py index e7408324d8..8fe14b6e1e 100644 --- a/tests/unit/target/test_target_configuration.py +++ b/tests/unit/target/test_target_configuration.py @@ -14,20 +14,27 @@ ) from pyrit.prompt_target.common.target_configuration import TargetConfiguration -_ADAPT_ALL = CapabilityHandlingPolicy( - behaviors={ - CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, - CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, - CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, - CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, - } -) +@pytest.fixture +def adapt_all_policy(): + return CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, + } + ) + + +@pytest.fixture +def make_message(): + def _make(role: ChatMessageRole, content: str) -> Message: + return Message(message_pieces=[MessagePiece(role=role, original_value=content)]) -def _make_message(role: ChatMessageRole, content: str) -> Message: - return Message(message_pieces=[MessagePiece(role=role, original_value=content)]) + return _make # --------------------------------------------------------------------------- @@ -42,21 +49,21 @@ def test_init_with_defaults_uses_raise_policy(): assert config.policy.get_behavior(capability=CapabilityName.MULTI_TURN) == UnsupportedCapabilityBehavior.RAISE -def test_init_with_explicit_policy(): +def test_init_with_explicit_policy(adapt_all_policy): caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) - assert config.policy is _ADAPT_ALL + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + assert config.policy is adapt_all_policy -def test_init_all_supported_empty_pipeline(): +def test_init_all_supported_empty_pipeline(adapt_all_policy): caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) assert config.pipeline.normalizers == () -def test_init_missing_capability_adapt_builds_pipeline(): +def test_init_missing_capability_adapt_builds_pipeline(adapt_all_policy): caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) - config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) assert len(config.pipeline.normalizers) == 2 assert isinstance(config.pipeline.normalizers[0], GenericSystemSquashNormalizer) assert isinstance(config.pipeline.normalizers[1], HistorySquashNormalizer) @@ -84,15 +91,15 @@ def test_capabilities_property(): # --------------------------------------------------------------------------- -def test_includes_returns_true_when_supported(): +def test_includes_returns_true_when_supported(adapt_all_policy): caps = TargetCapabilities(supports_multi_turn=True) - config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) assert config.includes(capability=CapabilityName.MULTI_TURN) is True -def test_includes_returns_false_when_unsupported(): +def test_includes_returns_false_when_unsupported(adapt_all_policy): caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) - config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) assert config.includes(capability=CapabilityName.MULTI_TURN) is False @@ -108,9 +115,9 @@ def test_ensure_can_handle_passes_when_supported(): config.ensure_can_handle(capability=CapabilityName.MULTI_TURN) -def test_ensure_can_handle_passes_when_adapt(): +def test_ensure_can_handle_passes_when_adapt(adapt_all_policy): caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) - config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) # ADAPT policy → should not raise config.ensure_can_handle(capability=CapabilityName.MULTI_TURN) @@ -143,23 +150,23 @@ def test_ensure_can_handle_raises_when_raise_policy(): @pytest.mark.asyncio -async def test_normalize_async_passthrough_when_all_supported(): +async def test_normalize_async_passthrough_when_all_supported(adapt_all_policy, make_message): caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) - msgs = [_make_message("user", "hello")] + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + msgs = [make_message("user", "hello")] result = await config.normalize_async(messages=msgs) assert len(result) == 1 assert result[0].message_pieces[0].converted_value == "hello" @pytest.mark.asyncio -async def test_normalize_async_adapts_system_prompt(): +async def test_normalize_async_adapts_system_prompt(adapt_all_policy, make_message): caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=False) - config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) msgs = [ - _make_message("system", "you are helpful"), - _make_message("user", "hello"), + make_message("system", "you are helpful"), + make_message("user", "hello"), ] result = await config.normalize_async(messages=msgs) # System squash merges system into user messages — no system role left @@ -169,14 +176,14 @@ async def test_normalize_async_adapts_system_prompt(): @pytest.mark.asyncio -async def test_normalize_async_adapts_multi_turn(): +async def test_normalize_async_adapts_multi_turn(adapt_all_policy, make_message): caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True) - config = TargetConfiguration(capabilities=caps, policy=_ADAPT_ALL) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) msgs = [ - _make_message("user", "turn 1"), - _make_message("assistant", "reply 1"), - _make_message("user", "turn 2"), + make_message("user", "turn 1"), + make_message("assistant", "reply 1"), + make_message("user", "turn 2"), ] result = await config.normalize_async(messages=msgs) # History squash collapses into a single message From 66e9f5654d0218adc23429a8b572f04ffa2db1ca Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Wed, 8 Apr 2026 13:10:07 -0400 Subject: [PATCH 9/9] pre-commit and catch --- pyrit/prompt_target/common/target_capabilities.py | 4 +++- pyrit/prompt_target/common/target_configuration.py | 7 ++++++- tests/unit/target/test_target_configuration.py | 7 +++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index 2e9a460c3d..7a34222803 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -76,7 +76,9 @@ def get_behavior(self, *, capability: CapabilityName) -> UnsupportedCapabilityBe return self.behaviors[capability] except KeyError: supported = ", ".join(sorted(cap.value for cap in self.behaviors)) - raise KeyError(f"No policy for capability '{capability.value}'. Supported capabilities: {supported}.") + raise KeyError( + f"No policy for capability '{capability.value}'. Supported capabilities: {supported}." + ) from None def __getattr__(self, name: str) -> NoReturn: """ diff --git a/pyrit/prompt_target/common/target_configuration.py b/pyrit/prompt_target/common/target_configuration.py index ae9094d7d9..47abdb55d5 100644 --- a/pyrit/prompt_target/common/target_configuration.py +++ b/pyrit/prompt_target/common/target_configuration.py @@ -110,7 +110,12 @@ def ensure_can_handle(self, *, capability: CapabilityName) -> None: if self._capabilities.includes(capability=capability): return - behavior = self._policy.get_behavior(capability=capability) + try: + behavior = self._policy.get_behavior(capability=capability) + except KeyError: + raise ValueError( + f"Target does not support '{capability.value}' and no handling policy exists for it." + ) from None if behavior == UnsupportedCapabilityBehavior.RAISE: raise ValueError(f"Target does not support '{capability.value}' and the handling policy is RAISE.") diff --git a/tests/unit/target/test_target_configuration.py b/tests/unit/target/test_target_configuration.py index 8fe14b6e1e..df0dbe3d62 100644 --- a/tests/unit/target/test_target_configuration.py +++ b/tests/unit/target/test_target_configuration.py @@ -144,6 +144,13 @@ def test_ensure_can_handle_raises_when_raise_policy(): config.ensure_can_handle(capability=CapabilityName.JSON_OUTPUT) +def test_ensure_can_handle_raises_valueerror_for_non_normalizable_capability(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True, supports_editable_history=False) + config = TargetConfiguration(capabilities=caps) + with pytest.raises(ValueError, match="no handling policy"): + config.ensure_can_handle(capability=CapabilityName.EDITABLE_HISTORY) + + # --------------------------------------------------------------------------- # normalize_async # ---------------------------------------------------------------------------