diff --git a/.gitignore b/.gitignore index 1010547..a960a8b 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ docs/src/**/*_files/libs docs/src/**/*.html docs/src/**/*.ipynb docs/src/**/*Manifest.toml +docs/src_stash/*.ipynb diff --git a/dev/doubleMM.jl b/dev/doubleMM.jl index 3f1a813..021b936 100644 --- a/dev/doubleMM.jl +++ b/dev/doubleMM.jl @@ -137,7 +137,7 @@ end () -> begin # optimized loss is indeed lower than with true parameters int_ϕθP = ComponentArrayInterpreter(CA.ComponentVector( ϕg = 1:length(prob0.ϕg), θP = prob0.θP)) - loss_gf = get_loss_gf(prob0.g, prob0.transM, prob0.transP, prob0.f, Float32[], int_ϕθP) + loss_gf = get_loss_gf(prob0.g, prob0.transM, prob0.transP, prob0.f, Float32[], py, int_ϕθP) loss_gf(vcat(prob3.ϕg, prob3.θP), xM, xP, y_o, y_unc, i_sites)[1] loss_gf(vcat(prob3o.ϕg, prob3o.θP), xM, xP, y_o, y_unc, i_sites)[1] # diff --git a/docs/make.jl b/docs/make.jl index 93bf49c..9ff370c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -24,6 +24,7 @@ makedocs(; ], "How to" => [ ".. use GPU" => "tutorials/lux_gpu.md", + ".. specify log-Likelihood" => "tutorials/logden_user.md", ".. model independent parameters" => "tutorials/blocks_corr.md", ".. model site-global corr" => "tutorials/corr_site_global.md", ], diff --git a/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-11-output-1.png b/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-11-output-1.png index deb0d1d..b5f5d20 100644 Binary files a/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-11-output-1.png and b/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-11-output-1.png differ diff --git a/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-12-output-1.png b/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-12-output-1.png index 7225faa..80db561 100644 Binary files a/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-12-output-1.png and b/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-12-output-1.png differ diff --git a/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-13-output-1.png b/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-13-output-1.png index a5ab7fb..d1f156b 100644 Binary files a/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-13-output-1.png and b/docs/src/tutorials/blocks_corr_files/figure-commonmark/cell-13-output-1.png differ diff --git a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-10-output-1.png b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-10-output-1.png index 9d2324e..8ee22fb 100644 Binary files a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-10-output-1.png and b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-10-output-1.png differ diff --git a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-11-output-1.png b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-11-output-1.png index 123a022..b8d8e64 100644 Binary files a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-11-output-1.png and b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-11-output-1.png differ diff --git a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-12-output-1.png b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-12-output-1.png index 11ba18f..7aa7d4d 100644 Binary files a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-12-output-1.png and b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-12-output-1.png differ diff --git a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-9-output-1.png b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-9-output-1.png index ab8eb5c..1b863b6 100644 Binary files a/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-9-output-1.png and b/docs/src/tutorials/corr_site_global_files/figure-commonmark/cell-9-output-1.png differ diff --git a/docs/src/tutorials/logden_user.md b/docs/src/tutorials/logden_user.md new file mode 100644 index 0000000..240f4a1 --- /dev/null +++ b/docs/src/tutorials/logden_user.md @@ -0,0 +1,225 @@ +# How to specify a custom LogLikelihood of the observations + + +``` @meta +CurrentModule = HybridVariationalInference +``` + +This guide shows how the user can specify a customized log-density function. + +## Motivation + +The loglikelihood function assigns a cost to the mismatch between predictions and +observations. This often needs to be customized to the specific inversion. + +This guide walks through he specification of such a function and inspects +differences among two log-likelihood functions. +Specifically, it will assume observation errors to be independently distributed +according to a LogNormal distribution with a specified fixed relative error, +compared to an inversion assuming observation error to be distributed independently normal. + +First load necessary packages. + +``` julia +using HybridVariationalInference +using ComponentArrays: ComponentArrays as CA +using Bijectors +using SimpleChains +using MLUtils +using JLD2 +using Random +using CairoMakie +using PairPlots # scatterplot matrices +``` + +This tutorial reuses and modifies the fitted object saved at the end of the +[Basic workflow without GPU](@ref) tutorial, that used a log-Likelihood +function assuming observation error to be distributed independently normal. + +``` julia +fname = "intermediate/basic_cpu_results.jld2" +print(abspath(fname)) +prob = probo_normal = load(fname, "probo"); +``` + +## Write the LogLikelihood Function + +The function signature corresponds to the one of [`neg_logden_indep_normal`](@ref). +of signature + +`neg_log_den_user(y_pred, y_obs, y_unc; kwargs...)` + +It takes inputs of predictions, `y_pred`, observations, `y_obs`, +and uncertainties parameters, `y_unc` and returns the logarithm of the +likelihhood up to a constant. + +All of the arguments are vectors of the same length specifying predictions and +observations for one site. +If `y_pred`, `y_obs` are given as a matrix of several column-vectors, their summed +Likelihood is computed. + +The density of a LogNormal distribution is + +$$ +\frac{ 1 }{ x \sqrt{2 \pi \sigma^2} } \exp\left( -\frac{ (\ln x-\mu)^2 }{2 \sigma^2} \right)$$ + +where x is the observation, μ is the log of the prediction, and σ is the scale +parameter that is related to the relative error, $c_v$ by $\sigma = \sqrt{ln(c^2_v + 1)}$. + +Taking the log: + +$$ + -ln x -\frac{1}{2} ln \sigma^2 -\frac{1}{2} ln (2 \pi) -\frac{ (\ln x-\mu)^2 }{2 \sigma^2}$$ + +Negating and dropping the constants $-\frac{1}{2} ln (2 \pi)$ and $-\frac{1}{2} ln \sigma^2$ + +$$ + ln x + \frac{1}{2} \left(\frac{ (\ln x-\mu)^2 }{\sigma^2} \right)$$ + +``` julia +function neg_logden_lognormalep_lognormal(y_pred, y_obs::AbstractArray{ET}, y_unc; + σ2 = log(abs2(ET(0.02)) + ET(1))) where ET + lnx = log.(CA.getdata(y_obs)) + μ = log.(CA.getdata(y_pred)) + nlogL = sum(lnx .+ abs2.(lnx .- μ) ./ (ET(2) .* σ2)) + #nlogL = sum(lnx + (log(σ2) .+ abs2.(lnx .- μ) ./ σ2) ./ ET(2)) # nonconstant σ2 + return (nlogL) +end +``` + +If information on the different relative error by observation was available, +we could pass that information using the DataLoader with `y_unc`, rather than +assuming a constant relative error across observations. + +## Update the problem and redo the inversion + +HybridProblem has keyword argument `py` to specify the function of negative Log-Likelihood. + +``` julia +prob_lognormal = HybridProblem(prob; py = neg_logden_lognormalep_lognormal) + +using OptimizationOptimisers +import Zygote + +solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) + +(; probo) = solve(prob_lognormal, solver; + callback = callback_loss(100), # output during fitting + epochs = 20, +); probo_lognormal = probo; +``` + +## Compare results between assumptions of observation error + +First, draw a sample form the inversion assumping normal and a sample from +the inversion assuming loglornally distributed observation errors. + +``` julia +n_sample_pred = 400 +(y_normal, θsP_normal, θsMs_normal) = (; y, θsP, θsMs) = predict_hvi( + Random.default_rng(), probo_normal; n_sample_pred) +(y_lognormal, θsP_lognormal, θsMs_lognormal) = (; y, θsP, θsMs) = predict_hvi( + Random.default_rng(), probo_lognormal; n_sample_pred) +``` + +Get the original observations from the DataLoader of the problem, and +compute the residuals. + +``` julia +train_loader = get_hybridproblem_train_dataloader(probo_normal; scenario=()) +y_o = train_loader.data[3] +resid_normal = y_o .- y_normal +resid_lognormal = y_o .- y_lognormal +``` + +And compare plots of some of the results. + +``` julia +i_out = 4 +i_site = 1 +fig = Figure(); ax = Axis(fig[1,1], xlabel="observations error (y_obs - y_pred)",ylabel="probability density") +#hist!(ax, resid_normal[i_out,i_site,:], label="normal", normalization=:pdf) +density!(ax, resid_normal[i_out,i_site,:], alpha = 0.8, label="normal") +density!(ax, resid_lognormal[i_out,i_site,:], alpha = 0.8, label="lognormal") +axislegend(ax, unique=true) +fig +``` + +![](logden_user_files/figure-commonmark/cell-8-output-1.png) + +The density plot of the observation residuals does not show the lognormal shape. +The used synthetic observations were actually noramally +distributed around predictions with true parameters. + +How does the wrong assumption of observation error influence the parameter +posterior? + +``` julia +i_site = 1 +fig = Figure(); ax = Axis(fig[1,1], xlabel="global parameter K2",ylabel="probability density") +#hist!(ax, resid_normal[i_out,i_site,:], label="normal", normalization=:pdf) +density!(ax, θsP_normal[:K2,:], alpha = 0.8, label="normal") +density!(ax, θsP_lognormal[:K2,:], alpha = 0.8, label="lognormal") +axislegend(ax, unique=true) +fig +``` + +![](logden_user_files/figure-commonmark/cell-9-output-1.png) + +The marginal posterior of the global parameters is also similar, with a small +trend of lower values. + +``` julia +i_site = 1 +θln = vcat(θsP_lognormal, θsMs_lognormal[i_site,:,:]) +θln_nt = NamedTuple(Symbol("$(k)_lognormal") => CA.getdata(θln[k,:]) for k in keys(θln[:,1])) # +#θn = vcat(θsP_normal, θsMs_normal[i_site,:,:]) +#θn_nt = NamedTuple(Symbol("$(k)_normal") => CA.getdata(θn[k,:]) for k in keys(θn[:,1])) # +# ntc = (;θn_nt..., θln_nt...) +plt = pairplot(θln_nt) +``` + +![](logden_user_files/figure-commonmark/cell-10-output-1.png) + +The corner plot of the independent-parameters estimate +looks similar and shows correlations between site parameters, $r_1$ and $K_1$. + +``` julia +i_out = 4 +fig = Figure(); ax = Axis(fig[1,1], xlabel="mean(y)",ylabel="sd(y)") +ymean_normal = [mean(y_normal[i_out,s,:]) for s in axes(y_normal, 2)] +ysd_normal = [std(y_normal[i_out,s,:]) for s in axes(y_normal, 2)] +scatter!(ax, ymean_normal, ysd_normal, label="normal") +ymean_lognormal = [mean(y_lognormal[i_out,s,:]) for s in axes(y_lognormal, 2)] +ysd_lognormal = [std(y_lognormal[i_out,s,:]) for s in axes(y_lognormal, 2)] +scatter!(ax, ymean_lognormal, ysd_lognormal, label="lognormal") +axislegend(ax, unique=true) +fig +``` + +![](logden_user_files/figure-commonmark/cell-11-output-1.png) + +The predicted magnitude of error in predictions for the fourth observation across sites +is of the same magnitude, +and still shows (although weaker) pattern of decreasing uncertainty with +increasing value. + +``` julia +plot_sd_vs_mean = (par) -> begin + fig = Figure(); ax = Axis(fig[1,1], xlabel="mean($par)",ylabel="sd($par)") + θmean_normal = [mean(θsMs_normal[s,par,:]) for s in axes(θsMs_normal, 1)] + θsd_normal = [std(θsMs_normal[s,par,:]) for s in axes(θsMs_normal, 1)] + scatter!(ax, θmean_normal, θsd_normal, label="correlated") + θmean_lognormal = [mean(θsMs_lognormal[s,par,:]) for s in axes(θsMs_lognormal, 1)] + θsd_lognormal = [std(θsMs_lognormal[s,par,:]) for s in axes(θsMs_lognormal, 1)] + scatter!(ax, θmean_lognormal, θsd_lognormal, label="independent") + axislegend(ax, unique=true) + fig +end +plot_sd_vs_mean(:K1) +``` + +![](logden_user_files/figure-commonmark/cell-12-output-1.png) + +For the assumed fixed relative error,the uncertainty in the model +parameters, $K_1$, across sites is similar to the uncertainty with normal log-likelihood. diff --git a/docs/src/tutorials/logden_user.qmd b/docs/src/tutorials/logden_user.qmd new file mode 100644 index 0000000..06f3043 --- /dev/null +++ b/docs/src/tutorials/logden_user.qmd @@ -0,0 +1,231 @@ +--- +title: "How to specify a custom LogLikelihood of the observations" +engine: julia +execute: + echo: true + output: false + daemon: 3600 +format: + commonmark: + variant: -raw_html+tex_math_dollars + wrap: preserve +bibliography: twutz_txt.bib +--- + +``` @meta +CurrentModule = HybridVariationalInference +``` + +This guide shows how the user can specify a customized log-Likelihood function. + +## Motivation +The loglikelihood function assigns a cost to the mismatch between predictions and +observations. This often needs to be customized to the specific inversion. + +This guide walks through he specification of such a function and inspects +differences among two log-likelihood functions. +Specifically, it will assume observation errors to be independently distributed +according to a LogNormal distribution with a specified fixed relative error, +compared to an inversion assuming observation error to be distributed independently normal. + +First load necessary packages. +```{julia} +using HybridVariationalInference +using ComponentArrays: ComponentArrays as CA +using Bijectors +using SimpleChains +using MLUtils +using JLD2 +using Random +using CairoMakie +using PairPlots # scatterplot matrices +``` + +This tutorial reuses and modifies the fitted object saved at the end of the +[Basic workflow without GPU](@ref) tutorial, that used a log-Likelihood +function assuming observation error to be distributed independently normal. + +```{julia} +fname = "intermediate/basic_cpu_results.jld2" +print(abspath(fname)) +prob = probo_normal = load(fname, "probo"); +``` + +## Write the LogLikelihood Function + +The function signature corresponds to the one of [`neg_logden_indep_normal`](@ref). +of signature + + `neg_log_den_user(y_pred, y_obs, y_unc; kwargs...)` + +It takes inputs of predictions, `y_pred`, observations, `y_obs`, +and uncertainties parameters, `y_unc` and returns the logarithm of the +likelihhood up to a constant. + +All of the arguments are vectors of the same length specifying predictions and +observations for one site. +If `y_pred`, `y_obs` are given as a matrix of several column-vectors, their summed +Likelihood is computed. + +The density of a LogNormal distribution is + +$$ +\frac{ 1 }{ x \sqrt{2 \pi \sigma^2} } \exp\left( -\frac{ (\ln x-\mu)^2 }{2 \sigma^2} \right)$$ + +where x is the observation, μ is the log of the prediction, and σ is the scale +parameter that is related to the relative error, $c_v$ by $\sigma = \sqrt{ln(c^2_v + 1)}$. + +Taking the log: + +$$ + -ln x -\frac{1}{2} ln \sigma^2 -\frac{1}{2} ln (2 \pi) -\frac{ (\ln x-\mu)^2 }{2 \sigma^2}$$ + +Negating and dropping the constants $-\frac{1}{2} ln (2 \pi)$ and $-\frac{1}{2} ln \sigma^2$ + +$$ + ln x + \frac{1}{2} \left(\frac{ (\ln x-\mu)^2 }{\sigma^2} \right)$$ + +```{julia} +function neg_logden_lognormalep_lognormal(y_pred, y_obs::AbstractArray{ET}, y_unc; + σ2 = log(abs2(ET(0.02)) + ET(1))) where ET + lnx = log.(CA.getdata(y_obs)) + μ = log.(CA.getdata(y_pred)) + nlogL = sum(lnx .+ abs2.(lnx .- μ) ./ (ET(2) .* σ2)) + #nlogL = sum(lnx + (log(σ2) .+ abs2.(lnx .- μ) ./ σ2) ./ ET(2)) # nonconstant σ2 + return (nlogL) +end +``` + +If information on the different relative error by observation was available, +we could pass that information using the DataLoader with `y_unc`, rather than +assuming a constant relative error across observations. + +## Update the problem and redo the inversion + +HybridProblem has keyword argument `py` to specify the function of negative Log-Likelihood. + +```{julia} +prob_lognormal = HybridProblem(prob; py = neg_logden_lognormalep_lognormal) + +using OptimizationOptimisers +import Zygote + +solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3) + +(; probo) = solve(prob_lognormal, solver; + callback = callback_loss(100), # output during fitting + epochs = 20, +); probo_lognormal = probo; +``` + +## Compare results between assumptions of observation error + +First, draw a sample form the inversion assumping normal and a sample from +the inversion assuming loglornally distributed observation errors. + +```{julia} +n_sample_pred = 400 +(y_normal, θsP_normal, θsMs_normal) = (; y, θsP, θsMs) = predict_hvi( + Random.default_rng(), probo_normal; n_sample_pred) +(y_lognormal, θsP_lognormal, θsMs_lognormal) = (; y, θsP, θsMs) = predict_hvi( + Random.default_rng(), probo_lognormal; n_sample_pred) +``` + +Get the original observations from the DataLoader of the problem, and +compute the residuals. + +```{julia} +train_loader = get_hybridproblem_train_dataloader(probo_normal; scenario=()) +y_o = train_loader.data[3] +resid_normal = y_o .- y_normal +resid_lognormal = y_o .- y_lognormal +``` + +And compare plots of some of the results. + +```{julia} +#| output: true +i_out = 4 +i_site = 1 +fig = Figure(); ax = Axis(fig[1,1], xlabel="observations error (y_obs - y_pred)",ylabel="probability density") +#hist!(ax, resid_normal[i_out,i_site,:], label="normal", normalization=:pdf) +density!(ax, resid_normal[i_out,i_site,:], alpha = 0.8, label="normal") +density!(ax, resid_lognormal[i_out,i_site,:], alpha = 0.8, label="lognormal") +axislegend(ax, unique=true) +fig +``` + +The density plot of the observation residuals does not show the lognormal shape. +The used synthetic observations were actually noramally +distributed around predictions with true parameters. + +How does the wrong assumption of observation error influence the parameter +posterior? + +```{julia} +#| output: true +i_site = 1 +fig = Figure(); ax = Axis(fig[1,1], xlabel="global parameter K2",ylabel="probability density") +#hist!(ax, resid_normal[i_out,i_site,:], label="normal", normalization=:pdf) +density!(ax, θsP_normal[:K2,:], alpha = 0.8, label="normal") +density!(ax, θsP_lognormal[:K2,:], alpha = 0.8, label="lognormal") +axislegend(ax, unique=true) +fig +``` + +The marginal posterior of the global parameters is also similar, with a small +trend of lower values. + +```{julia} +#| output: true +i_site = 1 +θln = vcat(θsP_lognormal, θsMs_lognormal[i_site,:,:]) +θln_nt = NamedTuple(Symbol("$(k)_lognormal") => CA.getdata(θln[k,:]) for k in keys(θln[:,1])) # +#θn = vcat(θsP_normal, θsMs_normal[i_site,:,:]) +#θn_nt = NamedTuple(Symbol("$(k)_normal") => CA.getdata(θn[k,:]) for k in keys(θn[:,1])) # +# ntc = (;θn_nt..., θln_nt...) +plt = pairplot(θln_nt) +``` +The corner plot of the independent-parameters estimate +looks similar and shows correlations between site parameters, $r_1$ and $K_1$. + +```{julia} +#| output: true +i_out = 4 +fig = Figure(); ax = Axis(fig[1,1], xlabel="mean(y)",ylabel="sd(y)") +ymean_normal = [mean(y_normal[i_out,s,:]) for s in axes(y_normal, 2)] +ysd_normal = [std(y_normal[i_out,s,:]) for s in axes(y_normal, 2)] +scatter!(ax, ymean_normal, ysd_normal, label="normal") +ymean_lognormal = [mean(y_lognormal[i_out,s,:]) for s in axes(y_lognormal, 2)] +ysd_lognormal = [std(y_lognormal[i_out,s,:]) for s in axes(y_lognormal, 2)] +scatter!(ax, ymean_lognormal, ysd_lognormal, label="lognormal") +axislegend(ax, unique=true) +fig +``` + +The predicted magnitude of error in predictions for the fourth observation across sites +is of the same magnitude, +and still shows (although weaker) pattern of decreasing uncertainty with +increasing value. + +```{julia} +#| output: true +plot_sd_vs_mean = (par) -> begin + fig = Figure(); ax = Axis(fig[1,1], xlabel="mean($par)",ylabel="sd($par)") + θmean_normal = [mean(θsMs_normal[s,par,:]) for s in axes(θsMs_normal, 1)] + θsd_normal = [std(θsMs_normal[s,par,:]) for s in axes(θsMs_normal, 1)] + scatter!(ax, θmean_normal, θsd_normal, label="correlated") + θmean_lognormal = [mean(θsMs_lognormal[s,par,:]) for s in axes(θsMs_lognormal, 1)] + θsd_lognormal = [std(θsMs_lognormal[s,par,:]) for s in axes(θsMs_lognormal, 1)] + scatter!(ax, θmean_lognormal, θsd_lognormal, label="independent") + axislegend(ax, unique=true) + fig +end +plot_sd_vs_mean(:K1) +``` + +For the assumed fixed relative error,the uncertainty in the model +parameters, $K_1$, across sites is similar to the uncertainty with normal log-likelihood. + + + diff --git a/docs/src/tutorials/logden_user_files/figure-commonmark/cell-10-output-1.png b/docs/src/tutorials/logden_user_files/figure-commonmark/cell-10-output-1.png new file mode 100644 index 0000000..dd132d4 Binary files /dev/null and b/docs/src/tutorials/logden_user_files/figure-commonmark/cell-10-output-1.png differ diff --git a/docs/src/tutorials/logden_user_files/figure-commonmark/cell-11-output-1.png b/docs/src/tutorials/logden_user_files/figure-commonmark/cell-11-output-1.png new file mode 100644 index 0000000..64b3b19 Binary files /dev/null and b/docs/src/tutorials/logden_user_files/figure-commonmark/cell-11-output-1.png differ diff --git a/docs/src/tutorials/logden_user_files/figure-commonmark/cell-12-output-1.png b/docs/src/tutorials/logden_user_files/figure-commonmark/cell-12-output-1.png new file mode 100644 index 0000000..26aae9f Binary files /dev/null and b/docs/src/tutorials/logden_user_files/figure-commonmark/cell-12-output-1.png differ diff --git a/docs/src/tutorials/logden_user_files/figure-commonmark/cell-8-output-1.png b/docs/src/tutorials/logden_user_files/figure-commonmark/cell-8-output-1.png new file mode 100644 index 0000000..0ab0e99 Binary files /dev/null and b/docs/src/tutorials/logden_user_files/figure-commonmark/cell-8-output-1.png differ diff --git a/docs/src/tutorials/logden_user_files/figure-commonmark/cell-9-output-1.png b/docs/src/tutorials/logden_user_files/figure-commonmark/cell-9-output-1.png new file mode 100644 index 0000000..8142630 Binary files /dev/null and b/docs/src/tutorials/logden_user_files/figure-commonmark/cell-9-output-1.png differ diff --git a/ext/HybridVariationalInferenceLuxExt.jl b/ext/HybridVariationalInferenceLuxExt.jl index 0329c1c..4c66175 100644 --- a/ext/HybridVariationalInferenceLuxExt.jl +++ b/ext/HybridVariationalInferenceLuxExt.jl @@ -7,7 +7,10 @@ using Random using StatsFuns: logistic - +# AbstractModelApplicator that stores a Lux.StatefulLuxLayer, so that +# it can be applied with given inputs and parameters +# The `int_ϕ` ComponentArrayInterpreter, attaches the correct axes to the +# supplied parameters, that do not need to keep the Axis information struct LuxApplicator{MT, IT} <: AbstractModelApplicator stateful_layer::MT int_ϕ::IT @@ -24,7 +27,7 @@ function HVI.construct_ChainsApplicator(rng::AbstractRNG, m::Chain, float_type=F end function HVI.apply_model(app::LuxApplicator, x, ϕ) - ϕc = app.int_ϕ(ϕ) + ϕc = app.int_ϕ(CA.getdata(ϕ)) app.stateful_layer(x, ϕc) end diff --git a/src/HybridSolver.jl b/src/HybridSolver.jl index 87cbaf6..a7c9cb0 100644 --- a/src/HybridSolver.jl +++ b/src/HybridSolver.jl @@ -32,13 +32,14 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolve train_loader_dev = train_loader end f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites=false) + py = get_hybridproblem_neg_logden_obs(prob; scenario) pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) priors = get_hybridproblem_priors(prob; scenario) priorsP = [priors[k] for k in keys(par_templates.θP)] priorsM = [priors[k] for k in keys(par_templates.θM)] #intP = ComponentArrayInterpreter(par_templates.θP) - loss_gf = get_loss_gf(g_dev, transM, transP, f, intϕ; + loss_gf = get_loss_gf(g_dev, transM, transP, f, py, intϕ; cdev=infer_cdev(gdevs), pbm_covars, n_site_batch=n_batch, priorsP, priorsM,) # call loss function once l1 = is_infer ? @@ -54,6 +55,8 @@ function CommonSolve.solve(prob::AbstractHybridProblem, solver::HybridPointSolve # Zygote.gradient(ϕ0_dev -> loss_gf(ϕ0_dev, data1...)[1], ϕ0_dev) optf = Optimization.OptimizationFunction((ϕ, data) -> loss_gf(ϕ, data...)[1], Optimization.AutoZygote()) + # use CA.getdata(ϕ0_dev), i.e. the plain vector to avoid recompiling for specific CA + # loss_gf re-attaches the axes optprob = OptimizationProblem(optf, CA.getdata(ϕ0_dev), train_loader_dev) res = Optimization.solve(optprob, solver.alg; kwargs...) ϕ = intϕ(res.u) diff --git a/src/ModelApplicator.jl b/src/ModelApplicator.jl index 943a147..4d08fc0 100644 --- a/src/ModelApplicator.jl +++ b/src/ModelApplicator.jl @@ -11,10 +11,10 @@ a tuple with two components: - The applicator - a sample parameter vector (type depends on the used ML-framework) -Implemented are -- `construct_SimpleChainsApplicator` -- `construct_FluxApplicator` -- `construct_LuxApplicator` +Implemented overloads of `construct_ChainsApplicator` for layers of +- `SimpleChains.SimpleChain` +- `Flux.Chain` +- `Lux.Chain` """ abstract type AbstractModelApplicator end diff --git a/src/gf.jl b/src/gf.jl index 4322992..30b5a76 100644 --- a/src/gf.jl +++ b/src/gf.jl @@ -137,9 +137,11 @@ Create a loss function for given - g(x, ϕ): machine learning model - transM: transforamtion of parameters at unconstrained space - f(θMs, θP): mechanistic model +- py: `function(y_pred, y_obs, y_unc)` to compute negative log-likelihood, i.e. cost - intϕ: interpreter attaching axis with components ϕg and ϕP -- intP: interpreter attaching axis to ζP = ϕP with components used by f -- kwargs: additional keyword arguments passed to gf, such as gdev or pbm_covars +- intP: interpreter attaching axis to ζP = ϕP with components used by f, + The default, uses `intϕ(ϕ)` as a template +- kwargs: additional keyword arguments passed to `gf`, such as `gdev` or `pbm_covars` The loss function `loss_gf(ϕ, xM, xP, y_o, y_unc, i_sites)` takes - parameter vector ϕ @@ -147,6 +149,8 @@ The loss function `loss_gf(ϕ, xM, xP, y_o, y_unc, i_sites)` takes - xP: iteration of drivers for each site - y_o: matrix of observations, sites in columns - y_unc: vector of uncertainty information for each observation + Currently, hardcoes squared error loss of `(y_pred .- y_o) ./ σ`, + with `σ = exp.(y_unc ./ 2)`. - i_sites: index of sites in the batch and returns a NamedTuple of @@ -157,7 +161,7 @@ and returns a NamedTuple of - `neg_log_prior`: negative log-prior of `θMs` and `θP` - `neg_log_prior`: negative log-prior of `θMs` and `θP` """ -function get_loss_gf(g, transM, transP, f, +function get_loss_gf(g, transM, transP, f, py, intϕ::AbstractComponentArrayInterpreter, intP::AbstractComponentArrayInterpreter = ComponentArrayInterpreter( intϕ(1:length(intϕ)).ϕP); @@ -175,7 +179,6 @@ function get_loss_gf(g, transM, transP, f, #, intP = get_concrete(intP) #inv_transP = inverse(transP), kwargs = kwargs function loss_gf(ϕ, xM, xP, y_o, y_unc, i_sites) - σ = exp.(y_unc ./ 2) ϕc = intϕ(ϕ) # μ_ζP = ϕc.ϕP # xMP = _append_each_covars(xM, CA.getdata(μ_ζP), pbm_covar_indices) @@ -187,7 +190,9 @@ function get_loss_gf(g, transM, transP, f, y_pred, θMs_pred, θP_pred = gf( g, transMs, transP, f, xM, xP, CA.getdata(ϕc.ϕg), CA.getdata(ϕc.ϕP), pbm_covar_indices; cdev, kwargs...) - nLy = sum(abs2, (y_pred .- y_o) ./ σ) + #σ = exp.(y_unc ./ 2) + #nLy = sum(abs2, (y_pred .- y_o) ./ σ) + nLy = py( y_pred, y_o, y_unc) # logpdf is not typestable for Distribution{Univariate, Continuous} logpdf_t = (prior, θ) -> logpdf(prior, θ)::eltype(θP_pred) logpdf_tv = (prior, θ::AbstractVector) -> begin diff --git a/src/logden_normal.jl b/src/logden_normal.jl index 19b7d3a..d3cd814 100644 --- a/src/logden_normal.jl +++ b/src/logden_normal.jl @@ -1,7 +1,7 @@ """ neg_logden_indep_normal(obs, μ, logσ2s; σfac=1.0) -Compute the negative Log-density of `θM` for multiple independent normal distributions, +Compute the negative Log-density of `obs` for multiple independent normal distributions, given estimated means `μ` and estimated log of variance parameters `logσ2s`. All the arguments should be vectors of the same length. diff --git a/test/test_HybridProblem.jl b/test/test_HybridProblem.jl index ef105bf..c8b9c2b 100644 --- a/test/test_HybridProblem.jl +++ b/test/test_HybridProblem.jl @@ -141,6 +141,7 @@ test_without_flux = (scenario) -> begin train_loader = get_hybridproblem_train_dataloader(prob; scenario) (xM, xP, y_o, y_unc, i_sites) = first(train_loader) f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false) + py = get_hybridproblem_neg_logden_obs(prob; scenario) par_templates = get_hybridproblem_par_templates(prob; scenario) #f(par_templates.θP, hcat(par_templates.θM, par_templates.θM), xP[1:2]) (; transM, transP) = get_hybridproblem_transforms(prob; scenario) @@ -154,7 +155,7 @@ test_without_flux = (scenario) -> begin p = p0 = vcat(ϕg0, par_templates.θP .* convert(eltype(ϕg0), 0.8)) # Pass the site-data for the batches as separate vectors wrapped in a tuple - loss_gf = get_loss_gf(g, transM, transP, f, intϕ; + loss_gf = get_loss_gf(g, transM, transP, f, py, intϕ; pbm_covars, n_site_batch = n_batch, priorsP, priorsM) (_xM, _xP, _y_o, _y_unc, _i_sites) = first(train_loader) l1 = loss_gf(p0, _xM, _xP, _y_o, _y_unc, _i_sites) diff --git a/test/test_doubleMM.jl b/test/test_doubleMM.jl index 9f16a4b..b4d6936 100644 --- a/test/test_doubleMM.jl +++ b/test/test_doubleMM.jl @@ -199,6 +199,7 @@ end n_site, n_site_batch = get_hybridproblem_n_site_and_batch(prob; scenario) f = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = false) f2 = get_hybridproblem_PBmodel(prob; scenario, use_all_sites = true) + py = get_hybridproblem_neg_logden_obs(prob; scenario) priors = get_hybridproblem_priors(prob; scenario) priorsP = [priors[k] for k in keys(par_templates.θP)] priorsM = [priors[k] for k in keys(par_templates.θM)] @@ -218,9 +219,9 @@ end pbm_covars = get_hybridproblem_pbmpar_covars(prob; scenario) #loss_gf = get_loss_gf(g, transM, f, intϕ; gdev = identity) - loss_gf = get_loss_gf(g, transM, transP, f, intϕ; + loss_gf = get_loss_gf(g, transM, transP, f, py, intϕ; pbm_covars, n_site_batch = n_batch, priorsP, priorsM) - loss_gf2 = get_loss_gf(g, transM, transP, f2, intϕ; + loss_gf_site = get_loss_gf(g, transM, transP, f2, py, intϕ; pbm_covars, n_site_batch = n_site, priorsP, priorsM) nLjoint = @inferred first(loss_gf(p0, first(train_loader)...)) (xM_batch, xP_batch, y_o_batch, y_unc_batch, i_sites_batch) = first(train_loader) @@ -237,7 +238,7 @@ end #optprob, Adam(0.02), callback = callback_loss(100), maxiters = 5000); optprob, Adam(0.02), maxiters = 2000) - (;nLjoint_pen, y_pred, θMs_pred, θP_pred, nLy, neg_log_prior, loss_penalty) = loss_gf2( + (;nLjoint_pen, y_pred, θMs_pred, θP_pred, nLy, neg_log_prior, loss_penalty) = loss_gf_site( res.u, train_loader.data...) #(nLjoint, y_pred, θMs_pred, θP, nLy, neg_log_prior, loss_penalty) = loss_gf(p0, xM, xP, y_o, y_unc); θMs_pred = CA.ComponentArray(θMs_pred, CA.getaxes(θMs_true')) diff --git a/test/test_util_gpu.jl b/test/test_util_gpu.jl index 75162f8..04a2b8e 100644 --- a/test/test_util_gpu.jl +++ b/test/test_util_gpu.jl @@ -15,11 +15,11 @@ gdev = gpu_device() if gdev isa MLDataDevices.CUDADevice @testset "ones_similar_x" begin B = CUDA.rand(Float32, 5, 2); # GPU matrix - @test HVI.ones_similar_x(B, size(B,1)) isa CuArray - @test HVI.ones_similar_x(ComponentVector(b=B), size(B,1)) isa CuArray - @test HVI.ones_similar_x(B', size(B,1)) isa CuArray - @test HVI.ones_similar_x(@view(B[:,2]), size(B,1)) isa CuArray - @test HVI.ones_similar_x(ComponentVector(b=B)[:,1], size(B,1)) isa CuArray + @test HVI.ones_similar_x(B, size(B,1)) isa CUDA.CuArray + @test HVI.ones_similar_x(ComponentVector(b=B), size(B,1)) isa CUDA.CuArray + @test HVI.ones_similar_x(B', size(B,1)) isa CUDA.CuArray + @test HVI.ones_similar_x(@view(B[:,2]), size(B,1)) isa CUDA.CuArray + @test HVI.ones_similar_x(ComponentVector(b=B)[:,1], size(B,1)) isa CUDA.CuArray end end