diff --git a/src/scalar.jl b/src/scalar.jl index 093a36d..424539e 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -154,6 +154,23 @@ inverse_eltype(::TVNeg, ::Type{T}) where T = typeof(-oneunit(T)) inverse(::TVNeg, x::Number) = -x inverse_and_logjac(::TVNeg, x::Number) = -x, logjac_zero(LogJac(), typeof(x)) +""" +$(TYPEDEF) + +“Softplus” transformation `x ↦ log(1+exp(x))`. + +!!! NOTE + This is *experimental* and not part of the API yet. +""" +struct TVSoftPlus <: ScalarTransform end + +transform(::TVSoftPlus, x::Real) = log1pexp(x) +transform_and_logjac(t::TVSoftPlus, x::Real) = transform(t, x), -log1pexp(-x) + +inverse_eltype(::TVSoftPlus, ::Type{T}) where T = _ensure_float(T) +inverse(::TVSoftPlus, y::Number) = logexpm1(y) +inverse_and_logjac(::TVSoftPlus, y::Number) = logexpm1(y), -log1mexp(-y) + #### #### composite scalar transforms #### diff --git a/test/runtests.jl b/test/runtests.jl index fa352f9..326a4a3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -986,6 +986,10 @@ end @test_throws InexactError inverse(t, fill(Complex(0, 1), 3)) end +@testset "TVSoftPlus" begin + test_transformation(TransformVariables.TVSoftPlus(), y -> y > 0) +end + #### #### static analysis with JET ####