diff --git a/docs/examples/KFe2As2-00838.tif b/docs/examples/data/KFe2As2-00838.tif similarity index 100% rename from docs/examples/KFe2As2-00838.tif rename to docs/examples/data/KFe2As2-00838.tif diff --git a/news/load-image-fix.rst b/news/load-image-fix.rst new file mode 100644 index 0000000..cbbb56e --- /dev/null +++ b/news/load-image-fix.rst @@ -0,0 +1,23 @@ +**Added:** + +* + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* `load_image()` function correctly finds files when passed a relative path. + +**Security:** + +* diff --git a/src/diffpy/srxplanar/loadimage.py b/src/diffpy/srxplanar/loadimage.py index b30e97e..88143ea 100644 --- a/src/diffpy/srxplanar/loadimage.py +++ b/src/diffpy/srxplanar/loadimage.py @@ -16,6 +16,7 @@ import fnmatch import os import time +from pathlib import Path import numpy as np @@ -24,14 +25,14 @@ try: import fabio - def openImage(im): + def open_image(im): rv = fabio.openimage.openimage(im) return rv.data except ImportError: import tifffile - def openImage(im): + def open_image(im): rv = tifffile.imread(im) return rv @@ -53,7 +54,7 @@ def __init__(self, p): self.config = p return - def flipImage(self, pic): + def flip_image(self, pic): """Flip image if configured in config. :param pic: 2d array, image array @@ -65,32 +66,41 @@ def flipImage(self, pic): pic = np.array(pic[::-1, :]) return pic - def loadImage(self, filename): - """Load image file, if failed (for example loading an incomplete - file), then it will keep trying loading file for 5s. + def load_image(self, filename): + """Load image file. If loading fails (e.g. incomplete file), + retry for 5 seconds (10×0.5s). - :param filename: str, image file name - :return: 2d ndarray, 2d image array (flipped) + :param filename: str or Path, image file name or path + :return: 2D ndarray, flipped image array """ - if os.path.exists(filename): + filename = Path( + filename + ).expanduser() # handle "~", make it a Path object + if filename.exists(): filenamefull = filename else: - filenamefull = os.path.join(self.opendirectory, filename) - image = np.zeros(10000).reshape(100, 100) - if os.path.exists(filenamefull): - i = 0 - while i < 10: - try: - if os.path.splitext(filenamefull)[-1] == ".npy": - image = np.load(filenamefull) - else: - image = openImage(filenamefull) - i = 10 - except FileNotFoundError: - i = i + 1 - time.sleep(0.5) - image = self.flipImage(image) - image[image < 0] = 0 + found_files = list(Path.home().rglob(filename.name)) + filenamefull = found_files[0] if found_files else None + + if filenamefull is None or not filenamefull.exists(): + raise FileNotFoundError( + f"Error: file not found: {filename}, " + f"Please rerun specifying a valid filename." + ) + return np.zeros((100, 100)) + + image = np.zeros((100, 100)) + for _ in range(10): # retry 10 times (5 seconds total) + try: + if filenamefull.suffix == ".npy": + image = np.load(filenamefull) + else: + image = open_image(filenamefull) + break + except FileNotFoundError: + time.sleep(0.5) + image = self.flip_image(image) + image[image < 0] = 0 return image def genFileList( diff --git a/src/diffpy/srxplanar/srxplanar.py b/src/diffpy/srxplanar/srxplanar.py index d4c025d..261766c 100644 --- a/src/diffpy/srxplanar/srxplanar.py +++ b/src/diffpy/srxplanar/srxplanar.py @@ -155,7 +155,7 @@ def _getPic(self, image, flip=None, correction=None): rv += self._getPic(imagefile) rv /= len(image) elif isinstance(image, str): - rv = self.loadimage.loadImage(image) + rv = self.loadimage.load_image(image) if correction is None or correction is True: ce = self.config.cropedges rv[ce[2] : -ce[3], ce[0] : -ce[1]] = ( @@ -165,7 +165,7 @@ def _getPic(self, image, flip=None, correction=None): else: rv = image if flip is True: - rv = self.loadimage.flipImage(rv) + rv = self.loadimage.flip_image(rv) if correction is True: # rv *= self.correction ce = self.config.cropedges @@ -339,13 +339,13 @@ def createMask(self, filename=None, pic=None, addmask=None): pic = self.pic else: pic = ( - self.loadimage.loadImage(filelist[0]) + self.loadimage.load_image(filelist[0]) if len(filelist) > 0 else None ) else: pic = ( - self.loadimage.loadImage(filelist[0]) + self.loadimage.load_image(filelist[0]) if len(filelist) > 0 else None ) diff --git a/tests/conftest.py b/tests/conftest.py index e3b6313..ca56905 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,13 +7,22 @@ @pytest.fixture def user_filesystem(tmp_path): base_dir = Path(tmp_path) - home_dir = base_dir / "home_dir" - home_dir.mkdir(parents=True, exist_ok=True) cwd_dir = base_dir / "cwd_dir" - cwd_dir.mkdir(parents=True, exist_ok=True) + home_dir = base_dir / "home_dir" + test_dir = base_dir / "test_dir" + for dir in (cwd_dir, home_dir, test_dir): + dir.mkdir(parents=True, exist_ok=True) - home_config_data = {"username": "home_username", "email": "home@email.com"} + home_config_data = { + "username": "home_username", + "email": "home@email.com", + } with open(home_dir / "diffpyconfig.json", "w") as f: json.dump(home_config_data, f) - yield tmp_path + yield { + "base": base_dir, + "cwd": cwd_dir, + "home": home_dir, + "test": test_dir, + } diff --git a/tests/test_load_image.py b/tests/test_load_image.py new file mode 100644 index 0000000..3324140 --- /dev/null +++ b/tests/test_load_image.py @@ -0,0 +1,66 @@ +import os +import shutil +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from diffpy.srxplanar.loadimage import LoadImage + +PROJECT_ROOT = Path(__file__).resolve().parents[1] + +load_image_param = [ + # case 1: just filename of file in current directory. + # expect function loads tiff file from cwd + (["KFe2As2-00838.tif", False, False], [0, 26, 173]), + # case 2: absolute file path to file in another directory. + # expect file is found and correctly read. + ( + ["home_dir/KFe2As2-00838.tif", True, False], + [102, 57, 136], + ), + # case 3: relative file path to file in another directory. + # expect file is found and correctly read + (["./KFe2As2-00838.tif", False, True], [39, 7, 0]), + # case 4: non-existent file that incurred by mistype. + ( + ["nonexistent_file.tif", False, False], + FileNotFoundError, + ), + # case 5: relative file path to file in another directory. + # expect file to be flip both horizontally and vertically + # and correctly read + (["./KFe2As2-00838.tif", True, True], [0, 53, 21]), +] + + +@pytest.mark.parametrize("inputs, expected", load_image_param) +def test_load_image(inputs, expected, user_filesystem): + home_dir = user_filesystem["home"] + cwd_dir = user_filesystem["cwd"] + os.chdir(cwd_dir) + + expected_mean = 2595.7087 + expected_shape = (2048, 2048) + + # locate source example file inside project docs + source_file = ( + PROJECT_ROOT / "docs" / "examples" / "data" / "KFe2As2-00838.tif" + ) + shutil.copy(source_file, cwd_dir / "KFe2As2-00838.tif") + shutil.copy(source_file, home_dir / "KFe2As2-00838.tif") + config = SimpleNamespace(fliphorizontal=inputs[1], flipvertical=inputs[2]) + try: + loader = LoadImage(config) + actual = loader.load_image(inputs[0]) + assert actual.shape == expected_shape + assert actual.mean() == expected_mean + assert actual[1][0] == expected[0] + assert actual[1][1] == expected[1] + assert actual[2][5] == expected[2] + except FileNotFoundError: + pytest.raises( + FileNotFoundError, + match=r"file not found:" + r" .*Please rerun specifying a valid filename\.", + )