From 54d418e2a28e2b92b11fab710b223f40c46dcb0e Mon Sep 17 00:00:00 2001 From: colganwi Date: Sun, 29 Mar 2026 22:07:32 -0400 Subject: [PATCH 1/2] feat(tl): improve ancestral_linkage permutation test in single-target mode Three related changes: - Pairwise permutation: update docstring to reflect that the stored linkage matrix is observed - permuted_mean (not z-scores). - Single-target + permutation: add tdata.obs['{target}_norm_linkage'] with per-cell z-scores normalized by each cell's source-category null distribution. - Single-target + by_tree + permutation: run permutation, stats, and normalization independently per tree so each cell is normalized by its own tree's null model. Co-Authored-By: Claude Sonnet 4.6 --- src/pycea/tl/ancestral_linkage.py | 145 +++++++++++++++++++----------- tests/test_ancestral_linkage.py | 72 +++++++++++++++ 2 files changed, 167 insertions(+), 50 deletions(-) diff --git a/src/pycea/tl/ancestral_linkage.py b/src/pycea/tl/ancestral_linkage.py index b285a4a..30ef505 100644 --- a/src/pycea/tl/ancestral_linkage.py +++ b/src/pycea/tl/ancestral_linkage.py @@ -578,7 +578,7 @@ def ancestral_linkage( - ``'permutation'``: randomly shuffle cell-category labels ``n_permutations`` times and recompute linkage each time to build a null distribution. Z-scores and p-values are added to the stats table. The stored linkage - matrix is replaced by z-scores when this test is run. + matrix is replaced by ``observed - permuted_mean`` when this test is run. alternative The alternative hypothesis for the permutation test (ignored when ``test=None``): @@ -637,9 +637,12 @@ def ancestral_linkage( * ``tdata.obs['{target}_linkage']`` : :class:`Series ` (dtype ``float``) – single-target mode only. Per-cell distance to the nearest cell of the target category. + * ``tdata.obs['{target}_norm_linkage']`` : :class:`Series ` (dtype ``float``) – single-target mode with ``test='permutation'`` only. + Per-cell z-score: ``sign * (score - null_mean) / null_std``, where the null + distribution is taken from the cell's source-category permutations. * ``tdata.uns['{key_added}_linkage']`` : :class:`DataFrame ` – pairwise mode only. Category × category linkage matrix (source rows, target columns). - Contains z-scores instead of raw distances when ``test='permutation'``. + Contains ``observed - permuted_mean`` instead of raw distances when ``test='permutation'``. * ``tdata.uns['{key_added}_linkage_params']`` : ``dict`` – pairwise mode only. Parameters used to compute the linkage matrix. * ``tdata.uns['{key_added}_linkage_stats']`` : :class:`DataFrame ` – pairwise mode only. @@ -710,32 +713,28 @@ def ancestral_linkage( # Always use "closest": min path or max lca single_agg: str = "max" if metric == "lca" else "min" - all_scores = _compute_scores(tdata, trees, leaf_to_cat, [target], single_agg, metric, depth_key) + sign = 1.0 if metric == "lca" else -1.0 - # Per-leaf scores - score_map = {leaf: scores.get(target, np.nan) for leaf, scores in all_scores.items()} - tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(score_map, dtype=float)) - - if test == "permutation": - # Observed: mean per source category - obs_cat_scores: dict = {} + def _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl, extra_row_fields=None): + """Run the permutation test for one tree (or globally) and return (rows, norm_map).""" + tree_obs_cat: dict = {} for cat in all_cats: - vals = [score_map[l] for l in cat_to_leaves[cat] if l in score_map and not np.isnan(score_map[l])] - obs_cat_scores[cat] = float(np.mean(vals)) if vals else np.nan + vals = [tree_sm[l] for l in tree_cl[cat] if l in tree_sm and not np.isnan(tree_sm[l])] + tree_obs_cat[cat] = float(np.mean(vals)) if vals else np.nan - all_leaf_list = list(leaf_to_cat.keys()) + tree_leaf_list = list(tree_lc.keys()) perm_seeds = np.random.randint(0, 2**31, size=n_permutations) if permutation_mode == "non_target": - t_leaves = cat_to_leaves[target] - t_set = set(t_leaves) - nt_leaves = [l for l in all_leaf_list if l not in t_set] + t_lv = tree_cl[target] + t_set = set(t_lv) + nt_lv = [l for l in tree_leaf_list if l not in t_set] _PERM_SINGLE_NON_TARGET_DATA.clear() _PERM_SINGLE_NON_TARGET_DATA.update({ - "fixed_scores": score_map, # precomputed above, target leaves fixed - "target_leaves": list(t_leaves), - "nt_leaves": nt_leaves, - "nt_cats": [leaf_to_cat[l] for l in nt_leaves], + "fixed_scores": tree_sm, + "target_leaves": list(t_lv), + "nt_leaves": nt_lv, + "nt_cats": [tree_lc[l] for l in nt_lv], "target": target, "all_cats": all_cats, }) @@ -744,9 +743,9 @@ def ancestral_linkage( _PERM_SINGLE_DATA.clear() _PERM_SINGLE_DATA.update({ "tdata": tdata, - "trees": trees, - "all_leaves": all_leaf_list, - "all_cat_vals": [leaf_to_cat[l] for l in all_leaf_list], + "trees": single_tree, + "all_leaves": tree_leaf_list, + "all_cat_vals": [tree_lc[l] for l in tree_leaf_list], "target": target, "single_agg": single_agg, "metric": metric, @@ -755,19 +754,21 @@ def ancestral_linkage( }) null_results = _run_parallel(_perm_single_target_worker, perm_seeds, n_threads) - null_cat_scores: dict = defaultdict(list) + null_cat: dict = defaultdict(list) for perm_result in null_results: for cat in all_cats: - null_cat_scores[cat].append(perm_result[cat]) + null_cat[cat].append(perm_result[cat]) - rows = [] + rows: list = [] + cat_null_mean: dict = {} + cat_null_std: dict = {} for cat in all_cats: - obs_val = obs_cat_scores[cat] - null_vals = np.array([v for v in null_cat_scores[cat] if not np.isnan(v)], dtype=float) + obs_val = tree_obs_cat[cat] + null_vals = np.array([v for v in null_cat[cat] if not np.isnan(v)], dtype=float) if len(null_vals) > 0: perm_val = float(np.mean(null_vals)) - sign = 1.0 if metric == "lca" else -1.0 - z = sign * (obs_val - perm_val) / (float(np.std(null_vals)) + 1e-10) + null_std = float(np.std(null_vals)) + z = sign * (obs_val - perm_val) / (null_std + 1e-10) if alternative == "two-sided": p = float(np.mean(np.abs(null_vals - perm_val) >= abs(obs_val - perm_val))) elif metric == "lca": @@ -775,32 +776,76 @@ def ancestral_linkage( else: p = float(np.mean(null_vals <= obs_val)) else: - perm_val, z, p = np.nan, np.nan, np.nan - rows.append( - { - "source": cat, - "target": target, - "value": obs_val, - "permuted_value": perm_val, - "z_score": z, - "p_value": p, - } + perm_val, null_std, z, p = np.nan, 0.0, np.nan, np.nan + cat_null_mean[cat] = perm_val + cat_null_std[cat] = null_std + row: dict = {"source": cat, "target": target, "value": obs_val, + "permuted_value": perm_val, "z_score": z, "p_value": p} + if extra_row_fields: + row.update(extra_row_fields) + rows.append(row) + + norm_map: dict = {} + for leaf, score in tree_sm.items(): + cat = tree_lc.get(leaf) + if cat is not None and not np.isnan(score): + norm_map[leaf] = sign * (score - cat_null_mean.get(cat, np.nan)) / (cat_null_std.get(cat, 0.0) + 1e-10) + else: + norm_map[leaf] = np.nan + return rows, norm_map + + if by_tree and test == "permutation": + # Per-tree: compute scores, run permutation, normalize per cell independently + merged_score_map: dict = {} + merged_norm_map: dict = {} + all_rows: list = [] + + for tree_key, t in trees.items(): + t_nodes = set(t.nodes()) + single_tree = {tree_key: t} + tree_lc: dict = {l: c for l, c in leaf_to_cat.items() if l in t_nodes} + tree_cl: dict = defaultdict(list) + for l, c in tree_lc.items(): + tree_cl[c].append(l) + + tree_all_scores = _compute_scores( + tdata, single_tree, tree_lc, [target], single_agg, metric, depth_key ) + tree_sm: dict = {l: s.get(target, np.nan) for l, s in tree_all_scores.items()} + merged_score_map.update(tree_sm) + + rows, norm_map = _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl, + extra_row_fields={"tree": tree_key}) + all_rows.extend(rows) + merged_norm_map.update(norm_map) - test_df = pd.DataFrame(rows) + tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(merged_score_map, dtype=float)) + tdata.obs[f"{target}_norm_linkage"] = tdata.obs.index.map(pd.Series(merged_norm_map, dtype=float)) + test_df = pd.DataFrame(all_rows) tdata.uns[f"{key_added}_test"] = test_df if copy: return test_df - if copy: - result_series = pd.Series( - { - cat: float(np.nanmean([score_map.get(l, np.nan) for l in cat_to_leaves[cat]])) - for cat in all_cats - }, - name=f"{target}_linkage", - ) - return result_series.to_frame() + else: + # Global (non-by_tree) path + all_scores = _compute_scores(tdata, trees, leaf_to_cat, [target], single_agg, metric, depth_key) + score_map = {leaf: scores.get(target, np.nan) for leaf, scores in all_scores.items()} + tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(score_map, dtype=float)) + + if test == "permutation": + rows, norm_map = _run_single_perm(trees, leaf_to_cat, score_map, cat_to_leaves) + tdata.obs[f"{target}_norm_linkage"] = tdata.obs.index.map(pd.Series(norm_map, dtype=float)) + test_df = pd.DataFrame(rows) + tdata.uns[f"{key_added}_test"] = test_df + if copy: + return test_df + + if copy: + result_series = pd.Series( + {cat: float(np.nanmean([score_map.get(l, np.nan) for l in cat_to_leaves[cat]])) for cat in all_cats}, + name=f"{target}_linkage", + ) + return result_series.to_frame() # ── pairwise mode ───────────────────────────────────────────────────────── else: diff --git a/tests/test_ancestral_linkage.py b/tests/test_ancestral_linkage.py index 172082f..a840653 100644 --- a/tests/test_ancestral_linkage.py +++ b/tests/test_ancestral_linkage.py @@ -669,3 +669,75 @@ def test_symmetrize_invalid_raises(): with pytest.raises(ValueError, match="symmetrize"): tl.ancestral_linkage(tdata, groupby="ct", symmetrize="meen") + + +# ── permutation improvement tests ───────────────────────────────────────────── + + +def test_pairwise_permutation_linkage_is_value_minus_permuted(balanced_tdata): + """uns linkage matrix stores observed - permuted_mean (not z-scores) when test='permutation'.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype", test="permutation", n_permutations=20, random_state=0) + linkage = tdata.uns["celltype_linkage"] + stats = tdata.uns["celltype_linkage_stats"] + for _, row in stats.iterrows(): + expected = row["value"] - row["permuted_value"] + actual = linkage.loc[row["source"], row["target"]] + assert actual == pytest.approx(expected, abs=1e-9) + + +def test_single_target_permutation_adds_norm_linkage(balanced_tdata): + """test='permutation' in single-target mode adds _norm_linkage column to obs.""" + tdata = balanced_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", target="B", test="permutation", + n_permutations=30, random_state=1, + ) + assert "B_linkage" in tdata.obs.columns + assert "B_norm_linkage" in tdata.obs.columns + # norm_linkage should be finite for cells with valid scores + valid = tdata.obs["B_linkage"].notna() + assert tdata.obs.loc[valid, "B_norm_linkage"].notna().all() + + +def test_single_target_norm_linkage_sign(balanced_tdata): + """Within-target cells should have positive norm_linkage (closer to target = higher z-score).""" + tdata = balanced_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", target="B", test="permutation", + n_permutations=100, random_state=42, + ) + # B cells are closest to B (path=1.0); A cells are farther (path=2.0) + # For metric='path', sign=-1, so B cells get higher (less negative) norm_linkage + b_norm = tdata.obs.loc[["b1", "b2"], "B_norm_linkage"].mean() + a_norm = tdata.obs.loc[["a1", "a2"], "B_norm_linkage"].mean() + assert b_norm > a_norm + + +def test_single_target_by_tree_permutation(three_cat_tdata): + """by_tree=True with single target + permutation runs per-tree and adds norm_linkage.""" + tdata = three_cat_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", target="A", test="permutation", by_tree=True, + n_permutations=20, random_state=7, + ) + assert "A_linkage" in tdata.obs.columns + assert "A_norm_linkage" in tdata.obs.columns + test_df = tdata.uns["celltype_test"] + # by_tree adds a "tree" column + assert "tree" in test_df.columns + # Each tree in tdata.obst should appear in the test results + for tree_key in tdata.obst: + assert tree_key in test_df["tree"].values + + +def test_single_target_by_tree_perm_non_target(balanced_tdata): + """by_tree + permutation_mode='non_target' in single-target mode runs and produces norm_linkage.""" + tdata = balanced_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", target="B", test="permutation", + by_tree=True, permutation_mode="non_target", n_permutations=20, random_state=3, + ) + assert "B_norm_linkage" in tdata.obs.columns + test_df = tdata.uns["celltype_test"] + assert "tree" in test_df.columns From 57858939170f7fe91eae72dbc70d945dbe4d0b19 Mon Sep 17 00:00:00 2001 From: colganwi Date: Sun, 29 Mar 2026 22:14:48 -0400 Subject: [PATCH 2/2] feat(tl): add normalize parameter to ancestral_linkage permutation test Adds normalize: bool = False to ancestral_linkage. When True and test='permutation': - Pairwise: linkage matrix stores observed - permuted_mean instead of raw values. - Single-target: tdata.obs['{target}_linkage'] is replaced with cell_score - category_permuted_mean. - Both modes support by_tree, where per-tree null distributions are used for normalization independently per tree. Co-Authored-By: Claude Sonnet 4.6 --- src/pycea/tl/ancestral_linkage.py | 62 +++++++++++--------- tests/test_ancestral_linkage.py | 95 ++++++++++++++++++++----------- 2 files changed, 99 insertions(+), 58 deletions(-) diff --git a/src/pycea/tl/ancestral_linkage.py b/src/pycea/tl/ancestral_linkage.py index 30ef505..8a173e3 100644 --- a/src/pycea/tl/ancestral_linkage.py +++ b/src/pycea/tl/ancestral_linkage.py @@ -477,6 +477,7 @@ def ancestral_linkage( metric: _TreeMetric = "path", symmetrize: Literal["mean", "max", "min", None] = None, test: Literal["permutation", None] = None, + normalize: bool = False, alternative: Literal["two-sided", None] = None, permutation_mode: Literal["all", "non_target"] = "all", n_permutations: int = 100, @@ -499,6 +500,7 @@ def ancestral_linkage( metric: _TreeMetric = "path", symmetrize: Literal["mean", "max", "min", None] = None, test: Literal["permutation", None] = None, + normalize: bool = False, alternative: Literal["two-sided", None] = None, permutation_mode: Literal["all", "non_target"] = "all", n_permutations: int = 100, @@ -519,6 +521,7 @@ def ancestral_linkage( metric: _TreeMetric = "path", symmetrize: Literal["mean", "max", "min", None] = None, test: Literal["permutation", None] = None, + normalize: bool = False, alternative: Literal["two-sided", None] = None, permutation_mode: Literal["all", "non_target"] = "all", n_permutations: int = 100, @@ -577,8 +580,12 @@ def ancestral_linkage( - ``'permutation'``: randomly shuffle cell-category labels ``n_permutations`` times and recompute linkage each time to build a null distribution. - Z-scores and p-values are added to the stats table. The stored linkage - matrix is replaced by ``observed - permuted_mean`` when this test is run. + Z-scores and p-values are added to the stats table. + normalize + If ``True`` and ``test='permutation'``, subtract the permuted mean from the + observed values: pairwise linkage matrix becomes ``observed - permuted_mean``; + single-target ``tdata.obs['{target}_linkage']`` becomes + ``cell_score - category_permuted_mean``. Ignored when ``test=None``. alternative The alternative hypothesis for the permutation test (ignored when ``test=None``): @@ -636,13 +643,13 @@ def ancestral_linkage( Sets the following fields: * ``tdata.obs['{target}_linkage']`` : :class:`Series ` (dtype ``float``) – single-target mode only. - Per-cell distance to the nearest cell of the target category. - * ``tdata.obs['{target}_norm_linkage']`` : :class:`Series ` (dtype ``float``) – single-target mode with ``test='permutation'`` only. - Per-cell z-score: ``sign * (score - null_mean) / null_std``, where the null - distribution is taken from the cell's source-category permutations. + Per-cell distance to the nearest cell of the target category. When + ``normalize=True`` and ``test='permutation'``, replaced by + ``cell_score - category_permuted_mean``. * ``tdata.uns['{key_added}_linkage']`` : :class:`DataFrame ` – pairwise mode only. Category × category linkage matrix (source rows, target columns). - Contains ``observed - permuted_mean`` instead of raw distances when ``test='permutation'``. + When ``normalize=True`` and ``test='permutation'``, contains + ``observed - permuted_mean`` instead of raw distances. * ``tdata.uns['{key_added}_linkage_params']`` : ``dict`` – pairwise mode only. Parameters used to compute the linkage matrix. * ``tdata.uns['{key_added}_linkage_stats']`` : :class:`DataFrame ` – pairwise mode only. @@ -716,7 +723,7 @@ def ancestral_linkage( sign = 1.0 if metric == "lca" else -1.0 def _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl, extra_row_fields=None): - """Run the permutation test for one tree (or globally) and return (rows, norm_map).""" + """Run the permutation test for one tree (or globally) and return (rows, cat_null_mean).""" tree_obs_cat: dict = {} for cat in all_cats: vals = [tree_sm[l] for l in tree_cl[cat] if l in tree_sm and not np.isnan(tree_sm[l])] @@ -761,7 +768,6 @@ def _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl, extra_row_fields=No rows: list = [] cat_null_mean: dict = {} - cat_null_std: dict = {} for cat in all_cats: obs_val = tree_obs_cat[cat] null_vals = np.array([v for v in null_cat[cat] if not np.isnan(v)], dtype=float) @@ -778,21 +784,13 @@ def _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl, extra_row_fields=No else: perm_val, null_std, z, p = np.nan, 0.0, np.nan, np.nan cat_null_mean[cat] = perm_val - cat_null_std[cat] = null_std row: dict = {"source": cat, "target": target, "value": obs_val, "permuted_value": perm_val, "z_score": z, "p_value": p} if extra_row_fields: row.update(extra_row_fields) rows.append(row) - norm_map: dict = {} - for leaf, score in tree_sm.items(): - cat = tree_lc.get(leaf) - if cat is not None and not np.isnan(score): - norm_map[leaf] = sign * (score - cat_null_mean.get(cat, np.nan)) / (cat_null_std.get(cat, 0.0) + 1e-10) - else: - norm_map[leaf] = np.nan - return rows, norm_map + return rows, cat_null_mean if by_tree and test == "permutation": # Per-tree: compute scores, run permutation, normalize per cell independently @@ -814,13 +812,18 @@ def _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl, extra_row_fields=No tree_sm: dict = {l: s.get(target, np.nan) for l, s in tree_all_scores.items()} merged_score_map.update(tree_sm) - rows, norm_map = _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl, - extra_row_fields={"tree": tree_key}) + rows, cat_null_mean = _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl, + extra_row_fields={"tree": tree_key}) all_rows.extend(rows) - merged_norm_map.update(norm_map) + if normalize: + for leaf, score in tree_sm.items(): + cat = tree_lc.get(leaf) + perm_val = cat_null_mean.get(cat, np.nan) if cat is not None else np.nan + merged_norm_map[leaf] = (score - perm_val) if not np.isnan(score) else np.nan tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(merged_score_map, dtype=float)) - tdata.obs[f"{target}_norm_linkage"] = tdata.obs.index.map(pd.Series(merged_norm_map, dtype=float)) + if normalize: + tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(merged_norm_map, dtype=float)) test_df = pd.DataFrame(all_rows) tdata.uns[f"{key_added}_test"] = test_df if copy: @@ -833,8 +836,14 @@ def _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl, extra_row_fields=No tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(score_map, dtype=float)) if test == "permutation": - rows, norm_map = _run_single_perm(trees, leaf_to_cat, score_map, cat_to_leaves) - tdata.obs[f"{target}_norm_linkage"] = tdata.obs.index.map(pd.Series(norm_map, dtype=float)) + rows, cat_null_mean = _run_single_perm(trees, leaf_to_cat, score_map, cat_to_leaves) + if normalize: + norm_map = { + leaf: (score - cat_null_mean.get(leaf_to_cat.get(leaf), np.nan)) + if not np.isnan(score) else np.nan + for leaf, score in score_map.items() + } + tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(norm_map, dtype=float)) test_df = pd.DataFrame(rows) tdata.uns[f"{key_added}_test"] = test_df if copy: @@ -931,8 +940,8 @@ def _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl, extra_row_fields=No row["p_value"] = global_p_df.loc[src_cat, tgt_cat] stats_rows.append(row) - # uns[linkage] = observed - permuted_mean (symmetrized) if test ran, else raw linkage (symmetrized) - output_df: pd.DataFrame = (linkage_df - global_null_mean_df) if test == "permutation" else linkage_df + # uns[linkage] = observed - permuted_mean if normalize, else raw linkage (both symmetrized if requested) + output_df: pd.DataFrame = (linkage_df - global_null_mean_df) if (test == "permutation" and normalize) else linkage_df if symmetrize is not None: output_df = _symmetrize_matrix(output_df, symmetrize) @@ -942,6 +951,7 @@ def _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl, extra_row_fields=No "metric": metric, "symmetrize": symmetrize, "test": test, + "normalize": normalize, "by_tree": by_tree, "depth_key": depth_key, } diff --git a/tests/test_ancestral_linkage.py b/tests/test_ancestral_linkage.py index a840653..382d359 100644 --- a/tests/test_ancestral_linkage.py +++ b/tests/test_ancestral_linkage.py @@ -671,73 +671,104 @@ def test_symmetrize_invalid_raises(): tl.ancestral_linkage(tdata, groupby="ct", symmetrize="meen") -# ── permutation improvement tests ───────────────────────────────────────────── +# ── normalize parameter tests ────────────────────────────────────────────────── -def test_pairwise_permutation_linkage_is_value_minus_permuted(balanced_tdata): - """uns linkage matrix stores observed - permuted_mean (not z-scores) when test='permutation'.""" +def test_pairwise_normalize_stores_value_minus_permuted(balanced_tdata): + """normalize=True stores observed - permuted_mean in the pairwise linkage matrix.""" tdata = balanced_tdata - tl.ancestral_linkage(tdata, groupby="celltype", test="permutation", n_permutations=20, random_state=0) + tl.ancestral_linkage( + tdata, groupby="celltype", test="permutation", normalize=True, + n_permutations=20, random_state=0, + ) linkage = tdata.uns["celltype_linkage"] stats = tdata.uns["celltype_linkage_stats"] for _, row in stats.iterrows(): expected = row["value"] - row["permuted_value"] - actual = linkage.loc[row["source"], row["target"]] - assert actual == pytest.approx(expected, abs=1e-9) + assert linkage.loc[row["source"], row["target"]] == pytest.approx(expected, abs=1e-9) -def test_single_target_permutation_adds_norm_linkage(balanced_tdata): - """test='permutation' in single-target mode adds _norm_linkage column to obs.""" +def test_pairwise_no_normalize_stores_raw_linkage(balanced_tdata): + """normalize=False (default) stores raw linkage even when test='permutation'.""" tdata = balanced_tdata tl.ancestral_linkage( - tdata, groupby="celltype", target="B", test="permutation", - n_permutations=30, random_state=1, + tdata, groupby="celltype", test="permutation", normalize=False, + n_permutations=20, random_state=0, ) - assert "B_linkage" in tdata.obs.columns - assert "B_norm_linkage" in tdata.obs.columns - # norm_linkage should be finite for cells with valid scores - valid = tdata.obs["B_linkage"].notna() - assert tdata.obs.loc[valid, "B_norm_linkage"].notna().all() + linkage = tdata.uns["celltype_linkage"] + stats = tdata.uns["celltype_linkage_stats"] + for _, row in stats.iterrows(): + assert linkage.loc[row["source"], row["target"]] == pytest.approx(row["value"], abs=1e-9) -def test_single_target_norm_linkage_sign(balanced_tdata): - """Within-target cells should have positive norm_linkage (closer to target = higher z-score).""" +def test_single_target_normalize_overwrites_linkage(balanced_tdata): + """normalize=True replaces _linkage in obs with score - category_permuted_mean.""" tdata = balanced_tdata - tl.ancestral_linkage( - tdata, groupby="celltype", target="B", test="permutation", - n_permutations=100, random_state=42, - ) - # B cells are closest to B (path=1.0); A cells are farther (path=2.0) - # For metric='path', sign=-1, so B cells get higher (less negative) norm_linkage - b_norm = tdata.obs.loc[["b1", "b2"], "B_norm_linkage"].mean() - a_norm = tdata.obs.loc[["a1", "a2"], "B_norm_linkage"].mean() - assert b_norm > a_norm + # Raw linkage without normalize + tl.ancestral_linkage(tdata, groupby="celltype", target="B", test="permutation", + normalize=False, n_permutations=30, random_state=1) + raw = tdata.obs["B_linkage"].copy() + + # Normalized linkage + tl.ancestral_linkage(tdata, groupby="celltype", target="B", test="permutation", + normalize=True, n_permutations=30, random_state=1) + norm = tdata.obs["B_linkage"].copy() + test_df = tdata.uns["celltype_test"] + + # Normalized values should differ from raw (permuted_mean is non-trivial) + # And for each cell: norm = raw - cat_permuted_mean + for cell in ["a1", "a2", "b1", "b2"]: + cat = tdata.obs.loc[cell, "celltype"] + perm_val = test_df.loc[test_df["source"] == cat, "permuted_value"].iloc[0] + assert norm[cell] == pytest.approx(raw[cell] - perm_val, abs=1e-9) + + assert "B_norm_linkage" not in tdata.obs.columns + + +def test_single_target_no_normalize_keeps_raw(balanced_tdata): + """normalize=False (default) does not overwrite _linkage in obs.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype", target="B", test="permutation", + normalize=False, n_permutations=20, random_state=2) + # With normalize=False, a1 (category A, not in B) has raw min-path score 2.0 + assert tdata.obs.loc["a1", "B_linkage"] == pytest.approx(2.0) def test_single_target_by_tree_permutation(three_cat_tdata): - """by_tree=True with single target + permutation runs per-tree and adds norm_linkage.""" + """by_tree=True with single target + permutation runs per-tree (stats has tree column).""" tdata = three_cat_tdata tl.ancestral_linkage( tdata, groupby="celltype", target="A", test="permutation", by_tree=True, n_permutations=20, random_state=7, ) assert "A_linkage" in tdata.obs.columns - assert "A_norm_linkage" in tdata.obs.columns + assert "A_norm_linkage" not in tdata.obs.columns test_df = tdata.uns["celltype_test"] - # by_tree adds a "tree" column assert "tree" in test_df.columns - # Each tree in tdata.obst should appear in the test results for tree_key in tdata.obst: assert tree_key in test_df["tree"].values +def test_single_target_by_tree_normalize(three_cat_tdata): + """by_tree=True + normalize=True normalizes each cell using its own tree's null distribution.""" + tdata = three_cat_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", target="A", test="permutation", + by_tree=True, normalize=True, n_permutations=20, random_state=7, + ) + assert "A_linkage" in tdata.obs.columns + assert "A_norm_linkage" not in tdata.obs.columns + # _linkage values should be finite for leaves with valid scores + assert tdata.obs["A_linkage"].notna().any() + + def test_single_target_by_tree_perm_non_target(balanced_tdata): - """by_tree + permutation_mode='non_target' in single-target mode runs and produces norm_linkage.""" + """by_tree + permutation_mode='non_target' in single-target mode produces tree column in stats.""" tdata = balanced_tdata tl.ancestral_linkage( tdata, groupby="celltype", target="B", test="permutation", by_tree=True, permutation_mode="non_target", n_permutations=20, random_state=3, ) - assert "B_norm_linkage" in tdata.obs.columns + assert "B_norm_linkage" not in tdata.obs.columns test_df = tdata.uns["celltype_test"] assert "tree" in test_df.columns