Skip to content
Merged

Dev #23

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions dev/doubleMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,13 @@ prob2o = probo;
fname_probos = "intermediate/probos800_$(last(HVI._val_value(scenario))).jld2"
JLD2.save(fname_probos, Dict("prob1o" => prob1o, "prob2o" => prob2o))
tmp = JLD2.load(fname_probos)
# TODO replace function closure by Callable to store
# closure function could not be restored with JLD2
prob1o = HVI.update(tmp["prob1o"], get_train_loader = prob0.get_train_loader);
prob2o = HVI.update(tmp["prob2o"], get_train_loader = prob0.get_train_loader);
end

() -> begin # load the non-covar scenario
using JLD2
#fname_probos = "intermediate/probos_$(last(_val_value(scenario))).jld2"
fname_probos = "intermediate/probos800_omit_r0.jld2"
tmp = JLD2.load(fname_probos)
# get_train_loader function could not be restored with JLD2
prob1o_indep = HVI.update(tmp["prob1o"], get_train_loader = prob0.get_train_loader);
prob2o_indep = HVI.update(tmp["prob2o"], get_train_loader = prob0.get_train_loader);
# test predicting correct obs-uncertainty of predictive posterior
n_sample_pred = 400
(; θ, y, entropy_ζ) = predict_hvi(rng, prob2o_indep, xM, xP; scenario, n_sample_pred);
Expand Down
52 changes: 50 additions & 2 deletions src/AbstractHybridProblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ function gen_hybridproblem_synthetic end

Determine the FloatType for given Case and scenario, defaults to Float32
"""
function get_hybridproblem_float_type(prob::AbstractHybridProblem; scenario = ())
function get_hybridproblem_float_type(prob::AbstractHybridProblem; scenario)
return eltype(get_hybridproblem_par_templates(prob; scenario).θM)
end

Expand Down Expand Up @@ -263,4 +263,52 @@ function setup_PBMpar_interpreter(θP, θM, θall = vcat(θP, θM))
θFix = θall[keys_fixed]
intθ = ComponentArrayInterpreter(flatten1(CA.ComponentVector(; θP, θM, θFix)))
intθ, θFix
end
end

struct PBmodelClosure{θFixT, θFix_devT, AX, pos_xPT}
θFix::θFixT
θFix_dev::θFix_devT
intθ::StaticComponentArrayInterpreter{AX}
isP::Matrix{Int}
n_site_batch::Int
pos_xP::pos_xPT
end

function PBmodelClosure(prob::AbstractHybridProblem; scenario::Val{scen},
use_all_sites = false,
gdev = :f_on_gpu ∈ _val_value(scenario) ? gpu_device() : identity,
θall, int_xP1,
) where {scen}
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario)
n_site_batch = use_all_sites ? n_site : n_batch
#fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
par_templates = get_hybridproblem_par_templates(prob; scenario)
intθ1, θFix1 = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall)
θFix = repeat(θFix1', n_site_batch)
θFix_dev = gdev(θFix)
intθ = get_concrete(ComponentArrayInterpreter((n_site_batch,), intθ1))
#int_xPb = ComponentArrayInterpreter((n_site_batch,), int_xP1)
isP = repeat(axes(par_templates.θP, 1)', n_site_batch)
pos_xP = get_positions(int_xP1)
PBmodelClosure(;θFix, θFix_dev, intθ, isP, n_site_batch, pos_xP)
end

function PBmodelClosure(;
θFix::θFixT,
θFix_dev::θFix_devT,
intθ::StaticComponentArrayInterpreter{AX},
isP::Matrix{Int},
n_site_batch::Int,
pos_xP::pos_xPT,
) where {θFixT, θFix_devT, AX, pos_xPT}
PBmodelClosure{θFixT, θFix_devT, AX, pos_xPT}(
θFix::AbstractArray, θFix_dev, intθ, isP, n_site_batch, pos_xP)
end








87 changes: 37 additions & 50 deletions src/DoubleMM/f_doubleMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,55 +181,43 @@ end
# end
# end

function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario::Val{scen},
use_all_sites = false,
gdev = :f_on_gpu ∈ HVI._val_value(scenario) ? gpu_device() : identity
) where {scen}
n_site, n_batch = get_hybridproblem_n_site_and_batch(prob; scenario)
n_site_batch = use_all_sites ? n_site : n_batch
#fsite = (θ, x_site) -> f_doubleMM(θ) # omit x_site drivers
par_templates = get_hybridproblem_par_templates(prob; scenario)
intθ1, θFix1 = setup_PBMpar_interpreter(par_templates.θP, par_templates.θM, θall)
θFix = repeat(θFix1', n_site_batch)
intθ = get_concrete(ComponentArrayInterpreter((n_site_batch,), intθ1))
#int_xPb = ComponentArrayInterpreter((n_site_batch,), int_xP1)
isP = repeat(axes(par_templates.θP, 1)', n_site_batch)
let θFix = θFix, θFix_dev = gdev(θFix), intθ = get_concrete(intθ), isP = isP,
n_site_batch = n_site_batch,
#int_xPb=get_concrete(int_xPb),
pos_xP = get_positions(int_xP1)
# defining the PBmodel as a closure with let leads to problems of JLD2 reloading
# Define all the variables additional to the ones passed curing the call by
# a dedicated Closure object and define the PBmodel as a callable
struct DoubleMMCaller{CLT}
cl::CLT
end

function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP)
@assert size(xP, 2) == n_site_batch
@assert size(θMs, 1) == n_site_batch
# # convert vector of tuples to tuple of matricesByRows
# # need to supply xP as vectorOfTuples to work with DataLoader
# # k = first(keys(xP[1]))
# xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k
# #stack(map(r -> r[k], xP))'
# stack(map(r -> r[k], xP); dims = 1)
# end)...)
#xPM = map(transpose, xPM1)
#xPc = int_xPb(CA.getdata(xP))
#xPM = (S1 = xPc[:,:S1], S2 = xPc[:,:S2]) # problems with Zygote
# make sure the same order of columns as in intθ
# reshape big matrix into NamedTuple of drivers S1 and S2
# for broadcasting need sites in rows
#xPM = map(p -> CA.getdata(xP[p,:])', pos_xP)
xPM = map(p -> CA.getdata(xP)'[:, p], pos_xP)
θFixd = (θP isa GPUArraysCore.AbstractGPUVector) ? θFix_dev : θFix
θ = hcat(CA.getdata(θP[isP]), CA.getdata(θMs), θFixd)
pred_sites = f_doubleMM(θ, xPM; intθ)'
pred_global = eltype(pred_sites)[]
return pred_global, pred_sites
end
# function f_doubleMM_with_global(θP::AbstractVector, θMs::AbstractMatrix, xP)
# # TODO
# pred_sites = f_doubleMM(θMs, θP, θFix, xP, intθ)
# pred_global = eltype(pred_sites)[]
# return pred_global, pred_sites
# end
end
function HVI.get_hybridproblem_PBmodel(prob::DoubleMMCase; scenario, kwargs...)
# θall defined in this module above
cl = HVI.PBmodelClosure(prob; scenario, θall, int_xP1, kwargs...)
return DoubleMMCaller{typeof(cl)}(cl)
end

function(caller::DoubleMMCaller)(θP::AbstractVector, θMs::AbstractMatrix, xP)
cl = caller.cl
@assert size(xP, 2) == cl.n_site_batch
@assert size(θMs, 1) == cl.n_site_batch
# # convert vector of tuples to tuple of matricesByRows
# # need to supply xP as vectorOfTuples to work with DataLoader
# # k = first(keys(xP[1]))
# xPM = (; zip(keys(xP[1]), map(keys(xP[1])) do k
# #stack(map(r -> r[k], xP))'
# stack(map(r -> r[k], xP); dims = 1)
# end)...)
#xPM = map(transpose, xPM1)
#xPc = int_xPb(CA.getdata(xP))
#xPM = (S1 = xPc[:,:S1], S2 = xPc[:,:S2]) # problems with Zygote
# make sure the same order of columns as in intθ
# reshape big matrix into NamedTuple of drivers S1 and S2
# for broadcasting need sites in rows
#xPM = map(p -> CA.getdata(xP[p,:])', pos_xP)
xPM = map(p -> CA.getdata(xP)'[:, p], cl.pos_xP)
θFixd = (θP isa GPUArraysCore.AbstractGPUVector) ? cl.θFix_dev : cl.θFix
θ = hcat(CA.getdata(θP[cl.isP]), CA.getdata(θMs), θFixd)
pred_sites = f_doubleMM(θ, xPM; cl.intθ)'
pred_global = eltype(pred_sites)[]
return pred_global, pred_sites
end

function HVI.get_hybridproblem_neg_logden_obs(::DoubleMMCase; scenario::Val)
Expand Down Expand Up @@ -284,8 +272,7 @@ function HVI.gen_hybridproblem_synthetic(rng::AbstractRNG, prob::DoubleMMCase;
xP = int_xP_sites(vcat(repeat(xP_S1, 1, n_site), repeat(xP_S2, 1, n_site)))
#xP[:S1,:]
θP = par_templates.θP
#θint = ComponentArrayInterpreter( (size(θMs_true,2),), CA.getaxes(vcat(θP, θMs_true[:,1])))
y_global_true, y_true = f(θP, θMs_true', xP)
y_global_true, y_true = f(θP, θMs_true', xP)
σ_o = FloatType(0.01)
#σ_o = FloatType(0.002)
logσ2_o = FloatType(2) .* log.(σ_o)
Expand Down
12 changes: 6 additions & 6 deletions src/HybridProblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ struct HybridProblem <: AbstractHybridProblem
θP::CA.ComponentVector, θM::CA.ComponentVector,
g::AbstractModelApplicator, ϕg::AbstractVector,
ϕunc::CA.ComponentVector,
f_batch::Function,
f_allsites::Function,
f_batch,
f_allsites,
priors::AbstractDict,
py,
transM::Stacked,
Expand All @@ -43,7 +43,7 @@ end

function HybridProblem(θP::CA.ComponentVector, θM::CA.ComponentVector,
# note no ϕg argument and g_chain unconstrained
g_chain, f_batch::Function,
g_chain, f_batch,
args...; rng = Random.default_rng(), kwargs...)
# dispatches on type of g_chain
g, ϕg = construct_ChainsApplicator(rng, g_chain, eltype(θM))
Expand Down Expand Up @@ -74,10 +74,10 @@ function update(prob::HybridProblem;
g::AbstractModelApplicator = prob.g,
ϕg::AbstractVector = prob.ϕg,
ϕunc::CA.ComponentVector = prob.ϕunc,
f_batch::Function = prob.f_batch,
f_allsites::Function = prob.f_allsites,
f_batch = prob.f_batch,
f_allsites = prob.f_allsites,
priors::AbstractDict = prob.priors,
py::Function = prob.py,
py = prob.py,
# transM::Union{Function, Bijectors.Transform} = prob.transM,
# transP::Union{Function, Bijectors.Transform} = prob.transP,
transM = prob.transM,
Expand Down
13 changes: 13 additions & 0 deletions src/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,19 @@ end
# U = _create_blockdiag(v[first(keys(v))], blocks) # v only for dispatch: plain matrix for gpu
# end

"""
Compute the cholesky-factor parameter for a given single
correlation in a 2x2 matrix.
Invert the transformation of cholesky-factor parameterization.
"""
function compute_cholcor_coefficient_single(ρ)
# invert ρ = a / sqrt(a^2 + 1)
sign(ρ) * sqrt(ρ^2/(1 - ρ^2))
end




"""
get_ca_starts(vc::ComponentVector)

Expand Down
2 changes: 2 additions & 0 deletions src/elbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,8 @@ function sample_ζresid_norm(urandn::AbstractMatrix, ζP::TP, ζMs::TM,
ρsP = isempty(ϕuncc.ρsP) ? similar(ϕuncc.ρsP) : ϕuncc.ρsP # required by zygote
UP = transformU_block_cholesky1(ρsP, cor_ends.P)
ρsM = isempty(ϕuncc.ρsM) ? similar(ϕuncc.ρsM) : ϕuncc.ρsM # required by zygote
# cholesky factor of the correlation: diag(UM' * UM) .== 1
# coefficients ρsM can be larger than 1, still yielding correlations <1 in UM' * UM
UM = transformU_block_cholesky1(ρsM, cor_ends.M)
cf = ϕuncc.coef_logσ2_ζMs
logσ2_logMs = vec(cf[1, :] .+ cf[2, :] .* ζMs)
Expand Down
1 change: 1 addition & 0 deletions test/test_doubleMM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ prob = DoubleMM.DoubleMMCase()
scenario = Val((:default,))
#using Flux
#scenario = Val((:use_Flux,))
#scenario = Val((:use_Flux,:f_on_gpu))

par_templates = get_hybridproblem_par_templates(prob; scenario)

Expand Down
21 changes: 13 additions & 8 deletions test/test_elbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,13 @@ test_scenario = (scenario) -> begin
ϕunc_true.logσ2_ζP = (log ∘ abs2).(sd_ζP_true)
ϕunc_true.coef_logσ2_ζMs[1,:] = (log ∘ abs2).(sd_ζMs_a_true)
ϕunc_true.coef_logσ2_ζMs[2,:] = logσ2_ζMs_b_true
ϕunc_true.ρsP = ρsP_true
ϕunc_true.ρsM = ρsM_true
# note that the parameterization contains a transformation that
# here only inverted for the single correlation case
ϕunc_true.ρsP = CP.compute_cholcor_coefficient_single.(ρsP_true)
ϕunc_true.ρsM = CP.compute_cholcor_coefficient_single.(ρsM_true)
# check that ρsM_true = -0.6 recovered with params ϕunc_true.ρsM = -0.75
UC = CP.transformU_cholesky1(ϕunc_true.ρsM); Σ = UC' * UC
@test Σ[1,2] ≈ ρsM_true[1]

probd = CP.update(probc; ϕunc=ϕunc_true);
_ϕ = vcat(ϕ_ini.μP, probc.ϕg, probd.ϕunc)
Expand Down Expand Up @@ -189,12 +194,12 @@ test_scenario = (scenario) -> begin
residPMst = vcat(residP,
reshape(residMst, size(residMst,1)*size(residMst,2), size(residMst,3)))
cor_PMs = cor(residPMst')
@test cor_PMs[1,2] ≈ ρsP_true[1] atol=0.2
@test all(.≈(cor_PMs[1:2,3:end], 0.0, atol=0.2)) # no correlations P,M
@test cor_PMs[3,4] ≈ ρsM_true[1] atol=0.2
@test all(.≈(cor_PMs[3:4,5:end], 0.0, atol=0.2)) # no correlations M1, M2
@test cor_PMs[5,6] ≈ ρsM_true[1] atol=0.2
@test all(.≈(cor_PMs[5:6,7:end], 0.0, atol=0.2)) # no correlations M1, M2
@test cor_PMs[1,2] ≈ ρsP_true[1] atol=0.02
@test all(.≈(cor_PMs[1:2,3:end], 0.0, atol=0.02)) # no correlations P,M
@test cor_PMs[3,4] ≈ ρsM_true[1] atol=0.02
@test all(.≈(cor_PMs[3:4,5:end], 0.0, atol=0.02)) # no correlations M1, M2
@test cor_PMs[5,6] ≈ ρsM_true[1] atol=0.02
@test all(.≈(cor_PMs[5:6,7:end], 0.0, atol=0.02)) # no correlations M1, M2
end
test_distζ(_ζsP, _ζsMs, ϕunc_true, ζMs_g)
@testset "predict_hvi check sd" begin
Expand Down
Loading