diff --git a/docs/notebooks b/docs/notebooks index 510b92da9..17d368281 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit 510b92da918efbb8b38cf7d5c3989b8e3ed19618 +Subproject commit 17d368281ea7b11f7e1174436bbc2429191a0245 diff --git a/src/squidpy/tl/_sliding_window.py b/src/squidpy/tl/_sliding_window.py index 72a06eac8..634efc349 100644 --- a/src/squidpy/tl/_sliding_window.py +++ b/src/squidpy/tl/_sliding_window.py @@ -1,6 +1,7 @@ from __future__ import annotations from itertools import product +from typing import Literal import numpy as np import pandas as pd @@ -23,7 +24,8 @@ def sliding_window( coord_columns: tuple[str, str] = ("globalX", "globalY"), sliding_window_key: str = "sliding_window_assignment", spatial_key: str = "spatial", - drop_partial_windows: bool = False, + partial_windows: Literal["adaptive", "drop", "split"] | None = None, + max_nr_cells: int | None = None, copy: bool = False, ) -> pd.DataFrame | None: """ @@ -42,8 +44,14 @@ def sliding_window( overlap: int Overlap size between consecutive windows. (0 = no overlap) %(spatial_key)s - drop_partial_windows: bool - If True, drop windows that are smaller than the window size at the borders. + partial_windows: Literal["adaptive", "drop", "split"] | None + If None, possibly small windows at the edges are kept. + If `adaptive`, all windows might be shrunken a bit to avoid small windows at the edges. + If `drop`, possibly small windows at the edges are removed. + If `split`, windows are split into subwindows until not exceeding `max_nr_cells` + max_nr_cells: int | None + The maximum number of cells allowed after merging two windows. + Required if `partial_windows = split` copy: bool If True, return the result, otherwise save it to the adata object. @@ -52,8 +60,18 @@ def sliding_window( If ``copy = True``, returns the sliding window annotation(s) as pandas dataframe Otherwise, stores the sliding window annotation(s) in .obs. """ - if overlap < 0: - raise ValueError("Overlap must be non-negative.") + if partial_windows == "split": + if max_nr_cells is None: + raise ValueError("`max_nr_cells` must be set when `partial_windows == split`.") + if window_size is not None: + logg.warning(f"Ingoring `window_size` when using `{partial_windows}`") + if overlap != 0: + logg.warning("Ignoring `overlap` as it cannot be used with `split`") + else: + if max_nr_cells is not None: + logg.warning("Ignoring `max_nr_cells` as `partial_windows != split`") + if overlap < 0: + raise ValueError("Overlap must be non-negative.") if isinstance(adata, SpatialData): adata = adata.table @@ -86,8 +104,13 @@ def sliding_window( # mostly arbitrary choice, except that full integers usually generate windows with 1-2 cells at the borders window_size = max(int(np.floor(coord_range // 3.95)), 1) - if window_size <= 0: - raise ValueError("Window size must be larger than 0.") + if partial_windows != "split": + if window_size <= 0: + raise ValueError("Window size must be larger than 0.") + if overlap >= window_size: + raise ValueError("Overlap must be less than the window size.") + if overlap >= window_size // 2 and partial_windows == "adaptive": + raise ValueError("Overlap must be less than `window_size` // 2 when using `adaptive`.") if library_key is not None and library_key not in adata.obs: raise ValueError(f"Library key '{library_key}' not found in adata.obs") @@ -119,7 +142,10 @@ def sliding_window( max_y=max_y, window_size=window_size, overlap=overlap, - drop_partial_windows=drop_partial_windows, + partial_windows=partial_windows, + lib_coords=lib_coords, + coord_columns=(x_col, y_col), + max_nr_cells=max_nr_cells, ) lib_key = f"{lib}_" if lib is not None else "" @@ -131,15 +157,17 @@ def sliding_window( y_start = window["y_start"] y_end = window["y_end"] - mask = ( - (lib_coords[x_col] >= x_start) - & (lib_coords[x_col] <= x_end) - & (lib_coords[y_col] >= y_start) - & (lib_coords[y_col] <= y_end) + mask = _get_window_mask( + coord_columns=(x_col, y_col), + lib_coords=lib_coords, + x_start=x_start, + x_end=x_end, + y_start=y_start, + y_end=y_end, ) obs_indices = lib_coords.index[mask] - if overlap == 0: + if overlap == 0 or partial_windows == "split": mask = ( (lib_coords[x_col] >= x_start) & (lib_coords[x_col] <= x_end) @@ -175,6 +203,50 @@ def sliding_window( _save_data(adata, attr="obs", key=col_name, data=col_data) +def _get_window_mask( + coord_columns: tuple[str, str], + lib_coords: pd.DataFrame, + x_start: int, + x_end: int, + y_start: int, + y_end: int, +) -> pd.Series: + """ + Compute a boolean mask selecting coordinates that fall within a given window. + + Parameters + ---------- + coord_columns: Tuple[str, str] + Tuple of column names in `adata.obs` that specify the coordinates (x, y), i.e. ('globalX', 'globalY') + lib_coords: pd.DataFrame + DataFrame containing spatial coordinates (e.g. `adata.obs` subset for one library). + Coordinate values are expected to be integers. + x_start: int + Lower bound of the window in x-direction (inclusive). + x_end: int + Upper bound of the window in x-direction (inclusive). + y_start: int + Lower bound of the window in y-direction (inclusive). + y_end: int + Upper bound of the window in y-direction (inclusive). + + Returns + ------- + pd.Series + Boolean mask indicating which rows in `lib_coords` fall inside the specified window. + """ + x_col, y_col = coord_columns + + mask = ( + (lib_coords[x_col] >= x_start) + & (lib_coords[x_col] <= x_end) + & (lib_coords[y_col] >= y_start) + & (lib_coords[y_col] <= y_end) + ) + + return mask + + def _calculate_window_corners( min_x: int, max_x: int, @@ -182,7 +254,10 @@ def _calculate_window_corners( max_y: int, window_size: int, overlap: int = 0, - drop_partial_windows: bool = False, + partial_windows: Literal["adaptive", "drop", "split"] | None = None, + lib_coords: pd.DataFrame | None = None, + coord_columns: tuple[str, str] | None = None, + max_nr_cells: int | None = None, ) -> pd.DataFrame: """ Calculate the corner points of all windows covering the area from min_x to max_x and min_y to max_y, @@ -200,23 +275,45 @@ def _calculate_window_corners( maximum Y coordinate window_size: float size of each window + lib_coords: pd.DataFrame | None + coordinates of all samples for one library + coord_columns: Tuple[str, str] + Tuple of column names in `adata.obs` that specify the coordinates (x, y), i.e. ('globalX', 'globalY') overlap: float overlap between consecutive windows (must be less than window_size) - drop_partial_windows: bool - if True, drop border windows that are smaller than window_size; - if False, create smaller windows at the borders to cover the remaining space. + partial_windows: Literal["adaptive", "drop", "split"] | None + If None, possibly small windows at the edges are kept. + If 'adaptive', all windows might be shrunken a bit to avoid small windows at the edges. + If 'drop', possibly small windows at the edges are removed. + If 'split', windows are split into subwindows until not exceeding `max_nr_cells` Returns ------- windows: pandas DataFrame with columns ['x_start', 'x_end', 'y_start', 'y_end'] """ - if overlap < 0: - raise ValueError("Overlap must be non-negative.") - if overlap >= window_size: - raise ValueError("Overlap must be less than the window size.") + # adjust x and y window size if 'adaptive' + if partial_windows == "adaptive": + total_width = max_x - min_x + total_height = max_y - min_y + + # number of windows in x and y direction + number_x_windows = np.ceil((total_width - overlap) / (window_size - overlap)) + number_y_windows = np.ceil((total_height - overlap) / (window_size - overlap)) + + # window size in x and y direction + x_window_size = (total_width + (number_x_windows - 1) * overlap) / number_x_windows + y_window_size = (total_height + (number_y_windows - 1) * overlap) / number_y_windows + + # avoid float errors + x_window_size = np.ceil(x_window_size) + y_window_size = np.ceil(y_window_size) + else: + x_window_size = window_size + y_window_size = window_size - x_step = window_size - overlap - y_step = window_size - overlap + # create the step sizes for each window + x_step = x_window_size - overlap + y_step = y_window_size - overlap # Generate starting points x_starts = np.arange(min_x, max_x, x_step) @@ -225,16 +322,117 @@ def _calculate_window_corners( # Create all combinations of x and y starting points starts = list(product(x_starts, y_starts)) windows = pd.DataFrame(starts, columns=["x_start", "y_start"]) - windows["x_end"] = windows["x_start"] + window_size - windows["y_end"] = windows["y_start"] + window_size + windows["x_end"] = windows["x_start"] + x_window_size + windows["y_end"] = windows["y_start"] + y_window_size # Adjust windows that extend beyond the bounds - if not drop_partial_windows: + if partial_windows is None: windows["x_end"] = windows["x_end"].clip(upper=max_x) windows["y_end"] = windows["y_end"].clip(upper=max_y) - else: + elif partial_windows == "adaptive": + # as window_size is an integer to avoid float errors, it can exceed max_x and max_y -> clip + windows["x_end"] = windows["x_end"].clip(upper=max_x) + windows["y_end"] = windows["y_end"].clip(upper=max_y) + + # remove redundant windows in the corners + redundant_windows = ((windows["x_end"] - windows["x_start"]) <= overlap) | ( + (windows["y_end"] - windows["y_start"]) <= overlap + ) + windows = windows[~redundant_windows] + elif partial_windows == "drop": valid_windows = (windows["x_end"] <= max_x) & (windows["y_end"] <= max_y) windows = windows[valid_windows] + elif partial_windows == "split": + # split the slide recursively into windows with at most max_nr_cells + x_col, y_col = coord_columns + + coord_x_sorted = lib_coords.sort_values(by=[x_col]) + coord_y_sorted = lib_coords.sort_values(by=[y_col]) + + windows = _split_window( + max_nr_cells, (x_col, y_col), coord_x_sorted, coord_y_sorted, min_x, max_x, min_y, max_y + ).sort_values(["x_start", "x_end", "y_start", "y_end"]) + else: + raise ValueError(f"{partial_windows} is not a valid partial_windows argument.") windows = windows.reset_index(drop=True) return windows[["x_start", "x_end", "y_start", "y_end"]] + + +def _split_window( + max_cells: int, + coord_columns: tuple[str, str], + coord_x_sorted: pd.DataFrame, + coord_y_sorted: pd.DataFrame, + x_start: int, + x_end: int, + y_start: int, + y_end: int, +) -> pd.DataFrame: + """ + Recursively split a rectangular window into subwindows such that each subwindow + contains at most `max_cells` cells and at least `max_cells` // 2 cells. + + Parameters + ---------- + max_cells : int + Maximum number of cells allowed per window. + coord_columns: Tuple[str, str] + Tuple of column names in `adata.obs` that specify the coordinates (x, y), i.e. ('globalX', 'globalY') + coord_x_sorted : pandas.DataFrame + DataFrame containing cell coordinates, sorted by `x_col`. + coord_y_sorted : pandas.DataFrame + DataFrame containing cell coordinates, sorted by `y_col`. + x_start : int + Left (minimum) x coordinate of the current window. + x_end : int + Right (maximum) x coordinate of the current window. + y_start : int + Bottom (minimum) y coordinate of the current window. + y_end : int + Top (maximum) y coordinate of the current window. + + Returns + ------- + windows: pandas DataFrame with columns ['x_start', 'x_end', 'y_start', 'y_end'] + """ + x_col, y_col = coord_columns + + # return current window if it contains less cells than max_cells + n_cells = _get_window_mask(coord_columns, coord_x_sorted, x_start, x_end, y_start, y_end).sum() + + if n_cells <= max_cells: + return pd.DataFrame({"x_start": [x_start], "x_end": [x_end], "y_start": [y_start], "y_end": [y_end]}) + + # define start and stop indices of subsetted windows + sub_coord_x_sorted = coord_x_sorted[ + _get_window_mask(coord_columns, coord_x_sorted, x_start, x_end, y_start, y_end) + ].reset_index(drop=True) + + sub_coord_y_sorted = coord_y_sorted[ + _get_window_mask(coord_columns, coord_y_sorted, x_start, x_end, y_start, y_end) + ].reset_index(drop=True) + + middle_pos = len(sub_coord_x_sorted) // 2 + + if (x_end - x_start) > (y_end - y_start): + # vertical split + x_middle = sub_coord_x_sorted[x_col].iloc[middle_pos] + + indices = ((x_start, x_middle, y_start, y_end), (x_middle, x_end, y_start, y_end)) + else: + # horizontal split + y_middle = sub_coord_y_sorted.loc[middle_pos, y_col] + + indices = ((x_start, x_end, y_start, y_middle), (x_start, x_end, y_middle, y_end)) + + # recursively continue with either left&right or upper&lower windows pairs + windows = [] + for x_start, x_end, y_start, y_end in indices: + windows.append( + _split_window( + max_cells, (x_col, y_col), sub_coord_x_sorted, sub_coord_y_sorted, x_start, x_end, y_start, y_end + ) + ) + + return pd.concat(windows) diff --git a/tests/tools/test_sliding_window.py b/tests/tools/test_sliding_window.py index 3dd670b00..672290de9 100644 --- a/tests/tools/test_sliding_window.py +++ b/tests/tools/test_sliding_window.py @@ -8,17 +8,19 @@ class TestSlidingWindow: @pytest.mark.parametrize( - "windowsize_overlap_drop", + "windowsize_overlap_partial", [ - (300, 0, False), - (300, 50, False), - (300, 50, True), + (300, 0, None), + (300, 50, None), + (300, 50, "drop"), + (300, 0, "adaptive"), + (300, 50, "adaptive"), ], ) def test_sliding_window_several_slices( self, adata_mibitof: AnnData, - windowsize_overlap_drop: tuple[int, int, bool], + windowsize_overlap_partial: tuple[int, int, str | None], sliding_window_key: str = "sliding_window_key", library_key: str = "library_id", ): @@ -30,7 +32,7 @@ def _count_total_assignments(): total_cells += df[col].sum() return total_cells - window_size, overlap, drop_partial_windows = windowsize_overlap_drop + window_size, overlap, partial_windows = windowsize_overlap_partial df = sliding_window( adata_mibitof, library_key=library_key, @@ -39,7 +41,7 @@ def _count_total_assignments(): coord_columns=("globalX", "globalY"), sliding_window_key=sliding_window_key, copy=True, - drop_partial_windows=drop_partial_windows, + partial_windows=partial_windows, ) if overlap == 0: @@ -50,9 +52,12 @@ def _count_total_assignments(): else: sliding_window_cols = df.columns[df.columns.str.contains("sliding_window")] - if drop_partial_windows: + if partial_windows == "drop": assert len(sliding_window_cols) == 27 assert _count_total_assignments() == 2536 + elif partial_windows == "adaptive": + assert len(sliding_window_cols) == 48 + assert _count_total_assignments() == 4411 else: assert len(sliding_window_cols) == 70 assert _count_total_assignments() == 4569 @@ -110,6 +115,53 @@ def test_sliding_window_invalid_window_size( copy=True, ) + with pytest.raises(ValueError, match="`max_nr_cells` must be set when `partial_windows == split`."): + sliding_window( + adata_squaregrid, + window_size=None, + overlap=0, + partial_windows="split", + coord_columns=("globalX", "globalY"), + copy=True, + ) + + def test_sliding_window_split_nr_cells( + self, + adata_mibitof: AnnData, + sliding_window_key: str = "sliding_window_key", + library_key: str = "library_id", + ): + """ + Test that when using 'split', each window contains at most max_nr_cells + and at least max_nr_cells // 2 cells, + unless the total number of cells is smaller than max_nr_cells // 2. + """ + max_nr_cells = 100 + total_cells = adata_mibitof.n_obs + + df = sliding_window( + adata_mibitof, + library_key=library_key, + sliding_window_key=sliding_window_key, + partial_windows="split", + max_nr_cells=max_nr_cells, + copy=True, + ) + + counts = df[sliding_window_key].value_counts() + + # all windows respect the upper bound + assert counts.max() <= max_nr_cells + + # determine strict lower bound + lower_bound = max_nr_cells // 2 + if total_cells < lower_bound: + # if total cells are too few, just one window is allowed smaller + assert counts.max() == total_cells + else: + # otherwise, every window must satisfy the lower bound + assert (counts >= lower_bound).all() + def test_calculate_window_corners_overlap(self): min_x = 0 max_x = 200 @@ -125,7 +177,7 @@ def test_calculate_window_corners_overlap(self): max_y=max_y, window_size=window_size, overlap=overlap, - drop_partial_windows=False, + partial_windows=None, ) assert windows.shape == (9, 4) @@ -147,7 +199,7 @@ def test_calculate_window_corners_no_overlap(self): max_y=max_y, window_size=window_size, overlap=overlap, - drop_partial_windows=False, + partial_windows=None, ) assert windows.shape == (4, 4) @@ -169,9 +221,31 @@ def test_calculate_window_corners_drop_partial_windows(self): max_y=max_y, window_size=window_size, overlap=overlap, - drop_partial_windows=True, + partial_windows="drop", ) assert windows.shape == (4, 4) assert windows.iloc[0].values.tolist() == [0, 100, 0, 100] assert windows.iloc[-1].values.tolist() == [80, 180, 80, 180] + + def test_calculate_window_corners_adaptive_partial_windows(self): + min_x = 0 + max_x = 200 + min_y = 0 + max_y = 200 + window_size = 100 + overlap = 20 + + windows = _calculate_window_corners( + min_x=min_x, + max_x=max_x, + min_y=min_y, + max_y=max_y, + window_size=window_size, + overlap=overlap, + partial_windows="adaptive", + ) + + assert windows.shape == (9, 4) + assert windows.iloc[0].values.tolist() == [0, 80, 0, 80] + assert windows.iloc[-1].values.tolist() == [120, 200, 120, 200]