Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ This changelog is effective from the 2025 releases.
* Added a `balance` function to `scm.plams.tools.reaction`, which returns a `ReactionEquation`
* `view` function to visualize molecules/chemical systems using AMSView
* `config.job.on_status_change` callback which fires any time a job status is updated
* `plot_image_grid` to plot multiple images (e.g. those generated from `view`) in a grid format

### Changed
* `JobAnalysis` returns an updated copy on modification instead of performing the operation in-place
Expand Down
2 changes: 2 additions & 0 deletions src/scm/plams/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@
plot_phonons_thermodynamic_properties,
plot_correlation,
plot_grid_molecules,
plot_image_grid,
plot_molecule,
plot_msd,
plot_work_function,
Expand Down Expand Up @@ -328,6 +329,7 @@
"plot_band_structure",
"plot_molecule",
"plot_grid_molecules",
"plot_image_grid",
"get_correlation_xy",
"plot_correlation",
"plot_msd",
Expand Down
67 changes: 67 additions & 0 deletions src/scm/plams/tools/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"plot_phonons_dos",
"plot_phonons_thermodynamic_properties",
"plot_molecule",
"plot_image_grid",
"plot_correlation",
"plot_msd",
"plot_work_function",
Expand Down Expand Up @@ -408,6 +409,72 @@ def plot_grid_molecules(
return img


@requires_optional_package("matplotlib")
def plot_image_grid(
images: Dict[str, "PilImage.Image"],
rows: Optional[int] = None,
cols: Optional[int] = None,
figsize: Optional[Tuple[float, float]] = None,
show_labels: bool = True,
save_path: Optional[Union[str, "PathLike"]] = None,
) -> np.ndarray:
"""Plot a dictionary of images in a matplotlib grid.

:param images: dictionary with labels as keys and images as values; iteration order determines image order in the grid
:param rows: number of rows in the grid; if ``None``, infer from ``cols`` and number of images
:param cols: number of columns in the grid; if ``None``, infer from ``rows`` and number of images
:param figsize: matplotlib figure size; if ``None``, uses a grid-proportional default
:param show_labels: whether to show labels above images; labels are taken from dictionary keys
:param save_path: optional path to save the plotted grid image using matplotlib ``savefig``
:return: 2D numpy array of matplotlib axes with shape ``(rows, cols)``
:rtype: np.ndarray
"""
import matplotlib.pyplot as plt

items = list(images.items())
n_images = len(items)

if n_images == 0:
raise ValueError("images must contain at least one image")

if rows is not None and rows <= 0:
raise ValueError(f"rows must be a positive integer when provided, but got {rows}")
if cols is not None and cols <= 0:
raise ValueError(f"cols must be a positive integer when provided, but got {cols}")

if rows is None and cols is None:
cols = int(np.ceil(np.sqrt(n_images)))
rows = int(np.ceil(n_images / cols))
elif rows is None:
rows = int(np.ceil(n_images / cols)) # type: ignore[operator]
elif cols is None:
cols = int(np.ceil(n_images / rows))

grid_size = rows * cols # type: ignore[operator]
if n_images > grid_size:
raise ValueError(f"Grid of shape ({rows}, {cols}) can hold at most {grid_size} images, but got {n_images}")

if figsize is None:
figsize = ((4.0 * cols), (4.0 * rows)) # type: ignore[operator]
fig, axes = plt.subplots(rows, cols, figsize=figsize) # type: ignore[arg-type]
axes = np.array(axes, dtype=object).reshape(rows, cols) # type: ignore[arg-type]

for ax in axes.flat:
ax.axis("off")

for i, (key, image) in enumerate(items):
row, col = divmod(i, cols) # type: ignore[operator]
ax = cast(Any, axes[row, col])
ax.imshow(image)
if show_labels:
ax.set_title(key)

if save_path is not None:
fig.savefig(save_path)

return axes


def get_correlation_xy(
job1: Union[AMSJob, List[AMSJob]],
job2: Union[AMSJob, List[AMSJob]],
Expand Down
Loading