Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyrit/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from pyrit.memory.azure_sql_memory import AzureSQLMemory
from pyrit.memory.central_memory import CentralMemory
from pyrit.memory.identifier_filters import IdentifierFilter, IdentifierType
from pyrit.memory.memory_embedding import MemoryEmbedding
from pyrit.memory.memory_exporter import MemoryExporter
from pyrit.memory.memory_interface import MemoryInterface
Expand All @@ -26,4 +27,6 @@
"MemoryExporter",
"PromptMemoryEntry",
"SeedEntry",
"IdentifierFilter",
"IdentifierType",
]
214 changes: 102 additions & 112 deletions pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sqlalchemy import and_, create_engine, event, exists, text
from sqlalchemy.engine.base import Engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import joinedload, sessionmaker
from sqlalchemy.orm import InstrumentedAttribute, joinedload, sessionmaker
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import TextClause

Expand Down Expand Up @@ -250,22 +250,6 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str
condition = text(conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()})
return [condition]

def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any:
"""
Generate SQL condition for filtering message pieces by attack ID.

Uses JSON_VALUE() function specific to SQL Azure to query the attack identifier.

Args:
attack_id (str): The attack identifier to filter by.

Returns:
Any: SQLAlchemy text condition with bound parameter.
"""
return text("ISJSON(attack_identifier) = 1 AND JSON_VALUE(attack_identifier, '$.hash') = :json_id").bindparams(
json_id=str(attack_id)
)

def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list[TextClause]:
"""
Generate SQL conditions for filtering by prompt metadata.
Expand Down Expand Up @@ -321,6 +305,105 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]])
"""
return self._get_metadata_conditions(prompt_metadata=metadata)[0]

def _get_condition_json_property_match(
self,
*,
json_column: InstrumentedAttribute[Any],
property_path: str,
value_to_match: str,
partial_match: bool = False,
case_sensitive: bool = False,
) -> Any:
"""
Return an Azure SQL DB condition for matching a value at a given path within a JSON object.

Args:
json_column (InstrumentedAttribute[Any]): The JSON-backed model field to query.
property_path (str): The JSON path for the property to match.
value_to_match (str): The string value that must match the extracted JSON property value.
partial_match (bool): Whether to perform a substring match.
case_sensitive (bool): Whether the match should be case-sensitive. Defaults to False.

Returns:
Any: A SQLAlchemy condition for the backend-specific JSON query.
"""
uid = self._uid()
table_name = json_column.class_.__tablename__
column_name = json_column.key
pp_param = f"pp_{uid}"
mv_param = f"mv_{uid}"
json_func = "JSON_VALUE" if case_sensitive else "LOWER(JSON_VALUE)"
operator = "LIKE" if partial_match else "="
target = value_to_match if case_sensitive else value_to_match.lower()
if partial_match:
escaped = target.replace("%", "\\%").replace("_", "\\_")
target = f"%{escaped}%"

escape_clause = " ESCAPE '\\'" if partial_match else ""
return text(
f"""ISJSON("{table_name}".{column_name}) = 1
AND {json_func}("{table_name}".{column_name}, :{pp_param}) {operator} :{mv_param}{escape_clause}"""
).bindparams(
**{
pp_param: property_path,
mv_param: target,
}
)

def _get_condition_json_array_match(
self,
*,
json_column: InstrumentedAttribute[Any],
property_path: str,
array_element_path: str | None = None,
array_to_match: Sequence[str],
) -> Any:
"""
Return an Azure SQL DB condition for matching an array at a given path within a JSON object.

Args:
json_column (InstrumentedAttribute[Any]): The JSON-backed SQLAlchemy field to query.
property_path (str): The JSON path for the target array.
array_element_path (Optional[str]): An optional JSON path applied to each array item before matching.
array_to_match (Sequence[str]): The array that must match the extracted JSON array values.
For a match, ALL values in this array must be present in the JSON array.
If `array_to_match` is empty, the condition matches only if the target is also an empty array or None.

Returns:
Any: A database-specific SQLAlchemy condition.
"""
uid = self._uid()
table_name = json_column.class_.__tablename__
column_name = json_column.key
pp_param = f"pp_{uid}"
sp_param = f"sp_{uid}"

if len(array_to_match) == 0:
return text(
f"""("{table_name}".{column_name} IS NULL
OR JSON_QUERY("{table_name}".{column_name}, :{pp_param}) IS NULL
OR JSON_QUERY("{table_name}".{column_name}, :{pp_param}) = '[]')"""
).bindparams(**{pp_param: property_path})

value_expression = f"LOWER(JSON_VALUE(value, :{sp_param}))" if array_element_path else "LOWER(value)"

conditions = []
bindparams_dict: dict[str, str] = {pp_param: property_path}
if array_element_path:
bindparams_dict[sp_param] = array_element_path

for index, match_value in enumerate(array_to_match):
mv_param = f"mv_{uid}_{index}"
conditions.append(
f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("{table_name}".{column_name},
:{pp_param}))
WHERE {value_expression} = :{mv_param})"""
)
bindparams_dict[mv_param] = match_value.lower()

combined = " AND ".join(conditions)
return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict)

def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any:
"""
Get the SQL Azure implementation for filtering AttackResults by targeted harm categories.
Expand Down Expand Up @@ -388,67 +471,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
)
)

def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any:
"""
Azure SQL implementation for filtering AttackResults by attack class.
Uses JSON_VALUE() on the atomic_attack_identifier JSON column.

Args:
attack_class (str): Exact attack class name to match.

Returns:
Any: SQLAlchemy text condition with bound parameter.
"""
return text(
"""ISJSON("AttackResultEntries".atomic_attack_identifier) = 1
AND JSON_VALUE("AttackResultEntries".atomic_attack_identifier,
'$.children.attack.class_name') = :attack_class"""
).bindparams(attack_class=attack_class)

def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any:
"""
Azure SQL implementation for filtering AttackResults by converter classes.

Uses JSON_VALUE()/JSON_QUERY()/OPENJSON() on the atomic_attack_identifier
JSON column.

When converter_classes is empty, matches attacks with no converters.
When non-empty, uses OPENJSON() to check all specified classes are present
(AND logic, case-insensitive).

Args:
converter_classes (Sequence[str]): List of converter class names. Empty list means no converters.

Returns:
Any: SQLAlchemy combined condition with bound parameters.
"""
if len(converter_classes) == 0:
# Explicitly "no converters": match attacks where the converter list
# is absent, null, or empty in the stored JSON.
return text(
"""("AttackResultEntries".atomic_attack_identifier IS NULL
OR JSON_QUERY("AttackResultEntries".atomic_attack_identifier,
'$.children.attack.children.request_converters') IS NULL
OR JSON_QUERY("AttackResultEntries".atomic_attack_identifier,
'$.children.attack.children.request_converters') = '[]')"""
)

conditions = []
bindparams_dict: dict[str, str] = {}
for i, cls in enumerate(converter_classes):
param_name = f"conv_cls_{i}"
conditions.append(
f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("AttackResultEntries".atomic_attack_identifier,
'$.children.attack.children.request_converters'))
WHERE LOWER(JSON_VALUE(value, '$.class_name')) = :{param_name})"""
)
bindparams_dict[param_name] = cls.lower()

combined = " AND ".join(conditions)
return text(f"""ISJSON("AttackResultEntries".atomic_attack_identifier) = 1 AND {combined}""").bindparams(
**bindparams_dict
)

def get_unique_attack_class_names(self) -> list[str]:
"""
Azure SQL implementation: extract unique class_name values from
Expand Down Expand Up @@ -593,44 +615,12 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any
conditions.append(condition)
return and_(*conditions)

def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> TextClause:
"""
Get the SQL Azure implementation for filtering ScenarioResults by target endpoint.

Uses JSON_VALUE() function specific to SQL Azure.

Args:
endpoint (str): The endpoint URL substring to filter by (case-insensitive).

Returns:
Any: SQLAlchemy text condition with bound parameter.
"""
return text(
"""ISJSON(objective_target_identifier) = 1
AND LOWER(JSON_VALUE(objective_target_identifier, '$.endpoint')) LIKE :endpoint"""
).bindparams(endpoint=f"%{endpoint.lower()}%")

def _get_scenario_result_target_model_condition(self, *, model_name: str) -> TextClause:
"""
Get the SQL Azure implementation for filtering ScenarioResults by target model name.

Uses JSON_VALUE() function specific to SQL Azure.

Args:
model_name (str): The model name substring to filter by (case-insensitive).

Returns:
Any: SQLAlchemy text condition with bound parameter.
"""
return text(
"""ISJSON(objective_target_identifier) = 1
AND LOWER(JSON_VALUE(objective_target_identifier, '$.model_name')) LIKE :model_name"""
).bindparams(model_name=f"%{model_name.lower()}%")

def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None:
"""
Insert a list of message pieces into the memory storage.

Args:
message_pieces (Sequence[MessagePiece]): A sequence of MessagePiece instances to be added.
"""
self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces])

Expand Down
51 changes: 51 additions & 0 deletions pyrit/memory/identifier_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from dataclasses import dataclass
from enum import Enum


class IdentifierType(Enum):
"""Enumeration of supported identifier types for filtering."""

ATTACK = "attack"
TARGET = "target"
SCORER = "scorer"
CONVERTER = "converter"


@dataclass(frozen=True)
class IdentifierFilter:
"""
Immutable filter definition for matching JSON-backed identifier properties.

Attributes:
identifier_type: The type of identifier column to filter on.
property_path: The JSON path for the property to match.
array_element_path : An optional JSON path that indicates the property at property_path is an array
and the condition should resolve if the value at array_element_path matches the target
for any element in that array. Cannot be used with partial_match or case_sensitive.
value_to_match: The string value that must match the extracted JSON property value.
partial_match: Whether to perform a substring match. Cannot be used with array_element_path or case_sensitive.
case_sensitive: Whether the match should be case-sensitive.
Cannot be used with array_element_path or partial_match.
"""

identifier_type: IdentifierType
property_path: str
value_to_match: str
array_element_path: str | None = None
partial_match: bool = False
case_sensitive: bool = False

def __post_init__(self) -> None:
"""
Validate the filter configuration.

Raises:
ValueError: If the filter configuration is not valid.
"""
if self.array_element_path and (self.partial_match or self.case_sensitive):
raise ValueError("Cannot use array_element_path with partial_match or case_sensitive")
if self.partial_match and self.case_sensitive:
raise ValueError("case_sensitive is not reliably supported with partial_match across all backends")
Loading
Loading