-
Notifications
You must be signed in to change notification settings - Fork 3
feat: Add distributed version of ChunkSampler
#150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b0ee20d
570f30e
5216dc9
7002aa9
772162c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| from ._chunk_sampler import ChunkSampler | ||
| from ._chunk_sampler import ChunkSampler, ChunkSamplerDistributed | ||
|
|
||
| __all__ = [ | ||
| "ChunkSampler", | ||
| "ChunkSamplerDistributed", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |
| from __future__ import annotations | ||
|
|
||
| import math | ||
| from typing import TYPE_CHECKING | ||
| from typing import TYPE_CHECKING, Literal | ||
|
|
||
| import numpy as np | ||
|
|
||
|
|
@@ -12,7 +12,7 @@ | |
| from annbatch.utils import _spawn_worker_rng, check_lt_1, split_given_size | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Iterator | ||
| from collections.abc import Callable, Iterator | ||
|
|
||
| from annbatch.types import LoadRequest | ||
|
|
||
|
|
@@ -184,3 +184,127 @@ def _compute_chunks(self, chunk_indices: np.ndarray, start: int, stop: int) -> l | |
| offsets = np.cumsum(offsets) | ||
| starts, stops = offsets[:-1][chunk_indices], offsets[1:][chunk_indices] | ||
| return [slice(int(s), int(e)) for s, e in zip(starts, stops, strict=True)] | ||
|
|
||
|
|
||
| def _get_dist_info_torch() -> tuple[int, int]: | ||
| """Get rank and world_size from ``torch.distributed``.""" | ||
| import torch.distributed as dist | ||
|
|
||
| if not dist.is_initialized(): | ||
| raise RuntimeError( | ||
| "torch.distributed is not initialized. " | ||
| "Initialize it before creating a ChunkSamplerDistributed with dist_info='torch'." | ||
| ) | ||
| return dist.get_rank(), dist.get_world_size() | ||
|
|
||
|
|
||
| def _get_dist_info_jax() -> tuple[int, int]: | ||
| """Get rank and world_size from JAX multi-process API.""" | ||
| import jax | ||
|
|
||
| if not jax.distributed.is_initialized(): | ||
| raise RuntimeError( | ||
| "JAX distributed is not initialized. " | ||
| "Call jax.distributed.initialize() before creating a ChunkSamplerDistributed with dist_info='jax'." | ||
| ) | ||
| return jax.process_index(), jax.process_count() | ||
|
|
||
|
|
||
| DISTRIBUTED_BACKENDS: dict[str, Callable[[], tuple[int, int]]] = { | ||
| "torch": _get_dist_info_torch, | ||
| "jax": _get_dist_info_jax, | ||
| } | ||
|
|
||
|
|
||
| class ChunkSamplerDistributed(ChunkSampler): | ||
| """Distributed chunk-based sampler that shards data across distributed processes. | ||
|
|
||
| Partitions the full observation range into ``world_size`` contiguous shards | ||
| using the ``mask`` mechanism of :class:`ChunkSampler`. Each rank receives a | ||
| non-overlapping slice of the data. The shard boundaries are computed lazily | ||
| when ``n_obs`` becomes known. | ||
|
|
||
| When ``enforce_equal_batches`` is *True* (the default), the per-rank observation | ||
| count is rounded down to the nearest multiple of ``batch_size``, | ||
| guaranteeing that every rank yields exactly the same number of complete | ||
| batches. | ||
|
|
||
| Rank and world size are obtained from ``dist_info`` at construction time. | ||
| The corresponding distributed framework must already be initialized. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| chunk_size | ||
| Size of each chunk i.e. the range of each chunk yielded. | ||
| preload_nchunks | ||
| Number of chunks to load per iteration. | ||
| batch_size | ||
| Number of observations per batch. | ||
| dist_info | ||
| How to obtain rank and world size. | ||
| Either a string naming a distributed backend (``"torch"`` or ``"jax"``), | ||
| or a callable that returns ``(rank, world_size)``. | ||
| shuffle | ||
| Whether to shuffle chunk and index order. | ||
| drop_last | ||
| Whether to drop the last incomplete batch. | ||
| rng | ||
| Random number generator for shuffling. | ||
| enforce_equal_batches | ||
| If *True*, round each rank's observation count down to a multiple of ``batch_size`` so that all workers (ranks) yield the same number of batches. | ||
| Set to *False* to use the raw ``n_obs // world_size`` split, which may result in an uneven number of batches per worker. | ||
| """ | ||
|
|
||
| _rank: int | ||
| _world_size: int | ||
| _enforce_equal_batches: bool | ||
|
|
||
| def __init__( | ||
| self, | ||
| chunk_size: int, | ||
| preload_nchunks: int, | ||
| batch_size: int, | ||
| *, | ||
| dist_info: Literal["torch", "jax"] | Callable[[], tuple[int, int]], | ||
| shuffle: bool = False, | ||
| drop_last: bool = False, | ||
| rng: np.random.Generator | None = None, | ||
| enforce_equal_batches: bool = True, | ||
| ): | ||
| if callable(dist_info): | ||
| self._rank, self._world_size = dist_info() | ||
| elif dist_info in DISTRIBUTED_BACKENDS: | ||
| self._rank, self._world_size = DISTRIBUTED_BACKENDS[dist_info]() | ||
| else: | ||
| raise ValueError(f"Unknown dist_info {dist_info!r}. Supported backends: {sorted(DISTRIBUTED_BACKENDS)}") | ||
| self._enforce_equal_batches = enforce_equal_batches | ||
|
|
||
| super().__init__( | ||
| chunk_size=chunk_size, | ||
| preload_nchunks=preload_nchunks, | ||
| batch_size=batch_size, | ||
| shuffle=shuffle, | ||
| drop_last=drop_last, | ||
| rng=_spawn_worker_rng(rng, self._rank) if rng else None, | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also added the rng handling like you suggested. Maybe this already addresses you concerns? @selmanozleyen
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that fixes having same batch indices across nodes |
||
| ) | ||
|
|
||
| def _shard_mask(self, n_obs: int) -> slice: | ||
| """Return the contiguous observation slice for this rank.""" | ||
| per_rank = n_obs // self._world_size | ||
| if self._enforce_equal_batches: | ||
| per_rank = per_rank // self._batch_size * self._batch_size | ||
| rank_start = self._rank * per_rank | ||
| rank_stop = rank_start + per_rank | ||
| return slice(rank_start, rank_stop) | ||
|
|
||
| def n_iters(self, n_obs: int) -> int: | ||
| self._mask = self._shard_mask(n_obs) | ||
| return super().n_iters(n_obs) | ||
|
|
||
| def validate(self, n_obs: int) -> None: | ||
| self._mask = self._shard_mask(n_obs) | ||
| super().validate(n_obs) | ||
|
|
||
| def _sample(self, n_obs: int) -> Iterator[LoadRequest]: | ||
| self._mask = self._shard_mask(n_obs) | ||
| yield from super()._sample(n_obs) | ||
|
Comment on lines
+308
to
+310
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit hacky and requires the child to know too much of the inner logics of it's parent functions details. From @ilan-gold
I mostly agree. That's why I wanted to name chunksampler a window/mask/slice sampler to be technically correct. For categorical samplers we shouldn't need list[slice] because categorical sampler should be able to use multiple chunksampler's right? I think list[slice] isn't trivial for ChunkSampler itself because Easiest change could be having an effective_mask attribute which is the identity self._mask by default and passed to _sampler(n_obs,mask)? and can be overridden in
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's put a pin into this and revisit if we see a need to update! For now, it's not really worth it (see comment from Ilan as well below) |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,12 +3,13 @@ | |
| from __future__ import annotations | ||
|
|
||
| import math | ||
| from unittest.mock import patch | ||
| import sys | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from annbatch import ChunkSampler | ||
| from annbatch import ChunkSampler, ChunkSamplerDistributed | ||
| from annbatch.abc import Sampler | ||
| from annbatch.samplers._utils import WorkerInfo | ||
|
|
||
|
|
@@ -369,3 +370,195 @@ def test_automatic_batching_respects_shuffle_flag(shuffle: bool): | |
| assert all_indices != list(range(n_obs)), "Indices should be shuffled" | ||
| else: | ||
| assert all_indices == list(range(n_obs)), "Indices should be sequential" | ||
|
|
||
|
|
||
| # ============================================================================= | ||
| # ChunkSamplerDistributed tests | ||
| # ============================================================================= | ||
|
|
||
|
|
||
| def _make_distributed_sampler_torch(rank: int, world_size: int, **kwargs) -> ChunkSamplerDistributed: | ||
| """Create a ChunkSamplerDistributed with mocked torch.distributed backend.""" | ||
| mock_dist = MagicMock() | ||
| mock_dist.is_initialized.return_value = True | ||
| mock_dist.get_rank.return_value = rank | ||
| mock_dist.get_world_size.return_value = world_size | ||
| mock_torch = MagicMock() | ||
| mock_torch.distributed = mock_dist | ||
| with patch.dict(sys.modules, {"torch": mock_torch, "torch.distributed": mock_dist}): | ||
|
Comment on lines
+380
to
+388
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is why an
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my comment above. I could remove those now that I added the |
||
| return ChunkSamplerDistributed(dist_info="torch", **kwargs) | ||
|
|
||
|
|
||
| def _make_distributed_sampler_jax(rank: int, world_size: int, **kwargs) -> ChunkSamplerDistributed: | ||
| """Create a ChunkSamplerDistributed with mocked jax backend.""" | ||
| mock_jax = MagicMock() | ||
| mock_jax.process_index.return_value = rank | ||
| mock_jax.process_count.return_value = world_size | ||
| mock_jax.distributed.is_initialized.return_value = True | ||
| with patch.dict(sys.modules, {"jax": mock_jax}): | ||
| return ChunkSamplerDistributed(dist_info="jax", **kwargs) | ||
|
|
||
|
|
||
| _SAMPLER_FACTORIES = { | ||
| "torch": _make_distributed_sampler_torch, | ||
| "jax": _make_distributed_sampler_jax, | ||
| } | ||
|
|
||
|
|
||
| @pytest.fixture(params=["torch", "jax"]) | ||
| def make_distributed_sampler(request): | ||
| """Fixture that yields a sampler factory for each backend.""" | ||
| return _SAMPLER_FACTORIES[request.param] | ||
|
|
||
|
|
||
| class TestChunkSamplerDistributed: | ||
| """Tests for ChunkSamplerDistributed, parameterized over all backends.""" | ||
|
|
||
| def test_not_initialized_raises_torch(self): | ||
| """RuntimeError when torch.distributed is not initialized.""" | ||
| mock_dist = MagicMock() | ||
| mock_dist.is_initialized.return_value = False | ||
| mock_torch = MagicMock() | ||
| mock_torch.distributed = mock_dist | ||
| with patch.dict(sys.modules, {"torch": mock_torch, "torch.distributed": mock_dist}): | ||
| with pytest.raises(RuntimeError, match="torch.distributed is not initialized"): | ||
| ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, dist_info="torch") | ||
|
|
||
| def test_not_initialized_raises_jax(self): | ||
| """RuntimeError when jax.distributed is not initialized.""" | ||
| mock_jax = MagicMock() | ||
| mock_jax.distributed.is_initialized.return_value = False | ||
| with patch.dict(sys.modules, {"jax": mock_jax}): | ||
| with pytest.raises(RuntimeError, match="JAX distributed is not initialized"): | ||
| ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, dist_info="jax") | ||
|
|
||
| def test_unknown_dist_info_raises(self): | ||
| """ValueError for an unsupported dist_info string.""" | ||
| with pytest.raises(ValueError, match="Unknown dist_info"): | ||
| ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, dist_info="mpi") | ||
|
|
||
| def test_shards_are_disjoint_and_cover_full_dataset(self, make_distributed_sampler): | ||
| """All ranks receive non-overlapping shards that together cover the full dataset.""" | ||
| n_obs, world_size = 200, 4 | ||
| chunk_size, preload_nchunks, batch_size = 10, 2, 10 | ||
|
|
||
| all_indices = [] | ||
| for rank in range(world_size): | ||
| sampler = make_distributed_sampler( | ||
| rank=rank, | ||
| world_size=world_size, | ||
| chunk_size=chunk_size, | ||
| preload_nchunks=preload_nchunks, | ||
| batch_size=batch_size, | ||
| ) | ||
| all_indices.append(collect_indices(sampler, n_obs)) | ||
|
|
||
| # Shards must be disjoint | ||
| for i in range(world_size): | ||
| for j in range(i + 1, world_size): | ||
| assert set(all_indices[i]).isdisjoint(set(all_indices[j])) | ||
|
|
||
| # Together they cover the full dataset (evenly divisible case) | ||
| assert set().union(*all_indices) == set(range(n_obs)) | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "n_obs,world_size,batch_size,chunk_size,preload_nchunks", | ||
| [ | ||
| pytest.param(200, 4, 10, 10, 2, id="evenly_divisible"), | ||
| pytest.param(205, 3, 10, 10, 2, id="remainder_obs"), | ||
| pytest.param(1000, 7, 5, 10, 2, id="prime_world_size"), | ||
| pytest.param(100, 3, 5, 10, 2, id="small_dataset"), | ||
| ], | ||
| ) | ||
| def test_enforce_equal_batches_all_ranks_same_count( | ||
| self, make_distributed_sampler, n_obs, world_size, batch_size, chunk_size, preload_nchunks | ||
| ): | ||
| """enforce_equal_batches=True guarantees identical batch counts across ranks.""" | ||
| batch_counts = [] | ||
| for rank in range(world_size): | ||
| sampler = make_distributed_sampler( | ||
| rank=rank, | ||
| world_size=world_size, | ||
| chunk_size=chunk_size, | ||
| preload_nchunks=preload_nchunks, | ||
| batch_size=batch_size, | ||
| enforce_equal_batches=True, | ||
| ) | ||
| n_batches = sum(len(lr["splits"]) for lr in sampler.sample(n_obs)) | ||
| batch_counts.append(n_batches) | ||
|
|
||
| assert len(set(batch_counts)) == 1, f"Batch counts differ across ranks: {batch_counts}" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not assert this to batch_size also like batch_counts[0] == batch_size
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. at this point I don't have the batch_size info anymore, I'm just tracking the number of yielded batches here (that's what the test is supposed to cover). I'd argue what you're suggesting is already covered by the non distributed tests. So would focus the test here that the dataset sharding is done right etc |
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("enforce_equal_batches", "expected"), | ||
| [(True, 30), (False, 35)], | ||
| ids=["rounded", "raw"], | ||
| ) | ||
| def test_enforce_equal_batches_per_rank_count(self, make_distributed_sampler, enforce_equal_batches, expected): | ||
| """enforce_equal_batches controls whether per_rank is rounded down to a multiple of batch_size.""" | ||
| n_obs, world_size = 107, 3 | ||
| chunk_size, preload_nchunks, batch_size = 10, 1, 10 | ||
| # raw per_rank = 107 // 3 = 35, rounded = 35 // 10 * 10 = 30 | ||
| sampler = make_distributed_sampler( | ||
| rank=0, | ||
| world_size=world_size, | ||
| chunk_size=chunk_size, | ||
| preload_nchunks=preload_nchunks, | ||
| batch_size=batch_size, | ||
| enforce_equal_batches=enforce_equal_batches, | ||
| ) | ||
| indices = collect_indices(sampler, n_obs) | ||
| assert len(set(indices)) == expected | ||
|
|
||
| def test_batch_shuffle_is_reproducible_with_same_seed_rng(self, make_distributed_sampler): | ||
| """Test that batch shuffling is reproducible when passing in rngs with identical seeds.""" | ||
| n_obs, chunk_size, preload_nchunks, batch_size = 200, 10, 2, 5 | ||
| world_size = 4 | ||
| seed = 42 | ||
|
|
||
| def collect_splits(sampler: ChunkSamplerDistributed) -> list[list[int]]: | ||
| all_splits: list[list[int]] = [] | ||
| for load_request in sampler.sample(n_obs): | ||
| for split in load_request["splits"]: | ||
| all_splits.append(split.tolist()) | ||
| return all_splits | ||
|
|
||
| splits_per_run: list[dict[int, list[list[int]]]] = [] | ||
| for _ in range(3): # test 3 runs to ensure reproducibility | ||
| splits_by_rank: dict[int, list[list[int]]] = {} | ||
| for rank in range(world_size): | ||
| sampler = make_distributed_sampler( | ||
| rank=rank, | ||
| world_size=world_size, | ||
| chunk_size=chunk_size, | ||
| preload_nchunks=preload_nchunks, | ||
| batch_size=batch_size, | ||
| shuffle=True, | ||
| rng=np.random.default_rng(seed), | ||
| ) | ||
| splits_by_rank[rank] = collect_splits(sampler) | ||
| splits_per_run.append(splits_by_rank) | ||
|
|
||
| for rank in range(world_size): | ||
| assert splits_per_run[0][rank] == splits_per_run[1][rank], ( | ||
| f"Rank {rank}: batch shuffling should be reproducible with same seed" | ||
| ) | ||
|
|
||
| def test_n_iters_matches_actual_batch_count(self, make_distributed_sampler): | ||
| """n_iters should match the actual number of yielded batches.""" | ||
| n_obs, world_size = 205, 3 | ||
| chunk_size, preload_nchunks, batch_size = 10, 2, 10 | ||
|
|
||
| for rank in range(world_size): | ||
| sampler = make_distributed_sampler( | ||
| rank=rank, | ||
| world_size=world_size, | ||
| chunk_size=chunk_size, | ||
| preload_nchunks=preload_nchunks, | ||
| batch_size=batch_size, | ||
| enforce_equal_batches=True, | ||
| drop_last=True, | ||
| ) | ||
| expected = sampler.n_iters(n_obs) | ||
| actual = sum(len(lr["splits"]) for lr in sampler.sample(n_obs)) | ||
| assert actual == expected, f"rank {rank}: n_iters={expected}, actual={actual}" | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think really don't need backend="jax" kind of interface here. We can get as an input a callable like
Why a callable? So that the they are called whenever the distributed system is spawned/launched.
In an example tutorial we can show anyone can inject anything that gives these. Or we can also implement convenience dist_info_fn's. Anyone who's knowledgeable enough to need this can use this so we can have a more mature interface. No need to add training wheels that will age bad here IMO
Even though it's a very strong opinion of mine it's still just a preference.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair point, I will add the option to provide a callable as well. Callable then just return a tuple
(rank, world_size)then people can easily extend this to other frameworks.I would keep the torch + jax option though. Most people who do distributed training are not really aware of those concepts. E.g lightning abstracts those away. So, it's easy for them to use it, they just need to know which framework their using.