Skip to content
2 changes: 1 addition & 1 deletion docs/notebooks
254 changes: 226 additions & 28 deletions src/squidpy/tl/_sliding_window.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from itertools import product
from typing import Literal

import numpy as np
import pandas as pd
Expand All @@ -23,7 +24,8 @@ def sliding_window(
coord_columns: tuple[str, str] = ("globalX", "globalY"),
sliding_window_key: str = "sliding_window_assignment",
spatial_key: str = "spatial",
drop_partial_windows: bool = False,
partial_windows: Literal["adaptive", "drop", "split"] | None = None,
max_nr_cells: int | None = None,
copy: bool = False,
) -> pd.DataFrame | None:
"""
Expand All @@ -42,8 +44,14 @@ def sliding_window(
overlap: int
Overlap size between consecutive windows. (0 = no overlap)
%(spatial_key)s
drop_partial_windows: bool
If True, drop windows that are smaller than the window size at the borders.
partial_windows: Literal["adaptive", "drop", "split"] | None
If None, possibly small windows at the edges are kept.
If `adaptive`, all windows might be shrunken a bit to avoid small windows at the edges.
If `drop`, possibly small windows at the edges are removed.
If `split`, windows are split into subwindows until not exceeding `max_nr_cells`
max_nr_cells: int | None
The maximum number of cells allowed after merging two windows.
Required if `partial_windows = split`
copy: bool
If True, return the result, otherwise save it to the adata object.

Expand All @@ -52,8 +60,18 @@ def sliding_window(
If ``copy = True``, returns the sliding window annotation(s) as pandas dataframe
Otherwise, stores the sliding window annotation(s) in .obs.
"""
if overlap < 0:
raise ValueError("Overlap must be non-negative.")
if partial_windows == "split":
if max_nr_cells is None:
raise ValueError("`max_nr_cells` must be set when `partial_windows == split`.")
if window_size is not None:
logg.warning(f"Ingoring `window_size` when using `{partial_windows}`")
if overlap != 0:
logg.warning("Ignoring `overlap` as it cannot be used with `split`")
else:
if max_nr_cells is not None:
logg.warning("Ignoring `max_nr_cells` as `partial_windows != split`")
if overlap < 0:
raise ValueError("Overlap must be non-negative.")

if isinstance(adata, SpatialData):
adata = adata.table
Expand Down Expand Up @@ -86,8 +104,13 @@ def sliding_window(
# mostly arbitrary choice, except that full integers usually generate windows with 1-2 cells at the borders
window_size = max(int(np.floor(coord_range // 3.95)), 1)

if window_size <= 0:
raise ValueError("Window size must be larger than 0.")
if partial_windows != "split":
if window_size <= 0:
raise ValueError("Window size must be larger than 0.")
if overlap >= window_size:
raise ValueError("Overlap must be less than the window size.")
if overlap >= window_size // 2 and partial_windows == "adaptive":
raise ValueError("Overlap must be less than `window_size` // 2 when using `adaptive`.")

if library_key is not None and library_key not in adata.obs:
raise ValueError(f"Library key '{library_key}' not found in adata.obs")
Expand Down Expand Up @@ -119,7 +142,10 @@ def sliding_window(
max_y=max_y,
window_size=window_size,
overlap=overlap,
drop_partial_windows=drop_partial_windows,
partial_windows=partial_windows,
lib_coords=lib_coords,
coord_columns=(x_col, y_col),
max_nr_cells=max_nr_cells,
)

lib_key = f"{lib}_" if lib is not None else ""
Expand All @@ -131,15 +157,17 @@ def sliding_window(
y_start = window["y_start"]
y_end = window["y_end"]

mask = (
(lib_coords[x_col] >= x_start)
& (lib_coords[x_col] <= x_end)
& (lib_coords[y_col] >= y_start)
& (lib_coords[y_col] <= y_end)
mask = _get_window_mask(
coord_columns=(x_col, y_col),
lib_coords=lib_coords,
x_start=x_start,
x_end=x_end,
y_start=y_start,
y_end=y_end,
)
obs_indices = lib_coords.index[mask]

if overlap == 0:
if overlap == 0 or partial_windows == "split":
mask = (
(lib_coords[x_col] >= x_start)
& (lib_coords[x_col] <= x_end)
Expand Down Expand Up @@ -175,14 +203,61 @@ def sliding_window(
_save_data(adata, attr="obs", key=col_name, data=col_data)


def _get_window_mask(
coord_columns: tuple[str, str],
lib_coords: pd.DataFrame,
x_start: int,
x_end: int,
y_start: int,
y_end: int,
) -> pd.Series:
"""
Compute a boolean mask selecting coordinates that fall within a given window.

Parameters
----------
coord_columns: Tuple[str, str]
Tuple of column names in `adata.obs` that specify the coordinates (x, y), i.e. ('globalX', 'globalY')
lib_coords: pd.DataFrame
DataFrame containing spatial coordinates (e.g. `adata.obs` subset for one library).
Coordinate values are expected to be integers.
x_start: int
Lower bound of the window in x-direction (inclusive).
x_end: int
Upper bound of the window in x-direction (inclusive).
y_start: int
Lower bound of the window in y-direction (inclusive).
y_end: int
Upper bound of the window in y-direction (inclusive).

Returns
-------
pd.Series
Boolean mask indicating which rows in `lib_coords` fall inside the specified window.
"""
x_col, y_col = coord_columns

mask = (
(lib_coords[x_col] >= x_start)
& (lib_coords[x_col] <= x_end)
& (lib_coords[y_col] >= y_start)
& (lib_coords[y_col] <= y_end)
)

return mask


def _calculate_window_corners(
min_x: int,
max_x: int,
min_y: int,
max_y: int,
window_size: int,
overlap: int = 0,
drop_partial_windows: bool = False,
partial_windows: Literal["adaptive", "drop", "split"] | None = None,
lib_coords: pd.DataFrame | None = None,
coord_columns: tuple[str, str] | None = None,
max_nr_cells: int | None = None,
) -> pd.DataFrame:
"""
Calculate the corner points of all windows covering the area from min_x to max_x and min_y to max_y,
Expand All @@ -200,23 +275,45 @@ def _calculate_window_corners(
maximum Y coordinate
window_size: float
size of each window
lib_coords: pd.DataFrame | None
coordinates of all samples for one library
coord_columns: Tuple[str, str]
Tuple of column names in `adata.obs` that specify the coordinates (x, y), i.e. ('globalX', 'globalY')
overlap: float
overlap between consecutive windows (must be less than window_size)
drop_partial_windows: bool
if True, drop border windows that are smaller than window_size;
if False, create smaller windows at the borders to cover the remaining space.
partial_windows: Literal["adaptive", "drop", "split"] | None
If None, possibly small windows at the edges are kept.
If 'adaptive', all windows might be shrunken a bit to avoid small windows at the edges.
If 'drop', possibly small windows at the edges are removed.
If 'split', windows are split into subwindows until not exceeding `max_nr_cells`

Returns
-------
windows: pandas DataFrame with columns ['x_start', 'x_end', 'y_start', 'y_end']
"""
if overlap < 0:
raise ValueError("Overlap must be non-negative.")
if overlap >= window_size:
raise ValueError("Overlap must be less than the window size.")
# adjust x and y window size if 'adaptive'
if partial_windows == "adaptive":
total_width = max_x - min_x
total_height = max_y - min_y

# number of windows in x and y direction
number_x_windows = np.ceil((total_width - overlap) / (window_size - overlap))
number_y_windows = np.ceil((total_height - overlap) / (window_size - overlap))

# window size in x and y direction
x_window_size = (total_width + (number_x_windows - 1) * overlap) / number_x_windows
y_window_size = (total_height + (number_y_windows - 1) * overlap) / number_y_windows

# avoid float errors
x_window_size = np.ceil(x_window_size)
y_window_size = np.ceil(y_window_size)
else:
x_window_size = window_size
y_window_size = window_size

x_step = window_size - overlap
y_step = window_size - overlap
# create the step sizes for each window
x_step = x_window_size - overlap
y_step = y_window_size - overlap

# Generate starting points
x_starts = np.arange(min_x, max_x, x_step)
Expand All @@ -225,16 +322,117 @@ def _calculate_window_corners(
# Create all combinations of x and y starting points
starts = list(product(x_starts, y_starts))
windows = pd.DataFrame(starts, columns=["x_start", "y_start"])
windows["x_end"] = windows["x_start"] + window_size
windows["y_end"] = windows["y_start"] + window_size
windows["x_end"] = windows["x_start"] + x_window_size
windows["y_end"] = windows["y_start"] + y_window_size

# Adjust windows that extend beyond the bounds
if not drop_partial_windows:
if partial_windows is None:
windows["x_end"] = windows["x_end"].clip(upper=max_x)
windows["y_end"] = windows["y_end"].clip(upper=max_y)
else:
elif partial_windows == "adaptive":
# as window_size is an integer to avoid float errors, it can exceed max_x and max_y -> clip
windows["x_end"] = windows["x_end"].clip(upper=max_x)
windows["y_end"] = windows["y_end"].clip(upper=max_y)

# remove redundant windows in the corners
redundant_windows = ((windows["x_end"] - windows["x_start"]) <= overlap) | (
(windows["y_end"] - windows["y_start"]) <= overlap
)
windows = windows[~redundant_windows]
elif partial_windows == "drop":
valid_windows = (windows["x_end"] <= max_x) & (windows["y_end"] <= max_y)
windows = windows[valid_windows]
elif partial_windows == "split":
# split the slide recursively into windows with at most max_nr_cells
x_col, y_col = coord_columns

coord_x_sorted = lib_coords.sort_values(by=[x_col])
coord_y_sorted = lib_coords.sort_values(by=[y_col])

windows = _split_window(
max_nr_cells, (x_col, y_col), coord_x_sorted, coord_y_sorted, min_x, max_x, min_y, max_y
).sort_values(["x_start", "x_end", "y_start", "y_end"])
else:
raise ValueError(f"{partial_windows} is not a valid partial_windows argument.")

windows = windows.reset_index(drop=True)
return windows[["x_start", "x_end", "y_start", "y_end"]]


def _split_window(
max_cells: int,
coord_columns: tuple[str, str],
coord_x_sorted: pd.DataFrame,
coord_y_sorted: pd.DataFrame,
x_start: int,
x_end: int,
y_start: int,
y_end: int,
) -> pd.DataFrame:
"""
Recursively split a rectangular window into subwindows such that each subwindow
contains at most `max_cells` cells and at least `max_cells` // 2 cells.

Parameters
----------
max_cells : int
Maximum number of cells allowed per window.
coord_columns: Tuple[str, str]
Tuple of column names in `adata.obs` that specify the coordinates (x, y), i.e. ('globalX', 'globalY')
coord_x_sorted : pandas.DataFrame
DataFrame containing cell coordinates, sorted by `x_col`.
coord_y_sorted : pandas.DataFrame
DataFrame containing cell coordinates, sorted by `y_col`.
x_start : int
Left (minimum) x coordinate of the current window.
x_end : int
Right (maximum) x coordinate of the current window.
y_start : int
Bottom (minimum) y coordinate of the current window.
y_end : int
Top (maximum) y coordinate of the current window.

Returns
-------
windows: pandas DataFrame with columns ['x_start', 'x_end', 'y_start', 'y_end']
"""
x_col, y_col = coord_columns

# return current window if it contains less cells than max_cells
n_cells = _get_window_mask(coord_columns, coord_x_sorted, x_start, x_end, y_start, y_end).sum()

if n_cells <= max_cells:
return pd.DataFrame({"x_start": [x_start], "x_end": [x_end], "y_start": [y_start], "y_end": [y_end]})

# define start and stop indices of subsetted windows
sub_coord_x_sorted = coord_x_sorted[
_get_window_mask(coord_columns, coord_x_sorted, x_start, x_end, y_start, y_end)
].reset_index(drop=True)

sub_coord_y_sorted = coord_y_sorted[
_get_window_mask(coord_columns, coord_y_sorted, x_start, x_end, y_start, y_end)
].reset_index(drop=True)

middle_pos = len(sub_coord_x_sorted) // 2

if (x_end - x_start) > (y_end - y_start):
# vertical split
x_middle = sub_coord_x_sorted[x_col].iloc[middle_pos]

indices = ((x_start, x_middle, y_start, y_end), (x_middle, x_end, y_start, y_end))
else:
# horizontal split
y_middle = sub_coord_y_sorted.loc[middle_pos, y_col]

indices = ((x_start, x_end, y_start, y_middle), (x_start, x_end, y_middle, y_end))

# recursively continue with either left&right or upper&lower windows pairs
windows = []
for x_start, x_end, y_start, y_end in indices:
windows.append(
_split_window(
max_cells, (x_col, y_col), sub_coord_x_sorted, sub_coord_y_sorted, x_start, x_end, y_start, y_end
)
)

return pd.concat(windows)
Loading