From 0cc569dd966dd5768432f3455f63db2a59a05c35 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Mon, 12 May 2025 16:05:44 +0200 Subject: [PATCH 1/3] replace model clusre by Callable Object --- dev/doubleMM.jl | 7 --- src/AbstractHybridProblem.jl | 52 ++++++++++++++++++++- src/DoubleMM/f_doubleMM.jl | 87 +++++++++++++++--------------------- src/HybridProblem.jl | 12 ++--- test/test_doubleMM.jl | 1 + 5 files changed, 94 insertions(+), 65 deletions(-) diff --git a/dev/doubleMM.jl b/dev/doubleMM.jl index 6bcf7f1..3a9455b 100644 --- a/dev/doubleMM.jl +++ b/dev/doubleMM.jl @@ -197,10 +197,6 @@ prob2o = probo; fname_probos = "intermediate/probos800_$(last(HVI._val_value(scenario))).jld2" JLD2.save(fname_probos, Dict("prob1o" => prob1o, "prob2o" => prob2o)) tmp = JLD2.load(fname_probos) - # TODO replace function closure by Callable to store - # closure function could not be restored with JLD2 - prob1o = HVI.update(tmp["prob1o"], get_train_loader = prob0.get_train_loader); - prob2o = HVI.update(tmp["prob2o"], get_train_loader = prob0.get_train_loader); end () -> begin # load the non-covar scenario @@ -208,9 +204,6 @@ end #fname_probos = "intermediate/probos_$(last(_val_value(scenario))).jld2" fname_probos = "intermediate/probos800_omit_r0.jld2" tmp = JLD2.load(fname_probos) - # get_train_loader function could not be restored with JLD2 - prob1o_indep = HVI.update(tmp["prob1o"], get_train_loader = prob0.get_train_loader); - prob2o_indep = HVI.update(tmp["prob2o"], get_train_loader = prob0.get_train_loader); # test predicting correct obs-uncertainty of predictive posterior n_sample_pred = 400 (; θ, y, entropy_ζ) = predict_hvi(rng, prob2o_indep, xM, xP; scenario, n_sample_pred); diff --git a/src/AbstractHybridProblem.jl b/src/AbstractHybridProblem.jl index 4705c4a..1d4a898 100644 --- a/src/AbstractHybridProblem.jl +++ b/src/AbstractHybridProblem.jl @@ -152,7 +152,7 @@ function gen_hybridproblem_synthetic end Determine the FloatType for given Case and scenario, defaults to Float32 """ -function get_hybridproblem_float_type(prob::AbstractHybridProblem; scenario = ()) +function get_hybridproblem_float_type(prob::AbstractHybridProblem; scenario) return eltype(get_hybridproblem_par_templates(prob; scenario).θM) end @@ -263,4 +263,52 @@ function setup_PBMpar_interpreter(θP, θM, θall = vcat(θP, θM)) θFix = θall[keys_fixed] intθ = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM, θFix))) intθ, θFix -end \ No newline at end of file +end + +struct PBmodelClosure{θFixT, θFix_devT, AX, pos_xPT} + θFix::θFixT + θFix_dev::θFix_devT + intθ::StaticComponentArrayInterpreter{AX} + isP::Matrix{Int} + n_site_batch::Int + pos_xP::pos_xPT +end + +function PBmodelClosure(prob::AbstractHybridProblem; scenario::Val{scen}, + use_all_sites = false, + gdev = :f_on_gpu ∈ _val_value(scenario) ? gpu_device() : identity, + θall, int_xP1, +) where {scen} + n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) + n_site_batch = use_all_sites ? n_site : n_batch + #fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers + par_templates = get_hybridproblem_par_templates(prob; scenario) + intθ1, θFix1 = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall) + θFix = repeat(θFix1', n_site_batch) + θFix_dev = gdev(θFix) + intθ = get_concrete(ComponentArrayInterpreter((n_site_batch,), intθ1)) + #int_xPb = ComponentArrayInterpreter((n_site_batch,), int_xP1) + isP = repeat(axes(par_templates.θP, 1)', n_site_batch) + pos_xP = get_positions(int_xP1) + PBmodelClosure(;θFix, θFix_dev, intθ, isP, n_site_batch, pos_xP) +end + +function PBmodelClosure(; + θFix::θFixT, + θFix_dev::θFix_devT, + intθ::StaticComponentArrayInterpreter{AX}, + isP::Matrix{Int}, + n_site_batch::Int, + pos_xP::pos_xPT, +) where {θFixT, θFix_devT, AX, pos_xPT} + PBmodelClosure{θFixT, θFix_devT, AX, pos_xPT}( + θFix::AbstractArray, θFix_dev, intθ, isP, n_site_batch, pos_xP) +end + + + + + + + + diff --git a/src/DoubleMM/f_doubleMM.jl b/src/DoubleMM/f_doubleMM.jl index bbdb986..6915237 100644 --- a/src/DoubleMM/f_doubleMM.jl +++ b/src/DoubleMM/f_doubleMM.jl @@ -181,55 +181,43 @@ end # end # end -function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::Val{scen}, - use_all_sites = false, - gdev = :f_on_gpu ∈ HVI._val_value(scenario) ? gpu_device() : identity -) where {scen} - n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario) - n_site_batch = use_all_sites ? n_site : n_batch - #fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers - par_templates = get_hybridproblem_par_templates(prob; scenario) - intθ1, θFix1 = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall) - θFix = repeat(θFix1', n_site_batch) - intθ = get_concrete(ComponentArrayInterpreter((n_site_batch,), intθ1)) - #int_xPb = ComponentArrayInterpreter((n_site_batch,), int_xP1) - isP = repeat(axes(par_templates.θP, 1)', n_site_batch) - let θFix = θFix, θFix_dev = gdev(θFix), intθ = get_concrete(intθ), isP = isP, - n_site_batch = n_site_batch, - #int_xPb=get_concrete(int_xPb), - pos_xP = get_positions(int_xP1) +# defining the PBmodel as a clousre with let leads to problems of JLD2 reloading +# Define all the varaibles additional to the ones passed curing the call by +# a dedicated Closure object and define the PBmodel as a callable +struct DoubleMMCaller{CLT} + cl::CLT +end - function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP) - @assert size(xP, 2) == n_site_batch - @assert size(θMs, 1) == n_site_batch - # # convert vector of tuples to tuple of matricesByRows - # # need to supply xP as vectorOfTuples to work with DataLoader - # # k = first(keys(xP[1])) - # xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k - # #stack(map(r -> r[k], xP))' - # stack(map(r -> r[k], xP); dims = 1) - # end)...) - #xPM = map(transpose, xPM1) - #xPc = int_xPb(CA.getdata(xP)) - #xPM = (S1 = xPc[:,:S1], S2 = xPc[:,:S2]) # problems with Zygote - # make sure the same order of columns as in intθ - # reshape big matrix into NamedTuple of drivers S1 and S2 - # for broadcasting need sites in rows - #xPM = map(p -> CA.getdata(xP[p,:])', pos_xP) - xPM = map(p -> CA.getdata(xP)'[:, p], pos_xP) - θFixd = (θP isa GPUArraysCore.AbstractGPUVector) ? θFix_dev : θFix - θ = hcat(CA.getdata(θP[isP]), CA.getdata(θMs), θFixd) - pred_sites = f_doubleMM(θ, xPM; intθ)' - pred_global = eltype(pred_sites)[] - return pred_global, pred_sites - end - # function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP) - # # TODO - # pred_sites = f_doubleMM(θMs, θP, θFix, xP, intθ) - # pred_global = eltype(pred_sites)[] - # return pred_global, pred_sites - # end - end +function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario, kwargs...) + # θall defined in this module above + cl = HVI.PBmodelClosure(prob; scenario, θall, int_xP1, kwargs...) + return DoubleMMCaller{typeof(cl)}(cl) +end + +function(caller::DoubleMMCaller)(θP::AbstractVector, θMs::AbstractMatrix, xP) + cl = caller.cl + @assert size(xP, 2) == cl.n_site_batch + @assert size(θMs, 1) == cl.n_site_batch + # # convert vector of tuples to tuple of matricesByRows + # # need to supply xP as vectorOfTuples to work with DataLoader + # # k = first(keys(xP[1])) + # xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k + # #stack(map(r -> r[k], xP))' + # stack(map(r -> r[k], xP); dims = 1) + # end)...) + #xPM = map(transpose, xPM1) + #xPc = int_xPb(CA.getdata(xP)) + #xPM = (S1 = xPc[:,:S1], S2 = xPc[:,:S2]) # problems with Zygote + # make sure the same order of columns as in intθ + # reshape big matrix into NamedTuple of drivers S1 and S2 + # for broadcasting need sites in rows + #xPM = map(p -> CA.getdata(xP[p,:])', pos_xP) + xPM = map(p -> CA.getdata(xP)'[:, p], cl.pos_xP) + θFixd = (θP isa GPUArraysCore.AbstractGPUVector) ? cl.θFix_dev : cl.θFix + θ = hcat(CA.getdata(θP[cl.isP]), CA.getdata(θMs), θFixd) + pred_sites = f_doubleMM(θ, xPM; cl.intθ)' + pred_global = eltype(pred_sites)[] + return pred_global, pred_sites end function HVI.get_hybridproblem_neg_logden_obs(::DoubleMMCase; scenario::Val) @@ -284,8 +272,7 @@ function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase; xP = int_xP_sites(vcat(repeat(xP_S1, 1, n_site), repeat(xP_S2, 1, n_site))) #xP[:S1,:] θP = par_templates.θP - #θint = ComponentArrayInterpreter( (size(θMs_true,2),), CA.getaxes(vcat(θP, θMs_true[:,1]))) - y_global_true, y_true = f(θP, θMs_true', xP) + y_global_true, y_true = f(θP, θMs_true', xP) σ_o = FloatType(0.01) #σ_o = FloatType(0.002) logσ2_o = FloatType(2) .* log.(σ_o) diff --git a/src/HybridProblem.jl b/src/HybridProblem.jl index a2c30eb..98d4fc9 100644 --- a/src/HybridProblem.jl +++ b/src/HybridProblem.jl @@ -21,8 +21,8 @@ struct HybridProblem <: AbstractHybridProblem θP::CA.ComponentVector, θM::CA.ComponentVector, g::AbstractModelApplicator, ϕg::AbstractVector, ϕunc::CA.ComponentVector, - f_batch::Function, - f_allsites::Function, + f_batch, + f_allsites, priors::AbstractDict, py, transM::Stacked, @@ -43,7 +43,7 @@ end function HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector, # note no ϕg argument and g_chain unconstrained - g_chain, f_batch::Function, + g_chain, f_batch, args...; rng = Random.default_rng(), kwargs...) # dispatches on type of g_chain g, ϕg = construct_ChainsApplicator(rng, g_chain, eltype(θM)) @@ -74,10 +74,10 @@ function update(prob::HybridProblem; g::AbstractModelApplicator = prob.g, ϕg::AbstractVector = prob.ϕg, ϕunc::CA.ComponentVector = prob.ϕunc, - f_batch::Function = prob.f_batch, - f_allsites::Function = prob.f_allsites, + f_batch = prob.f_batch, + f_allsites = prob.f_allsites, priors::AbstractDict = prob.priors, - py::Function = prob.py, + py = prob.py, # transM::Union{Function, Bijectors.Transform} = prob.transM, # transP::Union{Function, Bijectors.Transform} = prob.transP, transM = prob.transM, diff --git a/test/test_doubleMM.jl b/test/test_doubleMM.jl index 4988866..1d3407f 100644 --- a/test/test_doubleMM.jl +++ b/test/test_doubleMM.jl @@ -25,6 +25,7 @@ prob = DoubleMM.DoubleMMCase() scenario = Val((:default,)) #using Flux #scenario = Val((:use_Flux,)) +#scenario = Val((:use_Flux,:f_on_gpu)) par_templates = get_hybridproblem_par_templates(prob; scenario) From 7707f4aba063de8e5eb5eb9bbdd0a568400113f1 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Mon, 12 May 2025 17:44:26 +0200 Subject: [PATCH 2/3] fix test on correlations the parameter specified for the cholesky-decompositoin is only a transformation of the correlation --- src/cholesky.jl | 13 +++++++++++++ src/elbo.jl | 2 ++ test/test_elbo.jl | 21 +++++++++++++-------- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/cholesky.jl b/src/cholesky.jl index 52a94b6..14af734 100644 --- a/src/cholesky.jl +++ b/src/cholesky.jl @@ -176,6 +176,19 @@ end # U = _create_blockdiag(v[first(keys(v))], blocks) # v only for dispatch: plain matrix for gpu # end +""" +Compute the cholesky-factor parameter for a given single +correlation in a 2x2 matrix. +Invert the transformation of cholesky-factor parameterization. +""" +function compute_cholcor_coefficient_single(ρ) + # invert ρ = a / sqrt(a^2 + 1) + sign(ρ) * sqrt(ρ^2/(1 - ρ^2)) +end + + + + """ get_ca_starts(vc::ComponentVector) diff --git a/src/elbo.jl b/src/elbo.jl index 8db37c0..66c4ce4 100644 --- a/src/elbo.jl +++ b/src/elbo.jl @@ -411,6 +411,8 @@ function sample_ζresid_norm(urandn::AbstractMatrix, ζP::TP, ζMs::TM, ρsP = isempty(ϕuncc.ρsP) ? similar(ϕuncc.ρsP) : ϕuncc.ρsP # required by zygote UP = transformU_block_cholesky1(ρsP, cor_ends.P) ρsM = isempty(ϕuncc.ρsM) ? similar(ϕuncc.ρsM) : ϕuncc.ρsM # required by zygote + # cholesky factor of the correlation: diag(UM' * UM) .== 1 + # coefficients ρsM can be larger than 1, still yielding correlations <1 in UM' * UM UM = transformU_block_cholesky1(ρsM, cor_ends.M) cf = ϕuncc.coef_logσ2_ζMs logσ2_logMs = vec(cf[1, :] .+ cf[2, :] .* ζMs) diff --git a/test/test_elbo.jl b/test/test_elbo.jl index 4e70253..c5b5a56 100644 --- a/test/test_elbo.jl +++ b/test/test_elbo.jl @@ -130,8 +130,13 @@ test_scenario = (scenario) -> begin ϕunc_true.logσ2_ζP = (log ∘ abs2).(sd_ζP_true) ϕunc_true.coef_logσ2_ζMs[1,:] = (log ∘ abs2).(sd_ζMs_a_true) ϕunc_true.coef_logσ2_ζMs[2,:] = logσ2_ζMs_b_true - ϕunc_true.ρsP = ρsP_true - ϕunc_true.ρsM = ρsM_true + # note that the parameterization contains a transformation that + # here only inverted for the single correlation case + ϕunc_true.ρsP = CP.compute_cholcor_coefficient_single.(ρsP_true) + ϕunc_true.ρsM = CP.compute_cholcor_coefficient_single.(ρsM_true) + # check that ρsM_true = -0.6 recovered with params ϕunc_true.ρsM = -0.75 + UC = CP.transformU_cholesky1(ϕunc_true.ρsM); Σ = UC' * UC + @test Σ[1,2] ≈ ρsM_true[1] probd = CP.update(probc; ϕunc=ϕunc_true); _ϕ = vcat(ϕ_ini.μP, probc.ϕg, probd.ϕunc) @@ -189,12 +194,12 @@ test_scenario = (scenario) -> begin residPMst = vcat(residP, reshape(residMst, size(residMst,1)*size(residMst,2), size(residMst,3))) cor_PMs = cor(residPMst') - @test cor_PMs[1,2] ≈ ρsP_true[1] atol=0.2 - @test all(.≈(cor_PMs[1:2,3:end], 0.0, atol=0.2)) # no correlations P,M - @test cor_PMs[3,4] ≈ ρsM_true[1] atol=0.2 - @test all(.≈(cor_PMs[3:4,5:end], 0.0, atol=0.2)) # no correlations M1, M2 - @test cor_PMs[5,6] ≈ ρsM_true[1] atol=0.2 - @test all(.≈(cor_PMs[5:6,7:end], 0.0, atol=0.2)) # no correlations M1, M2 + @test cor_PMs[1,2] ≈ ρsP_true[1] atol=0.02 + @test all(.≈(cor_PMs[1:2,3:end], 0.0, atol=0.02)) # no correlations P,M + @test cor_PMs[3,4] ≈ ρsM_true[1] atol=0.02 + @test all(.≈(cor_PMs[3:4,5:end], 0.0, atol=0.02)) # no correlations M1, M2 + @test cor_PMs[5,6] ≈ ρsM_true[1] atol=0.02 + @test all(.≈(cor_PMs[5:6,7:end], 0.0, atol=0.02)) # no correlations M1, M2 end test_distζ(_ζsP, _ζsMs, ϕunc_true, ζMs_g) @testset "predict_hvi check sd" begin From 80d0f4d2ca6bdd9aeef32277e2236dd6fc10c2b1 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Mon, 12 May 2025 17:46:26 +0200 Subject: [PATCH 3/3] typos --- src/DoubleMM/f_doubleMM.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/DoubleMM/f_doubleMM.jl b/src/DoubleMM/f_doubleMM.jl index 6915237..01b58b2 100644 --- a/src/DoubleMM/f_doubleMM.jl +++ b/src/DoubleMM/f_doubleMM.jl @@ -181,8 +181,8 @@ end # end # end -# defining the PBmodel as a clousre with let leads to problems of JLD2 reloading -# Define all the varaibles additional to the ones passed curing the call by +# defining the PBmodel as a closure with let leads to problems of JLD2 reloading +# Define all the variables additional to the ones passed curing the call by # a dedicated Closure object and define the PBmodel as a callable struct DoubleMMCaller{CLT} cl::CLT