Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
5232e12
typo
selmanozleyen Feb 11, 2026
5d14e5c
_collection_added' defined outside
selmanozleyen Feb 11, 2026
2558b40
consistent naming with add_anndatas
selmanozleyen Feb 11, 2026
36af588
ruff format
selmanozleyen Feb 11, 2026
d841e13
typo2
selmanozleyen Feb 11, 2026
304dcbb
adapt add_anndatas change to tests
selmanozleyen Feb 11, 2026
4958749
add torch and h5py to mypy ignore_missing_imports
selmanozleyen Feb 11, 2026
830f2d4
fix Mapping.copy() call in write_sharded callback
selmanozleyen Feb 11, 2026
6d6067a
wrap categories in pd.Index for Categorical.from_codes
selmanozleyen Feb 11, 2026
0f5aa1d
add asserts for match/case narrowing and rename idxs variable
selmanozleyen Feb 11, 2026
12830d5
is none == is none works better with mypy
selmanozleyen Feb 11, 2026
a60774c
other add_anndatas renames + changelog
selmanozleyen Feb 12, 2026
9c495b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 12, 2026
38f408a
Merge branch 'main' into fix/typos-n-cleanup
selmanozleyen Feb 12, 2026
3ff60d2
Merge branch 'main' into fix/typos-n-cleanup
selmanozleyen Feb 13, 2026
7f33e78
Revert "is none == is none works better with mypy"
selmanozleyen Feb 13, 2026
462ca48
update changelogs
selmanozleyen Feb 13, 2026
400ee88
updatechangelog again
selmanozleyen Feb 13, 2026
fea4e70
update changelog
selmanozleyen Feb 13, 2026
425bf6e
Merge branch 'main' into fix/typos-n-cleanup
selmanozleyen Feb 17, 2026
ac14eb9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2026
43909a2
Merge branch 'main' into fix/typos-n-cleanup
selmanozleyen Feb 24, 2026
8d23871
Update src/annbatch/io.py
selmanozleyen Feb 24, 2026
0129dc8
anndata_paths
selmanozleyen Feb 24, 2026
d8c9577
load_adata to load_anndata
selmanozleyen Feb 24, 2026
50293fb
fix mistake
selmanozleyen Feb 24, 2026
a485103
rename from anndata to adata
selmanozleyen Feb 25, 2026
b1af3b6
update changelog
selmanozleyen Feb 25, 2026
8716f94
Apply suggestions from code review
ilan-gold Feb 26, 2026
c885049
Apply suggestion from @ilan-gold
ilan-gold Feb 26, 2026
bab856e
Merge branch 'main' into fix/typos-n-cleanup
selmanozleyen Mar 3, 2026
4305a8b
fix after merge conflict
selmanozleyen Mar 3, 2026
33061e6
undo dataset collection changes
selmanozleyen Mar 3, 2026
9da30ca
conftest
selmanozleyen Mar 3, 2026
5098d5c
Update src/annbatch/io.py
ilan-gold Mar 3, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@ and this project adheres to [Semantic Versioning][].

## [0.0.9]

### Breaking
- Renamed `annbatch.Loader.add_anndatas` to {meth}`annbatch.Loader.add_adatas`.
- Renamed `annbatch.Loader.add_anndata` to {meth}`annbatch.Loader.add_adata`.

### Fixed
- Formatted progress bar descriptions to be more readable.
- {class}`annbatch.DatasetCollection` now accepts a `rng` argument to the {meth}`annbatch.DatasetCollection.add_adatas` method.


## [0.0.8]

- {class}`~annbatch.Loader` acccepts an `rng` argument now
Expand Down Expand Up @@ -39,7 +46,7 @@ and this project adheres to [Semantic Versioning][].

## [0.0.4]

- Load into memory nullables/categoricals from `obs` by default when shuffling (i.e., no custom `load_adata` argument to {meth}`annbatch.DatasetCollection.add_adatas`)
- Load into memory nullables/categoricals from `obs` by default when shuffling (i.e., no custom `load_adata` argument to `annbatch.DatasetCollection.add_adatas`)

## [0.0.3]

Expand All @@ -50,12 +57,13 @@ and this project adheres to [Semantic Versioning][].

## [0.0.2]


### Breaking

- `ZarrSparseDataset` and `ZarrDenseDataset` have been conslidated into {class}`annbatch.Loader`
- `create_anndata_collection` and `add_to_collection` have been moved into the {meth}`annbatch.DatasetCollection.add_adatas` method
- Default reading of input data is now fully lazy in {meth}`annbatch.DatasetCollection.add_adatas`, and therefore the shuffle process may now be slower although have better memory properties. Use `load_adata` argument in {meth}`annbatch.DatasetCollection.add_adatas` to customize this behavior.
- Files shuffled under the old `create_anndata_collection` will not be recognized by {class}`annbatch.DatasetCollection` and therefore are not usable with the new {class}`annbatch.Loader.use_collection` API. At the moment, the file metadata we maintain is only for internal purposes - however, if you wish to migrate to be able to use {class}`annbatch.DatasetCollection` in conjunction with {class}`annbatch.Loader.use_collection`, the root folder of the old collection must have attrs `{"encoding-type": "annbatch-preshuffled", "encoding-version": "0.1.0"}` and be a {class}`zarr.Group`. The subfolders (i.e., datasets) must be called `dataset_([0-9]*)`. Otherwise you can use the {meth}`annbatch.Loader.add_anndatas` as before.
- `create_anndata_collection` and `add_to_collection` have been moved into the `annbatch.DatasetCollection.add_adatas` method
- Default reading of input data is now fully lazy in `annbatch.DatasetCollection.add_adatas`, and therefore the shuffle process may now be slower although have better memory properties. Use `load_adata` argument in `annbatch.DatasetCollection.add_adatas` to customize this behavior.
- Files shuffled under the old `create_anndata_collection` will not be recognized by {class}`annbatch.DatasetCollection` and therefore are not usable with the new {class}`annbatch.Loader.use_collection` API. At the moment, the file metadata we maintain is only for internal purposes - however, if you wish to migrate to be able to use {class}`annbatch.DatasetCollection` in conjunction with {class}`annbatch.Loader.use_collection`, the root folder of the old collection must have attrs `{"encoding-type": "annbatch-preshuffled", "encoding-version": "0.1.0"}` and be a {class}`zarr.Group`. The subfolders (i.e., datasets) must be called `dataset_([0-9]*)`. Otherwise you can use the `annbatch.DatasetCollection.add_adatas` as before.

### Changed

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ zarr.config.set(

# Create a collection at the given path. The subgroups will all be anndata stores.
collection = DatasetCollection("path/to/output/collection.zarr")
collection.add_adatas(
collection.add_adata(
adata_paths=[
"path/to/your/file1.h5ad",
"path/to/your/file2.h5ad"
Expand All @@ -98,7 +98,7 @@ collection.add_adatas(
Data loading:

> [!IMPORTANT]
> Without custom loading via {meth}`annbatch.Loader.use_collection` or `load_anndata{s}` or `load_dataset{s}`, *all* columns of the (obs) {class}`pandas.DataFrame` will be loaded and yielded potentially degrading performance.
> Without custom loading via {meth}`annbatch.Loader.use_collection` or `load_adata{s}` or `load_dataset{s}`, *all* columns of the (obs) {class}`pandas.DataFrame` will be loaded and yielded potentially degrading performance.

```python
from pathlib import Path
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Let's go through the above example:
### Preprocessing

```python
colleciton = DatasetCollection("path/to/output/store.zarr").add_adatas(
colleciton = DatasetCollection("path/to/output/store.zarr").add_adata(
adata_paths=[
"path/to/your/file1.h5ad",
"path/to/your/file2.h5ad"
Expand Down
12 changes: 6 additions & 6 deletions docs/notebooks/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {
"tags": [
"hide-output"
Expand Down Expand Up @@ -178,7 +178,7 @@
"\n",
"\n",
"# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n",
"# To have a store that only contains raw counts, we can write the following load_adata function\n",
"# To have a store that only contains raw counts, we can write the following `load_adata` function\n",
"def read_lazy_x_and_obs_only(path) -> ad.AnnData:\n",
" \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n",
" # IMPORTANT: Large data should always be loaded lazily to reduce the memory footprint\n",
Expand Down Expand Up @@ -227,7 +227,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {
"tags": [
"hide-output"
Expand Down Expand Up @@ -328,7 +328,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {
"tags": [
"hide-output"
Expand Down Expand Up @@ -381,7 +381,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": "annbatch",
"language": "python",
"name": "python3"
},
Expand All @@ -395,7 +395,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.12.12"
}
},
"nbformat": 4,
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ run.patch = [ "subprocess" ]
run.source = [ "annbatch" ]

[tool.mypy]
overrides = [ { module = [ "anndata.*", "cupyx.*", "cupy.*" ], ignore_missing_imports = true } ]
[[tool.mypy.overrides]]
overrides = [ { module = [ "anndata.*", "cupyx.*", "cupy.*", "torch.*", "h5py.*" ], ignore_missing_imports = true } ]

[tool.cruft]
skip = [
Expand Down
36 changes: 19 additions & 17 deletions src/annbatch/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def callback(
iospec: ad.experimental.IOSpec,
):
# Ensure we're not overriding anything here
dataset_kwargs = dataset_kwargs.copy()
dataset_kwargs = dict(dataset_kwargs)
if iospec.encoding_type in {"array"} and (
any(n in store.name for n in {"obsm", "layers", "obsp"}) or "X" == elem_name
):
Expand Down Expand Up @@ -135,7 +135,7 @@ def callback(


def _check_for_mismatched_keys[T: zarr.Group | h5py.Group | PathLike[str] | str](
paths_or_anndatas: Iterable[T | ad.AnnData],
paths_or_adata: Iterable[T | ad.AnnData],
*,
load_adata: Callable[[T], ad.AnnData] = lambda x: ad.experimental.read_lazy(x, load_annotation_index=False),
):
Expand All @@ -145,7 +145,7 @@ def _check_for_mismatched_keys[T: zarr.Group | h5py.Group | PathLike[str] | str]
"obsm": defaultdict(lambda: 0),
"obs": defaultdict(lambda: 0),
}
for path_or_anndata in tqdm(paths_or_anndatas, desc="checking for mismatched keys"):
for path_or_anndata in tqdm(paths_or_adata, desc="Checking for mismatched keys"):
if not isinstance(path_or_anndata, ad.AnnData):
adata = load_adata(path_or_anndata)
else:
Expand All @@ -157,31 +157,31 @@ def _check_for_mismatched_keys[T: zarr.Group | h5py.Group | PathLike[str] | str]
key_count[key] += 1
if adata.raw is not None:
num_raw_in_adata += 1
if num_raw_in_adata != (num_anndatas := len(list(paths_or_anndatas))) and num_raw_in_adata != 0:
if num_raw_in_adata != (num_anndatas := len(list(paths_or_adata))) and num_raw_in_adata != 0:
warnings.warn(
f"Found raw keys not present in all anndatas {paths_or_anndatas}, consider deleting raw or moving it to a shared layer/X location via `load_adata`",
f"Found raw keys not present in all anndatas {paths_or_adata}, consider deleting raw or moving it to a shared layer/X location via `load_adata`",
stacklevel=2,
)
for elem_name, key_count in found_keys.items():
elem_keys_mismatched = [key for key, count in key_count.items() if (count != num_anndatas and count != 0)]
if len(elem_keys_mismatched) > 0:
warnings.warn(
f"Found {elem_name} keys {elem_keys_mismatched} not present in all anndatas {paths_or_anndatas}, consider stopping and using the `load_adata` argument to alter {elem_name} accordingly.",
f"Found {elem_name} keys {elem_keys_mismatched} not present in all anndatas {paths_or_adata}, consider stopping and using the `load_adata` argument to alter {elem_name} accordingly.",
stacklevel=2,
)


def _lazy_load_anndatas[T: zarr.Group | h5py.Group | PathLike[str] | str](
def _lazy_load_adata[T: zarr.Group | h5py.Group | PathLike[str] | str](
paths: Iterable[T],
load_adata: Callable[[T], ad.AnnData] = _default_load_adata,
):
adatas = []
categoricals_in_all_adatas: dict[str, pd.Index] = {}
for i, path in tqdm(enumerate(paths), desc="loading"):
for i, path in tqdm(enumerate(paths), total=len(paths), desc="Lazy loading adata"):
adata = load_adata(path)
# Track the source file for this given anndata object
adata.obs["src_path"] = pd.Categorical.from_codes(
np.ones((adata.shape[0],), dtype="int") * i, categories=[str(p) for p in paths]
np.ones((adata.shape[0],), dtype="int") * i, categories=pd.Index([str(p) for p in paths])
)
# Concatenating Dataset2D drops categoricals so we need to track them
if isinstance(adata.obs, Dataset2D):
Expand Down Expand Up @@ -239,11 +239,13 @@ def _create_chunks_for_shuffling(
if use_single_chunking:
return [np.concatenate(idxs)]
# unfortunately, this is the only way to prevent numpy.split from trying to np.array the idxs list, which can have uneven elements.
idxs = np.array([slice(int(idx[0]), int(idx[-1] + 1)) for idx in idxs])
idxs_as_slices = np.array([slice(int(idx[0]), int(idx[-1] + 1)) for idx in idxs])
return [
np.concatenate([np.arange(s.start, s.stop) for s in idx])
for idx in (
split_given_size(idxs, n_slices_per_dataset) if n_chunkings is None else np.array_split(idxs, n_chunkings)
split_given_size(idxs_as_slices, n_slices_per_dataset)
if n_chunkings is None
else np.array_split(idxs_as_slices, n_chunkings)
)
]

Expand Down Expand Up @@ -385,7 +387,7 @@ def __iter__(self) -> Generator[zarr.Group]:

@property
def is_empty(self) -> bool:
"""Wether or not there is an existing store at the group location."""
"""Whether or not there is an existing store at the group location."""
return (
(not (V1_ENCODING.items() <= self._group.attrs.items()) or len(self._dataset_keys) == 0)
if isinstance(self._group, zarr.Group)
Expand Down Expand Up @@ -571,7 +573,7 @@ def _create_collection(
if not self.is_empty:
raise RuntimeError("Cannot create a collection at a location that already has a shuffled collection")
_check_for_mismatched_keys(adata_paths, load_adata=load_adata)
adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata)
adata_concat = _lazy_load_adata(adata_paths, load_adata=load_adata)
adata_concat.obs_names_make_unique()
n_obs_per_dataset = min(adata_concat.shape[0], n_obs_per_dataset)
chunks = _create_chunks_for_shuffling(
Expand All @@ -584,7 +586,7 @@ def _create_collection(

if var_subset is None:
var_subset = adata_concat.var_names
for i, chunk in enumerate(tqdm(chunks, desc="processing chunks")):
for i, chunk in enumerate(tqdm(chunks, desc="Creating collection")):
var_mask = adata_concat.var_names.isin(var_subset)
# np.sort: It's more efficient to access elements sequentially from dask arrays
# The data will be shuffled later on, we just want the elements at this point
Expand Down Expand Up @@ -663,11 +665,11 @@ def _add_to_collection(
Whether or not to shuffle when adding. Otherwise, the incoming data will just be split up and appended.
"""
if self.is_empty:
raise ValueError("Store is empty. Please run `DatasetCollection.add` first.")
raise ValueError("Store is empty. Please run `DatasetCollection.add_adatas` first.")
# Check for mismatched keys among the inputs.
_check_for_mismatched_keys(adata_paths, load_adata=load_adata)

adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata)
adata_concat = _lazy_load_adata(adata_paths, load_adata=load_adata)
if math.ceil(adata_concat.shape[0] / shuffle_chunk_size) < len(self._dataset_keys):
raise ValueError(
f"Use a shuffle size small enough to distribute the input data with {adata_concat.shape[0]} obs across {len(self._dataset_keys)} anndata stores."
Expand All @@ -685,7 +687,7 @@ def _add_to_collection(

adata_concat.obs_names_make_unique()
for dataset, chunk in tqdm(
zip(self._dataset_keys, chunks, strict=True), total=len(self._dataset_keys), desc="processing chunks"
zip(self._dataset_keys, chunks, strict=True), total=len(self._dataset_keys), desc="Extending collection"
):
adata_dataset = ad.io.read_elem(self._group[dataset])
subset_adata = _to_categorical_obs(
Expand Down
28 changes: 16 additions & 12 deletions src/annbatch/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class Loader[
batch_size=4096,
chunk_size=32,
preload_nchunks=512,
).add_anndata(my_anndata)
).add_adata(my_anndata)
>>> for batch in ds:
# optionally convert to dense
# batch = batch.to_dense()
Expand Down Expand Up @@ -161,6 +161,7 @@ class Loader[
_batch_sampler: Sampler
_concat_strategy: None | concat_strategies = None
_dataset_intervals: pd.IntervalIndex | None = None
_collection_added: bool = False

def __init__(
self,
Expand Down Expand Up @@ -295,52 +296,55 @@ def batch_sampler(self) -> Sampler:
return self._batch_sampler

def use_collection(
self, collection: DatasetCollection, *, load_adata: Callable[[zarr.Group], ad.AnnData] = load_x_and_obs_and_var
self,
collection: DatasetCollection,
*,
load_adata: Callable[[zarr.Group], ad.AnnData] = load_x_and_obs_and_var,
) -> Self:
"""Load from an existing :class:`annbatch.DatasetCollection`.

This function can only be called once. If you want to manually add more data, use :meth:`Loader.add_anndatas` or open an issue.
This function can only be called once. If you want to manually add more data, use :meth:`Loader.add_adatas` or open an issue.

Parameters
----------
collection
The collection who on-disk datasets should be used in this loader.
The collection whose on-disk datasets should be used in this loader.
load_adata
A custom load function - recall that whatever is found in :attr:`~anndata.AnnData.X` and :attr:`~anndata.AnnData.obs` will be yielded in batches.
Default is to just load `X` and all of `obs`.
This default behavior can degrade performance if you don't need all columns in `obs` - it is recommended to use the `load_adata` argument.
"""
if collection.is_empty:
raise ValueError("DatasetCollection is empty")
if getattr(self, "_collection_added", False):
if self._collection_added:
raise RuntimeError(
"You should not add multiple collections, independently shuffled - please preshuffle multiple collections, use `add_anndatas` manually if you know what you are doing, or open an issue if you believe that this should be supported at an API level higher than `add_anndatas`."
"You should not add multiple collections, independently shuffled - please preshuffle multiple collections, use `add_adatas` manually if you know what you are doing, or open an issue if you believe that this should be supported at an API level higher than `add_adatas`."
)
adatas = [load_adata(g) for g in collection]
self.add_anndatas(adatas)
self.add_adatas(adatas)
self._collection_added = True
return self

@validate_sampler
def add_anndatas(
def add_adatas(
self,
adatas: list[ad.AnnData],
) -> Self:
"""Append anndatas to this dataset.
"""Append adatas to this dataset.

Parameters
----------
adatas
List of :class:`anndata.AnnData` objects, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing annotations to yield in a :class:`pandas.DataFrame`.
"""
check_lt_1([len(adatas)], ["Number of anndatas"])
check_lt_1([len(adatas)], ["Number of adatas"])
for adata in adatas:
dataset, obs, var = self._prepare_dataset_obs_and_var(adata)
self._add_dataset_unchecked(dataset, obs, var)
return self

def add_anndata(self, adata: ad.AnnData) -> Self:
"""Append an anndata to this dataset.
def add_adata(self, adata: ad.AnnData) -> Self:
"""Append an adata to this dataset.

Parameters
----------
Expand Down
12 changes: 6 additions & 6 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,16 +552,16 @@ def test_mismatched_var_raises_error(tmp_path: Path, subtests):
var=adata2.var,
)

with subtests.test(msg="add_anndata"):
with subtests.test(msg="add_adata"):
loader = Loader(chunk_size=10, preload_nchunks=4, batch_size=20)
loader.add_anndata(adata1_on_disk)
loader.add_adata(adata1_on_disk)
with pytest.raises(ValueError, match="All datasets must have identical var DataFrames"):
loader.add_anndata(adata2_on_disk)
loader.add_adata(adata2_on_disk)

with subtests.test(msg="add_anndatas"):
with subtests.test(msg="add_adatas"):
loader = Loader(chunk_size=10, preload_nchunks=4, batch_size=20)
with pytest.raises(ValueError, match="All datasets must have identical var DataFrames"):
loader.add_anndatas([adata1_on_disk, adata2_on_disk])
loader.add_adatas([adata1_on_disk, adata2_on_disk])

with subtests.test(msg="add_dataset"):
loader = Loader(chunk_size=10, preload_nchunks=4, batch_size=20)
Expand All @@ -585,7 +585,7 @@ def test_preload_dtype(tmp_path: Path, dtype_in: np.dtype, expected: np.dtype):
z = zarr.open(tmp_path / "foo.zarr")
write_sharded(z, ad.AnnData(X=sp.random(100, 10, dtype=dtype_in, format="csr", rng=np.random.default_rng())))
adata = ad.AnnData(X=ad.io.sparse_dataset(z["X"]))
loader = Loader(preload_to_gpu=True, batch_size=10, chunk_size=10, preload_nchunks=2, to_torch=False).add_anndata(
loader = Loader(preload_to_gpu=True, batch_size=10, chunk_size=10, preload_nchunks=2, to_torch=False).add_adata(
adata
)
assert next(iter(loader))["X"].dtype == expected
Expand Down
Loading