diff --git a/CHANGELOG.md b/CHANGELOG.md index 287c2ca7..69495f55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ This changelog is effective from the 2025 releases. * Added a `balance` function to `scm.plams.tools.reaction`, which returns a `ReactionEquation` * `view` function to visualize molecules/chemical systems using AMSView * `config.job.on_status_change` callback which fires any time a job status is updated +* `plot_image_grid` to plot multiple images (e.g. those generated from `view`) in a grid format ### Changed * `JobAnalysis` returns an updated copy on modification instead of performing the operation in-place diff --git a/src/scm/plams/__init__.py b/src/scm/plams/__init__.py index 96bf256a..a786aba7 100644 --- a/src/scm/plams/__init__.py +++ b/src/scm/plams/__init__.py @@ -156,6 +156,7 @@ plot_phonons_thermodynamic_properties, plot_correlation, plot_grid_molecules, + plot_image_grid, plot_molecule, plot_msd, plot_work_function, @@ -328,6 +329,7 @@ "plot_band_structure", "plot_molecule", "plot_grid_molecules", + "plot_image_grid", "get_correlation_xy", "plot_correlation", "plot_msd", diff --git a/src/scm/plams/tools/plot.py b/src/scm/plams/tools/plot.py index 45603c15..7686a3f8 100644 --- a/src/scm/plams/tools/plot.py +++ b/src/scm/plams/tools/plot.py @@ -26,6 +26,7 @@ "plot_phonons_dos", "plot_phonons_thermodynamic_properties", "plot_molecule", + "plot_image_grid", "plot_correlation", "plot_msd", "plot_work_function", @@ -408,6 +409,72 @@ def plot_grid_molecules( return img +@requires_optional_package("matplotlib") +def plot_image_grid( + images: Dict[str, "PilImage.Image"], + rows: Optional[int] = None, + cols: Optional[int] = None, + figsize: Optional[Tuple[float, float]] = None, + show_labels: bool = True, + save_path: Optional[Union[str, "PathLike"]] = None, +) -> np.ndarray: + """Plot a dictionary of images in a matplotlib grid. + + :param images: dictionary with labels as keys and images as values; iteration order determines image order in the grid + :param rows: number of rows in the grid; if ``None``, infer from ``cols`` and number of images + :param cols: number of columns in the grid; if ``None``, infer from ``rows`` and number of images + :param figsize: matplotlib figure size; if ``None``, uses a grid-proportional default + :param show_labels: whether to show labels above images; labels are taken from dictionary keys + :param save_path: optional path to save the plotted grid image using matplotlib ``savefig`` + :return: 2D numpy array of matplotlib axes with shape ``(rows, cols)`` + :rtype: np.ndarray + """ + import matplotlib.pyplot as plt + + items = list(images.items()) + n_images = len(items) + + if n_images == 0: + raise ValueError("images must contain at least one image") + + if rows is not None and rows <= 0: + raise ValueError(f"rows must be a positive integer when provided, but got {rows}") + if cols is not None and cols <= 0: + raise ValueError(f"cols must be a positive integer when provided, but got {cols}") + + if rows is None and cols is None: + cols = int(np.ceil(np.sqrt(n_images))) + rows = int(np.ceil(n_images / cols)) + elif rows is None: + rows = int(np.ceil(n_images / cols)) # type: ignore[operator] + elif cols is None: + cols = int(np.ceil(n_images / rows)) + + grid_size = rows * cols # type: ignore[operator] + if n_images > grid_size: + raise ValueError(f"Grid of shape ({rows}, {cols}) can hold at most {grid_size} images, but got {n_images}") + + if figsize is None: + figsize = ((4.0 * cols), (4.0 * rows)) # type: ignore[operator] + fig, axes = plt.subplots(rows, cols, figsize=figsize) # type: ignore[arg-type] + axes = np.array(axes, dtype=object).reshape(rows, cols) # type: ignore[arg-type] + + for ax in axes.flat: + ax.axis("off") + + for i, (key, image) in enumerate(items): + row, col = divmod(i, cols) # type: ignore[operator] + ax = cast(Any, axes[row, col]) + ax.imshow(image) + if show_labels: + ax.set_title(key) + + if save_path is not None: + fig.savefig(save_path) + + return axes + + def get_correlation_xy( job1: Union[AMSJob, List[AMSJob]], job2: Union[AMSJob, List[AMSJob]],