diff --git a/src/pycea/tl/ancestral_linkage.py b/src/pycea/tl/ancestral_linkage.py index b285a4a..8a173e3 100644 --- a/src/pycea/tl/ancestral_linkage.py +++ b/src/pycea/tl/ancestral_linkage.py @@ -477,6 +477,7 @@ def ancestral_linkage( metric: _TreeMetric = "path", symmetrize: Literal["mean", "max", "min", None] = None, test: Literal["permutation", None] = None, + normalize: bool = False, alternative: Literal["two-sided", None] = None, permutation_mode: Literal["all", "non_target"] = "all", n_permutations: int = 100, @@ -499,6 +500,7 @@ def ancestral_linkage( metric: _TreeMetric = "path", symmetrize: Literal["mean", "max", "min", None] = None, test: Literal["permutation", None] = None, + normalize: bool = False, alternative: Literal["two-sided", None] = None, permutation_mode: Literal["all", "non_target"] = "all", n_permutations: int = 100, @@ -519,6 +521,7 @@ def ancestral_linkage( metric: _TreeMetric = "path", symmetrize: Literal["mean", "max", "min", None] = None, test: Literal["permutation", None] = None, + normalize: bool = False, alternative: Literal["two-sided", None] = None, permutation_mode: Literal["all", "non_target"] = "all", n_permutations: int = 100, @@ -577,8 +580,12 @@ def ancestral_linkage( - ``'permutation'``: randomly shuffle cell-category labels ``n_permutations`` times and recompute linkage each time to build a null distribution. - Z-scores and p-values are added to the stats table. The stored linkage - matrix is replaced by z-scores when this test is run. + Z-scores and p-values are added to the stats table. + normalize + If ``True`` and ``test='permutation'``, subtract the permuted mean from the + observed values: pairwise linkage matrix becomes ``observed - permuted_mean``; + single-target ``tdata.obs['{target}_linkage']`` becomes + ``cell_score - category_permuted_mean``. Ignored when ``test=None``. alternative The alternative hypothesis for the permutation test (ignored when ``test=None``): @@ -636,10 +643,13 @@ def ancestral_linkage( Sets the following fields: * ``tdata.obs['{target}_linkage']`` : :class:`Series ` (dtype ``float``) – single-target mode only. - Per-cell distance to the nearest cell of the target category. + Per-cell distance to the nearest cell of the target category. When + ``normalize=True`` and ``test='permutation'``, replaced by + ``cell_score - category_permuted_mean``. * ``tdata.uns['{key_added}_linkage']`` : :class:`DataFrame ` – pairwise mode only. Category × category linkage matrix (source rows, target columns). - Contains z-scores instead of raw distances when ``test='permutation'``. + When ``normalize=True`` and ``test='permutation'``, contains + ``observed - permuted_mean`` instead of raw distances. * ``tdata.uns['{key_added}_linkage_params']`` : ``dict`` – pairwise mode only. Parameters used to compute the linkage matrix. * ``tdata.uns['{key_added}_linkage_stats']`` : :class:`DataFrame ` – pairwise mode only. @@ -710,32 +720,28 @@ def ancestral_linkage( # Always use "closest": min path or max lca single_agg: str = "max" if metric == "lca" else "min" - all_scores = _compute_scores(tdata, trees, leaf_to_cat, [target], single_agg, metric, depth_key) + sign = 1.0 if metric == "lca" else -1.0 - # Per-leaf scores - score_map = {leaf: scores.get(target, np.nan) for leaf, scores in all_scores.items()} - tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(score_map, dtype=float)) - - if test == "permutation": - # Observed: mean per source category - obs_cat_scores: dict = {} + def _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl, extra_row_fields=None): + """Run the permutation test for one tree (or globally) and return (rows, cat_null_mean).""" + tree_obs_cat: dict = {} for cat in all_cats: - vals = [score_map[l] for l in cat_to_leaves[cat] if l in score_map and not np.isnan(score_map[l])] - obs_cat_scores[cat] = float(np.mean(vals)) if vals else np.nan + vals = [tree_sm[l] for l in tree_cl[cat] if l in tree_sm and not np.isnan(tree_sm[l])] + tree_obs_cat[cat] = float(np.mean(vals)) if vals else np.nan - all_leaf_list = list(leaf_to_cat.keys()) + tree_leaf_list = list(tree_lc.keys()) perm_seeds = np.random.randint(0, 2**31, size=n_permutations) if permutation_mode == "non_target": - t_leaves = cat_to_leaves[target] - t_set = set(t_leaves) - nt_leaves = [l for l in all_leaf_list if l not in t_set] + t_lv = tree_cl[target] + t_set = set(t_lv) + nt_lv = [l for l in tree_leaf_list if l not in t_set] _PERM_SINGLE_NON_TARGET_DATA.clear() _PERM_SINGLE_NON_TARGET_DATA.update({ - "fixed_scores": score_map, # precomputed above, target leaves fixed - "target_leaves": list(t_leaves), - "nt_leaves": nt_leaves, - "nt_cats": [leaf_to_cat[l] for l in nt_leaves], + "fixed_scores": tree_sm, + "target_leaves": list(t_lv), + "nt_leaves": nt_lv, + "nt_cats": [tree_lc[l] for l in nt_lv], "target": target, "all_cats": all_cats, }) @@ -744,9 +750,9 @@ def ancestral_linkage( _PERM_SINGLE_DATA.clear() _PERM_SINGLE_DATA.update({ "tdata": tdata, - "trees": trees, - "all_leaves": all_leaf_list, - "all_cat_vals": [leaf_to_cat[l] for l in all_leaf_list], + "trees": single_tree, + "all_leaves": tree_leaf_list, + "all_cat_vals": [tree_lc[l] for l in tree_leaf_list], "target": target, "single_agg": single_agg, "metric": metric, @@ -755,19 +761,20 @@ def ancestral_linkage( }) null_results = _run_parallel(_perm_single_target_worker, perm_seeds, n_threads) - null_cat_scores: dict = defaultdict(list) + null_cat: dict = defaultdict(list) for perm_result in null_results: for cat in all_cats: - null_cat_scores[cat].append(perm_result[cat]) + null_cat[cat].append(perm_result[cat]) - rows = [] + rows: list = [] + cat_null_mean: dict = {} for cat in all_cats: - obs_val = obs_cat_scores[cat] - null_vals = np.array([v for v in null_cat_scores[cat] if not np.isnan(v)], dtype=float) + obs_val = tree_obs_cat[cat] + null_vals = np.array([v for v in null_cat[cat] if not np.isnan(v)], dtype=float) if len(null_vals) > 0: perm_val = float(np.mean(null_vals)) - sign = 1.0 if metric == "lca" else -1.0 - z = sign * (obs_val - perm_val) / (float(np.std(null_vals)) + 1e-10) + null_std = float(np.std(null_vals)) + z = sign * (obs_val - perm_val) / (null_std + 1e-10) if alternative == "two-sided": p = float(np.mean(np.abs(null_vals - perm_val) >= abs(obs_val - perm_val))) elif metric == "lca": @@ -775,32 +782,79 @@ def ancestral_linkage( else: p = float(np.mean(null_vals <= obs_val)) else: - perm_val, z, p = np.nan, np.nan, np.nan - rows.append( - { - "source": cat, - "target": target, - "value": obs_val, - "permuted_value": perm_val, - "z_score": z, - "p_value": p, - } - ) + perm_val, null_std, z, p = np.nan, 0.0, np.nan, np.nan + cat_null_mean[cat] = perm_val + row: dict = {"source": cat, "target": target, "value": obs_val, + "permuted_value": perm_val, "z_score": z, "p_value": p} + if extra_row_fields: + row.update(extra_row_fields) + rows.append(row) + + return rows, cat_null_mean + + if by_tree and test == "permutation": + # Per-tree: compute scores, run permutation, normalize per cell independently + merged_score_map: dict = {} + merged_norm_map: dict = {} + all_rows: list = [] + + for tree_key, t in trees.items(): + t_nodes = set(t.nodes()) + single_tree = {tree_key: t} + tree_lc: dict = {l: c for l, c in leaf_to_cat.items() if l in t_nodes} + tree_cl: dict = defaultdict(list) + for l, c in tree_lc.items(): + tree_cl[c].append(l) - test_df = pd.DataFrame(rows) + tree_all_scores = _compute_scores( + tdata, single_tree, tree_lc, [target], single_agg, metric, depth_key + ) + tree_sm: dict = {l: s.get(target, np.nan) for l, s in tree_all_scores.items()} + merged_score_map.update(tree_sm) + + rows, cat_null_mean = _run_single_perm(single_tree, tree_lc, tree_sm, tree_cl, + extra_row_fields={"tree": tree_key}) + all_rows.extend(rows) + if normalize: + for leaf, score in tree_sm.items(): + cat = tree_lc.get(leaf) + perm_val = cat_null_mean.get(cat, np.nan) if cat is not None else np.nan + merged_norm_map[leaf] = (score - perm_val) if not np.isnan(score) else np.nan + + tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(merged_score_map, dtype=float)) + if normalize: + tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(merged_norm_map, dtype=float)) + test_df = pd.DataFrame(all_rows) tdata.uns[f"{key_added}_test"] = test_df if copy: return test_df - if copy: - result_series = pd.Series( - { - cat: float(np.nanmean([score_map.get(l, np.nan) for l in cat_to_leaves[cat]])) - for cat in all_cats - }, - name=f"{target}_linkage", - ) - return result_series.to_frame() + else: + # Global (non-by_tree) path + all_scores = _compute_scores(tdata, trees, leaf_to_cat, [target], single_agg, metric, depth_key) + score_map = {leaf: scores.get(target, np.nan) for leaf, scores in all_scores.items()} + tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(score_map, dtype=float)) + + if test == "permutation": + rows, cat_null_mean = _run_single_perm(trees, leaf_to_cat, score_map, cat_to_leaves) + if normalize: + norm_map = { + leaf: (score - cat_null_mean.get(leaf_to_cat.get(leaf), np.nan)) + if not np.isnan(score) else np.nan + for leaf, score in score_map.items() + } + tdata.obs[f"{target}_linkage"] = tdata.obs.index.map(pd.Series(norm_map, dtype=float)) + test_df = pd.DataFrame(rows) + tdata.uns[f"{key_added}_test"] = test_df + if copy: + return test_df + + if copy: + result_series = pd.Series( + {cat: float(np.nanmean([score_map.get(l, np.nan) for l in cat_to_leaves[cat]])) for cat in all_cats}, + name=f"{target}_linkage", + ) + return result_series.to_frame() # ── pairwise mode ───────────────────────────────────────────────────────── else: @@ -886,8 +940,8 @@ def ancestral_linkage( row["p_value"] = global_p_df.loc[src_cat, tgt_cat] stats_rows.append(row) - # uns[linkage] = observed - permuted_mean (symmetrized) if test ran, else raw linkage (symmetrized) - output_df: pd.DataFrame = (linkage_df - global_null_mean_df) if test == "permutation" else linkage_df + # uns[linkage] = observed - permuted_mean if normalize, else raw linkage (both symmetrized if requested) + output_df: pd.DataFrame = (linkage_df - global_null_mean_df) if (test == "permutation" and normalize) else linkage_df if symmetrize is not None: output_df = _symmetrize_matrix(output_df, symmetrize) @@ -897,6 +951,7 @@ def ancestral_linkage( "metric": metric, "symmetrize": symmetrize, "test": test, + "normalize": normalize, "by_tree": by_tree, "depth_key": depth_key, } diff --git a/tests/test_ancestral_linkage.py b/tests/test_ancestral_linkage.py index 172082f..382d359 100644 --- a/tests/test_ancestral_linkage.py +++ b/tests/test_ancestral_linkage.py @@ -669,3 +669,106 @@ def test_symmetrize_invalid_raises(): with pytest.raises(ValueError, match="symmetrize"): tl.ancestral_linkage(tdata, groupby="ct", symmetrize="meen") + + +# ── normalize parameter tests ────────────────────────────────────────────────── + + +def test_pairwise_normalize_stores_value_minus_permuted(balanced_tdata): + """normalize=True stores observed - permuted_mean in the pairwise linkage matrix.""" + tdata = balanced_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", test="permutation", normalize=True, + n_permutations=20, random_state=0, + ) + linkage = tdata.uns["celltype_linkage"] + stats = tdata.uns["celltype_linkage_stats"] + for _, row in stats.iterrows(): + expected = row["value"] - row["permuted_value"] + assert linkage.loc[row["source"], row["target"]] == pytest.approx(expected, abs=1e-9) + + +def test_pairwise_no_normalize_stores_raw_linkage(balanced_tdata): + """normalize=False (default) stores raw linkage even when test='permutation'.""" + tdata = balanced_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", test="permutation", normalize=False, + n_permutations=20, random_state=0, + ) + linkage = tdata.uns["celltype_linkage"] + stats = tdata.uns["celltype_linkage_stats"] + for _, row in stats.iterrows(): + assert linkage.loc[row["source"], row["target"]] == pytest.approx(row["value"], abs=1e-9) + + +def test_single_target_normalize_overwrites_linkage(balanced_tdata): + """normalize=True replaces _linkage in obs with score - category_permuted_mean.""" + tdata = balanced_tdata + # Raw linkage without normalize + tl.ancestral_linkage(tdata, groupby="celltype", target="B", test="permutation", + normalize=False, n_permutations=30, random_state=1) + raw = tdata.obs["B_linkage"].copy() + + # Normalized linkage + tl.ancestral_linkage(tdata, groupby="celltype", target="B", test="permutation", + normalize=True, n_permutations=30, random_state=1) + norm = tdata.obs["B_linkage"].copy() + test_df = tdata.uns["celltype_test"] + + # Normalized values should differ from raw (permuted_mean is non-trivial) + # And for each cell: norm = raw - cat_permuted_mean + for cell in ["a1", "a2", "b1", "b2"]: + cat = tdata.obs.loc[cell, "celltype"] + perm_val = test_df.loc[test_df["source"] == cat, "permuted_value"].iloc[0] + assert norm[cell] == pytest.approx(raw[cell] - perm_val, abs=1e-9) + + assert "B_norm_linkage" not in tdata.obs.columns + + +def test_single_target_no_normalize_keeps_raw(balanced_tdata): + """normalize=False (default) does not overwrite _linkage in obs.""" + tdata = balanced_tdata + tl.ancestral_linkage(tdata, groupby="celltype", target="B", test="permutation", + normalize=False, n_permutations=20, random_state=2) + # With normalize=False, a1 (category A, not in B) has raw min-path score 2.0 + assert tdata.obs.loc["a1", "B_linkage"] == pytest.approx(2.0) + + +def test_single_target_by_tree_permutation(three_cat_tdata): + """by_tree=True with single target + permutation runs per-tree (stats has tree column).""" + tdata = three_cat_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", target="A", test="permutation", by_tree=True, + n_permutations=20, random_state=7, + ) + assert "A_linkage" in tdata.obs.columns + assert "A_norm_linkage" not in tdata.obs.columns + test_df = tdata.uns["celltype_test"] + assert "tree" in test_df.columns + for tree_key in tdata.obst: + assert tree_key in test_df["tree"].values + + +def test_single_target_by_tree_normalize(three_cat_tdata): + """by_tree=True + normalize=True normalizes each cell using its own tree's null distribution.""" + tdata = three_cat_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", target="A", test="permutation", + by_tree=True, normalize=True, n_permutations=20, random_state=7, + ) + assert "A_linkage" in tdata.obs.columns + assert "A_norm_linkage" not in tdata.obs.columns + # _linkage values should be finite for leaves with valid scores + assert tdata.obs["A_linkage"].notna().any() + + +def test_single_target_by_tree_perm_non_target(balanced_tdata): + """by_tree + permutation_mode='non_target' in single-target mode produces tree column in stats.""" + tdata = balanced_tdata + tl.ancestral_linkage( + tdata, groupby="celltype", target="B", test="permutation", + by_tree=True, permutation_mode="non_target", n_permutations=20, random_state=3, + ) + assert "B_norm_linkage" not in tdata.obs.columns + test_df = tdata.uns["celltype_test"] + assert "tree" in test_df.columns