From 75fc2a33eba6220a07a5717e673f4870102321fd Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 26 May 2025 18:02:27 -0500 Subject: [PATCH 1/6] Enzyme: add func_annotation --- ext/OptimizationEnzymeExt.jl | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index a19bfb9..669fa11 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -101,6 +101,12 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, set_runtime_activity2(Enzyme.Forward, adtype.mode) end + func_annot = if adtype.mode isa Nothing + Nothing + else + adtype.mode.function_annotation + end + if g == true && f.grad === nothing function grad(res, θ, p = p) Enzyme.make_zero!(res) @@ -217,6 +223,14 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, # if num_cons > length(x) seeds = Enzyme.onehot(x) Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x)) + basefunc = f.cons + if func_annot <: Enzyme.Const + basefunc = Enzyme.Const(basefunc) + elseif func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated + basefunc = Enzyme.BatchDuplicated(basefunc, Tuple(make_zero(basefunc) for i in 1:length(x))) + elseif func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed + basefunc = Enzyme.BatchDuplicatedNoNeed(basefunc, Tuple(make_zero(basefunc) for i in 1:length(x))) + end # else # seeds = Enzyme.onehot(zeros(eltype(x), num_cons)) # Jaccache = Tuple(zero(x) for i in 1:num_cons) @@ -225,11 +239,16 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, y = zeros(eltype(x), num_cons) function cons_j!(J, θ) - for i in eachindex(Jaccache) - Enzyme.make_zero!(Jaccache[i]) + for jc in Jaccache + Enzyme.make_zero!(jc) end Enzyme.make_zero!(y) - Enzyme.autodiff(fmode, f.cons, BatchDuplicated(y, Jaccache), + if func_annot <: Enzyme.Duplicated || func_annot <: Enzyme.BatchDuplicated || func_annot <: Enzyme.DuplicatedNoNeed || func_annot <: Enzyme.BatchDuplicatedNoNeed + for bf in basefunc.dval + Enzyme.make_zero!(bf) + end + end + Enzyme.autodiff(fmode, basfunc, BatchDuplicated(y, Jaccache), BatchDuplicated(θ, seeds), Const(p)) for i in eachindex(θ) if J isa Vector From 1607b88623d0a64484e968a14313a26e1c630f78 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 26 May 2025 21:18:39 -0500 Subject: [PATCH 2/6] Update OptimizationEnzymeExt.jl --- ext/OptimizationEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 669fa11..b6dfb81 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -104,7 +104,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, func_annot = if adtype.mode isa Nothing Nothing else - adtype.mode.function_annotation + adtype.function_annotation end if g == true && f.grad === nothing From 9130204f729a55249676906e80df7bede328dcf7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 27 May 2025 00:51:23 -0500 Subject: [PATCH 3/6] Update OptimizationEnzymeExt.jl --- ext/OptimizationEnzymeExt.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index b6dfb81..15a246a 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -84,6 +84,8 @@ function set_runtime_activity2( a::Mode1, ::Enzyme.Mode{ABI, Err, RTA}) where {Mode1, ABI, Err, RTA} Enzyme.set_runtime_activity(a, RTA) end +function_annotation(::Nothing) = Nothing +function_annotation(::AutoEnzyme{<:Any, A}) = A function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoEnzyme, p, num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false, @@ -101,11 +103,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, set_runtime_activity2(Enzyme.Forward, adtype.mode) end - func_annot = if adtype.mode isa Nothing - Nothing - else - adtype.function_annotation - end + func_annot = function_annotation(adtype.function_annotation) if g == true && f.grad === nothing function grad(res, θ, p = p) From c9b6dcaa0fc3ec676b91bc4ed0d84a17aa67ef65 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 28 May 2025 01:01:24 +0000 Subject: [PATCH 4/6] Update ext/OptimizationEnzymeExt.jl --- ext/OptimizationEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 15a246a..4b39534 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -246,7 +246,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, Enzyme.make_zero!(bf) end end - Enzyme.autodiff(fmode, basfunc, BatchDuplicated(y, Jaccache), + Enzyme.autodiff(fmode, basefunc , BatchDuplicated(y, Jaccache), BatchDuplicated(θ, seeds), Const(p)) for i in eachindex(θ) if J isa Vector From a248217855d2cf6bb59864001d19963c510640e3 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 29 May 2025 01:51:17 +0000 Subject: [PATCH 5/6] Update ext/OptimizationEnzymeExt.jl --- ext/OptimizationEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 4b39534..e47cc94 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -85,7 +85,7 @@ function set_runtime_activity2( Enzyme.set_runtime_activity(a, RTA) end function_annotation(::Nothing) = Nothing -function_annotation(::AutoEnzyme{<:Any, A}) = A +function_annotation(::AutoEnzyme{<:Any, A}) where A = A function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoEnzyme, p, num_cons = 0; g = false, h = false, hv = false, fg = false, fgh = false, From d11da6c467c290679c96b659bf3ae6da92b710eb Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 29 May 2025 02:17:20 +0000 Subject: [PATCH 6/6] Update ext/OptimizationEnzymeExt.jl --- ext/OptimizationEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index e47cc94..d53b449 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -103,7 +103,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, set_runtime_activity2(Enzyme.Forward, adtype.mode) end - func_annot = function_annotation(adtype.function_annotation) + func_annot = function_annotation(adtype) if g == true && f.grad === nothing function grad(res, θ, p = p)