Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions src/special_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,18 +377,32 @@ function transform_with(flag::LogJacFlag, t::CorrCholeskyFactor, x::AbstractVect
UpperTriangular(U), ℓ, index′
end

function transform_with(flag::LogJacFlag, transformation::StaticCorrCholeskyFactor{D,S},
x::AbstractVector{T}, index) where {D,S,T}
# NOTE: add an unrolled version for small sizes
E = _ensure_float(eltype(x))
U = if isbitstype(E)
zero(MMatrix{S,S,E})
else
# NOTE: currently allocating because non-bitstype based AD (eg ReverseDiff) does not work with MMatrix
zeros(E, S, S)
@generated function calculate_corr_cholesky_factor(::Type{T}, flag::LogJacFlag,
t::StaticCorrCholeskyFactor{D,S},
x::AbstractVector, index) where {T,D,S}
exprs = [:(ℓ = logjac_zero(flag, T)), :(z = zero(T))]
u(row, col) = row ≤ col ? Symbol("u_", row, "_", col) : :z
for col in 1:S
push!(exprs, :(log_r = z))
# above diagonal
for row in 1:(col-1)
push!(exprs,
:(($(u(row, col)), log_r, Δℓ) = l2_remainder_transform(flag, x[index], log_r)),
:(ℓ += Δℓ),
:(index += 1))
end
# diagonal
push!(exprs, :($(u(col, col)) = exp(log_r / 2)))
end
U, ℓ, index′ = calculate_corr_cholesky_factor!(U, flag, x, index)
UpperTriangular(SMatrix{S,S}(U)), ℓ, index′
U_elements = (u(row, col) for row in 1:S, col in 1:S)
push!(exprs, :(UpperTriangular(SMatrix{$S,$S}($(U_elements...))), ℓ, index))
Expr(:block, exprs...)
end

function transform_with(flag::LogJacFlag, transformation::StaticCorrCholeskyFactor{D,S},
x::AbstractVector, index) where {D,S}
T = _ensure_float(eltype(x))
calculate_corr_cholesky_factor(T, flag, transformation, x, index)
end

function inverse_eltype(t::Union{CorrCholeskyFactor,StaticCorrCholeskyFactor},
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,11 @@ end
@test inverse(t, y) ≈ x
end
end

# allocations
t7 = corr_cholesky_factor(SMatrix{7,7})
z7 = zeros(dimension(t7))
@test @allocations(transform(t7, z7)) == 0
end

@testset "corr cholesky factor large inputs" begin
Expand Down
Loading