Skip to content

Model communication-computation overlap in the sharding ILP#353

Draft
fmassa wants to merge 3 commits intomainfrom
fmassa/comms_compute_overlap_model
Draft

Model communication-computation overlap in the sharding ILP#353
fmassa wants to merge 3 commits intomainfrom
fmassa/comms_compute_overlap_model

Conversation

@fmassa
Copy link
Contributor

@fmassa fmassa commented Mar 7, 2026

The ILP objective currently treats communication and computation as fully sequential (Σ(comm + compute) * x), which is an upper bound on actual runtime. In practice, parameter redistributions (e.g., all-gathers) can overlap with preceding compute (prefetch), and gradient reduce-scatters can overlap with subsequent compute (post-compute overlap).

This PR models overlap within the ILP using continuous "savings" variables. For each overlappable edge, a savings variable is created with:

  • savings <= comm_cost(selected) — can't save more than the communication
  • Σ savings_using_A <= compute_cost(A, selected) — can't save more than the available compute budget

The solver maximizes savings (since they're subtracted from the objective), computing savings = min(comm, compute_budget).

Two scan passes identify overlappable edges:

  • Forward scan: edges from parameter-derived inputs (propagated transitively through dtype_cast, views, etc.) overlap with preceding compute
  • Reverse scan: edges into terminal-derived nodes (all paths lead to output) overlap with subsequent compute

A shared compute budget constraint across both scans prevents double-counting.

The feature is off by default (enable_prefetch_overlap=False) and can be enabled via AutoParallel or auto_parallel().

Authored with Claude.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 7, 2026
fmassa added 2 commits March 8, 2026 12:18
Previously each savings variable was added directly to every compute partner's budget constraint, which meant savings <= min(compute_cost(A), compute_cost(B)) — bounding it
by the smallest node in the group. This is too conservative: a 12-unit comm overlapping with two compute nodes (5 and 10) should allow up to 12 units of savings, not 5.

Fix by splitting each savings into per-node contribution variables (savings = contrib_A + contrib_B), where each contribution is non-negative and participates in its node's
budget constraint. The solver can now allocate e.g. 5 from A and 7 from B to fully hide 12 units of comm.

Authored with Claude.
… savings logging

Three improvements to the prefetch overlap model:

The forward scan now creates savings variables for all param-derived input edges (not just boundary edges), since the ILP may place the all-gather anywhere in the param
chain (e.g. param → dtype_cast). The compute group is only reset at boundary edges (non-param-derived consumer), so intermediate param-derived edges share the same compute
window and don't fragment the budget. The reverse scan applies symmetric boundary logic.

The violated-constraints logger now uses a 1e-6 tolerance, fixing false positives from floating-point residuals in the continuous savings/contribution equality constraints.

The cost summary now reports overlap_savings and effective_cost when prefetch overlap is enabled.

Authored with Claude.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant