Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
18f807a
feat: wire async task-queue scheduler into ColumnWiseDatasetBuilder
andreatgretel Mar 17, 2026
5474d4e
chore: add async benchmark notebook and demo scripts
andreatgretel Mar 17, 2026
cbbf162
fix: address all PR review comments on async builder integration
andreatgretel Mar 18, 2026
f06b90d
style: fix ruff format for lambda expression
andreatgretel Mar 18, 2026
3facfef
fix: address open review issues on async scheduler
andreatgretel Mar 18, 2026
c650a2f
fix: sync pre-batch row drops to CompletionTracker and restore stderr…
andreatgretel Mar 18, 2026
71e7412
fix: prune _admitted_rg_ids on row group checkpoint
andreatgretel Mar 18, 2026
102f9c6
Merge branch 'main' into andreatgretel/feat/async-builder-integration
andreatgretel Mar 18, 2026
d43fc41
chore: remove demo/async files from PR
andreatgretel Mar 18, 2026
259828d
fix: wire disable_early_shutdown into AsyncTaskScheduler
andreatgretel Mar 18, 2026
40e16fe
test: add e2e test for async engine concurrency
andreatgretel Mar 18, 2026
37e3b62
fix: drop row group on on_before_checkpoint failure instead of writin…
andreatgretel Mar 18, 2026
954117b
fix: skip on_before_checkpoint when no POST_BATCH processors configured
andreatgretel Mar 18, 2026
d7dd2ee
fix: address remaining review nits from nabinchha and greptile summary
andreatgretel Mar 19, 2026
b1b6741
fix: preserve async callback contract and e2e setup
andreatgretel Mar 19, 2026
372e274
fix: prune _seeds_dispatched_rgs and _pre_batch_done_rgs on checkpoint
andreatgretel Mar 19, 2026
2114d3b
refactor: consolidate per-RG state into _RowGroupState dataclass
andreatgretel Mar 19, 2026
166cfff
fix: skip checkpoint and callbacks when on_before_checkpoint fails
andreatgretel Mar 19, 2026
2fa4a3a
Merge branch 'main' into andreatgretel/feat/async-builder-integration
andreatgretel Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class RunConfig(ConfigBase):
max_conversation_correction_steps: Maximum number of correction rounds permitted within a
single conversation when generation tasks call `ModelFacade.generate(...)`. Must be >= 0.
Default is 0.
async_trace: If True, collect per-task tracing data when using the async engine
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per-task here may need a little more context?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring already reads "collect per-task tracing data when using the async engine" - lmk if you'd like more detail

(DATA_DESIGNER_ASYNC_ENGINE=1). Has no effect on the sync path. Default is False.
"""

disable_early_shutdown: bool = False
Expand All @@ -42,6 +44,7 @@ class RunConfig(ConfigBase):
non_inference_max_parallel_workers: int = Field(default=4, ge=1)
max_conversation_restarts: int = Field(default=5, ge=0)
max_conversation_correction_steps: int = Field(default=0, ge=0)
async_trace: bool = False

@model_validator(mode="after")
def normalize_shutdown_settings(self) -> Self:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import contextlib
import logging
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

import data_designer.lazy_heavy_imports as lazy
Expand All @@ -27,6 +29,16 @@
logger = logging.getLogger(__name__)


@dataclass
class _RowGroupState:
"""Lifecycle state for a single admitted row group."""

size: int
seeds_dispatched: bool = False
pre_batch_done: bool = False
in_flight_count: int = 0


class AsyncTaskScheduler:
"""Dependency-aware async task scheduler for the dataset builder.

Expand All @@ -46,6 +58,12 @@ def __init__(
max_submitted_tasks: int = 256,
salvage_max_rounds: int = 2,
on_row_group_complete: Callable[[int], None] | None = None,
on_checkpoint_complete: Callable[[Path | str], None] | None = None,
on_seeds_complete: Callable[[int, int], None] | None = None,
on_before_checkpoint: Callable[[int, int], None] | None = None,
shutdown_error_rate: float = 0.5,
shutdown_error_window: int = 10,
disable_early_shutdown: bool = False,
trace: bool = False,
) -> None:
self._generators = generators
Expand All @@ -62,6 +80,15 @@ def __init__(
self._wake_event = asyncio.Event()
self._salvage_max_rounds = salvage_max_rounds
self._on_row_group_complete = on_row_group_complete
self._on_checkpoint_complete = on_checkpoint_complete
self._on_seeds_complete = on_seeds_complete
self._on_before_checkpoint = on_before_checkpoint

# Error rate shutdown (caller passes pre-normalized values via RunConfig)
self._shutdown_error_rate = shutdown_error_rate
self._shutdown_error_window = shutdown_error_window
self._disable_early_shutdown = disable_early_shutdown
self._early_shutdown = False

# Multi-column dedup: group output columns by generator identity
instance_to_columns: dict[int, list[str]] = {}
Expand All @@ -75,13 +102,12 @@ def __init__(
if gen.is_order_dependent and id(gen) not in self._stateful_locks:
self._stateful_locks[id(gen)] = asyncio.Lock()

# Per-RG lifecycle state (admitted but not yet checkpointed)
self._rg_states: dict[int, _RowGroupState] = {}

# Deferred retryable failures (retried in salvage rounds)
self._deferred: list[Task] = []

# Active row groups (admitted but not yet checkpointed)
self._active_rgs: list[tuple[int, int]] = []
self._admitted_rg_ids: set[int] = set()

# Tracing
self._trace = trace
self.traces: list[TaskTrace] = []
Expand All @@ -94,12 +120,14 @@ def __init__(
# Pre-compute row-group sizes for O(1) lookup
self._rg_size_map: dict[int, int] = dict(row_groups)

# Pre-compute seed columns (graph is static)
self._seed_cols: frozenset[str] = frozenset(c for c in graph.columns if not graph.get_upstream_columns(c))

async def _admit_row_groups(self) -> None:
"""Admit row groups as semaphore slots become available."""
for rg_id, rg_size in self._row_groups:
await self._rg_semaphore.acquire()
self._active_rgs.append((rg_id, rg_size))
self._admitted_rg_ids.add(rg_id)
self._rg_states[rg_id] = _RowGroupState(size=rg_size)

if self._buffer_manager is not None:
self._buffer_manager.init_row_group(rg_id, rg_size)
Expand All @@ -112,25 +140,44 @@ async def _admit_row_groups(self) -> None:
async def run(self) -> None:
"""Main scheduler loop."""
all_columns = self._graph.columns
seed_cols = self._seed_cols
has_pre_batch = self._on_seeds_complete is not None

# Launch admission as a background task so it interleaves with dispatch.
admission_task = asyncio.create_task(self._admit_row_groups())

# Main dispatch loop
while True:
if self._early_shutdown:
logger.warning("Early shutdown triggered - error rate exceeded threshold")
self._checkpoint_completed_row_groups(all_columns)
break

self._wake_event.clear()

ready = self._tracker.get_ready_tasks(self._dispatched, self._admitted_rg_ids)
self._run_seeds_complete_check(seed_cols)

admitted_ids = set(self._rg_states)
ready = self._tracker.get_ready_tasks(self._dispatched, admitted_ids)
# Gate non-seed tasks on pre-batch completion when a pre-batch callback is configured
if has_pre_batch:
ready = [
t
for t in ready
if (s := self._rg_states.get(t.row_group)) is not None and s.pre_batch_done or t.column in seed_cols
]
for task in ready:
await self._submission_semaphore.acquire()
self._dispatched.add(task)
self._in_flight.add(task)
if (s := self._rg_states.get(task.row_group)) is not None:
s.in_flight_count += 1
asyncio.create_task(self._execute_task(task))

self._checkpoint_completed_row_groups(all_columns)

# Are we done?
all_done = self._all_rgs_admitted and not self._active_rgs and not self._in_flight
all_done = self._all_rgs_admitted and not self._rg_states and not self._in_flight
if all_done:
break

Expand Down Expand Up @@ -185,29 +232,41 @@ async def run(self) -> None:
Task(column=task.column, row_group=task.row_group, row_index=None, task_type="batch")
)
self._in_flight.add(task)
if (s := self._rg_states.get(task.row_group)) is not None:
s.in_flight_count += 1
asyncio.create_task(self._execute_seed_task(task, gid))
else:
self._dispatched.discard(task)
# Drain: dispatch frontier tasks and any newly-ready downstream tasks
# until nothing remains in-flight or in the frontier.
await self._drain_frontier()
await self._drain_frontier(seed_cols, has_pre_batch, all_columns)
self._checkpoint_completed_row_groups(all_columns)

if self._active_rgs:
incomplete = [rg_id for rg_id, _ in self._active_rgs]
if self._rg_states:
incomplete = list(self._rg_states)
logger.error(
f"Scheduler exited with {len(self._active_rgs)} unfinished row group(s): {incomplete}. "
f"Scheduler exited with {len(self._rg_states)} unfinished row group(s): {incomplete}. "
"These row groups were not checkpointed."
)

async def _drain_frontier(self) -> None:
async def _drain_frontier(self, seed_cols: frozenset[str], has_pre_batch: bool, all_columns: list[str]) -> None:
"""Dispatch all frontier tasks and their downstream until quiescent."""
while True:
ready = self._tracker.get_ready_tasks(self._dispatched, self._admitted_rg_ids)
self._run_seeds_complete_check(seed_cols)
admitted_ids = set(self._rg_states)
ready = self._tracker.get_ready_tasks(self._dispatched, admitted_ids)
if has_pre_batch:
ready = [
t
for t in ready
if (s := self._rg_states.get(t.row_group)) is not None and s.pre_batch_done or t.column in seed_cols
]
for task in ready:
await self._submission_semaphore.acquire()
self._dispatched.add(task)
self._in_flight.add(task)
if (s := self._rg_states.get(task.row_group)) is not None:
s.in_flight_count += 1
asyncio.create_task(self._execute_task(task))
if not self._in_flight:
break
Expand All @@ -217,25 +276,81 @@ async def _drain_frontier(self) -> None:
def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None:
"""Checkpoint any row groups that reached completion."""
completed = [
(rg_id, rg_size)
for rg_id, rg_size in self._active_rgs
if self._tracker.is_row_group_complete(rg_id, rg_size, all_columns)
(rg_id, state.size)
for rg_id, state in self._rg_states.items()
if self._tracker.is_row_group_complete(rg_id, state.size, all_columns)
]
for rg_id, rg_size in completed:
self._active_rgs.remove((rg_id, rg_size))
del self._rg_states[rg_id]
dropped = False
try:
if self._buffer_manager is not None:
self._buffer_manager.checkpoint_row_group(rg_id)
if self._on_row_group_complete:
if self._on_before_checkpoint:
try:
self._on_before_checkpoint(rg_id, rg_size)
except Exception:
# Post-batch is mandatory; drop rather than checkpoint unprocessed data.
logger.error(
f"on_before_checkpoint failed for row group {rg_id}, dropping row group.",
exc_info=True,
)
for ri in range(rg_size):
if self._buffer_manager:
Comment on lines +285 to +297
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 on_before_checkpoint failure drops buffer rows but not tracker rows

When on_before_checkpoint raises (lines 287–297), every row in the row group is marked as dropped in self._buffer_manager but self._tracker.drop_row(rg_id, ri) is never called. The tracker still has those rows marked as complete rather than dropped.

In the current scheduler this is functionally harmless because del self._rg_states[rg_id] runs before the try-block (line 284), so admitted_ids will never include this rg_id again and get_ready_tasks will silently discard any stale frontier entries. However, tracker.is_dropped(rg_id, ri) returns False for rows that were actually discarded, making post-mortem inspection of the tracker (e.g., via task_traces or future diagnostics) misleading.

Syncing the drops to the tracker mirrors what on_seeds_complete already does (lines 272–274) and makes the two failure paths consistent:

except Exception:
    logger.error(
        f"on_before_checkpoint failed for row group {rg_id}, dropping row group.",
        exc_info=True,
    )
    for ri in range(rg_size):
        if self._buffer_manager:
            self._buffer_manager.drop_row(rg_id, ri)
        self._tracker.drop_row(rg_id, ri)   # keep tracker in sync
Prompt To Fix With AI
This is a comment left during a code review.
Path: packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py
Line: 285-297

Comment:
**`on_before_checkpoint` failure drops buffer rows but not tracker rows**

When `on_before_checkpoint` raises (lines 287–297), every row in the row group is marked as dropped in `self._buffer_manager` but `self._tracker.drop_row(rg_id, ri)` is never called. The tracker still has those rows marked as *complete* rather than *dropped*.

In the current scheduler this is functionally harmless because `del self._rg_states[rg_id]` runs before the try-block (line 284), so `admitted_ids` will never include this `rg_id` again and `get_ready_tasks` will silently discard any stale frontier entries. However, `tracker.is_dropped(rg_id, ri)` returns `False` for rows that were actually discarded, making post-mortem inspection of the tracker (e.g., via `task_traces` or future diagnostics) misleading.

Syncing the drops to the tracker mirrors what `on_seeds_complete` already does (lines 272–274) and makes the two failure paths consistent:

```python
except Exception:
    logger.error(
        f"on_before_checkpoint failed for row group {rg_id}, dropping row group.",
        exc_info=True,
    )
    for ri in range(rg_size):
        if self._buffer_manager:
            self._buffer_manager.drop_row(rg_id, ri)
        self._tracker.drop_row(rg_id, ri)   # keep tracker in sync
```

How can I resolve this? If you propose a fix, please make it concise.

self._buffer_manager.drop_row(rg_id, ri)
dropped = True
if not dropped and self._buffer_manager is not None:
if self._on_checkpoint_complete is not None:

def on_complete(final_path: Path | str | None) -> None:
if final_path is not None:
self._on_checkpoint_complete(final_path)

self._buffer_manager.checkpoint_row_group(rg_id, on_complete=on_complete)
else:
self._buffer_manager.checkpoint_row_group(rg_id)
if not dropped and self._on_row_group_complete:
self._on_row_group_complete(rg_id)
except Exception:
logger.error(f"Failed to checkpoint row group {rg_id}.", exc_info=True)
finally:
self._rg_semaphore.release()

def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None:
"""Run pre-batch callbacks for row groups whose seeds just completed."""
for rg_id, state in list(self._rg_states.items()):
if state.seeds_dispatched and not state.pre_batch_done:
all_seeds_done = all(self._tracker.is_column_complete_for_rg(col, rg_id) for col in seed_cols)
if all_seeds_done and state.in_flight_count == 0:
state.pre_batch_done = True
if self._on_seeds_complete:
try:
self._on_seeds_complete(rg_id, state.size)
except Exception as exc:
logger.warning(f"Pre-batch processor failed for row group {rg_id}, skipping: {exc}")
for ri in range(state.size):
self._tracker.drop_row(rg_id, ri)
if self._buffer_manager:
self._buffer_manager.drop_row(rg_id, ri)

def _in_flight_for_rg(self, rg_id: int) -> bool:
"""Check if any tasks are in-flight for a given row group."""
state = self._rg_states.get(rg_id)
return state is not None and state.in_flight_count > 0

def _check_error_rate(self) -> None:
"""Trigger early shutdown if error rate exceeds threshold."""
if self._disable_early_shutdown:
return
completed = self._success_count + self._error_count
if completed < self._shutdown_error_window:
return
error_rate = self._error_count / max(1, completed)
if error_rate > self._shutdown_error_rate:
self._early_shutdown = True

async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None:
"""Dispatch from_scratch tasks for a row group."""
seed_cols = [col for col in self._graph.get_topological_order() if not self._graph.get_upstream_columns(col)]
self._rg_states[rg_id].seeds_dispatched = True
seed_cols = self._seed_cols
seen_instances: set[int] = set()

for col in seed_cols:
Expand Down Expand Up @@ -268,6 +383,8 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None:
)
self._dispatched.add(Task(column=sibling_col, row_group=rg_id, row_index=None, task_type="batch"))
self._in_flight.add(task)
if (s := self._rg_states.get(task.row_group)) is not None:
s.in_flight_count += 1
asyncio.create_task(self._execute_seed_task(task, gid))

async def _execute_seed_task(self, task: Task, generator_id: int) -> None:
Expand Down Expand Up @@ -300,7 +417,7 @@ async def _execute_task_inner(self, task: Task) -> None:
# Skip tasks whose row group was already checkpointed (can happen
# when a vacuously-ready downstream is dispatched via create_task
# in the same loop iteration that checkpoints the row group).
if not any(rg_id == task.row_group for rg_id, _ in self._active_rgs):
if task.row_group not in self._rg_states:
skipped = True
return

Expand Down Expand Up @@ -330,6 +447,7 @@ async def _execute_task_inner(self, task: Task) -> None:

except Exception as exc:
self._error_count += 1
self._check_error_rate()
if self._trace and trace:
trace.status = "error"
trace.error = str(exc)
Expand Down Expand Up @@ -360,6 +478,8 @@ async def _execute_task_inner(self, task: Task) -> None:
self.traces.append(trace)

self._in_flight.discard(task)
if (s := self._rg_states.get(task.row_group)) is not None:
s.in_flight_count = max(0, s.in_flight_count - 1)
if not retryable and not skipped:
self._dispatched.discard(task)
self._submission_semaphore.release()
Expand Down
Loading
Loading