Skip to content

feat: Add distributed version of ChunkSampler#150

Open
felix0097 wants to merge 5 commits intomainfrom
ff/torch-dist-sampler
Open

feat: Add distributed version of ChunkSampler#150
felix0097 wants to merge 5 commits intomainfrom
ff/torch-dist-sampler

Conversation

@felix0097
Copy link
Collaborator

This PR adds an extension of the ChunkSampler class that shards the dataset for training using torch.distributed.

@felix0097 felix0097 self-assigned this Feb 13, 2026
@felix0097 felix0097 added enhancement New feature or request skip-gpu-ci Whether gpu ci should be skipped labels Feb 13, 2026
@codecov
Copy link

codecov bot commented Feb 13, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 91.86%. Comparing base (c80f1ee) to head (570f30e).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #150      +/-   ##
==========================================
- Coverage   93.71%   91.86%   -1.85%     
==========================================
  Files          11       11              
  Lines         811      848      +37     
==========================================
+ Hits          760      779      +19     
- Misses         51       69      +18     
Files with missing lines Coverage Δ
src/annbatch/__init__.py 100.00% <100.00%> (ø)
src/annbatch/samplers/__init__.py 100.00% <100.00%> (ø)
src/annbatch/samplers/_chunk_sampler.py 95.20% <100.00%> (+2.01%) ⬆️

... and 3 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@ilan-gold ilan-gold left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be made generlizable. We can make a IsDistributable Protocol that the samplers inherit. That way anything that has implements a _mask attribute can be handled this way.

To do this, we may need to consider making _mask a bit more general, like a list[slice] for things like categorical samplers

Maybe I'm misunderstanding though, it seems like this is just masking.

Also no reason to have a class-per-package - if everything is just based on rank/world, we can make this generic (i.e., users can pass in rank/world if torch or one of the recognized packages is not installed)

Copy link
Collaborator

@ilan-gold ilan-gold left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If @giovp approves, so do I (same for @selmanozleyen )

Comment on lines +492 to +523
def test_enforce_equal_batches_rounds_down_per_rank(self, make_distributed_sampler):
"""enforce_equal_batches=True rounds per_rank down to a multiple of batch_size."""
n_obs, world_size = 107, 3
chunk_size, preload_nchunks, batch_size = 10, 1, 10
# raw per_rank = 107 // 3 = 35, rounded = 35 // 10 * 10 = 30
sampler = make_distributed_sampler(
rank=0,
world_size=world_size,
chunk_size=chunk_size,
preload_nchunks=preload_nchunks,
batch_size=batch_size,
enforce_equal_batches=True,
)
indices = collect_indices(sampler, n_obs)
assert len(set(indices)) == 30

def test_enforce_equal_batches_false_uses_raw_split(self, make_distributed_sampler):
"""enforce_equal_batches=False uses n_obs // world_size without rounding."""
n_obs, world_size = 107, 3
chunk_size, preload_nchunks, batch_size = 10, 1, 10
# raw per_rank = 107 // 3 = 35 (not rounded to 30)
sampler = make_distributed_sampler(
rank=0,
world_size=world_size,
chunk_size=chunk_size,
preload_nchunks=preload_nchunks,
batch_size=batch_size,
enforce_equal_batches=False,
)
indices = collect_indices(sampler, n_obs)
assert len(set(indices)) == 35

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this one test, parametrized by enforce_equal_batches



def _get_dist_info_torch() -> tuple[int, int]:
"""Get rank and world_size from ``torch.distributed``."""
Copy link
Member

@selmanozleyen selmanozleyen Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think really don't need backend="jax" kind of interface here. We can get as an input a callable like

class DistInfo(TypedDict):
   process_index
   process_count

class ChunkedDistSampler

    def __init__(
          dist_info_fn:Callable()->DistInfo(TypedDict)
    ):
    pass

Why a callable? So that the they are called whenever the distributed system is spawned/launched.

In an example tutorial we can show anyone can inject anything that gives these. Or we can also implement convenience dist_info_fn's. Anyone who's knowledgeable enough to need this can use this so we can have a more mature interface. No need to add training wheels that will age bad here IMO

Even though it's a very strong opinion of mine it's still just a preference.

Copy link
Collaborator Author

@felix0097 felix0097 Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point, I will add the option to provide a callable as well. Callable then just return a tuple (rank, world_size) then people can easily extend this to other frameworks.

I would keep the torch + jax option though. Most people who do distributed training are not really aware of those concepts. E.g lightning abstracts those away. So, it's easy for them to use it, they just need to know which framework their using.

Comment on lines +309 to +311
def _sample(self, n_obs: int) -> Iterator[LoadRequest]:
self._mask = self._shard_mask(n_obs)
yield from super()._sample(n_obs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit hacky and requires the child to know too much of the inner logics of it's parent functions details.

From @ilan-gold

This should be made generlizable. We can make a IsDistributable Protocol that the samplers inherit. That way anything that has implements a _mask attribute can be handled this way.

To do this, we may need to consider making _mask a bit more general, like a list[slice] for things like categorical samplers

Maybe I'm misunderstanding though, it seems like this is just masking.

Also no reason to have a class-per-package - if everything is just based on rank/world, we can make this generic (i.e., users can pass in rank/world if torch or one of the recognized packages is not installed)

I mostly agree. That's why I wanted to name chunksampler a window/mask/slice sampler to be technically correct.

For categorical samplers we shouldn't need list[slice] because categorical sampler should be able to use multiple chunksampler's right?

I think list[slice] isn't trivial for ChunkSampler itself because _compute_chunks relies on assumptions of continuity. For example the case for the last chunk. If we want to generalize maybe we can have a base class like MaskListSampler which enforces drop_last=True (even though drop_last isn't a problem for dist training because each model there is different we would need this if we want to generalize).

Easiest change could be having an effective_mask attribute which is the identity self._mask by default and passed to _sampler(n_obs,mask)? and can be overridden in IsDistributables.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put a pin into this and revisit if we see a need to update! For now, it's not really worth it (see comment from Ilan as well below)

Comment on lines +380 to +388
def _make_distributed_sampler_torch(rank: int, world_size: int, **kwargs) -> ChunkSamplerDistributed:
"""Create a ChunkSamplerDistributed with mocked torch.distributed backend."""
mock_dist = MagicMock()
mock_dist.is_initialized.return_value = True
mock_dist.get_rank.return_value = rank
mock_dist.get_world_size.return_value = world_size
mock_torch = MagicMock()
mock_torch.distributed = mock_dist
with patch.dict(sys.modules, {"torch": mock_torch, "torch.distributed": mock_dist}):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is why an dist_info_fn as input would be helpful we wouldn't need this hacky ways to test

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment above. I could remove those now that I added the dist_info callalbe. But then we don't have the tests anymore where we check if errors get raised if the dist backend is not initialized (like I said above, i would keep those for ease of use)

If *True*, round each rank's observation count down to a multiple of
``batch_size`` so that all ranks yield the same number of batches.
Set to *False* to use the raw ``n_obs // world_size`` split, which may
result in an uneven number of batches per worker.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"result in an uneven number of batches per worker"
worker or world? I'd to get the terminology right

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx, fixed it

batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
rng=rng,
Copy link
Member

@selmanozleyen selmanozleyen Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You haven't considered rng handling. Each node would have same rng passed would result in same batch_indices which isn't good for training. Here you can split the rng's based on the process_index by using _spawn_worker_rng like this

batch_rng = _spawn_worker_rng(self._rng, worker_info.id) if worker_info else self._rng

maybe we can even rename this to split_rng and make it return the root rng again to match https://docs.jax.dev/en/latest/_autosummary/jax.random.split.html wdyt @ilan-gold ?

please add a rng test where you can reproduce per each rank and maybe even per each worker in rank. Using ofc the mock dist and workers

n_batches = sum(len(lr["splits"]) for lr in sampler.sample(n_obs))
batch_counts.append(n_batches)

assert len(set(batch_counts)) == 1, f"Batch counts differ across ranks: {batch_counts}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not assert this to batch_size also like batch_counts[0] == batch_size

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point I don't have the batch_size info anymore, I'm just tracking the number of yielded batches here (that's what the test is supposed to cover). I'd argue what you're suggesting is already covered by the non distributed tests. So would focus the test here that the dataset sharding is done right etc

enforce_equal_batches=True,
)
indices = collect_indices(sampler, n_obs)
assert len(set(indices)) == 30
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this different when len(indices)? from what I know we shouldn't have overlapping indices right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if enforce_equal_batches=True we might not cover the full dataset anymore (as it is not guaranteed that we can equally shard it across x workers). That's what the test is supposed to cover

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But is it normal if len(set(indices)) != len(indices)? if so it's fine but it seems like len(indices) == 30 should also hold as well. Nevermind if we can have overlapping indices here

@ilan-gold
Copy link
Collaborator

In an example tutorial we can show anyone can inject anything that gives these. Or we can also implement convenience dist_info_fn's. Anyone who's knowledgeable enough to need this can use this so we can have a more mature interface. No need to add training wheels that will age bad here IMO

I am not sure what a "training wheel" is but I do think anyone should be able to provide their concept of rank/world (via a function like @selmanozleyen is saying) and get back a loader that does the thing. I don't think we need a tutorial for this though - it's pretty niche since we will provide out-of-the-box Jax and torch versions.

You haven't considered rng handling. Each node would have same rng passed would result in same batch_indices which isn't good for training.

Why would this matter? Each node would have a different subset of the pre-shuffled dataset. Like, concretely, what could go wrong here?

This is a bit hacky and requires the child to know too much of the inner logics of it's parent functions details.

Yeah, so I don't feel strongly. As a V1 this seems fine to me, we can always extend/grow the requirements but I don't have enough of a sense of where to go with a mask API for now. @selmanozleyen how would distributed training over a collection that is shuffled-by-category look? Probably masking out subsets of each category? But also giving categories per node? This probably depends on how the weights are merged based on my limited understanding.

In short, that's why I gave up on the point, but I'm happy to revisit (and we will need to definitely).

@selmanozleyen
Copy link
Member

Why would this matter? Each node would have a different subset of the pre-shuffled dataset. Like, concretely, what could go wrong here?

Maybe "bad for training" is too strong since it is hard to prove with concrete examples but in theory if you write down the training, batches are sampled with i.i.d. assumptions. Having same batch indices would hurt the assumptions more.

@selmanozleyen
Copy link
Member

@selmanozleyen how would distributed training over a collection that is shuffled-by-category look? Probably masking out subsets of each category? But also giving categories per node? This probably depends on how the weights are merged based on my limited understanding.

We don't need to share categories per nodes unless we have a concept of epoch (sampling without replacement). In fact we can't share categories per node during whole training if num_categories % num_nodes != 0 and weights would make it more complicated. So I wouldn't support sampling without replacement on distributed training for categoricals.

@ilan-gold ilan-gold changed the title feat: Add torch.distributed version of ChunkSampler feat: Add distributed version of ChunkSampler Feb 24, 2026
@ilan-gold
Copy link
Collaborator

API

It sounds like there would not be much overlap then so my original comment

As a V1 this seems fine to me, we can always extend/grow the requirements ...

holds @felix0097 so I think leaving the usage of _mask is fine for now. If we run across another use-case, we can generalize, but it sounds like the requirements for categorical samplers would be different.

RNG

Maybe "bad for training" is too strong since it is hard to prove with concrete examples but in theory if you write down the training, batches are sampled with i.i.d. assumptions. Having same batch indices would hurt the assumptions more

But if the data is pre-shuffled, we wouldn't care about having the same batch indices, no?

@selmanozleyen
Copy link
Member

But the preshuffling in theory could be learned by the model after certain amounts of epochs. When the chunks divide dataset evenly, chunk id to slice mapping is always the same, the chunk ids order is the only thing thats randomized at that stage. It is low probability but telling tryin to show a counter example exists in this case (they always exist but we would have one less one of it): if we have a degenerate rng in one part of the data that doesn't do a good job of shuffling batches and keeps the elements in a chunk in the same batch the model might pick up on that. You don't even need a degenerate rng, in theory you can reverse engineer by noticing a chunk of elements are always seen with at most preload_nchunk*chunks/batchsize adjacent batches. Then you can learn the chunks. So preshuffling in theory doesn't change iid assumptions at all, just makes the signal given to the model weaker.

In fact, I don't think it's a stretch at all, the model might pick up in each gradient step the given batch indices are the same because it will also have signal about the chunk locality already.

I think we should handle rng a bit more carefully at least. Here is an interesting paper how the seed can effect the training performance https://arxiv.org/abs/2109.08203 .

@felix0097
Copy link
Collaborator Author

Not sure I fully understand your statement @selmanozleyen. In general, we always load several chunks at the same time (e.g 128) and shuffle them in memory. So we do break up the on disk chunks structure that the model can't simply "remember" on disk chunks

batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
rng=_spawn_worker_rng(rng, self._rank) if rng else None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also added the rng handling like you suggested. Maybe this already addresses you concerns? @selmanozleyen

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that fixes having same batch indices across nodes

Copy link
Member

@selmanozleyen selmanozleyen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm.

Copy link
Collaborator

@ilan-gold ilan-gold left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok last question - I know this would be a big re-architecture (eh, not that big) but something we accidentally have as a result of ChuSampler.mask as an arg is training-test-validation splits with shuffled datasets.

How would one do that here? I think this is a point in favor of baking this into ChunkSampler

@selmanozleyen
Copy link
Member

selmanozleyen commented Feb 26, 2026

Ok last question - I know this would be a big re-architecture (eh, not that big) but something we accidentally have as a result of ChuSampler.mask as an arg is training-test-validation splits with shuffled datasets.

How would one do that here? I think this is a point in favor of baking this into ChunkSampler

Very good point. But can't we have a IntervalListSampler that does this and uses ChunkSampler internally? This way we won't have to check if everything supported by ChunkSampler is also supported when mask is a list of masks. Then categorical sampler would use that.

Also we can also just expect train, test, val to be categories in obs. I think this is also good practice. If we can have a categorical sampler and specify which categories we want it to sample from this way: CategoricalSampler(., groupby=['split', 'cell_line', 'drug'], select=(('train',),), ).

Like if we are to do a training pipeline guide for example we can suggest this to be more common. For example in perturbation training you can split by cell_line or do something more complicated but in the end I think it is always worth labeling each cell in which split they belong.

Plus:
Another good thing is for splits based on cell_line for example you wouldn't need to reorder the dataset for the new split. You can detect if that splits respect the categorical hierarchy (a cell_line is either in train or val for example no two cells with same category can be in different split) then you wouldn't need reordering on disk in this case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request skip-gpu-ci Whether gpu ci should be skipped

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants