diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index 39180c0..a08b76a 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -5,7 +5,7 @@ from . import abc, types from .io import DatasetCollection, write_sharded from .loader import Loader -from .samplers._chunk_sampler import ChunkSampler +from .samplers._chunk_sampler import ChunkSampler, ChunkSamplerDistributed __version__ = version("annbatch") @@ -15,5 +15,6 @@ "types", "write_sharded", "ChunkSampler", + "ChunkSamplerDistributed", "abc", ] diff --git a/src/annbatch/samplers/__init__.py b/src/annbatch/samplers/__init__.py index 9f92bbf..b986627 100644 --- a/src/annbatch/samplers/__init__.py +++ b/src/annbatch/samplers/__init__.py @@ -1,5 +1,6 @@ -from ._chunk_sampler import ChunkSampler +from ._chunk_sampler import ChunkSampler, ChunkSamplerDistributed __all__ = [ "ChunkSampler", + "ChunkSamplerDistributed", ] diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 691862a..a24f693 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -3,7 +3,7 @@ from __future__ import annotations import math -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import numpy as np @@ -12,7 +12,7 @@ from annbatch.utils import _spawn_worker_rng, check_lt_1, split_given_size if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Callable, Iterator from annbatch.types import LoadRequest @@ -184,3 +184,127 @@ def _compute_chunks(self, chunk_indices: np.ndarray, start: int, stop: int) -> l offsets = np.cumsum(offsets) starts, stops = offsets[:-1][chunk_indices], offsets[1:][chunk_indices] return [slice(int(s), int(e)) for s, e in zip(starts, stops, strict=True)] + + +def _get_dist_info_torch() -> tuple[int, int]: + """Get rank and world_size from ``torch.distributed``.""" + import torch.distributed as dist + + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed is not initialized. " + "Initialize it before creating a ChunkSamplerDistributed with dist_info='torch'." + ) + return dist.get_rank(), dist.get_world_size() + + +def _get_dist_info_jax() -> tuple[int, int]: + """Get rank and world_size from JAX multi-process API.""" + import jax + + if not jax.distributed.is_initialized(): + raise RuntimeError( + "JAX distributed is not initialized. " + "Call jax.distributed.initialize() before creating a ChunkSamplerDistributed with dist_info='jax'." + ) + return jax.process_index(), jax.process_count() + + +DISTRIBUTED_BACKENDS: dict[str, Callable[[], tuple[int, int]]] = { + "torch": _get_dist_info_torch, + "jax": _get_dist_info_jax, +} + + +class ChunkSamplerDistributed(ChunkSampler): + """Distributed chunk-based sampler that shards data across distributed processes. + + Partitions the full observation range into ``world_size`` contiguous shards + using the ``mask`` mechanism of :class:`ChunkSampler`. Each rank receives a + non-overlapping slice of the data. The shard boundaries are computed lazily + when ``n_obs`` becomes known. + + When ``enforce_equal_batches`` is *True* (the default), the per-rank observation + count is rounded down to the nearest multiple of ``batch_size``, + guaranteeing that every rank yields exactly the same number of complete + batches. + + Rank and world size are obtained from ``dist_info`` at construction time. + The corresponding distributed framework must already be initialized. + + Parameters + ---------- + chunk_size + Size of each chunk i.e. the range of each chunk yielded. + preload_nchunks + Number of chunks to load per iteration. + batch_size + Number of observations per batch. + dist_info + How to obtain rank and world size. + Either a string naming a distributed backend (``"torch"`` or ``"jax"``), + or a callable that returns ``(rank, world_size)``. + shuffle + Whether to shuffle chunk and index order. + drop_last + Whether to drop the last incomplete batch. + rng + Random number generator for shuffling. + enforce_equal_batches + If *True*, round each rank's observation count down to a multiple of ``batch_size`` so that all workers (ranks) yield the same number of batches. + Set to *False* to use the raw ``n_obs // world_size`` split, which may result in an uneven number of batches per worker. + """ + + _rank: int + _world_size: int + _enforce_equal_batches: bool + + def __init__( + self, + chunk_size: int, + preload_nchunks: int, + batch_size: int, + *, + dist_info: Literal["torch", "jax"] | Callable[[], tuple[int, int]], + shuffle: bool = False, + drop_last: bool = False, + rng: np.random.Generator | None = None, + enforce_equal_batches: bool = True, + ): + if callable(dist_info): + self._rank, self._world_size = dist_info() + elif dist_info in DISTRIBUTED_BACKENDS: + self._rank, self._world_size = DISTRIBUTED_BACKENDS[dist_info]() + else: + raise ValueError(f"Unknown dist_info {dist_info!r}. Supported backends: {sorted(DISTRIBUTED_BACKENDS)}") + self._enforce_equal_batches = enforce_equal_batches + + super().__init__( + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + rng=_spawn_worker_rng(rng, self._rank) if rng else None, + ) + + def _shard_mask(self, n_obs: int) -> slice: + """Return the contiguous observation slice for this rank.""" + per_rank = n_obs // self._world_size + if self._enforce_equal_batches: + per_rank = per_rank // self._batch_size * self._batch_size + rank_start = self._rank * per_rank + rank_stop = rank_start + per_rank + return slice(rank_start, rank_stop) + + def n_iters(self, n_obs: int) -> int: + self._mask = self._shard_mask(n_obs) + return super().n_iters(n_obs) + + def validate(self, n_obs: int) -> None: + self._mask = self._shard_mask(n_obs) + super().validate(n_obs) + + def _sample(self, n_obs: int) -> Iterator[LoadRequest]: + self._mask = self._shard_mask(n_obs) + yield from super()._sample(n_obs) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 8fff301..357574e 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -3,12 +3,13 @@ from __future__ import annotations import math -from unittest.mock import patch +import sys +from unittest.mock import MagicMock, patch import numpy as np import pytest -from annbatch import ChunkSampler +from annbatch import ChunkSampler, ChunkSamplerDistributed from annbatch.abc import Sampler from annbatch.samplers._utils import WorkerInfo @@ -369,3 +370,195 @@ def test_automatic_batching_respects_shuffle_flag(shuffle: bool): assert all_indices != list(range(n_obs)), "Indices should be shuffled" else: assert all_indices == list(range(n_obs)), "Indices should be sequential" + + +# ============================================================================= +# ChunkSamplerDistributed tests +# ============================================================================= + + +def _make_distributed_sampler_torch(rank: int, world_size: int, **kwargs) -> ChunkSamplerDistributed: + """Create a ChunkSamplerDistributed with mocked torch.distributed backend.""" + mock_dist = MagicMock() + mock_dist.is_initialized.return_value = True + mock_dist.get_rank.return_value = rank + mock_dist.get_world_size.return_value = world_size + mock_torch = MagicMock() + mock_torch.distributed = mock_dist + with patch.dict(sys.modules, {"torch": mock_torch, "torch.distributed": mock_dist}): + return ChunkSamplerDistributed(dist_info="torch", **kwargs) + + +def _make_distributed_sampler_jax(rank: int, world_size: int, **kwargs) -> ChunkSamplerDistributed: + """Create a ChunkSamplerDistributed with mocked jax backend.""" + mock_jax = MagicMock() + mock_jax.process_index.return_value = rank + mock_jax.process_count.return_value = world_size + mock_jax.distributed.is_initialized.return_value = True + with patch.dict(sys.modules, {"jax": mock_jax}): + return ChunkSamplerDistributed(dist_info="jax", **kwargs) + + +_SAMPLER_FACTORIES = { + "torch": _make_distributed_sampler_torch, + "jax": _make_distributed_sampler_jax, +} + + +@pytest.fixture(params=["torch", "jax"]) +def make_distributed_sampler(request): + """Fixture that yields a sampler factory for each backend.""" + return _SAMPLER_FACTORIES[request.param] + + +class TestChunkSamplerDistributed: + """Tests for ChunkSamplerDistributed, parameterized over all backends.""" + + def test_not_initialized_raises_torch(self): + """RuntimeError when torch.distributed is not initialized.""" + mock_dist = MagicMock() + mock_dist.is_initialized.return_value = False + mock_torch = MagicMock() + mock_torch.distributed = mock_dist + with patch.dict(sys.modules, {"torch": mock_torch, "torch.distributed": mock_dist}): + with pytest.raises(RuntimeError, match="torch.distributed is not initialized"): + ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, dist_info="torch") + + def test_not_initialized_raises_jax(self): + """RuntimeError when jax.distributed is not initialized.""" + mock_jax = MagicMock() + mock_jax.distributed.is_initialized.return_value = False + with patch.dict(sys.modules, {"jax": mock_jax}): + with pytest.raises(RuntimeError, match="JAX distributed is not initialized"): + ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, dist_info="jax") + + def test_unknown_dist_info_raises(self): + """ValueError for an unsupported dist_info string.""" + with pytest.raises(ValueError, match="Unknown dist_info"): + ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, dist_info="mpi") + + def test_shards_are_disjoint_and_cover_full_dataset(self, make_distributed_sampler): + """All ranks receive non-overlapping shards that together cover the full dataset.""" + n_obs, world_size = 200, 4 + chunk_size, preload_nchunks, batch_size = 10, 2, 10 + + all_indices = [] + for rank in range(world_size): + sampler = make_distributed_sampler( + rank=rank, + world_size=world_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + ) + all_indices.append(collect_indices(sampler, n_obs)) + + # Shards must be disjoint + for i in range(world_size): + for j in range(i + 1, world_size): + assert set(all_indices[i]).isdisjoint(set(all_indices[j])) + + # Together they cover the full dataset (evenly divisible case) + assert set().union(*all_indices) == set(range(n_obs)) + + @pytest.mark.parametrize( + "n_obs,world_size,batch_size,chunk_size,preload_nchunks", + [ + pytest.param(200, 4, 10, 10, 2, id="evenly_divisible"), + pytest.param(205, 3, 10, 10, 2, id="remainder_obs"), + pytest.param(1000, 7, 5, 10, 2, id="prime_world_size"), + pytest.param(100, 3, 5, 10, 2, id="small_dataset"), + ], + ) + def test_enforce_equal_batches_all_ranks_same_count( + self, make_distributed_sampler, n_obs, world_size, batch_size, chunk_size, preload_nchunks + ): + """enforce_equal_batches=True guarantees identical batch counts across ranks.""" + batch_counts = [] + for rank in range(world_size): + sampler = make_distributed_sampler( + rank=rank, + world_size=world_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + enforce_equal_batches=True, + ) + n_batches = sum(len(lr["splits"]) for lr in sampler.sample(n_obs)) + batch_counts.append(n_batches) + + assert len(set(batch_counts)) == 1, f"Batch counts differ across ranks: {batch_counts}" + + @pytest.mark.parametrize( + ("enforce_equal_batches", "expected"), + [(True, 30), (False, 35)], + ids=["rounded", "raw"], + ) + def test_enforce_equal_batches_per_rank_count(self, make_distributed_sampler, enforce_equal_batches, expected): + """enforce_equal_batches controls whether per_rank is rounded down to a multiple of batch_size.""" + n_obs, world_size = 107, 3 + chunk_size, preload_nchunks, batch_size = 10, 1, 10 + # raw per_rank = 107 // 3 = 35, rounded = 35 // 10 * 10 = 30 + sampler = make_distributed_sampler( + rank=0, + world_size=world_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + enforce_equal_batches=enforce_equal_batches, + ) + indices = collect_indices(sampler, n_obs) + assert len(set(indices)) == expected + + def test_batch_shuffle_is_reproducible_with_same_seed_rng(self, make_distributed_sampler): + """Test that batch shuffling is reproducible when passing in rngs with identical seeds.""" + n_obs, chunk_size, preload_nchunks, batch_size = 200, 10, 2, 5 + world_size = 4 + seed = 42 + + def collect_splits(sampler: ChunkSamplerDistributed) -> list[list[int]]: + all_splits: list[list[int]] = [] + for load_request in sampler.sample(n_obs): + for split in load_request["splits"]: + all_splits.append(split.tolist()) + return all_splits + + splits_per_run: list[dict[int, list[list[int]]]] = [] + for _ in range(3): # test 3 runs to ensure reproducibility + splits_by_rank: dict[int, list[list[int]]] = {} + for rank in range(world_size): + sampler = make_distributed_sampler( + rank=rank, + world_size=world_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + shuffle=True, + rng=np.random.default_rng(seed), + ) + splits_by_rank[rank] = collect_splits(sampler) + splits_per_run.append(splits_by_rank) + + for rank in range(world_size): + assert splits_per_run[0][rank] == splits_per_run[1][rank], ( + f"Rank {rank}: batch shuffling should be reproducible with same seed" + ) + + def test_n_iters_matches_actual_batch_count(self, make_distributed_sampler): + """n_iters should match the actual number of yielded batches.""" + n_obs, world_size = 205, 3 + chunk_size, preload_nchunks, batch_size = 10, 2, 10 + + for rank in range(world_size): + sampler = make_distributed_sampler( + rank=rank, + world_size=world_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + enforce_equal_batches=True, + drop_last=True, + ) + expected = sampler.n_iters(n_obs) + actual = sum(len(lr["splits"]) for lr in sampler.sample(n_obs)) + assert actual == expected, f"rank {rank}: n_iters={expected}, actual={actual}"