From cfcffe38feed0cd38bf8bd3bfdc767d82d97a014 Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Wed, 21 Jan 2026 21:13:27 -0800 Subject: [PATCH 1/3] Disambiguate eps into min_branch_length; change default to 1e-8 --- CHANGELOG.md | 11 +++++++++ tsdate/core.py | 63 ++++++++++++++++++++++++++++++-------------------- 2 files changed, 49 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a92a0631..21938fdb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## [0.2.5] - XXXXXXXXXX + +- The `eps` parameter was previously used for two things: the minimum time interval in + likelihood calculations for the discrete-time algorithms, and the minimum allowed branch length + when forcing positive branch lengths. The latter is now a separate parameter called + `min_branch_length` for all algorithms, while the `eps` parameter is only used for the + discrete time algorithms. + +- The default `min_branch_length` and `eps` have been set to 1e-8 rather than 1e-10, to avoid + occasional issues with floating point error. + ## [0.2.4] - 2025-09-18 - Add support for Python 3.13, minimum version is now 3.10. diff --git a/tsdate/core.py b/tsdate/core.py index b2e1c895..02bdbfdc 100644 --- a/tsdate/core.py +++ b/tsdate/core.py @@ -41,7 +41,8 @@ DEFAULT_RESCALING_INTERVALS = 1000 DEFAULT_RESCALING_ITERATIONS = 5 DEFAULT_MAX_ITERATIONS = 25 -DEFAULT_EPSILON = 1e-10 +DEFAULT_EPSILON = 1e-8 +DEFAULT_MIN_BRANCH_LENGTH = 1e-8 # Classes for each method @@ -87,6 +88,7 @@ def __init__( allow_unary=None, record_provenance=None, constr_iterations=None, + min_branch_length=None, set_metadata=None, progress=None, # deprecated params @@ -150,6 +152,13 @@ def __init__( ) self.constr_iterations = constr_iterations + if min_branch_length is None: + self.min_branch_length = DEFAULT_MIN_BRANCH_LENGTH + else: + if not min_branch_length > 0.0: + raise ValueError("Minimum branch length must be positive") + self.min_branch_length = min_branch_length + self.allow_unary = False if allow_unary is None else allow_unary if self.prior_grid_func_name is None: @@ -188,7 +197,7 @@ def __init__( self.edges_mutations, self.mutations_edge = util.mutation_span_array(ts) self.fixed_nodes = np.array(list(ts.samples())) - def get_modified_ts(self, result, eps): + def get_modified_ts(self, result): # Return a new ts based on the existing one, but with the various # time-related information correctly set. ts = self.ts @@ -215,7 +224,9 @@ def get_modified_ts(self, result, eps): # Constrain node ages for positive branch lengths constr_timing = time.time() - nodes.time = util.constrain_ages(ts, node_mean_t, eps, self.constr_iterations) + nodes.time = util.constrain_ages( + ts, node_mean_t, self.min_branch_length, self.constr_iterations + ) constr_timing -= time.time() logger.info(f"Constrained node ages in {abs(constr_timing):.2f} seconds") # Possibly change mutation nodes if phasing singletons @@ -282,10 +293,10 @@ def _time_md_array(table, mean, var): table.metadata_schema = default_schema table.packset_metadata(_time_md_array(table, mean, var)) - def parse_result(self, result, epsilon): + def parse_result(self, result): # Construct the tree sequence to return and add other stuff we might want to # return. pst_cols is a dict to be appended to the output posterior dict - ret = [self.get_modified_ts(result, epsilon)] + ret = [self.get_modified_ts(result)] if self.return_fit: ret.append(result.fit_object) if self.return_likelihood: @@ -451,7 +462,6 @@ def __init__(self, ts, **kwargs): def run( self, - eps, max_iterations, max_shape, rescaling_intervals, @@ -570,9 +580,8 @@ def maximization( the prior parameters for each node-to-be-dated. Note that different estimation methods may require different types of prior, as described in the documentation for each estimation method. - :param float eps: The error factor in time difference calculations, and the - minimum distance separating parent and child ages in the returned tree sequence. - Default: None, treated as 1e-10. + :param float eps: The error factor in time difference calculations. Default: None, + treated as 1e-8. :param int num_threads: The number of threads to use when precalculating likelihoods. A simpler unthreaded algorithm is used unless this is >= 1. Default: None :param string probability_space: Should the internal algorithm save @@ -587,7 +596,7 @@ def maximization( - **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with updated node times based on the posterior mean, corrected where necessary to ensure that parents are strictly older than all their children by an amount - given by the ``eps`` parameter. + given by the ``min_branch_length`` parameter. - **marginal_likelihood** (:py:class:`float`) -- (Only returned if ``return_likelihood`` is ``True``) The marginal likelihood of the mutation data given the inferred node times. @@ -615,7 +624,7 @@ def maximization( cache_inside=cache_inside, probability_space=probability_space, ) - return dating_method.parse_result(result, eps) + return dating_method.parse_result(result) def inside_outside( @@ -692,9 +701,8 @@ def inside_outside( the prior parameters for each node-to-be-dated. Note that different estimation methods may require different types of prior, as described in the documentation for each estimation method. - :param float eps: The error factor in time difference calculations, and the - minimum distance separating parent and child ages in the returned tree sequence. - Default: None, treated as 1e-10. + :param float eps: The error factor in time difference calculations. Default: None, + treated as 1e-8. :param int num_threads: The number of threads to use when precalculating likelihoods. A simpler unthreaded algorithm is used unless this is >= 1. Default: None :param bool outside_standardize: Should the likelihoods be standardized during the @@ -720,7 +728,7 @@ def inside_outside( - **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with updated node times based on the posterior mean, corrected where necessary to ensure that parents are strictly older than all their children by an amount - given by the ``eps`` parameter. + given by the ``min_branch_length`` parameter. - **fit** (:class:`~discrete.BeliefPropagation`) -- (Only returned if ``return_fit`` is ``True``) The underlying object used to run the dating inference. This can then be queried e.g. using @@ -757,14 +765,13 @@ def inside_outside( cache_inside=cache_inside, probability_space=probability_space, ) - return dating_method.parse_result(result, eps) + return dating_method.parse_result(result) def variational_gamma( tree_sequence, *, mutation_rate, - eps=None, max_iterations=None, rescaling_intervals=None, rescaling_iterations=None, @@ -773,10 +780,12 @@ def variational_gamma( max_shape=None, regularise_roots=None, singletons_phased=None, + # deprecated parameters + eps=None, **kwargs, ): """ - variational_gamma(tree_sequence, *, mutation_rate, eps=None, max_iterations=None,\ + variational_gamma(tree_sequence, *, mutation_rate, max_iterations=None,\ rescaling_intervals=None, rescaling_iterations=None,\ match_segregating_sites=None, **kwargs) @@ -797,8 +806,6 @@ def variational_gamma( :param ~tskit.TreeSequence tree_sequence: The input tree sequence to be dated. :param float mutation_rate: The estimated mutation rate per unit of genome per unit time. - :param float eps: The minimum distance separating parent and child ages in - the returned tree sequence. Default: None, treated as 1e-10 :param int max_iterations: The number of iterations used in the expectation propagation algorithm. Default: None, treated as 25. :param float rescaling_intervals: For time rescaling, the number of time @@ -820,7 +827,7 @@ def variational_gamma( - **ts** (:class:`~tskit.TreeSequence`) -- a copy of the input tree sequence with updated node times based on the posterior mean, corrected where necessary to ensure that parents are strictly older than all their children by an amount - given by the ``eps`` parameter. + given by the ``min_branch_length`` parameter. - **fit** (:class:`~variational.ExpectationPropagation`) -- (Only returned if ``return_fit`` is ``True``). The underlying object used to run the dating inference. This can then be queried e.g. using @@ -830,8 +837,6 @@ def variational_gamma( the mutation data given the inferred node times. Not currently implemented for this method (set to ``None``) """ - if eps is None: - eps = DEFAULT_EPSILON if max_iterations is None: max_iterations = DEFAULT_MAX_ITERATIONS if max_shape is None: @@ -848,6 +853,11 @@ def variational_gamma( regularise_roots = True if singletons_phased is None: singletons_phased = True + if eps is not None: + raise ValueError( + "The `eps` parameter has been disambiguated and is no longer used " + "for the variational gamma algorithm; use `min_branch_length` instead" + ) if tree_sequence.num_mutations == 0: raise ValueError( "No mutations present: these are required for the variational_gamma method" @@ -856,7 +866,6 @@ def variational_gamma( tree_sequence, mutation_rate=mutation_rate, **kwargs ) result = dating_method.run( - eps=eps, max_iterations=max_iterations, max_shape=max_shape, rescaling_intervals=rescaling_intervals, @@ -865,7 +874,7 @@ def variational_gamma( regularise_roots=regularise_roots, singletons_phased=singletons_phased, ) - return dating_method.parse_result(result, eps) + return dating_method.parse_result(result) estimation_methods = { @@ -893,6 +902,7 @@ def date( time_units=None, method=None, constr_iterations=None, + min_branch_length=None, set_metadata=None, return_fit=None, return_likelihood=None, @@ -943,6 +953,8 @@ def date( :param int constr_iterations: The maximum number of constrained least squares iterations to use prior to forcing positive branch lengths. Default: None, treated as 0. + :param float min_branch_length: The minimum distance separating parent and + child ages in the returned tree sequence. Default: None, treated as 1e-8 :param bool set_metadata: Should unconstrained times be stored in table metadata, in the form of ``"mn"`` (mean) and ``"vr"`` (variance) fields? If ``False``, do not store metadata. If ``True``, force metadata to be set (if no schema @@ -984,6 +996,7 @@ def date( time_units=time_units, progress=progress, constr_iterations=constr_iterations, + min_branch_length=min_branch_length, return_fit=return_fit, return_likelihood=return_likelihood, allow_unary=allow_unary, From 7fc00c5c0992427c671027ee1c01d15931048b3b Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Sat, 24 Jan 2026 20:11:07 -0800 Subject: [PATCH 2/3] fix cli tests --- tests/test_cli.py | 3 ++- tests/test_functions.py | 4 ++-- tsdate/cli.py | 18 ++++++++++++++---- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index d537dbd3..5da1d6f8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -57,7 +57,8 @@ def test_default_values(self): assert args.population_size is None assert args.mutation_rate is None assert args.recombination_rate is None - assert args.epsilon == 1e-10 + assert args.epsilon == 1e-8 + assert args.min_branch_length == 1e-8 assert args.num_threads is None assert args.probability_space is None # Use the defaults assert args.method == "variational_gamma" diff --git a/tests/test_functions.py b/tests/test_functions.py index 9947df27..d369fd34 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -1546,7 +1546,7 @@ def test_node_metadata_inside_outside(self): ) algorithm = InsideOutsideMethod(ts, mutation_rate=1, population_size=10000) mn_post, *_ = algorithm.run( - eps=1e-10, + eps=1e-8, outside_standardize=True, ignore_oldest_root=False, probability_space=tsdate.node_time_class.LOG_GRID, @@ -1743,7 +1743,7 @@ def test_sites_time_insideoutside(self): dated = tsdate.inside_outside(ts, mutation_rate=1, population_size=1) algorithm = InsideOutsideMethod(ts, mutation_rate=1, population_size=1) mn_post, *_ = algorithm.run( - eps=1e-10, + eps=1e-8, outside_standardize=True, ignore_oldest_root=False, probability_space=LOG_GRID, diff --git a/tsdate/cli.py b/tsdate/cli.py index 6f90fdb5..1e64da35 100644 --- a/tsdate/cli.py +++ b/tsdate/cli.py @@ -124,11 +124,20 @@ def tsdate_cli_parser(): type=float, default=core.DEFAULT_EPSILON, help=( - "Specify minimum distance separating time points. Also " - "specifies the error factor in time difference calculations. " + "Specify the error factor in time difference calculations. " f"Default: {core.DEFAULT_EPSILON}" ), ) + parser.add_argument( + "-b", + "--min-branch-length", + type=float, + default=core.DEFAULT_MIN_BRANCH_LENGTH, + help=( + "Specify the minimum difference in age between parent and child nodes. " + f"Default: {core.DEFAULT_MIN_BRANCH_LENGTH}" + ), + ) parser.add_argument( "--method", choices=["inside_outside", "maximization", "variational_gamma"], @@ -169,7 +178,7 @@ def tsdate_cli_parser(): ), default=None, ) - # TODO array specification from file? + # arguments for discrete time methods parser.add_argument( "-n", "--population_size", @@ -283,7 +292,7 @@ def run_date(args): params = dict( recombination_rate=args.recombination_rate, method=args.method, - eps=args.epsilon, + min_branch_length=args.min_branch_length, progress=args.progress, max_iterations=args.max_iterations, rescaling_intervals=args.rescaling_intervals, @@ -299,6 +308,7 @@ def run_date(args): population_size=args.population_size, recombination_rate=args.recombination_rate, method=args.method, + min_branch_length=args.min_branch_length, eps=args.epsilon, progress=args.progress, probability_space=args.probability_space, From 03b80d2a88af6285b28b534c5f8932dd31a0407d Mon Sep 17 00:00:00 2001 From: Nate Pope Date: Sat, 24 Jan 2026 20:27:50 -0800 Subject: [PATCH 3/3] Fix quadpack warnings in tests --- tests/test_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_functions.py b/tests/test_functions.py index d369fd34..2433b28c 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -2064,13 +2064,13 @@ def test_moments_numerically(self): beta = 1.7 demography = PopulationSizeHistory([1000, 2000, 3000], [500, 2500]) numer_mn, _ = scipy.integrate.quad( - lambda t: demography.to_natural_timescale(np.array([t])) + lambda t: demography.to_natural_timescale(np.array([t])).item() * scipy.stats.gamma.pdf(t, alpha, scale=1 / beta), 0, np.inf, ) numer_va, _ = scipy.integrate.quad( - lambda t: demography.to_natural_timescale(np.array([t])) ** 2 + lambda t: demography.to_natural_timescale(np.array([t])).item() ** 2 * scipy.stats.gamma.pdf(t, alpha, scale=1 / beta), 0, np.inf,