Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyrit/setup/configuration_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class ConfigurationLoader(YamlLoadable):
env_files: List of environment file paths to load.
None means "use defaults (.env, .env.local)", [] means "load nothing".
silent: Whether to suppress initialization messages.
operator: Name for the current operator, e.g. a team or username.
operation: Name for the current operation.
Example YAML configuration:
memory_db_type: sqlite
Expand Down
45 changes: 45 additions & 0 deletions pyrit/setup/initializers/airt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
AIRT configuration including converters, scorers, and targets using Azure OpenAI.
"""

import json
import os
from collections.abc import Callable
from pathlib import Path

import yaml

from pyrit.auth import get_azure_openai_auth, get_azure_token_provider
from pyrit.common.apply_defaults import set_default_value, set_global_variable
Expand Down Expand Up @@ -43,12 +47,15 @@ class AIRTInitializer(PyRITInitializer):
- Converter targets with Azure OpenAI configuration
- Composite harm and objective scorers
- Adversarial target configurations for attacks
- Use of an Azure SQL database

Required Environment Variables:
- AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT: Azure OpenAI endpoint for converters and targets
- AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL: Azure OpenAI model name for converters and targets
- AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2: Azure OpenAI endpoint for scoring
- AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2: Azure OpenAI model name for scoring
- AZURE_SQL_DB_CONNECTION_STRING: Azure SQL database connection string
- AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: Azure SQL database location

Optional Environment Variables:
- AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY: API key for converter endpoint. If not set, Entra ID auth is used.
Expand Down Expand Up @@ -90,6 +97,8 @@ def required_env_vars(self) -> list[str]:
"AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2",
"AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2",
"AZURE_CONTENT_SAFETY_API_ENDPOINT",
"AZURE_SQL_DB_CONNECTION_STRING",
"AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL",
]

async def initialize_async(self) -> None:
Expand All @@ -102,6 +111,9 @@ async def initialize_async(self) -> None:
3. Adversarial target configurations
4. Default values for all attack types
"""
# Ensure op_name, username, and email are populated from GLOBAL_MEMORY_LABELS.
self._validate_operation_fields()

# Get environment variables (validated by validate() method)
converter_endpoint = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT")
converter_model_name = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL")
Expand Down Expand Up @@ -255,3 +267,36 @@ def _setup_adversarial_targets(self, *, endpoint: str, api_key: str, model_name:
parameter_name="attack_adversarial_config",
value=adversarial_config,
)

def _validate_operation_fields(self) -> None:
"""
Check that mandatory global memory labels (operation, operator)
are populated.

Raises:
ValueError: If mandatory global memory labels are missing.
"""
config_path = Path.home() / ".pyrit" / ".pyrit_conf"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you reference DEFAULT_CONFIG_PATH in path.py

with open(config_path) as f:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure how configuration loader and initializers tie into each other, but it sounds like we have configuration_loader to avoid parsing yml directly? in that case can we use that to resolve these lables?

( skimming throught config loader, it looks like config loader references an initializer, which made me think: maybe moving this whole validation function to the config loader makes more sense? but the other validation on require_env_vars is happening here, so maybe this is a good place for label validation too ... hence why I don't have a strong opinion)

data = yaml.load(f, Loader=yaml.SafeLoader)

if "operator" not in data:
raise ValueError(
"Error: `operator` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
)

if "operation" not in data:
raise ValueError(
"Error: `operation` was not set in .pyrit_conf. This is a required value for the AIRTInitializer."
)

raw_labels = os.environ.get("GLOBAL_MEMORY_LABELS")
labels = dict(json.loads(raw_labels)) if raw_labels else {}

if "username" not in labels:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a bit odd that the fields in .pyrit_conf and GLOBAL_MEMORY_LABELS have different names... any chance we can align them as part of this PR? I don't think we'll be breaking any AIRT operator at this point...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are username and op_name labels needed exactly with these key names? whatever systems/person wants to reference these, can't they just reference operator and operation ? (just because those are already in the example, and used in a few places (like the GUI, internal bootstrap script, etc)

labels["username"] = data["operator"]

if "op_name" not in labels:
labels["op_name"] = data["operation"]

os.environ["GLOBAL_MEMORY_LABELS"] = json.dumps(labels)
57 changes: 53 additions & 4 deletions tests/unit/setup/test_airt_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,27 @@

import os
import sys
from unittest.mock import patch
from pathlib import Path
from unittest.mock import mock_open, patch

import pytest
import yaml

from pyrit.common.apply_defaults import reset_default_values
from pyrit.setup.initializers import AIRTInitializer


@pytest.fixture
def patch_pyrit_conf(tmp_path):
"""Create a temporary .pyrit_conf file and patch _validate_operation_fields to read from it."""
conf_file = tmp_path / ".pyrit_conf"
conf_file.write_text(yaml.dump({"operator": "test_user", "operation": "test_op"}))
with patch.object(Path, "home", return_value=tmp_path):
(tmp_path / ".pyrit").mkdir(exist_ok=True)
(tmp_path / ".pyrit" / ".pyrit_conf").write_text(yaml.dump({"operator": "test_user", "operation": "test_op"}))
yield


class TestAIRTInitializer:
"""Tests for AIRTInitializer class - basic functionality."""

Expand Down Expand Up @@ -41,6 +54,9 @@ def setup_method(self) -> None:
os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2"] = "https://test-scorer.openai.azure.com"
os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2"] = "gpt-4"
os.environ["AZURE_CONTENT_SAFETY_API_ENDPOINT"] = "https://test-safety.cognitiveservices.azure.com"
os.environ["AZURE_SQL_DB_CONNECTION_STRING"] = "Server=test.database.windows.net;Database=testdb"
os.environ["AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL"] = "https://teststorage.blob.core.windows.net/data"
os.environ["GLOBAL_MEMORY_LABELS"] = '{"op_name": "test_op", "username": "test_user", "email": "test@test.com"}'
# Clean up globals
for attr in [
"default_converter_target",
Expand All @@ -61,6 +77,9 @@ def teardown_method(self) -> None:
"AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2",
"AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2",
"AZURE_CONTENT_SAFETY_API_ENDPOINT",
"AZURE_SQL_DB_CONNECTION_STRING",
"AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL",
"GLOBAL_MEMORY_LABELS",
]:
if var in os.environ:
del os.environ[var]
Expand All @@ -75,7 +94,7 @@ def teardown_method(self) -> None:
delattr(sys.modules["__main__"], attr)

@pytest.mark.asyncio
async def test_initialize_runs_without_error(self):
async def test_initialize_runs_without_error(self, patch_pyrit_conf):
"""Test that initialize runs without errors when no API keys are set (Entra auth fallback)."""
init = AIRTInitializer()
with (
Expand All @@ -85,7 +104,7 @@ async def test_initialize_runs_without_error(self):
await init.initialize_async()

@pytest.mark.asyncio
async def test_initialize_uses_api_keys_when_set(self):
async def test_initialize_uses_api_keys_when_set(self, patch_pyrit_conf):
"""Test that initialize uses API keys from env vars when they are set."""
os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY"] = "converter-key"
os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2"] = "scorer-key"
Expand All @@ -110,7 +129,7 @@ async def test_initialize_uses_api_keys_when_set(self):
del os.environ[var]

@pytest.mark.asyncio
async def test_get_info_after_initialize_has_populated_data(self):
async def test_get_info_after_initialize_has_populated_data(self, patch_pyrit_conf):
"""Test that get_info_async() returns populated data after initialization."""
init = AIRTInitializer()
with (
Expand Down Expand Up @@ -174,6 +193,36 @@ def test_validate_missing_multiple_env_vars_raises_error(self):
assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT" in error_message
assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL" in error_message

def test_validate_missing_operator_raises_error(self):
"""Test that _validate_operation_fields raises error when operator is missing from .pyrit_conf."""
conf_data = yaml.dump({"operation": "test_op"})
init = AIRTInitializer()
with (
patch("builtins.open", mock_open(read_data=conf_data)),
pytest.raises(ValueError, match="operator"),
):
init._validate_operation_fields()

def test_validate_missing_operation_raises_error(self):
"""Test that _validate_operation_fields raises error when operation is missing from .pyrit_conf."""
conf_data = yaml.dump({"operator": "test_user"})
init = AIRTInitializer()
with (
patch("builtins.open", mock_open(read_data=conf_data)),
pytest.raises(ValueError, match="operation"),
):
init._validate_operation_fields()

def test_validate_db_connection_raises_error(self):
"""Test that validate raises error when AZURE_SQL_DB_CONNECTION_STRING is missing."""
del os.environ["AZURE_SQL_DB_CONNECTION_STRING"]
init = AIRTInitializer()
with pytest.raises(ValueError) as exc_info:
init.validate()

error_message = str(exc_info.value)
assert "AZURE_SQL_DB_CONNECTION_STRING" in error_message


class TestAIRTInitializerGetInfo:
"""Tests for AIRTInitializer.get_info method - basic functionality."""
Expand Down
Loading