diff --git a/cog_safe_push/ai.py b/cog_safe_push/ai.py index db4ec3b..9ea7526 100644 --- a/cog_safe_push/ai.py +++ b/cog_safe_push/ai.py @@ -12,6 +12,8 @@ from . import log from .exceptions import AIError, ArgumentError +MAX_TOKENS = 8192 + def async_retry(attempts=3): def decorator_retry(func): @@ -38,13 +40,13 @@ async def wrapper_retry(*args, **kwargs): async def boolean( prompt: str, files: list[Path] | None = None, include_file_metadata: bool = False ) -> bool: - system_prompt = "You only answer YES or NO, and absolutely nothing else. Your response will be used in a programmatic context so it's important that you only ever answer with either the string YES or the string NO." - # system_prompt = "You are a helpful assistant" + system_prompt = "You are a boolean classifier. You must only respond with either YES or NO, and absolutely nothing else. Your response will be used in a programmatic context so it is critical that you only ever answer with either the string YES or the string NO." output = await call( system_prompt=system_prompt, prompt=prompt.strip(), files=files, include_file_metadata=include_file_metadata, + thinking=True, ) if output == "YES": return True @@ -54,9 +56,30 @@ async def boolean( @async_retry(3) -async def json_object(prompt: str, files: list[Path] | None = None) -> dict: - system_prompt = "You always respond with valid JSON, and nothing else (no backticks, etc.). Your outputs will be used in a programmatic context." - output = await call(system_prompt=system_prompt, prompt=prompt.strip(), files=files) +async def json_object( + prompt: str, + files: list[Path] | None = None, + system_prompt: str = "", + thinking: bool = False, +) -> dict: + if system_prompt: + system_prompt = system_prompt.strip() + "\n\n" + system_prompt += "You always respond with valid JSON, and nothing else (no backticks, etc.). Your outputs will be used in a programmatic context." + output = await call( + system_prompt=system_prompt, + prompt=prompt.strip(), + files=files, + thinking=thinking, + ) + + if output.startswith("```json"): + output = output[7:] + elif output.startswith("```"): + output = output[3:] + if output.endswith("```"): + output = output[:-3] + output = output.strip() + try: return json.loads(output) except json.JSONDecodeError: @@ -68,12 +91,13 @@ async def call( prompt: str, files: list[Path] | None = None, include_file_metadata: bool = False, + thinking: bool = False, ) -> str: api_key = os.environ.get("ANTHROPIC_API_KEY") if not api_key: raise ArgumentError("ANTHROPIC_API_KEY is not defined") - model = "claude-sonnet-4-20250514" + model = "claude-sonnet-4-5" client = anthropic.AsyncAnthropic(api_key=api_key) try: @@ -96,15 +120,30 @@ async def call( {"role": "user", "content": content} ] - response = await client.messages.create( - model=model, - messages=messages, - system=system_prompt, - max_tokens=4096, - stream=False, - temperature=1.0, - ) - content = cast("anthropic.types.TextBlock", response.content[0]) + if thinking: + response = await client.messages.create( + model=model, + messages=messages, + system=system_prompt, + max_tokens=MAX_TOKENS, + stream=False, + temperature=1.0, + thinking={"type": "enabled", "budget_tokens": 2048}, + ) + else: + response = await client.messages.create( + model=model, + messages=messages, + system=system_prompt, + max_tokens=MAX_TOKENS, + stream=False, + temperature=1.0, + ) + + text_blocks = [block for block in response.content if block.type == "text"] + if not text_blocks: + raise AIError("No text content in response") + content = cast("anthropic.types.TextBlock", text_blocks[0]) finally: await client.close() diff --git a/cog_safe_push/match_outputs.py b/cog_safe_push/match_outputs.py index 09edce2..a3dbae2 100644 --- a/cog_safe_push/match_outputs.py +++ b/cog_safe_push/match_outputs.py @@ -22,12 +22,18 @@ async def output_matches_prompt(output: Any, prompt: str) -> tuple[bool, str]: urls = output if isinstance(output, list) else list(output.values()) with download_many(urls) as tmp_files: - claude_prompt = """You are part of an automatic evaluation that compares media (text, audio, image, video, etc.) to captions. I want to know if the caption matches the text or file.. + claude_prompt = """You are part of an automatic evaluation that compares media (text, audio, image, video, etc.) to descriptions. I want to know if the description matches the text or file.. """ if urls: claude_prompt += f"""Does this file(s) and the attached content of the file(s) match the description? Pay close attention to the metadata about the attached files which is included below, especially if the description mentions file type, image dimensions, or any other aspect that is described in the metadata. Do not infer file type or image dimensions from the image content, but from the attached metadata. +The description may be specific or vague, but you should match on whatever is in the description. For example: +* If the description is 'a jpg image' and it's a jpg image of a cat, that's still a match. +* If the description is 'an image of a cat' and the image is actually of a dog, it's not a match. +* If the description is 'an audio file' it should match any audio files regardless of content. +* etc. + Description to evaluate: {prompt} Filename(s): {output}""" @@ -132,6 +138,9 @@ async def strings_match(s1: str, s2: str, is_deterministic: bool) -> tuple[bool, f""" Have these two strings been generated by the same generative AI model inputs/prompt? +* If the two strings are identical, respond with YES +* If the two strings have very similar content, respond with YES + String 1: '{s1}' String 2: '{s2}' """ @@ -175,11 +184,23 @@ def is_video(url: str) -> bool: def extensions_match(url1: str, url2: str) -> bool: - ext1 = Path(urlparse(url1).path).suffix - ext2 = Path(urlparse(url2).path).suffix + ext1 = normalize_suffix(Path(urlparse(url1).path).suffix) + ext2 = normalize_suffix(Path(urlparse(url2).path).suffix) return ext1.lower() == ext2.lower() +def normalize_suffix(suffix: str) -> str: + suffix = suffix.lower() + normalizations = { + ".jpeg": ".jpg", + ".jpe": ".jpg", + ".tiff": ".tif", + ".mpeg": ".mpg", + ".htm": ".html", + } + return normalizations.get(suffix, suffix) + + def is_url(s: str) -> bool: return s.startswith(("http://", "https://")) @@ -204,7 +225,11 @@ async def images_match( return True, "" fuzzy_match = await ai.boolean( - "These two images have been generated by or modified by an AI model. Is it highly likely that those two predictions of the model had the same inputs?", + """I provide you with _two_ input images. These two images have been generated by or modified by an AI model. Is it highly likely that those two predictions of the model had the same inputs? + +* If the two images are identical, respond with YES. +* If the two images have very similar subject matters that have probably been generated by the same prompt, respond with YES. + """, files=[tmp1, tmp2], ) if fuzzy_match: diff --git a/cog_safe_push/predict.py b/cog_safe_push/predict.py index b56ca96..82e23e1 100644 --- a/cog_safe_push/predict.py +++ b/cog_safe_push/predict.py @@ -3,6 +3,7 @@ import time from typing import Any, cast +import httpx import replicate from replicate.exceptions import ReplicateError from replicate.model import Model @@ -16,38 +17,25 @@ from .utils import truncate -async def make_predict_inputs( - schemas: dict, - train: bool, - only_required: bool, - seed: int | None, - fixed_inputs: dict[str, Any], - disabled_inputs: list[str], - fuzz_prompt: str | None, - inputs_history: list[dict] | None = None, - attempt=0, -) -> tuple[dict, bool]: - input_name = "TrainingInput" if train else "Input" - input_schema = schemas[input_name] - properties = input_schema["properties"] - required = input_schema.get("required", []) +async def make_fuzz_system_prompt() -> str: + async with httpx.AsyncClient() as client: + response = await client.get( + "https://multimedia-example-files.replicate.dev/index.txt" + ) + multimedia_example_files = response.text + return ( + """# Replicate model fuzzing inputs - is_deterministic = False - if "seed" in properties and seed is not None: - is_deterministic = True - del properties["seed"] +Your task is to generate inputs for model fuzzing of a Replicate model. - fixed_inputs = {k: v for k, v in fixed_inputs.items() if k not in disabled_inputs} +Given a model input JSON schema, return a valid JSON payload for this model. - schemas_str = json.dumps(schemas, indent=2) - prompt = ( - ''' -Below is an example of an OpenAPI schema for a Cog model: +## Example + +For example, { - "''' - + input_name - + '''": { + "Input": { "properties": { "my_bool": { "description": "A bool.", @@ -99,9 +87,7 @@ async def make_predict_inputs( "my_choice", "my_constrained_int" ], - "title": "''' - + input_name - + """", + "title": "Input", "type": "object" }, "my_choice": { @@ -116,7 +102,7 @@ async def make_predict_inputs( } } -A valid json payload for that input schema would be: +A valid JSON payload for that input schema would be: { "my_bool": true, @@ -127,42 +113,69 @@ async def make_predict_inputs( "text": "world", } +The following is NOT a valid JSON payload: + +{ + "my_bool": true, + "my_choice": "foo", + "my_constrained_int": 11, + "my_float": 3.14, + "my_int": 10, + "text": "world", +} + +...because my_constrained_int is greater than the maximum in the schema. + +## Respect constraints + +Be careful to respect constraints. For example: +* If there is a "maximum" or "minimum" constraint on a number input, your generated input value must not be below the minimum or above the maximum +* If there is an allOf constraint, your input values must be one of the valid enumeration values +* If the description of an input describes constraints, your generated input must respect those constraints +* etc. + +## Multimedia file inputs + +If an input have format=uri and you decide to populate that input, you should use one of the media URLs from the Multimedia example files section below. + +Make sure you pick an appropriate URL for the the input, e.g. pick one of the image examples below if the input expects an image. Also make sure you respect any hints or documentation about file types. + """ - + f""" -Now, given the following OpenAPI schemas: + + multimedia_example_files + ) + + +async def make_fuzz_inputs( + schemas: dict, + train: bool, + only_required: bool, + seed: int | None, + fixed_inputs: dict[str, Any], + disabled_inputs: list[str], + fuzz_prompt: str | None, + inputs_history: list[dict] | None = None, + attempt=0, +) -> tuple[dict, bool]: + input_name = "TrainingInput" if train else "Input" + input_schema = schemas[input_name] + properties = input_schema["properties"] + required = input_schema.get("required", []) + + is_deterministic = False + if "seed" in properties and seed is not None: + is_deterministic = True + del properties["seed"] + + fixed_inputs = {k: v for k, v in fixed_inputs.items() if k not in disabled_inputs} + + schemas_str = json.dumps(schemas, indent=2) + prompt = f"""Given the following OpenAPI schemas: {schemas_str} -Generate a json payload for the {input_name} schema. - -If an input have format=uri and you decide to populate that input, you should use one of the following media URLs. Make sure you pick an appropriate URL for the the input, e.g. pick one of the image examples below if the input expects represents an image. - -Image: -* https://storage.googleapis.com/cog-safe-push-public/skull.jpg -* https://storage.googleapis.com/cog-safe-push-public/fast-car.jpg -* https://storage.googleapis.com/cog-safe-push-public/forest.png -* https://storage.googleapis.com/cog-safe-push-public/face.gif -Video: -* https://storage.googleapis.com/cog-safe-push-public/harry-truman.webm -* https://storage.googleapis.com/cog-safe-push-public/mariner-launch.ogv -Music audio: -* https://storage.googleapis.com/cog-safe-push-public/folk-music.mp3 -* https://storage.googleapis.com/cog-safe-push-public/ocarina.ogg -* https://storage.googleapis.com/cog-safe-push-public/nu-style-kick.wav -Test audio: -* https://storage.googleapis.com/cog-safe-push-public/clap.ogg -* https://storage.googleapis.com/cog-safe-push-public/beeps.mp3 -Long speech: -* https://storage.googleapis.com/cog-safe-push-public/chekhov-article.ogg -* https://storage.googleapis.com/cog-safe-push-public/momentos-spanish.ogg -Short speech: -* https://storage.googleapis.com/cog-safe-push-public/de-experiment-german-word.ogg -* https://storage.googleapis.com/cog-safe-push-public/de-ionendosis-german-word.ogg - -If the schema has default values for some of the inputs, feel free to either use the defaults or come up with new values. - - """ - ) +Generate a valid JSON payload for the {input_name} schema. + +""" if fixed_inputs: fixed_inputs_str = json.dumps(fixed_inputs) @@ -182,7 +195,7 @@ async def make_predict_inputs( inputs_history_str = "\n".join(["* " + json.dumps(i) for i in inputs_history]) prompt += f""" -Return a new combination of inputs that you haven't used before, ideally that's quite diverse from inputs you've used before. You have previously used these inputs: +Return a new combination of inputs that you haven't used before, ideally that's quite diverse from inputs you've used before -- but still make sure you respect the constraints in the input schema (respecting those constraints is very important!). You have previously used these inputs: {inputs_history_str}""" if fuzz_prompt: @@ -192,14 +205,15 @@ async def make_predict_inputs( You must follow these instructions: {fuzz_prompt}""" - inputs = await ai.json_object(prompt) + system_prompt = await make_fuzz_system_prompt() + inputs = await ai.json_object(prompt, system_prompt=system_prompt, thinking=True) if set(required) - set(inputs.keys()): max_attempts = 5 if attempt == max_attempts: raise AIError( f"Failed to generate a json payload with the correct keys after {max_attempts} attempts, giving up" ) - return await make_predict_inputs( + return await make_fuzz_inputs( schemas=schemas, train=train, only_required=only_required, @@ -207,6 +221,7 @@ async def make_predict_inputs( fixed_inputs=fixed_inputs, disabled_inputs=disabled_inputs, fuzz_prompt=fuzz_prompt, + inputs_history=inputs_history, attempt=attempt + 1, ) diff --git a/cog_safe_push/tasks.py b/cog_safe_push/tasks.py index af0ff32..7590034 100644 --- a/cog_safe_push/tasks.py +++ b/cog_safe_push/tasks.py @@ -11,7 +11,7 @@ ) from .match_outputs import outputs_match from .output_checkers import OutputChecker -from .predict import make_predict_inputs, predict +from .predict import make_fuzz_inputs, predict from .task_context import TaskContext @@ -41,7 +41,7 @@ async def run(self) -> None: schemas = schema.get_schemas( self.context.model, train=self.context.is_train() ) - inputs, is_deterministic = await make_predict_inputs( + inputs, is_deterministic = await make_fuzz_inputs( schemas, train=self.context.is_train(), only_required=True, @@ -130,7 +130,7 @@ async def run(self) -> None: ) inputs_history = [] for _ in range(self.num_inputs): - inputs, _ = await make_predict_inputs( + inputs, _ = await make_fuzz_inputs( schemas, train=self.context.is_train(), only_required=False, diff --git a/end-to-end-test/fixtures/image-base/predict.py b/end-to-end-test/fixtures/image-base/predict.py index 8c92f78..682f447 100644 --- a/end-to-end-test/fixtures/image-base/predict.py +++ b/end-to-end-test/fixtures/image-base/predict.py @@ -11,7 +11,9 @@ def setup(self): def predict( self, - image: Path = Input(description="Input image."), + image: Path = Input( + description="Input image. Valid file types are: jpg, png, webp, bmp, gif (not animated)" + ), width: int = Input(description="New width.", ge=1, le=2000), height: int = Input(description="New height.", ge=1, le=1000), ) -> Path: diff --git a/integration-test/test_non_matching_images.py b/integration-test/test_non_matching_images.py new file mode 100644 index 0000000..68f1df0 --- /dev/null +++ b/integration-test/test_non_matching_images.py @@ -0,0 +1,29 @@ +import pytest + +from cog_safe_push.match_outputs import outputs_match + + +@pytest.mark.asyncio +async def test_output_match_similar_images(): + url1 = "https://replicate.delivery/xezq/OrGhA2j4ACZ8FdbZgTxyaav6EKSxZ4jBnNzZwXIZZleq8TvKA/out-0.webp" + url2 = "https://replicate.delivery/xezq/Z4UKfUkAqp0RRaGQRIerW3ZGansA1Rqg6eodiOfYTfedZeTvKA/out-0.webp" + matches, error_message = await outputs_match(url1, url2, is_deterministic=False) + assert matches, error_message + + +@pytest.mark.asyncio +async def test_output_match_same_image(): + url = "https://replicate.delivery/xezq/OrGhA2j4ACZ8FdbZgTxyaav6EKSxZ4jBnNzZwXIZZleq8TvKA/out-0.webp" + matches, error_message = await outputs_match(url, url, is_deterministic=False) + assert matches, error_message + + +@pytest.mark.asyncio +async def test_output_match_not_similar_images(): + url1 = "https://replicate.delivery/xezq/OrGhA2j4ACZ8FdbZgTxyaav6EKSxZ4jBnNzZwXIZZleq8TvKA/out-0.webp" + url2 = "https://replicate.delivery/xezq/NtEEOzxwpTaFFF5fhalpLevI1HwrmGc3bNX799EzWmf51P9qA/out-0.webp" + matches, error_message = await outputs_match(url1, url2, is_deterministic=False) + assert not matches + assert error_message == "Images are not similar", ( + f"Expected 'Images are not similar' but got: {error_message}" + ) diff --git a/integration-test/test_output_matches_prompt.py b/integration-test/test_output_matches_prompt.py index 749124f..92e28d3 100644 --- a/integration-test/test_output_matches_prompt.py +++ b/integration-test/test_output_matches_prompt.py @@ -13,14 +13,14 @@ "A webp image of a bird", "A webp image of a red bird", ], - "https://replicate.delivery/czjl/QFrZ9RF8VroFM5Ml9MKt3rm0vP8ZHTWaqfO1oT6bouj0m76JA/tmpn888w5a8.jpg": [ + "https://replicate.delivery/xezq/7mpYHTkoCW5hFhYeyZHhc8pbNGjpZrSVypReBr4JsbXeLd7qA/tmpqsp1ykrz.jpg": [ "A jpg image of a formula one car", "a jpg image of a car", "A jpg image", "Formula 1 car", "car", ], - "https://replicate.delivery/czjl/8C4OJCR6w7rQEFeernSerHH5e3xe2f9cYYsGTW8k5Eob57d9E/tmpjwitpu7f.png": [ + "https://replicate.delivery/xezq/Gf7onwGGPDzgVaiQfARPWEOJZGzq94QKS3qsqlRfL9xUfi1VB/tmppfxwubub.png": [ "480x320px png image", "480x320px image of a formula one car", ], diff --git a/script/integration-test b/script/integration-test index c90b239..16da579 100755 --- a/script/integration-test +++ b/script/integration-test @@ -1,3 +1,3 @@ #!/bin/bash -eu -pytest -n4 -s integration-test/ +pytest -n8 -s integration-test/ diff --git a/test/test_predict.py b/test/test_predict.py index 140874d..dfd894d 100644 --- a/test/test_predict.py +++ b/test/test_predict.py @@ -3,7 +3,7 @@ import pytest from cog_safe_push.exceptions import AIError -from cog_safe_push.predict import make_predict_inputs +from cog_safe_push.predict import make_fuzz_inputs @pytest.fixture @@ -34,7 +34,7 @@ def sample_schemas(): async def test_make_predict_inputs_basic(mock_json_object, sample_schemas): mock_json_object.return_value = {"text": "hello", "number": 42, "choice": "A"} - inputs, is_deterministic = await make_predict_inputs( + inputs, is_deterministic = await make_fuzz_inputs( sample_schemas, train=False, only_required=True, @@ -52,7 +52,7 @@ async def test_make_predict_inputs_with_seed(sample_schemas): with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: mock_json_object.return_value = {"text": "hello", "number": 42, "choice": "A"} - inputs, is_deterministic = await make_predict_inputs( + inputs, is_deterministic = await make_fuzz_inputs( sample_schemas, train=False, only_required=True, @@ -70,7 +70,7 @@ async def test_make_predict_inputs_with_fixed_inputs(sample_schemas): with patch("cog_safe_push.predict.ai.json_object") as mock_json_object: mock_json_object.return_value = {"text": "hello", "number": 42, "choice": "A"} - inputs, _ = await make_predict_inputs( + inputs, _ = await make_fuzz_inputs( sample_schemas, train=False, only_required=True, @@ -92,7 +92,7 @@ async def test_make_predict_inputs_with_disabled_inputs(sample_schemas): "optional": True, } - inputs, _ = await make_predict_inputs( + inputs, _ = await make_fuzz_inputs( sample_schemas, train=False, only_required=False, @@ -114,7 +114,7 @@ async def test_make_predict_inputs_with_inputs_history(sample_schemas): {"text": "older", "number": 21, "choice": "B"}, ] - inputs, _ = await make_predict_inputs( + inputs, _ = await make_fuzz_inputs( sample_schemas, train=False, only_required=True, @@ -136,7 +136,7 @@ async def test_make_predict_inputs_ai_error(sample_schemas): {"text": "hello", "number": 42, "choice": "A"}, # Correct input ] - inputs, _ = await make_predict_inputs( + inputs, _ = await make_fuzz_inputs( sample_schemas, train=False, only_required=True, @@ -157,7 +157,7 @@ async def test_make_predict_inputs_max_attempts_reached(sample_schemas): } # Always missing required fields with pytest.raises(AIError): - await make_predict_inputs( + await make_fuzz_inputs( sample_schemas, train=False, only_required=True, @@ -179,7 +179,7 @@ async def test_make_predict_inputs_filters_null_values(sample_schemas): "input_image": None, # This should be filtered out } - inputs, _ = await make_predict_inputs( + inputs, _ = await make_fuzz_inputs( sample_schemas, train=False, only_required=False, @@ -220,7 +220,7 @@ async def test_make_predict_inputs_filters_various_null_representations(): "optional_field": None, # Optional field with null that should be filtered } - inputs, _ = await make_predict_inputs( + inputs, _ = await make_fuzz_inputs( schemas, train=False, only_required=False, @@ -263,7 +263,7 @@ async def test_make_predict_inputs_preserves_valid_values(): "null_field": None, # Should be filtered out } - inputs, _ = await make_predict_inputs( + inputs, _ = await make_fuzz_inputs( schemas, train=False, only_required=False,