diff --git a/Project.toml b/Project.toml index b459dfccd..75649eafa 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "BAT" uuid = "c0cd4b16-88b7-57fa-983b-ab80aecada7e" -version = "4.0.4" +version = "4.0.5" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" @@ -73,7 +73,7 @@ HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" MGVI = "fdae7790-d271-4276-880d-f72bbddf129c" NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e" Optim = "429524aa-4258-5aef-a3af-852621145aeb" -Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" +OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" UltraNest = "6822f173-b0be-4018-9ee2-28bf56348d09" @@ -86,7 +86,7 @@ BATHDF5Ext = "HDF5" BATMGVIExt = "MGVI" BATNestedSamplersExt = "NestedSamplers" BATOptimExt = "Optim" -BATOptimizationExt = ["Optimization"] +BATOptimizationBaseExt = ["OptimizationBase"] BATPlotsExt = "Plots" BATUltraNestExt = "UltraNest" @@ -140,7 +140,7 @@ NamedArrays = "0.9.3, 0.10" NestedSamplers = "0.8" OneTwoMany = "0.1.2" Optim = "1.12" -Optimization = "3, 4" +OptimizationBase = "3.3" PDMats = "0.9, 0.10, 0.11" ParallelProcessingTools = "0.4" Parameters = "0.12, 0.13" diff --git a/docs/src/list_of_algorithms.md b/docs/src/list_of_algorithms.md index 8b419ed13..88ee9957d 100644 --- a/docs/src/list_of_algorithms.md +++ b/docs/src/list_of_algorithms.md @@ -167,7 +167,7 @@ bat_findmode(target, OptimAlg(optalg = Optim.LBFGS())) Requires the [Optim](https://github.com/JuliaNLSolvers/Optim.jl) Julia package to be loaded explicitly. -### Optimization.jl Optimization Algorithms +### OptimizationBase.jl Optimization Algorithms BAT mode finding algorithm type: [`OptimizationAlg`](@ref). @@ -181,7 +181,7 @@ alg = OptimizationAlg(; ) bat_findmode(target, alg) ``` -Requires one of the [Optimization.jl](https://github.com/SciML/Optimization.jl) packages to be loaded explicitly. +Requires the desired package that implements the [OptimizationBase.jl](https://github.com/SciML/OptimizationBase.jl) interface to be loaded (e.g. via `import OptimizationOptimJL`). ### Maximum Sample Estimator diff --git a/ext/BATOptimizationExt.jl b/ext/BATOptimizationBaseExt.jl similarity index 75% rename from ext/BATOptimizationExt.jl rename to ext/BATOptimizationBaseExt.jl index c0198c01e..21b0f32ca 100644 --- a/ext/BATOptimizationExt.jl +++ b/ext/BATOptimizationBaseExt.jl @@ -1,11 +1,11 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). -module BATOptimizationExt +module BATOptimizationBaseExt -import Optimization +import OptimizationBase using BAT -BAT.pkgext(::Val{:Optimization}) = BAT.PackageExtension{:Optimization}() +BAT.pkgext(::Val{:OptimizationBase}) = BAT.PackageExtension{:OptimizationBase}() using BAT: MeasureLike, unevaluated using BAT: get_context, get_adselector @@ -18,7 +18,7 @@ using AutoDiffOperators: AbstractADType, NoAutoDiff, reverse_adtype AbstractModeEstimator(optalg::Any) = OptimizationAlg(optalg) Base.convert(::Type{AbstractModeEstimator}, alg::OptimizationAlg) = alg.optalg -BAT.ext_default(::BAT.PackageExtension{:Optimization}, ::Val{:DEFAULT_OPTALG}) = nothing #Optim.NelderMead() +BAT.ext_default(::BAT.PackageExtension{:OptimizationBase}, ::Val{:DEFAULT_OPTALG}) = nothing #Optim.NelderMead() struct _OptimizationTargetFunc{F} <: Function @@ -29,8 +29,8 @@ _OptimizationTargetFunc(::Type{F}) where F = _OptimizationTargetFunc{Type{F}}(F) (ft::_OptimizationTargetFunc)(x, ::Any) = ft.f(x) -build_optimizationfunction(f, ad::AbstractADType) = Optimization.OptimizationFunction(f, ad) -build_optimizationfunction(f, ::NoAutoDiff) = Optimization.OptimizationFunction(f) +build_optimizationfunction(f, ad::AbstractADType) = OptimizationBase.OptimizationFunction(f, ad) +build_optimizationfunction(f, ::NoAutoDiff) = OptimizationBase.OptimizationFunction(f) function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimizationAlg, context::BATContext) @@ -47,12 +47,12 @@ function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimizationAlg, f_target = _OptimizationTargetFunc(f) ad = reverse_adtype(get_adselector(context)) optimization_function = build_optimizationfunction(f_target, ad) - optimization_problem = Optimization.OptimizationProblem(optimization_function, x_init) + optimization_problem = OptimizationBase.OptimizationProblem(optimization_function, x_init) algopts = (maxiters = algorithm.maxiters, maxtime = algorithm.maxtime, abstol = algorithm.abstol, reltol = algorithm.reltol) # Not all algorithms support abstol, just filter all NaN-valued opts out: filtered_algopts = NamedTuple(filter(p -> !isnan(p[2]), pairs(algopts))) - optimization_result = Optimization.solve(optimization_problem, algorithm.optalg; filtered_algopts..., algorithm.kwargs...) + optimization_result = OptimizationBase.solve(optimization_problem, algorithm.optalg; filtered_algopts..., algorithm.kwargs...) transformed_mode = optimization_result.u result_mode = inv_trafo(transformed_mode) @@ -61,4 +61,4 @@ function BAT.bat_findmode_impl(target::MeasureLike, algorithm::OptimizationAlg, end -end # module BATOptimizationExt +end # module BATOptimizationBaseExt diff --git a/src/extdefs/optimization_defs.jl b/src/extdefs/optimization_defs.jl index fb107955b..c162eb6ab 100644 --- a/src/extdefs/optimization_defs.jl +++ b/src/extdefs/optimization_defs.jl @@ -4,7 +4,7 @@ struct OptimizationAlg Selects an optimization algorithm from the -[Optimization.jl](https://github.com/SciML/Optimization.jl) +[OptimizationBase.jl](https://github.com/SciML/OptimizationBase.jl) package. Note that when using first order algorithms like `OptimizationOptimJL.LBFGS`, your [`BATContext`](@ref) needs to have `ad` set to an automatic differentiation @@ -12,21 +12,21 @@ backend. Constructors: * ```$(FUNCTIONNAME)(; fields...)``` -`optalg` must be an `Optimization.AbstractOptimizer`. +`optalg` must be an `OptimizationBase.AbstractOptimizer`. The field `kwargs` can be used to pass additional keywords to the optimizers -See the [Optimization.jl documentation](https://docs.sciml.ai/Optimization/stable/) for the available keyword arguments. +See the [OptimizationBase.jl documentation](https://docs.sciml.ai/Optimization/stable/) for the available keyword arguments. Fields: $(TYPEDFIELDS) !!! note - This algorithm is only available if the `Optimization` package or any of its submodules, like `OptimizationOptimJL`, is loaded (e.g. via - `import Optimization`). + This algorithm is only available if the `OptimizationBase` package or any of its submodules, like `OptimizationOptimJL`, is loaded (e.g. via + `import OptimizationOptimJL`). """ @with_kw struct OptimizationAlg{ ALG, TR<:AbstractTransformTarget, IA<:InitvalAlgorithm } <: AbstractModeEstimator - optalg::ALG = ext_default(pkgext(Val(:Optimization)), Val(:DEFAULT_OPTALG)) + optalg::ALG = ext_default(pkgext(Val(:OptimizationBase)), Val(:DEFAULT_OPTALG)) pretransform::TR = PriorToNormal() init::IA = InitFromTarget() maxiters::Int64 = 1_000 diff --git a/src/samplers/mcmc/mcmc_state.jl b/src/samplers/mcmc/mcmc_state.jl index 53734a673..1b815a93a 100644 --- a/src/samplers/mcmc/mcmc_state.jl +++ b/src/samplers/mcmc/mcmc_state.jl @@ -303,8 +303,8 @@ end function mcmc_update_z_position!!(mc_state::MCMCChainState) f_inv = inverse(mc_state.f_transform) - current_z_new::typeof(mc_state.current.z) = transform_samples(f_inv, mc_state.current.x) - proposed_z_new::typeof(mc_state.proposed.z) = transform_samples(f_inv, mc_state.proposed.x) + current_z_new = _transform_dsv!!(f_inv, mc_state.current.z, mc_state.current.x) + proposed_z_new = _transform_dsv!!(f_inv, mc_state.proposed.z, mc_state.proposed.x) mc_state_new::typeof(mc_state) = @set mc_state.current.z = current_z_new mc_state_new = @set mc_state_new.proposed.z = proposed_z_new diff --git a/src/transforms/trafo_utils.jl b/src/transforms/trafo_utils.jl index 677057403..d1e77cad4 100644 --- a/src/transforms/trafo_utils.jl +++ b/src/transforms/trafo_utils.jl @@ -271,3 +271,17 @@ function _trafo_create_unshaped_ys(f, xs, y_shape::AbstractValueShape) n = length(eachindex(xs)) return nestedview(allocate_array(cpunit, R, (m, n))) end + + +function _transform_dsv!!(f, dsv_y::DensitySampleVector, dsv_x::DensitySampleVector) + xs = dsv_x.v + ys_ladjs = with_logabsdet_jacobian.(f, xs) + + dsv_y.v .= first.(ys_ladjs) + dsv_y.logd .= dsv_x.logd .- getsecond.(ys_ladjs) + dsv_y.weight .= dsv_x.weight + dsv_y.info .= dsv_x.info + dsv_y.aux .= dsv_x.aux + + return dsv_y +end diff --git a/test/Project.toml b/test/Project.toml index 8d6742fbb..21a1b3a54 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -30,6 +30,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MGVI = "fdae7790-d271-4276-880d-f72bbddf129c" NestedSamplers = "41ceaf6f-1696-4a54-9b49-2e7a9ec3782e" Optim = "429524aa-4258-5aef-a3af-852621145aeb" +OptimizationLBFGSB = "22f7324a-a79d-40f2-bebe-3af60c77bd15" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" ParallelProcessingTools = "8e8a01fc-6193-5ca1-a2f1-20776dae4199" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" @@ -50,4 +51,4 @@ ValueShapes = "136a8f8c-c49b-4edb-8b98-f3d64d48be8f" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -BATTestCases = "0.1" +BATTestCases = "0.2" diff --git a/test/optimization/test_mode_estimators.jl b/test/optimization/test_mode_estimators.jl index c2f8df778..222fde9ca 100644 --- a/test/optimization/test_mode_estimators.jl +++ b/test/optimization/test_mode_estimators.jl @@ -5,7 +5,7 @@ using AutoDiffOperators using LinearAlgebra, Distributions, StatsBase, ValueShapes, Random123, DensityInterface using UnPack, InverseFunctions import ForwardDiff -using Optim, OptimizationOptimJL +using Optim, OptimizationOptimJL, OptimizationLBFGSB @testset "mode_estimators" begin prior = NamedTupleDist( @@ -98,17 +98,17 @@ using Optim, OptimizationOptimJL end - @testset "Optimization.jl - NelderMead" begin + @testset "OptimizationBase.jl" begin context = BATContext(rng = Philox4x((0, 0))) # result is not type-stable: test_findmode(posterior, OptimizationAlg(optalg = OptimizationOptimJL.NelderMead(), pretransform = DoNotTransform()), 0.01, context, inferred = false) context = BATContext(rng = Philox4x((0, 0)), ad = ADSelector(ForwardDiff)) # result is not type-stable: - test_findmode(posterior, OptimizationAlg(optalg = Optimization.LBFGS(), pretransform = DoNotTransform()), 0.01, context, inferred = false) + test_findmode(posterior, OptimizationAlg(optalg = OptimizationLBFGSB.LBFGSB(), pretransform = DoNotTransform()), 0.01, context, inferred = false) end - @testset "Optimization.jl with custom options" begin # checks that options are correctly passed to Optimization.jl + @testset "OptimizationBase.jl with custom options" begin # checks that options are correctly passed to OptimizationBase.jl context = BATContext(rng = Philox4x((0, 0))) optimizer = OptimizationAlg(optalg = OptimizationOptimJL.ParticleSwarm(n_particles=10), maxiters=200, kwargs=(f_calls_limit=500,), pretransform=DoNotTransform())