From 514c8f4a6018e7d5f0934d0bb21e3005418023e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Mon, 8 Dec 2025 14:31:13 +0100 Subject: [PATCH] add softplus transformation --- src/scalar.jl | 17 +++++++++++++++++ test/runtests.jl | 4 ++++ 2 files changed, 21 insertions(+) diff --git a/src/scalar.jl b/src/scalar.jl index 093a36d..424539e 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -154,6 +154,23 @@ inverse_eltype(::TVNeg, ::Type{T}) where T = typeof(-oneunit(T)) inverse(::TVNeg, x::Number) = -x inverse_and_logjac(::TVNeg, x::Number) = -x, logjac_zero(LogJac(), typeof(x)) +""" +$(TYPEDEF) + +“Softplus” transformation `x ↦ log(1+exp(x))`. + +!!! NOTE + This is *experimental* and not part of the API yet. +""" +struct TVSoftPlus <: ScalarTransform end + +transform(::TVSoftPlus, x::Real) = log1pexp(x) +transform_and_logjac(t::TVSoftPlus, x::Real) = transform(t, x), -log1pexp(-x) + +inverse_eltype(::TVSoftPlus, ::Type{T}) where T = _ensure_float(T) +inverse(::TVSoftPlus, y::Number) = logexpm1(y) +inverse_and_logjac(::TVSoftPlus, y::Number) = logexpm1(y), -log1mexp(-y) + #### #### composite scalar transforms #### diff --git a/test/runtests.jl b/test/runtests.jl index fa352f9..326a4a3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -986,6 +986,10 @@ end @test_throws InexactError inverse(t, fill(Complex(0, 1), 3)) end +@testset "TVSoftPlus" begin + test_transformation(TransformVariables.TVSoftPlus(), y -> y > 0) +end + #### #### static analysis with JET ####