diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 1635bdd2a..cd1d60ade 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -53,13 +53,43 @@ jobs: fi fi uv sync --group=test + # Start storage emulators (S3, Azure, GCS) only on Linux; service containers are not available on Windows/macOS + - name: Build and start storage emulators + if: matrix.os == 'ubuntu-latest' + run: | + docker build -f tests/io/remote_storage/Dockerfile.emulators -t spatialdata-emulators . + docker run --rm -d --name spatialdata-emulators \ + -p 5000:5000 -p 10000:10000 -p 4443:4443 \ + spatialdata-emulators + - name: Wait for emulator ports + if: matrix.os == 'ubuntu-latest' + run: | + echo "Waiting for S3 (5000), Azure (10000), GCS (4443)..." + python3 -c " + import socket, time + for _ in range(45): + try: + for p in (5000, 10000, 4443): + socket.create_connection(('127.0.0.1', p), timeout=2) + print('Emulators ready.') + break + except (socket.error, OSError): + time.sleep(2) + else: + raise SystemExit('Emulators did not become ready.') + " + # On Linux, emulators run above so full suite (incl. tests/io/remote_storage/) runs. On Windows/macOS, skip remote_storage. - name: Test env: MPLBACKEND: agg PLATFORM: ${{ matrix.os }} DISPLAY: :42 run: | - uv run pytest --cov --color=yes --cov-report=xml + if [[ "${{ matrix.os }}" == "ubuntu-latest" ]]; then + uv run pytest --cov --color=yes --cov-report=xml + else + uv run pytest --cov --color=yes --cov-report=xml --ignore=tests/io/remote_storage/ + fi - name: Upload coverage to Codecov uses: codecov/codecov-action@v5 with: diff --git a/pyproject.toml b/pyproject.toml index e5f3134aa..cce73720b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,9 @@ dev = [ "bump2version", ] test = [ + "adlfs", + "gcsfs", + "moto[server]", "pytest", "pytest-cov", "pytest-mock", diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 739b225fe..810713d45 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -121,7 +121,7 @@ def __init__( tables: dict[str, AnnData] | Tables | None = None, attrs: Mapping[Any, Any] | None = None, ) -> None: - self._path: Path | None = None + self._path: Path | UPath | None = None self._shared_keys: set[str | None] = set() self._images: Images = Images(shared_keys=self._shared_keys) @@ -548,16 +548,16 @@ def is_backed(self) -> bool: return self.path is not None @property - def path(self) -> Path | None: + def path(self) -> Path | UPath | None: """Path to the Zarr storage.""" return self._path @path.setter - def path(self, value: Path | None) -> None: - if value is None or isinstance(value, str | Path): + def path(self, value: Path | UPath | None) -> None: + if value is None or isinstance(value, (str, Path, UPath)): self._path = value else: - raise TypeError("Path must be `None`, a `str` or a `Path` object.") + raise TypeError("Path must be `None`, a `str`, a `Path` or a `UPath` object.") def locate_element(self, element: SpatialElement) -> list[str]: """ @@ -1032,18 +1032,34 @@ def _symmetric_difference_with_zarr_store(self) -> tuple[list[str], list[str]]: def _validate_can_safely_write_to_path( self, - file_path: str | Path, + file_path: str | Path | UPath, overwrite: bool = False, saving_an_element: bool = False, ) -> None: - from spatialdata._io._utils import _backed_elements_contained_in_path, _is_subfolder, _resolve_zarr_store + from spatialdata._io._utils import ( + _backed_elements_contained_in_path, + _is_subfolder, + _remote_zarr_store_exists, + _resolve_zarr_store, + ) if isinstance(file_path, str): file_path = Path(file_path) - if not isinstance(file_path, Path): - raise ValueError(f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}.") + if not isinstance(file_path, (Path, UPath)): + raise ValueError(f"file_path must be a string, Path or UPath object, type(file_path) = {type(file_path)}.") + + if isinstance(file_path, UPath): + store = _resolve_zarr_store(file_path) + if _remote_zarr_store_exists(store) and not overwrite: + raise ValueError( + "The Zarr store already exists. Use `overwrite=True` to try overwriting the store. " + "Please note that only Zarr stores not currently in use by the current SpatialData object can be " + "overwritten." + ) + return + # Local Path: existing logic # TODO: add test for this if os.path.exists(file_path): store = _resolve_zarr_store(file_path) @@ -1072,8 +1088,13 @@ def _validate_can_safely_write_to_path( ERROR_MSG + "\nDetails: the target path contains one or more files that Dask use for " "backing elements in the SpatialData object." + WORKAROUND ) - if self.path is not None and ( - _is_subfolder(parent=self.path, child=file_path) or _is_subfolder(parent=file_path, child=self.path) + # Subfolder checks only for local paths (Path); skip when self.path is UPath + if ( + self.path is not None + and isinstance(self.path, Path) + and ( + _is_subfolder(parent=self.path, child=file_path) or _is_subfolder(parent=file_path, child=self.path) + ) ): if saving_an_element and _is_subfolder(parent=self.path, child=file_path): raise ValueError( @@ -1102,7 +1123,7 @@ def _validate_all_elements(self) -> None: @_deprecation_alias(format="sdata_formats", version="0.7.0") def write( self, - file_path: str | Path, + file_path: str | Path | UPath | None = None, overwrite: bool = False, consolidate_metadata: bool = True, update_sdata_path: bool = True, @@ -1115,7 +1136,7 @@ def write( Parameters ---------- file_path - The path to the Zarr store to write to. + The path to the Zarr store to write to. If ``None``, uses :attr:`path` (must be set). overwrite If `True`, overwrite the Zarr store if it already exists. If `False`, `write()` will fail if the Zarr store already exists. @@ -1161,8 +1182,13 @@ def write( parsed = _parse_formats(sdata_formats) + if file_path is None: + if self.path is None: + raise ValueError("file_path must be provided when SpatialData.path is not set.") + file_path = self.path if isinstance(file_path, str): file_path = Path(file_path) + # Keep UPath as-is; do not convert to Path self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) self._validate_all_elements() @@ -1192,7 +1218,7 @@ def write( def _write_element( self, element: SpatialElement | AnnData, - zarr_container_path: Path, + zarr_container_path: Path | UPath, element_type: str, element_name: str, overwrite: bool, @@ -1201,10 +1227,8 @@ def _write_element( ) -> None: from spatialdata._io.io_zarr import _get_groups_for_element - if not isinstance(zarr_container_path, Path): - raise ValueError( - f"zarr_container_path must be a Path object, type(zarr_container_path) = {type(zarr_container_path)}." - ) + if not isinstance(zarr_container_path, (Path, UPath)): + raise ValueError(f"zarr_container_path must be a Path or UPath, got {type(zarr_container_path).__name__}.") file_path_of_element = zarr_container_path / element_type / element_name self._validate_can_safely_write_to_path( file_path=file_path_of_element, overwrite=overwrite, saving_an_element=True @@ -1489,7 +1513,7 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st # check if the element exists in the Zarr storage if not _group_for_element_exists( - zarr_path=Path(self.path), + zarr_path=self.path, element_type=element_type, element_name=element_name, ): @@ -1503,7 +1527,7 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st # warn the users if the element is not self-contained, that is, it is Dask-backed by files outside the Zarr # group for the element - element_zarr_path = Path(self.path) / element_type / element_name + element_zarr_path = self.path / element_type / element_name if not _is_element_self_contained(element=element, element_path=element_zarr_path): logger.info( f"Element {element_type}/{element_name} is not self-contained. The metadata will be" @@ -1544,7 +1568,7 @@ def write_channel_names(self, element_name: str | None = None) -> None: # Mypy does not understand that path is not None so we have the check in the conditional if element_type == "images" and self.path is not None: _, _, element_group = _get_groups_for_element( - zarr_path=Path(self.path), element_type=element_type, element_name=element_name, use_consolidated=False + zarr_path=self.path, element_type=element_type, element_name=element_name, use_consolidated=False ) from spatialdata._io._utils import overwrite_channel_names @@ -1588,7 +1612,7 @@ def write_transformations(self, element_name: str | None = None) -> None: # Mypy does not understand that path is not None so we have a conditional assert self.path is not None _, _, element_group = _get_groups_for_element( - zarr_path=Path(self.path), + zarr_path=self.path, element_type=element_type, element_name=element_name, use_consolidated=False, @@ -1956,7 +1980,8 @@ def h(s: str) -> str: descr = "SpatialData object" if self.path is not None: - descr += f", with associated Zarr store: {self.path.resolve()}" + path_descr = str(self.path) if isinstance(self.path, UPath) else self.path.resolve() + descr += f", with associated Zarr store: {path_descr}" non_empty_elements = self._non_empty_elements() last_element_index = len(non_empty_elements) - 1 diff --git a/src/spatialdata/_io/__init__.py b/src/spatialdata/_io/__init__.py index 38ff8c6bb..9e4b11de1 100644 --- a/src/spatialdata/_io/__init__.py +++ b/src/spatialdata/_io/__init__.py @@ -1,5 +1,7 @@ from __future__ import annotations +# Patch da.to_zarr so ome_zarr's **kwargs are passed as zarr_array_kwargs (avoids FutureWarning) +import spatialdata._io._dask_zarr_compat # noqa: F401 from spatialdata._io._utils import get_dask_backing_files from spatialdata._io.format import SpatialDataFormatType from spatialdata._io.io_points import write_points diff --git a/src/spatialdata/_io/_dask_zarr_compat.py b/src/spatialdata/_io/_dask_zarr_compat.py new file mode 100644 index 000000000..b0988aef7 --- /dev/null +++ b/src/spatialdata/_io/_dask_zarr_compat.py @@ -0,0 +1,55 @@ +"""Compatibility layer for dask.array.to_zarr when callers pass array options via **kwargs. + +ome_zarr.writer calls da.to_zarr(..., **options) with array options (compressor, dimension_names, +etc.). Dask deprecated **kwargs in favor of zarr_array_kwargs. This module patches da.to_zarr to +forward such kwargs into zarr_array_kwargs (excluding dask-internal keys like zarr_format that +zarr.Group.create_array() does not accept), avoiding the FutureWarning and keeping behavior correct. +""" + +from __future__ import annotations + +from typing import Any + +import dask.array as _da + +_orig_to_zarr = _da.to_zarr + +# Keys from ome_zarr/dask **kwargs that must not be passed to zarr.Group.create_array() +# dimension_separator: not accepted by all zarr versions in the create_array() path. +_DASK_INTERNAL_KEYS = frozenset({"zarr_format", "dimension_separator"}) + + +def _to_zarr( + arr: Any, + url: Any, + component: Any = None, + storage_options: Any = None, + region: Any = None, + compute: bool = True, + return_stored: bool = False, + zarr_array_kwargs: Any = None, + zarr_read_kwargs: Any = None, + **kwargs: Any, +) -> Any: + """Forward deprecated **kwargs into zarr_array_kwargs, excluding _DASK_INTERNAL_KEYS.""" + if kwargs: + zarr_array_kwargs = dict(zarr_array_kwargs) if zarr_array_kwargs else {} + for k, v in kwargs.items(): + if k not in _DASK_INTERNAL_KEYS: + zarr_array_kwargs[k] = v + kwargs = {} + return _orig_to_zarr( + arr, + url, + component=component, + storage_options=storage_options, + region=region, + compute=compute, + return_stored=return_stored, + zarr_array_kwargs=zarr_array_kwargs, + zarr_read_kwargs=zarr_read_kwargs, + **kwargs, + ) + + +_da.to_zarr = _to_zarr diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 6690d1118..2a5d44e26 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import filecmp +import json import os.path import re import sys @@ -14,6 +15,7 @@ from pathlib import Path from typing import Any, Literal +import fsspec.asyn as _asyn_mod import zarr from anndata import AnnData from dask._task_spec import Task @@ -23,6 +25,7 @@ from upath import UPath from upath.implementations.local import PosixUPath, WindowsUPath from xarray import DataArray, DataTree +from zarr.errors import GroupNotFoundError from zarr.storage import FsspecStore, LocalStore from spatialdata._core.spatialdata import SpatialData @@ -38,6 +41,74 @@ from spatialdata.transformations.transformations import BaseTransformation, _get_current_output_axes +class _FsspecStoreRoot: + """Path-like root for FsspecStore (no .root attribute); supports __truediv__ and str() as full URL.""" + + __slots__ = ("_store", "_path") + + def __init__(self, store: FsspecStore, path: str | None = None) -> None: + self._store = store + self._path = (path or store.path).rstrip("/") + + def __truediv__(self, other: str | Path) -> _FsspecStoreRoot: + return _FsspecStoreRoot(self._store, self._path + "/" + str(other).lstrip("/")) + + def __str__(self) -> str: + protocol = getattr(self._store.fs, "protocol", None) + if isinstance(protocol, (list, tuple)): + protocol = protocol[0] if protocol else "file" + elif protocol is None: + protocol = "file" + return f"{protocol}://{self._path}" + + def __fspath__(self) -> str: + return str(self) + + +def _storage_options_from_fs(fs: Any) -> dict[str, Any]: + """Build storage_options dict from an fsspec filesystem for use with to_parquet/write_parquet. + + Ensures parquet writes to remote stores (Azure, S3, GCS) use the same credentials as the + zarr store. + """ + out: dict[str, Any] = {} + name = type(fs).__name__ + if name == "AzureBlobFileSystem": + if getattr(fs, "connection_string", None): + out["connection_string"] = fs.connection_string + elif getattr(fs, "account_name", None) and getattr(fs, "account_key", None): + out["account_name"] = fs.account_name + out["account_key"] = fs.account_key + if getattr(fs, "anon", None) is not None: + out["anon"] = fs.anon + elif name in ("S3FileSystem", "MotoS3FS"): + if getattr(fs, "endpoint_url", None): + out["endpoint_url"] = fs.endpoint_url + if getattr(fs, "key", None): + out["key"] = fs.key + if getattr(fs, "secret", None): + out["secret"] = fs.secret + if getattr(fs, "anon", None) is not None: + out["anon"] = fs.anon + elif name == "GCSFileSystem": + if getattr(fs, "token", None) is not None: + out["token"] = fs.token + if getattr(fs, "_endpoint", None): + out["endpoint_url"] = fs._endpoint + if getattr(fs, "project", None): + out["project"] = fs.project + return out + + +def _get_store_root(store: LocalStore | FsspecStore) -> Path | _FsspecStoreRoot: + """Return a path-like root for the store (supports / and str()). Use for building paths to parquet etc.""" + if isinstance(store, LocalStore): + return Path(store.root) + if isinstance(store, FsspecStore): + return _FsspecStoreRoot(store) + raise TypeError(f"Unsupported store type: {type(store)}") + + def _get_transformations_from_ngff_dict( list_of_encoded_ngff_transformations: list[dict[str, Any]], ) -> MappingToCoordinateSystem_t: @@ -370,7 +441,9 @@ def _search_for_backing_files_recursively(subgraph: Any, files: list[str]) -> No files.append(os.path.realpath(parquet_file)) -def _backed_elements_contained_in_path(path: Path, object: SpatialData | SpatialElement | AnnData) -> list[bool]: +def _backed_elements_contained_in_path( + path: Path | UPath, object: SpatialData | SpatialElement | AnnData +) -> list[bool]: """ Return the list of boolean values indicating if backing files for an object are child directory of a path. @@ -390,8 +463,10 @@ def _backed_elements_contained_in_path(path: Path, object: SpatialData | Spatial If an object does not have a Dask computational graph, it will return an empty list. It is possible for a single SpatialElement to contain multiple files in their Dask computational graph. """ + if isinstance(path, UPath): + return [] # no local backing files are "contained" in a remote path if not isinstance(path, Path): - raise TypeError(f"Expected a Path object, got {type(path)}") + raise TypeError(f"Expected a Path or UPath object, got {type(path)}") return [_is_subfolder(parent=path, child=Path(fp)) for fp in get_dask_backing_files(object)] @@ -420,14 +495,58 @@ def _is_subfolder(parent: Path, child: Path) -> bool: def _is_element_self_contained( - element: DataArray | DataTree | DaskDataFrame | GeoDataFrame | AnnData, element_path: Path + element: DataArray | DataTree | DaskDataFrame | GeoDataFrame | AnnData, + element_path: Path | UPath, ) -> bool: + if isinstance(element_path, UPath): + return True # treat remote-backed as self-contained for this check if isinstance(element, DaskDataFrame): pass # TODO when running test_save_transformations it seems that for the same element this is called multiple times return all(_backed_elements_contained_in_path(path=element_path, object=element)) +def _is_azure_http_response_error(exc: BaseException) -> bool: + """Return True if exc is an Azure SDK HttpResponseError (e.g. emulator API mismatch).""" + t = type(exc) + return t.__name__ == "HttpResponseError" and (getattr(t, "__module__", "") or "").startswith("azure.") + + +def _remote_zarr_store_exists(store: zarr.storage.StoreLike) -> bool: + """Return True if the store contains a zarr group. Closes the store. Handles Azure emulator errors.""" + try: + zarr.open_group(store, mode="r") + return True + except (GroupNotFoundError, OSError, FileNotFoundError): + return False + except Exception as e: + if _is_azure_http_response_error(e): + return False + raise + finally: + store.close() + + +def _ensure_async_fs(fs: Any) -> Any: + """Return an async fsspec filesystem for use with zarr's FsspecStore. + + Zarr's FsspecStore expects an async filesystem. If the given fs is synchronous, + it is converted using fsspec's public API (async instance or AsyncFileSystemWrapper) + so that ZarrUserWarning is not raised. + """ + if getattr(fs, "asynchronous", False): + return fs + import fsspec + + if getattr(fs, "async_impl", False): + fs_dict = json.loads(fs.to_json()) + fs_dict["asynchronous"] = True + return fsspec.AbstractFileSystem.from_json(json.dumps(fs_dict)) + from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper + + return AsyncFileSystemWrapper(fs, asynchronous=True) + + def _resolve_zarr_store( path: str | Path | UPath | zarr.storage.StoreLike | zarr.Group, **kwargs: Any ) -> zarr.storage.StoreLike: @@ -477,17 +596,24 @@ def _resolve_zarr_store( if isinstance(path.store, FsspecStore): # if the store within the zarr.Group is an FSStore, return it # but extend the path of the store with that of the zarr.Group - return FsspecStore(path.store.path + "/" + path.path, fs=path.store.fs, **kwargs) + return FsspecStore( + path.store.path + "/" + path.path, + fs=_ensure_async_fs(path.store.fs), + **kwargs, + ) if isinstance(path.store, zarr.storage.ConsolidatedMetadataStore): # if the store is a ConsolidatedMetadataStore, just return the underlying FSSpec store return path.store.store raise ValueError(f"Unsupported store type or zarr.Group: {type(path.store)}") + if isinstance(path, _FsspecStoreRoot): + # path-like from read_zarr that carries the same fs (preserves Azure/GCS credentials) + return FsspecStore(_ensure_async_fs(path._store.fs), path=path._path, **kwargs) + if isinstance(path, UPath): + # if input is a remote UPath, map it to an FSStore (check before StoreLike to avoid UnionType isinstance) + return FsspecStore(_ensure_async_fs(path.fs), path=path.path, **kwargs) if isinstance(path, zarr.storage.StoreLike): # if the input already a store, wrap it in an FSStore return FsspecStore(path, **kwargs) - if isinstance(path, UPath): - # if input is a remote UPath, map it to an FSStore - return FsspecStore(path.path, fs=path.fs, **kwargs) raise TypeError(f"Unsupported type: {type(path)}") @@ -545,3 +671,20 @@ def handle_read_errors( else: # on_bad_files == BadFileHandleMethod.ERROR # Let it raise exceptions yield + + +# Avoid RuntimeError "Loop is not running" when fsspec closes async sessions at process exit +# (remote storage: Azure, S3, GCS). _utils is used for all store resolution. +_orig_sync = _asyn_mod.sync + + +def _fsspec_sync_wrapped(loop: Any, func: Any, *args: Any, timeout: Any = None, **kwargs: Any) -> Any: + try: + return _orig_sync(loop, func, *args, timeout=timeout, **kwargs) + except RuntimeError as e: + if "Loop is not running" in str(e) or "different loop" in str(e): + return None + raise + + +_asyn_mod.sync = _fsspec_sync_wrapped diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index b47fc418c..684b39a27 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +from typing import Any import zarr from dask.dataframe import DataFrame as DaskDataFrame @@ -8,7 +9,11 @@ from ome_zarr.format import Format from spatialdata._io._utils import ( + _FsspecStoreRoot, + _get_store_root, _get_transformations_from_ngff_dict, + _resolve_zarr_store, + _storage_options_from_fs, _write_metadata, overwrite_coordinate_transformations_non_raster, ) @@ -24,17 +29,21 @@ def _read_points( store: str | Path, ) -> DaskDataFrame: """Read points from a zarr store.""" - f = zarr.open(store, mode="r") + resolved_store = _resolve_zarr_store(store) + f = zarr.open(resolved_store, mode="r") version = _parse_version(f, expect_attrs_key=True) assert version is not None points_format = PointsFormats[version] - store_root = f.store_path.store.root + store_root = _get_store_root(f.store_path.store) path = store_root / f.path / "points.parquet" # cache on remote file needed for parquet reader to work # TODO: allow reading in the metadata without caching all the data - points = read_parquet("simplecache::" + str(path) if str(path).startswith("http") else path) + if isinstance(path, _FsspecStoreRoot): + points = read_parquet(str(path), storage_options=_storage_options_from_fs(path._store.fs)) + else: + points = read_parquet("simplecache::" + str(path) if str(path).startswith("http") else path) assert isinstance(points, DaskDataFrame) transformations = _get_transformations_from_ngff_dict(f.attrs.asdict()["coordinateTransformations"]) @@ -68,7 +77,7 @@ def write_points( axes = get_axes_names(points) transformations = _get_transformations(points) - store_root = group.store_path.store.root + store_root = _get_store_root(group.store_path.store) path = store_root / group.path / "points.parquet" # The following code iterates through all columns in the 'points' DataFrame. If the column's datatype is @@ -84,7 +93,10 @@ def write_points( points_without_transform = points.copy() del points_without_transform.attrs["transform"] - points_without_transform.to_parquet(path) + storage_options: dict[str, Any] = {} + if isinstance(path, _FsspecStoreRoot): + storage_options = _storage_options_from_fs(path._store.fs) + points_without_transform.to_parquet(str(path), storage_options=storage_options or None) attrs = element_format.attrs_to_dict(points.attrs) attrs["version"] = element_format.spatialdata_format_version diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index df7e1cb8f..767232fdd 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -19,6 +19,7 @@ from spatialdata._io._utils import ( _get_transformations_from_ngff_dict, + _resolve_zarr_store, overwrite_coordinate_transformations_raster, ) from spatialdata._io.format import ( @@ -41,11 +42,11 @@ def _read_multiscale( store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format ) -> DataArray | DataTree: - assert isinstance(store, str | Path) assert raster_type in ["image", "labels"] + resolved_store = _resolve_zarr_store(store) nodes: list[Node] = [] - image_loc = ZarrLocation(store, fmt=reader_format) + image_loc = ZarrLocation(resolved_store, fmt=reader_format) if exists := image_loc.exists(): image_reader = Reader(image_loc)() image_nodes = list(image_reader) diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index b07256273..adf4716f3 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -1,5 +1,8 @@ from __future__ import annotations +import contextlib +import os +import tempfile from pathlib import Path from typing import Any, Literal @@ -11,7 +14,11 @@ from shapely import from_ragged_array, to_ragged_array from spatialdata._io._utils import ( + _FsspecStoreRoot, + _get_store_root, _get_transformations_from_ngff_dict, + _resolve_zarr_store, + _storage_options_from_fs, _write_metadata, overwrite_coordinate_transformations_non_raster, ) @@ -34,7 +41,8 @@ def _read_shapes( store: str | Path, ) -> GeoDataFrame: """Read shapes from a zarr store.""" - f = zarr.open(store, mode="r") + resolved_store = _resolve_zarr_store(store) + f = zarr.open(resolved_store, mode="r") version = _parse_version(f, expect_attrs_key=True) assert version is not None shape_format = ShapesFormats[version] @@ -54,9 +62,12 @@ def _read_shapes( geometry = from_ragged_array(typ, coords, offsets) geo_df = GeoDataFrame({"geometry": geometry}, index=index) elif isinstance(shape_format, ShapesFormatV02 | ShapesFormatV03): - store_root = f.store_path.store.root - path = Path(store_root) / f.path / "shapes.parquet" - geo_df = read_parquet(path) + store_root = _get_store_root(f.store_path.store) + path = store_root / f.path / "shapes.parquet" + if isinstance(path, _FsspecStoreRoot): + geo_df = read_parquet(str(path), storage_options=_storage_options_from_fs(path._store.fs)) + else: + geo_df = read_parquet(path) else: raise ValueError( f"Unsupported shapes format {shape_format} from version {version}. Please update the spatialdata library." @@ -150,6 +161,67 @@ def _write_shapes_v01(shapes: GeoDataFrame, group: zarr.Group, element_format: F return attrs +def _parse_fsspec_remote_path(path: _FsspecStoreRoot) -> tuple[str, str]: + """Return (bucket_or_container, blob_key) from an fsspec store path.""" + remote = str(path) + if "://" in remote: + remote = remote.split("://", 1)[1] + parts = remote.split("/", 1) + bucket_or_container = parts[0] + blob_key = parts[1] if len(parts) > 1 else "" + return bucket_or_container, blob_key + + +def _upload_parquet_to_azure(tmp_path: str, bucket: str, key: str, fs: Any) -> None: + from azure.storage.blob import BlobServiceClient + + client = BlobServiceClient.from_connection_string(fs.connection_string) + blob_client = client.get_blob_client(container=bucket, blob=key) + with open(tmp_path, "rb") as f: + blob_client.upload_blob(f, overwrite=True) + + +def _upload_parquet_to_s3(tmp_path: str, bucket: str, key: str, fs: Any) -> None: + import boto3 + + endpoint = getattr(fs, "endpoint_url", None) or os.environ.get("AWS_ENDPOINT_URL") + s3 = boto3.client( + "s3", + endpoint_url=endpoint, + aws_access_key_id=getattr(fs, "key", None) or os.environ.get("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=getattr(fs, "secret", None) or os.environ.get("AWS_SECRET_ACCESS_KEY"), + region_name=os.environ.get("AWS_DEFAULT_REGION", "us-east-1"), + ) + s3.upload_file(tmp_path, bucket, key) + + +def _upload_parquet_to_gcs(tmp_path: str, bucket: str, key: str, fs: Any) -> None: + from google.auth.credentials import AnonymousCredentials + from google.cloud import storage + + client = storage.Client( + credentials=AnonymousCredentials(), + project=getattr(fs, "project", None) or "test", + ) + blob = client.bucket(bucket).blob(key) + blob.upload_from_filename(tmp_path) + + +def _upload_parquet_to_fsspec(path: _FsspecStoreRoot, tmp_path: str) -> None: + """Upload local parquet file to remote fsspec store using sync APIs to avoid event-loop issues.""" + fs = path._store.fs + bucket, key = _parse_fsspec_remote_path(path) + fs_name = type(fs).__name__ + if fs_name == "AzureBlobFileSystem" and getattr(fs, "connection_string", None): + _upload_parquet_to_azure(tmp_path, bucket, key, fs) + elif fs_name in ("S3FileSystem", "MotoS3FS"): + _upload_parquet_to_s3(tmp_path, bucket, key, fs) + elif fs_name == "GCSFileSystem": + _upload_parquet_to_gcs(tmp_path, bucket, key, fs) + else: + fs.put(tmp_path, str(path)) + + def _write_shapes_v02_v03( shapes: GeoDataFrame, group: zarr.Group, element_format: Format, geometry_encoding: Literal["WKB", "geoarrow"] ) -> Any: @@ -169,13 +241,23 @@ def _write_shapes_v02_v03( """ from spatialdata.models._utils import TRANSFORM_KEY - store_root = group.store_path.store.root + store_root = _get_store_root(group.store_path.store) path = store_root / group.path / "shapes.parquet" # Temporarily remove transformations from attrs to avoid serialization issues transforms = shapes.attrs[TRANSFORM_KEY] del shapes.attrs[TRANSFORM_KEY] - shapes.to_parquet(path, geometry_encoding=geometry_encoding) + if isinstance(path, _FsspecStoreRoot): + with tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) as tmp: + tmp_path = tmp.name + try: + shapes.to_parquet(tmp_path, geometry_encoding=geometry_encoding) + _upload_parquet_to_fsspec(path, tmp_path) + finally: + with contextlib.suppress(OSError): + os.unlink(tmp_path) + else: + shapes.to_parquet(path, geometry_encoding=geometry_encoding) shapes.attrs[TRANSFORM_KEY] = transforms attrs = element_format.attrs_to_dict(shapes.attrs) diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 8cd7b8385..03ec78526 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -9,6 +9,7 @@ from anndata._io.specs import write_elem as write_adata from ome_zarr.format import Format +from spatialdata._io._utils import _resolve_zarr_store from spatialdata._io.format import ( CurrentTablesFormat, TablesFormats, @@ -20,9 +21,10 @@ def _read_table(store: str | Path) -> AnnData: - table = read_anndata_zarr(str(store)) + resolved_store = _resolve_zarr_store(store) + table = read_anndata_zarr(resolved_store) - f = zarr.open(store, mode="r") + f = zarr.open(resolved_store, mode="r") version = _parse_version(f, expect_attrs_key=False) assert version is not None table_format = TablesFormats[version] diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 4c410fab0..48795513c 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import warnings from collections.abc import Callable from json import JSONDecodeError @@ -19,6 +18,8 @@ from spatialdata._core.spatialdata import SpatialData from spatialdata._io._utils import ( BadFileHandleMethod, + _FsspecStoreRoot, + _get_store_root, _resolve_zarr_store, handle_read_errors, ) @@ -32,7 +33,7 @@ def _read_zarr_group_spatialdata_element( root_group: zarr.Group, - root_store_path: str, + root_store_path: Path | _FsspecStoreRoot, sdata_version: Literal["0.1", "0.2"], selector: set[str], read_func: Callable[..., Any], @@ -54,7 +55,7 @@ def _read_zarr_group_spatialdata_element( # skip hidden files like .zgroup or .zmetadata continue elem_group = group[subgroup_name] - elem_group_path = os.path.join(root_store_path, elem_group.path) + elem_group_path = root_store_path / elem_group.path with handle_read_errors( on_bad_files, location=f"{group.path}/{subgroup_name}", @@ -170,7 +171,7 @@ def read_zarr( UserWarning, stacklevel=2, ) - root_store_path = root_group.store.root + root_store_path = _get_store_root(root_group.store) images: dict[str, Raster_T] = {} labels: dict[str, Raster_T] = {} @@ -231,12 +232,12 @@ def read_zarr( tables=tables, attrs=attrs, ) - sdata.path = resolved_store.root + sdata.path = store if isinstance(store, UPath) else resolved_store.root return sdata def _get_groups_for_element( - zarr_path: Path, element_type: str, element_name: str, use_consolidated: bool = True + zarr_path: Path | UPath, element_type: str, element_name: str, use_consolidated: bool = True ) -> tuple[zarr.Group, zarr.Group, zarr.Group]: """ Get the Zarr groups for the root, element_type and element for a specific element. @@ -265,8 +266,8 @@ def _get_groups_for_element( ------- The Zarr groups for the root, element_type and element for a specific element. """ - if not isinstance(zarr_path, Path): - raise ValueError("zarr_path should be a Path object") + if not isinstance(zarr_path, (Path, UPath)): + raise ValueError("zarr_path should be a Path or UPath object") if element_type not in [ "images", @@ -289,7 +290,7 @@ def _get_groups_for_element( return root_group, element_type_group, element_name_group -def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: str) -> bool: +def _group_for_element_exists(zarr_path: Path | UPath, element_type: str, element_name: str) -> bool: """ Check if the group for an element exists. @@ -319,9 +320,13 @@ def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: return exists -def _write_consolidated_metadata(path: Path | str | None) -> None: +def _write_consolidated_metadata(path: Path | UPath | str | None) -> None: if path is not None: - f = zarr.open_group(path, mode="r+", use_consolidated=False) + if isinstance(path, UPath): + store = _resolve_zarr_store(path) + f = zarr.open_group(store, mode="r+", use_consolidated=False) + else: + f = zarr.open_group(path, mode="r+", use_consolidated=False) # .parquet files are not recognized as proper zarr and thus throw a warning. This does not affect SpatialData. # and therefore we silence it for our users as they can't do anything about this. # TODO check with remote PR whether we can prevent this warning at least for points data and whether with zarrv3 diff --git a/tests/conftest.py b/tests/conftest.py index c97939129..a6deba0ae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -89,12 +89,18 @@ def tables() -> list[AnnData]: @pytest.fixture() def full_sdata() -> SpatialData: + # Use two regions so the table categorical has two categories; otherwise anndata does not + # write the obs/region/codes/c/0 chunk (only codes/zarr.json), causing 404 on remote read. return SpatialData( images=_get_images(), labels=_get_labels(), shapes=_get_shapes(), points=_get_points(), - tables=_get_tables(region="labels2d", region_key="region", instance_key="instance_id"), + tables=_get_tables( + region=["labels2d", "poly"], + region_key="region", + instance_key="instance_id", + ), ) diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 68b538e0a..a898bed0c 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -559,14 +559,15 @@ def test_init_from_elements(full_sdata: SpatialData) -> None: def test_subset(full_sdata: SpatialData) -> None: - element_names = ["image2d", "points_0", "circles", "poly"] + # Exclude labels and poly so the default table (annotating labels2d and poly) is not included + element_names = ["image2d", "points_0", "circles"] subset0 = full_sdata.subset(element_names) unique_names = set() for _, k, _ in subset0.gen_spatial_elements(): unique_names.add(k) assert "image3d_xarray" in full_sdata.images assert unique_names == set(element_names) - # no table since the labels are not present in the subset + # no table since neither labels2d nor poly are in the subset assert "table" not in subset0.tables adata = AnnData( @@ -675,7 +676,9 @@ def test_transform_to_data_extent(full_sdata: SpatialData, maintain_positioning: def test_validate_table_in_spatialdata(full_sdata): table = full_sdata["table"] region, region_key, _ = get_table_keys(table) - assert region == "labels2d" + # full_sdata uses two regions (labels2d, poly) so the table annotates both + expected = {"labels2d", "poly"} + assert set(region if isinstance(region, list) else [region]) == expected full_sdata.validate_table_in_spatialdata(table) diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index 63e7a6f19..c28725681 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -914,6 +914,9 @@ def test_filter_table_non_annotating(full_sdata): def test_labels_table_joins(full_sdata): + # Restrict table to labels2d only so the join returns one row per label (full_sdata default has two regions) + full_sdata["table"].obs["region"] = pd.Categorical(["labels2d"] * full_sdata["table"].n_obs) + full_sdata["table"].uns["spatialdata_attrs"]["region"] = "labels2d" element_dict, table = join_spatialelement_table( sdata=full_sdata, spatial_element_names="labels2d", diff --git a/tests/io/remote_storage/Dockerfile.emulators b/tests/io/remote_storage/Dockerfile.emulators new file mode 100644 index 000000000..bc3bb6f53 --- /dev/null +++ b/tests/io/remote_storage/Dockerfile.emulators @@ -0,0 +1,28 @@ +# Storage emulators for tests in this directory (S3, Azure, GCS). +# Emulator URLs: S3 127.0.0.1:5000 | Azure 127.0.0.1:10000 | GCS 127.0.0.1:4443 +# +# Build (from project root): +# docker build -f tests/io/remote_storage/Dockerfile.emulators -t spatialdata-emulators . +# +# Run in background (detached): +# docker run --rm -d --name spatialdata-emulators -p 5000:5000 -p 10000:10000 -p 4443:4443 spatialdata-emulators +# +# Run in foreground (attach to terminal): +# docker run --rm --name spatialdata-emulators -p 5000:5000 -p 10000:10000 -p 4443:4443 spatialdata-emulators +# +# Stop / remove: +# docker stop spatialdata-emulators +# docker rm -f spatialdata-emulators # if already stopped or to force-remove +FROM node:20-slim +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 python3-pip python3-venv curl ca-certificates \ + && rm -rf /var/lib/apt/lists/* +RUN python3 -m venv /opt/venv && /opt/venv/bin/pip install --no-cache-dir 'moto[server]' +ENV PATH="/opt/venv/bin:$PATH" +RUN cd /tmp && curl -sSL -o fgs.tgz https://github.com/fsouza/fake-gcs-server/releases/download/v1.54.0/fake-gcs-server_1.54.0_linux_amd64.tar.gz \ + && tar xzf fgs.tgz && mv fake-gcs-server /usr/local/bin/ 2>/dev/null || mv fake-gcs-server_*/fake-gcs-server /usr/local/bin/ \ + && chmod +x /usr/local/bin/fake-gcs-server && rm -f fgs.tgz +RUN mkdir -p /data +EXPOSE 5000 10000 4443 +RUN echo 'moto_server -H 0.0.0.0 -p 5000 & npx --yes azurite --silent --location /data --blobHost 0.0.0.0 --skipApiVersionCheck & fake-gcs-server -scheme http -port 4443 & wait' > /start.sh && chmod +x /start.sh +CMD ["/bin/sh", "/start.sh"] diff --git a/tests/io/remote_storage/conftest.py b/tests/io/remote_storage/conftest.py new file mode 100644 index 000000000..c650ab53c --- /dev/null +++ b/tests/io/remote_storage/conftest.py @@ -0,0 +1,193 @@ +"""Minimal pytest config for IO tests. Creates buckets/containers when remote emulators are running. + +Assumes emulators are already running (e.g. Docker: + docker run -p 5000:5000 -p 10000:10000 -p 4443:4443 spatialdata-emulators). +Ports: S3/moto 5000, Azure/Azurite 10000, GCS/fake-gcs-server 4443. +""" + +from __future__ import annotations + +import os +import socket +import time + +import pytest + +# Error messages from asyncio when closing sessions after the event loop is gone (e.g. at process exit) +_LOOP_GONE_ERRORS = ("different loop", "Loop is not running") + + +def _patch_fsspec_sync_for_shutdown() -> None: + """If fsspec.asyn.sync() runs at exit when the loop is gone, return None instead of raising.""" + import fsspec.asyn as asyn_mod + + _orig = asyn_mod.sync + + def _wrapped(loop, func, *args, timeout=None, **kwargs): + try: + return _orig(loop, func, *args, timeout=timeout, **kwargs) + except RuntimeError as e: + if any(msg in str(e) for msg in _LOOP_GONE_ERRORS): + return None + raise + + asyn_mod.sync = _wrapped + + +def _patch_gcsfs_close_session_for_shutdown() -> None: + """If gcsfs close_session fails (loop gone), close the connector synchronously instead of raising.""" + import asyncio + + import fsspec + import fsspec.asyn as asyn_mod + import gcsfs.core + + @staticmethod + def _close_session(loop, session, asynchronous=False): + if session.closed: + return + try: + running = asyncio.get_running_loop() + except RuntimeError: + running = None + + use_force_close = False + if loop and loop.is_running(): + loop.create_task(session.close()) + elif running and running.is_running() and asynchronous: + running.create_task(session.close()) + elif asyn_mod.loop[0] is not None and asyn_mod.loop[0].is_running(): + try: + asyn_mod.sync(asyn_mod.loop[0], session.close, timeout=0.1) + except (RuntimeError, fsspec.FSTimeoutError): + use_force_close = True + else: + use_force_close = True + + if use_force_close: + connector = getattr(session, "_connector", None) + if connector is not None: + connector._close() + + gcsfs.core.GCSFileSystem.close_session = _close_session + + +def _apply_resilient_async_close_patches() -> None: + """Avoid RuntimeError tracebacks when aiohttp sessions are closed at process exit (loop already gone).""" + _patch_fsspec_sync_for_shutdown() + _patch_gcsfs_close_session_for_shutdown() + + +def pytest_configure(config: pytest.Config) -> None: + """Apply patches for remote storage tests (resilient async close at shutdown).""" + _apply_resilient_async_close_patches() + + +EMULATOR_PORTS = {"s3": 5000, "azure": 10000, "gcs": 4443} +S3_BUCKETS = ("bucket", "test-azure", "test-s3", "test-gcs") +AZURE_CONTAINERS = ("test-container", "test-azure", "test-s3", "test-gcs") +GCS_BUCKETS = ("bucket", "test-azure", "test-s3", "test-gcs") + +AZURITE_CONNECTION_STRING = ( + "DefaultEndpointsProtocol=http;" + "AccountName=devstoreaccount1;" + "AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;" + "BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;" +) + + +def _port_open(host: str = "127.0.0.1", port: int | None = None, timeout: float = 2.0) -> bool: + if port is None: + return False + try: + with socket.create_connection((host, port), timeout=timeout): + return True + except (OSError, TimeoutError): + return False + + +def _ensure_s3_buckets(host: str) -> None: + if not _port_open(host, EMULATOR_PORTS["s3"]): + return + os.environ.setdefault("AWS_ENDPOINT_URL", "http://127.0.0.1:5000") + os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing") + os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing") + import boto3 + from botocore.config import Config + + client = boto3.client( + "s3", + endpoint_url=os.environ["AWS_ENDPOINT_URL"], + aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], + region_name="us-east-1", + config=Config(signature_version="s3v4"), + ) + existing = {b["Name"] for b in client.list_buckets().get("Buckets", [])} + for name in S3_BUCKETS: + if name not in existing: + client.create_bucket(Bucket=name) + + +def _ensure_azure_containers(host: str) -> None: + if not _port_open(host, EMULATOR_PORTS["azure"]): + return + from azure.storage.blob import BlobServiceClient + + client = BlobServiceClient.from_connection_string(AZURITE_CONNECTION_STRING) + existing = {c.name for c in client.list_containers()} + for name in AZURE_CONTAINERS: + if name not in existing: + client.create_container(name) + + +def _ensure_gcs_buckets(host: str) -> None: + if not _port_open(host, EMULATOR_PORTS["gcs"]): + return + os.environ.setdefault("STORAGE_EMULATOR_HOST", "http://127.0.0.1:4443") + from google.auth.credentials import AnonymousCredentials + from google.cloud import storage + + client = storage.Client(credentials=AnonymousCredentials(), project="test") + existing = {b.name for b in client.list_buckets()} + for name in GCS_BUCKETS: + if name not in existing: + client.create_bucket(name) + + +def _wait_for_emulator_ports(host: str = "127.0.0.1", timeout: float = 60.0, check_interval: float = 2.0) -> None: + """Wait until all three emulator ports accept connections (e.g. after docker run).""" + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if all(_port_open(host, EMULATOR_PORTS[p]) for p in ("s3", "azure", "gcs")): + return + time.sleep(check_interval) + raise RuntimeError( + f"Emulators did not become ready within {timeout}s. " + "Ensure the container is running: docker run --rm -d -p 5000:5000 " + "-p 10000:10000 -p 4443:4443 spatialdata-emulators" + ) + + +@pytest.fixture(scope="session") +def _remote_storage_buckets_containers(): + """Create buckets/containers on running emulators so remote storage tests can run. + + Run with emulators up, e.g.: + docker run --rm -d -p 5000:5000 -p 10000:10000 -p 4443:4443 spatialdata-emulators + Then: pytest tests/io/test_remote_storage.py -v + """ + host = "127.0.0.1" + _wait_for_emulator_ports(host) + _ensure_s3_buckets(host) + _ensure_azure_containers(host) + _ensure_gcs_buckets(host) + yield + + +def pytest_collection_modifyitems(config: pytest.Config, items: list) -> None: + """Inject bucket/container creation for test_remote_storage.py.""" + for item in items: + path = getattr(item, "path", None) or getattr(item, "fspath", None) + if path and "test_remote_storage" in str(path): + item.add_marker(pytest.mark.usefixtures("_remote_storage_buckets_containers")) diff --git a/tests/io/remote_storage/test_remote_storage.py b/tests/io/remote_storage/test_remote_storage.py new file mode 100644 index 000000000..44685061a --- /dev/null +++ b/tests/io/remote_storage/test_remote_storage.py @@ -0,0 +1,190 @@ +"""Integration tests for remote storage (Azure, S3, GCS) using real emulators. + +Emulators must be running (e.g. Docker: docker run -p 5000:5000 -p 10000:10000 -p 4443:4443 spatialdata-emulators). +Ports: S3/moto 5000, Azure/Azurite 10000, GCS/fake-gcs-server 4443. +tests/io/conftest.py creates the required buckets/containers when emulators are up. + +All remote paths use uuid.uuid4().hex so each test run writes to a unique location. +""" + +from __future__ import annotations + +import os +import uuid + +import pytest +from upath import UPath + +from spatialdata import SpatialData +from spatialdata.testing import assert_spatial_data_objects_are_identical + +# Azure emulator connection string (Azurite default). +# https://learn.microsoft.com/en-us/azure/storage/common/storage-configure-connection-string +AZURE_CONNECTION_STRING = ( + "DefaultEndpointsProtocol=http;" + "AccountName=devstoreaccount1;" + "AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;" + "BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1;" +) + + +def _get_azure_upath(container: str = "test-container", path: str = "test.zarr") -> UPath: + """Create Azure UPath for testing with Azurite (local emulator).""" + return UPath(f"az://{container}/{path}", connection_string=AZURE_CONNECTION_STRING) + + +def _get_s3_upath(container: str = "bucket", path: str = "test.zarr") -> UPath: + """Create S3 UPath for testing (moto emulator at 5000).""" + endpoint = os.environ.get("AWS_ENDPOINT_URL", "http://127.0.0.1:5000") + if endpoint: + return UPath( + f"s3://{container}/{path}", + endpoint_url=endpoint, + key=os.environ.get("AWS_ACCESS_KEY_ID", "testing"), + secret=os.environ.get("AWS_SECRET_ACCESS_KEY", "testing"), + ) + return UPath(f"s3://{container}/{path}", anon=True) + + +def _get_gcs_upath(container: str = "bucket", path: str = "test.zarr") -> UPath: + """Create GCS UPath for testing with fake-gcs-server (port 4443).""" + os.environ.setdefault("STORAGE_EMULATOR_HOST", "http://127.0.0.1:4443") + return UPath( + f"gs://{container}/{path}", + endpoint_url=os.environ["STORAGE_EMULATOR_HOST"], + token="anon", + project="test", + ) + + +GET_UPATH_PARAMS = pytest.mark.parametrize( + "get_upath", [_get_azure_upath, _get_s3_upath, _get_gcs_upath], ids=["azure", "s3", "gcs"] +) +REMOTE_STORAGE_PARAMS = pytest.mark.parametrize( + "get_upath,storage_name", + [(_get_azure_upath, "azure"), (_get_s3_upath, "s3"), (_get_gcs_upath, "gcs")], + ids=["azure", "s3", "gcs"], +) + +# Ensure buckets/containers exist on emulators before any test (see tests/io/conftest.py) +pytestmark = pytest.mark.usefixtures("_remote_storage_buckets_containers") + + +def _assert_read_identical(expected: SpatialData, upath: UPath, *, check_path: bool = True) -> None: + """Read SpatialData from upath and assert it equals expected; optionally assert path.""" + sdata_read = SpatialData.read(upath) + if check_path: + assert isinstance(sdata_read.path, UPath) + assert sdata_read.path == upath + assert_spatial_data_objects_are_identical(expected, sdata_read) + + +class TestPathSetter: + """Test SpatialData.path setter with UPath objects.""" + + @GET_UPATH_PARAMS + def test_path_setter_accepts_upath(self, get_upath) -> None: + """Test that SpatialData.path setter accepts UPath for remote storage. + + This test fails, reproducing issue #441: SpatialData.path setter only accepts + None | str | Path, not UPath, preventing the use of remote storage. + """ + sdata = SpatialData() + upath = get_upath(path=f"test-accept-{uuid.uuid4().hex}.zarr") + sdata.path = upath + assert sdata.path == upath + + @GET_UPATH_PARAMS + def test_write_with_upath_sets_path(self, get_upath) -> None: + """Test that writing to UPath sets SpatialData.path correctly. + + This test fails because SpatialData.write() rejects UPath in + _validate_can_safely_write_to_path() before it can set sdata.path. + """ + sdata = SpatialData() + upath = get_upath(path=f"test-write-path-{uuid.uuid4().hex}.zarr") + sdata.write(upath) + assert isinstance(sdata.path, UPath) + + def test_path_setter_rejects_other_types(self) -> None: + """Test that SpatialData.path setter rejects other types.""" + sdata = SpatialData() + with pytest.raises(TypeError, match="Path must be.*str.*Path"): + sdata.path = 123 + with pytest.raises(TypeError, match="Path must be.*str.*Path"): + sdata.path = {"not": "a path"} + + +class TestRemoteStorage: + """Test end-to-end remote storage workflows with UPath. + + Note: These tests require appropriate emulators running (Azurite for Azure, + moto for S3, fake-gcs-server for GCS). Tests will fail if emulators are not available. + """ + + @REMOTE_STORAGE_PARAMS + def test_write_read_roundtrip_remote(self, full_sdata: SpatialData, get_upath, storage_name: str) -> None: + """Test writing and reading SpatialData to/from remote storage. + + This test verifies the full workflow: + 1. Write SpatialData to remote storage using UPath + 2. Read SpatialData from remote storage using UPath + 3. Verify data integrity (round-trip) + """ + upath = get_upath(container=f"test-{storage_name}", path=f"roundtrip-{uuid.uuid4().hex}.zarr") + full_sdata.write(upath, overwrite=True) + assert isinstance(full_sdata.path, UPath) + assert full_sdata.path == upath + _assert_read_identical(full_sdata, upath) + + @REMOTE_STORAGE_PARAMS + def test_path_setter_with_remote_then_operations( + self, full_sdata: SpatialData, get_upath, storage_name: str + ) -> None: + """Test setting remote path, then performing operations. + + This test verifies that after setting a remote path: + 1. Path is correctly stored + 2. Write operations work + 3. Read operations work + """ + upath = get_upath(container=f"test-{storage_name}", path=f"operations-{uuid.uuid4().hex}.zarr") + full_sdata.path = upath + assert full_sdata.path == upath + assert full_sdata.is_backed() is True + full_sdata.write(overwrite=True) + assert full_sdata.path == upath + _assert_read_identical(full_sdata, upath) + + @REMOTE_STORAGE_PARAMS + def test_overwrite_existing_remote_data(self, full_sdata: SpatialData, get_upath, storage_name: str) -> None: + """Test overwriting existing data in remote storage. + + Verifies that overwriting existing remote data works (path-exists handling) + and data integrity after overwrite. Round-trip is covered by + test_write_read_roundtrip_remote. + """ + upath = get_upath(container=f"test-{storage_name}", path=f"overwrite-{uuid.uuid4().hex}.zarr") + full_sdata.write(upath, overwrite=True) + full_sdata.write(upath, overwrite=True) + _assert_read_identical(full_sdata, upath, check_path=False) + + @REMOTE_STORAGE_PARAMS + def test_write_element_to_remote_storage(self, full_sdata: SpatialData, get_upath, storage_name: str) -> None: + """Test writing individual elements to remote storage using write_element(). + + This test verifies that: + 1. Setting path to remote UPath works + 2. write_element() works with remote storage + 3. Written elements can be read back correctly + """ + upath = get_upath(container=f"test-{storage_name}", path=f"write-element-{uuid.uuid4().hex}.zarr") + # Create empty SpatialData and write to remote storage + empty_sdata = SpatialData() + empty_sdata.write(upath, overwrite=True) + full_sdata.path = upath + assert full_sdata.path == upath + # Write each element type individually + for _element_type, element_name, _ in full_sdata.gen_elements(): + full_sdata.write_element(element_name, overwrite=True) + _assert_read_identical(full_sdata, upath, check_path=False) diff --git a/tests/io/remote_storage/test_resolve_zarr_store.py b/tests/io/remote_storage/test_resolve_zarr_store.py new file mode 100644 index 000000000..d8c90d46d --- /dev/null +++ b/tests/io/remote_storage/test_resolve_zarr_store.py @@ -0,0 +1,55 @@ +"""Unit tests for remote-storage-specific store resolution and credential handling. + +Covers only code paths used when reading/writing from remote backends (Azure, S3, GCS): +- _FsspecStoreRoot resolution (used when reading elements from a remote zarr store). +- _storage_options_from_fs for Azure and GCS (used when writing parquet to remote). +""" + +from __future__ import annotations + +from zarr.storage import FsspecStore + +from spatialdata._io._utils import _FsspecStoreRoot, _resolve_zarr_store, _storage_options_from_fs + + +def test_resolve_zarr_store_fsspec_store_root() -> None: + """_FsspecStoreRoot is resolved to FsspecStore when reading from remote (e.g. points/shapes paths).""" + import fsspec + from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper + + fs = fsspec.filesystem("memory") + async_fs = AsyncFileSystemWrapper(fs, asynchronous=True) + base = FsspecStore(async_fs, path="/") + root = _FsspecStoreRoot(base, "/") + store = _resolve_zarr_store(root) + assert isinstance(store, FsspecStore) + + +def test_storage_options_from_fs_azure_account_key() -> None: + """_storage_options_from_fs extracts Azure credentials for writing parquet to remote Azure Blob.""" + + class AzureBlobFileSystemMock: + account_name = "dev" + account_key = "key123" + connection_string = None + anon = None + + AzureBlobFileSystemMock.__name__ = "AzureBlobFileSystem" + out = _storage_options_from_fs(AzureBlobFileSystemMock()) + assert out["account_name"] == "dev" + assert out["account_key"] == "key123" + + +def test_storage_options_from_fs_gcs_endpoint() -> None: + """_storage_options_from_fs extracts GCS endpoint and project for writing parquet to remote GCS.""" + + class GCSFileSystemMock: + token = "anon" + _endpoint = "http://localhost:4443" + project = "test" + + GCSFileSystemMock.__name__ = "GCSFileSystem" + out = _storage_options_from_fs(GCSFileSystemMock()) + assert out["token"] == "anon" + assert out["endpoint_url"] == "http://localhost:4443" + assert out["project"] == "test" diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index abaaea8d2..5c6bcf6e2 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -113,6 +113,8 @@ def test_set_table_nonexisting_target(self, full_sdata): def test_set_table_annotates_spatialelement(self, full_sdata, tmp_path): tmpdir = Path(tmp_path) / "tmp.zarr" del full_sdata["table"].uns[TableModel.ATTRS_KEY] + # full_sdata table has region labels2d+poly; set to labels2d only so set_table_annotates_spatialelement succeeds + full_sdata["table"].obs["region"] = pd.Categorical(["labels2d"] * full_sdata["table"].n_obs) with pytest.raises( TypeError, match="No current annotation metadata found. Please specify both region_key and instance_key." ):