diff --git a/src/special_arrays.jl b/src/special_arrays.jl index d8b9ac6..552ca14 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -377,18 +377,32 @@ function transform_with(flag::LogJacFlag, t::CorrCholeskyFactor, x::AbstractVect UpperTriangular(U), ℓ, index′ end -function transform_with(flag::LogJacFlag, transformation::StaticCorrCholeskyFactor{D,S}, - x::AbstractVector{T}, index) where {D,S,T} - # NOTE: add an unrolled version for small sizes - E = _ensure_float(eltype(x)) - U = if isbitstype(E) - zero(MMatrix{S,S,E}) - else - # NOTE: currently allocating because non-bitstype based AD (eg ReverseDiff) does not work with MMatrix - zeros(E, S, S) +@generated function calculate_corr_cholesky_factor(::Type{T}, flag::LogJacFlag, + t::StaticCorrCholeskyFactor{D,S}, + x::AbstractVector, index) where {T,D,S} + exprs = [:(ℓ = logjac_zero(flag, T)), :(z = zero(T))] + u(row, col) = row ≤ col ? Symbol("u_", row, "_", col) : :z + for col in 1:S + push!(exprs, :(log_r = z)) + # above diagonal + for row in 1:(col-1) + push!(exprs, + :(($(u(row, col)), log_r, Δℓ) = l2_remainder_transform(flag, x[index], log_r)), + :(ℓ += Δℓ), + :(index += 1)) + end + # diagonal + push!(exprs, :($(u(col, col)) = exp(log_r / 2))) end - U, ℓ, index′ = calculate_corr_cholesky_factor!(U, flag, x, index) - UpperTriangular(SMatrix{S,S}(U)), ℓ, index′ + U_elements = (u(row, col) for row in 1:S, col in 1:S) + push!(exprs, :(UpperTriangular(SMatrix{$S,$S}($(U_elements...))), ℓ, index)) + Expr(:block, exprs...) +end + +function transform_with(flag::LogJacFlag, transformation::StaticCorrCholeskyFactor{D,S}, + x::AbstractVector, index) where {D,S} + T = _ensure_float(eltype(x)) + calculate_corr_cholesky_factor(T, flag, transformation, x, index) end function inverse_eltype(t::Union{CorrCholeskyFactor,StaticCorrCholeskyFactor}, diff --git a/test/runtests.jl b/test/runtests.jl index fa352f9..60a55e1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -819,6 +819,11 @@ end @test inverse(t, y) ≈ x end end + + # allocations + t7 = corr_cholesky_factor(SMatrix{7,7}) + z7 = zeros(dimension(t7)) + @test @allocations(transform(t7, z7)) == 0 end @testset "corr cholesky factor large inputs" begin