diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index a19bfb9..d53b449 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -84,6 +84,8 @@ function set_runtime_activity2( a::Mode1, ::Enzyme.Mode{ABI, Err, RTA}) where {Mode1, ABI, Err, RTA} Enzyme.set_runtime_activity(a, RTA) end +function_annotation(::Nothing) = Nothing +function_annotation(::AutoEnzyme{<:Any, A}) where A = A function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoEnzyme, p, num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false, @@ -101,6 +103,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, set_runtime_activity2(Enzyme.Forward, adtype.mode) end + func_annot = function_annotation(adtype) + if g == true && f.grad === nothing function grad(res, θ, p = p) Enzyme.make_zero!(res) @@ -217,6 +221,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, # if num_cons > length(x) seeds = Enzyme.onehot(x) Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x)) + basefunc = f.cons + if func_annot <: Enzyme.Const + basefunc = Enzyme.Const(basefunc) + elseif func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated + basefunc = Enzyme.BatchDuplicated(basefunc, Tuple(make_zero(basefunc) for i in 1:length(x))) + elseif func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed + basefunc = Enzyme.BatchDuplicatedNoNeed(basefunc, Tuple(make_zero(basefunc) for i in 1:length(x))) + end # else # seeds = Enzyme.onehot(zeros(eltype(x), num_cons)) # Jaccache = Tuple(zero(x) for i in 1:num_cons) @@ -225,11 +237,16 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, y = zeros(eltype(x), num_cons) function cons_j!(J, θ) - for i in eachindex(Jaccache) - Enzyme.make_zero!(Jaccache[i]) + for jc in Jaccache + Enzyme.make_zero!(jc) end Enzyme.make_zero!(y) - Enzyme.autodiff(fmode, f.cons, BatchDuplicated(y, Jaccache), + if func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated || func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed + for bf in basefunc.dval + Enzyme.make_zero!(bf) + end + end + Enzyme.autodiff(fmode, basefunc , BatchDuplicated(y, Jaccache), BatchDuplicated(θ, seeds), Const(p)) for i in eachindex(θ) if J isa Vector