-
Notifications
You must be signed in to change notification settings - Fork 79
feat: wire async task-queue scheduler into ColumnWiseDatasetBuilder #429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
18f807a
5474d4e
cbbf162
f06b90d
3facfef
c650a2f
71e7412
102f9c6
d43fc41
259828d
40e16fe
37e3b62
954117b
d7dd2ee
b1b6741
372e274
2114d3b
166cfff
2fa4a3a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,8 @@ | |
| import contextlib | ||
| import logging | ||
| import time | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Any, Callable | ||
|
|
||
| import data_designer.lazy_heavy_imports as lazy | ||
|
|
@@ -27,6 +29,16 @@ | |
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class _RowGroupState: | ||
| """Lifecycle state for a single admitted row group.""" | ||
|
|
||
| size: int | ||
| seeds_dispatched: bool = False | ||
| pre_batch_done: bool = False | ||
| in_flight_count: int = 0 | ||
|
|
||
|
|
||
| class AsyncTaskScheduler: | ||
| """Dependency-aware async task scheduler for the dataset builder. | ||
|
|
||
|
|
@@ -46,6 +58,12 @@ def __init__( | |
| max_submitted_tasks: int = 256, | ||
| salvage_max_rounds: int = 2, | ||
| on_row_group_complete: Callable[[int], None] | None = None, | ||
| on_checkpoint_complete: Callable[[Path | str], None] | None = None, | ||
| on_seeds_complete: Callable[[int, int], None] | None = None, | ||
| on_before_checkpoint: Callable[[int, int], None] | None = None, | ||
| shutdown_error_rate: float = 0.5, | ||
| shutdown_error_window: int = 10, | ||
| disable_early_shutdown: bool = False, | ||
| trace: bool = False, | ||
| ) -> None: | ||
| self._generators = generators | ||
|
|
@@ -62,6 +80,15 @@ def __init__( | |
| self._wake_event = asyncio.Event() | ||
| self._salvage_max_rounds = salvage_max_rounds | ||
| self._on_row_group_complete = on_row_group_complete | ||
| self._on_checkpoint_complete = on_checkpoint_complete | ||
| self._on_seeds_complete = on_seeds_complete | ||
| self._on_before_checkpoint = on_before_checkpoint | ||
|
|
||
| # Error rate shutdown (caller passes pre-normalized values via RunConfig) | ||
| self._shutdown_error_rate = shutdown_error_rate | ||
| self._shutdown_error_window = shutdown_error_window | ||
| self._disable_early_shutdown = disable_early_shutdown | ||
| self._early_shutdown = False | ||
|
|
||
| # Multi-column dedup: group output columns by generator identity | ||
| instance_to_columns: dict[int, list[str]] = {} | ||
|
|
@@ -75,13 +102,12 @@ def __init__( | |
| if gen.is_order_dependent and id(gen) not in self._stateful_locks: | ||
| self._stateful_locks[id(gen)] = asyncio.Lock() | ||
|
|
||
| # Per-RG lifecycle state (admitted but not yet checkpointed) | ||
| self._rg_states: dict[int, _RowGroupState] = {} | ||
|
|
||
| # Deferred retryable failures (retried in salvage rounds) | ||
| self._deferred: list[Task] = [] | ||
|
|
||
| # Active row groups (admitted but not yet checkpointed) | ||
| self._active_rgs: list[tuple[int, int]] = [] | ||
| self._admitted_rg_ids: set[int] = set() | ||
|
|
||
| # Tracing | ||
| self._trace = trace | ||
| self.traces: list[TaskTrace] = [] | ||
|
|
@@ -94,12 +120,14 @@ def __init__( | |
| # Pre-compute row-group sizes for O(1) lookup | ||
| self._rg_size_map: dict[int, int] = dict(row_groups) | ||
|
|
||
| # Pre-compute seed columns (graph is static) | ||
| self._seed_cols: frozenset[str] = frozenset(c for c in graph.columns if not graph.get_upstream_columns(c)) | ||
|
|
||
| async def _admit_row_groups(self) -> None: | ||
| """Admit row groups as semaphore slots become available.""" | ||
| for rg_id, rg_size in self._row_groups: | ||
| await self._rg_semaphore.acquire() | ||
| self._active_rgs.append((rg_id, rg_size)) | ||
| self._admitted_rg_ids.add(rg_id) | ||
| self._rg_states[rg_id] = _RowGroupState(size=rg_size) | ||
|
|
||
| if self._buffer_manager is not None: | ||
| self._buffer_manager.init_row_group(rg_id, rg_size) | ||
|
|
@@ -112,25 +140,44 @@ async def _admit_row_groups(self) -> None: | |
| async def run(self) -> None: | ||
| """Main scheduler loop.""" | ||
| all_columns = self._graph.columns | ||
| seed_cols = self._seed_cols | ||
| has_pre_batch = self._on_seeds_complete is not None | ||
|
|
||
| # Launch admission as a background task so it interleaves with dispatch. | ||
andreatgretel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| admission_task = asyncio.create_task(self._admit_row_groups()) | ||
|
|
||
| # Main dispatch loop | ||
| while True: | ||
| if self._early_shutdown: | ||
| logger.warning("Early shutdown triggered - error rate exceeded threshold") | ||
| self._checkpoint_completed_row_groups(all_columns) | ||
| break | ||
|
|
||
| self._wake_event.clear() | ||
|
|
||
| ready = self._tracker.get_ready_tasks(self._dispatched, self._admitted_rg_ids) | ||
| self._run_seeds_complete_check(seed_cols) | ||
|
|
||
| admitted_ids = set(self._rg_states) | ||
| ready = self._tracker.get_ready_tasks(self._dispatched, admitted_ids) | ||
| # Gate non-seed tasks on pre-batch completion when a pre-batch callback is configured | ||
| if has_pre_batch: | ||
| ready = [ | ||
| t | ||
| for t in ready | ||
| if (s := self._rg_states.get(t.row_group)) is not None and s.pre_batch_done or t.column in seed_cols | ||
| ] | ||
| for task in ready: | ||
| await self._submission_semaphore.acquire() | ||
| self._dispatched.add(task) | ||
| self._in_flight.add(task) | ||
| if (s := self._rg_states.get(task.row_group)) is not None: | ||
| s.in_flight_count += 1 | ||
| asyncio.create_task(self._execute_task(task)) | ||
|
|
||
| self._checkpoint_completed_row_groups(all_columns) | ||
|
|
||
| # Are we done? | ||
| all_done = self._all_rgs_admitted and not self._active_rgs and not self._in_flight | ||
| all_done = self._all_rgs_admitted and not self._rg_states and not self._in_flight | ||
| if all_done: | ||
| break | ||
|
|
||
|
|
@@ -185,29 +232,41 @@ async def run(self) -> None: | |
| Task(column=task.column, row_group=task.row_group, row_index=None, task_type="batch") | ||
| ) | ||
| self._in_flight.add(task) | ||
| if (s := self._rg_states.get(task.row_group)) is not None: | ||
| s.in_flight_count += 1 | ||
| asyncio.create_task(self._execute_seed_task(task, gid)) | ||
| else: | ||
| self._dispatched.discard(task) | ||
| # Drain: dispatch frontier tasks and any newly-ready downstream tasks | ||
| # until nothing remains in-flight or in the frontier. | ||
| await self._drain_frontier() | ||
| await self._drain_frontier(seed_cols, has_pre_batch, all_columns) | ||
| self._checkpoint_completed_row_groups(all_columns) | ||
|
|
||
| if self._active_rgs: | ||
| incomplete = [rg_id for rg_id, _ in self._active_rgs] | ||
| if self._rg_states: | ||
| incomplete = list(self._rg_states) | ||
| logger.error( | ||
| f"Scheduler exited with {len(self._active_rgs)} unfinished row group(s): {incomplete}. " | ||
| f"Scheduler exited with {len(self._rg_states)} unfinished row group(s): {incomplete}. " | ||
| "These row groups were not checkpointed." | ||
| ) | ||
|
|
||
| async def _drain_frontier(self) -> None: | ||
| async def _drain_frontier(self, seed_cols: frozenset[str], has_pre_batch: bool, all_columns: list[str]) -> None: | ||
| """Dispatch all frontier tasks and their downstream until quiescent.""" | ||
| while True: | ||
| ready = self._tracker.get_ready_tasks(self._dispatched, self._admitted_rg_ids) | ||
| self._run_seeds_complete_check(seed_cols) | ||
| admitted_ids = set(self._rg_states) | ||
| ready = self._tracker.get_ready_tasks(self._dispatched, admitted_ids) | ||
| if has_pre_batch: | ||
| ready = [ | ||
| t | ||
| for t in ready | ||
| if (s := self._rg_states.get(t.row_group)) is not None and s.pre_batch_done or t.column in seed_cols | ||
| ] | ||
| for task in ready: | ||
| await self._submission_semaphore.acquire() | ||
| self._dispatched.add(task) | ||
| self._in_flight.add(task) | ||
| if (s := self._rg_states.get(task.row_group)) is not None: | ||
| s.in_flight_count += 1 | ||
| asyncio.create_task(self._execute_task(task)) | ||
| if not self._in_flight: | ||
| break | ||
|
|
@@ -217,25 +276,81 @@ async def _drain_frontier(self) -> None: | |
| def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: | ||
| """Checkpoint any row groups that reached completion.""" | ||
| completed = [ | ||
| (rg_id, rg_size) | ||
| for rg_id, rg_size in self._active_rgs | ||
| if self._tracker.is_row_group_complete(rg_id, rg_size, all_columns) | ||
| (rg_id, state.size) | ||
| for rg_id, state in self._rg_states.items() | ||
| if self._tracker.is_row_group_complete(rg_id, state.size, all_columns) | ||
| ] | ||
| for rg_id, rg_size in completed: | ||
| self._active_rgs.remove((rg_id, rg_size)) | ||
| del self._rg_states[rg_id] | ||
| dropped = False | ||
| try: | ||
andreatgretel marked this conversation as resolved.
Show resolved
Hide resolved
andreatgretel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if self._buffer_manager is not None: | ||
| self._buffer_manager.checkpoint_row_group(rg_id) | ||
| if self._on_row_group_complete: | ||
| if self._on_before_checkpoint: | ||
| try: | ||
| self._on_before_checkpoint(rg_id, rg_size) | ||
| except Exception: | ||
| # Post-batch is mandatory; drop rather than checkpoint unprocessed data. | ||
| logger.error( | ||
| f"on_before_checkpoint failed for row group {rg_id}, dropping row group.", | ||
| exc_info=True, | ||
| ) | ||
| for ri in range(rg_size): | ||
| if self._buffer_manager: | ||
|
Comment on lines
+285
to
+297
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When In the current scheduler this is functionally harmless because Syncing the drops to the tracker mirrors what except Exception:
logger.error(
f"on_before_checkpoint failed for row group {rg_id}, dropping row group.",
exc_info=True,
)
for ri in range(rg_size):
if self._buffer_manager:
self._buffer_manager.drop_row(rg_id, ri)
self._tracker.drop_row(rg_id, ri) # keep tracker in syncPrompt To Fix With AIThis is a comment left during a code review.
Path: packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py
Line: 285-297
Comment:
**`on_before_checkpoint` failure drops buffer rows but not tracker rows**
When `on_before_checkpoint` raises (lines 287–297), every row in the row group is marked as dropped in `self._buffer_manager` but `self._tracker.drop_row(rg_id, ri)` is never called. The tracker still has those rows marked as *complete* rather than *dropped*.
In the current scheduler this is functionally harmless because `del self._rg_states[rg_id]` runs before the try-block (line 284), so `admitted_ids` will never include this `rg_id` again and `get_ready_tasks` will silently discard any stale frontier entries. However, `tracker.is_dropped(rg_id, ri)` returns `False` for rows that were actually discarded, making post-mortem inspection of the tracker (e.g., via `task_traces` or future diagnostics) misleading.
Syncing the drops to the tracker mirrors what `on_seeds_complete` already does (lines 272–274) and makes the two failure paths consistent:
```python
except Exception:
logger.error(
f"on_before_checkpoint failed for row group {rg_id}, dropping row group.",
exc_info=True,
)
for ri in range(rg_size):
if self._buffer_manager:
self._buffer_manager.drop_row(rg_id, ri)
self._tracker.drop_row(rg_id, ri) # keep tracker in sync
```
How can I resolve this? If you propose a fix, please make it concise. |
||
| self._buffer_manager.drop_row(rg_id, ri) | ||
| dropped = True | ||
| if not dropped and self._buffer_manager is not None: | ||
| if self._on_checkpoint_complete is not None: | ||
|
|
||
| def on_complete(final_path: Path | str | None) -> None: | ||
| if final_path is not None: | ||
| self._on_checkpoint_complete(final_path) | ||
|
|
||
| self._buffer_manager.checkpoint_row_group(rg_id, on_complete=on_complete) | ||
| else: | ||
| self._buffer_manager.checkpoint_row_group(rg_id) | ||
| if not dropped and self._on_row_group_complete: | ||
| self._on_row_group_complete(rg_id) | ||
| except Exception: | ||
| logger.error(f"Failed to checkpoint row group {rg_id}.", exc_info=True) | ||
| finally: | ||
| self._rg_semaphore.release() | ||
andreatgretel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None: | ||
| """Run pre-batch callbacks for row groups whose seeds just completed.""" | ||
| for rg_id, state in list(self._rg_states.items()): | ||
| if state.seeds_dispatched and not state.pre_batch_done: | ||
| all_seeds_done = all(self._tracker.is_column_complete_for_rg(col, rg_id) for col in seed_cols) | ||
| if all_seeds_done and state.in_flight_count == 0: | ||
| state.pre_batch_done = True | ||
| if self._on_seeds_complete: | ||
| try: | ||
| self._on_seeds_complete(rg_id, state.size) | ||
| except Exception as exc: | ||
| logger.warning(f"Pre-batch processor failed for row group {rg_id}, skipping: {exc}") | ||
| for ri in range(state.size): | ||
| self._tracker.drop_row(rg_id, ri) | ||
| if self._buffer_manager: | ||
| self._buffer_manager.drop_row(rg_id, ri) | ||
andreatgretel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _in_flight_for_rg(self, rg_id: int) -> bool: | ||
| """Check if any tasks are in-flight for a given row group.""" | ||
| state = self._rg_states.get(rg_id) | ||
| return state is not None and state.in_flight_count > 0 | ||
|
|
||
| def _check_error_rate(self) -> None: | ||
| """Trigger early shutdown if error rate exceeds threshold.""" | ||
| if self._disable_early_shutdown: | ||
| return | ||
| completed = self._success_count + self._error_count | ||
| if completed < self._shutdown_error_window: | ||
| return | ||
| error_rate = self._error_count / max(1, completed) | ||
| if error_rate > self._shutdown_error_rate: | ||
| self._early_shutdown = True | ||
|
|
||
| async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: | ||
| """Dispatch from_scratch tasks for a row group.""" | ||
| seed_cols = [col for col in self._graph.get_topological_order() if not self._graph.get_upstream_columns(col)] | ||
| self._rg_states[rg_id].seeds_dispatched = True | ||
| seed_cols = self._seed_cols | ||
| seen_instances: set[int] = set() | ||
|
|
||
| for col in seed_cols: | ||
|
|
@@ -268,6 +383,8 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: | |
| ) | ||
| self._dispatched.add(Task(column=sibling_col, row_group=rg_id, row_index=None, task_type="batch")) | ||
| self._in_flight.add(task) | ||
| if (s := self._rg_states.get(task.row_group)) is not None: | ||
| s.in_flight_count += 1 | ||
| asyncio.create_task(self._execute_seed_task(task, gid)) | ||
|
|
||
| async def _execute_seed_task(self, task: Task, generator_id: int) -> None: | ||
|
|
@@ -300,7 +417,7 @@ async def _execute_task_inner(self, task: Task) -> None: | |
| # Skip tasks whose row group was already checkpointed (can happen | ||
| # when a vacuously-ready downstream is dispatched via create_task | ||
| # in the same loop iteration that checkpoints the row group). | ||
| if not any(rg_id == task.row_group for rg_id, _ in self._active_rgs): | ||
| if task.row_group not in self._rg_states: | ||
| skipped = True | ||
| return | ||
|
|
||
|
|
@@ -330,6 +447,7 @@ async def _execute_task_inner(self, task: Task) -> None: | |
|
|
||
| except Exception as exc: | ||
| self._error_count += 1 | ||
| self._check_error_rate() | ||
| if self._trace and trace: | ||
| trace.status = "error" | ||
| trace.error = str(exc) | ||
|
|
@@ -360,6 +478,8 @@ async def _execute_task_inner(self, task: Task) -> None: | |
| self.traces.append(trace) | ||
|
|
||
| self._in_flight.discard(task) | ||
| if (s := self._rg_states.get(task.row_group)) is not None: | ||
| s.in_flight_count = max(0, s.in_flight_count - 1) | ||
| if not retryable and not skipped: | ||
| self._dispatched.discard(task) | ||
| self._submission_semaphore.release() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
per-taskhere may need a little more context?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring already reads "collect per-task tracing data when using the async engine" - lmk if you'd like more detail