From 3e61731839e5370e4c7aebc1b7222a08e2bb3181 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20M=C3=BCller-Widmann?= Date: Sat, 18 Oct 2025 17:08:47 -0600 Subject: [PATCH] Add unnormalized log pdf functions --- ext/StatsFunsChainRulesCoreExt.jl | 72 ++++++++++++++++++++++++++++++- src/StatsFuns.jl | 32 ++++++++++++++ src/distrs/beta.jl | 10 ++++- src/distrs/binom.jl | 14 +++++- src/distrs/chisq.jl | 10 ++++- src/distrs/fdist.jl | 21 +++++++++ src/distrs/gamma.jl | 26 +++++++++-- src/distrs/hyper.jl | 4 ++ src/distrs/nbeta.jl | 3 ++ src/distrs/nbinom.jl | 3 ++ src/distrs/nchisq.jl | 3 ++ src/distrs/nfdist.jl | 3 ++ src/distrs/norm.jl | 27 ++++++++++++ src/distrs/ntdist.jl | 3 ++ src/distrs/pois.jl | 16 +++++-- src/distrs/signrank.jl | 23 +++++++--- src/distrs/tdist.jl | 14 ++++++ src/distrs/wilcox.jl | 21 ++++++--- test/chainrules.jl | 30 ++++++++++--- test/runtests.jl | 2 +- test/unnormalized.jl | 66 ++++++++++++++++++++++++++++ 21 files changed, 374 insertions(+), 29 deletions(-) create mode 100644 test/unnormalized.jl diff --git a/ext/StatsFunsChainRulesCoreExt.jl b/ext/StatsFunsChainRulesCoreExt.jl index e94e549..fcb9989 100644 --- a/ext/StatsFunsChainRulesCoreExt.jl +++ b/ext/StatsFunsChainRulesCoreExt.jl @@ -13,6 +13,14 @@ ChainRulesCore.@scalar_rule( (α - 1) / x + (1 - β) / (1 - x), ), ) +ChainRulesCore.@scalar_rule( + betalogupdf(α::Real, β::Real, x::Number), + ( + log(x), + log1p(-x), + (α - 1) / x + (1 - β) / (1 - x), + ), +) ChainRulesCore.@scalar_rule( binomlogpdf(n::Real, p::Real, k::Real), @@ -22,12 +30,28 @@ ChainRulesCore.@scalar_rule( ChainRulesCore.NoTangent(), ), ) +ChainRulesCore.@scalar_rule( + binomlogupdf(n::Real, p::Real, k::Real), + ( + ChainRulesCore.NoTangent(), + (k / p - n) / (1 - p), + ChainRulesCore.NoTangent(), + ), +) ChainRulesCore.@scalar_rule( chisqlogpdf(k::Real, x::Number), @setup(hk = k / 2), ( - (log(x) - logtwo - digamma(hk)) / 2, + (log(x / 2) - digamma(hk)) / 2, + (hk - 1) / x - one(hk) / 2, + ), +) +ChainRulesCore.@scalar_rule( + chisqlogupdf(k::Real, x::Number), + @setup(hk = k / 2), + ( + log(x / 2) / 2, (hk - 1) / x - one(hk) / 2, ), ) @@ -47,6 +71,19 @@ ChainRulesCore.@scalar_rule( ((ν1 - 2) / x - ν1 * νsum / temp1) / 2, ), ) +ChainRulesCore.@scalar_rule( + fdistlogupdf(ν1::Real, ν2::Real, x::Number), + @setup( + tmp = x * ν1 + ν2, + a = x * (ν1 + ν2) / tmp, + b = ν2 / tmp, + ), + ( + (log(x * b) - a) / 2, + (log(b) + (ν1 / ν2) * a) / 2, + (ν1 * (1 - a) - 2) / (2 * x), + ), +) ChainRulesCore.@scalar_rule( gammalogpdf(k::Real, θ::Real, x::Number), @@ -61,11 +98,32 @@ ChainRulesCore.@scalar_rule( - (1 + z) / x, ), ) +ChainRulesCore.@scalar_rule( + gammalogupdf(k::Real, θ::Real, x::Number), + @setup( + invθ = inv(θ), + xoθ = invθ * x, + z = xoθ - (k - 1), + ), + ( + log(xoθ), + invθ * z, + - z / x, + ), +) ChainRulesCore.@scalar_rule( poislogpdf(λ::Number, x::Number), ((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, ChainRulesCore.NoTangent()), ) +ChainRulesCore.@scalar_rule( + poislogupdf(λ::Number, x::Number), + ((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ), ChainRulesCore.NoTangent()), +) +ChainRulesCore.@scalar_rule( + poislogulikelihood(λ::Number, x::Number), + ((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, ChainRulesCore.NoTangent()), +) ChainRulesCore.@scalar_rule( tdistlogpdf(ν::Real, x::Number), @@ -81,5 +139,17 @@ ChainRulesCore.@scalar_rule( - x * b, ), ) +ChainRulesCore.@scalar_rule( + tdistlogupdf(ν::Real, x::Number), + @setup( + xsq = x^2, + a = xsq / ν, + b = (ν + 1) / (ν + xsq), + ), + ( + (a * b - log1p(a)) / 2, + - x * b, + ), +) end # module diff --git a/src/StatsFuns.jl b/src/StatsFuns.jl index 22fdbc0..7ba8b20 100644 --- a/src/StatsFuns.jl +++ b/src/StatsFuns.jl @@ -55,6 +55,8 @@ export # distrs/beta betapdf, # pdf of beta distribution betalogpdf, # logpdf of beta distribution + betalogupdf, # unnormalized logpdf of beta distribution (parameters constant) + betalogulikelihood, # unnormalized logpdf of beta distribution (data constant) betacdf, # cdf of beta distribution betaccdf, # ccdf of beta distribution betalogcdf, # logcdf of beta distribution @@ -67,6 +69,8 @@ export # distrs/binom binompdf, # pdf of binomial distribution binomlogpdf, # logpdf of binomial distribution + binomlogupdf, # unnormalized logpdf of binomial distribution (parameters constant) + binomlogulikelihood, # unnormalized logpdf of binomial distribution (data constant) binomcdf, # cdf of binomial distribution binomccdf, # ccdf of binomial distribution binomlogcdf, # logcdf of binomial distribution @@ -79,6 +83,8 @@ export # distrs/chisq chisqpdf, # pdf of chi-square distribution chisqlogpdf, # logpdf of chi-square distribution + chisqlogupdf, # unnormalized logpdf of chi-square distribution (parameters constant) + chisqlogulikelihood, # unnormalized logpdf of chi-square distribution (data constant) chisqcdf, # cdf of chi-square distribution chisqccdf, # ccdf of chi-square distribution chisqlogcdf, # logcdf of chi-square distribution @@ -91,6 +97,8 @@ export # distrs/fdist fdistpdf, # pdf of F distribution fdistlogpdf, # logpdf of F distribution + fdistlogupdf, # unnormalized logpdf of F distribution (parameters constant) + fdistlogulikelihood, # unnormalized logpdf of F distribution (data constant) fdistcdf, # cdf of F distribution fdistccdf, # ccdf of F distribution fdistlogcdf, # logcdf of F distribution @@ -103,6 +111,8 @@ export # distrs/gamma gammapdf, # pdf of gamma distribution gammalogpdf, # logpdf of gamma distribution + gammalogupdf, # unnormalized logpdf of gamma distribution (parameters constant) + gammalogulikelihood, # unnormalized logpdf of gamma distribution (data constant) gammacdf, # cdf of gamma distribution gammaccdf, # ccdf of gamma distribution gammalogcdf, # logcdf of gamma distribution @@ -115,6 +125,8 @@ export # distrs/hyper hyperpdf, # pdf of hypergeometric distribution hyperlogpdf, # logpdf of hypergeometric distribution + hyperlogupdf, # unnormalized logpdf of hypergeometric distribution (parameters constant) + hyperlogulikelihood, # unnormalized logpdf of hypergeometric distribution (data constant) hypercdf, # cdf of hypergeometric distribution hyperccdf, # ccdf of hypergeometric distribution hyperlogcdf, # logcdf of hypergeometric distribution @@ -127,6 +139,8 @@ export # distrs/nbeta nbetapdf, # pdf of noncentral beta distribution nbetalogpdf, # logpdf of noncentral beta distribution + nbetalogupdf, # unnormalized logpdf of noncentral beta distribution (parameters constant) + nbetalogulikelihood, # unnormalized logpdf of noncentral beta distribution (data constant) nbetacdf, # cdf of noncentral beta distribution nbetaccdf, # ccdf of noncentral beta distribution nbetalogcdf, # logcdf of noncentral beta distribution @@ -139,6 +153,8 @@ export # distrs/nbinom nbinompdf, # pdf of negative nbinomial distribution nbinomlogpdf, # logpdf of negative nbinomial distribution + nbinomlogupdf, # unnormalized logpdf of negative nbinomial distribution (parameters constant) + nbinomlogulikelihood, # unnormalized logpdf of negative nbinomial distribution (data constant) nbinomcdf, # cdf of negative nbinomial distribution nbinomccdf, # ccdf of negative nbinomial distribution nbinomlogcdf, # logcdf of negative nbinomial distribution @@ -151,6 +167,8 @@ export # distrs/nchisq nchisqpdf, # pdf of noncentral chi-square distribution nchisqlogpdf, # logpdf of noncentral chi-square distribution + nchisqlogupdf, # unnormalized logpdf of noncentral chi-square distribution (parameters constant) + nchisqlogulikelihood, # unnormalized logpdf of noncentral chi-square distribution (data constant) nchisqcdf, # cdf of noncentral chi-square distribution nchisqccdf, # ccdf of noncentral chi-square distribution nchisqlogcdf, # logcdf of noncentral chi-square distribution @@ -163,6 +181,8 @@ export # distrs/nfdist nfdistpdf, # pdf of noncentral F distribution nfdistlogpdf, # logpdf of noncentral F distribution + nfdistlogupdf, # unnormalized logpdf of noncentral F distribution (parameters constant) + nfdistlogulikelihood, # unnormalized logpdf of noncentral F distribution (data constant) nfdistcdf, # cdf of noncentral F distribution nfdistccdf, # ccdf of noncentral F distribution nfdistlogcdf, # logcdf of noncentral F distribution @@ -175,6 +195,8 @@ export # distrs/norm normpdf, # pdf of normal distribution normlogpdf, # logpdf of normal distribution + normlogupdf, # unnormalized logpdf of normal distribution (parameters constant) + normlogulikelihood, # unnormalized logpdf of normal distribution (data constant) normcdf, # cdf of normal distribution normccdf, # ccdf of normal distribution normlogcdf, # logcdf of normal distribution @@ -187,6 +209,8 @@ export # distrs/ntdist ntdistpdf, # pdf of noncentral t distribution ntdistlogpdf, # logpdf of noncentral t distribution + ntdistlogupdf, # unnormalized logpdf of noncentral t distribution (parameters constant) + ntdistlogulikelihood, # unnormalized logpdf of noncentral t distribution (data constant) ntdistcdf, # cdf of noncentral t distribution ntdistccdf, # ccdf of noncentral t distribution ntdistlogcdf, # logcdf of noncentral t distribution @@ -199,6 +223,8 @@ export # distrs/pois poispdf, # pdf of Poisson distribution poislogpdf, # logpdf of Poisson distribution + poislogupdf, # unnormalized logpdf of Poisson distribution (parameters constant) + poislogulikelihood, # unnormalized logpdf of Poisson distribution (data constant) poiscdf, # cdf of Poisson distribution poisccdf, # ccdf of Poisson distribution poislogcdf, # logcdf of Poisson distribution @@ -211,6 +237,8 @@ export # distrs/tdist tdistpdf, # pdf of student's t distribution tdistlogpdf, # logpdf of student's t distribution + tdistlogupdf, # unnormalized logpdf of student's t distribution (parameters constant) + tdistlogulikelihood, # unnormalized logpdf of student's t distribution (data constant) tdistcdf, # cdf of student's t distribution tdistccdf, # ccdf of student's t distribution tdistlogcdf, # logcdf of student's t distribution @@ -223,6 +251,8 @@ export # distrs/signrank signrankpdf, signranklogpdf, + signranklogupdf, + signranklogulikelihood, signrankcdf, signranklogcdf, signrankccdf, @@ -245,6 +275,8 @@ export # distrs/wilcox wilcoxpdf, wilcoxlogpdf, + wilcoxlogupdf, + wilcoxlogulikelihood, wilcoxcdf, wilcoxlogcdf, wilcoxccdf, diff --git a/src/distrs/beta.jl b/src/distrs/beta.jl index c741e6c..10e5b19 100644 --- a/src/distrs/beta.jl +++ b/src/distrs/beta.jl @@ -20,12 +20,20 @@ betapdf(α::Real, β::Real, x::Real) = exp(betalogpdf(α, β, x)) betalogpdf(α::Real, β::Real, x::Real) = betalogpdf(promote(α, β, x)...) function betalogpdf(α::T, β::T, x::T) where {T <: Real} + logupdf = betalogupdf(α, β, x) + return isfinite(logupdf) ? logupdf - logbeta(α, β) : logupdf +end + +betalogupdf(α::Real, β::Real, x::Real) = betalogupdf(promote(α, β, x)...) +function betalogupdf(α::T, β::T, x::T) where {T <: Real} # we ensure that `log(x)` and `log1p(-x)` do not error y = clamp(x, 0, 1) - val = xlogy(α - 1, y) + xlog1py(β - 1, -y) - logbeta(α, β) + val = xlogy(α - 1, y) + xlog1py(β - 1, -y) return x < 0 || x > 1 ? oftype(val, -Inf) : val end +betalogulikelihood(α::Real, β::Real, x::Real) = betalogpdf(α, β, x) + function betacdf(α::Real, β::Real, x::Real) # Handle degenerate cases if iszero(α) && β > 0 diff --git a/src/distrs/binom.jl b/src/distrs/binom.jl index a2836f2..acbfefe 100644 --- a/src/distrs/binom.jl +++ b/src/distrs/binom.jl @@ -18,11 +18,23 @@ binompdf(n::Real, p::Real, k::Real) = exp(binomlogpdf(n, p, k)) binomlogpdf(n::Real, p::Real, k::Real) = binomlogpdf(promote(n, p, k)...) function binomlogpdf(n::T, p::T, k::T) where {T <: Real} + logupdf = binomlogupdf(n, p, k) + if isfinite(logupdf) + return min(0, logupdf - log(n + 1)) + else + return logupdf + end +end + +binomlogupdf(n::Real, p::Real, k::Real) = binomlogupdf(promote(n, p, k)...) +function binomlogupdf(n::T, p::T, k::T) where {T <: Real} m = clamp(k, 0, n) - val = min(0, betalogpdf(m + 1, n - m + 1, p) - log(n + 1)) + val = betalogpdf(m + 1, n - m + 1, p) return 0 <= k <= n && isinteger(k) ? val : oftype(val, -Inf) end +binomlogulikelihood(n::Real, p::Real, k::Real) = binomlogpdf(n, p, k) + for l in ("", "log"), compl in (false, true) fbinom = Symbol(string("binom", l, ifelse(compl, "c", ""), "cdf")) fbeta = Symbol(string("beta", l, ifelse(compl, "", "c"), "cdf")) diff --git a/src/distrs/chisq.jl b/src/distrs/chisq.jl index 2204e94..577a88d 100644 --- a/src/distrs/chisq.jl +++ b/src/distrs/chisq.jl @@ -1,7 +1,7 @@ # functions related to chi-square distribution # Just use the Gamma definitions -for f in ("pdf", "logpdf", "cdf", "ccdf", "logcdf", "logccdf", "invcdf", "invccdf", "invlogcdf", "invlogccdf") +for f in ("pdf", "logpdf", "logupdf", "cdf", "ccdf", "logcdf", "logccdf", "invcdf", "invccdf", "invlogcdf", "invlogccdf") _chisqf = Symbol("chisq" * f) _gammaf = Symbol("gamma" * f) @eval begin @@ -9,3 +9,11 @@ for f in ("pdf", "logpdf", "cdf", "ccdf", "logcdf", "logccdf", "invcdf", "invccd $(_chisqf)(k::T, x::T) where {T <: Real} = $(_gammaf)(k / 2, 2, x) end end + +chisqlogulikelihood(k::Real, x::Real) = chisqlogulikelihood(promote(k, x)...) +function chisqlogulikelihood(k::T, x::T) where {T <: Real} + y = max(x, 0) + k2 = k / 2 + val = xlogy(k2, x / 2) - loggamma(k2) + return x < 0 ? oftype(val, -Inf) : val +end diff --git a/src/distrs/fdist.jl b/src/distrs/fdist.jl index 856d40f..fad6023 100644 --- a/src/distrs/fdist.jl +++ b/src/distrs/fdist.jl @@ -12,6 +12,27 @@ function fdistlogpdf(ν1::T, ν2::T, x::T) where {T <: Real} return x < 0 ? oftype(val, -Inf) : val end +fdistlogupdf(ν1::Real, ν2::Real, x::Real) = fdistlogupdf(promote(ν1, ν2, x)...) +function fdistlogupdf(ν1::T, ν2::T, x::T) where {T <: Real} + # we ensure that `log(x)` does not error if `x < 0` + y = max(x, 0) + val = (xlogy(ν1 - 2, y) - xlogy(ν1 + ν2, ν1 * y + ν2)) / 2 + return x < 0 ? oftype(val, -Inf) : val +end + +fdistloguloglikelihood(ν1::Real, ν2::Real, x::Real) = fdistlogulikelihood(promote(ν1, ν2, x)...) +function fdistlogulikelihood(ν1::T, ν2::T, x::T) where {T} + # we ensure that `log(x)` does not error if `x < 0` + y = max(x, 0) + tmp = ν1 * y + ν2 + a = ν1 / tmp + b = ν2 / tmp + halfν1 = ν1 / 2 + halfν2 = ν2 / 2 + val = (xlogy(halfν1, a) + xlogy(halfν2, b)) - logbeta(halfν1, halfν2) + return x < 0 ? oftype(val, -Inf) : val +end + for f in ("cdf", "ccdf", "logcdf", "logccdf") ff = Symbol("fdist" * f) bf = Symbol("beta" * f) diff --git a/src/distrs/gamma.jl b/src/distrs/gamma.jl index ee888ac..bd5ffa9 100644 --- a/src/distrs/gamma.jl +++ b/src/distrs/gamma.jl @@ -20,13 +20,31 @@ gammapdf(k::Real, θ::Real, x::Real) = exp(gammalogpdf(k, θ, x)) gammalogpdf(k::Real, θ::Real, x::Real) = gammalogpdf(promote(k, θ, x)...) function gammalogpdf(k::T, θ::T, x::T) where {T <: Real} + logupdf = gammalogupdf(k, θ, x) + return isfinite(logupdf) ? logupdf - loggamma(k) - k * log(θ) : logupdf +end + +gammalogupdf(k::Real, θ::Real, x::Real) = gammalogupdf(promote(k, θ, x)...) +function gammalogupdf(k::T, θ::T, x::T) where {T <: Real} # we ensure that `log(x)` does not error if `x < 0` - xθ = max(x, 0) / θ - val = -loggamma(k) - log(θ) - xθ + y = max(x, 0) + val = -float(y / θ) # xlogy(k - 1, xθ) - xθ -> -∞ for xθ -> ∞ so we only add the first term # when it's safe - if isfinite(xθ) - val += xlogy(k - 1, xθ) + if isfinite(val) + val += xlogy(k - 1, y) + end + return x < 0 ? oftype(val, -Inf) : val +end + +function gammalogulikelihood(k::Real, θ::Real, x::Real) + # we ensure that `log(x)` does not error if `x < 0` + xθ = max(x, 0) / θ + val = - xθ - loggamma(k) + # xlogy(k, xθ) - xθ -> -∞ for xθ -> ∞ so we only add the first term + # when it's safe + if isfinite(val) + val += xlogy(k, xθ) end return x < 0 ? oftype(val, -Inf) : val end diff --git a/src/distrs/hyper.jl b/src/distrs/hyper.jl index 6c2869f..f3d8545 100644 --- a/src/distrs/hyper.jl +++ b/src/distrs/hyper.jl @@ -12,3 +12,7 @@ using .RFunctions: hyperinvccdf, hyperinvlogcdf, hyperinvlogccdf + + +hyperlogupdf(ms::Real, mf::Real, n::Real, x::Real) = hyperlogpdf(ms, mf, n, x) +hyperlogulikelihood(ms::Real, mf::Real, n::Real, x::Real) = hyperlogpdf(ms, mf, n, x) diff --git a/src/distrs/nbeta.jl b/src/distrs/nbeta.jl index 33c59cf..c298cbc 100644 --- a/src/distrs/nbeta.jl +++ b/src/distrs/nbeta.jl @@ -12,3 +12,6 @@ using .RFunctions: nbetainvccdf, nbetainvlogcdf, nbetainvlogccdf + +nbetalogupdf(α::Real, β::Real, λ::Real, x::Real) = nbetalogpdf(α, β, λ, x) +nbetalogulikelihood(α::Real, β::Real, λ::Real, x::Real) = nbetalogpdf(α, β, λ, x) diff --git a/src/distrs/nbinom.jl b/src/distrs/nbinom.jl index 5f80b1e..cf45dc2 100644 --- a/src/distrs/nbinom.jl +++ b/src/distrs/nbinom.jl @@ -12,3 +12,6 @@ using .RFunctions: nbinominvccdf, nbinominvlogcdf, nbinominvlogccdf + +nbinomlogupdf(r::Real, p::Real, x::Real) = nbinomlogpdf(r, p, x) +nbinomlogulikelihood(r::Real, p::Real, x::Real) = nbinomlogpdf(r, p, x) diff --git a/src/distrs/nchisq.jl b/src/distrs/nchisq.jl index 9382d04..c871c55 100644 --- a/src/distrs/nchisq.jl +++ b/src/distrs/nchisq.jl @@ -12,3 +12,6 @@ using .RFunctions: nchisqinvccdf, nchisqinvlogcdf, nchisqinvlogccdf + +nchisqlogupdf(k::Real, λ::Real, x::Real) = nchisqlogpdf(k, λ, x) +nchisqlogulikelihood(k::Real, λ::Real, x::Real) = nchisqlogpdf(k, λ, x) diff --git a/src/distrs/nfdist.jl b/src/distrs/nfdist.jl index 3f2e735..f7392e8 100644 --- a/src/distrs/nfdist.jl +++ b/src/distrs/nfdist.jl @@ -12,3 +12,6 @@ using .RFunctions: nfdistinvccdf, nfdistinvlogcdf, nfdistinvlogccdf + +nfdistlogupdf(k1::Real, k2::Real, λ::Real, x::Real) = nfdistlogpdf(k1, k2, λ, x) +nfdistlogulikelihood(k1::Real, k2::Real, λ::Real, x::Real) = nfdistlogpdf(k1, k2, λ, x) diff --git a/src/distrs/norm.jl b/src/distrs/norm.jl index e79713c..65564a5 100644 --- a/src/distrs/norm.jl +++ b/src/distrs/norm.jl @@ -41,6 +41,33 @@ function normlogpdf(μ::Real, σ::Real, x::Number) return normlogpdf(z) - log(σ) end +# logupdf +normlogupdf(z::Number) = -abs2(z) / 2 +function normlogupdf(μ::Real, σ::Real, x::Number) + if iszero(σ) && x == μ + z = zval(μ, one(σ), x) + else + z = zval(μ, σ, x) + end + return normlogupdf(z) +end + +# logulikelihood +normlogulikelihood(z::Number) = normulogpdf(z) +function normlogulikelihood(μ::Real, σ::Real, x::Number) + if iszero(σ) + if x == μ + z = zval(μ, one(σ), x) + else + z = zval(μ, σ, x) + σ = one(σ) + end + else + z = zval(μ, σ, x) + end + return normlogulikelihood(z) - log(σ) +end + # cdf normcdf(z::Number) = erfc(-z * invsqrt2) / 2 function normcdf(μ::Real, σ::Real, x::Number) diff --git a/src/distrs/ntdist.jl b/src/distrs/ntdist.jl index b1bbb1b..fa7e453 100644 --- a/src/distrs/ntdist.jl +++ b/src/distrs/ntdist.jl @@ -12,3 +12,6 @@ using .RFunctions: ntdistinvccdf, ntdistinvlogcdf, ntdistinvlogccdf + +ntdistlogupdf(k::Real, λ::Real, x::Real) = ntdistlogpdf(k, λ, x) +ntdistlogulikelihood(k::Real, λ::Real, x::Real) = ntdistlogpdf(k, λ, x) diff --git a/src/distrs/pois.jl b/src/distrs/pois.jl index 0928fae..4153f8b 100644 --- a/src/distrs/pois.jl +++ b/src/distrs/pois.jl @@ -16,9 +16,19 @@ using .RFunctions: # Julia implementations poispdf(λ::Real, x::Real) = exp(poislogpdf(λ, x)) -poislogpdf(λ::Real, x::Real) = poislogpdf(promote(λ, x)...) -function poislogpdf(λ::T, x::T) where {T <: Real} - val = xlogy(x, λ) - λ - loggamma(x + 1) +function poislogpdf(λ::Real, x::Real) + logupdf = poislogupdf(λ, x) + return isfinite(logupdf) ? logupdf - λ : logupdf +end + +poislogupdf(λ::Real, x::Real) = poislogupdf(promote(λ, x)...) +function poislogupdf(λ::T, x::T) where {T <: Real} + val = xlogy(x, λ) - loggamma(x + 1) + return x >= 0 && isinteger(x) ? val : oftype(val, -Inf) +end + +function poislogulikelihood(λ::Real, x::Real) + val = xlogy(x, λ) - λ return x >= 0 && isinteger(x) ? val : oftype(val, -Inf) end diff --git a/src/distrs/signrank.jl b/src/distrs/signrank.jl index 96936db..64b3224 100644 --- a/src/distrs/signrank.jl +++ b/src/distrs/signrank.jl @@ -25,25 +25,36 @@ the number of ways {1,2,...,j} can sum to W-i+1. return DP end -function signrankpdf(n::Int, W::Float64) - return isinteger(W) ? signrankpdf(n, Int(W)) : 0.0 +function signrankpdf(n::Int, W::Union{Int, Float64}) + numsets = signrank_numsets(n, W) + return iszero(numsets) ? 0.0 : ldexp(float(numsets), -n) end -function signrankpdf(n::Int, W::Int) + +function signrank_numsets(n::Int, W::Float64) + return isinteger(W) ? signrank_numsets(n, Int(W)) : 0 +end +function signrank_numsets(n::Int, W::Int) if W < 0 - return 0.0 + return 0 end max_W = (n * (n + 1)) >> 1 W2 = max_W - W if W2 < W - return signrankpdf(n, W2) + return signrank_numsets(n, W2) end DP = signrankDP(n, W) - return ldexp(float(DP[1]), -n) + return DP[1] end function signranklogpdf(n::Int, W::Union{Float64, Int}) return log(signrankpdf(n, W)) end +function signranklogupdf(n::Int, W::Union{Float64, Int}) + return log(signrank_numsets(n, W)) +end +function signranklogulikelihood(n::Int, W::Union{Float64, Int}) + return signranklogpdf(n, W) +end function signrankcdf(n::Int, W::Float64) return signrankcdf(n, round(Int, W, RoundNearestTiesUp)) diff --git a/src/distrs/tdist.jl b/src/distrs/tdist.jl index 8b1fef9..9f2e180 100644 --- a/src/distrs/tdist.jl +++ b/src/distrs/tdist.jl @@ -9,6 +9,20 @@ function tdistlogpdf(ν::T, x::T) where {T <: Real} return loggamma(νp12) - (logπ + log(ν)) / 2 - loggamma(ν / 2) - νp12 * log1p(x^2 / ν) end +tdistlogupdf(ν::Real, x::Real) = tdistlogupdf(promote(ν, x)...) +function tdistlogupdf(ν::T, x::T) where {T <: Real} + isinf(ν) && return normlogupdf(x) + νp12 = (ν + 1) / 2 + return - νp12 * log1p(x^2 / ν) +end + +tdistlogulikelihood(ν::Real, x::Real) = tdistlogulikelihood(promote(ν, x)...) +function tdistlogulikelihood(ν::T, x::T) where {T <: Real} + isinf(ν) && return normlogulikelihood(x) + νp12 = (ν + 1) / 2 + return loggamma(νp12) - log(ν) / 2 - loggamma(ν / 2) - νp12 * log1p(x^2 / ν) +end + function tdistcdf(ν::T, x::T) where {T <: Real} if isinf(ν) return normcdf(x) diff --git a/src/distrs/wilcox.jl b/src/distrs/wilcox.jl index 64511cb..d54d3cc 100644 --- a/src/distrs/wilcox.jl +++ b/src/distrs/wilcox.jl @@ -74,22 +74,33 @@ A. Löffler: "Über eine Partition der nat. Zahlen und ihre Anwendung beim U-Tes return partitions end -function wilcoxpdf(nx::Int, ny::Int, U::Float64) - return isinteger(U) ? wilcoxpdf(nx, ny, Int(U)) : 0.0 +function wilcoxpdf(nx::Int, ny::Int, U::Union{Float64, Int}) + numseqs = wilcox_numseqs(nx, ny, U) + return iszero(numseqs) ? 0.0 : numseqs / binomial(nx + ny, nx) end -function wilcoxpdf(nx::Int, ny::Int, U::Int) + +function wilcox_numseqs(nx::Int, ny::Int, U::Float64) + return isinteger(U) ? wilcox_numseqs(nx, ny, Int(U)) : 0 +end +function wilcox_numseqs(nx::Int, ny::Int, U::Int) max_U = nx * ny if !(0 <= U <= max_U) - return 0.0 + return 0 end U = min(U, max_U - U) partitions = wilcox_partitions(nx, ny, U) - return partitions[end] / binomial(nx + ny, nx) + return partitions[end] end function wilcoxlogpdf(nx::Int, ny::Int, U::Union{Float64, Int}) return log(wilcoxpdf(nx, ny, U)) end +function wilcoxlogupdf(nx::Int, ny::Int, U::Union{Float64, Int}) + return log(wilcox_numseqs(nx, ny, U)) +end +function wilcoxlogulikelihood(nx::Int, ny::Int, U::Union{Float64, Int}) + return wilcoxlogpdf(nx, ny, U) +end function wilcoxcdf(nx::Int, ny::Int, U::Float64) return wilcoxcdf(nx, ny, round(Int, U, RoundNearestTiesUp)) diff --git a/test/chainrules.jl b/test/chainrules.jl index d59b5ea..adf618c 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -9,45 +9,63 @@ using Random z = logistic(randn()) test_frule(betalogpdf, x, y, z) test_rrule(betalogpdf, x, y, z) + test_frule(betalogupdf, x, y, z) + test_rrule(betalogupdf, x, y, z) x = exp(randn()) y = exp(randn()) z = exp(randn()) test_frule(gammalogpdf, x, y, z) test_rrule(gammalogpdf, x, y, z) + test_frule(gammalogupdf, x, y, z) + test_rrule(gammalogupdf, x, y, z) x = exp(randn()) y = exp(randn()) test_frule(chisqlogpdf, x, y) test_rrule(chisqlogpdf, x, y) + test_frule(chisqlogupdf, x, y) + test_rrule(chisqlogupdf, x, y) x = exp(randn()) y = exp(randn()) z = exp(randn()) test_frule(fdistlogpdf, x, y, z) test_rrule(fdistlogpdf, x, y, z) + test_frule(fdistlogupdf, x, y, z) + test_rrule(fdistlogupdf, x, y, z) x = exp(randn()) y = randn() test_frule(tdistlogpdf, x, y) test_rrule(tdistlogpdf, x, y) + test_frule(tdistlogupdf, x, y) + test_rrule(tdistlogupdf, x, y) x = rand(1:100) y = logistic(randn()) z = rand(1:x) test_frule(binomlogpdf, x, y, z) test_rrule(binomlogpdf, x, y, z) + test_frule(binomlogupdf, x, y, z) + test_rrule(binomlogupdf, x, y, z) x = exp(randn()) y = rand(1:100) test_frule(poislogpdf, x, y) test_rrule(poislogpdf, x, y) + test_frule(poislogupdf, x, y) + test_rrule(poislogupdf, x, y) + test_frule(poislogulikelihood, x, y) + test_rrule(poislogulikelihood, x, y) # test special case λ = 0 - _, pb = rrule(poislogpdf, 0.0, 0) - _, x̄1, _ = pb(1) - @test x̄1 == -1 - _, pb = rrule(poislogpdf, 0.0, 1) - _, x̄1, _ = pb(1) - @test x̄1 == Inf + for f in (poislogpdf, poislogupdf, poislogulikelihood) + _, pb = rrule(f, 0.0, 0) + _, x̄1, _ = pb(1) + @test x̄1 == (f === poislogupdf ? 0 : -1) + _, pb = rrule(f, 0.0, 1) + _, x̄1, _ = pb(1) + @test x̄1 == Inf + end end diff --git a/test/runtests.jl b/test/runtests.jl index 0cd38a1..aeed0cd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -tests = ["rmath", "generic", "misc", "chainrules", "inverse", "tvpack", "qa"] +tests = ["rmath", "generic", "misc", "chainrules", "inverse", "tvpack", "unnormalized", "qa"] for t in tests fp = "$t.jl" diff --git a/test/unnormalized.jl b/test/unnormalized.jl new file mode 100644 index 0000000..2e35adf --- /dev/null +++ b/test/unnormalized.jl @@ -0,0 +1,66 @@ +using StatsFuns +using SpecialFunctions +using Test + +@testset "logupdf + logulikelihood" begin + @testset "optimized" begin + # Beta distribution + α = 0.1 + β = 2.5 + x = 0.7 + @test betalogupdf(α, β, x) ≈ betalogpdf(α, β, x) + logbeta(α, β) + @test betalogulikelihood(α, β, x) == betalogpdf(α, β, x) + + # Binomial distribution + n = 9 + p = 0.4 + k = 6 + @test binomlogupdf(n, p, k) ≈ binomlogpdf(n, p, k) + log(n + 1) + @test binomlogulikelihood(n, p, k) == binomlogpdf(n, p, k) + + # Chi-squared distribution + k = 4.2 + x = 3.1 + @test chisqlogupdf(k, x) ≈ chisqlogpdf(k, x) + k / 2 * log(2) + loggamma(k / 2) + @test chisqlogulikelihood(k, x) ≈ chisqlogpdf(k, x) + log(x) + x / 2 + + # F distribution + ν1 = 0.9 + ν2 = 1.5 + x = 2.1 + @test fdistlogupdf(ν1, ν2, x) ≈ fdistlogpdf(ν1, ν2, x) + logbeta(ν1 / 2, ν2 / 2) - (ν1 * log(ν1) + ν2 * log(ν2)) / 2 + @test fdistlogulikelihood(ν1, ν2, x) ≈ fdistlogpdf(ν1, ν2, x) - (ν1 / 2 - 1) * log(x) + + # Gamma distribution + k = 1.4 + θ = 2.3 + x = 1.9 + @test gammalogupdf(k, θ, x) ≈ gammalogpdf(k, θ, x) + loggamma(k) + k * log(θ) + @test gammalogulikelihood(k, θ, x) ≈ gammalogpdf(k, θ, x) + log(x) + end + + @testset "fallback" begin + # Hyper-geometric distribution + ms = 2 + mf = 3 + n = 4 + x = 2 + @test hyperlogupdf(ms, mf, n, x) == hyperlogpdf(ms, mf, n, x) + @test hyperlogulikelihood(ms, mf, n, x) == hyperlogpdf(ms, mf, n, x) + + # Non-central beta distribution + α = 0.8 + β = 2.1 + λ = 1.1 + x = 0.8 + @test nbetalogupdf(α, β, λ, x) == nbetalogpdf(α, β, λ, x) + @test nbetalogulikelihood(α, β, λ, x) == nbetalogpdf(α, β, λ, x) + + # Negative binomial distribution + r = 3 + p = 0.7 + x = 2 + @test nbinomlogupdf(r, p, x) == nbinomlogpdf(r, p, x) + @test nbinomlogulikelihood(r, p, x) == nbinomlogpdf(r, p, x) + end +end