Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/annbatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from . import abc, types
from .io import DatasetCollection, write_sharded
from .loader import Loader
from .samplers._chunk_sampler import ChunkSampler
from .samplers._chunk_sampler import ChunkSampler, ChunkSamplerDistributed

__version__ = version("annbatch")

Expand All @@ -15,5 +15,6 @@
"types",
"write_sharded",
"ChunkSampler",
"ChunkSamplerDistributed",
"abc",
]
3 changes: 2 additions & 1 deletion src/annbatch/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._chunk_sampler import ChunkSampler
from ._chunk_sampler import ChunkSampler, ChunkSamplerDistributed

__all__ = [
"ChunkSampler",
"ChunkSamplerDistributed",
]
128 changes: 126 additions & 2 deletions src/annbatch/samplers/_chunk_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import math
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

import numpy as np

Expand All @@ -12,7 +12,7 @@
from annbatch.utils import _spawn_worker_rng, check_lt_1, split_given_size

if TYPE_CHECKING:
from collections.abc import Iterator
from collections.abc import Callable, Iterator

from annbatch.types import LoadRequest

Expand Down Expand Up @@ -184,3 +184,127 @@ def _compute_chunks(self, chunk_indices: np.ndarray, start: int, stop: int) -> l
offsets = np.cumsum(offsets)
starts, stops = offsets[:-1][chunk_indices], offsets[1:][chunk_indices]
return [slice(int(s), int(e)) for s, e in zip(starts, stops, strict=True)]


def _get_dist_info_torch() -> tuple[int, int]:
"""Get rank and world_size from ``torch.distributed``."""
Copy link
Member

@selmanozleyen selmanozleyen Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think really don't need backend="jax" kind of interface here. We can get as an input a callable like

class DistInfo(TypedDict):
   process_index
   process_count

class ChunkedDistSampler

    def __init__(
          dist_info_fn:Callable()->DistInfo(TypedDict)
    ):
    pass

Why a callable? So that the they are called whenever the distributed system is spawned/launched.

In an example tutorial we can show anyone can inject anything that gives these. Or we can also implement convenience dist_info_fn's. Anyone who's knowledgeable enough to need this can use this so we can have a more mature interface. No need to add training wheels that will age bad here IMO

Even though it's a very strong opinion of mine it's still just a preference.

Copy link
Collaborator Author

@felix0097 felix0097 Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point, I will add the option to provide a callable as well. Callable then just return a tuple (rank, world_size) then people can easily extend this to other frameworks.

I would keep the torch + jax option though. Most people who do distributed training are not really aware of those concepts. E.g lightning abstracts those away. So, it's easy for them to use it, they just need to know which framework their using.

import torch.distributed as dist

if not dist.is_initialized():
raise RuntimeError(
"torch.distributed is not initialized. "
"Initialize it before creating a ChunkSamplerDistributed with dist_info='torch'."
)
return dist.get_rank(), dist.get_world_size()


def _get_dist_info_jax() -> tuple[int, int]:
"""Get rank and world_size from JAX multi-process API."""
import jax

if not jax.distributed.is_initialized():
raise RuntimeError(
"JAX distributed is not initialized. "
"Call jax.distributed.initialize() before creating a ChunkSamplerDistributed with dist_info='jax'."
)
return jax.process_index(), jax.process_count()


DISTRIBUTED_BACKENDS: dict[str, Callable[[], tuple[int, int]]] = {
"torch": _get_dist_info_torch,
"jax": _get_dist_info_jax,
}


class ChunkSamplerDistributed(ChunkSampler):
"""Distributed chunk-based sampler that shards data across distributed processes.

Partitions the full observation range into ``world_size`` contiguous shards
using the ``mask`` mechanism of :class:`ChunkSampler`. Each rank receives a
non-overlapping slice of the data. The shard boundaries are computed lazily
when ``n_obs`` becomes known.

When ``enforce_equal_batches`` is *True* (the default), the per-rank observation
count is rounded down to the nearest multiple of ``batch_size``,
guaranteeing that every rank yields exactly the same number of complete
batches.

Rank and world size are obtained from ``dist_info`` at construction time.
The corresponding distributed framework must already be initialized.

Parameters
----------
chunk_size
Size of each chunk i.e. the range of each chunk yielded.
preload_nchunks
Number of chunks to load per iteration.
batch_size
Number of observations per batch.
dist_info
How to obtain rank and world size.
Either a string naming a distributed backend (``"torch"`` or ``"jax"``),
or a callable that returns ``(rank, world_size)``.
shuffle
Whether to shuffle chunk and index order.
drop_last
Whether to drop the last incomplete batch.
rng
Random number generator for shuffling.
enforce_equal_batches
If *True*, round each rank's observation count down to a multiple of ``batch_size`` so that all workers (ranks) yield the same number of batches.
Set to *False* to use the raw ``n_obs // world_size`` split, which may result in an uneven number of batches per worker.
"""

_rank: int
_world_size: int
_enforce_equal_batches: bool

def __init__(
self,
chunk_size: int,
preload_nchunks: int,
batch_size: int,
*,
dist_info: Literal["torch", "jax"] | Callable[[], tuple[int, int]],
shuffle: bool = False,
drop_last: bool = False,
rng: np.random.Generator | None = None,
enforce_equal_batches: bool = True,
):
if callable(dist_info):
self._rank, self._world_size = dist_info()
elif dist_info in DISTRIBUTED_BACKENDS:
self._rank, self._world_size = DISTRIBUTED_BACKENDS[dist_info]()
else:
raise ValueError(f"Unknown dist_info {dist_info!r}. Supported backends: {sorted(DISTRIBUTED_BACKENDS)}")
self._enforce_equal_batches = enforce_equal_batches

super().__init__(
chunk_size=chunk_size,
preload_nchunks=preload_nchunks,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
rng=_spawn_worker_rng(rng, self._rank) if rng else None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also added the rng handling like you suggested. Maybe this already addresses you concerns? @selmanozleyen

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that fixes having same batch indices across nodes

)

def _shard_mask(self, n_obs: int) -> slice:
"""Return the contiguous observation slice for this rank."""
per_rank = n_obs // self._world_size
if self._enforce_equal_batches:
per_rank = per_rank // self._batch_size * self._batch_size
rank_start = self._rank * per_rank
rank_stop = rank_start + per_rank
return slice(rank_start, rank_stop)

def n_iters(self, n_obs: int) -> int:
self._mask = self._shard_mask(n_obs)
return super().n_iters(n_obs)

def validate(self, n_obs: int) -> None:
self._mask = self._shard_mask(n_obs)
super().validate(n_obs)

def _sample(self, n_obs: int) -> Iterator[LoadRequest]:
self._mask = self._shard_mask(n_obs)
yield from super()._sample(n_obs)
Comment on lines +308 to +310
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit hacky and requires the child to know too much of the inner logics of it's parent functions details.

From @ilan-gold

This should be made generlizable. We can make a IsDistributable Protocol that the samplers inherit. That way anything that has implements a _mask attribute can be handled this way.

To do this, we may need to consider making _mask a bit more general, like a list[slice] for things like categorical samplers

Maybe I'm misunderstanding though, it seems like this is just masking.

Also no reason to have a class-per-package - if everything is just based on rank/world, we can make this generic (i.e., users can pass in rank/world if torch or one of the recognized packages is not installed)

I mostly agree. That's why I wanted to name chunksampler a window/mask/slice sampler to be technically correct.

For categorical samplers we shouldn't need list[slice] because categorical sampler should be able to use multiple chunksampler's right?

I think list[slice] isn't trivial for ChunkSampler itself because _compute_chunks relies on assumptions of continuity. For example the case for the last chunk. If we want to generalize maybe we can have a base class like MaskListSampler which enforces drop_last=True (even though drop_last isn't a problem for dist training because each model there is different we would need this if we want to generalize).

Easiest change could be having an effective_mask attribute which is the identity self._mask by default and passed to _sampler(n_obs,mask)? and can be overridden in IsDistributables.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put a pin into this and revisit if we see a need to update! For now, it's not really worth it (see comment from Ilan as well below)

197 changes: 195 additions & 2 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from __future__ import annotations

import math
from unittest.mock import patch
import sys
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from annbatch import ChunkSampler
from annbatch import ChunkSampler, ChunkSamplerDistributed
from annbatch.abc import Sampler
from annbatch.samplers._utils import WorkerInfo

Expand Down Expand Up @@ -369,3 +370,195 @@ def test_automatic_batching_respects_shuffle_flag(shuffle: bool):
assert all_indices != list(range(n_obs)), "Indices should be shuffled"
else:
assert all_indices == list(range(n_obs)), "Indices should be sequential"


# =============================================================================
# ChunkSamplerDistributed tests
# =============================================================================


def _make_distributed_sampler_torch(rank: int, world_size: int, **kwargs) -> ChunkSamplerDistributed:
"""Create a ChunkSamplerDistributed with mocked torch.distributed backend."""
mock_dist = MagicMock()
mock_dist.is_initialized.return_value = True
mock_dist.get_rank.return_value = rank
mock_dist.get_world_size.return_value = world_size
mock_torch = MagicMock()
mock_torch.distributed = mock_dist
with patch.dict(sys.modules, {"torch": mock_torch, "torch.distributed": mock_dist}):
Comment on lines +380 to +388
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is why an dist_info_fn as input would be helpful we wouldn't need this hacky ways to test

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment above. I could remove those now that I added the dist_info callalbe. But then we don't have the tests anymore where we check if errors get raised if the dist backend is not initialized (like I said above, i would keep those for ease of use)

return ChunkSamplerDistributed(dist_info="torch", **kwargs)


def _make_distributed_sampler_jax(rank: int, world_size: int, **kwargs) -> ChunkSamplerDistributed:
"""Create a ChunkSamplerDistributed with mocked jax backend."""
mock_jax = MagicMock()
mock_jax.process_index.return_value = rank
mock_jax.process_count.return_value = world_size
mock_jax.distributed.is_initialized.return_value = True
with patch.dict(sys.modules, {"jax": mock_jax}):
return ChunkSamplerDistributed(dist_info="jax", **kwargs)


_SAMPLER_FACTORIES = {
"torch": _make_distributed_sampler_torch,
"jax": _make_distributed_sampler_jax,
}


@pytest.fixture(params=["torch", "jax"])
def make_distributed_sampler(request):
"""Fixture that yields a sampler factory for each backend."""
return _SAMPLER_FACTORIES[request.param]


class TestChunkSamplerDistributed:
"""Tests for ChunkSamplerDistributed, parameterized over all backends."""

def test_not_initialized_raises_torch(self):
"""RuntimeError when torch.distributed is not initialized."""
mock_dist = MagicMock()
mock_dist.is_initialized.return_value = False
mock_torch = MagicMock()
mock_torch.distributed = mock_dist
with patch.dict(sys.modules, {"torch": mock_torch, "torch.distributed": mock_dist}):
with pytest.raises(RuntimeError, match="torch.distributed is not initialized"):
ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, dist_info="torch")

def test_not_initialized_raises_jax(self):
"""RuntimeError when jax.distributed is not initialized."""
mock_jax = MagicMock()
mock_jax.distributed.is_initialized.return_value = False
with patch.dict(sys.modules, {"jax": mock_jax}):
with pytest.raises(RuntimeError, match="JAX distributed is not initialized"):
ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, dist_info="jax")

def test_unknown_dist_info_raises(self):
"""ValueError for an unsupported dist_info string."""
with pytest.raises(ValueError, match="Unknown dist_info"):
ChunkSamplerDistributed(chunk_size=10, preload_nchunks=2, batch_size=10, dist_info="mpi")

def test_shards_are_disjoint_and_cover_full_dataset(self, make_distributed_sampler):
"""All ranks receive non-overlapping shards that together cover the full dataset."""
n_obs, world_size = 200, 4
chunk_size, preload_nchunks, batch_size = 10, 2, 10

all_indices = []
for rank in range(world_size):
sampler = make_distributed_sampler(
rank=rank,
world_size=world_size,
chunk_size=chunk_size,
preload_nchunks=preload_nchunks,
batch_size=batch_size,
)
all_indices.append(collect_indices(sampler, n_obs))

# Shards must be disjoint
for i in range(world_size):
for j in range(i + 1, world_size):
assert set(all_indices[i]).isdisjoint(set(all_indices[j]))

# Together they cover the full dataset (evenly divisible case)
assert set().union(*all_indices) == set(range(n_obs))

@pytest.mark.parametrize(
"n_obs,world_size,batch_size,chunk_size,preload_nchunks",
[
pytest.param(200, 4, 10, 10, 2, id="evenly_divisible"),
pytest.param(205, 3, 10, 10, 2, id="remainder_obs"),
pytest.param(1000, 7, 5, 10, 2, id="prime_world_size"),
pytest.param(100, 3, 5, 10, 2, id="small_dataset"),
],
)
def test_enforce_equal_batches_all_ranks_same_count(
self, make_distributed_sampler, n_obs, world_size, batch_size, chunk_size, preload_nchunks
):
"""enforce_equal_batches=True guarantees identical batch counts across ranks."""
batch_counts = []
for rank in range(world_size):
sampler = make_distributed_sampler(
rank=rank,
world_size=world_size,
chunk_size=chunk_size,
preload_nchunks=preload_nchunks,
batch_size=batch_size,
enforce_equal_batches=True,
)
n_batches = sum(len(lr["splits"]) for lr in sampler.sample(n_obs))
batch_counts.append(n_batches)

assert len(set(batch_counts)) == 1, f"Batch counts differ across ranks: {batch_counts}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not assert this to batch_size also like batch_counts[0] == batch_size

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point I don't have the batch_size info anymore, I'm just tracking the number of yielded batches here (that's what the test is supposed to cover). I'd argue what you're suggesting is already covered by the non distributed tests. So would focus the test here that the dataset sharding is done right etc


@pytest.mark.parametrize(
("enforce_equal_batches", "expected"),
[(True, 30), (False, 35)],
ids=["rounded", "raw"],
)
def test_enforce_equal_batches_per_rank_count(self, make_distributed_sampler, enforce_equal_batches, expected):
"""enforce_equal_batches controls whether per_rank is rounded down to a multiple of batch_size."""
n_obs, world_size = 107, 3
chunk_size, preload_nchunks, batch_size = 10, 1, 10
# raw per_rank = 107 // 3 = 35, rounded = 35 // 10 * 10 = 30
sampler = make_distributed_sampler(
rank=0,
world_size=world_size,
chunk_size=chunk_size,
preload_nchunks=preload_nchunks,
batch_size=batch_size,
enforce_equal_batches=enforce_equal_batches,
)
indices = collect_indices(sampler, n_obs)
assert len(set(indices)) == expected

def test_batch_shuffle_is_reproducible_with_same_seed_rng(self, make_distributed_sampler):
"""Test that batch shuffling is reproducible when passing in rngs with identical seeds."""
n_obs, chunk_size, preload_nchunks, batch_size = 200, 10, 2, 5
world_size = 4
seed = 42

def collect_splits(sampler: ChunkSamplerDistributed) -> list[list[int]]:
all_splits: list[list[int]] = []
for load_request in sampler.sample(n_obs):
for split in load_request["splits"]:
all_splits.append(split.tolist())
return all_splits

splits_per_run: list[dict[int, list[list[int]]]] = []
for _ in range(3): # test 3 runs to ensure reproducibility
splits_by_rank: dict[int, list[list[int]]] = {}
for rank in range(world_size):
sampler = make_distributed_sampler(
rank=rank,
world_size=world_size,
chunk_size=chunk_size,
preload_nchunks=preload_nchunks,
batch_size=batch_size,
shuffle=True,
rng=np.random.default_rng(seed),
)
splits_by_rank[rank] = collect_splits(sampler)
splits_per_run.append(splits_by_rank)

for rank in range(world_size):
assert splits_per_run[0][rank] == splits_per_run[1][rank], (
f"Rank {rank}: batch shuffling should be reproducible with same seed"
)

def test_n_iters_matches_actual_batch_count(self, make_distributed_sampler):
"""n_iters should match the actual number of yielded batches."""
n_obs, world_size = 205, 3
chunk_size, preload_nchunks, batch_size = 10, 2, 10

for rank in range(world_size):
sampler = make_distributed_sampler(
rank=rank,
world_size=world_size,
chunk_size=chunk_size,
preload_nchunks=preload_nchunks,
batch_size=batch_size,
enforce_equal_batches=True,
drop_last=True,
)
expected = sampler.n_iters(n_obs)
actual = sum(len(lr["splits"]) for lr in sampler.sample(n_obs))
assert actual == expected, f"rank {rank}: n_iters={expected}, actual={actual}"
Loading