Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/annbatch/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

from importlib.util import find_spec


def get_rank_and_world_size() -> tuple[int, int]:
"""Return (rank, world_size) if torch.distributed is initialized, else (0, 1)."""
if find_spec("torch") is None:
return 0, 1

import torch.distributed as dist

if dist.is_available() and dist.is_initialized():
return dist.get_rank(), dist.get_world_size()
return 0, 1


131 changes: 119 additions & 12 deletions src/annbatch/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import math
import warnings
from collections import OrderedDict, defaultdict
from functools import singledispatchmethod
from importlib.util import find_spec
Expand All @@ -16,6 +17,7 @@
from scipy import sparse as sp
from zarr import Array as ZarrArray

from annbatch.distributed import get_rank_and_world_size
from annbatch.types import BackingArray_T, InputInMemoryArray_T, OutputInMemoryArray_T
from annbatch.utils import (
CSRContainer,
Expand All @@ -35,8 +37,15 @@
CupyCSRMatrix = NoneType
CupyArray = NoneType
try:
from torch.utils.data import IterableDataset as _IterableDataset
except ImportError:
with warnings.catch_warnings():
# pytest config treats warnings as errors; torch import may emit a FutureWarning in some envs
warnings.filterwarnings(
"ignore",
message=r"The pynvml package is deprecated.*",
category=FutureWarning,
)
from torch.utils.data import IterableDataset as _IterableDataset
except Exception:

class _IterableDataset:
pass
Expand Down Expand Up @@ -129,13 +138,23 @@ class Loader[BackingArray: BackingArray_T, InputInMemoryArray: InputInMemoryArra
_worker_handle: WorkerHandle
_chunk_size: int
_dataset_elem_cache: dict[int, CSRDatasetElems]
_distributed: bool
_shuffle_seed: int
_epoch: int
_drop_last_indices: bool
_pad_indices: bool
_n_obs_effective: int | None

def __init__(
self,
*,
chunk_size: int = 512,
preload_nchunks: int = 32,
shuffle: bool = True,
shuffle_seed: int = 0,
distributed: bool = True,
drop_last_indices: bool = False,
pad_indices: bool = True,
return_index: bool = False,
batch_size: int = 1,
preload_to_gpu: bool = find_spec("cupy") is not None,
Expand Down Expand Up @@ -165,11 +184,21 @@ def __init__(
self._chunk_size = chunk_size
self._preload_nchunks = preload_nchunks
self._shuffle = shuffle
self._shuffle_seed = shuffle_seed
self._distributed = distributed
self._drop_last_indices = drop_last_indices
self._pad_indices = pad_indices
self._epoch = 0
self._n_obs_effective = None
self._worker_handle = WorkerHandle()
self._train_datasets = []
self._shapes = []
self._dataset_elem_cache = {}

def set_epoch(self, epoch: int) -> None:
"""Set epoch for deterministic shuffling (recommended for distributed training)."""
self._epoch = epoch

def __len__(self) -> int:
return self.n_obs

Expand Down Expand Up @@ -421,9 +450,59 @@ def _get_chunks(self, chunk_size: int) -> np.ndarray:
-------
A :class:`numpy.ndarray` of chunk ids.
"""
chunks = np.arange(math.ceil(self.n_obs / chunk_size))
rank, world_size = (0, 1)
if self._distributed:
rank, world_size = get_rank_and_world_size()

if world_size < 1:
raise ValueError(f"Expected world_size >= 1 but got {world_size}")
if not (0 <= rank < world_size):
raise ValueError(f"Expected rank in [0, {world_size}) but got {rank}")

n_obs = self.n_obs

# NOTE: Minimal distributed implementation shards work at the chunk-id level.
# TODO: Changing DataLoader num_workers changes the interleaving of yielded batches (not worker-count invariant).
if world_size > 1:
if (not self._drop_last_indices) and (not self._pad_indices):
raise ValueError(
"When distributed, set either drop_last_indices=True (drop tail for even ranks) "
"or pad_indices=True (wrap-around padding for even ranks)."
)

chunk_multiple = chunk_size * world_size
if self._drop_last_indices:
# Drop tail so each rank gets exactly the same number of samples and we never need to wrap.
n_obs_effective = (n_obs // chunk_multiple) * chunk_multiple
else:
# Pad (by wrap-around at iteration time) so each rank gets exactly the same number of samples.
n_obs_effective = math.ceil(n_obs / chunk_multiple) * chunk_multiple
else:
n_obs_effective = n_obs

self._n_obs_effective = n_obs_effective

if n_obs_effective == 0:
return np.array([], dtype=int)

n_chunks = math.ceil(n_obs_effective / chunk_size)
chunks = np.arange(n_chunks)

if self._shuffle:
self._worker_handle.shuffle(chunks)
# Need to handle shuffle
rng = np.random.default_rng(self._shuffle_seed + self._epoch)
rng.shuffle(chunks)

if world_size > 1:
# In even mode, n_obs_effective is a multiple of (chunk_size * world_size),
# so n_chunks must be divisible by world_size. Keep this as a sanity check.
if (n_chunks % world_size) != 0:
raise ValueError(
"Internal invariant violated: expected n_chunks divisible by world_size "
f"but got {n_chunks} % {world_size} != 0."
)
per_rank = n_chunks // world_size
chunks = chunks[rank * per_rank : (rank + 1) * per_rank]

return self._worker_handle.get_part_for_worker(chunks)

Expand Down Expand Up @@ -605,14 +684,40 @@ def __iter__(
in_memory_labels = None
in_memory_indices = None
mod = self._sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np

def _wrap_slice(s: slice) -> list[slice]:
# Map a potentially out-of-range slice in the (possibly padded) global space back into [0, n_obs).
# For padding, we wrap around to the start of the dataset.
start = 0 if s.start is None else int(s.start)
stop = 0 if s.stop is None else int(s.stop)
if start >= self.n_obs:
return []
if stop <= self.n_obs:
return [slice(start, stop)]
if not self._pad_indices:
return [slice(start, min(stop, self.n_obs))]
if self.n_obs == 0:
return []
start_mod = start % self.n_obs
stop_mod = stop % self.n_obs
if start_mod < stop_mod:
return [slice(start_mod, stop_mod)]
return [slice(start_mod, self.n_obs), slice(0, stop_mod)]

rank, _ = (0, 1)
if self._distributed:
rank, _ = get_rank_and_world_size()
worker_id = 0
if (wi := self._worker_handle._worker_info) is not None: # noqa: SLF001
worker_id = wi.id

for chunk_indices in _batched(self._get_chunks(self._chunk_size), self._preload_nchunks):
slices = [
slice(
index * self._chunk_size,
min(self.n_obs, (index + 1) * self._chunk_size),
)
for index in chunk_indices
]
n_obs_effective = self._n_obs_effective if self._n_obs_effective is not None else self.n_obs
slices: list[slice] = []
for index in chunk_indices:
start = index * self._chunk_size
stop = min(n_obs_effective, (index + 1) * self._chunk_size)
slices.extend([ss for ss in _wrap_slice(slice(start, stop)) if ss.start is not None and ss.stop is not None and ss.start < ss.stop])
dataset_index_to_slices = self._slices_to_slices_with_array_index(slices)
# Fetch the data over slices
chunks: list[InputInMemoryArray] = zsync.sync(self._index_datasets(dataset_index_to_slices))
Expand Down Expand Up @@ -675,7 +780,9 @@ def __iter__(
# save it for the next iteration.
batch_indices = np.arange(in_memory_data.shape[0])
if self._shuffle:
np.random.default_rng().shuffle(batch_indices)
# Deterministic per-epoch shuffling for a fixed (rank, num_workers).
rng = np.random.default_rng(self._shuffle_seed + self._epoch + (rank * 10_000) + worker_id)
rng.shuffle(batch_indices)
splits = split_given_size(batch_indices, self._batch_size)
for i, s in enumerate(splits):
if s.shape[0] == self._batch_size:
Expand Down
124 changes: 124 additions & 0 deletions tests/test_distributed_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest
import zarr

from annbatch import Loader


def _dense_datasets_from_store(store_path: Path) -> list[zarr.Array]:
return [zarr.open(p)["X"] for p in sorted(store_path.glob("*.zarr"))]


@pytest.mark.parametrize("world_size", [2, 3])
def test_distributed_no_overlap_drop_last_indices(
monkeypatch,
adata_with_zarr_path_same_var_space: tuple[object, Path],
world_size: int,
):
# chunk_size chosen so that we drop some tail data at chunk granularity
chunk_size = 7
preload_nchunks = 1
batch_size = 1

store_path = adata_with_zarr_path_same_var_space[1]
datasets = _dense_datasets_from_store(store_path)

per_rank_indices: list[np.ndarray] = []
for rank in range(world_size):
monkeypatch.setattr("annbatch.loader.get_rank_and_world_size", lambda r=rank: (r, world_size))

ds = Loader(
shuffle=False,
chunk_size=chunk_size,
preload_nchunks=preload_nchunks,
batch_size=batch_size,
return_index=True,
preload_to_gpu=False,
to_torch=False,
distributed=True,
drop_last_indices=True,
pad_indices=False,
)
ds.add_datasets(datasets)

idxs = np.concatenate([idx for _, _, idx in ds]).ravel()
per_rank_indices.append(idxs)

# No overlap between ranks
for i in range(world_size):
for j in range(i + 1, world_size):
assert set(per_rank_indices[i]).isdisjoint(set(per_rank_indices[j]))

# Dropped tail at chunk granularity -> total yielded is a multiple of (chunk_size * world_size)
total = sum(len(v) for v in per_rank_indices)
assert total % (chunk_size * world_size) == 0


def test_distributed_padding_repeats_when_enabled(monkeypatch, adata_with_zarr_path_same_var_space: tuple[object, Path]):
world_size = 2
chunk_size = 7
preload_nchunks = 1
batch_size = 1

store_path = adata_with_zarr_path_same_var_space[1]
datasets = _dense_datasets_from_store(store_path)

all_indices = []
for rank in range(world_size):
monkeypatch.setattr("annbatch.loader.get_rank_and_world_size", lambda r=rank: (r, world_size))

ds = Loader(
shuffle=False,
chunk_size=chunk_size,
preload_nchunks=preload_nchunks,
batch_size=batch_size,
return_index=True,
preload_to_gpu=False,
to_torch=False,
distributed=True,
drop_last_indices=False,
pad_indices=True,
)
ds.add_datasets(datasets)
all_indices.append(np.concatenate([idx for _, _, idx in ds]).ravel())

concatenated = np.concatenate(all_indices)
assert len(concatenated) % (chunk_size * world_size) == 0
# Padding implies repeats (since underlying dataset is finite)
assert len(np.unique(concatenated)) <= len(concatenated)
assert len(np.unique(concatenated)) == sum(d.shape[0] for d in datasets)


def test_distributed_deterministic_for_fixed_rank(monkeypatch, adata_with_zarr_path_same_var_space: tuple[object, Path]):
world_size = 2
rank = 0
monkeypatch.setattr("annbatch.loader.get_rank_and_world_size", lambda: (rank, world_size))

store_path = adata_with_zarr_path_same_var_space[1]
datasets = _dense_datasets_from_store(store_path)

ds = Loader(
shuffle=True,
shuffle_seed=123,
chunk_size=7,
preload_nchunks=2,
batch_size=3,
return_index=True,
preload_to_gpu=False,
to_torch=False,
distributed=True,
drop_last_indices=True,
pad_indices=False,
)
ds.add_datasets(datasets)
ds.set_epoch(0)

idxs1 = np.concatenate([idx for _, _, idx in ds]).ravel()
idxs2 = np.concatenate([idx for _, _, idx in ds]).ravel()
assert np.array_equal(idxs1, idxs2)


Loading