-
Notifications
You must be signed in to change notification settings - Fork 718
FEAT: Standardizing AIRTInitializer #1578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,8 +8,12 @@ | |
| AIRT configuration including converters, scorers, and targets using Azure OpenAI. | ||
| """ | ||
|
|
||
| import json | ||
| import os | ||
| from collections.abc import Callable | ||
| from pathlib import Path | ||
|
|
||
| import yaml | ||
|
|
||
| from pyrit.auth import get_azure_openai_auth, get_azure_token_provider | ||
| from pyrit.common.apply_defaults import set_default_value, set_global_variable | ||
|
|
@@ -43,12 +47,15 @@ class AIRTInitializer(PyRITInitializer): | |
| - Converter targets with Azure OpenAI configuration | ||
| - Composite harm and objective scorers | ||
| - Adversarial target configurations for attacks | ||
| - Use of an Azure SQL database | ||
|
|
||
| Required Environment Variables: | ||
| - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT: Azure OpenAI endpoint for converters and targets | ||
| - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL: Azure OpenAI model name for converters and targets | ||
| - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2: Azure OpenAI endpoint for scoring | ||
| - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2: Azure OpenAI model name for scoring | ||
| - AZURE_SQL_DB_CONNECTION_STRING: Azure SQL database connection string | ||
| - AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL: Azure SQL database location | ||
|
|
||
| Optional Environment Variables: | ||
| - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY: API key for converter endpoint. If not set, Entra ID auth is used. | ||
|
|
@@ -90,6 +97,8 @@ def required_env_vars(self) -> list[str]: | |
| "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2", | ||
| "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2", | ||
| "AZURE_CONTENT_SAFETY_API_ENDPOINT", | ||
| "AZURE_SQL_DB_CONNECTION_STRING", | ||
| "AZURE_STORAGE_ACCOUNT_DB_DATA_CONTAINER_URL", | ||
| ] | ||
|
|
||
| async def initialize_async(self) -> None: | ||
|
|
@@ -102,6 +111,9 @@ async def initialize_async(self) -> None: | |
| 3. Adversarial target configurations | ||
| 4. Default values for all attack types | ||
| """ | ||
| # Ensure op_name, username, and email are populated from GLOBAL_MEMORY_LABELS. | ||
| self._validate_operation_fields() | ||
|
|
||
| # Get environment variables (validated by validate() method) | ||
| converter_endpoint = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") | ||
| converter_model_name = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL") | ||
|
|
@@ -255,3 +267,36 @@ def _setup_adversarial_targets(self, *, endpoint: str, api_key: str, model_name: | |
| parameter_name="attack_adversarial_config", | ||
| value=adversarial_config, | ||
| ) | ||
|
|
||
| def _validate_operation_fields(self) -> None: | ||
| """ | ||
| Check that mandatory global memory labels (operation, operator) | ||
| are populated. | ||
|
|
||
| Raises: | ||
| ValueError: If mandatory global memory labels are missing. | ||
| """ | ||
| config_path = Path.home() / ".pyrit" / ".pyrit_conf" | ||
| with open(config_path) as f: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not 100% sure how configuration loader and initializers tie into each other, but it sounds like we have configuration_loader to avoid parsing yml directly? in that case can we use that to resolve these lables? ( skimming throught config loader, it looks like config loader references an initializer, which made me think: maybe moving this whole validation function to the config loader makes more sense? but the other validation on require_env_vars is happening here, so maybe this is a good place for label validation too ... hence why I don't have a strong opinion) |
||
| data = yaml.load(f, Loader=yaml.SafeLoader) | ||
|
|
||
| if "operator" not in data: | ||
| raise ValueError( | ||
| "Error: `operator` was not set in .pyrit_conf. This is a required value for the AIRTInitializer." | ||
| ) | ||
|
|
||
| if "operation" not in data: | ||
| raise ValueError( | ||
| "Error: `operation` was not set in .pyrit_conf. This is a required value for the AIRTInitializer." | ||
| ) | ||
|
|
||
| raw_labels = os.environ.get("GLOBAL_MEMORY_LABELS") | ||
| labels = dict(json.loads(raw_labels)) if raw_labels else {} | ||
|
|
||
| if "username" not in labels: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems a bit odd that the fields in
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are |
||
| labels["username"] = data["operator"] | ||
|
|
||
| if "op_name" not in labels: | ||
| labels["op_name"] = data["operation"] | ||
|
|
||
| os.environ["GLOBAL_MEMORY_LABELS"] = json.dumps(labels) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you reference DEFAULT_CONFIG_PATH in path.py