diff --git a/cog_safe_push/main.py b/cog_safe_push/main.py index 6a31ff8..a892290 100644 --- a/cog_safe_push/main.py +++ b/cog_safe_push/main.py @@ -451,14 +451,14 @@ async def run_tasks(tasks: list[Task], parallel: int) -> None: log.info(f"Running tasks with parallelism {parallel}") semaphore = asyncio.Semaphore(parallel) - errors: list[Exception] = [] + errors: list[tuple[Task, Exception]] = [] async def run_with_semaphore(task: Task) -> None: async with semaphore: try: await task.run() except Exception as e: - errors.append(e) + errors.append((task, e)) # Create task coroutines and run them concurrently task_coroutines = [run_with_semaphore(task) for task in tasks] @@ -467,11 +467,22 @@ async def run_with_semaphore(task: Task) -> None: await asyncio.gather(*task_coroutines, return_exceptions=True) if errors: - # If there are multiple errors, we'll raise the first one - # but log all of them - for error in errors[1:]: - log.error(f"Additional error occurred: {error}") - raise errors[0] + # Log all failures with their test case number and prediction URL + for task, error in errors: + prediction_index = getattr(task, "prediction_index", None) + prediction_url = getattr(task, "prediction_url", None) + + prefix = f"[{prediction_index}] " if prediction_index is not None else "" + message = str(error) + + if prediction_url: + log.error( + f"{prefix}Test case failed: {message}; Prediction URL: {prediction_url}" + ) + else: + log.error(f"{prefix}Test case failed: {message}") + + raise errors[0][1] def parse_inputs(inputs_list: list[str]) -> dict[str, Any]: diff --git a/cog_safe_push/predict.py b/cog_safe_push/predict.py index 4e15217..cc99347 100644 --- a/cog_safe_push/predict.py +++ b/cog_safe_push/predict.py @@ -252,7 +252,7 @@ async def predict( inputs: dict, timeout_seconds: float, prediction_index: int | None = None, -) -> tuple[Any | None, str | None]: +) -> tuple[Any | None, str | None, str]: prefix = f"[{prediction_index}] " if prediction_index is not None else "" log.vv( f"{prefix}Running {'training' if train else 'prediction'} with inputs:\n{json.dumps(inputs, indent=2)}" @@ -283,7 +283,8 @@ async def predict( else: raise - log.v(f"{prefix}Prediction URL: https://replicate.com/p/{prediction.id}") + prediction_url = f"https://replicate.com/p/{prediction.id}" + log.v(f"{prefix}Prediction URL: {prediction_url}") while prediction.status not in ["succeeded", "failed", "canceled"]: await asyncio.sleep(0.5) @@ -295,7 +296,7 @@ async def predict( if prediction.status == "failed": log.v(f"{prefix}Got error: {prediction.error} ({duration:.2f} sec)") - return None, prediction.error + return None, prediction.error, prediction_url output = prediction.output if _has_output_iterator_array_type(version): @@ -303,4 +304,4 @@ async def predict( log.v(f"{prefix}Got output: {truncate(output)} ({duration:.2f} sec)") - return output, None + return output, None, prediction_url diff --git a/cog_safe_push/tasks.py b/cog_safe_push/tasks.py index 7590034..aeb8545 100644 --- a/cog_safe_push/tasks.py +++ b/cog_safe_push/tasks.py @@ -28,6 +28,7 @@ class CheckOutputsMatch(Task): fuzz_disabled_inputs: list[str] fuzz_prompt: str | None prediction_index: int | None = None + prediction_url: str | None = None async def run(self) -> None: if self.first_test_case_inputs is not None: @@ -57,7 +58,7 @@ async def run(self) -> None: log.v( f"{prefix}Checking outputs match between existing version and test version, with inputs: {inputs}" ) - test_output, test_error = await predict( + test_output, test_error, test_url = await predict( model=self.context.test_model, train=self.context.is_train(), train_destination=self.context.train_destination, @@ -65,7 +66,8 @@ async def run(self) -> None: timeout_seconds=self.timeout_seconds, prediction_index=self.prediction_index, ) - output, error = await predict( + self.prediction_url = test_url + output, error, _ = await predict( model=self.context.model, train=self.context.is_train(), train_destination=self.context.train_destination, @@ -97,13 +99,14 @@ class RunTestCase(Task): checker: OutputChecker predict_timeout: int prediction_index: int | None = None + prediction_url: str | None = None async def run(self) -> None: prefix = ( f"[{self.prediction_index}] " if self.prediction_index is not None else "" ) log.v(f"{prefix}Running test case with inputs: {self.inputs}") - output, error = await predict( + output, error, url = await predict( model=self.context.test_model, train=self.context.is_train(), train_destination=self.context.train_destination, @@ -111,6 +114,7 @@ async def run(self) -> None: timeout_seconds=self.predict_timeout, prediction_index=self.prediction_index, ) + self.prediction_url = url await self.checker(output, error) @@ -150,6 +154,7 @@ class FuzzModel(Task): inputs_queue: Queue[dict[str, Any]] predict_timeout: int prediction_index: int | None = None + prediction_url: str | None = None async def run(self) -> None: inputs = await asyncio.wait_for(self.inputs_queue.get(), timeout=60) @@ -159,7 +164,7 @@ async def run(self) -> None: ) log.v(f"{prefix}Fuzzing with inputs: {inputs}") try: - output, error = await predict( + output, error, url = await predict( model=self.context.test_model, train=self.context.is_train(), train_destination=self.context.train_destination, @@ -167,6 +172,7 @@ async def run(self) -> None: timeout_seconds=self.predict_timeout, prediction_index=self.prediction_index, ) + self.prediction_url = url except PredictionTimeoutError: raise FuzzError(f"{prefix}Prediction timed out") if error is not None: diff --git a/test/test_main.py b/test/test_main.py index 13dec72..413ac84 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -1,13 +1,15 @@ import pytest from cog_safe_push import log -from cog_safe_push.exceptions import ArgumentError +from cog_safe_push.exceptions import ArgumentError, TestCaseFailedError from cog_safe_push.main import ( parse_args_and_config, parse_input_value, parse_inputs, parse_model, + run_tasks, ) +from cog_safe_push.tasks import Task def test_parse_args_minimal(monkeypatch): @@ -476,3 +478,48 @@ def test_parse_args_push_official_model(monkeypatch): assert config.model == "user/model" assert not no_push assert push_official_model + + +@pytest.mark.asyncio +async def test_run_tasks_reports_all_errors_with_details(capsys): + class FailingTask(Task): + def __init__(self, prediction_index, prediction_url, error_message): + self.prediction_index = prediction_index + self.prediction_url = prediction_url + self.error_message = error_message + + async def run(self): + raise TestCaseFailedError(self.error_message) + + class FailingTaskNoUrl(Task): + def __init__(self, prediction_index, error_message): + self.prediction_index = prediction_index + self.error_message = error_message + + async def run(self): + raise TestCaseFailedError(self.error_message) + + tasks = [ + FailingTask(1, "https://replicate.com/p/abc123", "Output mismatch"), + FailingTask(2, "https://replicate.com/p/def456", "Timeout occurred"), + FailingTaskNoUrl(3, "Invalid input"), + ] + + with pytest.raises(TestCaseFailedError): + await run_tasks(tasks, parallel=2) + + captured = capsys.readouterr() + error_output = captured.err + + # All errors should be logged + assert ( + "[1] Test case failed: Test case failed: Output mismatch; Prediction URL: https://replicate.com/p/abc123" + in error_output + ) + assert ( + "[2] Test case failed: Test case failed: Timeout occurred; Prediction URL: https://replicate.com/p/def456" + in error_output + ) + assert "[3] Test case failed: Test case failed: Invalid input" in error_output + # Should not have prediction URL in the third error + assert "Invalid input; Prediction URL:" not in error_output