feat: Add distributed version of ChunkSampler#150
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #150 +/- ##
==========================================
- Coverage 93.71% 91.86% -1.85%
==========================================
Files 11 11
Lines 811 848 +37
==========================================
+ Hits 760 779 +19
- Misses 51 69 +18
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
This should be made generlizable. We can make a IsDistributable Protocol that the samplers inherit. That way anything that has implements a _mask attribute can be handled this way.
To do this, we may need to consider making _mask a bit more general, like a list[slice] for things like categorical samplers
Maybe I'm misunderstanding though, it seems like this is just masking.
Also no reason to have a class-per-package - if everything is just based on rank/world, we can make this generic (i.e., users can pass in rank/world if torch or one of the recognized packages is not installed)
There was a problem hiding this comment.
If @giovp approves, so do I (same for @selmanozleyen )
tests/test_sampler.py
Outdated
| def test_enforce_equal_batches_rounds_down_per_rank(self, make_distributed_sampler): | ||
| """enforce_equal_batches=True rounds per_rank down to a multiple of batch_size.""" | ||
| n_obs, world_size = 107, 3 | ||
| chunk_size, preload_nchunks, batch_size = 10, 1, 10 | ||
| # raw per_rank = 107 // 3 = 35, rounded = 35 // 10 * 10 = 30 | ||
| sampler = make_distributed_sampler( | ||
| rank=0, | ||
| world_size=world_size, | ||
| chunk_size=chunk_size, | ||
| preload_nchunks=preload_nchunks, | ||
| batch_size=batch_size, | ||
| enforce_equal_batches=True, | ||
| ) | ||
| indices = collect_indices(sampler, n_obs) | ||
| assert len(set(indices)) == 30 | ||
|
|
||
| def test_enforce_equal_batches_false_uses_raw_split(self, make_distributed_sampler): | ||
| """enforce_equal_batches=False uses n_obs // world_size without rounding.""" | ||
| n_obs, world_size = 107, 3 | ||
| chunk_size, preload_nchunks, batch_size = 10, 1, 10 | ||
| # raw per_rank = 107 // 3 = 35 (not rounded to 30) | ||
| sampler = make_distributed_sampler( | ||
| rank=0, | ||
| world_size=world_size, | ||
| chunk_size=chunk_size, | ||
| preload_nchunks=preload_nchunks, | ||
| batch_size=batch_size, | ||
| enforce_equal_batches=False, | ||
| ) | ||
| indices = collect_indices(sampler, n_obs) | ||
| assert len(set(indices)) == 35 | ||
|
|
There was a problem hiding this comment.
Make this one test, parametrized by enforce_equal_batches
|
|
||
|
|
||
| def _get_dist_info_torch() -> tuple[int, int]: | ||
| """Get rank and world_size from ``torch.distributed``.""" |
There was a problem hiding this comment.
I think really don't need backend="jax" kind of interface here. We can get as an input a callable like
class DistInfo(TypedDict):
process_index
process_count
class ChunkedDistSampler
def __init__(
dist_info_fn:Callable()->DistInfo(TypedDict)
):
passWhy a callable? So that the they are called whenever the distributed system is spawned/launched.
In an example tutorial we can show anyone can inject anything that gives these. Or we can also implement convenience dist_info_fn's. Anyone who's knowledgeable enough to need this can use this so we can have a more mature interface. No need to add training wheels that will age bad here IMO
Even though it's a very strong opinion of mine it's still just a preference.
There was a problem hiding this comment.
Fair point, I will add the option to provide a callable as well. Callable then just return a tuple (rank, world_size) then people can easily extend this to other frameworks.
I would keep the torch + jax option though. Most people who do distributed training are not really aware of those concepts. E.g lightning abstracts those away. So, it's easy for them to use it, they just need to know which framework their using.
| def _sample(self, n_obs: int) -> Iterator[LoadRequest]: | ||
| self._mask = self._shard_mask(n_obs) | ||
| yield from super()._sample(n_obs) |
There was a problem hiding this comment.
This is a bit hacky and requires the child to know too much of the inner logics of it's parent functions details.
From @ilan-gold
This should be made generlizable. We can make a IsDistributable Protocol that the samplers inherit. That way anything that has implements a _mask attribute can be handled this way.
To do this, we may need to consider making _mask a bit more general, like a list[slice] for things like categorical samplers
Maybe I'm misunderstanding though, it seems like this is just masking.
Also no reason to have a class-per-package - if everything is just based on rank/world, we can make this generic (i.e., users can pass in rank/world if torch or one of the recognized packages is not installed)
I mostly agree. That's why I wanted to name chunksampler a window/mask/slice sampler to be technically correct.
For categorical samplers we shouldn't need list[slice] because categorical sampler should be able to use multiple chunksampler's right?
I think list[slice] isn't trivial for ChunkSampler itself because _compute_chunks relies on assumptions of continuity. For example the case for the last chunk. If we want to generalize maybe we can have a base class like MaskListSampler which enforces drop_last=True (even though drop_last isn't a problem for dist training because each model there is different we would need this if we want to generalize).
Easiest change could be having an effective_mask attribute which is the identity self._mask by default and passed to _sampler(n_obs,mask)? and can be overridden in IsDistributables.
There was a problem hiding this comment.
Let's put a pin into this and revisit if we see a need to update! For now, it's not really worth it (see comment from Ilan as well below)
| def _make_distributed_sampler_torch(rank: int, world_size: int, **kwargs) -> ChunkSamplerDistributed: | ||
| """Create a ChunkSamplerDistributed with mocked torch.distributed backend.""" | ||
| mock_dist = MagicMock() | ||
| mock_dist.is_initialized.return_value = True | ||
| mock_dist.get_rank.return_value = rank | ||
| mock_dist.get_world_size.return_value = world_size | ||
| mock_torch = MagicMock() | ||
| mock_torch.distributed = mock_dist | ||
| with patch.dict(sys.modules, {"torch": mock_torch, "torch.distributed": mock_dist}): |
There was a problem hiding this comment.
this is why an dist_info_fn as input would be helpful we wouldn't need this hacky ways to test
There was a problem hiding this comment.
See my comment above. I could remove those now that I added the dist_info callalbe. But then we don't have the tests anymore where we check if errors get raised if the dist backend is not initialized (like I said above, i would keep those for ease of use)
| If *True*, round each rank's observation count down to a multiple of | ||
| ``batch_size`` so that all ranks yield the same number of batches. | ||
| Set to *False* to use the raw ``n_obs // world_size`` split, which may | ||
| result in an uneven number of batches per worker. |
There was a problem hiding this comment.
"result in an uneven number of batches per worker"
worker or world? I'd to get the terminology right
| batch_size=batch_size, | ||
| shuffle=shuffle, | ||
| drop_last=drop_last, | ||
| rng=rng, |
There was a problem hiding this comment.
You haven't considered rng handling. Each node would have same rng passed would result in same batch_indices which isn't good for training. Here you can split the rng's based on the process_index by using _spawn_worker_rng like this
maybe we can even rename this to split_rng and make it return the root rng again to match https://docs.jax.dev/en/latest/_autosummary/jax.random.split.html wdyt @ilan-gold ?
please add a rng test where you can reproduce per each rank and maybe even per each worker in rank. Using ofc the mock dist and workers
| n_batches = sum(len(lr["splits"]) for lr in sampler.sample(n_obs)) | ||
| batch_counts.append(n_batches) | ||
|
|
||
| assert len(set(batch_counts)) == 1, f"Batch counts differ across ranks: {batch_counts}" |
There was a problem hiding this comment.
why not assert this to batch_size also like batch_counts[0] == batch_size
There was a problem hiding this comment.
at this point I don't have the batch_size info anymore, I'm just tracking the number of yielded batches here (that's what the test is supposed to cover). I'd argue what you're suggesting is already covered by the non distributed tests. So would focus the test here that the dataset sharding is done right etc
tests/test_sampler.py
Outdated
| enforce_equal_batches=True, | ||
| ) | ||
| indices = collect_indices(sampler, n_obs) | ||
| assert len(set(indices)) == 30 |
There was a problem hiding this comment.
is this different when len(indices)? from what I know we shouldn't have overlapping indices right?
There was a problem hiding this comment.
if enforce_equal_batches=True we might not cover the full dataset anymore (as it is not guaranteed that we can equally shard it across x workers). That's what the test is supposed to cover
There was a problem hiding this comment.
But is it normal if len(set(indices)) != len(indices)? if so it's fine but it seems like len(indices) == 30 should also hold as well. Nevermind if we can have overlapping indices here
I am not sure what a "training wheel" is but I do think anyone should be able to provide their concept of rank/world (via a function like @selmanozleyen is saying) and get back a loader that does the thing. I don't think we need a tutorial for this though - it's pretty niche since we will provide out-of-the-box Jax and torch versions.
Why would this matter? Each node would have a different subset of the pre-shuffled dataset. Like, concretely, what could go wrong here?
Yeah, so I don't feel strongly. As a V1 this seems fine to me, we can always extend/grow the requirements but I don't have enough of a sense of where to go with a In short, that's why I gave up on the point, but I'm happy to revisit (and we will need to definitely). |
Maybe "bad for training" is too strong since it is hard to prove with concrete examples but in theory if you write down the training, batches are sampled with i.i.d. assumptions. Having same batch indices would hurt the assumptions more. |
We don't need to share categories per nodes unless we have a concept of epoch (sampling without replacement). In fact we can't share categories per node during whole training if |
ChunkSamplerChunkSampler
APIIt sounds like there would not be much overlap then so my original comment
holds @felix0097 so I think leaving the usage of RNG
But if the data is pre-shuffled, we wouldn't care about having the same batch indices, no? |
|
But the preshuffling in theory could be learned by the model after certain amounts of epochs. When the chunks divide dataset evenly, chunk id to slice mapping is always the same, the chunk ids order is the only thing thats randomized at that stage. It is low probability but telling tryin to show a counter example exists in this case (they always exist but we would have one less one of it): if we have a degenerate rng in one part of the data that doesn't do a good job of shuffling batches and keeps the elements in a chunk in the same batch the model might pick up on that. You don't even need a degenerate rng, in theory you can reverse engineer by noticing a chunk of elements are always seen with at most preload_nchunk*chunks/batchsize adjacent batches. Then you can learn the chunks. So preshuffling in theory doesn't change iid assumptions at all, just makes the signal given to the model weaker. In fact, I don't think it's a stretch at all, the model might pick up in each gradient step the given batch indices are the same because it will also have signal about the chunk locality already. I think we should handle rng a bit more carefully at least. Here is an interesting paper how the seed can effect the training performance https://arxiv.org/abs/2109.08203 . |
|
Not sure I fully understand your statement @selmanozleyen. In general, we always load several chunks at the same time (e.g 128) and shuffle them in memory. So we do break up the on disk chunks structure that the model can't simply "remember" on disk chunks |
| batch_size=batch_size, | ||
| shuffle=shuffle, | ||
| drop_last=drop_last, | ||
| rng=_spawn_worker_rng(rng, self._rank) if rng else None, |
There was a problem hiding this comment.
I also added the rng handling like you suggested. Maybe this already addresses you concerns? @selmanozleyen
There was a problem hiding this comment.
Yeah that fixes having same batch indices across nodes
ilan-gold
left a comment
There was a problem hiding this comment.
Ok last question - I know this would be a big re-architecture (eh, not that big) but something we accidentally have as a result of ChuSampler.mask as an arg is training-test-validation splits with shuffled datasets.
How would one do that here? I think this is a point in favor of baking this into ChunkSampler
Very good point. But can't we have a IntervalListSampler that does this and uses ChunkSampler internally? This way we won't have to check if everything supported by ChunkSampler is also supported when mask is a list of masks. Then categorical sampler would use that. Also we can also just expect train, test, val to be categories in obs. I think this is also good practice. If we can have a categorical sampler and specify which categories we want it to sample from this way: Like if we are to do a training pipeline guide for example we can suggest this to be more common. For example in perturbation training you can split by cell_line or do something more complicated but in the end I think it is always worth labeling each cell in which split they belong. Plus: |
This PR adds an extension of the
ChunkSamplerclass that shards the dataset for training usingtorch.distributed.