Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Changelog

## [Unreleased]
## Version 0.0.7
- Added `img_source` function in main SpatialExperiment class and child classes of VirtualSpatialExperiment (PR #36)
- Added `remove_img` function (PR #34)
- Refactored `get_img_idx` for improved maintainability
- Disambiguated `get_img_data` between `_imgutils.py` and `SpatialExperiment.py`
Expand Down
66 changes: 62 additions & 4 deletions src/spatialexperiment/SpatialExperiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,8 +745,7 @@ def get_img(
sample_id: Union[str, bool, None] = None,
image_id: Union[str, bool, None] = None,
) -> Union[VirtualSpatialImage, List[VirtualSpatialImage]]:
"""
Retrieve spatial images based on the provided sample and image ids.
"""Retrieve spatial images based on the provided sample and image ids.

Args:
sample_id:
Expand All @@ -760,7 +759,9 @@ def get_img(
- `image_id="<str>"`: Matches image(s) by its(their) id.

Returns:
Zero, one, or more `VirtualSpatialImage` objects.
The image(s) matching the specified criteria. Returns `None` if `img_data` is `None`.
When a single image matches, returns a :py:class:`~VirtualSpatialImage` object.
When multiple images match, returns a list of :py:class:`~VirtualSpatialImage` objects.

Behavior:
- sample_id = True, image_id = True:
Expand All @@ -783,6 +784,9 @@ def get_img(

- sample_id = <str>, image_id = <str>:
Returns the image matching the specified sample and image identifiers.

Raises:
ValueError: If no row matches the provided sample_id and image_id pair.
"""
_validate_id(sample_id)
_validate_id(image_id)
Expand All @@ -792,6 +796,9 @@ def get_img(

indices = get_img_idx(img_data=self.img_data, sample_id=sample_id, image_id=image_id)

if len(indices) == 0:
raise ValueError(f"No matching rows for sample_id={sample_id} and image_id={image_id}")

images = self.img_data[indices,]["data"]
return images[0] if len(images) == 1 else images

Expand Down Expand Up @@ -834,6 +841,9 @@ def add_img(

Raises:
ValueError: If the sample_id and image_id pair already exists.

Note:
See :py:meth:`~get_img` for detailed behavior regarding sample_id and image_id parameters.
"""
_validate_sample_image_ids(img_data=self._img_data, new_sample_id=sample_id, new_image_id=image_id)

Expand Down Expand Up @@ -880,12 +890,24 @@ def remove_img(
in_place:
Whether to modify the ``SpatialExperiment`` in place.
Defaults to False.

Returns:
A modified ``SpatialExperiment`` object, either as a copy of the original or as a reference to the (in-place-modified) original.

Raises:
ValueError: If no row matches the provided sample_id and image_id pair.

Note:
See :py:meth:`~get_img` for detailed behavior regarding sample_id and image_id parameters.
"""
_validate_id(sample_id)
_validate_id(image_id)

indices = get_img_idx(img_data=self.img_data, sample_id=sample_id, image_id=image_id)

if len(indices) == 0:
raise ValueError(f"No matching rows for sample_id={sample_id} and image_id={image_id}")

new_img_data = self._img_data.remove_rows(indices)

output = self._define_output(in_place=in_place)
Expand All @@ -898,7 +920,43 @@ def img_source(
image_id: Union[str, bool, None] = None,
path=False,
):
raise NotImplementedError("This function is irrelevant because it is for `RemoteSpatialImages`")
"""Retrieve the source(s) for images stored in the SpatialExperiment object.

Args:
sample_id:
- `sample_id=True`: Matches all samples.
- `sample_id=None`: Matches the first sample.
- `sample_id="<str>"`: Matches a sample by its id.

image_id:
- `image_id=True`: Matches all images for the specified sample(s).
- `image_id=None`: Matches the first image for the sample(s).
- `image_id="<str>"`: Matches image(s) by its(their) id.

path: If True, returns path as string. Defaults to False.

Returns:
The image source(s) for the matching criteria. Returns `None` if `img_data` is `None`.
When a single image matches, returns its source as a `str`, `Path`, or `None`.
When multiple images match, returns a list of sources.

Raises:
ValueError: If no row matches the provided sample_id and image_id pair.

Note:
See :py:meth:`~get_img` for detailed behavior regarding sample_id and image_id parameters.
"""
spis = self.get_img(sample_id=sample_id, image_id=image_id)

if spis is None:
return None

if isinstance(spis, VirtualSpatialImage):
return spis.img_source(as_path=path)

img_sources = [spi.img_source(as_path=path) for spi in spis]

return img_sources

def img_raster(self, sample_id=None, image_id=None):
# NOTE: this function seems redundant, might be an artifact of the different subclasses of SpatialImage in the R implementation? just call `get_img()` for now
Expand Down
38 changes: 36 additions & 2 deletions src/spatialexperiment/SpatialImage.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ def dimensions(self) -> Tuple[int, int]:
######>> img utils <<#######
############################

@abstractmethod
def img_source(self, as_path: bool = False) -> Union[str, None]:
"""Get the source of the image.

Args:
as_path: If True, returns path as string. Defaults to False.

Returns:
Source path/URL of the image, or None if loaded in memory.
"""
pass

@abstractmethod
def img_raster(self) -> Image.Image:
"""Get the image as a PIL Image object."""
Expand Down Expand Up @@ -282,6 +294,14 @@ def image(self, image: Union[Image.Image, np.ndarray]):
)
return self.set_image(image=image, in_place=True)

def img_source(self, as_path: bool = False) -> None:
"""Get the source of the loaded image.

Returns:
Always returns None.
"""
return None

############################
######>> img utils <<#######
############################
Expand Down Expand Up @@ -440,7 +460,14 @@ def path(self, path: Union[str, Path]):
return self.set_path(path=path, in_place=True)

def img_source(self, as_path: bool = False) -> str:
"""Get the source path of the image."""
"""Get the source path of the image.

Args:
as_path: If True, returns string path. Defaults to False.

Returns:
Path to the image.
"""
return str(self._path) if as_path is True else self._path

############################
Expand Down Expand Up @@ -633,7 +660,14 @@ def img_raster(self) -> Image.Image:
return Image.open(cache_path)

def img_source(self, as_path: bool = False) -> str:
"""Get the source URL or cached path of the image."""
"""Get the source URL or cached path of the image.

Args:
as_path: If True, returns downloaded path. Defaults to False.

Returns:
URL or cached path of the image.
"""
if as_path:
return str(self._download_image())
return self._url
Expand Down
1 change: 1 addition & 0 deletions tests/test_get_img.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from copy import deepcopy
import numpy as np

from spatialexperiment.SpatialImage import VirtualSpatialImage

__author__ = "keviny2"
Expand Down
176 changes: 176 additions & 0 deletions tests/test_img_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import pytest
from copy import deepcopy
from pathlib import Path
from PIL import Image
import numpy as np
from spatialexperiment import construct_spatial_image_class
from spatialexperiment.SpatialImage import (
LoadedSpatialImage,
StoredSpatialImage,
RemoteSpatialImage
)

def test_loaded_spatial_image_img_source():
image = Image.open("tests/images/sample_image2.png")
spi_loaded = construct_spatial_image_class(image, is_url=False)

assert isinstance(spi_loaded, LoadedSpatialImage)
assert spi_loaded.img_source() is None
assert spi_loaded.img_source(as_path=True) is None

np_image = np.zeros((100, 100, 3), dtype=np.uint8)
spi_loaded_np = construct_spatial_image_class(np_image)

assert isinstance(spi_loaded_np, LoadedSpatialImage)
assert spi_loaded_np.img_source() is None
assert spi_loaded_np.img_source(as_path=True) is None


def test_stored_spatial_image_img_source():
image_path = "tests/images/sample_image1.jpg"
spi_stored = construct_spatial_image_class(image_path, is_url=False)

assert isinstance(spi_stored, StoredSpatialImage)

source_path = spi_stored.img_source()
assert isinstance(source_path, Path)
assert image_path in str(source_path)

source_str = spi_stored.img_source(as_path=True)
assert isinstance(source_str, str)
assert image_path in source_str

assert str(source_path) == str(spi_stored.path)


def test_remote_spatial_image_img_source():
image_url = "https://example.com/test_image.jpg"
spi_remote = construct_spatial_image_class(image_url, is_url=True)

assert isinstance(spi_remote, RemoteSpatialImage)

source = spi_remote.img_source()
assert isinstance(source, str)
assert source == image_url


def test_remote_spatial_image_img_source_with_mock(monkeypatch):
image_url = "https://example.com/test_image.jpg"
spi_remote = construct_spatial_image_class(image_url, is_url=True)

assert isinstance(spi_remote, RemoteSpatialImage)

# Mock the _download_image method to return a fixed path
mock_path = Path("/tmp/image.jpg")
monkeypatch.setattr(spi_remote, "_download_image", lambda: mock_path)

# Test with as_path=True (returns the cached path)
source_path = spi_remote.img_source(as_path=True)
assert source_path == str(mock_path)

# Test default behavior returns URL
assert spi_remote.img_source() == image_url


def test_img_source_no_img_data(spe):
tspe = deepcopy(spe)
tspe.img_data = None
assert not tspe.img_source()


def test_img_source_no_matches(spe):
with pytest.raises(ValueError):
sources = spe.img_source(sample_id="foo", image_id="foo")


def test_img_source_both_str(spe):
res = spe.img_source(sample_id="sample_1", image_id="dice")
expected_source = spe.get_img(sample_id="sample_1", image_id="dice").img_source()

assert isinstance(res, Path)
assert res == expected_source


def test_img_source_both_str_path(spe):
res = spe.img_source(sample_id="sample_1", image_id="dice", path=True)
expected_source = spe.get_img(sample_id="sample_1", image_id="dice").img_source(as_path=True)

assert isinstance(res, str)
assert res == expected_source


def test_img_source_both_true(spe):
res = spe.img_source(sample_id=True, image_id=True)
images = spe.get_img(sample_id=True, image_id=True)
expected_sources = [img.img_source() for img in images]

assert isinstance(res, list)
assert res == expected_sources


def test_img_source_both_true_path(spe):
res = spe.img_source(sample_id=True, image_id=True, path=True)
images = spe.get_img(sample_id=True, image_id=True)
expected_sources = [img.img_source(as_path=True) for img in images]

assert isinstance(res, list)
assert res == expected_sources


def test_img_source_both_none(spe):
res = spe.img_source(sample_id=None, image_id=None)
expected_source = spe.get_img(sample_id=None, image_id=None).img_source()

assert isinstance(res, Path)
assert res == expected_source


def test_img_source_sample_str_image_true(spe):
res = spe.img_source(sample_id="sample_1", image_id=True)
images = spe.get_img(sample_id="sample_1", image_id=True)
expected_sources = [img.img_source() for img in images]

assert isinstance(res, list)
assert res == expected_sources


def test_img_source_sample_true_image_str(spe):
res = spe.img_source(sample_id=True, image_id="desert")
expected_source = spe.get_img(sample_id=True, image_id="desert").img_source()

assert isinstance(res, Path)
assert res == expected_source


def test_img_source_sample_str_image_none(spe):
res = spe.img_source(sample_id="sample_1", image_id=None)
expected_source = spe.get_img(sample_id="sample_1", image_id=None).img_source()

assert isinstance(res, Path)
assert res == expected_source


def test_img_source_sample_none_image_str(spe):
res = spe.img_source(sample_id=None, image_id="aurora")
expected_source = spe.get_img(sample_id=None, image_id="aurora").img_source()

assert isinstance(res, Path)
assert res == expected_source


def test_img_source_sample_true_image_none(spe):
res = spe.img_source(sample_id=True, image_id=None)
images = spe.get_img(sample_id=True, image_id=None)
expected_sources = [img.img_source() for img in images]

assert isinstance(res, list)
assert res == expected_sources


def test_img_source_sample_none_image_true(spe):
res = spe.img_source(sample_id=None, image_id=True)
images = spe.get_img(sample_id=None, image_id=True)
expected_sources = [img.img_source() for img in images]

assert isinstance(res, list)
assert res == expected_sources
Loading