Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/pycea/tl/ancestral_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,18 @@ def traceback(node, parent_value_index):
del tree.nodes[node]["_pointers"]


def _reconstruct_sum(tree: nx.DiGraph, key: str, index: int | None, fixed_nodes: set | None = None) -> None:
"""Reconstructs ancestral states by summing leaf values with an iterative bottom-up traversal."""
for node in reversed(list(nx.topological_sort(tree))):
val = _get_node_value(tree, node, key, index)
is_fixed = fixed_nodes is not None and node in fixed_nodes and val is not None
if tree.out_degree(node) == 0 or is_fixed:
continue
children_values = [_get_node_value(tree, child, key, index) for child in tree.successors(node)]
valid = [v for v in children_values if v is not None]
_set_node_value(tree, node, key, sum(valid) if valid else None, index)


def _reconstruct_mean(tree: nx.DiGraph, key: str, index: int | None, fixed_nodes: set | None = None) -> None:
"""Reconstructs ancestral by averaging the values of the children."""

Expand Down Expand Up @@ -243,6 +255,8 @@ def _ancestral_states(
_reconstruct_fitch_hartigan(tree, key, missing, index, fixed_nodes)
elif method == "mean":
_reconstruct_mean(tree, key, index, fixed_nodes)
elif method == "sum":
_reconstruct_sum(tree, key, index, fixed_nodes)
elif method == "mode":
_reconstruct_list(tree, key, _most_common, index, fixed_nodes)
elif callable(method):
Expand Down Expand Up @@ -308,6 +322,7 @@ def ancestral_states(
Method to reconstruct ancestral states:

* 'mean' : The mean of leaves in subtree.
* 'sum' : The sum of leaves in subtree (iterative bottom-up traversal).
* 'mode' : The most common value in the subtree.
* 'fitch_hartigan' : The Fitch-Hartigan algorithm.
* 'sankoff' : The Sankoff algorithm with specified costs.
Expand Down Expand Up @@ -365,7 +380,7 @@ def ancestral_states(
if method in ["fitch_hartigan", "sankoff"]:
raise ValueError(f"Method {method} requires categorical data.")
if dtypes.intersection({"O", "S"}):
if method in ["mean"]:
if method in ["mean", "sum"]:
raise ValueError(f"Method {method} requires numeric data.")
# Determine fixed internal nodes for nodes/subset alignment
leaves_set = set(get_leaves(t))
Expand Down
31 changes: 31 additions & 0 deletions tests/test_ancestral_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,37 @@ def test_ancestral_states_nodes_fitch(nodes_tdata):
assert tree.nodes["C"]["str_value"] == "1" # C value preserved


def test_ancestral_states_sum(tdata):
# tree1: root -> B(0), C; C -> D(0), E(3) [index order: B=0, D=0, E=3, F=2]
# C sum = 0+3 = 3; root sum = 0+3 = 3
states = ancestral_states(tdata, "value", method="sum", copy=True)
assert tdata.obst["tree1"].nodes["C"]["value"] == 3
assert tdata.obst["tree1"].nodes["root"]["value"] == 3
# tree2: root -> F(2); root sum = 2
assert tdata.obst["tree2"].nodes["root"]["value"] == 2
assert states is not None
assert states["value"].loc[("tree1", "root")] == 3


def test_ancestral_states_sum_array(tdata):
# spatial tree1: B=[0,4], D=[1,1], E=[2,1]
# C sum = [1+2, 1+1] = [3, 2]; root sum = [0+3, 4+2] = [3, 6]
states = ancestral_states(tdata, "spatial", method="sum", copy=True)
assert tdata.obst["tree1"].nodes["C"]["spatial"] == [3, 2]
assert tdata.obst["tree1"].nodes["root"]["spatial"] == [3, 6]
assert states is not None
assert states.loc[("tree1", "root"), "spatial"] == [3, 6]


def test_ancestral_states_sum_fixed_nodes(nodes_tdata):
# C=5 (fixed), so C is treated as a leaf for reconstruction
# root sum = B(0) + C(5) = 5
ancestral_states(nodes_tdata, "value", method="sum", copy=False)
tree = nodes_tdata.obst["tree"]
assert tree.nodes["C"]["value"] == 5 # fixed, unchanged
assert tree.nodes["root"]["value"] == pytest.approx(5.0)


def test_ancestral_states_invalid(tdata):
with pytest.raises(ValueError):
ancestral_states(tdata, "characters", method="sankoff")
Expand Down
Loading