diff --git a/src/spatialexperiment/SpatialExperiment.py b/src/spatialexperiment/SpatialExperiment.py index c289fe5..a5f66e6 100644 --- a/src/spatialexperiment/SpatialExperiment.py +++ b/src/spatialexperiment/SpatialExperiment.py @@ -919,7 +919,7 @@ def img_source( sample_id: Union[str, bool, None] = None, image_id: Union[str, bool, None] = None, path=False, - ): + ) -> Union[str, Path, None, List[Union[str, Path]]]: """Retrieve the source(s) for images stored in the SpatialExperiment object. Args: @@ -958,9 +958,42 @@ def img_source( return img_sources - def img_raster(self, sample_id=None, image_id=None): - # NOTE: this function seems redundant, might be an artifact of the different subclasses of SpatialImage in the R implementation? just call `get_img()` for now - self.get_img(sample_id=sample_id, image_id=image_id) + def img_raster(self, sample_id=None, image_id=None) -> Union[Image.Image, List[Image.Image], None]: + """Retrieve and load (if necessary) the images stored in the SpatialExperiment object. + + Args: + sample_id: + - `sample_id=True`: Matches all samples. + - `sample_id=None`: Matches the first sample. + - `sample_id=""`: Matches a sample by its id. + + image_id: + - `image_id=True`: Matches all images for the specified sample(s). + - `image_id=None`: Matches the first image for the sample(s). + - `image_id=""`: Matches image(s) by its(their) id. + + Returns: + The loaded image(s) for the matching criteria. Returns `None` if `img_data` is `None`. + When a single image matches, returns its loaded image. + When multiple images match, returns a list of loaded images. + + Raises: + ValueError: If no row matches the provided sample_id and image_id pair. + + Note: + See :py:meth:`~get_img` for detailed behavior regarding sample_id and image_id parameters. + """ + spis = self.get_img(sample_id=sample_id, image_id=image_id) + + if spis is None: + return None + + if isinstance(spis, VirtualSpatialImage): + return spis.img_raster() + + img_rasters = [spi.img_raster() for spi in spis] + + return img_rasters def rotate_img(self, sample_id=None, image_id=None, degrees=90): raise NotImplementedError() diff --git a/tests/test_img_raster.py b/tests/test_img_raster.py new file mode 100644 index 0000000..c53d153 --- /dev/null +++ b/tests/test_img_raster.py @@ -0,0 +1,148 @@ +import pytest +from copy import deepcopy +from PIL import Image +import numpy as np +from spatialexperiment import construct_spatial_image_class +from spatialexperiment.SpatialImage import ( + LoadedSpatialImage, + StoredSpatialImage, + RemoteSpatialImage +) + +def test_loaded_spatial_image_img_raster(): + image = Image.open("tests/images/sample_image2.png") + spi_loaded = construct_spatial_image_class(image, is_url=False) + raster = spi_loaded.img_raster() + + assert isinstance(spi_loaded, LoadedSpatialImage) + assert isinstance(raster, Image.Image) + + np_image = np.zeros((100, 100, 3), dtype=np.uint8) + spi_loaded_np = construct_spatial_image_class(np_image) + raster_np = spi_loaded_np.img_raster() + + assert isinstance(spi_loaded_np, LoadedSpatialImage) + assert isinstance(raster_np, Image.Image) + + +def test_stored_spatial_image_img_raster(): + image_path = "tests/images/sample_image1.jpg" + spi_stored = construct_spatial_image_class(image_path, is_url=False) + raster = spi_stored.img_raster() + + assert isinstance(spi_stored, StoredSpatialImage) + assert isinstance(raster, Image.Image) + + +def test_remote_spatial_image_img_raster(monkeypatch): + image_url = "https://example.com/test_image.jpg" + spi_remote = construct_spatial_image_class(image_url, is_url=True) + + # Mock the _download_image method to return an image + mock_path = "tests/images/sample_image1.jpg" + monkeypatch.setattr(spi_remote, "_download_image", lambda: mock_path) + + raster = spi_remote.img_raster() + + assert isinstance(spi_remote, RemoteSpatialImage) + assert isinstance(raster, Image.Image) + + # Test LRU cache works as expected + num_calls = 0 + + def mock_download(): + num_calls += 1 + return mock_path + + monkeypatch.setattr(spi_remote, "_download_image", mock_download) + + raster2 = spi_remote.img_raster() + assert num_calls == 0 + assert raster2 is raster + + +def test_img_raster_no_img_data(spe): + tspe = deepcopy(spe) + tspe.img_data = None + assert not tspe.img_raster() + + +def test_img_raster_no_matches(spe): + with pytest.raises(ValueError): + res = spe.img_raster(sample_id="foo", image_id="foo") + + +def test_img_raster_both_str(spe): + res = spe.img_raster(sample_id="sample_1", image_id="dice") + expected_raster = spe.get_img(sample_id="sample_1", image_id="dice").img_raster() + + assert isinstance(res, Image.Image) + assert res == expected_raster + + +def test_img_raster_both_true(spe): + res = spe.img_raster(sample_id=True, image_id=True) + images = spe.get_img(sample_id=True, image_id=True) + expected_rasters = [img.img_raster() for img in images] + + assert isinstance(res, list) + assert res == expected_rasters + + +def test_img_raster_both_none(spe): + res = spe.img_raster(sample_id=None, image_id=None) + expected_raster = spe.get_img(sample_id=None, image_id=None).img_raster() + + assert isinstance(res, Image.Image) + assert res == expected_raster + + +def test_img_raster_sample_str_image_true(spe): + res = spe.img_raster(sample_id="sample_1", image_id=True) + images = spe.get_img(sample_id="sample_1", image_id=True) + expected_rasters = [img.img_raster() for img in images] + + assert isinstance(res, list) + assert res == expected_rasters + + +def test_img_raster_sample_true_image_str(spe): + res = spe.img_raster(sample_id=True, image_id="desert") + expected_raster = spe.get_img(sample_id=True, image_id="desert").img_raster() + + assert isinstance(res, Image.Image) + assert res == expected_raster + + +def test_img_raster_sample_str_image_none(spe): + res = spe.img_raster(sample_id="sample_1", image_id=None) + expected_raster = spe.get_img(sample_id="sample_1", image_id=None).img_raster() + + assert isinstance(res, Image.Image) + assert res == expected_raster + + +def test_img_raster_sample_none_image_str(spe): + res = spe.img_raster(sample_id=None, image_id="aurora") + expected_raster = spe.get_img(sample_id=None, image_id="aurora").img_raster() + + assert isinstance(res, Image.Image) + assert res == expected_raster + + +def test_img_raster_sample_true_image_none(spe): + res = spe.img_raster(sample_id=True, image_id=None) + images = spe.get_img(sample_id=True, image_id=None) + expected_rasters = [img.img_raster() for img in images] + + assert isinstance(res, list) + assert res == expected_rasters + + +def test_img_raster_sample_none_image_true(spe): + res = spe.img_raster(sample_id=None, image_id=True) + images = spe.get_img(sample_id=None, image_id=True) + expected_rasters = [img.img_raster() for img in images] + + assert isinstance(res, list) + assert res == expected_rasters