Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ function set_runtime_activity2(
a::Mode1, ::Enzyme.Mode{ABI, Err, RTA}) where {Mode1, ABI, Err, RTA}
Enzyme.set_runtime_activity(a, RTA)
end
function_annotation(::Nothing) = Nothing
function_annotation(::AutoEnzyme{<:Any, A}) where A = A
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
adtype::AutoEnzyme, p, num_cons = 0;
g = false, h = false, hv = false, fg = false, fgh = false,
Expand All @@ -101,6 +103,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
set_runtime_activity2(Enzyme.Forward, adtype.mode)
end

func_annot = function_annotation(adtype)

if g == true && f.grad === nothing
function grad(res, θ, p = p)
Enzyme.make_zero!(res)
Expand Down Expand Up @@ -217,6 +221,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
# if num_cons > length(x)
seeds = Enzyme.onehot(x)
Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x))
basefunc = f.cons
if func_annot <: Enzyme.Const
basefunc = Enzyme.Const(basefunc)
elseif func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated
basefunc = Enzyme.BatchDuplicated(basefunc, Tuple(make_zero(basefunc) for i in 1:length(x)))
elseif func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed
basefunc = Enzyme.BatchDuplicatedNoNeed(basefunc, Tuple(make_zero(basefunc) for i in 1:length(x)))
end
# else
# seeds = Enzyme.onehot(zeros(eltype(x), num_cons))
# Jaccache = Tuple(zero(x) for i in 1:num_cons)
Expand All @@ -225,11 +237,16 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
y = zeros(eltype(x), num_cons)

function cons_j!(J, θ)
for i in eachindex(Jaccache)
Enzyme.make_zero!(Jaccache[i])
for jc in Jaccache
Enzyme.make_zero!(jc)
end
Enzyme.make_zero!(y)
Enzyme.autodiff(fmode, f.cons, BatchDuplicated(y, Jaccache),
if func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated || func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed
for bf in basefunc.dval
Enzyme.make_zero!(bf)
end
end
Enzyme.autodiff(fmode, basefunc , BatchDuplicated(y, Jaccache),
BatchDuplicated(θ, seeds), Const(p))
for i in eachindex(θ)
if J isa Vector
Expand Down
Loading