From 81844e398ae2820ec23e18a18b643b53238a1f40 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Thu, 23 Oct 2025 15:17:07 +0200 Subject: [PATCH 1/8] Update predict function to return prediction URL --- cog_safe_push/predict.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cog_safe_push/predict.py b/cog_safe_push/predict.py index 4e15217..cc99347 100644 --- a/cog_safe_push/predict.py +++ b/cog_safe_push/predict.py @@ -252,7 +252,7 @@ async def predict( inputs: dict, timeout_seconds: float, prediction_index: int | None = None, -) -> tuple[Any | None, str | None]: +) -> tuple[Any | None, str | None, str]: prefix = f"[{prediction_index}] " if prediction_index is not None else "" log.vv( f"{prefix}Running {'training' if train else 'prediction'} with inputs:\n{json.dumps(inputs, indent=2)}" @@ -283,7 +283,8 @@ async def predict( else: raise - log.v(f"{prefix}Prediction URL: https://replicate.com/p/{prediction.id}") + prediction_url = f"https://replicate.com/p/{prediction.id}" + log.v(f"{prefix}Prediction URL: {prediction_url}") while prediction.status not in ["succeeded", "failed", "canceled"]: await asyncio.sleep(0.5) @@ -295,7 +296,7 @@ async def predict( if prediction.status == "failed": log.v(f"{prefix}Got error: {prediction.error} ({duration:.2f} sec)") - return None, prediction.error + return None, prediction.error, prediction_url output = prediction.output if _has_output_iterator_array_type(version): @@ -303,4 +304,4 @@ async def predict( log.v(f"{prefix}Got output: {truncate(output)} ({duration:.2f} sec)") - return output, None + return output, None, prediction_url From f3942fcd45a93cf60bfb9c8df268536cd91d482b Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Thu, 23 Oct 2025 15:17:33 +0200 Subject: [PATCH 2/8] Update CheckOutputsMatch to store prediction URL --- cog_safe_push/tasks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cog_safe_push/tasks.py b/cog_safe_push/tasks.py index 7590034..283243b 100644 --- a/cog_safe_push/tasks.py +++ b/cog_safe_push/tasks.py @@ -28,6 +28,7 @@ class CheckOutputsMatch(Task): fuzz_disabled_inputs: list[str] fuzz_prompt: str | None prediction_index: int | None = None + prediction_url: str | None = None async def run(self) -> None: if self.first_test_case_inputs is not None: @@ -57,7 +58,7 @@ async def run(self) -> None: log.v( f"{prefix}Checking outputs match between existing version and test version, with inputs: {inputs}" ) - test_output, test_error = await predict( + test_output, test_error, test_url = await predict( model=self.context.test_model, train=self.context.is_train(), train_destination=self.context.train_destination, @@ -65,7 +66,8 @@ async def run(self) -> None: timeout_seconds=self.timeout_seconds, prediction_index=self.prediction_index, ) - output, error = await predict( + self.prediction_url = test_url + output, error, _ = await predict( model=self.context.model, train=self.context.is_train(), train_destination=self.context.train_destination, From b670a8cc34896954aea8e70b5863f0c41000297e Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Thu, 23 Oct 2025 15:17:41 +0200 Subject: [PATCH 3/8] Update RunTestCase to store prediction URL --- cog_safe_push/tasks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cog_safe_push/tasks.py b/cog_safe_push/tasks.py index 283243b..f5de1db 100644 --- a/cog_safe_push/tasks.py +++ b/cog_safe_push/tasks.py @@ -99,13 +99,14 @@ class RunTestCase(Task): checker: OutputChecker predict_timeout: int prediction_index: int | None = None + prediction_url: str | None = None async def run(self) -> None: prefix = ( f"[{self.prediction_index}] " if self.prediction_index is not None else "" ) log.v(f"{prefix}Running test case with inputs: {self.inputs}") - output, error = await predict( + output, error, url = await predict( model=self.context.test_model, train=self.context.is_train(), train_destination=self.context.train_destination, @@ -113,6 +114,7 @@ async def run(self) -> None: timeout_seconds=self.predict_timeout, prediction_index=self.prediction_index, ) + self.prediction_url = url await self.checker(output, error) From 68b89498aaa10b39b970016ecd12a7668a883bc4 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Thu, 23 Oct 2025 15:17:51 +0200 Subject: [PATCH 4/8] Update FuzzModel to store prediction URL --- cog_safe_push/tasks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cog_safe_push/tasks.py b/cog_safe_push/tasks.py index f5de1db..aeb8545 100644 --- a/cog_safe_push/tasks.py +++ b/cog_safe_push/tasks.py @@ -154,6 +154,7 @@ class FuzzModel(Task): inputs_queue: Queue[dict[str, Any]] predict_timeout: int prediction_index: int | None = None + prediction_url: str | None = None async def run(self) -> None: inputs = await asyncio.wait_for(self.inputs_queue.get(), timeout=60) @@ -163,7 +164,7 @@ async def run(self) -> None: ) log.v(f"{prefix}Fuzzing with inputs: {inputs}") try: - output, error = await predict( + output, error, url = await predict( model=self.context.test_model, train=self.context.is_train(), train_destination=self.context.train_destination, @@ -171,6 +172,7 @@ async def run(self) -> None: timeout_seconds=self.predict_timeout, prediction_index=self.prediction_index, ) + self.prediction_url = url except PredictionTimeoutError: raise FuzzError(f"{prefix}Prediction timed out") if error is not None: From 3c6ccc761ffb71e3cea5a0e249d55636f9e830cf Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Thu, 23 Oct 2025 15:18:05 +0200 Subject: [PATCH 5/8] Update run_tasks to show all failures with test case number and prediction URL --- cog_safe_push/main.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/cog_safe_push/main.py b/cog_safe_push/main.py index 6a31ff8..f29fb9a 100644 --- a/cog_safe_push/main.py +++ b/cog_safe_push/main.py @@ -451,14 +451,14 @@ async def run_tasks(tasks: list[Task], parallel: int) -> None: log.info(f"Running tasks with parallelism {parallel}") semaphore = asyncio.Semaphore(parallel) - errors: list[Exception] = [] + errors: list[tuple[Task, Exception]] = [] async def run_with_semaphore(task: Task) -> None: async with semaphore: try: await task.run() except Exception as e: - errors.append(e) + errors.append((task, e)) # Create task coroutines and run them concurrently task_coroutines = [run_with_semaphore(task) for task in tasks] @@ -467,11 +467,20 @@ async def run_with_semaphore(task: Task) -> None: await asyncio.gather(*task_coroutines, return_exceptions=True) if errors: - # If there are multiple errors, we'll raise the first one - # but log all of them - for error in errors[1:]: - log.error(f"Additional error occurred: {error}") - raise errors[0] + # Log all failures with their test case number and prediction URL + for task, error in errors: + prediction_index = getattr(task, "prediction_index", None) + prediction_url = getattr(task, "prediction_url", None) + + prefix = f"[{prediction_index}] " if prediction_index is not None else "" + message = str(error) + + if prediction_url: + log.error(f"{prefix}Test case failed: {message}; Prediction URL: {prediction_url}") + else: + log.error(f"{prefix}Test case failed: {message}") + + raise errors[0][1] def parse_inputs(inputs_list: list[str]) -> dict[str, Any]: From 5864a653ce624c9d3984d9bf4ce294eda5d69fb0 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Thu, 23 Oct 2025 15:18:38 +0200 Subject: [PATCH 6/8] Add imports for testing run_tasks --- test/test_main.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_main.py b/test/test_main.py index 13dec72..af1f9a2 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -1,13 +1,17 @@ +import asyncio + import pytest from cog_safe_push import log -from cog_safe_push.exceptions import ArgumentError +from cog_safe_push.exceptions import ArgumentError, TestCaseFailedError from cog_safe_push.main import ( parse_args_and_config, parse_input_value, parse_inputs, parse_model, + run_tasks, ) +from cog_safe_push.tasks import Task def test_parse_args_minimal(monkeypatch): From a5dc7d64a560852d564cf6787ca1704b18eb3315 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Thu, 23 Oct 2025 15:19:05 +0200 Subject: [PATCH 7/8] Add test for error reporting in run_tasks --- test/test_main.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/test/test_main.py b/test/test_main.py index af1f9a2..aa1bb4f 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -480,3 +480,42 @@ def test_parse_args_push_official_model(monkeypatch): assert config.model == "user/model" assert not no_push assert push_official_model + + +@pytest.mark.asyncio +async def test_run_tasks_reports_all_errors_with_details(capsys): + class FailingTask(Task): + def __init__(self, prediction_index, prediction_url, error_message): + self.prediction_index = prediction_index + self.prediction_url = prediction_url + self.error_message = error_message + + async def run(self): + raise TestCaseFailedError(self.error_message) + + class FailingTaskNoUrl(Task): + def __init__(self, prediction_index, error_message): + self.prediction_index = prediction_index + self.error_message = error_message + + async def run(self): + raise TestCaseFailedError(self.error_message) + + tasks = [ + FailingTask(1, "https://replicate.com/p/abc123", "Output mismatch"), + FailingTask(2, "https://replicate.com/p/def456", "Timeout occurred"), + FailingTaskNoUrl(3, "Invalid input"), + ] + + with pytest.raises(TestCaseFailedError): + await run_tasks(tasks, parallel=2) + + captured = capsys.readouterr() + error_output = captured.err + + # All errors should be logged + assert "[1] Test case failed: Test case failed: Output mismatch; Prediction URL: https://replicate.com/p/abc123" in error_output + assert "[2] Test case failed: Test case failed: Timeout occurred; Prediction URL: https://replicate.com/p/def456" in error_output + assert "[3] Test case failed: Test case failed: Invalid input" in error_output + # Should not have prediction URL in the third error + assert "Invalid input; Prediction URL:" not in error_output From ab6a98ad39e520cf442be7aa1fb54dadd6be8414 Mon Sep 17 00:00:00 2001 From: andreasjansson Date: Thu, 23 Oct 2025 15:20:47 +0200 Subject: [PATCH 8/8] lint --- cog_safe_push/main.py | 10 ++++++---- test/test_main.py | 12 ++++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/cog_safe_push/main.py b/cog_safe_push/main.py index f29fb9a..a892290 100644 --- a/cog_safe_push/main.py +++ b/cog_safe_push/main.py @@ -471,15 +471,17 @@ async def run_with_semaphore(task: Task) -> None: for task, error in errors: prediction_index = getattr(task, "prediction_index", None) prediction_url = getattr(task, "prediction_url", None) - + prefix = f"[{prediction_index}] " if prediction_index is not None else "" message = str(error) - + if prediction_url: - log.error(f"{prefix}Test case failed: {message}; Prediction URL: {prediction_url}") + log.error( + f"{prefix}Test case failed: {message}; Prediction URL: {prediction_url}" + ) else: log.error(f"{prefix}Test case failed: {message}") - + raise errors[0][1] diff --git a/test/test_main.py b/test/test_main.py index aa1bb4f..413ac84 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -1,5 +1,3 @@ -import asyncio - import pytest from cog_safe_push import log @@ -514,8 +512,14 @@ async def run(self): error_output = captured.err # All errors should be logged - assert "[1] Test case failed: Test case failed: Output mismatch; Prediction URL: https://replicate.com/p/abc123" in error_output - assert "[2] Test case failed: Test case failed: Timeout occurred; Prediction URL: https://replicate.com/p/def456" in error_output + assert ( + "[1] Test case failed: Test case failed: Output mismatch; Prediction URL: https://replicate.com/p/abc123" + in error_output + ) + assert ( + "[2] Test case failed: Test case failed: Timeout occurred; Prediction URL: https://replicate.com/p/def456" + in error_output + ) assert "[3] Test case failed: Test case failed: Invalid input" in error_output # Should not have prediction URL in the third error assert "Invalid input; Prediction URL:" not in error_output