Skip to content

feat: reduce DP rank imbalance from variable sequence lengths #13

@cicirori

Description

@cicirori

Problem

_partition_results in training_controller.py (line 409-414) uses simple round-robin assignment:

for i, result in enumerate(results):
    partitions[i % self.dp_size].append(result)

When micro_batch_size > 1, samples within a batch may have significantly different sequence lengths. Since padding aligns to the longest sequence per rank, some DP ranks end up with disproportionately more compute, causing faster ranks to idle at all-reduce barriers.

Note: This is only relevant when micro_batch_size > 1. With mbs=1, each rank gets exactly one sample and there's nothing to balance.

Possible Approaches

A. Greedy bin-packing at dispatch time

Replace round-robin in _partition_results with greedy assignment (longest-first, assign to lightest rank):

def _partition_results(self, results):
    partitions = [[] for _ in range(self.dp_size)]
    loads = [0] * self.dp_size
    for result in sorted(results, key=lambda r: r.seq_length, reverse=True):
        min_rank = loads.index(min(loads))
        partitions[min_rank].append(result)
        loads[min_rank] += result.seq_length
    return partitions

Pros: Minimal change, works with existing pipeline.
Cons: Limited by the samples available in a single dispatch batch.

B. Sample buffer with larger balancing window

Instead of dispatching immediately when pool_size >= dispatch_batch_size, maintain a larger sample buffer (e.g., 2x ~ 4x dispatch_batch_size) and pick the best-balanced subset to dispatch. This gives more room for bin-packing.

Pros: Better balance from a larger candidate pool.
Cons: Slightly higher dispatch latency; adds buffering complexity.

C. Pre-sort during data preprocessing

Sort or bucket samples by sequence length during dataset loading / tokenization, so that consecutive samples in the prompt buffer already have similar lengths. Dispatch then naturally produces balanced batches.

Pros: Zero dispatch overhead; works regardless of mbs.
Cons: Reduces data randomness (may need to shuffle within length buckets); requires changes to the data pipeline.

Considerations

  • Only implement when micro_batch_size > 1 — with mbs=1 this is a no-op
  • Measure DP rank training time variance before and after to validate impact
  • Approach A is the simplest starting point; B and C are worth considering if A is insufficient

Files

  • torchspec/controller/training_controller.py
  • torchspec/utils/types.py (InferenceOutput may need seq_length field)
  • torchspec/data/dataset.py (for approach C)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions