@@ -71,22 +71,24 @@ function __solve(prob::OptimizationProblem, opt, data = DEFAULT_DATA;
7171 # here should be build_solution to create the output message
7272end
7373
74- function Flux. update! (x :: AbstractArray , x̄ :: AbstractArray{<:ForwardDiff.Dual} )
75- x .- = x̄
74+ function Flux. update! (opt, xs :: Flux.Zygote.Params , gs )
75+ update! (opt, xs[ 1 ], gs)
7676end
7777
78- function Flux. update! (x:: AbstractArray , x̄)
79- x .- = getindex .(ForwardDiff. partials .(x̄),1 )
80- end
78+ @require ForwardDiff= " f6369f11-7733-5829-9624-2563aa707210" begin
79+ function Flux. update! (x:: AbstractArray , x̄:: AbstractArray{<:ForwardDiff.Dual} )
80+ x .- = x̄
81+ end
8182
82- function Flux. update! (opt, x , x̄)
83- x .- = Flux . Optimise . apply! (opt, x, x̄ )
84- end
83+ function Flux. update! (x :: AbstractArray , x̄)
84+ x .- = getindex .(ForwardDiff . partials .(x̄), 1 )
85+ end
8586
86- function Flux. update! (opt, x, x̄:: AbstractArray{<:ForwardDiff.Dual} )
87- x .- = Flux. Optimise. apply! (opt, x, getindex .(ForwardDiff . partials .(x̄), 1 ) )
88- end
87+ function Flux. update! (opt, x, x̄)
88+ x .- = Flux. Optimise. apply! (opt, x, x̄ )
89+ end
8990
90- function Flux. update! (opt, xs:: Flux.Zygote.Params , gs)
91- update! (opt, xs[1 ], gs)
91+ function Flux. update! (opt, x, x̄:: AbstractArray{<:ForwardDiff.Dual} )
92+ x .- = Flux. Optimise. apply! (opt, x, getindex .(ForwardDiff. partials .(x̄),1 ))
93+ end
9294end
0 commit comments