From b0ee20d552bf069baa0bb37a8458f22ee7a3fd07 Mon Sep 17 00:00:00 2001 From: Felix Fischer Date: Fri, 13 Feb 2026 14:32:45 +0100 Subject: [PATCH 1/4] Add ChunkSamplerTorchDistributed sampler --- src/annbatch/__init__.py | 3 +- src/annbatch/samplers/__init__.py | 3 +- src/annbatch/samplers/_chunk_sampler.py | 93 ++++++++++++++++ tests/test_sampler.py | 137 +++++++++++++++++++++++- 4 files changed, 232 insertions(+), 4 deletions(-) diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index 39180c0b..b61556cb 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -5,7 +5,7 @@ from . import abc, types from .io import DatasetCollection, write_sharded from .loader import Loader -from .samplers._chunk_sampler import ChunkSampler +from .samplers._chunk_sampler import ChunkSampler, ChunkSamplerTorchDistributed __version__ = version("annbatch") @@ -15,5 +15,6 @@ "types", "write_sharded", "ChunkSampler", + "ChunkSamplerTorchDistributed", "abc", ] diff --git a/src/annbatch/samplers/__init__.py b/src/annbatch/samplers/__init__.py index 9f92bbf0..c8d04431 100644 --- a/src/annbatch/samplers/__init__.py +++ b/src/annbatch/samplers/__init__.py @@ -1,5 +1,6 @@ -from ._chunk_sampler import ChunkSampler +from ._chunk_sampler import ChunkSampler, ChunkSamplerTorchDistributed __all__ = [ "ChunkSampler", + "ChunkSamplerTorchDistributed", ] diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 691862a4..0306778a 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -184,3 +184,96 @@ def _compute_chunks(self, chunk_indices: np.ndarray, start: int, stop: int) -> l offsets = np.cumsum(offsets) starts, stops = offsets[:-1][chunk_indices], offsets[1:][chunk_indices] return [slice(int(s), int(e)) for s, e in zip(starts, stops, strict=True)] + + +class ChunkSamplerTorchDistributed(ChunkSampler): + """Distributed chunk-based sampler that shards data across torch distributed processes. + + Partitions the full observation range into ``world_size`` contiguous shards + using the ``mask`` mechanism of :class:`ChunkSampler`. Each rank receives a + non-overlapping slice of the data. The shard boundaries are computed lazily + when ``n_obs`` becomes known. + + When ``enforce_equal_batches`` is *True* (the default), the per-rank observation + count is rounded down to the nearest multiple of ``batch_size``, + guaranteeing that every rank yields exactly the same number of complete + batches. + + Rank and world size are obtained from ``torch.distributed`` at construction + time, so ``torch.distributed`` must be initialized before creating an + instance of this sampler. + + Parameters + ---------- + chunk_size + Size of each chunk i.e. the range of each chunk yielded. + preload_nchunks + Number of chunks to load per iteration. + batch_size + Number of observations per batch. + shuffle + Whether to shuffle chunk and index order. + drop_last + Whether to drop the last incomplete batch. + rng + Random number generator for shuffling. + enforce_equal_batches + If *True*, round each rank's observation count down to a multiple of ``batch_size`` so that all ranks yield the same numberof batches. + Set to *False* to use the raw ``n_obs // world_size`` split, which may result in a uneven number of batches per worker. + """ + + _rank: int + _world_size: int + _enforce_equal_batches: bool + + def __init__( + self, + chunk_size: int, + preload_nchunks: int, + batch_size: int, + *, + shuffle: bool = False, + drop_last: bool = False, + rng: np.random.Generator | None = None, + enforce_equal_batches: bool = True, + ): + import torch.distributed as dist + + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed is not initialized. Initialize it before creating a ChunkSamplerTorchDistributed." + ) + + self._rank = dist.get_rank() + self._world_size = dist.get_world_size() + self._enforce_equal_batches = enforce_equal_batches + + super().__init__( + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + rng=rng, + ) + + def _shard_mask(self, n_obs: int) -> slice: + """Return the contiguous observation slice for this rank.""" + per_rank = n_obs // self._world_size + if self._enforce_equal_batches: + per_rank = per_rank // self._batch_size * self._batch_size + rank_start = self._rank * per_rank + rank_stop = rank_start + per_rank + return slice(rank_start, rank_stop) + + def n_iters(self, n_obs: int) -> int: + self._mask = self._shard_mask(n_obs) + return super().n_iters(n_obs) + + def validate(self, n_obs: int) -> None: + self._mask = self._shard_mask(n_obs) + super().validate(n_obs) + + def _sample(self, n_obs: int) -> Iterator[LoadRequest]: + self._mask = self._shard_mask(n_obs) + yield from super()._sample(n_obs) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 8fff3011..5c85bcf9 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -3,12 +3,13 @@ from __future__ import annotations import math -from unittest.mock import patch +import sys +from unittest.mock import MagicMock, patch import numpy as np import pytest -from annbatch import ChunkSampler +from annbatch import ChunkSampler, ChunkSamplerTorchDistributed from annbatch.abc import Sampler from annbatch.samplers._utils import WorkerInfo @@ -369,3 +370,135 @@ def test_automatic_batching_respects_shuffle_flag(shuffle: bool): assert all_indices != list(range(n_obs)), "Indices should be shuffled" else: assert all_indices == list(range(n_obs)), "Indices should be sequential" + + +# ============================================================================= +# ChunkSamplerTorchDistributed tests +# ============================================================================= + + +def _make_distributed_sampler(rank: int, world_size: int, **kwargs) -> ChunkSamplerTorchDistributed: + """Create a ChunkSamplerTorchDistributed with mocked torch.distributed.""" + mock_dist = MagicMock() + mock_dist.is_initialized.return_value = True + mock_dist.get_rank.return_value = rank + mock_dist.get_world_size.return_value = world_size + mock_torch = MagicMock() + mock_torch.distributed = mock_dist + with patch.dict(sys.modules, {"torch": mock_torch, "torch.distributed": mock_dist}): + return ChunkSamplerTorchDistributed(**kwargs) + + +class TestChunkSamplerTorchDistributed: + def test_not_initialized_raises(self): + """RuntimeError when torch.distributed is not initialized.""" + mock_dist = MagicMock() + mock_dist.is_initialized.return_value = False + mock_torch = MagicMock() + mock_torch.distributed = mock_dist + with patch.dict(sys.modules, {"torch": mock_torch, "torch.distributed": mock_dist}): + with pytest.raises(RuntimeError, match="torch.distributed is not initialized"): + ChunkSamplerTorchDistributed(chunk_size=10, preload_nchunks=2, batch_size=10) + + def test_shards_are_disjoint_and_cover_full_dataset(self): + """All ranks receive non-overlapping shards that together cover the full dataset.""" + n_obs, world_size = 200, 4 + chunk_size, preload_nchunks, batch_size = 10, 2, 10 + + all_indices = [] + for rank in range(world_size): + sampler = _make_distributed_sampler( + rank=rank, + world_size=world_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + ) + all_indices.append(collect_indices(sampler, n_obs)) + + # Shards must be disjoint + for i in range(world_size): + for j in range(i + 1, world_size): + assert set(all_indices[i]).isdisjoint(set(all_indices[j])) + + # Together they cover the full dataset (evenly divisible case) + assert set().union(*all_indices) == set(range(n_obs)) + + @pytest.mark.parametrize( + "n_obs,world_size,batch_size,chunk_size,preload_nchunks", + [ + pytest.param(200, 4, 10, 10, 2, id="evenly_divisible"), + pytest.param(205, 3, 10, 10, 2, id="remainder_obs"), + pytest.param(1000, 7, 5, 10, 2, id="prime_world_size"), + pytest.param(100, 3, 5, 10, 2, id="small_dataset"), + ], + ) + def test_enforce_equal_batches_all_ranks_same_count( + self, n_obs, world_size, batch_size, chunk_size, preload_nchunks + ): + """enforce_equal_batches=True guarantees identical batch counts across ranks.""" + batch_counts = [] + for rank in range(world_size): + sampler = _make_distributed_sampler( + rank=rank, + world_size=world_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + enforce_equal_batches=True, + ) + n_batches = sum(len(lr["splits"]) for lr in sampler.sample(n_obs)) + batch_counts.append(n_batches) + + assert len(set(batch_counts)) == 1, f"Batch counts differ across ranks: {batch_counts}" + + def test_enforce_equal_batches_rounds_down_per_rank(self): + """enforce_equal_batches=True rounds per_rank down to a multiple of batch_size.""" + n_obs, world_size = 107, 3 + chunk_size, preload_nchunks, batch_size = 10, 1, 10 + # raw per_rank = 107 // 3 = 35, rounded = 35 // 10 * 10 = 30 + sampler = _make_distributed_sampler( + rank=0, + world_size=world_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + enforce_equal_batches=True, + ) + indices = collect_indices(sampler, n_obs) + assert len(set(indices)) == 30 + + def test_enforce_equal_batches_false_uses_raw_split(self): + """enforce_equal_batches=False uses n_obs // world_size without rounding.""" + n_obs, world_size = 107, 3 + chunk_size, preload_nchunks, batch_size = 10, 1, 10 + # raw per_rank = 107 // 3 = 35 (not rounded to 30) + sampler = _make_distributed_sampler( + rank=0, + world_size=world_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + enforce_equal_batches=False, + ) + indices = collect_indices(sampler, n_obs) + assert len(set(indices)) == 35 + + def test_n_iters_matches_actual_batch_count(self): + """n_iters should match the actual number of yielded batches.""" + n_obs, world_size = 205, 3 + chunk_size, preload_nchunks, batch_size = 10, 2, 10 + + for rank in range(world_size): + sampler = _make_distributed_sampler( + rank=rank, + world_size=world_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + enforce_equal_batches=True, + drop_last=True, + ) + expected = sampler.n_iters(n_obs) + actual = sum(len(lr["splits"]) for lr in sampler.sample(n_obs)) + assert actual == expected, f"rank {rank}: n_iters={expected}, actual={actual}" From 570f30e65f7dc403df6a0e3c28b835ee7b59cdd9 Mon Sep 17 00:00:00 2001 From: Felix Fischer Date: Fri, 13 Feb 2026 16:34:45 +0100 Subject: [PATCH 2/4] Make distributed backend configurable --- src/annbatch/__init__.py | 4 +- src/annbatch/samplers/__init__.py | 4 +- src/annbatch/samplers/_chunk_sampler.py | 64 +++++++++++++++------ tests/test_sampler.py | 75 ++++++++++++++++++------- 4 files changed, 108 insertions(+), 39 deletions(-) diff --git a/src/annbatch/__init__.py b/src/annbatch/__init__.py index b61556cb..a08b76ab 100644 --- a/src/annbatch/__init__.py +++ b/src/annbatch/__init__.py @@ -5,7 +5,7 @@ from . import abc, types from .io import DatasetCollection, write_sharded from .loader import Loader -from .samplers._chunk_sampler import ChunkSampler, ChunkSamplerTorchDistributed +from .samplers._chunk_sampler import ChunkSampler, ChunkSamplerDistributed __version__ = version("annbatch") @@ -15,6 +15,6 @@ "types", "write_sharded", "ChunkSampler", - "ChunkSamplerTorchDistributed", + "ChunkSamplerDistributed", "abc", ] diff --git a/src/annbatch/samplers/__init__.py b/src/annbatch/samplers/__init__.py index c8d04431..b9866273 100644 --- a/src/annbatch/samplers/__init__.py +++ b/src/annbatch/samplers/__init__.py @@ -1,6 +1,6 @@ -from ._chunk_sampler import ChunkSampler, ChunkSamplerTorchDistributed +from ._chunk_sampler import ChunkSampler, ChunkSamplerDistributed __all__ = [ "ChunkSampler", - "ChunkSamplerTorchDistributed", + "ChunkSamplerDistributed", ] diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index 0306778a..ecca0507 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -3,7 +3,7 @@ from __future__ import annotations import math -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import numpy as np @@ -186,8 +186,38 @@ def _compute_chunks(self, chunk_indices: np.ndarray, start: int, stop: int) -> l return [slice(int(s), int(e)) for s, e in zip(starts, stops, strict=True)] -class ChunkSamplerTorchDistributed(ChunkSampler): - """Distributed chunk-based sampler that shards data across torch distributed processes. +def _get_dist_info_torch() -> tuple[int, int]: + """Get rank and world_size from ``torch.distributed``.""" + import torch.distributed as dist + + if not dist.is_initialized(): + raise RuntimeError( + "torch.distributed is not initialized. " + "Initialize it before creating a ChunkSamplerDistributed with backend='torch'." + ) + return dist.get_rank(), dist.get_world_size() + + +def _get_dist_info_jax() -> tuple[int, int]: + """Get rank and world_size from JAX multi-process API.""" + import jax + + if not jax.distributed.is_initialized(): + raise RuntimeError( + "JAX distributed is not initialized. " + "Call jax.distributed.initialize() before creating a ChunkSamplerDistributed with backend='jax'." + ) + return jax.process_index(), jax.process_count() + + +DISTRIBUTED_BACKENDS: dict[str, callable] = { + "torch": _get_dist_info_torch, + "jax": _get_dist_info_jax, +} + + +class ChunkSamplerDistributed(ChunkSampler): + """Distributed chunk-based sampler that shards data across distributed processes. Partitions the full observation range into ``world_size`` contiguous shards using the ``mask`` mechanism of :class:`ChunkSampler`. Each rank receives a @@ -199,9 +229,9 @@ class ChunkSamplerTorchDistributed(ChunkSampler): guaranteeing that every rank yields exactly the same number of complete batches. - Rank and world size are obtained from ``torch.distributed`` at construction - time, so ``torch.distributed`` must be initialized before creating an - instance of this sampler. + Rank and world size are obtained from the distributed framework specified by + ``backend`` at construction time, so the framework must be initialized + before creating an instance of this sampler. Parameters ---------- @@ -211,6 +241,10 @@ class ChunkSamplerTorchDistributed(ChunkSampler): Number of chunks to load per iteration. batch_size Number of observations per batch. + backend + Distributed backend to query for rank and world size. + Supported values: ``"torch"`` (uses :mod:`torch.distributed`) and + ``"jax"`` (uses :func:`jax.process_index` / :func:`jax.process_count`). shuffle Whether to shuffle chunk and index order. drop_last @@ -218,8 +252,10 @@ class ChunkSamplerTorchDistributed(ChunkSampler): rng Random number generator for shuffling. enforce_equal_batches - If *True*, round each rank's observation count down to a multiple of ``batch_size`` so that all ranks yield the same numberof batches. - Set to *False* to use the raw ``n_obs // world_size`` split, which may result in a uneven number of batches per worker. + If *True*, round each rank's observation count down to a multiple of + ``batch_size`` so that all ranks yield the same number of batches. + Set to *False* to use the raw ``n_obs // world_size`` split, which may + result in an uneven number of batches per worker. """ _rank: int @@ -232,20 +268,16 @@ def __init__( preload_nchunks: int, batch_size: int, *, + backend: Literal["torch", "jax"], shuffle: bool = False, drop_last: bool = False, rng: np.random.Generator | None = None, enforce_equal_batches: bool = True, ): - import torch.distributed as dist - - if not dist.is_initialized(): - raise RuntimeError( - "torch.distributed is not initialized. Initialize it before creating a ChunkSamplerTorchDistributed." - ) + if backend not in DISTRIBUTED_BACKENDS: + raise ValueError(f"Unknown backend {backend!r}. Supported backends: {sorted(DISTRIBUTED_BACKENDS)}") - self._rank = dist.get_rank() - self._world_size = dist.get_world_size() + self._rank, self._world_size = DISTRIBUTED_BACKENDS[backend]() self._enforce_equal_batches = enforce_equal_batches super().__init__( diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 5c85bcf9..c811c5b2 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -9,7 +9,7 @@ import numpy as np import pytest -from annbatch import ChunkSampler, ChunkSamplerTorchDistributed +from annbatch import ChunkSampler, ChunkSamplerDistributed from annbatch.abc import Sampler from annbatch.samplers._utils import WorkerInfo @@ -373,12 +373,12 @@ def test_automatic_batching_respects_shuffle_flag(shuffle: bool): # ============================================================================= -# ChunkSamplerTorchDistributed tests +# ChunkSamplerDistributed tests # ============================================================================= -def _make_distributed_sampler(rank: int, world_size: int, **kwargs) -> ChunkSamplerTorchDistributed: - """Create a ChunkSamplerTorchDistributed with mocked torch.distributed.""" +def _make_distributed_sampler_torch(rank: int, world_size: int, **kwargs) -> ChunkSamplerDistributed: + """Create a ChunkSamplerDistributed with mocked torch.distributed backend.""" mock_dist = MagicMock() mock_dist.is_initialized.return_value = True mock_dist.get_rank.return_value = rank @@ -386,11 +386,35 @@ def _make_distributed_sampler(rank: int, world_size: int, **kwargs) -> ChunkSamp mock_torch = MagicMock() mock_torch.distributed = mock_dist with patch.dict(sys.modules, {"torch": mock_torch, "torch.distributed": mock_dist}): - return ChunkSamplerTorchDistributed(**kwargs) + return ChunkSamplerDistributed(backend="torch", **kwargs) -class TestChunkSamplerTorchDistributed: - def test_not_initialized_raises(self): +def _make_distributed_sampler_jax(rank: int, world_size: int, **kwargs) -> ChunkSamplerDistributed: + """Create a ChunkSamplerDistributed with mocked jax backend.""" + mock_jax = MagicMock() + mock_jax.process_index.return_value = rank + mock_jax.process_count.return_value = world_size + mock_jax.distributed.is_initialized.return_value = True + with patch.dict(sys.modules, {"jax": mock_jax}): + return ChunkSamplerDistributed(backend="jax", **kwargs) + + +_SAMPLER_FACTORIES = { + "torch": _make_distributed_sampler_torch, + "jax": _make_distributed_sampler_jax, +} + + +@pytest.fixture(params=["torch", "jax"]) +def make_distributed_sampler(request): + """Fixture that yields a sampler factory for each backend.""" + return _SAMPLER_FACTORIES[request.param] + + +class TestChunkSamplerDistributed: + """Tests for ChunkSamplerDistributed, parameterized over all backends.""" + + def test_not_initialized_raises_torch(self): """RuntimeError when torch.distributed is not initialized.""" mock_dist = MagicMock() mock_dist.is_initialized.return_value = False @@ -398,16 +422,29 @@ def test_not_initialized_raises(self): mock_torch.distributed = mock_dist with patch.dict(sys.modules, {"torch": mock_torch, "torch.distributed": mock_dist}): with pytest.raises(RuntimeError, match="torch.distributed is not initialized"): - ChunkSamplerTorchDistributed(chunk_size=10, preload_nchunks=2, batch_size=10) - - def test_shards_are_disjoint_and_cover_full_dataset(self): + ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, backend="torch") + + def test_not_initialized_raises_jax(self): + """RuntimeError when jax.distributed is not initialized.""" + mock_jax = MagicMock() + mock_jax.distributed.is_initialized.return_value = False + with patch.dict(sys.modules, {"jax": mock_jax}): + with pytest.raises(RuntimeError, match="JAX distributed is not initialized"): + ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, backend="jax") + + def test_unknown_backend_raises(self): + """ValueError for an unsupported backend string.""" + with pytest.raises(ValueError, match="Unknown backend"): + ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, backend="mpi") + + def test_shards_are_disjoint_and_cover_full_dataset(self, make_distributed_sampler): """All ranks receive non-overlapping shards that together cover the full dataset.""" n_obs, world_size = 200, 4 chunk_size, preload_nchunks, batch_size = 10, 2, 10 all_indices = [] for rank in range(world_size): - sampler = _make_distributed_sampler( + sampler = make_distributed_sampler( rank=rank, world_size=world_size, chunk_size=chunk_size, @@ -434,12 +471,12 @@ def test_shards_are_disjoint_and_cover_full_dataset(self): ], ) def test_enforce_equal_batches_all_ranks_same_count( - self, n_obs, world_size, batch_size, chunk_size, preload_nchunks + self, make_distributed_sampler, n_obs, world_size, batch_size, chunk_size, preload_nchunks ): """enforce_equal_batches=True guarantees identical batch counts across ranks.""" batch_counts = [] for rank in range(world_size): - sampler = _make_distributed_sampler( + sampler = make_distributed_sampler( rank=rank, world_size=world_size, chunk_size=chunk_size, @@ -452,12 +489,12 @@ def test_enforce_equal_batches_all_ranks_same_count( assert len(set(batch_counts)) == 1, f"Batch counts differ across ranks: {batch_counts}" - def test_enforce_equal_batches_rounds_down_per_rank(self): + def test_enforce_equal_batches_rounds_down_per_rank(self, make_distributed_sampler): """enforce_equal_batches=True rounds per_rank down to a multiple of batch_size.""" n_obs, world_size = 107, 3 chunk_size, preload_nchunks, batch_size = 10, 1, 10 # raw per_rank = 107 // 3 = 35, rounded = 35 // 10 * 10 = 30 - sampler = _make_distributed_sampler( + sampler = make_distributed_sampler( rank=0, world_size=world_size, chunk_size=chunk_size, @@ -468,12 +505,12 @@ def test_enforce_equal_batches_rounds_down_per_rank(self): indices = collect_indices(sampler, n_obs) assert len(set(indices)) == 30 - def test_enforce_equal_batches_false_uses_raw_split(self): + def test_enforce_equal_batches_false_uses_raw_split(self, make_distributed_sampler): """enforce_equal_batches=False uses n_obs // world_size without rounding.""" n_obs, world_size = 107, 3 chunk_size, preload_nchunks, batch_size = 10, 1, 10 # raw per_rank = 107 // 3 = 35 (not rounded to 30) - sampler = _make_distributed_sampler( + sampler = make_distributed_sampler( rank=0, world_size=world_size, chunk_size=chunk_size, @@ -484,13 +521,13 @@ def test_enforce_equal_batches_false_uses_raw_split(self): indices = collect_indices(sampler, n_obs) assert len(set(indices)) == 35 - def test_n_iters_matches_actual_batch_count(self): + def test_n_iters_matches_actual_batch_count(self, make_distributed_sampler): """n_iters should match the actual number of yielded batches.""" n_obs, world_size = 205, 3 chunk_size, preload_nchunks, batch_size = 10, 2, 10 for rank in range(world_size): - sampler = _make_distributed_sampler( + sampler = make_distributed_sampler( rank=rank, world_size=world_size, chunk_size=chunk_size, From 5216dc92ec50e461f1181f49552e4c23ee121e56 Mon Sep 17 00:00:00 2001 From: Felix Fischer Date: Wed, 25 Feb 2026 12:21:50 +0100 Subject: [PATCH 3/4] Clean up test + fix rng + dist_info handling --- src/annbatch/samplers/_chunk_sampler.py | 41 +++++++++++----------- tests/test_sampler.py | 45 ++++++++++--------------- 2 files changed, 37 insertions(+), 49 deletions(-) diff --git a/src/annbatch/samplers/_chunk_sampler.py b/src/annbatch/samplers/_chunk_sampler.py index ecca0507..a24f6935 100644 --- a/src/annbatch/samplers/_chunk_sampler.py +++ b/src/annbatch/samplers/_chunk_sampler.py @@ -12,7 +12,7 @@ from annbatch.utils import _spawn_worker_rng, check_lt_1, split_given_size if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Callable, Iterator from annbatch.types import LoadRequest @@ -193,7 +193,7 @@ def _get_dist_info_torch() -> tuple[int, int]: if not dist.is_initialized(): raise RuntimeError( "torch.distributed is not initialized. " - "Initialize it before creating a ChunkSamplerDistributed with backend='torch'." + "Initialize it before creating a ChunkSamplerDistributed with dist_info='torch'." ) return dist.get_rank(), dist.get_world_size() @@ -205,12 +205,12 @@ def _get_dist_info_jax() -> tuple[int, int]: if not jax.distributed.is_initialized(): raise RuntimeError( "JAX distributed is not initialized. " - "Call jax.distributed.initialize() before creating a ChunkSamplerDistributed with backend='jax'." + "Call jax.distributed.initialize() before creating a ChunkSamplerDistributed with dist_info='jax'." ) return jax.process_index(), jax.process_count() -DISTRIBUTED_BACKENDS: dict[str, callable] = { +DISTRIBUTED_BACKENDS: dict[str, Callable[[], tuple[int, int]]] = { "torch": _get_dist_info_torch, "jax": _get_dist_info_jax, } @@ -229,9 +229,8 @@ class ChunkSamplerDistributed(ChunkSampler): guaranteeing that every rank yields exactly the same number of complete batches. - Rank and world size are obtained from the distributed framework specified by - ``backend`` at construction time, so the framework must be initialized - before creating an instance of this sampler. + Rank and world size are obtained from ``dist_info`` at construction time. + The corresponding distributed framework must already be initialized. Parameters ---------- @@ -241,10 +240,10 @@ class ChunkSamplerDistributed(ChunkSampler): Number of chunks to load per iteration. batch_size Number of observations per batch. - backend - Distributed backend to query for rank and world size. - Supported values: ``"torch"`` (uses :mod:`torch.distributed`) and - ``"jax"`` (uses :func:`jax.process_index` / :func:`jax.process_count`). + dist_info + How to obtain rank and world size. + Either a string naming a distributed backend (``"torch"`` or ``"jax"``), + or a callable that returns ``(rank, world_size)``. shuffle Whether to shuffle chunk and index order. drop_last @@ -252,10 +251,8 @@ class ChunkSamplerDistributed(ChunkSampler): rng Random number generator for shuffling. enforce_equal_batches - If *True*, round each rank's observation count down to a multiple of - ``batch_size`` so that all ranks yield the same number of batches. - Set to *False* to use the raw ``n_obs // world_size`` split, which may - result in an uneven number of batches per worker. + If *True*, round each rank's observation count down to a multiple of ``batch_size`` so that all workers (ranks) yield the same number of batches. + Set to *False* to use the raw ``n_obs // world_size`` split, which may result in an uneven number of batches per worker. """ _rank: int @@ -268,16 +265,18 @@ def __init__( preload_nchunks: int, batch_size: int, *, - backend: Literal["torch", "jax"], + dist_info: Literal["torch", "jax"] | Callable[[], tuple[int, int]], shuffle: bool = False, drop_last: bool = False, rng: np.random.Generator | None = None, enforce_equal_batches: bool = True, ): - if backend not in DISTRIBUTED_BACKENDS: - raise ValueError(f"Unknown backend {backend!r}. Supported backends: {sorted(DISTRIBUTED_BACKENDS)}") - - self._rank, self._world_size = DISTRIBUTED_BACKENDS[backend]() + if callable(dist_info): + self._rank, self._world_size = dist_info() + elif dist_info in DISTRIBUTED_BACKENDS: + self._rank, self._world_size = DISTRIBUTED_BACKENDS[dist_info]() + else: + raise ValueError(f"Unknown dist_info {dist_info!r}. Supported backends: {sorted(DISTRIBUTED_BACKENDS)}") self._enforce_equal_batches = enforce_equal_batches super().__init__( @@ -286,7 +285,7 @@ def __init__( batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, - rng=rng, + rng=_spawn_worker_rng(rng, self._rank) if rng else None, ) def _shard_mask(self, n_obs: int) -> slice: diff --git a/tests/test_sampler.py b/tests/test_sampler.py index c811c5b2..0148f570 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -386,7 +386,7 @@ def _make_distributed_sampler_torch(rank: int, world_size: int, **kwargs) -> Chu mock_torch = MagicMock() mock_torch.distributed = mock_dist with patch.dict(sys.modules, {"torch": mock_torch, "torch.distributed": mock_dist}): - return ChunkSamplerDistributed(backend="torch", **kwargs) + return ChunkSamplerDistributed(dist_info="torch", **kwargs) def _make_distributed_sampler_jax(rank: int, world_size: int, **kwargs) -> ChunkSamplerDistributed: @@ -396,7 +396,7 @@ def _make_distributed_sampler_jax(rank: int, world_size: int, **kwargs) -> Chunk mock_jax.process_count.return_value = world_size mock_jax.distributed.is_initialized.return_value = True with patch.dict(sys.modules, {"jax": mock_jax}): - return ChunkSamplerDistributed(backend="jax", **kwargs) + return ChunkSamplerDistributed(dist_info="jax", **kwargs) _SAMPLER_FACTORIES = { @@ -422,7 +422,7 @@ def test_not_initialized_raises_torch(self): mock_torch.distributed = mock_dist with patch.dict(sys.modules, {"torch": mock_torch, "torch.distributed": mock_dist}): with pytest.raises(RuntimeError, match="torch.distributed is not initialized"): - ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, backend="torch") + ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, dist_info="torch") def test_not_initialized_raises_jax(self): """RuntimeError when jax.distributed is not initialized.""" @@ -430,12 +430,12 @@ def test_not_initialized_raises_jax(self): mock_jax.distributed.is_initialized.return_value = False with patch.dict(sys.modules, {"jax": mock_jax}): with pytest.raises(RuntimeError, match="JAX distributed is not initialized"): - ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, backend="jax") + ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, dist_info="jax") - def test_unknown_backend_raises(self): - """ValueError for an unsupported backend string.""" - with pytest.raises(ValueError, match="Unknown backend"): - ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, backend="mpi") + def test_unknown_dist_info_raises(self): + """ValueError for an unsupported dist_info string.""" + with pytest.raises(ValueError, match="Unknown dist_info"): + ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, dist_info="mpi") def test_shards_are_disjoint_and_cover_full_dataset(self, make_distributed_sampler): """All ranks receive non-overlapping shards that together cover the full dataset.""" @@ -489,8 +489,13 @@ def test_enforce_equal_batches_all_ranks_same_count( assert len(set(batch_counts)) == 1, f"Batch counts differ across ranks: {batch_counts}" - def test_enforce_equal_batches_rounds_down_per_rank(self, make_distributed_sampler): - """enforce_equal_batches=True rounds per_rank down to a multiple of batch_size.""" + @pytest.mark.parametrize( + ("enforce_equal_batches", "expected"), + [(True, 30), (False, 35)], + ids=["rounded", "raw"], + ) + def test_enforce_equal_batches_per_rank_count(self, make_distributed_sampler, enforce_equal_batches, expected): + """enforce_equal_batches controls whether per_rank is rounded down to a multiple of batch_size.""" n_obs, world_size = 107, 3 chunk_size, preload_nchunks, batch_size = 10, 1, 10 # raw per_rank = 107 // 3 = 35, rounded = 35 // 10 * 10 = 30 @@ -500,26 +505,10 @@ def test_enforce_equal_batches_rounds_down_per_rank(self, make_distributed_sampl chunk_size=chunk_size, preload_nchunks=preload_nchunks, batch_size=batch_size, - enforce_equal_batches=True, - ) - indices = collect_indices(sampler, n_obs) - assert len(set(indices)) == 30 - - def test_enforce_equal_batches_false_uses_raw_split(self, make_distributed_sampler): - """enforce_equal_batches=False uses n_obs // world_size without rounding.""" - n_obs, world_size = 107, 3 - chunk_size, preload_nchunks, batch_size = 10, 1, 10 - # raw per_rank = 107 // 3 = 35 (not rounded to 30) - sampler = make_distributed_sampler( - rank=0, - world_size=world_size, - chunk_size=chunk_size, - preload_nchunks=preload_nchunks, - batch_size=batch_size, - enforce_equal_batches=False, + enforce_equal_batches=enforce_equal_batches, ) indices = collect_indices(sampler, n_obs) - assert len(set(indices)) == 35 + assert len(set(indices)) == expected def test_n_iters_matches_actual_batch_count(self, make_distributed_sampler): """n_iters should match the actual number of yielded batches.""" From 7002aa93c91519b7d8ae20d267ed0f63b5f9d1e0 Mon Sep 17 00:00:00 2001 From: Felix Fischer Date: Wed, 25 Feb 2026 12:31:42 +0100 Subject: [PATCH 4/4] Add test to check reproducibility --- tests/test_sampler.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 0148f570..357574eb 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -510,6 +510,40 @@ def test_enforce_equal_batches_per_rank_count(self, make_distributed_sampler, en indices = collect_indices(sampler, n_obs) assert len(set(indices)) == expected + def test_batch_shuffle_is_reproducible_with_same_seed_rng(self, make_distributed_sampler): + """Test that batch shuffling is reproducible when passing in rngs with identical seeds.""" + n_obs, chunk_size, preload_nchunks, batch_size = 200, 10, 2, 5 + world_size = 4 + seed = 42 + + def collect_splits(sampler: ChunkSamplerDistributed) -> list[list[int]]: + all_splits: list[list[int]] = [] + for load_request in sampler.sample(n_obs): + for split in load_request["splits"]: + all_splits.append(split.tolist()) + return all_splits + + splits_per_run: list[dict[int, list[list[int]]]] = [] + for _ in range(3): # test 3 runs to ensure reproducibility + splits_by_rank: dict[int, list[list[int]]] = {} + for rank in range(world_size): + sampler = make_distributed_sampler( + rank=rank, + world_size=world_size, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + shuffle=True, + rng=np.random.default_rng(seed), + ) + splits_by_rank[rank] = collect_splits(sampler) + splits_per_run.append(splits_by_rank) + + for rank in range(world_size): + assert splits_per_run[0][rank] == splits_per_run[1][rank], ( + f"Rank {rank}: batch shuffling should be reproducible with same seed" + ) + def test_n_iters_matches_actual_batch_count(self, make_distributed_sampler): """n_iters should match the actual number of yielded batches.""" n_obs, world_size = 205, 3