diff --git a/CHANGELOG.md b/CHANGELOG.md index e58f69bb..831e622b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,8 +10,15 @@ and this project adheres to [Semantic Versioning][]. ## [0.0.9] +### Breaking +- Renamed `annbatch.Loader.add_anndatas` to {meth}`annbatch.Loader.add_adatas`. +- Renamed `annbatch.Loader.add_anndata` to {meth}`annbatch.Loader.add_adata`. + +### Fixed +- Formatted progress bar descriptions to be more readable. - {class}`annbatch.DatasetCollection` now accepts a `rng` argument to the {meth}`annbatch.DatasetCollection.add_adatas` method. + ## [0.0.8] - {class}`~annbatch.Loader` acccepts an `rng` argument now @@ -39,7 +46,7 @@ and this project adheres to [Semantic Versioning][]. ## [0.0.4] -- Load into memory nullables/categoricals from `obs` by default when shuffling (i.e., no custom `load_adata` argument to {meth}`annbatch.DatasetCollection.add_adatas`) +- Load into memory nullables/categoricals from `obs` by default when shuffling (i.e., no custom `load_adata` argument to `annbatch.DatasetCollection.add_adatas`) ## [0.0.3] @@ -50,12 +57,13 @@ and this project adheres to [Semantic Versioning][]. ## [0.0.2] + ### Breaking - `ZarrSparseDataset` and `ZarrDenseDataset` have been conslidated into {class}`annbatch.Loader` -- `create_anndata_collection` and `add_to_collection` have been moved into the {meth}`annbatch.DatasetCollection.add_adatas` method -- Default reading of input data is now fully lazy in {meth}`annbatch.DatasetCollection.add_adatas`, and therefore the shuffle process may now be slower although have better memory properties. Use `load_adata` argument in {meth}`annbatch.DatasetCollection.add_adatas` to customize this behavior. -- Files shuffled under the old `create_anndata_collection` will not be recognized by {class}`annbatch.DatasetCollection` and therefore are not usable with the new {class}`annbatch.Loader.use_collection` API. At the moment, the file metadata we maintain is only for internal purposes - however, if you wish to migrate to be able to use {class}`annbatch.DatasetCollection` in conjunction with {class}`annbatch.Loader.use_collection`, the root folder of the old collection must have attrs `{"encoding-type": "annbatch-preshuffled", "encoding-version": "0.1.0"}` and be a {class}`zarr.Group`. The subfolders (i.e., datasets) must be called `dataset_([0-9]*)`. Otherwise you can use the {meth}`annbatch.Loader.add_anndatas` as before. +- `create_anndata_collection` and `add_to_collection` have been moved into the `annbatch.DatasetCollection.add_adatas` method +- Default reading of input data is now fully lazy in `annbatch.DatasetCollection.add_adatas`, and therefore the shuffle process may now be slower although have better memory properties. Use `load_adata` argument in `annbatch.DatasetCollection.add_adatas` to customize this behavior. +- Files shuffled under the old `create_anndata_collection` will not be recognized by {class}`annbatch.DatasetCollection` and therefore are not usable with the new {class}`annbatch.Loader.use_collection` API. At the moment, the file metadata we maintain is only for internal purposes - however, if you wish to migrate to be able to use {class}`annbatch.DatasetCollection` in conjunction with {class}`annbatch.Loader.use_collection`, the root folder of the old collection must have attrs `{"encoding-type": "annbatch-preshuffled", "encoding-version": "0.1.0"}` and be a {class}`zarr.Group`. The subfolders (i.e., datasets) must be called `dataset_([0-9]*)`. Otherwise you can use the `annbatch.DatasetCollection.add_adatas` as before. ### Changed diff --git a/README.md b/README.md index 3df65708..e7052e12 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ zarr.config.set( # Create a collection at the given path. The subgroups will all be anndata stores. collection = DatasetCollection("path/to/output/collection.zarr") -collection.add_adatas( +collection.add_adata( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" @@ -98,7 +98,7 @@ collection.add_adatas( Data loading: > [!IMPORTANT] -> Without custom loading via {meth}`annbatch.Loader.use_collection` or `load_anndata{s}` or `load_dataset{s}`, *all* columns of the (obs) {class}`pandas.DataFrame` will be loaded and yielded potentially degrading performance. +> Without custom loading via {meth}`annbatch.Loader.use_collection` or `load_adata{s}` or `load_dataset{s}`, *all* columns of the (obs) {class}`pandas.DataFrame` will be loaded and yielded potentially degrading performance. ```python from pathlib import Path diff --git a/docs/index.md b/docs/index.md index f9ee017e..dd8b4a89 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,7 +9,7 @@ Let's go through the above example: ### Preprocessing ```python -colleciton = DatasetCollection("path/to/output/store.zarr").add_adatas( +colleciton = DatasetCollection("path/to/output/store.zarr").add_adata( adata_paths=[ "path/to/your/file1.h5ad", "path/to/your/file2.h5ad" diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index f7085129..7bf3c822 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -133,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "tags": [ "hide-output" @@ -178,7 +178,7 @@ "\n", "\n", "# For CELLxGENE data, the raw counts can either be found under .raw.X or under .X (if .raw is not supplied).\n", - "# To have a store that only contains raw counts, we can write the following load_adata function\n", + "# To have a store that only contains raw counts, we can write the following `load_adata` function\n", "def read_lazy_x_and_obs_only(path) -> ad.AnnData:\n", " \"\"\"Custom load function to only load raw counts from CxG data.\"\"\"\n", " # IMPORTANT: Large data should always be loaded lazily to reduce the memory footprint\n", @@ -227,7 +227,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "tags": [ "hide-output" @@ -328,7 +328,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "tags": [ "hide-output" @@ -381,7 +381,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "annbatch", "language": "python", "name": "python3" }, @@ -395,7 +395,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.12.12" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index b9679925..1114211d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,7 +168,8 @@ run.patch = [ "subprocess" ] run.source = [ "annbatch" ] [tool.mypy] -overrides = [ { module = [ "anndata.*", "cupyx.*", "cupy.*" ], ignore_missing_imports = true } ] +[[tool.mypy.overrides]] +overrides = [ { module = [ "anndata.*", "cupyx.*", "cupy.*", "torch.*", "h5py.*" ], ignore_missing_imports = true } ] [tool.cruft] skip = [ diff --git a/src/annbatch/io.py b/src/annbatch/io.py index f8501f32..d5d131f6 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -103,7 +103,7 @@ def callback( iospec: ad.experimental.IOSpec, ): # Ensure we're not overriding anything here - dataset_kwargs = dataset_kwargs.copy() + dataset_kwargs = dict(dataset_kwargs) if iospec.encoding_type in {"array"} and ( any(n in store.name for n in {"obsm", "layers", "obsp"}) or "X" == elem_name ): @@ -135,7 +135,7 @@ def callback( def _check_for_mismatched_keys[T: zarr.Group | h5py.Group | PathLike[str] | str]( - paths_or_anndatas: Iterable[T | ad.AnnData], + paths_or_adata: Iterable[T | ad.AnnData], *, load_adata: Callable[[T], ad.AnnData] = lambda x: ad.experimental.read_lazy(x, load_annotation_index=False), ): @@ -145,7 +145,7 @@ def _check_for_mismatched_keys[T: zarr.Group | h5py.Group | PathLike[str] | str] "obsm": defaultdict(lambda: 0), "obs": defaultdict(lambda: 0), } - for path_or_anndata in tqdm(paths_or_anndatas, desc="checking for mismatched keys"): + for path_or_anndata in tqdm(paths_or_adata, desc="Checking for mismatched keys"): if not isinstance(path_or_anndata, ad.AnnData): adata = load_adata(path_or_anndata) else: @@ -157,31 +157,31 @@ def _check_for_mismatched_keys[T: zarr.Group | h5py.Group | PathLike[str] | str] key_count[key] += 1 if adata.raw is not None: num_raw_in_adata += 1 - if num_raw_in_adata != (num_anndatas := len(list(paths_or_anndatas))) and num_raw_in_adata != 0: + if num_raw_in_adata != (num_anndatas := len(list(paths_or_adata))) and num_raw_in_adata != 0: warnings.warn( - f"Found raw keys not present in all anndatas {paths_or_anndatas}, consider deleting raw or moving it to a shared layer/X location via `load_adata`", + f"Found raw keys not present in all anndatas {paths_or_adata}, consider deleting raw or moving it to a shared layer/X location via `load_adata`", stacklevel=2, ) for elem_name, key_count in found_keys.items(): elem_keys_mismatched = [key for key, count in key_count.items() if (count != num_anndatas and count != 0)] if len(elem_keys_mismatched) > 0: warnings.warn( - f"Found {elem_name} keys {elem_keys_mismatched} not present in all anndatas {paths_or_anndatas}, consider stopping and using the `load_adata` argument to alter {elem_name} accordingly.", + f"Found {elem_name} keys {elem_keys_mismatched} not present in all anndatas {paths_or_adata}, consider stopping and using the `load_adata` argument to alter {elem_name} accordingly.", stacklevel=2, ) -def _lazy_load_anndatas[T: zarr.Group | h5py.Group | PathLike[str] | str]( +def _lazy_load_adata[T: zarr.Group | h5py.Group | PathLike[str] | str]( paths: Iterable[T], load_adata: Callable[[T], ad.AnnData] = _default_load_adata, ): adatas = [] categoricals_in_all_adatas: dict[str, pd.Index] = {} - for i, path in tqdm(enumerate(paths), desc="loading"): + for i, path in tqdm(enumerate(paths), total=len(paths), desc="Lazy loading adata"): adata = load_adata(path) # Track the source file for this given anndata object adata.obs["src_path"] = pd.Categorical.from_codes( - np.ones((adata.shape[0],), dtype="int") * i, categories=[str(p) for p in paths] + np.ones((adata.shape[0],), dtype="int") * i, categories=pd.Index([str(p) for p in paths]) ) # Concatenating Dataset2D drops categoricals so we need to track them if isinstance(adata.obs, Dataset2D): @@ -239,11 +239,13 @@ def _create_chunks_for_shuffling( if use_single_chunking: return [np.concatenate(idxs)] # unfortunately, this is the only way to prevent numpy.split from trying to np.array the idxs list, which can have uneven elements. - idxs = np.array([slice(int(idx[0]), int(idx[-1] + 1)) for idx in idxs]) + idxs_as_slices = np.array([slice(int(idx[0]), int(idx[-1] + 1)) for idx in idxs]) return [ np.concatenate([np.arange(s.start, s.stop) for s in idx]) for idx in ( - split_given_size(idxs, n_slices_per_dataset) if n_chunkings is None else np.array_split(idxs, n_chunkings) + split_given_size(idxs_as_slices, n_slices_per_dataset) + if n_chunkings is None + else np.array_split(idxs_as_slices, n_chunkings) ) ] @@ -385,7 +387,7 @@ def __iter__(self) -> Generator[zarr.Group]: @property def is_empty(self) -> bool: - """Wether or not there is an existing store at the group location.""" + """Whether or not there is an existing store at the group location.""" return ( (not (V1_ENCODING.items() <= self._group.attrs.items()) or len(self._dataset_keys) == 0) if isinstance(self._group, zarr.Group) @@ -571,7 +573,7 @@ def _create_collection( if not self.is_empty: raise RuntimeError("Cannot create a collection at a location that already has a shuffled collection") _check_for_mismatched_keys(adata_paths, load_adata=load_adata) - adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) + adata_concat = _lazy_load_adata(adata_paths, load_adata=load_adata) adata_concat.obs_names_make_unique() n_obs_per_dataset = min(adata_concat.shape[0], n_obs_per_dataset) chunks = _create_chunks_for_shuffling( @@ -584,7 +586,7 @@ def _create_collection( if var_subset is None: var_subset = adata_concat.var_names - for i, chunk in enumerate(tqdm(chunks, desc="processing chunks")): + for i, chunk in enumerate(tqdm(chunks, desc="Creating collection")): var_mask = adata_concat.var_names.isin(var_subset) # np.sort: It's more efficient to access elements sequentially from dask arrays # The data will be shuffled later on, we just want the elements at this point @@ -663,11 +665,11 @@ def _add_to_collection( Whether or not to shuffle when adding. Otherwise, the incoming data will just be split up and appended. """ if self.is_empty: - raise ValueError("Store is empty. Please run `DatasetCollection.add` first.") + raise ValueError("Store is empty. Please run `DatasetCollection.add_adatas` first.") # Check for mismatched keys among the inputs. _check_for_mismatched_keys(adata_paths, load_adata=load_adata) - adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata) + adata_concat = _lazy_load_adata(adata_paths, load_adata=load_adata) if math.ceil(adata_concat.shape[0] / shuffle_chunk_size) < len(self._dataset_keys): raise ValueError( f"Use a shuffle size small enough to distribute the input data with {adata_concat.shape[0]} obs across {len(self._dataset_keys)} anndata stores." @@ -685,7 +687,7 @@ def _add_to_collection( adata_concat.obs_names_make_unique() for dataset, chunk in tqdm( - zip(self._dataset_keys, chunks, strict=True), total=len(self._dataset_keys), desc="processing chunks" + zip(self._dataset_keys, chunks, strict=True), total=len(self._dataset_keys), desc="Extending collection" ): adata_dataset = ad.io.read_elem(self._group[dataset]) subset_adata = _to_categorical_obs( diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 973a47d9..0e969c74 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -131,7 +131,7 @@ class Loader[ batch_size=4096, chunk_size=32, preload_nchunks=512, - ).add_anndata(my_anndata) + ).add_adata(my_anndata) >>> for batch in ds: # optionally convert to dense # batch = batch.to_dense() @@ -161,6 +161,7 @@ class Loader[ _batch_sampler: Sampler _concat_strategy: None | concat_strategies = None _dataset_intervals: pd.IntervalIndex | None = None + _collection_added: bool = False def __init__( self, @@ -295,16 +296,19 @@ def batch_sampler(self) -> Sampler: return self._batch_sampler def use_collection( - self, collection: DatasetCollection, *, load_adata: Callable[[zarr.Group], ad.AnnData] = load_x_and_obs_and_var + self, + collection: DatasetCollection, + *, + load_adata: Callable[[zarr.Group], ad.AnnData] = load_x_and_obs_and_var, ) -> Self: """Load from an existing :class:`annbatch.DatasetCollection`. - This function can only be called once. If you want to manually add more data, use :meth:`Loader.add_anndatas` or open an issue. + This function can only be called once. If you want to manually add more data, use :meth:`Loader.add_adatas` or open an issue. Parameters ---------- collection - The collection who on-disk datasets should be used in this loader. + The collection whose on-disk datasets should be used in this loader. load_adata A custom load function - recall that whatever is found in :attr:`~anndata.AnnData.X` and :attr:`~anndata.AnnData.obs` will be yielded in batches. Default is to just load `X` and all of `obs`. @@ -312,35 +316,35 @@ def use_collection( """ if collection.is_empty: raise ValueError("DatasetCollection is empty") - if getattr(self, "_collection_added", False): + if self._collection_added: raise RuntimeError( - "You should not add multiple collections, independently shuffled - please preshuffle multiple collections, use `add_anndatas` manually if you know what you are doing, or open an issue if you believe that this should be supported at an API level higher than `add_anndatas`." + "You should not add multiple collections, independently shuffled - please preshuffle multiple collections, use `add_adatas` manually if you know what you are doing, or open an issue if you believe that this should be supported at an API level higher than `add_adatas`." ) adatas = [load_adata(g) for g in collection] - self.add_anndatas(adatas) + self.add_adatas(adatas) self._collection_added = True return self @validate_sampler - def add_anndatas( + def add_adatas( self, adatas: list[ad.AnnData], ) -> Self: - """Append anndatas to this dataset. + """Append adatas to this dataset. Parameters ---------- adatas List of :class:`anndata.AnnData` objects, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing annotations to yield in a :class:`pandas.DataFrame`. """ - check_lt_1([len(adatas)], ["Number of anndatas"]) + check_lt_1([len(adatas)], ["Number of adatas"]) for adata in adatas: dataset, obs, var = self._prepare_dataset_obs_and_var(adata) self._add_dataset_unchecked(dataset, obs, var) return self - def add_anndata(self, adata: ad.AnnData) -> Self: - """Append an anndata to this dataset. + def add_adata(self, adata: ad.AnnData) -> Self: + """Append an adata to this dataset. Parameters ---------- diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 8a80924a..18684717 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -552,16 +552,16 @@ def test_mismatched_var_raises_error(tmp_path: Path, subtests): var=adata2.var, ) - with subtests.test(msg="add_anndata"): + with subtests.test(msg="add_adata"): loader = Loader(chunk_size=10, preload_nchunks=4, batch_size=20) - loader.add_anndata(adata1_on_disk) + loader.add_adata(adata1_on_disk) with pytest.raises(ValueError, match="All datasets must have identical var DataFrames"): - loader.add_anndata(adata2_on_disk) + loader.add_adata(adata2_on_disk) - with subtests.test(msg="add_anndatas"): + with subtests.test(msg="add_adatas"): loader = Loader(chunk_size=10, preload_nchunks=4, batch_size=20) with pytest.raises(ValueError, match="All datasets must have identical var DataFrames"): - loader.add_anndatas([adata1_on_disk, adata2_on_disk]) + loader.add_adatas([adata1_on_disk, adata2_on_disk]) with subtests.test(msg="add_dataset"): loader = Loader(chunk_size=10, preload_nchunks=4, batch_size=20) @@ -585,7 +585,7 @@ def test_preload_dtype(tmp_path: Path, dtype_in: np.dtype, expected: np.dtype): z = zarr.open(tmp_path / "foo.zarr") write_sharded(z, ad.AnnData(X=sp.random(100, 10, dtype=dtype_in, format="csr", rng=np.random.default_rng()))) adata = ad.AnnData(X=ad.io.sparse_dataset(z["X"])) - loader = Loader(preload_to_gpu=True, batch_size=10, chunk_size=10, preload_nchunks=2, to_torch=False).add_anndata( + loader = Loader(preload_to_gpu=True, batch_size=10, chunk_size=10, preload_nchunks=2, to_torch=False).add_adata( adata ) assert next(iter(loader))["X"].dtype == expected