Skip to content

Commit ddac60f

Browse files
sethaxendevmotion
andauthored
Add compatibility with MCMCDiagnosticTools v0.3 (#401)
* Bump MCMCDiagnosticTools compat * Update imported/exported methods * Remove type constraint on classifier * Overload and export mcse * Overload and update ess and rhat * Update summarystats * Update tests * Increment major version * Rename ess.jl to ess_rhat.jl * Add back ess_per_sec * Fix bug constructing ess_per_sec * Update ess_rhat tests * Test mcse * Update docs * Remove deprecations * Remove unused import * Revert "Fix MLJDecisionTreeInterface to 0.3.0 (#402)" This reverts commit 991f10b. * Always include ess_per_sec in table * Use isequal to pass with missing values * Use isequal for missing * Remove naive_se Fixes #351 * Test Tables interface before loading StatsPlots DataValues (a StatsPlots dependency) pirates a convert method that causes the Tables equality tests with `missing` to fail. See https://github.com/queryverse/DataValues.jl * Apply suggestions from code review Co-authored-by: David Widmann <devmotion@users.noreply.github.com> --------- Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent 991f10b commit ddac60f

File tree

16 files changed

+268
-151
lines changed

16 files changed

+268
-151
lines changed

Project.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "Chain types and utility functions for MCMC simulations."
6-
version = "5.7.1"
6+
version = "6.0.0"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -21,7 +21,6 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2121
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
2222
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2323
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
24-
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
2524
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2625
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2726
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
@@ -35,7 +34,7 @@ Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
3534
Formatting = "0.4"
3635
IteratorInterfaceExtensions = "0.1.1, 1"
3736
KernelDensity = "0.6.2"
38-
MCMCDiagnosticTools = "0.2"
37+
MCMCDiagnosticTools = "0.3"
3938
MLJModelInterface = "0.3.5, 0.4, 1.0"
4039
NaturalSort = "1"
4140
OrderedCollections = "1.4"

docs/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ CategoricalArrays = "0.8, 0.9, 0.10"
1515
DataFrames = "0.22, 1"
1616
Documenter = "0.26, 0.27"
1717
Gadfly = "1.3.4"
18-
MCMCChains = "5"
18+
MCMCChains = "6"
1919
MLJBase = "0.19, 0.20, 0.21"
20-
MLJDecisionTreeInterface = "=0.3.0"
20+
MLJDecisionTreeInterface = "0.3"
2121
StatsPlots = "0.14, 0.15"
2222
julia = "1.7"

docs/src/diagnostics.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Pages = [
99
"heideldiag.jl",
1010
"rafterydiag.jl",
1111
"rstar.jl",
12-
"ess.jl"
12+
"ess_rhat.jl",
13+
"mcse.jl",
1314
]
1415
```

src/MCMCChains.jl

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import IteratorInterfaceExtensions
2424

2525
import LinearAlgebra
2626
import Random
27-
import Serialization
2827
import Statistics: std, cor, mean, var, mean!
2928

3029
export Chains, chains, chainscat
@@ -36,13 +35,15 @@ export ChainDataFrame
3635
export summarize
3736

3837
# Reexport diagnostics functions
39-
using MCMCDiagnosticTools: discretediag, ess_rhat, ESSMethod, FFTESSMethod, BDAESSMethod,
40-
gelmandiag, gelmandiag_multivariate, gewekediag, heideldiag, rafterydiag, rstar
38+
using MCMCDiagnosticTools: discretediag, ess, ess_rhat, AutocovMethod, FFTAutocovMethod,
39+
BDAAutocovMethod, gelmandiag, gelmandiag_multivariate, gewekediag, heideldiag, mcse,
40+
rafterydiag, rhat, rstar
4141
export discretediag
42-
export ess_rhat, ESSMethod, FFTESSMethod, BDAESSMethod
42+
export ess, ess_rhat, rhat, AutocovMethod, FFTAutocovMethod, BDAAutocovMethod
4343
export gelmandiag, gelmandiag_multivariate
4444
export gewekediag
4545
export heideldiag
46+
export mcse
4647
export rafterydiag
4748
export rstar
4849

@@ -69,13 +70,14 @@ end
6970
include("utils.jl")
7071
include("chains.jl")
7172
include("constructors.jl")
72-
include("ess.jl")
73+
include("ess_rhat.jl")
7374
include("summarize.jl")
7475
include("discretediag.jl")
7576
include("fileio.jl")
7677
include("gelmandiag.jl")
7778
include("gewekediag.jl")
7879
include("heideldiag.jl")
80+
include("mcse.jl")
7981
include("rafterydiag.jl")
8082
include("sampling.jl")
8183
include("stats.jl")
@@ -84,19 +86,4 @@ include("plot.jl")
8486
include("tables.jl")
8587
include("rstar.jl")
8688

87-
# deprecations
88-
# TODO: Remove dependency on Serialization if this deprecation is removed
89-
# somehow `@deprecate` doesn't work with qualified function names,
90-
# so we use the following hack
91-
const _read = Base.read
92-
const _write = Base.write
93-
Base.@deprecate _read(
94-
f::AbstractString,
95-
::Type{T}
96-
) where {T<:Chains} Serialization.deserialize(f) false
97-
Base.@deprecate _write(
98-
f::AbstractString,
99-
c::Chains
100-
) Serialization.serialize(f, c) false
101-
10289
end # module

src/ess.jl

Lines changed: 0 additions & 33 deletions
This file was deleted.

src/ess_rhat.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""
2+
ess(chains::Chains; duration=compute_duration, kwargs...)
3+
4+
Estimate the effective sample size.
5+
6+
ESS per second options include `duration=MCMCChains.compute_duration` (the default)
7+
and `duration=MCMCChains.wall_duration`.
8+
"""
9+
function MCMCDiagnosticTools.ess(
10+
chains::Chains;
11+
sections = _default_sections(chains), duration = compute_duration, kwargs...
12+
)
13+
# Subset the chain
14+
_chains = Chains(chains, _clean_sections(chains, sections))
15+
16+
# Estimate the effective sample size
17+
ess = MCMCDiagnosticTools.ess(
18+
_permutedims_diagnostics(_chains.value.data);
19+
kwargs...,
20+
)
21+
22+
# Calculate ESS/minute if available
23+
dur = duration(chains)
24+
25+
# Convert to NamedTuple
26+
ess_per_sec = ess ./ dur
27+
nt = merge((parameters = names(_chains),), (; ess, ess_per_sec))
28+
29+
return ChainDataFrame("ESS", nt)
30+
end
31+
32+
"""
33+
rhat(chains::Chains; kwargs...)
34+
35+
Estimate the ``\\widehat{R}`` diagnostic.
36+
"""
37+
function MCMCDiagnosticTools.rhat(
38+
chains::Chains;
39+
sections = _default_sections(chains), kwargs...
40+
)
41+
# Subset the chain
42+
_chains = Chains(chains, _clean_sections(chains, sections))
43+
44+
# Estimate the rhat
45+
rhat = MCMCDiagnosticTools.rhat(
46+
_permutedims_diagnostics(_chains.value.data);
47+
kwargs...,
48+
)
49+
50+
# Convert to NamedTuple
51+
nt = merge((parameters = names(_chains),), (; rhat))
52+
53+
return ChainDataFrame("R-hat", nt)
54+
end
55+
56+
"""
57+
ess_rhat(chains::Chains; duration=compute_duration, kwargs...)
58+
59+
Estimate the effective sample size and the ``\\widehat{R}`` diagnostic
60+
61+
ESS per second options include `duration=MCMCChains.compute_duration` (the default)
62+
and `duration=MCMCChains.wall_duration`.
63+
"""
64+
function MCMCDiagnosticTools.ess_rhat(
65+
chains::Chains;
66+
sections = _default_sections(chains), duration = compute_duration, kwargs...
67+
)
68+
# Subset the chain
69+
_chains = Chains(chains, _clean_sections(chains, sections))
70+
71+
# Estimate the effective sample size and rhat
72+
ess_rhat = MCMCDiagnosticTools.ess_rhat(
73+
_permutedims_diagnostics(_chains.value.data);
74+
kwargs...,
75+
)
76+
77+
# Calculate ESS/minute if available
78+
dur = duration(chains)
79+
80+
# Convert to NamedTuple
81+
ess_per_sec = ess_rhat.ess ./ dur
82+
nt = merge((parameters = names(_chains),), ess_rhat, (; ess_per_sec))
83+
84+
return ChainDataFrame("ESS/R-hat", nt)
85+
end

src/mcse.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""
2+
mcse(chains::Chains; duration=compute_duration, kwargs...)
3+
4+
Estimate the Monte Carlo standard error.
5+
"""
6+
function MCMCDiagnosticTools.mcse(
7+
chains::Chains;
8+
sections = _default_sections(chains), kwargs...
9+
)
10+
# Subset the chain
11+
_chains = Chains(chains, _clean_sections(chains, sections))
12+
13+
# Estimate the effective sample size
14+
mcse = MCMCDiagnosticTools.mcse(
15+
_permutedims_diagnostics(_chains.value.data);
16+
kwargs...,
17+
)
18+
19+
nt = merge((parameters = names(_chains),), (; mcse))
20+
21+
return ChainDataFrame("MCSE", nt)
22+
end

src/rstar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ true
3838
```
3939
"""
4040
function MCMCDiagnosticTools.rstar(
41-
classif::MLJModelInterface.Supervised, chn::Chains; kwargs...
41+
classif, chn::Chains; kwargs...
4242
)
4343
return MCMCDiagnosticTools.rstar(Random.GLOBAL_RNG, classif, chn; kwargs...)
4444
end
4545

4646
function MCMCDiagnosticTools.rstar(
4747
rng::Random.AbstractRNG,
48-
classif::MLJModelInterface.Supervised,
48+
classif,
4949
chn::Chains;
5050
sections = _default_sections(chn),
5151
kwargs...

src/stats.jl

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -270,14 +270,13 @@ end
270270
chains;
271271
sections = _default_sections(chains),
272272
append_chains= true,
273-
method::AbstractESSMethod = ESSMethod(),
273+
autocov_method::AbstractAutocovMethod = AutocovMethod(),
274274
maxlag = 250,
275-
etype = :bm,
276275
kwargs...
277276
)
278277
279-
Compute the mean, standard deviation, naive standard error, Monte Carlo standard error,
280-
and effective sample size for each parameter in the chain.
278+
Compute the mean, standard deviation, Monte Carlo standard error, bulk- and tail- effective
279+
sample size, and ``\\widehat{R}`` diagnostic for each parameter in the chain.
281280
282281
Setting `append_chains=false` will return a vector of dataframes containing the summary
283282
statistics for each chain.
@@ -288,27 +287,42 @@ function summarystats(
288287
chains::Chains;
289288
sections = _default_sections(chains),
290289
append_chains::Bool = true,
291-
method::MCMCDiagnosticTools.AbstractESSMethod = ESSMethod(),
290+
autocov_method::MCMCDiagnosticTools.AbstractAutocovMethod = AutocovMethod(),
292291
maxlag = 250,
293-
etype = :bm,
294292
kwargs...
295293
)
296294
# Store everything.
297-
funs = [meancskip, stdcskip, semcskip, x -> MCMCDiagnosticTools.mcse(cskip(x); method=etype, kwargs...)]
298-
func_names = [:mean, :std, :naive_se, :mcse]
295+
funs = [meancskip, stdcskip]
296+
func_names = [:mean, :std]
299297

300298
# Subset the chain.
301299
_chains = Chains(chains, _clean_sections(chains, sections))
302300

303-
# Calculate ESS separately.
304-
ess_df = MCMCDiagnosticTools.ess_rhat(_chains; sections = nothing, method = method, maxlag = maxlag)
301+
# Calculate MCSE and ESS/R-hat separately.
302+
mcse_df = MCMCDiagnosticTools.mcse(
303+
_chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag,
304+
)
305+
ess_rhat_rank_df = MCMCDiagnosticTools.ess_rhat(
306+
_chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:rank
307+
)
308+
ess_tail_df = MCMCDiagnosticTools.ess(
309+
_chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:tail
310+
)
311+
nt_additional = (
312+
mcse=mcse_df.nt.mcse,
313+
ess_bulk=ess_rhat_rank_df.nt.ess,
314+
ess_tail=ess_tail_df.nt.ess,
315+
rhat=ess_rhat_rank_df.nt.rhat,
316+
ess_per_sec=ess_rhat_rank_df.nt.ess_per_sec,
317+
)
318+
additional_df = ChainDataFrame("Additional", nt_additional)
305319

306320
# Summarize.
307321
summary_df = summarize(
308322
_chains, funs...;
309323
func_names = func_names,
310324
append_chains = append_chains,
311-
additional_df = ess_df,
325+
additional_df = additional_df,
312326
name = "Summary Statistics",
313327
sections = nothing
314328
)

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ Documenter = "0.26, 0.27"
3030
FFTW = "1.1"
3131
IteratorInterfaceExtensions = "1"
3232
KernelDensity = "0.6.2"
33-
MCMCChains = "5"
33+
MCMCChains = "6"
3434
MLJBase = "0.18, 0.19, 0.20, 0.21"
35-
MLJDecisionTreeInterface = "=0.3.0"
35+
MLJDecisionTreeInterface = "0.3"
3636
StatsBase = "0.33.2"
3737
StatsPlots = "0.14.17, 0.15"
3838
TableTraits = "1"

0 commit comments

Comments
 (0)