diff --git a/pyproject.toml b/pyproject.toml index b2ccba7..d909ba5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dev = [ "pytest-cov", "pytest-xdist", "nox", + "piqtree", ] [tool.pytest.ini_options] diff --git a/src/phylim/apps.py b/src/phylim/apps.py index f777458..38dd3bc 100644 --- a/src/phylim/apps.py +++ b/src/phylim/apps.py @@ -9,6 +9,7 @@ from cogent3.core.tree import PhyloNode from cogent3.draw.dendrogram import Dendrogram from cogent3.evolve import ns_substitution_model, predicate, substitution_model +from cogent3.evolve.models import register_model from cogent3.evolve.parameter_controller import AlignmentLikelihoodFunction from cogent3.evolve.predicate import MotifChange @@ -243,30 +244,23 @@ class phylim_to_model_result: this app assume tree object derived from piqtree """ - excludes = ["length", "mprobs"] - def __init__(self, stationarity: bool = True) -> None: self.stationarity = stationarity def main(self, tree: PhyloNode) -> model_result: - params = tree.get_root().params - mprobs = tree.params["mprobs"] - - # build predicates, excluding length and mprobs - predicates = [k for k in params.keys() if k not in self.excludes] - + params, predicate_names, motif_probs = _parse_params(tree) # decide model type based on number of rates/predicates - if len(predicates) < 6: - predicates = [predicate.parse(k) for k in predicates] - submodel = substitution_model.TimeReversibleNucleotide( - predicates=predicates - ) + if tree.params["model"] == "UNREST": + submodel = _build_unrest_model(predicate_names) + # skip "T/G" since it's reference rate + params = [p for p in params if p.get("par_name") != "T/G"] else: - submodel = _gn_constructor(predicates) + submodel = _build_reversible_model(predicate_names) + lf = submodel.make_likelihood_function(tree, aligned=True) - lf.set_motif_probs(mprobs) - # set to stationary as default + lf.set_motif_probs(motif_probs) + lf.apply_param_rules(params) lf.set_param_rule("mprobs", is_constant=self.stationarity) result = model_result( @@ -278,21 +272,68 @@ def main(self, tree: PhyloNode) -> model_result: return result -def _gn_constructor( - predicates: list - ) -> ns_substitution_model.NonReversibleNucleotide: - predicates = [MotifChange(*n.split("/"), forward_only=True).aliased(n) - for n, value in predicates.items() if value != 1] - - - required = { - "optimise_motif_probs": False, - "predicates": predicates, - } - kwargs = {"recode_gaps": True, "model_gaps": False} | required - return ns_substitution_model.NonReversibleNucleotide( - **kwargs, - ) + +def _parse_params(result: PhyloNode) -> tuple[list[dict], list[str], dict]: + param_vals = {} + for node in result.preorder(include_self=False): + for k, v in node.params.items(): + value = param_vals.get(k, {}) + param_vals[k] = value + if isinstance(v, float): + nodes = value.get(v, []) + nodes.append(node.name) + value[v] = nodes + elif isinstance(v, dict): + val = tuple((a, v[a]) for a in sorted(v)) + nodes = value.get(val, []) + nodes.append(node.name) + value[val] = nodes + + predicate_names = [n for n in param_vals if n not in ("length", "mprobs")] + if len(param_vals.get("mprobs", {})) != 1: + raise NotImplementedError("More than one set of motif probabilities") + + motif_probs, _ = param_vals.pop("mprobs", {}).popitem() + motif_probs = dict(motif_probs) + + param_rules = [] + for par_name, values in param_vals.items(): + if len(values) != 1: + raise NotImplementedError(f"Parameter: {par_name!r} has a multiple values.") + value, _ = values.popitem() + param_rules.append({"par_name": par_name, "init": value}) + + return param_rules, predicate_names, motif_probs + + +@register_model("nucleotide") +def _build_reversible_model( + predicate_names: list[str], +) -> substitution_model.TimeReversibleNucleotide: + predicates = [predicate.parse(name) for name in predicate_names] + return substitution_model.TimeReversibleNucleotide(predicates=predicates) + + +@register_model("nucleotide") +def _build_unrest_model( + predicate_names: list, +) -> ns_substitution_model.NonReversibleNucleotide: + predicates = [ + MotifChange(*name.split("/"), forward_only=True).aliased(name) + for name in predicate_names + if "/" in name and "T/G" not in name # skip reference rate + ] + if not predicates: + raise ValueError("No valid predicates found for UNREST model") + + required = { + "optimise_motif_probs": False, + "predicates": predicates, + } + kwargs = required | {"recode_gaps": True, "model_gaps": False} + return ns_substitution_model.NonReversibleNucleotide( + **kwargs, + ) @define_app diff --git a/tests/test_apps.py b/tests/test_apps.py index c30296a..ca585e4 100644 --- a/tests/test_apps.py +++ b/tests/test_apps.py @@ -138,15 +138,18 @@ def test_check_fit_boundary(): assert isinstance(res, BoundsViolation) -@pytest.mark.parametrize("tree_name", ["hky_tree", "gtr_tree", "unrest_tree"]) -def test_convert_piq_build_treeto_model_result(tree_name): - tree = deserialise_object(f"{DATADIR}/piqtree/{tree_name}.json") +@pytest.mark.xfail(reason="need new version of piqtree") +@pytest.mark.parametrize("model_name", ["HKY", "GTR", "UNREST"]) +def test_convert_piq_build_treeto_model_result(model_name): + fit = get_app("piq_fit_tree", tree=_algn.quick_tree(), model=model_name) + tree = fit(_algn) converter = phylim_to_model_result() res = converter(tree) res.lf.set_alignment(_algn) assert allclose(res.lf.lnL, tree.params["lnL"]) +@pytest.mark.xfail(reason="need new version of piqtree") @pytest.mark.skipif( sys.platform.startswith("win"), reason="Test not supported on Windows" )