Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dev = [
"pytest-cov",
"pytest-xdist",
"nox",
"piqtree",
]

[tool.pytest.ini_options]
Expand Down
103 changes: 72 additions & 31 deletions src/phylim/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from cogent3.core.tree import PhyloNode
from cogent3.draw.dendrogram import Dendrogram
from cogent3.evolve import ns_substitution_model, predicate, substitution_model
from cogent3.evolve.models import register_model
from cogent3.evolve.parameter_controller import AlignmentLikelihoodFunction
from cogent3.evolve.predicate import MotifChange

Expand Down Expand Up @@ -243,30 +244,23 @@ class phylim_to_model_result:
this app assume tree object derived from piqtree
"""

excludes = ["length", "mprobs"]

def __init__(self, stationarity: bool = True) -> None:
self.stationarity = stationarity

def main(self, tree: PhyloNode) -> model_result:
params = tree.get_root().params
mprobs = tree.params["mprobs"]

# build predicates, excluding length and mprobs
predicates = [k for k in params.keys() if k not in self.excludes]

params, predicate_names, motif_probs = _parse_params(tree)
# decide model type based on number of rates/predicates
if len(predicates) < 6:
predicates = [predicate.parse(k) for k in predicates]
submodel = substitution_model.TimeReversibleNucleotide(
predicates=predicates
)
if tree.params["model"] == "UNREST":
submodel = _build_unrest_model(predicate_names)
# skip "T/G" since it's reference rate
params = [p for p in params if p.get("par_name") != "T/G"]
else:
submodel = _gn_constructor(predicates)
submodel = _build_reversible_model(predicate_names)

lf = submodel.make_likelihood_function(tree, aligned=True)
lf.set_motif_probs(mprobs)

# set to stationary as default
lf.set_motif_probs(motif_probs)
lf.apply_param_rules(params)
lf.set_param_rule("mprobs", is_constant=self.stationarity)

result = model_result(
Expand All @@ -278,21 +272,68 @@ def main(self, tree: PhyloNode) -> model_result:

return result

def _gn_constructor(
predicates: list
) -> ns_substitution_model.NonReversibleNucleotide:
predicates = [MotifChange(*n.split("/"), forward_only=True).aliased(n)
for n, value in predicates.items() if value != 1]


required = {
"optimise_motif_probs": False,
"predicates": predicates,
}
kwargs = {"recode_gaps": True, "model_gaps": False} | required
return ns_substitution_model.NonReversibleNucleotide(
**kwargs,
)

def _parse_params(result: PhyloNode) -> tuple[list[dict], list[str], dict]:
param_vals = {}
for node in result.preorder(include_self=False):
for k, v in node.params.items():
value = param_vals.get(k, {})
param_vals[k] = value
if isinstance(v, float):
nodes = value.get(v, [])
nodes.append(node.name)
value[v] = nodes
elif isinstance(v, dict):
val = tuple((a, v[a]) for a in sorted(v))
nodes = value.get(val, [])
nodes.append(node.name)
value[val] = nodes

predicate_names = [n for n in param_vals if n not in ("length", "mprobs")]
if len(param_vals.get("mprobs", {})) != 1:
raise NotImplementedError("More than one set of motif probabilities")

motif_probs, _ = param_vals.pop("mprobs", {}).popitem()
motif_probs = dict(motif_probs)

param_rules = []
for par_name, values in param_vals.items():
if len(values) != 1:
raise NotImplementedError(f"Parameter: {par_name!r} has a multiple values.")
value, _ = values.popitem()
param_rules.append({"par_name": par_name, "init": value})

return param_rules, predicate_names, motif_probs


@register_model("nucleotide")
def _build_reversible_model(
predicate_names: list[str],
) -> substitution_model.TimeReversibleNucleotide:
predicates = [predicate.parse(name) for name in predicate_names]
return substitution_model.TimeReversibleNucleotide(predicates=predicates)


@register_model("nucleotide")
def _build_unrest_model(
predicate_names: list,
) -> ns_substitution_model.NonReversibleNucleotide:
predicates = [
MotifChange(*name.split("/"), forward_only=True).aliased(name)
for name in predicate_names
if "/" in name and "T/G" not in name # skip reference rate
]
if not predicates:
raise ValueError("No valid predicates found for UNREST model")

required = {
"optimise_motif_probs": False,
"predicates": predicates,
}
kwargs = required | {"recode_gaps": True, "model_gaps": False}
return ns_substitution_model.NonReversibleNucleotide(
**kwargs,
)


@define_app
Expand Down
9 changes: 6 additions & 3 deletions tests/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,18 @@ def test_check_fit_boundary():
assert isinstance(res, BoundsViolation)


@pytest.mark.parametrize("tree_name", ["hky_tree", "gtr_tree", "unrest_tree"])
def test_convert_piq_build_treeto_model_result(tree_name):
tree = deserialise_object(f"{DATADIR}/piqtree/{tree_name}.json")
@pytest.mark.xfail(reason="need new version of piqtree")
@pytest.mark.parametrize("model_name", ["HKY", "GTR", "UNREST"])
def test_convert_piq_build_treeto_model_result(model_name):
fit = get_app("piq_fit_tree", tree=_algn.quick_tree(), model=model_name)
tree = fit(_algn)
converter = phylim_to_model_result()
res = converter(tree)
res.lf.set_alignment(_algn)
assert allclose(res.lf.lnL, tree.params["lnL"])


@pytest.mark.xfail(reason="need new version of piqtree")
@pytest.mark.skipif(
sys.platform.startswith("win"), reason="Test not supported on Windows"
)
Expand Down