From 214c4d7d4c726b1ee8327035bf6e5868a266642d Mon Sep 17 00:00:00 2001 From: gpalla Date: Tue, 16 Dec 2025 18:20:01 -0800 Subject: [PATCH] distributed dataloader --- src/annbatch/distributed.py | 17 ++++ src/annbatch/loader.py | 131 ++++++++++++++++++++++++++++--- tests/test_distributed_loader.py | 124 +++++++++++++++++++++++++++++ 3 files changed, 260 insertions(+), 12 deletions(-) create mode 100644 src/annbatch/distributed.py create mode 100644 tests/test_distributed_loader.py diff --git a/src/annbatch/distributed.py b/src/annbatch/distributed.py new file mode 100644 index 00000000..0baa8847 --- /dev/null +++ b/src/annbatch/distributed.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from importlib.util import find_spec + + +def get_rank_and_world_size() -> tuple[int, int]: + """Return (rank, world_size) if torch.distributed is initialized, else (0, 1).""" + if find_spec("torch") is None: + return 0, 1 + + import torch.distributed as dist + + if dist.is_available() and dist.is_initialized(): + return dist.get_rank(), dist.get_world_size() + return 0, 1 + + diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 0e572336..c8867c0e 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -2,6 +2,7 @@ import asyncio import math +import warnings from collections import OrderedDict, defaultdict from functools import singledispatchmethod from importlib.util import find_spec @@ -16,6 +17,7 @@ from scipy import sparse as sp from zarr import Array as ZarrArray +from annbatch.distributed import get_rank_and_world_size from annbatch.types import BackingArray_T, InputInMemoryArray_T, OutputInMemoryArray_T from annbatch.utils import ( CSRContainer, @@ -35,8 +37,15 @@ CupyCSRMatrix = NoneType CupyArray = NoneType try: - from torch.utils.data import IterableDataset as _IterableDataset -except ImportError: + with warnings.catch_warnings(): + # pytest config treats warnings as errors; torch import may emit a FutureWarning in some envs + warnings.filterwarnings( + "ignore", + message=r"The pynvml package is deprecated.*", + category=FutureWarning, + ) + from torch.utils.data import IterableDataset as _IterableDataset +except Exception: class _IterableDataset: pass @@ -129,6 +138,12 @@ class Loader[BackingArray: BackingArray_T, InputInMemoryArray: InputInMemoryArra _worker_handle: WorkerHandle _chunk_size: int _dataset_elem_cache: dict[int, CSRDatasetElems] + _distributed: bool + _shuffle_seed: int + _epoch: int + _drop_last_indices: bool + _pad_indices: bool + _n_obs_effective: int | None def __init__( self, @@ -136,6 +151,10 @@ def __init__( chunk_size: int = 512, preload_nchunks: int = 32, shuffle: bool = True, + shuffle_seed: int = 0, + distributed: bool = True, + drop_last_indices: bool = False, + pad_indices: bool = True, return_index: bool = False, batch_size: int = 1, preload_to_gpu: bool = find_spec("cupy") is not None, @@ -165,11 +184,21 @@ def __init__( self._chunk_size = chunk_size self._preload_nchunks = preload_nchunks self._shuffle = shuffle + self._shuffle_seed = shuffle_seed + self._distributed = distributed + self._drop_last_indices = drop_last_indices + self._pad_indices = pad_indices + self._epoch = 0 + self._n_obs_effective = None self._worker_handle = WorkerHandle() self._train_datasets = [] self._shapes = [] self._dataset_elem_cache = {} + def set_epoch(self, epoch: int) -> None: + """Set epoch for deterministic shuffling (recommended for distributed training).""" + self._epoch = epoch + def __len__(self) -> int: return self.n_obs @@ -421,9 +450,59 @@ def _get_chunks(self, chunk_size: int) -> np.ndarray: ------- A :class:`numpy.ndarray` of chunk ids. """ - chunks = np.arange(math.ceil(self.n_obs / chunk_size)) + rank, world_size = (0, 1) + if self._distributed: + rank, world_size = get_rank_and_world_size() + + if world_size < 1: + raise ValueError(f"Expected world_size >= 1 but got {world_size}") + if not (0 <= rank < world_size): + raise ValueError(f"Expected rank in [0, {world_size}) but got {rank}") + + n_obs = self.n_obs + + # NOTE: Minimal distributed implementation shards work at the chunk-id level. + # TODO: Changing DataLoader num_workers changes the interleaving of yielded batches (not worker-count invariant). + if world_size > 1: + if (not self._drop_last_indices) and (not self._pad_indices): + raise ValueError( + "When distributed, set either drop_last_indices=True (drop tail for even ranks) " + "or pad_indices=True (wrap-around padding for even ranks)." + ) + + chunk_multiple = chunk_size * world_size + if self._drop_last_indices: + # Drop tail so each rank gets exactly the same number of samples and we never need to wrap. + n_obs_effective = (n_obs // chunk_multiple) * chunk_multiple + else: + # Pad (by wrap-around at iteration time) so each rank gets exactly the same number of samples. + n_obs_effective = math.ceil(n_obs / chunk_multiple) * chunk_multiple + else: + n_obs_effective = n_obs + + self._n_obs_effective = n_obs_effective + + if n_obs_effective == 0: + return np.array([], dtype=int) + + n_chunks = math.ceil(n_obs_effective / chunk_size) + chunks = np.arange(n_chunks) + if self._shuffle: - self._worker_handle.shuffle(chunks) + # Need to handle shuffle + rng = np.random.default_rng(self._shuffle_seed + self._epoch) + rng.shuffle(chunks) + + if world_size > 1: + # In even mode, n_obs_effective is a multiple of (chunk_size * world_size), + # so n_chunks must be divisible by world_size. Keep this as a sanity check. + if (n_chunks % world_size) != 0: + raise ValueError( + "Internal invariant violated: expected n_chunks divisible by world_size " + f"but got {n_chunks} % {world_size} != 0." + ) + per_rank = n_chunks // world_size + chunks = chunks[rank * per_rank : (rank + 1) * per_rank] return self._worker_handle.get_part_for_worker(chunks) @@ -605,14 +684,40 @@ def __iter__( in_memory_labels = None in_memory_indices = None mod = self._sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np + + def _wrap_slice(s: slice) -> list[slice]: + # Map a potentially out-of-range slice in the (possibly padded) global space back into [0, n_obs). + # For padding, we wrap around to the start of the dataset. + start = 0 if s.start is None else int(s.start) + stop = 0 if s.stop is None else int(s.stop) + if start >= self.n_obs: + return [] + if stop <= self.n_obs: + return [slice(start, stop)] + if not self._pad_indices: + return [slice(start, min(stop, self.n_obs))] + if self.n_obs == 0: + return [] + start_mod = start % self.n_obs + stop_mod = stop % self.n_obs + if start_mod < stop_mod: + return [slice(start_mod, stop_mod)] + return [slice(start_mod, self.n_obs), slice(0, stop_mod)] + + rank, _ = (0, 1) + if self._distributed: + rank, _ = get_rank_and_world_size() + worker_id = 0 + if (wi := self._worker_handle._worker_info) is not None: # noqa: SLF001 + worker_id = wi.id + for chunk_indices in _batched(self._get_chunks(self._chunk_size), self._preload_nchunks): - slices = [ - slice( - index * self._chunk_size, - min(self.n_obs, (index + 1) * self._chunk_size), - ) - for index in chunk_indices - ] + n_obs_effective = self._n_obs_effective if self._n_obs_effective is not None else self.n_obs + slices: list[slice] = [] + for index in chunk_indices: + start = index * self._chunk_size + stop = min(n_obs_effective, (index + 1) * self._chunk_size) + slices.extend([ss for ss in _wrap_slice(slice(start, stop)) if ss.start is not None and ss.stop is not None and ss.start < ss.stop]) dataset_index_to_slices = self._slices_to_slices_with_array_index(slices) # Fetch the data over slices chunks: list[InputInMemoryArray] = zsync.sync(self._index_datasets(dataset_index_to_slices)) @@ -675,7 +780,9 @@ def __iter__( # save it for the next iteration. batch_indices = np.arange(in_memory_data.shape[0]) if self._shuffle: - np.random.default_rng().shuffle(batch_indices) + # Deterministic per-epoch shuffling for a fixed (rank, num_workers). + rng = np.random.default_rng(self._shuffle_seed + self._epoch + (rank * 10_000) + worker_id) + rng.shuffle(batch_indices) splits = split_given_size(batch_indices, self._batch_size) for i, s in enumerate(splits): if s.shape[0] == self._batch_size: diff --git a/tests/test_distributed_loader.py b/tests/test_distributed_loader.py new file mode 100644 index 00000000..24fbf2d9 --- /dev/null +++ b/tests/test_distributed_loader.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest +import zarr + +from annbatch import Loader + + +def _dense_datasets_from_store(store_path: Path) -> list[zarr.Array]: + return [zarr.open(p)["X"] for p in sorted(store_path.glob("*.zarr"))] + + +@pytest.mark.parametrize("world_size", [2, 3]) +def test_distributed_no_overlap_drop_last_indices( + monkeypatch, + adata_with_zarr_path_same_var_space: tuple[object, Path], + world_size: int, +): + # chunk_size chosen so that we drop some tail data at chunk granularity + chunk_size = 7 + preload_nchunks = 1 + batch_size = 1 + + store_path = adata_with_zarr_path_same_var_space[1] + datasets = _dense_datasets_from_store(store_path) + + per_rank_indices: list[np.ndarray] = [] + for rank in range(world_size): + monkeypatch.setattr("annbatch.loader.get_rank_and_world_size", lambda r=rank: (r, world_size)) + + ds = Loader( + shuffle=False, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + return_index=True, + preload_to_gpu=False, + to_torch=False, + distributed=True, + drop_last_indices=True, + pad_indices=False, + ) + ds.add_datasets(datasets) + + idxs = np.concatenate([idx for _, _, idx in ds]).ravel() + per_rank_indices.append(idxs) + + # No overlap between ranks + for i in range(world_size): + for j in range(i + 1, world_size): + assert set(per_rank_indices[i]).isdisjoint(set(per_rank_indices[j])) + + # Dropped tail at chunk granularity -> total yielded is a multiple of (chunk_size * world_size) + total = sum(len(v) for v in per_rank_indices) + assert total % (chunk_size * world_size) == 0 + + +def test_distributed_padding_repeats_when_enabled(monkeypatch, adata_with_zarr_path_same_var_space: tuple[object, Path]): + world_size = 2 + chunk_size = 7 + preload_nchunks = 1 + batch_size = 1 + + store_path = adata_with_zarr_path_same_var_space[1] + datasets = _dense_datasets_from_store(store_path) + + all_indices = [] + for rank in range(world_size): + monkeypatch.setattr("annbatch.loader.get_rank_and_world_size", lambda r=rank: (r, world_size)) + + ds = Loader( + shuffle=False, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + batch_size=batch_size, + return_index=True, + preload_to_gpu=False, + to_torch=False, + distributed=True, + drop_last_indices=False, + pad_indices=True, + ) + ds.add_datasets(datasets) + all_indices.append(np.concatenate([idx for _, _, idx in ds]).ravel()) + + concatenated = np.concatenate(all_indices) + assert len(concatenated) % (chunk_size * world_size) == 0 + # Padding implies repeats (since underlying dataset is finite) + assert len(np.unique(concatenated)) <= len(concatenated) + assert len(np.unique(concatenated)) == sum(d.shape[0] for d in datasets) + + +def test_distributed_deterministic_for_fixed_rank(monkeypatch, adata_with_zarr_path_same_var_space: tuple[object, Path]): + world_size = 2 + rank = 0 + monkeypatch.setattr("annbatch.loader.get_rank_and_world_size", lambda: (rank, world_size)) + + store_path = adata_with_zarr_path_same_var_space[1] + datasets = _dense_datasets_from_store(store_path) + + ds = Loader( + shuffle=True, + shuffle_seed=123, + chunk_size=7, + preload_nchunks=2, + batch_size=3, + return_index=True, + preload_to_gpu=False, + to_torch=False, + distributed=True, + drop_last_indices=True, + pad_indices=False, + ) + ds.add_datasets(datasets) + ds.set_epoch(0) + + idxs1 = np.concatenate([idx for _, _, idx in ds]).ravel() + idxs2 = np.concatenate([idx for _, _, idx in ds]).ravel() + assert np.array_equal(idxs1, idxs2) + +