From f8bf582477571e0a07e0e04146f4caf37023422b Mon Sep 17 00:00:00 2001 From: David Ormrod Morley Date: Tue, 17 Mar 2026 16:12:12 +0100 Subject: [PATCH 1/2] Add plot_image_grid method SO-- --- CHANGELOG.md | 1 + src/scm/plams/__init__.py | 2 ++ src/scm/plams/tools/plot.py | 67 +++++++++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 287c2ca7a..69495f559 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ This changelog is effective from the 2025 releases. * Added a `balance` function to `scm.plams.tools.reaction`, which returns a `ReactionEquation` * `view` function to visualize molecules/chemical systems using AMSView * `config.job.on_status_change` callback which fires any time a job status is updated +* `plot_image_grid` to plot multiple images (e.g. those generated from `view`) in a grid format ### Changed * `JobAnalysis` returns an updated copy on modification instead of performing the operation in-place diff --git a/src/scm/plams/__init__.py b/src/scm/plams/__init__.py index 96bf256a1..a786aba7b 100644 --- a/src/scm/plams/__init__.py +++ b/src/scm/plams/__init__.py @@ -156,6 +156,7 @@ plot_phonons_thermodynamic_properties, plot_correlation, plot_grid_molecules, + plot_image_grid, plot_molecule, plot_msd, plot_work_function, @@ -328,6 +329,7 @@ "plot_band_structure", "plot_molecule", "plot_grid_molecules", + "plot_image_grid", "get_correlation_xy", "plot_correlation", "plot_msd", diff --git a/src/scm/plams/tools/plot.py b/src/scm/plams/tools/plot.py index 45603c155..2823f7149 100644 --- a/src/scm/plams/tools/plot.py +++ b/src/scm/plams/tools/plot.py @@ -26,6 +26,7 @@ "plot_phonons_dos", "plot_phonons_thermodynamic_properties", "plot_molecule", + "plot_image_grid", "plot_correlation", "plot_msd", "plot_work_function", @@ -408,6 +409,72 @@ def plot_grid_molecules( return img +@requires_optional_package("matplotlib") +def plot_image_grid( + images: Dict[str, "PilImage.Image"], + rows: Optional[int] = None, + cols: Optional[int] = None, + figsize: Optional[Tuple[float, float]] = None, + show_labels: bool = True, + save_path: Optional[Union[str, "PathLike"]] = None, +) -> np.ndarray: + """Plot a dictionary of images in a matplotlib grid. + + :param images: dictionary with labels as keys and images as values; iteration order determines image order in the grid + :param rows: number of rows in the grid; if ``None``, infer from ``cols`` and number of images + :param cols: number of columns in the grid; if ``None``, infer from ``rows`` and number of images + :param figsize: matplotlib figure size; if ``None``, uses a grid-proportional default + :param show_labels: whether to show labels above images; labels are taken from dictionary keys + :param save_path: optional path to save the plotted grid image using matplotlib ``savefig`` + :return: 2D numpy array of matplotlib axes with shape ``(rows, cols)`` + :rtype: np.ndarray + """ + import matplotlib.pyplot as plt + + items = list(images.items()) + n_images = len(items) + + if n_images == 0: + raise ValueError("images must contain at least one image") + + if rows is not None and rows <= 0: + raise ValueError(f"rows must be a positive integer when provided, but got {rows}") + if cols is not None and cols <= 0: + raise ValueError(f"cols must be a positive integer when provided, but got {cols}") + + if rows is None and cols is None: + cols = int(np.ceil(np.sqrt(n_images))) + rows = int(np.ceil(n_images / cols)) + elif rows is None: + rows = int(np.ceil(n_images / cols)) + elif cols is None: + cols = int(np.ceil(n_images / rows)) + + grid_size = rows * cols + if n_images > grid_size: + raise ValueError(f"Grid of shape ({rows}, {cols}) can hold at most {grid_size} images, but got {n_images}") + + if figsize is None: + figsize = ((4.0 * cols), (4.0 * rows)) + fig, axes = plt.subplots(rows, cols, figsize=figsize) + axes = np.array(axes, dtype=object).reshape(rows, cols) + + for ax in axes.flat: + ax.axis("off") + + for i, (key, image) in enumerate(items): + row, col = divmod(i, cols) + ax = cast(Any, axes[row, col]) + ax.imshow(image) + if show_labels: + ax.set_title(key) + + if save_path is not None: + fig.savefig(save_path) + + return axes + + def get_correlation_xy( job1: Union[AMSJob, List[AMSJob]], job2: Union[AMSJob, List[AMSJob]], From 319f211029b41f23be1f7a623575dbb26b74087a Mon Sep 17 00:00:00 2001 From: David Ormrod Morley Date: Tue, 17 Mar 2026 16:46:11 +0100 Subject: [PATCH 2/2] Add mypy ignores --- src/scm/plams/tools/plot.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/scm/plams/tools/plot.py b/src/scm/plams/tools/plot.py index 2823f7149..7686a3f8c 100644 --- a/src/scm/plams/tools/plot.py +++ b/src/scm/plams/tools/plot.py @@ -446,24 +446,24 @@ def plot_image_grid( cols = int(np.ceil(np.sqrt(n_images))) rows = int(np.ceil(n_images / cols)) elif rows is None: - rows = int(np.ceil(n_images / cols)) + rows = int(np.ceil(n_images / cols)) # type: ignore[operator] elif cols is None: cols = int(np.ceil(n_images / rows)) - grid_size = rows * cols + grid_size = rows * cols # type: ignore[operator] if n_images > grid_size: raise ValueError(f"Grid of shape ({rows}, {cols}) can hold at most {grid_size} images, but got {n_images}") if figsize is None: - figsize = ((4.0 * cols), (4.0 * rows)) - fig, axes = plt.subplots(rows, cols, figsize=figsize) - axes = np.array(axes, dtype=object).reshape(rows, cols) + figsize = ((4.0 * cols), (4.0 * rows)) # type: ignore[operator] + fig, axes = plt.subplots(rows, cols, figsize=figsize) # type: ignore[arg-type] + axes = np.array(axes, dtype=object).reshape(rows, cols) # type: ignore[arg-type] for ax in axes.flat: ax.axis("off") for i, (key, image) in enumerate(items): - row, col = divmod(i, cols) + row, col = divmod(i, cols) # type: ignore[operator] ax = cast(Any, axes[row, col]) ax.imshow(image) if show_labels: