Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions cog_safe_push/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,14 +451,14 @@ async def run_tasks(tasks: list[Task], parallel: int) -> None:
log.info(f"Running tasks with parallelism {parallel}")

semaphore = asyncio.Semaphore(parallel)
errors: list[Exception] = []
errors: list[tuple[Task, Exception]] = []

async def run_with_semaphore(task: Task) -> None:
async with semaphore:
try:
await task.run()
except Exception as e:
errors.append(e)
errors.append((task, e))

# Create task coroutines and run them concurrently
task_coroutines = [run_with_semaphore(task) for task in tasks]
Expand All @@ -467,11 +467,22 @@ async def run_with_semaphore(task: Task) -> None:
await asyncio.gather(*task_coroutines, return_exceptions=True)

if errors:
# If there are multiple errors, we'll raise the first one
# but log all of them
for error in errors[1:]:
log.error(f"Additional error occurred: {error}")
raise errors[0]
# Log all failures with their test case number and prediction URL
for task, error in errors:
prediction_index = getattr(task, "prediction_index", None)
prediction_url = getattr(task, "prediction_url", None)

prefix = f"[{prediction_index}] " if prediction_index is not None else ""
message = str(error)

if prediction_url:
log.error(
f"{prefix}Test case failed: {message}; Prediction URL: {prediction_url}"
)
else:
log.error(f"{prefix}Test case failed: {message}")

raise errors[0][1]


def parse_inputs(inputs_list: list[str]) -> dict[str, Any]:
Expand Down
9 changes: 5 additions & 4 deletions cog_safe_push/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ async def predict(
inputs: dict,
timeout_seconds: float,
prediction_index: int | None = None,
) -> tuple[Any | None, str | None]:
) -> tuple[Any | None, str | None, str]:
prefix = f"[{prediction_index}] " if prediction_index is not None else ""
log.vv(
f"{prefix}Running {'training' if train else 'prediction'} with inputs:\n{json.dumps(inputs, indent=2)}"
Expand Down Expand Up @@ -283,7 +283,8 @@ async def predict(
else:
raise

log.v(f"{prefix}Prediction URL: https://replicate.com/p/{prediction.id}")
prediction_url = f"https://replicate.com/p/{prediction.id}"
log.v(f"{prefix}Prediction URL: {prediction_url}")

while prediction.status not in ["succeeded", "failed", "canceled"]:
await asyncio.sleep(0.5)
Expand All @@ -295,12 +296,12 @@ async def predict(

if prediction.status == "failed":
log.v(f"{prefix}Got error: {prediction.error} ({duration:.2f} sec)")
return None, prediction.error
return None, prediction.error, prediction_url

output = prediction.output
if _has_output_iterator_array_type(version):
output = "".join(cast("list[str]", output))

log.v(f"{prefix}Got output: {truncate(output)} ({duration:.2f} sec)")

return output, None
return output, None, prediction_url
14 changes: 10 additions & 4 deletions cog_safe_push/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class CheckOutputsMatch(Task):
fuzz_disabled_inputs: list[str]
fuzz_prompt: str | None
prediction_index: int | None = None
prediction_url: str | None = None

async def run(self) -> None:
if self.first_test_case_inputs is not None:
Expand Down Expand Up @@ -57,15 +58,16 @@ async def run(self) -> None:
log.v(
f"{prefix}Checking outputs match between existing version and test version, with inputs: {inputs}"
)
test_output, test_error = await predict(
test_output, test_error, test_url = await predict(
model=self.context.test_model,
train=self.context.is_train(),
train_destination=self.context.train_destination,
inputs=inputs,
timeout_seconds=self.timeout_seconds,
prediction_index=self.prediction_index,
)
output, error = await predict(
self.prediction_url = test_url
output, error, _ = await predict(
model=self.context.model,
train=self.context.is_train(),
train_destination=self.context.train_destination,
Expand Down Expand Up @@ -97,20 +99,22 @@ class RunTestCase(Task):
checker: OutputChecker
predict_timeout: int
prediction_index: int | None = None
prediction_url: str | None = None

async def run(self) -> None:
prefix = (
f"[{self.prediction_index}] " if self.prediction_index is not None else ""
)
log.v(f"{prefix}Running test case with inputs: {self.inputs}")
output, error = await predict(
output, error, url = await predict(
model=self.context.test_model,
train=self.context.is_train(),
train_destination=self.context.train_destination,
inputs=self.inputs,
timeout_seconds=self.predict_timeout,
prediction_index=self.prediction_index,
)
self.prediction_url = url

await self.checker(output, error)

Expand Down Expand Up @@ -150,6 +154,7 @@ class FuzzModel(Task):
inputs_queue: Queue[dict[str, Any]]
predict_timeout: int
prediction_index: int | None = None
prediction_url: str | None = None

async def run(self) -> None:
inputs = await asyncio.wait_for(self.inputs_queue.get(), timeout=60)
Expand All @@ -159,14 +164,15 @@ async def run(self) -> None:
)
log.v(f"{prefix}Fuzzing with inputs: {inputs}")
try:
output, error = await predict(
output, error, url = await predict(
model=self.context.test_model,
train=self.context.is_train(),
train_destination=self.context.train_destination,
inputs=inputs,
timeout_seconds=self.predict_timeout,
prediction_index=self.prediction_index,
)
self.prediction_url = url
except PredictionTimeoutError:
raise FuzzError(f"{prefix}Prediction timed out")
if error is not None:
Expand Down
49 changes: 48 additions & 1 deletion test/test_main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import pytest

from cog_safe_push import log
from cog_safe_push.exceptions import ArgumentError
from cog_safe_push.exceptions import ArgumentError, TestCaseFailedError
from cog_safe_push.main import (
parse_args_and_config,
parse_input_value,
parse_inputs,
parse_model,
run_tasks,
)
from cog_safe_push.tasks import Task


def test_parse_args_minimal(monkeypatch):
Expand Down Expand Up @@ -476,3 +478,48 @@ def test_parse_args_push_official_model(monkeypatch):
assert config.model == "user/model"
assert not no_push
assert push_official_model


@pytest.mark.asyncio
async def test_run_tasks_reports_all_errors_with_details(capsys):
class FailingTask(Task):
def __init__(self, prediction_index, prediction_url, error_message):
self.prediction_index = prediction_index
self.prediction_url = prediction_url
self.error_message = error_message

async def run(self):
raise TestCaseFailedError(self.error_message)

class FailingTaskNoUrl(Task):
def __init__(self, prediction_index, error_message):
self.prediction_index = prediction_index
self.error_message = error_message

async def run(self):
raise TestCaseFailedError(self.error_message)

tasks = [
FailingTask(1, "https://replicate.com/p/abc123", "Output mismatch"),
FailingTask(2, "https://replicate.com/p/def456", "Timeout occurred"),
FailingTaskNoUrl(3, "Invalid input"),
]

with pytest.raises(TestCaseFailedError):
await run_tasks(tasks, parallel=2)

captured = capsys.readouterr()
error_output = captured.err

# All errors should be logged
assert (
"[1] Test case failed: Test case failed: Output mismatch; Prediction URL: https://replicate.com/p/abc123"
in error_output
)
assert (
"[2] Test case failed: Test case failed: Timeout occurred; Prediction URL: https://replicate.com/p/def456"
in error_output
)
assert "[3] Test case failed: Test case failed: Invalid input" in error_output
# Should not have prediction URL in the third error
assert "Invalid input; Prediction URL:" not in error_output