diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 5e2ce65b9..6f720b303 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -70,6 +70,8 @@ class ConfigurationLoader(YamlLoadable): env_files: List of environment file paths to load. None means "use defaults (.env, .env.local)", [] means "load nothing". silent: Whether to suppress initialization messages. + operator: Name for the current operator, e.g. a team or username. + operation: Name for the current operation. Example YAML configuration: memory_db_type: sqlite diff --git a/pyrit/setup/initializers/airt.py b/pyrit/setup/initializers/airt.py index 96740565d..51ca7289d 100644 --- a/pyrit/setup/initializers/airt.py +++ b/pyrit/setup/initializers/airt.py @@ -8,8 +8,12 @@ AIRT configuration including converters, scorers, and targets using Azure OpenAI. """ +import json import os from collections.abc import Callable +from pathlib import Path + +import yaml from pyrit.auth import get_azure_openai_auth, get_azure_token_provider from pyrit.common.apply_defaults import set_default_value, set_global_variable @@ -43,12 +47,15 @@ class AIRTInitializer(PyRITInitializer): - Converter targets with Azure OpenAI configuration - Composite harm and objective scorers - Adversarial target configurations for attacks + - Use of an Azure SQL database Required Environment Variables: - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT: Azure OpenAI endpoint for converters and targets - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL: Azure OpenAI model name for converters and targets - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2: Azure OpenAI endpoint for scoring - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2: Azure OpenAI model name for scoring + - AZURE_SQL_DB_CONNECTION_STRING: Azure SQL database connection string + - AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: Azure SQL database location Optional Environment Variables: - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY: API key for converter endpoint. If not set, Entra ID auth is used. @@ -90,6 +97,8 @@ def required_env_vars(self) -> list[str]: "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2", "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2", "AZURE_CONTENT_SAFETY_API_ENDPOINT", + "AZURE_SQL_DB_CONNECTION_STRING", + "AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL", ] async def initialize_async(self) -> None: @@ -102,6 +111,9 @@ async def initialize_async(self) -> None: 3. Adversarial target configurations 4. Default values for all attack types """ + # Ensure op_name, username, and email are populated from GLOBAL_MEMORY_LABELS. + self._validate_operation_fields() + # Get environment variables (validated by validate() method) converter_endpoint = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") converter_model_name = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL") @@ -255,3 +267,36 @@ def _setup_adversarial_targets(self, *, endpoint: str, api_key: str, model_name: parameter_name="attack_adversarial_config", value=adversarial_config, ) + + def _validate_operation_fields(self) -> None: + """ + Check that mandatory global memory labels (operation, operator) + are populated. + + Raises: + ValueError: If mandatory global memory labels are missing. + """ + config_path = Path.home() / ".pyrit" / ".pyrit_conf" + with open(config_path) as f: + data = yaml.load(f, Loader=yaml.SafeLoader) + + if "operator" not in data: + raise ValueError( + "Error: `operator` was not set in .pyrit_conf. This is a required value for the AIRTInitializer." + ) + + if "operation" not in data: + raise ValueError( + "Error: `operation` was not set in .pyrit_conf. This is a required value for the AIRTInitializer." + ) + + raw_labels = os.environ.get("GLOBAL_MEMORY_LABELS") + labels = dict(json.loads(raw_labels)) if raw_labels else {} + + if "username" not in labels: + labels["username"] = data["operator"] + + if "op_name" not in labels: + labels["op_name"] = data["operation"] + + os.environ["GLOBAL_MEMORY_LABELS"] = json.dumps(labels) diff --git a/tests/unit/setup/test_airt_initializer.py b/tests/unit/setup/test_airt_initializer.py index 2a8606cde..5f059302c 100644 --- a/tests/unit/setup/test_airt_initializer.py +++ b/tests/unit/setup/test_airt_initializer.py @@ -3,14 +3,27 @@ import os import sys -from unittest.mock import patch +from pathlib import Path +from unittest.mock import mock_open, patch import pytest +import yaml from pyrit.common.apply_defaults import reset_default_values from pyrit.setup.initializers import AIRTInitializer +@pytest.fixture +def patch_pyrit_conf(tmp_path): + """Create a temporary .pyrit_conf file and patch _validate_operation_fields to read from it.""" + conf_file = tmp_path / ".pyrit_conf" + conf_file.write_text(yaml.dump({"operator": "test_user", "operation": "test_op"})) + with patch.object(Path, "home", return_value=tmp_path): + (tmp_path / ".pyrit").mkdir(exist_ok=True) + (tmp_path / ".pyrit" / ".pyrit_conf").write_text(yaml.dump({"operator": "test_user", "operation": "test_op"})) + yield + + class TestAIRTInitializer: """Tests for AIRTInitializer class - basic functionality.""" @@ -41,6 +54,9 @@ def setup_method(self) -> None: os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2"] = "https://test-scorer.openai.azure.com" os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2"] = "gpt-4" os.environ["AZURE_CONTENT_SAFETY_API_ENDPOINT"] = "https://test-safety.cognitiveservices.azure.com" + os.environ["AZURE_SQL_DB_CONNECTION_STRING"] = "Server=test.database.windows.net;Database=testdb" + os.environ["AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL"] = "https://teststorage.blob.core.windows.net/data" + os.environ["GLOBAL_MEMORY_LABELS"] = '{"op_name": "test_op", "username": "test_user", "email": "test@test.com"}' # Clean up globals for attr in [ "default_converter_target", @@ -61,6 +77,9 @@ def teardown_method(self) -> None: "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2", "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2", "AZURE_CONTENT_SAFETY_API_ENDPOINT", + "AZURE_SQL_DB_CONNECTION_STRING", + "AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL", + "GLOBAL_MEMORY_LABELS", ]: if var in os.environ: del os.environ[var] @@ -75,7 +94,7 @@ def teardown_method(self) -> None: delattr(sys.modules["__main__"], attr) @pytest.mark.asyncio - async def test_initialize_runs_without_error(self): + async def test_initialize_runs_without_error(self, patch_pyrit_conf): """Test that initialize runs without errors when no API keys are set (Entra auth fallback).""" init = AIRTInitializer() with ( @@ -85,7 +104,7 @@ async def test_initialize_runs_without_error(self): await init.initialize_async() @pytest.mark.asyncio - async def test_initialize_uses_api_keys_when_set(self): + async def test_initialize_uses_api_keys_when_set(self, patch_pyrit_conf): """Test that initialize uses API keys from env vars when they are set.""" os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY"] = "converter-key" os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2"] = "scorer-key" @@ -110,7 +129,7 @@ async def test_initialize_uses_api_keys_when_set(self): del os.environ[var] @pytest.mark.asyncio - async def test_get_info_after_initialize_has_populated_data(self): + async def test_get_info_after_initialize_has_populated_data(self, patch_pyrit_conf): """Test that get_info_async() returns populated data after initialization.""" init = AIRTInitializer() with ( @@ -174,6 +193,36 @@ def test_validate_missing_multiple_env_vars_raises_error(self): assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT" in error_message assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL" in error_message + def test_validate_missing_operator_raises_error(self): + """Test that _validate_operation_fields raises error when operator is missing from .pyrit_conf.""" + conf_data = yaml.dump({"operation": "test_op"}) + init = AIRTInitializer() + with ( + patch("builtins.open", mock_open(read_data=conf_data)), + pytest.raises(ValueError, match="operator"), + ): + init._validate_operation_fields() + + def test_validate_missing_operation_raises_error(self): + """Test that _validate_operation_fields raises error when operation is missing from .pyrit_conf.""" + conf_data = yaml.dump({"operator": "test_user"}) + init = AIRTInitializer() + with ( + patch("builtins.open", mock_open(read_data=conf_data)), + pytest.raises(ValueError, match="operation"), + ): + init._validate_operation_fields() + + def test_validate_db_connection_raises_error(self): + """Test that validate raises error when AZURE_SQL_DB_CONNECTION_STRING is missing.""" + del os.environ["AZURE_SQL_DB_CONNECTION_STRING"] + init = AIRTInitializer() + with pytest.raises(ValueError) as exc_info: + init.validate() + + error_message = str(exc_info.value) + assert "AZURE_SQL_DB_CONNECTION_STRING" in error_message + class TestAIRTInitializerGetInfo: """Tests for AIRTInitializer.get_info method - basic functionality."""