Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions src/spatialexperiment/SpatialExperiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ def img_source(
sample_id: Union[str, bool, None] = None,
image_id: Union[str, bool, None] = None,
path=False,
):
) -> Union[str, Path, None, List[Union[str, Path]]]:
"""Retrieve the source(s) for images stored in the SpatialExperiment object.

Args:
Expand Down Expand Up @@ -958,9 +958,42 @@ def img_source(

return img_sources

def img_raster(self, sample_id=None, image_id=None):
# NOTE: this function seems redundant, might be an artifact of the different subclasses of SpatialImage in the R implementation? just call `get_img()` for now
self.get_img(sample_id=sample_id, image_id=image_id)
def img_raster(self, sample_id=None, image_id=None) -> Union[Image.Image, List[Image.Image], None]:
"""Retrieve and load (if necessary) the images stored in the SpatialExperiment object.

Args:
sample_id:
- `sample_id=True`: Matches all samples.
- `sample_id=None`: Matches the first sample.
- `sample_id="<str>"`: Matches a sample by its id.

image_id:
- `image_id=True`: Matches all images for the specified sample(s).
- `image_id=None`: Matches the first image for the sample(s).
- `image_id="<str>"`: Matches image(s) by its(their) id.

Returns:
The loaded image(s) for the matching criteria. Returns `None` if `img_data` is `None`.
When a single image matches, returns its loaded image.
When multiple images match, returns a list of loaded images.

Raises:
ValueError: If no row matches the provided sample_id and image_id pair.

Note:
See :py:meth:`~get_img` for detailed behavior regarding sample_id and image_id parameters.
"""
spis = self.get_img(sample_id=sample_id, image_id=image_id)

if spis is None:
return None

if isinstance(spis, VirtualSpatialImage):
return spis.img_raster()

img_rasters = [spi.img_raster() for spi in spis]

return img_rasters

def rotate_img(self, sample_id=None, image_id=None, degrees=90):
raise NotImplementedError()
Expand Down
148 changes: 148 additions & 0 deletions tests/test_img_raster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import pytest
from copy import deepcopy
from PIL import Image
import numpy as np
from spatialexperiment import construct_spatial_image_class
from spatialexperiment.SpatialImage import (
LoadedSpatialImage,
StoredSpatialImage,
RemoteSpatialImage
)

def test_loaded_spatial_image_img_raster():
image = Image.open("tests/images/sample_image2.png")
spi_loaded = construct_spatial_image_class(image, is_url=False)
raster = spi_loaded.img_raster()

assert isinstance(spi_loaded, LoadedSpatialImage)
assert isinstance(raster, Image.Image)

np_image = np.zeros((100, 100, 3), dtype=np.uint8)
spi_loaded_np = construct_spatial_image_class(np_image)
raster_np = spi_loaded_np.img_raster()

assert isinstance(spi_loaded_np, LoadedSpatialImage)
assert isinstance(raster_np, Image.Image)


def test_stored_spatial_image_img_raster():
image_path = "tests/images/sample_image1.jpg"
spi_stored = construct_spatial_image_class(image_path, is_url=False)
raster = spi_stored.img_raster()

assert isinstance(spi_stored, StoredSpatialImage)
assert isinstance(raster, Image.Image)


def test_remote_spatial_image_img_raster(monkeypatch):
image_url = "https://example.com/test_image.jpg"
spi_remote = construct_spatial_image_class(image_url, is_url=True)

# Mock the _download_image method to return an image
mock_path = "tests/images/sample_image1.jpg"
monkeypatch.setattr(spi_remote, "_download_image", lambda: mock_path)

raster = spi_remote.img_raster()

assert isinstance(spi_remote, RemoteSpatialImage)
assert isinstance(raster, Image.Image)

# Test LRU cache works as expected
num_calls = 0

def mock_download():
num_calls += 1
return mock_path

monkeypatch.setattr(spi_remote, "_download_image", mock_download)

raster2 = spi_remote.img_raster()
assert num_calls == 0
assert raster2 is raster


def test_img_raster_no_img_data(spe):
tspe = deepcopy(spe)
tspe.img_data = None
assert not tspe.img_raster()


def test_img_raster_no_matches(spe):
with pytest.raises(ValueError):
res = spe.img_raster(sample_id="foo", image_id="foo")


def test_img_raster_both_str(spe):
res = spe.img_raster(sample_id="sample_1", image_id="dice")
expected_raster = spe.get_img(sample_id="sample_1", image_id="dice").img_raster()

assert isinstance(res, Image.Image)
assert res == expected_raster


def test_img_raster_both_true(spe):
res = spe.img_raster(sample_id=True, image_id=True)
images = spe.get_img(sample_id=True, image_id=True)
expected_rasters = [img.img_raster() for img in images]

assert isinstance(res, list)
assert res == expected_rasters


def test_img_raster_both_none(spe):
res = spe.img_raster(sample_id=None, image_id=None)
expected_raster = spe.get_img(sample_id=None, image_id=None).img_raster()

assert isinstance(res, Image.Image)
assert res == expected_raster


def test_img_raster_sample_str_image_true(spe):
res = spe.img_raster(sample_id="sample_1", image_id=True)
images = spe.get_img(sample_id="sample_1", image_id=True)
expected_rasters = [img.img_raster() for img in images]

assert isinstance(res, list)
assert res == expected_rasters


def test_img_raster_sample_true_image_str(spe):
res = spe.img_raster(sample_id=True, image_id="desert")
expected_raster = spe.get_img(sample_id=True, image_id="desert").img_raster()

assert isinstance(res, Image.Image)
assert res == expected_raster


def test_img_raster_sample_str_image_none(spe):
res = spe.img_raster(sample_id="sample_1", image_id=None)
expected_raster = spe.get_img(sample_id="sample_1", image_id=None).img_raster()

assert isinstance(res, Image.Image)
assert res == expected_raster


def test_img_raster_sample_none_image_str(spe):
res = spe.img_raster(sample_id=None, image_id="aurora")
expected_raster = spe.get_img(sample_id=None, image_id="aurora").img_raster()

assert isinstance(res, Image.Image)
assert res == expected_raster


def test_img_raster_sample_true_image_none(spe):
res = spe.img_raster(sample_id=True, image_id=None)
images = spe.get_img(sample_id=True, image_id=None)
expected_rasters = [img.img_raster() for img in images]

assert isinstance(res, list)
assert res == expected_rasters


def test_img_raster_sample_none_image_true(spe):
res = spe.img_raster(sample_id=None, image_id=True)
images = spe.get_img(sample_id=None, image_id=True)
expected_rasters = [img.img_raster() for img in images]

assert isinstance(res, list)
assert res == expected_rasters