From 6ec9e809e7188af4956c48cb4bc92a0718da08b4 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 1 Apr 2026 13:47:31 -0700 Subject: [PATCH 01/30] Introduce IdentifierFilters to allow generic DB queries on identifier properties --- pyrit/memory/__init__.py | 9 + pyrit/memory/azure_sql_memory.py | 247 +++++++---------- pyrit/memory/identifier_filters.py | 95 +++++++ pyrit/memory/memory_interface.py | 262 +++++++++++++----- pyrit/memory/sqlite_memory.py | 202 ++++++-------- .../test_interface_attack_results.py | 53 ++++ .../test_interface_prompts.py | 115 ++++++++ .../test_interface_scenario_results.py | 114 ++++++++ .../memory_interface/test_interface_scores.py | 74 +++++ 9 files changed, 822 insertions(+), 349 deletions(-) create mode 100644 pyrit/memory/identifier_filters.py diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index 9f10860130..102a1f8607 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -7,6 +7,7 @@ This package defines the core `MemoryInterface` and concrete implementations for different storage backends. """ +from pyrit.memory.identifier_filters import AttackIdentifierFilter, AttackIdentifierProperty, ConverterIdentifierFilter, ConverterIdentifierProperty, ScorerIdentifierFilter, ScorerIdentifierProperty, TargetIdentifierFilter, TargetIdentifierProperty from pyrit.memory.azure_sql_memory import AzureSQLMemory from pyrit.memory.central_memory import CentralMemory from pyrit.memory.memory_embedding import MemoryEmbedding @@ -17,6 +18,10 @@ __all__ = [ "AttackResultEntry", + "AttackIdentifierFilter", + "AttackIdentifierProperty", + "ConverterIdentifierFilter", + "ConverterIdentifierProperty", "AzureSQLMemory", "CentralMemory", "SQLiteMemory", @@ -25,5 +30,9 @@ "MemoryEmbedding", "MemoryExporter", "PromptMemoryEntry", + "ScorerIdentifierFilter", + "ScorerIdentifierProperty", "SeedEntry", + "TargetIdentifierFilter", + "TargetIdentifierProperty", ] diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index c9f349c0d9..48ae2c5df2 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -250,22 +250,6 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str condition = text(conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()}) return [condition] - def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: - """ - Generate SQL condition for filtering message pieces by attack ID. - - Uses JSON_VALUE() function specific to SQL Azure to query the attack identifier. - - Args: - attack_id (str): The attack identifier to filter by. - - Returns: - Any: SQLAlchemy text condition with bound parameter. - """ - return text("ISJSON(attack_identifier) = 1 AND JSON_VALUE(attack_identifier, '$.hash') = :json_id").bindparams( - json_id=str(attack_id) - ) - def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list[TextClause]: """ Generate SQL conditions for filtering by prompt metadata. @@ -321,6 +305,99 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) """ return self._get_metadata_conditions(prompt_metadata=metadata)[0] + def _get_condition_json_property_match( + self, + *, + json_column: Any, + property_path: str, + value_to_match: str, + partial_match: bool = False, + ) -> Any: + table_name = json_column.class_.__tablename__ + column_name = json_column.key + + return text( + f"""ISJSON("{table_name}".{column_name}) = 1 + AND LOWER(JSON_VALUE("{table_name}".{column_name}, :property_path)) {"LIKE" if partial_match else "="} :match_property_value""" # noqa: E501 + ).bindparams( + property_path=property_path, + match_property_value=f"%{value_to_match.lower()}%" if partial_match else value_to_match, + ) + # The above return statement already handles both partial and exact matches + # The following code is now unreachable and can be removed + + def _get_condition_json_array_match( + self, + *, + json_column: Any, + property_path: str, + array_to_match: Sequence[str], + case_insensitive: bool = False, + ) -> Any: + table_name = json_column.class_.__tablename__ + column_name = json_column.key + if len(array_to_match) == 0: + return text( + f"""("{table_name}".{column_name} IS NULL + OR JSON_QUERY("{table_name}".{column_name}, :property_path) IS NULL + OR JSON_QUERY("{table_name}".{column_name}, :property_path) = '[]')""" + ).bindparams(property_path=property_path) + + value_expression = "JSON_VALUE(value, '$.class_name')" + if case_insensitive: + value_expression = f"LOWER({value_expression})" + + conditions = [] + bindparams_dict: dict[str, str] = {"property_path": property_path} + + for index, match_value in enumerate(array_to_match): + param_name = f"match_value_{index}" + conditions.append( + f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("{table_name}".{column_name}, + :property_path)) + WHERE {value_expression} = :{param_name})""" + ) + bindparams_dict[param_name] = match_value.lower() if case_insensitive else match_value + + combined = " AND ".join(conditions) + return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams( + **bindparams_dict + ) + + def _get_unique_json_property_values( + self, + *, + json_column: Any, + path_to_array: str, + sub_path: str | None = None, + ) -> list[str]: + table_name = json_column.class_.__tablename__ + column_name = json_column.key + with closing(self.get_session()) as session: + if sub_path is None: + rows = session.execute( + text( + f"""SELECT DISTINCT JSON_VALUE("{table_name}".{column_name}, :path_to_array) AS value + FROM "{table_name}" + WHERE ISJSON("{table_name}".{column_name}) = 1 + AND JSON_VALUE("{table_name}".{column_name}, :path_to_array) IS NOT NULL""" + ).bindparams(path_to_array=path_to_array) + ).fetchall() + else: + rows = session.execute( + text( + f"""SELECT DISTINCT JSON_VALUE(items.value, :sub_path) AS value + FROM "{table_name}" + CROSS APPLY OPENJSON(JSON_QUERY("{table_name}".{column_name}, :path_to_array)) AS items + WHERE ISJSON("{table_name}".{column_name}) = 1 + AND JSON_VALUE(items.value, :sub_path) IS NOT NULL""" + ).bindparams( + path_to_array=path_to_array, + sub_path=sub_path, + ) + ).fetchall() + return sorted(row[0] for row in rows) + def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any: """ Get the SQL Azure implementation for filtering AttackResults by targeted harm categories. @@ -388,110 +465,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) ) - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: - """ - Azure SQL implementation for filtering AttackResults by attack class. - Uses JSON_VALUE() on the atomic_attack_identifier JSON column. - - Args: - attack_class (str): Exact attack class name to match. - - Returns: - Any: SQLAlchemy text condition with bound parameter. - """ - return text( - """ISJSON("AttackResultEntries".atomic_attack_identifier) = 1 - AND JSON_VALUE("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.class_name') = :attack_class""" - ).bindparams(attack_class=attack_class) - - def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: - """ - Azure SQL implementation for filtering AttackResults by converter classes. - - Uses JSON_VALUE()/JSON_QUERY()/OPENJSON() on the atomic_attack_identifier - JSON column. - - When converter_classes is empty, matches attacks with no converters. - When non-empty, uses OPENJSON() to check all specified classes are present - (AND logic, case-insensitive). - - Args: - converter_classes (Sequence[str]): List of converter class names. Empty list means no converters. - - Returns: - Any: SQLAlchemy combined condition with bound parameters. - """ - if len(converter_classes) == 0: - # Explicitly "no converters": match attacks where the converter list - # is absent, null, or empty in the stored JSON. - return text( - """("AttackResultEntries".atomic_attack_identifier IS NULL - OR JSON_QUERY("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters') IS NULL - OR JSON_QUERY("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters') = '[]')""" - ) - - conditions = [] - bindparams_dict: dict[str, str] = {} - for i, cls in enumerate(converter_classes): - param_name = f"conv_cls_{i}" - conditions.append( - f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters')) - WHERE LOWER(JSON_VALUE(value, '$.class_name')) = :{param_name})""" - ) - bindparams_dict[param_name] = cls.lower() - - combined = " AND ".join(conditions) - return text(f"""ISJSON("AttackResultEntries".atomic_attack_identifier) = 1 AND {combined}""").bindparams( - **bindparams_dict - ) - - def get_unique_attack_class_names(self) -> list[str]: - """ - Azure SQL implementation: extract unique class_name values from - the atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique attack class name strings. - """ - with closing(self.get_session()) as session: - rows = session.execute( - text( - """SELECT DISTINCT JSON_VALUE(atomic_attack_identifier, - '$.children.attack.class_name') AS cls - FROM "AttackResultEntries" - WHERE ISJSON(atomic_attack_identifier) = 1 - AND JSON_VALUE(atomic_attack_identifier, - '$.children.attack.class_name') IS NOT NULL""" - ) - ).fetchall() - return sorted(row[0] for row in rows) - - def get_unique_converter_class_names(self) -> list[str]: - """ - Azure SQL implementation: extract unique converter class_name values - from the children.attack.children.request_converters array - in the atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique converter class name strings. - """ - with closing(self.get_session()) as session: - rows = session.execute( - text( - """SELECT DISTINCT JSON_VALUE(c.value, '$.class_name') AS cls - FROM "AttackResultEntries" - CROSS APPLY OPENJSON(JSON_QUERY(atomic_attack_identifier, - '$.children.attack.children.request_converters')) AS c - WHERE ISJSON(atomic_attack_identifier) = 1 - AND JSON_VALUE(c.value, '$.class_name') IS NOT NULL""" - ) - ).fetchall() - return sorted(row[0] for row in rows) - def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, ConversationStats]: """ Azure SQL implementation: lightweight aggregate stats per conversation. @@ -593,40 +566,6 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any conditions.append(condition) return and_(*conditions) - def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> TextClause: - """ - Get the SQL Azure implementation for filtering ScenarioResults by target endpoint. - - Uses JSON_VALUE() function specific to SQL Azure. - - Args: - endpoint (str): The endpoint URL substring to filter by (case-insensitive). - - Returns: - Any: SQLAlchemy text condition with bound parameter. - """ - return text( - """ISJSON(objective_target_identifier) = 1 - AND LOWER(JSON_VALUE(objective_target_identifier, '$.endpoint')) LIKE :endpoint""" - ).bindparams(endpoint=f"%{endpoint.lower()}%") - - def _get_scenario_result_target_model_condition(self, *, model_name: str) -> TextClause: - """ - Get the SQL Azure implementation for filtering ScenarioResults by target model name. - - Uses JSON_VALUE() function specific to SQL Azure. - - Args: - model_name (str): The model name substring to filter by (case-insensitive). - - Returns: - Any: SQLAlchemy text condition with bound parameter. - """ - return text( - """ISJSON(objective_target_identifier) = 1 - AND LOWER(JSON_VALUE(objective_target_identifier, '$.model_name')) LIKE :model_name""" - ).bindparams(model_name=f"%{model_name.lower()}%") - def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Insert a list of message pieces into the memory storage. diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py new file mode 100644 index 0000000000..8792f03241 --- /dev/null +++ b/pyrit/memory/identifier_filters.py @@ -0,0 +1,95 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from abc import ABC +from dataclasses import dataclass +from enum import Enum +from typing import Generic, TypeVar + + +# TODO: if/when we move to python 3.11+, we can replace this with StrEnum +class _StrEnum(str, Enum): + """Base class that mimics StrEnum behavior for Python < 3.11.""" + + def __str__(self) -> str: + return self.value + + +T = TypeVar("T", bound=_StrEnum) + + +class IdentifierProperty(_StrEnum): + """Allowed JSON paths for identifier filtering.""" + + HASH = "$.hash" + CLASS_NAME = "$.class_name" + + +@dataclass(frozen=True) +class IdentifierFilter(ABC, Generic[T]): + """Immutable filter definition for matching JSON-backed identifier properties.""" + + property_path: T | str + value_to_match: str + partial_match: bool = False + + def __post_init__(self) -> None: + """Normalize and validate the configured property path.""" + object.__setattr__(self, "property_path", str(self.property_path)) + + +class AttackIdentifierProperty(_StrEnum): + """Allowed JSON paths for attack identifier filtering.""" + + HASH = "$.hash" + ATTACK_CLASS_NAME = "$.children.attack.class_name" + REQUEST_CONVERTERS = "$.children.attack.children.request_converters" + + +class TargetIdentifierProperty(_StrEnum): + """Allowed JSON paths for target identifier filtering.""" + + HASH = "$.hash" + ENDPOINT = "$.endpoint" + MODEL_NAME = "$.model_name" + + +class ConverterIdentifierProperty(_StrEnum): + """Allowed JSON paths for converter identifier filtering.""" + + HASH = "$.hash" + CLASS_NAME = "$.class_name" + + +class ScorerIdentifierProperty(_StrEnum): + """Allowed JSON paths for scorer identifier filtering.""" + + HASH = "$.hash" + CLASS_NAME = "$.class_name" + + +@dataclass(frozen=True) +class AttackIdentifierFilter(IdentifierFilter[AttackIdentifierProperty]): + """ + Immutable filter definition for matching JSON-backed attack identifier properties. + + Args: + property_path: The JSON path of the property to filter on. + value_to_match: The value to match against the property. + partial_match: Whether to allow partial matches (default: False). + """ + + +@dataclass(frozen=True) +class TargetIdentifierFilter(IdentifierFilter[TargetIdentifierProperty]): + """Immutable filter definition for matching JSON-backed target identifier properties.""" + + +@dataclass(frozen=True) +class ConverterIdentifierFilter(IdentifierFilter[ConverterIdentifierProperty]): + """Immutable filter definition for matching JSON-backed converter identifier properties.""" + + +@dataclass(frozen=True) +class ScorerIdentifierFilter(IdentifierFilter[ScorerIdentifierProperty]): + """Immutable filter definition for matching JSON-backed scorer identifier properties.""" diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 90322ebec4..5bc1f4ad3e 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -19,6 +19,14 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute from pyrit.common.path import DB_DATA_PATH +from pyrit.memory.identifier_filters import ( + AttackIdentifierFilter, + AttackIdentifierProperty, + ConverterIdentifierProperty, + ScorerIdentifierFilter, + TargetIdentifierFilter, + TargetIdentifierProperty, +) from pyrit.memory.memory_embedding import ( MemoryEmbedding, default_memory_embedding_factory, @@ -113,6 +121,77 @@ def disable_embedding(self) -> None: """ self.memory_embedding = None + @abc.abstractmethod + def _get_condition_json_property_match( + self, + *, + json_column: InstrumentedAttribute[Any], + property_path: str, + value_to_match: str, + partial_match: bool = False, + ) -> Any: + """ + Return a database-specific condition for matching a value at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query. + property_path (str): The JSON path for the property to match. + value_to_match (str): The string value that must match the extracted JSON property value. + partial_match (bool): Whether to perform a case-insensitive substring match. + + Returns: + Any: A SQLAlchemy condition for the backend-specific JSON query. + """ + + @abc.abstractmethod + def _get_condition_json_array_match( + self, + *, + json_column: InstrumentedAttribute[Any], + property_path: str, + array_to_match: Sequence[str], + case_insensitive: bool = False, + ) -> Any: + """ + Return a database-specific condition for matching an array at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. + property_path (str): The JSON path for the target array. + array_to_match (Sequence[str]): The array that must match the extracted JSON array values. + For a match, ALL values in this array must be present in the JSON array. + If `array_to_match` is empty, the condition must match only if the target is also an empty array or None. + case_insensitive (bool): Whether string comparison should ignore casing. + + Returns: + Any: A database-specific SQLAlchemy condition. + """ + + @abc.abstractmethod + def _get_unique_json_array_values( + self, + *, + json_column: Any, + path_to_array: str, + sub_path: str | None = None, + ) -> list[str]: + """ + Return sorted unique values in an array located at a given path within a JSON object. + + This method performs a database-level query to extract distinct values from a + an array within a JSON-type column. When ``sub_path`` is provided, the distinct values are + extracted from each array item using the sub-path. + + Args: + json_column (Any): The JSON-backed model field to query. + path_to_array (str): The JSON path to the array whose unique values are extracted. + sub_path (str | None): Optional JSON path applied to each array + item before collecting distinct values. + + Returns: + list[str]: A sorted list of unique values in the array. + """ + @abc.abstractmethod def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]: """ @@ -155,12 +234,6 @@ def _get_message_pieces_prompt_metadata_conditions( list: A list of conditions for filtering memory entries based on prompt metadata. """ - @abc.abstractmethod - def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: - """ - Return a condition to retrieve based on attack ID. - """ - @abc.abstractmethod def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: """ @@ -289,41 +362,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: Database-specific SQLAlchemy condition. """ - @abc.abstractmethod - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: - """ - Return a database-specific condition for filtering AttackResults by attack class - (class_name in the attack_identifier JSON column). - - Args: - attack_class: Exact attack class name to match. - - Returns: - Database-specific SQLAlchemy condition. - """ - - @abc.abstractmethod - def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: - """ - Return a database-specific condition for filtering AttackResults by converter classes - in the request_converter_identifiers array within attack_identifier JSON column. - - This method is only called when converter filtering is requested (converter_classes - is not None). The caller handles the None-vs-list distinction: - - - ``len(converter_classes) == 0``: return a condition matching attacks with NO converters. - - ``len(converter_classes) > 0``: return a condition requiring ALL specified converter - class names to be present (AND logic, case-insensitive). - - Args: - converter_classes: Converter class names to require. An empty sequence means - "match only attacks that have no converters". - - Returns: - Database-specific SQLAlchemy condition. - """ - - @abc.abstractmethod def get_unique_attack_class_names(self) -> list[str]: """ Return sorted unique attack class names from all stored attack results. @@ -334,8 +372,11 @@ def get_unique_attack_class_names(self) -> list[str]: Returns: Sorted list of unique attack class name strings. """ + return self._get_unique_json_array_values( + json_column=AttackResultEntry.atomic_attack_identifier, + path_to_array=AttackIdentifierProperty.ATTACK_CLASS_NAME + ) - @abc.abstractmethod def get_unique_converter_class_names(self) -> list[str]: """ Return sorted unique converter class names used across all attack results. @@ -346,6 +387,11 @@ def get_unique_converter_class_names(self) -> list[str]: Returns: Sorted list of unique converter class name strings. """ + return self._get_unique_json_array_values( + json_column=AttackResultEntry.atomic_attack_identifier, + path_to_array=AttackIdentifierProperty.REQUEST_CONVERTERS, + sub_path=ConverterIdentifierProperty.CLASS_NAME, + ) @abc.abstractmethod def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, "ConversationStats"]: @@ -377,30 +423,6 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any Database-specific SQLAlchemy condition. """ - @abc.abstractmethod - def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> Any: - """ - Return a database-specific condition for filtering ScenarioResults by target endpoint. - - Args: - endpoint: Endpoint substring to search for (case-insensitive). - - Returns: - Database-specific SQLAlchemy condition. - """ - - @abc.abstractmethod - def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any: - """ - Return a database-specific condition for filtering ScenarioResults by target model name. - - Args: - model_name: Model name substring to search for (case-insensitive). - - Returns: - Database-specific SQLAlchemy condition. - """ - def add_scores_to_memory(self, *, scores: Sequence[Score]) -> None: """ Insert a list of scores into the memory storage. @@ -425,6 +447,7 @@ def get_scores( score_category: Optional[str] = None, sent_after: Optional[datetime] = None, sent_before: Optional[datetime] = None, + scorer_identifier_filter: Optional[ScorerIdentifierFilter] = None, ) -> Sequence[Score]: """ Retrieve a list of Score objects based on the specified filters. @@ -435,6 +458,8 @@ def get_scores( score_category (Optional[str]): The category of the score to filter by. sent_after (Optional[datetime]): Filter for scores sent after this datetime. sent_before (Optional[datetime]): Filter for scores sent before this datetime. + scorer_identifier_filter (Optional[ScorerIdentifierFilter]): A ScorerIdentifierFilter object that + allows filtering by various scorer identifier JSON properties. Defaults to None. Returns: Sequence[Score]: A list of Score objects that match the specified filters. @@ -451,6 +476,15 @@ def get_scores( conditions.append(ScoreEntry.timestamp >= sent_after) if sent_before: conditions.append(ScoreEntry.timestamp <= sent_before) + if scorer_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=ScoreEntry.scorer_class_identifier, + property_path=scorer_identifier_filter.property_path, + value_to_match=scorer_identifier_filter.value_to_match, + partial_match=scorer_identifier_filter.partial_match, + ) + ) if not conditions: return [] @@ -581,6 +615,8 @@ def get_message_pieces( data_type: Optional[str] = None, not_data_type: Optional[str] = None, converted_value_sha256: Optional[Sequence[str]] = None, + attack_identifier_filter: Optional[AttackIdentifierFilter] = None, + prompt_target_identifier_filter: Optional[TargetIdentifierFilter] = None, ) -> Sequence[MessagePiece]: """ Retrieve a list of MessagePiece objects based on the specified filters. @@ -602,6 +638,12 @@ def get_message_pieces( not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. Defaults to None. + attack_identifier_filter (Optional[AttackIdentifierFilter], optional): + An AttackIdentifierFilter object that + allows filtering by various attack identifier JSON properties. Defaults to None. + prompt_target_identifier_filter (Optional[TargetIdentifierFilter], optional): + A TargetIdentifierFilter object that + allows filtering by various target identifier JSON properties. Defaults to None. Returns: Sequence[MessagePiece]: A list of MessagePiece objects that match the specified filters. @@ -612,7 +654,13 @@ def get_message_pieces( """ conditions = [] if attack_id: - conditions.append(self._get_message_pieces_attack_conditions(attack_id=str(attack_id))) + conditions.append( + self._get_condition_json_property_match( + json_column=PromptMemoryEntry.attack_identifier, + property_path=AttackIdentifierProperty.HASH, + value_to_match=str(attack_id) + ) + ) if role: conditions.append(PromptMemoryEntry.role == role) if conversation_id: @@ -638,6 +686,24 @@ def get_message_pieces( conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) if converted_value_sha256: conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256)) + if attack_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=PromptMemoryEntry.attack_identifier, + property_path=attack_identifier_filter.property_path, + value_to_match=attack_identifier_filter.value_to_match, + partial_match=attack_identifier_filter.partial_match, + ) + ) + if prompt_target_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=PromptMemoryEntry.prompt_target_identifier, + property_path=prompt_target_identifier_filter.property_path, + value_to_match=prompt_target_identifier_filter.value_to_match, + partial_match=prompt_target_identifier_filter.partial_match, + ) + ) try: memory_entries: Sequence[PromptMemoryEntry] = self._query_entries( @@ -1365,6 +1431,7 @@ def get_attack_results( converter_classes: Optional[Sequence[str]] = None, targeted_harm_categories: Optional[Sequence[str]] = None, labels: Optional[dict[str, str]] = None, + attack_identifier_filter: Optional[AttackIdentifierFilter] = None, ) -> Sequence[AttackResult]: """ Retrieve a list of AttackResult objects based on the specified filters. @@ -1392,6 +1459,9 @@ def get_attack_results( labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. These labels are associated with the prompts themselves, used for custom tagging and tracking. Defaults to None. + attack_identifier_filter (Optional[AttackIdentifierFilter], optional): + An AttackIdentifierFilter object that allows filtering by various attack identifier + JSON properties. Defaults to None. Returns: Sequence[AttackResult]: A list of AttackResult objects that match the specified filters. @@ -1415,12 +1485,25 @@ def get_attack_results( if attack_class: # Use database-specific JSON query method - conditions.append(self._get_attack_result_attack_class_condition(attack_class=attack_class)) + conditions.append( + self._get_condition_json_property_match( + json_column=AttackResultEntry.atomic_attack_identifier, + property_path=AttackIdentifierProperty.ATTACK_CLASS_NAME, + value_to_match=attack_class, + ) + ) if converter_classes is not None: # converter_classes=[] means "only attacks with no converters" # converter_classes=["A","B"] means "must have all listed converters" - conditions.append(self._get_attack_result_converter_classes_condition(converter_classes=converter_classes)) + conditions.append( + self._get_condition_json_array_match( + json_column=AttackResultEntry.atomic_attack_identifier, + property_path=AttackIdentifierProperty.REQUEST_CONVERTERS, + array_to_match=converter_classes, + case_insensitive=True, + ) + ) if targeted_harm_categories: # Use database-specific JSON query method @@ -1432,6 +1515,16 @@ def get_attack_results( # Use database-specific JSON query method conditions.append(self._get_attack_result_label_condition(labels=labels)) + if attack_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=AttackResultEntry.atomic_attack_identifier, + property_path=attack_identifier_filter.property_path, + value_to_match=attack_identifier_filter.value_to_match, + partial_match=attack_identifier_filter.partial_match, + ) + ) + try: entries: Sequence[AttackResultEntry] = self._query_entries( AttackResultEntry, conditions=and_(*conditions) if conditions else None @@ -1612,6 +1705,7 @@ def get_scenario_results( labels: Optional[dict[str, str]] = None, objective_target_endpoint: Optional[str] = None, objective_target_model_name: Optional[str] = None, + objective_target_identifier_filter: Optional[TargetIdentifierFilter] = None, ) -> Sequence[ScenarioResult]: """ Retrieve a list of ScenarioResult objects based on the specified filters. @@ -1635,6 +1729,8 @@ def get_scenario_results( objective_target_model_name (Optional[str], optional): Filter for scenarios where the objective_target_identifier has a model_name attribute containing this value (case-insensitive). Defaults to None. + objective_target_identifier_filter (Optional[TargetIdentifierFilter], optional): + A TargetIdentifierFilter object that allows filtering by various target identifier JSON properties. Returns: Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. @@ -1672,11 +1768,35 @@ def get_scenario_results( if objective_target_endpoint: # Use database-specific JSON query method - conditions.append(self._get_scenario_result_target_endpoint_condition(endpoint=objective_target_endpoint)) + conditions.append( + self._get_condition_json_property_match( + json_column=ScenarioResultEntry.objective_target_identifier, + property_path=TargetIdentifierProperty.ENDPOINT, + value_to_match=objective_target_endpoint, + partial_match=True, + ) + ) if objective_target_model_name: # Use database-specific JSON query method - conditions.append(self._get_scenario_result_target_model_condition(model_name=objective_target_model_name)) + conditions.append( + self._get_condition_json_property_match( + json_column=ScenarioResultEntry.objective_target_identifier, + property_path=TargetIdentifierProperty.MODEL_NAME, + value_to_match=objective_target_model_name, + partial_match=True, + ) + ) + + if objective_target_identifier_filter: + conditions.append( + self._get_condition_json_property_match( + json_column=ScenarioResultEntry.objective_target_identifier, + property_path=objective_target_identifier_filter.property_path, + value_to_match=objective_target_identifier_filter.value_to_match, + partial_match=objective_target_identifier_filter.partial_match, + ) + ) try: entries: Sequence[ScenarioResultEntry] = self._query_entries( diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 7bd05b4f82..a41dbffc90 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -177,15 +177,6 @@ def _get_message_pieces_prompt_metadata_conditions( condition = text(json_conditions).bindparams(**{key: str(value) for key, value in prompt_metadata.items()}) return [condition] - def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: - """ - Generate SQLAlchemy filter conditions for filtering by attack ID. - - Returns: - Any: A SQLAlchemy text condition with bound parameters. - """ - return text("JSON_EXTRACT(attack_identifier, '$.hash') = :attack_id").bindparams(attack_id=str(attack_id)) - def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: """ Generate SQLAlchemy filter conditions for filtering seed prompts by metadata. @@ -199,6 +190,84 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) # Note: We do NOT convert values to string here, to allow integer comparison in JSON return text(json_conditions).bindparams(**dict(metadata.items())) + def _get_condition_json_property_match( + self, + *, + json_column: Any, + property_path: str, + value_to_match: str, + partial_match: bool = False, + ) -> Any: + extracted_value = func.json_extract(json_column, property_path) + if partial_match: + return func.lower(extracted_value).like(f"%{value_to_match.lower()}%") + return extracted_value == value_to_match + + def _get_condition_json_array_match( + self, + *, + json_column: Any, + property_path: str, + array_to_match: Sequence[str], + case_insensitive: bool = False, + ) -> Any: + array_expr = func.json_extract(json_column, property_path) + if len(array_to_match) == 0: + return or_( + json_column.is_(None), + array_expr.is_(None), + array_expr == "[]", + ) + + table_name = json_column.class_.__tablename__ + column_name = json_column.key + value_expression = "json_extract(value, '$.class_name')" + if case_insensitive: + value_expression = f"LOWER({value_expression})" + + conditions = [] + for index, match_value in enumerate(array_to_match): + param_name = f"match_value_{index}" + bind_params: dict[str, str] = { + "property_path": property_path, + param_name: match_value.lower() if case_insensitive else match_value, + } + conditions.append( + text( + f'''EXISTS(SELECT 1 FROM json_each( + json_extract("{table_name}".{column_name}, :property_path)) + WHERE {value_expression} = :{param_name})''' + ).bindparams(**bind_params) + ) + return and_(*conditions) + + def _get_unique_json_array_values( + self, + *, + json_column: Any, + path_to_array: str, + sub_path: str | None = None, + ) -> list[str]: + with closing(self.get_session()) as session: + if sub_path is None: + property_expr = func.json_extract(json_column, path_to_array) + rows = session.query(property_expr).filter(property_expr.isnot(None)).distinct().all() + else: + table_name = json_column.class_.__tablename__ + column_name = json_column.key + rows = session.execute( + text( + f'''SELECT DISTINCT json_extract(j.value, :sub_path) AS value + FROM "{table_name}", + json_each(json_extract("{table_name}".{column_name}, :path_to_array)) AS j + WHERE json_extract(j.value, :sub_path) IS NOT NULL''' + ).bindparams( + path_to_array=path_to_array, + sub_path=sub_path, + ) + ).fetchall() + return sorted(row[0] for row in rows) + def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Insert a list of message pieces into the memory storage. @@ -526,97 +595,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) return labels_subquery # noqa: RET504 - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: - """ - SQLite implementation for filtering AttackResults by attack class. - Uses json_extract() on the atomic_attack_identifier JSON column. - - Returns: - Any: A SQLAlchemy condition for filtering by attack class. - """ - return ( - func.json_extract(AttackResultEntry.atomic_attack_identifier, "$.children.attack.class_name") - == attack_class - ) - - def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: - """ - SQLite implementation for filtering AttackResults by converter classes. - - Uses json_extract() on the atomic_attack_identifier JSON column. - - When converter_classes is empty, matches attacks with no converters - (children.attack.children.request_converters is absent or null in the JSON). - When non-empty, uses json_each() to check all specified classes are present - (AND logic, case-insensitive). - - Returns: - Any: A SQLAlchemy condition for filtering by converter classes. - """ - if len(converter_classes) == 0: - # Explicitly "no converters": match attacks where the converter list - # is absent, null, or empty in the stored JSON. - converter_json = func.json_extract( - AttackResultEntry.atomic_attack_identifier, - "$.children.attack.children.request_converters", - ) - return or_( - AttackResultEntry.atomic_attack_identifier.is_(None), - converter_json.is_(None), - converter_json == "[]", - ) - - conditions = [] - for i, cls in enumerate(converter_classes): - param_name = f"conv_cls_{i}" - conditions.append( - text( - f"""EXISTS(SELECT 1 FROM json_each( - json_extract("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters')) - WHERE LOWER(json_extract(value, '$.class_name')) = :{param_name})""" - ).bindparams(**{param_name: cls.lower()}) - ) - return and_(*conditions) - - def get_unique_attack_class_names(self) -> list[str]: - """ - SQLite implementation: extract unique class_name values from - the atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique attack class name strings. - """ - with closing(self.get_session()) as session: - class_name_expr = func.json_extract( - AttackResultEntry.atomic_attack_identifier, "$.children.attack.class_name" - ) - rows = session.query(class_name_expr).filter(class_name_expr.isnot(None)).distinct().all() - return sorted(row[0] for row in rows) - - def get_unique_converter_class_names(self) -> list[str]: - """ - SQLite implementation: extract unique converter class_name values - from the children.attack.children.request_converters array in the - atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique converter class name strings. - """ - with closing(self.get_session()) as session: - rows = session.execute( - text( - """SELECT DISTINCT json_extract(j.value, '$.class_name') AS cls - FROM "AttackResultEntries", - json_each( - json_extract("AttackResultEntries".atomic_attack_identifier, - '$.children.attack.children.request_converters') - ) AS j - WHERE cls IS NOT NULL""" - ) - ).fetchall() - return sorted(row[0] for row in rows) - def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, ConversationStats]: """ SQLite implementation: lightweight aggregate stats per conversation. @@ -710,27 +688,3 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any return and_( *[func.json_extract(ScenarioResultEntry.labels, f"$.{key}") == value for key, value in labels.items()] ) - - def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> Any: - """ - SQLite implementation for filtering ScenarioResults by target endpoint. - Uses json_extract() function specific to SQLite. - - Returns: - Any: A SQLAlchemy subquery for filtering by target endpoint. - """ - return func.lower(func.json_extract(ScenarioResultEntry.objective_target_identifier, "$.endpoint")).like( - f"%{endpoint.lower()}%" - ) - - def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any: - """ - SQLite implementation for filtering ScenarioResults by target model name. - Uses json_extract() function specific to SQLite. - - Returns: - Any: A SQLAlchemy subquery for filtering by target model name. - """ - return func.lower(func.json_extract(ScenarioResultEntry.objective_target_identifier, "$.model_name")).like( - f"%{model_name.lower()}%" - ) diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 91367c3a1c..de238952f4 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -9,6 +9,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.memory import MemoryInterface +from pyrit.memory.identifier_filters import AttackIdentifierFilter, AttackIdentifierProperty from pyrit.memory.memory_models import AttackResultEntry from pyrit.models import ( AttackOutcome, @@ -1352,3 +1353,55 @@ def test_get_unique_converter_class_names_skips_no_converters(sqlite_instance: M result = sqlite_instance.get_unique_converter_class_names() assert result == ["Base64Converter"] + + +def test_get_attack_results_by_attack_identifier_filter_hash(sqlite_instance: MemoryInterface): + """Test filtering attack results by AttackIdentifierFilter with hash.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) + + # Filter by hash of ar1's attack identifier + results = sqlite_instance.get_attack_results( + attack_identifier_filter=AttackIdentifierFilter( + property_path=AttackIdentifierProperty.HASH, + value_to_match=ar1.atomic_attack_identifier.hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].conversation_id == "conv_1" + + +def test_get_attack_results_by_attack_identifier_filter_class_name(sqlite_instance: MemoryInterface): + """Test filtering attack results by AttackIdentifierFilter with class_name.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") + ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) + + # Filter by partial attack class name + results = sqlite_instance.get_attack_results( + attack_identifier_filter=AttackIdentifierFilter( + property_path=AttackIdentifierProperty.ATTACK_CLASS_NAME, + value_to_match="Crescendo", + partial_match=True, + ), + ) + assert len(results) == 2 + assert {r.conversation_id for r in results} == {"conv_1", "conv_3"} + + +def test_get_attack_results_by_attack_identifier_filter_no_match(sqlite_instance: MemoryInterface): + """Test that AttackIdentifierFilter returns empty when nothing matches.""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) + + results = sqlite_instance.get_attack_results( + attack_identifier_filter=AttackIdentifierFilter( + property_path=AttackIdentifierProperty.HASH, + value_to_match="nonexistent_hash", + partial_match=False, + ), + ) + assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 67a4292f87..457169b911 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -14,6 +14,12 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry +from pyrit.memory.identifier_filters import ( + AttackIdentifierFilter, + AttackIdentifierProperty, + TargetIdentifierFilter, + TargetIdentifierProperty, +) from pyrit.models import ( Message, MessagePiece, @@ -1248,3 +1254,112 @@ def test_get_request_from_response_raises_error_for_sequence_less_than_one(sqlit with pytest.raises(ValueError, match="The provided request does not have a preceding request \\(sequence < 1\\)."): sqlite_instance.get_request_from_response(response=response_without_request) + + +def test_get_message_pieces_by_attack_identifier_filter(sqlite_instance: MemoryInterface): + attack1 = PromptSendingAttack(objective_target=get_mock_target()) + attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2")) + + entries = [ + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Hello 1", + attack_identifier=attack1.get_identifier(), + ) + ), + PromptMemoryEntry( + entry=MessagePiece( + role="assistant", + original_value="Hello 2", + attack_identifier=attack2.get_identifier(), + ) + ), + ] + + sqlite_instance._insert_entries(entries=entries) + + # Filter by exact attack hash + results = sqlite_instance.get_message_pieces( + attack_identifier_filter=AttackIdentifierFilter( + property_path=AttackIdentifierProperty.HASH, + value_to_match=attack1.get_identifier().hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].original_value == "Hello 1" + + # No match + results = sqlite_instance.get_message_pieces( + attack_identifier_filter=AttackIdentifierFilter( + property_path=AttackIdentifierProperty.HASH, + value_to_match="nonexistent_hash", + partial_match=False, + ), + ) + assert len(results) == 0 + + +def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryInterface): + target_id_1 = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + ) + target_id_2 = ComponentIdentifier( + class_name="AzureChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, + ) + + entries = [ + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Hello OpenAI", + prompt_target_identifier=target_id_1, + ) + ), + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="Hello Azure", + prompt_target_identifier=target_id_2, + ) + ), + ] + + sqlite_instance._insert_entries(entries=entries) + + # Filter by target hash + results = sqlite_instance.get_message_pieces( + prompt_target_identifier_filter=TargetIdentifierFilter( + property_path=TargetIdentifierProperty.HASH, + value_to_match=target_id_1.hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].original_value == "Hello OpenAI" + + # Filter by endpoint partial match + results = sqlite_instance.get_message_pieces( + prompt_target_identifier_filter=TargetIdentifierFilter( + property_path=TargetIdentifierProperty.ENDPOINT, + value_to_match="openai", + partial_match=True, + ), + ) + assert len(results) == 1 + assert results[0].original_value == "Hello OpenAI" + + # No match + results = sqlite_instance.get_message_pieces( + prompt_target_identifier_filter=TargetIdentifierFilter( + property_path=TargetIdentifierProperty.HASH, + value_to_match="nonexistent", + partial_match=False, + ), + ) + assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index e513e8b873..51b64a819b 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -9,6 +9,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface +from pyrit.memory.identifier_filters import TargetIdentifierFilter, TargetIdentifierProperty from pyrit.models import ( AttackOutcome, AttackResult, @@ -645,3 +646,116 @@ def test_combined_filters(sqlite_instance: MemoryInterface): assert len(results) == 1 assert results[0].scenario_identifier.pyrit_version == "0.5.0" assert "gpt-4" in results[0].objective_target_identifier.params["model_name"] + + +def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: MemoryInterface): + """Test filtering scenario results by TargetIdentifierFilter with hash.""" + target_id_1 = ComponentIdentifier( + class_name="OpenAI", + class_module="test", + params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + ) + target_id_2 = ComponentIdentifier( + class_name="Azure", + class_module="test", + params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, + ) + + attack_result1 = create_attack_result("conv_1", "Objective 1") + attack_result2 = create_attack_result("conv_2", "Objective 2") + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2]) + + scenario1 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario OpenAI", scenario_version=1), + objective_target_identifier=target_id_1, + attack_results={"Attack1": [attack_result1]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + scenario2 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario Azure", scenario_version=1), + objective_target_identifier=target_id_2, + attack_results={"Attack2": [attack_result2]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1, scenario2]) + + # Filter by target hash + results = sqlite_instance.get_scenario_results( + objective_target_identifier_filter=TargetIdentifierFilter( + property_path=TargetIdentifierProperty.HASH, + value_to_match=target_id_1.hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].scenario_identifier.name == "Scenario OpenAI" + + +def test_get_scenario_results_by_target_identifier_filter_endpoint(sqlite_instance: MemoryInterface): + """Test filtering scenario results by TargetIdentifierFilter with endpoint.""" + target_id_1 = ComponentIdentifier( + class_name="OpenAI", + class_module="test", + params={"endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + ) + target_id_2 = ComponentIdentifier( + class_name="Azure", + class_module="test", + params={"endpoint": "https://azure.com", "model_name": "gpt-3.5"}, + ) + + attack_result1 = create_attack_result("conv_1", "Objective 1") + attack_result2 = create_attack_result("conv_2", "Objective 2") + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2]) + + scenario1 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario OpenAI", scenario_version=1), + objective_target_identifier=target_id_1, + attack_results={"Attack1": [attack_result1]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + scenario2 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Scenario Azure", scenario_version=1), + objective_target_identifier=target_id_2, + attack_results={"Attack2": [attack_result2]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1, scenario2]) + + # Filter by endpoint partial match + results = sqlite_instance.get_scenario_results( + objective_target_identifier_filter=TargetIdentifierFilter( + property_path=TargetIdentifierProperty.ENDPOINT, + value_to_match="openai", + partial_match=True, + ), + ) + assert len(results) == 1 + assert results[0].scenario_identifier.name == "Scenario OpenAI" + + +def test_get_scenario_results_by_target_identifier_filter_no_match(sqlite_instance: MemoryInterface): + """Test that TargetIdentifierFilter returns empty when nothing matches.""" + attack_result1 = create_attack_result("conv_1", "Objective 1") + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1]) + + scenario1 = ScenarioResult( + scenario_identifier=ScenarioIdentifier(name="Test Scenario", scenario_version=1), + objective_target_identifier=ComponentIdentifier( + class_name="OpenAI", + class_module="test", + params={"endpoint": "https://api.openai.com"}, + ), + attack_results={"Attack1": [attack_result1]}, + objective_scorer_identifier=get_mock_scorer_identifier(), + ) + sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1]) + + results = sqlite_instance.get_scenario_results( + objective_target_identifier_filter=TargetIdentifierFilter( + property_path=TargetIdentifierProperty.HASH, + value_to_match="nonexistent_hash", + partial_match=False, + ), + ) + assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 6087af1418..e9945bfc2e 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -13,6 +13,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry +from pyrit.memory.identifier_filters import ScorerIdentifierFilter, ScorerIdentifierProperty from pyrit.models import ( MessagePiece, Score, @@ -227,3 +228,76 @@ async def test_get_seeds_no_filters(sqlite_instance: MemoryInterface): assert len(result) == 2 assert result[0].value == "prompt1" assert result[1].value == "prompt2" + + +def test_get_scores_by_scorer_identifier_filter( + sqlite_instance: MemoryInterface, sample_conversation_entries: Sequence[PromptMemoryEntry], +): + prompt_id = sample_conversation_entries[0].id + sqlite_instance._insert_entries(entries=sample_conversation_entries) + + score_a = Score( + score_value="0.9", + score_value_description="High", + score_type="float_scale", + score_category=["cat_a"], + score_rationale="Rationale A", + score_metadata={}, + scorer_class_identifier=_test_scorer_id("ScorerAlpha"), + message_piece_id=prompt_id, + ) + score_b = Score( + score_value="0.1", + score_value_description="Low", + score_type="float_scale", + score_category=["cat_b"], + score_rationale="Rationale B", + score_metadata={}, + scorer_class_identifier=_test_scorer_id("ScorerBeta"), + message_piece_id=prompt_id, + ) + + sqlite_instance.add_scores_to_memory(scores=[score_a, score_b]) + + # Filter by exact class_name match + results = sqlite_instance.get_scores( + scorer_identifier_filter=ScorerIdentifierFilter( + property_path=ScorerIdentifierProperty.CLASS_NAME, + value_to_match="ScorerAlpha", + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].score_value == "0.9" + + # Filter by partial class_name match + results = sqlite_instance.get_scores( + scorer_identifier_filter=ScorerIdentifierFilter( + property_path=ScorerIdentifierProperty.CLASS_NAME, + value_to_match="Scorer", + partial_match=True, + ), + ) + assert len(results) == 2 + + # Filter by hash + scorer_hash = score_a.scorer_class_identifier.hash + results = sqlite_instance.get_scores( + scorer_identifier_filter=ScorerIdentifierFilter( + property_path=ScorerIdentifierProperty.HASH, + value_to_match=scorer_hash, + partial_match=False, + ), + ) + assert len(results) == 1 + assert results[0].score_value == "0.9" + + # No match + results = sqlite_instance.get_scores( + scorer_identifier_filter=ScorerIdentifierFilter( + property_path=ScorerIdentifierProperty.CLASS_NAME, + value_to_match="NonExistent", + partial_match=False, + ), + ) + assert len(results) == 0 From 01aaa159e559247699bee95217923722d6955d46 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 1 Apr 2026 13:56:45 -0700 Subject: [PATCH 02/30] forgot formatting --- pyrit/memory/__init__.py | 11 ++++++++++- pyrit/memory/azure_sql_memory.py | 12 +++++------- pyrit/memory/memory_interface.py | 16 ++++++++-------- pyrit/memory/sqlite_memory.py | 8 ++++---- .../memory_interface/test_interface_scores.py | 3 ++- 5 files changed, 29 insertions(+), 21 deletions(-) diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index 102a1f8607..a22469de00 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -7,9 +7,18 @@ This package defines the core `MemoryInterface` and concrete implementations for different storage backends. """ -from pyrit.memory.identifier_filters import AttackIdentifierFilter, AttackIdentifierProperty, ConverterIdentifierFilter, ConverterIdentifierProperty, ScorerIdentifierFilter, ScorerIdentifierProperty, TargetIdentifierFilter, TargetIdentifierProperty from pyrit.memory.azure_sql_memory import AzureSQLMemory from pyrit.memory.central_memory import CentralMemory +from pyrit.memory.identifier_filters import ( + AttackIdentifierFilter, + AttackIdentifierProperty, + ConverterIdentifierFilter, + ConverterIdentifierProperty, + ScorerIdentifierFilter, + ScorerIdentifierProperty, + TargetIdentifierFilter, + TargetIdentifierProperty, +) from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter from pyrit.memory.memory_interface import MemoryInterface diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 48ae2c5df2..fc7a951f1e 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -319,10 +319,10 @@ def _get_condition_json_property_match( return text( f"""ISJSON("{table_name}".{column_name}) = 1 AND LOWER(JSON_VALUE("{table_name}".{column_name}, :property_path)) {"LIKE" if partial_match else "="} :match_property_value""" # noqa: E501 - ).bindparams( - property_path=property_path, - match_property_value=f"%{value_to_match.lower()}%" if partial_match else value_to_match, - ) + ).bindparams( + property_path=property_path, + match_property_value=f"%{value_to_match.lower()}%" if partial_match else value_to_match, + ) # The above return statement already handles both partial and exact matches # The following code is now unreachable and can be removed @@ -360,9 +360,7 @@ def _get_condition_json_array_match( bindparams_dict[param_name] = match_value.lower() if case_insensitive else match_value combined = " AND ".join(conditions) - return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams( - **bindparams_dict - ) + return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict) def _get_unique_json_property_values( self, diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 5bc1f4ad3e..0fcdfc6f3c 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -374,7 +374,7 @@ def get_unique_attack_class_names(self) -> list[str]: """ return self._get_unique_json_array_values( json_column=AttackResultEntry.atomic_attack_identifier, - path_to_array=AttackIdentifierProperty.ATTACK_CLASS_NAME + path_to_array=AttackIdentifierProperty.ATTACK_CLASS_NAME, ) def get_unique_converter_class_names(self) -> list[str]: @@ -638,7 +638,7 @@ def get_message_pieces( not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. Defaults to None. - attack_identifier_filter (Optional[AttackIdentifierFilter], optional): + attack_identifier_filter (Optional[AttackIdentifierFilter], optional): An AttackIdentifierFilter object that allows filtering by various attack identifier JSON properties. Defaults to None. prompt_target_identifier_filter (Optional[TargetIdentifierFilter], optional): @@ -658,7 +658,7 @@ def get_message_pieces( self._get_condition_json_property_match( json_column=PromptMemoryEntry.attack_identifier, property_path=AttackIdentifierProperty.HASH, - value_to_match=str(attack_id) + value_to_match=str(attack_id), ) ) if role: @@ -1770,12 +1770,12 @@ def get_scenario_results( # Use database-specific JSON query method conditions.append( self._get_condition_json_property_match( - json_column=ScenarioResultEntry.objective_target_identifier, - property_path=TargetIdentifierProperty.ENDPOINT, - value_to_match=objective_target_endpoint, - partial_match=True, + json_column=ScenarioResultEntry.objective_target_identifier, + property_path=TargetIdentifierProperty.ENDPOINT, + value_to_match=objective_target_endpoint, + partial_match=True, + ) ) - ) if objective_target_model_name: # Use database-specific JSON query method diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index a41dbffc90..3e94e0e2ea 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -234,9 +234,9 @@ def _get_condition_json_array_match( } conditions.append( text( - f'''EXISTS(SELECT 1 FROM json_each( + f"""EXISTS(SELECT 1 FROM json_each( json_extract("{table_name}".{column_name}, :property_path)) - WHERE {value_expression} = :{param_name})''' + WHERE {value_expression} = :{param_name})""" ).bindparams(**bind_params) ) return and_(*conditions) @@ -257,10 +257,10 @@ def _get_unique_json_array_values( column_name = json_column.key rows = session.execute( text( - f'''SELECT DISTINCT json_extract(j.value, :sub_path) AS value + f"""SELECT DISTINCT json_extract(j.value, :sub_path) AS value FROM "{table_name}", json_each(json_extract("{table_name}".{column_name}, :path_to_array)) AS j - WHERE json_extract(j.value, :sub_path) IS NOT NULL''' + WHERE json_extract(j.value, :sub_path) IS NOT NULL""" ).bindparams( path_to_array=path_to_array, sub_path=sub_path, diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index e9945bfc2e..bb9478c3b6 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -231,7 +231,8 @@ async def test_get_seeds_no_filters(sqlite_instance: MemoryInterface): def test_get_scores_by_scorer_identifier_filter( - sqlite_instance: MemoryInterface, sample_conversation_entries: Sequence[PromptMemoryEntry], + sqlite_instance: MemoryInterface, + sample_conversation_entries: Sequence[PromptMemoryEntry], ): prompt_id = sample_conversation_entries[0].id sqlite_instance._insert_entries(entries=sample_conversation_entries) From e77b43c0b604e162242791578df6611f44376a5b Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 1 Apr 2026 14:05:36 -0700 Subject: [PATCH 03/30] return str --- pyrit/memory/identifier_filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py index 8792f03241..10aba39aa5 100644 --- a/pyrit/memory/identifier_filters.py +++ b/pyrit/memory/identifier_filters.py @@ -12,7 +12,7 @@ class _StrEnum(str, Enum): """Base class that mimics StrEnum behavior for Python < 3.11.""" def __str__(self) -> str: - return self.value + return str(self.value) T = TypeVar("T", bound=_StrEnum) From a06b5060ca25add903bef9055f5435dfe9a05779 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 1 Apr 2026 14:08:55 -0700 Subject: [PATCH 04/30] fix method name --- pyrit/memory/azure_sql_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index fc7a951f1e..cf9c5f6d49 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -362,7 +362,7 @@ def _get_condition_json_array_match( combined = " AND ".join(conditions) return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict) - def _get_unique_json_property_values( + def _get_unique_json_array_values( self, *, json_column: Any, From 9d3cb5f378ea1f30d163a798aed57f84875c4964 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Apr 2026 09:39:51 -0700 Subject: [PATCH 05/30] add back public methods --- pyrit/memory/azure_sql_memory.py | 21 +++++++++++++++++++++ pyrit/memory/sqlite_memory.py | 21 +++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index cf9c5f6d49..8941078e4f 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -571,6 +571,27 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] """ self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces]) + def get_unique_attack_class_names(self) -> list[str]: + """ + Azure SQL implementation: extract unique class_name values from + the atomic_attack_identifier JSON column. + + Returns: + Sorted list of unique attack class name strings. + """ + return super().get_unique_attack_class_names() + + def get_unique_converter_class_names(self) -> list[str]: + """ + Azure SQL implementation: extract unique converter class_name values + from the children.attack.children.request_converters array + in the atomic_attack_identifier JSON column. + + Returns: + Sorted list of unique converter class name strings. + """ + return super().get_unique_converter_class_names() + def dispose_engine(self) -> None: """ Dispose the engine and clean up resources. diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 3e94e0e2ea..f76f300a3d 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -268,6 +268,27 @@ def _get_unique_json_array_values( ).fetchall() return sorted(row[0] for row in rows) + def get_unique_attack_class_names(self) -> list[str]: + """ + SQLite implementation: extract unique class_name values from + the atomic_attack_identifier JSON column. + + Returns: + Sorted list of unique attack class name strings. + """ + return super().get_unique_attack_class_names() + + def get_unique_converter_class_names(self) -> list[str]: + """ + SQLite implementation: extract unique converter class_name values + from the children.attack.children.request_converters array in the + atomic_attack_identifier JSON column. + + Returns: + Sorted list of unique converter class name strings. + """ + return super().get_unique_converter_class_names() + def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Insert a list of message pieces into the memory storage. From 5389a9f4c85cabbf987873077ae97da6c2c1b97f Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Apr 2026 11:42:13 -0700 Subject: [PATCH 06/30] custom subpath for array match and make all matches case insensitive --- pyrit/memory/azure_sql_memory.py | 14 +++++--------- pyrit/memory/memory_interface.py | 6 +++--- pyrit/memory/sqlite_memory.py | 14 ++++++-------- 3 files changed, 14 insertions(+), 20 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 8941078e4f..5ff9710ceb 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -321,18 +321,16 @@ def _get_condition_json_property_match( AND LOWER(JSON_VALUE("{table_name}".{column_name}, :property_path)) {"LIKE" if partial_match else "="} :match_property_value""" # noqa: E501 ).bindparams( property_path=property_path, - match_property_value=f"%{value_to_match.lower()}%" if partial_match else value_to_match, + match_property_value=f"%{value_to_match.lower()}%", ) - # The above return statement already handles both partial and exact matches - # The following code is now unreachable and can be removed def _get_condition_json_array_match( self, *, json_column: Any, property_path: str, - array_to_match: Sequence[str], - case_insensitive: bool = False, + sub_path: str | None = None, + array_to_match: Sequence[str] ) -> Any: table_name = json_column.class_.__tablename__ column_name = json_column.key @@ -343,9 +341,7 @@ def _get_condition_json_array_match( OR JSON_QUERY("{table_name}".{column_name}, :property_path) = '[]')""" ).bindparams(property_path=property_path) - value_expression = "JSON_VALUE(value, '$.class_name')" - if case_insensitive: - value_expression = f"LOWER({value_expression})" + value_expression = f"LOWER(JSON_VALUE(value, '{sub_path}'))" if sub_path else "LOWER(value)" conditions = [] bindparams_dict: dict[str, str] = {"property_path": property_path} @@ -357,7 +353,7 @@ def _get_condition_json_array_match( :property_path)) WHERE {value_expression} = :{param_name})""" ) - bindparams_dict[param_name] = match_value.lower() if case_insensitive else match_value + bindparams_dict[param_name] = match_value.lower() combined = " AND ".join(conditions) return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 0fcdfc6f3c..74f99f0217 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -149,8 +149,8 @@ def _get_condition_json_array_match( *, json_column: InstrumentedAttribute[Any], property_path: str, + sub_path: Optional[str] = None, array_to_match: Sequence[str], - case_insensitive: bool = False, ) -> Any: """ Return a database-specific condition for matching an array at a given path within a JSON object. @@ -158,10 +158,10 @@ def _get_condition_json_array_match( Args: json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. property_path (str): The JSON path for the target array. + sub_path (Optional[str]): An optional JSON path applied to each array item before matching. array_to_match (Sequence[str]): The array that must match the extracted JSON array values. For a match, ALL values in this array must be present in the JSON array. If `array_to_match` is empty, the condition must match only if the target is also an empty array or None. - case_insensitive (bool): Whether string comparison should ignore casing. Returns: Any: A database-specific SQLAlchemy condition. @@ -1500,8 +1500,8 @@ def get_attack_results( self._get_condition_json_array_match( json_column=AttackResultEntry.atomic_attack_identifier, property_path=AttackIdentifierProperty.REQUEST_CONVERTERS, + sub_path=ConverterIdentifierProperty.CLASS_NAME, array_to_match=converter_classes, - case_insensitive=True, ) ) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index f76f300a3d..fa9487055e 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -198,18 +198,18 @@ def _get_condition_json_property_match( value_to_match: str, partial_match: bool = False, ) -> Any: - extracted_value = func.json_extract(json_column, property_path) + extracted_value = func.lower(func.json_extract(json_column, property_path)) if partial_match: - return func.lower(extracted_value).like(f"%{value_to_match.lower()}%") - return extracted_value == value_to_match + return extracted_value.like(f"%{value_to_match.lower()}%") + return extracted_value == value_to_match.lower() def _get_condition_json_array_match( self, *, json_column: Any, property_path: str, + sub_path: str | None = None, array_to_match: Sequence[str], - case_insensitive: bool = False, ) -> Any: array_expr = func.json_extract(json_column, property_path) if len(array_to_match) == 0: @@ -221,16 +221,14 @@ def _get_condition_json_array_match( table_name = json_column.class_.__tablename__ column_name = json_column.key - value_expression = "json_extract(value, '$.class_name')" - if case_insensitive: - value_expression = f"LOWER({value_expression})" + value_expression = f"LOWER(json_extract(value, '{sub_path}'))" if sub_path else "LOWER(value)" conditions = [] for index, match_value in enumerate(array_to_match): param_name = f"match_value_{index}" bind_params: dict[str, str] = { "property_path": property_path, - param_name: match_value.lower() if case_insensitive else match_value, + param_name: match_value.lower(), } conditions.append( text( From 3fa071367a9a25486ec842fcc419ab3c4fa58027 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Apr 2026 11:47:45 -0700 Subject: [PATCH 07/30] format --- pyrit/memory/azure_sql_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 5ff9710ceb..916f64508d 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -330,7 +330,7 @@ def _get_condition_json_array_match( json_column: Any, property_path: str, sub_path: str | None = None, - array_to_match: Sequence[str] + array_to_match: Sequence[str], ) -> Any: table_name = json_column.class_.__tablename__ column_name = json_column.key From 24f61d1ecb73b3c171c29072a497ef9c67981ab6 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Apr 2026 14:26:39 -0700 Subject: [PATCH 08/30] allow free-form paths in identifier filters --- pyrit/memory/__init__.py | 20 +---- pyrit/memory/identifier_filters.py | 82 +------------------ pyrit/memory/memory_interface.py | 55 ++++++------- .../test_interface_attack_results.py | 23 ++---- .../test_interface_prompts.py | 27 +++--- .../test_interface_scenario_results.py | 18 ++-- .../memory_interface/test_interface_scores.py | 18 ++-- 7 files changed, 64 insertions(+), 179 deletions(-) diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index a22469de00..6098122d7d 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -9,16 +9,7 @@ from pyrit.memory.azure_sql_memory import AzureSQLMemory from pyrit.memory.central_memory import CentralMemory -from pyrit.memory.identifier_filters import ( - AttackIdentifierFilter, - AttackIdentifierProperty, - ConverterIdentifierFilter, - ConverterIdentifierProperty, - ScorerIdentifierFilter, - ScorerIdentifierProperty, - TargetIdentifierFilter, - TargetIdentifierProperty, -) +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter from pyrit.memory.memory_interface import MemoryInterface @@ -27,10 +18,6 @@ __all__ = [ "AttackResultEntry", - "AttackIdentifierFilter", - "AttackIdentifierProperty", - "ConverterIdentifierFilter", - "ConverterIdentifierProperty", "AzureSQLMemory", "CentralMemory", "SQLiteMemory", @@ -39,9 +26,6 @@ "MemoryEmbedding", "MemoryExporter", "PromptMemoryEntry", - "ScorerIdentifierFilter", - "ScorerIdentifierProperty", "SeedEntry", - "TargetIdentifierFilter", - "TargetIdentifierProperty", + "IdentifierFilter", ] diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py index 10aba39aa5..74c62c877a 100644 --- a/pyrit/memory/identifier_filters.py +++ b/pyrit/memory/identifier_filters.py @@ -1,95 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from abc import ABC from dataclasses import dataclass -from enum import Enum -from typing import Generic, TypeVar - - -# TODO: if/when we move to python 3.11+, we can replace this with StrEnum -class _StrEnum(str, Enum): - """Base class that mimics StrEnum behavior for Python < 3.11.""" - - def __str__(self) -> str: - return str(self.value) - - -T = TypeVar("T", bound=_StrEnum) - - -class IdentifierProperty(_StrEnum): - """Allowed JSON paths for identifier filtering.""" - - HASH = "$.hash" - CLASS_NAME = "$.class_name" @dataclass(frozen=True) -class IdentifierFilter(ABC, Generic[T]): +class IdentifierFilter: """Immutable filter definition for matching JSON-backed identifier properties.""" - property_path: T | str + property_path: str value_to_match: str partial_match: bool = False def __post_init__(self) -> None: """Normalize and validate the configured property path.""" object.__setattr__(self, "property_path", str(self.property_path)) - - -class AttackIdentifierProperty(_StrEnum): - """Allowed JSON paths for attack identifier filtering.""" - - HASH = "$.hash" - ATTACK_CLASS_NAME = "$.children.attack.class_name" - REQUEST_CONVERTERS = "$.children.attack.children.request_converters" - - -class TargetIdentifierProperty(_StrEnum): - """Allowed JSON paths for target identifier filtering.""" - - HASH = "$.hash" - ENDPOINT = "$.endpoint" - MODEL_NAME = "$.model_name" - - -class ConverterIdentifierProperty(_StrEnum): - """Allowed JSON paths for converter identifier filtering.""" - - HASH = "$.hash" - CLASS_NAME = "$.class_name" - - -class ScorerIdentifierProperty(_StrEnum): - """Allowed JSON paths for scorer identifier filtering.""" - - HASH = "$.hash" - CLASS_NAME = "$.class_name" - - -@dataclass(frozen=True) -class AttackIdentifierFilter(IdentifierFilter[AttackIdentifierProperty]): - """ - Immutable filter definition for matching JSON-backed attack identifier properties. - - Args: - property_path: The JSON path of the property to filter on. - value_to_match: The value to match against the property. - partial_match: Whether to allow partial matches (default: False). - """ - - -@dataclass(frozen=True) -class TargetIdentifierFilter(IdentifierFilter[TargetIdentifierProperty]): - """Immutable filter definition for matching JSON-backed target identifier properties.""" - - -@dataclass(frozen=True) -class ConverterIdentifierFilter(IdentifierFilter[ConverterIdentifierProperty]): - """Immutable filter definition for matching JSON-backed converter identifier properties.""" - - -@dataclass(frozen=True) -class ScorerIdentifierFilter(IdentifierFilter[ScorerIdentifierProperty]): - """Immutable filter definition for matching JSON-backed scorer identifier properties.""" diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 74f99f0217..1ef99789d8 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -19,14 +19,7 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute from pyrit.common.path import DB_DATA_PATH -from pyrit.memory.identifier_filters import ( - AttackIdentifierFilter, - AttackIdentifierProperty, - ConverterIdentifierProperty, - ScorerIdentifierFilter, - TargetIdentifierFilter, - TargetIdentifierProperty, -) +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.memory.memory_embedding import ( MemoryEmbedding, default_memory_embedding_factory, @@ -374,7 +367,7 @@ def get_unique_attack_class_names(self) -> list[str]: """ return self._get_unique_json_array_values( json_column=AttackResultEntry.atomic_attack_identifier, - path_to_array=AttackIdentifierProperty.ATTACK_CLASS_NAME, + path_to_array="$.children.attack.class_name", ) def get_unique_converter_class_names(self) -> list[str]: @@ -389,8 +382,8 @@ def get_unique_converter_class_names(self) -> list[str]: """ return self._get_unique_json_array_values( json_column=AttackResultEntry.atomic_attack_identifier, - path_to_array=AttackIdentifierProperty.REQUEST_CONVERTERS, - sub_path=ConverterIdentifierProperty.CLASS_NAME, + path_to_array="$.children.attack.children.request_converters", + sub_path="$.class_name", ) @abc.abstractmethod @@ -447,7 +440,7 @@ def get_scores( score_category: Optional[str] = None, sent_after: Optional[datetime] = None, sent_before: Optional[datetime] = None, - scorer_identifier_filter: Optional[ScorerIdentifierFilter] = None, + scorer_identifier_filter: Optional[IdentifierFilter] = None, ) -> Sequence[Score]: """ Retrieve a list of Score objects based on the specified filters. @@ -458,7 +451,7 @@ def get_scores( score_category (Optional[str]): The category of the score to filter by. sent_after (Optional[datetime]): Filter for scores sent after this datetime. sent_before (Optional[datetime]): Filter for scores sent before this datetime. - scorer_identifier_filter (Optional[ScorerIdentifierFilter]): A ScorerIdentifierFilter object that + scorer_identifier_filter (Optional[IdentifierFilter]): An IdentifierFilter object that allows filtering by various scorer identifier JSON properties. Defaults to None. Returns: @@ -615,8 +608,8 @@ def get_message_pieces( data_type: Optional[str] = None, not_data_type: Optional[str] = None, converted_value_sha256: Optional[Sequence[str]] = None, - attack_identifier_filter: Optional[AttackIdentifierFilter] = None, - prompt_target_identifier_filter: Optional[TargetIdentifierFilter] = None, + attack_identifier_filter: Optional[IdentifierFilter] = None, + prompt_target_identifier_filter: Optional[IdentifierFilter] = None, ) -> Sequence[MessagePiece]: """ Retrieve a list of MessagePiece objects based on the specified filters. @@ -638,11 +631,11 @@ def get_message_pieces( not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. Defaults to None. - attack_identifier_filter (Optional[AttackIdentifierFilter], optional): - An AttackIdentifierFilter object that + attack_identifier_filter (Optional[IdentifierFilter], optional): + An IdentifierFilter object that allows filtering by various attack identifier JSON properties. Defaults to None. - prompt_target_identifier_filter (Optional[TargetIdentifierFilter], optional): - A TargetIdentifierFilter object that + prompt_target_identifier_filter (Optional[IdentifierFilter], optional): + An IdentifierFilter object that allows filtering by various target identifier JSON properties. Defaults to None. Returns: @@ -657,7 +650,7 @@ def get_message_pieces( conditions.append( self._get_condition_json_property_match( json_column=PromptMemoryEntry.attack_identifier, - property_path=AttackIdentifierProperty.HASH, + property_path="$.hash", value_to_match=str(attack_id), ) ) @@ -1431,7 +1424,7 @@ def get_attack_results( converter_classes: Optional[Sequence[str]] = None, targeted_harm_categories: Optional[Sequence[str]] = None, labels: Optional[dict[str, str]] = None, - attack_identifier_filter: Optional[AttackIdentifierFilter] = None, + attack_identifier_filter: Optional[IdentifierFilter] = None, ) -> Sequence[AttackResult]: """ Retrieve a list of AttackResult objects based on the specified filters. @@ -1459,8 +1452,8 @@ def get_attack_results( labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. These labels are associated with the prompts themselves, used for custom tagging and tracking. Defaults to None. - attack_identifier_filter (Optional[AttackIdentifierFilter], optional): - An AttackIdentifierFilter object that allows filtering by various attack identifier + attack_identifier_filter (Optional[IdentifierFilter], optional): + An IdentifierFilter object that allows filtering by various attack identifier JSON properties. Defaults to None. Returns: @@ -1488,7 +1481,7 @@ def get_attack_results( conditions.append( self._get_condition_json_property_match( json_column=AttackResultEntry.atomic_attack_identifier, - property_path=AttackIdentifierProperty.ATTACK_CLASS_NAME, + property_path="$.children.attack.class_name", value_to_match=attack_class, ) ) @@ -1499,8 +1492,8 @@ def get_attack_results( conditions.append( self._get_condition_json_array_match( json_column=AttackResultEntry.atomic_attack_identifier, - property_path=AttackIdentifierProperty.REQUEST_CONVERTERS, - sub_path=ConverterIdentifierProperty.CLASS_NAME, + property_path="$.children.attack.children.request_converters", + sub_path="$.class_name", array_to_match=converter_classes, ) ) @@ -1705,7 +1698,7 @@ def get_scenario_results( labels: Optional[dict[str, str]] = None, objective_target_endpoint: Optional[str] = None, objective_target_model_name: Optional[str] = None, - objective_target_identifier_filter: Optional[TargetIdentifierFilter] = None, + objective_target_identifier_filter: Optional[IdentifierFilter] = None, ) -> Sequence[ScenarioResult]: """ Retrieve a list of ScenarioResult objects based on the specified filters. @@ -1729,8 +1722,8 @@ def get_scenario_results( objective_target_model_name (Optional[str], optional): Filter for scenarios where the objective_target_identifier has a model_name attribute containing this value (case-insensitive). Defaults to None. - objective_target_identifier_filter (Optional[TargetIdentifierFilter], optional): - A TargetIdentifierFilter object that allows filtering by various target identifier JSON properties. + objective_target_identifier_filter (Optional[IdentifierFilter], optional): + An IdentifierFilter object that allows filtering by various target identifier JSON properties. Returns: Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. @@ -1771,7 +1764,7 @@ def get_scenario_results( conditions.append( self._get_condition_json_property_match( json_column=ScenarioResultEntry.objective_target_identifier, - property_path=TargetIdentifierProperty.ENDPOINT, + property_path="$.endpoint", value_to_match=objective_target_endpoint, partial_match=True, ) @@ -1782,7 +1775,7 @@ def get_scenario_results( conditions.append( self._get_condition_json_property_match( json_column=ScenarioResultEntry.objective_target_identifier, - property_path=TargetIdentifierProperty.MODEL_NAME, + property_path="$.model_name", value_to_match=objective_target_model_name, partial_match=True, ) diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index de238952f4..84cda0b409 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -9,7 +9,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.memory import MemoryInterface -from pyrit.memory.identifier_filters import AttackIdentifierFilter, AttackIdentifierProperty +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.memory.memory_models import AttackResultEntry from pyrit.models import ( AttackOutcome, @@ -1176,15 +1176,6 @@ def test_get_attack_results_by_attack_class_no_match(sqlite_instance: MemoryInte assert len(results) == 0 -def test_get_attack_results_by_attack_class_case_sensitive(sqlite_instance: MemoryInterface): - """Test that attack_class filter is case-sensitive (exact match).""" - ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") - sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) - - results = sqlite_instance.get_attack_results(attack_class="crescendoattack") - assert len(results) == 0 - - def test_get_attack_results_by_attack_class_no_identifier(sqlite_instance: MemoryInterface): """Test that attacks with no attack_identifier (empty JSON) are excluded by attack_class filter.""" ar1 = create_attack_result("conv_1", 1) # No attack_identifier → stored as {} @@ -1363,8 +1354,8 @@ def test_get_attack_results_by_attack_identifier_filter_hash(sqlite_instance: Me # Filter by hash of ar1's attack identifier results = sqlite_instance.get_attack_results( - attack_identifier_filter=AttackIdentifierFilter( - property_path=AttackIdentifierProperty.HASH, + attack_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match=ar1.atomic_attack_identifier.hash, partial_match=False, ), @@ -1382,8 +1373,8 @@ def test_get_attack_results_by_attack_identifier_filter_class_name(sqlite_instan # Filter by partial attack class name results = sqlite_instance.get_attack_results( - attack_identifier_filter=AttackIdentifierFilter( - property_path=AttackIdentifierProperty.ATTACK_CLASS_NAME, + attack_identifier_filter=IdentifierFilter( + property_path="$.children.attack.class_name", value_to_match="Crescendo", partial_match=True, ), @@ -1398,8 +1389,8 @@ def test_get_attack_results_by_attack_identifier_filter_no_match(sqlite_instance sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) results = sqlite_instance.get_attack_results( - attack_identifier_filter=AttackIdentifierFilter( - property_path=AttackIdentifierProperty.HASH, + attack_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match="nonexistent_hash", partial_match=False, ), diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 457169b911..eec4d3d88a 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -14,12 +14,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry -from pyrit.memory.identifier_filters import ( - AttackIdentifierFilter, - AttackIdentifierProperty, - TargetIdentifierFilter, - TargetIdentifierProperty, -) +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.models import ( Message, MessagePiece, @@ -1281,8 +1276,8 @@ def test_get_message_pieces_by_attack_identifier_filter(sqlite_instance: MemoryI # Filter by exact attack hash results = sqlite_instance.get_message_pieces( - attack_identifier_filter=AttackIdentifierFilter( - property_path=AttackIdentifierProperty.HASH, + attack_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match=attack1.get_identifier().hash, partial_match=False, ), @@ -1292,8 +1287,8 @@ def test_get_message_pieces_by_attack_identifier_filter(sqlite_instance: MemoryI # No match results = sqlite_instance.get_message_pieces( - attack_identifier_filter=AttackIdentifierFilter( - property_path=AttackIdentifierProperty.HASH, + attack_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match="nonexistent_hash", partial_match=False, ), @@ -1334,8 +1329,8 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI # Filter by target hash results = sqlite_instance.get_message_pieces( - prompt_target_identifier_filter=TargetIdentifierFilter( - property_path=TargetIdentifierProperty.HASH, + prompt_target_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match=target_id_1.hash, partial_match=False, ), @@ -1345,8 +1340,8 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI # Filter by endpoint partial match results = sqlite_instance.get_message_pieces( - prompt_target_identifier_filter=TargetIdentifierFilter( - property_path=TargetIdentifierProperty.ENDPOINT, + prompt_target_identifier_filter=IdentifierFilter( + property_path="$.endpoint", value_to_match="openai", partial_match=True, ), @@ -1356,8 +1351,8 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI # No match results = sqlite_instance.get_message_pieces( - prompt_target_identifier_filter=TargetIdentifierFilter( - property_path=TargetIdentifierProperty.HASH, + prompt_target_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match="nonexistent", partial_match=False, ), diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index 51b64a819b..ee2933b70a 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -9,7 +9,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface -from pyrit.memory.identifier_filters import TargetIdentifierFilter, TargetIdentifierProperty +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.models import ( AttackOutcome, AttackResult, @@ -649,7 +649,7 @@ def test_combined_filters(sqlite_instance: MemoryInterface): def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: MemoryInterface): - """Test filtering scenario results by TargetIdentifierFilter with hash.""" + """Test filtering scenario results by identifier filter.""" target_id_1 = ComponentIdentifier( class_name="OpenAI", class_module="test", @@ -681,8 +681,8 @@ def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: # Filter by target hash results = sqlite_instance.get_scenario_results( - objective_target_identifier_filter=TargetIdentifierFilter( - property_path=TargetIdentifierProperty.HASH, + objective_target_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match=target_id_1.hash, partial_match=False, ), @@ -692,7 +692,7 @@ def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: def test_get_scenario_results_by_target_identifier_filter_endpoint(sqlite_instance: MemoryInterface): - """Test filtering scenario results by TargetIdentifierFilter with endpoint.""" + """Test filtering scenario results by identifier filter with endpoint.""" target_id_1 = ComponentIdentifier( class_name="OpenAI", class_module="test", @@ -724,8 +724,8 @@ def test_get_scenario_results_by_target_identifier_filter_endpoint(sqlite_instan # Filter by endpoint partial match results = sqlite_instance.get_scenario_results( - objective_target_identifier_filter=TargetIdentifierFilter( - property_path=TargetIdentifierProperty.ENDPOINT, + objective_target_identifier_filter=IdentifierFilter( + property_path="$.endpoint", value_to_match="openai", partial_match=True, ), @@ -752,8 +752,8 @@ def test_get_scenario_results_by_target_identifier_filter_no_match(sqlite_instan sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1]) results = sqlite_instance.get_scenario_results( - objective_target_identifier_filter=TargetIdentifierFilter( - property_path=TargetIdentifierProperty.HASH, + objective_target_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match="nonexistent_hash", partial_match=False, ), diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index bb9478c3b6..2c90b18313 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -13,7 +13,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry -from pyrit.memory.identifier_filters import ScorerIdentifierFilter, ScorerIdentifierProperty +from pyrit.memory.identifier_filters import IdentifierFilter from pyrit.models import ( MessagePiece, Score, @@ -262,8 +262,8 @@ def test_get_scores_by_scorer_identifier_filter( # Filter by exact class_name match results = sqlite_instance.get_scores( - scorer_identifier_filter=ScorerIdentifierFilter( - property_path=ScorerIdentifierProperty.CLASS_NAME, + scorer_identifier_filter=IdentifierFilter( + property_path="$.class_name", value_to_match="ScorerAlpha", partial_match=False, ), @@ -273,8 +273,8 @@ def test_get_scores_by_scorer_identifier_filter( # Filter by partial class_name match results = sqlite_instance.get_scores( - scorer_identifier_filter=ScorerIdentifierFilter( - property_path=ScorerIdentifierProperty.CLASS_NAME, + scorer_identifier_filter=IdentifierFilter( + property_path="$.class_name", value_to_match="Scorer", partial_match=True, ), @@ -284,8 +284,8 @@ def test_get_scores_by_scorer_identifier_filter( # Filter by hash scorer_hash = score_a.scorer_class_identifier.hash results = sqlite_instance.get_scores( - scorer_identifier_filter=ScorerIdentifierFilter( - property_path=ScorerIdentifierProperty.HASH, + scorer_identifier_filter=IdentifierFilter( + property_path="$.hash", value_to_match=scorer_hash, partial_match=False, ), @@ -295,8 +295,8 @@ def test_get_scores_by_scorer_identifier_filter( # No match results = sqlite_instance.get_scores( - scorer_identifier_filter=ScorerIdentifierFilter( - property_path=ScorerIdentifierProperty.CLASS_NAME, + scorer_identifier_filter=IdentifierFilter( + property_path="$.class_name", value_to_match="NonExistent", partial_match=False, ), From 39361af24d7ce24f0740fed0aaff6eed811ecea2 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Thu, 2 Apr 2026 14:54:41 -0700 Subject: [PATCH 09/30] unncecessary post-init --- pyrit/memory/identifier_filters.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py index 74c62c877a..122d89965b 100644 --- a/pyrit/memory/identifier_filters.py +++ b/pyrit/memory/identifier_filters.py @@ -11,7 +11,3 @@ class IdentifierFilter: property_path: str value_to_match: str partial_match: bool = False - - def __post_init__(self) -> None: - """Normalize and validate the configured property path.""" - object.__setattr__(self, "property_path", str(self.property_path)) From d2191a20aa6be2383300896f9c95587491cf700f Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Fri, 3 Apr 2026 13:52:08 -0700 Subject: [PATCH 10/30] fix exact match in azsql --- pyrit/memory/azure_sql_memory.py | 2 +- tests/unit/memory/test_azure_sql_memory.py | 23 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 916f64508d..96f8b35342 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -321,7 +321,7 @@ def _get_condition_json_property_match( AND LOWER(JSON_VALUE("{table_name}".{column_name}, :property_path)) {"LIKE" if partial_match else "="} :match_property_value""" # noqa: E501 ).bindparams( property_path=property_path, - match_property_value=f"%{value_to_match.lower()}%", + match_property_value=f"%{value_to_match.lower()}%" if partial_match else value_to_match.lower(), ) def _get_condition_json_array_match( diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index 5723800396..c9e4497625 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -326,6 +326,29 @@ def test_update_labels_by_conversation_id(memory_interface: AzureSQLMemory): assert updated_entry.labels["test1"] == "change" +@pytest.mark.parametrize( + "partial_match, expected_value", + [ + (False, "testvalue"), + (True, "%testvalue%"), + ], + ids=["exact_match", "partial_match"], +) +def test_get_condition_json_property_match_bind_params( + memory_interface: AzureSQLMemory, partial_match: bool, expected_value: str +): + condition = memory_interface._get_condition_json_property_match( + json_column=PromptMemoryEntry.labels, + property_path="$.key", + value_to_match="TestValue", + partial_match=partial_match, + ) + # Extract the compiled bind parameters + params = condition.compile().params + assert params["match_property_value"] == expected_value + assert params["property_path"] == "$.key" + + def test_update_prompt_metadata_by_conversation_id(memory_interface: AzureSQLMemory): # Insert a test entry entry = PromptMemoryEntry( From fd22ab82ccd31bd7ab90e2f583273dbd67963a73 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Fri, 3 Apr 2026 14:03:40 -0700 Subject: [PATCH 11/30] use bind_param in new methods to avoid sql injection --- pyrit/memory/azure_sql_memory.py | 4 +++- pyrit/memory/sqlite_memory.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 96f8b35342..c87dfad8a5 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -341,10 +341,12 @@ def _get_condition_json_array_match( OR JSON_QUERY("{table_name}".{column_name}, :property_path) = '[]')""" ).bindparams(property_path=property_path) - value_expression = f"LOWER(JSON_VALUE(value, '{sub_path}'))" if sub_path else "LOWER(value)" + value_expression = "LOWER(JSON_VALUE(value, :sub_path))" if sub_path else "LOWER(value)" conditions = [] bindparams_dict: dict[str, str] = {"property_path": property_path} + if sub_path: + bindparams_dict["sub_path"] = sub_path for index, match_value in enumerate(array_to_match): param_name = f"match_value_{index}" diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index fa9487055e..3dac83658f 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -221,7 +221,7 @@ def _get_condition_json_array_match( table_name = json_column.class_.__tablename__ column_name = json_column.key - value_expression = f"LOWER(json_extract(value, '{sub_path}'))" if sub_path else "LOWER(value)" + value_expression = "LOWER(json_extract(value, :sub_path))" if sub_path else "LOWER(value)" conditions = [] for index, match_value in enumerate(array_to_match): @@ -230,6 +230,8 @@ def _get_condition_json_array_match( "property_path": property_path, param_name: match_value.lower(), } + if sub_path: + bind_params["sub_path"] = sub_path conditions.append( text( f"""EXISTS(SELECT 1 FROM json_each( From 227e7e582a5820d8fd69bd804544cfbbb2d451c9 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Fri, 3 Apr 2026 14:59:16 -0700 Subject: [PATCH 12/30] prevent text collisions using a uuid for bind_params --- pyrit/memory/azure_sql_memory.py | 60 +++++++++++++--------- pyrit/memory/memory_interface.py | 5 ++ pyrit/memory/sqlite_memory.py | 34 +++++++----- tests/unit/memory/test_azure_sql_memory.py | 10 ++-- 4 files changed, 67 insertions(+), 42 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index c87dfad8a5..bea8aa9915 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -313,16 +313,19 @@ def _get_condition_json_property_match( value_to_match: str, partial_match: bool = False, ) -> Any: + uid = self._uid() table_name = json_column.class_.__tablename__ column_name = json_column.key + pp_param = f"pp_{uid}" + mv_param = f"mv_{uid}" return text( f"""ISJSON("{table_name}".{column_name}) = 1 - AND LOWER(JSON_VALUE("{table_name}".{column_name}, :property_path)) {"LIKE" if partial_match else "="} :match_property_value""" # noqa: E501 - ).bindparams( - property_path=property_path, - match_property_value=f"%{value_to_match.lower()}%" if partial_match else value_to_match.lower(), - ) + AND LOWER(JSON_VALUE("{table_name}".{column_name}, :{pp_param})) {"LIKE" if partial_match else "="} :{mv_param}""" # noqa: E501 + ).bindparams(**{ + pp_param: property_path, + mv_param: f"%{value_to_match.lower()}%" if partial_match else value_to_match.lower(), + }) def _get_condition_json_array_match( self, @@ -332,30 +335,34 @@ def _get_condition_json_array_match( sub_path: str | None = None, array_to_match: Sequence[str], ) -> Any: + uid = self._uid() table_name = json_column.class_.__tablename__ column_name = json_column.key + pp_param = f"pp_{uid}" + sp_param = f"sp_{uid}" + if len(array_to_match) == 0: return text( f"""("{table_name}".{column_name} IS NULL - OR JSON_QUERY("{table_name}".{column_name}, :property_path) IS NULL - OR JSON_QUERY("{table_name}".{column_name}, :property_path) = '[]')""" - ).bindparams(property_path=property_path) + OR JSON_QUERY("{table_name}".{column_name}, :{pp_param}) IS NULL + OR JSON_QUERY("{table_name}".{column_name}, :{pp_param}) = '[]')""" + ).bindparams(**{pp_param: property_path}) - value_expression = "LOWER(JSON_VALUE(value, :sub_path))" if sub_path else "LOWER(value)" + value_expression = f"LOWER(JSON_VALUE(value, :{sp_param}))" if sub_path else "LOWER(value)" conditions = [] - bindparams_dict: dict[str, str] = {"property_path": property_path} + bindparams_dict: dict[str, str] = {pp_param: property_path} if sub_path: - bindparams_dict["sub_path"] = sub_path + bindparams_dict[sp_param] = sub_path for index, match_value in enumerate(array_to_match): - param_name = f"match_value_{index}" + mv_param = f"mv_{uid}_{index}" conditions.append( f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("{table_name}".{column_name}, - :property_path)) - WHERE {value_expression} = :{param_name})""" + :{pp_param})) + WHERE {value_expression} = :{mv_param})""" ) - bindparams_dict[param_name] = match_value.lower() + bindparams_dict[mv_param] = match_value.lower() combined = " AND ".join(conditions) return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict) @@ -367,30 +374,33 @@ def _get_unique_json_array_values( path_to_array: str, sub_path: str | None = None, ) -> list[str]: + uid = self._uid() + pa_param = f"pa_{uid}" + sp_param = f"sp_{uid}" table_name = json_column.class_.__tablename__ column_name = json_column.key with closing(self.get_session()) as session: if sub_path is None: rows = session.execute( text( - f"""SELECT DISTINCT JSON_VALUE("{table_name}".{column_name}, :path_to_array) AS value + f"""SELECT DISTINCT JSON_VALUE("{table_name}".{column_name}, :{pa_param}) AS value FROM "{table_name}" WHERE ISJSON("{table_name}".{column_name}) = 1 - AND JSON_VALUE("{table_name}".{column_name}, :path_to_array) IS NOT NULL""" - ).bindparams(path_to_array=path_to_array) + AND JSON_VALUE("{table_name}".{column_name}, :{pa_param}) IS NOT NULL""" + ).bindparams(**{pa_param: path_to_array}) ).fetchall() else: rows = session.execute( text( - f"""SELECT DISTINCT JSON_VALUE(items.value, :sub_path) AS value + f"""SELECT DISTINCT JSON_VALUE(items.value, :{sp_param}) AS value FROM "{table_name}" - CROSS APPLY OPENJSON(JSON_QUERY("{table_name}".{column_name}, :path_to_array)) AS items + CROSS APPLY OPENJSON(JSON_QUERY("{table_name}".{column_name}, :{pa_param})) AS items WHERE ISJSON("{table_name}".{column_name}) = 1 - AND JSON_VALUE(items.value, :sub_path) IS NOT NULL""" - ).bindparams( - path_to_array=path_to_array, - sub_path=sub_path, - ) + AND JSON_VALUE(items.value, :{sp_param}) IS NOT NULL""" + ).bindparams(**{ + pa_param: path_to_array, + sp_param: sub_path, + }) ).fetchall() return sorted(row[0] for row in rows) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 1ef99789d8..c7a439965d 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -75,6 +75,11 @@ class MemoryInterface(abc.ABC): results_path: str = None engine: Engine = None + @staticmethod + def _uid() -> str: + """Return a short unique suffix for bind-param deduplication.""" + return uuid.uuid4().hex[:8] + def __init__(self, embedding_model: Optional[Any] = None) -> None: """ Initialize the MemoryInterface. diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 3dac83658f..a2151ff318 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -219,24 +219,27 @@ def _get_condition_json_array_match( array_expr == "[]", ) + uid = self._uid() table_name = json_column.class_.__tablename__ column_name = json_column.key - value_expression = "LOWER(json_extract(value, :sub_path))" if sub_path else "LOWER(value)" + pp_param = f"property_path_{uid}" + sp_param = f"sub_path_{uid}" + value_expression = f"LOWER(json_extract(value, :{sp_param}))" if sub_path else "LOWER(value)" conditions = [] for index, match_value in enumerate(array_to_match): - param_name = f"match_value_{index}" + mv_param = f"mv_{uid}_{index}" bind_params: dict[str, str] = { - "property_path": property_path, - param_name: match_value.lower(), + pp_param: property_path, + mv_param: match_value.lower(), } if sub_path: - bind_params["sub_path"] = sub_path + bind_params[sp_param] = sub_path conditions.append( text( f"""EXISTS(SELECT 1 FROM json_each( - json_extract("{table_name}".{column_name}, :property_path)) - WHERE {value_expression} = :{param_name})""" + json_extract("{table_name}".{column_name}, :{pp_param})) + WHERE {value_expression} = :{mv_param})""" ).bindparams(**bind_params) ) return and_(*conditions) @@ -253,18 +256,21 @@ def _get_unique_json_array_values( property_expr = func.json_extract(json_column, path_to_array) rows = session.query(property_expr).filter(property_expr.isnot(None)).distinct().all() else: + uid = self._uid() + pa_param = f"path_to_array_{uid}" + sp_param = f"sub_path_{uid}" table_name = json_column.class_.__tablename__ column_name = json_column.key rows = session.execute( text( - f"""SELECT DISTINCT json_extract(j.value, :sub_path) AS value + f"""SELECT DISTINCT json_extract(j.value, :{sp_param}) AS value FROM "{table_name}", - json_each(json_extract("{table_name}".{column_name}, :path_to_array)) AS j - WHERE json_extract(j.value, :sub_path) IS NOT NULL""" - ).bindparams( - path_to_array=path_to_array, - sub_path=sub_path, - ) + json_each(json_extract("{table_name}".{column_name}, :{pa_param})) AS j + WHERE json_extract(j.value, :{sp_param}) IS NOT NULL""" + ).bindparams(**{ + pa_param: path_to_array, + sp_param: sub_path, + }) ).fetchall() return sorted(row[0] for row in rows) diff --git a/tests/unit/memory/test_azure_sql_memory.py b/tests/unit/memory/test_azure_sql_memory.py index c9e4497625..e0d488a61f 100644 --- a/tests/unit/memory/test_azure_sql_memory.py +++ b/tests/unit/memory/test_azure_sql_memory.py @@ -343,10 +343,14 @@ def test_get_condition_json_property_match_bind_params( value_to_match="TestValue", partial_match=partial_match, ) - # Extract the compiled bind parameters + # Extract the compiled bind parameters (param names include a random uid suffix) params = condition.compile().params - assert params["match_property_value"] == expected_value - assert params["property_path"] == "$.key" + pp_params = {k: v for k, v in params.items() if k.startswith("pp_")} + mv_params = {k: v for k, v in params.items() if k.startswith("mv_")} + assert len(pp_params) == 1 + assert list(pp_params.values())[0] == "$.key" + assert len(mv_params) == 1 + assert list(mv_params.values())[0] == expected_value def test_update_prompt_metadata_by_conversation_id(memory_interface: AzureSQLMemory): From 7b3b5c1fae811d3b864e4024e76219b8f5bf2e96 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Fri, 3 Apr 2026 15:02:47 -0700 Subject: [PATCH 13/30] format --- pyrit/memory/azure_sql_memory.py | 20 ++++++++++++-------- pyrit/memory/sqlite_memory.py | 10 ++++++---- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index bea8aa9915..0e3c5aead8 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -322,10 +322,12 @@ def _get_condition_json_property_match( return text( f"""ISJSON("{table_name}".{column_name}) = 1 AND LOWER(JSON_VALUE("{table_name}".{column_name}, :{pp_param})) {"LIKE" if partial_match else "="} :{mv_param}""" # noqa: E501 - ).bindparams(**{ - pp_param: property_path, - mv_param: f"%{value_to_match.lower()}%" if partial_match else value_to_match.lower(), - }) + ).bindparams( + **{ + pp_param: property_path, + mv_param: f"%{value_to_match.lower()}%" if partial_match else value_to_match.lower(), + } + ) def _get_condition_json_array_match( self, @@ -397,10 +399,12 @@ def _get_unique_json_array_values( CROSS APPLY OPENJSON(JSON_QUERY("{table_name}".{column_name}, :{pa_param})) AS items WHERE ISJSON("{table_name}".{column_name}) = 1 AND JSON_VALUE(items.value, :{sp_param}) IS NOT NULL""" - ).bindparams(**{ - pa_param: path_to_array, - sp_param: sub_path, - }) + ).bindparams( + **{ + pa_param: path_to_array, + sp_param: sub_path, + } + ) ).fetchall() return sorted(row[0] for row in rows) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index a2151ff318..dab6d61893 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -267,10 +267,12 @@ def _get_unique_json_array_values( FROM "{table_name}", json_each(json_extract("{table_name}".{column_name}, :{pa_param})) AS j WHERE json_extract(j.value, :{sp_param}) IS NOT NULL""" - ).bindparams(**{ - pa_param: path_to_array, - sp_param: sub_path, - }) + ).bindparams( + **{ + pa_param: path_to_array, + sp_param: sub_path, + } + ) ).fetchall() return sorted(row[0] for row in rows) From ede7e7792ebb834d82225ad19665cb44f56a7238 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 7 Apr 2026 16:00:22 -0700 Subject: [PATCH 14/30] more generic filters + doc fixes --- pyrit/memory/azure_sql_memory.py | 55 ++++- pyrit/memory/identifier_filters.py | 23 ++ pyrit/memory/memory_interface.py | 209 ++++++++++++------ pyrit/memory/sqlite_memory.py | 55 ++++- .../test_interface_attack_results.py | 53 +++-- .../test_interface_prompts.py | 148 ++++++++++--- .../test_interface_scenario_results.py | 44 ++-- .../memory_interface/test_interface_scores.py | 58 +++-- 8 files changed, 499 insertions(+), 146 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 0e3c5aead8..43ddbadd82 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -312,20 +312,39 @@ def _get_condition_json_property_match( property_path: str, value_to_match: str, partial_match: bool = False, + case_sensitive: bool = False, ) -> Any: uid = self._uid() table_name = json_column.class_.__tablename__ column_name = json_column.key pp_param = f"pp_{uid}" mv_param = f"mv_{uid}" + """ + Return an Azure SQL DB condition for matching a value at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query. + property_path (str): The JSON path for the property to match. + value_to_match (str): The string value that must match the extracted JSON property value. + partial_match (bool): Whether to perform a case-insensitive substring match. + case_sensitive (bool): Whether the match should be case-sensitive. Defaults to False. + + Returns: + Any: A SQLAlchemy condition for the backend-specific JSON query. + """ + json_func = "JSON_VALUE" if case_sensitive else "LOWER(JSON_VALUE)" + operator = "LIKE" if partial_match else "=" + target = value_to_match if case_sensitive else value_to_match.lower() + if partial_match: + target = f"%{target}%" return text( f"""ISJSON("{table_name}".{column_name}) = 1 - AND LOWER(JSON_VALUE("{table_name}".{column_name}, :{pp_param})) {"LIKE" if partial_match else "="} :{mv_param}""" # noqa: E501 + AND {json_func}("{table_name}".{column_name}, :{pp_param}) {operator} :{mv_param}""" ).bindparams( **{ pp_param: property_path, - mv_param: f"%{value_to_match.lower()}%" if partial_match else value_to_match.lower(), + mv_param: target, } ) @@ -337,6 +356,20 @@ def _get_condition_json_array_match( sub_path: str | None = None, array_to_match: Sequence[str], ) -> Any: + """ + Return an Azure SQL DB condition for matching an array at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. + property_path (str): The JSON path for the target array. + sub_path (Optional[str]): An optional JSON path applied to each array item before matching. + array_to_match (Sequence[str]): The array that must match the extracted JSON array values. + For a match, ALL values in this array must be present in the JSON array. + If `array_to_match` is empty, the condition must match only if the target is also an empty array or None. + + Returns: + Any: A database-specific SQLAlchemy condition. + """ uid = self._uid() table_name = json_column.class_.__tablename__ column_name = json_column.key @@ -376,6 +409,22 @@ def _get_unique_json_array_values( path_to_array: str, sub_path: str | None = None, ) -> list[str]: + """ + Return sorted unique values in an array located at a given path within a JSON object in an Azure SQL DB Column. + + This method performs a database-level query to extract distinct values from a + an array within a JSON-type column. When ``sub_path`` is provided, the distinct values are + extracted from each array item using the sub-path. + + Args: + json_column (Any): The JSON-backed model field to query. + path_to_array (str): The JSON path to the array whose unique values are extracted. + sub_path (str | None): Optional JSON path applied to each array + item before collecting distinct values. + + Returns: + list[str]: A sorted list of unique values in the array. + """ uid = self._uid() pa_param = f"pa_{uid}" sp_param = f"sp_{uid}" @@ -580,6 +629,8 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] """ Insert a list of message pieces into the memory storage. + Args: + message_pieces (Sequence[MessagePiece]): A sequence of MessagePiece instances to be added. """ self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces]) diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py index 122d89965b..18a6423f0b 100644 --- a/pyrit/memory/identifier_filters.py +++ b/pyrit/memory/identifier_filters.py @@ -3,11 +3,34 @@ from dataclasses import dataclass +from prometheus_client import Enum + + +class IdentifierType(Enum): + """Enumeration of supported identifier types for filtering.""" + + ATTACK = "attack" + TARGET = "target" + SCORER = "scorer" + CONVERTER = "converter" + @dataclass(frozen=True) class IdentifierFilter: """Immutable filter definition for matching JSON-backed identifier properties.""" + identifier_type: IdentifierType property_path: str + sub_path: str | None value_to_match: str partial_match: bool = False + + def __post_init__(self) -> None: + """ + Validate that the filter configuration. + + Raises: + ValueError: If the filter configuration is not valid. + """ + if self.partial_match and self.sub_path: + raise ValueError("Cannot use sub_path with partial_match") diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index c7a439965d..fdc326cb27 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -19,7 +19,7 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute from pyrit.common.path import DB_DATA_PATH -from pyrit.memory.identifier_filters import IdentifierFilter +from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType from pyrit.memory.memory_embedding import ( MemoryEmbedding, default_memory_embedding_factory, @@ -119,6 +119,70 @@ def disable_embedding(self) -> None: """ self.memory_embedding = None + def _get_identifier_property_match_condition( + self, identifier_column: Any, identifier_filter: IdentifierFilter + ) -> Any: + """ + Build a SQLAlchemy condition that matches a JSON identifier column against the given filter. + + Args: + identifier_column (Any): The JSON-backed SQLAlchemy column to query. + identifier_filter (IdentifierFilter): The filter specifying the property path, + optional sub-path, value to match, and whether to use partial matching. + + Returns: + Any: A SQLAlchemy condition for the backend-specific JSON query. + """ + return self._get_condition_json_match( + json_column=identifier_column, + property_path=identifier_filter.property_path, + sub_path=identifier_filter.sub_path, + value_to_match=identifier_filter.value_to_match, + partial_match=identifier_filter.partial_match, + ) + + def _get_condition_json_match( + self, + *, + json_column: InstrumentedAttribute[Any], + property_path: str, + sub_path: str | None = None, + value_to_match: str, + partial_match: bool = False, + case_sensitive: bool = False, + ) -> Any: + """ + Return a database-specific condition for matching a value at a given path within a JSON object + or within items of a JSON array if sub_path is provided. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query. + property_path (str): The JSON path for the property to match. + sub_path (str | None): An optional JSON path that indicates property at property_path is an array + and the condition should resolve if any element in that array matches the value. + Cannot be used with partial_match. + value_to_match (str): The string value that must match the extracted JSON property value. + partial_match (bool): Whether to perform a case-insensitive substring match. + case_sensitive (bool): Whether the match should be case-sensitive. Defaults to False. + + Returns: + Any: A SQLAlchemy condition for the backend-specific JSON query. + """ + if sub_path: + return self._get_condition_json_array_match( + json_column=json_column, + property_path=property_path, + sub_path=sub_path, + array_to_match=[value_to_match], + ) + return self._get_condition_json_property_match( + json_column=json_column, + property_path=property_path, + value_to_match=value_to_match, + partial_match=partial_match, + case_sensitive=case_sensitive, + ) + @abc.abstractmethod def _get_condition_json_property_match( self, @@ -127,6 +191,7 @@ def _get_condition_json_property_match( property_path: str, value_to_match: str, partial_match: bool = False, + case_sensitive: bool = False, ) -> Any: """ Return a database-specific condition for matching a value at a given path within a JSON object. @@ -136,6 +201,7 @@ def _get_condition_json_property_match( property_path (str): The JSON path for the property to match. value_to_match (str): The string value that must match the extracted JSON property value. partial_match (bool): Whether to perform a case-insensitive substring match. + case_sensitive (bool): Whether the match should be case-sensitive. Defaults to False. Returns: Any: A SQLAlchemy condition for the backend-specific JSON query. @@ -445,7 +511,7 @@ def get_scores( score_category: Optional[str] = None, sent_after: Optional[datetime] = None, sent_before: Optional[datetime] = None, - scorer_identifier_filter: Optional[IdentifierFilter] = None, + identifier_filters: Optional[set[IdentifierFilter]] = None, ) -> Sequence[Score]: """ Retrieve a list of Score objects based on the specified filters. @@ -456,7 +522,7 @@ def get_scores( score_category (Optional[str]): The category of the score to filter by. sent_after (Optional[datetime]): Filter for scores sent after this datetime. sent_before (Optional[datetime]): Filter for scores sent before this datetime. - scorer_identifier_filter (Optional[IdentifierFilter]): An IdentifierFilter object that + identifier_filters (Optional[set[IdentifierFilter]]): A set of IdentifierFilter objects that allows filtering by various scorer identifier JSON properties. Defaults to None. Returns: @@ -474,15 +540,21 @@ def get_scores( conditions.append(ScoreEntry.timestamp >= sent_after) if sent_before: conditions.append(ScoreEntry.timestamp <= sent_before) - if scorer_identifier_filter: - conditions.append( - self._get_condition_json_property_match( - json_column=ScoreEntry.scorer_class_identifier, - property_path=scorer_identifier_filter.property_path, - value_to_match=scorer_identifier_filter.value_to_match, - partial_match=scorer_identifier_filter.partial_match, - ) - ) + if identifier_filters: + for identifier_filter in identifier_filters: + column = None + + match identifier_filter.identifier_type: + case IdentifierType.SCORER: + column = ScoreEntry.scorer_class_identifier + + if column is not None: + conditions.append( + self._get_identifier_property_match_condition( + identifier_column=column, + identifier_filter=identifier_filter, + ) + ) if not conditions: return [] @@ -613,8 +685,7 @@ def get_message_pieces( data_type: Optional[str] = None, not_data_type: Optional[str] = None, converted_value_sha256: Optional[Sequence[str]] = None, - attack_identifier_filter: Optional[IdentifierFilter] = None, - prompt_target_identifier_filter: Optional[IdentifierFilter] = None, + identifier_filters: Optional[set[IdentifierFilter]] = None, ) -> Sequence[MessagePiece]: """ Retrieve a list of MessagePiece objects based on the specified filters. @@ -636,12 +707,9 @@ def get_message_pieces( not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. Defaults to None. - attack_identifier_filter (Optional[IdentifierFilter], optional): - An IdentifierFilter object that - allows filtering by various attack identifier JSON properties. Defaults to None. - prompt_target_identifier_filter (Optional[IdentifierFilter], optional): - An IdentifierFilter object that - allows filtering by various target identifier JSON properties. Defaults to None. + identifier_filters (Optional[set[IdentifierFilter]], optional): + A set of IdentifierFilter objects that + allow filtering by various identifier JSON properties. Defaults to None. Returns: Sequence[MessagePiece]: A list of MessagePiece objects that match the specified filters. @@ -684,25 +752,25 @@ def get_message_pieces( conditions.append(PromptMemoryEntry.converted_value_data_type != not_data_type) if converted_value_sha256: conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256)) - if attack_identifier_filter: - conditions.append( - self._get_condition_json_property_match( - json_column=PromptMemoryEntry.attack_identifier, - property_path=attack_identifier_filter.property_path, - value_to_match=attack_identifier_filter.value_to_match, - partial_match=attack_identifier_filter.partial_match, - ) - ) - if prompt_target_identifier_filter: - conditions.append( - self._get_condition_json_property_match( - json_column=PromptMemoryEntry.prompt_target_identifier, - property_path=prompt_target_identifier_filter.property_path, - value_to_match=prompt_target_identifier_filter.value_to_match, - partial_match=prompt_target_identifier_filter.partial_match, - ) - ) - + if identifier_filters: + for identifier_filter in identifier_filters: + column: Any = None + + match identifier_filter.identifier_type: + case IdentifierType.ATTACK: + column = PromptMemoryEntry.attack_identifier + case IdentifierType.TARGET: + column = PromptMemoryEntry.prompt_target_identifier + case IdentifierType.CONVERTER: + column = PromptMemoryEntry.converter_identifiers + + if column is not None: + conditions.append( + self._get_identifier_property_match_condition( + identifier_column=column, + identifier_filter=identifier_filter, + ) + ) try: memory_entries: Sequence[PromptMemoryEntry] = self._query_entries( PromptMemoryEntry, conditions=and_(*conditions) if conditions else None, join_scores=True @@ -1429,7 +1497,7 @@ def get_attack_results( converter_classes: Optional[Sequence[str]] = None, targeted_harm_categories: Optional[Sequence[str]] = None, labels: Optional[dict[str, str]] = None, - attack_identifier_filter: Optional[IdentifierFilter] = None, + identifier_filters: Optional[set[IdentifierFilter]] = None, ) -> Sequence[AttackResult]: """ Retrieve a list of AttackResult objects based on the specified filters. @@ -1457,8 +1525,8 @@ def get_attack_results( labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. These labels are associated with the prompts themselves, used for custom tagging and tracking. Defaults to None. - attack_identifier_filter (Optional[IdentifierFilter], optional): - An IdentifierFilter object that allows filtering by various attack identifier + identifier_filters (Optional[set[IdentifierFilter]], optional): + A set of IdentifierFilter objects that allows filtering by various attack identifier JSON properties. Defaults to None. Returns: @@ -1488,6 +1556,7 @@ def get_attack_results( json_column=AttackResultEntry.atomic_attack_identifier, property_path="$.children.attack.class_name", value_to_match=attack_class, + case_sensitive=True, ) ) @@ -1513,15 +1582,21 @@ def get_attack_results( # Use database-specific JSON query method conditions.append(self._get_attack_result_label_condition(labels=labels)) - if attack_identifier_filter: - conditions.append( - self._get_condition_json_property_match( - json_column=AttackResultEntry.atomic_attack_identifier, - property_path=attack_identifier_filter.property_path, - value_to_match=attack_identifier_filter.value_to_match, - partial_match=attack_identifier_filter.partial_match, - ) - ) + if identifier_filters: + for identifier_filter in identifier_filters: + column = None + + match identifier_filter.identifier_type: + case IdentifierType.ATTACK: + column = AttackResultEntry.atomic_attack_identifier + + if column is not None: + conditions.append( + self._get_identifier_property_match_condition( + identifier_column=column, + identifier_filter=identifier_filter, + ) + ) try: entries: Sequence[AttackResultEntry] = self._query_entries( @@ -1703,7 +1778,7 @@ def get_scenario_results( labels: Optional[dict[str, str]] = None, objective_target_endpoint: Optional[str] = None, objective_target_model_name: Optional[str] = None, - objective_target_identifier_filter: Optional[IdentifierFilter] = None, + identifier_filters: Optional[set[IdentifierFilter]] = None, ) -> Sequence[ScenarioResult]: """ Retrieve a list of ScenarioResult objects based on the specified filters. @@ -1727,8 +1802,8 @@ def get_scenario_results( objective_target_model_name (Optional[str], optional): Filter for scenarios where the objective_target_identifier has a model_name attribute containing this value (case-insensitive). Defaults to None. - objective_target_identifier_filter (Optional[IdentifierFilter], optional): - An IdentifierFilter object that allows filtering by various target identifier JSON properties. + identifier_filters (Optional[set[IdentifierFilter]], optional): + A set of IdentifierFilter objects that allows filtering by various target identifier JSON properties. Returns: Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. @@ -1786,15 +1861,23 @@ def get_scenario_results( ) ) - if objective_target_identifier_filter: - conditions.append( - self._get_condition_json_property_match( - json_column=ScenarioResultEntry.objective_target_identifier, - property_path=objective_target_identifier_filter.property_path, - value_to_match=objective_target_identifier_filter.value_to_match, - partial_match=objective_target_identifier_filter.partial_match, - ) - ) + if identifier_filters: + for identifier_filter in identifier_filters: + column = None + + match identifier_filter.identifier_type: + case IdentifierType.SCORER: + column = ScenarioResultEntry.objective_scorer_identifier + case IdentifierType.TARGET: + column = ScenarioResultEntry.objective_target_identifier + + if column is not None: + conditions.append( + self._get_identifier_property_match_condition( + identifier_column=column, + identifier_filter=identifier_filter, + ) + ) try: entries: Sequence[ScenarioResultEntry] = self._query_entries( diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index dab6d61893..d9f6909274 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -197,11 +197,30 @@ def _get_condition_json_property_match( property_path: str, value_to_match: str, partial_match: bool = False, + case_sensitive: bool = False, ) -> Any: - extracted_value = func.lower(func.json_extract(json_column, property_path)) + """ + Return a SQLite DB condition for matching a value at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query. + property_path (str): The JSON path for the property to match. + value_to_match (str): The string value that must match the extracted JSON property value. + partial_match (bool): Whether to perform a case-insensitive substring match. + case_sensitive (bool): Whether the match should be case-sensitive. Defaults to False. + + Returns: + Any: A SQLAlchemy condition for the backend-specific JSON query. + """ + raw = func.json_extract(json_column, property_path) + if case_sensitive: + extracted_value, target = raw, value_to_match + else: + extracted_value, target = func.lower(raw), value_to_match.lower() + if partial_match: - return extracted_value.like(f"%{value_to_match.lower()}%") - return extracted_value == value_to_match.lower() + return extracted_value.like(f"%{target}%") + return extracted_value == target def _get_condition_json_array_match( self, @@ -211,6 +230,20 @@ def _get_condition_json_array_match( sub_path: str | None = None, array_to_match: Sequence[str], ) -> Any: + """ + Return a SQLite DB condition for matching an array at a given path within a JSON object. + + Args: + json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. + property_path (str): The JSON path for the target array. + sub_path (Optional[str]): An optional JSON path applied to each array item before matching. + array_to_match (Sequence[str]): The array that must match the extracted JSON array values. + For a match, ALL values in this array must be present in the JSON array. + If `array_to_match` is empty, the condition must match only if the target is also an empty array or None. + + Returns: + Any: A database-specific SQLAlchemy condition. + """ array_expr = func.json_extract(json_column, property_path) if len(array_to_match) == 0: return or_( @@ -251,6 +284,22 @@ def _get_unique_json_array_values( path_to_array: str, sub_path: str | None = None, ) -> list[str]: + """ + Return sorted unique values in an array located at a given path within a JSON object in a SQLite DB Column. + + This method performs a database-level query to extract distinct values from a + an array within a JSON-type column. When ``sub_path`` is provided, the distinct values are + extracted from each array item using the sub-path. + + Args: + json_column (Any): The JSON-backed model field to query. + path_to_array (str): The JSON path to the array whose unique values are extracted. + sub_path (str | None): Optional JSON path applied to each array + item before collecting distinct values. + + Returns: + list[str]: A sorted list of unique values in the array. + """ with closing(self.get_session()) as session: if sub_path is None: property_expr = func.json_extract(json_column, path_to_array) diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 84cda0b409..b17fe35fd5 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -9,7 +9,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.memory import MemoryInterface -from pyrit.memory.identifier_filters import IdentifierFilter +from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType from pyrit.memory.memory_models import AttackResultEntry from pyrit.models import ( AttackOutcome, @@ -1176,6 +1176,15 @@ def test_get_attack_results_by_attack_class_no_match(sqlite_instance: MemoryInte assert len(results) == 0 +def test_get_attack_results_by_attack_class_case_sensitive(sqlite_instance: MemoryInterface): + """Test that attack_class filter is case-sensitive (exact match).""" + ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") + sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) + + results = sqlite_instance.get_attack_results(attack_class="crescendoattack") + assert len(results) == 0 + + def test_get_attack_results_by_attack_class_no_identifier(sqlite_instance: MemoryInterface): """Test that attacks with no attack_identifier (empty JSON) are excluded by attack_class filter.""" ar1 = create_attack_result("conv_1", 1) # No attack_identifier → stored as {} @@ -1354,11 +1363,15 @@ def test_get_attack_results_by_attack_identifier_filter_hash(sqlite_instance: Me # Filter by hash of ar1's attack identifier results = sqlite_instance.get_attack_results( - attack_identifier_filter=IdentifierFilter( - property_path="$.hash", - value_to_match=ar1.atomic_attack_identifier.hash, - partial_match=False, - ), + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.ATTACK, + property_path="$.hash", + sub_path=None, + value_to_match=ar1.atomic_attack_identifier.hash, + partial_match=False, + ) + }, ) assert len(results) == 1 assert results[0].conversation_id == "conv_1" @@ -1373,11 +1386,15 @@ def test_get_attack_results_by_attack_identifier_filter_class_name(sqlite_instan # Filter by partial attack class name results = sqlite_instance.get_attack_results( - attack_identifier_filter=IdentifierFilter( - property_path="$.children.attack.class_name", - value_to_match="Crescendo", - partial_match=True, - ), + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.ATTACK, + property_path="$.children.attack.class_name", + sub_path=None, + value_to_match="Crescendo", + partial_match=True, + ) + }, ) assert len(results) == 2 assert {r.conversation_id for r in results} == {"conv_1", "conv_3"} @@ -1389,10 +1406,14 @@ def test_get_attack_results_by_attack_identifier_filter_no_match(sqlite_instance sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) results = sqlite_instance.get_attack_results( - attack_identifier_filter=IdentifierFilter( - property_path="$.hash", - value_to_match="nonexistent_hash", - partial_match=False, - ), + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.ATTACK, + property_path="$.hash", + sub_path=None, + value_to_match="nonexistent_hash", + partial_match=False, + ) + }, ) assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index eec4d3d88a..7126968907 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -14,7 +14,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry -from pyrit.memory.identifier_filters import IdentifierFilter +from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType from pyrit.models import ( Message, MessagePiece, @@ -1276,22 +1276,30 @@ def test_get_message_pieces_by_attack_identifier_filter(sqlite_instance: MemoryI # Filter by exact attack hash results = sqlite_instance.get_message_pieces( - attack_identifier_filter=IdentifierFilter( - property_path="$.hash", - value_to_match=attack1.get_identifier().hash, - partial_match=False, - ), + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.ATTACK, + property_path="$.hash", + sub_path=None, + value_to_match=attack1.get_identifier().hash, + partial_match=False, + ) + }, ) assert len(results) == 1 assert results[0].original_value == "Hello 1" # No match results = sqlite_instance.get_message_pieces( - attack_identifier_filter=IdentifierFilter( - property_path="$.hash", - value_to_match="nonexistent_hash", - partial_match=False, - ), + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.ATTACK, + property_path="$.hash", + sub_path=None, + value_to_match="nonexistent_hash", + partial_match=False, + ) + }, ) assert len(results) == 0 @@ -1329,32 +1337,122 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI # Filter by target hash results = sqlite_instance.get_message_pieces( - prompt_target_identifier_filter=IdentifierFilter( - property_path="$.hash", - value_to_match=target_id_1.hash, - partial_match=False, - ), + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.TARGET, + property_path="$.hash", + sub_path=None, + value_to_match=target_id_1.hash, + partial_match=False, + ) + }, ) assert len(results) == 1 assert results[0].original_value == "Hello OpenAI" # Filter by endpoint partial match results = sqlite_instance.get_message_pieces( - prompt_target_identifier_filter=IdentifierFilter( - property_path="$.endpoint", - value_to_match="openai", - partial_match=True, - ), + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.TARGET, + property_path="$.endpoint", + sub_path=None, + value_to_match="openai", + partial_match=True, + ) + }, ) assert len(results) == 1 assert results[0].original_value == "Hello OpenAI" # No match results = sqlite_instance.get_message_pieces( - prompt_target_identifier_filter=IdentifierFilter( - property_path="$.hash", - value_to_match="nonexistent", - partial_match=False, + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.TARGET, + property_path="$.hash", + sub_path=None, + value_to_match="nonexistent", + partial_match=False, + ) + }, + ) + assert len(results) == 0 + + +def test_get_message_pieces_by_converter_identifier_filter_with_sub_path(sqlite_instance: MemoryInterface): + converter_a = ComponentIdentifier( + class_name="Base64Converter", + class_module="pyrit.prompt_converter", + ) + converter_b = ComponentIdentifier( + class_name="ROT13Converter", + class_module="pyrit.prompt_converter", + ) + + entries = [ + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="With Base64", + converter_identifiers=[converter_a], + ) ), + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="With both converters", + converter_identifiers=[converter_a, converter_b], + ) + ), + PromptMemoryEntry( + entry=MessagePiece( + role="user", + original_value="No converters", + ) + ), + ] + + sqlite_instance._insert_entries(entries=entries) + + # Filter by converter class_name using sub_path (array element matching) + results = sqlite_instance.get_message_pieces( + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.CONVERTER, + property_path="$", + sub_path="$.class_name", + value_to_match="Base64Converter", + ) + }, + ) + assert len(results) == 2 + original_values = {r.original_value for r in results} + assert original_values == {"With Base64", "With both converters"} + + # Filter by ROT13Converter — only the entry with both converters + results = sqlite_instance.get_message_pieces( + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.CONVERTER, + property_path="$", + sub_path="$.class_name", + value_to_match="ROT13Converter", + ) + }, + ) + assert len(results) == 1 + assert results[0].original_value == "With both converters" + + # No match + results = sqlite_instance.get_message_pieces( + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.CONVERTER, + property_path="$", + sub_path="$.class_name", + value_to_match="NonexistentConverter", + ) + }, ) assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index ee2933b70a..d04931d470 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -9,7 +9,7 @@ from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface -from pyrit.memory.identifier_filters import IdentifierFilter +from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType from pyrit.models import ( AttackOutcome, AttackResult, @@ -681,11 +681,15 @@ def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: # Filter by target hash results = sqlite_instance.get_scenario_results( - objective_target_identifier_filter=IdentifierFilter( - property_path="$.hash", - value_to_match=target_id_1.hash, - partial_match=False, - ), + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.TARGET, + property_path="$.hash", + sub_path=None, + value_to_match=target_id_1.hash, + partial_match=False, + ) + }, ) assert len(results) == 1 assert results[0].scenario_identifier.name == "Scenario OpenAI" @@ -724,11 +728,15 @@ def test_get_scenario_results_by_target_identifier_filter_endpoint(sqlite_instan # Filter by endpoint partial match results = sqlite_instance.get_scenario_results( - objective_target_identifier_filter=IdentifierFilter( - property_path="$.endpoint", - value_to_match="openai", - partial_match=True, - ), + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.TARGET, + property_path="$.endpoint", + sub_path=None, + value_to_match="openai", + partial_match=True, + ) + }, ) assert len(results) == 1 assert results[0].scenario_identifier.name == "Scenario OpenAI" @@ -752,10 +760,14 @@ def test_get_scenario_results_by_target_identifier_filter_no_match(sqlite_instan sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1]) results = sqlite_instance.get_scenario_results( - objective_target_identifier_filter=IdentifierFilter( - property_path="$.hash", - value_to_match="nonexistent_hash", - partial_match=False, - ), + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.TARGET, + property_path="$.hash", + sub_path=None, + value_to_match="nonexistent_hash", + partial_match=False, + ) + }, ) assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 2c90b18313..10d0888ea7 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -13,7 +13,7 @@ from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ComponentIdentifier from pyrit.memory import MemoryInterface, PromptMemoryEntry -from pyrit.memory.identifier_filters import IdentifierFilter +from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType from pyrit.models import ( MessagePiece, Score, @@ -262,43 +262,59 @@ def test_get_scores_by_scorer_identifier_filter( # Filter by exact class_name match results = sqlite_instance.get_scores( - scorer_identifier_filter=IdentifierFilter( - property_path="$.class_name", - value_to_match="ScorerAlpha", - partial_match=False, - ), + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.SCORER, + property_path="$.class_name", + sub_path=None, + value_to_match="ScorerAlpha", + partial_match=False, + ) + }, ) assert len(results) == 1 assert results[0].score_value == "0.9" # Filter by partial class_name match results = sqlite_instance.get_scores( - scorer_identifier_filter=IdentifierFilter( - property_path="$.class_name", - value_to_match="Scorer", - partial_match=True, - ), + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.SCORER, + property_path="$.class_name", + sub_path=None, + value_to_match="Scorer", + partial_match=True, + ) + }, ) assert len(results) == 2 # Filter by hash scorer_hash = score_a.scorer_class_identifier.hash results = sqlite_instance.get_scores( - scorer_identifier_filter=IdentifierFilter( - property_path="$.hash", - value_to_match=scorer_hash, - partial_match=False, - ), + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.SCORER, + property_path="$.hash", + sub_path=None, + value_to_match=scorer_hash, + partial_match=False, + ) + }, ) assert len(results) == 1 assert results[0].score_value == "0.9" # No match results = sqlite_instance.get_scores( - scorer_identifier_filter=IdentifierFilter( - property_path="$.class_name", - value_to_match="NonExistent", - partial_match=False, - ), + identifier_filters={ + IdentifierFilter( + identifier_type=IdentifierType.SCORER, + property_path="$.class_name", + sub_path=None, + value_to_match="NonExistent", + partial_match=False, + ) + }, ) assert len(results) == 0 From 4dcddad33069edb3bb23c48b967ba70582cf1ec5 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 7 Apr 2026 16:14:08 -0700 Subject: [PATCH 15/30] add casesensitive --- pyrit/memory/identifier_filters.py | 21 +++++++++++++++++---- pyrit/memory/memory_interface.py | 11 ++++++----- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py index 18a6423f0b..6a119e2ceb 100644 --- a/pyrit/memory/identifier_filters.py +++ b/pyrit/memory/identifier_filters.py @@ -17,20 +17,33 @@ class IdentifierType(Enum): @dataclass(frozen=True) class IdentifierFilter: - """Immutable filter definition for matching JSON-backed identifier properties.""" + """ + Immutable filter definition for matching JSON-backed identifier properties. + + Attributes: + identifier_type: The type of identifier column to filter on. + property_path: The JSON path for the property to match. + sub_path: An optional JSON path that indicates the property at property_path is an array + and the condition should resolve if any element in that array matches the value. + Cannot be used with partial_match. + value_to_match: The string value that must match the extracted JSON property value. + partial_match: Whether to perform a substring match. Cannot be used with sub_path. + case_sensitive: Whether the match should be case-sensitive. Defaults to False. + """ identifier_type: IdentifierType property_path: str sub_path: str | None value_to_match: str partial_match: bool = False + case_sensitive: bool = False def __post_init__(self) -> None: """ - Validate that the filter configuration. + Validate the filter configuration. Raises: ValueError: If the filter configuration is not valid. """ - if self.partial_match and self.sub_path: - raise ValueError("Cannot use sub_path with partial_match") + if self.sub_path and (self.partial_match or self.case_sensitive): + raise ValueError("Cannot use sub_path with partial_match or case_sensitive") diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index fdc326cb27..5599aefc60 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -119,7 +119,7 @@ def disable_embedding(self) -> None: """ self.memory_embedding = None - def _get_identifier_property_match_condition( + def _get_condition_identifier_property_match( self, identifier_column: Any, identifier_filter: IdentifierFilter ) -> Any: """ @@ -139,6 +139,7 @@ def _get_identifier_property_match_condition( sub_path=identifier_filter.sub_path, value_to_match=identifier_filter.value_to_match, partial_match=identifier_filter.partial_match, + case_sensitive=identifier_filter.case_sensitive, ) def _get_condition_json_match( @@ -550,7 +551,7 @@ def get_scores( if column is not None: conditions.append( - self._get_identifier_property_match_condition( + self._get_condition_identifier_property_match( identifier_column=column, identifier_filter=identifier_filter, ) @@ -766,7 +767,7 @@ def get_message_pieces( if column is not None: conditions.append( - self._get_identifier_property_match_condition( + self._get_condition_identifier_property_match( identifier_column=column, identifier_filter=identifier_filter, ) @@ -1592,7 +1593,7 @@ def get_attack_results( if column is not None: conditions.append( - self._get_identifier_property_match_condition( + self._get_condition_identifier_property_match( identifier_column=column, identifier_filter=identifier_filter, ) @@ -1873,7 +1874,7 @@ def get_scenario_results( if column is not None: conditions.append( - self._get_identifier_property_match_condition( + self._get_condition_identifier_property_match( identifier_column=column, identifier_filter=identifier_filter, ) From b6fa8ee7714f92cc61ca2eae828d12527279d912 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 7 Apr 2026 16:31:07 -0700 Subject: [PATCH 16/30] enum --- pyrit/memory/identifier_filters.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py index 6a119e2ceb..a906528daf 100644 --- a/pyrit/memory/identifier_filters.py +++ b/pyrit/memory/identifier_filters.py @@ -2,8 +2,7 @@ # Licensed under the MIT license. from dataclasses import dataclass - -from prometheus_client import Enum +from enum import Enum class IdentifierType(Enum): From 8379e71dd599d916fc76e04e0b196eb7c4b792d1 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 7 Apr 2026 17:00:00 -0700 Subject: [PATCH 17/30] optimizations --- pyrit/memory/azure_sql_memory.py | 10 +-- pyrit/memory/identifier_filters.py | 2 +- pyrit/memory/memory_interface.py | 138 ++++++++++++++++------------- 3 files changed, 81 insertions(+), 69 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 43ddbadd82..a2b3c40ea0 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -314,11 +314,6 @@ def _get_condition_json_property_match( partial_match: bool = False, case_sensitive: bool = False, ) -> Any: - uid = self._uid() - table_name = json_column.class_.__tablename__ - column_name = json_column.key - pp_param = f"pp_{uid}" - mv_param = f"mv_{uid}" """ Return an Azure SQL DB condition for matching a value at a given path within a JSON object. @@ -332,6 +327,11 @@ def _get_condition_json_property_match( Returns: Any: A SQLAlchemy condition for the backend-specific JSON query. """ + uid = self._uid() + table_name = json_column.class_.__tablename__ + column_name = json_column.key + pp_param = f"pp_{uid}" + mv_param = f"mv_{uid}" json_func = "JSON_VALUE" if case_sensitive else "LOWER(JSON_VALUE)" operator = "LIKE" if partial_match else "=" target = value_to_match if case_sensitive else value_to_match.lower() diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py index a906528daf..c0e545ec41 100644 --- a/pyrit/memory/identifier_filters.py +++ b/pyrit/memory/identifier_filters.py @@ -32,8 +32,8 @@ class IdentifierFilter: identifier_type: IdentifierType property_path: str - sub_path: str | None value_to_match: str + sub_path: str | None = None partial_match: bool = False case_sensitive: bool = False diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 5599aefc60..48d6de5f54 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -120,7 +120,7 @@ def disable_embedding(self) -> None: self.memory_embedding = None def _get_condition_identifier_property_match( - self, identifier_column: Any, identifier_filter: IdentifierFilter + self, *, identifier_column: Any, identifier_filter: IdentifierFilter ) -> Any: """ Build a SQLAlchemy condition that matches a JSON identifier column against the given filter. @@ -142,6 +142,45 @@ def _get_condition_identifier_property_match( case_sensitive=identifier_filter.case_sensitive, ) + def _build_identifier_filter_conditions( + self, + *, + identifier_filters: set[IdentifierFilter], + identifier_column_map: dict[IdentifierType, Any], + caller: str, + ) -> list[Any]: + """ + Build SQLAlchemy conditions from a set of IdentifierFilters. + + Args: + identifier_filters (set[IdentifierFilter]): The filters to convert to conditions. + identifier_column_map (dict[IdentifierType, Any]): Mapping from IdentifierType to the + JSON-backed SQLAlchemy column that should be queried for that type. + caller (str): Name of the calling method, used in error messages. + + Returns: + list[Any]: A list of SQLAlchemy conditions. + + Raises: + ValueError: If a filter uses an IdentifierType not in identifier_column_map. + """ + conditions: list[Any] = [] + for identifier_filter in identifier_filters: + column = identifier_column_map.get(identifier_filter.identifier_type) + if column is None: + supported = ", ".join(t.name for t in identifier_column_map) + raise ValueError( + f"{caller} does not support identifier type " + f"{identifier_filter.identifier_type!r}. Supported: {supported}" + ) + conditions.append( + self._get_condition_identifier_property_match( + identifier_column=column, + identifier_filter=identifier_filter, + ) + ) + return conditions + def _get_condition_json_match( self, *, @@ -542,20 +581,13 @@ def get_scores( if sent_before: conditions.append(ScoreEntry.timestamp <= sent_before) if identifier_filters: - for identifier_filter in identifier_filters: - column = None - - match identifier_filter.identifier_type: - case IdentifierType.SCORER: - column = ScoreEntry.scorer_class_identifier - - if column is not None: - conditions.append( - self._get_condition_identifier_property_match( - identifier_column=column, - identifier_filter=identifier_filter, - ) - ) + conditions.extend( + self._build_identifier_filter_conditions( + identifier_filters=identifier_filters, + identifier_column_map={IdentifierType.SCORER: ScoreEntry.scorer_class_identifier}, + caller="get_scores", + ) + ) if not conditions: return [] @@ -754,24 +786,17 @@ def get_message_pieces( if converted_value_sha256: conditions.append(PromptMemoryEntry.converted_value_sha256.in_(converted_value_sha256)) if identifier_filters: - for identifier_filter in identifier_filters: - column: Any = None - - match identifier_filter.identifier_type: - case IdentifierType.ATTACK: - column = PromptMemoryEntry.attack_identifier - case IdentifierType.TARGET: - column = PromptMemoryEntry.prompt_target_identifier - case IdentifierType.CONVERTER: - column = PromptMemoryEntry.converter_identifiers - - if column is not None: - conditions.append( - self._get_condition_identifier_property_match( - identifier_column=column, - identifier_filter=identifier_filter, - ) - ) + conditions.extend( + self._build_identifier_filter_conditions( + identifier_filters=identifier_filters, + identifier_column_map={ + IdentifierType.ATTACK: PromptMemoryEntry.attack_identifier, + IdentifierType.TARGET: PromptMemoryEntry.prompt_target_identifier, + IdentifierType.CONVERTER: PromptMemoryEntry.converter_identifiers, + }, + caller="get_message_pieces", + ) + ) try: memory_entries: Sequence[PromptMemoryEntry] = self._query_entries( PromptMemoryEntry, conditions=and_(*conditions) if conditions else None, join_scores=True @@ -1584,20 +1609,13 @@ def get_attack_results( conditions.append(self._get_attack_result_label_condition(labels=labels)) if identifier_filters: - for identifier_filter in identifier_filters: - column = None - - match identifier_filter.identifier_type: - case IdentifierType.ATTACK: - column = AttackResultEntry.atomic_attack_identifier - - if column is not None: - conditions.append( - self._get_condition_identifier_property_match( - identifier_column=column, - identifier_filter=identifier_filter, - ) - ) + conditions.extend( + self._build_identifier_filter_conditions( + identifier_filters=identifier_filters, + identifier_column_map={IdentifierType.ATTACK: AttackResultEntry.atomic_attack_identifier}, + caller="get_attack_results", + ) + ) try: entries: Sequence[AttackResultEntry] = self._query_entries( @@ -1863,22 +1881,16 @@ def get_scenario_results( ) if identifier_filters: - for identifier_filter in identifier_filters: - column = None - - match identifier_filter.identifier_type: - case IdentifierType.SCORER: - column = ScenarioResultEntry.objective_scorer_identifier - case IdentifierType.TARGET: - column = ScenarioResultEntry.objective_target_identifier - - if column is not None: - conditions.append( - self._get_condition_identifier_property_match( - identifier_column=column, - identifier_filter=identifier_filter, - ) - ) + conditions.extend( + self._build_identifier_filter_conditions( + identifier_filters=identifier_filters, + identifier_column_map={ + IdentifierType.SCORER: ScenarioResultEntry.objective_scorer_identifier, + IdentifierType.TARGET: ScenarioResultEntry.objective_target_identifier, + }, + caller="get_scenario_results", + ) + ) try: entries: Sequence[ScenarioResultEntry] = self._query_entries( From 7b206f575cb970d1fc8ca0975f013f031a424cb9 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 7 Apr 2026 17:09:05 -0700 Subject: [PATCH 18/30] optimize more --- pyrit/memory/__init__.py | 3 ++- .../memory/memory_interface/test_interface_attack_results.py | 3 --- tests/unit/memory/memory_interface/test_interface_prompts.py | 5 ----- .../memory_interface/test_interface_scenario_results.py | 3 --- tests/unit/memory/memory_interface/test_interface_scores.py | 4 ---- 5 files changed, 2 insertions(+), 16 deletions(-) diff --git a/pyrit/memory/__init__.py b/pyrit/memory/__init__.py index 6098122d7d..cb4f8af272 100644 --- a/pyrit/memory/__init__.py +++ b/pyrit/memory/__init__.py @@ -9,7 +9,7 @@ from pyrit.memory.azure_sql_memory import AzureSQLMemory from pyrit.memory.central_memory import CentralMemory -from pyrit.memory.identifier_filters import IdentifierFilter +from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType from pyrit.memory.memory_embedding import MemoryEmbedding from pyrit.memory.memory_exporter import MemoryExporter from pyrit.memory.memory_interface import MemoryInterface @@ -28,4 +28,5 @@ "PromptMemoryEntry", "SeedEntry", "IdentifierFilter", + "IdentifierType", ] diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index b17fe35fd5..6999ff6bb7 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -1367,7 +1367,6 @@ def test_get_attack_results_by_attack_identifier_filter_hash(sqlite_instance: Me IdentifierFilter( identifier_type=IdentifierType.ATTACK, property_path="$.hash", - sub_path=None, value_to_match=ar1.atomic_attack_identifier.hash, partial_match=False, ) @@ -1390,7 +1389,6 @@ def test_get_attack_results_by_attack_identifier_filter_class_name(sqlite_instan IdentifierFilter( identifier_type=IdentifierType.ATTACK, property_path="$.children.attack.class_name", - sub_path=None, value_to_match="Crescendo", partial_match=True, ) @@ -1410,7 +1408,6 @@ def test_get_attack_results_by_attack_identifier_filter_no_match(sqlite_instance IdentifierFilter( identifier_type=IdentifierType.ATTACK, property_path="$.hash", - sub_path=None, value_to_match="nonexistent_hash", partial_match=False, ) diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 7126968907..9225d06364 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -1280,7 +1280,6 @@ def test_get_message_pieces_by_attack_identifier_filter(sqlite_instance: MemoryI IdentifierFilter( identifier_type=IdentifierType.ATTACK, property_path="$.hash", - sub_path=None, value_to_match=attack1.get_identifier().hash, partial_match=False, ) @@ -1295,7 +1294,6 @@ def test_get_message_pieces_by_attack_identifier_filter(sqlite_instance: MemoryI IdentifierFilter( identifier_type=IdentifierType.ATTACK, property_path="$.hash", - sub_path=None, value_to_match="nonexistent_hash", partial_match=False, ) @@ -1341,7 +1339,6 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI IdentifierFilter( identifier_type=IdentifierType.TARGET, property_path="$.hash", - sub_path=None, value_to_match=target_id_1.hash, partial_match=False, ) @@ -1356,7 +1353,6 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI IdentifierFilter( identifier_type=IdentifierType.TARGET, property_path="$.endpoint", - sub_path=None, value_to_match="openai", partial_match=True, ) @@ -1371,7 +1367,6 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI IdentifierFilter( identifier_type=IdentifierType.TARGET, property_path="$.hash", - sub_path=None, value_to_match="nonexistent", partial_match=False, ) diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index d04931d470..32fdb0a7ee 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -685,7 +685,6 @@ def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: IdentifierFilter( identifier_type=IdentifierType.TARGET, property_path="$.hash", - sub_path=None, value_to_match=target_id_1.hash, partial_match=False, ) @@ -732,7 +731,6 @@ def test_get_scenario_results_by_target_identifier_filter_endpoint(sqlite_instan IdentifierFilter( identifier_type=IdentifierType.TARGET, property_path="$.endpoint", - sub_path=None, value_to_match="openai", partial_match=True, ) @@ -764,7 +762,6 @@ def test_get_scenario_results_by_target_identifier_filter_no_match(sqlite_instan IdentifierFilter( identifier_type=IdentifierType.TARGET, property_path="$.hash", - sub_path=None, value_to_match="nonexistent_hash", partial_match=False, ) diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 10d0888ea7..4fbd9bb865 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -266,7 +266,6 @@ def test_get_scores_by_scorer_identifier_filter( IdentifierFilter( identifier_type=IdentifierType.SCORER, property_path="$.class_name", - sub_path=None, value_to_match="ScorerAlpha", partial_match=False, ) @@ -281,7 +280,6 @@ def test_get_scores_by_scorer_identifier_filter( IdentifierFilter( identifier_type=IdentifierType.SCORER, property_path="$.class_name", - sub_path=None, value_to_match="Scorer", partial_match=True, ) @@ -296,7 +294,6 @@ def test_get_scores_by_scorer_identifier_filter( IdentifierFilter( identifier_type=IdentifierType.SCORER, property_path="$.hash", - sub_path=None, value_to_match=scorer_hash, partial_match=False, ) @@ -311,7 +308,6 @@ def test_get_scores_by_scorer_identifier_filter( IdentifierFilter( identifier_type=IdentifierType.SCORER, property_path="$.class_name", - sub_path=None, value_to_match="NonExistent", partial_match=False, ) From c51cb35c8513bc9f8826ccaae09b5df0c5f9e397 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 7 Apr 2026 17:25:46 -0700 Subject: [PATCH 19/30] little fixes --- pyrit/memory/azure_sql_memory.py | 26 +++----------------------- pyrit/memory/memory_interface.py | 7 ++++++- pyrit/memory/sqlite_memory.py | 5 +++-- 3 files changed, 12 insertions(+), 26 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index a2b3c40ea0..048b580646 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -321,7 +321,7 @@ def _get_condition_json_property_match( json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query. property_path (str): The JSON path for the property to match. value_to_match (str): The string value that must match the extracted JSON property value. - partial_match (bool): Whether to perform a case-insensitive substring match. + partial_match (bool): Whether to perform a substring match. case_sensitive (bool): Whether the match should be case-sensitive. Defaults to False. Returns: @@ -336,7 +336,8 @@ def _get_condition_json_property_match( operator = "LIKE" if partial_match else "=" target = value_to_match if case_sensitive else value_to_match.lower() if partial_match: - target = f"%{target}%" + escaped = target.replace("%", "\\%").replace("_", "\\_") + target = f"%{escaped}%" return text( f"""ISJSON("{table_name}".{column_name}) = 1 @@ -634,27 +635,6 @@ def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece] """ self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces]) - def get_unique_attack_class_names(self) -> list[str]: - """ - Azure SQL implementation: extract unique class_name values from - the atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique attack class name strings. - """ - return super().get_unique_attack_class_names() - - def get_unique_converter_class_names(self) -> list[str]: - """ - Azure SQL implementation: extract unique converter class_name values - from the children.attack.children.request_converters array - in the atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique converter class name strings. - """ - return super().get_unique_converter_class_names() - def dispose_engine(self) -> None: """ Dispose the engine and clean up resources. diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 48d6de5f54..a4d941ab23 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -202,12 +202,17 @@ def _get_condition_json_match( and the condition should resolve if any element in that array matches the value. Cannot be used with partial_match. value_to_match (str): The string value that must match the extracted JSON property value. - partial_match (bool): Whether to perform a case-insensitive substring match. + partial_match (bool): Whether to perform a substring match. case_sensitive (bool): Whether the match should be case-sensitive. Defaults to False. Returns: Any: A SQLAlchemy condition for the backend-specific JSON query. + + Raises: + ValueError: If sub_path is provided together with partial_match or case_sensitive """ + if sub_path and (partial_match or case_sensitive): + raise ValueError("sub_path cannot be combined with partial_match or case_sensitive") if sub_path: return self._get_condition_json_array_match( json_column=json_column, diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index d9f6909274..f2fc61cebe 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -206,7 +206,7 @@ def _get_condition_json_property_match( json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query. property_path (str): The JSON path for the property to match. value_to_match (str): The string value that must match the extracted JSON property value. - partial_match (bool): Whether to perform a case-insensitive substring match. + partial_match (bool): Whether to perform a substring match. case_sensitive (bool): Whether the match should be case-sensitive. Defaults to False. Returns: @@ -219,7 +219,8 @@ def _get_condition_json_property_match( extracted_value, target = func.lower(raw), value_to_match.lower() if partial_match: - return extracted_value.like(f"%{target}%") + escaped = target.replace("%", "\\%").replace("_", "\\_") + return extracted_value.like(f"%{escaped}%", escape="\\") return extracted_value == target def _get_condition_json_array_match( From 71a87417df4a701211a87bb94cd9f97359e4cde2 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Tue, 7 Apr 2026 17:44:18 -0700 Subject: [PATCH 20/30] ghcp feedback --- pyrit/memory/memory_interface.py | 3 +- pyrit/memory/sqlite_memory.py | 40 +++---------- tests/unit/memory/test_identifier_filters.py | 59 ++++++++++++++++++++ 3 files changed, 70 insertions(+), 32 deletions(-) create mode 100644 tests/unit/memory/test_identifier_filters.py diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index a4d941ab23..133c21541b 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -258,7 +258,7 @@ def _get_condition_json_array_match( *, json_column: InstrumentedAttribute[Any], property_path: str, - sub_path: Optional[str] = None, + sub_path: str | None = None, array_to_match: Sequence[str], ) -> Any: """ @@ -1828,6 +1828,7 @@ def get_scenario_results( Defaults to None. identifier_filters (Optional[set[IdentifierFilter]], optional): A set of IdentifierFilter objects that allows filtering by various target identifier JSON properties. + Defaults to None. Returns: Sequence[ScenarioResult]: A list of ScenarioResult objects that match the specified filters. diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index f2fc61cebe..b62f2cf8e1 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -261,22 +261,21 @@ def _get_condition_json_array_match( value_expression = f"LOWER(json_extract(value, :{sp_param}))" if sub_path else "LOWER(value)" conditions = [] + bindparams_dict: dict[str, str] = {pp_param: property_path} + if sub_path: + bindparams_dict[sp_param] = sub_path + for index, match_value in enumerate(array_to_match): mv_param = f"mv_{uid}_{index}" - bind_params: dict[str, str] = { - pp_param: property_path, - mv_param: match_value.lower(), - } - if sub_path: - bind_params[sp_param] = sub_path conditions.append( - text( - f"""EXISTS(SELECT 1 FROM json_each( + f"""EXISTS(SELECT 1 FROM json_each( json_extract("{table_name}".{column_name}, :{pp_param})) WHERE {value_expression} = :{mv_param})""" - ).bindparams(**bind_params) ) - return and_(*conditions) + bindparams_dict[mv_param] = match_value.lower() + + combined = " AND ".join(conditions) + return text(combined).bindparams(**bindparams_dict) def _get_unique_json_array_values( self, @@ -326,27 +325,6 @@ def _get_unique_json_array_values( ).fetchall() return sorted(row[0] for row in rows) - def get_unique_attack_class_names(self) -> list[str]: - """ - SQLite implementation: extract unique class_name values from - the atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique attack class name strings. - """ - return super().get_unique_attack_class_names() - - def get_unique_converter_class_names(self) -> list[str]: - """ - SQLite implementation: extract unique converter class_name values - from the children.attack.children.request_converters array in the - atomic_attack_identifier JSON column. - - Returns: - Sorted list of unique converter class name strings. - """ - return super().get_unique_converter_class_names() - def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Insert a list of message pieces into the memory storage. diff --git a/tests/unit/memory/test_identifier_filters.py b/tests/unit/memory/test_identifier_filters.py new file mode 100644 index 0000000000..21349a0520 --- /dev/null +++ b/tests/unit/memory/test_identifier_filters.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.memory import MemoryInterface +from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType +from pyrit.memory.memory_models import AttackResultEntry + + +@pytest.mark.parametrize( + "sub_path, partial_match, case_sensitive", + [ + ("$.class_name", True, False), + ("$.class_name", False, True), + ("$.class_name", True, True), + ], + ids=["sub_path+partial_match", "sub_path+case_sensitive", "sub_path+both"], +) +def test_identifier_filter_sub_path_with_partial_or_case_sensitive_raises( + sub_path: str, partial_match: bool, case_sensitive: bool +): + with pytest.raises(ValueError, match="Cannot use sub_path with partial_match or case_sensitive"): + IdentifierFilter( + identifier_type=IdentifierType.ATTACK, + property_path="$.children", + value_to_match="test", + sub_path=sub_path, + partial_match=partial_match, + case_sensitive=case_sensitive, + ) + + +def test_identifier_filter_valid_with_sub_path(): + f = IdentifierFilter( + identifier_type=IdentifierType.CONVERTER, + property_path="$", + value_to_match="Base64Converter", + sub_path="$.class_name", + ) + assert f.sub_path == "$.class_name" + assert not f.partial_match + assert not f.case_sensitive + + +def test_build_identifier_filter_conditions_unsupported_type_raises(sqlite_instance: MemoryInterface): + filters = { + IdentifierFilter( + identifier_type=IdentifierType.SCORER, + property_path="$.class_name", + value_to_match="MyScorer", + ) + } + with pytest.raises(ValueError, match="does not support identifier type"): + sqlite_instance._build_identifier_filter_conditions( + identifier_filters=filters, + identifier_column_map={IdentifierType.ATTACK: AttackResultEntry.atomic_attack_identifier}, + caller="test_caller", + ) From 93daed24d518dd0fed92714a40f334a5d3141879 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 8 Apr 2026 07:23:21 -0700 Subject: [PATCH 21/30] nits --- pyrit/memory/memory_interface.py | 33 +++++++------------------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 133c21541b..508bef9bc8 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -119,29 +119,6 @@ def disable_embedding(self) -> None: """ self.memory_embedding = None - def _get_condition_identifier_property_match( - self, *, identifier_column: Any, identifier_filter: IdentifierFilter - ) -> Any: - """ - Build a SQLAlchemy condition that matches a JSON identifier column against the given filter. - - Args: - identifier_column (Any): The JSON-backed SQLAlchemy column to query. - identifier_filter (IdentifierFilter): The filter specifying the property path, - optional sub-path, value to match, and whether to use partial matching. - - Returns: - Any: A SQLAlchemy condition for the backend-specific JSON query. - """ - return self._get_condition_json_match( - json_column=identifier_column, - property_path=identifier_filter.property_path, - sub_path=identifier_filter.sub_path, - value_to_match=identifier_filter.value_to_match, - partial_match=identifier_filter.partial_match, - case_sensitive=identifier_filter.case_sensitive, - ) - def _build_identifier_filter_conditions( self, *, @@ -174,9 +151,13 @@ def _build_identifier_filter_conditions( f"{identifier_filter.identifier_type!r}. Supported: {supported}" ) conditions.append( - self._get_condition_identifier_property_match( - identifier_column=column, - identifier_filter=identifier_filter, + self._get_condition_json_match( + json_column=column, + property_path=identifier_filter.property_path, + sub_path=identifier_filter.sub_path, + value_to_match=identifier_filter.value_to_match, + partial_match=identifier_filter.partial_match, + case_sensitive=identifier_filter.case_sensitive, ) ) return conditions From f7a99be9cb6622882e019c9d74d56697e728c44a Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 8 Apr 2026 07:29:56 -0700 Subject: [PATCH 22/30] escape --- pyrit/memory/azure_sql_memory.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 048b580646..526c860664 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -339,9 +339,10 @@ def _get_condition_json_property_match( escaped = target.replace("%", "\\%").replace("_", "\\_") target = f"%{escaped}%" + escape_clause = " ESCAPE '\\'" if partial_match else "" return text( f"""ISJSON("{table_name}".{column_name}) = 1 - AND {json_func}("{table_name}".{column_name}, :{pp_param}) {operator} :{mv_param}""" + AND {json_func}("{table_name}".{column_name}, :{pp_param}) {operator} :{mv_param}{escape_clause}""" ).bindparams( **{ pp_param: property_path, From b7174d4bedb22de1f3f635ff241d907eb17ac816 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 8 Apr 2026 07:44:32 -0700 Subject: [PATCH 23/30] copilot recommendations --- pyrit/memory/memory_interface.py | 24 +++++++------- .../test_interface_attack_results.py | 12 +++---- .../test_interface_prompts.py | 32 +++++++++---------- .../test_interface_scenario_results.py | 12 +++---- .../memory_interface/test_interface_scores.py | 16 +++++----- 5 files changed, 48 insertions(+), 48 deletions(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 508bef9bc8..5ae19e7598 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -122,7 +122,7 @@ def disable_embedding(self) -> None: def _build_identifier_filter_conditions( self, *, - identifier_filters: set[IdentifierFilter], + identifier_filters: Sequence[IdentifierFilter], identifier_column_map: dict[IdentifierType, Any], caller: str, ) -> list[Any]: @@ -130,7 +130,7 @@ def _build_identifier_filter_conditions( Build SQLAlchemy conditions from a set of IdentifierFilters. Args: - identifier_filters (set[IdentifierFilter]): The filters to convert to conditions. + identifier_filters (Sequence[IdentifierFilter]): The filters to convert to conditions. identifier_column_map (dict[IdentifierType, Any]): Mapping from IdentifierType to the JSON-backed SQLAlchemy column that should be queried for that type. caller (str): Name of the calling method, used in error messages. @@ -193,7 +193,7 @@ def _get_condition_json_match( ValueError: If sub_path is provided together with partial_match or case_sensitive """ if sub_path and (partial_match or case_sensitive): - raise ValueError("sub_path cannot be combined with partial_match or case_sensitive") + raise ValueError("Cannot use sub_path with partial_match or case_sensitive") if sub_path: return self._get_condition_json_array_match( json_column=json_column, @@ -226,7 +226,7 @@ def _get_condition_json_property_match( json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query. property_path (str): The JSON path for the property to match. value_to_match (str): The string value that must match the extracted JSON property value. - partial_match (bool): Whether to perform a case-insensitive substring match. + partial_match (bool): Whether to perform a substring match. case_sensitive (bool): Whether the match should be case-sensitive. Defaults to False. Returns: @@ -537,7 +537,7 @@ def get_scores( score_category: Optional[str] = None, sent_after: Optional[datetime] = None, sent_before: Optional[datetime] = None, - identifier_filters: Optional[set[IdentifierFilter]] = None, + identifier_filters: Optional[Sequence[IdentifierFilter]] = None, ) -> Sequence[Score]: """ Retrieve a list of Score objects based on the specified filters. @@ -548,7 +548,7 @@ def get_scores( score_category (Optional[str]): The category of the score to filter by. sent_after (Optional[datetime]): Filter for scores sent after this datetime. sent_before (Optional[datetime]): Filter for scores sent before this datetime. - identifier_filters (Optional[set[IdentifierFilter]]): A set of IdentifierFilter objects that + identifier_filters (Optional[Sequence[IdentifierFilter]]): A set of IdentifierFilter objects that allows filtering by various scorer identifier JSON properties. Defaults to None. Returns: @@ -704,7 +704,7 @@ def get_message_pieces( data_type: Optional[str] = None, not_data_type: Optional[str] = None, converted_value_sha256: Optional[Sequence[str]] = None, - identifier_filters: Optional[set[IdentifierFilter]] = None, + identifier_filters: Optional[Sequence[IdentifierFilter]] = None, ) -> Sequence[MessagePiece]: """ Retrieve a list of MessagePiece objects based on the specified filters. @@ -726,7 +726,7 @@ def get_message_pieces( not_data_type (Optional[str], optional): The data type to exclude. Defaults to None. converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. Defaults to None. - identifier_filters (Optional[set[IdentifierFilter]], optional): + identifier_filters (Optional[Sequence[IdentifierFilter]], optional): A set of IdentifierFilter objects that allow filtering by various identifier JSON properties. Defaults to None. @@ -1509,7 +1509,7 @@ def get_attack_results( converter_classes: Optional[Sequence[str]] = None, targeted_harm_categories: Optional[Sequence[str]] = None, labels: Optional[dict[str, str]] = None, - identifier_filters: Optional[set[IdentifierFilter]] = None, + identifier_filters: Optional[Sequence[IdentifierFilter]] = None, ) -> Sequence[AttackResult]: """ Retrieve a list of AttackResult objects based on the specified filters. @@ -1537,7 +1537,7 @@ def get_attack_results( labels (Optional[dict[str, str]], optional): A dictionary of memory labels to filter results by. These labels are associated with the prompts themselves, used for custom tagging and tracking. Defaults to None. - identifier_filters (Optional[set[IdentifierFilter]], optional): + identifier_filters (Optional[Sequence[IdentifierFilter]], optional): A set of IdentifierFilter objects that allows filtering by various attack identifier JSON properties. Defaults to None. @@ -1783,7 +1783,7 @@ def get_scenario_results( labels: Optional[dict[str, str]] = None, objective_target_endpoint: Optional[str] = None, objective_target_model_name: Optional[str] = None, - identifier_filters: Optional[set[IdentifierFilter]] = None, + identifier_filters: Optional[Sequence[IdentifierFilter]] = None, ) -> Sequence[ScenarioResult]: """ Retrieve a list of ScenarioResult objects based on the specified filters. @@ -1807,7 +1807,7 @@ def get_scenario_results( objective_target_model_name (Optional[str], optional): Filter for scenarios where the objective_target_identifier has a model_name attribute containing this value (case-insensitive). Defaults to None. - identifier_filters (Optional[set[IdentifierFilter]], optional): + identifier_filters (Optional[Sequence[IdentifierFilter]], optional): A set of IdentifierFilter objects that allows filtering by various target identifier JSON properties. Defaults to None. diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 6999ff6bb7..03600e3260 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -1363,14 +1363,14 @@ def test_get_attack_results_by_attack_identifier_filter_hash(sqlite_instance: Me # Filter by hash of ar1's attack identifier results = sqlite_instance.get_attack_results( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.ATTACK, property_path="$.hash", value_to_match=ar1.atomic_attack_identifier.hash, partial_match=False, ) - }, + ], ) assert len(results) == 1 assert results[0].conversation_id == "conv_1" @@ -1385,14 +1385,14 @@ def test_get_attack_results_by_attack_identifier_filter_class_name(sqlite_instan # Filter by partial attack class name results = sqlite_instance.get_attack_results( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.ATTACK, property_path="$.children.attack.class_name", value_to_match="Crescendo", partial_match=True, ) - }, + ], ) assert len(results) == 2 assert {r.conversation_id for r in results} == {"conv_1", "conv_3"} @@ -1404,13 +1404,13 @@ def test_get_attack_results_by_attack_identifier_filter_no_match(sqlite_instance sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) results = sqlite_instance.get_attack_results( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.ATTACK, property_path="$.hash", value_to_match="nonexistent_hash", partial_match=False, ) - }, + ], ) assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 9225d06364..4921da7df7 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -1276,28 +1276,28 @@ def test_get_message_pieces_by_attack_identifier_filter(sqlite_instance: MemoryI # Filter by exact attack hash results = sqlite_instance.get_message_pieces( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.ATTACK, property_path="$.hash", value_to_match=attack1.get_identifier().hash, partial_match=False, ) - }, + ], ) assert len(results) == 1 assert results[0].original_value == "Hello 1" # No match results = sqlite_instance.get_message_pieces( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.ATTACK, property_path="$.hash", value_to_match="nonexistent_hash", partial_match=False, ) - }, + ], ) assert len(results) == 0 @@ -1335,42 +1335,42 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI # Filter by target hash results = sqlite_instance.get_message_pieces( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.TARGET, property_path="$.hash", value_to_match=target_id_1.hash, partial_match=False, ) - }, + ], ) assert len(results) == 1 assert results[0].original_value == "Hello OpenAI" # Filter by endpoint partial match results = sqlite_instance.get_message_pieces( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.TARGET, property_path="$.endpoint", value_to_match="openai", partial_match=True, ) - }, + ], ) assert len(results) == 1 assert results[0].original_value == "Hello OpenAI" # No match results = sqlite_instance.get_message_pieces( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.TARGET, property_path="$.hash", value_to_match="nonexistent", partial_match=False, ) - }, + ], ) assert len(results) == 0 @@ -1412,14 +1412,14 @@ def test_get_message_pieces_by_converter_identifier_filter_with_sub_path(sqlite_ # Filter by converter class_name using sub_path (array element matching) results = sqlite_instance.get_message_pieces( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.CONVERTER, property_path="$", sub_path="$.class_name", value_to_match="Base64Converter", ) - }, + ], ) assert len(results) == 2 original_values = {r.original_value for r in results} @@ -1427,27 +1427,27 @@ def test_get_message_pieces_by_converter_identifier_filter_with_sub_path(sqlite_ # Filter by ROT13Converter — only the entry with both converters results = sqlite_instance.get_message_pieces( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.CONVERTER, property_path="$", sub_path="$.class_name", value_to_match="ROT13Converter", ) - }, + ], ) assert len(results) == 1 assert results[0].original_value == "With both converters" # No match results = sqlite_instance.get_message_pieces( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.CONVERTER, property_path="$", sub_path="$.class_name", value_to_match="NonexistentConverter", ) - }, + ], ) assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index 32fdb0a7ee..3696705c5a 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -681,14 +681,14 @@ def test_get_scenario_results_by_target_identifier_filter_hash(sqlite_instance: # Filter by target hash results = sqlite_instance.get_scenario_results( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.TARGET, property_path="$.hash", value_to_match=target_id_1.hash, partial_match=False, ) - }, + ], ) assert len(results) == 1 assert results[0].scenario_identifier.name == "Scenario OpenAI" @@ -727,14 +727,14 @@ def test_get_scenario_results_by_target_identifier_filter_endpoint(sqlite_instan # Filter by endpoint partial match results = sqlite_instance.get_scenario_results( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.TARGET, property_path="$.endpoint", value_to_match="openai", partial_match=True, ) - }, + ], ) assert len(results) == 1 assert results[0].scenario_identifier.name == "Scenario OpenAI" @@ -758,13 +758,13 @@ def test_get_scenario_results_by_target_identifier_filter_no_match(sqlite_instan sqlite_instance.add_scenario_results_to_memory(scenario_results=[scenario1]) results = sqlite_instance.get_scenario_results( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.TARGET, property_path="$.hash", value_to_match="nonexistent_hash", partial_match=False, ) - }, + ], ) assert len(results) == 0 diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 4fbd9bb865..1b2f79f47b 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -262,55 +262,55 @@ def test_get_scores_by_scorer_identifier_filter( # Filter by exact class_name match results = sqlite_instance.get_scores( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.SCORER, property_path="$.class_name", value_to_match="ScorerAlpha", partial_match=False, ) - }, + ], ) assert len(results) == 1 assert results[0].score_value == "0.9" # Filter by partial class_name match results = sqlite_instance.get_scores( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.SCORER, property_path="$.class_name", value_to_match="Scorer", partial_match=True, ) - }, + ], ) assert len(results) == 2 # Filter by hash scorer_hash = score_a.scorer_class_identifier.hash results = sqlite_instance.get_scores( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.SCORER, property_path="$.hash", value_to_match=scorer_hash, partial_match=False, ) - }, + ], ) assert len(results) == 1 assert results[0].score_value == "0.9" # No match results = sqlite_instance.get_scores( - identifier_filters={ + identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.SCORER, property_path="$.class_name", value_to_match="NonExistent", partial_match=False, ) - }, + ], ) assert len(results) == 0 From b2a7f4126449caa24e8adb73471fa661eb2df510 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 8 Apr 2026 07:50:59 -0700 Subject: [PATCH 24/30] doc update --- pyrit/memory/memory_interface.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 5ae19e7598..2202521294 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -127,7 +127,7 @@ def _build_identifier_filter_conditions( caller: str, ) -> list[Any]: """ - Build SQLAlchemy conditions from a set of IdentifierFilters. + Build SQLAlchemy conditions from a sequence of IdentifierFilters. Args: identifier_filters (Sequence[IdentifierFilter]): The filters to convert to conditions. @@ -548,7 +548,7 @@ def get_scores( score_category (Optional[str]): The category of the score to filter by. sent_after (Optional[datetime]): Filter for scores sent after this datetime. sent_before (Optional[datetime]): Filter for scores sent before this datetime. - identifier_filters (Optional[Sequence[IdentifierFilter]]): A set of IdentifierFilter objects that + identifier_filters (Optional[Sequence[IdentifierFilter]]): A sequence of IdentifierFilter objects that allows filtering by various scorer identifier JSON properties. Defaults to None. Returns: @@ -727,7 +727,7 @@ def get_message_pieces( converted_value_sha256 (Optional[Sequence[str]], optional): A list of SHA256 hashes of converted values. Defaults to None. identifier_filters (Optional[Sequence[IdentifierFilter]], optional): - A set of IdentifierFilter objects that + A sequence of IdentifierFilter objects that allow filtering by various identifier JSON properties. Defaults to None. Returns: @@ -1538,7 +1538,7 @@ def get_attack_results( These labels are associated with the prompts themselves, used for custom tagging and tracking. Defaults to None. identifier_filters (Optional[Sequence[IdentifierFilter]], optional): - A set of IdentifierFilter objects that allows filtering by various attack identifier + A sequence of IdentifierFilter objects that allows filtering by various attack identifier JSON properties. Defaults to None. Returns: @@ -1808,7 +1808,7 @@ def get_scenario_results( objective_target_identifier has a model_name attribute containing this value (case-insensitive). Defaults to None. identifier_filters (Optional[Sequence[IdentifierFilter]], optional): - A set of IdentifierFilter objects that allows filtering by various target identifier JSON properties. + A sequence of IdentifierFilter objects that allows filtering by various target identifier JSON properties. Defaults to None. Returns: From 61a042046fb5f82e8a650c17ba5272c0f0236df4 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 8 Apr 2026 08:02:24 -0700 Subject: [PATCH 25/30] sequence in test --- tests/unit/memory/test_identifier_filters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/memory/test_identifier_filters.py b/tests/unit/memory/test_identifier_filters.py index 21349a0520..62d9c0f745 100644 --- a/tests/unit/memory/test_identifier_filters.py +++ b/tests/unit/memory/test_identifier_filters.py @@ -44,13 +44,13 @@ def test_identifier_filter_valid_with_sub_path(): def test_build_identifier_filter_conditions_unsupported_type_raises(sqlite_instance: MemoryInterface): - filters = { + filters = [ IdentifierFilter( identifier_type=IdentifierType.SCORER, property_path="$.class_name", value_to_match="MyScorer", ) - } + ] with pytest.raises(ValueError, match="does not support identifier type"): sqlite_instance._build_identifier_filter_conditions( identifier_filters=filters, From 9f22cecc1a919e64f9148f08cf62b116fe1a1ff6 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 8 Apr 2026 08:16:59 -0700 Subject: [PATCH 26/30] doc --- pyrit/memory/memory_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 2202521294..c371d816b8 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -1808,7 +1808,7 @@ def get_scenario_results( objective_target_identifier has a model_name attribute containing this value (case-insensitive). Defaults to None. identifier_filters (Optional[Sequence[IdentifierFilter]], optional): - A sequence of IdentifierFilter objects that allows filtering by various target identifier JSON properties. + A sequence of IdentifierFilter objects that allows filtering by identifier JSON properties. Defaults to None. Returns: From 899864d1c30008424d30ed00e81aedc5a258030c Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 8 Apr 2026 08:30:28 -0700 Subject: [PATCH 27/30] drop the generic unique value methods. not related to identifier filters --- pyrit/memory/azure_sql_memory.py | 98 ++++++++++++++------------------ pyrit/memory/memory_interface.py | 36 +----------- pyrit/memory/sqlite_memory.py | 86 +++++++++++++--------------- 3 files changed, 83 insertions(+), 137 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 526c860664..5108e85030 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -404,61 +404,6 @@ def _get_condition_json_array_match( combined = " AND ".join(conditions) return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict) - def _get_unique_json_array_values( - self, - *, - json_column: Any, - path_to_array: str, - sub_path: str | None = None, - ) -> list[str]: - """ - Return sorted unique values in an array located at a given path within a JSON object in an Azure SQL DB Column. - - This method performs a database-level query to extract distinct values from a - an array within a JSON-type column. When ``sub_path`` is provided, the distinct values are - extracted from each array item using the sub-path. - - Args: - json_column (Any): The JSON-backed model field to query. - path_to_array (str): The JSON path to the array whose unique values are extracted. - sub_path (str | None): Optional JSON path applied to each array - item before collecting distinct values. - - Returns: - list[str]: A sorted list of unique values in the array. - """ - uid = self._uid() - pa_param = f"pa_{uid}" - sp_param = f"sp_{uid}" - table_name = json_column.class_.__tablename__ - column_name = json_column.key - with closing(self.get_session()) as session: - if sub_path is None: - rows = session.execute( - text( - f"""SELECT DISTINCT JSON_VALUE("{table_name}".{column_name}, :{pa_param}) AS value - FROM "{table_name}" - WHERE ISJSON("{table_name}".{column_name}) = 1 - AND JSON_VALUE("{table_name}".{column_name}, :{pa_param}) IS NOT NULL""" - ).bindparams(**{pa_param: path_to_array}) - ).fetchall() - else: - rows = session.execute( - text( - f"""SELECT DISTINCT JSON_VALUE(items.value, :{sp_param}) AS value - FROM "{table_name}" - CROSS APPLY OPENJSON(JSON_QUERY("{table_name}".{column_name}, :{pa_param})) AS items - WHERE ISJSON("{table_name}".{column_name}) = 1 - AND JSON_VALUE(items.value, :{sp_param}) IS NOT NULL""" - ).bindparams( - **{ - pa_param: path_to_array, - sp_param: sub_path, - } - ) - ).fetchall() - return sorted(row[0] for row in rows) - def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any: """ Get the SQL Azure implementation for filtering AttackResults by targeted harm categories. @@ -526,6 +471,49 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) ) + def get_unique_attack_class_names(self) -> list[str]: + """ + Azure SQL implementation: extract unique class_name values from + the atomic_attack_identifier JSON column. + + Returns: + Sorted list of unique attack class name strings. + """ + with closing(self.get_session()) as session: + rows = session.execute( + text( + """SELECT DISTINCT JSON_VALUE(atomic_attack_identifier, + '$.children.attack.class_name') AS cls + FROM "AttackResultEntries" + WHERE ISJSON(atomic_attack_identifier) = 1 + AND JSON_VALUE(atomic_attack_identifier, + '$.children.attack.class_name') IS NOT NULL""" + ) + ).fetchall() + return sorted(row[0] for row in rows) + + def get_unique_converter_class_names(self) -> list[str]: + """ + Azure SQL implementation: extract unique converter class_name values + from the children.attack.children.request_converters array + in the atomic_attack_identifier JSON column. + + Returns: + Sorted list of unique converter class name strings. + """ + with closing(self.get_session()) as session: + rows = session.execute( + text( + """SELECT DISTINCT JSON_VALUE(c.value, '$.class_name') AS cls + FROM "AttackResultEntries" + CROSS APPLY OPENJSON(JSON_QUERY(atomic_attack_identifier, + '$.children.attack.children.request_converters')) AS c + WHERE ISJSON(atomic_attack_identifier) = 1 + AND JSON_VALUE(c.value, '$.class_name') IS NOT NULL""" + ) + ).fetchall() + return sorted(row[0] for row in rows) + def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, ConversationStats]: """ Azure SQL implementation: lightweight aggregate stats per conversation. diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index c371d816b8..20608cb8c7 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -257,31 +257,6 @@ def _get_condition_json_array_match( Any: A database-specific SQLAlchemy condition. """ - @abc.abstractmethod - def _get_unique_json_array_values( - self, - *, - json_column: Any, - path_to_array: str, - sub_path: str | None = None, - ) -> list[str]: - """ - Return sorted unique values in an array located at a given path within a JSON object. - - This method performs a database-level query to extract distinct values from a - an array within a JSON-type column. When ``sub_path`` is provided, the distinct values are - extracted from each array item using the sub-path. - - Args: - json_column (Any): The JSON-backed model field to query. - path_to_array (str): The JSON path to the array whose unique values are extracted. - sub_path (str | None): Optional JSON path applied to each array - item before collecting distinct values. - - Returns: - list[str]: A sorted list of unique values in the array. - """ - @abc.abstractmethod def get_all_embeddings(self) -> Sequence[EmbeddingDataEntry]: """ @@ -452,6 +427,7 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: Database-specific SQLAlchemy condition. """ + @abc.abstractmethod def get_unique_attack_class_names(self) -> list[str]: """ Return sorted unique attack class names from all stored attack results. @@ -462,11 +438,8 @@ def get_unique_attack_class_names(self) -> list[str]: Returns: Sorted list of unique attack class name strings. """ - return self._get_unique_json_array_values( - json_column=AttackResultEntry.atomic_attack_identifier, - path_to_array="$.children.attack.class_name", - ) + @abc.abstractmethod def get_unique_converter_class_names(self) -> list[str]: """ Return sorted unique converter class names used across all attack results. @@ -477,11 +450,6 @@ def get_unique_converter_class_names(self) -> list[str]: Returns: Sorted list of unique converter class name strings. """ - return self._get_unique_json_array_values( - json_column=AttackResultEntry.atomic_attack_identifier, - path_to_array="$.children.attack.children.request_converters", - sub_path="$.class_name", - ) @abc.abstractmethod def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, "ConversationStats"]: diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index b62f2cf8e1..0ea7dc6cab 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -277,54 +277,6 @@ def _get_condition_json_array_match( combined = " AND ".join(conditions) return text(combined).bindparams(**bindparams_dict) - def _get_unique_json_array_values( - self, - *, - json_column: Any, - path_to_array: str, - sub_path: str | None = None, - ) -> list[str]: - """ - Return sorted unique values in an array located at a given path within a JSON object in a SQLite DB Column. - - This method performs a database-level query to extract distinct values from a - an array within a JSON-type column. When ``sub_path`` is provided, the distinct values are - extracted from each array item using the sub-path. - - Args: - json_column (Any): The JSON-backed model field to query. - path_to_array (str): The JSON path to the array whose unique values are extracted. - sub_path (str | None): Optional JSON path applied to each array - item before collecting distinct values. - - Returns: - list[str]: A sorted list of unique values in the array. - """ - with closing(self.get_session()) as session: - if sub_path is None: - property_expr = func.json_extract(json_column, path_to_array) - rows = session.query(property_expr).filter(property_expr.isnot(None)).distinct().all() - else: - uid = self._uid() - pa_param = f"path_to_array_{uid}" - sp_param = f"sub_path_{uid}" - table_name = json_column.class_.__tablename__ - column_name = json_column.key - rows = session.execute( - text( - f"""SELECT DISTINCT json_extract(j.value, :{sp_param}) AS value - FROM "{table_name}", - json_each(json_extract("{table_name}".{column_name}, :{pa_param})) AS j - WHERE json_extract(j.value, :{sp_param}) IS NOT NULL""" - ).bindparams( - **{ - pa_param: path_to_array, - sp_param: sub_path, - } - ) - ).fetchall() - return sorted(row[0] for row in rows) - def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None: """ Insert a list of message pieces into the memory storage. @@ -652,6 +604,44 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) return labels_subquery # noqa: RET504 + def get_unique_attack_class_names(self) -> list[str]: + """ + SQLite implementation: extract unique class_name values from + the atomic_attack_identifier JSON column. + + Returns: + Sorted list of unique attack class name strings. + """ + with closing(self.get_session()) as session: + class_name_expr = func.json_extract( + AttackResultEntry.atomic_attack_identifier, "$.children.attack.class_name" + ) + rows = session.query(class_name_expr).filter(class_name_expr.isnot(None)).distinct().all() + return sorted(row[0] for row in rows) + + def get_unique_converter_class_names(self) -> list[str]: + """ + SQLite implementation: extract unique converter class_name values + from the children.attack.children.request_converters array in the + atomic_attack_identifier JSON column. + + Returns: + Sorted list of unique converter class name strings. + """ + with closing(self.get_session()) as session: + rows = session.execute( + text( + """SELECT DISTINCT json_extract(j.value, '$.class_name') AS cls + FROM "AttackResultEntries", + json_each( + json_extract("AttackResultEntries".atomic_attack_identifier, + '$.children.attack.children.request_converters') + ) AS j + WHERE cls IS NOT NULL""" + ) + ).fetchall() + return sorted(row[0] for row in rows) + def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, ConversationStats]: """ SQLite implementation: lightweight aggregate stats per conversation. From ce8fd543a567c4d27d78d75e7b4bc1c348dbf1ee Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 8 Apr 2026 08:47:28 -0700 Subject: [PATCH 28/30] renames --- pyrit/memory/azure_sql_memory.py | 10 ++++---- pyrit/memory/identifier_filters.py | 19 ++++++++------- pyrit/memory/memory_interface.py | 24 +++++++++---------- pyrit/memory/sqlite_memory.py | 12 +++++----- .../test_interface_prompts.py | 10 ++++---- tests/unit/memory/test_identifier_filters.py | 18 +++++++------- 6 files changed, 48 insertions(+), 45 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 5108e85030..0e8777a0aa 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -355,7 +355,7 @@ def _get_condition_json_array_match( *, json_column: Any, property_path: str, - sub_path: str | None = None, + array_element_path: str | None = None, array_to_match: Sequence[str], ) -> Any: """ @@ -364,7 +364,7 @@ def _get_condition_json_array_match( Args: json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. property_path (str): The JSON path for the target array. - sub_path (Optional[str]): An optional JSON path applied to each array item before matching. + array_element_path (Optional[str]): An optional JSON path applied to each array item before matching. array_to_match (Sequence[str]): The array that must match the extracted JSON array values. For a match, ALL values in this array must be present in the JSON array. If `array_to_match` is empty, the condition must match only if the target is also an empty array or None. @@ -385,12 +385,12 @@ def _get_condition_json_array_match( OR JSON_QUERY("{table_name}".{column_name}, :{pp_param}) = '[]')""" ).bindparams(**{pp_param: property_path}) - value_expression = f"LOWER(JSON_VALUE(value, :{sp_param}))" if sub_path else "LOWER(value)" + value_expression = f"LOWER(JSON_VALUE(value, :{sp_param}))" if array_element_path else "LOWER(value)" conditions = [] bindparams_dict: dict[str, str] = {pp_param: property_path} - if sub_path: - bindparams_dict[sp_param] = sub_path + if array_element_path: + bindparams_dict[sp_param] = array_element_path for index, match_value in enumerate(array_to_match): mv_param = f"mv_{uid}_{index}" diff --git a/pyrit/memory/identifier_filters.py b/pyrit/memory/identifier_filters.py index c0e545ec41..357625bbb6 100644 --- a/pyrit/memory/identifier_filters.py +++ b/pyrit/memory/identifier_filters.py @@ -22,18 +22,19 @@ class IdentifierFilter: Attributes: identifier_type: The type of identifier column to filter on. property_path: The JSON path for the property to match. - sub_path: An optional JSON path that indicates the property at property_path is an array - and the condition should resolve if any element in that array matches the value. - Cannot be used with partial_match. + array_element_path : An optional JSON path that indicates the property at property_path is an array + and the condition should resolve if the value at array_element_path matches the target + for any element in that array. Cannot be used with partial_match or case_sensitive. value_to_match: The string value that must match the extracted JSON property value. - partial_match: Whether to perform a substring match. Cannot be used with sub_path. - case_sensitive: Whether the match should be case-sensitive. Defaults to False. + partial_match: Whether to perform a substring match. Cannot be used with array_element_path or case_sensitive. + case_sensitive: Whether the match should be case-sensitive. + Cannot be used with array_element_path or partial_match. """ identifier_type: IdentifierType property_path: str value_to_match: str - sub_path: str | None = None + array_element_path: str | None = None partial_match: bool = False case_sensitive: bool = False @@ -44,5 +45,7 @@ def __post_init__(self) -> None: Raises: ValueError: If the filter configuration is not valid. """ - if self.sub_path and (self.partial_match or self.case_sensitive): - raise ValueError("Cannot use sub_path with partial_match or case_sensitive") + if self.array_element_path and (self.partial_match or self.case_sensitive): + raise ValueError("Cannot use array_element_path with partial_match or case_sensitive") + if self.partial_match and self.case_sensitive: + raise ValueError("case_sensitive is not reliably supported with partial_match across all backends") diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 20608cb8c7..828ab7c96f 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -154,7 +154,7 @@ def _build_identifier_filter_conditions( self._get_condition_json_match( json_column=column, property_path=identifier_filter.property_path, - sub_path=identifier_filter.sub_path, + array_element_path=identifier_filter.array_element_path, value_to_match=identifier_filter.value_to_match, partial_match=identifier_filter.partial_match, case_sensitive=identifier_filter.case_sensitive, @@ -167,19 +167,19 @@ def _get_condition_json_match( *, json_column: InstrumentedAttribute[Any], property_path: str, - sub_path: str | None = None, + array_element_path: str | None = None, value_to_match: str, partial_match: bool = False, case_sensitive: bool = False, ) -> Any: """ Return a database-specific condition for matching a value at a given path within a JSON object - or within items of a JSON array if sub_path is provided. + or within items of a JSON array if array_element_path is provided. Args: json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query. property_path (str): The JSON path for the property to match. - sub_path (str | None): An optional JSON path that indicates property at property_path is an array + array_element_path (str | None): An optional JSON path that indicates property at property_path is an array and the condition should resolve if any element in that array matches the value. Cannot be used with partial_match. value_to_match (str): The string value that must match the extracted JSON property value. @@ -190,15 +190,15 @@ def _get_condition_json_match( Any: A SQLAlchemy condition for the backend-specific JSON query. Raises: - ValueError: If sub_path is provided together with partial_match or case_sensitive + ValueError: If array_element_path is provided together with partial_match or case_sensitive """ - if sub_path and (partial_match or case_sensitive): - raise ValueError("Cannot use sub_path with partial_match or case_sensitive") - if sub_path: + if array_element_path and (partial_match or case_sensitive): + raise ValueError("Cannot use array_element_path with partial_match or case_sensitive") + if array_element_path: return self._get_condition_json_array_match( json_column=json_column, property_path=property_path, - sub_path=sub_path, + array_element_path=array_element_path, array_to_match=[value_to_match], ) return self._get_condition_json_property_match( @@ -239,7 +239,7 @@ def _get_condition_json_array_match( *, json_column: InstrumentedAttribute[Any], property_path: str, - sub_path: str | None = None, + array_element_path: str | None = None, array_to_match: Sequence[str], ) -> Any: """ @@ -248,7 +248,7 @@ def _get_condition_json_array_match( Args: json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. property_path (str): The JSON path for the target array. - sub_path (Optional[str]): An optional JSON path applied to each array item before matching. + array_element_path (Optional[str]): An optional JSON path applied to each array item before matching. array_to_match (Sequence[str]): The array that must match the extracted JSON array values. For a match, ALL values in this array must be present in the JSON array. If `array_to_match` is empty, the condition must match only if the target is also an empty array or None. @@ -1547,7 +1547,7 @@ def get_attack_results( self._get_condition_json_array_match( json_column=AttackResultEntry.atomic_attack_identifier, property_path="$.children.attack.children.request_converters", - sub_path="$.class_name", + array_element_path="$.class_name", array_to_match=converter_classes, ) ) diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 0ea7dc6cab..ce65bb2381 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -228,7 +228,7 @@ def _get_condition_json_array_match( *, json_column: Any, property_path: str, - sub_path: str | None = None, + array_element_path: str | None = None, array_to_match: Sequence[str], ) -> Any: """ @@ -237,7 +237,7 @@ def _get_condition_json_array_match( Args: json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query. property_path (str): The JSON path for the target array. - sub_path (Optional[str]): An optional JSON path applied to each array item before matching. + array_element_path (Optional[str]): An optional JSON path applied to each array item before matching. array_to_match (Sequence[str]): The array that must match the extracted JSON array values. For a match, ALL values in this array must be present in the JSON array. If `array_to_match` is empty, the condition must match only if the target is also an empty array or None. @@ -257,13 +257,13 @@ def _get_condition_json_array_match( table_name = json_column.class_.__tablename__ column_name = json_column.key pp_param = f"property_path_{uid}" - sp_param = f"sub_path_{uid}" - value_expression = f"LOWER(json_extract(value, :{sp_param}))" if sub_path else "LOWER(value)" + sp_param = f"array_element_path_{uid}" + value_expression = f"LOWER(json_extract(value, :{sp_param}))" if array_element_path else "LOWER(value)" conditions = [] bindparams_dict: dict[str, str] = {pp_param: property_path} - if sub_path: - bindparams_dict[sp_param] = sub_path + if array_element_path: + bindparams_dict[sp_param] = array_element_path for index, match_value in enumerate(array_to_match): mv_param = f"mv_{uid}_{index}" diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index 4921da7df7..e85ce02739 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -1375,7 +1375,7 @@ def test_get_message_pieces_by_target_identifier_filter(sqlite_instance: MemoryI assert len(results) == 0 -def test_get_message_pieces_by_converter_identifier_filter_with_sub_path(sqlite_instance: MemoryInterface): +def test_get_message_pieces_by_converter_identifier_filter_with_array_element_path(sqlite_instance: MemoryInterface): converter_a = ComponentIdentifier( class_name="Base64Converter", class_module="pyrit.prompt_converter", @@ -1410,13 +1410,13 @@ def test_get_message_pieces_by_converter_identifier_filter_with_sub_path(sqlite_ sqlite_instance._insert_entries(entries=entries) - # Filter by converter class_name using sub_path (array element matching) + # Filter by converter class_name using array_element_path (array element matching) results = sqlite_instance.get_message_pieces( identifier_filters=[ IdentifierFilter( identifier_type=IdentifierType.CONVERTER, property_path="$", - sub_path="$.class_name", + array_element_path="$.class_name", value_to_match="Base64Converter", ) ], @@ -1431,7 +1431,7 @@ def test_get_message_pieces_by_converter_identifier_filter_with_sub_path(sqlite_ IdentifierFilter( identifier_type=IdentifierType.CONVERTER, property_path="$", - sub_path="$.class_name", + array_element_path="$.class_name", value_to_match="ROT13Converter", ) ], @@ -1445,7 +1445,7 @@ def test_get_message_pieces_by_converter_identifier_filter_with_sub_path(sqlite_ IdentifierFilter( identifier_type=IdentifierType.CONVERTER, property_path="$", - sub_path="$.class_name", + array_element_path="$.class_name", value_to_match="NonexistentConverter", ) ], diff --git a/tests/unit/memory/test_identifier_filters.py b/tests/unit/memory/test_identifier_filters.py index 62d9c0f745..8316ef08ba 100644 --- a/tests/unit/memory/test_identifier_filters.py +++ b/tests/unit/memory/test_identifier_filters.py @@ -9,36 +9,36 @@ @pytest.mark.parametrize( - "sub_path, partial_match, case_sensitive", + "array_element_path, partial_match, case_sensitive", [ ("$.class_name", True, False), ("$.class_name", False, True), ("$.class_name", True, True), ], - ids=["sub_path+partial_match", "sub_path+case_sensitive", "sub_path+both"], + ids=["array_element_path+partial_match", "array_element_path+case_sensitive", "array_element_path+both"], ) -def test_identifier_filter_sub_path_with_partial_or_case_sensitive_raises( - sub_path: str, partial_match: bool, case_sensitive: bool +def test_identifier_filter_array_element_path_with_partial_or_case_sensitive_raises( + array_element_path: str, partial_match: bool, case_sensitive: bool ): - with pytest.raises(ValueError, match="Cannot use sub_path with partial_match or case_sensitive"): + with pytest.raises(ValueError, match="Cannot use array_element_path with partial_match or case_sensitive"): IdentifierFilter( identifier_type=IdentifierType.ATTACK, property_path="$.children", value_to_match="test", - sub_path=sub_path, + array_element_path=array_element_path, partial_match=partial_match, case_sensitive=case_sensitive, ) -def test_identifier_filter_valid_with_sub_path(): +def test_identifier_filter_valid_with_array_element_path(): f = IdentifierFilter( identifier_type=IdentifierType.CONVERTER, property_path="$", value_to_match="Base64Converter", - sub_path="$.class_name", + array_element_path="$.class_name", ) - assert f.sub_path == "$.class_name" + assert f.array_element_path == "$.class_name" assert not f.partial_match assert not f.case_sensitive From 775a7a56e4ab93b19614e49b2a483ef38edac641 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 8 Apr 2026 08:57:30 -0700 Subject: [PATCH 29/30] nits --- pyrit/memory/azure_sql_memory.py | 10 +++++----- pyrit/memory/memory_interface.py | 2 ++ pyrit/memory/sqlite_memory.py | 6 +++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 0e8777a0aa..d9a641ac60 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -12,7 +12,7 @@ from sqlalchemy import and_, create_engine, event, exists, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import joinedload, sessionmaker +from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import TextClause @@ -308,7 +308,7 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) def _get_condition_json_property_match( self, *, - json_column: Any, + json_column: InstrumentedAttribute[Any], property_path: str, value_to_match: str, partial_match: bool = False, @@ -353,7 +353,7 @@ def _get_condition_json_property_match( def _get_condition_json_array_match( self, *, - json_column: Any, + json_column: InstrumentedAttribute[Any], property_path: str, array_element_path: str | None = None, array_to_match: Sequence[str], @@ -366,8 +366,8 @@ def _get_condition_json_array_match( property_path (str): The JSON path for the target array. array_element_path (Optional[str]): An optional JSON path applied to each array item before matching. array_to_match (Sequence[str]): The array that must match the extracted JSON array values. - For a match, ALL values in this array must be present in the JSON array. - If `array_to_match` is empty, the condition must match only if the target is also an empty array or None. + For a match, ALL values in this array must be present in the JSON array. + If `array_to_match` is empty, the condition matches only if the target is also an empty array or None. Returns: Any: A database-specific SQLAlchemy condition. diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 828ab7c96f..e52d869053 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -194,6 +194,8 @@ def _get_condition_json_match( """ if array_element_path and (partial_match or case_sensitive): raise ValueError("Cannot use array_element_path with partial_match or case_sensitive") + if partial_match and case_sensitive: + raise ValueError("case_sensitive is not reliably supported with partial_match across all backends") if array_element_path: return self._get_condition_json_array_match( json_column=json_column, diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index ce65bb2381..8b70322af4 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -13,7 +13,7 @@ from sqlalchemy import and_, create_engine, func, or_, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import joinedload, sessionmaker +from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker from sqlalchemy.orm.session import Session from sqlalchemy.pool import StaticPool from sqlalchemy.sql.expression import TextClause @@ -193,7 +193,7 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) def _get_condition_json_property_match( self, *, - json_column: Any, + json_column: InstrumentedAttribute[Any], property_path: str, value_to_match: str, partial_match: bool = False, @@ -226,7 +226,7 @@ def _get_condition_json_property_match( def _get_condition_json_array_match( self, *, - json_column: Any, + json_column: InstrumentedAttribute[Any], property_path: str, array_element_path: str | None = None, array_to_match: Sequence[str], From 4113b81d82540f2f6920edb4b6dd313fdc69f135 Mon Sep 17 00:00:00 2001 From: Behnam Ousat Date: Wed, 8 Apr 2026 09:05:12 -0700 Subject: [PATCH 30/30] docs --- pyrit/memory/memory_interface.py | 4 ++-- pyrit/memory/sqlite_memory.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index e52d869053..b60c5ecaeb 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -252,8 +252,8 @@ def _get_condition_json_array_match( property_path (str): The JSON path for the target array. array_element_path (Optional[str]): An optional JSON path applied to each array item before matching. array_to_match (Sequence[str]): The array that must match the extracted JSON array values. - For a match, ALL values in this array must be present in the JSON array. - If `array_to_match` is empty, the condition must match only if the target is also an empty array or None. + For a match, ALL values in this array must be present in the JSON array. + If `array_to_match` is empty, the condition matches only if the target is also an empty array or None. Returns: Any: A database-specific SQLAlchemy condition. diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 8b70322af4..942c4384d2 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -239,8 +239,8 @@ def _get_condition_json_array_match( property_path (str): The JSON path for the target array. array_element_path (Optional[str]): An optional JSON path applied to each array item before matching. array_to_match (Sequence[str]): The array that must match the extracted JSON array values. - For a match, ALL values in this array must be present in the JSON array. - If `array_to_match` is empty, the condition must match only if the target is also an empty array or None. + For a match, ALL values in this array must be present in the JSON array. + If `array_to_match` is empty, the condition matches only if the target is also an empty array or None. Returns: Any: A database-specific SQLAlchemy condition.