diff --git a/CHANGELOG.md b/CHANGELOG.md index ff8c6d5..13b1d68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # Changelog -## [Unreleased] +## Version 0.0.7 +- Added `img_source` function in main SpatialExperiment class and child classes of VirtualSpatialExperiment (PR #36) - Added `remove_img` function (PR #34) - Refactored `get_img_idx` for improved maintainability - Disambiguated `get_img_data` between `_imgutils.py` and `SpatialExperiment.py` diff --git a/src/spatialexperiment/SpatialExperiment.py b/src/spatialexperiment/SpatialExperiment.py index 715c519..c289fe5 100644 --- a/src/spatialexperiment/SpatialExperiment.py +++ b/src/spatialexperiment/SpatialExperiment.py @@ -745,8 +745,7 @@ def get_img( sample_id: Union[str, bool, None] = None, image_id: Union[str, bool, None] = None, ) -> Union[VirtualSpatialImage, List[VirtualSpatialImage]]: - """ - Retrieve spatial images based on the provided sample and image ids. + """Retrieve spatial images based on the provided sample and image ids. Args: sample_id: @@ -760,7 +759,9 @@ def get_img( - `image_id=""`: Matches image(s) by its(their) id. Returns: - Zero, one, or more `VirtualSpatialImage` objects. + The image(s) matching the specified criteria. Returns `None` if `img_data` is `None`. + When a single image matches, returns a :py:class:`~VirtualSpatialImage` object. + When multiple images match, returns a list of :py:class:`~VirtualSpatialImage` objects. Behavior: - sample_id = True, image_id = True: @@ -783,6 +784,9 @@ def get_img( - sample_id = , image_id = : Returns the image matching the specified sample and image identifiers. + + Raises: + ValueError: If no row matches the provided sample_id and image_id pair. """ _validate_id(sample_id) _validate_id(image_id) @@ -792,6 +796,9 @@ def get_img( indices = get_img_idx(img_data=self.img_data, sample_id=sample_id, image_id=image_id) + if len(indices) == 0: + raise ValueError(f"No matching rows for sample_id={sample_id} and image_id={image_id}") + images = self.img_data[indices,]["data"] return images[0] if len(images) == 1 else images @@ -834,6 +841,9 @@ def add_img( Raises: ValueError: If the sample_id and image_id pair already exists. + + Note: + See :py:meth:`~get_img` for detailed behavior regarding sample_id and image_id parameters. """ _validate_sample_image_ids(img_data=self._img_data, new_sample_id=sample_id, new_image_id=image_id) @@ -880,12 +890,24 @@ def remove_img( in_place: Whether to modify the ``SpatialExperiment`` in place. Defaults to False. + + Returns: + A modified ``SpatialExperiment`` object, either as a copy of the original or as a reference to the (in-place-modified) original. + + Raises: + ValueError: If no row matches the provided sample_id and image_id pair. + + Note: + See :py:meth:`~get_img` for detailed behavior regarding sample_id and image_id parameters. """ _validate_id(sample_id) _validate_id(image_id) indices = get_img_idx(img_data=self.img_data, sample_id=sample_id, image_id=image_id) + if len(indices) == 0: + raise ValueError(f"No matching rows for sample_id={sample_id} and image_id={image_id}") + new_img_data = self._img_data.remove_rows(indices) output = self._define_output(in_place=in_place) @@ -898,7 +920,43 @@ def img_source( image_id: Union[str, bool, None] = None, path=False, ): - raise NotImplementedError("This function is irrelevant because it is for `RemoteSpatialImages`") + """Retrieve the source(s) for images stored in the SpatialExperiment object. + + Args: + sample_id: + - `sample_id=True`: Matches all samples. + - `sample_id=None`: Matches the first sample. + - `sample_id=""`: Matches a sample by its id. + + image_id: + - `image_id=True`: Matches all images for the specified sample(s). + - `image_id=None`: Matches the first image for the sample(s). + - `image_id=""`: Matches image(s) by its(their) id. + + path: If True, returns path as string. Defaults to False. + + Returns: + The image source(s) for the matching criteria. Returns `None` if `img_data` is `None`. + When a single image matches, returns its source as a `str`, `Path`, or `None`. + When multiple images match, returns a list of sources. + + Raises: + ValueError: If no row matches the provided sample_id and image_id pair. + + Note: + See :py:meth:`~get_img` for detailed behavior regarding sample_id and image_id parameters. + """ + spis = self.get_img(sample_id=sample_id, image_id=image_id) + + if spis is None: + return None + + if isinstance(spis, VirtualSpatialImage): + return spis.img_source(as_path=path) + + img_sources = [spi.img_source(as_path=path) for spi in spis] + + return img_sources def img_raster(self, sample_id=None, image_id=None): # NOTE: this function seems redundant, might be an artifact of the different subclasses of SpatialImage in the R implementation? just call `get_img()` for now diff --git a/src/spatialexperiment/SpatialImage.py b/src/spatialexperiment/SpatialImage.py index 5ac1266..852af7d 100644 --- a/src/spatialexperiment/SpatialImage.py +++ b/src/spatialexperiment/SpatialImage.py @@ -105,6 +105,18 @@ def dimensions(self) -> Tuple[int, int]: ######>> img utils <<####### ############################ + @abstractmethod + def img_source(self, as_path: bool = False) -> Union[str, None]: + """Get the source of the image. + + Args: + as_path: If True, returns path as string. Defaults to False. + + Returns: + Source path/URL of the image, or None if loaded in memory. + """ + pass + @abstractmethod def img_raster(self) -> Image.Image: """Get the image as a PIL Image object.""" @@ -282,6 +294,14 @@ def image(self, image: Union[Image.Image, np.ndarray]): ) return self.set_image(image=image, in_place=True) + def img_source(self, as_path: bool = False) -> None: + """Get the source of the loaded image. + + Returns: + Always returns None. + """ + return None + ############################ ######>> img utils <<####### ############################ @@ -440,7 +460,14 @@ def path(self, path: Union[str, Path]): return self.set_path(path=path, in_place=True) def img_source(self, as_path: bool = False) -> str: - """Get the source path of the image.""" + """Get the source path of the image. + + Args: + as_path: If True, returns string path. Defaults to False. + + Returns: + Path to the image. + """ return str(self._path) if as_path is True else self._path ############################ @@ -633,7 +660,14 @@ def img_raster(self) -> Image.Image: return Image.open(cache_path) def img_source(self, as_path: bool = False) -> str: - """Get the source URL or cached path of the image.""" + """Get the source URL or cached path of the image. + + Args: + as_path: If True, returns downloaded path. Defaults to False. + + Returns: + URL or cached path of the image. + """ if as_path: return str(self._download_image()) return self._url diff --git a/tests/test_get_img.py b/tests/test_get_img.py index 2c66405..6da0bc1 100644 --- a/tests/test_get_img.py +++ b/tests/test_get_img.py @@ -1,6 +1,7 @@ import pytest from copy import deepcopy import numpy as np + from spatialexperiment.SpatialImage import VirtualSpatialImage __author__ = "keviny2" diff --git a/tests/test_img_source.py b/tests/test_img_source.py new file mode 100644 index 0000000..4346924 --- /dev/null +++ b/tests/test_img_source.py @@ -0,0 +1,176 @@ +import pytest +from copy import deepcopy +from pathlib import Path +from PIL import Image +import numpy as np +from spatialexperiment import construct_spatial_image_class +from spatialexperiment.SpatialImage import ( + LoadedSpatialImage, + StoredSpatialImage, + RemoteSpatialImage +) + +def test_loaded_spatial_image_img_source(): + image = Image.open("tests/images/sample_image2.png") + spi_loaded = construct_spatial_image_class(image, is_url=False) + + assert isinstance(spi_loaded, LoadedSpatialImage) + assert spi_loaded.img_source() is None + assert spi_loaded.img_source(as_path=True) is None + + np_image = np.zeros((100, 100, 3), dtype=np.uint8) + spi_loaded_np = construct_spatial_image_class(np_image) + + assert isinstance(spi_loaded_np, LoadedSpatialImage) + assert spi_loaded_np.img_source() is None + assert spi_loaded_np.img_source(as_path=True) is None + + +def test_stored_spatial_image_img_source(): + image_path = "tests/images/sample_image1.jpg" + spi_stored = construct_spatial_image_class(image_path, is_url=False) + + assert isinstance(spi_stored, StoredSpatialImage) + + source_path = spi_stored.img_source() + assert isinstance(source_path, Path) + assert image_path in str(source_path) + + source_str = spi_stored.img_source(as_path=True) + assert isinstance(source_str, str) + assert image_path in source_str + + assert str(source_path) == str(spi_stored.path) + + +def test_remote_spatial_image_img_source(): + image_url = "https://example.com/test_image.jpg" + spi_remote = construct_spatial_image_class(image_url, is_url=True) + + assert isinstance(spi_remote, RemoteSpatialImage) + + source = spi_remote.img_source() + assert isinstance(source, str) + assert source == image_url + + +def test_remote_spatial_image_img_source_with_mock(monkeypatch): + image_url = "https://example.com/test_image.jpg" + spi_remote = construct_spatial_image_class(image_url, is_url=True) + + assert isinstance(spi_remote, RemoteSpatialImage) + + # Mock the _download_image method to return a fixed path + mock_path = Path("/tmp/image.jpg") + monkeypatch.setattr(spi_remote, "_download_image", lambda: mock_path) + + # Test with as_path=True (returns the cached path) + source_path = spi_remote.img_source(as_path=True) + assert source_path == str(mock_path) + + # Test default behavior returns URL + assert spi_remote.img_source() == image_url + + +def test_img_source_no_img_data(spe): + tspe = deepcopy(spe) + tspe.img_data = None + assert not tspe.img_source() + + +def test_img_source_no_matches(spe): + with pytest.raises(ValueError): + sources = spe.img_source(sample_id="foo", image_id="foo") + + +def test_img_source_both_str(spe): + res = spe.img_source(sample_id="sample_1", image_id="dice") + expected_source = spe.get_img(sample_id="sample_1", image_id="dice").img_source() + + assert isinstance(res, Path) + assert res == expected_source + + +def test_img_source_both_str_path(spe): + res = spe.img_source(sample_id="sample_1", image_id="dice", path=True) + expected_source = spe.get_img(sample_id="sample_1", image_id="dice").img_source(as_path=True) + + assert isinstance(res, str) + assert res == expected_source + + +def test_img_source_both_true(spe): + res = spe.img_source(sample_id=True, image_id=True) + images = spe.get_img(sample_id=True, image_id=True) + expected_sources = [img.img_source() for img in images] + + assert isinstance(res, list) + assert res == expected_sources + + +def test_img_source_both_true_path(spe): + res = spe.img_source(sample_id=True, image_id=True, path=True) + images = spe.get_img(sample_id=True, image_id=True) + expected_sources = [img.img_source(as_path=True) for img in images] + + assert isinstance(res, list) + assert res == expected_sources + + +def test_img_source_both_none(spe): + res = spe.img_source(sample_id=None, image_id=None) + expected_source = spe.get_img(sample_id=None, image_id=None).img_source() + + assert isinstance(res, Path) + assert res == expected_source + + +def test_img_source_sample_str_image_true(spe): + res = spe.img_source(sample_id="sample_1", image_id=True) + images = spe.get_img(sample_id="sample_1", image_id=True) + expected_sources = [img.img_source() for img in images] + + assert isinstance(res, list) + assert res == expected_sources + + +def test_img_source_sample_true_image_str(spe): + res = spe.img_source(sample_id=True, image_id="desert") + expected_source = spe.get_img(sample_id=True, image_id="desert").img_source() + + assert isinstance(res, Path) + assert res == expected_source + + +def test_img_source_sample_str_image_none(spe): + res = spe.img_source(sample_id="sample_1", image_id=None) + expected_source = spe.get_img(sample_id="sample_1", image_id=None).img_source() + + assert isinstance(res, Path) + assert res == expected_source + + +def test_img_source_sample_none_image_str(spe): + res = spe.img_source(sample_id=None, image_id="aurora") + expected_source = spe.get_img(sample_id=None, image_id="aurora").img_source() + + assert isinstance(res, Path) + assert res == expected_source + + +def test_img_source_sample_true_image_none(spe): + res = spe.img_source(sample_id=True, image_id=None) + images = spe.get_img(sample_id=True, image_id=None) + expected_sources = [img.img_source() for img in images] + + assert isinstance(res, list) + assert res == expected_sources + + +def test_img_source_sample_none_image_true(spe): + res = spe.img_source(sample_id=None, image_id=True) + images = spe.get_img(sample_id=None, image_id=True) + expected_sources = [img.img_source() for img in images] + + assert isinstance(res, list) + assert res == expected_sources diff --git a/tests/test_spi.py b/tests/test_spi.py index f7e93c5..f5a8ad4 100644 --- a/tests/test_spi.py +++ b/tests/test_spi.py @@ -45,16 +45,18 @@ def test_spi_constructor_image(): def test_spi_constructor_url(): - image_url = "https://i.redd.it/3pw5uah7xo041.jpg" + image_url = "https://example.com/test_image.jpg" spi_remote = construct_spatial_image_class(image_url, is_url=True) assert issubclass(type(spi_remote), VirtualSpatialImage) assert isinstance(spi_remote, RemoteSpatialImage) assert spi_remote.url == image_url -def test_invalid_input(): - with pytest.raises(Exception): - construct_spatial_image_class(5, is_url=False) +def test_auto_detect_url(): + url = "https://example.com/image.jpg" + img = construct_spatial_image_class(url) + assert isinstance(img, RemoteSpatialImage) + assert img.img_source() == url def test_spi_equality(): @@ -67,7 +69,7 @@ def test_spi_equality(): assert spi_path_1 == spi_path_2 - image_url = "https://i.redd.it/3pw5uah7xo041.jpg" + image_url = "https://example.com/test_image.jpg" spi_url_1 = construct_spatial_image_class(image_url, is_url=True) spi_url_2 = construct_spatial_image_class(image_url, is_url=True) @@ -82,3 +84,8 @@ def test_spi_equality(): assert spi_path_1 != spi_url_1 assert spi_path_1 != spi_image_1 assert spi_url_1 != spi_image_1 + + +def test_invalid_input(): + with pytest.raises(Exception): + construct_spatial_image_class(5, is_url=False)