Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 110 additions & 55 deletions src/pycea/tl/ancestral_linkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def ancestral_linkage(
metric: _TreeMetric = "path",
symmetrize: Literal["mean", "max", "min", None] = None,
test: Literal["permutation", None] = None,
normalize: bool = False,
alternative: Literal["two-sided", None] = None,
permutation_mode: Literal["all", "non_target"] = "all",
n_permutations: int = 100,
Expand All @@ -499,6 +500,7 @@ def ancestral_linkage(
metric: _TreeMetric = "path",
symmetrize: Literal["mean", "max", "min", None] = None,
test: Literal["permutation", None] = None,
normalize: bool = False,
alternative: Literal["two-sided", None] = None,
permutation_mode: Literal["all", "non_target"] = "all",
n_permutations: int = 100,
Expand All @@ -519,6 +521,7 @@ def ancestral_linkage(
metric: _TreeMetric = "path",
symmetrize: Literal["mean", "max", "min", None] = None,
test: Literal["permutation", None] = None,
normalize: bool = False,
alternative: Literal["two-sided", None] = None,
Comment on lines 523 to 525
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep normalize from shifting positional arguments

Adding normalize between test and alternative changes the positional call contract of this public API: existing callers that previously passed alternative (or later args) positionally will now silently bind those values to normalize instead. For example, a positional 'two-sided' now becomes a truthy normalize value while alternative stays None, which changes both p-value calculation and stored linkage values without raising an error. Please make normalize keyword-only or append it after existing positional parameters to preserve backward compatibility.

Useful? React with 👍 / 👎.

permutation_mode: Literal["all", "non_target"] = "all",
n_permutations: int = 100,
Expand Down Expand Up @@ -577,8 +580,12 @@ def ancestral_linkage(

- ``'permutation'``: randomly shuffle cell-category labels ``n_permutations``
times and recompute linkage each time to build a null distribution.
Z-scores and p-values are added to the stats table. The stored linkage
matrix is replaced by z-scores when this test is run.
Z-scores and p-values are added to the stats table.
normalize
If ``True`` and ``test='permutation'``, subtract the permuted mean from the
observed values: pairwise linkage matrix becomes ``observed - permuted_mean``;
single-target ``tdata.obs['{target}_linkage']`` becomes
``cell_score - category_permuted_mean``. Ignored when ``test=None``.
alternative
The alternative hypothesis for the permutation test (ignored when
``test=None``):
Expand Down Expand Up @@ -636,10 +643,13 @@ def ancestral_linkage(
Sets the following fields:

* ``tdata.obs['{target}_linkage']`` : :class:`Series <pandas.Series>` (dtype ``float``) – single-target mode only.
Per-cell distance to the nearest cell of the target category.
Per-cell distance to the nearest cell of the target category. When
``normalize=True`` and ``test='permutation'``, replaced by
``cell_score - category_permuted_mean``.
* ``tdata.uns['{key_added}_linkage']`` : :class:`DataFrame <pandas.DataFrame>` – pairwise mode only.
Category × category linkage matrix (source rows, target columns).
Contains z-scores instead of raw distances when ``test='permutation'``.
When ``normalize=True`` and ``test='permutation'``, contains
``observed - permuted_mean`` instead of raw distances.
* ``tdata.uns['{key_added}_linkage_params']`` : ``dict`` – pairwise mode only.
Parameters used to compute the linkage matrix.
* ``tdata.uns['{key_added}_linkage_stats']`` : :class:`DataFrame <pandas.DataFrame>` – pairwise mode only.
Expand Down Expand Up @@ -710,32 +720,28 @@ def ancestral_linkage(

# Always use "closest": min path or max lca
single_agg: str = "max" if metric == "lca" else "min"
all_scores = _compute_scores(tdata, trees, leaf_to_cat, [target], single_agg, metric, depth_key)
sign = 1.0 if metric == "lca" else -1.0

# Per-leaf scores
score_map = {leaf: scores.get(target, np.nan) for leaf, scores in all_scores.items()}
tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(score_map, dtype=float))

if test == "permutation":
# Observed: mean per source category
obs_cat_scores: dict = {}
def _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl, extra_row_fields=None):
"""Run the permutation test for one tree (or globally) and return (rows, cat_null_mean)."""
tree_obs_cat: dict = {}
for cat in all_cats:
vals = [score_map[l] for l in cat_to_leaves[cat] if l in score_map and not np.isnan(score_map[l])]
obs_cat_scores[cat] = float(np.mean(vals)) if vals else np.nan
vals = [tree_sm[l] for l in tree_cl[cat] if l in tree_sm and not np.isnan(tree_sm[l])]
tree_obs_cat[cat] = float(np.mean(vals)) if vals else np.nan

all_leaf_list = list(leaf_to_cat.keys())
tree_leaf_list = list(tree_lc.keys())
perm_seeds = np.random.randint(0, 2**31, size=n_permutations)

if permutation_mode == "non_target":
t_leaves = cat_to_leaves[target]
t_set = set(t_leaves)
nt_leaves = [l for l in all_leaf_list if l not in t_set]
t_lv = tree_cl[target]
t_set = set(t_lv)
nt_lv = [l for l in tree_leaf_list if l not in t_set]
_PERM_SINGLE_NON_TARGET_DATA.clear()
_PERM_SINGLE_NON_TARGET_DATA.update({
"fixed_scores": score_map, # precomputed above, target leaves fixed
"target_leaves": list(t_leaves),
"nt_leaves": nt_leaves,
"nt_cats": [leaf_to_cat[l] for l in nt_leaves],
"fixed_scores": tree_sm,
"target_leaves": list(t_lv),
"nt_leaves": nt_lv,
"nt_cats": [tree_lc[l] for l in nt_lv],
"target": target,
"all_cats": all_cats,
})
Expand All @@ -744,9 +750,9 @@ def ancestral_linkage(
_PERM_SINGLE_DATA.clear()
_PERM_SINGLE_DATA.update({
"tdata": tdata,
"trees": trees,
"all_leaves": all_leaf_list,
"all_cat_vals": [leaf_to_cat[l] for l in all_leaf_list],
"trees": single_tree,
"all_leaves": tree_leaf_list,
"all_cat_vals": [tree_lc[l] for l in tree_leaf_list],
"target": target,
"single_agg": single_agg,
"metric": metric,
Expand All @@ -755,52 +761,100 @@ def ancestral_linkage(
})
null_results = _run_parallel(_perm_single_target_worker, perm_seeds, n_threads)

null_cat_scores: dict = defaultdict(list)
null_cat: dict = defaultdict(list)
for perm_result in null_results:
for cat in all_cats:
null_cat_scores[cat].append(perm_result[cat])
null_cat[cat].append(perm_result[cat])

rows = []
rows: list = []
cat_null_mean: dict = {}
for cat in all_cats:
obs_val = obs_cat_scores[cat]
null_vals = np.array([v for v in null_cat_scores[cat] if not np.isnan(v)], dtype=float)
obs_val = tree_obs_cat[cat]
null_vals = np.array([v for v in null_cat[cat] if not np.isnan(v)], dtype=float)
if len(null_vals) > 0:
perm_val = float(np.mean(null_vals))
sign = 1.0 if metric == "lca" else -1.0
z = sign * (obs_val - perm_val) / (float(np.std(null_vals)) + 1e-10)
null_std = float(np.std(null_vals))
z = sign * (obs_val - perm_val) / (null_std + 1e-10)
if alternative == "two-sided":
p = float(np.mean(np.abs(null_vals - perm_val) >= abs(obs_val - perm_val)))
elif metric == "lca":
p = float(np.mean(null_vals >= obs_val))
else:
p = float(np.mean(null_vals <= obs_val))
else:
perm_val, z, p = np.nan, np.nan, np.nan
rows.append(
{
"source": cat,
"target": target,
"value": obs_val,
"permuted_value": perm_val,
"z_score": z,
"p_value": p,
}
)
perm_val, null_std, z, p = np.nan, 0.0, np.nan, np.nan
cat_null_mean[cat] = perm_val
row: dict = {"source": cat, "target": target, "value": obs_val,
"permuted_value": perm_val, "z_score": z, "p_value": p}
if extra_row_fields:
row.update(extra_row_fields)
rows.append(row)

return rows, cat_null_mean

if by_tree and test == "permutation":
# Per-tree: compute scores, run permutation, normalize per cell independently
merged_score_map: dict = {}
merged_norm_map: dict = {}
all_rows: list = []

for tree_key, t in trees.items():
t_nodes = set(t.nodes())
single_tree = {tree_key: t}
tree_lc: dict = {l: c for l, c in leaf_to_cat.items() if l in t_nodes}
tree_cl: dict = defaultdict(list)
for l, c in tree_lc.items():
tree_cl[c].append(l)

test_df = pd.DataFrame(rows)
tree_all_scores = _compute_scores(
tdata, single_tree, tree_lc, [target], single_agg, metric, depth_key
)
tree_sm: dict = {l: s.get(target, np.nan) for l, s in tree_all_scores.items()}
merged_score_map.update(tree_sm)

rows, cat_null_mean = _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl,
extra_row_fields={"tree": tree_key})
all_rows.extend(rows)
if normalize:
for leaf, score in tree_sm.items():
cat = tree_lc.get(leaf)
perm_val = cat_null_mean.get(cat, np.nan) if cat is not None else np.nan
merged_norm_map[leaf] = (score - perm_val) if not np.isnan(score) else np.nan

tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(merged_score_map, dtype=float))
if normalize:
tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(merged_norm_map, dtype=float))
test_df = pd.DataFrame(all_rows)
tdata.uns[f"{key_added}_test"] = test_df
if copy:
return test_df

if copy:
result_series = pd.Series(
{
cat: float(np.nanmean([score_map.get(l, np.nan) for l in cat_to_leaves[cat]]))
for cat in all_cats
},
name=f"{target}_linkage",
)
return result_series.to_frame()
else:
# Global (non-by_tree) path
all_scores = _compute_scores(tdata, trees, leaf_to_cat, [target], single_agg, metric, depth_key)
score_map = {leaf: scores.get(target, np.nan) for leaf, scores in all_scores.items()}
tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(score_map, dtype=float))

if test == "permutation":
rows, cat_null_mean = _run_single_perm(trees, leaf_to_cat, score_map, cat_to_leaves)
if normalize:
norm_map = {
leaf: (score - cat_null_mean.get(leaf_to_cat.get(leaf), np.nan))
if not np.isnan(score) else np.nan
for leaf, score in score_map.items()
}
tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(norm_map, dtype=float))
test_df = pd.DataFrame(rows)
tdata.uns[f"{key_added}_test"] = test_df
if copy:
return test_df

if copy:
result_series = pd.Series(
{cat: float(np.nanmean([score_map.get(l, np.nan) for l in cat_to_leaves[cat]])) for cat in all_cats},
name=f"{target}_linkage",
)
return result_series.to_frame()

# ── pairwise mode ─────────────────────────────────────────────────────────
else:
Expand Down Expand Up @@ -886,8 +940,8 @@ def ancestral_linkage(
row["p_value"] = global_p_df.loc[src_cat, tgt_cat]
stats_rows.append(row)

# uns[linkage] = observed - permuted_mean (symmetrized) if test ran, else raw linkage (symmetrized)
output_df: pd.DataFrame = (linkage_df - global_null_mean_df) if test == "permutation" else linkage_df
# uns[linkage] = observed - permuted_mean if normalize, else raw linkage (both symmetrized if requested)
output_df: pd.DataFrame = (linkage_df - global_null_mean_df) if (test == "permutation" and normalize) else linkage_df
if symmetrize is not None:
output_df = _symmetrize_matrix(output_df, symmetrize)

Expand All @@ -897,6 +951,7 @@ def ancestral_linkage(
"metric": metric,
"symmetrize": symmetrize,
"test": test,
"normalize": normalize,
"by_tree": by_tree,
"depth_key": depth_key,
}
Expand Down
103 changes: 103 additions & 0 deletions tests/test_ancestral_linkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,3 +669,106 @@ def test_symmetrize_invalid_raises():

with pytest.raises(ValueError, match="symmetrize"):
tl.ancestral_linkage(tdata, groupby="ct", symmetrize="meen")


# ── normalize parameter tests ──────────────────────────────────────────────────


def test_pairwise_normalize_stores_value_minus_permuted(balanced_tdata):
"""normalize=True stores observed - permuted_mean in the pairwise linkage matrix."""
tdata = balanced_tdata
tl.ancestral_linkage(
tdata, groupby="celltype", test="permutation", normalize=True,
n_permutations=20, random_state=0,
)
linkage = tdata.uns["celltype_linkage"]
stats = tdata.uns["celltype_linkage_stats"]
for _, row in stats.iterrows():
expected = row["value"] - row["permuted_value"]
assert linkage.loc[row["source"], row["target"]] == pytest.approx(expected, abs=1e-9)


def test_pairwise_no_normalize_stores_raw_linkage(balanced_tdata):
"""normalize=False (default) stores raw linkage even when test='permutation'."""
tdata = balanced_tdata
tl.ancestral_linkage(
tdata, groupby="celltype", test="permutation", normalize=False,
n_permutations=20, random_state=0,
)
linkage = tdata.uns["celltype_linkage"]
stats = tdata.uns["celltype_linkage_stats"]
for _, row in stats.iterrows():
assert linkage.loc[row["source"], row["target"]] == pytest.approx(row["value"], abs=1e-9)


def test_single_target_normalize_overwrites_linkage(balanced_tdata):
"""normalize=True replaces _linkage in obs with score - category_permuted_mean."""
tdata = balanced_tdata
# Raw linkage without normalize
tl.ancestral_linkage(tdata, groupby="celltype", target="B", test="permutation",
normalize=False, n_permutations=30, random_state=1)
raw = tdata.obs["B_linkage"].copy()

# Normalized linkage
tl.ancestral_linkage(tdata, groupby="celltype", target="B", test="permutation",
normalize=True, n_permutations=30, random_state=1)
norm = tdata.obs["B_linkage"].copy()
test_df = tdata.uns["celltype_test"]

# Normalized values should differ from raw (permuted_mean is non-trivial)
# And for each cell: norm = raw - cat_permuted_mean
for cell in ["a1", "a2", "b1", "b2"]:
cat = tdata.obs.loc[cell, "celltype"]
perm_val = test_df.loc[test_df["source"] == cat, "permuted_value"].iloc[0]
assert norm[cell] == pytest.approx(raw[cell] - perm_val, abs=1e-9)

assert "B_norm_linkage" not in tdata.obs.columns


def test_single_target_no_normalize_keeps_raw(balanced_tdata):
"""normalize=False (default) does not overwrite _linkage in obs."""
tdata = balanced_tdata
tl.ancestral_linkage(tdata, groupby="celltype", target="B", test="permutation",
normalize=False, n_permutations=20, random_state=2)
# With normalize=False, a1 (category A, not in B) has raw min-path score 2.0
assert tdata.obs.loc["a1", "B_linkage"] == pytest.approx(2.0)


def test_single_target_by_tree_permutation(three_cat_tdata):
"""by_tree=True with single target + permutation runs per-tree (stats has tree column)."""
tdata = three_cat_tdata
tl.ancestral_linkage(
tdata, groupby="celltype", target="A", test="permutation", by_tree=True,
n_permutations=20, random_state=7,
)
assert "A_linkage" in tdata.obs.columns
assert "A_norm_linkage" not in tdata.obs.columns
test_df = tdata.uns["celltype_test"]
assert "tree" in test_df.columns
for tree_key in tdata.obst:
assert tree_key in test_df["tree"].values


def test_single_target_by_tree_normalize(three_cat_tdata):
"""by_tree=True + normalize=True normalizes each cell using its own tree's null distribution."""
tdata = three_cat_tdata
tl.ancestral_linkage(
tdata, groupby="celltype", target="A", test="permutation",
by_tree=True, normalize=True, n_permutations=20, random_state=7,
)
assert "A_linkage" in tdata.obs.columns
assert "A_norm_linkage" not in tdata.obs.columns
# _linkage values should be finite for leaves with valid scores
assert tdata.obs["A_linkage"].notna().any()


def test_single_target_by_tree_perm_non_target(balanced_tdata):
"""by_tree + permutation_mode='non_target' in single-target mode produces tree column in stats."""
tdata = balanced_tdata
tl.ancestral_linkage(
tdata, groupby="celltype", target="B", test="permutation",
by_tree=True, permutation_mode="non_target", n_permutations=20, random_state=3,
)
assert "B_norm_linkage" not in tdata.obs.columns
test_df = tdata.uns["celltype_test"]
assert "tree" in test_df.columns
Loading