Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions src/spatialexperiment/SpatialImage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import biocutils as ut
import numpy as np
import requests
from PIL import Image
from PIL import Image, ImageChops

__author__ = "jkanche"
__copyright__ = "jkanche"
__author__ = "jkanche, keviny2"
__copyright__ = "jkanche, keviny2"
__license__ = "MIT"


Expand All @@ -24,6 +24,21 @@ class VirtualSpatialImage(ABC):
def __init__(self, metadata: Optional[dict] = None):
self._metadata = metadata if metadata is not None else {}

#########################
######>> Equality <<#####
#########################

def __eq__(self, other) -> bool:
if not isinstance(other, type(self)):
return False

return self.metadata == other.metadata

def __hash__(self):
# Note: This exists primarily to support lru_cache.
# Generally, these classes are mutable and shouldn't be used as dict keys or in sets.
return hash(frozenset(self._metadata.items()))

###########################
######>> metadata <<#######
###########################
Expand Down Expand Up @@ -151,6 +166,18 @@ def _define_output(self, in_place: bool = False) -> "LoadedSpatialImage":
else:
return self.__copy__()

#########################
######>> Equality <<#####
#########################

def __eq__(self, other) -> bool:
diff = ImageChops.difference(self.image, other.image)

return super().__eq__(other) and not diff.getbbox()

def __hash__(self):
return hash((super().__hash__(), self._image.tobytes()))

#########################
######>> Copying <<######
#########################
Expand Down Expand Up @@ -294,6 +321,16 @@ def _define_output(self, in_place: bool = False) -> "LoadedSpatialImage":
else:
return self.__copy__()

#########################
######>> Equality <<#####
#########################

def __eq__(self, other):
return super().__eq__(other) and self.path == other.path

def __hash__(self):
return hash((super().__hash__(), str(self._path)))

#########################
######>> Copying <<######
#########################
Expand Down Expand Up @@ -454,6 +491,16 @@ def _define_output(self, in_place: bool = False) -> "RemoteSpatialImage":
else:
return self.__copy__()

#########################
######>> Equality <<#####
#########################

def __eq__(self, other) -> bool:
return super().__eq__(other) and self.url == other.url

def __hash__(self):
return hash((super().__hash__(), self._url))

#########################
######>> Copying <<######
#########################
Expand Down
2 changes: 1 addition & 1 deletion tests/test_img_data_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_add_img(spe):
image_id="unsplash",
)

tspe.img_data.shape[0] == spe.img_data.shape[0] + 1
assert tspe.img_data.shape[0] == spe.img_data.shape[0] + 1


def test_add_img_already_exists(spe):
Expand Down
48 changes: 0 additions & 48 deletions tests/test_si.py

This file was deleted.

72 changes: 72 additions & 0 deletions tests/test_spi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest
from PIL import Image
from spatialexperiment import construct_spatial_image_class
from spatialexperiment.SpatialImage import VirtualSpatialImage, StoredSpatialImage, LoadedSpatialImage, RemoteSpatialImage

__author__ = "keviny2"
__copyright__ = "keviny2"
__license__ = "MIT"


def test_spi_constructor_path():
spi = construct_spatial_image_class("tests/images/sample_image1.jpg", is_url=False)

assert issubclass(type(spi), VirtualSpatialImage)
assert isinstance(spi, StoredSpatialImage)

assert "tests/images/sample_image1.jpg" in str(spi.path)


def test_spi_constructor_spi():
spi_1 = construct_spatial_image_class("tests/images/sample_image1.jpg", is_url=False)
spi_2 = construct_spatial_image_class(spi_1, is_url=False)

assert issubclass(type(spi_2), VirtualSpatialImage)
assert isinstance(spi_2, StoredSpatialImage)

assert str(spi_1.path) == str(spi_2.path)


def test_spi_constructor_image():
image = Image.open("tests/images/sample_image2.png")
spi = construct_spatial_image_class(image, is_url=False)

assert issubclass(type(spi), VirtualSpatialImage)
assert isinstance(spi, LoadedSpatialImage)

assert spi.image == image


def test_spi_constructor_url():
image_url = "https://i.redd.it/3pw5uah7xo041.jpg"
spi_remote = construct_spatial_image_class(image_url, is_url=True)
assert issubclass(type(spi_remote), VirtualSpatialImage)
assert isinstance(spi_remote, RemoteSpatialImage)
assert spi_remote.url == image_url


def test_invalid_input():
with pytest.raises(Exception):
construct_spatial_image_class(5, is_url=False)

def test_spi_equality():
spi_path_1 = construct_spatial_image_class("tests/images/sample_image1.jpg", is_url=False)
spi_path_2 = construct_spatial_image_class("tests/images/sample_image1.jpg", is_url=False)

assert spi_path_1 == spi_path_2

image_url = "https://i.redd.it/3pw5uah7xo041.jpg"
spi_url_1 = construct_spatial_image_class(image_url, is_url=True)
spi_url_2 = construct_spatial_image_class(image_url, is_url=True)

assert spi_url_1 == spi_url_2

image = Image.open("tests/images/sample_image2.png")
spi_image_1 = construct_spatial_image_class(image, is_url=False)
spi_image_2 = construct_spatial_image_class(image, is_url=False)

assert spi_image_1 == spi_image_2

assert spi_path_1 != spi_url_1
assert spi_path_1 != spi_image_1
assert spi_url_1 != spi_image_1
Loading