Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7e67cc1
fix
selmanozleyen Aug 24, 2025
dbdda0c
attempt to add tests and add better documentation to the functions to…
selmanozleyen Aug 29, 2025
a3366e0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2025
b3ddcb7
add vibe coded tests prune later
selmanozleyen Aug 29, 2025
2fffcfd
add tests for spectral transform
selmanozleyen Sep 1, 2025
eeab937
Merge branch 'fix/spectral-transform' of https://github.com/scverse/s…
selmanozleyen Sep 4, 2025
2ce6e85
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2025
651505a
complete the merge conflict issue
selmanozleyen Sep 4, 2025
303d5c4
add fau and clean up code
selmanozleyen Sep 4, 2025
8c881be
make loops parallel
selmanozleyen Sep 4, 2025
bcb5989
Merge branch 'fix/spectral-transform' of https://github.com/scverse/s…
selmanozleyen Sep 4, 2025
172080d
specify fast arrayutils dep
selmanozleyen Sep 4, 2025
806273a
Merge branch 'main' into fix/spectral-transform
selmanozleyen Sep 4, 2025
a3db2d0
Merge branch 'main' into fix/spectral-transform
flying-sheep Sep 4, 2025
ea457c3
Merge branch 'main' into fix/spectral-transform
selmanozleyen Oct 1, 2025
6e60c6c
Merge branch 'main' into fix/spectral-transform
selmanozleyen Nov 3, 2025
22acdad
Add fast_array_utils to project dependencies
selmanozleyen Nov 3, 2025
017a763
Merge branch 'main' into fix/spectral-transform
selmanozleyen Nov 7, 2025
126fd38
cache kernel
selmanozleyen Nov 7, 2025
a24419a
forgot to save file bf commit
selmanozleyen Nov 7, 2025
4eae80c
Merge branch 'main' into fix/spectral-transform
selmanozleyen Nov 7, 2025
4f6010b
Apply suggestion from @flying-sheep
flying-sheep Nov 7, 2025
8f178cc
remove unused imports (idk why this didn't fail on linter)
selmanozleyen Nov 7, 2025
c7c8f83
remove more unused imports
selmanozleyen Nov 7, 2025
3d82ed5
Merge branch 'main' into fix/spectral-transform
selmanozleyen Nov 10, 2025
6680ec2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2025
0086dac
move to float32
selmanozleyen Nov 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ dependencies = [
"xarray>=2024.10.0",
"imagecodecs>=2025.8.2,<2026",
"zarr>=2.6.1",
"fast-array-utils",
]

[project.optional-dependencies]
Expand Down
73 changes: 57 additions & 16 deletions src/squidpy/gr/_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
import pandas as pd
from anndata import AnnData
from anndata.utils import make_index_unique
from numba import njit
from fast_array_utils import stats as fau_stats
from numba import njit, prange
from scanpy import logging as logg
from scipy.sparse import (
SparseEfficiencyWarning,
block_diag,
csr_array,
csr_matrix,
isspmatrix_csr,
spmatrix,
Expand Down Expand Up @@ -388,7 +390,7 @@ def _build_connectivity(
if delaunay:
tri = Delaunay(coords)
indptr, indices = tri.vertex_neighbor_vertices
Adj = csr_matrix((np.ones_like(indices, dtype=np.float64), indices, indptr), shape=(N, N))
Adj = csr_matrix((np.ones_like(indices, dtype=np.float32), indices, indptr), shape=(N, N))

if return_distance:
# fmt: off
Expand Down Expand Up @@ -423,7 +425,7 @@ def _build_connectivity(
col_indices = np.concatenate(col_indices)

Adj = csr_matrix(
(np.ones_like(row_indices, dtype=np.float64), (row_indices, col_indices)),
(np.ones_like(row_indices, dtype=np.float32), (row_indices, col_indices)),
shape=(N, N),
)
if return_distance:
Expand All @@ -438,29 +440,68 @@ def _build_connectivity(
return Adj


@njit
def outer(indices: NDArrayA, indptr: NDArrayA, degrees: NDArrayA) -> NDArrayA:
res = np.empty_like(indices, dtype=np.float64)
start = 0
for i in range(len(indptr) - 1):
ixs = indices[indptr[i] : indptr[i + 1]]
res[start : start + len(ixs)] = degrees[i] * degrees[ixs]
start += len(ixs)
@njit(parallel=True, cache=True)
def _csr_bilateral_diag_scale_helper(
mat: csr_array | csr_matrix,
degrees: NDArrayA,
) -> NDArrayA:
"""
Return an array F aligned with CSR non-zeros such that
F[k] = d[i] * data[k] * d[j] for the k-th non-zero (i, j) in CSR order.

Parameters
----------

data : array of float
CSR `data` (non-zero values).
indices : array of int
CSR `indices` (column indices).
indptr : array of int
CSR `indptr` (row pointer).
degrees : array of float, shape (n,)
Diagonal scaling vector.

Returns
-------
array of float
Length equals len(data). Entry-wise factors d_i * d_j * data[k]
"""

res = np.empty_like(mat.data, dtype=np.float32)
for i in prange(len(mat.indptr) - 1):
ixs = mat.indices[mat.indptr[i] : mat.indptr[i + 1]]
res[mat.indptr[i] : mat.indptr[i + 1]] = degrees[i] * degrees[ixs] * mat.data[mat.indptr[i] : mat.indptr[i + 1]]

return res


def symmetric_normalize_csr(adj: spmatrix) -> csr_matrix:
"""
Return D^{-1/2} * A * D^{-1/2}, where D = diag(degrees(A)) and A = adj.


Parameters
----------
adj : scipy.sparse.csr_matrix

Returns
-------
scipy.sparse.csr_matrix
"""
degrees = np.squeeze(np.array(np.sqrt(1.0 / fau_stats.sum(adj, axis=0))))
if adj.shape[0] != len(degrees):
raise ValueError("len(degrees) must equal number of rows of adj")
res_data = _csr_bilateral_diag_scale_helper(adj, degrees)
return csr_matrix((res_data, adj.indices, adj.indptr), shape=adj.shape)


def _transform_a_spectral(a: spmatrix) -> spmatrix:
if not isspmatrix_csr(a):
a = a.tocsr()
if not a.nnz:
return a

degrees = np.squeeze(np.array(np.sqrt(1.0 / a.sum(axis=0))))
a = a.multiply(outer(a.indices, a.indptr, degrees))
a.eliminate_zeros()

return a
return symmetric_normalize_csr(a)


def _transform_a_cosine(a: spmatrix) -> spmatrix:
Expand Down
69 changes: 69 additions & 0 deletions tests/graph/test_spatial_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,72 @@ def test_mask_graph(
negative_mask=True,
key_added=key_added,
)

def test_spatial_neighbors_transform_mathematical_properties(self, non_visium_adata: AnnData):
"""
Test mathematical properties of each transform.
"""
# Test spectral transform properties
spatial_neighbors(non_visium_adata, delaunay=True, coord_type=None, transform="spectral")
adj_spectral = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray()

# Spectral transform should be symmetric
np.testing.assert_allclose(adj_spectral, adj_spectral.T, atol=1e-10)

# Spectral transform should have normalized rows (L2 norm <= 1)
row_norms = np.sqrt(np.sum(adj_spectral**2, axis=1))
np.testing.assert_array_less(row_norms, 1.0 + 1e-10)

# Test cosine transform properties
spatial_neighbors(non_visium_adata, delaunay=True, coord_type=None, transform="cosine")
adj_cosine = non_visium_adata.obsp[Key.obsp.spatial_conn()].toarray()

# Cosine transform should be symmetric
np.testing.assert_allclose(adj_cosine, adj_cosine.T, atol=1e-10)

# Cosine transform should have values in [-1, 1]
np.testing.assert_array_less(-1.0 - 1e-10, adj_cosine)
np.testing.assert_array_less(adj_cosine, 1.0 + 1e-10)

# Diagonal of cosine transform should be 1 (self-similarity)
np.testing.assert_allclose(np.diag(adj_cosine), 1.0, atol=1e-10)

def test_spatial_neighbors_transform_edge_cases(self, non_visium_adata: AnnData):
"""
Test transforms with edge cases (empty graph, single node, etc.).
"""
# Test with a very small dataset
small_adata = non_visium_adata[:5].copy() # Only 5 points

# Test all transforms with small dataset
for transform in [None, "spectral", "cosine"]:
spatial_neighbors(small_adata, delaunay=True, coord_type=None, transform=transform)
assert Key.obsp.spatial_conn() in small_adata.obsp
assert Key.obsp.spatial_dist() in small_adata.obsp

# Verify transform parameter is saved
assert small_adata.uns[Key.uns.spatial_neighs()]["params"]["transform"] == transform

def test_spatial_neighbors_spectral_transform_properties(self, non_visium_adata: AnnData):
"""
Test that spectral transform preserves nonzero pattern and normalizes rows to sum to 1.
"""
# Apply spatial_neighbors without transform
spatial_neighbors(non_visium_adata, delaunay=True, coord_type=None, transform=None)
adj_no_transform = non_visium_adata.obsp[Key.obsp.spatial_conn()].copy()

# Apply spatial_neighbors with spectral transform
spatial_neighbors(non_visium_adata, delaunay=True, coord_type=None, transform="spectral")
adj_spectral = non_visium_adata.obsp[Key.obsp.spatial_conn()]

# Check that nonzero patterns are identical
np.testing.assert_array_equal(
adj_no_transform.nonzero(),
adj_spectral.nonzero(),
err_msg="Spectral transform should preserve the sparsity pattern",
)

w = np.linalg.eigvals(adj_spectral.toarray())
# Eigenvalues should be in range [-1, 1]
np.testing.assert_array_less(w, 1.0, err_msg="Eigenvalues should be <= 1")
np.testing.assert_array_less(-1.0, w, err_msg="Eigenvalues should be >= -1")
Loading