-
Notifications
You must be signed in to change notification settings - Fork 3
Description
Problem
_partition_results in training_controller.py (line 409-414) uses simple round-robin assignment:
for i, result in enumerate(results):
partitions[i % self.dp_size].append(result)When micro_batch_size > 1, samples within a batch may have significantly different sequence lengths. Since padding aligns to the longest sequence per rank, some DP ranks end up with disproportionately more compute, causing faster ranks to idle at all-reduce barriers.
Note: This is only relevant when micro_batch_size > 1. With mbs=1, each rank gets exactly one sample and there's nothing to balance.
Possible Approaches
A. Greedy bin-packing at dispatch time
Replace round-robin in _partition_results with greedy assignment (longest-first, assign to lightest rank):
def _partition_results(self, results):
partitions = [[] for _ in range(self.dp_size)]
loads = [0] * self.dp_size
for result in sorted(results, key=lambda r: r.seq_length, reverse=True):
min_rank = loads.index(min(loads))
partitions[min_rank].append(result)
loads[min_rank] += result.seq_length
return partitionsPros: Minimal change, works with existing pipeline.
Cons: Limited by the samples available in a single dispatch batch.
B. Sample buffer with larger balancing window
Instead of dispatching immediately when pool_size >= dispatch_batch_size, maintain a larger sample buffer (e.g., 2x ~ 4x dispatch_batch_size) and pick the best-balanced subset to dispatch. This gives more room for bin-packing.
Pros: Better balance from a larger candidate pool.
Cons: Slightly higher dispatch latency; adds buffering complexity.
C. Pre-sort during data preprocessing
Sort or bucket samples by sequence length during dataset loading / tokenization, so that consecutive samples in the prompt buffer already have similar lengths. Dispatch then naturally produces balanced batches.
Pros: Zero dispatch overhead; works regardless of mbs.
Cons: Reduces data randomness (may need to shuffle within length buckets); requires changes to the data pipeline.
Considerations
- Only implement when
micro_batch_size > 1— withmbs=1this is a no-op - Measure DP rank training time variance before and after to validate impact
- Approach A is the simplest starting point; B and C are worth considering if A is insufficient
Files
torchspec/controller/training_controller.pytorchspec/utils/types.py(InferenceOutput may need seq_length field)torchspec/data/dataset.py(for approach C)