From d593c41d01b2463ecbce6f363e113b9e63cd9034 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Apr 2022 15:09:53 -0400 Subject: [PATCH 01/76] move towards efl --- .github/workflows/CI.yml | 2 +- Project.toml | 18 +- src/FastDEQ.jl | 62 ++-- src/adjoint.jl | 35 ++ src/layers/chain.jl | 51 +++ src/layers/core.jl | 55 +++ src/layers/deq.jl | 124 +++---- src/layers/jacobian_stabilization.jl | 38 +- src/layers/mdeq.jl | 226 ++++++------ src/layers/sdeq.jl | 130 ------- src/layers/smdeq.jl | 235 ------------- src/layers/utils.jl | 76 ---- src/models/basics.jl | 46 --- src/models/chain.jl | 61 ---- src/operator.jl | 25 ++ src/solve.jl | 324 +----------------- src/solvers/continuous.jl | 147 ++++++++ src/solvers/discrete.jl | 27 ++ src/solvers/{ => discrete}/broyden.jl | 0 .../{ => discrete}/limited_memory_broyden.jl | 0 src/utils.jl | 74 +--- test/runtests.jl | 302 ++++++++-------- 22 files changed, 753 insertions(+), 1305 deletions(-) create mode 100644 src/adjoint.jl create mode 100644 src/layers/chain.jl create mode 100644 src/layers/core.jl delete mode 100644 src/layers/sdeq.jl delete mode 100644 src/layers/smdeq.jl delete mode 100644 src/layers/utils.jl delete mode 100644 src/models/basics.jl delete mode 100644 src/models/chain.jl create mode 100644 src/operator.jl create mode 100644 src/solvers/continuous.jl create mode 100644 src/solvers/discrete.jl rename src/solvers/{ => discrete}/broyden.jl (100%) rename src/solvers/{ => discrete}/limited_memory_broyden.jl (100%) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b85bfce8..286a86cd 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -29,7 +29,7 @@ jobs: ${{ runner.os }}-test- ${{ runner.os }}- - name: Install dependencies - run: julia --project=. -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/FluxExperimental.jl"); Pkg.add(url="https://github.com/SciML/DiffEqSensitivity.jl", rev="ap/fastdeq"); Pkg.instantiate()' + run: julia --project=. -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/ExplicitFluxLayers.jl", rev="ap/sparse"); Pkg.instantiate()' - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 with: diff --git a/Project.toml b/Project.toml index f133446f..02a925ea 100644 --- a/Project.toml +++ b/Project.toml @@ -5,19 +5,22 @@ version = "0.1.0" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2" +ExplicitFluxLayers = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -FluxExperimental = "c0d22e4d-7f3e-44a4-9c97-37045f84daf2" -FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" @@ -25,14 +28,17 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] CUDA = "3" +ChainRulesCore = "1" DiffEqBase = "6" DiffEqCallbacks = "2.20.1" DiffEqSensitivity = "6.64" -Flux = "0.12" -FluxMPI = "0.1.1" +ExplicitFluxLayers = "0.2" +Flux = "0.13" +Functors = "0.2" LinearSolve = "1" OrdinaryDiffEq = "6" SciMLBase = "1.19" +Setfield = "0.8.2" SteadyStateDiffEq = "1.6" UnPack = "1" Zygote = "0.6.34" @@ -41,10 +47,10 @@ julia = "1.7" [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -FluxExperimental = "c0d22e4d-7f3e-44a4-9c97-37045f84daf2" +ExplicitFluxLayers = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["CUDA", "Flux", "FluxExperimental", "LinearAlgebra", "Random", "Test"] \ No newline at end of file +test = ["CUDA", "Flux", "ExplicitFluxLayers", "LinearAlgebra", "Random", "Test"] diff --git a/src/FastDEQ.jl b/src/FastDEQ.jl index 8830f9e7..3d2b538a 100644 --- a/src/FastDEQ.jl +++ b/src/FastDEQ.jl @@ -1,48 +1,50 @@ module FastDEQ -using CUDA, DiffEqBase, DiffEqCallbacks, DiffEqSensitivity, Flux, FluxExperimental, LinearAlgebra, LinearSolve, - OrdinaryDiffEq, SciMLBase, Statistics, SteadyStateDiffEq, UnPack, Zygote - -abstract type AbstractDeepEquilibriumNetwork end - -function Base.show(io::IO, l::AbstractDeepEquilibriumNetwork) - return print(io, string(typeof(l).name.name), "() ", string(length(l.p)), " Trainable Parameters") -end - -Flux.trainable(d::AbstractDeepEquilibriumNetwork) = (d.p,) - -Base.deepcopy(op::DiffEqSensitivity.ZygotePullbackMultiplyOperator) = op +using CUDA, + DiffEqBase, + DiffEqCallbacks, + DiffEqSensitivity, + Flux, + LinearAlgebra, + LinearSolve, + OrdinaryDiffEq, + SciMLBase, + Statistics, + SteadyStateDiffEq, + UnPack, + Zygote, + ExplicitFluxLayers, + Functors, + ChainRulesCore, + Setfield + +import ExplicitFluxLayers: + AbstractExplicitLayer, initialparameters, initialstates, createcache, parameterlength, statelength, cachesize +import Random: AbstractRNG + +include("operator.jl") + +include("solvers/continuous.jl") +include("solvers/discrete.jl") include("solve.jl") include("utils.jl") -include("solvers/broyden.jl") -include("solvers/limited_memory_broyden.jl") - -include("models/basics.jl") - +include("layers/core.jl") include("layers/jacobian_stabilization.jl") -include("layers/utils.jl") include("layers/deq.jl") -include("layers/sdeq.jl") include("layers/mdeq.jl") -include("layers/smdeq.jl") +include("layers/chain.jl") -include("models/chain.jl") - -include("losses.jl") +include("adjoint.jl") # DEQ Solvers export ContinuousDEQSolver, DiscreteDEQSolver, BroydenSolver, LimitedMemoryBroydenSolver # Utils -export NormalInitializer, SteadyStateAdjoint, get_and_clear_nfe!, compute_deq_jacobian_loss, DeepEquilibriumSolution, SupervisedLossContainer - -# Layers -export MultiParallelNet +export NormalInitializer, SteadyStateAdjoint, compute_deq_jacobian_loss, DeepEquilibriumSolution -# DEQ Layers -export DeepEquilibriumNetwork, SkipDeepEquilibriumNetwork, MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork -export DEQChain +export DeepEquilibriumNetwork, + SkipDeepEquilibriumNetwork, MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork, DEQChain end diff --git a/src/adjoint.jl b/src/adjoint.jl new file mode 100644 index 00000000..3c31ac78 --- /dev/null +++ b/src/adjoint.jl @@ -0,0 +1,35 @@ +neg(x) = -x +neg(::Nothing) = nothing + +@noinline function DiffEqSensitivity.SteadyStateAdjointProblem( + sol::EquilibriumSolution, sensealg::DiffEqSensitivity.SteadyStateAdjoint, g::Nothing, dg; save_idxs=nothing +) + @unpack f, p, u0 = sol.prob + + diffcache, y = DiffEqSensitivity.adjointdiffcache(g, sensealg, false, sol, dg, f; quad=false, needs_jac=false) + + _save_idxs = save_idxs === nothing ? Colon() : save_idxs + if dg !== nothing + if typeof(_save_idxs) <: Number + diffcache.dg_val[_save_idxs] = dg[_save_idxs] + elseif typeof(dg) <: Number + @. diffcache.dg_val[_save_idxs] = dg + else + @. diffcache.dg_val[_save_idxs] = dg[_save_idxs] + end + end + + # Solve the Linear Problem + _val, back = Zygote.pullback(x -> f(x, p, nothing), y) + s_val = size(_val) + op = ZygotePullbackMultiplyOperator{eltype(y),typeof(back),typeof(s_val)}(back, s_val) + linear_problem = LinearProblem(op, vec(diffcache.dg_val)) + ## Automatically choose the best algorithm + λ = solve(linear_problem, sensealg.linsolve).u + + # Compute the VJP + _, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p) + dp = back(vec(λ))[1] + + return dp isa NamedTuple ? fmap(neg, dp) : -vec(dp) +end diff --git a/src/layers/chain.jl b/src/layers/chain.jl new file mode 100644 index 00000000..8d7e935b --- /dev/null +++ b/src/layers/chain.jl @@ -0,0 +1,51 @@ +struct DEQChain{P1,D<:AbstractDeepEquilibriumNetwork,P2} <: AbstractExplicitLayer + pre_deq::P1 + deq::D + post_deq::P2 +end + +function initialparameters(rng::AbstractRNG, c::DEQChain) + return ( + pre_deq=initialparameters(rng, c.pre_deq), + deq=initialparameters(rng, c.deq), + post_deq=initialparameters(rng, c.post_deq), + ) +end + +function initialstates(rng::AbstractRNG, c::DEQChain) + return ( + pre_deq=initialstates(rng, c.pre_deq), deq=initialstates(rng, c.deq), post_deq=initialstates(rng, c.post_deq) + ) +end + +function DEQChain(layers...) + pre_deq, post_deq, deq, encounter_deq = [], [], nothing, false + for l in layers + if l isa AbstractDeepEquilibriumNetwork + @assert !encounter_deq "Can have only 1 DEQ Layer in the Chain!!!" + deq = l + encounter_deq = true + continue + end + push!(encounter_deq ? post_deq : pre_deq, l) + end + @assert encounter_deq "No DEQ Layer in the Chain!!! Maybe you wanted to use Chain" + pre_deq = length(pre_deq) == 0 ? nothing : ExplicitFluxLayers.Chain(pre_deq...) + post_deq = length(post_deq) == 0 ? nothing : ExplicitFluxLayers.Chain(post_deq...) + return DEQChain(pre_deq, deq, post_deq) +end + +function (deq::DEQChain{P1,D,P2})(x, ps::NamedTuple, st::NamedTuple) where {P1,D,P2} + x1, st1 = if P1 == Nothing + x, st.pre_deq + else + deq.pre_deq(x, ps.pre_deq, st.pre_deq) + end + (x2, deq_soln), st2 = deq.deq(x1, ps.deq, st.deq) + x3, st3 = if P2 == Nothing + x2, st.post_deq + else + deq.post_deq(x2, ps.post_deq, st.post_deq) + end + return (x3, deq_soln), (pre_deq=st1, deq=st2, post_deq=st3) +end diff --git a/src/layers/core.jl b/src/layers/core.jl new file mode 100644 index 00000000..0cd9dcba --- /dev/null +++ b/src/layers/core.jl @@ -0,0 +1,55 @@ +abstract type AbstractDeepEquilibriumNetwork <: AbstractExplicitLayer end +abstract type AbstractSkipDeepEquilibriumNetwork <: AbstractDeepEquilibriumNetwork end + +initialparameters(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork) = initialparameters(rng, deq.model) +initialstates(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork) = initialstates(rng, deq.model) +createcache(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork, x) = createcache(rng, deq.model, x) + +parameterlength(deq::AbstractDeepEquilibriumNetwork) = parameterlength(deq.model) +statelength(deq::AbstractDeepEquilibriumNetwork) = statelength(deq.model) +cachesize(deq::AbstractDeepEquilibriumNetwork) = cachesize(deq.model) + +function initialparameters(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork) + return (model=initialparameters(rng, deq.model), shortcut=initialparameters(rng, deq.shortcut)) +end +function initialstates(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork) + return (model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut)) +end +function createcache(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork, x) + return (model=createcache(rng, deq.model, x), shortcut=createcache(rng, deq.shortcut, x)) +end + +parameterlength(deq::AbstractSkipDeepEquilibriumNetwork) = parameterlength(deq.model) + parameterlength(deq.shortcut) +statelength(deq::AbstractSkipDeepEquilibriumNetwork) = statelength(deq.model) + statelength(deq.shortcut) +cachesize(deq::AbstractSkipDeepEquilibriumNetwork) = cachesize(deq.model) + cachesize(deq.shortcut) + +""" + DeepEquilibriumSolution(z_star, u₀, residual, jacobian_loss, nfe) + +Stores the solution of a DeepEquilibriumNetwork and its variants. + +## Fields + * `z_star`: Steady-State or the value reached due to maxiters + * `u₀`: Initial Condition + * `residual`: Difference of the ``z^*`` and ``f(z^*, x)`` + * `jacobian_loss`: Jacobian Stabilization Loss (see individual networks to see how it can be computed) + * `nfe`: Number of Function Evaluations +""" +struct DeepEquilibriumSolution{T,R<:AbstractFloat} + z_star::T + u₀::T + residual::T + jacobian_loss::R + nfe::Int +end + +function Base.show(io::IO, l::DeepEquilibriumSolution) + print(io, "DeepEquilibriumSolution(") + print(io, ", z_star: ", l.z_star) + print(io, ", initial_condition: ", l.u₀) + print(io, ", residual: ", l.residual) + print(io, ", jacobian_loss: ", l.jacobian_loss) + print(io, ", NFE: ", l.nfe) + print(io, ")") + return nothing +end \ No newline at end of file diff --git a/src/layers/deq.jl b/src/layers/deq.jl index 931c19ef..cbd59b06 100644 --- a/src/layers/deq.jl +++ b/src/layers/deq.jl @@ -1,81 +1,81 @@ -""" - DeepEquilibriumNetwork(model, solver; jacobian_regularization::Bool=false, - p=nothing, sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), - kwargs...) +struct DeepEquilibriumNetwork{J,M,A,S,K} <: AbstractDeepEquilibriumNetwork + model::M + solver::A + sensealg::S + kwargs::K +end -Deep Equilibrium Network as proposed in [baideep2019](@cite) +function DeepEquilibriumNetwork( + model, solver; jacobian_regularization::Bool=false, sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs... +) + return DeepEquilibriumNetwork{ + jacobian_regularization,typeof(model),typeof(solver),typeof(sensealg),typeof(kwargs) + }( + model, solver, sensealg, kwargs + ) +end -## Arguments +function (deq::DeepEquilibriumNetwork{J})(x::AbstractArray{T}, ps::NamedTuple, st::NamedTuple) where {J,T} + z = zero(x) -* `model`: Explicit Neural Network which takes 2 inputs -* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) -* `jacobian_regularization`: If true, Jacobian Loss is computed and stored in the [`DeepEquilibriumSolution`](@ref) -* `p`: Optional parameters for the `model` -* `sensealg`: See [`SteadyStateAdjoint`](@ref) -* `kwargs`: Additional Parameters that are directly passed to `solve` + function dudt(u, p, t) + u_, _ = deq.model((u, x), p, st) + return u_ .- u + end -## Example + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps) + sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) + z_star, st_ = deq.model((sol.u, x), ps, st) -```julia -model = DeepEquilibriumNetwork( - Parallel( - +, - Dense(2, 2; bias=false), - Dense(2, 2; bias=false) - ), - ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0) -) + jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps, st, z_star, x) : T(0)) + residual = Zygote.@ignore z_star .- deq.model((z_star, x), ps, st)[1] -model(rand(Float32, 2, 1)) -``` + return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st_ +end -See also: [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref) -""" -struct DeepEquilibriumNetwork{J,M,P,RE,A,S,K} <: AbstractDeepEquilibriumNetwork - jacobian_regularization::Bool +struct SkipDeepEquilibriumNetwork{J,M,Sh,A,S,K} <: AbstractSkipDeepEquilibriumNetwork model::M - p::P - re::RE + shortcut::Sh solver::A - kwargs::K sensealg::S - stats::DEQTrainingStats - - function DeepEquilibriumNetwork(jacobian_regularization, model, p, re, solver, kwargs, sensealg, stats) - _p, re = destructure_parameters(model) - p = p === nothing ? _p : convert(typeof(_p), p) - - return new{jacobian_regularization,typeof(model),typeof(p),typeof(re),typeof(solver), - typeof(sensealg),typeof(kwargs)}(jacobian_regularization, model, p, re, - solver, kwargs, sensealg, stats) - end -end - -Flux.@functor DeepEquilibriumNetwork - -function Base.show(io::IO, l::DeepEquilibriumNetwork{J}) where {J} - return print(io, "DeepEquilibriumNetwork(jacobian_regularization = $J) ", - string(length(l.p)), " Trainable Parameters") + kwargs::K end -function DeepEquilibriumNetwork(model, solver; jacobian_regularization::Bool=false, - p=nothing, sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - return DeepEquilibriumNetwork(jacobian_regularization, model, p, nothing, solver, - kwargs, sensealg, DEQTrainingStats(0)) +function SkipDeepEquilibriumNetwork( + model, + shortcut, + solver; + jacobian_regularization::Bool=false, + sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + kwargs..., +) + return SkipDeepEquilibriumNetwork{ + jacobian_regularization,typeof(model),typeof(shortcut),typeof(solver),typeof(sensealg),typeof(kwargs) + }( + model, shortcut, solver, sensealg, kwargs + ) end -function (deq::DeepEquilibriumNetwork)(x::AbstractArray{T}) where {T} - z = zero(x) - Zygote.@ignore deq.re(deq.p)(z, x) - - current_nfe = deq.stats.nfe +function (deq::SkipDeepEquilibriumNetwork{J,M,S})(x::AbstractArray{T}, ps::NamedTuple, st::NamedTuple) where {J,M,S,T} + z, st__ = if S == Nothing + deq.model((zero(x), x), ps.model, st.model) + else + deq.shortcut(x, ps.shortcut, st.shortcut) + end + @set! st.shortcut = st__ - z_star = solve_steady_state_problem(deq.re, deq.p, x, z, deq.sensealg, deq.solver; dudt=nothing, - update_nfe=() -> (deq.stats.nfe += 1), deq.kwargs...) + function dudt(u, p, t) + u_, = deq.model((u, x), p, st.model) + return u_ .- u + end - jac_loss = (deq.jacobian_regularization ? compute_deq_jacobian_loss(deq.re, deq.p, z_star, x) : T(0))::T + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model) + sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) + z_star, st_ = deq.model((sol.u, x), ps.model, st.model) - residual = Zygote.@ignore z_star .- deq.re(deq.p)(z_star, x) + jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : T(0)) + residual = Zygote.@ignore z_star .- deq.model((z_star, x), ps.model, st.model)[1] + @set! st.model = st_ :: typeof(st.model) - return z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, deq.stats.nfe - current_nfe) + return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end diff --git a/src/layers/jacobian_stabilization.jl b/src/layers/jacobian_stabilization.jl index ae6bbe89..3f53f452 100644 --- a/src/layers/jacobian_stabilization.jl +++ b/src/layers/jacobian_stabilization.jl @@ -1,32 +1,8 @@ -gaussian_like(p::Array) = randn(eltype(p), size(p)) -gaussian_like(p::CuArray) = CUDA.randn(eltype(p), size(p)) - -Zygote.@nograd gaussian_like - -""" - compute_deq_jacobian_loss(re, p, z, x) - -Computes Jacobian Stabilization Loss ([bai2021stabilizing](@cite)). - -## Arguments - -* `re`: Constructs the model given the parameters `p`. -* `p`: Parameters of the model. -* `z`: Steady State. -* `x`: Input to the model. - -## Current Known Failure Modes - -1. Conv layers error out due to ForwardDiff on GPUs -2. If the model internally uses destructure/restructure eg. `WeightNorm` Layer, then this loss function will error out in the backward pass. -""" -function compute_deq_jacobian_loss(re, p::AbstractVector{T}, z::A, x::A) where {T,A<:AbstractArray} - d = length(z) - v = gaussian_like(z) - model = re(p) - - _, back = Zygote.pullback(model, z, x) - vjp_z, vjp_x = back(v) - # NOTE: This weird sum(zero, ...) ensures that we get zeros instead of nothings - return sum(abs2, vjp_z) / d + sum(zero, vjp_x) +# Doesn't work as of now +function compute_deq_jacobian_loss( + model::AbstractExplicitLayer, ps::NamedTuple, st::NamedTuple, z::AbstractArray, x::AbstractArray +) + l, back = Zygote.pullback(u -> model((u, x), ps, st)[1], z) + vjp_z = back(gaussian_like(l))[1] + return sum(abs2, vjp_z) / length(z) end diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index d90227ab..d373b303 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -1,131 +1,149 @@ -""" - MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, solver; - post_fuse_layers::Union{Tuple,Nothing}=nothing, p=nothing, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - -Multiscale Deep Equilibrium Network as proposed in [baimultiscale2020](@cite) - -## Arguments - -* `main_layers`: Tuple of Explicit Neural Networks. The first network needs to take 2 inputs, the other ones only take 1 input -* `mapping_layers`: Matrix of Explicit Neural Networks. The ``(i, j)^{th}`` network takes the output of ``i^{th}`` `main_layer` - and passes it to the ``j^{th}`` `main_layer` -* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) -* `post_fuse_layers`: Tuple of Explicit Neural Networks. Applied after the `mapping_layers` (Default: `nothing`) -* `p`: Optional parameters for the `model` -* `sensealg`: See [`SteadyStateAdjoint`](@ref) -* `kwargs`: Additional Parameters that are directly passed to `solve` - -## Example - -```julia -model = MultiScaleDeepEquilibriumNetwork( - ( - Parallel(+, Dense(4, 4, tanh_fast), Dense(4, 4, tanh_fast)), - Dense(3, 3, tanh_fast), Dense(2, 2, tanh_fast), - Dense(1, 1, tanh_fast) - ), - [ - NoOpLayer() Dense(4, 3, tanh_fast) Dense(4, 2, tanh_fast) Dense(4, 1, tanh_fast); - Dense(3, 4, tanh_fast) NoOpLayer() Dense(3, 2, tanh_fast) Dense(3, 1, tanh_fast); - Dense(2, 4, tanh_fast) Dense(2, 3, tanh_fast) NoOpLayer() Dense(2, 1, tanh_fast); - Dense(1, 4, tanh_fast) Dense(1, 3, tanh_fast) Dense(1, 2, tanh_fast) NoOpLayer() - ], - ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0), -) - -model(rand(Float32, 4, 1)) -``` - -See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref) -""" -struct MultiScaleDeepEquilibriumNetwork{N,M1<:Parallel,M2<:Union{Chain,FChain},RE1,RE2,P,A,K,S} <: - AbstractDeepEquilibriumNetwork - main_layers::M1 - mapping_layers::M2 - main_layers_re::RE1 - mapping_layers_re::RE2 - p::P - ordered_split_idxs::NTuple{N,Int} +struct MultiScaleDeepEquilibriumNetwork{N,L,M,A,S,K} <: AbstractDeepEquilibriumNetwork + model::M solver::A - kwargs::K sensealg::S - stats::DEQTrainingStats - - function MultiScaleDeepEquilibriumNetwork(main_layers::Parallel, mapping_layers::Union{Chain,FChain}, re1, re2, - p, ordered_split_idxs, solver::A, kwargs::K, sensealg::S, stats) where {A,K,S} - @assert length(mapping_layers) == 2 - @assert mapping_layers[1] isa MultiParallelNet - - p_main_layers, re_main_layers = destructure_parameters(main_layers) - p_mapping_layers, re_mapping_layers = destructure_parameters(mapping_layers) + scales::NTuple{N,NTuple{L,Int64}} + kwargs::K +end - ordered_split_idxs = tuple(cumsum([0, length(p_main_layers), length(p_mapping_layers)])...) +function initialstates(rng::AbstractRNG, deq::MultiScaleDeepEquilibriumNetwork) + return (model=initialstates(rng, deq.model), split_idxs=Tuple(vcat(0, cumsum(prod.(deq.scales))...))) +end - p = p === nothing ? vcat(p_main_layers, p_mapping_layers) : convert(typeof(p_main_layers), p) +function MultiScaleDeepEquilibriumNetwork( + main_layers::Tuple, + mapping_layers::Matrix, + post_fuse_layer::Union{Nothing,Tuple}, + solver, + scales; + sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + kwargs..., +) + l1 = ExplicitFluxLayers.Parallel(nothing, main_layers...) + l2 = ExplicitFluxLayers.BranchLayer( + ExplicitFluxLayers.Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)... + ) + model = if post_fuse_layer === nothing + ExplicitFluxLayers.Chain(l1, l2) + else + l3 = ExplicitFluxLayers.Parallel(nothing, post_fuse_layer...) + ExplicitFluxLayers.Chain(l1, l2, l3) + end + return MultiScaleDeepEquilibriumNetwork(model, solver, sensealg, scales, kwargs) +end - return new{length(ordered_split_idxs), - typeof.((main_layers, mapping_layers, re_main_layers, re_mapping_layers, p))..., - A,K,S}(main_layers, mapping_layers, re_main_layers, re_mapping_layers, p, ordered_split_idxs, - solver, kwargs, sensealg, stats) +function get_initial_condition_mdeq(scales::NTuple, x::AbstractArray{T,N}, st::NamedTuple{fields}) where {T,N,fields} + if hasproperty(st, :initial_condition) && size(st.initial_condition, 2) == size(x, N) + return st.initial_condition, st end + u0 = vcat(map(scale -> fill!(similar(x, prod(scale), size(x, N)), T(0)), scales)...) + st = merge((initial_condition=u0,), st) + return u0, st end -Flux.@functor MultiScaleDeepEquilibriumNetwork +Zygote.@nograd get_initial_condition_mdeq -function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, solver; - post_fuse_layers::Union{Tuple,Nothing}=nothing, p=nothing, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - mapping_layers = if post_fuse_layers === nothing - @assert size(mapping_layers, 1) == size(mapping_layers, 2) == length(main_layers) - FChain(MultiParallelNet(Parallel.(+, map(x -> tuple(x...), eachcol(mapping_layers)))...), NoOpLayer()) - else - @assert size(mapping_layers, 1) == size(mapping_layers, 2) == length(main_layers) == length(post_fuse_layers) - FChain(MultiParallelNet(Parallel.(+, map(x -> tuple(x...), eachcol(mapping_layers)))...), - Parallel(flatten_merge, post_fuse_layers...)) +function (deq::MultiScaleDeepEquilibriumNetwork{N})(x::AbstractArray{T}, ps::NamedTuple, st::NamedTuple) where {N,T} + z, st = get_initial_condition_mdeq(deq.scales, x, st) + + function dudt_(u, p, t) + u_split = split_and_reshape(u, st.split_idxs, deq.scales) + u_, st_ = deq.model(((u_split[1], x), u_split[2:N]...), p, st.model) + return u_, st_ end - main_layers = Parallel(flatten_merge, main_layers...) + dudt(u, p, t) = vcat(Flux.flatten.(dudt_(u, p, t)[1])...) .- u - return MultiScaleDeepEquilibriumNetwork(main_layers, mapping_layers, nothing, nothing, p, - nothing, solver, kwargs, sensealg, DEQTrainingStats(0)) -end + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps) + sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) + z_star, st_ = dudt_(sol.u, ps, nothing)::Tuple{NTuple{N,typeof(x)},typeof(st.model)} -function (mdeq::MultiScaleDeepEquilibriumNetwork)(x::AbstractArray{T}) where {T} - current_nfe = mdeq.stats.nfe + residual = Zygote.@ignore dudt(sol.u, ps, nothing) ::typeof(x) - z = zero(x) - initial_conditions = Zygote.@ignore map(l -> l(z), map(l -> l.layers[1], mdeq.mapping_layers[1].layers)) - u_sizes = Zygote.@ignore size.(initial_conditions) - u_split_idxs = Zygote.@ignore vcat(0, cumsum(length.(initial_conditions) .÷ size(x, ndims(x)))...) - u0 = Zygote.@ignore vcat(Flux.flatten.(initial_conditions)...) + @set! st.model = st_ - N = length(u_sizes) - update_is_variational_hidden_dropout_mask_reset_allowed(false) + return ( + (z_star, DeepEquilibriumSolution(vcat(Flux.flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st + ) +end - function dudt_(u, _p) - mdeq.stats.nfe += 1 +struct MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh,A,S,K} <: AbstractSkipDeepEquilibriumNetwork + model::M + shortcut::Sh + solver::A + sensealg::S + scales::NTuple{N,NTuple{L,Int64}} + kwargs::K +end - uₛ = split_array_by_indices(u, u_split_idxs) - p1, p2 = split_array_by_indices(_p, mdeq.ordered_split_idxs) +function initialstates(rng::AbstractRNG, deq::MultiScaleSkipDeepEquilibriumNetwork) + return ( + model=initialstates(rng, deq.model), + shortcut=initialstates(rng, deq.shortcut), + split_idxs=Tuple(vcat(0, cumsum(prod.(deq.scales))...)), + ) +end - u_reshaped = ntuple(i -> reshape(uₛ[i], u_sizes[i]), N) +function MultiScaleSkipDeepEquilibriumNetwork( + main_layers::Tuple, + mapping_layers::Matrix, + post_fuse_layer::Union{Nothing,Tuple}, + shortcut_layers::Union{Nothing,Tuple}, + solver, + scales; + sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + kwargs..., +) + l1 = ExplicitFluxLayers.Parallel(nothing, main_layers...) + l2 = ExplicitFluxLayers.BranchLayer( + ExplicitFluxLayers.Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)... + ) + model = if post_fuse_layer === nothing + ExplicitFluxLayers.Chain(l1, l2) + else + l3 = ExplicitFluxLayers.Parallel(nothing, post_fuse_layer...) + ExplicitFluxLayers.Chain(l1, l2, l3) + end + shortcut = if shortcut_layers === nothing + nothing + else + ExplicitFluxLayers.Parallel(nothing, shortcut_layers...) + end + return MultiScaleSkipDeepEquilibriumNetwork(model, shortcut, solver, sensealg, scales, kwargs) +end - main_layers_output = mdeq.main_layers_re(p1)((u_reshaped[1], x), u_reshaped[2:end]...) +function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( + x::AbstractArray{T}, ps::NamedTuple, st::NamedTuple +) where {N,L,M,Sh,T} + z, st = if Sh == Nothing + u0, st_ = get_initial_condition_mdeq(deq.scales, x, st) + u0_ = split_and_reshape(u0, st.split_idxs, deq.scales) + z0, st__ = deq.model(((u0_[1], x), u0_[2:N]...), ps.model, st_.model) + @set! st_.model = st__ + (vcat(Flux.flatten.(z0)...), st_) + else + z0, st_ = deq.shortcut(x, ps.shortcut, st.shortcut) + @set! st.shortcut = st_ + (vcat(Flux.flatten.(z0)...), st) + end - return mdeq.mapping_layers_re(p2)(main_layers_output) + function dudt_(u, p, t) + u_split = split_and_reshape(u, st.split_idxs, deq.scales) + u_, st_ = deq.model(((u_split[1], x), u_split[2:N]...), p, st.model) + return u_, st_ end - dudt(u, _p, t) = vcat(Flux.flatten.(dudt_(u, _p))...) .- u + dudt(u, p, t) = vcat(Flux.flatten.(dudt_(u, p, t)[1])...) .- u - ssprob = SteadyStateProblem(dudt, u0, mdeq.p) - res = solve(ssprob, mdeq.solver; u0=u0, sensealg=mdeq.sensealg, mdeq.kwargs...).u + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model) + sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) + z_star, st_ = dudt_(sol.u, ps.model, nothing)::Tuple{NTuple{N,typeof(x)},typeof(st.model)} - x_ = dudt_(res, mdeq.p) + residual = Zygote.@ignore dudt(sol.u, ps.model, nothing) ::typeof(x) - residual = Zygote.@ignore Tuple(map((iu) -> reshape(iu[2], u_sizes[iu[1]]), - enumerate(split_array_by_indices(dudt(res, mdeq.p, nothing), u_split_idxs)))) - update_is_variational_hidden_dropout_mask_reset_allowed(true) + @set! st.model = st_ - return x_, DeepEquilibriumSolution(x_, initial_conditions, residual, T(0), mdeq.stats.nfe - current_nfe) + return ( + (z_star, DeepEquilibriumSolution(vcat(Flux.flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st + ) end diff --git a/src/layers/sdeq.jl b/src/layers/sdeq.jl deleted file mode 100644 index 06728a89..00000000 --- a/src/layers/sdeq.jl +++ /dev/null @@ -1,130 +0,0 @@ -""" - SkipDeepEquilibriumNetwork(model, shortcut, solver; p=nothing, jacobian_regularization::Bool=false, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - SkipDeepEquilibriumNetwork(model, solver; p=nothing, jacobian_regularization::Bool=false, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - -Skip Deep Equilibrium Network as proposed in [pal2022mixing](@cite) - -## Arguments - -* `model`: Explicit Neural Network which takes 2 inputs -* `shortcut`: Shortcut for the network (If not given, then we create SkipDEQV2) -* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) -* `jacobian_regularization`: If true, Jacobian Loss is computed and stored in the [`DeepEquilibriumSolution`](@ref) -* `p`: Optional parameters for the `model` -* `sensealg`: See [`SteadyStateAdjoint`](@ref) -* `kwargs`: Additional Parameters that are directly passed to `solve` - -## Example - -```julia -# SkipDEQ -model = SkipDeepEquilibriumNetwork( - Parallel( - +, - Dense(2, 2; bias=false), - Dense(2, 2; bias=false) - ), - Dense(2, 2), - ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0) -) - -model(rand(Float32, 2, 1)) - -# SkipDEQV2 -model = SkipDeepEquilibriumNetwork( - Parallel( - +, - Dense(2, 2; bias=false), - Dense(2, 2; bias=false) - ), - ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0) -) - -model(rand(Float32, 2, 1)) -``` - -See also: [`DeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref) -""" -struct SkipDeepEquilibriumNetwork{M,S,J,P,RE1,RE2,A,Se,K} <: AbstractDeepEquilibriumNetwork - jacobian_regularization::Bool - model::M - shortcut::S - p::P - re1::RE1 - re2::RE2 - split_idx::Int - solver::A - kwargs::K - sensealg::Se - stats::DEQTrainingStats - - function SkipDeepEquilibriumNetwork(jacobian_regularization, model, shortcut, p, re1, re2, - split_idx, solver, kwargs, sensealg, stats) - p1, re1 = destructure_parameters(model) - split_idx = length(p1) - p2, re2 = shortcut === nothing ? ((eltype(p1))[], nothing) : destructure_parameters(shortcut) - - p = p === nothing ? vcat(p1, p2) : eltype(p1).(p) - - return new{typeof(model),typeof(shortcut),jacobian_regularization,typeof(p),typeof(re1), - typeof(re2),typeof(solver),typeof(sensealg),typeof(kwargs)}(jacobian_regularization, model, shortcut, p, - re1, re2, split_idx, solver, kwargs, - sensealg, stats) - end -end - -Flux.@functor SkipDeepEquilibriumNetwork - -function Base.show(io::IO, l::SkipDeepEquilibriumNetwork{M,S,J}) where {M,S,J} - shortcut_ps = l.split_idx == length(l.p) ? 0 : length(l.p) - l.split_idx - return print(io, "SkipDeepEquilibriumNetwork(jacobian_regularization = $J, ", - "shortcut_parameter_count = $shortcut_ps) ", string(length(l.p)), " Trainable Parameters") -end - -function SkipDeepEquilibriumNetwork(model, shortcut, solver; p=nothing, jacobian_regularization::Bool=false, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - return SkipDeepEquilibriumNetwork(jacobian_regularization, model, shortcut, p, nothing, - nothing, 0, solver, kwargs, sensealg, DEQTrainingStats(0)) -end - -function SkipDeepEquilibriumNetwork(model, solver; p=nothing, jacobian_regularization::Bool=false, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - return SkipDeepEquilibriumNetwork(jacobian_regularization, model, nothing, p, nothing, - nothing, 0, solver, kwargs, sensealg, DEQTrainingStats(0)) -end - -function (deq::SkipDeepEquilibriumNetwork)(x::AbstractArray{T}) where {T} - p1, p2 = deq.p[1:(deq.split_idx)], deq.p[(deq.split_idx + 1):end] - z = deq.re2(p2)(x)::typeof(x) - - current_nfe = deq.stats.nfe - - # Dummy call to ensure that mask is generated - Zygote.@ignore _ = deq.re1(p1)(z, x) - - z_star = solve_steady_state_problem(deq.re1, p1, x, z, deq.sensealg, deq.solver; dudt=nothing, - update_nfe=() -> (deq.stats.nfe += 1), deq.kwargs...) - - jac_loss = (deq.jacobian_regularization ? compute_deq_jacobian_loss(deq.re1, p1, z_star, x) : T(0)) ::T - - residual = Zygote.@ignore z_star .- deq.re1(p1)(z_star, x) - - return z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, deq.stats.nfe - current_nfe) -end - -function (deq::SkipDeepEquilibriumNetwork{M,Nothing})(x::AbstractArray{T}) where {M,T} - z = deq.re1(deq.p)(zero(x), x)::typeof(x) - - current_nfe = deq.stats.nfe - - z_star = solve_steady_state_problem(deq.re1, deq.p, x, z, deq.sensealg, deq.solver; dudt=nothing, - update_nfe=() -> (deq.stats.nfe += 1), deq.kwargs...) - - jac_loss = (deq.jacobian_regularization ? compute_deq_jacobian_loss(deq.re1, deq.p, z_star, x) : T(0)) ::T - - residual = Zygote.@ignore z_star .- deq.re1(deq.p)(z_star, x) - - return z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, deq.stats.nfe - current_nfe) -end diff --git a/src/layers/smdeq.jl b/src/layers/smdeq.jl deleted file mode 100644 index 7aa47a77..00000000 --- a/src/layers/smdeq.jl +++ /dev/null @@ -1,235 +0,0 @@ -""" - MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, shortcut_layers::Tuple, - solver; post_fuse_layers::Union{Tuple,Nothing}=nothing, p=nothing, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, solver; - post_fuse_layers::Union{Tuple,Nothing}=nothing, p=nothing, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - -Multiscale Deep Equilibrium Network as proposed in [baimultiscale2020](@cite) - -## Arguments - -* `main_layers`: Tuple of Explicit Neural Networks. The first network needs to take 2 inputs, the other ones only take 1 input -* `mapping_layers`: Matrix of Explicit Neural Networks. The ``(i, j)^{th}`` network takes the output of ``i^{th}`` `main_layer` - and passes it to the ``j^{th}`` `main_layer` -* `shortcut_layers`: Shortcuts for the network (If not given, then we create SkipDEQV2) -* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) -* `post_fuse_layers`: Tuple of Explicit Neural Networks. Applied after the `mapping_layers` (Default: `nothing`) -* `p`: Optional parameters for the `model` -* `sensealg`: See [`SteadyStateAdjoint`](@ref) -* `kwargs`: Additional Parameters that are directly passed to `solve` - -## Example - -```julia -# MSkipDEQ -model = MultiScaleSkipDeepEquilibriumNetwork( - ( - Parallel(+, Dense(4, 4, tanh_fast), Dense(4, 4, tanh_fast)), - Dense(3, 3, tanh_fast), Dense(2, 2, tanh_fast), - Dense(1, 1, tanh_fast) - ), - [ - NoOpLayer() Dense(4, 3, tanh_fast) Dense(4, 2, tanh_fast) Dense(4, 1, tanh_fast); - Dense(3, 4, tanh_fast) NoOpLayer() Dense(3, 2, tanh_fast) Dense(3, 1, tanh_fast); - Dense(2, 4, tanh_fast) Dense(2, 3, tanh_fast) NoOpLayer() Dense(2, 1, tanh_fast); - Dense(1, 4, tanh_fast) Dense(1, 3, tanh_fast) Dense(1, 2, tanh_fast) NoOpLayer() - ], - ( - Dense(4, 4, tanh_fast), - Dense(4, 3, tanh_fast), - Dense(4, 2, tanh_fast), - Dense(4, 1, tanh_fast) - ), - ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0), -) - -model(rand(Float32, 4, 1)) - - -# MSkipDEQV2 -model = MultiScaleSkipDeepEquilibriumNetwork( - ( - Parallel(+, Dense(4, 4, tanh_fast), Dense(4, 4, tanh_fast)), - Dense(3, 3, tanh_fast), Dense(2, 2, tanh_fast), - Dense(1, 1, tanh_fast) - ), - [ - NoOpLayer() Dense(4, 3, tanh_fast) Dense(4, 2, tanh_fast) Dense(4, 1, tanh_fast); - Dense(3, 4, tanh_fast) NoOpLayer() Dense(3, 2, tanh_fast) Dense(3, 1, tanh_fast); - Dense(2, 4, tanh_fast) Dense(2, 3, tanh_fast) NoOpLayer() Dense(2, 1, tanh_fast); - Dense(1, 4, tanh_fast) Dense(1, 3, tanh_fast) Dense(1, 2, tanh_fast) NoOpLayer() - ], - ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0), -) - -model(rand(Float32, 4, 1)) -``` - -See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref) -""" -struct MultiScaleSkipDeepEquilibriumNetwork{M3<:Union{Nothing,Parallel},N,M1<:Parallel,M2<:Union{Chain,FChain},RE1, - RE2,RE3,P,A,K,S} <: AbstractDeepEquilibriumNetwork - main_layers::M1 - mapping_layers::M2 - shortcut_layers::M3 - main_layers_re::RE1 - mapping_layers_re::RE2 - shortcut_layers_re::RE3 - p::P - ordered_split_idxs::NTuple{N,Int} - solver::A - kwargs::K - sensealg::S - stats::DEQTrainingStats - - function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Parallel, - mapping_layers::Union{Chain,FChain}, - shortcut_layers::Union{Nothing,Parallel}, re1, re2, re3, p, - ordered_split_idxs, solver::A, kwargs::K, sensealg::S, - stats) where {A,K,S} - @assert length(mapping_layers) == 2 - @assert mapping_layers[1] isa MultiParallelNet - - p_main_layers, re_main_layers = destructure_parameters(main_layers) - p_mapping_layers, re_mapping_layers = destructure_parameters(mapping_layers) - p_shortcut_layers, re_shortcut_layers = shortcut_layers === nothing ? ([], nothing) : - destructure_parameters(shortcut_layers) - - ordered_split_idxs = tuple(cumsum([0, length(p_main_layers), length(p_mapping_layers), - length(p_shortcut_layers)])...) - - p = p === nothing ? vcat(p_main_layers, p_mapping_layers, p_shortcut_layers) : convert(typeof(p_main_layers), p) - - return new{typeof(shortcut_layers),length(ordered_split_idxs), - typeof.((main_layers, mapping_layers, re_main_layers, re_mapping_layers, re_shortcut_layers, p))..., - A,K,S}(main_layers, mapping_layers, shortcut_layers, re_main_layers, - re_mapping_layers, re_shortcut_layers, p, ordered_split_idxs, solver, kwargs, sensealg, stats) - end -end - -Flux.@functor MultiScaleSkipDeepEquilibriumNetwork - -function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, shortcut_layers::Tuple, - solver; post_fuse_layers::Union{Tuple,Nothing}=nothing, p=nothing, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - mapping_layers = if post_fuse_layers === nothing - @assert size(mapping_layers, 1) == size(mapping_layers, 2) == length(main_layers) == length(shortcut_layers) - Chain(MultiParallelNet(Parallel.(+, map(x -> tuple(x...), eachcol(mapping_layers)))...), NoOpLayer()) - else - @assert size(mapping_layers, 1) == - size(mapping_layers, 2) == - length(main_layers) == - length(post_fuse_layers) == - length(shortcut_layers) - Chain(MultiParallelNet(Parallel.(+, map(x -> tuple(x...), eachcol(mapping_layers)))...), - Parallel(flatten_merge, post_fuse_layers...)) - end - - main_layers = Parallel(flatten_merge, main_layers...) - shortcut_layers = Parallel(flatten_merge, shortcut_layers...) - - return MultiScaleSkipDeepEquilibriumNetwork(main_layers, mapping_layers, shortcut_layers, - nothing, nothing, nothing, p, nothing, solver, kwargs, sensealg, - DEQTrainingStats(0)) -end - -function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, solver; - post_fuse_layers::Union{Tuple,Nothing}=nothing, p=nothing, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs...) - mapping_layers = if post_fuse_layers === nothing - @assert size(mapping_layers, 1) == size(mapping_layers, 2) == length(main_layers) - Chain(MultiParallelNet(Parallel.(+, map(x -> tuple(x...), eachcol(mapping_layers)))...), NoOpLayer()) - else - @assert size(mapping_layers, 1) == size(mapping_layers, 2) == length(main_layers) == length(post_fuse_layers) - Chain(MultiParallelNet(Parallel.(+, map(x -> tuple(x...), eachcol(mapping_layers)))...), - Parallel(flatten_merge, post_fuse_layers...)) - end - - main_layers = Parallel(flatten_merge, main_layers...) - - return MultiScaleSkipDeepEquilibriumNetwork(main_layers, mapping_layers, nothing, nothing, - nothing, nothing, p, nothing, solver, kwargs, sensealg, - DEQTrainingStats(0)) -end - -function (mdeq::MultiScaleSkipDeepEquilibriumNetwork)(x::AbstractArray{T}) where {T} - current_nfe = mdeq.stats.nfe - - p1, p2, p3 = split_array_by_indices(mdeq.p, mdeq.ordered_split_idxs) - initial_conditions = mdeq.shortcut_layers_re(p3)(x) - u_sizes = size.(initial_conditions) - u_split_idxs = vcat(0, cumsum(length.(initial_conditions) .÷ size(x, ndims(x)))...) - u0 = Zygote.@ignore vcat(Flux.flatten.(initial_conditions)...) - - N = length(u_sizes) - update_is_variational_hidden_dropout_mask_reset_allowed(false) - - function dudt_(u, _p) - mdeq.stats.nfe += 1 - - uₛ = split_array_by_indices(u, u_split_idxs) - p1, p2, _ = split_array_by_indices(_p, mdeq.ordered_split_idxs) - - u_reshaped = ntuple(i -> reshape(uₛ[i], u_sizes[i]), N) - - main_layers_output = mdeq.main_layers_re(p1)((u_reshaped[1], x), u_reshaped[2:end]...) - - return mdeq.mapping_layers_re(p2)(main_layers_output) - end - - dudt(u, _p, t) = vcat(Flux.flatten.(dudt_(u, _p))...) .- u - - ssprob = SteadyStateProblem(dudt, u0, mdeq.p) - res = solve(ssprob, mdeq.solver; u0=u0, sensealg=mdeq.sensealg, mdeq.kwargs...).u - x_ = dudt_(res, mdeq.p) - residual = Zygote.@ignore Tuple(map((iu) -> reshape(iu[2], u_sizes[iu[1]]), - enumerate(split_array_by_indices(dudt(res, mdeq.p, nothing), u_split_idxs)))) - - update_is_variational_hidden_dropout_mask_reset_allowed(true) - - return x_, DeepEquilibriumSolution(x_, initial_conditions, residual, T(0), mdeq.stats.nfe - current_nfe) -end - -function (mdeq::MultiScaleSkipDeepEquilibriumNetwork{Nothing})(x::AbstractArray{T}) where {T} - current_nfe = mdeq.stats.nfe - - p1, p2 = split_array_by_indices(mdeq.p, mdeq.ordered_split_idxs) - - _initial_conditions = Zygote.@ignore [l(x) for l in map(l -> l.layers[1], mdeq.mapping_layers[1].layers)] - _initial_conditions = mdeq.mapping_layers_re(p2)((x, zero.(_initial_conditions[2:end])...)) - initial_conditions = mdeq.main_layers_re(p1)((zero(_initial_conditions[1]), _initial_conditions[1]), - _initial_conditions[2:end]...) - u_sizes = size.(initial_conditions) - u_split_idxs = vcat(0, cumsum(length.(initial_conditions) .÷ size(x, ndims(x)))...) - u0 = vcat(Flux.flatten.(initial_conditions)...) - - N = length(u_sizes) - update_is_variational_hidden_dropout_mask_reset_allowed(false) - - function dudt_(u, _p) - mdeq.stats.nfe += 1 - - uₛ = split_array_by_indices(u, u_split_idxs) - p1, p2, _ = split_array_by_indices(_p, mdeq.ordered_split_idxs) - - u_reshaped = ntuple(i -> reshape(uₛ[i], u_sizes[i]), N) - - main_layers_output = mdeq.main_layers_re(p1)((u_reshaped[1], x), u_reshaped[2:end]...) - - return mdeq.mapping_layers_re(p2)(main_layers_output) - end - - dudt(u, _p, t) = vcat(Flux.flatten.(dudt_(u, _p))...) .- u - - ssprob = SteadyStateProblem(dudt, u0, mdeq.p) - res = solve(ssprob, mdeq.solver; u0=u0, sensealg=mdeq.sensealg, mdeq.kwargs...).u - x_ = dudt_(res, mdeq.p) - residual = Zygote.@ignore Tuple(map((iu) -> reshape(iu[2], u_sizes[iu[1]]), - enumerate(split_array_by_indices(dudt(res, mdeq.p, nothing), u_split_idxs)))) - - update_is_variational_hidden_dropout_mask_reset_allowed(true) - - return x_, DeepEquilibriumSolution(x_, initial_conditions, residual, T(0), mdeq.stats.nfe - current_nfe) -end diff --git a/src/layers/utils.jl b/src/layers/utils.jl deleted file mode 100644 index 3be36acf..00000000 --- a/src/layers/utils.jl +++ /dev/null @@ -1,76 +0,0 @@ -""" - DeepEquilibriumSolution(z_star, u₀, residual, jacobian_loss, nfe) - -Stores the solution of a DeepEquilibriumNetwork and its variants. - -## Fields - -* `z_star`: Steady-State or the value reached due to maxiters -* `u₀`: Initial Condition -* `residual`: Difference of the ``z^*`` and ``f(z^*, x)`` -* `jacobian_loss`: Jacobian Stabilization Loss (see individual networks to see how it can be computed) -* `nfe`: Number of Function Evaluations -""" -struct DeepEquilibriumSolution{T,R<:AbstractFloat} - z_star::T - u₀::T - residual::T - jacobian_loss::R - nfe::Int -end - -function Base.show(io::IO, l::DeepEquilibriumSolution) - println(io, "DeepEquilibriumSolution(") - println(io, "\tz_star: ", l.z_star) - println(io, "\tinitial_condition: ", l.u₀) - println(io, "\tresidual: ", l.residual) - println(io, "\tjacobian_loss: ", l.jacobian_loss) - println(io, "\tNFE: ", l.nfe) - print(io, ")") - return nothing -end - - -function solve_steady_state_problem(re, p, x, u0, sensealg, args...; dudt=nothing, update_nfe=() -> (), kwargs...) - # Solving the equation f(u) - u = du = 0 - update_is_variational_hidden_dropout_mask_reset_allowed(false) - - dudt_ = if dudt === nothing - function (u, _p, t) - update_nfe() - return re(_p)(u, x) .- u - end - else - dudt - end - - ssprob = SteadyStateProblem(dudt_, u0, p) - sol = solve(ssprob, args...; u0=u0, sensealg=sensealg, kwargs...) - - z = re(p)(sol.u, x)::typeof(x) - update_nfe() - - update_is_variational_hidden_dropout_mask_reset_allowed(true) - - return z -end - -function solve_depth_k_neural_network(re, p, x, u0, depth) - update_is_variational_hidden_dropout_mask_reset_allowed(false) - model = re(p) - for _ in 1:depth - u0 = model(u0, x) - end - update_is_variational_hidden_dropout_mask_reset_allowed(true) - return u0 -end - - -flatten(x::AbstractArray{T,N}) where {T,N} = reshape(x, :, size(x, N)) - -Zygote.@adjoint function flatten(x::AbstractArray{T,N}) where {T,N} - s = size(x) - res = reshape(x, :, s[N]) - flatten_sensitivity(Δ) = (reshape(Δ, s),) - return res, flatten_sensitivity -end diff --git a/src/models/basics.jl b/src/models/basics.jl deleted file mode 100644 index 7701f7c8..00000000 --- a/src/models/basics.jl +++ /dev/null @@ -1,46 +0,0 @@ -""" - MultiParallelNet(layers...) - MultiParallelNet(layers::Tuple) - MultiParallelNet(layers::Vector) - -Creates a MultiParallelNet mostly used for MultiScale Models. It takes a list of inputs -and passes all of them through each `layer` and returns a tuple of outputs. - -## Example - -``` -Model := MultiParallelNet(L1, L2, L3) - -Model(X1, X2) := (Model.L1(X1, X2), Model.L2(X1, X2), Model.L3(X1, X2)) -``` -""" -struct MultiParallelNet{L} - layers::L - - function MultiParallelNet(args...) - layers = tuple(args...) - return new{typeof(layers)}(layers) - end - - MultiParallelNet(layers::Tuple) = new{typeof(layers)}(layers) - - MultiParallelNet(layers::Vector) = MultiParallelNet(layers...) -end - -Flux.@functor MultiParallelNet - -function (mpn::MultiParallelNet)(x::Union{Tuple,Vector}) - buf = Zygote.Buffer(Vector{Any}(undef, length(mpn.layers))) - for (i, l) in enumerate(mpn.layers) - buf[i] = l(x...) - end - return Tuple(copy(buf)) -end - -function (mpn::MultiParallelNet)(args...) - buf = Zygote.Buffer(Vector{Any}(undef, length(mpn.layers))) - for (i, l) in enumerate(mpn.layers) - buf[i] = l(args...) - end - return Tuple(copy(buf)) -end diff --git a/src/models/chain.jl b/src/models/chain.jl deleted file mode 100644 index 1fe93c92..00000000 --- a/src/models/chain.jl +++ /dev/null @@ -1,61 +0,0 @@ -# Default to nothing happening -reset_mask!(x) = nothing - -""" - DEQChain(pre_deq, deq, post_deq) - DEQChain(layers...) - -A Sequential Model containing a DEQ. - -!!! note - The Model should contain exactly 1 `AbstractDEQ` Layer -""" -struct DEQChain{P1,D<:AbstractDeepEquilibriumNetwork,P2} - pre_deq::P1 - deq::D - post_deq::P2 -end - -function DEQChain(layers...) - pre_deq, post_deq, deq, encounter_deq = [], [], nothing, false - for l in layers - if typeof(l) <: AbstractDeepEquilibriumNetwork - @assert !encounter_deq "Can have only 1 DEQ Layer in the Chain!!!" - deq = l - encounter_deq = true - continue - end - push!(encounter_deq ? post_deq : pre_deq, l) - end - @assert encounter_deq "No DEQ Layer in the Chain!!! Maybe you wanted to use Chain" - pre_deq = length(pre_deq) == 0 ? NoOpLayer() : (length(pre_deq) == 1 ? pre_deq[1] : FChain(pre_deq...)) - post_deq = length(post_deq) == 0 ? NoOpLayer() : (length(post_deq) == 1 ? post_deq[1] : FChain(post_deq...)) - return DEQChain(pre_deq, deq, post_deq) -end - -Flux.@functor DEQChain - -function (deq::DEQChain)(x; kwargs...) - x1 = deq.pre_deq(x) - x2, deq_soln = deq.deq(x1; kwargs...) - x3 = deq.post_deq(x2) - return x3, deq_soln -end - -function get_and_clear_nfe!(model::DEQChain) - nfe = model.deq.stats.nfe - model.deq.stats.nfe = 0 - return nfe -end - -function Base.show(io::IO, model::DEQChain) - l1 = length(destructure_parameters(model)[1]) - println(io, "DEQChain(") - print(io, "\t") - show(io, model.pre_deq) - print(io, "\n\t") - show(io, model.deq) - print(io, "\n\t") - show(io, model.post_deq) - return print(io, "\n) $l1 Trainable Parameters") -end diff --git a/src/operator.jl b/src/operator.jl new file mode 100644 index 00000000..797fd0cd --- /dev/null +++ b/src/operator.jl @@ -0,0 +1,25 @@ +struct ZygotePullbackMultiplyOperator{T,F,S} + f::F + s::S +end + +Base.deepcopy(op::ZygotePullbackMultiplyOperator) = op + +Base.size(z::ZygotePullbackMultiplyOperator) = (prod(z.s), prod(z.s)) +Base.size(z::ZygotePullbackMultiplyOperator, ::Int64) = prod(z.s) + +Base.eltype(::ZygotePullbackMultiplyOperator{T}) where {T} = T + +function LinearAlgebra.mul!( + du::AbstractVector, + L::ZygotePullbackMultiplyOperator, + x::AbstractVector, +) + du .= vec(L * x) +end + +function Base.:*(L::ZygotePullbackMultiplyOperator, x::AbstractVector) + return L.f(reshape(x, L.s))[1] +end + +SciMLBase.isinplace(z::ZygotePullbackMultiplyOperator, ::Int64) = false \ No newline at end of file diff --git a/src/solve.jl b/src/solve.jl index fafc26ca..8abb424a 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,245 +1,10 @@ -""" - ContinuousDEQSolver(alg=VCABM4(); mode::Symbol=:rel_deq_default, abstol=1e-8, reltol=1e-8, tspan=Inf) - -Solver for Continuous DEQ Problem ([pal2022mixing](@cite)). Similar to `DynamicSS` but provides more flexibility needed -for solving DEQ problems. - -## Arguments - -* `alg`: Algorithm to solve the ODEProblem. (Default: `VCABM4()`) -* `mode`: Termination Mode of the solver. See below for a description of the various termination conditions (Default: `:rel_deq_default`) -* `abstol`: Absolute tolerance for termination. (Default: `1e-8`) -* `reltol`: Relative tolerance for termination. (Default: `1e-8`) -* `tspan`: Time span. Users should not change this value, instead control termination through `maxiters` in `solve` (Default: `Inf`) - -## Termination Modes - -#### Termination on Absolute Tolerance - -* `:abs`: Terminates if ``all \\left( | \\frac{\\partial u}{\\partial t} | \\leq abstol \\right)`` -* `:abs_norm`: Terminates if ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq abstol`` -* `:abs_deq_default`: Essentially `abs_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) -* `:abs_deq_best`: Same as `:abs_deq_default` but uses the best solution found so far, i.e. deviates only if the solution has not converged - -#### Termination on Relative Tolerance - -* `:rel`: Terminates if ``all \\left(| \\frac{\\partial u}{\\partial t} | \\leq reltol \\times | u | \\right)`` -* `:rel_norm`: Terminates if ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq reltol \\times \\| \\frac{\\partial u}{\\partial t} + u \\|`` -* `:rel_deq_default`: Essentially `rel_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) -* `:rel_deq_best`: Same as `:rel_deq_default` but uses the best solution found so far, i.e. deviates only if the solution has not converged - -#### Termination using both Absolute and Relative Tolerances - -* `:norm`: Terminates if ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq reltol \\times \\| \\frac{\\partial u}{\\partial t} + u \\|`` & - ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq abstol`` -* `fallback`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. This is usable for small problems - but doesn't scale well for neural networks, and should be avoided unless absolutely necessary - -See also: [`DiscreteDEQSolver`](@ref) - -!!! note - This will be upstreamed to DiffEqSensitivity in the later releases of the package -""" -struct ContinuousDEQSolver{M,A,AT,RT,TS} <: SteadyStateDiffEq.SteadyStateDiffEqAlgorithm - alg::A - abstol::AT - reltol::RT - tspan::TS -end - -function ContinuousDEQSolver(alg=VCABM4(); mode::Symbol=:rel_deq_default, abstol=1e-8, reltol=1e-8, tspan=Inf) - return ContinuousDEQSolver{Val(mode),typeof(alg),typeof(abstol),typeof(reltol),typeof(tspan)}(alg, abstol, reltol, tspan) -end - -function terminate_condition_reltol(integrator, abstol, reltol, min_t) - return all(abs.(DiffEqBase.get_du(integrator)) .<= reltol .* abs.(integrator.u)) -end - -function terminate_condition_reltol_norm(integrator, abstol, reltol, min_t) - du = DiffEqBase.get_du(integrator) - return norm(du) <= reltol * norm(du .+ integrator.u) -end - -function terminate_condition_abstol(integrator, abstol, reltol, min_t) - return all(abs.(DiffEqBase.get_du(integrator)) .<= abstol) -end - -function terminate_condition_abstol_norm(integrator, abstol, reltol, min_t) - return norm(DiffEqBase.get_du(integrator)) <= abstol -end - -function terminate_condition(integrator, abstol, reltol, min_t) - return all((abs.(DiffEqBase.get_du(integrator)) .<= reltol .* abs.(integrator.u)) .& - (abs.(DiffEqBase.get_du(integrator)) .<= abstol)) -end - -function terminate_condition_norm(integrator, abstol, reltol, min_t) - du = DiffEqBase.get_du(integrator) - du_norm = norm(du) - return (du_norm <= reltol * norm(du .+ integrator.u)) && (du_norm <= abstol) -end - -get_terminate_condition(::ContinuousDEQSolver{Val(:abs)}, args...; kwargs...) = terminate_condition_abstol -get_terminate_condition(::ContinuousDEQSolver{Val(:abs_norm)}, args...; kwargs...) = terminate_condition_abstol_norm -get_terminate_condition(::ContinuousDEQSolver{Val(:rel)}, args...; kwargs...) = terminate_condition_reltol -get_terminate_condition(::ContinuousDEQSolver{Val(:rel_norm)}, args...; kwargs...) = terminate_condition_reltol_norm -get_terminate_condition(::ContinuousDEQSolver{Val(:norm)}, args...; kwargs...) = terminate_condition_norm -get_terminate_condition(::ContinuousDEQSolver, args...; kwargs...) = terminate_condition - -# Termination conditions used in the original DEQ Paper -function get_terminate_condition(::ContinuousDEQSolver{Val(:abs_deq_default),A,T}, args...; kwargs...) where {A,T} - nstep = 0 - protective_threshold = T(1e6) - objective_values = T[] - function terminate_condition_closure(integrator, abstol, reltol, min_t) - du = DiffEqBase.get_du(integrator) - objective = norm(du) - # Main termination condition - objective <= abstol && return true - - # Terminate if there has been no improvement for the last 30 steps - nstep += 1 - push!(objective_values, objective) - - objective <= 3 * abstol && - nstep >= 30 && - maximum(objective_values[(end - nstep):end]) < 1.3 * minimum(objective_values[(end - nstep):end]) && - return true - - # Protective break - objective >= objective_values[1] * protective_threshold * length(du) && return true - - return false - end - return terminate_condition_closure -end - -function get_terminate_condition(::ContinuousDEQSolver{Val(:rel_deq_default),A,T}, args...; kwargs...) where {A,T} - nstep = 0 - protective_threshold = T(1e3) - objective_values = T[] - function terminate_condition_closure(integrator, abstol, reltol, min_t) - du = DiffEqBase.get_du(integrator) - u = integrator.u - objective = norm(du) / (norm(du .+ u) + eps(T)) - # Main termination condition - objective <= reltol && return true - - # Terminate if there has been no improvement for the last 30 steps - nstep += 1 - push!(objective_values, objective) - - objective <= 3 * reltol && - nstep >= 30 && - maximum(objective_values[(end - nstep + 1):end]) < 1.3 * minimum(objective_values[(end - nstep + 1):end]) && - return true - - # Protective break - objective >= objective_values[1] * protective_threshold * length(du) && return true - - return false - end - return terminate_condition_closure -end - -function get_terminate_condition(::ContinuousDEQSolver{Val(:rel_deq_best),A,T}, terminate_stats::Dict, args...; - kwargs...) where {A,T} - nstep = 0 - protective_threshold = T(1e3) - objective_values = T[] - - terminate_stats[:best_objective_value] = T(Inf) - terminate_stats[:best_objective_value_iteration] = 0 - - function terminate_condition_closure(integrator, abstol, reltol, min_t) - du = DiffEqBase.get_du(integrator) - u = integrator.u - objective = norm(du) / (norm(du .+ u) + eps(T)) - - if objective < terminate_stats[:best_objective_value] - terminate_stats[:best_objective_value] = objective - terminate_stats[:best_objective_value_iteration] = nstep + 1 - end - - # Main termination condition - objective <= reltol && return true - - # Terminate if there has been no improvement for the last 30 steps - nstep += 1 - push!(objective_values, objective) - - objective <= 3 * reltol && - nstep >= 30 && - maximum(objective_values[(end - nstep + 1):end]) < 1.3 * minimum(objective_values[(end - nstep + 1):end]) && - return true - - # Protective break - objective >= objective_values[1] * protective_threshold * length(du) && return true - - return false - end - - return terminate_condition_closure -end - -function get_terminate_condition(::ContinuousDEQSolver{Val(:abs_deq_best),A,T}, terminate_stats::Dict, args...; - kwargs...) where {A,T} - nstep = 0 - protective_threshold = T(1e3) - objective_values = T[] - - terminate_stats[:best_objective_value] = T(Inf) - terminate_stats[:best_objective_value_iteration] = 0 - - function terminate_condition_closure(integrator, abstol, reltol, min_t) - du = DiffEqBase.get_du(integrator) - objective = norm(du) - - if objective < terminate_stats[:best_objective_value] - terminate_stats[:best_objective_value] = objective - terminate_stats[:best_objective_value_iteration] = nstep + 1 - end - - # Main termination condition - objective <= reltol && return true - - # Terminate if there has been no improvement for the last 30 steps - nstep += 1 - push!(objective_values, objective) - - objective <= 3 * reltol && - nstep >= 30 && - maximum(objective_values[(end - nstep + 1):end]) < 1.3 * minimum(objective_values[(end - nstep + 1):end]) && - return true - - # Protective break - objective >= objective_values[1] * protective_threshold * length(du) && return true - - return false - end - - return terminate_condition_closure -end - -has_converged(du, u, alg::ContinuousDEQSolver) = all(abs.(du) .<= alg.abstol .& abs.(du) .<= alg.reltol .* abs.(u)) -has_converged(du, u, alg::ContinuousDEQSolver{Val(:norm)}) = norm(du) <= alg.abstol && norm(du) <= alg.reltol * norm(du .+ u) -has_converged(du, u, alg::ContinuousDEQSolver{Val(:rel)}) = all(abs.(du) .<= alg.reltol .* abs.(u)) -has_converged(du, u, alg::ContinuousDEQSolver{Val(:rel_norm)}) = norm(du) <= alg.reltol * norm(du .+ u) -has_converged(du, u, alg::ContinuousDEQSolver{Val(:rel_deq_default)}) = norm(du) <= alg.reltol * norm(du .+ u) -has_converged(du, u, alg::ContinuousDEQSolver{Val(:rel_deq_best)}) = norm(du) <= alg.reltol * norm(du .+ u) -has_converged(du, u, alg::ContinuousDEQSolver{Val(:abs)}) = all(abs.(du) .<= alg.abstol) -has_converged(du, u, alg::ContinuousDEQSolver{Val(:abs_norm)}) = norm(du) <= alg.abstol -has_converged(du, u, alg::ContinuousDEQSolver{Val(:abs_deq_default)}) = norm(du) <= alg.abstol -has_converged(du, u, alg::ContinuousDEQSolver{Val(:abs_deq_best)}) = norm(du) <= alg.abstol - -struct EquilibriumSolution{T,N,uType,R,P,A,TEnd} <: SciMLBase.AbstractNonlinearSolution{T,N} +struct EquilibriumSolution{T,N,uType,P,A,D} <: SciMLBase.AbstractNonlinearSolution{T,N} u::uType - resid::R + resid::uType prob::P alg::A retcode::Symbol - t::TEnd - λₜ::T + destats::D end function transform_solution(soln::EquilibriumSolution) @@ -247,8 +12,7 @@ function transform_solution(soln::EquilibriumSolution) return DiffEqBase.build_solution(soln.prob, soln.alg, soln.u, soln.resid; retcode=soln.retcode) end -function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, alg::ContinuousDEQSolver, args...; - regularize_endpoint=false, kwargs...) +function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem{uType}, alg::ContinuousDEQSolver, args...; kwargs...) where {uType} tspan = alg.tspan isa Tuple ? alg.tspan : convert.(real(eltype(prob.u0)), (zero(alg.tspan), alg.tspan)) _prob = ODEProblem(prob.f, prob.u0, tspan, prob.p) @@ -262,86 +26,10 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, alg::Co (sol.u[terminate_stats[:best_objective_value_iteration] + 1], sol.t[terminate_stats[:best_objective_value_iteration] + 1]) + # Dont count towards NFE since this is mostly a check for convergence du = prob.f(u, prob.p, t) - T = eltype(eltype(u)) - N = ndims(u) retcode = (sol.retcode == :Terminated && has_converged(du, u, alg) ? :Success : :Failure) - _t = regularize_endpoint isa Bool ? (regularize_endpoint ? t : nothing) : t - regularize_endpoint = regularize_endpoint isa Bool ? (regularize_endpoint ? T(1e-5) : T(0)) : T(regularize_endpoint) - - return EquilibriumSolution{T,N,typeof(u),typeof(du),typeof(prob),typeof(alg),typeof(_t)}(u, du, prob, alg, retcode, - _t, regularize_endpoint) -end - -function clear_zero(x::T) where T - ϵ = eps(T) - if -ϵ <= x < 0 - return -ϵ - elseif 0 <= x < ϵ - return ϵ - end - return x -end - -@noinline function DiffEqSensitivity.SteadyStateAdjointProblem(sol::EquilibriumSolution, - sensealg::DiffEqSensitivity.SteadyStateAdjoint, g, dg; - save_idxs=nothing) - @unpack f, p, u0 = sol.prob - - discrete = false - - p === DiffEqBase.NullParameters() && - error("Your model does not have parameters, and thus it is impossible to calculate the derivative of the solution with respect to the parameters. Your model must have parameters to use parameter sensitivity calculations!") - - sense = DiffEqSensitivity.SteadyStateAdjointSensitivityFunction(g, sensealg, discrete, sol, dg, f.colorvec, false) - @unpack diffcache, y, sol, λ, vjp, linsolve = sense - - _save_idxs = save_idxs === nothing ? Colon() : save_idxs - if dg !== nothing - if g !== nothing - dg(vec(diffcache.dg_val), y, p, nothing, nothing) - else - if typeof(_save_idxs) <: Number - diffcache.dg_val[_save_idxs] = dg[_save_idxs] - elseif typeof(dg) <: Number - @. diffcache.dg_val[_save_idxs] = dg - else - @. diffcache.dg_val[_save_idxs] = dg[_save_idxs] - end - end - else - if g !== nothing - DiffEqSensitivity.gradient!(vec(diffcache.dg_val), diffcache.g, y, sensealg, diffcache.g_grad_config) - end - end - - _val, back = Zygote.pullback(x -> f(x, p, nothing), y) - s_val = size(_val) - op = DiffEqSensitivity.ZygotePullbackMultiplyOperator{eltype(y),typeof(back),typeof(s_val)}(back, s_val) - - b = vec(diffcache.dg_val) - # println("Original Mean: $(mean(b)) & Residual Mean: $(mean(sol.resid)) Norm: $(norm(sol.resid))") - if sol.t !== nothing - @. b = (b + clamp(sol.λₜ ./ norm(sol.resid), -Inf, mean(b))) / 2 - end - # println("Updated mean: $(mean(b))") - linear_problem = LinearProblem(op, b) - - copyto!(vec(λ), solve(linear_problem, linsolve).u) - _, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p) - vjp .= -vec(back(λ)[1]) - - if g !== nothing - # compute del g/del p - dg_dp_val = zero(p) - dg_dp = DiffEqSensitivity.ParamGradientWrapper(g, nothing, y) - dg_dp_config = DiffEqSensitivity.build_grad_config(sensealg, dg_dp, p, p) - DiffEqSensitivity.gradient!(dg_dp_val, dg_dp, p, sensealg, dg_dp_config) - @. dg_dp_val = dg_dp_val + vjp - return dg_dp_val - else - return vjp - end + return EquilibriumSolution{eltype(uType),ndims(uType),uType,typeof(prob),typeof(alg),typeof(sol.destats)}(u, du, prob, alg, retcode, sol.destats) end diff --git a/src/solvers/continuous.jl b/src/solvers/continuous.jl new file mode 100644 index 00000000..a6a5dd08 --- /dev/null +++ b/src/solvers/continuous.jl @@ -0,0 +1,147 @@ +""" + ContinuousDEQSolver(alg=VCABM4(); mode::Symbol=:rel_deq_default, abstol=1f-8, reltol=1f-8, abstol_termination=1f-8, reltol_termination=1f-8, tspan=Inf32) + +Solver for Continuous DEQ Problem ([pal2022mixing](@cite)). Similar to `DynamicSS` but provides more flexibility needed +for solving DEQ problems. + +## Arguments + +* `alg`: Algorithm to solve the ODEProblem. (Default: `VCABM4()`) +* `mode`: Termination Mode of the solver. See below for a description of the various termination conditions (Default: `:rel_deq_default`) +* `abstol`: Absolute tolerance for time stepping. (Default: `1f-8`) +* `reltol`: Relative tolerance for time stepping. (Default: `1f-8`) +* `abstol_termination`: Absolute tolerance for termination. (Default: `1f-8`) +* `reltol_termination`: Relative tolerance for termination. (Default: `1f-8`) +* `tspan`: Time span. Users should not change this value, instead control termination through `maxiters` in `solve` (Default: `Inf32`) + +## Termination Modes + +#### Termination on Absolute Tolerance + +* `:abs`: Terminates if ``all \\left( | \\frac{\\partial u}{\\partial t} | \\leq abstol \\right)`` +* `:abs_norm`: Terminates if ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq abstol`` +* `:abs_deq_default`: Essentially `abs_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) +* `:abs_deq_best`: Same as `:abs_deq_default` but uses the best solution found so far, i.e. deviates only if the solution has not converged + +#### Termination on Relative Tolerance + +* `:rel`: Terminates if ``all \\left(| \\frac{\\partial u}{\\partial t} | \\leq reltol \\times | u | \\right)`` +* `:rel_norm`: Terminates if ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq reltol \\times \\| \\frac{\\partial u}{\\partial t} + u \\|`` +* `:rel_deq_default`: Essentially `rel_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) +* `:rel_deq_best`: Same as `:rel_deq_default` but uses the best solution found so far, i.e. deviates only if the solution has not converged + +#### Termination using both Absolute and Relative Tolerances + +* `:norm`: Terminates if ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq reltol \\times \\| \\frac{\\partial u}{\\partial t} + u \\|`` & + ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq abstol`` +* `fallback`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. This is usable for small problems + but doesn't scale well for neural networks, and should be avoided unless absolutely necessary + +See also: [`DiscreteDEQSolver`](@ref) + +!!! note + This will be upstreamed to DiffEqSensitivity in the later releases of the package +""" +struct ContinuousDEQSolver{M,A,T,TS} <: SteadyStateDiffEq.SteadyStateDiffEqAlgorithm + alg::A + abstol::T + reltol::T + abstol_termination::T + reltol_termination::T + tspan::TS +end + +function ContinuousDEQSolver( + alg=VCABM4(); + mode::Symbol=:rel_deq_default, + abstol::T=1.0f-8, + reltol::T=1.0f-8, + abstol_termination::T=1.0f-8, + reltol_termination::T=1.0f-8, + tspan=Inf32, +) where {T<:Number} + return ContinuousDEQSolver{Val(mode),typeof(alg),T,typeof(tspan)}( + alg, abstol, reltol, abstol_termination, reltol_termination, tspan + ) +end + +get_mode(::Val{mode}) where {mode} = mode + +function get_terminate_condition(alg::ContinuousDEQSolver{M,A,T}, args...; kwargs...) where {M,A,T} + mode = get_mode(M) + if mode ∈ (:abs_deq_default, :rel_deq_default, :abs_deq_best, :rel_deq_best) + nstep, protective_threshold, objective_values = 0, T(1e3), T[] + + if mode ∈ (:rel_deq_best, :abs_deq_best) + @assert length(args) == 1 + + args[1][:best_objective_value] = T(Inf) + args[1][:best_objective_value_iteration] = 0 + end + + function terminate_condition_closure_1(integrator, abstol, reltol, min_t) + du, u = DiffEqBase.get_du(integrator), integrator.u + objective = norm(du) / (mode ∈ (:abs_deq_default, :abs_deq_best) ? 1 : (norm(du .+ u) + eps(T))) + criteria = mode ∈ (:abs_deq_default, :abs_deq_best) ? abstol : reltol + + if mode ∈ (:rel_deq_best, :abs_deq_best) + if objective < args[1][:best_objective_value] + args[1][:best_objective_value] = objective + args[1][:best_objective_value_iteration] = nstep + 1 + end + end + + # Main Termination Criteria + objective <= criteria && return true + + # Terminate if there has been no improvement for the last 30 steps + nstep += 1 + push!(objective_values, objective) + + objective <= 3 * criteria && + nstep >= 30 && + maximum(objective_values[max(1, length(objective_values) - nstep):end]) < + 1.3 * minimum(objective_values[max(1, length(objective_values) - nstep):end]) && + return true + + # Protective break + objective >= objective_values[1] * protective_threshold * length(du) && return true + + return false + end + return terminate_condition_closure_1 + else + function terminate_condition_closure_2(integrator, abstol, reltol, min_t) + return has_converged(DiffEqBase.get_du(integrator), integrator.u, alg, abstol, reltol) + end + return terminate_condition_closure_2 + end +end + +# Convergence Criterions +function has_converged( + du, u, alg::ContinuousDEQSolver{M}, abstol=alg.abstol_termination, reltol=alg.reltol_termination +) where {M} + mode = get_mode(M) + if mode == :norm + return norm(du) <= abstol && norm(du) <= reltol * norm(du .+ u) + elseif mode == :rel + return all(abs.(du) .<= reltol .* abs.(u)) + elseif mode == :rel_norm + return norm(du) <= reltol * norm(du .+ u) + elseif mode == :rel_deq_default + return norm(du) <= reltol * norm(du .+ u) + elseif mode == :rel_deq_best + return norm(du) <= reltol * norm(du .+ u) + elseif mode == :abs + return all(abs.(du) .<= abstol) + elseif mode == :abs_norm + return norm(du) <= abstol + elseif mode == :abs_deq_default + return norm(du) <= abstol + elseif mode == :abs_deq_best + return norm(du) <= abstol + else + return all(abs.(du) .<= abstol .& abs.(du) .<= reltol .* abs.(u)) + end +end diff --git a/src/solvers/discrete.jl b/src/solvers/discrete.jl new file mode 100644 index 00000000..d0657261 --- /dev/null +++ b/src/solvers/discrete.jl @@ -0,0 +1,27 @@ +# Wrapper for Discrete DEQs +""" + DiscreteDEQSolver(solver=LimitedMemoryBroydenSolver; abstol=1e-8, reltol=1e-8, kwargs...) + +Solver for Discrete DEQ Problem ([baideep2019](@cite)). A wrapper around `SSrootfind` to mimic the [`ContinuousDEQSolver`](@ref) API. + +## Arguments + +* `solver`: NonLinear Solver for the DEQ problem. (Default: [`LimitedMemoryBroydenSolver`](@ref)) +* `abstol`: Absolute tolerance for termination. (Default: `1e-8`) +* `reltol`: Relative tolerance for termination. (Default: `1e-8`) +* `kwargs`: Additional keyword arguments passed to the solver. + +!!! note + There is no `mode` kwarg for [`DiscreteDEQSolver`](@ref). Instead solvers directly define their own termination condition. + For [`BroydenSolver`](@ref) and [`LimitedMemoryBroydenSolver`](@ref), the termination conditions are `:abs_norm` & + `:rel_deq_default` respectively. + +See also: [`ContinuousDEQSolver`](@ref) +""" +function DiscreteDEQSolver(solver=LimitedMemoryBroydenSolver; abstol=1e-8, reltol=1e-8, kwargs...) + solver = solver(; kwargs..., reltol=reltol, abstol=abstol) + return SSRootfind(; nlsolve=(f, u0, abstol) -> solver(f, u0)) +end + +include("discrete/broyden.jl") +include("discrete/limited_memory_broyden.jl") diff --git a/src/solvers/broyden.jl b/src/solvers/discrete/broyden.jl similarity index 100% rename from src/solvers/broyden.jl rename to src/solvers/discrete/broyden.jl diff --git a/src/solvers/limited_memory_broyden.jl b/src/solvers/discrete/limited_memory_broyden.jl similarity index 100% rename from src/solvers/limited_memory_broyden.jl rename to src/solvers/discrete/limited_memory_broyden.jl diff --git a/src/utils.jl b/src/utils.jl index cbcbf80e..da234c7f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,19 +1,4 @@ # General DEQ Utils -mutable struct DEQTrainingStats - nfe::Int -end - -""" - get_and_clear_nfe!(model::AbstractDeepEquilibriumNetwork) - -Return the number of function evaluations (NFE) and clear the counter. -""" -function get_and_clear_nfe!(model::AbstractDeepEquilibriumNetwork) - nfe = model.stats.nfe - model.stats.nfe = 0 - return nfe -end - """ SteadyStateAdjoint(reltol, abstol, maxiters; autojacvec=ZygoteVJP(), linsolve=KrylovJL_GMRES(; rtol=reltol, atol=abstol, itmax=maxiters)) @@ -41,53 +26,14 @@ Initializes the weights of the network with a normal distribution. For DEQs the if we use this as the Initialization """ function NormalInitializer(μ = 0.0f0, σ² = 0.01f0) - return (dims...) -> randn(dims...) .* σ² .+ μ -end - -# Wrapper for Discrete DEQs -""" - DiscreteDEQSolver(solver=LimitedMemoryBroydenSolver; abstol=1e-8, reltol=1e-8, kwargs...) - -Solver for Discrete DEQ Problem ([baideep2019](@cite)). A wrapper around `SSrootfind` to mimic the [`ContinuousDEQSolver`](@ref) API. - -## Arguments - -* `solver`: NonLinear Solver for the DEQ problem. (Default: [`LimitedMemoryBroydenSolver`](@ref)) -* `abstol`: Absolute tolerance for termination. (Default: `1e-8`) -* `reltol`: Relative tolerance for termination. (Default: `1e-8`) -* `kwargs`: Additional keyword arguments passed to the solver. - -!!! note - There is no `mode` kwarg for [`DiscreteDEQSolver`](@ref). Instead solvers directly define their own termination condition. - For [`BroydenSolver`](@ref) and [`LimitedMemoryBroydenSolver`](@ref), the termination conditions are `:abs_norm` & - `:rel_deq_default` respectively. - -See also: [`ContinuousDEQSolver`](@ref) -""" -function DiscreteDEQSolver(solver=LimitedMemoryBroydenSolver; abstol=1e-8, reltol=1e-8, kwargs...) - solver = solver(; kwargs..., reltol=reltol, abstol=abstol) - return SSRootfind(; nlsolve=(f, u0, abstol) -> solver(f, u0)) + return (rng::AbstractRNG, dims...) -> randn(rng, dims...) .* σ² .+ μ end # For MultiScale DEQs -function split_array_by_indices(x::AbstractVector, idxs) - return collect((x[(i + 1):j] for (i, j) in zip(idxs[1:(end - 1)], idxs[2:end]))) -end - -function split_array_by_indices(x::AbstractMatrix, idxs) - return collect((x[(i + 1):j, :] for (i, j) in zip(idxs[1:(end - 1)], idxs[2:end]))) -end - -Zygote.@adjoint function split_array_by_indices(x, idxs) - res = split_array_by_indices(x, idxs) - function split_array_by_indices_sensitivity(Δ) - is_nothings = Δ .=== nothing - if any(is_nothings) - Δ[is_nothings] .= zero.(res[is_nothings]) - end - return (vcat(Δ...), nothing) - end - return res, split_array_by_indices_sensitivity +function split_and_reshape(x::AbstractMatrix, idxs::Tuple, shapes::Tuple) + return Tuple( + @view(x[reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...), :]) for i in 1:(length(idxs) - 1) + ) end # Zygote Fix @@ -125,13 +71,3 @@ end dims = filter(i -> i != except_dim, 1:N) return _norm(x; dims=dims) end - -flatten_merge(x, y) = (x..., y...) -flatten_merge(x::T, y::T) where {T<:AbstractArray} = (x, y) -flatten_merge(x::NTuple{N,T}, y::T) where {N,T<:AbstractArray} = (x..., y) -flatten_merge(x::T, y::NTuple{N,T}) where {N,T<:AbstractArray} = (x, y...) -flatten_merge(x::NTuple{N,T}, y) where {N,T<:AbstractArray} = (x, y...) -flatten_merge(x, y::NTuple{N,T}) where {N,T<:AbstractArray} = (x..., y) -function flatten_merge(x::NTuple{N,T}, y::NTuple{N,T}) where {N,T<:AbstractArray} - return (x, y) -end diff --git a/test/runtests.jl b/test/runtests.jl index 41b61793..31f0972a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,156 +1,186 @@ -using FastDEQ -using CUDA -using Flux -using FluxExperimental -using LinearAlgebra -using Random -using Test +using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux + +const EFL = ExplicitFluxLayers @testset "FastDEQ.jl" begin - mse_loss_function = SupervisedLossContainer(loss_function = Flux.Losses.mse) + seed = 0 @info "Testing DEQ" - Random.seed!(0) - - model = gpu(DEQChain(Dense(2, 2), - DeepEquilibriumNetwork(Parallel(+, Dense(2, 2; bias=false), Dense(2, 2; bias=false)), - ContinuousDEQSolver(;abstol=0.1f0, reltol=0.1f0)))) - x = gpu(rand(Float32, 2, 1)) - y = gpu(rand(Float32, 2, 1)) - ps = Flux.params(model) - gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - for _p in ps - @test all(isfinite.(gs[_p])) - end + model = DEQChain( + EFL.Dense(2, 2), + DeepEquilibriumNetwork( + EFL.Parallel(+, EFL.Dense(2, 2; bias=false), EFL.Dense(2, 2; bias=false)), + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), + ), + ) + ps, st = gpu.(EFL.setup(MersenneTwister(seed), model)) + x = gpu(rand(MersenneTwister(seed + 1), Float32, 2, 1)) + y = gpu(rand(MersenneTwister(seed + 2), Float32, 2, 1)) + + gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps) @info "Testing SkipDEQ" - Random.seed!(0) - - model = gpu(DEQChain(Dense(2, 2), - SkipDeepEquilibriumNetwork(Parallel(+, Dense(2, 2), Dense(2, 2)), - Dense(2, 2), - ContinuousDEQSolver(;abstol=0.1f0, reltol=0.1f0); - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10)))) - x = gpu(rand(Float32, 2, 1)) - y = gpu(rand(Float32, 2, 1)) - ps = Flux.params(model) - gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - for _p in ps - @test all(isfinite.(gs[_p])) - end - - @info "Testing SkipDEQ V2" - Random.seed!(0) - - model = gpu(DEQChain(Dense(2, 2), - SkipDeepEquilibriumNetwork(Parallel(+, Dense(2, 2), Dense(2, 2)), - ContinuousDEQSolver(;abstol=0.1f0, reltol=0.1f0)))) - x = gpu(rand(Float32, 2, 1)) - y = gpu(rand(Float32, 2, 1)) - ps = Flux.params(model) - gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - for _p in ps - @test all(isfinite.(gs[_p])) + model = DEQChain( + EFL.Dense(2, 2), + SkipDeepEquilibriumNetwork( + EFL.Parallel(+, EFL.Dense(2, 2), EFL.Dense(2, 2)), + EFL.Dense(2, 2), + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0); + sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + ), + ) + ps, st = gpu.(EFL.setup(MersenneTwister(seed), model)) + x = gpu(rand(MersenneTwister(seed + 1), Float32, 2, 1)) + y = gpu(rand(MersenneTwister(seed + 2), Float32, 2, 1)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end - @info "Testing Broyden Solver" - Random.seed!(0) - - model = gpu(DEQChain(Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - SkipDeepEquilibriumNetwork(Parallel(+, Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - Conv((3, 3), 1 => 1, relu; pad=1, stride=1)), - Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - DiscreteDEQSolver(BroydenSolver; abstol=0.001f0, - reltol=0.001f0, device=gpu, original_dims=(8 * 8, 1), - batch_size=4); - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10)))) - x = gpu(rand(Float32, 8, 8, 1, 4)) - y = gpu(rand(Float32, 8, 8, 1, 4)) - ps = Flux.params(model) - gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - for _p in ps - @test all(isfinite.(gs[_p])) + @info "Testing SkipDEQ " + model = DEQChain( + EFL.Dense(2, 2), + SkipDeepEquilibriumNetwork( + EFL.Parallel(+, EFL.Dense(2, 2), EFL.Dense(2, 2)), + nothing, + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0); + sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + ), + ) + ps, st = gpu.(EFL.setup(MersenneTwister(seed), model)) + x = gpu(rand(MersenneTwister(seed + 1), Float32, 2, 1)) + y = gpu(rand(MersenneTwister(seed + 2), Float32, 2, 1)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end - @info "Testing L-Broyden Solver" - Random.seed!(0) - - model = gpu(DEQChain(Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - SkipDeepEquilibriumNetwork(Parallel(+, Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - Conv((3, 3), 1 => 1, relu; pad=1, stride=1)), - Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - DiscreteDEQSolver(LimitedMemoryBroydenSolver; abstol=0.001f0, - reltol=0.001f0, device=gpu, original_dims=(8 * 8, 1), - batch_size=4); - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10)))) - x = gpu(rand(Float32, 8, 8, 1, 4)) - y = gpu(rand(Float32, 8, 8, 1, 4)) - ps = Flux.params(model) - gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - for _p in ps - @test all(isfinite.(gs[_p])) - end - - @info "Testing MultiScaleDEQ" - Random.seed!(0) - - model = gpu(MultiScaleDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh_fast), Dense(4, 4, tanh_fast)), - Dense(3, 3, tanh_fast), Dense(2, 2, tanh_fast), - Dense(1, 1, tanh_fast)), - [NoOpLayer() Dense(4, 3, tanh_fast) Dense(4, 2, tanh_fast) Dense(4, 1, tanh_fast); - Dense(3, 4, tanh_fast) NoOpLayer() Dense(3, 2, tanh_fast) Dense(3, 1, tanh_fast); - Dense(2, 4, tanh_fast) Dense(2, 3, tanh_fast) NoOpLayer() Dense(2, 1, tanh_fast); - Dense(1, 4, tanh_fast) Dense(1, 3, tanh_fast) Dense(1, 2, tanh_fast) NoOpLayer()], - ContinuousDEQSolver(;abstol=0.1f0, reltol=0.1f0))) - x = gpu(rand(Float32, 4, 2)) - y = tuple([gpu(rand(Float32, i, 2)) for i in 4:-1:1]...) - sol = model(x) - ps = Flux.params(model) - gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - for _p in ps - @test all(isfinite.(gs[_p])) - end + # @info "Testing Broyden Solver" + # Random.seed!(0) - @info "Testing MultiScaleSkipDEQ" - Random.seed!(0) - - model = gpu(MultiScaleSkipDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh_fast), Dense(4, 4, tanh_fast)), - Dense(3, 3, tanh_fast), Dense(2, 2, tanh_fast), - Dense(1, 1, tanh_fast)), - [NoOpLayer() Dense(4, 3, tanh_fast) Dense(4, 2, tanh_fast) Dense(4, 1, tanh_fast); - Dense(3, 4, tanh_fast) NoOpLayer() Dense(3, 2, tanh_fast) Dense(3, 1, tanh_fast); - Dense(2, 4, tanh_fast) Dense(2, 3, tanh_fast) NoOpLayer() Dense(2, 1, tanh_fast); - Dense(1, 4, tanh_fast) Dense(1, 3, tanh_fast) Dense(1, 2, tanh_fast) NoOpLayer()], - (Dense(4, 4, tanh_fast), Dense(4, 3, tanh_fast), - Dense(4, 2, tanh_fast), Dense(4, 1, tanh_fast)), - ContinuousDEQSolver(;abstol=0.1f0, reltol=0.1f0))) - x = gpu(rand(Float32, 4, 2)) - y = tuple([gpu(rand(Float32, i, 2)) for i in 4:-1:1]...) - sol = model(x) - ps = Flux.params(model) - gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - for _p in ps - @test all(isfinite.(gs[_p])) - end + # model = gpu(DEQChain(Conv((3, 3), 1 => 1, relu; pad=1, stride=1), + # SkipDeepEquilibriumNetwork(Parallel(+, Conv((3, 3), 1 => 1, relu; pad=1, stride=1), + # Conv((3, 3), 1 => 1, relu; pad=1, stride=1)), + # Conv((3, 3), 1 => 1, relu; pad=1, stride=1), + # DiscreteDEQSolver(BroydenSolver; abstol=0.001f0, + # reltol=0.001f0, device=gpu, original_dims=(8 * 8, 1), + # batch_size=4); + # sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10)))) + # x = gpu(rand(Float32, 8, 8, 1, 4)) + # y = gpu(rand(Float32, 8, 8, 1, 4)) + # ps = Flux.params(model) + # gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) + # for _p in ps + # @test all(isfinite.(gs[_p])) + # end - # CI gives mutation error though it works locally. - # @info "Testing MultiScaleSkipDEQV2" + # @info "Testing L-Broyden Solver" # Random.seed!(0) - # model = gpu(MultiScaleSkipDeepEquilibriumNetwork((Parallel(+, Dense(4, 4, tanh_fast), Dense(4, 4, tanh_fast)), - # Dense(3, 3, tanh_fast), Dense(2, 2, tanh_fast), - # Dense(1, 1, tanh_fast)), - # [NoOpLayer() Dense(4, 3, tanh_fast) Dense(4, 2, tanh_fast) Dense(4, 1, tanh_fast); - # Dense(3, 4, tanh_fast) NoOpLayer() Dense(3, 2, tanh_fast) Dense(3, 1, tanh_fast); - # Dense(2, 4, tanh_fast) Dense(2, 3, tanh_fast) NoOpLayer() Dense(2, 1, tanh_fast); - # Dense(1, 4, tanh_fast) Dense(1, 3, tanh_fast) Dense(1, 2, tanh_fast) NoOpLayer()], - # ContinuousDEQSolver(;abstol=0.1f0, reltol=0.1f0))) - # x = gpu(rand(Float32, 4, 2)) - # y = tuple([gpu(rand(Float32, i, 2)) for i in 4:-1:1]...) - # sol = model(x) + # model = gpu(DEQChain(Conv((3, 3), 1 => 1, relu; pad=1, stride=1), + # SkipDeepEquilibriumNetwork(Parallel(+, Conv((3, 3), 1 => 1, relu; pad=1, stride=1), + # Conv((3, 3), 1 => 1, relu; pad=1, stride=1)), + # Conv((3, 3), 1 => 1, relu; pad=1, stride=1), + # DiscreteDEQSolver(LimitedMemoryBroydenSolver; abstol=0.001f0, + # reltol=0.001f0, device=gpu, original_dims=(8 * 8, 1), + # batch_size=4); + # sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10)))) + # x = gpu(rand(Float32, 8, 8, 1, 4)) + # y = gpu(rand(Float32, 8, 8, 1, 4)) # ps = Flux.params(model) # gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) # for _p in ps # @test all(isfinite.(gs[_p])) # end + + @info "Testing MultiScaleDEQ" + model = MultiScaleDeepEquilibriumNetwork( + ( + EFL.Parallel(+, EFL.Dense(4, 4, tanh), EFL.Dense(4, 4, tanh)), + EFL.Dense(3, 3, tanh), + EFL.Dense(2, 2, tanh), + EFL.Dense(1, 1, tanh), + ), + [ + EFL.NoOpLayer() EFL.Dense(4, 3, tanh) EFL.Dense(4, 2, tanh) EFL.Dense(4, 1, tanh) + EFL.Dense(3, 4, tanh) EFL.NoOpLayer() EFL.Dense(3, 2, tanh) EFL.Dense(3, 1, tanh) + EFL.Dense(2, 4, tanh) EFL.Dense(2, 3, tanh) EFL.NoOpLayer() EFL.Dense(2, 1, tanh) + EFL.Dense(1, 4, tanh) EFL.Dense(1, 3, tanh) EFL.Dense(1, 2, tanh) EFL.NoOpLayer() + ], + nothing, + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + ) + + ps, st = gpu.(EFL.setup(MersenneTwister(seed), model)) + x = gpu(rand(MersenneTwister(seed + 1), Float32, 4, 2)) + y = tuple([gpu(rand(MersenneTwister(seed + 1 + i), Float32, i, 2)) for i in 4:-1:1]...) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + end + + @info "Testing MultiScaleSkipDEQ" + model = MultiScaleSkipDeepEquilibriumNetwork( + ( + EFL.Parallel(+, EFL.Dense(4, 4, tanh), EFL.Dense(4, 4, tanh)), + EFL.Dense(3, 3, tanh), + EFL.Dense(2, 2, tanh), + EFL.Dense(1, 1, tanh), + ), + [ + EFL.NoOpLayer() EFL.Dense(4, 3, tanh) EFL.Dense(4, 2, tanh) EFL.Dense(4, 1, tanh) + EFL.Dense(3, 4, tanh) EFL.NoOpLayer() EFL.Dense(3, 2, tanh) EFL.Dense(3, 1, tanh) + EFL.Dense(2, 4, tanh) EFL.Dense(2, 3, tanh) EFL.NoOpLayer() EFL.Dense(2, 1, tanh) + EFL.Dense(1, 4, tanh) EFL.Dense(1, 3, tanh) EFL.Dense(1, 2, tanh) EFL.NoOpLayer() + ], + nothing, + (EFL.Dense(4, 4, tanh), EFL.Dense(4, 3, tanh), EFL.Dense(4, 2, tanh), EFL.Dense(4, 1, tanh)), + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + ) + + ps, st = gpu.(EFL.setup(MersenneTwister(seed), model)) + x = gpu(rand(MersenneTwister(seed + 1), Float32, 4, 2)) + y = tuple([gpu(rand(MersenneTwister(seed + 1 + i), Float32, i, 2)) for i in 4:-1:1]...) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end + + @info "Testing MultiScaleSkipDEQ" + model = MultiScaleSkipDeepEquilibriumNetwork( + ( + EFL.Parallel(+, EFL.Dense(4, 4, tanh), EFL.Dense(4, 4, tanh)), + EFL.Dense(3, 3, tanh), + EFL.Dense(2, 2, tanh), + EFL.Dense(1, 1, tanh), + ), + [ + EFL.NoOpLayer() EFL.Dense(4, 3, tanh) EFL.Dense(4, 2, tanh) EFL.Dense(4, 1, tanh) + EFL.Dense(3, 4, tanh) EFL.NoOpLayer() EFL.Dense(3, 2, tanh) EFL.Dense(3, 1, tanh) + EFL.Dense(2, 4, tanh) EFL.Dense(2, 3, tanh) EFL.NoOpLayer() EFL.Dense(2, 1, tanh) + EFL.Dense(1, 4, tanh) EFL.Dense(1, 3, tanh) EFL.Dense(1, 2, tanh) EFL.NoOpLayer() + ], + nothing, + nothing, + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + ) + + ps, st = gpu.(EFL.setup(MersenneTwister(seed), model)) + x = gpu(rand(MersenneTwister(seed + 1), Float32, 4, 2)) + y = tuple([gpu(rand(MersenneTwister(seed + 1 + i), Float32, i, 2)) for i in 4:-1:1]...) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end end From a032635239f8b68bc46d6a93dce1cc584aab0aa0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Apr 2022 18:17:31 -0400 Subject: [PATCH 02/76] Some model code --- examples/core/core.jl | 9 ++ examples/core/models.jl | 324 ++++++++++++++++++++++++++++++++++++++++ src/layers/mdeq.jl | 8 +- src/operator.jl | 2 +- src/utils.jl | 2 +- test/runtests.jl | 4 +- 6 files changed, 341 insertions(+), 8 deletions(-) create mode 100644 examples/core/core.jl create mode 100644 examples/core/models.jl diff --git a/examples/core/core.jl b/examples/core/core.jl new file mode 100644 index 00000000..2dbf4106 --- /dev/null +++ b/examples/core/core.jl @@ -0,0 +1,9 @@ +module FastDEQExperiments + +using FastDEQ, ExplicitFluxLayers, Random, Flux, OrdinaryDiffEq + +const EFL = ExplicitFluxLayers + +include("models.jl") + +end \ No newline at end of file diff --git a/examples/core/models.jl b/examples/core/models.jl new file mode 100644 index 00000000..4acf3caa --- /dev/null +++ b/examples/core/models.jl @@ -0,0 +1,324 @@ +# Building Blocks +## Helpful Functional Wrappers +function conv1x1(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) + return EFL.Conv((1, 1), mapping, activation; stride=stride, pad=0, bias=bias, kwargs...) +end + +function conv3x3(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) + return EFL.Conv((3, 3), mapping, activation; stride=stride, pad=1, bias=bias, kwargs...) +end + +function conv5x5(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) + return EFL.Conv((5, 5), mapping, activation; stride=stride, pad=2, bias=bias, kwargs...) +end + +reassociate(x::NTuple{2,<:AbstractArray}, y) = (x[1], (x[2], y)) + +## Downsample Module +function downsample_module(mapping, resolution_mapping, activation; group_count=8) + in_resolution, out_resolution = resolution_mapping + in_channels, out_channels = mapping + @assert in_resolution > out_resolution + @assert ispow2(in_resolution ÷ out_resolution) + level_diff = Int(log2(in_resolution ÷ out_resolution)) + + function intermediate_mapping(i) + if in_channels * (2^level_diff) == out_channels + return (in_channels * (2^(i - 1))) => (in_channels * (2^i)) + else + return i == level_diff ? in_channels => out_channels : in_channels => in_channels + end + end + + layers = EFL.AbstractExplicitLayer[] + for i in 1:level_diff + inchs, outchs = intermediate_mapping(i) + push!(layers, conv3x3(inchs => outchs; stride=2, initW=NormalInitializer())) + push!(layers, EFL.GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) + end + return EFL.Chain(layers...) +end + +## Upsample Module +function upsample_module(mapping, resolution_mapping, activation; upsample_mode::Symbol=:nearest, group_count=8) + in_resolution, out_resolution = resolution_mapping + in_channels, out_channels = mapping + @assert in_resolution < out_resolution + @assert ispow2(out_resolution ÷ in_resolution) + level_diff = Int(log2(out_resolution ÷ in_resolution)) + + function intermediate_mapping(i) + if out_channels * (2^level_diff) == in_channels + (in_channels ÷ (2^(i - 1))) => (in_channels ÷ (2^i)) + else + i == level_diff ? in_channels => out_channels : in_channels => in_channels + end + end + + layers = EFL.AbstractExplicitLayer[] + for i in 1:level_diff + inchs, outchs = intermediate_mapping(i) + push!(layers, conv1x1(inchs => outchs; initW=NormalInitializer())) + push!(layers, EFL.GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) + push!(layers, EFL.Upsample(upsample_mode; scale=2)) + end + return EFL.Chain(layers...) +end + +## Residual Block +function ResidualBlockV1( + mapping; + deq_expand::Int=5, + num_gn_groups::Int=4, + downsample=EFL.NoOpLayer(), + n_big_kernels::Int=0, + dropout_rate::Real=0.0f0, + gn_affine::Bool=true, + weight_norm::Bool=true, + gn_track_stats::Bool=false, +) + inplanes, outplanes = mapping + inner_planes = outplanes * deq_expand + conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; initW=NormalInitializer(), bias=true) + conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; initW=NormalInitializer(), bias=true) + + conv1, conv2 = if weight_norm + EFL.WeightNorm(conv1, (:weight,)), EFL.WeightNorm(conv2, (:weight,)) + else + conv1, conv2 + end + + gn1 = EFL.GroupNorm(inner_planes, num_gn_groups, gelu; affine=gn_affine, track_stats=gn_track_stats) + gn2 = EFL.GroupNorm(outplanes, num_gn_groups, gelu; affine=gn_affine, track_stats=gn_track_stats) + gn3 = EFL.GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + + dropout = iszero(dropout_rate) ? EFL.NoOpLayer() : EFL.Dropout(dropout_rate) + + return EFL.Chain( + EFL.Parallel( + reassociate, # Reassociate and Merge + EFL.Chain(conv1, gn1, conv2, EFL.BranchLayer(downsample, dropout)), # For x + EFL.NoOpLayer(), # For injection + ), + EFL.Parallel( + +, + EFL.NoOpLayer(), # For y1 + EFL.Chain( + EFL.WrappedFunction(y2i -> y2i[1] .+ y2i[2]), # Since injection could be a scalar + gn2, + ), # For (y2, injection) + ), + EFL.WrappedFunction(Base.Fix1(broadcast, gelu)), + gn3, + ) +end + +function ResidualBlockV2( + mapping; + deq_expand::Int=5, + num_gn_groups::Int=4, + downsample=EFL.NoOpLayer(), + n_big_kernels::Int=0, + dropout_rate::Real=0.0f0, + gn_affine::Bool=true, + weight_norm::Bool=true, + gn_track_stats::Bool=false, +) + inplanes, outplanes = mapping + inner_planes = outplanes * deq_expand + conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; initW=NormalInitializer(), bias=true) + conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; initW=NormalInitializer(), bias=true) + + conv1, conv2 = if weight_norm + EFL.WeightNorm(conv1, (:weight,)), EFL.WeightNorm(conv2, (:weight,)) + else + conv1, conv2 + end + + gn1 = EFL.GroupNorm(inner_planes, num_gn_groups, gelu; affine=gn_affine, track_stats=gn_track_stats) + gn2 = EFL.GroupNorm(outplanes, num_gn_groups, gelu; affine=gn_affine, track_stats=gn_track_stats) + gn3 = EFL.GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + + dropout = iszero(dropout_rate) ? EFL.NoOpLayer() : EFL.Dropout(dropout_rate) + + return EFL.Chain( + conv1, + gn1, + conv2, + EFL.BranchLayer(downsample, dropout), + EFL.Parallel(+, EFL.NoOpLayer(), gn2), + EFL.WrappedFunction(Base.Fix1(broadcast, gelu)), + gn3, + ) +end + +function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=true, bn_affine::Bool=true) + rescale = if first(mapping) != last(mapping) * expansion + EFL.Chain( + conv1x1(first(mapping) => last(mapping) * expansion; initW=NormalInitializer()), + EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), + ) + else + EFL.NoOpLayer() + end + + return EFL.Chain( + EFL.Parallel( + reassociate, EFL.BranchLayer(rescale, conv1x1(mapping; initW=NormalInitializer())), EFL.NoOpLayer() + ), + EFL.Parallel( + +, + EFL.NoOpLayer(), + EFL.Chain( + EFL.WrappedFunction(y2i -> y2i[1] .+ y2i[2]), # Since injection could be a scalar + EFL.Chain( + EFL.BatchNorm(last(mapping), gelu; affine=bn_affine, track_stats=bn_track_stats), + conv3x3(last(mapping) => last(mapping) * expansion; initW=NormalInitializer()), + EFL.BatchNorm(last(mapping) * expansion, gelu; track_stats=bn_track_stats, affine=bn_affine), + conv1x1(last(mapping) * expansion => last(mapping) * expansion; initW=NormalInitializer()), + EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), + ), + ), + ), + EFL.WrappedFunction(Base.Fix1(broadcast, gelu)), + ) +end + +function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=true, bn_affine::Bool=true) + rescale = if first(mapping) != last(mapping) * expansion + EFL.Chain( + conv1x1(first(mapping) => last(mapping) * expansion; initW=NormalInitializer()), + EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), + ) + else + EFL.NoOpLayer() + end + + return EFL.Chain( + EFL.Parallel( + +, + rescale, + EFL.Chain( + conv1x1(mapping; initW=NormalInitializer()), + EFL.BatchNorm(last(mapping), gelu; affine=bn_affine, track_stats=bn_track_stats), + conv3x3(last(mapping) => last(mapping) * expansion; initW=NormalInitializer()), + EFL.BatchNorm(last(mapping) * expansion, gelu; track_stats=bn_track_stats, affine=bn_affine), + conv1x1(last(mapping) * expansion => last(mapping) * expansion; initW=NormalInitializer()), + EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), + ), + ), + EFL.WrappedFunction(Base.Fix1(broadcast, gelu)), + ) +end + +# Dataset Specific Models +## CIFAR10 -- MultiScaleDEQ +function get_model( + ::Val{:CIFAR10}; + dropout_rate, + group_count=8, + model_type::Symbol, + continuous::Bool=true, + maxiters::Int, + abstol, + reltol, +) + initial_layers = EFL.Chain( + conv3x3(3 => 24; initW=NormalInitializer()), + EFL.BatchNorm(24, gelu; track_stats=true, affine=true), + conv3x3(24 => 24; initW=NormalInitializer()), + EFL.BatchNorm(24, gelu; track_stats=true, affine=true), + ) + + main_layers = ( + ResidualBlockV1(24 => 24; dropout_rate, num_gn_groups=group_count), # 32 x 32 + ResidualBlockV1(24 => 24; dropout_rate, num_gn_groups=group_count), # 16 x 16 + ) + + mapping_layers = [ + EFL.NoOpLayer() downsample_module(24 => 24, 32 => 16, gelu; group_count=group_count) + upsample_module(24 => 24, 16 => 32, gelu; group_count=group_count, upsample_mode=:nearest) EFL.NoOpLayer() + ] + + post_fuse_layers = ( + EFL.Chain( + EFL.WrappedFunction(Base.Fix1(broadcast, gelu)), + conv1x1(24 => 24; initW=NormalInitializer()), + EFL.GroupNorm(24, group_count ÷ 2; affine=true, track_stats=false), + ), + EFL.Chain( + EFL.WrappedFunction(Base.Fix1(broadcast, gelu)), + conv1x1(24 => 24; initW=NormalInitializer()), + EFL.GroupNorm(24, group_count ÷ 2; affine=true, track_stats=false), + ), + ) + + final_layers = EFL.Chain( + EFL.Parallel( + +, + EFL.Chain( + BottleneckBlockV2(24 => 8), + conv3x3(8 * 4 => 16 * 4; stride=2, initW=NormalInitializer()), + EFL.BatchNorm(16 * 4, gelu; track_stats=true, affine=true), + ), + BottleneckBlockV2(24 => 16, 4), + ), + conv1x1(16 * 4 => 200; initW=NormalInitializer()), + EFL.BatchNorm(200, gelu; track_stats=true, affine=true), + EFL.GlobalMeanPool(), + EFL.FlattenLayer(), + EFL.Dense(200, 10), + ) + + solver = if continuous + ContinuousDEQSolver( + VCABM3(); + mode=:rel_deq_best, + abstol=abstol, + reltol=reltol, + abstol_termination=abstol, + reltol_termination=reltol, + ) + else + error("Discrete Solvers have not been updated yet") + end + + sensealg = SteadyStateAdjoint(abstol, reltol, min(maxiters, 15)) + + deq = if model_type ∈ (:skip, :skipv2) + shortcut = if model_type == :skip + ( + ResidualBlockV2(24 => 24; num_gn_groups=group_count), + downsample_module(24 => 24, 32 => 16, gelu; group_count=group_count), + ) + else + nothing + end + MultiScaleSkipDeepEquilibriumNetwork( + main_layers, + mapping_layers, + post_fuse_layers, + shortcut, + solver, + ((32, 32, 24), (16, 16, 24)); + maxiters=maxiters, + sensealg=sensealg, + verbose=false, + ) + elseif model_type == :vanilla + MultiScaleSkipDeepEquilibriumNetwork( + main_layers, + mapping_layers, + post_fuse_layers, + solver, + ((32, 32, 24), (16, 16, 24)); + maxiters=maxiters, + sensealg=sensealg, + verbose=false, + ) + else + throw(ArgumentError("`model_type` must be one of `[:skip, :skipv2, :vanilla]`")) + end + + return DEQChain(initial_layers, deq, final_layers) +end diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index d373b303..3d32930b 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -56,9 +56,9 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})(x::AbstractArray{T}, ps::Nam prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps) sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) - z_star, st_ = dudt_(sol.u, ps, nothing)::Tuple{NTuple{N,typeof(x)},typeof(st.model)} + z_star, st_ = dudt_(sol.u, ps, nothing) - residual = Zygote.@ignore dudt(sol.u, ps, nothing) ::typeof(x) + residual = Zygote.@ignore dudt(sol.u, ps, nothing) @set! st.model = st_ @@ -137,9 +137,9 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model) sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) - z_star, st_ = dudt_(sol.u, ps.model, nothing)::Tuple{NTuple{N,typeof(x)},typeof(st.model)} + z_star, st_ = dudt_(sol.u, ps.model, nothing) - residual = Zygote.@ignore dudt(sol.u, ps.model, nothing) ::typeof(x) + residual = Zygote.@ignore dudt(sol.u, ps.model, nothing) @set! st.model = st_ diff --git a/src/operator.jl b/src/operator.jl index 797fd0cd..3e98099c 100644 --- a/src/operator.jl +++ b/src/operator.jl @@ -22,4 +22,4 @@ function Base.:*(L::ZygotePullbackMultiplyOperator, x::AbstractVector) return L.f(reshape(x, L.s))[1] end -SciMLBase.isinplace(z::ZygotePullbackMultiplyOperator, ::Int64) = false \ No newline at end of file +SciMLBase.isinplace(z::ZygotePullbackMultiplyOperator, ::Int64) = false diff --git a/src/utils.jl b/src/utils.jl index da234c7f..83336047 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -26,7 +26,7 @@ Initializes the weights of the network with a normal distribution. For DEQs the if we use this as the Initialization """ function NormalInitializer(μ = 0.0f0, σ² = 0.01f0) - return (rng::AbstractRNG, dims...) -> randn(rng, dims...) .* σ² .+ μ + return (rng::AbstractRNG, dims...) -> randn(rng, Float32, dims...) .* σ² .+ μ end # For MultiScale DEQs diff --git a/test/runtests.jl b/test/runtests.jl index 31f0972a..fb9d457d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,7 +38,7 @@ const EFL = ExplicitFluxLayers sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end - @info "Testing SkipDEQ " + @info "Testing SkipDEQV2" model = DEQChain( EFL.Dense(2, 2), SkipDeepEquilibriumNetwork( @@ -154,7 +154,7 @@ const EFL = ExplicitFluxLayers sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end - @info "Testing MultiScaleSkipDEQ" + @info "Testing MultiScaleSkipDEQV2" model = MultiScaleSkipDeepEquilibriumNetwork( ( EFL.Parallel(+, EFL.Dense(4, 4, tanh), EFL.Dense(4, 4, tanh)), From 8ec596df822545d77d3f96d4f2ce303a9843c0d6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Apr 2022 01:30:49 -0400 Subject: [PATCH 03/76] Update model --- examples/core/models.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/core/models.jl b/examples/core/models.jl index 4acf3caa..c964c89e 100644 --- a/examples/core/models.jl +++ b/examples/core/models.jl @@ -79,8 +79,8 @@ function ResidualBlockV1( ) inplanes, outplanes = mapping inner_planes = outplanes * deq_expand - conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; initW=NormalInitializer(), bias=true) - conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; initW=NormalInitializer(), bias=true) + conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; initW=NormalInitializer(), bias=false) + conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; initW=NormalInitializer(), bias=false) conv1, conv2 = if weight_norm EFL.WeightNorm(conv1, (:weight,)), EFL.WeightNorm(conv2, (:weight,)) @@ -126,8 +126,8 @@ function ResidualBlockV2( ) inplanes, outplanes = mapping inner_planes = outplanes * deq_expand - conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; initW=NormalInitializer(), bias=true) - conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; initW=NormalInitializer(), bias=true) + conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; initW=NormalInitializer(), bias=false) + conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; initW=NormalInitializer(), bias=false) conv1, conv2 = if weight_norm EFL.WeightNorm(conv1, (:weight,)), EFL.WeightNorm(conv2, (:weight,)) From c6fd5e51a3c32423eaff029425d54e51e2af3871 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Apr 2022 10:34:35 -0400 Subject: [PATCH 04/76] Update examples --- examples/Manifest.toml | 1631 ++++++++++++++++++++++++++++ examples/Project.toml | 18 + examples/core/core.jl | 9 - examples/src/FastDEQExperiments.jl | 18 + examples/src/dataloaders.jl | 36 + examples/src/logging.jl | 137 +++ examples/{core => src}/models.jl | 25 +- examples/src/train.jl | 24 + 8 files changed, 1888 insertions(+), 10 deletions(-) create mode 100644 examples/Manifest.toml create mode 100644 examples/Project.toml delete mode 100644 examples/core/core.jl create mode 100644 examples/src/FastDEQExperiments.jl create mode 100644 examples/src/dataloaders.jl create mode 100644 examples/src/logging.jl rename examples/{core => src}/models.jl (93%) create mode 100644 examples/src/train.jl diff --git a/examples/Manifest.toml b/examples/Manifest.toml new file mode 100644 index 00000000..a15e4ec8 --- /dev/null +++ b/examples/Manifest.toml @@ -0,0 +1,1631 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.7.2" +manifest_format = "2.0" + +[[deps.AbstractFFTs]] +deps = ["ChainRulesCore", "LinearAlgebra"] +git-tree-sha1 = "6f1d9bc1c08f9f4a8fa92e3ea3cb50153a1b40d4" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.1.0" + +[[deps.Accessors]] +deps = ["Compat", "CompositionsBase", "ConstructionBase", "Future", "LinearAlgebra", "MacroTools", "Requires", "Test"] +git-tree-sha1 = "2bba2aa45df94e95b1a9c2405d7cfc3d60281db8" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.9" + +[[deps.Adapt]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "af92965fb30777147966f58acb05da51c5616b5f" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "3.3.3" + +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" + +[[deps.ArnoldiMethod]] +deps = ["LinearAlgebra", "Random", "StaticArrays"] +git-tree-sha1 = "62e51b39331de8911e4a7ff6f5aaf38a5f4cc0ae" +uuid = "ec485272-7323-5ecc-a04f-4719b315124d" +version = "0.2.0" + +[[deps.ArrayInterface]] +deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"] +git-tree-sha1 = "c933ce606f6535a7c7b98e1d86d5d1014f730596" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "5.0.7" + +[[deps.ArrayLayouts]] +deps = ["FillArrays", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "8b921542ad44cba67f1487e2226446597e0a90af" +uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" +version = "0.8.5" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.AxisAlgorithms]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "WoodburyMatrices"] +git-tree-sha1 = "66771c8d21c8ff5e3a93379480a2307ac36863f7" +uuid = "13072b0f-2c55-5437-9ae7-d433b7a33950" +version = "1.0.1" + +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.2.0" + +[[deps.BandedMatrices]] +deps = ["ArrayLayouts", "FillArrays", "LinearAlgebra", "Random", "SparseArrays"] +git-tree-sha1 = "960ad9a4b34380595500f60add129e178740c3a6" +uuid = "aae01518-5342-5314-be14-df237901396f" +version = "0.17.0" + +[[deps.BangBang]] +deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] +git-tree-sha1 = "b15a6bc52594f5e4a3b825858d1089618871bf9d" +uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +version = "0.3.36" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.Baselet]] +git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" +uuid = "9718e550-a3fa-408a-8086-8db961cd8217" +version = "0.1.1" + +[[deps.BinDeps]] +deps = ["Libdl", "Pkg", "SHA", "URIParser", "Unicode"] +git-tree-sha1 = "1289b57e8cf019aede076edab0587eb9644175bd" +uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee" +version = "1.0.2" + +[[deps.BinaryProvider]] +deps = ["Libdl", "Logging", "SHA"] +git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058" +uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" +version = "0.5.10" + +[[deps.BitTwiddlingConvenienceFunctions]] +deps = ["Static"] +git-tree-sha1 = "28bbdbf0354959db89358d1d79d421ff31ef0b5e" +uuid = "62783981-4cbd-42fc-bca8-16325de8dc4b" +version = "0.1.3" + +[[deps.BlockArrays]] +deps = ["ArrayLayouts", "FillArrays", "LinearAlgebra"] +git-tree-sha1 = "7278f5ffec86a6c10233bf9c6be1a9c593012299" +uuid = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" +version = "0.16.13" + +[[deps.BlockBandedMatrices]] +deps = ["ArrayLayouts", "BandedMatrices", "BlockArrays", "FillArrays", "LinearAlgebra", "MatrixFactorizations", "SparseArrays", "Statistics"] +git-tree-sha1 = "8aaea69570a48b505383210451cbf36a7237a829" +uuid = "ffab5731-97b5-5995-9138-79e8c1846df0" +version = "0.11.4" + +[[deps.BufferedStreams]] +deps = ["Compat", "Test"] +git-tree-sha1 = "5d55b9486590fdda5905c275bb21ce1f0754020f" +uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" +version = "1.0.0" + +[[deps.CEnum]] +git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.4.1" + +[[deps.CPUSummary]] +deps = ["IfElse", "Static"] +git-tree-sha1 = "48e01b22ef077b07541309652f697595f8decf25" +uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" +version = "0.1.18" + +[[deps.CSV]] +deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings"] +git-tree-sha1 = "873fb188a4b9d76549b81465b1f75c82aaf59238" +uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +version = "0.10.4" + +[[deps.CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] +git-tree-sha1 = "ba75320aaa092b3e17c020a2d8b9e0a572dbfa6a" +uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" +version = "3.9.0" + +[[deps.Cassette]] +git-tree-sha1 = "063b2e77c5537a548c5bf2f44161f1d3e1ab3227" +uuid = "7057c7e9-c182-5462-911a-8362d720325c" +version = "0.3.10" + +[[deps.ChainRules]] +deps = ["ChainRulesCore", "Compat", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics"] +git-tree-sha1 = "8b887daa6af5daf705081061e36386190204ac87" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.28.1" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "9950387274246d08af38f6eef8cb5480862a435f" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.14.0" + +[[deps.ChangesOfVariables]] +deps = ["ChainRulesCore", "LinearAlgebra", "Test"] +git-tree-sha1 = "bf98fa45a0a4cee295de98d4c1462be26345b9a1" +uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +version = "0.1.2" + +[[deps.CloseOpenIntervals]] +deps = ["ArrayInterface", "Static"] +git-tree-sha1 = "f576084239e6bdf801007c80e27e2cc2cd963fe0" +uuid = "fb6a15b2-703c-40df-9091-08a04967cfa9" +version = "0.1.6" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.0" + +[[deps.ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.0" + +[[deps.CommonSolve]] +git-tree-sha1 = "68a0743f578349ada8bc911a5cbd5a2ef6ed6d1f" +uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" +version = "0.2.0" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "96b0bc6c52df76506efc8a441c6cf1adcb1babc4" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "3.42.0" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" + +[[deps.CompositeTypes]] +git-tree-sha1 = "d5b014b216dc891e81fea299638e4c10c657b582" +uuid = "b152e2b5-7a66-4b01-a709-34e65c35f657" +version = "0.1.2" + +[[deps.CompositionsBase]] +git-tree-sha1 = "455419f7e328a1a2493cabc6428d79e951349769" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.1" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "f74e9d5388b8620b4cee35d4c5a618dd4dc547f4" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.3.0" + +[[deps.ContextVariablesX]] +deps = ["Compat", "Logging", "UUIDs"] +git-tree-sha1 = "8ccaa8c655bc1b83d2da4d569c9b28254ababd6e" +uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +version = "0.1.2" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DEDataArrays]] +deps = ["ArrayInterface", "DocStringExtensions", "LinearAlgebra", "RecursiveArrayTools", "SciMLBase", "StaticArrays"] +git-tree-sha1 = "5e5f8f363c8c9a2415ef9185c4e0ff6966c87d52" +uuid = "754358af-613d-5f8d-9788-280bf1605d4c" +version = "0.2.2" + +[[deps.DataAPI]] +git-tree-sha1 = "cc70b17275652eb47bc9e5f81635981f13cea5c8" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.9.0" + +[[deps.DataDeps]] +deps = ["BinaryProvider", "HTTP", "Libdl", "Reexport", "SHA", "p7zip_jll"] +git-tree-sha1 = "4f0e41ff461d42cfc62ff0de4f1cd44c6e6b3771" +uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" +version = "0.7.7" + +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "ae02104e835f219b8930c7664b8012c93475c340" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.3.2" + +[[deps.DataLoaders]] +deps = ["DocStringExtensions", "LearnBase", "MLDataPattern", "Parameters", "Random", "ThreadPools"] +git-tree-sha1 = "4668e1c3fa50d9b9a91a1810b495b07008a8f6fb" +uuid = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" +version = "0.1.3" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "3daef5523dd2e769dad2365274f760ff5f282c7d" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.11" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DefineSingletons]] +git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" +uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" +version = "0.1.2" + +[[deps.DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[deps.DensityInterface]] +deps = ["InverseFunctions", "Test"] +git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" +uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d" +version = "0.4.0" + +[[deps.DiffEqBase]] +deps = ["ArrayInterface", "ChainRulesCore", "DEDataArrays", "DataStructures", "Distributions", "DocStringExtensions", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "IterativeSolvers", "LabelledArrays", "LinearAlgebra", "Logging", "MuladdMacro", "NonlinearSolve", "Parameters", "PreallocationTools", "Printf", "RecursiveArrayTools", "RecursiveFactorization", "Reexport", "Requires", "SciMLBase", "Setfield", "SparseArrays", "StaticArrays", "Statistics", "SuiteSparse", "ZygoteRules"] +git-tree-sha1 = "d1c8d8b645500d7dffec3355d29af6c4f8bfa6df" +uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" +version = "6.82.3" + +[[deps.DiffEqCallbacks]] +deps = ["DataStructures", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "NLsolve", "OrdinaryDiffEq", "Parameters", "RecipesBase", "RecursiveArrayTools", "SciMLBase", "StaticArrays"] +git-tree-sha1 = "c4b99e3a199e293e7290eea94ba89364d47ee557" +uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def" +version = "2.22.0" + +[[deps.DiffEqJump]] +deps = ["ArrayInterface", "Compat", "DataStructures", "DiffEqBase", "FunctionWrappers", "Graphs", "LinearAlgebra", "PoissonRandom", "Random", "RandomNumbers", "RecursiveArrayTools", "Reexport", "StaticArrays", "TreeViews", "UnPack"] +git-tree-sha1 = "eec5fd03c26dadc6b20f84d815309d060358e95b" +uuid = "c894b116-72e5-5b58-be3c-e6d8d4ac2b12" +version = "8.3.0" + +[[deps.DiffEqNoiseProcess]] +deps = ["DiffEqBase", "Distributions", "LinearAlgebra", "Optim", "PoissonRandom", "QuadGK", "Random", "Random123", "RandomNumbers", "RecipesBase", "RecursiveArrayTools", "Requires", "ResettableStacks", "SciMLBase", "StaticArrays", "Statistics"] +git-tree-sha1 = "d6839a44a268c69ef0ed927b22a6f43c8a4c2e73" +uuid = "77a26b50-5914-5dd7-bc55-306e6241c503" +version = "5.9.0" + +[[deps.DiffEqOperators]] +deps = ["BandedMatrices", "BlockBandedMatrices", "DiffEqBase", "DomainSets", "ForwardDiff", "LazyArrays", "LazyBandedMatrices", "LinearAlgebra", "LoopVectorization", "NNlib", "NonlinearSolve", "Requires", "RuntimeGeneratedFunctions", "SciMLBase", "SparseArrays", "SparseDiffTools", "StaticArrays", "SuiteSparse"] +git-tree-sha1 = "a7a5cfe90dfa64dba88bc17a4e0b208e403885cf" +uuid = "9fdde737-9c7f-55bf-ade8-46b3f136cc48" +version = "4.42.0" + +[[deps.DiffEqSensitivity]] +deps = ["Adapt", "ArrayInterface", "Cassette", "ChainRulesCore", "DiffEqBase", "DiffEqCallbacks", "DiffEqNoiseProcess", "DiffEqOperators", "DiffRules", "Distributions", "EllipsisNotation", "Enzyme", "FFTW", "FiniteDiff", "ForwardDiff", "GlobalSensitivity", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Parameters", "QuadGK", "QuasiMonteCarlo", "Random", "RandomNumbers", "RecursiveArrayTools", "Reexport", "Requires", "ReverseDiff", "SciMLBase", "SharedArrays", "Statistics", "StochasticDiffEq", "Tracker", "Zygote", "ZygoteRules"] +git-tree-sha1 = "6c6ef510268d7dff2af69e3d74f6080404639d32" +uuid = "41bf760c-e81c-5289-8e54-58b1f1f8abe2" +version = "6.72.0" + +[[deps.DiffResults]] +deps = ["StaticArrays"] +git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.0.3" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "dd933c4ef7b4c270aacd4eb88fa64c147492acf0" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.10.0" + +[[deps.Distances]] +deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "3258d0659f812acde79e8a74b11f17ac06d0ca04" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.10.7" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.Distributions]] +deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"] +git-tree-sha1 = "5a4168170ede913a2cd679e53c2123cb4b889795" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.25.53" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.8.6" + +[[deps.DomainSets]] +deps = ["CompositeTypes", "IntervalSets", "LinearAlgebra", "StaticArrays", "Statistics"] +git-tree-sha1 = "5f5f0b750ac576bcf2ab1d7782959894b304923e" +uuid = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" +version = "0.5.9" + +[[deps.Downloads]] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" + +[[deps.EllipsisNotation]] +deps = ["ArrayInterface"] +git-tree-sha1 = "d064b0340db45d48893e7604ec95e7a2dc9da904" +uuid = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" +version = "1.5.0" + +[[deps.Enzyme]] +deps = ["Adapt", "CEnum", "Enzyme_jll", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "ObjectFile", "Printf", "Test"] +git-tree-sha1 = "e673706c6fedcac810b678e238c980e89b656968" +uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" +version = "0.9.3" + +[[deps.Enzyme_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "f0a858b1c8b2b103c16f01ab6074e9a83c783781" +uuid = "7cc45869-7501-5eee-bdea-0790c847d4ef" +version = "0.0.29+0" + +[[deps.ExplicitFluxLayers]] +deps = ["CUDA", "ChainRulesCore", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Octavian", "Random", "Setfield", "SparseArrays", "Statistics"] +git-tree-sha1 = "ac921bfb0de25739c6a1bba2a70f5820cca529fd" +repo-rev = "ap/sparse" +repo-url = "https://github.com/avik-pal/ExplicitFluxLayers.jl.git" +uuid = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" +version = "0.2.0" + +[[deps.ExponentialUtilities]] +deps = ["ArrayInterface", "LinearAlgebra", "Printf", "Requires", "SparseArrays", "libblastrampoline_jll"] +git-tree-sha1 = "b026981973ccbe38682fbb4ccb0732fd6b1e1207" +uuid = "d4d017d3-3776-5f7e-afef-a10c40355c18" +version = "1.13.0" + +[[deps.ExprTools]] +git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.8" + +[[deps.FFTW]] +deps = ["AbstractFFTs", "FFTW_jll", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"] +git-tree-sha1 = "505876577b5481e50d089c1c68899dfb6faebc62" +uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +version = "1.4.6" + +[[deps.FFTW_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "c6033cc3892d0ef5bb9cd29b7f2f0331ea5184ea" +uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" +version = "3.3.10+0" + +[[deps.FLoops]] +deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] +git-tree-sha1 = "4391d3ed58db9dc5a9883b23a0578316b4798b1f" +uuid = "cc61a311-1640-44b5-9fba-1b764f453329" +version = "0.2.0" + +[[deps.FLoopsBase]] +deps = ["ContextVariablesX"] +git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" +uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" +version = "0.1.1" + +[[deps.FastBroadcast]] +deps = ["LinearAlgebra", "Polyester", "Static"] +git-tree-sha1 = "b6bf57ec7a3f294c97ae46124705a9e6b906a209" +uuid = "7034ab61-46d4-4ed7-9d0f-46aef9175898" +version = "0.1.15" + +[[deps.FastClosures]] +git-tree-sha1 = "acebe244d53ee1b461970f8910c235b259e772ef" +uuid = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" +version = "0.3.2" + +[[deps.FastDEQ]] +deps = ["CUDA", "ChainRulesCore", "DataLoaders", "DiffEqBase", "DiffEqCallbacks", "DiffEqSensitivity", "ExplicitFluxLayers", "Flux", "Functors", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Random", "Reexport", "Requires", "SciMLBase", "Setfield", "Statistics", "SteadyStateDiffEq", "UnPack", "Zygote"] +git-tree-sha1 = "23164863ce195e94e8e5edc643e8af6b60cc8402" +repo-rev = "ap/efl" +repo-url = ".." +uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" +version = "0.1.0" + +[[deps.FileIO]] +deps = ["Pkg", "Requires", "UUIDs"] +git-tree-sha1 = "80ced645013a5dbdc52cf70329399c35ce007fae" +uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +version = "1.13.0" + +[[deps.FilePathsBase]] +deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] +git-tree-sha1 = "129b104185df66e408edd6625d480b7f9e9823a0" +uuid = "48062228-2e41-5def-b9a4-89aafe57970f" +version = "0.9.18" + +[[deps.FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +git-tree-sha1 = "246621d23d1f43e3b9c368bf3b72b2331a27c286" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.13.2" + +[[deps.FiniteDiff]] +deps = ["ArrayInterface", "LinearAlgebra", "Requires", "SparseArrays", "StaticArrays"] +git-tree-sha1 = "56956d1e4c1221000b7781104c58c34019792951" +uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" +version = "2.11.0" + +[[deps.FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.4" + +[[deps.Flux]] +deps = ["Adapt", "ArrayInterface", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Test", "Zygote"] +git-tree-sha1 = "e932b26ac243f312af2d9009de08b89be0e01a84" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.13.0" + +[[deps.FluxMPI]] +deps = ["CUDA", "Dates", "Flux", "Functors", "LearnBase", "MLDataUtils", "MPI", "Optimisers", "Setfield", "Zygote"] +git-tree-sha1 = "7e0a62fddea90780bfc1e29c85c2eb36fcffac8d" +repo-rev = "ap/opt" +repo-url = "https://github.com/avik-pal/FluxMPI.jl.git" +uuid = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" +version = "0.3.0" + +[[deps.FoldsThreads]] +deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] +git-tree-sha1 = "eb8e1989b9028f7e0985b4268dabe94682249025" +uuid = "9c68100b-dfe1-47cf-94c8-95104e173443" +version = "0.1.1" + +[[deps.Format]] +git-tree-sha1 = "03bcdf8ab1a5b9e6455ccb45c30910d282aa09f4" +uuid = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" +version = "1.3.2" + +[[deps.Formatting]] +deps = ["Printf"] +git-tree-sha1 = "8339d61043228fdd3eb658d86c926cb282ae72a8" +uuid = "59287772-0a20-5a39-b81b-1366585eb4c0" +version = "0.4.2" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "1bd6fc0c344fc0cbee1f42f8d2e7ec8253dda2d2" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.25" + +[[deps.FunctionWrappers]] +git-tree-sha1 = "241552bc2209f0fa068b6415b1942cc0aa486bcc" +uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" +version = "1.1.2" + +[[deps.Functors]] +git-tree-sha1 = "223fffa49ca0ff9ce4f875be001ffe173b2b7de4" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.2.8" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GPUArrays]] +deps = ["Adapt", "LLVM", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"] +git-tree-sha1 = "c783e8883028bf26fb05ed4022c450ef44edd875" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "8.3.2" + +[[deps.GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "556190e1e0ea3e37d83059fc9aa576f1e2104375" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.14.1" + +[[deps.GZip]] +deps = ["Libdl"] +git-tree-sha1 = "039be665faf0b8ae36e089cd694233f5dee3f7d6" +uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" +version = "0.5.1" + +[[deps.Glob]] +git-tree-sha1 = "4df9f7e06108728ebf00a0a11edee4b29a482bb2" +uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" +version = "1.3.0" + +[[deps.GlobalSensitivity]] +deps = ["Distributions", "FFTW", "ForwardDiff", "KernelDensity", "LinearAlgebra", "Parameters", "QuasiMonteCarlo", "Random", "RecursiveArrayTools", "Statistics", "StatsBase", "Trapz"] +git-tree-sha1 = "0324e96625317e8f1cd51196be542de18788e3af" +uuid = "af5da776-676b-467e-8baf-acd8249e4f0f" +version = "1.3.2" + +[[deps.Graphs]] +deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] +git-tree-sha1 = "57c021de207e234108a6f1454003120a1bf350c4" +uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" +version = "1.6.0" + +[[deps.HDF5]] +deps = ["Compat", "HDF5_jll", "Libdl", "Mmap", "Random", "Requires"] +git-tree-sha1 = "cdd249512de03cbf8370365a0a08b9a24955dca9" +uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" +version = "0.16.6" + +[[deps.HDF5_jll]] +deps = ["Artifacts", "JLLWrappers", "LibCURL_jll", "Libdl", "OpenSSL_jll", "Pkg", "Zlib_jll"] +git-tree-sha1 = "bab67c0d1c4662d2c4be8c6007751b0b6111de5c" +uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" +version = "1.12.1+0" + +[[deps.HTTP]] +deps = ["Base64", "Dates", "IniFile", "Logging", "MbedTLS", "NetworkOptions", "Sockets", "URIs"] +git-tree-sha1 = "0fa77022fe4b511826b39c894c90daf5fce3334a" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "0.9.17" + +[[deps.HostCPUFeatures]] +deps = ["BitTwiddlingConvenienceFunctions", "IfElse", "Libdl", "Static"] +git-tree-sha1 = "18be5268cf415b5e27f34980ed25a7d34261aa83" +uuid = "3e5b6fbb-0976-4d2c-9146-d79de83f2fb0" +version = "0.1.7" + +[[deps.Hwloc]] +deps = ["Hwloc_jll"] +git-tree-sha1 = "92d99146066c5c6888d5a3abc871e6a214388b91" +uuid = "0e44f5e4-bd66-52a0-8798-143a42290a1d" +version = "2.0.0" + +[[deps.Hwloc_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "303d70c961317c4c20fafaf5dbe0e6d610c38542" +uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" +version = "2.7.1+0" + +[[deps.IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "7f43342f8d5fd30ead0ba1b49ab1a3af3b787d24" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.5" + +[[deps.IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" + +[[deps.Inflate]] +git-tree-sha1 = "f5fc07d4e706b84f72d54eedcc1c13d92fb0871c" +uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" +version = "0.1.2" + +[[deps.IniFile]] +git-tree-sha1 = "f550e6e32074c939295eb5ea6de31849ac2c9625" +uuid = "83e8ac13-25f8-5344-8a64-a9f2b223428f" +version = "0.5.1" + +[[deps.InitialValues]] +git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" +uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +version = "0.3.1" + +[[deps.InlineStrings]] +deps = ["Parsers"] +git-tree-sha1 = "61feba885fac3a407465726d0c330b3055df897f" +uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +version = "1.1.2" + +[[deps.IntelOpenMP_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "d979e54b71da82f3a65b62553da4fc3d18c9004c" +uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" +version = "2018.0.3+2" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InternedStrings]] +deps = ["Random", "Test"] +git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" +uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" +version = "0.7.0" + +[[deps.Interpolations]] +deps = ["AxisAlgorithms", "ChainRulesCore", "LinearAlgebra", "OffsetArrays", "Random", "Ratios", "Requires", "SharedArrays", "SparseArrays", "StaticArrays", "WoodburyMatrices"] +git-tree-sha1 = "b15fc0a95c564ca2e0a7ae12c1f095ca848ceb31" +uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +version = "0.13.5" + +[[deps.IntervalSets]] +deps = ["Dates", "EllipsisNotation", "Statistics"] +git-tree-sha1 = "bcf640979ee55b652f3b01650444eb7bbe3ea837" +uuid = "8197267c-284f-5f27-9208-e0e47529a953" +version = "0.5.4" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "91b5dcf362c5add98049e6c29ee756910b03051d" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.3" + +[[deps.InvertedIndices]] +git-tree-sha1 = "bee5f1ef5bf65df56bdd2e40447590b272a5471f" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.1.0" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.1.1" + +[[deps.IterativeSolvers]] +deps = ["LinearAlgebra", "Printf", "Random", "RecipesBase", "SparseArrays"] +git-tree-sha1 = "1169632f425f79429f245113b775a0e3d121457c" +uuid = "42fd0dbc-a981-5370-80f2-aaf504508153" +version = "0.9.2" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLD2]] +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "Printf", "Reexport", "TranscodingStreams", "UUIDs"] +git-tree-sha1 = "81b9477b49402b47fbe7f7ae0b252077f53e4a08" +uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +version = "0.4.22" + +[[deps.JLLWrappers]] +deps = ["Preferences"] +git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.4.1" + +[[deps.JSON3]] +deps = ["Dates", "Mmap", "Parsers", "StructTypes", "UUIDs"] +git-tree-sha1 = "8c1f668b24d999fb47baf80436194fdccec65ad2" +uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +version = "1.9.4" + +[[deps.JuliaVariables]] +deps = ["MLStyle", "NameResolution"] +git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" +uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" +version = "0.2.4" + +[[deps.KLU]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse_jll"] +git-tree-sha1 = "cae5e3dfd89b209e01bcd65b3a25e74462c67ee0" +uuid = "ef3ab10e-7fda-4108-b977-705223b18434" +version = "0.3.0" + +[[deps.KernelDensity]] +deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"] +git-tree-sha1 = "591e8dc09ad18386189610acafb970032c519707" +uuid = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" +version = "0.6.3" + +[[deps.Krylov]] +deps = ["LinearAlgebra", "Printf", "SparseArrays"] +git-tree-sha1 = "82f5afb342a5624dc4651981584a841f6088166b" +uuid = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" +version = "0.8.0" + +[[deps.KrylovKit]] +deps = ["LinearAlgebra", "Printf"] +git-tree-sha1 = "0328ad9966ae29ccefb4e1b9bfd8c8867e4360df" +uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +version = "0.5.3" + +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] +git-tree-sha1 = "c9b86064be5ae0f63e50816a5a90b08c474507ae" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "4.9.1" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "5558ad3c8972d602451efe9d81c78ec14ef4f5ef" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.14+2" + +[[deps.LabelledArrays]] +deps = ["ArrayInterface", "ChainRulesCore", "LinearAlgebra", "MacroTools", "StaticArrays"] +git-tree-sha1 = "fbd884a02f8bf98fd90c53c1c9d2b21f9f30f42a" +uuid = "2ee39098-c373-598a-b85f-a56591580800" +version = "1.8.0" + +[[deps.LatinHypercubeSampling]] +deps = ["Random", "StableRNGs", "StatsBase", "Test"] +git-tree-sha1 = "42938ab65e9ed3c3029a8d2c58382ca75bdab243" +uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" +version = "1.8.0" + +[[deps.LatticeRules]] +deps = ["Random"] +git-tree-sha1 = "7f5b02258a3ca0221a6a9710b0a0a2e8fb4957fe" +uuid = "73f95e8e-ec14-4e6a-8b18-0d2e271c4e55" +version = "0.0.1" + +[[deps.LayoutPointers]] +deps = ["ArrayInterface", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static"] +git-tree-sha1 = "b651f573812d6c36c22c944dd66ef3ab2283dfa1" +uuid = "10f19ff3-798f-405d-979b-55457f8fc047" +version = "0.1.6" + +[[deps.LazyArrays]] +deps = ["ArrayLayouts", "FillArrays", "LinearAlgebra", "MacroTools", "MatrixFactorizations", "SparseArrays", "StaticArrays"] +git-tree-sha1 = "721bebe4d0f8581c18fccf272c62000e22a80a2d" +uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02" +version = "0.22.10" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LazyBandedMatrices]] +deps = ["ArrayLayouts", "BandedMatrices", "BlockArrays", "BlockBandedMatrices", "FillArrays", "LazyArrays", "LinearAlgebra", "MatrixFactorizations", "SparseArrays", "StaticArrays"] +git-tree-sha1 = "b1708e45e6b4308593904a14d0e5b0970d9ed0bb" +uuid = "d7e5e226-e90b-4449-9968-0f923699bf6f" +version = "0.7.12" + +[[deps.LearnBase]] +deps = ["LinearAlgebra", "StatsBase"] +git-tree-sha1 = "47e6f4623c1db88570c7a7fa66c6528b92ba4725" +uuid = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6" +version = "0.3.0" + +[[deps.LevyArea]] +deps = ["LinearAlgebra", "Random", "SpecialFunctions"] +git-tree-sha1 = "56513a09b8e0ae6485f34401ea9e2f31357958ec" +uuid = "2d8b4e74-eb68-11e8-0fb9-d5eb67b50637" +version = "1.0.0" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" + +[[deps.LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.Libiconv_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "42b62845d70a619f063a7da093d995ec8e15e778" +uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" +version = "1.16.1+1" + +[[deps.LineSearches]] +deps = ["LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "Printf"] +git-tree-sha1 = "f27132e551e959b3667d8c93eae90973225032dd" +uuid = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" +version = "7.1.1" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LinearSolve]] +deps = ["ArrayInterface", "DocStringExtensions", "IterativeSolvers", "KLU", "Krylov", "KrylovKit", "LinearAlgebra", "RecursiveFactorization", "Reexport", "Requires", "SciMLBase", "Setfield", "SparseArrays", "SuiteSparse", "UnPack"] +git-tree-sha1 = "6eb8e10ed29b85673495c29bd77ee0dfa8929977" +uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +version = "1.15.0" + +[[deps.LogExpFunctions]] +deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "58f25e56b706f95125dcb796f39e1fb01d913a71" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.10" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.LoopVectorization]] +deps = ["ArrayInterface", "CPUSummary", "ChainRulesCore", "CloseOpenIntervals", "DocStringExtensions", "ForwardDiff", "HostCPUFeatures", "IfElse", "LayoutPointers", "LinearAlgebra", "OffsetArrays", "PolyesterWeave", "SIMDDualNumbers", "SLEEFPirates", "SpecialFunctions", "Static", "ThreadingUtilities", "UnPack", "VectorizationBase"] +git-tree-sha1 = "f9d84dcb46419e973872b32c051e5baad2d29de7" +uuid = "bdcacae8-1622-11e9-2a5c-532679323890" +version = "0.12.107" + +[[deps.MAT]] +deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] +git-tree-sha1 = "971be550166fe3f604d28715302b58a3f7293160" +uuid = "23992714-dd62-5051-b70f-ba57cb901cac" +version = "0.10.3" + +[[deps.MKL_jll]] +deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "e595b205efd49508358f7dc670a940c790204629" +uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" +version = "2022.0.0+0" + +[[deps.MLDataPattern]] +deps = ["LearnBase", "MLLabelUtils", "Random", "SparseArrays", "StatsBase"] +git-tree-sha1 = "e99514e96e8b8129bb333c69e063a56ab6402b5b" +uuid = "9920b226-0b2a-5f5f-9153-9aa70a013f8b" +version = "0.5.4" + +[[deps.MLDataUtils]] +deps = ["DataFrames", "DelimitedFiles", "LearnBase", "MLDataPattern", "MLLabelUtils", "Statistics", "StatsBase"] +git-tree-sha1 = "ee54803aea12b9c8ee972e78ece11ac6023715e6" +uuid = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" +version = "0.5.4" + +[[deps.MLDatasets]] +deps = ["BinDeps", "CSV", "ColorTypes", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "JLD2", "JSON3", "MAT", "MLUtils", "Pickle", "Requires", "SparseArrays", "Tables"] +git-tree-sha1 = "862c3a31a5a6dfc68e78e2e1634dac1d3b0f654e" +uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" +version = "0.5.16" + +[[deps.MLLabelUtils]] +deps = ["LearnBase", "MappedArrays", "StatsBase"] +git-tree-sha1 = "fd75d4b0c4016e047bbb6263eecf7ae3891af522" +uuid = "66a33bbf-0c2b-5fc8-a008-9da813334f0a" +version = "0.5.7" + +[[deps.MLStyle]] +git-tree-sha1 = "594e189325f66e23a8818e5beb11c43bb0141bcd" +uuid = "d8e11817-5142-5d16-987a-aa16d5891078" +version = "0.4.10" + +[[deps.MLUtils]] +deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase"] +git-tree-sha1 = "32eeb46fa393ae36a4127c9442ade478c8d01117" +uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" +version = "0.2.3" + +[[deps.MPI]] +deps = ["Distributed", "DocStringExtensions", "Libdl", "MPICH_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "Pkg", "Random", "Requires", "Serialization", "Sockets"] +git-tree-sha1 = "d56a80d8cf8b9dc3050116346b3d83432b1912c0" +uuid = "da04e1cc-30fd-572f-bb4f-1f8673147195" +version = "0.19.2" + +[[deps.MPICH_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "8eed51eb836c8f47781cdb493ffd5f56370c0496" +uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" +version = "4.0.1+0" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.9" + +[[deps.ManualMemory]] +git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" +uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" +version = "0.1.8" + +[[deps.MappedArrays]] +git-tree-sha1 = "e8b359ef06ec72e8c030463fe02efe5527ee5142" +uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" +version = "0.4.1" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MatrixFactorizations]] +deps = ["ArrayLayouts", "LinearAlgebra", "Printf", "Random"] +git-tree-sha1 = "2212d36f97e01347adb1460a6914e20f2feee853" +uuid = "a3b82374-2e81-5b9e-98ce-41277c0e4c87" +version = "0.9.1" + +[[deps.MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "Random", "Sockets"] +git-tree-sha1 = "1c38e51c3d08ef2278062ebceade0e46cefc96fe" +uuid = "739be429-bea8-5141-9913-cc70e7f3736d" +version = "1.0.3" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" + +[[deps.MicroCollections]] +deps = ["BangBang", "InitialValues", "Setfield"] +git-tree-sha1 = "6bb7786e4f24d44b4e29df03c69add1b63d88f01" +uuid = "128add7d-3638-4c79-886c-908ea0c25c34" +version = "0.1.2" + +[[deps.MicrosoftMPI_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "a16aa086d335ed7e0170c5265247db29172af2f9" +uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" +version = "10.1.3+2" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.0.2" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" + +[[deps.MuladdMacro]] +git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" +uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" +version = "0.2.2" + +[[deps.NLSolversBase]] +deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] +git-tree-sha1 = "50310f934e55e5ca3912fb941dec199b49ca9b68" +uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" +version = "7.8.2" + +[[deps.NLsolve]] +deps = ["Distances", "LineSearches", "LinearAlgebra", "NLSolversBase", "Printf", "Reexport"] +git-tree-sha1 = "019f12e9a1a7880459d0173c182e6a99365d7ac1" +uuid = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" +version = "4.5.1" + +[[deps.NNlib]] +deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] +git-tree-sha1 = "a59a614b8b4ea6dc1dcec8c6514e251f13ccbe10" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.8.4" + +[[deps.NNlibCUDA]] +deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] +git-tree-sha1 = "0d18b4c80a92a00d3d96e8f9677511a7422a946e" +uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" +version = "0.2.2" + +[[deps.NaNMath]] +git-tree-sha1 = "b086b7ea07f8e38cf122f5016af580881ac914fe" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.7" + +[[deps.NameResolution]] +deps = ["PrettyPrint"] +git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" +uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" +version = "0.1.5" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" + +[[deps.NonlinearSolve]] +deps = ["ArrayInterface", "FiniteDiff", "ForwardDiff", "IterativeSolvers", "LinearAlgebra", "RecursiveArrayTools", "RecursiveFactorization", "Reexport", "SciMLBase", "Setfield", "StaticArrays", "UnPack"] +git-tree-sha1 = "aeebff6a2a23506e5029fd2248a26aca98e477b3" +uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" +version = "0.3.16" + +[[deps.ObjectFile]] +deps = ["Reexport", "StructIO"] +git-tree-sha1 = "55ce61d43409b1fb0279d1781bf3b0f22c83ab3b" +uuid = "d8793406-e978-5875-9003-1fc021f44a92" +version = "0.3.7" + +[[deps.Octavian]] +deps = ["ArrayInterface", "CPUSummary", "IfElse", "LoopVectorization", "ManualMemory", "PolyesterWeave", "Requires", "Static", "ThreadingUtilities", "VectorizationBase"] +git-tree-sha1 = "26c004c96dc634cefe9174cb9180c496f6c7e100" +uuid = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" +version = "0.3.13" + +[[deps.OffsetArrays]] +deps = ["Adapt"] +git-tree-sha1 = "043017e0bdeff61cfbb7afeb558ab29536bbb5ed" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.10.8" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" + +[[deps.OpenMPI_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "6340586e076b2abd41f5ba1a3b9c774ec6b30fde" +uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" +version = "4.1.2+0" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "ab05aa4cc89736e95915b01e7279e61b1bfe33b8" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "1.1.14+0" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Optim]] +deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"] +git-tree-sha1 = "bc0a748740e8bc5eeb9ea6031e6f050de1fc0ba2" +uuid = "429524aa-4258-5aef-a3af-852621145aeb" +version = "1.6.2" + +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "e440ecef249dea69e79248857e800e71820d386c" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.2.1" + +[[deps.OrderedCollections]] +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.4.1" + +[[deps.OrdinaryDiffEq]] +deps = ["Adapt", "ArrayInterface", "DataStructures", "DiffEqBase", "DocStringExtensions", "ExponentialUtilities", "FastClosures", "FiniteDiff", "ForwardDiff", "LinearAlgebra", "LinearSolve", "Logging", "LoopVectorization", "MacroTools", "MuladdMacro", "NLsolve", "NonlinearSolve", "Polyester", "PreallocationTools", "RecursiveArrayTools", "Reexport", "SciMLBase", "SparseArrays", "SparseDiffTools", "StaticArrays", "UnPack"] +git-tree-sha1 = "c5568ed45ee56cb4a5e3cebff3b91541ae016a83" +uuid = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +version = "6.9.0" + +[[deps.PDMats]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "e8185b83b9fc56eb6456200e873ce598ebc7f262" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.11.7" + +[[deps.Parameters]] +deps = ["OrderedCollections", "UnPack"] +git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" +uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" +version = "0.12.3" + +[[deps.Parsers]] +deps = ["Dates"] +git-tree-sha1 = "621f4f3b4977325b9128d5fae7a8b4829a0c2222" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.2.4" + +[[deps.Pickle]] +deps = ["DataStructures", "InternedStrings", "Serialization", "SparseArrays", "Strided", "StringEncodings", "ZipFile"] +git-tree-sha1 = "de8165bc4d1c448824cefa98cd5cd281dc01d9b2" +uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" +version = "0.3.0" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[deps.PoissonRandom]] +deps = ["Random", "Statistics", "Test"] +git-tree-sha1 = "44d018211a56626288b5d3f8c6497d28c26dc850" +uuid = "e409e4f3-bfea-5376-8464-e040bb5c01ab" +version = "0.4.0" + +[[deps.Polyester]] +deps = ["ArrayInterface", "BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "ManualMemory", "PolyesterWeave", "Requires", "Static", "StrideArraysCore", "ThreadingUtilities"] +git-tree-sha1 = "8d95a735921204f5d551ac300b20d802a150433a" +uuid = "f517fe37-dbe3-4b94-8317-1923a5111588" +version = "0.6.8" + +[[deps.PolyesterWeave]] +deps = ["BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "Static", "ThreadingUtilities"] +git-tree-sha1 = "7e597df97e46ffb1c8adbaddfa56908a7a20194b" +uuid = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad" +version = "0.1.5" + +[[deps.PooledArrays]] +deps = ["DataAPI", "Future"] +git-tree-sha1 = "28ef6c7ce353f0b35d0df0d5930e0d072c1f5b9b" +uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +version = "1.4.1" + +[[deps.PositiveFactorizations]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "17275485f373e6673f7e7f97051f703ed5b15b20" +uuid = "85a6dd25-e78a-55b7-8502-1745935b8125" +version = "0.2.4" + +[[deps.PreallocationTools]] +deps = ["Adapt", "ArrayInterface", "ForwardDiff", "LabelledArrays"] +git-tree-sha1 = "6c138c8510111fa47b5d2ed8ada482d97e279bee" +uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46" +version = "0.2.4" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "d3538e7f8a790dc8903519090857ef8e1283eecd" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.2.5" + +[[deps.PrettyPrint]] +git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" +uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" +version = "0.2.0" + +[[deps.PrettyTables]] +deps = ["Crayons", "Formatting", "Markdown", "Reexport", "Tables"] +git-tree-sha1 = "dfb54c4e414caa595a1f2ed759b160f5a3ddcba5" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "1.3.1" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.4" + +[[deps.QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "78aadffb3efd2155af139781b8a8df1ef279ea39" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.4.2" + +[[deps.QuasiMonteCarlo]] +deps = ["Distributions", "LatinHypercubeSampling", "LatticeRules", "Sobol"] +git-tree-sha1 = "bc69c718a83951dcb999404ff267a7b8c39c1c63" +uuid = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" +version = "0.2.4" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA", "Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.Random123]] +deps = ["Random", "RandomNumbers"] +git-tree-sha1 = "afeacaecf4ed1649555a19cb2cad3c141bbc9474" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.5.0" + +[[deps.RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + +[[deps.Ratios]] +deps = ["Requires"] +git-tree-sha1 = "dc84268fe0e3335a62e315a3a7cf2afa7178a734" +uuid = "c84ed2f1-dad5-54f0-aa8e-dbefe2724439" +version = "0.4.3" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.RecipesBase]] +git-tree-sha1 = "6bf3f380ff52ce0832ddd3a2a7b9538ed1bcca7d" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "1.2.1" + +[[deps.RecursiveArrayTools]] +deps = ["Adapt", "ArrayInterface", "ChainRulesCore", "DocStringExtensions", "FillArrays", "LinearAlgebra", "RecipesBase", "Requires", "StaticArrays", "Statistics", "ZygoteRules"] +git-tree-sha1 = "bfe14f127f3e7def02a6c2b1940b39d0dabaa3ef" +uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" +version = "2.26.3" + +[[deps.RecursiveFactorization]] +deps = ["LinearAlgebra", "LoopVectorization", "Polyester", "StrideArraysCore", "TriangularSolve"] +git-tree-sha1 = "7ad4c2ef15b7aecd767b3921c0d255d39b3603ea" +uuid = "f2c3362d-daeb-58d1-803e-2bc74f2840b4" +version = "0.2.9" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.ResettableStacks]] +deps = ["StaticArrays"] +git-tree-sha1 = "256eeeec186fa7f26f2801732774ccf277f05db9" +uuid = "ae5879a3-cd67-5da8-be7f-38c6eb64a37b" +version = "1.1.1" + +[[deps.ReverseDiff]] +deps = ["ChainRulesCore", "DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"] +git-tree-sha1 = "8d85c98fc33d4d37d88c8f9ccee4f1f3f98e56f4" +uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +version = "1.12.0" + +[[deps.Rmath]] +deps = ["Random", "Rmath_jll"] +git-tree-sha1 = "bf3188feca147ce108c76ad82c2792c57abe7b1f" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.7.0" + +[[deps.Rmath_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "68db32dff12bb6127bac73c209881191bf0efbb7" +uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" +version = "0.3.0+0" + +[[deps.RuntimeGeneratedFunctions]] +deps = ["ExprTools", "SHA", "Serialization"] +git-tree-sha1 = "cdc1e4278e91a6ad530770ebb327f9ed83cf10c4" +uuid = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" +version = "0.5.3" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[deps.SIMDDualNumbers]] +deps = ["ForwardDiff", "IfElse", "SLEEFPirates", "VectorizationBase"] +git-tree-sha1 = "62c2da6eb66de8bb88081d20528647140d4daa0e" +uuid = "3cdde19b-5bb0-4aaf-8931-af3e248e098b" +version = "0.1.0" + +[[deps.SIMDTypes]] +git-tree-sha1 = "330289636fb8107c5f32088d2741e9fd7a061a5c" +uuid = "94e857df-77ce-4151-89e5-788b33177be4" +version = "0.1.0" + +[[deps.SLEEFPirates]] +deps = ["IfElse", "Static", "VectorizationBase"] +git-tree-sha1 = "d4c366b135fc2e1af7a000473e08edc5afd94819" +uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa" +version = "0.6.31" + +[[deps.SciMLBase]] +deps = ["ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "RecipesBase", "RecursiveArrayTools", "StaticArrays", "Statistics", "Tables", "TreeViews"] +git-tree-sha1 = "61159e034c4cb36b76ad2926bb5bf8c28cc2fb12" +uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +version = "1.29.0" + +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "6a2f7d70512d205ca8c7ee31bfa9f142fe74310c" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.3.12" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "Requires"] +git-tree-sha1 = "38d88503f695eb0301479bc9b0d4320b378bafe5" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "0.8.2" + +[[deps.SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[deps.ShowCases]] +git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" +uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +version = "0.1.0" + +[[deps.SimpleTraits]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" +uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +version = "0.9.4" + +[[deps.Sobol]] +deps = ["DelimitedFiles", "Random"] +git-tree-sha1 = "5a74ac22a9daef23705f010f72c81d6925b19df8" +uuid = "ed01d8cd-4d21-5b2a-85b4-cc3bdc58bad4" +version = "1.5.0" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.0.1" + +[[deps.SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[deps.SparseDiffTools]] +deps = ["Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "Graphs", "LinearAlgebra", "Requires", "SparseArrays", "StaticArrays", "VertexSafeGraphs"] +git-tree-sha1 = "314a07e191ea4a5ea5a2f9d6b39f03833bde5e08" +uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" +version = "1.21.0" + +[[deps.SpecialFunctions]] +deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "5ba658aeecaaf96923dce0da9e703bd1fe7666f9" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.1.4" + +[[deps.SplittablesBase]] +deps = ["Setfield", "Test"] +git-tree-sha1 = "39c9f91521de844bad65049efd4f9223e7ed43f9" +uuid = "171d559e-b47b-412a-8079-5efa626c420e" +version = "0.1.14" + +[[deps.StableRNGs]] +deps = ["Random", "Test"] +git-tree-sha1 = "3be7d49667040add7ee151fefaf1f8c04c8c8276" +uuid = "860ef19b-820b-49d6-a774-d7a799459cd3" +version = "1.0.0" + +[[deps.Static]] +deps = ["IfElse"] +git-tree-sha1 = "87e9954dfa33fd145694e42337bdd3d5b07021a6" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.6.0" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "4f6ec5d99a28e1a749559ef7dd518663c5eca3d5" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.4.3" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "c3d8ba7f3fa0625b062b82853a7d5229cb728b6b" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.2.1" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "8977b17906b0a1cc74ab2e3a05faa16cf08a8291" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.33.16" + +[[deps.StatsFuns]] +deps = ["ChainRulesCore", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] +git-tree-sha1 = "5950925ff997ed6fb3e985dcce8eb1ba42a0bbe7" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "0.9.18" + +[[deps.SteadyStateDiffEq]] +deps = ["DiffEqBase", "DiffEqCallbacks", "LinearAlgebra", "NLsolve", "Reexport", "SciMLBase"] +git-tree-sha1 = "3e057e1f9f12d18cac32011aed9e61eef6c1c0ce" +uuid = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" +version = "1.6.6" + +[[deps.StochasticDiffEq]] +deps = ["Adapt", "ArrayInterface", "DataStructures", "DiffEqBase", "DiffEqJump", "DiffEqNoiseProcess", "DocStringExtensions", "FillArrays", "FiniteDiff", "ForwardDiff", "LevyArea", "LinearAlgebra", "Logging", "MuladdMacro", "NLsolve", "OrdinaryDiffEq", "Random", "RandomNumbers", "RecursiveArrayTools", "Reexport", "SparseArrays", "SparseDiffTools", "StaticArrays", "UnPack"] +git-tree-sha1 = "4d428684218ac7a3dc54aaeb3f76e03bf892c33c" +uuid = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" +version = "6.46.0" + +[[deps.StrideArraysCore]] +deps = ["ArrayInterface", "CloseOpenIntervals", "IfElse", "LayoutPointers", "ManualMemory", "Requires", "SIMDTypes", "Static", "ThreadingUtilities"] +git-tree-sha1 = "c7e0392560f15771003cce388fe8471d17941374" +uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" +version = "0.2.19" + +[[deps.Strided]] +deps = ["LinearAlgebra", "TupleTools"] +git-tree-sha1 = "4d581938087ca90eab9bd4bb6d270edaefd70dcd" +uuid = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" +version = "1.1.2" + +[[deps.StringEncodings]] +deps = ["Libiconv_jll"] +git-tree-sha1 = "50ccd5ddb00d19392577902f0079267a72c5ab04" +uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" +version = "0.3.5" + +[[deps.StructIO]] +deps = ["Test"] +git-tree-sha1 = "010dc73c7146869c042b49adcdb6bf528c12e859" +uuid = "53d494c1-5632-5724-8f4c-31dff12d585f" +version = "0.3.0" + +[[deps.StructTypes]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "d24a825a95a6d98c385001212dc9020d609f2d4f" +uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +version = "1.8.1" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] +git-tree-sha1 = "5ce79ce186cc678bbb5c5681ca3379d1ddae11a1" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.7.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.ThreadPools]] +deps = ["Printf", "RecipesBase", "Statistics"] +git-tree-sha1 = "705ccc29d575b87cceb359dfea19f4653d06df8f" +uuid = "b189fb0b-2eb5-4ed4-bc0c-d34c51242431" +version = "1.2.1" + +[[deps.ThreadingUtilities]] +deps = ["ManualMemory"] +git-tree-sha1 = "f8629df51cab659d70d2e5618a430b4d3f37f2c3" +uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" +version = "0.5.0" + +[[deps.TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "d60b0c96a16aaa42138d5d38ad386df672cb8bd8" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.16" + +[[deps.Tracker]] +deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"] +git-tree-sha1 = "0874c1b5de1b5529b776cfeca3ec0acfada97b1b" +uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +version = "0.2.20" + +[[deps.TranscodingStreams]] +deps = ["Random", "Test"] +git-tree-sha1 = "216b95ea110b5972db65aa90f88d8d89dcb8851c" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.9.6" + +[[deps.Transducers]] +deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] +git-tree-sha1 = "c76399a3bbe6f5a88faa33c8f8a65aa631d95013" +uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" +version = "0.4.73" + +[[deps.Trapz]] +git-tree-sha1 = "79eb0ed763084a3e7de81fe1838379ac6a23b6a0" +uuid = "592b5752-818d-11e9-1e9a-2b8ca4a44cd1" +version = "2.0.3" + +[[deps.TreeViews]] +deps = ["Test"] +git-tree-sha1 = "8d0d7a3fe2f30d6a7f833a5f19f7c7a5b396eae6" +uuid = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7" +version = "0.3.0" + +[[deps.TriangularSolve]] +deps = ["CloseOpenIntervals", "IfElse", "LayoutPointers", "LinearAlgebra", "LoopVectorization", "Polyester", "Static", "VectorizationBase"] +git-tree-sha1 = "b8d08f55b02625770c09615d96927b3a8396925e" +uuid = "d5829a12-d9aa-46ab-831f-fb7c9ab06edf" +version = "0.1.11" + +[[deps.TupleTools]] +git-tree-sha1 = "3c712976c47707ff893cf6ba4354aa14db1d8938" +uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" +version = "1.3.0" + +[[deps.URIParser]] +deps = ["Unicode"] +git-tree-sha1 = "53a9f49546b8d2dd2e688d216421d050c9a31d0d" +uuid = "30578b45-9adc-5946-b283-645ec420af67" +version = "0.4.1" + +[[deps.URIs]] +git-tree-sha1 = "97bbe755a53fe859669cd907f2d96aee8d2c1355" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.3.0" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.UnPack]] +git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" +uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +version = "1.0.2" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.VectorizationBase]] +deps = ["ArrayInterface", "CPUSummary", "HostCPUFeatures", "Hwloc", "IfElse", "LayoutPointers", "Libdl", "LinearAlgebra", "SIMDTypes", "Static"] +git-tree-sha1 = "9d1b533f597d87ce9b4abd36a2ce4664f08e08ed" +uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" +version = "0.21.29" + +[[deps.VertexSafeGraphs]] +deps = ["Graphs"] +git-tree-sha1 = "8351f8d73d7e880bfc042a8b6922684ebeafb35c" +uuid = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f" +version = "0.2.0" + +[[deps.WeakRefStrings]] +deps = ["DataAPI", "InlineStrings", "Parsers"] +git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" +uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" +version = "1.4.2" + +[[deps.WoodburyMatrices]] +deps = ["LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "de67fa59e33ad156a590055375a30b23c40299d3" +uuid = "efce3f68-66dc-5838-9240-27a6d6f5f9b6" +version = "0.5.5" + +[[deps.ZipFile]] +deps = ["Libdl", "Printf", "Zlib_jll"] +git-tree-sha1 = "3593e69e469d2111389a9bd06bac1f3d730ac6de" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.9.4" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" + +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "52adc0a505b6421a8668f13dcdb0c4cb498bd72c" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.37" + +[[deps.ZygoteRules]] +deps = ["MacroTools"] +git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.2" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" diff --git a/examples/Project.toml b/examples/Project.toml new file mode 100644 index 00000000..caa8b0b6 --- /dev/null +++ b/examples/Project.toml @@ -0,0 +1,18 @@ +name = "FastDEQExperiments" +uuid = "5aa64bb0-ce80-4310-96b1-36313c344f92" +authors = ["Avik Pal "] +version = "0.1.0" + +[deps] +DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" +ExplicitFluxLayers = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" +FastDEQ = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" +Format = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" +LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6" +MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/examples/core/core.jl b/examples/core/core.jl deleted file mode 100644 index 2dbf4106..00000000 --- a/examples/core/core.jl +++ /dev/null @@ -1,9 +0,0 @@ -module FastDEQExperiments - -using FastDEQ, ExplicitFluxLayers, Random, Flux, OrdinaryDiffEq - -const EFL = ExplicitFluxLayers - -include("models.jl") - -end \ No newline at end of file diff --git a/examples/src/FastDEQExperiments.jl b/examples/src/FastDEQExperiments.jl new file mode 100644 index 00000000..dd45dfb5 --- /dev/null +++ b/examples/src/FastDEQExperiments.jl @@ -0,0 +1,18 @@ +module FastDEQExperiments + +using FastDEQ, ExplicitFluxLayers, Random, Flux, OrdinaryDiffEq, FluxMPI, Format, MLDatasets, MLDataUtils, DataLoaders, Optimisers +import LearnBase: ObsDim +import MLDataUtils: nobs, getobs + +const EFL = ExplicitFluxLayers + +# get_model +include("models.jl") +# PrettyTableLogger +include("logging.jl") +# get_dataloaders +include("dataloaders.jl") + +include("train.jl") + +end \ No newline at end of file diff --git a/examples/src/dataloaders.jl b/examples/src/dataloaders.jl new file mode 100644 index 00000000..1e1f5b32 --- /dev/null +++ b/examples/src/dataloaders.jl @@ -0,0 +1,36 @@ +struct MLDatasetsImageData + images + labels +end + +MLDatasetsImageData(images::AbstractArray{T,4}, labels::AbstractArray{T,2}) where {T} = + MLDatasetsImageData(collect(eachslice(images, dims=4)), collect(eachslice(labels, dims=2))) + +nobs(d::MLDatasetsImageData) = length(d.images) + +getobs(d::MLDatasetsImageData, i::Int, ::ObsDim.Undefined) = (d.images[i], d.labels[i]) +getobs(d::MLDatasetsImageData, i::Int) = (d.images[i], d.labels[i]) + +function get_dataloaders( + dataset::Symbol; μ=nothing, σ²=nothing, distributed=false, train_batchsize::Int64, test_batchsize::Int64 +) + (x_train, y_train), (x_test, y_test), μ, σ², nclasses = if dataset == :CIFAR10 + μ = μ === nothing ? reshape([0.4914, 0.4822, 0.4465], 1, 1, :, 1) : μ + σ² = σ² === nothing ? reshape([0.2023, 0.1994, 0.2010], 1, 1, :, 1) : σ² + CIFAR10.traindata(Float32), CIFAR10.testdata(Float32), μ, σ², 10 + else + throw(ArgumentError("Not yet implemented for $dataset")) + end + + x_train = (x_train .- μ) ./ σ² + y_train = Float32.(Flux.onehotbatch(y_train, 0:(nclasses - 1))) + x_test = (x_test .- μ) ./ σ² + y_test = Float32.(Flux.onehotbatch(y_test, 0:(nclasses - 1))) + + train_dataset = shuffleobs(MLDatasetsImageData(x_train, y_train)) + train_dataset = distributed ? DistributedDataContainer(train_dataset) : train_dataset + test_dataset = MLDatasetsImageData(x_test, y_test) + test_dataset = distributed ? DistributedDataContainer(test_dataset) : test_dataset + + return (DataLoader(train_dataset, train_batchsize), DataLoader(test_dataset, test_batchsize)) +end diff --git a/examples/src/logging.jl b/examples/src/logging.jl new file mode 100644 index 00000000..ddca1d12 --- /dev/null +++ b/examples/src/logging.jl @@ -0,0 +1,137 @@ + +function _should_log(; logging_rank=0, comm=MPI.COMM_WORLD) + FluxMPI.Initialized() || return true # Not using MPI + return local_rank() == logging_rank +end + +# Running AverageMeter +mutable struct AverageMeter{T} + last_value::T + sum::T + count::Int + + AverageMeter(T=Float32) = new{T}(T(0), T(0), 0) +end + +function reset!(am::AverageMeter{T}) where {T} + val = am() + am.last_value = T(0) + am.sum = T(0) + am.count = 0 + return val +end + +function update!(am::AverageMeter{T}, val::T) where {T} + am.last_value = val + am.sum += val + am.count += 1 + return am.sum / am.count +end + +update!(am::AverageMeter{T}, val) where {T} = update!(am, T(val)) + +(am::AverageMeter)() = am.sum / am.count + +# Simple Table Logger +struct PrettyTableLogger{N,AM,F,R,FIO} + header::NTuple{N,String} + average_meters::AM + span::Int + fmtrfuncs::F + records::R + fio::FIO + + function PrettyTableLogger(filename::String, header, record=[]) + fio = _should_log() ? open(filename, "w") : nothing + + N = length(header) + length(record) + headers = vcat(header, record) + headers_og = headers + _c = 0 + count() = (_c += 1; _c) + rsplits = first.(map(x -> length(x) >= 2 ? x : ("__" * string(count()), x), rsplit.(headers, "/"; limit=2)),) + headers = string.(last.(rsplit.(headers, "/"; limit=2))) + headers = map(x -> length(x) <= 6 ? x * (" "^length(x)) : x, headers) + ind_lens = length.(headers) + span = sum(ind_lens .+ 3) + 1 + rsplit_lens = Dict() + if fio !== nothing + for (i, r) in enumerate(rsplits) + _r = string(r) + _r ∉ keys(rsplit_lens) && (rsplit_lens[_r] = -3 - length(_r) + 1) + rsplit_lens[_r] = rsplit_lens[_r] + ind_lens[i] + 3 + end + rsplits_unique = unique(rsplits) + if !(length(rsplits_unique) == 1 && rsplits_unique[0] == "") + println("="^span) + for r in rsplits_unique + if startswith(r, "__") + print("| " * (" "^length(r)) * (" "^rsplit_lens[string(r)])) + else + print("| $r" * (" "^rsplit_lens[string(r)])) + end + end + println("|") + end + println("="^span) + for h in headers + print("| $h ") + end + println("|") + println("="^span) + for h in headers_og[1:(end - 1)] + print(fio, "$h,") + end + println(fio, "$(headers_og[end])") + end + + avg_meters = Dict{String,AverageMeter}(rec => AverageMeter() for rec in record) + + patterns = ["%$l.4f" for l in ind_lens] + fmtrfuncs = generate_formatter.(patterns) + + record = tuple(record...) + + return new{N,typeof(avg_meters),typeof(fmtrfuncs),typeof(record),typeof(fio)}( + tuple(headers...), avg_meters, span, fmtrfuncs, record, fio + ) + end +end + +function (pl::PrettyTableLogger)(args...; last::Bool=false, records::Dict=Dict()) + _should_log() || return nothing + if length(records) > 0 + for (rec, val) in records + update!(pl.average_meters[rec], val) + end + return nothing + end + if last + str = "="^pl.span + println(str) + return nothing + end + for (i, (fmtrfunc, arg)) in + enumerate(zip(pl.fmtrfuncs, vcat([args...], [reset!(pl.average_meters[rec]) for rec in pl.records]))) + h = fmtrfunc(arg) + print("| $h ") + if i < length(pl.fmtrfuncs) + print(pl.fio, "$arg,") + else + println(pl.fio, "$arg") + end + end + println("|") + flush(pl.fio) + return nothing +end + +function Base.close(pl::PrettyTableLogger) + pl(; last=true) + pl.fio === nothing || close(pl.fio) + return nothing +end + +function Base.show(io, pl::PrettyTableLogger) + print(io, "PrettyTableLogger(", pl.fio, ")") +end diff --git a/examples/core/models.jl b/examples/src/models.jl similarity index 93% rename from examples/core/models.jl rename to examples/src/models.jl index c964c89e..480332bd 100644 --- a/examples/core/models.jl +++ b/examples/src/models.jl @@ -222,6 +222,9 @@ function get_model( maxiters::Int, abstol, reltol, + seed, + device=gpu, + warmup::Bool=true, # Helps reduce time for Zygote to compile gradients first time ) initial_layers = EFL.Chain( conv3x3(3 => 24; initW=NormalInitializer()), @@ -320,5 +323,25 @@ function get_model( throw(ArgumentError("`model_type` must be one of `[:skip, :skipv2, :vanilla]`")) end - return DEQChain(initial_layers, deq, final_layers) + model = DEQChain(initial_layers, deq, final_layers) + ps, st = EFL.setup(MersenneTwister(seed), model) .|> device + + if warmup + clean_println("Starting Model Warmup") + x__ = randn(Float32, 32, 32, 3, 1) |> device + y__ = Float32.(Flux.onehotbatch([1], 0:9)) |> device + model(x__, ps, st) + clean_println("Forward Pass Warmup Completed") + _, back = Flux.pullback(ps) do p + (y, soln), st_ = model(x__, p, st) + return ( + Flux.logitcrossentropy(y, y__) + sum(abs2, soln.z_star .- soln.u₀), + st_ + ) + end + back((1.0f0, nothing)) + clean_println("Backward Pass Warmup Completed") + end + + return model, ps, st end diff --git a/examples/src/train.jl b/examples/src/train.jl new file mode 100644 index 00000000..220a4581 --- /dev/null +++ b/examples/src/train.jl @@ -0,0 +1,24 @@ +function evaluate(model, ps, st, dataloader, device) + matches, total_loss, total_datasize, total_nfe = 0, 0, 0, 0 + for (x, y) in dataloader + x = device(x) + y = device(y) + (ŷ, soln), st = model(x, ps, st) + + total_nfe += soln.nfe * size(x, ndims(x)) + total_loss += Flux.Losses.logitcrossentropy(ŷ, y) * size(x, ndims(x)) + matches += sum(argmax.(eachcol(ŷ)) .== Flux.onecold(cpu(y))) + total_datasize += size(x, ndims(x)) + end + return ( + ( + loss=total_loss / total_datasize, + accuracy=matches / total_datasize, + mean_nfe=total_nfe / total_datasize + ), + st + ) +end + +function train_one_epoch() +end \ No newline at end of file From 9b6c74c9c63a3c0546b375fe819f5ae340d0abed Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Apr 2022 11:57:55 -0400 Subject: [PATCH 05/76] add cifar10 training script --- examples/Project.toml | 1 + examples/cifar10/script.jl | 94 +++++++++++++++++++++++++ examples/src/FastDEQExperiments.jl | 8 +-- examples/src/dataloaders.jl | 8 +-- examples/src/logging.jl | 4 -- examples/src/models.jl | 11 +-- examples/src/train.jl | 109 ++++++++++++++++++++++++++--- 7 files changed, 205 insertions(+), 30 deletions(-) create mode 100644 examples/cifar10/script.jl diff --git a/examples/Project.toml b/examples/Project.toml index caa8b0b6..deb0b941 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -4,6 +4,7 @@ authors = ["Avik Pal "] version = "0.1.0" [deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" ExplicitFluxLayers = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" FastDEQ = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl new file mode 100644 index 00000000..31f63b0c --- /dev/null +++ b/examples/cifar10/script.jl @@ -0,0 +1,94 @@ +using FastDEQExperiments, Flux, CUDA, Optimisers + +## TODO: Distributed Training + +# Setup +CUDA.versioninfo() +CUDA.math_mode!(CUDA.FAST_MATH) +CUDA.allowscalar(false) + +function invoke_gc() + GC.gc(true) + CUDA.reclaim() + return nothing +end + +# Hyperparameters +config = Dict( + "seed" => 0, + "learning_rate" => 0.001, + "abstol" => 5.0f-2, + "reltol" => 5.0f-2, + "maxiters" => 20, + "epochs" => 50, + "dropout_rate" => 0.25, + "batchsize" => 128, + "eval_batchsize" => 128, + "model_type" => :skip, + "continuous" => true, + "weight_decay" => 0.0000025, +) + +expt_name = "cifar-10_seed-$(config["seed"])_model-$(config["model_type"])_continuous-$(config["continuous"])" + +# Training +function train_model(config, expt_name) + # Logger Setup + mkpath("logs/") + lg = FastDEQExperiments.PrettyTableLogger( + joinpath("logs/", expt_name * ".csv"), + [ + "Epoch Number", + "Train/NFE", + "Train/Accuracy", + "Train/Loss", + "Train/Eval Time", + "Train/Time", + "Test/NFE", + "Test/Accuracy", + "Test/Loss", + "Test/Time", + ], + ["Train/Running/NFE", "Train/Running/Loss"], + ); + + # Model Setup + model, ps, st = FastDEQExperiments.get_model( + Val(:CIFAR10); + dropout_rate=config["dropout_rate"], + model_type=config["model_type"], + continuous=config["continuous"], + maxiters=config["maxiters"], + abstol=config["abstol"], + reltol=config["reltol"], + seed=config["seed"], + device=gpu, + warmup=true, + group_count=8, + ) + + # Get Dataloaders + train_dataloader, test_dataloader = FastDEQExperiments.get_dataloaders( + :CIFAR10; distributed=false, train_batchsize=config["batchsize"], eval_batchsize=config["eval_batchsize"] + ) + + # Train + ps, st, st_opt = FastDEQExperiments.train( + model, + ps, + st, + FastDEQExperiments.loss_function(:CIFAR10, config["model_type"]), + Optimisers.ADAM(config["learning_rate"]), + train_dataloader, + nothing, + test_dataloader, + gpu, + config["epochs"], + lg; + cleanup_function=invoke_gc, + distributed=false, + ) + + # Close Logger and Flush Data to disk + return Base.close(lg) +end diff --git a/examples/src/FastDEQExperiments.jl b/examples/src/FastDEQExperiments.jl index dd45dfb5..10357b8c 100644 --- a/examples/src/FastDEQExperiments.jl +++ b/examples/src/FastDEQExperiments.jl @@ -6,13 +6,13 @@ import MLDataUtils: nobs, getobs const EFL = ExplicitFluxLayers -# get_model -include("models.jl") # PrettyTableLogger include("logging.jl") +# train, loss_function +include("train.jl") +# get_model +include("models.jl") # get_dataloaders include("dataloaders.jl") -include("train.jl") - end \ No newline at end of file diff --git a/examples/src/dataloaders.jl b/examples/src/dataloaders.jl index 1e1f5b32..bcb27481 100644 --- a/examples/src/dataloaders.jl +++ b/examples/src/dataloaders.jl @@ -12,11 +12,11 @@ getobs(d::MLDatasetsImageData, i::Int, ::ObsDim.Undefined) = (d.images[i], d.lab getobs(d::MLDatasetsImageData, i::Int) = (d.images[i], d.labels[i]) function get_dataloaders( - dataset::Symbol; μ=nothing, σ²=nothing, distributed=false, train_batchsize::Int64, test_batchsize::Int64 + dataset::Symbol; μ=nothing, σ²=nothing, distributed=false, train_batchsize::Int64, eval_batchsize::Int64 ) (x_train, y_train), (x_test, y_test), μ, σ², nclasses = if dataset == :CIFAR10 - μ = μ === nothing ? reshape([0.4914, 0.4822, 0.4465], 1, 1, :, 1) : μ - σ² = σ² === nothing ? reshape([0.2023, 0.1994, 0.2010], 1, 1, :, 1) : σ² + μ = μ === nothing ? reshape([0.4914f0, 0.4822f0, 0.4465f0], 1, 1, :, 1) : μ + σ² = σ² === nothing ? reshape([0.2023f0, 0.1994f0, 0.2010f0], 1, 1, :, 1) : σ² CIFAR10.traindata(Float32), CIFAR10.testdata(Float32), μ, σ², 10 else throw(ArgumentError("Not yet implemented for $dataset")) @@ -32,5 +32,5 @@ function get_dataloaders( test_dataset = MLDatasetsImageData(x_test, y_test) test_dataset = distributed ? DistributedDataContainer(test_dataset) : test_dataset - return (DataLoader(train_dataset, train_batchsize), DataLoader(test_dataset, test_batchsize)) + return (DataLoader(train_dataset, train_batchsize), DataLoader(test_dataset, eval_batchsize)) end diff --git a/examples/src/logging.jl b/examples/src/logging.jl index ddca1d12..6f6f9690 100644 --- a/examples/src/logging.jl +++ b/examples/src/logging.jl @@ -131,7 +131,3 @@ function Base.close(pl::PrettyTableLogger) pl.fio === nothing || close(pl.fio) return nothing end - -function Base.show(io, pl::PrettyTableLogger) - print(io, "PrettyTableLogger(", pl.fio, ")") -end diff --git a/examples/src/models.jl b/examples/src/models.jl index 480332bd..81da2f3b 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -332,14 +332,9 @@ function get_model( y__ = Float32.(Flux.onehotbatch([1], 0:9)) |> device model(x__, ps, st) clean_println("Forward Pass Warmup Completed") - _, back = Flux.pullback(ps) do p - (y, soln), st_ = model(x__, p, st) - return ( - Flux.logitcrossentropy(y, y__) + sum(abs2, soln.z_star .- soln.u₀), - st_ - ) - end - back((1.0f0, nothing)) + lfn = loss_function(:CIFAR10, model_type) + (l, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st), ps) + back((one(l), nothing, nothing)) clean_println("Backward Pass Warmup Completed") end diff --git a/examples/src/train.jl b/examples/src/train.jl index 220a4581..6589c893 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -1,9 +1,14 @@ +evaluate(model, ps, st, ::Nothing, device) = nothing + function evaluate(model, ps, st, dataloader, device) - matches, total_loss, total_datasize, total_nfe = 0, 0, 0, 0 + matches, total_loss, total_datasize, total_nfe, total_time = 0, 0, 0, 0, 0 for (x, y) in dataloader x = device(x) y = device(y) - (ŷ, soln), st = model(x, ps, st) + + start_time = time() + (ŷ, soln), _ = model(x, ps, st) + total_time += time() - start_time total_nfe += soln.nfe * size(x, ndims(x)) total_loss += Flux.Losses.logitcrossentropy(ŷ, y) * size(x, ndims(x)) @@ -11,14 +16,98 @@ function evaluate(model, ps, st, dataloader, device) total_datasize += size(x, ndims(x)) end return ( - ( - loss=total_loss / total_datasize, - accuracy=matches / total_datasize, - mean_nfe=total_nfe / total_datasize - ), - st + loss=total_loss / total_datasize, + accuracy=matches / total_datasize, + mean_nfe=total_nfe / total_datasize, + total_time=total_time, ) end -function train_one_epoch() -end \ No newline at end of file +function train_one_epoch(model, ps, st, loss_function, opt_state, dataloader, device, lg::PrettyTableLogger) + total_time = 0 + + for (x, y) in dataloader + x = device(x) + y = device(y) + + # Compute Loss + Backprop + Update + start_time = time() + + (loss, st, nfe), back = Flux.pullback(p -> loss_function(x, y, model, p, st), ps) + gs, = back(one(loss)) + ps, opt_state = Optimisers.update!(opt_state, ps, gs) + + total_time += time() - start_time + + # Logging + lg(; records=Dict("Train/Running/NFE" => nfe, "Train/Running/Loss" => loss)) + end + + return ps, st, opt_state, (total_time=total_time,) +end + +function loss_function(dataset::Symbol, model_type::Symbol) + if dataset ∈ (:CIFAR10,) + function loss_function_closure(x, y, model, ps, st) + (ŷ, soln), st_ = model(x, ps, st) + loss = if model_type == :vanilla + Flux.Losses.logitcrossentropy(ŷ, y) + else + Flux.Losses.logitcrossentropy(ŷ, y) + Flux.Losses.mse(soln.u₀, soln.z_star) + end + return loss, st_, soln.nfe + end + return loss_function_closure + else + throw(ArgumentError("$dataset - $model_type not yet supported")) + end +end + +function train( + model, + ps, + st, + loss_function, + opt, + train_dataloader, + val_dataloader, + test_dataloader, + device, + nepochs, + lg::PrettyTableLogger; + distributed::Bool=false, + cleanup_function=identity +) + # TODO: Saving model weights + opt_state = Optimisers.setup(opt, ps) + + for epoch in 1:nepochs + # Run a cleanup function + cleanup_function() + + # Train 1 epoch + ps, st, opt_state, training_stats = train_one_epoch( + model, ps, st, loss_function, opt_state, train_dataloader, device, lg + ) + + # Evaluate + train_eval_stats = evaluate(model, ps, st, train_dataloader, device) + val_eval_stats = evaluate(model, ps, st, val_dataloader, device) + test_eval_stats = evaluate(model, ps, st, test_dataloader, device) + + train_stats, val_stats, test_stats = if distributed + # TODO: Implement syncing the statistics + error("Distributed Training not yet implemented") + else + ( + (train_eval_stats.mean_nfe, train_eval_stats.accuracy, train_eval_stats.loss, train_eval_stats.total_time), + val_eval_stats === nothing ? () : (val_eval_stats.mean_nfe, val_eval_stats.accuracy, val_eval_stats.loss, val_eval_stats.total_time), + test_eval_stats === nothing ? () : (test_eval_stats.mean_nfe, test_eval_stats.accuracy, test_eval_stats.loss, test_eval_stats.total_time), + ) + end + + lg(epoch, train_stats..., training_stats.total_time, val_stats..., test_stats...) + end + + return ps, st, opt_state +end From cbc2adf6352684a03495c5212a6c24851a9ca45d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Apr 2022 12:22:17 -0400 Subject: [PATCH 06/76] Add dates dep --- examples/Project.toml | 1 + examples/cifar10/script.jl | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/Project.toml b/examples/Project.toml index deb0b941..842b47c3 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -6,6 +6,7 @@ version = "0.1.0" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" ExplicitFluxLayers = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" FastDEQ = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index 31f63b0c..f04a7214 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -1,4 +1,4 @@ -using FastDEQExperiments, Flux, CUDA, Optimisers +using FastDEQExperiments, Flux, CUDA, Optimisers, Dates ## TODO: Distributed Training @@ -29,7 +29,7 @@ config = Dict( "weight_decay" => 0.0000025, ) -expt_name = "cifar-10_seed-$(config["seed"])_model-$(config["model_type"])_continuous-$(config["continuous"])" +expt_name = "cifar-10_seed-$(config["seed"])_model-$(config["model_type"])_continuous-$(config["continuous"])_now-$(now())" # Training function train_model(config, expt_name) @@ -90,5 +90,9 @@ function train_model(config, expt_name) ) # Close Logger and Flush Data to disk - return Base.close(lg) + Base.close(lg) + + return model, cpu(ps), cpu(st), st_opt end + +model, ps, st, st_opt = train_model(config, expt_name) From 7f7312cdf65782d870437a2fc80d001d19f0c898 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Apr 2022 12:33:58 -0400 Subject: [PATCH 07/76] Fix training --- examples/Manifest.toml | 4 +--- examples/src/train.jl | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/Manifest.toml b/examples/Manifest.toml index a15e4ec8..98210465 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -443,9 +443,7 @@ version = "0.3.2" [[deps.FastDEQ]] deps = ["CUDA", "ChainRulesCore", "DataLoaders", "DiffEqBase", "DiffEqCallbacks", "DiffEqSensitivity", "ExplicitFluxLayers", "Flux", "Functors", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Random", "Reexport", "Requires", "SciMLBase", "Setfield", "Statistics", "SteadyStateDiffEq", "UnPack", "Zygote"] -git-tree-sha1 = "23164863ce195e94e8e5edc643e8af6b60cc8402" -repo-rev = "ap/efl" -repo-url = ".." +path = ".." uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" version = "0.1.0" diff --git a/examples/src/train.jl b/examples/src/train.jl index 6589c893..a178c7e0 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -34,7 +34,7 @@ function train_one_epoch(model, ps, st, loss_function, opt_state, dataloader, de start_time = time() (loss, st, nfe), back = Flux.pullback(p -> loss_function(x, y, model, p, st), ps) - gs, = back(one(loss)) + gs, = back((one(loss), nothing, nothing)) ps, opt_state = Optimisers.update!(opt_state, ps, gs) total_time += time() - start_time From 2c296a1c575db7332d96300188848bb22c5f41ca Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Apr 2022 12:39:38 -0400 Subject: [PATCH 08/76] RIP ordering --- examples/src/train.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/train.jl b/examples/src/train.jl index a178c7e0..9df61b95 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -35,7 +35,7 @@ function train_one_epoch(model, ps, st, loss_function, opt_state, dataloader, de (loss, st, nfe), back = Flux.pullback(p -> loss_function(x, y, model, p, st), ps) gs, = back((one(loss), nothing, nothing)) - ps, opt_state = Optimisers.update!(opt_state, ps, gs) + opt_state, ps = Optimisers.update!(opt_state, ps, gs) total_time += time() - start_time From 98a40031a837278910f0a7b998f6f9dac4fd18ea Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Apr 2022 17:44:17 -0400 Subject: [PATCH 09/76] Some printing --- examples/cifar10/script.jl | 4 ++-- examples/src/train.jl | 17 +++++++++++++---- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index f04a7214..0ec3793c 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -22,8 +22,8 @@ config = Dict( "maxiters" => 20, "epochs" => 50, "dropout_rate" => 0.25, - "batchsize" => 128, - "eval_batchsize" => 128, + "batchsize" => 256, + "eval_batchsize" => 256, "model_type" => :skip, "continuous" => true, "weight_decay" => 0.0000025, diff --git a/examples/src/train.jl b/examples/src/train.jl index 9df61b95..8d779b42 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -24,9 +24,9 @@ function evaluate(model, ps, st, dataloader, device) end function train_one_epoch(model, ps, st, loss_function, opt_state, dataloader, device, lg::PrettyTableLogger) - total_time = 0 + total_time, dlen = 0, length(dataloader) - for (x, y) in dataloader + for (i, (x, y)) in enumerate(dataloader) x = device(x) y = device(y) @@ -39,6 +39,10 @@ function train_one_epoch(model, ps, st, loss_function, opt_state, dataloader, de total_time += time() - start_time + if local_rank() == 0 && i % 25 == 1 + clean_println(" [$(i)/$(dlen)] data processed. Loss: $(loss). Time Taken: $(total_time)") + end + # Logging lg(; records=Dict("Train/Running/NFE" => nfe, "Train/Running/Loss" => loss)) end @@ -82,18 +86,23 @@ function train( opt_state = Optimisers.setup(opt, ps) for epoch in 1:nepochs - # Run a cleanup function - cleanup_function() + if local_rank() == 0 + clean_println("Epoch [$(epoch) / $(nepochs)]") + end # Train 1 epoch ps, st, opt_state, training_stats = train_one_epoch( model, ps, st, loss_function, opt_state, train_dataloader, device, lg ) + cleanup_function() # Evaluate train_eval_stats = evaluate(model, ps, st, train_dataloader, device) + cleanup_function() val_eval_stats = evaluate(model, ps, st, val_dataloader, device) + cleanup_function() test_eval_stats = evaluate(model, ps, st, test_dataloader, device) + cleanup_function() train_stats, val_stats, test_stats = if distributed # TODO: Implement syncing the statistics From b3a6a083d615930a77ae7d13868a8f6ed93698af Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Apr 2022 17:46:38 -0400 Subject: [PATCH 10/76] Some printing --- examples/src/train.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/train.jl b/examples/src/train.jl index 8d779b42..73bc4800 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -39,7 +39,7 @@ function train_one_epoch(model, ps, st, loss_function, opt_state, dataloader, de total_time += time() - start_time - if local_rank() == 0 && i % 25 == 1 + if _should_log() && i % 25 == 1 clean_println(" [$(i)/$(dlen)] data processed. Loss: $(loss). Time Taken: $(total_time)") end @@ -86,7 +86,7 @@ function train( opt_state = Optimisers.setup(opt, ps) for epoch in 1:nepochs - if local_rank() == 0 + if _should_log() clean_println("Epoch [$(epoch) / $(nepochs)]") end From 47f9ea3cb6d0d309eaaffd6a7c52dcc18abf6573 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Apr 2022 18:54:33 -0400 Subject: [PATCH 11/76] Stop printing --- examples/cifar10/script.jl | 15 ++++++--------- examples/src/models.jl | 4 ++-- examples/src/train.jl | 21 ++++++--------------- 3 files changed, 14 insertions(+), 26 deletions(-) diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index 0ec3793c..5c98144c 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -1,6 +1,7 @@ -using FastDEQExperiments, Flux, CUDA, Optimisers, Dates +using FastDEQExperiments, Flux, CUDA, Optimisers, Dates, FluxMPI -## TODO: Distributed Training +## Distributed Training +# FluxMPI.Init(verbose=true) # Setup CUDA.versioninfo() @@ -22,8 +23,8 @@ config = Dict( "maxiters" => 20, "epochs" => 50, "dropout_rate" => 0.25, - "batchsize" => 256, - "eval_batchsize" => 256, + "batchsize" => 128, + "eval_batchsize" => 128, "model_type" => :skip, "continuous" => true, "weight_decay" => 0.0000025, @@ -39,17 +40,13 @@ function train_model(config, expt_name) joinpath("logs/", expt_name * ".csv"), [ "Epoch Number", - "Train/NFE", - "Train/Accuracy", - "Train/Loss", - "Train/Eval Time", "Train/Time", "Test/NFE", "Test/Accuracy", "Test/Loss", "Test/Time", ], - ["Train/Running/NFE", "Train/Running/Loss"], + ["Train/Running/NFE", "Train/Running/Loss", "Train/Running/Accuracy"], ); # Model Setup diff --git a/examples/src/models.jl b/examples/src/models.jl index 81da2f3b..60c0c43e 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -333,8 +333,8 @@ function get_model( model(x__, ps, st) clean_println("Forward Pass Warmup Completed") lfn = loss_function(:CIFAR10, model_type) - (l, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st), ps) - back((one(l), nothing, nothing)) + (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st), ps) + back((one(l), nothing, nothing, nothing)) clean_println("Backward Pass Warmup Completed") end diff --git a/examples/src/train.jl b/examples/src/train.jl index 73bc4800..0f9c65f3 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -33,15 +33,13 @@ function train_one_epoch(model, ps, st, loss_function, opt_state, dataloader, de # Compute Loss + Backprop + Update start_time = time() - (loss, st, nfe), back = Flux.pullback(p -> loss_function(x, y, model, p, st), ps) - gs, = back((one(loss), nothing, nothing)) + (loss, ŷ, st, nfe), back = Flux.pullback(p -> loss_function(x, y, model, p, st), ps) + gs, = back((one(loss), nothing, nothing, nothing)) opt_state, ps = Optimisers.update!(opt_state, ps, gs) total_time += time() - start_time - if _should_log() && i % 25 == 1 - clean_println(" [$(i)/$(dlen)] data processed. Loss: $(loss). Time Taken: $(total_time)") - end + matches += sum(argmax.(eachcol(cpu(ŷ))) .== Flux.onecold(cpu(y))) # Logging lg(; records=Dict("Train/Running/NFE" => nfe, "Train/Running/Loss" => loss)) @@ -59,7 +57,7 @@ function loss_function(dataset::Symbol, model_type::Symbol) else Flux.Losses.logitcrossentropy(ŷ, y) + Flux.Losses.mse(soln.u₀, soln.z_star) end - return loss, st_, soln.nfe + return loss, ŷ, st_, soln.nfe end return loss_function_closure else @@ -86,10 +84,6 @@ function train( opt_state = Optimisers.setup(opt, ps) for epoch in 1:nepochs - if _should_log() - clean_println("Epoch [$(epoch) / $(nepochs)]") - end - # Train 1 epoch ps, st, opt_state, training_stats = train_one_epoch( model, ps, st, loss_function, opt_state, train_dataloader, device, lg @@ -97,25 +91,22 @@ function train( cleanup_function() # Evaluate - train_eval_stats = evaluate(model, ps, st, train_dataloader, device) - cleanup_function() val_eval_stats = evaluate(model, ps, st, val_dataloader, device) cleanup_function() test_eval_stats = evaluate(model, ps, st, test_dataloader, device) cleanup_function() - train_stats, val_stats, test_stats = if distributed + val_stats, test_stats = if distributed # TODO: Implement syncing the statistics error("Distributed Training not yet implemented") else ( - (train_eval_stats.mean_nfe, train_eval_stats.accuracy, train_eval_stats.loss, train_eval_stats.total_time), val_eval_stats === nothing ? () : (val_eval_stats.mean_nfe, val_eval_stats.accuracy, val_eval_stats.loss, val_eval_stats.total_time), test_eval_stats === nothing ? () : (test_eval_stats.mean_nfe, test_eval_stats.accuracy, test_eval_stats.loss, test_eval_stats.total_time), ) end - lg(epoch, train_stats..., training_stats.total_time, val_stats..., test_stats...) + lg(epoch, training_stats.total_time, val_stats..., test_stats...) end return ps, st, opt_state From 6ec4c7b86eee1d36599cf76d815fd285d97d4d60 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 10 Apr 2022 19:10:08 -0400 Subject: [PATCH 12/76] missing acc --- examples/src/train.jl | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/src/train.jl b/examples/src/train.jl index 0f9c65f3..a9acefe1 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -24,9 +24,9 @@ function evaluate(model, ps, st, dataloader, device) end function train_one_epoch(model, ps, st, loss_function, opt_state, dataloader, device, lg::PrettyTableLogger) - total_time, dlen = 0, length(dataloader) + total_time = 0 - for (i, (x, y)) in enumerate(dataloader) + for (x, y) in dataloader x = device(x) y = device(y) @@ -39,10 +39,10 @@ function train_one_epoch(model, ps, st, loss_function, opt_state, dataloader, de total_time += time() - start_time - matches += sum(argmax.(eachcol(cpu(ŷ))) .== Flux.onecold(cpu(y))) + acc = sum(argmax.(eachcol(cpu(ŷ))) .== Flux.onecold(cpu(y))) / size(x, 4) # Logging - lg(; records=Dict("Train/Running/NFE" => nfe, "Train/Running/Loss" => loss)) + lg(; records=Dict("Train/Running/NFE" => nfe, "Train/Running/Loss" => loss, "Train/Running/Accuracy" => acc)) end return ps, st, opt_state, (total_time=total_time,) @@ -78,7 +78,7 @@ function train( nepochs, lg::PrettyTableLogger; distributed::Bool=false, - cleanup_function=identity + cleanup_function=identity, ) # TODO: Saving model weights opt_state = Optimisers.setup(opt, ps) @@ -101,8 +101,16 @@ function train( error("Distributed Training not yet implemented") else ( - val_eval_stats === nothing ? () : (val_eval_stats.mean_nfe, val_eval_stats.accuracy, val_eval_stats.loss, val_eval_stats.total_time), - test_eval_stats === nothing ? () : (test_eval_stats.mean_nfe, test_eval_stats.accuracy, test_eval_stats.loss, test_eval_stats.total_time), + if val_eval_stats === nothing + () + else + (val_eval_stats.mean_nfe, val_eval_stats.accuracy, val_eval_stats.loss, val_eval_stats.total_time) + end, + if test_eval_stats === nothing + () + else + (test_eval_stats.mean_nfe, test_eval_stats.accuracy, test_eval_stats.loss, test_eval_stats.total_time) + end, ) end From 0844f88dc77c9907891b0ab68807453c981c774a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 11 Apr 2022 10:57:46 -0400 Subject: [PATCH 13/76] Add distributed training --- examples/cifar10/script.jl | 22 +++++----------- examples/src/dataloaders.jl | 6 ++--- examples/src/models.jl | 10 +++++++- examples/src/train.jl | 51 +++++++++++++++++-------------------- 4 files changed, 43 insertions(+), 46 deletions(-) diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index 5c98144c..ee2ec53e 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -1,7 +1,7 @@ using FastDEQExperiments, Flux, CUDA, Optimisers, Dates, FluxMPI -## Distributed Training -# FluxMPI.Init(verbose=true) +# Distributed Training +FluxMPI.Init(; verbose=true) # Setup CUDA.versioninfo() @@ -23,8 +23,8 @@ config = Dict( "maxiters" => 20, "epochs" => 50, "dropout_rate" => 0.25, - "batchsize" => 128, - "eval_batchsize" => 128, + "batchsize" => 64, + "eval_batchsize" => 64, "model_type" => :skip, "continuous" => true, "weight_decay" => 0.0000025, @@ -38,16 +38,9 @@ function train_model(config, expt_name) mkpath("logs/") lg = FastDEQExperiments.PrettyTableLogger( joinpath("logs/", expt_name * ".csv"), - [ - "Epoch Number", - "Train/Time", - "Test/NFE", - "Test/Accuracy", - "Test/Loss", - "Test/Time", - ], + ["Epoch Number", "Train/Time", "Test/NFE", "Test/Accuracy", "Test/Loss", "Test/Time"], ["Train/Running/NFE", "Train/Running/Loss", "Train/Running/Accuracy"], - ); + ) # Model Setup model, ps, st = FastDEQExperiments.get_model( @@ -66,7 +59,7 @@ function train_model(config, expt_name) # Get Dataloaders train_dataloader, test_dataloader = FastDEQExperiments.get_dataloaders( - :CIFAR10; distributed=false, train_batchsize=config["batchsize"], eval_batchsize=config["eval_batchsize"] + :CIFAR10; train_batchsize=config["batchsize"], eval_batchsize=config["eval_batchsize"] ) # Train @@ -83,7 +76,6 @@ function train_model(config, expt_name) config["epochs"], lg; cleanup_function=invoke_gc, - distributed=false, ) # Close Logger and Flush Data to disk diff --git a/examples/src/dataloaders.jl b/examples/src/dataloaders.jl index bcb27481..09b15f19 100644 --- a/examples/src/dataloaders.jl +++ b/examples/src/dataloaders.jl @@ -12,7 +12,7 @@ getobs(d::MLDatasetsImageData, i::Int, ::ObsDim.Undefined) = (d.images[i], d.lab getobs(d::MLDatasetsImageData, i::Int) = (d.images[i], d.labels[i]) function get_dataloaders( - dataset::Symbol; μ=nothing, σ²=nothing, distributed=false, train_batchsize::Int64, eval_batchsize::Int64 + dataset::Symbol; μ=nothing, σ²=nothing, train_batchsize::Int64, eval_batchsize::Int64 ) (x_train, y_train), (x_test, y_test), μ, σ², nclasses = if dataset == :CIFAR10 μ = μ === nothing ? reshape([0.4914f0, 0.4822f0, 0.4465f0], 1, 1, :, 1) : μ @@ -28,9 +28,9 @@ function get_dataloaders( y_test = Float32.(Flux.onehotbatch(y_test, 0:(nclasses - 1))) train_dataset = shuffleobs(MLDatasetsImageData(x_train, y_train)) - train_dataset = distributed ? DistributedDataContainer(train_dataset) : train_dataset + train_dataset = is_distributed() ? DistributedDataContainer(train_dataset) : train_dataset test_dataset = MLDatasetsImageData(x_test, y_test) - test_dataset = distributed ? DistributedDataContainer(test_dataset) : test_dataset + test_dataset = is_distributed() ? DistributedDataContainer(test_dataset) : test_dataset return (DataLoader(train_dataset, train_batchsize), DataLoader(test_dataset, eval_batchsize)) end diff --git a/examples/src/models.jl b/examples/src/models.jl index 60c0c43e..dcb86c68 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -275,7 +275,7 @@ function get_model( solver = if continuous ContinuousDEQSolver( - VCABM3(); + VCABM4(); mode=:rel_deq_best, abstol=abstol, reltol=reltol, @@ -338,5 +338,13 @@ function get_model( clean_println("Backward Pass Warmup Completed") end + ps, st = if is_distributed() + ps_ = FluxMPI.synchronize!(ps; root_rank=0) + st_ = FluxMPI.synchronize!(st; root_rank=0) + ps_, st_ + else + ps, st + end + return model, ps, st end diff --git a/examples/src/train.jl b/examples/src/train.jl index a9acefe1..1b3c25e8 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -1,3 +1,22 @@ +is_distributed() = FluxMPI.Initialized() && total_workers() > 1 + +_get_loggable_stats(::Nothing) = () + +function _get_loggable_stats(stats::NamedTuple) + if is_distributed() + arr = [stats.mean_nfe, stats.accuracy, stats.loss, stats.total_datasize] + FluxMPI.MPIExtensions.Reduce!(arr, +, 0, FluxMPI.MPI.COMM_WORLD) + return ((arr[1:3] ./ arr[4])..., stats.total_time) + else + return ( + stats.mean_nfe / stats.total_datasize, + stats.accuracy / stats.total_datasize, + stats.loss / stats.total_datasize, + stats.total_time, + ) + end +end + evaluate(model, ps, st, ::Nothing, device) = nothing function evaluate(model, ps, st, dataloader, device) @@ -15,12 +34,7 @@ function evaluate(model, ps, st, dataloader, device) matches += sum(argmax.(eachcol(ŷ)) .== Flux.onecold(cpu(y))) total_datasize += size(x, ndims(x)) end - return ( - loss=total_loss / total_datasize, - accuracy=matches / total_datasize, - mean_nfe=total_nfe / total_datasize, - total_time=total_time, - ) + return (loss=total_loss, accuracy=matches, mean_nfe=total_nfe, total_time=total_time, total_datasize=total_datasize) end function train_one_epoch(model, ps, st, loss_function, opt_state, dataloader, device, lg::PrettyTableLogger) @@ -77,11 +91,12 @@ function train( device, nepochs, lg::PrettyTableLogger; - distributed::Bool=false, cleanup_function=identity, ) + cleanup_function() # TODO: Saving model weights opt_state = Optimisers.setup(opt, ps) + opt_state = is_distributed() ? FluxMPI.synchronize!(opt_state; root_rank=0) : opt_state for epoch in 1:nepochs # Train 1 epoch @@ -91,29 +106,11 @@ function train( cleanup_function() # Evaluate - val_eval_stats = evaluate(model, ps, st, val_dataloader, device) + val_stats = _get_loggable_stats(evaluate(model, ps, st, val_dataloader, device)) cleanup_function() - test_eval_stats = evaluate(model, ps, st, test_dataloader, device) + test_stats = _get_loggable_stats(evaluate(model, ps, st, test_dataloader, device)) cleanup_function() - val_stats, test_stats = if distributed - # TODO: Implement syncing the statistics - error("Distributed Training not yet implemented") - else - ( - if val_eval_stats === nothing - () - else - (val_eval_stats.mean_nfe, val_eval_stats.accuracy, val_eval_stats.loss, val_eval_stats.total_time) - end, - if test_eval_stats === nothing - () - else - (test_eval_stats.mean_nfe, test_eval_stats.accuracy, test_eval_stats.loss, test_eval_stats.total_time) - end, - ) - end - lg(epoch, training_stats.total_time, val_stats..., test_stats...) end From 7f3273a8acac14bf34d28bd9d3f29c825c59544b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 11 Apr 2022 11:31:40 -0400 Subject: [PATCH 14/76] Make float32 --- examples/cifar10/script.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index ee2ec53e..4d202822 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -17,17 +17,17 @@ end # Hyperparameters config = Dict( "seed" => 0, - "learning_rate" => 0.001, + "learning_rate" => 0.001f0, "abstol" => 5.0f-2, "reltol" => 5.0f-2, "maxiters" => 20, "epochs" => 50, - "dropout_rate" => 0.25, + "dropout_rate" => 0.25f0, "batchsize" => 64, "eval_batchsize" => 64, "model_type" => :skip, "continuous" => true, - "weight_decay" => 0.0000025, + "weight_decay" => 0.0000025f0, ) expt_name = "cifar-10_seed-$(config["seed"])_model-$(config["model_type"])_continuous-$(config["continuous"])_now-$(now())" From 37f5c87d6deadc5ee22063c6ebc1e06ca9bf11f8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 11 Apr 2022 12:41:06 -0400 Subject: [PATCH 15/76] HPC safe precompile --- examples/cifar10/script.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index 4d202822..9a2d496a 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -1,3 +1,10 @@ +# ----------------------------------- # +# ------ Precompilation in HPC ------ # +d = strip(String(read(`mktemp -d`))) +mkdir(joinpath(d, "compiled")) +pushfirst!(DEPOT_PATH, d) +#------------------------------------ # + using FastDEQExperiments, Flux, CUDA, Optimisers, Dates, FluxMPI # Distributed Training From 30da3c5e6916afa940d6561f1f8fe9d8e10d19da Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 11 Apr 2022 16:41:07 -0400 Subject: [PATCH 16/76] Change solver from VCABM4 since it changes inputs to FP64 --- examples/cifar10/script.jl | 9 +-------- examples/src/models.jl | 2 +- src/utils.jl | 7 +++++++ 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index 9a2d496a..28af88fc 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -1,10 +1,3 @@ -# ----------------------------------- # -# ------ Precompilation in HPC ------ # -d = strip(String(read(`mktemp -d`))) -mkdir(joinpath(d, "compiled")) -pushfirst!(DEPOT_PATH, d) -#------------------------------------ # - using FastDEQExperiments, Flux, CUDA, Optimisers, Dates, FluxMPI # Distributed Training @@ -75,7 +68,7 @@ function train_model(config, expt_name) ps, st, FastDEQExperiments.loss_function(:CIFAR10, config["model_type"]), - Optimisers.ADAM(config["learning_rate"]), + Optimisers.ADAMW(config["learning_rate"], (0.9f0, 0.999f0), config["weight_decay"]), train_dataloader, nothing, test_dataloader, diff --git a/examples/src/models.jl b/examples/src/models.jl index dcb86c68..c5a7c5a3 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -275,7 +275,7 @@ function get_model( solver = if continuous ContinuousDEQSolver( - VCABM4(); + VCABM3(); mode=:rel_deq_best, abstol=abstol, reltol=reltol, diff --git a/src/utils.jl b/src/utils.jl index 83336047..e88b3a01 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -36,6 +36,13 @@ function split_and_reshape(x::AbstractMatrix, idxs::Tuple, shapes::Tuple) ) end +## Some dispatches for CuArrays are not defined for subarrays +# function split_and_reshape(x::AbstractMatrix, idxs::Tuple, shapes::Tuple) +# return Tuple( +# x[reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...), :] for i in 1:(length(idxs) - 1) +# ) +# end + # Zygote Fix function Zygote.accum(x::NTuple{N,T}, y::AbstractVector{T}) where {N,T<:AbstractArray} return Zygote.accum.(x, y) From b3f301e95c5a024404559c21351f0d4da99aaeac Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 12 Apr 2022 01:52:41 -0400 Subject: [PATCH 17/76] ADAMW syncing not working for MPI --- examples/cifar10/script.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index 28af88fc..e963fc6b 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -27,7 +27,6 @@ config = Dict( "eval_batchsize" => 64, "model_type" => :skip, "continuous" => true, - "weight_decay" => 0.0000025f0, ) expt_name = "cifar-10_seed-$(config["seed"])_model-$(config["model_type"])_continuous-$(config["continuous"])_now-$(now())" @@ -68,7 +67,7 @@ function train_model(config, expt_name) ps, st, FastDEQExperiments.loss_function(:CIFAR10, config["model_type"]), - Optimisers.ADAMW(config["learning_rate"], (0.9f0, 0.999f0), config["weight_decay"]), + Optimisers.ADAM(config["learning_rate"], (0.9f0, 0.999f0)), train_dataloader, nothing, test_dataloader, From ed4c40f7edc0dcf656c82b61f89356eb34e914b5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 12 Apr 2022 12:08:54 -0400 Subject: [PATCH 18/76] Reduce using MPI --- examples/Manifest.toml | 64 ++++++++++++++++++++---------- examples/Project.toml | 1 + examples/src/FastDEQExperiments.jl | 13 +++++- examples/src/train.jl | 2 +- 4 files changed, 58 insertions(+), 22 deletions(-) diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 98210465..78c9552e 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -1,7 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.7.2" +julia_version = "1.8.0-beta1" manifest_format = "2.0" +project_hash = "c980a263b9193d100d02f087d77471d79992ffe5" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] @@ -28,6 +29,7 @@ version = "2.3.0" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" [[deps.ArnoldiMethod]] deps = ["LinearAlgebra", "Random", "StaticArrays"] @@ -195,13 +197,14 @@ version = "0.3.0" [[deps.Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "96b0bc6c52df76506efc8a441c6cf1adcb1babc4" +git-tree-sha1 = "b153278a25dd42c65abbf4e62344f9d22e59191b" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.42.0" +version = "3.43.0" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "0.5.0+0" [[deps.CompositeTypes]] git-tree-sha1 = "d5b014b216dc891e81fea299638e4c10c657b582" @@ -291,9 +294,9 @@ version = "0.4.0" [[deps.DiffEqBase]] deps = ["ArrayInterface", "ChainRulesCore", "DEDataArrays", "DataStructures", "Distributions", "DocStringExtensions", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "IterativeSolvers", "LabelledArrays", "LinearAlgebra", "Logging", "MuladdMacro", "NonlinearSolve", "Parameters", "PreallocationTools", "Printf", "RecursiveArrayTools", "RecursiveFactorization", "Reexport", "Requires", "SciMLBase", "Setfield", "SparseArrays", "StaticArrays", "Statistics", "SuiteSparse", "ZygoteRules"] -git-tree-sha1 = "d1c8d8b645500d7dffec3355d29af6c4f8bfa6df" +git-tree-sha1 = "d19393983b7609b0b7d4caa2bce6b018f663b688" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" -version = "6.82.3" +version = "6.83.0" [[deps.DiffEqCallbacks]] deps = ["DataStructures", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "NLsolve", "OrdinaryDiffEq", "Parameters", "RecipesBase", "RecursiveArrayTools", "SciMLBase", "StaticArrays"] @@ -366,8 +369,9 @@ uuid = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" version = "0.5.9" [[deps.Downloads]] -deps = ["ArgTools", "LibCURL", "NetworkOptions"] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" [[deps.EllipsisNotation]] deps = ["ArrayInterface"] @@ -389,7 +393,7 @@ version = "0.0.29+0" [[deps.ExplicitFluxLayers]] deps = ["CUDA", "ChainRulesCore", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Octavian", "Random", "Setfield", "SparseArrays", "Statistics"] -git-tree-sha1 = "ac921bfb0de25739c6a1bba2a70f5820cca529fd" +git-tree-sha1 = "d8520e4d150fea76c4865a177893fcda429a27fd" repo-rev = "ap/sparse" repo-url = "https://github.com/avik-pal/ExplicitFluxLayers.jl.git" uuid = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" @@ -459,6 +463,9 @@ git-tree-sha1 = "129b104185df66e408edd6625d480b7f9e9823a0" uuid = "48062228-2e41-5def-b9a4-89aafe57970f" version = "0.9.18" +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + [[deps.FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] git-tree-sha1 = "246621d23d1f43e3b9c368bf3b72b2331a27c286" @@ -485,7 +492,7 @@ version = "0.13.0" [[deps.FluxMPI]] deps = ["CUDA", "Dates", "Flux", "Functors", "LearnBase", "MLDataUtils", "MPI", "Optimisers", "Setfield", "Zygote"] -git-tree-sha1 = "7e0a62fddea90780bfc1e29c85c2eb36fcffac8d" +git-tree-sha1 = "be5c7b6c1acf5081be9179f400c78a0bddbd7f0d" repo-rev = "ap/opt" repo-url = "https://github.com/avik-pal/FluxMPI.jl.git" uuid = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" @@ -730,9 +737,9 @@ version = "0.8.0" [[deps.KrylovKit]] deps = ["LinearAlgebra", "Printf"] -git-tree-sha1 = "0328ad9966ae29ccefb4e1b9bfd8c8867e4360df" +git-tree-sha1 = "49b0c1dd5c292870577b8f58c51072bd558febb9" uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" -version = "0.5.3" +version = "0.5.4" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] @@ -801,10 +808,12 @@ version = "1.0.0" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.3" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "7.81.0+0" [[deps.LibGit2]] deps = ["Base64", "NetworkOptions", "Printf", "SHA"] @@ -813,6 +822,7 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.10.2+0" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -841,18 +851,18 @@ version = "1.15.0" [[deps.LogExpFunctions]] deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "58f25e56b706f95125dcb796f39e1fb01d913a71" +git-tree-sha1 = "a970d55c2ad8084ca317a4658ba6ce99b7523571" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.10" +version = "0.3.12" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.LoopVectorization]] deps = ["ArrayInterface", "CPUSummary", "ChainRulesCore", "CloseOpenIntervals", "DocStringExtensions", "ForwardDiff", "HostCPUFeatures", "IfElse", "LayoutPointers", "LinearAlgebra", "OffsetArrays", "PolyesterWeave", "SIMDDualNumbers", "SLEEFPirates", "SpecialFunctions", "Static", "ThreadingUtilities", "UnPack", "VectorizationBase"] -git-tree-sha1 = "f9d84dcb46419e973872b32c051e5baad2d29de7" +git-tree-sha1 = "4acc35e95bf18de5e9562d27735bef0950f2ed74" uuid = "bdcacae8-1622-11e9-2a5c-532679323890" -version = "0.12.107" +version = "0.12.108" [[deps.MAT]] deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] @@ -909,9 +919,9 @@ version = "0.19.2" [[deps.MPICH_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "8eed51eb836c8f47781cdb493ffd5f56370c0496" +git-tree-sha1 = "3dacfc006764fe498515a022c3976b7e133c4008" uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" -version = "4.0.1+0" +version = "4.0.2+0" [[deps.MacroTools]] deps = ["Markdown", "Random"] @@ -948,6 +958,7 @@ version = "1.0.3" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.0+0" [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] @@ -972,6 +983,7 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2022.2.1" [[deps.MuladdMacro]] git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" @@ -1015,6 +1027,7 @@ version = "0.1.5" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" [[deps.NonlinearSolve]] deps = ["ArrayInterface", "FiniteDiff", "ForwardDiff", "IterativeSolvers", "LinearAlgebra", "RecursiveArrayTools", "RecursiveFactorization", "Reexport", "SciMLBase", "Setfield", "StaticArrays", "UnPack"] @@ -1043,10 +1056,12 @@ version = "1.10.8" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.17+2" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+0" [[deps.OpenMPI_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] @@ -1116,6 +1131,7 @@ version = "0.3.0" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.8.0" [[deps.PoissonRandom]] deps = ["Random", "Statistics", "Test"] @@ -1284,6 +1300,7 @@ version = "0.5.3" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" [[deps.SIMDDualNumbers]] deps = ["ForwardDiff", "IfElse", "SLEEFPirates", "VectorizationBase"] @@ -1399,9 +1416,9 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[deps.StatsAPI]] deps = ["LinearAlgebra"] -git-tree-sha1 = "c3d8ba7f3fa0625b062b82853a7d5229cb728b6b" +git-tree-sha1 = "8d7530a38dbd2c397be7ddd01a424e4f411dcc41" uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.2.1" +version = "1.2.2" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] @@ -1435,9 +1452,9 @@ version = "0.2.19" [[deps.Strided]] deps = ["LinearAlgebra", "TupleTools"] -git-tree-sha1 = "4d581938087ca90eab9bd4bb6d270edaefd70dcd" +git-tree-sha1 = "972de61ae8cb965c516b871b69bb8594463d39a9" uuid = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" -version = "1.1.2" +version = "1.2.0" [[deps.StringEncodings]] deps = ["Libiconv_jll"] @@ -1464,10 +1481,12 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "5.10.1+0" [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.0" [[deps.TableTraits]] deps = ["IteratorInterfaceExtensions"] @@ -1484,6 +1503,7 @@ version = "1.7.0" [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" [[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] @@ -1603,6 +1623,7 @@ version = "0.9.4" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.12+1" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] @@ -1619,11 +1640,14 @@ version = "0.2.2" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.0.1+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.41.0+1" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "16.2.1+1" diff --git a/examples/Project.toml b/examples/Project.toml index 842b47c3..479b615a 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -15,6 +15,7 @@ Format = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6" MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/examples/src/FastDEQExperiments.jl b/examples/src/FastDEQExperiments.jl index 10357b8c..70301726 100644 --- a/examples/src/FastDEQExperiments.jl +++ b/examples/src/FastDEQExperiments.jl @@ -1,6 +1,17 @@ module FastDEQExperiments -using FastDEQ, ExplicitFluxLayers, Random, Flux, OrdinaryDiffEq, FluxMPI, Format, MLDatasets, MLDataUtils, DataLoaders, Optimisers +using FastDEQ, + ExplicitFluxLayers, + Random, + Flux, + OrdinaryDiffEq, + FluxMPI, + Format, + MLDatasets, + MLDataUtils, + DataLoaders, + Optimisers, + MPI import LearnBase: ObsDim import MLDataUtils: nobs, getobs diff --git a/examples/src/train.jl b/examples/src/train.jl index 1b3c25e8..7a2e5005 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -5,7 +5,7 @@ _get_loggable_stats(::Nothing) = () function _get_loggable_stats(stats::NamedTuple) if is_distributed() arr = [stats.mean_nfe, stats.accuracy, stats.loss, stats.total_datasize] - FluxMPI.MPIExtensions.Reduce!(arr, +, 0, FluxMPI.MPI.COMM_WORLD) + MPI.Reduce!(arr, +, 0, MPI.COMM_WORLD) return ((arr[1:3] ./ arr[4])..., stats.total_time) else return ( From e89b4018955470a6b441b3ba6e45d68f92a551e7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 13 Apr 2022 13:27:51 -0400 Subject: [PATCH 19/76] Fixed depth for pretraining --- src/layers/core.jl | 16 ++++++++++------ src/layers/deq.jl | 37 +++++++++++++++++++++++++++++++------ src/layers/mdeq.jl | 34 ++++++++++++++++++++++++++++++---- 3 files changed, 71 insertions(+), 16 deletions(-) diff --git a/src/layers/core.jl b/src/layers/core.jl index 0cd9dcba..932db452 100644 --- a/src/layers/core.jl +++ b/src/layers/core.jl @@ -1,26 +1,30 @@ abstract type AbstractDeepEquilibriumNetwork <: AbstractExplicitLayer end abstract type AbstractSkipDeepEquilibriumNetwork <: AbstractDeepEquilibriumNetwork end -initialparameters(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork) = initialparameters(rng, deq.model) -initialstates(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork) = initialstates(rng, deq.model) -createcache(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork, x) = createcache(rng, deq.model, x) +initialparameters(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork) = (model=initialparameters(rng, deq.model),) +function initialstates(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork) + return (model=initialstates(rng, deq.model), fixed_depth=0) +end +createcache(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork, x) = (model=createcache(rng, deq.model, x),) parameterlength(deq::AbstractDeepEquilibriumNetwork) = parameterlength(deq.model) -statelength(deq::AbstractDeepEquilibriumNetwork) = statelength(deq.model) +statelength(deq::AbstractDeepEquilibriumNetwork) = statelength(deq.model) + 2 cachesize(deq::AbstractDeepEquilibriumNetwork) = cachesize(deq.model) function initialparameters(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork) return (model=initialparameters(rng, deq.model), shortcut=initialparameters(rng, deq.shortcut)) end function initialstates(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork) - return (model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut)) + return ( + model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut), fixed_depth=0 + ) end function createcache(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork, x) return (model=createcache(rng, deq.model, x), shortcut=createcache(rng, deq.shortcut, x)) end parameterlength(deq::AbstractSkipDeepEquilibriumNetwork) = parameterlength(deq.model) + parameterlength(deq.shortcut) -statelength(deq::AbstractSkipDeepEquilibriumNetwork) = statelength(deq.model) + statelength(deq.shortcut) +statelength(deq::AbstractSkipDeepEquilibriumNetwork) = statelength(deq.model) + statelength(deq.shortcut) + 2 cachesize(deq::AbstractSkipDeepEquilibriumNetwork) = cachesize(deq.model) + cachesize(deq.shortcut) """ diff --git a/src/layers/deq.jl b/src/layers/deq.jl index cbd59b06..2bbf831b 100644 --- a/src/layers/deq.jl +++ b/src/layers/deq.jl @@ -18,19 +18,32 @@ end function (deq::DeepEquilibriumNetwork{J})(x::AbstractArray{T}, ps::NamedTuple, st::NamedTuple) where {J,T} z = zero(x) + if !iszero(st.fixed_depth) + # Pretraining without Fixed Point Solving + st_ = st.model + z_star = z + for _ ∈ 1:st.fixed_depth + z_star, st_ = deq.model((z_star, x), ps.model, st_) + end + @set! st.model = st_ + + return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, st.fixed_depth)), st + end + function dudt(u, p, t) - u_, _ = deq.model((u, x), p, st) + u_, _ = deq.model((u, x), p, st.model) return u_ .- u end - prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps) + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model) sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) - z_star, st_ = deq.model((sol.u, x), ps, st) + z_star, st_ = deq.model((sol.u, x), ps.model, st.model) - jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps, st, z_star, x) : T(0)) - residual = Zygote.@ignore z_star .- deq.model((z_star, x), ps, st)[1] + jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : T(0)) + residual = Zygote.@ignore z_star .- deq.model((z_star, x), ps.model, st.model)[1] + @set! st.model = st_ :: typeof(st.model) - return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st_ + return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end struct SkipDeepEquilibriumNetwork{J,M,Sh,A,S,K} <: AbstractSkipDeepEquilibriumNetwork @@ -64,6 +77,18 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})(x::AbstractArray{T}, ps::Named end @set! st.shortcut = st__ + if !iszero(st.fixed_depth) + # Pretraining without Fixed Point Solving + st_ = st.model + z_star = z + for _ ∈ 1:st.fixed_depth + z_star, st_ = deq.model((z_star, x), ps.model, st_) + end + @set! st.model = st_ + + return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, st.fixed_depth)), st + end + function dudt(u, p, t) u_, = deq.model((u, x), p, st.model) return u_ .- u diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index 3d32930b..7d68fdba 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -7,7 +7,7 @@ struct MultiScaleDeepEquilibriumNetwork{N,L,M,A,S,K} <: AbstractDeepEquilibriumN end function initialstates(rng::AbstractRNG, deq::MultiScaleDeepEquilibriumNetwork) - return (model=initialstates(rng, deq.model), split_idxs=Tuple(vcat(0, cumsum(prod.(deq.scales))...))) + return (model=initialstates(rng, deq.model), split_idxs=Tuple(vcat(0, cumsum(prod.(deq.scales))...)), fixed_depth=0) end function MultiScaleDeepEquilibriumNetwork( @@ -46,6 +46,19 @@ Zygote.@nograd get_initial_condition_mdeq function (deq::MultiScaleDeepEquilibriumNetwork{N})(x::AbstractArray{T}, ps::NamedTuple, st::NamedTuple) where {N,T} z, st = get_initial_condition_mdeq(deq.scales, x, st) + if !iszero(st.fixed_depth) + z_star = split_and_reshape(z, st.split_idxs, deq.scales) + st_ = st.model + + for _ ∈ 1:st.fixed_depth + z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps.model, st_) + end + + @set! st.model = st_ + + return (z_star, DeepEquilibriumSolution(vcat(Flux.flatten.(z_star)...), z, z, 0.0f0, st.fixed_depth)), st + end + function dudt_(u, p, t) u_split = split_and_reshape(u, st.split_idxs, deq.scales) u_, st_ = deq.model(((u_split[1], x), u_split[2:N]...), p, st.model) @@ -54,11 +67,11 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})(x::AbstractArray{T}, ps::Nam dudt(u, p, t) = vcat(Flux.flatten.(dudt_(u, p, t)[1])...) .- u - prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps) + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model) sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) - z_star, st_ = dudt_(sol.u, ps, nothing) + z_star, st_ = dudt_(sol.u, ps.model, nothing) - residual = Zygote.@ignore dudt(sol.u, ps, nothing) + residual = Zygote.@ignore dudt(sol.u, ps.model, nothing) @set! st.model = st_ @@ -127,6 +140,19 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( (vcat(Flux.flatten.(z0)...), st) end + if !iszero(st.fixed_depth) + z_star = split_and_reshape(z, st.split_idxs, deq.scales) + st_ = st.model + + for _ ∈ 1:st.fixed_depth + z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps.model, st_) + end + + @set! st.model = st_ + + return (z_star, DeepEquilibriumSolution(vcat(Flux.flatten.(z_star)...), z, z, 0.0f0, st.fixed_depth)), st + end + function dudt_(u, p, t) u_split = split_and_reshape(u, st.split_idxs, deq.scales) u_, st_ = deq.model(((u_split[1], x), u_split[2:N]...), p, st.model) From 07634bc899bfdde3d77f09254b78b3ed9dc15a53 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 13 Apr 2022 13:42:00 -0400 Subject: [PATCH 20/76] Fixed depth testing --- test/runtests.jl | 47 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index fb9d457d..b8659b6c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,5 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux -const EFL = ExplicitFluxLayers - @testset "FastDEQ.jl" begin seed = 0 @@ -19,6 +17,11 @@ const EFL = ExplicitFluxLayers gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps) + @info "Testing DEQ without Fixed Point Iterations" + st = EFL.update_state(st, :fixed_depth, 5) + + gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps) + @info "Testing SkipDEQ" model = DEQChain( EFL.Dense(2, 2), @@ -38,6 +41,14 @@ const EFL = ExplicitFluxLayers sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end + @info "Testing SkipDEQ without Fixed Point Iterations" + st = EFL.update_state(st, :fixed_depth, 5) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end + @info "Testing SkipDEQV2" model = DEQChain( EFL.Dense(2, 2), @@ -57,6 +68,14 @@ const EFL = ExplicitFluxLayers sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end + @info "Testing SkipDEQV2 without Fixed Point Iterations" + st = EFL.update_state(st, :fixed_depth, 5) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end + # @info "Testing Broyden Solver" # Random.seed!(0) @@ -124,6 +143,14 @@ const EFL = ExplicitFluxLayers sum(Base.Fix1(sum, abs2), ŷ .- y) end + @info "Testing MultiScaleDEQ without Fixed Point Iterations" + st = EFL.update_state(st, :fixed_depth, 5) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + end + @info "Testing MultiScaleSkipDEQ" model = MultiScaleSkipDeepEquilibriumNetwork( ( @@ -154,6 +181,14 @@ const EFL = ExplicitFluxLayers sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end + @info "Testing MultiScaleSkipDEQ without Fixed Point Iterations" + st = EFL.update_state(st, :fixed_depth, 5) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end + @info "Testing MultiScaleSkipDEQV2" model = MultiScaleSkipDeepEquilibriumNetwork( ( @@ -183,4 +218,12 @@ const EFL = ExplicitFluxLayers (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end + + @info "Testing MultiScaleSkipDEQV2 without Fixed Point Iterations" + st = EFL.update_state(st, :fixed_depth, 5) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end end From 3968f1cc744011c9582c9b731a3f0899f5323998 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 13 Apr 2022 14:02:46 -0400 Subject: [PATCH 21/76] Missing state --- src/layers/mdeq.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index 7d68fdba..5dd66689 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -94,6 +94,7 @@ function initialstates(rng::AbstractRNG, deq::MultiScaleSkipDeepEquilibriumNetwo model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut), split_idxs=Tuple(vcat(0, cumsum(prod.(deq.scales))...)), + fixed_depth=0 ) end From a951f98718ceec5fe9bb22612319cecf4d423ce0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 13 Apr 2022 15:24:35 -0400 Subject: [PATCH 22/76] Imagenet model configuration --- examples/Manifest.toml | 41 +--- examples/cifar10/script.jl | 41 ++-- examples/src/FastDEQExperiments.jl | 5 +- examples/src/config.jl | 289 +++++++++++++++++++++++++++++ examples/src/models.jl | 235 +++++++++++++---------- examples/src/train.jl | 95 +++++++--- 6 files changed, 521 insertions(+), 185 deletions(-) create mode 100644 examples/src/config.jl diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 78c9552e..3ff589e6 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.8.0-beta1" +julia_version = "1.7.2" manifest_format = "2.0" project_hash = "c980a263b9193d100d02f087d77471d79992ffe5" @@ -29,7 +29,6 @@ version = "2.3.0" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" [[deps.ArnoldiMethod]] deps = ["LinearAlgebra", "Random", "StaticArrays"] @@ -204,7 +203,6 @@ version = "3.43.0" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "0.5.0+0" [[deps.CompositeTypes]] git-tree-sha1 = "d5b014b216dc891e81fea299638e4c10c657b582" @@ -369,9 +367,8 @@ uuid = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" version = "0.5.9" [[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" [[deps.EllipsisNotation]] deps = ["ArrayInterface"] @@ -393,9 +390,7 @@ version = "0.0.29+0" [[deps.ExplicitFluxLayers]] deps = ["CUDA", "ChainRulesCore", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Octavian", "Random", "Setfield", "SparseArrays", "Statistics"] -git-tree-sha1 = "d8520e4d150fea76c4865a177893fcda429a27fd" -repo-rev = "ap/sparse" -repo-url = "https://github.com/avik-pal/ExplicitFluxLayers.jl.git" +path = "/mnt/research/softwares/ExplicitFluxLayers/" uuid = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" version = "0.2.0" @@ -463,9 +458,6 @@ git-tree-sha1 = "129b104185df66e408edd6625d480b7f9e9823a0" uuid = "48062228-2e41-5def-b9a4-89aafe57970f" version = "0.9.18" -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - [[deps.FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] git-tree-sha1 = "246621d23d1f43e3b9c368bf3b72b2331a27c286" @@ -492,7 +484,7 @@ version = "0.13.0" [[deps.FluxMPI]] deps = ["CUDA", "Dates", "Flux", "Functors", "LearnBase", "MLDataUtils", "MPI", "Optimisers", "Setfield", "Zygote"] -git-tree-sha1 = "be5c7b6c1acf5081be9179f400c78a0bddbd7f0d" +git-tree-sha1 = "0956751f425663d4f468cf4ed97b95249257e202" repo-rev = "ap/opt" repo-url = "https://github.com/avik-pal/FluxMPI.jl.git" uuid = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" @@ -808,12 +800,10 @@ version = "1.0.0" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.3" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "7.81.0+0" [[deps.LibGit2]] deps = ["Base64", "NetworkOptions", "Printf", "SHA"] @@ -822,7 +812,6 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.10.2+0" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -958,7 +947,6 @@ version = "1.0.3" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.0+0" [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] @@ -983,7 +971,6 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.2.1" [[deps.MuladdMacro]] git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" @@ -1027,7 +1014,6 @@ version = "0.1.5" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" [[deps.NonlinearSolve]] deps = ["ArrayInterface", "FiniteDiff", "ForwardDiff", "IterativeSolvers", "LinearAlgebra", "RecursiveArrayTools", "RecursiveFactorization", "Reexport", "SciMLBase", "Setfield", "StaticArrays", "UnPack"] @@ -1056,12 +1042,10 @@ version = "1.10.8" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.17+2" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+0" [[deps.OpenMPI_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] @@ -1131,7 +1115,6 @@ version = "0.3.0" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.8.0" [[deps.PoissonRandom]] deps = ["Random", "Statistics", "Test"] @@ -1253,9 +1236,9 @@ version = "2.26.3" [[deps.RecursiveFactorization]] deps = ["LinearAlgebra", "LoopVectorization", "Polyester", "StrideArraysCore", "TriangularSolve"] -git-tree-sha1 = "7ad4c2ef15b7aecd767b3921c0d255d39b3603ea" +git-tree-sha1 = "a9a852c7ebb08e2a40e8c0ab9830a744fa283690" uuid = "f2c3362d-daeb-58d1-803e-2bc74f2840b4" -version = "0.2.9" +version = "0.2.10" [[deps.Reexport]] git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" @@ -1300,7 +1283,6 @@ version = "0.5.3" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" [[deps.SIMDDualNumbers]] deps = ["ForwardDiff", "IfElse", "SLEEFPirates", "VectorizationBase"] @@ -1446,9 +1428,9 @@ version = "6.46.0" [[deps.StrideArraysCore]] deps = ["ArrayInterface", "CloseOpenIntervals", "IfElse", "LayoutPointers", "ManualMemory", "Requires", "SIMDTypes", "Static", "ThreadingUtilities"] -git-tree-sha1 = "c7e0392560f15771003cce388fe8471d17941374" +git-tree-sha1 = "df8fc9d0407a77241c529cc2ef97ba2e3436ff51" uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" -version = "0.2.19" +version = "0.3.2" [[deps.Strided]] deps = ["LinearAlgebra", "TupleTools"] @@ -1481,12 +1463,10 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "5.10.1+0" [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.0" [[deps.TableTraits]] deps = ["IteratorInterfaceExtensions"] @@ -1503,7 +1483,6 @@ version = "1.7.0" [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" [[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] @@ -1623,7 +1602,6 @@ version = "0.9.4" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.12+1" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] @@ -1640,14 +1618,11 @@ version = "0.2.2" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.0.1+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.41.0+1" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "16.2.1+1" diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index e963fc6b..2247eaef 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -4,29 +4,17 @@ using FastDEQExperiments, Flux, CUDA, Optimisers, Dates, FluxMPI FluxMPI.Init(; verbose=true) # Setup -CUDA.versioninfo() CUDA.math_mode!(CUDA.FAST_MATH) CUDA.allowscalar(false) -function invoke_gc() - GC.gc(true) - CUDA.reclaim() - return nothing -end - # Hyperparameters config = Dict( "seed" => 0, - "learning_rate" => 0.001f0, "abstol" => 5.0f-2, "reltol" => 5.0f-2, - "maxiters" => 20, - "epochs" => 50, - "dropout_rate" => 0.25f0, - "batchsize" => 64, - "eval_batchsize" => 64, "model_type" => :skip, "continuous" => true, + "model_size" => :TINY, ) expt_name = "cifar-10_seed-$(config["seed"])_model-$(config["model_type"])_continuous-$(config["continuous"])_now-$(now())" @@ -41,19 +29,20 @@ function train_model(config, expt_name) ["Train/Running/NFE", "Train/Running/Loss", "Train/Running/Accuracy"], ) + # Experiment Configuration + expt_config = FastDEQExperiments.get_experiment_config( + :CIFAR10, + config["model_size"]; + model_type = config["model_type"], + continuous = config["continuous"], + ) + # Model Setup model, ps, st = FastDEQExperiments.get_model( - Val(:CIFAR10); - dropout_rate=config["dropout_rate"], - model_type=config["model_type"], - continuous=config["continuous"], - maxiters=config["maxiters"], - abstol=config["abstol"], - reltol=config["reltol"], + expt_config; seed=config["seed"], device=gpu, warmup=true, - group_count=8, ) # Get Dataloaders @@ -66,15 +55,15 @@ function train_model(config, expt_name) model, ps, st, - FastDEQExperiments.loss_function(:CIFAR10, config["model_type"]), - Optimisers.ADAM(config["learning_rate"], (0.9f0, 0.999f0)), + FastDEQExperiments.loss_function(expt_config), + FastDEQExperiments.construct_optimiser(expt_config), train_dataloader, nothing, test_dataloader, gpu, - config["epochs"], - lg; - cleanup_function=invoke_gc, + expt_config.nepochs, + lg, + expt_config ) # Close Logger and Flush Data to disk diff --git a/examples/src/FastDEQExperiments.jl b/examples/src/FastDEQExperiments.jl index 70301726..68c3a31e 100644 --- a/examples/src/FastDEQExperiments.jl +++ b/examples/src/FastDEQExperiments.jl @@ -11,7 +11,8 @@ using FastDEQ, MLDataUtils, DataLoaders, Optimisers, - MPI + MPI, + CUDA import LearnBase: ObsDim import MLDataUtils: nobs, getobs @@ -19,6 +20,8 @@ const EFL = ExplicitFluxLayers # PrettyTableLogger include("logging.jl") +# get_model_config +include("config.jl") # train, loss_function include("train.jl") # get_model diff --git a/examples/src/config.jl b/examples/src/config.jl new file mode 100644 index 00000000..0aea3628 --- /dev/null +++ b/examples/src/config.jl @@ -0,0 +1,289 @@ +abstract type AbstractTaskModelConfiguration end + +# Predefined Image Classification Models +Base.@kwdef struct ImageClassificationModelConfiguration{N} <: AbstractTaskModelConfiguration + num_layers::Int + num_classes::Int + dropout_rate::Float32 + group_count::Int + weight_norm::Bool + downsample_times::Int + expansion_factor::Int + post_gn_affine::Bool + image_size::Tuple{Int,Int} + + num_modules::Int + num_branches::Int + block_type::Symbol + big_kernels::NTuple{N,Int} + head_channels::NTuple{N,Int} + num_blocks::NTuple{N,Int} + num_channels::NTuple{N,Int} + + fuse_method::Symbol + final_channelsize::Int + + fwd_maxiters::Int + bwd_maxiters::Int + model_type::Symbol + continuous::Bool + + # Specific for Continuous Models + abstol::Float32 = 1f-2 + reltol::Float32 = 1f-2 + stop_mode::Symbol = :abs_deq_best + ode_solver = VCABM3() +end + +function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) + if dataset == :CIFAR10 + if model_size == :TINY + return ImageClassificationModelConfiguration{2}(; + num_layers=10, + num_classes=10, + dropout_rate=0.25f0, + group_count=8, + weight_norm=true, + downsample_times=0, + expansion_factor=5, + post_gn_affine=false, + image_size=(32, 32), + num_modules=1, + num_branches=2, + block_type=:basic, + big_kernels=(0, 0), + head_channels=(8, 16), + num_blocks=(1, 1), + num_channels=(24, 24), + fuse_method=:sum, + final_channelsize=200, + fwd_maxiters=18, + bwd_maiters=20, + kwargs... + ) + elseif model_size == :LARGE + return ImageClassificationModelConfiguration{4}(; + num_layers=10, + num_classes=10, + dropout_rate=0.3f0, + group_count=8, + weight_norm=true, + downsample_times=0, + expansion_factor=5, + post_gn_affine=false, + image_size=(32, 32), + num_modules=1, + num_branches=4, + block_type=:basic, + big_kernels=(0, 0, 0, 0), + head_channels=(14, 28, 56, 112), + num_blocks=(1, 1, 1, 1), + num_channels=(32, 64, 128, 256), + fuse_method=:sum, + final_channelsize=1680, + fwd_maxiters=18, + bwd_maiters=20, + kwargs... + ) + else + throw(ArgumentError("`model_size` must be one of `[:TINY, :LARGE]`")) + end + elseif dataset == :IMAGENET + if model_size == :SMALL + return ImageClassificationModelConfiguration{4}(; + num_layers=4, + num_classes=1000, + dropout_rate=0.0f0, + group_count=8, + weight_norm=true, + downsample_times=2, + expansion_factor=5, + post_gn_affine=true, + image_size=(224, 224), + num_modules=1, + num_branches=4, + block_type=:basic, + big_kernels=(0, 0, 0, 0), + head_channels=(24, 48, 96, 192), + num_blocks=(1, 1, 1, 1), + num_channels=(32, 64, 128, 256), + fuse_method=:sum, + final_channelsize=2048, + fwd_maxiters=27, + bwd_maxiters=28, + kwargs... + ) + elseif model_size == :LARGE + return ImageClassificationModelConfiguration{4}(; + num_layers=4, + num_classes=1000, + dropout_rate=0.0f0, + group_count=8, + weight_norm=true, + downsample_times=2, + expansion_factor=5, + post_gn_affine=true, + image_size=(224, 224), + num_modules=1, + num_branches=4, + block_type=:basic, + big_kernels=(0, 0, 0, 0), + head_channels=(32, 64, 128, 256), + num_blocks=(1, 1, 1, 1), + num_channels=(80, 160, 320, 640), + fuse_method=:sum, + final_channelsize=2048, + fwd_maxiters=27, + bwd_maxiters=28, + kwargs... + ) + elseif model_size == :XL + return ImageClassificationModelConfiguration{4}(; + num_layers=4, + num_classes=1000, + dropout_rate=0.0f0, + group_count=8, + weight_norm=true, + downsample_times=2, + expansion_factor=5, + post_gn_affine=true, + image_size=(224, 224), + num_modules=1, + num_branches=4, + block_type=:basic, + big_kernels=(0, 0, 0, 0), + head_channels=(32, 64, 128, 256), + num_blocks=(1, 1, 1, 1), + num_channels=(88, 176, 352, 704), + fuse_method=:sum, + final_channelsize=2048, + fwd_maxiters=27, + bwd_maxiters=28, + kwargs... + ) + else + throw(ArgumentError("`model_size` must be one of `[:SMALL, :LARGE, :XL]`")) + end + else + throw(ArgumentError("`dataset` must be one of `[:CIFAR10]`")) + end +end + +function compute_feature_scales(config::ImageClassificationModelConfiguration) + image_size = config.image_size + image_size_downsampled = image_size + for _ in 1:(config.downsample_times) + image_size_downsampled = image_size_downsampled .÷ 2 + end + scales = [(image_size_downsampled..., config.num_channels[1])] + for i in 2:(config.num_branches) + push!(scales, ((scales[end][1:2] .÷ 2)..., config.num_channels[i])) + end + return Tuple(scales) +end + +# Experiment Configuration +Base.@kwdef struct ExperimentConfiguration{M<:AbstractTaskModelConfiguration} + model_config::M + + # Eval + eval_batchsize::Int + + # Train + train_batchsize::Int + nepochs::Int + pretrain_steps::Int + + # Optimiser + lr_scheduler::Symbol + optimiser::Symbol + eta::Float32 + momentum::Float32 + nesterov::Bool + weight_decay::Float32 +end + +function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) + if dataset == :CIFAR10 + if model_size == :TINY + return ExperimentConfiguration( + model_config=get_model_config(dataset, model_size; kwargs...), + eval_batchsize=32, + train_batchsize=32, + nepochs=50, + pretrain_steps=3000, + lr_scheduler=:COSINE, + optimiser=:ADAM, + eta=0.001f0 / 2 * (is_distributed() ? total_workers() : 1), + weight_decay=0.0f0, + momentum=0.9f0, + nesterov=true + ) + elseif model_size == :LARGE + return ExperimentConfiguration( + model_config=get_model_config(dataset, model_size; kwargs...), + eval_batchsize=32, + train_batchsize=32, + nepochs=220, + pretrain_steps=20000, + lr_scheduler=:COSINE, + optimiser=:ADAM, + eta=0.001f0 / 4 * (is_distributed() ? total_workers() : 1), + weight_decay=0.0f0, + momentum=0.9f0, + nesterov=true + ) + else + throw(ArgumentError("`model_size` must be one of `[:TINY, :LARGE]`")) + end + elseif dataset == :IMAGENET + if model_size == :SMALL + return ExperimentConfiguration( + model_config=get_model_config(dataset, model_size; kwargs...), + eval_batchsize=32, + train_batchsize=32, + nepochs=100, + pretrain_steps=510000, + lr_scheduler=:COSINE, + optimiser=:SGD, + eta=0.05f0 / 4 * (is_distributed() ? total_workers() : 1), + weight_decay=0.00005f0, + momentum=0.9f0, + nesterov=true + ) + elseif model_size == :LARGE + return ExperimentConfiguration( + model_config=get_model_config(dataset, model_size; kwargs...), + eval_batchsize=32, + train_batchsize=32, + nepochs=100, + pretrain_steps=510000, + lr_scheduler=:COSINE, + optimiser=:SGD, + eta=0.05f0 / 4 * (is_distributed() ? total_workers() : 1), + weight_decay=0.00005f0, + momentum=0.9f0, + nesterov=true + ) + elseif model_size == :XL + return ExperimentConfiguration( + model_config=get_model_config(dataset, model_size; kwargs...), + eval_batchsize=32, + train_batchsize=32, + nepochs=100, + pretrain_steps=510000, + lr_scheduler=:COSINE, + optimiser=:SGD, + eta=0.05f0 / 4 * (is_distributed() ? total_workers() : 1), + weight_decay=0.00005f0, + momentum=0.9f0, + nesterov=true + ) + else + throw(ArgumentError("`model_size` must be one of `[:SMALL, :LARGE, :XL]`")) + end + else + throw(ArgumentError("`dataset` must be one of `[:CIFAR10]`")) + end +end + diff --git a/examples/src/models.jl b/examples/src/models.jl index c5a7c5a3..7f04ae7e 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -1,26 +1,22 @@ # Building Blocks ## Helpful Functional Wrappers function conv1x1(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) - return EFL.Conv((1, 1), mapping, activation; stride=stride, pad=0, bias=bias, kwargs...) + return EFL.Conv((1, 1), mapping, activation; stride=stride, pad=0, bias=bias, initW=NormalInitializer(), kwargs...) end function conv3x3(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) - return EFL.Conv((3, 3), mapping, activation; stride=stride, pad=1, bias=bias, kwargs...) + return EFL.Conv((3, 3), mapping, activation; stride=stride, pad=1, bias=bias, initW=NormalInitializer(), kwargs...) end function conv5x5(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) - return EFL.Conv((5, 5), mapping, activation; stride=stride, pad=2, bias=bias, kwargs...) + return EFL.Conv((5, 5), mapping, activation; stride=stride, pad=2, bias=bias, initW=NormalInitializer(), kwargs...) end reassociate(x::NTuple{2,<:AbstractArray}, y) = (x[1], (x[2], y)) ## Downsample Module -function downsample_module(mapping, resolution_mapping, activation; group_count=8) - in_resolution, out_resolution = resolution_mapping +function downsample_module(mapping, level_diff, activation; group_count=8) in_channels, out_channels = mapping - @assert in_resolution > out_resolution - @assert ispow2(in_resolution ÷ out_resolution) - level_diff = Int(log2(in_resolution ÷ out_resolution)) function intermediate_mapping(i) if in_channels * (2^level_diff) == out_channels @@ -33,19 +29,15 @@ function downsample_module(mapping, resolution_mapping, activation; group_count= layers = EFL.AbstractExplicitLayer[] for i in 1:level_diff inchs, outchs = intermediate_mapping(i) - push!(layers, conv3x3(inchs => outchs; stride=2, initW=NormalInitializer())) + push!(layers, conv3x3(inchs => outchs; stride=2)) push!(layers, EFL.GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) end return EFL.Chain(layers...) end ## Upsample Module -function upsample_module(mapping, resolution_mapping, activation; upsample_mode::Symbol=:nearest, group_count=8) - in_resolution, out_resolution = resolution_mapping +function upsample_module(mapping, level_diff, activation; upsample_mode::Symbol=:nearest, group_count=8) in_channels, out_channels = mapping - @assert in_resolution < out_resolution - @assert ispow2(out_resolution ÷ in_resolution) - level_diff = Int(log2(out_resolution ÷ in_resolution)) function intermediate_mapping(i) if out_channels * (2^level_diff) == in_channels @@ -58,7 +50,7 @@ function upsample_module(mapping, resolution_mapping, activation; upsample_mode: layers = EFL.AbstractExplicitLayer[] for i in 1:level_diff inchs, outchs = intermediate_mapping(i) - push!(layers, conv1x1(inchs => outchs; initW=NormalInitializer())) + push!(layers, conv1x1(inchs => outchs)) push!(layers, EFL.GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) push!(layers, EFL.Upsample(upsample_mode; scale=2)) end @@ -79,11 +71,11 @@ function ResidualBlockV1( ) inplanes, outplanes = mapping inner_planes = outplanes * deq_expand - conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; initW=NormalInitializer(), bias=false) - conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; initW=NormalInitializer(), bias=false) + conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; bias=false) + conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=false) conv1, conv2 = if weight_norm - EFL.WeightNorm(conv1, (:weight,)), EFL.WeightNorm(conv2, (:weight,)) + EFL.WeightNorm(conv1, (:weight,), (4,)), EFL.WeightNorm(conv2, (:weight,), (4,)) else conv1, conv2 end @@ -122,15 +114,15 @@ function ResidualBlockV2( dropout_rate::Real=0.0f0, gn_affine::Bool=true, weight_norm::Bool=true, - gn_track_stats::Bool=false, + gn_track_stats::Bool=true, ) inplanes, outplanes = mapping inner_planes = outplanes * deq_expand - conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; initW=NormalInitializer(), bias=false) - conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; initW=NormalInitializer(), bias=false) + conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; bias=false) + conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=false) conv1, conv2 = if weight_norm - EFL.WeightNorm(conv1, (:weight,)), EFL.WeightNorm(conv2, (:weight,)) + EFL.WeightNorm(conv1, (:weight,), (4,)), EFL.WeightNorm(conv2, (:weight,), (4,)) else conv1, conv2 end @@ -155,7 +147,7 @@ end function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=true, bn_affine::Bool=true) rescale = if first(mapping) != last(mapping) * expansion EFL.Chain( - conv1x1(first(mapping) => last(mapping) * expansion; initW=NormalInitializer()), + conv1x1(first(mapping) => last(mapping) * expansion), EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), ) else @@ -163,9 +155,7 @@ function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool end return EFL.Chain( - EFL.Parallel( - reassociate, EFL.BranchLayer(rescale, conv1x1(mapping; initW=NormalInitializer())), EFL.NoOpLayer() - ), + EFL.Parallel(reassociate, EFL.BranchLayer(rescale, conv1x1(mapping)), EFL.NoOpLayer()), EFL.Parallel( +, EFL.NoOpLayer(), @@ -173,9 +163,9 @@ function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool EFL.WrappedFunction(y2i -> y2i[1] .+ y2i[2]), # Since injection could be a scalar EFL.Chain( EFL.BatchNorm(last(mapping), gelu; affine=bn_affine, track_stats=bn_track_stats), - conv3x3(last(mapping) => last(mapping) * expansion; initW=NormalInitializer()), + conv3x3(last(mapping) => last(mapping) * expansion), EFL.BatchNorm(last(mapping) * expansion, gelu; track_stats=bn_track_stats, affine=bn_affine), - conv1x1(last(mapping) * expansion => last(mapping) * expansion; initW=NormalInitializer()), + conv1x1(last(mapping) * expansion => last(mapping) * expansion), EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), ), ), @@ -187,7 +177,7 @@ end function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=true, bn_affine::Bool=true) rescale = if first(mapping) != last(mapping) * expansion EFL.Chain( - conv1x1(first(mapping) => last(mapping) * expansion; initW=NormalInitializer()), + conv1x1(first(mapping) => last(mapping) * expansion), EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), ) else @@ -199,11 +189,11 @@ function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool +, rescale, EFL.Chain( - conv1x1(mapping; initW=NormalInitializer()), + conv1x1(mapping), EFL.BatchNorm(last(mapping), gelu; affine=bn_affine, track_stats=bn_track_stats), - conv3x3(last(mapping) => last(mapping) * expansion; initW=NormalInitializer()), + conv3x3(last(mapping) => last(mapping) * expansion), EFL.BatchNorm(last(mapping) * expansion, gelu; track_stats=bn_track_stats, affine=bn_affine), - conv1x1(last(mapping) * expansion => last(mapping) * expansion; initW=NormalInitializer()), + conv1x1(last(mapping) * expansion => last(mapping) * expansion), EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), ), ), @@ -212,88 +202,131 @@ function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool end # Dataset Specific Models -## CIFAR10 -- MultiScaleDEQ +get_model(econfig::ExperimentConfiguration, args...; kwargs...) = get_model(econfig.model_config, args...; kwargs...) + function get_model( - ::Val{:CIFAR10}; - dropout_rate, - group_count=8, - model_type::Symbol, - continuous::Bool=true, - maxiters::Int, - abstol, - reltol, - seed, + config::ImageClassificationModelConfiguration; + seed::Int, device=gpu, - warmup::Bool=true, # Helps reduce time for Zygote to compile gradients first time + warmup::Bool=true, # Helps reduce Zygote compile times ) - initial_layers = EFL.Chain( - conv3x3(3 => 24; initW=NormalInitializer()), - EFL.BatchNorm(24, gelu; track_stats=true, affine=true), - conv3x3(24 => 24; initW=NormalInitializer()), - EFL.BatchNorm(24, gelu; track_stats=true, affine=true), - ) + init_channel_size = config.num_channels[1] - main_layers = ( - ResidualBlockV1(24 => 24; dropout_rate, num_gn_groups=group_count), # 32 x 32 - ResidualBlockV1(24 => 24; dropout_rate, num_gn_groups=group_count), # 16 x 16 - ) - - mapping_layers = [ - EFL.NoOpLayer() downsample_module(24 => 24, 32 => 16, gelu; group_count=group_count) - upsample_module(24 => 24, 16 => 32, gelu; group_count=group_count, upsample_mode=:nearest) EFL.NoOpLayer() + downsample_layers = [ + conv3x3(3 => init_channel_size; stride=config.downsample_times >= 1 ? 2 : 1), + EFL.BatchNorm(init_channel_size, relu; affine=true), + conv3x3(init_channel_size => init_channel_size; stride=config.downsample_times >= 2 ? 2 : 1), + EFL.BatchNorm(init_channel_size, relu; affine=true), ] + for _ in 3:(config.downsample_times) + append!( + downsample_layers, + [ + conv3x3(init_channel_size => init_channel_size; stride=2), + EFL.BatchNorm(init_channel_size, relu; affine=true), + ], + ) + end + downsample = EFL.Chain(downsample_layers...) - post_fuse_layers = ( + stage0 = if config.downsample_times == 0 && config.num_branches <= 2 + EFL.NoOpLayer() + else EFL.Chain( - EFL.WrappedFunction(Base.Fix1(broadcast, gelu)), - conv1x1(24 => 24; initW=NormalInitializer()), - EFL.GroupNorm(24, group_count ÷ 2; affine=true, track_stats=false), - ), + conv1x1(init_channel_size => init_channel_size; bias=false), + EFL.BatchNorm(init_channel_size, relu; affine=true), + ) + end + + initial_layers = EFL.Chain(downsample, stage0) + + main_layers = Tuple( + ResidualBlockV1( + config.num_channels[i] => config.num_channels[i]; + deq_expand=config.expansion_factor, + dropout_rate=config.dropout_rate, + num_gn_groups=config.group_count, + n_big_kernels=config.big_kernels[i], + ) for i in 1:(config.num_branches) + ) + + mapping_layers = Matrix{EFL.AbstractExplicitLayer}(undef, config.num_branches, config.num_branches) + for i in 1:(config.num_branches) + for j in 1:(config.num_branches) + if i == j + mapping_layers[i, j] = EFL.NoOpLayer() + elseif i < j + mapping_layers[i, j] = downsample_module( + config.num_channels[i] => config.num_channels[j], j - i, gelu; group_count=config.group_count + ) + else + mapping_layers[i, j] = upsample_module( + config.num_channels[i] => config.num_channels[j], + i - j, + gelu; + group_count=config.group_count, + upsample_mode=:nearest, + ) + end + end + end + + post_fuse_layers = Tuple( EFL.Chain( EFL.WrappedFunction(Base.Fix1(broadcast, gelu)), - conv1x1(24 => 24; initW=NormalInitializer()), - EFL.GroupNorm(24, group_count ÷ 2; affine=true, track_stats=false), - ), + conv1x1(config.num_channels[i] => config.num_channels[i]), + EFL.GroupNorm(config.num_channels[i], config.group_count ÷ 2; affine=true, track_stats=false), + ) for i in 1:(config.num_branches) ) - final_layers = EFL.Chain( - EFL.Parallel( - +, + increment_modules = EFL.Parallel( + nothing, + [BottleneckBlockV2(config.num_channels[i] => config.head_channels[i]) for i in 1:(config.num_branches)]..., + ) + + downsample_modules = EFL.PairwiseFusion( + config.fuse_method == :sum ? (+) : error("Only `fuse_method` = `:sum` is supported"), + [ EFL.Chain( - BottleneckBlockV2(24 => 8), - conv3x3(8 * 4 => 16 * 4; stride=2, initW=NormalInitializer()), - EFL.BatchNorm(16 * 4, gelu; track_stats=true, affine=true), - ), - BottleneckBlockV2(24 => 16, 4), - ), - conv1x1(16 * 4 => 200; initW=NormalInitializer()), - EFL.BatchNorm(200, gelu; track_stats=true, affine=true), + conv3x3(config.head_channels[i] * 4 => config.head_channels[i + 1] * 4; stride=2, bias=true), + EFL.BatchNorm(config.head_channels[i + 1] * 4, gelu; track_stats=true, affine=true), + ) for i in 1:(config.num_branches - 1) + ]..., + ) + + final_layers = EFL.Chain( + increment_modules, + downsample_modules, + conv1x1(config.head_channels[config.num_branches] * 4 => config.final_channelsize; bias=true), + EFL.BatchNorm(config.final_channelsize, gelu; track_stats=true, affine=true), EFL.GlobalMeanPool(), EFL.FlattenLayer(), - EFL.Dense(200, 10), + EFL.Dense(config.final_channelsize, config.num_classes), ) - solver = if continuous + solver = if config.continuous ContinuousDEQSolver( - VCABM3(); - mode=:rel_deq_best, - abstol=abstol, - reltol=reltol, - abstol_termination=abstol, - reltol_termination=reltol, + config.ode_solver; + mode=config.stop_mode, + abstol=config.abstol, + reltol=config.reltol, + abstol_termination=config.abstol, + reltol_termination=config.reltol, ) else error("Discrete Solvers have not been updated yet") end - sensealg = SteadyStateAdjoint(abstol, reltol, min(maxiters, 15)) + sensealg = SteadyStateAdjoint(config.abstol, config.reltol, config.bwd_maxiters) - deq = if model_type ∈ (:skip, :skipv2) - shortcut = if model_type == :skip - ( - ResidualBlockV2(24 => 24; num_gn_groups=group_count), - downsample_module(24 => 24, 32 => 16, gelu; group_count=group_count), - ) + deq = if config.model_type ∈ (:skip, :skipv2) + shortcut = if config.model_type == :skip + # ( + # ResidualBlockV2(24 => 24; num_gn_groups=group_count), + # downsample_module(24 => 24, 32 => 16, gelu; group_count=group_count), + # ) + # Not yet implemented for the general case + nothing else nothing end @@ -303,19 +336,19 @@ function get_model( post_fuse_layers, shortcut, solver, - ((32, 32, 24), (16, 16, 24)); - maxiters=maxiters, + compute_feature_scales(config); + maxiters=config.fwd_maxiters, sensealg=sensealg, verbose=false, ) - elseif model_type == :vanilla - MultiScaleSkipDeepEquilibriumNetwork( + elseif config.model_type == :vanilla + MultiScaleDeepEquilibriumNetwork( main_layers, mapping_layers, post_fuse_layers, solver, - ((32, 32, 24), (16, 16, 24)); - maxiters=maxiters, + compute_feature_scales(config); + maxiters=config.fwd_maxiters, sensealg=sensealg, verbose=false, ) @@ -324,15 +357,15 @@ function get_model( end model = DEQChain(initial_layers, deq, final_layers) - ps, st = EFL.setup(MersenneTwister(seed), model) .|> device + ps, st = device.(EFL.setup(MersenneTwister(seed), model)) if warmup clean_println("Starting Model Warmup") - x__ = randn(Float32, 32, 32, 3, 1) |> device - y__ = Float32.(Flux.onehotbatch([1], 0:9)) |> device + x__ = device(randn(Float32, config.image_size..., 3, 1)) + y__ = device(Float32.(Flux.onehotbatch([1], 0:9))) model(x__, ps, st) clean_println("Forward Pass Warmup Completed") - lfn = loss_function(:CIFAR10, model_type) + lfn = loss_function(config, model_type) (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st), ps) back((one(l), nothing, nothing, nothing)) clean_println("Backward Pass Warmup Completed") diff --git a/examples/src/train.jl b/examples/src/train.jl index 7a2e5005..a638f546 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -1,3 +1,32 @@ +function invoke_gc() + GC.gc(true) + CUDA.reclaim() + return nothing +end + +function construct_optimiser(config::ExperimentConfiguration) + opt = if config.optimiser == :ADAM + Optimisers.ADAM(config.eta) + elseif config.optimiser == :SGD + if config.nesterov + Optimisers.Nesterov(config.eta, config.momentum) + else + if iszero(config.momentum) + Optimisers.Descent(config.eta) + else + Optimisers.Momentum(config.eta, config.momentum) + end + end + else + throw(ArgumentError("`config.optimiser` must be either `:ADAM` or `:SGD`")) + end + if !iszero(config.weight_decay) + opt = Optimisers.OptimiserChain(opt, Optimisers.WeightDecay(config.weight_decay)) + end + + return opt +end + is_distributed() = FluxMPI.Initialized() && total_workers() > 1 _get_loggable_stats(::Nothing) = () @@ -37,10 +66,27 @@ function evaluate(model, ps, st, dataloader, device) return (loss=total_loss, accuracy=matches, mean_nfe=total_nfe, total_time=total_time, total_datasize=total_datasize) end -function train_one_epoch(model, ps, st, loss_function, opt_state, dataloader, device, lg::PrettyTableLogger) +function train_one_epoch( + model, + ps, + st, + loss_function, + opt_state, + dataloader, + device, + lg::PrettyTableLogger, + econfig::ExperimentConfiguration, + iteration_count::Int, +) total_time = 0 - for (x, y) in dataloader + for (x, y) in enumerate(dataloader) + # Without this we might frequently run out of memory + # especially with the MPI-UCX CUDA.jl mempool issue + iteration_count += 1 + st = econfig.pretrain_steps == iteration_count ? EFL.update_state(st, :fixed_depth, 0) : st + iteration_count % 25 == 0 && invoke_gc() + x = device(x) y = device(y) @@ -59,24 +105,22 @@ function train_one_epoch(model, ps, st, loss_function, opt_state, dataloader, de lg(; records=Dict("Train/Running/NFE" => nfe, "Train/Running/Loss" => loss, "Train/Running/Accuracy" => acc)) end - return ps, st, opt_state, (total_time=total_time,) + return ps, st, opt_state, iteration_count, (total_time=total_time,) end -function loss_function(dataset::Symbol, model_type::Symbol) - if dataset ∈ (:CIFAR10,) - function loss_function_closure(x, y, model, ps, st) - (ŷ, soln), st_ = model(x, ps, st) - loss = if model_type == :vanilla - Flux.Losses.logitcrossentropy(ŷ, y) - else - Flux.Losses.logitcrossentropy(ŷ, y) + Flux.Losses.mse(soln.u₀, soln.z_star) - end - return loss, ŷ, st_, soln.nfe +loss_function(e::ExperimentConfiguration, args...; kwargs...) = loss_function(e.model_config, args...; kwargs...) + +function loss_function(c::ImageClassificationModelConfiguration; λ_skip=1.0f0) + function loss_function_closure(x, y, model, ps, st) + (ŷ, soln), st_ = model(x, ps, st) + loss = if c.model_type == :vanilla + Flux.Losses.logitcrossentropy(ŷ, y) + else + Flux.Losses.logitcrossentropy(ŷ, y) + λ_skip * Flux.Losses.mae(soln.u₀, soln.z_star) end - return loss_function_closure - else - throw(ArgumentError("$dataset - $model_type not yet supported")) + return loss, ŷ, st_, soln.nfe end + return loss_function_closure end function train( @@ -90,26 +134,29 @@ function train( test_dataloader, device, nepochs, - lg::PrettyTableLogger; - cleanup_function=identity, + lg::PrettyTableLogger, + econfig::ExperimentConfiguration; ) - cleanup_function() + invoke_gc() # TODO: Saving model weights opt_state = Optimisers.setup(opt, ps) opt_state = is_distributed() ? FluxMPI.synchronize!(opt_state; root_rank=0) : opt_state + iteration_count = 0 + + st = econfig.pretrain_steps != 0 ? EFL.update_state(st, :fixed_depth, econfig.model_config.num_layers) : st for epoch in 1:nepochs # Train 1 epoch - ps, st, opt_state, training_stats = train_one_epoch( - model, ps, st, loss_function, opt_state, train_dataloader, device, lg + ps, st, opt_state, iteration_count, training_stats = train_one_epoch( + model, ps, st, loss_function, opt_state, train_dataloader, device, lg, econfig, iteration_count ) - cleanup_function() + invoke_gc() # Evaluate val_stats = _get_loggable_stats(evaluate(model, ps, st, val_dataloader, device)) - cleanup_function() + invoke_gc() test_stats = _get_loggable_stats(evaluate(model, ps, st, test_dataloader, device)) - cleanup_function() + invoke_gc() lg(epoch, training_stats.total_time, val_stats..., test_stats...) end From 005f35f40927415b2c30fda9f3c8003e353594e6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 13 Apr 2022 15:42:05 -0400 Subject: [PATCH 23/76] Imagenet model configuration --- examples/Manifest.toml | 4 +++- examples/src/config.jl | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 3ff589e6..742c4985 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -390,7 +390,9 @@ version = "0.0.29+0" [[deps.ExplicitFluxLayers]] deps = ["CUDA", "ChainRulesCore", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Octavian", "Random", "Setfield", "SparseArrays", "Statistics"] -path = "/mnt/research/softwares/ExplicitFluxLayers/" +git-tree-sha1 = "e92b9fcc3a30b1d75e312c04db5ac983cd4e31e3" +repo-rev = "ap/sparse" +repo-url = "https://github.com/avik-pal/ExplicitFluxLayers.jl.git" uuid = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" version = "0.2.0" diff --git a/examples/src/config.jl b/examples/src/config.jl index 0aea3628..61fec777 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -58,7 +58,7 @@ function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) fuse_method=:sum, final_channelsize=200, fwd_maxiters=18, - bwd_maiters=20, + bwd_maxiters=20, kwargs... ) elseif model_size == :LARGE @@ -82,7 +82,7 @@ function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) fuse_method=:sum, final_channelsize=1680, fwd_maxiters=18, - bwd_maiters=20, + bwd_maxiters=20, kwargs... ) else From 3745309828b91931748b7a3ebe5c18e801838eb0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 13 Apr 2022 16:40:23 -0400 Subject: [PATCH 24/76] Update config --- examples/Manifest.toml | 2 +- examples/src/config.jl | 12 ++++++------ examples/src/models.jl | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 742c4985..830a4fa1 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -390,7 +390,7 @@ version = "0.0.29+0" [[deps.ExplicitFluxLayers]] deps = ["CUDA", "ChainRulesCore", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Octavian", "Random", "Setfield", "SparseArrays", "Statistics"] -git-tree-sha1 = "e92b9fcc3a30b1d75e312c04db5ac983cd4e31e3" +git-tree-sha1 = "9c71e6fb85ad4e6fcbcd487be03693e916772474" repo-rev = "ap/sparse" repo-url = "https://github.com/avik-pal/ExplicitFluxLayers.jl.git" uuid = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" diff --git a/examples/src/config.jl b/examples/src/config.jl index 61fec777..875a161d 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -211,7 +211,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eval_batchsize=32, train_batchsize=32, nepochs=50, - pretrain_steps=3000, + pretrain_steps=3000 / (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 / 2 * (is_distributed() ? total_workers() : 1), @@ -225,7 +225,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eval_batchsize=32, train_batchsize=32, nepochs=220, - pretrain_steps=20000, + pretrain_steps=20000 / (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 / 4 * (is_distributed() ? total_workers() : 1), @@ -243,7 +243,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eval_batchsize=32, train_batchsize=32, nepochs=100, - pretrain_steps=510000, + pretrain_steps=510000 / (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:SGD, eta=0.05f0 / 4 * (is_distributed() ? total_workers() : 1), @@ -257,7 +257,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eval_batchsize=32, train_batchsize=32, nepochs=100, - pretrain_steps=510000, + pretrain_steps=510000 / (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:SGD, eta=0.05f0 / 4 * (is_distributed() ? total_workers() : 1), @@ -271,10 +271,10 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eval_batchsize=32, train_batchsize=32, nepochs=100, - pretrain_steps=510000, + pretrain_steps=510000 / (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:SGD, - eta=0.05f0 / 4 * (is_distributed() ? total_workers() : 1), + eta=0.05f0 / 8 * (is_distributed() ? total_workers() : 1), weight_decay=0.00005f0, momentum=0.9f0, nesterov=true diff --git a/examples/src/models.jl b/examples/src/models.jl index 7f04ae7e..3d7f3d6a 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -365,7 +365,7 @@ function get_model( y__ = device(Float32.(Flux.onehotbatch([1], 0:9))) model(x__, ps, st) clean_println("Forward Pass Warmup Completed") - lfn = loss_function(config, model_type) + lfn = loss_function(config) (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st), ps) back((one(l), nothing, nothing, nothing)) clean_println("Backward Pass Warmup Completed") From f67a7761879645789801dd52f1bdb2ec2ee688dc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 13 Apr 2022 16:51:01 -0400 Subject: [PATCH 25/76] Int divide --- examples/src/config.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/src/config.jl b/examples/src/config.jl index 875a161d..51833f66 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -211,7 +211,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eval_batchsize=32, train_batchsize=32, nepochs=50, - pretrain_steps=3000 / (is_distributed() ? total_workers() : 1), + pretrain_steps=3000 ÷ (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 / 2 * (is_distributed() ? total_workers() : 1), @@ -225,7 +225,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eval_batchsize=32, train_batchsize=32, nepochs=220, - pretrain_steps=20000 / (is_distributed() ? total_workers() : 1), + pretrain_steps=20000 ÷ (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 / 4 * (is_distributed() ? total_workers() : 1), @@ -243,7 +243,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eval_batchsize=32, train_batchsize=32, nepochs=100, - pretrain_steps=510000 / (is_distributed() ? total_workers() : 1), + pretrain_steps=510000 ÷ (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:SGD, eta=0.05f0 / 4 * (is_distributed() ? total_workers() : 1), @@ -257,7 +257,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eval_batchsize=32, train_batchsize=32, nepochs=100, - pretrain_steps=510000 / (is_distributed() ? total_workers() : 1), + pretrain_steps=510000 ÷ (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:SGD, eta=0.05f0 / 4 * (is_distributed() ? total_workers() : 1), @@ -271,7 +271,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eval_batchsize=32, train_batchsize=32, nepochs=100, - pretrain_steps=510000 / (is_distributed() ? total_workers() : 1), + pretrain_steps=510000 ÷ (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:SGD, eta=0.05f0 / 8 * (is_distributed() ? total_workers() : 1), From 7a401befdd60b58e299169fe6a39b73adafb65c0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 13 Apr 2022 16:56:55 -0400 Subject: [PATCH 26/76] exptconfig --- examples/cifar10/script.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index 2247eaef..833770b7 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -47,7 +47,7 @@ function train_model(config, expt_name) # Get Dataloaders train_dataloader, test_dataloader = FastDEQExperiments.get_dataloaders( - :CIFAR10; train_batchsize=config["batchsize"], eval_batchsize=config["eval_batchsize"] + :CIFAR10; train_batchsize=expt_config.train_batchsize, eval_batchsize=expt_config.eval_batchsize ) # Train From e51a483341610820ffc9a11ad88278a1696430db Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 13 Apr 2022 17:11:02 -0400 Subject: [PATCH 27/76] Oops --- examples/src/train.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/train.jl b/examples/src/train.jl index a638f546..bab8fd8b 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -80,7 +80,7 @@ function train_one_epoch( ) total_time = 0 - for (x, y) in enumerate(dataloader) + for (x, y) in dataloader # Without this we might frequently run out of memory # especially with the MPI-UCX CUDA.jl mempool issue iteration_count += 1 From e3f411f03f1de21061d4f14eaa635b8183f901e9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 13 Apr 2022 17:34:13 -0400 Subject: [PATCH 28/76] Remove nothing --- examples/src/models.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/src/models.jl b/examples/src/models.jl index 3d7f3d6a..50bfc53d 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -137,8 +137,7 @@ function ResidualBlockV2( conv1, gn1, conv2, - EFL.BranchLayer(downsample, dropout), - EFL.Parallel(+, EFL.NoOpLayer(), gn2), + EFL.Parallel(+, downsample, EFL.Chain(dropout, gn2)), EFL.WrappedFunction(Base.Fix1(broadcast, gelu)), gn3, ) @@ -321,12 +320,14 @@ function get_model( deq = if config.model_type ∈ (:skip, :skipv2) shortcut = if config.model_type == :skip - # ( - # ResidualBlockV2(24 => 24; num_gn_groups=group_count), - # downsample_module(24 => 24, 32 => 16, gelu; group_count=group_count), - # ) - # Not yet implemented for the general case - nothing + slayers = EFL.AbstractExplicitLayer[ResidualBlockV2(config.num_channels[1] => config.num_channels[1])] + for i in 1:(config.num_branches - 1) + push!( + slayers, + downsample_module(config.num_channels[1] => config.num_channels[i + 1], 1, gelu; group_count=config.group_count), + ) + end + tuple(slayers...) else nothing end From 43fd1acdefa5825afc547f53d939609c948d78a8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 13 Apr 2022 17:49:22 -0400 Subject: [PATCH 29/76] Warmup the pretraining mode --- examples/Manifest.toml | 2 +- examples/src/models.jl | 6 ++++++ examples/src/train.jl | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 830a4fa1..5ed7ac93 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -390,7 +390,7 @@ version = "0.0.29+0" [[deps.ExplicitFluxLayers]] deps = ["CUDA", "ChainRulesCore", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Octavian", "Random", "Setfield", "SparseArrays", "Statistics"] -git-tree-sha1 = "9c71e6fb85ad4e6fcbcd487be03693e916772474" +git-tree-sha1 = "e911998f524df4fbccc4f9978d4c686462c62ee2" repo-rev = "ap/sparse" repo-url = "https://github.com/avik-pal/ExplicitFluxLayers.jl.git" uuid = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" diff --git a/examples/src/models.jl b/examples/src/models.jl index 50bfc53d..88875552 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -366,10 +366,16 @@ function get_model( y__ = device(Float32.(Flux.onehotbatch([1], 0:9))) model(x__, ps, st) clean_println("Forward Pass Warmup Completed") + st_ = EFL.update_state(st, :fixed_depth, config.num_layers) + model(x__, ps, st_) + clean_println("Forward Pass (Pretraining) Warmup Completed") lfn = loss_function(config) (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st), ps) back((one(l), nothing, nothing, nothing)) clean_println("Backward Pass Warmup Completed") + (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st_), ps) + back((one(l), nothing, nothing, nothing)) + clean_println("Backward Pass (Pretraining) Warmup Completed") end ps, st = if is_distributed() diff --git a/examples/src/train.jl b/examples/src/train.jl index bab8fd8b..d5816114 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -85,7 +85,7 @@ function train_one_epoch( # especially with the MPI-UCX CUDA.jl mempool issue iteration_count += 1 st = econfig.pretrain_steps == iteration_count ? EFL.update_state(st, :fixed_depth, 0) : st - iteration_count % 25 == 0 && invoke_gc() + iteration_count % 5 == 0 && invoke_gc() x = device(x) y = device(y) From 1b57511f69eee137ac7046a92d902ecbb1300626 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 14 Apr 2022 16:11:14 -0400 Subject: [PATCH 30/76] updates --- examples/Manifest.toml | 2 +- examples/src/FastDEQExperiments.jl | 12 ++++++++++++ examples/src/config.jl | 4 ++-- examples/src/models.jl | 5 +++++ examples/src/train.jl | 22 +++++++++------------- 5 files changed, 29 insertions(+), 16 deletions(-) diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 5ed7ac93..6a17ea0b 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -390,7 +390,7 @@ version = "0.0.29+0" [[deps.ExplicitFluxLayers]] deps = ["CUDA", "ChainRulesCore", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Octavian", "Random", "Setfield", "SparseArrays", "Statistics"] -git-tree-sha1 = "e911998f524df4fbccc4f9978d4c686462c62ee2" +git-tree-sha1 = "76d1d41d26fd2eec2ca4bc1f5780b51ee8064ef2" repo-rev = "ap/sparse" repo-url = "https://github.com/avik-pal/ExplicitFluxLayers.jl.git" uuid = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" diff --git a/examples/src/FastDEQExperiments.jl b/examples/src/FastDEQExperiments.jl index 68c3a31e..71ec612d 100644 --- a/examples/src/FastDEQExperiments.jl +++ b/examples/src/FastDEQExperiments.jl @@ -18,6 +18,18 @@ import MLDataUtils: nobs, getobs const EFL = ExplicitFluxLayers +# Memory Management +relieve_gc_pressure(::Union{Nothing,<:AbstractArray}) = nothing +relieve_gc_pressure(x::CuArray) = CUDA.unsafe_free!(x) +relieve_gc_pressure(t::Tuple) = relieve_gc_pressure.(t) +relieve_gc_pressure(x::NamedTuple) = fmap(relieve_gc_pressure, x) + +function invoke_gc() + GC.gc(true) + # CUDA.reclaim() + return nothing +end + # PrettyTableLogger include("logging.jl") # get_model_config diff --git a/examples/src/config.jl b/examples/src/config.jl index 51833f66..7fed228b 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -208,8 +208,8 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) if model_size == :TINY return ExperimentConfiguration( model_config=get_model_config(dataset, model_size; kwargs...), - eval_batchsize=32, - train_batchsize=32, + eval_batchsize=64, + train_batchsize=64, nepochs=50, pretrain_steps=3000 ÷ (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, diff --git a/examples/src/models.jl b/examples/src/models.jl index 88875552..fd90629a 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -366,16 +366,21 @@ function get_model( y__ = device(Float32.(Flux.onehotbatch([1], 0:9))) model(x__, ps, st) clean_println("Forward Pass Warmup Completed") + st_ = EFL.update_state(st, :fixed_depth, config.num_layers) model(x__, ps, st_) clean_println("Forward Pass (Pretraining) Warmup Completed") + lfn = loss_function(config) (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st), ps) back((one(l), nothing, nothing, nothing)) clean_println("Backward Pass Warmup Completed") + (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st_), ps) back((one(l), nothing, nothing, nothing)) clean_println("Backward Pass (Pretraining) Warmup Completed") + + invoke_gc() end ps, st = if is_distributed() diff --git a/examples/src/train.jl b/examples/src/train.jl index d5816114..11c2e454 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -1,9 +1,3 @@ -function invoke_gc() - GC.gc(true) - CUDA.reclaim() - return nothing -end - function construct_optimiser(config::ExperimentConfiguration) opt = if config.optimiser == :ADAM Optimisers.ADAM(config.eta) @@ -60,7 +54,7 @@ function evaluate(model, ps, st, dataloader, device) total_nfe += soln.nfe * size(x, ndims(x)) total_loss += Flux.Losses.logitcrossentropy(ŷ, y) * size(x, ndims(x)) - matches += sum(argmax.(eachcol(ŷ)) .== Flux.onecold(cpu(y))) + matches += sum(argmax.(eachcol(cpu(ŷ))) .== Flux.onecold(cpu(y))) total_datasize += size(x, ndims(x)) end return (loss=total_loss, accuracy=matches, mean_nfe=total_nfe, total_time=total_time, total_datasize=total_datasize) @@ -81,12 +75,6 @@ function train_one_epoch( total_time = 0 for (x, y) in dataloader - # Without this we might frequently run out of memory - # especially with the MPI-UCX CUDA.jl mempool issue - iteration_count += 1 - st = econfig.pretrain_steps == iteration_count ? EFL.update_state(st, :fixed_depth, 0) : st - iteration_count % 5 == 0 && invoke_gc() - x = device(x) y = device(y) @@ -101,6 +89,14 @@ function train_one_epoch( acc = sum(argmax.(eachcol(cpu(ŷ))) .== Flux.onecold(cpu(y))) / size(x, 4) + # Relieve GC Pressure + relieve_gc_pressure((gs, ŷ, x, y)) + # Without this we might frequently run out of memory + # especially with the MPI-UCX CUDA.jl mempool issue + iteration_count += 1 + st = econfig.pretrain_steps == iteration_count ? EFL.update_state(st, :fixed_depth, 0) : st + iteration_count % 25 == 0 && invoke_gc() + # Logging lg(; records=Dict("Train/Running/NFE" => nfe, "Train/Running/Loss" => loss, "Train/Running/Accuracy" => acc)) end From b57102dbd63e0626e7bc613b466862e4fcc79df6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 16 Apr 2022 12:07:55 -0400 Subject: [PATCH 31/76] Update training --- examples/Manifest.toml | 65 +++++++++++++++++++++--------- examples/Project.toml | 2 + examples/cifar10/script.jl | 2 +- examples/src/FastDEQExperiments.jl | 4 +- examples/src/models.jl | 15 +++---- examples/src/train.jl | 47 +++++++++++++++++---- 6 files changed, 100 insertions(+), 35 deletions(-) diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 6a17ea0b..b1fe4362 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -125,10 +125,10 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.4.1" [[deps.CPUSummary]] -deps = ["IfElse", "Static"] -git-tree-sha1 = "48e01b22ef077b07541309652f697595f8decf25" +deps = ["CpuId", "IfElse", "Static"] +git-tree-sha1 = "913b28a04929053e4310d0a4915f1efe195c0ce6" uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" -version = "0.1.18" +version = "0.1.19" [[deps.CSV]] deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings"] @@ -226,6 +226,12 @@ git-tree-sha1 = "8ccaa8c655bc1b83d2da4d569c9b28254ababd6e" uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" version = "0.1.2" +[[deps.CpuId]] +deps = ["Markdown"] +git-tree-sha1 = "32d125af0fb8ec3f8935896122c5e345709909e5" +uuid = "adafc99b-e345-5852-983c-f28acb93d879" +version = "0.3.0" + [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" @@ -250,9 +256,9 @@ version = "0.7.7" [[deps.DataFrames]] deps = ["Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "ae02104e835f219b8930c7664b8012c93475c340" +git-tree-sha1 = "6c19003824cbebd804a51211fd3bbd81bf1ecad5" uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.3.2" +version = "1.3.3" [[deps.DataLoaders]] deps = ["DocStringExtensions", "LearnBase", "MLDataPattern", "Parameters", "Random", "ThreadPools"] @@ -292,9 +298,9 @@ version = "0.4.0" [[deps.DiffEqBase]] deps = ["ArrayInterface", "ChainRulesCore", "DEDataArrays", "DataStructures", "Distributions", "DocStringExtensions", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "IterativeSolvers", "LabelledArrays", "LinearAlgebra", "Logging", "MuladdMacro", "NonlinearSolve", "Parameters", "PreallocationTools", "Printf", "RecursiveArrayTools", "RecursiveFactorization", "Reexport", "Requires", "SciMLBase", "Setfield", "SparseArrays", "StaticArrays", "Statistics", "SuiteSparse", "ZygoteRules"] -git-tree-sha1 = "d19393983b7609b0b7d4caa2bce6b018f663b688" +git-tree-sha1 = "cde20558d9a50ebef5f173aaa0e6ece8ca563c93" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" -version = "6.83.0" +version = "6.83.1" [[deps.DiffEqCallbacks]] deps = ["DataStructures", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "NLsolve", "OrdinaryDiffEq", "Parameters", "RecipesBase", "RecursiveArrayTools", "SciMLBase", "StaticArrays"] @@ -397,10 +403,10 @@ uuid = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" version = "0.2.0" [[deps.ExponentialUtilities]] -deps = ["ArrayInterface", "LinearAlgebra", "Printf", "Requires", "SparseArrays", "libblastrampoline_jll"] -git-tree-sha1 = "b026981973ccbe38682fbb4ccb0732fd6b1e1207" +deps = ["ArrayInterface", "GenericSchur", "LinearAlgebra", "Printf", "Requires", "SparseArrays", "libblastrampoline_jll"] +git-tree-sha1 = "951c44b4af9d1e061d5cf789a30881471604c14c" uuid = "d4d017d3-3776-5f7e-afef-a10c40355c18" -version = "1.13.0" +version = "1.14.0" [[deps.ExprTools]] git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d" @@ -547,6 +553,12 @@ git-tree-sha1 = "039be665faf0b8ae36e089cd694233f5dee3f7d6" uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" version = "0.5.1" +[[deps.GenericSchur]] +deps = ["LinearAlgebra", "Printf"] +git-tree-sha1 = "fb69b2a645fa69ba5f474af09221b9308b160ce6" +uuid = "c145ed77-6b09-5dd9-b285-bf645a82121e" +version = "0.5.3" + [[deps.Glob]] git-tree-sha1 = "4df9f7e06108728ebf00a0a11edee4b29a482bb2" uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" @@ -611,6 +623,17 @@ git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" version = "0.1.1" +[[deps.InfiniteArrays]] +deps = ["ArrayLayouts", "FillArrays", "Infinities", "LazyArrays", "LinearAlgebra", "Statistics"] +git-tree-sha1 = "e84e4ea66ca02755eaf7063e07169bf22370ee2c" +uuid = "4858937d-0d70-526a-a4dd-2d5cb5dd786c" +version = "0.12.6" + +[[deps.Infinities]] +git-tree-sha1 = "b2732e2076cd50639d827f9ae9fc4ea913c927fe" +uuid = "e1ba4f0e-776d-440f-acd9-e1d2e9742647" +version = "0.1.4" + [[deps.Inflate]] git-tree-sha1 = "f5fc07d4e706b84f72d54eedcc1c13d92fb0871c" uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" @@ -1096,6 +1119,12 @@ git-tree-sha1 = "e8185b83b9fc56eb6456200e873ce598ebc7f262" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" version = "0.11.7" +[[deps.ParameterSchedulers]] +deps = ["Flux", "InfiniteArrays"] +git-tree-sha1 = "68f63744d5d3e1714f989a9b4f38182275d3f348" +uuid = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" +version = "0.3.3" + [[deps.Parameters]] deps = ["OrderedCollections", "UnPack"] git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" @@ -1110,9 +1139,9 @@ version = "2.2.4" [[deps.Pickle]] deps = ["DataStructures", "InternedStrings", "Serialization", "SparseArrays", "Strided", "StringEncodings", "ZipFile"] -git-tree-sha1 = "de8165bc4d1c448824cefa98cd5cd281dc01d9b2" +git-tree-sha1 = "8e4ba4cb57bedd0289865c65ffedeee910d6a8b6" uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" -version = "0.3.0" +version = "0.3.1" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] @@ -1261,9 +1290,9 @@ version = "1.1.1" [[deps.ReverseDiff]] deps = ["ChainRulesCore", "DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"] -git-tree-sha1 = "8d85c98fc33d4d37d88c8f9ccee4f1f3f98e56f4" +git-tree-sha1 = "559db2c7a28262e9ff1af1ad4ec539aa972c8934" uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -version = "1.12.0" +version = "1.13.0" [[deps.Rmath]] deps = ["Random", "Rmath_jll"] @@ -1299,9 +1328,9 @@ version = "0.1.0" [[deps.SLEEFPirates]] deps = ["IfElse", "Static", "VectorizationBase"] -git-tree-sha1 = "d4c366b135fc2e1af7a000473e08edc5afd94819" +git-tree-sha1 = "ac399b5b163b9140f9c310dfe9e9aaa225617ff6" uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa" -version = "0.6.31" +version = "0.6.32" [[deps.SciMLBase]] deps = ["ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "RecipesBase", "RecursiveArrayTools", "StaticArrays", "Statistics", "Tables", "TreeViews"] @@ -1607,9 +1636,9 @@ uuid = "83775a58-1f1d-513f-b197-d71354ab007a" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "52adc0a505b6421a8668f13dcdb0c4cb498bd72c" +git-tree-sha1 = "8c3e9ae8c2b520200df59d4f683a0dab65685ade" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.37" +version = "0.6.38" [[deps.ZygoteRules]] deps = ["MacroTools"] diff --git a/examples/Project.toml b/examples/Project.toml index 479b615a..5e213565 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -18,4 +18,6 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index 833770b7..fdeecd22 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -56,7 +56,7 @@ function train_model(config, expt_name) ps, st, FastDEQExperiments.loss_function(expt_config), - FastDEQExperiments.construct_optimiser(expt_config), + FastDEQExperiments.construct_optimiser(expt_config)..., train_dataloader, nothing, test_dataloader, diff --git a/examples/src/FastDEQExperiments.jl b/examples/src/FastDEQExperiments.jl index 71ec612d..af303d6b 100644 --- a/examples/src/FastDEQExperiments.jl +++ b/examples/src/FastDEQExperiments.jl @@ -12,7 +12,9 @@ using FastDEQ, DataLoaders, Optimisers, MPI, - CUDA + CUDA, + Setfield, + ParameterSchedulers import LearnBase: ObsDim import MLDataUtils: nobs, getobs diff --git a/examples/src/models.jl b/examples/src/models.jl index fd90629a..e6504ad7 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -364,22 +364,23 @@ function get_model( clean_println("Starting Model Warmup") x__ = device(randn(Float32, config.image_size..., 3, 1)) y__ = device(Float32.(Flux.onehotbatch([1], 0:9))) - model(x__, ps, st) - clean_println("Forward Pass Warmup Completed") - st_ = EFL.update_state(st, :fixed_depth, config.num_layers) + st_ = EFL.update_state(st, :fixed_depth, 2) model(x__, ps, st_) clean_println("Forward Pass (Pretraining) Warmup Completed") - lfn = loss_function(config) - (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st), ps) - back((one(l), nothing, nothing, nothing)) - clean_println("Backward Pass Warmup Completed") + model(x__, ps, st) + clean_println("Forward Pass Warmup Completed") + lfn = loss_function(config) (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st_), ps) back((one(l), nothing, nothing, nothing)) clean_println("Backward Pass (Pretraining) Warmup Completed") + (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st), ps) + back((one(l), nothing, nothing, nothing)) + clean_println("Backward Pass Warmup Completed") + invoke_gc() end diff --git a/examples/src/train.jl b/examples/src/train.jl index 11c2e454..74f70324 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -1,3 +1,18 @@ +function update_lr(st::ST, eta) where {ST} + if hasfield(ST, :eta) + @set! st.eta = eta + end + return st +end + +update_lr(st::Optimisers.OptimiserChain, eta) = update_lr.(st.opts, eta) + +function update_lr(st::Optimisers.Leaf, eta) + @set! st.rule = update_lr(st.rule, eta) +end + +update_lr(st_opt::NamedTuple, eta) = fmap(l -> update_lr(l, eta), st_opt) + function construct_optimiser(config::ExperimentConfiguration) opt = if config.optimiser == :ADAM Optimisers.ADAM(config.eta) @@ -18,7 +33,15 @@ function construct_optimiser(config::ExperimentConfiguration) opt = Optimisers.OptimiserChain(opt, Optimisers.WeightDecay(config.weight_decay)) end - return opt + sched = if config.lr_scheduler == :COSINE + ParameterSchedulers.Stateful(ParameterSchedulers.Cos(config.eta, 1.0f-6, config.nepochs)) + elseif config.lr_scheduler == :CONSTANT + ParameterSchedulers.Stateful(ParameterSchedulers.Constant(config.eta)) + else + throw(ArgumentError("`config.lr_scheduler` must be either `:COSINE` or `:CONSTANT`")) + end + + return opt, sched end is_distributed() = FluxMPI.Initialized() && total_workers() > 1 @@ -89,13 +112,16 @@ function train_one_epoch( acc = sum(argmax.(eachcol(cpu(ŷ))) .== Flux.onecold(cpu(y))) / size(x, 4) - # Relieve GC Pressure - relieve_gc_pressure((gs, ŷ, x, y)) - # Without this we might frequently run out of memory - # especially with the MPI-UCX CUDA.jl mempool issue iteration_count += 1 st = econfig.pretrain_steps == iteration_count ? EFL.update_state(st, :fixed_depth, 0) : st - iteration_count % 25 == 0 && invoke_gc() + if iteration_count % 25 == 0 + # Without this we might frequently run out of memory + # Disabling CUDA Mempool with MPI seems to help but it + # also slows down the overall code + # Relieve GC Pressure + relieve_gc_pressure((gs, ŷ, x, y)) + invoke_gc() + end # Logging lg(; records=Dict("Train/Running/NFE" => nfe, "Train/Running/Loss" => loss, "Train/Running/Accuracy" => acc)) @@ -106,13 +132,13 @@ end loss_function(e::ExperimentConfiguration, args...; kwargs...) = loss_function(e.model_config, args...; kwargs...) -function loss_function(c::ImageClassificationModelConfiguration; λ_skip=1.0f0) +function loss_function(c::ImageClassificationModelConfiguration; λ_skip=1.0f-2) function loss_function_closure(x, y, model, ps, st) (ŷ, soln), st_ = model(x, ps, st) loss = if c.model_type == :vanilla Flux.Losses.logitcrossentropy(ŷ, y) else - Flux.Losses.logitcrossentropy(ŷ, y) + λ_skip * Flux.Losses.mae(soln.u₀, soln.z_star) + Flux.Losses.logitcrossentropy(ŷ, y) + λ_skip * Flux.Losses.mae(soln.u₀, Flux.Zygote.dropgrad(soln.z_star)) end return loss, ŷ, st_, soln.nfe end @@ -125,6 +151,7 @@ function train( st, loss_function, opt, + scheduler, train_dataloader, val_dataloader, test_dataloader, @@ -155,6 +182,10 @@ function train( invoke_gc() lg(epoch, training_stats.total_time, val_stats..., test_stats...) + + # Run ParameterScheduler + eta_new = ParameterSchedulers.next!(scheduler) + opt_state = update_lr(opt_state, eta_new) end return ps, st, opt_state From 8e84ae2ed581d62a0f64a0584707e7e727643401 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 17 Apr 2022 19:03:20 -0400 Subject: [PATCH 32/76] Update defaults --- examples/Manifest.toml | 39 +++++++++++++++++++----- examples/cifar10/script.jl | 4 ++- examples/src/config.jl | 8 ++--- examples/src/models.jl | 61 +++++++++++++++++++++----------------- examples/src/train.jl | 5 ++-- src/solve.jl | 39 +++++++++++++++++------- 6 files changed, 102 insertions(+), 54 deletions(-) diff --git a/examples/Manifest.toml b/examples/Manifest.toml index b1fe4362..e58bcd48 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.7.2" +julia_version = "1.8.0-beta3" manifest_format = "2.0" -project_hash = "c980a263b9193d100d02f087d77471d79992ffe5" +project_hash = "55e1c12df8760eecee422230ecbed81b99e0268c" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] @@ -29,6 +29,7 @@ version = "2.3.0" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" [[deps.ArnoldiMethod]] deps = ["LinearAlgebra", "Random", "StaticArrays"] @@ -126,9 +127,9 @@ version = "0.4.1" [[deps.CPUSummary]] deps = ["CpuId", "IfElse", "Static"] -git-tree-sha1 = "913b28a04929053e4310d0a4915f1efe195c0ce6" +git-tree-sha1 = "80f3d536df634cabed8b98ad3f0cea3a715fd254" uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" -version = "0.1.19" +version = "0.1.20" [[deps.CSV]] deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings"] @@ -203,6 +204,7 @@ version = "3.43.0" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "0.5.2+0" [[deps.CompositeTypes]] git-tree-sha1 = "d5b014b216dc891e81fea299638e4c10c657b582" @@ -373,8 +375,9 @@ uuid = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" version = "0.5.9" [[deps.Downloads]] -deps = ["ArgTools", "LibCURL", "NetworkOptions"] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" [[deps.EllipsisNotation]] deps = ["ArrayInterface"] @@ -396,7 +399,7 @@ version = "0.0.29+0" [[deps.ExplicitFluxLayers]] deps = ["CUDA", "ChainRulesCore", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Octavian", "Random", "Setfield", "SparseArrays", "Statistics"] -git-tree-sha1 = "76d1d41d26fd2eec2ca4bc1f5780b51ee8064ef2" +git-tree-sha1 = "c1e2783e6ff0a8327b58e869f9c243769f4e89d0" repo-rev = "ap/sparse" repo-url = "https://github.com/avik-pal/ExplicitFluxLayers.jl.git" uuid = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" @@ -466,6 +469,9 @@ git-tree-sha1 = "129b104185df66e408edd6625d480b7f9e9823a0" uuid = "48062228-2e41-5def-b9a4-89aafe57970f" version = "0.9.18" +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + [[deps.FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] git-tree-sha1 = "246621d23d1f43e3b9c368bf3b72b2331a27c286" @@ -673,9 +679,9 @@ version = "0.7.0" [[deps.Interpolations]] deps = ["AxisAlgorithms", "ChainRulesCore", "LinearAlgebra", "OffsetArrays", "Random", "Ratios", "Requires", "SharedArrays", "SparseArrays", "StaticArrays", "WoodburyMatrices"] -git-tree-sha1 = "b15fc0a95c564ca2e0a7ae12c1f095ca848ceb31" +git-tree-sha1 = "b7bc05649af456efc75d178846f47006c2c4c3c7" uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" -version = "0.13.5" +version = "0.13.6" [[deps.IntervalSets]] deps = ["Dates", "EllipsisNotation", "Statistics"] @@ -825,10 +831,12 @@ version = "1.0.0" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.3" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "7.81.0+0" [[deps.LibGit2]] deps = ["Base64", "NetworkOptions", "Printf", "SHA"] @@ -837,6 +845,7 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.10.2+0" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -972,6 +981,7 @@ version = "1.0.3" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.0+0" [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] @@ -996,6 +1006,7 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2022.2.1" [[deps.MuladdMacro]] git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" @@ -1039,6 +1050,7 @@ version = "0.1.5" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" [[deps.NonlinearSolve]] deps = ["ArrayInterface", "FiniteDiff", "ForwardDiff", "IterativeSolvers", "LinearAlgebra", "RecursiveArrayTools", "RecursiveFactorization", "Reexport", "SciMLBase", "Setfield", "StaticArrays", "UnPack"] @@ -1067,10 +1079,12 @@ version = "1.10.8" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.20+0" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+0" [[deps.OpenMPI_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] @@ -1146,6 +1160,7 @@ version = "0.3.1" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.8.0" [[deps.PoissonRandom]] deps = ["Random", "Statistics", "Test"] @@ -1314,6 +1329,7 @@ version = "0.5.3" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" [[deps.SIMDDualNumbers]] deps = ["ForwardDiff", "IfElse", "SLEEFPirates", "VectorizationBase"] @@ -1494,10 +1510,12 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "5.10.1+0" [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.0" [[deps.TableTraits]] deps = ["IteratorInterfaceExtensions"] @@ -1514,6 +1532,7 @@ version = "1.7.0" [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" [[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] @@ -1633,6 +1652,7 @@ version = "0.9.4" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.12+1" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] @@ -1649,11 +1669,14 @@ version = "0.2.2" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.1.0+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.41.0+1" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "16.2.1+1" diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index fdeecd22..e4f1c221 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -12,7 +12,7 @@ config = Dict( "seed" => 0, "abstol" => 5.0f-2, "reltol" => 5.0f-2, - "model_type" => :skip, + "model_type" => :SKIP, "continuous" => true, "model_size" => :TINY, ) @@ -35,6 +35,8 @@ function train_model(config, expt_name) config["model_size"]; model_type = config["model_type"], continuous = config["continuous"], + abstol = config["abstol"], + reltol = config["reltol"], ) # Model Setup diff --git a/examples/src/config.jl b/examples/src/config.jl index 7fed228b..8f30cd1c 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -29,9 +29,9 @@ Base.@kwdef struct ImageClassificationModelConfiguration{N} <: AbstractTaskModel continuous::Bool # Specific for Continuous Models - abstol::Float32 = 1f-2 - reltol::Float32 = 1f-2 - stop_mode::Symbol = :abs_deq_best + abstol::Float32 = 5f-2 + reltol::Float32 = 5f-2 + stop_mode::Symbol = :rel_norm ode_solver = VCABM3() end @@ -211,7 +211,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eval_batchsize=64, train_batchsize=64, nepochs=50, - pretrain_steps=3000 ÷ (is_distributed() ? total_workers() : 1), + pretrain_steps=0 ÷ (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 / 2 * (is_distributed() ? total_workers() : 1), diff --git a/examples/src/models.jl b/examples/src/models.jl index e6504ad7..3dee5687 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -71,8 +71,8 @@ function ResidualBlockV1( ) inplanes, outplanes = mapping inner_planes = outplanes * deq_expand - conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; bias=false) - conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=false) + conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; bias=true) + conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=true) conv1, conv2 = if weight_norm EFL.WeightNorm(conv1, (:weight,), (4,)), EFL.WeightNorm(conv2, (:weight,), (4,)) @@ -80,8 +80,8 @@ function ResidualBlockV1( conv1, conv2 end - gn1 = EFL.GroupNorm(inner_planes, num_gn_groups, gelu; affine=gn_affine, track_stats=gn_track_stats) - gn2 = EFL.GroupNorm(outplanes, num_gn_groups, gelu; affine=gn_affine, track_stats=gn_track_stats) + gn1 = EFL.GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + gn2 = EFL.GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) gn3 = EFL.GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) dropout = iszero(dropout_rate) ? EFL.NoOpLayer() : EFL.Dropout(dropout_rate) @@ -100,7 +100,7 @@ function ResidualBlockV1( gn2, ), # For (y2, injection) ), - EFL.WrappedFunction(Base.Fix1(broadcast, gelu)), + EFL.WrappedFunction(Base.Fix1(broadcast, relu)), gn3, ) end @@ -118,8 +118,8 @@ function ResidualBlockV2( ) inplanes, outplanes = mapping inner_planes = outplanes * deq_expand - conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; bias=false) - conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=false) + conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; bias=true) + conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=true) conv1, conv2 = if weight_norm EFL.WeightNorm(conv1, (:weight,), (4,)), EFL.WeightNorm(conv2, (:weight,), (4,)) @@ -127,8 +127,8 @@ function ResidualBlockV2( conv1, conv2 end - gn1 = EFL.GroupNorm(inner_planes, num_gn_groups, gelu; affine=gn_affine, track_stats=gn_track_stats) - gn2 = EFL.GroupNorm(outplanes, num_gn_groups, gelu; affine=gn_affine, track_stats=gn_track_stats) + gn1 = EFL.GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + gn2 = EFL.GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) gn3 = EFL.GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) dropout = iszero(dropout_rate) ? EFL.NoOpLayer() : EFL.Dropout(dropout_rate) @@ -138,7 +138,7 @@ function ResidualBlockV2( gn1, conv2, EFL.Parallel(+, downsample, EFL.Chain(dropout, gn2)), - EFL.WrappedFunction(Base.Fix1(broadcast, gelu)), + EFL.WrappedFunction(Base.Fix1(broadcast, relu)), gn3, ) end @@ -161,15 +161,15 @@ function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool EFL.Chain( EFL.WrappedFunction(y2i -> y2i[1] .+ y2i[2]), # Since injection could be a scalar EFL.Chain( - EFL.BatchNorm(last(mapping), gelu; affine=bn_affine, track_stats=bn_track_stats), + EFL.BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats), conv3x3(last(mapping) => last(mapping) * expansion), - EFL.BatchNorm(last(mapping) * expansion, gelu; track_stats=bn_track_stats, affine=bn_affine), + EFL.BatchNorm(last(mapping) * expansion, relu; track_stats=bn_track_stats, affine=bn_affine), conv1x1(last(mapping) * expansion => last(mapping) * expansion), EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), ), ), ), - EFL.WrappedFunction(Base.Fix1(broadcast, gelu)), + EFL.WrappedFunction(Base.Fix1(broadcast, relu)), ) end @@ -189,14 +189,14 @@ function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool rescale, EFL.Chain( conv1x1(mapping), - EFL.BatchNorm(last(mapping), gelu; affine=bn_affine, track_stats=bn_track_stats), + EFL.BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats), conv3x3(last(mapping) => last(mapping) * expansion), - EFL.BatchNorm(last(mapping) * expansion, gelu; track_stats=bn_track_stats, affine=bn_affine), + EFL.BatchNorm(last(mapping) * expansion, relu; track_stats=bn_track_stats, affine=bn_affine), conv1x1(last(mapping) * expansion => last(mapping) * expansion), EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), ), ), - EFL.WrappedFunction(Base.Fix1(broadcast, gelu)), + EFL.WrappedFunction(Base.Fix1(broadcast, relu)), ) end @@ -256,13 +256,13 @@ function get_model( mapping_layers[i, j] = EFL.NoOpLayer() elseif i < j mapping_layers[i, j] = downsample_module( - config.num_channels[i] => config.num_channels[j], j - i, gelu; group_count=config.group_count + config.num_channels[i] => config.num_channels[j], j - i, relu; group_count=config.group_count ) else mapping_layers[i, j] = upsample_module( config.num_channels[i] => config.num_channels[j], i - j, - gelu; + relu; group_count=config.group_count, upsample_mode=:nearest, ) @@ -272,7 +272,7 @@ function get_model( post_fuse_layers = Tuple( EFL.Chain( - EFL.WrappedFunction(Base.Fix1(broadcast, gelu)), + EFL.WrappedFunction(Base.Fix1(broadcast, relu)), conv1x1(config.num_channels[i] => config.num_channels[i]), EFL.GroupNorm(config.num_channels[i], config.group_count ÷ 2; affine=true, track_stats=false), ) for i in 1:(config.num_branches) @@ -288,7 +288,7 @@ function get_model( [ EFL.Chain( conv3x3(config.head_channels[i] * 4 => config.head_channels[i + 1] * 4; stride=2, bias=true), - EFL.BatchNorm(config.head_channels[i + 1] * 4, gelu; track_stats=true, affine=true), + EFL.BatchNorm(config.head_channels[i + 1] * 4, relu; track_stats=true, affine=true), ) for i in 1:(config.num_branches - 1) ]..., ) @@ -297,7 +297,7 @@ function get_model( increment_modules, downsample_modules, conv1x1(config.head_channels[config.num_branches] * 4 => config.final_channelsize; bias=true), - EFL.BatchNorm(config.final_channelsize, gelu; track_stats=true, affine=true), + EFL.BatchNorm(config.final_channelsize, relu; track_stats=true, affine=true), EFL.GlobalMeanPool(), EFL.FlattenLayer(), EFL.Dense(config.final_channelsize, config.num_classes), @@ -307,8 +307,8 @@ function get_model( ContinuousDEQSolver( config.ode_solver; mode=config.stop_mode, - abstol=config.abstol, - reltol=config.reltol, + abstol=1f-5, #config.abstol, + reltol=1f-5, #config.reltol, abstol_termination=config.abstol, reltol_termination=config.reltol, ) @@ -318,13 +318,18 @@ function get_model( sensealg = SteadyStateAdjoint(config.abstol, config.reltol, config.bwd_maxiters) - deq = if config.model_type ∈ (:skip, :skipv2) - shortcut = if config.model_type == :skip + deq = if config.model_type ∈ (:SKIP, :SKIPV2) + shortcut = if config.model_type == :SKIP slayers = EFL.AbstractExplicitLayer[ResidualBlockV2(config.num_channels[1] => config.num_channels[1])] for i in 1:(config.num_branches - 1) push!( slayers, - downsample_module(config.num_channels[1] => config.num_channels[i + 1], 1, gelu; group_count=config.group_count), + downsample_module( + config.num_channels[1] => config.num_channels[i + 1], + 1, + relu; + group_count=config.group_count, + ), ) end tuple(slayers...) @@ -342,7 +347,7 @@ function get_model( sensealg=sensealg, verbose=false, ) - elseif config.model_type == :vanilla + elseif config.model_type == :VANILLA MultiScaleDeepEquilibriumNetwork( main_layers, mapping_layers, @@ -354,7 +359,7 @@ function get_model( verbose=false, ) else - throw(ArgumentError("`model_type` must be one of `[:skip, :skipv2, :vanilla]`")) + throw(ArgumentError("`model_type` must be one of `[:SKIP, :SKIPV2, :VANILLA]`")) end model = DEQChain(initial_layers, deq, final_layers) diff --git a/examples/src/train.jl b/examples/src/train.jl index 74f70324..1f711806 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -66,13 +66,14 @@ end evaluate(model, ps, st, ::Nothing, device) = nothing function evaluate(model, ps, st, dataloader, device) + st_eval = EFL.testmode(st) matches, total_loss, total_datasize, total_nfe, total_time = 0, 0, 0, 0, 0 for (x, y) in dataloader x = device(x) y = device(y) start_time = time() - (ŷ, soln), _ = model(x, ps, st) + (ŷ, soln), _ = model(x, ps, st_eval) total_time += time() - start_time total_nfe += soln.nfe * size(x, ndims(x)) @@ -135,7 +136,7 @@ loss_function(e::ExperimentConfiguration, args...; kwargs...) = loss_function(e. function loss_function(c::ImageClassificationModelConfiguration; λ_skip=1.0f-2) function loss_function_closure(x, y, model, ps, st) (ŷ, soln), st_ = model(x, ps, st) - loss = if c.model_type == :vanilla + loss = if c.model_type == :VANILLA Flux.Losses.logitcrossentropy(ŷ, y) else Flux.Losses.logitcrossentropy(ŷ, y) + λ_skip * Flux.Losses.mae(soln.u₀, Flux.Zygote.dropgrad(soln.z_star)) diff --git a/src/solve.jl b/src/solve.jl index 8abb424a..30c9f541 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -12,24 +12,41 @@ function transform_solution(soln::EquilibriumSolution) return DiffEqBase.build_solution(soln.prob, soln.alg, soln.u, soln.resid; retcode=soln.retcode) end -function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem{uType}, alg::ContinuousDEQSolver, args...; kwargs...) where {uType} +function DiffEqBase.__solve( + prob::DiffEqBase.AbstractSteadyStateProblem{uType}, alg::ContinuousDEQSolver, args...; kwargs... +) where {uType} tspan = alg.tspan isa Tuple ? alg.tspan : convert.(real(eltype(prob.u0)), (zero(alg.tspan), alg.tspan)) _prob = ODEProblem(prob.f, prob.u0, tspan, prob.p) - terminate_stats = Dict{Symbol,Any}(:best_objective_value => real(eltype(prob.u0))(Inf), - :best_objective_value_iteration => nothing) - - sol = solve(_prob, alg.alg, args...; kwargs..., - callback=TerminateSteadyState(alg.abstol, alg.reltol, get_terminate_condition(alg, terminate_stats))) - - u, t = terminate_stats[:best_objective_value_iteration] === nothing ? (sol.u[end], sol.t[end]) : - (sol.u[terminate_stats[:best_objective_value_iteration] + 1], - sol.t[terminate_stats[:best_objective_value_iteration] + 1]) + terminate_stats = Dict{Symbol,Any}( + :best_objective_value => real(eltype(prob.u0))(Inf), :best_objective_value_iteration => nothing + ) + + sol = solve( + _prob, + alg.alg, + args...; + kwargs..., + callback=TerminateSteadyState( + alg.abstol_termination, alg.reltol_termination, get_terminate_condition(alg, terminate_stats) + ), + ) + + u, t = if terminate_stats[:best_objective_value_iteration] === nothing + (sol.u[end], sol.t[end]) + else + ( + sol.u[terminate_stats[:best_objective_value_iteration] + 1], + sol.t[terminate_stats[:best_objective_value_iteration] + 1], + ) + end # Dont count towards NFE since this is mostly a check for convergence du = prob.f(u, prob.p, t) retcode = (sol.retcode == :Terminated && has_converged(du, u, alg) ? :Success : :Failure) - return EquilibriumSolution{eltype(uType),ndims(uType),uType,typeof(prob),typeof(alg),typeof(sol.destats)}(u, du, prob, alg, retcode, sol.destats) + return EquilibriumSolution{eltype(uType),ndims(uType),uType,typeof(prob),typeof(alg),typeof(sol.destats)}( + u, du, prob, alg, retcode, sol.destats + ) end From ab2afcda49e9acbab978a8775aa5f8ecb7fa5e8c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Apr 2022 11:31:55 -0400 Subject: [PATCH 33/76] Run for different configurations --- examples/cifar10/script.jl | 36 ++++++++++++++++++++++-------------- examples/src/config.jl | 2 +- examples/src/train.jl | 4 ++-- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index e4f1c221..e83466aa 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -4,21 +4,8 @@ using FastDEQExperiments, Flux, CUDA, Optimisers, Dates, FluxMPI FluxMPI.Init(; verbose=true) # Setup -CUDA.math_mode!(CUDA.FAST_MATH) CUDA.allowscalar(false) -# Hyperparameters -config = Dict( - "seed" => 0, - "abstol" => 5.0f-2, - "reltol" => 5.0f-2, - "model_type" => :SKIP, - "continuous" => true, - "model_size" => :TINY, -) - -expt_name = "cifar-10_seed-$(config["seed"])_model-$(config["model_type"])_continuous-$(config["continuous"])_now-$(now())" - # Training function train_model(config, expt_name) # Logger Setup @@ -74,4 +61,25 @@ function train_model(config, expt_name) return model, cpu(ps), cpu(st), st_opt end -model, ps, st, st_opt = train_model(config, expt_name) +# Experiment Configurations +configs = [] +for seed in [6171, 3859, 2961], model_type in [:VANILLA, :SKIP, :SKIPV2], model_size in [:TINY, :LARGE] + push!( + configs, + Dict( + "seed" => seed, + "abstol" => 5.0f-2, + "reltol" => 5.0f-2, + "model_type" => model_type, + "continuous" => true, + "model_size" => model_size, + ) + ) +end + +# Training +for config in configs + expt_name = "cifar-10_seed-$(config["seed"])_model-$(config["model_type"])_continuous-$(config["continuous"])_now-$(now())" + FastDEQExperiments._should_log() && println("Starting Experiment: " * expt_name) + model, ps, st, st_opt = train_model(config, expt_name) +end diff --git a/examples/src/config.jl b/examples/src/config.jl index 8f30cd1c..9c672155 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -210,7 +210,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) model_config=get_model_config(dataset, model_size; kwargs...), eval_batchsize=64, train_batchsize=64, - nepochs=50, + nepochs=75, # For 4 GPUs pretrain_steps=0 ÷ (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:ADAM, diff --git a/examples/src/train.jl b/examples/src/train.jl index 1f711806..5ab74456 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -133,13 +133,13 @@ end loss_function(e::ExperimentConfiguration, args...; kwargs...) = loss_function(e.model_config, args...; kwargs...) -function loss_function(c::ImageClassificationModelConfiguration; λ_skip=1.0f-2) +function loss_function(c::ImageClassificationModelConfiguration; λ_skip=1.0f-3) function loss_function_closure(x, y, model, ps, st) (ŷ, soln), st_ = model(x, ps, st) loss = if c.model_type == :VANILLA Flux.Losses.logitcrossentropy(ŷ, y) else - Flux.Losses.logitcrossentropy(ŷ, y) + λ_skip * Flux.Losses.mae(soln.u₀, Flux.Zygote.dropgrad(soln.z_star)) + Flux.Losses.logitcrossentropy(ŷ, y) + λ_skip * Flux.Losses.mse(soln.u₀, soln.z_star) end return loss, ŷ, st_, soln.nfe end From a016807c9207e3786f89048e1dd2548d26febcf1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Apr 2022 14:04:48 -0400 Subject: [PATCH 34/76] Fix model construction --- examples/src/config.jl | 2 +- examples/src/models.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/config.jl b/examples/src/config.jl index 9c672155..f91db332 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -211,7 +211,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eval_batchsize=64, train_batchsize=64, nepochs=75, # For 4 GPUs - pretrain_steps=0 ÷ (is_distributed() ? total_workers() : 1), + pretrain_steps=3000 ÷ (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 / 2 * (is_distributed() ? total_workers() : 1), diff --git a/examples/src/models.jl b/examples/src/models.jl index 3dee5687..6e3e7fa2 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -326,7 +326,7 @@ function get_model( slayers, downsample_module( config.num_channels[1] => config.num_channels[i + 1], - 1, + i, relu; group_count=config.group_count, ), From c3cb712964b02b45b8071eb8725c9ca482d4b287 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 19 Apr 2022 13:12:52 -0400 Subject: [PATCH 35/76] Better scheduling --- examples/src/config.jl | 22 +++++++++++++++++----- examples/src/train.jl | 21 +++++++++++++-------- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/examples/src/config.jl b/examples/src/config.jl index f91db332..56b02071 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -188,9 +188,11 @@ Base.@kwdef struct ExperimentConfiguration{M<:AbstractTaskModelConfiguration} # Eval eval_batchsize::Int + eval_datasize_per_process::Int # Train train_batchsize::Int + train_datasize_per_process::Int nepochs::Int pretrain_steps::Int @@ -217,7 +219,9 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eta=0.001f0 / 2 * (is_distributed() ? total_workers() : 1), weight_decay=0.0f0, momentum=0.9f0, - nesterov=true + nesterov=true, + eval_datasize_per_process=10000 ÷ (is_distributed() ? total_workers() : 1), + train_datasize_per_process=50000 ÷ (is_distributed() ? total_workers() : 1), ) elseif model_size == :LARGE return ExperimentConfiguration( @@ -231,7 +235,9 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eta=0.001f0 / 4 * (is_distributed() ? total_workers() : 1), weight_decay=0.0f0, momentum=0.9f0, - nesterov=true + nesterov=true, + eval_datasize_per_process=10000 ÷ (is_distributed() ? total_workers() : 1), + train_datasize_per_process=50000 ÷ (is_distributed() ? total_workers() : 1), ) else throw(ArgumentError("`model_size` must be one of `[:TINY, :LARGE]`")) @@ -249,7 +255,9 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eta=0.05f0 / 4 * (is_distributed() ? total_workers() : 1), weight_decay=0.00005f0, momentum=0.9f0, - nesterov=true + nesterov=true, + eval_datasize_per_process=50000 ÷ (is_distributed() ? total_workers() : 1), + train_datasize_per_process=1281166 ÷ (is_distributed() ? total_workers() : 1), ) elseif model_size == :LARGE return ExperimentConfiguration( @@ -263,7 +271,9 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eta=0.05f0 / 4 * (is_distributed() ? total_workers() : 1), weight_decay=0.00005f0, momentum=0.9f0, - nesterov=true + nesterov=true, + eval_datasize_per_process=50000 ÷ (is_distributed() ? total_workers() : 1), + train_datasize_per_process=1281166 ÷ (is_distributed() ? total_workers() : 1), ) elseif model_size == :XL return ExperimentConfiguration( @@ -277,7 +287,9 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eta=0.05f0 / 8 * (is_distributed() ? total_workers() : 1), weight_decay=0.00005f0, momentum=0.9f0, - nesterov=true + nesterov=true, + eval_datasize_per_process=50000 ÷ (is_distributed() ? total_workers() : 1), + train_datasize_per_process=1281166 ÷ (is_distributed() ? total_workers() : 1), ) else throw(ArgumentError("`model_size` must be one of `[:SMALL, :LARGE, :XL]`")) diff --git a/examples/src/train.jl b/examples/src/train.jl index 5ab74456..b733f554 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -34,7 +34,11 @@ function construct_optimiser(config::ExperimentConfiguration) end sched = if config.lr_scheduler == :COSINE - ParameterSchedulers.Stateful(ParameterSchedulers.Cos(config.eta, 1.0f-6, config.nepochs)) + ParameterSchedulers.Stateful( + ParameterSchedulers.Cos( + config.eta, 1.0f-6, config.nepochs * (config.train_datasize_per_process ÷ config.train_batchsize) + ), + ) elseif config.lr_scheduler == :CONSTANT ParameterSchedulers.Stateful(ParameterSchedulers.Constant(config.eta)) else @@ -90,6 +94,7 @@ function train_one_epoch( st, loss_function, opt_state, + scheduler, dataloader, device, lg::PrettyTableLogger, @@ -124,11 +129,15 @@ function train_one_epoch( invoke_gc() end + # Run ParameterScheduler + eta_new = ParameterSchedulers.next!(scheduler) + opt_state = update_lr(opt_state, eta_new) + # Logging lg(; records=Dict("Train/Running/NFE" => nfe, "Train/Running/Loss" => loss, "Train/Running/Accuracy" => acc)) end - return ps, st, opt_state, iteration_count, (total_time=total_time,) + return ps, st, opt_state, scheduler, iteration_count, (total_time=total_time,) end loss_function(e::ExperimentConfiguration, args...; kwargs...) = loss_function(e.model_config, args...; kwargs...) @@ -171,8 +180,8 @@ function train( for epoch in 1:nepochs # Train 1 epoch - ps, st, opt_state, iteration_count, training_stats = train_one_epoch( - model, ps, st, loss_function, opt_state, train_dataloader, device, lg, econfig, iteration_count + ps, st, opt_state, scheduler, iteration_count, training_stats = train_one_epoch( + model, ps, st, loss_function, opt_state, scheduler, train_dataloader, device, lg, econfig, iteration_count ) invoke_gc() @@ -183,10 +192,6 @@ function train( invoke_gc() lg(epoch, training_stats.total_time, val_stats..., test_stats...) - - # Run ParameterScheduler - eta_new = ParameterSchedulers.next!(scheduler) - opt_state = update_lr(opt_state, eta_new) end return ps, st, opt_state From 27703a5a71ccdf7aeda6da5d731dfcf403356fcb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 19 Apr 2022 21:19:00 -0400 Subject: [PATCH 36/76] Dont track statistics --- examples/src/models.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/src/models.jl b/examples/src/models.jl index 6e3e7fa2..3b91b3fb 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -114,7 +114,7 @@ function ResidualBlockV2( dropout_rate::Real=0.0f0, gn_affine::Bool=true, weight_norm::Bool=true, - gn_track_stats::Bool=true, + gn_track_stats::Bool=false, ) inplanes, outplanes = mapping inner_planes = outplanes * deq_expand @@ -143,7 +143,7 @@ function ResidualBlockV2( ) end -function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=true, bn_affine::Bool=true) +function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=false, bn_affine::Bool=true) rescale = if first(mapping) != last(mapping) * expansion EFL.Chain( conv1x1(first(mapping) => last(mapping) * expansion), @@ -173,7 +173,7 @@ function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool ) end -function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=true, bn_affine::Bool=true) +function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=false, bn_affine::Bool=true) rescale = if first(mapping) != last(mapping) * expansion EFL.Chain( conv1x1(first(mapping) => last(mapping) * expansion), @@ -288,7 +288,7 @@ function get_model( [ EFL.Chain( conv3x3(config.head_channels[i] * 4 => config.head_channels[i + 1] * 4; stride=2, bias=true), - EFL.BatchNorm(config.head_channels[i + 1] * 4, relu; track_stats=true, affine=true), + EFL.BatchNorm(config.head_channels[i + 1] * 4, relu; track_stats=false, affine=true), ) for i in 1:(config.num_branches - 1) ]..., ) @@ -297,7 +297,7 @@ function get_model( increment_modules, downsample_modules, conv1x1(config.head_channels[config.num_branches] * 4 => config.final_channelsize; bias=true), - EFL.BatchNorm(config.final_channelsize, relu; track_stats=true, affine=true), + EFL.BatchNorm(config.final_channelsize, relu; track_stats=false, affine=true), EFL.GlobalMeanPool(), EFL.FlattenLayer(), EFL.Dense(config.final_channelsize, config.num_classes), From ce8cddf282a7a336e183ed68a7c3f6ebd9a35fee Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 20 Apr 2022 11:37:05 -0400 Subject: [PATCH 37/76] Temp fix for keys error --- examples/src/FastDEQExperiments.jl | 3 +++ examples/src/models.jl | 27 ++++++++++++++++++--------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/examples/src/FastDEQExperiments.jl b/examples/src/FastDEQExperiments.jl index af303d6b..bdc5c817 100644 --- a/examples/src/FastDEQExperiments.jl +++ b/examples/src/FastDEQExperiments.jl @@ -20,6 +20,9 @@ import MLDataUtils: nobs, getobs const EFL = ExplicitFluxLayers +# FIXME: Remove once FastDEQ has been updated to use latest EFL +Base.keys(::Nothing) = () + # Memory Management relieve_gc_pressure(::Union{Nothing,<:AbstractArray}) = nothing relieve_gc_pressure(x::CuArray) = CUDA.unsafe_free!(x) diff --git a/examples/src/models.jl b/examples/src/models.jl index 3b91b3fb..11c910eb 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -68,6 +68,7 @@ function ResidualBlockV1( gn_affine::Bool=true, weight_norm::Bool=true, gn_track_stats::Bool=false, + dropout_seed::UInt64=UInt64(0), ) inplanes, outplanes = mapping inner_planes = outplanes * deq_expand @@ -84,7 +85,7 @@ function ResidualBlockV1( gn2 = EFL.GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) gn3 = EFL.GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) - dropout = iszero(dropout_rate) ? EFL.NoOpLayer() : EFL.Dropout(dropout_rate) + dropout = iszero(dropout_rate) ? EFL.NoOpLayer() : EFL.Dropout(dropout_rate; initial_seed=dropout_seed) return EFL.Chain( EFL.Parallel( @@ -115,6 +116,7 @@ function ResidualBlockV2( gn_affine::Bool=true, weight_norm::Bool=true, gn_track_stats::Bool=false, + dropout_seed::UInt64=UInt64(0), ) inplanes, outplanes = mapping inner_planes = outplanes * deq_expand @@ -131,7 +133,7 @@ function ResidualBlockV2( gn2 = EFL.GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) gn3 = EFL.GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) - dropout = iszero(dropout_rate) ? EFL.NoOpLayer() : EFL.Dropout(dropout_rate) + dropout = iszero(dropout_rate) ? EFL.NoOpLayer() : EFL.Dropout(dropout_rate; initial_seed=dropout_seed) return EFL.Chain( conv1, @@ -213,16 +215,16 @@ function get_model( downsample_layers = [ conv3x3(3 => init_channel_size; stride=config.downsample_times >= 1 ? 2 : 1), - EFL.BatchNorm(init_channel_size, relu; affine=true), + EFL.BatchNorm(init_channel_size, relu; affine=true, track_stats=false), conv3x3(init_channel_size => init_channel_size; stride=config.downsample_times >= 2 ? 2 : 1), - EFL.BatchNorm(init_channel_size, relu; affine=true), + EFL.BatchNorm(init_channel_size, relu; affine=true, track_stats=false), ] for _ in 3:(config.downsample_times) append!( downsample_layers, [ conv3x3(init_channel_size => init_channel_size; stride=2), - EFL.BatchNorm(init_channel_size, relu; affine=true), + EFL.BatchNorm(init_channel_size, relu; affine=true, track_stats=false), ], ) end @@ -233,12 +235,14 @@ function get_model( else EFL.Chain( conv1x1(init_channel_size => init_channel_size; bias=false), - EFL.BatchNorm(init_channel_size, relu; affine=true), + EFL.BatchNorm(init_channel_size, relu; affine=true, track_stats=false), ) end initial_layers = EFL.Chain(downsample, stage0) + dropout_seed = UInt64(0) + main_layers = Tuple( ResidualBlockV1( config.num_channels[i] => config.num_channels[i]; @@ -246,9 +250,12 @@ function get_model( dropout_rate=config.dropout_rate, num_gn_groups=config.group_count, n_big_kernels=config.big_kernels[i], + dropout_seed=dropout_seed + (i - 1) * 100, ) for i in 1:(config.num_branches) ) + dropout_seed = dropout_seed + config.num_branches * 100 + mapping_layers = Matrix{EFL.AbstractExplicitLayer}(undef, config.num_branches, config.num_branches) for i in 1:(config.num_branches) for j in 1:(config.num_branches) @@ -307,8 +314,8 @@ function get_model( ContinuousDEQSolver( config.ode_solver; mode=config.stop_mode, - abstol=1f-5, #config.abstol, - reltol=1f-5, #config.reltol, + abstol=1.0f-5, #config.abstol, + reltol=1.0f-5, #config.reltol, abstol_termination=config.abstol, reltol_termination=config.reltol, ) @@ -320,7 +327,9 @@ function get_model( deq = if config.model_type ∈ (:SKIP, :SKIPV2) shortcut = if config.model_type == :SKIP - slayers = EFL.AbstractExplicitLayer[ResidualBlockV2(config.num_channels[1] => config.num_channels[1])] + slayers = EFL.AbstractExplicitLayer[ResidualBlockV2( + config.num_channels[1] => config.num_channels[1]; dropout_seed=dropout_seed + )] for i in 1:(config.num_branches - 1) push!( slayers, From d12347f979188d02a23c10240c9492a6e406897e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 20 Apr 2022 12:51:17 -0400 Subject: [PATCH 38/76] Move to component arrays --- Project.toml | 3 ++- src/FastDEQ.jl | 8 +++++++- src/adjoint.jl | 5 +---- src/layers/chain.jl | 20 +++----------------- src/layers/core.jl | 19 ++----------------- src/layers/deq.jl | 18 +++++++++--------- src/layers/jacobian_stabilization.jl | 2 +- src/layers/mdeq.jl | 14 +++++++------- src/solvers/continuous.jl | 2 +- 9 files changed, 33 insertions(+), 58 deletions(-) diff --git a/Project.toml b/Project.toml index 02a925ea..566b92d6 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.1.0" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" @@ -46,8 +47,8 @@ julia = "1.7" [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ExplicitFluxLayers = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/FastDEQ.jl b/src/FastDEQ.jl index 3d2b538a..407566f8 100644 --- a/src/FastDEQ.jl +++ b/src/FastDEQ.jl @@ -16,10 +16,16 @@ using CUDA, ExplicitFluxLayers, Functors, ChainRulesCore, + ComponentArrays, Setfield import ExplicitFluxLayers: - AbstractExplicitLayer, initialparameters, initialstates, createcache, parameterlength, statelength, cachesize + AbstractExplicitLayer, + AbstractExplicitContainerLayer, + initialparameters, + initialstates, + parameterlength, + statelength import Random: AbstractRNG include("operator.jl") diff --git a/src/adjoint.jl b/src/adjoint.jl index 3c31ac78..9e809c4c 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -1,6 +1,3 @@ -neg(x) = -x -neg(::Nothing) = nothing - @noinline function DiffEqSensitivity.SteadyStateAdjointProblem( sol::EquilibriumSolution, sensealg::DiffEqSensitivity.SteadyStateAdjoint, g::Nothing, dg; save_idxs=nothing ) @@ -31,5 +28,5 @@ neg(::Nothing) = nothing _, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p) dp = back(vec(λ))[1] - return dp isa NamedTuple ? fmap(neg, dp) : -vec(dp) + return -dp end diff --git a/src/layers/chain.jl b/src/layers/chain.jl index 8d7e935b..e50d58d4 100644 --- a/src/layers/chain.jl +++ b/src/layers/chain.jl @@ -1,27 +1,13 @@ -struct DEQChain{P1,D<:AbstractDeepEquilibriumNetwork,P2} <: AbstractExplicitLayer +struct DEQChain{P1,D,P2} <: AbstractExplicitContainerLayer{(:pre_deq, :deq, :post_deq)} pre_deq::P1 deq::D post_deq::P2 end -function initialparameters(rng::AbstractRNG, c::DEQChain) - return ( - pre_deq=initialparameters(rng, c.pre_deq), - deq=initialparameters(rng, c.deq), - post_deq=initialparameters(rng, c.post_deq), - ) -end - -function initialstates(rng::AbstractRNG, c::DEQChain) - return ( - pre_deq=initialstates(rng, c.pre_deq), deq=initialstates(rng, c.deq), post_deq=initialstates(rng, c.post_deq) - ) -end - function DEQChain(layers...) pre_deq, post_deq, deq, encounter_deq = [], [], nothing, false for l in layers - if l isa AbstractDeepEquilibriumNetwork + if l isa AbstractDeepEquilibriumNetwork || l isa AbstractSkipDeepEquilibriumNetwork @assert !encounter_deq "Can have only 1 DEQ Layer in the Chain!!!" deq = l encounter_deq = true @@ -35,7 +21,7 @@ function DEQChain(layers...) return DEQChain(pre_deq, deq, post_deq) end -function (deq::DEQChain{P1,D,P2})(x, ps::NamedTuple, st::NamedTuple) where {P1,D,P2} +function (deq::DEQChain{P1,D,P2})(x, ps::ComponentArray, st::NamedTuple) where {P1,D,P2} x1, st1 = if P1 == Nothing x, st.pre_deq else diff --git a/src/layers/core.jl b/src/layers/core.jl index 932db452..08f98e15 100644 --- a/src/layers/core.jl +++ b/src/layers/core.jl @@ -1,31 +1,16 @@ -abstract type AbstractDeepEquilibriumNetwork <: AbstractExplicitLayer end -abstract type AbstractSkipDeepEquilibriumNetwork <: AbstractDeepEquilibriumNetwork end +abstract type AbstractDeepEquilibriumNetwork <: AbstractExplicitContainerLayer{(:model,)} end -initialparameters(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork) = (model=initialparameters(rng, deq.model),) function initialstates(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork) return (model=initialstates(rng, deq.model), fixed_depth=0) end -createcache(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork, x) = (model=createcache(rng, deq.model, x),) -parameterlength(deq::AbstractDeepEquilibriumNetwork) = parameterlength(deq.model) -statelength(deq::AbstractDeepEquilibriumNetwork) = statelength(deq.model) + 2 -cachesize(deq::AbstractDeepEquilibriumNetwork) = cachesize(deq.model) +abstract type AbstractSkipDeepEquilibriumNetwork <: AbstractExplicitContainerLayer{(:model,:shortcut)} end -function initialparameters(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork) - return (model=initialparameters(rng, deq.model), shortcut=initialparameters(rng, deq.shortcut)) -end function initialstates(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork) return ( model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut), fixed_depth=0 ) end -function createcache(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork, x) - return (model=createcache(rng, deq.model, x), shortcut=createcache(rng, deq.shortcut, x)) -end - -parameterlength(deq::AbstractSkipDeepEquilibriumNetwork) = parameterlength(deq.model) + parameterlength(deq.shortcut) -statelength(deq::AbstractSkipDeepEquilibriumNetwork) = statelength(deq.model) + statelength(deq.shortcut) + 2 -cachesize(deq::AbstractSkipDeepEquilibriumNetwork) = cachesize(deq.model) + cachesize(deq.shortcut) """ DeepEquilibriumSolution(z_star, u₀, residual, jacobian_loss, nfe) diff --git a/src/layers/deq.jl b/src/layers/deq.jl index 2bbf831b..425cd913 100644 --- a/src/layers/deq.jl +++ b/src/layers/deq.jl @@ -15,7 +15,7 @@ function DeepEquilibriumNetwork( ) end -function (deq::DeepEquilibriumNetwork{J})(x::AbstractArray{T}, ps::NamedTuple, st::NamedTuple) where {J,T} +function (deq::DeepEquilibriumNetwork{J})(x::AbstractArray{T}, ps::ComponentArray, st::NamedTuple) where {J,T} z = zero(x) if !iszero(st.fixed_depth) @@ -23,7 +23,7 @@ function (deq::DeepEquilibriumNetwork{J})(x::AbstractArray{T}, ps::NamedTuple, s st_ = st.model z_star = z for _ ∈ 1:st.fixed_depth - z_star, st_ = deq.model((z_star, x), ps.model, st_) + z_star, st_ = deq.model((z_star, x), ps, st_) end @set! st.model = st_ @@ -35,13 +35,13 @@ function (deq::DeepEquilibriumNetwork{J})(x::AbstractArray{T}, ps::NamedTuple, s return u_ .- u end - prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model) + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps) sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) - z_star, st_ = deq.model((sol.u, x), ps.model, st.model) + z_star, st_ = deq.model((sol.u, x), ps, st.model) - jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : T(0)) - residual = Zygote.@ignore z_star .- deq.model((z_star, x), ps.model, st.model)[1] - @set! st.model = st_ :: typeof(st.model) + jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps, st.model, z_star, x) : T(0)) + residual = z_star .- deq.model((z_star, x), ps, st.model)[1] + @set! st.model = st_ return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end @@ -69,7 +69,7 @@ function SkipDeepEquilibriumNetwork( ) end -function (deq::SkipDeepEquilibriumNetwork{J,M,S})(x::AbstractArray{T}, ps::NamedTuple, st::NamedTuple) where {J,M,S,T} +function (deq::SkipDeepEquilibriumNetwork{J,M,S})(x::AbstractArray{T}, ps::ComponentArray, st::NamedTuple) where {J,M,S,T} z, st__ = if S == Nothing deq.model((zero(x), x), ps.model, st.model) else @@ -99,7 +99,7 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})(x::AbstractArray{T}, ps::Named z_star, st_ = deq.model((sol.u, x), ps.model, st.model) jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : T(0)) - residual = Zygote.@ignore z_star .- deq.model((z_star, x), ps.model, st.model)[1] + residual = z_star .- deq.model((z_star, x), ps.model, st.model)[1] @set! st.model = st_ :: typeof(st.model) return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st diff --git a/src/layers/jacobian_stabilization.jl b/src/layers/jacobian_stabilization.jl index 3f53f452..d87a64c6 100644 --- a/src/layers/jacobian_stabilization.jl +++ b/src/layers/jacobian_stabilization.jl @@ -1,6 +1,6 @@ # Doesn't work as of now function compute_deq_jacobian_loss( - model::AbstractExplicitLayer, ps::NamedTuple, st::NamedTuple, z::AbstractArray, x::AbstractArray + model::AbstractExplicitLayer, ps::ComponentArray, st::NamedTuple, z::AbstractArray, x::AbstractArray ) l, back = Zygote.pullback(u -> model((u, x), ps, st)[1], z) vjp_z = back(gaussian_like(l))[1] diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index 5dd66689..84f96a39 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -43,7 +43,7 @@ end Zygote.@nograd get_initial_condition_mdeq -function (deq::MultiScaleDeepEquilibriumNetwork{N})(x::AbstractArray{T}, ps::NamedTuple, st::NamedTuple) where {N,T} +function (deq::MultiScaleDeepEquilibriumNetwork{N})(x::AbstractArray{T}, ps::ComponentArray, st::NamedTuple) where {N,T} z, st = get_initial_condition_mdeq(deq.scales, x, st) if !iszero(st.fixed_depth) @@ -51,7 +51,7 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})(x::AbstractArray{T}, ps::Nam st_ = st.model for _ ∈ 1:st.fixed_depth - z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps.model, st_) + z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps, st_) end @set! st.model = st_ @@ -67,11 +67,11 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})(x::AbstractArray{T}, ps::Nam dudt(u, p, t) = vcat(Flux.flatten.(dudt_(u, p, t)[1])...) .- u - prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model) + prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps) sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) - z_star, st_ = dudt_(sol.u, ps.model, nothing) + z_star, st_ = dudt_(sol.u, ps, nothing) - residual = Zygote.@ignore dudt(sol.u, ps.model, nothing) + residual = dudt(sol.u, ps, nothing) @set! st.model = st_ @@ -127,7 +127,7 @@ function MultiScaleSkipDeepEquilibriumNetwork( end function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( - x::AbstractArray{T}, ps::NamedTuple, st::NamedTuple + x::AbstractArray{T}, ps::ComponentArray, st::NamedTuple ) where {N,L,M,Sh,T} z, st = if Sh == Nothing u0, st_ = get_initial_condition_mdeq(deq.scales, x, st) @@ -166,7 +166,7 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) z_star, st_ = dudt_(sol.u, ps.model, nothing) - residual = Zygote.@ignore dudt(sol.u, ps.model, nothing) + residual = dudt(sol.u, ps.model, nothing) @set! st.model = st_ diff --git a/src/solvers/continuous.jl b/src/solvers/continuous.jl index a6a5dd08..07c7acf3 100644 --- a/src/solvers/continuous.jl +++ b/src/solvers/continuous.jl @@ -52,7 +52,7 @@ struct ContinuousDEQSolver{M,A,T,TS} <: SteadyStateDiffEq.SteadyStateDiffEqAlgor end function ContinuousDEQSolver( - alg=VCABM4(); + alg=VCABM3(); mode::Symbol=:rel_deq_default, abstol::T=1.0f-8, reltol::T=1.0f-8, From c019b20dfe1fd3363c5fed4ec3315b483657c61a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 21 Apr 2022 08:18:20 -0400 Subject: [PATCH 39/76] Use named tuple --- examples/Manifest.toml | 54 +++--- examples/cifar10/script.jl | 21 +-- examples/src/config.jl | 2 +- examples/src/models.jl | 3 + examples/src/train.jl | 60 +++---- src/FastDEQ.jl | 1 + src/adjoint.jl | 5 +- src/layers/chain.jl | 2 +- src/layers/core.jl | 2 +- src/layers/deq.jl | 18 +- src/layers/mdeq.jl | 12 +- src/solve.jl | 27 +++ src/solvers/continuous.jl | 83 +-------- src/solvers/discrete.jl | 13 +- src/solvers/discrete/broyden.jl | 147 ++++++--------- .../discrete/limited_memory_broyden.jl | 168 ++++++------------ src/utils.jl | 7 - test/runtests.jl | 86 ++++----- 18 files changed, 269 insertions(+), 442 deletions(-) diff --git a/examples/Manifest.toml b/examples/Manifest.toml index e58bcd48..b0b0337e 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -104,15 +104,15 @@ version = "0.1.3" [[deps.BlockArrays]] deps = ["ArrayLayouts", "FillArrays", "LinearAlgebra"] -git-tree-sha1 = "7278f5ffec86a6c10233bf9c6be1a9c593012299" +git-tree-sha1 = "28c497806c05326e7cadac0c916980d5a9c0e905" uuid = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" -version = "0.16.13" +version = "0.16.14" [[deps.BlockBandedMatrices]] deps = ["ArrayLayouts", "BandedMatrices", "BlockArrays", "FillArrays", "LinearAlgebra", "MatrixFactorizations", "SparseArrays", "Statistics"] -git-tree-sha1 = "8aaea69570a48b505383210451cbf36a7237a829" +git-tree-sha1 = "646a8081a8f7a728b2c01a1d00a9fa07b678900a" uuid = "ffab5731-97b5-5995-9138-79e8c1846df0" -version = "0.11.4" +version = "0.11.5" [[deps.BufferedStreams]] deps = ["Compat", "Test"] @@ -206,6 +206,12 @@ deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" version = "0.5.2+0" +[[deps.ComponentArrays]] +deps = ["ArrayInterface", "ChainRulesCore", "LinearAlgebra", "Requires"] +git-tree-sha1 = "243d8b8afc829a6707bbb1cd00da868703c2ef42" +uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +version = "0.11.15" + [[deps.CompositeTypes]] git-tree-sha1 = "d5b014b216dc891e81fea299638e4c10c657b582" uuid = "b152e2b5-7a66-4b01-a709-34e65c35f657" @@ -398,8 +404,8 @@ uuid = "7cc45869-7501-5eee-bdea-0790c847d4ef" version = "0.0.29+0" [[deps.ExplicitFluxLayers]] -deps = ["CUDA", "ChainRulesCore", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Octavian", "Random", "Setfield", "SparseArrays", "Statistics"] -git-tree-sha1 = "c1e2783e6ff0a8327b58e869f9c243769f4e89d0" +deps = ["CUDA", "ChainRulesCore", "ComponentArrays", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Optimisers", "Random", "Setfield", "SparseArrays", "Statistics", "Zygote"] +git-tree-sha1 = "e923ed1f219c9c505faa2c05af29dbb74b4f5760" repo-rev = "ap/sparse" repo-url = "https://github.com/avik-pal/ExplicitFluxLayers.jl.git" uuid = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" @@ -452,7 +458,7 @@ uuid = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" version = "0.3.2" [[deps.FastDEQ]] -deps = ["CUDA", "ChainRulesCore", "DataLoaders", "DiffEqBase", "DiffEqCallbacks", "DiffEqSensitivity", "ExplicitFluxLayers", "Flux", "Functors", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Random", "Reexport", "Requires", "SciMLBase", "Setfield", "Statistics", "SteadyStateDiffEq", "UnPack", "Zygote"] +deps = ["CUDA", "ChainRulesCore", "ComponentArrays", "DataLoaders", "DiffEqBase", "DiffEqCallbacks", "DiffEqSensitivity", "ExplicitFluxLayers", "Flux", "Functors", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Random", "Reexport", "Requires", "SciMLBase", "Setfield", "Statistics", "SteadyStateDiffEq", "UnPack", "Zygote"] path = ".." uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" version = "0.1.0" @@ -497,12 +503,12 @@ uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" version = "0.13.0" [[deps.FluxMPI]] -deps = ["CUDA", "Dates", "Flux", "Functors", "LearnBase", "MLDataUtils", "MPI", "Optimisers", "Setfield", "Zygote"] -git-tree-sha1 = "0956751f425663d4f468cf4ed97b95249257e202" -repo-rev = "ap/opt" +deps = ["CUDA", "ComponentArrays", "Dates", "Flux", "Functors", "LearnBase", "MLDataUtils", "MPI", "Optimisers", "Setfield", "Zygote"] +git-tree-sha1 = "65e52d4bf8600f15c8e640dce2ba80bb9bbc1f16" +repo-rev = "main" repo-url = "https://github.com/avik-pal/FluxMPI.jl.git" uuid = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" -version = "0.3.0" +version = "0.3.1" [[deps.FoldsThreads]] deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] @@ -1064,12 +1070,6 @@ git-tree-sha1 = "55ce61d43409b1fb0279d1781bf3b0f22c83ab3b" uuid = "d8793406-e978-5875-9003-1fc021f44a92" version = "0.3.7" -[[deps.Octavian]] -deps = ["ArrayInterface", "CPUSummary", "IfElse", "LoopVectorization", "ManualMemory", "PolyesterWeave", "Requires", "Static", "ThreadingUtilities", "VectorizationBase"] -git-tree-sha1 = "26c004c96dc634cefe9174cb9180c496f6c7e100" -uuid = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" -version = "0.3.13" - [[deps.OffsetArrays]] deps = ["Adapt"] git-tree-sha1 = "043017e0bdeff61cfbb7afeb558ab29536bbb5ed" @@ -1112,9 +1112,9 @@ version = "1.6.2" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "e440ecef249dea69e79248857e800e71820d386c" +git-tree-sha1 = "cfedc2d6990d792e705ade4c458fea5fbe574520" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.2.1" +version = "0.2.2" [[deps.OrderedCollections]] git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" @@ -1123,9 +1123,9 @@ version = "1.4.1" [[deps.OrdinaryDiffEq]] deps = ["Adapt", "ArrayInterface", "DataStructures", "DiffEqBase", "DocStringExtensions", "ExponentialUtilities", "FastClosures", "FiniteDiff", "ForwardDiff", "LinearAlgebra", "LinearSolve", "Logging", "LoopVectorization", "MacroTools", "MuladdMacro", "NLsolve", "NonlinearSolve", "Polyester", "PreallocationTools", "RecursiveArrayTools", "Reexport", "SciMLBase", "SparseArrays", "SparseDiffTools", "StaticArrays", "UnPack"] -git-tree-sha1 = "c5568ed45ee56cb4a5e3cebff3b91541ae016a83" +git-tree-sha1 = "8031a288c9b418664a3dfbac36e464a3f61ace73" uuid = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" -version = "6.9.0" +version = "6.10.0" [[deps.PDMats]] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] @@ -1200,9 +1200,9 @@ version = "0.2.4" [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "d3538e7f8a790dc8903519090857ef8e1283eecd" +git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.2.5" +version = "1.3.0" [[deps.PrettyPrint]] git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" @@ -1350,9 +1350,9 @@ version = "0.6.32" [[deps.SciMLBase]] deps = ["ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "RecipesBase", "RecursiveArrayTools", "StaticArrays", "Statistics", "Tables", "TreeViews"] -git-tree-sha1 = "61159e034c4cb36b76ad2926bb5bf8c28cc2fb12" +git-tree-sha1 = "f03796a588eba66f6bcc63cfdeda89b4a339ce4e" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "1.29.0" +version = "1.30.0" [[deps.SentinelArrays]] deps = ["Dates", "Random"] @@ -1481,9 +1481,9 @@ version = "0.3.2" [[deps.Strided]] deps = ["LinearAlgebra", "TupleTools"] -git-tree-sha1 = "972de61ae8cb965c516b871b69bb8594463d39a9" +git-tree-sha1 = "7c4bcef07d559776a9e2a009c441547fb9eb5c92" uuid = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" -version = "1.2.0" +version = "1.2.1" [[deps.StringEncodings]] deps = ["Libiconv_jll"] diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index e83466aa..68d42ffd 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -20,19 +20,14 @@ function train_model(config, expt_name) expt_config = FastDEQExperiments.get_experiment_config( :CIFAR10, config["model_size"]; - model_type = config["model_type"], - continuous = config["continuous"], - abstol = config["abstol"], - reltol = config["reltol"], + model_type=config["model_type"], + continuous=config["continuous"], + abstol=config["abstol"], + reltol=config["reltol"], ) # Model Setup - model, ps, st = FastDEQExperiments.get_model( - expt_config; - seed=config["seed"], - device=gpu, - warmup=true, - ) + model, ps, st = FastDEQExperiments.get_model(expt_config; seed=config["seed"], device=gpu, warmup=true) # Get Dataloaders train_dataloader, test_dataloader = FastDEQExperiments.get_dataloaders( @@ -52,7 +47,7 @@ function train_model(config, expt_name) gpu, expt_config.nepochs, lg, - expt_config + expt_config, ) # Close Logger and Flush Data to disk @@ -73,13 +68,13 @@ for seed in [6171, 3859, 2961], model_type in [:VANILLA, :SKIP, :SKIPV2], model_ "model_type" => model_type, "continuous" => true, "model_size" => model_size, - ) + ), ) end # Training for config in configs - expt_name = "cifar-10_seed-$(config["seed"])_model-$(config["model_type"])_continuous-$(config["continuous"])_now-$(now())" + expt_name = "cifar-10_seed-$(config["seed"])_model-$(config["model_type"])_size-$(config["model_size"])_continuous-$(config["continuous"])_now-$(now())" FastDEQExperiments._should_log() && println("Starting Experiment: " * expt_name) model, ps, st, st_opt = train_model(config, expt_name) end diff --git a/examples/src/config.jl b/examples/src/config.jl index 56b02071..377b7cd5 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -212,7 +212,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) model_config=get_model_config(dataset, model_size; kwargs...), eval_batchsize=64, train_batchsize=64, - nepochs=75, # For 4 GPUs + nepochs=50, pretrain_steps=3000 ÷ (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:ADAM, diff --git a/examples/src/models.jl b/examples/src/models.jl index 11c910eb..00afda13 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -373,6 +373,9 @@ function get_model( model = DEQChain(initial_layers, deq, final_layers) ps, st = device.(EFL.setup(MersenneTwister(seed), model)) + # NOTE: ComponentArrays seem to have some overhead + ps = NamedTuple(ps) + st = NamedTuple(st) if warmup clean_println("Starting Model Warmup") diff --git a/examples/src/train.jl b/examples/src/train.jl index b733f554..c16901ee 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -67,15 +67,12 @@ function _get_loggable_stats(stats::NamedTuple) end end -evaluate(model, ps, st, ::Nothing, device) = nothing +evaluate(model, ps, st, ::Nothing) = nothing -function evaluate(model, ps, st, dataloader, device) +function evaluate(model, ps, st, dataloader) st_eval = EFL.testmode(st) matches, total_loss, total_datasize, total_nfe, total_time = 0, 0, 0, 0, 0 - for (x, y) in dataloader - x = device(x) - y = device(y) - + for (x, y) in CuIterator(dataloader) start_time = time() (ŷ, soln), _ = model(x, ps, st_eval) total_time += time() - start_time @@ -88,6 +85,21 @@ function evaluate(model, ps, st, dataloader, device) return (loss=total_loss, accuracy=matches, mean_nfe=total_nfe, total_time=total_time, total_datasize=total_datasize) end +loss_function(e::ExperimentConfiguration, args...; kwargs...) = loss_function(e.model_config, args...; kwargs...) + +function loss_function(c::ImageClassificationModelConfiguration; λ_skip=1.0f-3) + function loss_function_closure(x, y, model, ps, st) + (ŷ, soln), st_ = model(x, ps, st) + loss = if c.model_type == :VANILLA + Flux.Losses.logitcrossentropy(ŷ, y) + else + Flux.Losses.logitcrossentropy(ŷ, y) + λ_skip * Flux.Losses.mse(soln.u₀, soln.z_star) + end + return loss, ŷ, st_, soln.nfe + end + return loss_function_closure +end + function train_one_epoch( model, ps, @@ -96,17 +108,13 @@ function train_one_epoch( opt_state, scheduler, dataloader, - device, lg::PrettyTableLogger, econfig::ExperimentConfiguration, iteration_count::Int, ) total_time = 0 - for (x, y) in dataloader - x = device(x) - y = device(y) - + for (x, y) in CuIterator(dataloader) # Compute Loss + Backprop + Update start_time = time() @@ -120,14 +128,6 @@ function train_one_epoch( iteration_count += 1 st = econfig.pretrain_steps == iteration_count ? EFL.update_state(st, :fixed_depth, 0) : st - if iteration_count % 25 == 0 - # Without this we might frequently run out of memory - # Disabling CUDA Mempool with MPI seems to help but it - # also slows down the overall code - # Relieve GC Pressure - relieve_gc_pressure((gs, ŷ, x, y)) - invoke_gc() - end # Run ParameterScheduler eta_new = ParameterSchedulers.next!(scheduler) @@ -140,21 +140,6 @@ function train_one_epoch( return ps, st, opt_state, scheduler, iteration_count, (total_time=total_time,) end -loss_function(e::ExperimentConfiguration, args...; kwargs...) = loss_function(e.model_config, args...; kwargs...) - -function loss_function(c::ImageClassificationModelConfiguration; λ_skip=1.0f-3) - function loss_function_closure(x, y, model, ps, st) - (ŷ, soln), st_ = model(x, ps, st) - loss = if c.model_type == :VANILLA - Flux.Losses.logitcrossentropy(ŷ, y) - else - Flux.Losses.logitcrossentropy(ŷ, y) + λ_skip * Flux.Losses.mse(soln.u₀, soln.z_star) - end - return loss, ŷ, st_, soln.nfe - end - return loss_function_closure -end - function train( model, ps, @@ -165,7 +150,6 @@ function train( train_dataloader, val_dataloader, test_dataloader, - device, nepochs, lg::PrettyTableLogger, econfig::ExperimentConfiguration; @@ -181,14 +165,14 @@ function train( for epoch in 1:nepochs # Train 1 epoch ps, st, opt_state, scheduler, iteration_count, training_stats = train_one_epoch( - model, ps, st, loss_function, opt_state, scheduler, train_dataloader, device, lg, econfig, iteration_count + model, ps, st, loss_function, opt_state, scheduler, train_dataloader, lg, econfig, iteration_count ) invoke_gc() # Evaluate - val_stats = _get_loggable_stats(evaluate(model, ps, st, val_dataloader, device)) + val_stats = _get_loggable_stats(evaluate(model, ps, st, val_dataloader)) invoke_gc() - test_stats = _get_loggable_stats(evaluate(model, ps, st, test_dataloader, device)) + test_stats = _get_loggable_stats(evaluate(model, ps, st, test_dataloader)) invoke_gc() lg(epoch, training_stats.total_time, val_stats..., test_stats...) diff --git a/src/FastDEQ.jl b/src/FastDEQ.jl index 407566f8..831cbf3e 100644 --- a/src/FastDEQ.jl +++ b/src/FastDEQ.jl @@ -32,6 +32,7 @@ include("operator.jl") include("solvers/continuous.jl") include("solvers/discrete.jl") +include("solvers/termination.jl") include("solve.jl") include("utils.jl") diff --git a/src/adjoint.jl b/src/adjoint.jl index 9e809c4c..b54f7d85 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -1,3 +1,6 @@ +neg(x::Any) = hasmethod(-, (typeof(x),)) ? -x : x +neg(nt::NamedTuple) = fmap(neg, nt) + @noinline function DiffEqSensitivity.SteadyStateAdjointProblem( sol::EquilibriumSolution, sensealg::DiffEqSensitivity.SteadyStateAdjoint, g::Nothing, dg; save_idxs=nothing ) @@ -28,5 +31,5 @@ _, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p) dp = back(vec(λ))[1] - return -dp + return neg(dp) end diff --git a/src/layers/chain.jl b/src/layers/chain.jl index e50d58d4..36bede0d 100644 --- a/src/layers/chain.jl +++ b/src/layers/chain.jl @@ -21,7 +21,7 @@ function DEQChain(layers...) return DEQChain(pre_deq, deq, post_deq) end -function (deq::DEQChain{P1,D,P2})(x, ps::ComponentArray, st::NamedTuple) where {P1,D,P2} +function (deq::DEQChain{P1,D,P2})(x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) where {P1,D,P2} x1, st1 = if P1 == Nothing x, st.pre_deq else diff --git a/src/layers/core.jl b/src/layers/core.jl index 08f98e15..ac4a727b 100644 --- a/src/layers/core.jl +++ b/src/layers/core.jl @@ -34,7 +34,7 @@ end function Base.show(io::IO, l::DeepEquilibriumSolution) print(io, "DeepEquilibriumSolution(") - print(io, ", z_star: ", l.z_star) + print(io, "z_star: ", l.z_star) print(io, ", initial_condition: ", l.u₀) print(io, ", residual: ", l.residual) print(io, ", jacobian_loss: ", l.jacobian_loss) diff --git a/src/layers/deq.jl b/src/layers/deq.jl index 425cd913..fdd33f77 100644 --- a/src/layers/deq.jl +++ b/src/layers/deq.jl @@ -8,21 +8,21 @@ end function DeepEquilibriumNetwork( model, solver; jacobian_regularization::Bool=false, sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs... ) - return DeepEquilibriumNetwork{ - jacobian_regularization,typeof(model),typeof(solver),typeof(sensealg),typeof(kwargs) - }( + return DeepEquilibriumNetwork{jacobian_regularization,typeof(model),typeof(solver),typeof(sensealg),typeof(kwargs)}( model, solver, sensealg, kwargs ) end -function (deq::DeepEquilibriumNetwork{J})(x::AbstractArray{T}, ps::ComponentArray, st::NamedTuple) where {J,T} +function (deq::DeepEquilibriumNetwork{J})( + x::AbstractArray{T}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple +) where {J,T} z = zero(x) if !iszero(st.fixed_depth) # Pretraining without Fixed Point Solving st_ = st.model z_star = z - for _ ∈ 1:st.fixed_depth + for _ in 1:(st.fixed_depth) z_star, st_ = deq.model((z_star, x), ps, st_) end @set! st.model = st_ @@ -69,7 +69,9 @@ function SkipDeepEquilibriumNetwork( ) end -function (deq::SkipDeepEquilibriumNetwork{J,M,S})(x::AbstractArray{T}, ps::ComponentArray, st::NamedTuple) where {J,M,S,T} +function (deq::SkipDeepEquilibriumNetwork{J,M,S})( + x::AbstractArray{T}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple +) where {J,M,S,T} z, st__ = if S == Nothing deq.model((zero(x), x), ps.model, st.model) else @@ -81,7 +83,7 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})(x::AbstractArray{T}, ps::Compo # Pretraining without Fixed Point Solving st_ = st.model z_star = z - for _ ∈ 1:st.fixed_depth + for _ in 1:(st.fixed_depth) z_star, st_ = deq.model((z_star, x), ps.model, st_) end @set! st.model = st_ @@ -100,7 +102,7 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})(x::AbstractArray{T}, ps::Compo jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : T(0)) residual = z_star .- deq.model((z_star, x), ps.model, st.model)[1] - @set! st.model = st_ :: typeof(st.model) + @set! st.model = st_::typeof(st.model) return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index 84f96a39..79cd78f1 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -43,14 +43,16 @@ end Zygote.@nograd get_initial_condition_mdeq -function (deq::MultiScaleDeepEquilibriumNetwork{N})(x::AbstractArray{T}, ps::ComponentArray, st::NamedTuple) where {N,T} +function (deq::MultiScaleDeepEquilibriumNetwork{N})( + x::AbstractArray{T}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple +) where {N,T} z, st = get_initial_condition_mdeq(deq.scales, x, st) if !iszero(st.fixed_depth) z_star = split_and_reshape(z, st.split_idxs, deq.scales) st_ = st.model - for _ ∈ 1:st.fixed_depth + for _ in 1:(st.fixed_depth) z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps, st_) end @@ -94,7 +96,7 @@ function initialstates(rng::AbstractRNG, deq::MultiScaleSkipDeepEquilibriumNetwo model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut), split_idxs=Tuple(vcat(0, cumsum(prod.(deq.scales))...)), - fixed_depth=0 + fixed_depth=0, ) end @@ -127,7 +129,7 @@ function MultiScaleSkipDeepEquilibriumNetwork( end function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( - x::AbstractArray{T}, ps::ComponentArray, st::NamedTuple + x::AbstractArray{T}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple ) where {N,L,M,Sh,T} z, st = if Sh == Nothing u0, st_ = get_initial_condition_mdeq(deq.scales, x, st) @@ -145,7 +147,7 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( z_star = split_and_reshape(z, st.split_idxs, deq.scales) st_ = st.model - for _ ∈ 1:st.fixed_depth + for _ in 1:(st.fixed_depth) z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps.model, st_) end diff --git a/src/solve.jl b/src/solve.jl index 30c9f541..eae2c0e6 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -50,3 +50,30 @@ function DiffEqBase.__solve( u, du, prob, alg, retcode, sol.destats ) end + +function DiffEqBase.__solve( + prob::DiffEqBase.AbstractSteadyStateProblem{uType}, alg::DiscreteDEQSolver, args...; maxiters=10, kwargs... +) where {uType} + terminate_stats = Dict{Symbol,Any}( + :best_objective_value => real(eltype(prob.u0))(Inf), :best_objective_value_iteration => nothing + ) + + u, stats = nlsolve( + alg.alg, + u -> prob.f(u, prob.p, nothing), + prob.u0; + maxiters=maxiters, + terminate_condition=get_terminate_condition(alg, terminate_stats) + ) + + # Dont count towards NFE since this is mostly a check for convergence + du = prob.f(u, prob.p, nothing) + + retcode = has_converged(du, u, alg) ? :Success : :Failure + + destats = (nf=stats.nf,) + + return EquilibriumSolution{eltype(uType),ndims(uType),uType,typeof(prob),typeof(alg),typeof(destats)}( + u, du, prob, alg, retcode, destats + ) +end diff --git a/src/solvers/continuous.jl b/src/solvers/continuous.jl index 07c7acf3..9119c971 100644 --- a/src/solvers/continuous.jl +++ b/src/solvers/continuous.jl @@ -1,5 +1,5 @@ """ - ContinuousDEQSolver(alg=VCABM4(); mode::Symbol=:rel_deq_default, abstol=1f-8, reltol=1f-8, abstol_termination=1f-8, reltol_termination=1f-8, tspan=Inf32) + ContinuousDEQSolver(alg=VCABM3(); mode::Symbol=:rel_deq_default, abstol=1f-8, reltol=1f-8, abstol_termination=1f-8, reltol_termination=1f-8, tspan=Inf32) Solver for Continuous DEQ Problem ([pal2022mixing](@cite)). Similar to `DynamicSS` but provides more flexibility needed for solving DEQ problems. @@ -64,84 +64,3 @@ function ContinuousDEQSolver( alg, abstol, reltol, abstol_termination, reltol_termination, tspan ) end - -get_mode(::Val{mode}) where {mode} = mode - -function get_terminate_condition(alg::ContinuousDEQSolver{M,A,T}, args...; kwargs...) where {M,A,T} - mode = get_mode(M) - if mode ∈ (:abs_deq_default, :rel_deq_default, :abs_deq_best, :rel_deq_best) - nstep, protective_threshold, objective_values = 0, T(1e3), T[] - - if mode ∈ (:rel_deq_best, :abs_deq_best) - @assert length(args) == 1 - - args[1][:best_objective_value] = T(Inf) - args[1][:best_objective_value_iteration] = 0 - end - - function terminate_condition_closure_1(integrator, abstol, reltol, min_t) - du, u = DiffEqBase.get_du(integrator), integrator.u - objective = norm(du) / (mode ∈ (:abs_deq_default, :abs_deq_best) ? 1 : (norm(du .+ u) + eps(T))) - criteria = mode ∈ (:abs_deq_default, :abs_deq_best) ? abstol : reltol - - if mode ∈ (:rel_deq_best, :abs_deq_best) - if objective < args[1][:best_objective_value] - args[1][:best_objective_value] = objective - args[1][:best_objective_value_iteration] = nstep + 1 - end - end - - # Main Termination Criteria - objective <= criteria && return true - - # Terminate if there has been no improvement for the last 30 steps - nstep += 1 - push!(objective_values, objective) - - objective <= 3 * criteria && - nstep >= 30 && - maximum(objective_values[max(1, length(objective_values) - nstep):end]) < - 1.3 * minimum(objective_values[max(1, length(objective_values) - nstep):end]) && - return true - - # Protective break - objective >= objective_values[1] * protective_threshold * length(du) && return true - - return false - end - return terminate_condition_closure_1 - else - function terminate_condition_closure_2(integrator, abstol, reltol, min_t) - return has_converged(DiffEqBase.get_du(integrator), integrator.u, alg, abstol, reltol) - end - return terminate_condition_closure_2 - end -end - -# Convergence Criterions -function has_converged( - du, u, alg::ContinuousDEQSolver{M}, abstol=alg.abstol_termination, reltol=alg.reltol_termination -) where {M} - mode = get_mode(M) - if mode == :norm - return norm(du) <= abstol && norm(du) <= reltol * norm(du .+ u) - elseif mode == :rel - return all(abs.(du) .<= reltol .* abs.(u)) - elseif mode == :rel_norm - return norm(du) <= reltol * norm(du .+ u) - elseif mode == :rel_deq_default - return norm(du) <= reltol * norm(du .+ u) - elseif mode == :rel_deq_best - return norm(du) <= reltol * norm(du .+ u) - elseif mode == :abs - return all(abs.(du) .<= abstol) - elseif mode == :abs_norm - return norm(du) <= abstol - elseif mode == :abs_deq_default - return norm(du) <= abstol - elseif mode == :abs_deq_best - return norm(du) <= abstol - else - return all(abs.(du) .<= abstol .& abs.(du) .<= reltol .* abs.(u)) - end -end diff --git a/src/solvers/discrete.jl b/src/solvers/discrete.jl index d0657261..8999de77 100644 --- a/src/solvers/discrete.jl +++ b/src/solvers/discrete.jl @@ -18,9 +18,16 @@ Solver for Discrete DEQ Problem ([baideep2019](@cite)). A wrapper around `SSroot See also: [`ContinuousDEQSolver`](@ref) """ -function DiscreteDEQSolver(solver=LimitedMemoryBroydenSolver; abstol=1e-8, reltol=1e-8, kwargs...) - solver = solver(; kwargs..., reltol=reltol, abstol=abstol) - return SSRootfind(; nlsolve=(f, u0, abstol) -> solver(f, u0)) +struct DiscreteDEQSolver{M,A,T} <: SteadyStateDiffEq.SteadyStateDiffEqAlgorithm + alg::A + abstol_termination::T + reltol_termination::T +end + +function DiscreteDEQSolver( + alg; mode::Symbol=:rel_deq_default, abstol_termination::T=1.0f-8, reltol_termination::T=1.0f-8 +) where {T<:Number} + return DiscreteDEQSolver{Val(mode),typeof(alg),T}(alg, abstol_termination, reltol_termination) end include("discrete/broyden.jl") diff --git a/src/solvers/discrete/broyden.jl b/src/solvers/discrete/broyden.jl index c845142c..b1d32416 100644 --- a/src/solvers/discrete/broyden.jl +++ b/src/solvers/discrete/broyden.jl @@ -1,24 +1,3 @@ -# Broyden -## NOTE: For the time being it is better to use `LimitedMemoryBroydenSolver` -struct BroydenCache{J,F,X} - Jinv::J - fx::F - Δfx::F - fx_old::F - x::X - Δx::X - x_old::X -end - -function BroydenCache(x) - fx, Δfx, fx_old = copy(x), copy(x), copy(x) - x, Δx, x_old = copy(x), copy(x), copy(x) - Jinv = _init_identity_matrix(x) - return BroydenCache(Jinv, fx, Δfx, fx_old, x, Δx, x_old) -end - -BroydenCache(vec_length::Int, device) = BroydenCache(device(zeros(vec_length))) - """ BroydenSolver(; T=Float32, device, original_dims::Tuple{Int,Int}, batch_size, maxiters::Int=50, ϵ::Real=1e-6, abstol::Union{Real,Nothing}=nothing, reltol::Union{Real,Nothing}=nothing) @@ -39,64 +18,54 @@ Broyden Solver ([broyden1965class](@cite)) for solving Discrete DEQs. It is reco See also: [`LimitedMemoryBroydenSolver`](@ref) """ -struct BroydenSolver{C<:BroydenCache,T<:Real} - cache::C - maxiters::Int - batch_size::Int - ϵ::T -end - -function BroydenSolver(; T=Float32, device, original_dims::Tuple{Int,Int}, batch_size, maxiters::Int=50, ϵ::Real=1e-6, - abstol::Union{Real,Nothing}=nothing, reltol::Union{Real,Nothing}=nothing) - ϵ = abstol !== nothing ? abstol : ϵ - - if reltol !== nothing - @warn "reltol is set to $reltol, but `BroydenSolver` ignores this value" maxlog=1 - end - - x = device(zeros(T, prod(original_dims) * batch_size)) - cache = BroydenCache(x) - - return BroydenSolver(cache, maxiters, batch_size, T(ϵ)) +struct BroydenSolver end + +function nlsolve( + b::BroydenSolver, f::Function, y::AbstractArray{T}; terminate_condition, maxiters::Int=10 +) where {T} + res, stats = nlsolve( + b, + u -> vec(f(reshape(u, size(y)))), + vec(y); + terminate_condition, + maxiters + ) + return reshape(res, size(y)), stats end -function (broyden::BroydenSolver{C,T})(f!, x_::AbstractVector{T}) where {C,T} - @unpack Jinv, fx, Δfx, fx_old, x, Δx, x_old = broyden.cache - if size(x) != size(x_) - # This might happen when the last batch with insufficient batch_size - # is passed. - @unpack Jinv, fx, Δfx, fx_old, x, Δx, x_old = BroydenCache(x_) - end - x .= x_ - - f!(fx, x) - _init_identity_matrix!(Jinv) - - maybe_stuck = false - max_resets = 3 - resets = 0 - - for i in 1:(broyden.maxiters) - x_old .= x - fx_old .= fx +function nlsolve( + ::BroydenSolver, f::Function, y::AbstractVector{T}; terminate_condition, maxiters::Int=10 +) where {T} + x = copy(y) + x_old = copy(y) + Δx = copy(y) + fx_old = f(y) + Δfx = copy(fx_old) + Jinv = _init_identity_matrix(y) + p = similar(fx_old, (size(Jinv, 1),)) + ρ, σ₂ = T(0.9), T(0.001) - p = -Jinv * fx_old + maybe_stuck, max_resets, resets, nsteps, nf = false, 3, 0, 1, 1 - ρ, σ₂ = T(0.9), T(0.001) + while nsteps <= maxiters + mul!(p, Jinv, fx_old) + p .*= -1 - x .= x_old .+ p - f!(fx, x) + @. x = x_old + p + fx = f(x) + nf += 1 if norm(fx, 2) ≤ ρ * norm(fx_old, 2) - σ₂ * norm(p, 2)^2 α = T(1) else - α = _approximate_norm_descent(f!, fx, x, p) - x .= x_old .+ α * p - f!(fx, x) + α, _stats = _approximate_norm_descent(f, x, p) + @. x = x_old + α * p + fx = f(x) + nf += 1 + _stats.nf end - Δx .= x .- x_old - Δfx .= fx .- fx_old + @. Δx = x - x_old + @. Δfx = fx - fx_old maybe_stuck = all(abs.(Δx) .<= eps(T)) || all(abs.(Δfx) .<= eps(T)) if maybe_stuck @@ -109,48 +78,38 @@ function (broyden::BroydenSolver{C,T})(f!, x_::AbstractVector{T}) where {C,T} end maybe_stuck = false + nsteps += 1 + copyto!(fx_old, fx) + copyto!(x_old, x) # Convergence Check - norm(Δfx, 2) ≤ broyden.ϵ && return x + terminate_condition(fx, x) && break end - return x + return x, (nf=nf,) end -# https://doi.org/10.1080/10556780008805782 -# FIXME: We are dropping some robustness tests for now. -function _approximate_norm_descent(f!, fx::AbstractArray{T,N}, x::AbstractArray{T,N}, p; λ₀=T(1), β=T(0.5), σ₁=T(0.001), +function _approximate_norm_descent(f::Function, x::AbstractArray{T,N}, p; λ₀=T(1), β=T(0.5), σ₁=T(0.001), η=T(0.1), max_iter=50) where {T,N} λ₂, λ₁ = λ₀, λ₀ - f!(fx, x) + fx = f(x) fx_norm = norm(fx, 2) + j = 1 + fx = f(x .+ λ₂ .* p) + converged = false - # TODO: Test NaN/Finite - # f!(fx, x .- λ₂ .* p) - # fxλp_norm = norm(fx, 2) - # TODO: nan backtrack - - j = 0 - - f!(fx, x .+ λ₂ .* p) - converged = _test_approximate_norm_descent_convergence(f!, fx, x, fx_norm, p, σ₁, λ₂, η) - - while j < max_iter && !converged + while j <= max_iter && !converged j += 1 λ₁, λ₂ = λ₂, β * λ₂ - converged = _test_approximate_norm_descent_convergence(f!, fx, x, fx_norm, p, σ₁, λ₂, η) + converged = _test_approximate_norm_descent_convergence(f, x, fx_norm, p, σ₁, λ₂, η) end - return λ₂ + return λ₂, (nf=2(j + 1),) end -function _test_approximate_norm_descent_convergence(f!, fx, x, fx_norm, p, σ₁, λ₂, η) - f!(fx, x .+ λ₂ .* p) - n1 = norm(fx, 2) - - f!(fx, x) - n2 = norm(fx, 2) - +function _test_approximate_norm_descent_convergence(f, x, fx_norm, p, σ₁, λ₂, η) + n1 = norm(f(x .+ λ₂ .* p), 2) + n2 = norm(f(x), 2) return n1 ≤ fx_norm - σ₁ * norm(λ₂ .* p, 2) .^ 2 + η * n2 end diff --git a/src/solvers/discrete/limited_memory_broyden.jl b/src/solvers/discrete/limited_memory_broyden.jl index b0324f72..e4e5a2be 100644 --- a/src/solvers/discrete/limited_memory_broyden.jl +++ b/src/solvers/discrete/limited_memory_broyden.jl @@ -1,11 +1,4 @@ # Limited Memory Broyden -struct LimitedMemoryBroydenCache{uT,vT,F,X} - Us::uT - VTs::vT - fx_::F - x::X -end - """ LimitedMemoryBroydenSolver(; T=Float32, device, original_dims::Tuple{Int,Int}, batch_size, maxiters::Int=50, ϵ::Real=1e-6, criteria::Symbol=:reltol, abstol::Union{Real,Nothing}=nothing, @@ -28,142 +21,81 @@ Limited Memory Broyden Solver ([baimultiscale2020](@cite)) for solving Discrete See also: [`BroydenSolver`](@ref) """ -struct LimitedMemoryBroydenSolver{C<:LimitedMemoryBroydenCache,RT<:Union{AbstractFloat,Nothing}, - AT<:Union{AbstractFloat,Nothing}} - cache::C - original_dims::Tuple{Int,Int} - maxiters::Int - batch_size::Int - criteria::Symbol - reltol::RT - abstol::AT -end - -function LimitedMemoryBroydenSolver(; T=Float32, device, original_dims::Tuple{Int,Int}, batch_size, maxiters::Int=50, - ϵ::Real=1e-6, criteria::Symbol=:reltol, abstol::Union{Real,Nothing}=nothing, - reltol::Union{Real,Nothing}=nothing) - @assert criteria ∈ (:abstol, :reltol) - - abstol = abstol !== nothing ? T(abstol) : T(ϵ) - reltol = reltol !== nothing ? T(reltol) : T(ϵ) - - LBFGS_threshold = min(maxiters, 27) - - x = device(zeros(T, original_dims..., batch_size)) - fx = device(zeros(T, original_dims..., batch_size)) - - total_hsize, n_elem, batch_size = size(x) - - # L x 2D x C x N - Us = fill!(similar(x, (LBFGS_threshold, total_hsize, n_elem, batch_size)), T(0)) - # 2D x C x L x N - VTs = fill!(similar(x, (total_hsize, n_elem, LBFGS_threshold, batch_size)), T(0)) - - cache = LimitedMemoryBroydenCache(Us, VTs, vec(fx), x) - - return LimitedMemoryBroydenSolver(cache, original_dims, maxiters, batch_size, criteria, reltol, abstol) -end +struct LimitedMemoryBroydenSolver end -function line_search(update, x₀, f₀, f, nstep::Int=0, on::Bool=false) - # TODO: Implement a line search algorithm - x_est = x₀ .+ update - f₀_new = f(x_est) - return (x_est, f₀_new, x_est .- x₀, f₀_new .- f₀, 0) +function nlsolve(l::LimitedMemoryBroydenSolver, f::Function, y::AbstractMatrix; kwargs...) + res, stats = nlsolve(l, f, reshape(y, size(y, 1), 1, size(y, 2)); kwargs...) + return dropdims(res; dims=2), stats end -function (lbroyden::LimitedMemoryBroydenSolver{C,T})(f!, x_::AbstractVector{T}) where {C,T} - @unpack cache, original_dims, batch_size, maxiters, criteria, reltol, abstol = lbroyden - ϵ = getfield(lbroyden, criteria) +function nlsolve( + ::LimitedMemoryBroydenSolver, f::Function, y::AbstractArray{T,3}; terminate_condition, maxiters::Int=10 +) where {T} + LBFGS_threshold = min(maxiters, 27) - nfeatures = prod(original_dims) - if nfeatures * batch_size != length(x_) - # Maybe the last batch is smaller than the others - cache = LimitedMemoryBroydenSolver(; T=T, device=x_ isa CuArray ? gpu : cpu, original_dims=original_dims, - batch_size=length(x_) ÷ nfeatures, maxiters=maxiters, ϵ=ϵ).cache - end + total_hsize, n_elem, batch_size = size(y) - @unpack Us, VTs, fx_, x = cache - x .= reshape(x_, size(x)) - LBFGS_threshold = size(Us, 1) - fill!(Us, T(0)) - fill!(VTs, T(0)) + # Initialize the cache + x₀ = copy(y) + fx₀ = f(x₀) + x₁ = copy(y) + Δx = copy(x₀) + Δfx = copy(x₀) + Us = fill!(similar(y, (LBFGS_threshold, total_hsize, n_elem, batch_size)), T(0)) + VTs = fill!(similar(y, (total_hsize, n_elem, LBFGS_threshold, batch_size)), T(0)) # Counters nstep = 1 - tnstep = 1 - - # Initialize - total_hsize, n_elem, batch_size = actual_size = size(x) - - # Modify the functions - f(x) = (f!(fx_, vec(x)); return reshape(fx_, actual_size)) - fx = f(x) - - update = fx - new_objective = norm(fx) - objective_values = [new_objective] - protect_threshold = (criteria == :abstol ? T(1e6) : T(1e3)) * n_elem - initial_objective = new_objective - lowest_objective = new_objective - lowest_xest = x + # Main Algorithm + update = fx₀ - @inbounds while nstep < maxiters - x, fx, Δx, Δfx, ite = line_search(update, x, fx, f, nstep, false) - nstep += 1 - tnstep += (ite + 1) - - new_objective = criteria == :abstol ? norm(fx) : (norm(fx) / (norm(fx .+ x) + eps(T))) - push!(objective_values, new_objective) - - if new_objective < lowest_objective - lowest_objective = new_objective - lowest_xest = x - end - new_objective < ϵ && break + while nstep <= maxiters + # Update + @. x₁ = x₀ + update + fx₁ = f(x₁) + @. Δx = x₁ - x₀ + @. Δfx = fx₁ - fx₀ - new_objective < 3ϵ && - nstep >= 30 && - maximum(objective_values[(end - nstep + 1):end]) < 1.3 * minimum(objective_values[(end - nstep + 1):end]) && - break - - # Prevent Divergence - (new_objective > initial_objective * protect_threshold) && break + # Convergence Check + terminate_condition(fx₁, x₁) && break + # Compute the update @views part_Us = Us[1:min(LBFGS_threshold, nstep), :, :, :] @views part_VTs = VTs[:, :, 1:min(LBFGS_threshold, nstep), :] - vT = rmatvec(part_Us, part_VTs, Δx) # 2D x C x N - u = (Δx .- matvec(part_Us, part_VTs, Δfx)) ./ sum(vT .* Δfx; dims=(1, 2)) # 2D x C x N - vT[.!isfinite.(vT)] .= T(0) - u[.!isfinite.(u)] .= T(0) + vT = rmatvec(part_Us, part_VTs, Δx) # D x C x N + mvec = matvec(part_Us, part_VTs, Δfx) + vTΔfx = sum(vT .* Δfx; dims=(1, 2)) + @. Δx = (Δx - mvec) / (vTΔfx + eps(T)) # D x C x N @views VTs[:, :, mod1(nstep, LBFGS_threshold), :] .= vT - @views Us[mod1(nstep, LBFGS_threshold), :, :, :] .= u + @views Us[mod1(nstep, LBFGS_threshold), :, :, :] .= Δx + + @views update = + -matvec( + Us[1:min(LBFGS_threshold, nstep + 1), :, :, :], VTs[:, :, 1:min(LBFGS_threshold, nstep + 1), :], fx₁ + ) + copyto!(x₀, x₁) + copyto!(fx₀, fx₁) - @views update = -matvec(Us[1:min(LBFGS_threshold, nstep + 1), :, :, :], - VTs[:, :, 1:min(LBFGS_threshold, nstep + 1), :], fx) + # Increment Counter + nstep += 1 end - return vec(lowest_xest) + return x₁, (nf=nstep + 1,) end -function matvec(part_Us::AbstractArray{E,4}, part_VTs::AbstractArray{E,4}, x::AbstractArray{E,3}) where {E} - # part_Us -> (T x D x C x N) - # part_VTs -> (D x C x T x N) - # x -> (D x C x N) - length(part_Us) == 0 && return -x - T, D, C, N = size(part_Us) +@inbounds function matvec(part_Us::AbstractArray{E,4}, part_VTs::AbstractArray{E,4}, x::AbstractArray{E,3}) where {E} + # part_Us -> (T x D x C x N) | part_VTs -> (D x C x T x N) | x -> (D x C x N) + _, D, C, N = size(part_Us) xTU = sum(reshape(x, (1, D, C, N)) .* part_Us; dims=(2, 3)) # T x 1 x 1 x N - return -x .+ reshape(sum(permutedims(xTU, (2, 3, 1, 4)) .* part_VTs; dims=3), (D, C, N)) + return -x .+ dropdims(sum(permutedims(xTU, (2, 3, 1, 4)) .* part_VTs; dims=3); dims=3) end function rmatvec(part_Us::AbstractArray{E,4}, part_VTs::AbstractArray{E,4}, x::AbstractArray{E,3}) where {E} - # part_Us -> (T x D x C x N) - # part_VTs -> (D x C x T x N) - # x -> (D x C x N) - length(part_Us) == 0 && return -x - T, D, C, N = size(part_Us) + # part_Us -> (T x D x C x N) | part_VTs -> (D x C x T x N) | x -> (D x C x N) + _, D, C, N = size(part_Us) VTx = sum(part_VTs .* reshape(x, (D, C, 1, N)); dims=(1, 2)) # 1 x 1 x T x N - return -x .+ reshape(sum(part_Us .* permutedims(VTx, (3, 1, 2, 4)); dims=1), (D, C, N)) + return -x .+ dropdims(sum(part_Us .* permutedims(VTx, (3, 1, 2, 4)); dims=1); dims=1) end diff --git a/src/utils.jl b/src/utils.jl index e88b3a01..83336047 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -36,13 +36,6 @@ function split_and_reshape(x::AbstractMatrix, idxs::Tuple, shapes::Tuple) ) end -## Some dispatches for CuArrays are not defined for subarrays -# function split_and_reshape(x::AbstractMatrix, idxs::Tuple, shapes::Tuple) -# return Tuple( -# x[reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...), :] for i in 1:(length(idxs) - 1) -# ) -# end - # Zygote Fix function Zygote.accum(x::NTuple{N,T}, y::AbstractVector{T}) where {N,T<:AbstractArray} return Zygote.accum.(x, y) diff --git a/test/runtests.jl b/test/runtests.jl index b8659b6c..a6700c0c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,7 +11,7 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), ), ) - ps, st = gpu.(EFL.setup(MersenneTwister(seed), model)) + ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) x = gpu(rand(MersenneTwister(seed + 1), Float32, 2, 1)) y = gpu(rand(MersenneTwister(seed + 2), Float32, 2, 1)) @@ -32,7 +32,7 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ), ) - ps, st = gpu.(EFL.setup(MersenneTwister(seed), model)) + ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) x = gpu(rand(MersenneTwister(seed + 1), Float32, 2, 1)) y = gpu(rand(MersenneTwister(seed + 2), Float32, 2, 1)) @@ -59,7 +59,7 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ), ) - ps, st = gpu.(EFL.setup(MersenneTwister(seed), model)) + ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) x = gpu(rand(MersenneTwister(seed + 1), Float32, 2, 1)) y = gpu(rand(MersenneTwister(seed + 2), Float32, 2, 1)) @@ -76,43 +76,43 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end - # @info "Testing Broyden Solver" - # Random.seed!(0) - - # model = gpu(DEQChain(Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - # SkipDeepEquilibriumNetwork(Parallel(+, Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - # Conv((3, 3), 1 => 1, relu; pad=1, stride=1)), - # Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - # DiscreteDEQSolver(BroydenSolver; abstol=0.001f0, - # reltol=0.001f0, device=gpu, original_dims=(8 * 8, 1), - # batch_size=4); - # sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10)))) - # x = gpu(rand(Float32, 8, 8, 1, 4)) - # y = gpu(rand(Float32, 8, 8, 1, 4)) - # ps = Flux.params(model) - # gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - # for _p in ps - # @test all(isfinite.(gs[_p])) - # end - - # @info "Testing L-Broyden Solver" - # Random.seed!(0) - - # model = gpu(DEQChain(Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - # SkipDeepEquilibriumNetwork(Parallel(+, Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - # Conv((3, 3), 1 => 1, relu; pad=1, stride=1)), - # Conv((3, 3), 1 => 1, relu; pad=1, stride=1), - # DiscreteDEQSolver(LimitedMemoryBroydenSolver; abstol=0.001f0, - # reltol=0.001f0, device=gpu, original_dims=(8 * 8, 1), - # batch_size=4); - # sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10)))) - # x = gpu(rand(Float32, 8, 8, 1, 4)) - # y = gpu(rand(Float32, 8, 8, 1, 4)) - # ps = Flux.params(model) - # gs = Flux.gradient(() -> mse_loss_function(model, x, y), ps) - # for _p in ps - # @test all(isfinite.(gs[_p])) - # end + @info "Testing SkipDEQ with Broyden Solver" + model = DEQChain( + EFL.Dense(2, 2), + SkipDeepEquilibriumNetwork( + EFL.Parallel(+, EFL.Dense(2, 2), EFL.Dense(2, 2)), + EFL.Dense(2, 2), + DiscreteDEQSolver(BroydenSolver(); abstol_termination=0.1f0, reltol_termination=0.1f0); + sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + ), + ) + ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) + x = gpu(rand(MersenneTwister(seed + 1), Float32, 2, 1)) + y = gpu(rand(MersenneTwister(seed + 2), Float32, 2, 1)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end + + @info "Testing SkipDEQ with L-Broyden Solver" + model = DEQChain( + EFL.Dense(2, 2), + SkipDeepEquilibriumNetwork( + EFL.Parallel(+, EFL.Dense(2, 2), EFL.Dense(2, 2)), + EFL.Dense(2, 2), + DiscreteDEQSolver(LimitedMemoryBroydenSolver(); abstol_termination=0.1f0, reltol_termination=0.1f0); + sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + ), + ) + ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) + x = gpu(rand(MersenneTwister(seed + 1), Float32, 2, 1)) + y = gpu(rand(MersenneTwister(seed + 2), Float32, 2, 1)) + + gs = gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end @info "Testing MultiScaleDEQ" model = MultiScaleDeepEquilibriumNetwork( @@ -134,7 +134,7 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ) - ps, st = gpu.(EFL.setup(MersenneTwister(seed), model)) + ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) x = gpu(rand(MersenneTwister(seed + 1), Float32, 4, 2)) y = tuple([gpu(rand(MersenneTwister(seed + 1 + i), Float32, i, 2)) for i in 4:-1:1]...) @@ -172,7 +172,7 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ) - ps, st = gpu.(EFL.setup(MersenneTwister(seed), model)) + ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) x = gpu(rand(MersenneTwister(seed + 1), Float32, 4, 2)) y = tuple([gpu(rand(MersenneTwister(seed + 1 + i), Float32, i, 2)) for i in 4:-1:1]...) @@ -210,7 +210,7 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ) - ps, st = gpu.(EFL.setup(MersenneTwister(seed), model)) + ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) x = gpu(rand(MersenneTwister(seed + 1), Float32, 4, 2)) y = tuple([gpu(rand(MersenneTwister(seed + 1 + i), Float32, i, 2)) for i in 4:-1:1]...) From 1ba67368effbd34acd3f6cacf8c3df5d4420025f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 21 Apr 2022 08:27:04 -0400 Subject: [PATCH 40/76] Add termination.jl --- src/solvers/termination.jl | 138 +++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 src/solvers/termination.jl diff --git a/src/solvers/termination.jl b/src/solvers/termination.jl new file mode 100644 index 00000000..d812f92c --- /dev/null +++ b/src/solvers/termination.jl @@ -0,0 +1,138 @@ +get_mode(::Val{mode}) where {mode} = mode + +function get_terminate_condition(alg::ContinuousDEQSolver{M,A,T}, args...; kwargs...) where {M,A,T} + mode = get_mode(M) + if mode ∈ (:abs_deq_default, :rel_deq_default, :abs_deq_best, :rel_deq_best) + nstep, protective_threshold, objective_values = 0, T(1e3), T[] + + if mode ∈ (:rel_deq_best, :abs_deq_best) + @assert length(args) == 1 + + args[1][:best_objective_value] = T(Inf) + args[1][:best_objective_value_iteration] = 0 + end + + function terminate_condition_closure_1(integrator, abstol, reltol, min_t) + du, u = DiffEqBase.get_du(integrator), integrator.u + objective = norm(du) / (mode ∈ (:abs_deq_default, :abs_deq_best) ? 1 : (norm(du .+ u) + eps(T))) + criteria = mode ∈ (:abs_deq_default, :abs_deq_best) ? abstol : reltol + + if mode ∈ (:rel_deq_best, :abs_deq_best) + if objective < args[1][:best_objective_value] + args[1][:best_objective_value] = objective + args[1][:best_objective_value_iteration] = nstep + 1 + end + end + + # Main Termination Criteria + objective <= criteria && return true + + # Terminate if there has been no improvement for the last 30 steps + nstep += 1 + push!(objective_values, objective) + + objective <= 3 * criteria && + nstep >= 30 && + maximum(objective_values[max(1, length(objective_values) - nstep):end]) < + 1.3 * minimum(objective_values[max(1, length(objective_values) - nstep):end]) && + return true + + # Protective break + objective >= objective_values[1] * protective_threshold * length(du) && return true + + return false + end + return terminate_condition_closure_1 + else + function terminate_condition_closure_2(integrator, abstol, reltol, min_t) + return has_converged(DiffEqBase.get_du(integrator), integrator.u, M, abstol, reltol) + end + return terminate_condition_closure_2 + end +end + +function get_terminate_condition(alg::DiscreteDEQSolver{M,A,T}, args...; kwargs...) where {M,A,T} + mode = get_mode(M) + if mode ∈ (:abs_deq_default, :rel_deq_default, :abs_deq_best, :rel_deq_best) + nstep, protective_threshold, objective_values = 0, T(1e3), T[] + + if mode ∈ (:rel_deq_best, :abs_deq_best) + @assert length(args) == 1 + + args[1][:best_objective_value] = T(Inf) + args[1][:best_objective_value_iteration] = 0 + end + + function terminate_condition_closure_1(du, u) + objective = norm(du) / (mode ∈ (:abs_deq_default, :abs_deq_best) ? 1 : (norm(du .+ u) + eps(T))) + criteria = mode ∈ (:abs_deq_default, :abs_deq_best) ? alg.abstol_termination : alg.reltol_termination + + if mode ∈ (:rel_deq_best, :abs_deq_best) + if objective < args[1][:best_objective_value] + args[1][:best_objective_value] = objective + args[1][:best_objective_value_iteration] = nstep + 1 + end + end + + # Main Termination Criteria + objective <= criteria && return true + + # Terminate if there has been no improvement for the last 30 steps + nstep += 1 + push!(objective_values, objective) + + objective <= 3 * criteria && + nstep >= 30 && + maximum(objective_values[max(1, length(objective_values) - nstep):end]) < + 1.3 * minimum(objective_values[max(1, length(objective_values) - nstep):end]) && + return true + + # Protective break + objective >= objective_values[1] * protective_threshold * length(du) && return true + + return false + end + return terminate_condition_closure_1 + else + function terminate_condition_closure_2(du, u) + return has_converged(du, u, M, alg.abstol_termination, alg.reltol_termination) + end + return terminate_condition_closure_2 + end +end + +# Convergence Criterions +function has_converged( + du, + u, + alg::Union{ContinuousDEQSolver{M},DiscreteDEQSolver{M}}, + abstol=alg.abstol_termination, + reltol=alg.reltol_termination, +) where {M} + return has_converged(du, u, M, abstol, reltol) +end + +function has_converged(du, u, M, abstol, reltol) + mode = get_mode(M) + if mode == :norm + return norm(du) <= abstol && norm(du) <= reltol * norm(du .+ u) + elseif mode == :rel + return all(abs.(du) .<= reltol .* abs.(u)) + elseif mode == :rel_norm + return norm(du) <= reltol * norm(du .+ u) + elseif mode == :rel_deq_default + return norm(du) <= reltol * norm(du .+ u) + elseif mode == :rel_deq_best + return norm(du) <= reltol * norm(du .+ u) + elseif mode == :abs + return all(abs.(du) .<= abstol) + elseif mode == :abs_norm + return norm(du) <= abstol + elseif mode == :abs_deq_default + return norm(du) <= abstol + elseif mode == :abs_deq_best + return norm(du) <= abstol + else + return all(abs.(du) .<= abstol .& abs.(du) .<= reltol .* abs.(u)) + end +end From 94f8fb8af1ec186547df938ce78372c6881b0c8f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 21 Apr 2022 09:26:27 -0400 Subject: [PATCH 41/76] extra in train --- examples/cifar10/script.jl | 1 - examples/src/models.jl | 13 ++++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index 68d42ffd..ab9a36bc 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -44,7 +44,6 @@ function train_model(config, expt_name) train_dataloader, nothing, test_dataloader, - gpu, expt_config.nepochs, lg, expt_config, diff --git a/examples/src/models.jl b/examples/src/models.jl index 00afda13..b247e838 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -381,22 +381,21 @@ function get_model( clean_println("Starting Model Warmup") x__ = device(randn(Float32, config.image_size..., 3, 1)) y__ = device(Float32.(Flux.onehotbatch([1], 0:9))) + model(x__, ps, st) + clean_println("Forward Pass Warmup Completed") st_ = EFL.update_state(st, :fixed_depth, 2) model(x__, ps, st_) clean_println("Forward Pass (Pretraining) Warmup Completed") - model(x__, ps, st) - clean_println("Forward Pass Warmup Completed") - lfn = loss_function(config) - (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st_), ps) - back((one(l), nothing, nothing, nothing)) - clean_println("Backward Pass (Pretraining) Warmup Completed") - (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st), ps) back((one(l), nothing, nothing, nothing)) clean_println("Backward Pass Warmup Completed") + + (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st_), ps) + back((one(l), nothing, nothing, nothing)) + clean_println("Backward Pass (Pretraining) Warmup Completed") invoke_gc() end From 5e45c4fbb929cf95b9cb63dcc6f9f97958d1686e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 21 Apr 2022 13:49:37 -0400 Subject: [PATCH 42/76] Update config --- examples/src/config.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/config.jl b/examples/src/config.jl index 377b7cd5..9725bb1b 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -210,8 +210,8 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) if model_size == :TINY return ExperimentConfiguration( model_config=get_model_config(dataset, model_size; kwargs...), - eval_batchsize=64, - train_batchsize=64, + eval_batchsize=128, + train_batchsize=128, nepochs=50, pretrain_steps=3000 ÷ (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, From 79348600aae933709946e55f37c3cdb2eb4b1ff2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 23 Apr 2022 13:29:32 -0400 Subject: [PATCH 43/76] Modify architecture to make it GPU friendly --- examples/Manifest.toml | 87 ++++++------- examples/src/models.jl | 83 ++++++------ src/layers/deq.jl | 14 +- src/layers/mdeq.jl | 10 +- src/solvers/discrete.jl | 5 +- .../discrete/limited_memory_broyden.jl | 49 ++++--- src/solvers/termination.jl | 12 +- test/runtests.jl | 121 ++++++++++++++---- 8 files changed, 238 insertions(+), 143 deletions(-) diff --git a/examples/Manifest.toml b/examples/Manifest.toml index b0b0337e..355bde6a 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.8.0-beta3" +julia_version = "1.7.2" manifest_format = "2.0" project_hash = "55e1c12df8760eecee422230ecbed81b99e0268c" @@ -29,7 +29,6 @@ version = "2.3.0" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" [[deps.ArnoldiMethod]] deps = ["LinearAlgebra", "Random", "StaticArrays"] @@ -121,15 +120,15 @@ uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" version = "1.0.0" [[deps.CEnum]] -git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9" +git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.4.1" +version = "0.4.2" [[deps.CPUSummary]] deps = ["CpuId", "IfElse", "Static"] -git-tree-sha1 = "80f3d536df634cabed8b98ad3f0cea3a715fd254" +git-tree-sha1 = "baaac45b4462b3b0be16726f38b789bf330fcb7a" uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" -version = "0.1.20" +version = "0.1.21" [[deps.CSV]] deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings"] @@ -150,9 +149,9 @@ version = "0.3.10" [[deps.ChainRules]] deps = ["ChainRulesCore", "Compat", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics"] -git-tree-sha1 = "8b887daa6af5daf705081061e36386190204ac87" +git-tree-sha1 = "f1e926b37a2e1f64388be59b1baff4152eae67b9" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.28.1" +version = "1.28.2" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] @@ -204,7 +203,6 @@ version = "3.43.0" [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "0.5.2+0" [[deps.ComponentArrays]] deps = ["ArrayInterface", "ChainRulesCore", "LinearAlgebra", "Requires"] @@ -381,9 +379,8 @@ uuid = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" version = "0.5.9" [[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" [[deps.EllipsisNotation]] deps = ["ArrayInterface"] @@ -404,8 +401,8 @@ uuid = "7cc45869-7501-5eee-bdea-0790c847d4ef" version = "0.0.29+0" [[deps.ExplicitFluxLayers]] -deps = ["CUDA", "ChainRulesCore", "ComponentArrays", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Optimisers", "Random", "Setfield", "SparseArrays", "Statistics", "Zygote"] -git-tree-sha1 = "e923ed1f219c9c505faa2c05af29dbb74b4f5760" +deps = ["CUDA", "ChainRulesCore", "ComponentArrays", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Optimisers", "Random", "Setfield", "SparseArrays", "Statistics", "Yota", "Zygote"] +git-tree-sha1 = "ea89d52a2be10118b7b5373ef11c277a0c206a16" repo-rev = "ap/sparse" repo-url = "https://github.com/avik-pal/ExplicitFluxLayers.jl.git" uuid = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" @@ -475,9 +472,6 @@ git-tree-sha1 = "129b104185df66e408edd6625d480b7f9e9823a0" uuid = "48062228-2e41-5def-b9a4-89aafe57970f" version = "0.9.18" -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - [[deps.FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] git-tree-sha1 = "246621d23d1f43e3b9c368bf3b72b2331a27c286" @@ -490,6 +484,12 @@ git-tree-sha1 = "56956d1e4c1221000b7781104c58c34019792951" uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" version = "2.11.0" +[[deps.FiniteDifferences]] +deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "SparseArrays", "StaticArrays"] +git-tree-sha1 = "0ee1275eb003b6fc7325cb14301665d1072abda1" +uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" +version = "0.12.24" + [[deps.FixedPointNumbers]] deps = ["Statistics"] git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" @@ -590,9 +590,9 @@ version = "1.6.0" [[deps.HDF5]] deps = ["Compat", "HDF5_jll", "Libdl", "Mmap", "Random", "Requires"] -git-tree-sha1 = "cdd249512de03cbf8370365a0a08b9a24955dca9" +git-tree-sha1 = "36df177c1ce5f399a8de959e5f4b75216fe6c834" uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" -version = "0.16.6" +version = "0.16.7" [[deps.HDF5_jll]] deps = ["Artifacts", "JLLWrappers", "LibCURL_jll", "Libdl", "OpenSSL_jll", "Pkg", "Zlib_jll"] @@ -837,12 +837,10 @@ version = "1.0.0" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.3" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "7.81.0+0" [[deps.LibGit2]] deps = ["Base64", "NetworkOptions", "Printf", "SHA"] @@ -851,7 +849,6 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.10.2+0" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -987,7 +984,6 @@ version = "1.0.3" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.0+0" [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] @@ -1012,7 +1008,6 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.2.1" [[deps.MuladdMacro]] git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" @@ -1056,7 +1051,6 @@ version = "0.1.5" [[deps.NetworkOptions]] uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" [[deps.NonlinearSolve]] deps = ["ArrayInterface", "FiniteDiff", "ForwardDiff", "IterativeSolvers", "LinearAlgebra", "RecursiveArrayTools", "RecursiveFactorization", "Reexport", "SciMLBase", "Setfield", "StaticArrays", "UnPack"] @@ -1079,12 +1073,10 @@ version = "1.10.8" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.20+0" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+0" [[deps.OpenMPI_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] @@ -1147,9 +1139,9 @@ version = "0.12.3" [[deps.Parsers]] deps = ["Dates"] -git-tree-sha1 = "621f4f3b4977325b9128d5fae7a8b4829a0c2222" +git-tree-sha1 = "3b429f37de37f1fc603cc1de4a799dc7fbe4c0b6" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.2.4" +version = "2.3.0" [[deps.Pickle]] deps = ["DataStructures", "InternedStrings", "Serialization", "SparseArrays", "Strided", "StringEncodings", "ZipFile"] @@ -1160,7 +1152,6 @@ version = "0.3.1" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.8.0" [[deps.PoissonRandom]] deps = ["Random", "Statistics", "Test"] @@ -1309,6 +1300,12 @@ git-tree-sha1 = "559db2c7a28262e9ff1af1ad4ec539aa972c8934" uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267" version = "1.13.0" +[[deps.Richardson]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "e03ca566bec93f8a3aeb059c8ef102f268a38949" +uuid = "708f8203-808e-40c0-ba2d-98a6953ed40d" +version = "1.4.0" + [[deps.Rmath]] deps = ["Random", "Rmath_jll"] git-tree-sha1 = "bf3188feca147ce108c76ad82c2792c57abe7b1f" @@ -1329,7 +1326,6 @@ version = "0.5.3" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" [[deps.SIMDDualNumbers]] deps = ["ForwardDiff", "IfElse", "SLEEFPirates", "VectorizationBase"] @@ -1435,9 +1431,9 @@ version = "0.6.0" [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "4f6ec5d99a28e1a749559ef7dd518663c5eca3d5" +git-tree-sha1 = "cd56bf18ed715e8b09f06ef8c6b781e6cdc49911" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.4.3" +version = "1.4.4" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -1445,9 +1441,9 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[deps.StatsAPI]] deps = ["LinearAlgebra"] -git-tree-sha1 = "8d7530a38dbd2c397be7ddd01a424e4f411dcc41" +git-tree-sha1 = "c82aaa13b44ea00134f8c9c89819477bd3986ecd" uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.2.2" +version = "1.3.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] @@ -1510,12 +1506,10 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "5.10.1+0" [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.0" [[deps.TableTraits]] deps = ["IteratorInterfaceExtensions"] @@ -1532,7 +1526,6 @@ version = "1.7.0" [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" [[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] @@ -1552,9 +1545,9 @@ version = "0.5.0" [[deps.TimerOutputs]] deps = ["ExprTools", "Printf"] -git-tree-sha1 = "d60b0c96a16aaa42138d5d38ad386df672cb8bd8" +git-tree-sha1 = "11db03dd5bbc0d2b57a570d228a0f34538c586b1" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.16" +version = "0.5.17" [[deps.Tracker]] deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"] @@ -1611,6 +1604,12 @@ version = "1.3.0" deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +[[deps.Umlaut]] +deps = ["LinearAlgebra", "Statistics", "Test"] +git-tree-sha1 = "1428bb6784d43298b29503b4a08b8a51b13e4c07" +uuid = "92992a2b-8ce5-4a9c-bb9d-58be9a7dc841" +version = "0.2.4" + [[deps.UnPack]] git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" @@ -1643,6 +1642,12 @@ git-tree-sha1 = "de67fa59e33ad156a590055375a30b23c40299d3" uuid = "efce3f68-66dc-5838-9240-27a6d6f5f9b6" version = "0.5.5" +[[deps.Yota]] +deps = ["ChainRules", "ChainRulesCore", "FiniteDifferences", "LinearAlgebra", "NNlib", "Random", "Statistics", "Test", "UUIDs", "Umlaut"] +git-tree-sha1 = "b4eef79929bab5503cbc6ca495aa205bdab98978" +uuid = "cd998857-8626-517d-b929-70ad188a48f0" +version = "0.7.3" + [[deps.ZipFile]] deps = ["Libdl", "Printf", "Zlib_jll"] git-tree-sha1 = "3593e69e469d2111389a9bd06bac1f3d730ac6de" @@ -1652,7 +1657,6 @@ version = "0.9.4" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.12+1" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] @@ -1669,14 +1673,11 @@ version = "0.2.2" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.1.0+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.41.0+1" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "16.2.1+1" diff --git a/examples/src/models.jl b/examples/src/models.jl index b247e838..b0745abd 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -30,7 +30,8 @@ function downsample_module(mapping, level_diff, activation; group_count=8) for i in 1:level_diff inchs, outchs = intermediate_mapping(i) push!(layers, conv3x3(inchs => outchs; stride=2)) - push!(layers, EFL.GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) + # push!(layers, EFL.GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) + push!(layers, EFL.BatchNorm(outchs, activation; affine=true, track_stats=false)) end return EFL.Chain(layers...) end @@ -51,7 +52,8 @@ function upsample_module(mapping, level_diff, activation; upsample_mode::Symbol= for i in 1:level_diff inchs, outchs = intermediate_mapping(i) push!(layers, conv1x1(inchs => outchs)) - push!(layers, EFL.GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) + # push!(layers, EFL.GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) + push!(layers, EFL.BatchNorm(outchs, activation; affine=true, track_stats=false)) push!(layers, EFL.Upsample(upsample_mode; scale=2)) end return EFL.Chain(layers...) @@ -68,12 +70,11 @@ function ResidualBlockV1( gn_affine::Bool=true, weight_norm::Bool=true, gn_track_stats::Bool=false, - dropout_seed::UInt64=UInt64(0), ) inplanes, outplanes = mapping inner_planes = outplanes * deq_expand - conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; bias=true) - conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=true) + conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; bias=false) + conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=false) conv1, conv2 = if weight_norm EFL.WeightNorm(conv1, (:weight,), (4,)), EFL.WeightNorm(conv2, (:weight,), (4,)) @@ -81,11 +82,14 @@ function ResidualBlockV1( conv1, conv2 end - gn1 = EFL.GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - gn2 = EFL.GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - gn3 = EFL.GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + # gn1 = EFL.GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # gn2 = EFL.GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # gn3 = EFL.GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + gn1 = EFL.BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) + gn2 = EFL.BatchNorm(outplanes, relu; affine=gn_affine, track_stats=gn_track_stats) + gn3 = EFL.BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) - dropout = iszero(dropout_rate) ? EFL.NoOpLayer() : EFL.Dropout(dropout_rate; initial_seed=dropout_seed) + dropout = iszero(dropout_rate) ? EFL.NoOpLayer() : EFL.VariationalHiddenDropout(dropout_rate) return EFL.Chain( EFL.Parallel( @@ -101,7 +105,8 @@ function ResidualBlockV1( gn2, ), # For (y2, injection) ), - EFL.WrappedFunction(Base.Fix1(broadcast, relu)), + # EFL.WrappedFunction(Base.Fix1(broadcast, relu)), + EFL.ActivationFunction(relu), gn3, ) end @@ -116,12 +121,11 @@ function ResidualBlockV2( gn_affine::Bool=true, weight_norm::Bool=true, gn_track_stats::Bool=false, - dropout_seed::UInt64=UInt64(0), ) inplanes, outplanes = mapping inner_planes = outplanes * deq_expand - conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; bias=true) - conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=true) + conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; bias=false) + conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=false) conv1, conv2 = if weight_norm EFL.WeightNorm(conv1, (:weight,), (4,)), EFL.WeightNorm(conv2, (:weight,), (4,)) @@ -129,18 +133,22 @@ function ResidualBlockV2( conv1, conv2 end - gn1 = EFL.GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - gn2 = EFL.GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - gn3 = EFL.GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + # gn1 = EFL.GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # gn2 = EFL.GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # gn3 = EFL.GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + gn1 = EFL.BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) + gn2 = EFL.BatchNorm(outplanes, relu; affine=gn_affine, track_stats=gn_track_stats) + gn3 = EFL.BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) - dropout = iszero(dropout_rate) ? EFL.NoOpLayer() : EFL.Dropout(dropout_rate; initial_seed=dropout_seed) + dropout = iszero(dropout_rate) ? EFL.NoOpLayer() : EFL.VariationalHiddenDropout(dropout_rate) return EFL.Chain( conv1, gn1, conv2, EFL.Parallel(+, downsample, EFL.Chain(dropout, gn2)), - EFL.WrappedFunction(Base.Fix1(broadcast, relu)), + # EFL.WrappedFunction(Base.Fix1(broadcast, relu)), + EFL.ActivationFunction(relu), gn3, ) end @@ -171,7 +179,8 @@ function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool ), ), ), - EFL.WrappedFunction(Base.Fix1(broadcast, relu)), + # EFL.WrappedFunction(Base.Fix1(broadcast, relu)), + EFL.ActivationFunction(relu), ) end @@ -198,7 +207,8 @@ function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), ), ), - EFL.WrappedFunction(Base.Fix1(broadcast, relu)), + # EFL.WrappedFunction(Base.Fix1(broadcast, relu)), + EFL.ActivationFunction(relu), ) end @@ -241,8 +251,6 @@ function get_model( initial_layers = EFL.Chain(downsample, stage0) - dropout_seed = UInt64(0) - main_layers = Tuple( ResidualBlockV1( config.num_channels[i] => config.num_channels[i]; @@ -250,12 +258,9 @@ function get_model( dropout_rate=config.dropout_rate, num_gn_groups=config.group_count, n_big_kernels=config.big_kernels[i], - dropout_seed=dropout_seed + (i - 1) * 100, ) for i in 1:(config.num_branches) ) - dropout_seed = dropout_seed + config.num_branches * 100 - mapping_layers = Matrix{EFL.AbstractExplicitLayer}(undef, config.num_branches, config.num_branches) for i in 1:(config.num_branches) for j in 1:(config.num_branches) @@ -279,9 +284,11 @@ function get_model( post_fuse_layers = Tuple( EFL.Chain( - EFL.WrappedFunction(Base.Fix1(broadcast, relu)), + # EFL.WrappedFunction(Base.Fix1(broadcast, relu)), + EFL.ActivationFunction(relu), conv1x1(config.num_channels[i] => config.num_channels[i]), - EFL.GroupNorm(config.num_channels[i], config.group_count ÷ 2; affine=true, track_stats=false), + # EFL.GroupNorm(config.num_channels[i], config.group_count ÷ 2; affine=true, track_stats=false), + EFL.BatchNorm(config.num_channels[i]; affine=true, track_stats=false), ) for i in 1:(config.num_branches) ) @@ -314,8 +321,8 @@ function get_model( ContinuousDEQSolver( config.ode_solver; mode=config.stop_mode, - abstol=1.0f-5, #config.abstol, - reltol=1.0f-5, #config.reltol, + abstol=1.0f-5, + reltol=1.0f-5, abstol_termination=config.abstol, reltol_termination=config.reltol, ) @@ -327,9 +334,7 @@ function get_model( deq = if config.model_type ∈ (:SKIP, :SKIPV2) shortcut = if config.model_type == :SKIP - slayers = EFL.AbstractExplicitLayer[ResidualBlockV2( - config.num_channels[1] => config.num_channels[1]; dropout_seed=dropout_seed - )] + slayers = EFL.AbstractExplicitLayer[ResidualBlockV2(config.num_channels[1] => config.num_channels[1])] for i in 1:(config.num_branches - 1) push!( slayers, @@ -372,15 +377,17 @@ function get_model( end model = DEQChain(initial_layers, deq, final_layers) - ps, st = device.(EFL.setup(MersenneTwister(seed), model)) - # NOTE: ComponentArrays seem to have some overhead - ps = NamedTuple(ps) - st = NamedTuple(st) + rng = Random.default_rng() + Random.seed!(rng, seed) + ps, st = device.(EFL.setup(rng, model)) + + # Temporary Fix: CUDA.RNG giving errors on Julia 1.7 + st = EFL.update_state(st, :rng, rng) if warmup clean_println("Starting Model Warmup") x__ = device(randn(Float32, config.image_size..., 3, 1)) - y__ = device(Float32.(Flux.onehotbatch([1], 0:9))) + y__ = device(Float32.(Flux.onehotbatch([1], 0:(config.num_classes - 1)))) model(x__, ps, st) clean_println("Forward Pass Warmup Completed") @@ -392,7 +399,7 @@ function get_model( (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st), ps) back((one(l), nothing, nothing, nothing)) clean_println("Backward Pass Warmup Completed") - + (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st_), ps) back((one(l), nothing, nothing, nothing)) clean_println("Backward Pass (Pretraining) Warmup Completed") diff --git a/src/layers/deq.jl b/src/layers/deq.jl index fdd33f77..8c58c840 100644 --- a/src/layers/deq.jl +++ b/src/layers/deq.jl @@ -30,8 +30,10 @@ function (deq::DeepEquilibriumNetwork{J})( return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, st.fixed_depth)), st end + st_ = st.model + function dudt(u, p, t) - u_, _ = deq.model((u, x), p, st.model) + u_, st_ = deq.model((u, x), p, st_) return u_ .- u end @@ -41,6 +43,8 @@ function (deq::DeepEquilibriumNetwork{J})( jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps, st.model, z_star, x) : T(0)) residual = z_star .- deq.model((z_star, x), ps, st.model)[1] + + st_ = EFL.update_state(st_, :update_mask, true) @set! st.model = st_ return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st @@ -91,8 +95,10 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, st.fixed_depth)), st end + st_ = st.model + function dudt(u, p, t) - u_, = deq.model((u, x), p, st.model) + u_, st_ = deq.model((u, x), p, st_) return u_ .- u end @@ -102,7 +108,9 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : T(0)) residual = z_star .- deq.model((z_star, x), ps.model, st.model)[1] - @set! st.model = st_::typeof(st.model) + + st_ = EFL.update_state(st_, :update_mask, true) + @set! st.model = st_ return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index 79cd78f1..fe35d71d 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -61,9 +61,11 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( return (z_star, DeepEquilibriumSolution(vcat(Flux.flatten.(z_star)...), z, z, 0.0f0, st.fixed_depth)), st end + st_ = st.model + function dudt_(u, p, t) u_split = split_and_reshape(u, st.split_idxs, deq.scales) - u_, st_ = deq.model(((u_split[1], x), u_split[2:N]...), p, st.model) + u_, st_ = deq.model(((u_split[1], x), u_split[2:N]...), p, st_) return u_, st_ end @@ -75,6 +77,7 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( residual = dudt(sol.u, ps, nothing) + st_ = EFL.update_state(st_, :update_mask, true) @set! st.model = st_ return ( @@ -156,9 +159,11 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( return (z_star, DeepEquilibriumSolution(vcat(Flux.flatten.(z_star)...), z, z, 0.0f0, st.fixed_depth)), st end + st_ = st.model + function dudt_(u, p, t) u_split = split_and_reshape(u, st.split_idxs, deq.scales) - u_, st_ = deq.model(((u_split[1], x), u_split[2:N]...), p, st.model) + u_, st_ = deq.model(((u_split[1], x), u_split[2:N]...), p, st_) return u_, st_ end @@ -170,6 +175,7 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( residual = dudt(sol.u, ps.model, nothing) + st_ = EFL.update_state(st_, :update_mask, true) @set! st.model = st_ return ( diff --git a/src/solvers/discrete.jl b/src/solvers/discrete.jl index 8999de77..aecf447f 100644 --- a/src/solvers/discrete.jl +++ b/src/solvers/discrete.jl @@ -25,7 +25,10 @@ struct DiscreteDEQSolver{M,A,T} <: SteadyStateDiffEq.SteadyStateDiffEqAlgorithm end function DiscreteDEQSolver( - alg; mode::Symbol=:rel_deq_default, abstol_termination::T=1.0f-8, reltol_termination::T=1.0f-8 + alg=LimitedMemoryBroydenSolver(); + mode::Symbol=:rel_deq_default, + abstol_termination::T=1.0f-8, + reltol_termination::T=1.0f-8 ) where {T<:Number} return DiscreteDEQSolver{Val(mode),typeof(alg),T}(alg, abstol_termination, reltol_termination) end diff --git a/src/solvers/discrete/limited_memory_broyden.jl b/src/solvers/discrete/limited_memory_broyden.jl index e4e5a2be..d9904f4a 100644 --- a/src/solvers/discrete/limited_memory_broyden.jl +++ b/src/solvers/discrete/limited_memory_broyden.jl @@ -23,17 +23,12 @@ See also: [`BroydenSolver`](@ref) """ struct LimitedMemoryBroydenSolver end -function nlsolve(l::LimitedMemoryBroydenSolver, f::Function, y::AbstractMatrix; kwargs...) - res, stats = nlsolve(l, f, reshape(y, size(y, 1), 1, size(y, 2)); kwargs...) - return dropdims(res; dims=2), stats -end - -function nlsolve( - ::LimitedMemoryBroydenSolver, f::Function, y::AbstractArray{T,3}; terminate_condition, maxiters::Int=10 +@inbounds @views function nlsolve( + ::LimitedMemoryBroydenSolver, f::Function, y::AbstractMatrix{T}; terminate_condition, maxiters::Int=10 ) where {T} LBFGS_threshold = min(maxiters, 27) - total_hsize, n_elem, batch_size = size(y) + total_hsize, batch_size = size(y) # Initialize the cache x₀ = copy(y) @@ -41,8 +36,8 @@ function nlsolve( x₁ = copy(y) Δx = copy(x₀) Δfx = copy(x₀) - Us = fill!(similar(y, (LBFGS_threshold, total_hsize, n_elem, batch_size)), T(0)) - VTs = fill!(similar(y, (total_hsize, n_elem, LBFGS_threshold, batch_size)), T(0)) + Us = fill!(similar(y, (LBFGS_threshold, total_hsize, batch_size)), T(0)) + VTs = fill!(similar(y, (total_hsize, LBFGS_threshold, batch_size)), T(0)) # Counters nstep = 1 @@ -61,20 +56,20 @@ function nlsolve( terminate_condition(fx₁, x₁) && break # Compute the update - @views part_Us = Us[1:min(LBFGS_threshold, nstep), :, :, :] - @views part_VTs = VTs[:, :, 1:min(LBFGS_threshold, nstep), :] + part_Us = Us[1:min(LBFGS_threshold, nstep), :, :] + part_VTs = VTs[:, 1:min(LBFGS_threshold, nstep), :] vT = rmatvec(part_Us, part_VTs, Δx) # D x C x N mvec = matvec(part_Us, part_VTs, Δfx) vTΔfx = sum(vT .* Δfx; dims=(1, 2)) @. Δx = (Δx - mvec) / (vTΔfx + eps(T)) # D x C x N - @views VTs[:, :, mod1(nstep, LBFGS_threshold), :] .= vT - @views Us[mod1(nstep, LBFGS_threshold), :, :, :] .= Δx + VTs[:, mod1(nstep, LBFGS_threshold), :] .= vT + Us[mod1(nstep, LBFGS_threshold), :, :] .= Δx - @views update = + update = -matvec( - Us[1:min(LBFGS_threshold, nstep + 1), :, :, :], VTs[:, :, 1:min(LBFGS_threshold, nstep + 1), :], fx₁ + Us[1:min(LBFGS_threshold, nstep + 1), :, :], VTs[:, 1:min(LBFGS_threshold, nstep + 1), :], fx₁ ) copyto!(x₀, x₁) copyto!(fx₀, fx₁) @@ -86,16 +81,18 @@ function nlsolve( return x₁, (nf=nstep + 1,) end -@inbounds function matvec(part_Us::AbstractArray{E,4}, part_VTs::AbstractArray{E,4}, x::AbstractArray{E,3}) where {E} - # part_Us -> (T x D x C x N) | part_VTs -> (D x C x T x N) | x -> (D x C x N) - _, D, C, N = size(part_Us) - xTU = sum(reshape(x, (1, D, C, N)) .* part_Us; dims=(2, 3)) # T x 1 x 1 x N - return -x .+ dropdims(sum(permutedims(xTU, (2, 3, 1, 4)) .* part_VTs; dims=3); dims=3) +@inbounds @views function matvec( + part_Us::AbstractArray{E,3}, part_VTs::AbstractArray{E,3}, x::AbstractArray{E,2} +) where {E} + # part_Us -> (T x D x N) | part_VTs -> (D x T x N) | x -> (D x N) + xTU = sum(Flux.unsqueeze(x; dims=1) .* part_Us; dims=2) # T x 1 x N + return -x .+ dropdims(sum(permutedims(xTU, (2, 1, 3)) .* part_VTs; dims=2); dims=2) end -function rmatvec(part_Us::AbstractArray{E,4}, part_VTs::AbstractArray{E,4}, x::AbstractArray{E,3}) where {E} - # part_Us -> (T x D x C x N) | part_VTs -> (D x C x T x N) | x -> (D x C x N) - _, D, C, N = size(part_Us) - VTx = sum(part_VTs .* reshape(x, (D, C, 1, N)); dims=(1, 2)) # 1 x 1 x T x N - return -x .+ dropdims(sum(part_Us .* permutedims(VTx, (3, 1, 2, 4)); dims=1); dims=1) +@inbounds @views function rmatvec( + part_Us::AbstractArray{E,3}, part_VTs::AbstractArray{E,3}, x::AbstractArray{E,2} +) where {E} + # part_Us -> (T x D x N) | part_VTs -> (D x T x N) | x -> (D x N) + VTx = sum(part_VTs .* Flux.unsqueeze(x; dims=2); dims=1) # 1 x T x N + return -x .+ dropdims(sum(part_Us .* permutedims(VTx, (2, 1, 3)); dims=1); dims=1) end diff --git a/src/solvers/termination.jl b/src/solvers/termination.jl index d812f92c..31f21cc2 100644 --- a/src/solvers/termination.jl +++ b/src/solvers/termination.jl @@ -102,7 +102,7 @@ function get_terminate_condition(alg::DiscreteDEQSolver{M,A,T}, args...; kwargs. end # Convergence Criterions -function has_converged( +@inline function has_converged( du, u, alg::Union{ContinuousDEQSolver{M},DiscreteDEQSolver{M}}, @@ -112,18 +112,18 @@ function has_converged( return has_converged(du, u, M, abstol, reltol) end -function has_converged(du, u, M, abstol, reltol) +@inline @inbounds function has_converged(du, u, M, abstol, reltol) mode = get_mode(M) if mode == :norm - return norm(du) <= abstol && norm(du) <= reltol * norm(du .+ u) + return norm(du) <= abstol && norm(du) <= reltol * norm(du + u) elseif mode == :rel return all(abs.(du) .<= reltol .* abs.(u)) elseif mode == :rel_norm - return norm(du) <= reltol * norm(du .+ u) + return norm(du) <= reltol * norm(du + u) elseif mode == :rel_deq_default - return norm(du) <= reltol * norm(du .+ u) + return norm(du) <= reltol * norm(du + u) elseif mode == :rel_deq_best - return norm(du) <= reltol * norm(du .+ u) + return norm(du) <= reltol * norm(du + u) elseif mode == :abs return all(abs.(du) .<= abstol) elseif mode == :abs_norm diff --git a/test/runtests.jl b/test/runtests.jl index a6700c0c..c831b61c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,8 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux @testset "FastDEQ.jl" begin seed = 0 + rng = Random.default_rng() + Random.seed!(rng, seed) @info "Testing DEQ" model = DEQChain( @@ -11,18 +13,23 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), ), ) - ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) - x = gpu(rand(MersenneTwister(seed + 1), Float32, 2, 1)) - y = gpu(rand(MersenneTwister(seed + 2), Float32, 2, 1)) + ps, st = gpu.(EFL.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps) + @time gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps) + @info "Testing DEQ without Fixed Point Iterations" st = EFL.update_state(st, :fixed_depth, 5) gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps) + @time gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps) + @info "Testing SkipDEQ" + Random.seed!(rng, seed) model = DEQChain( EFL.Dense(2, 2), SkipDeepEquilibriumNetwork( @@ -32,15 +39,20 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ), ) - ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) - x = gpu(rand(MersenneTwister(seed + 1), Float32, 2, 1)) - y = gpu(rand(MersenneTwister(seed + 2), Float32, 2, 1)) + ps, st = gpu.(EFL.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end + @time gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end + @info "Testing SkipDEQ without Fixed Point Iterations" st = EFL.update_state(st, :fixed_depth, 5) @@ -49,7 +61,13 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end + @time gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end + @info "Testing SkipDEQV2" + Random.seed!(rng, seed) model = DEQChain( EFL.Dense(2, 2), SkipDeepEquilibriumNetwork( @@ -59,15 +77,20 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ), ) - ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) - x = gpu(rand(MersenneTwister(seed + 1), Float32, 2, 1)) - y = gpu(rand(MersenneTwister(seed + 2), Float32, 2, 1)) + ps, st = gpu.(EFL.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end + @time gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end + @info "Testing SkipDEQV2 without Fixed Point Iterations" st = EFL.update_state(st, :fixed_depth, 5) @@ -76,7 +99,13 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end + @time gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end + @info "Testing SkipDEQ with Broyden Solver" + Random.seed!(rng, seed) model = DEQChain( EFL.Dense(2, 2), SkipDeepEquilibriumNetwork( @@ -86,16 +115,22 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ), ) - ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) - x = gpu(rand(MersenneTwister(seed + 1), Float32, 2, 1)) - y = gpu(rand(MersenneTwister(seed + 2), Float32, 2, 1)) + ps, st = gpu.(EFL.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end + @time gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end + @info "Testing SkipDEQ with L-Broyden Solver" + Random.seed!(rng, seed) model = DEQChain( EFL.Dense(2, 2), SkipDeepEquilibriumNetwork( @@ -105,16 +140,22 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ), ) - ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) - x = gpu(rand(MersenneTwister(seed + 1), Float32, 2, 1)) - y = gpu(rand(MersenneTwister(seed + 2), Float32, 2, 1)) + ps, st = gpu.(EFL.setup(rng, model)) + x = gpu(rand(rng, Float32, 2, 1)) + y = gpu(rand(rng, Float32, 2, 1)) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end + @time gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end + @info "Testing MultiScaleDEQ" + Random.seed!(rng, seed) model = MultiScaleDeepEquilibriumNetwork( ( EFL.Parallel(+, EFL.Dense(4, 4, tanh), EFL.Dense(4, 4, tanh)), @@ -134,15 +175,20 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ) - ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) - x = gpu(rand(MersenneTwister(seed + 1), Float32, 4, 2)) - y = tuple([gpu(rand(MersenneTwister(seed + 1 + i), Float32, i, 2)) for i in 4:-1:1]...) + ps, st = gpu.(EFL.setup(rng, model)) + x = gpu(rand(rng, Float32, 4, 2)) + y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) end + @time gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + end + @info "Testing MultiScaleDEQ without Fixed Point Iterations" st = EFL.update_state(st, :fixed_depth, 5) @@ -151,7 +197,13 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sum(Base.Fix1(sum, abs2), ŷ .- y) end + @time gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + end + @info "Testing MultiScaleSkipDEQ" + Random.seed!(rng, seed) model = MultiScaleSkipDeepEquilibriumNetwork( ( EFL.Parallel(+, EFL.Dense(4, 4, tanh), EFL.Dense(4, 4, tanh)), @@ -172,15 +224,20 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ) - ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) - x = gpu(rand(MersenneTwister(seed + 1), Float32, 4, 2)) - y = tuple([gpu(rand(MersenneTwister(seed + 1 + i), Float32, i, 2)) for i in 4:-1:1]...) + ps, st = gpu.(EFL.setup(rng, model)) + x = gpu(rand(rng, Float32, 4, 2)) + y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end + @time gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end + @info "Testing MultiScaleSkipDEQ without Fixed Point Iterations" st = EFL.update_state(st, :fixed_depth, 5) @@ -189,7 +246,13 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end + @time gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end + @info "Testing MultiScaleSkipDEQV2" + Random.seed!(rng, seed) model = MultiScaleSkipDeepEquilibriumNetwork( ( EFL.Parallel(+, EFL.Dense(4, 4, tanh), EFL.Dense(4, 4, tanh)), @@ -210,15 +273,20 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ) - ps, st = NamedTuple.(gpu.(EFL.setup(MersenneTwister(seed), model))) - x = gpu(rand(MersenneTwister(seed + 1), Float32, 4, 2)) - y = tuple([gpu(rand(MersenneTwister(seed + 1 + i), Float32, i, 2)) for i in 4:-1:1]...) + ps, st = gpu.(EFL.setup(rng, model)) + x = gpu(rand(rng, Float32, 4, 2)) + y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end + @time gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end + @info "Testing MultiScaleSkipDEQV2 without Fixed Point Iterations" st = EFL.update_state(st, :fixed_depth, 5) @@ -226,4 +294,9 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) end + + @time gradient(ps) do p + (ŷ, soln), _ = model(x, p, st) + sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) + end end From 495308789458feb8ffe4cc9a346390b5a346bacf Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 23 Apr 2022 17:11:54 -0400 Subject: [PATCH 44/76] Update --- examples/cifar10/script.jl | 2 +- examples/src/FastDEQExperiments.jl | 3 --- examples/src/config.jl | 14 +++++++------- examples/src/train.jl | 21 +++++++++++++-------- src/adjoint.jl | 1 - 5 files changed, 21 insertions(+), 20 deletions(-) diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index ab9a36bc..d8e0eda1 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -57,7 +57,7 @@ end # Experiment Configurations configs = [] -for seed in [6171, 3859, 2961], model_type in [:VANILLA, :SKIP, :SKIPV2], model_size in [:TINY, :LARGE] +for seed in [6171, 3859, 2961], model_type in [:VANILLA, :SKIP, :SKIPV2], model_size in [:TINY] #, :LARGE] push!( configs, Dict( diff --git a/examples/src/FastDEQExperiments.jl b/examples/src/FastDEQExperiments.jl index bdc5c817..af303d6b 100644 --- a/examples/src/FastDEQExperiments.jl +++ b/examples/src/FastDEQExperiments.jl @@ -20,9 +20,6 @@ import MLDataUtils: nobs, getobs const EFL = ExplicitFluxLayers -# FIXME: Remove once FastDEQ has been updated to use latest EFL -Base.keys(::Nothing) = () - # Memory Management relieve_gc_pressure(::Union{Nothing,<:AbstractArray}) = nothing relieve_gc_pressure(x::CuArray) = CUDA.unsafe_free!(x) diff --git a/examples/src/config.jl b/examples/src/config.jl index 9725bb1b..aaba30bf 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -32,7 +32,7 @@ Base.@kwdef struct ImageClassificationModelConfiguration{N} <: AbstractTaskModel abstol::Float32 = 5f-2 reltol::Float32 = 5f-2 stop_mode::Symbol = :rel_norm - ode_solver = VCABM3() + ode_solver = VCAB3() end function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) @@ -58,7 +58,7 @@ function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) fuse_method=:sum, final_channelsize=200, fwd_maxiters=18, - bwd_maxiters=20, + bwd_maxiters=10, kwargs... ) elseif model_size == :LARGE @@ -82,7 +82,7 @@ function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) fuse_method=:sum, final_channelsize=1680, fwd_maxiters=18, - bwd_maxiters=20, + bwd_maxiters=10, kwargs... ) else @@ -110,7 +110,7 @@ function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) fuse_method=:sum, final_channelsize=2048, fwd_maxiters=27, - bwd_maxiters=28, + bwd_maxiters=15, kwargs... ) elseif model_size == :LARGE @@ -134,7 +134,7 @@ function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) fuse_method=:sum, final_channelsize=2048, fwd_maxiters=27, - bwd_maxiters=28, + bwd_maxiters=15, kwargs... ) elseif model_size == :XL @@ -158,7 +158,7 @@ function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) fuse_method=:sum, final_channelsize=2048, fwd_maxiters=27, - bwd_maxiters=28, + bwd_maxiters=15, kwargs... ) else @@ -213,7 +213,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eval_batchsize=128, train_batchsize=128, nepochs=50, - pretrain_steps=3000 ÷ (is_distributed() ? total_workers() : 1), + pretrain_steps=0 ÷ (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 / 2 * (is_distributed() ? total_workers() : 1), diff --git a/examples/src/train.jl b/examples/src/train.jl index c16901ee..0602ae15 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -88,16 +88,21 @@ end loss_function(e::ExperimentConfiguration, args...; kwargs...) = loss_function(e.model_config, args...; kwargs...) function loss_function(c::ImageClassificationModelConfiguration; λ_skip=1.0f-3) - function loss_function_closure(x, y, model, ps, st) - (ŷ, soln), st_ = model(x, ps, st) - loss = if c.model_type == :VANILLA - Flux.Losses.logitcrossentropy(ŷ, y) - else - Flux.Losses.logitcrossentropy(ŷ, y) + λ_skip * Flux.Losses.mse(soln.u₀, soln.z_star) + if c.model_type == :VANILLA + function loss_function_closure_1(x, y, model, ps, st) + (ŷ, soln), st_ = model(x, ps, st) + loss = Flux.Losses.logitcrossentropy(ŷ, y) + return loss, ŷ, st_, soln.nfe + end + return loss_function_closure_1 + else + function loss_function_closure_2(x, y, model, ps, st) + (ŷ, soln), st_ = model(x, ps, st) + loss = Flux.Losses.logitcrossentropy(ŷ, y) + λ_skip * Flux.Losses.mse(soln.u₀, soln.z_star) + return loss, ŷ, st_, soln.nfe end - return loss, ŷ, st_, soln.nfe + return loss_function_closure_2 end - return loss_function_closure end function train_one_epoch( diff --git a/src/adjoint.jl b/src/adjoint.jl index b54f7d85..5968ff17 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -24,7 +24,6 @@ neg(nt::NamedTuple) = fmap(neg, nt) s_val = size(_val) op = ZygotePullbackMultiplyOperator{eltype(y),typeof(back),typeof(s_val)}(back, s_val) linear_problem = LinearProblem(op, vec(diffcache.dg_val)) - ## Automatically choose the best algorithm λ = solve(linear_problem, sensealg.linsolve).u # Compute the VJP From 0aef6720b9aa14ade9038493823b0bb6a48b5dae Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 24 Apr 2022 09:31:45 -0400 Subject: [PATCH 45/76] Use pretraining --- examples/Manifest.toml | 4 +--- examples/src/config.jl | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 355bde6a..5648adcd 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -402,9 +402,7 @@ version = "0.0.29+0" [[deps.ExplicitFluxLayers]] deps = ["CUDA", "ChainRulesCore", "ComponentArrays", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Optimisers", "Random", "Setfield", "SparseArrays", "Statistics", "Yota", "Zygote"] -git-tree-sha1 = "ea89d52a2be10118b7b5373ef11c277a0c206a16" -repo-rev = "ap/sparse" -repo-url = "https://github.com/avik-pal/ExplicitFluxLayers.jl.git" +path = "/mnt/research/softwares/ExplicitFluxLayers/" uuid = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" version = "0.2.0" diff --git a/examples/src/config.jl b/examples/src/config.jl index aaba30bf..257e90c3 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -213,7 +213,7 @@ function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) eval_batchsize=128, train_batchsize=128, nepochs=50, - pretrain_steps=0 ÷ (is_distributed() ? total_workers() : 1), + pretrain_steps=3000 ÷ (is_distributed() ? total_workers() : 1), lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 / 2 * (is_distributed() ? total_workers() : 1), From 3a12e0a390fb68fd651996573cab35f322ef2e10 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 24 Apr 2022 15:03:26 -0400 Subject: [PATCH 46/76] Fix backpass for variational autoencoder --- examples/src/config.jl | 10 +++++----- src/layers/deq.jl | 10 ++++------ src/layers/mdeq.jl | 10 ++++------ 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/examples/src/config.jl b/examples/src/config.jl index 257e90c3..e894ba04 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -58,7 +58,7 @@ function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) fuse_method=:sum, final_channelsize=200, fwd_maxiters=18, - bwd_maxiters=10, + bwd_maxiters=20, kwargs... ) elseif model_size == :LARGE @@ -82,7 +82,7 @@ function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) fuse_method=:sum, final_channelsize=1680, fwd_maxiters=18, - bwd_maxiters=10, + bwd_maxiters=20, kwargs... ) else @@ -110,7 +110,7 @@ function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) fuse_method=:sum, final_channelsize=2048, fwd_maxiters=27, - bwd_maxiters=15, + bwd_maxiters=28, kwargs... ) elseif model_size == :LARGE @@ -134,7 +134,7 @@ function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) fuse_method=:sum, final_channelsize=2048, fwd_maxiters=27, - bwd_maxiters=15, + bwd_maxiters=28, kwargs... ) elseif model_size == :XL @@ -158,7 +158,7 @@ function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) fuse_method=:sum, final_channelsize=2048, fwd_maxiters=27, - bwd_maxiters=15, + bwd_maxiters=28, kwargs... ) else diff --git a/src/layers/deq.jl b/src/layers/deq.jl index 8c58c840..f1688e88 100644 --- a/src/layers/deq.jl +++ b/src/layers/deq.jl @@ -25,7 +25,7 @@ function (deq::DeepEquilibriumNetwork{J})( for _ in 1:(st.fixed_depth) z_star, st_ = deq.model((z_star, x), ps, st_) end - @set! st.model = st_ + @set! st.model = EFL.update_state(st_, :update_mask, true) return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, st.fixed_depth)), st end @@ -44,8 +44,7 @@ function (deq::DeepEquilibriumNetwork{J})( jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps, st.model, z_star, x) : T(0)) residual = z_star .- deq.model((z_star, x), ps, st.model)[1] - st_ = EFL.update_state(st_, :update_mask, true) - @set! st.model = st_ + @set! st.model = EFL.update_state(st_, :update_mask, true) return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end @@ -90,7 +89,7 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( for _ in 1:(st.fixed_depth) z_star, st_ = deq.model((z_star, x), ps.model, st_) end - @set! st.model = st_ + @set! st.model = EFL.update_state(st_, :update_mask, true) return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, st.fixed_depth)), st end @@ -109,8 +108,7 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : T(0)) residual = z_star .- deq.model((z_star, x), ps.model, st.model)[1] - st_ = EFL.update_state(st_, :update_mask, true) - @set! st.model = st_ + @set! st.model = EFL.update_state(st_, :update_mask, true) return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index fe35d71d..04614419 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -56,7 +56,7 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps, st_) end - @set! st.model = st_ + @set! st.model = EFL.update_state(st_, :update_mask, true) return (z_star, DeepEquilibriumSolution(vcat(Flux.flatten.(z_star)...), z, z, 0.0f0, st.fixed_depth)), st end @@ -77,8 +77,7 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( residual = dudt(sol.u, ps, nothing) - st_ = EFL.update_state(st_, :update_mask, true) - @set! st.model = st_ + @set! st.model = EFL.update_state(st_, :update_mask, true) return ( (z_star, DeepEquilibriumSolution(vcat(Flux.flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st @@ -154,7 +153,7 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps.model, st_) end - @set! st.model = st_ + @set! st.model = EFL.update_state(st_, :update_mask, true) return (z_star, DeepEquilibriumSolution(vcat(Flux.flatten.(z_star)...), z, z, 0.0f0, st.fixed_depth)), st end @@ -175,8 +174,7 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( residual = dudt(sol.u, ps.model, nothing) - st_ = EFL.update_state(st_, :update_mask, true) - @set! st.model = st_ + @set! st.model = EFL.update_state(st_, :update_mask, true) return ( (z_star, DeepEquilibriumSolution(vcat(Flux.flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st From 9c2626d6080bbda3c06b81d2463744f5e395003f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 24 Apr 2022 15:04:13 -0400 Subject: [PATCH 47/76] Fix dep --- examples/Manifest.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 5648adcd..bdc1b7ff 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -402,7 +402,9 @@ version = "0.0.29+0" [[deps.ExplicitFluxLayers]] deps = ["CUDA", "ChainRulesCore", "ComponentArrays", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Optimisers", "Random", "Setfield", "SparseArrays", "Statistics", "Yota", "Zygote"] -path = "/mnt/research/softwares/ExplicitFluxLayers/" +git-tree-sha1 = "a1182b44d76f7df48f9e4201ba3f8c9d918a15d6" +repo-rev = "ap/sparse" +repo-url = "https://github.com/avik-pal/ExplicitFluxLayers.jl.git" uuid = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" version = "0.2.0" From 0b6a927b344bd5ef6978fbfacf035b9a63eb7403 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 May 2022 15:30:31 -0400 Subject: [PATCH 48/76] Partial update to Lux --- .github/workflows/CI.yml | 3 +- .gitignore | 7 +- Project.toml | 13 +- src/FastDEQ.jl | 30 ++- src/layers/chain.jl | 4 +- src/layers/deq.jl | 8 +- src/layers/jacobian_stabilization.jl | 2 +- src/layers/mdeq.jl | 50 ++--- .../discrete/limited_memory_broyden.jl | 4 +- test/runtests.jl | 210 ++++++++---------- 10 files changed, 153 insertions(+), 178 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 286a86cd..4ae2716e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -29,7 +29,8 @@ jobs: ${{ runner.os }}-test- ${{ runner.os }}- - name: Install dependencies - run: julia --project=. -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/ExplicitFluxLayers.jl", rev="ap/sparse"); Pkg.instantiate()' + # FIXME: Remove once Lux.jl is registered + run: julia --project=. -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/Lux.jl"); Pkg.instantiate()' - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 with: diff --git a/.gitignore b/.gitignore index 3ebae6b8..2f54b33e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,9 @@ wandb/ .vscode data/ /Manifest.toml -build \ No newline at end of file +docs/Manifest.toml +build +statprof +profs +logs +benchmarking \ No newline at end of file diff --git a/Project.toml b/Project.toml index 566b92d6..9f38cdf6 100644 --- a/Project.toml +++ b/Project.toml @@ -11,11 +11,11 @@ DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2" -ExplicitFluxLayers = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -33,10 +33,10 @@ ChainRulesCore = "1" DiffEqBase = "6" DiffEqCallbacks = "2.20.1" DiffEqSensitivity = "6.64" -ExplicitFluxLayers = "0.2" -Flux = "0.13" Functors = "0.2" LinearSolve = "1" +Lux = "0.3" +MLUtils = "0.2" OrdinaryDiffEq = "6" SciMLBase = "1.19" Setfield = "0.8.2" @@ -47,11 +47,10 @@ julia = "1.7" [extras] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -ExplicitFluxLayers = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["CUDA", "Flux", "ExplicitFluxLayers", "LinearAlgebra", "Random", "Test"] +test = ["CUDA", "LinearAlgebra", "Lux", "Random", "Test"] diff --git a/src/FastDEQ.jl b/src/FastDEQ.jl index 831cbf3e..f6849bd2 100644 --- a/src/FastDEQ.jl +++ b/src/FastDEQ.jl @@ -1,33 +1,31 @@ module FastDEQ -using CUDA, +using ChainRulesCore, + ComponentArrays, + CUDA, DiffEqBase, DiffEqCallbacks, DiffEqSensitivity, - Flux, + Functors, LinearAlgebra, LinearSolve, + Lux, + MLUtils, OrdinaryDiffEq, SciMLBase, + Setfield, Statistics, SteadyStateDiffEq, UnPack, - Zygote, - ExplicitFluxLayers, - Functors, - ChainRulesCore, - ComponentArrays, - Setfield - -import ExplicitFluxLayers: - AbstractExplicitLayer, - AbstractExplicitContainerLayer, - initialparameters, - initialstates, - parameterlength, - statelength + Zygote + +import Lux: AbstractExplicitContainerLayer, initialparameters, initialstates, parameterlength, statelength import Random: AbstractRNG +# This shouldn't be put in Lux since it is not true in the general case +# However for our usecase gradients dont propagate through the state +ChainRulesCore.@non_differentiable Lux.update_state(::Any...) + include("operator.jl") include("solvers/continuous.jl") diff --git a/src/layers/chain.jl b/src/layers/chain.jl index 36bede0d..9706e972 100644 --- a/src/layers/chain.jl +++ b/src/layers/chain.jl @@ -16,8 +16,8 @@ function DEQChain(layers...) push!(encounter_deq ? post_deq : pre_deq, l) end @assert encounter_deq "No DEQ Layer in the Chain!!! Maybe you wanted to use Chain" - pre_deq = length(pre_deq) == 0 ? nothing : ExplicitFluxLayers.Chain(pre_deq...) - post_deq = length(post_deq) == 0 ? nothing : ExplicitFluxLayers.Chain(post_deq...) + pre_deq = length(pre_deq) == 0 ? nothing : Chain(pre_deq...) + post_deq = length(post_deq) == 0 ? nothing : Chain(post_deq...) return DEQChain(pre_deq, deq, post_deq) end diff --git a/src/layers/deq.jl b/src/layers/deq.jl index f1688e88..e4a83081 100644 --- a/src/layers/deq.jl +++ b/src/layers/deq.jl @@ -25,7 +25,7 @@ function (deq::DeepEquilibriumNetwork{J})( for _ in 1:(st.fixed_depth) z_star, st_ = deq.model((z_star, x), ps, st_) end - @set! st.model = EFL.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, true) return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, st.fixed_depth)), st end @@ -44,7 +44,7 @@ function (deq::DeepEquilibriumNetwork{J})( jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps, st.model, z_star, x) : T(0)) residual = z_star .- deq.model((z_star, x), ps, st.model)[1] - @set! st.model = EFL.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, true) return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end @@ -89,7 +89,7 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( for _ in 1:(st.fixed_depth) z_star, st_ = deq.model((z_star, x), ps.model, st_) end - @set! st.model = EFL.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, true) return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, st.fixed_depth)), st end @@ -108,7 +108,7 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : T(0)) residual = z_star .- deq.model((z_star, x), ps.model, st.model)[1] - @set! st.model = EFL.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, true) return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end diff --git a/src/layers/jacobian_stabilization.jl b/src/layers/jacobian_stabilization.jl index d87a64c6..dc2439c0 100644 --- a/src/layers/jacobian_stabilization.jl +++ b/src/layers/jacobian_stabilization.jl @@ -1,6 +1,6 @@ # Doesn't work as of now function compute_deq_jacobian_loss( - model::AbstractExplicitLayer, ps::ComponentArray, st::NamedTuple, z::AbstractArray, x::AbstractArray + model, ps::ComponentArray, st::NamedTuple, z::AbstractArray, x::AbstractArray ) l, back = Zygote.pullback(u -> model((u, x), ps, st)[1], z) vjp_z = back(gaussian_like(l))[1] diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index 04614419..1a5455e3 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -19,15 +19,13 @@ function MultiScaleDeepEquilibriumNetwork( sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs..., ) - l1 = ExplicitFluxLayers.Parallel(nothing, main_layers...) - l2 = ExplicitFluxLayers.BranchLayer( - ExplicitFluxLayers.Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)... - ) + l1 = Parallel(nothing, main_layers...) + l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) model = if post_fuse_layer === nothing - ExplicitFluxLayers.Chain(l1, l2) + Chain(l1, l2) else - l3 = ExplicitFluxLayers.Parallel(nothing, post_fuse_layer...) - ExplicitFluxLayers.Chain(l1, l2, l3) + l3 = Parallel(nothing, post_fuse_layer...) + Chain(l1, l2, l3) end return MultiScaleDeepEquilibriumNetwork(model, solver, sensealg, scales, kwargs) end @@ -56,9 +54,9 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps, st_) end - @set! st.model = EFL.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, true) - return (z_star, DeepEquilibriumSolution(vcat(Flux.flatten.(z_star)...), z, z, 0.0f0, st.fixed_depth)), st + return (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, z, 0.0f0, st.fixed_depth)), st end st_ = st.model @@ -69,7 +67,7 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( return u_, st_ end - dudt(u, p, t) = vcat(Flux.flatten.(dudt_(u, p, t)[1])...) .- u + dudt(u, p, t) = vcat(flatten.(dudt_(u, p, t)[1])...) .- u prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps) sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) @@ -77,10 +75,10 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( residual = dudt(sol.u, ps, nothing) - @set! st.model = EFL.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, true) return ( - (z_star, DeepEquilibriumSolution(vcat(Flux.flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st + (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st ) end @@ -112,20 +110,18 @@ function MultiScaleSkipDeepEquilibriumNetwork( sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs..., ) - l1 = ExplicitFluxLayers.Parallel(nothing, main_layers...) - l2 = ExplicitFluxLayers.BranchLayer( - ExplicitFluxLayers.Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)... - ) + l1 = Parallel(nothing, main_layers...) + l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) model = if post_fuse_layer === nothing - ExplicitFluxLayers.Chain(l1, l2) + Chain(l1, l2) else - l3 = ExplicitFluxLayers.Parallel(nothing, post_fuse_layer...) - ExplicitFluxLayers.Chain(l1, l2, l3) + l3 = Parallel(nothing, post_fuse_layer...) + Chain(l1, l2, l3) end shortcut = if shortcut_layers === nothing nothing else - ExplicitFluxLayers.Parallel(nothing, shortcut_layers...) + Parallel(nothing, shortcut_layers...) end return MultiScaleSkipDeepEquilibriumNetwork(model, shortcut, solver, sensealg, scales, kwargs) end @@ -138,11 +134,11 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( u0_ = split_and_reshape(u0, st.split_idxs, deq.scales) z0, st__ = deq.model(((u0_[1], x), u0_[2:N]...), ps.model, st_.model) @set! st_.model = st__ - (vcat(Flux.flatten.(z0)...), st_) + (vcat(flatten.(z0)...), st_) else z0, st_ = deq.shortcut(x, ps.shortcut, st.shortcut) @set! st.shortcut = st_ - (vcat(Flux.flatten.(z0)...), st) + (vcat(flatten.(z0)...), st) end if !iszero(st.fixed_depth) @@ -153,9 +149,9 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps.model, st_) end - @set! st.model = EFL.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, true) - return (z_star, DeepEquilibriumSolution(vcat(Flux.flatten.(z_star)...), z, z, 0.0f0, st.fixed_depth)), st + return (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, z, 0.0f0, st.fixed_depth)), st end st_ = st.model @@ -166,7 +162,7 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( return u_, st_ end - dudt(u, p, t) = vcat(Flux.flatten.(dudt_(u, p, t)[1])...) .- u + dudt(u, p, t) = vcat(flatten.(dudt_(u, p, t)[1])...) .- u prob = SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model) sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) @@ -174,9 +170,9 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( residual = dudt(sol.u, ps.model, nothing) - @set! st.model = EFL.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, true) return ( - (z_star, DeepEquilibriumSolution(vcat(Flux.flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st + (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st ) end diff --git a/src/solvers/discrete/limited_memory_broyden.jl b/src/solvers/discrete/limited_memory_broyden.jl index d9904f4a..63c896cc 100644 --- a/src/solvers/discrete/limited_memory_broyden.jl +++ b/src/solvers/discrete/limited_memory_broyden.jl @@ -85,7 +85,7 @@ end part_Us::AbstractArray{E,3}, part_VTs::AbstractArray{E,3}, x::AbstractArray{E,2} ) where {E} # part_Us -> (T x D x N) | part_VTs -> (D x T x N) | x -> (D x N) - xTU = sum(Flux.unsqueeze(x; dims=1) .* part_Us; dims=2) # T x 1 x N + xTU = sum(unsqueeze(x; dims=1) .* part_Us; dims=2) # T x 1 x N return -x .+ dropdims(sum(permutedims(xTU, (2, 1, 3)) .* part_VTs; dims=2); dims=2) end @@ -93,6 +93,6 @@ end part_Us::AbstractArray{E,3}, part_VTs::AbstractArray{E,3}, x::AbstractArray{E,2} ) where {E} # part_Us -> (T x D x N) | part_VTs -> (D x T x N) | x -> (D x N) - VTx = sum(part_VTs .* Flux.unsqueeze(x; dims=2); dims=1) # 1 x T x N + VTx = sum(part_VTs .* unsqueeze(x; dims=2); dims=1) # 1 x T x N return -x .+ dropdims(sum(part_Us .* permutedims(VTx, (2, 1, 3)); dims=1); dims=1) end diff --git a/test/runtests.jl b/test/runtests.jl index c831b61c..8bfb7b8f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,16 @@ -using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux +using CUDA, FastDEQ, Functors, LinearAlgebra, Lux, Random, Test, Zygote + +function test_gradient_isfinite(gs::NamedTuple) + gradient_is_finite = [true] + function is_gradient_finite(x) + if !isnothing(x) && !all(isfinite, x) + gradient_is_finite[1] = false + end + return x + end + fmap(is_gradient_finite, gs) + return gradient_is_finite[1] +end @testset "FastDEQ.jl" begin seed = 0 @@ -7,167 +19,149 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux @info "Testing DEQ" model = DEQChain( - EFL.Dense(2, 2), + Dense(2, 2), DeepEquilibriumNetwork( - EFL.Parallel(+, EFL.Dense(2, 2; bias=false), EFL.Dense(2, 2; bias=false)), + Parallel(+, Dense(2, 2; bias=false), Dense(2, 2; bias=false)), ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), ), ) - ps, st = gpu.(EFL.setup(rng, model)) + ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) - gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps) + gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1] - @time gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps) + @test test_gradient_isfinite(gs) @info "Testing DEQ without Fixed Point Iterations" - st = EFL.update_state(st, :fixed_depth, 5) + st = Lux.update_state(st, :fixed_depth, 5) - gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps) + gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1] - @time gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps) + @test test_gradient_isfinite(gs) @info "Testing SkipDEQ" Random.seed!(rng, seed) model = DEQChain( - EFL.Dense(2, 2), + Dense(2, 2), SkipDeepEquilibriumNetwork( - EFL.Parallel(+, EFL.Dense(2, 2), EFL.Dense(2, 2)), - EFL.Dense(2, 2), + Parallel(+, Dense(2, 2), Dense(2, 2)), + Dense(2, 2), ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0); sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ), ) - ps, st = gpu.(EFL.setup(rng, model)) + ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + end[1] - @time gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + @test test_gradient_isfinite(gs) @info "Testing SkipDEQ without Fixed Point Iterations" - st = EFL.update_state(st, :fixed_depth, 5) + st = Lux.update_state(st, :fixed_depth, 5) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + end[1] - @time gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + @test test_gradient_isfinite(gs) @info "Testing SkipDEQV2" Random.seed!(rng, seed) model = DEQChain( - EFL.Dense(2, 2), + Dense(2, 2), SkipDeepEquilibriumNetwork( - EFL.Parallel(+, EFL.Dense(2, 2), EFL.Dense(2, 2)), + Parallel(+, Dense(2, 2), Dense(2, 2)), nothing, ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0); sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ), ) - ps, st = gpu.(EFL.setup(rng, model)) + ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + end[1] - @time gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + @test test_gradient_isfinite(gs) @info "Testing SkipDEQV2 without Fixed Point Iterations" - st = EFL.update_state(st, :fixed_depth, 5) + st = Lux.update_state(st, :fixed_depth, 5) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + end[1] - @time gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + @test test_gradient_isfinite(gs) @info "Testing SkipDEQ with Broyden Solver" Random.seed!(rng, seed) model = DEQChain( - EFL.Dense(2, 2), + Dense(2, 2), SkipDeepEquilibriumNetwork( - EFL.Parallel(+, EFL.Dense(2, 2), EFL.Dense(2, 2)), - EFL.Dense(2, 2), + Parallel(+, Dense(2, 2), Dense(2, 2)), + Dense(2, 2), DiscreteDEQSolver(BroydenSolver(); abstol_termination=0.1f0, reltol_termination=0.1f0); sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ), ) - ps, st = gpu.(EFL.setup(rng, model)) + ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + end[1] - @time gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + @test test_gradient_isfinite(gs) @info "Testing SkipDEQ with L-Broyden Solver" Random.seed!(rng, seed) model = DEQChain( - EFL.Dense(2, 2), + Dense(2, 2), SkipDeepEquilibriumNetwork( - EFL.Parallel(+, EFL.Dense(2, 2), EFL.Dense(2, 2)), - EFL.Dense(2, 2), + Parallel(+, Dense(2, 2), Dense(2, 2)), + Dense(2, 2), DiscreteDEQSolver(LimitedMemoryBroydenSolver(); abstol_termination=0.1f0, reltol_termination=0.1f0); sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ), ) - ps, st = gpu.(EFL.setup(rng, model)) + ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + end[1] - @time gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + @test test_gradient_isfinite(gs) @info "Testing MultiScaleDEQ" Random.seed!(rng, seed) model = MultiScaleDeepEquilibriumNetwork( ( - EFL.Parallel(+, EFL.Dense(4, 4, tanh), EFL.Dense(4, 4, tanh)), - EFL.Dense(3, 3, tanh), - EFL.Dense(2, 2, tanh), - EFL.Dense(1, 1, tanh), + Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh), ), [ - EFL.NoOpLayer() EFL.Dense(4, 3, tanh) EFL.Dense(4, 2, tanh) EFL.Dense(4, 1, tanh) - EFL.Dense(3, 4, tanh) EFL.NoOpLayer() EFL.Dense(3, 2, tanh) EFL.Dense(3, 1, tanh) - EFL.Dense(2, 4, tanh) EFL.Dense(2, 3, tanh) EFL.NoOpLayer() EFL.Dense(2, 1, tanh) - EFL.Dense(1, 4, tanh) EFL.Dense(1, 3, tanh) EFL.Dense(1, 2, tanh) EFL.NoOpLayer() + NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() ], nothing, ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), @@ -175,96 +169,84 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ) - ps, st = gpu.(EFL.setup(rng, model)) + ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 4, 2)) y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) - end + end[1] - @time gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(Base.Fix1(sum, abs2), ŷ .- y) - end + @test test_gradient_isfinite(gs) @info "Testing MultiScaleDEQ without Fixed Point Iterations" - st = EFL.update_state(st, :fixed_depth, 5) + st = Lux.update_state(st, :fixed_depth, 5) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) - end + end[1] - @time gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(Base.Fix1(sum, abs2), ŷ .- y) - end + @test test_gradient_isfinite(gs) @info "Testing MultiScaleSkipDEQ" Random.seed!(rng, seed) model = MultiScaleSkipDeepEquilibriumNetwork( ( - EFL.Parallel(+, EFL.Dense(4, 4, tanh), EFL.Dense(4, 4, tanh)), - EFL.Dense(3, 3, tanh), - EFL.Dense(2, 2, tanh), - EFL.Dense(1, 1, tanh), + Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh), ), [ - EFL.NoOpLayer() EFL.Dense(4, 3, tanh) EFL.Dense(4, 2, tanh) EFL.Dense(4, 1, tanh) - EFL.Dense(3, 4, tanh) EFL.NoOpLayer() EFL.Dense(3, 2, tanh) EFL.Dense(3, 1, tanh) - EFL.Dense(2, 4, tanh) EFL.Dense(2, 3, tanh) EFL.NoOpLayer() EFL.Dense(2, 1, tanh) - EFL.Dense(1, 4, tanh) EFL.Dense(1, 3, tanh) EFL.Dense(1, 2, tanh) EFL.NoOpLayer() + NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() ], nothing, - (EFL.Dense(4, 4, tanh), EFL.Dense(4, 3, tanh), EFL.Dense(4, 2, tanh), EFL.Dense(4, 1, tanh)), + (Dense(4, 4, tanh), Dense(4, 3, tanh), Dense(4, 2, tanh), Dense(4, 1, tanh)), ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), ((4,), (3,), (2,), (1,)); sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ) - ps, st = gpu.(EFL.setup(rng, model)) + ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 4, 2)) y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + end[1] - @time gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + @test test_gradient_isfinite(gs) @info "Testing MultiScaleSkipDEQ without Fixed Point Iterations" - st = EFL.update_state(st, :fixed_depth, 5) + st = Lux.update_state(st, :fixed_depth, 5) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + end[1] - @time gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + @test test_gradient_isfinite(gs) @info "Testing MultiScaleSkipDEQV2" Random.seed!(rng, seed) model = MultiScaleSkipDeepEquilibriumNetwork( ( - EFL.Parallel(+, EFL.Dense(4, 4, tanh), EFL.Dense(4, 4, tanh)), - EFL.Dense(3, 3, tanh), - EFL.Dense(2, 2, tanh), - EFL.Dense(1, 1, tanh), + Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh), ), [ - EFL.NoOpLayer() EFL.Dense(4, 3, tanh) EFL.Dense(4, 2, tanh) EFL.Dense(4, 1, tanh) - EFL.Dense(3, 4, tanh) EFL.NoOpLayer() EFL.Dense(3, 2, tanh) EFL.Dense(3, 1, tanh) - EFL.Dense(2, 4, tanh) EFL.Dense(2, 3, tanh) EFL.NoOpLayer() EFL.Dense(2, 1, tanh) - EFL.Dense(1, 4, tanh) EFL.Dense(1, 3, tanh) EFL.Dense(1, 2, tanh) EFL.NoOpLayer() + NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() ], nothing, nothing, @@ -273,30 +255,24 @@ using FastDEQ, CUDA, LinearAlgebra, Random, Test, ExplicitFluxLayers, Flux sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), ) - ps, st = gpu.(EFL.setup(rng, model)) + ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 4, 2)) y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + end[1] - @time gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + @test test_gradient_isfinite(gs) @info "Testing MultiScaleSkipDEQV2 without Fixed Point Iterations" - st = EFL.update_state(st, :fixed_depth, 5) + st = Lux.update_state(st, :fixed_depth, 5) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + end[1] - @time gradient(ps) do p - (ŷ, soln), _ = model(x, p, st) - sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) - end + @test test_gradient_isfinite(gs) end From 336a19aea5e2ce2ca5eb2787952fad9af15ae8a9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 May 2022 17:07:47 -0400 Subject: [PATCH 49/76] Update FastDEQExperiments --- .gitignore | 3 +- examples/Manifest.toml | 1683 ---------------------------- examples/Project.toml | 6 +- examples/src/FastDEQExperiments.jl | 15 +- examples/src/dataloaders.jl | 10 +- examples/src/models.jl | 201 ++-- examples/src/train.jl | 16 +- 7 files changed, 122 insertions(+), 1812 deletions(-) delete mode 100644 examples/Manifest.toml diff --git a/.gitignore b/.gitignore index 2f54b33e..b346baea 100644 --- a/.gitignore +++ b/.gitignore @@ -2,8 +2,7 @@ wandb/ .vscode data/ -/Manifest.toml -docs/Manifest.toml +Manifest.toml build statprof profs diff --git a/examples/Manifest.toml b/examples/Manifest.toml deleted file mode 100644 index bdc1b7ff..00000000 --- a/examples/Manifest.toml +++ /dev/null @@ -1,1683 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.7.2" -manifest_format = "2.0" -project_hash = "55e1c12df8760eecee422230ecbed81b99e0268c" - -[[deps.AbstractFFTs]] -deps = ["ChainRulesCore", "LinearAlgebra"] -git-tree-sha1 = "6f1d9bc1c08f9f4a8fa92e3ea3cb50153a1b40d4" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.1.0" - -[[deps.Accessors]] -deps = ["Compat", "CompositionsBase", "ConstructionBase", "Future", "LinearAlgebra", "MacroTools", "Requires", "Test"] -git-tree-sha1 = "2bba2aa45df94e95b1a9c2405d7cfc3d60281db8" -uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.9" - -[[deps.Adapt]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "af92965fb30777147966f58acb05da51c5616b5f" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.3.3" - -[[deps.ArgCheck]] -git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.3.0" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" - -[[deps.ArnoldiMethod]] -deps = ["LinearAlgebra", "Random", "StaticArrays"] -git-tree-sha1 = "62e51b39331de8911e4a7ff6f5aaf38a5f4cc0ae" -uuid = "ec485272-7323-5ecc-a04f-4719b315124d" -version = "0.2.0" - -[[deps.ArrayInterface]] -deps = ["Compat", "IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"] -git-tree-sha1 = "c933ce606f6535a7c7b98e1d86d5d1014f730596" -uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "5.0.7" - -[[deps.ArrayLayouts]] -deps = ["FillArrays", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "8b921542ad44cba67f1487e2226446597e0a90af" -uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" -version = "0.8.5" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.AxisAlgorithms]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "WoodburyMatrices"] -git-tree-sha1 = "66771c8d21c8ff5e3a93379480a2307ac36863f7" -uuid = "13072b0f-2c55-5437-9ae7-d433b7a33950" -version = "1.0.1" - -[[deps.BFloat16s]] -deps = ["LinearAlgebra", "Printf", "Random", "Test"] -git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072" -uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.2.0" - -[[deps.BandedMatrices]] -deps = ["ArrayLayouts", "FillArrays", "LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "960ad9a4b34380595500f60add129e178740c3a6" -uuid = "aae01518-5342-5314-be14-df237901396f" -version = "0.17.0" - -[[deps.BangBang]] -deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] -git-tree-sha1 = "b15a6bc52594f5e4a3b825858d1089618871bf9d" -uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.3.36" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.Baselet]] -git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" -uuid = "9718e550-a3fa-408a-8086-8db961cd8217" -version = "0.1.1" - -[[deps.BinDeps]] -deps = ["Libdl", "Pkg", "SHA", "URIParser", "Unicode"] -git-tree-sha1 = "1289b57e8cf019aede076edab0587eb9644175bd" -uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee" -version = "1.0.2" - -[[deps.BinaryProvider]] -deps = ["Libdl", "Logging", "SHA"] -git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058" -uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" -version = "0.5.10" - -[[deps.BitTwiddlingConvenienceFunctions]] -deps = ["Static"] -git-tree-sha1 = "28bbdbf0354959db89358d1d79d421ff31ef0b5e" -uuid = "62783981-4cbd-42fc-bca8-16325de8dc4b" -version = "0.1.3" - -[[deps.BlockArrays]] -deps = ["ArrayLayouts", "FillArrays", "LinearAlgebra"] -git-tree-sha1 = "28c497806c05326e7cadac0c916980d5a9c0e905" -uuid = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" -version = "0.16.14" - -[[deps.BlockBandedMatrices]] -deps = ["ArrayLayouts", "BandedMatrices", "BlockArrays", "FillArrays", "LinearAlgebra", "MatrixFactorizations", "SparseArrays", "Statistics"] -git-tree-sha1 = "646a8081a8f7a728b2c01a1d00a9fa07b678900a" -uuid = "ffab5731-97b5-5995-9138-79e8c1846df0" -version = "0.11.5" - -[[deps.BufferedStreams]] -deps = ["Compat", "Test"] -git-tree-sha1 = "5d55b9486590fdda5905c275bb21ce1f0754020f" -uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" -version = "1.0.0" - -[[deps.CEnum]] -git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.4.2" - -[[deps.CPUSummary]] -deps = ["CpuId", "IfElse", "Static"] -git-tree-sha1 = "baaac45b4462b3b0be16726f38b789bf330fcb7a" -uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" -version = "0.1.21" - -[[deps.CSV]] -deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings"] -git-tree-sha1 = "873fb188a4b9d76549b81465b1f75c82aaf59238" -uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -version = "0.10.4" - -[[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] -git-tree-sha1 = "ba75320aaa092b3e17c020a2d8b9e0a572dbfa6a" -uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "3.9.0" - -[[deps.Cassette]] -git-tree-sha1 = "063b2e77c5537a548c5bf2f44161f1d3e1ab3227" -uuid = "7057c7e9-c182-5462-911a-8362d720325c" -version = "0.3.10" - -[[deps.ChainRules]] -deps = ["ChainRulesCore", "Compat", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics"] -git-tree-sha1 = "f1e926b37a2e1f64388be59b1baff4152eae67b9" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.28.2" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "9950387274246d08af38f6eef8cb5480862a435f" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.14.0" - -[[deps.ChangesOfVariables]] -deps = ["ChainRulesCore", "LinearAlgebra", "Test"] -git-tree-sha1 = "bf98fa45a0a4cee295de98d4c1462be26345b9a1" -uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -version = "0.1.2" - -[[deps.CloseOpenIntervals]] -deps = ["ArrayInterface", "Static"] -git-tree-sha1 = "f576084239e6bdf801007c80e27e2cc2cd963fe0" -uuid = "fb6a15b2-703c-40df-9091-08a04967cfa9" -version = "0.1.6" - -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.0" - -[[deps.ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.0" - -[[deps.CommonSolve]] -git-tree-sha1 = "68a0743f578349ada8bc911a5cbd5a2ef6ed6d1f" -uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" -version = "0.2.0" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "b153278a25dd42c65abbf4e62344f9d22e59191b" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.43.0" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" - -[[deps.ComponentArrays]] -deps = ["ArrayInterface", "ChainRulesCore", "LinearAlgebra", "Requires"] -git-tree-sha1 = "243d8b8afc829a6707bbb1cd00da868703c2ef42" -uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -version = "0.11.15" - -[[deps.CompositeTypes]] -git-tree-sha1 = "d5b014b216dc891e81fea299638e4c10c657b582" -uuid = "b152e2b5-7a66-4b01-a709-34e65c35f657" -version = "0.1.2" - -[[deps.CompositionsBase]] -git-tree-sha1 = "455419f7e328a1a2493cabc6428d79e951349769" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.1" - -[[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f74e9d5388b8620b4cee35d4c5a618dd4dc547f4" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.3.0" - -[[deps.ContextVariablesX]] -deps = ["Compat", "Logging", "UUIDs"] -git-tree-sha1 = "8ccaa8c655bc1b83d2da4d569c9b28254ababd6e" -uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" -version = "0.1.2" - -[[deps.CpuId]] -deps = ["Markdown"] -git-tree-sha1 = "32d125af0fb8ec3f8935896122c5e345709909e5" -uuid = "adafc99b-e345-5852-983c-f28acb93d879" -version = "0.3.0" - -[[deps.Crayons]] -git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.1.1" - -[[deps.DEDataArrays]] -deps = ["ArrayInterface", "DocStringExtensions", "LinearAlgebra", "RecursiveArrayTools", "SciMLBase", "StaticArrays"] -git-tree-sha1 = "5e5f8f363c8c9a2415ef9185c4e0ff6966c87d52" -uuid = "754358af-613d-5f8d-9788-280bf1605d4c" -version = "0.2.2" - -[[deps.DataAPI]] -git-tree-sha1 = "cc70b17275652eb47bc9e5f81635981f13cea5c8" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.9.0" - -[[deps.DataDeps]] -deps = ["BinaryProvider", "HTTP", "Libdl", "Reexport", "SHA", "p7zip_jll"] -git-tree-sha1 = "4f0e41ff461d42cfc62ff0de4f1cd44c6e6b3771" -uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" -version = "0.7.7" - -[[deps.DataFrames]] -deps = ["Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "6c19003824cbebd804a51211fd3bbd81bf1ecad5" -uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.3.3" - -[[deps.DataLoaders]] -deps = ["DocStringExtensions", "LearnBase", "MLDataPattern", "Parameters", "Random", "ThreadPools"] -git-tree-sha1 = "4668e1c3fa50d9b9a91a1810b495b07008a8f6fb" -uuid = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" -version = "0.1.3" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "3daef5523dd2e769dad2365274f760ff5f282c7d" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.11" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DefineSingletons]] -git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" -uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" -version = "0.1.2" - -[[deps.DelimitedFiles]] -deps = ["Mmap"] -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" - -[[deps.DensityInterface]] -deps = ["InverseFunctions", "Test"] -git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" -uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d" -version = "0.4.0" - -[[deps.DiffEqBase]] -deps = ["ArrayInterface", "ChainRulesCore", "DEDataArrays", "DataStructures", "Distributions", "DocStringExtensions", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "IterativeSolvers", "LabelledArrays", "LinearAlgebra", "Logging", "MuladdMacro", "NonlinearSolve", "Parameters", "PreallocationTools", "Printf", "RecursiveArrayTools", "RecursiveFactorization", "Reexport", "Requires", "SciMLBase", "Setfield", "SparseArrays", "StaticArrays", "Statistics", "SuiteSparse", "ZygoteRules"] -git-tree-sha1 = "cde20558d9a50ebef5f173aaa0e6ece8ca563c93" -uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" -version = "6.83.1" - -[[deps.DiffEqCallbacks]] -deps = ["DataStructures", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "NLsolve", "OrdinaryDiffEq", "Parameters", "RecipesBase", "RecursiveArrayTools", "SciMLBase", "StaticArrays"] -git-tree-sha1 = "c4b99e3a199e293e7290eea94ba89364d47ee557" -uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def" -version = "2.22.0" - -[[deps.DiffEqJump]] -deps = ["ArrayInterface", "Compat", "DataStructures", "DiffEqBase", "FunctionWrappers", "Graphs", "LinearAlgebra", "PoissonRandom", "Random", "RandomNumbers", "RecursiveArrayTools", "Reexport", "StaticArrays", "TreeViews", "UnPack"] -git-tree-sha1 = "eec5fd03c26dadc6b20f84d815309d060358e95b" -uuid = "c894b116-72e5-5b58-be3c-e6d8d4ac2b12" -version = "8.3.0" - -[[deps.DiffEqNoiseProcess]] -deps = ["DiffEqBase", "Distributions", "LinearAlgebra", "Optim", "PoissonRandom", "QuadGK", "Random", "Random123", "RandomNumbers", "RecipesBase", "RecursiveArrayTools", "Requires", "ResettableStacks", "SciMLBase", "StaticArrays", "Statistics"] -git-tree-sha1 = "d6839a44a268c69ef0ed927b22a6f43c8a4c2e73" -uuid = "77a26b50-5914-5dd7-bc55-306e6241c503" -version = "5.9.0" - -[[deps.DiffEqOperators]] -deps = ["BandedMatrices", "BlockBandedMatrices", "DiffEqBase", "DomainSets", "ForwardDiff", "LazyArrays", "LazyBandedMatrices", "LinearAlgebra", "LoopVectorization", "NNlib", "NonlinearSolve", "Requires", "RuntimeGeneratedFunctions", "SciMLBase", "SparseArrays", "SparseDiffTools", "StaticArrays", "SuiteSparse"] -git-tree-sha1 = "a7a5cfe90dfa64dba88bc17a4e0b208e403885cf" -uuid = "9fdde737-9c7f-55bf-ade8-46b3f136cc48" -version = "4.42.0" - -[[deps.DiffEqSensitivity]] -deps = ["Adapt", "ArrayInterface", "Cassette", "ChainRulesCore", "DiffEqBase", "DiffEqCallbacks", "DiffEqNoiseProcess", "DiffEqOperators", "DiffRules", "Distributions", "EllipsisNotation", "Enzyme", "FFTW", "FiniteDiff", "ForwardDiff", "GlobalSensitivity", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Parameters", "QuadGK", "QuasiMonteCarlo", "Random", "RandomNumbers", "RecursiveArrayTools", "Reexport", "Requires", "ReverseDiff", "SciMLBase", "SharedArrays", "Statistics", "StochasticDiffEq", "Tracker", "Zygote", "ZygoteRules"] -git-tree-sha1 = "6c6ef510268d7dff2af69e3d74f6080404639d32" -uuid = "41bf760c-e81c-5289-8e54-58b1f1f8abe2" -version = "6.72.0" - -[[deps.DiffResults]] -deps = ["StaticArrays"] -git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.0.3" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "dd933c4ef7b4c270aacd4eb88fa64c147492acf0" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.10.0" - -[[deps.Distances]] -deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "3258d0659f812acde79e8a74b11f17ac06d0ca04" -uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.7" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.Distributions]] -deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "5a4168170ede913a2cd679e53c2123cb4b889795" -uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.53" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "b19534d1895d702889b219c382a6e18010797f0b" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.6" - -[[deps.DomainSets]] -deps = ["CompositeTypes", "IntervalSets", "LinearAlgebra", "StaticArrays", "Statistics"] -git-tree-sha1 = "5f5f0b750ac576bcf2ab1d7782959894b304923e" -uuid = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" -version = "0.5.9" - -[[deps.Downloads]] -deps = ["ArgTools", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" - -[[deps.EllipsisNotation]] -deps = ["ArrayInterface"] -git-tree-sha1 = "d064b0340db45d48893e7604ec95e7a2dc9da904" -uuid = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" -version = "1.5.0" - -[[deps.Enzyme]] -deps = ["Adapt", "CEnum", "Enzyme_jll", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "ObjectFile", "Printf", "Test"] -git-tree-sha1 = "e673706c6fedcac810b678e238c980e89b656968" -uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" -version = "0.9.3" - -[[deps.Enzyme_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "f0a858b1c8b2b103c16f01ab6074e9a83c783781" -uuid = "7cc45869-7501-5eee-bdea-0790c847d4ef" -version = "0.0.29+0" - -[[deps.ExplicitFluxLayers]] -deps = ["CUDA", "ChainRulesCore", "ComponentArrays", "FillArrays", "Flux", "Functors", "LinearAlgebra", "NNlib", "NNlibCUDA", "Optimisers", "Random", "Setfield", "SparseArrays", "Statistics", "Yota", "Zygote"] -git-tree-sha1 = "a1182b44d76f7df48f9e4201ba3f8c9d918a15d6" -repo-rev = "ap/sparse" -repo-url = "https://github.com/avik-pal/ExplicitFluxLayers.jl.git" -uuid = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" -version = "0.2.0" - -[[deps.ExponentialUtilities]] -deps = ["ArrayInterface", "GenericSchur", "LinearAlgebra", "Printf", "Requires", "SparseArrays", "libblastrampoline_jll"] -git-tree-sha1 = "951c44b4af9d1e061d5cf789a30881471604c14c" -uuid = "d4d017d3-3776-5f7e-afef-a10c40355c18" -version = "1.14.0" - -[[deps.ExprTools]] -git-tree-sha1 = "56559bbef6ca5ea0c0818fa5c90320398a6fbf8d" -uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.8" - -[[deps.FFTW]] -deps = ["AbstractFFTs", "FFTW_jll", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"] -git-tree-sha1 = "505876577b5481e50d089c1c68899dfb6faebc62" -uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" -version = "1.4.6" - -[[deps.FFTW_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "c6033cc3892d0ef5bb9cd29b7f2f0331ea5184ea" -uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" -version = "3.3.10+0" - -[[deps.FLoops]] -deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] -git-tree-sha1 = "4391d3ed58db9dc5a9883b23a0578316b4798b1f" -uuid = "cc61a311-1640-44b5-9fba-1b764f453329" -version = "0.2.0" - -[[deps.FLoopsBase]] -deps = ["ContextVariablesX"] -git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" -uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" -version = "0.1.1" - -[[deps.FastBroadcast]] -deps = ["LinearAlgebra", "Polyester", "Static"] -git-tree-sha1 = "b6bf57ec7a3f294c97ae46124705a9e6b906a209" -uuid = "7034ab61-46d4-4ed7-9d0f-46aef9175898" -version = "0.1.15" - -[[deps.FastClosures]] -git-tree-sha1 = "acebe244d53ee1b461970f8910c235b259e772ef" -uuid = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" -version = "0.3.2" - -[[deps.FastDEQ]] -deps = ["CUDA", "ChainRulesCore", "ComponentArrays", "DataLoaders", "DiffEqBase", "DiffEqCallbacks", "DiffEqSensitivity", "ExplicitFluxLayers", "Flux", "Functors", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Random", "Reexport", "Requires", "SciMLBase", "Setfield", "Statistics", "SteadyStateDiffEq", "UnPack", "Zygote"] -path = ".." -uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" -version = "0.1.0" - -[[deps.FileIO]] -deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "80ced645013a5dbdc52cf70329399c35ce007fae" -uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.13.0" - -[[deps.FilePathsBase]] -deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "129b104185df66e408edd6625d480b7f9e9823a0" -uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.18" - -[[deps.FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] -git-tree-sha1 = "246621d23d1f43e3b9c368bf3b72b2331a27c286" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.13.2" - -[[deps.FiniteDiff]] -deps = ["ArrayInterface", "LinearAlgebra", "Requires", "SparseArrays", "StaticArrays"] -git-tree-sha1 = "56956d1e4c1221000b7781104c58c34019792951" -uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" -version = "2.11.0" - -[[deps.FiniteDifferences]] -deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "SparseArrays", "StaticArrays"] -git-tree-sha1 = "0ee1275eb003b6fc7325cb14301665d1072abda1" -uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.24" - -[[deps.FixedPointNumbers]] -deps = ["Statistics"] -git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.4" - -[[deps.Flux]] -deps = ["Adapt", "ArrayInterface", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Test", "Zygote"] -git-tree-sha1 = "e932b26ac243f312af2d9009de08b89be0e01a84" -uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.13.0" - -[[deps.FluxMPI]] -deps = ["CUDA", "ComponentArrays", "Dates", "Flux", "Functors", "LearnBase", "MLDataUtils", "MPI", "Optimisers", "Setfield", "Zygote"] -git-tree-sha1 = "65e52d4bf8600f15c8e640dce2ba80bb9bbc1f16" -repo-rev = "main" -repo-url = "https://github.com/avik-pal/FluxMPI.jl.git" -uuid = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" -version = "0.3.1" - -[[deps.FoldsThreads]] -deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] -git-tree-sha1 = "eb8e1989b9028f7e0985b4268dabe94682249025" -uuid = "9c68100b-dfe1-47cf-94c8-95104e173443" -version = "0.1.1" - -[[deps.Format]] -git-tree-sha1 = "03bcdf8ab1a5b9e6455ccb45c30910d282aa09f4" -uuid = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" -version = "1.3.2" - -[[deps.Formatting]] -deps = ["Printf"] -git-tree-sha1 = "8339d61043228fdd3eb658d86c926cb282ae72a8" -uuid = "59287772-0a20-5a39-b81b-1366585eb4c0" -version = "0.4.2" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "1bd6fc0c344fc0cbee1f42f8d2e7ec8253dda2d2" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.25" - -[[deps.FunctionWrappers]] -git-tree-sha1 = "241552bc2209f0fa068b6415b1942cc0aa486bcc" -uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" -version = "1.1.2" - -[[deps.Functors]] -git-tree-sha1 = "223fffa49ca0ff9ce4f875be001ffe173b2b7de4" -uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.2.8" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.GPUArrays]] -deps = ["Adapt", "LLVM", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"] -git-tree-sha1 = "c783e8883028bf26fb05ed4022c450ef44edd875" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "8.3.2" - -[[deps.GPUCompiler]] -deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "556190e1e0ea3e37d83059fc9aa576f1e2104375" -uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.14.1" - -[[deps.GZip]] -deps = ["Libdl"] -git-tree-sha1 = "039be665faf0b8ae36e089cd694233f5dee3f7d6" -uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" -version = "0.5.1" - -[[deps.GenericSchur]] -deps = ["LinearAlgebra", "Printf"] -git-tree-sha1 = "fb69b2a645fa69ba5f474af09221b9308b160ce6" -uuid = "c145ed77-6b09-5dd9-b285-bf645a82121e" -version = "0.5.3" - -[[deps.Glob]] -git-tree-sha1 = "4df9f7e06108728ebf00a0a11edee4b29a482bb2" -uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" -version = "1.3.0" - -[[deps.GlobalSensitivity]] -deps = ["Distributions", "FFTW", "ForwardDiff", "KernelDensity", "LinearAlgebra", "Parameters", "QuasiMonteCarlo", "Random", "RecursiveArrayTools", "Statistics", "StatsBase", "Trapz"] -git-tree-sha1 = "0324e96625317e8f1cd51196be542de18788e3af" -uuid = "af5da776-676b-467e-8baf-acd8249e4f0f" -version = "1.3.2" - -[[deps.Graphs]] -deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "57c021de207e234108a6f1454003120a1bf350c4" -uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.6.0" - -[[deps.HDF5]] -deps = ["Compat", "HDF5_jll", "Libdl", "Mmap", "Random", "Requires"] -git-tree-sha1 = "36df177c1ce5f399a8de959e5f4b75216fe6c834" -uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" -version = "0.16.7" - -[[deps.HDF5_jll]] -deps = ["Artifacts", "JLLWrappers", "LibCURL_jll", "Libdl", "OpenSSL_jll", "Pkg", "Zlib_jll"] -git-tree-sha1 = "bab67c0d1c4662d2c4be8c6007751b0b6111de5c" -uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" -version = "1.12.1+0" - -[[deps.HTTP]] -deps = ["Base64", "Dates", "IniFile", "Logging", "MbedTLS", "NetworkOptions", "Sockets", "URIs"] -git-tree-sha1 = "0fa77022fe4b511826b39c894c90daf5fce3334a" -uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "0.9.17" - -[[deps.HostCPUFeatures]] -deps = ["BitTwiddlingConvenienceFunctions", "IfElse", "Libdl", "Static"] -git-tree-sha1 = "18be5268cf415b5e27f34980ed25a7d34261aa83" -uuid = "3e5b6fbb-0976-4d2c-9146-d79de83f2fb0" -version = "0.1.7" - -[[deps.Hwloc]] -deps = ["Hwloc_jll"] -git-tree-sha1 = "92d99146066c5c6888d5a3abc871e6a214388b91" -uuid = "0e44f5e4-bd66-52a0-8798-143a42290a1d" -version = "2.0.0" - -[[deps.Hwloc_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "303d70c961317c4c20fafaf5dbe0e6d610c38542" -uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" -version = "2.7.1+0" - -[[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "7f43342f8d5fd30ead0ba1b49ab1a3af3b787d24" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.5" - -[[deps.IfElse]] -git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" -uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" -version = "0.1.1" - -[[deps.InfiniteArrays]] -deps = ["ArrayLayouts", "FillArrays", "Infinities", "LazyArrays", "LinearAlgebra", "Statistics"] -git-tree-sha1 = "e84e4ea66ca02755eaf7063e07169bf22370ee2c" -uuid = "4858937d-0d70-526a-a4dd-2d5cb5dd786c" -version = "0.12.6" - -[[deps.Infinities]] -git-tree-sha1 = "b2732e2076cd50639d827f9ae9fc4ea913c927fe" -uuid = "e1ba4f0e-776d-440f-acd9-e1d2e9742647" -version = "0.1.4" - -[[deps.Inflate]] -git-tree-sha1 = "f5fc07d4e706b84f72d54eedcc1c13d92fb0871c" -uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" -version = "0.1.2" - -[[deps.IniFile]] -git-tree-sha1 = "f550e6e32074c939295eb5ea6de31849ac2c9625" -uuid = "83e8ac13-25f8-5344-8a64-a9f2b223428f" -version = "0.5.1" - -[[deps.InitialValues]] -git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" -uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" -version = "0.3.1" - -[[deps.InlineStrings]] -deps = ["Parsers"] -git-tree-sha1 = "61feba885fac3a407465726d0c330b3055df897f" -uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" -version = "1.1.2" - -[[deps.IntelOpenMP_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "d979e54b71da82f3a65b62553da4fc3d18c9004c" -uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" -version = "2018.0.3+2" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.InternedStrings]] -deps = ["Random", "Test"] -git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" -uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" -version = "0.7.0" - -[[deps.Interpolations]] -deps = ["AxisAlgorithms", "ChainRulesCore", "LinearAlgebra", "OffsetArrays", "Random", "Ratios", "Requires", "SharedArrays", "SparseArrays", "StaticArrays", "WoodburyMatrices"] -git-tree-sha1 = "b7bc05649af456efc75d178846f47006c2c4c3c7" -uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" -version = "0.13.6" - -[[deps.IntervalSets]] -deps = ["Dates", "EllipsisNotation", "Statistics"] -git-tree-sha1 = "bcf640979ee55b652f3b01650444eb7bbe3ea837" -uuid = "8197267c-284f-5f27-9208-e0e47529a953" -version = "0.5.4" - -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "91b5dcf362c5add98049e6c29ee756910b03051d" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.3" - -[[deps.InvertedIndices]] -git-tree-sha1 = "bee5f1ef5bf65df56bdd2e40447590b272a5471f" -uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" -version = "1.1.0" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.1.1" - -[[deps.IterativeSolvers]] -deps = ["LinearAlgebra", "Printf", "Random", "RecipesBase", "SparseArrays"] -git-tree-sha1 = "1169632f425f79429f245113b775a0e3d121457c" -uuid = "42fd0dbc-a981-5370-80f2-aaf504508153" -version = "0.9.2" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLD2]] -deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "Pkg", "Printf", "Reexport", "TranscodingStreams", "UUIDs"] -git-tree-sha1 = "81b9477b49402b47fbe7f7ae0b252077f53e4a08" -uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.22" - -[[deps.JLLWrappers]] -deps = ["Preferences"] -git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.4.1" - -[[deps.JSON3]] -deps = ["Dates", "Mmap", "Parsers", "StructTypes", "UUIDs"] -git-tree-sha1 = "8c1f668b24d999fb47baf80436194fdccec65ad2" -uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -version = "1.9.4" - -[[deps.JuliaVariables]] -deps = ["MLStyle", "NameResolution"] -git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" -uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" -version = "0.2.4" - -[[deps.KLU]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse_jll"] -git-tree-sha1 = "cae5e3dfd89b209e01bcd65b3a25e74462c67ee0" -uuid = "ef3ab10e-7fda-4108-b977-705223b18434" -version = "0.3.0" - -[[deps.KernelDensity]] -deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"] -git-tree-sha1 = "591e8dc09ad18386189610acafb970032c519707" -uuid = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" -version = "0.6.3" - -[[deps.Krylov]] -deps = ["LinearAlgebra", "Printf", "SparseArrays"] -git-tree-sha1 = "82f5afb342a5624dc4651981584a841f6088166b" -uuid = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" -version = "0.8.0" - -[[deps.KrylovKit]] -deps = ["LinearAlgebra", "Printf"] -git-tree-sha1 = "49b0c1dd5c292870577b8f58c51072bd558febb9" -uuid = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" -version = "0.5.4" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "c9b86064be5ae0f63e50816a5a90b08c474507ae" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "4.9.1" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "5558ad3c8972d602451efe9d81c78ec14ef4f5ef" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.14+2" - -[[deps.LabelledArrays]] -deps = ["ArrayInterface", "ChainRulesCore", "LinearAlgebra", "MacroTools", "StaticArrays"] -git-tree-sha1 = "fbd884a02f8bf98fd90c53c1c9d2b21f9f30f42a" -uuid = "2ee39098-c373-598a-b85f-a56591580800" -version = "1.8.0" - -[[deps.LatinHypercubeSampling]] -deps = ["Random", "StableRNGs", "StatsBase", "Test"] -git-tree-sha1 = "42938ab65e9ed3c3029a8d2c58382ca75bdab243" -uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" -version = "1.8.0" - -[[deps.LatticeRules]] -deps = ["Random"] -git-tree-sha1 = "7f5b02258a3ca0221a6a9710b0a0a2e8fb4957fe" -uuid = "73f95e8e-ec14-4e6a-8b18-0d2e271c4e55" -version = "0.0.1" - -[[deps.LayoutPointers]] -deps = ["ArrayInterface", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static"] -git-tree-sha1 = "b651f573812d6c36c22c944dd66ef3ab2283dfa1" -uuid = "10f19ff3-798f-405d-979b-55457f8fc047" -version = "0.1.6" - -[[deps.LazyArrays]] -deps = ["ArrayLayouts", "FillArrays", "LinearAlgebra", "MacroTools", "MatrixFactorizations", "SparseArrays", "StaticArrays"] -git-tree-sha1 = "721bebe4d0f8581c18fccf272c62000e22a80a2d" -uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02" -version = "0.22.10" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" - -[[deps.LazyBandedMatrices]] -deps = ["ArrayLayouts", "BandedMatrices", "BlockArrays", "BlockBandedMatrices", "FillArrays", "LazyArrays", "LinearAlgebra", "MatrixFactorizations", "SparseArrays", "StaticArrays"] -git-tree-sha1 = "b1708e45e6b4308593904a14d0e5b0970d9ed0bb" -uuid = "d7e5e226-e90b-4449-9968-0f923699bf6f" -version = "0.7.12" - -[[deps.LearnBase]] -deps = ["LinearAlgebra", "StatsBase"] -git-tree-sha1 = "47e6f4623c1db88570c7a7fa66c6528b92ba4725" -uuid = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6" -version = "0.3.0" - -[[deps.LevyArea]] -deps = ["LinearAlgebra", "Random", "SpecialFunctions"] -git-tree-sha1 = "56513a09b8e0ae6485f34401ea9e2f31357958ec" -uuid = "2d8b4e74-eb68-11e8-0fb9-d5eb67b50637" -version = "1.0.0" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" - -[[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.Libiconv_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "42b62845d70a619f063a7da093d995ec8e15e778" -uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" -version = "1.16.1+1" - -[[deps.LineSearches]] -deps = ["LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "Printf"] -git-tree-sha1 = "f27132e551e959b3667d8c93eae90973225032dd" -uuid = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" -version = "7.1.1" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LinearSolve]] -deps = ["ArrayInterface", "DocStringExtensions", "IterativeSolvers", "KLU", "Krylov", "KrylovKit", "LinearAlgebra", "RecursiveFactorization", "Reexport", "Requires", "SciMLBase", "Setfield", "SparseArrays", "SuiteSparse", "UnPack"] -git-tree-sha1 = "6eb8e10ed29b85673495c29bd77ee0dfa8929977" -uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" -version = "1.15.0" - -[[deps.LogExpFunctions]] -deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "a970d55c2ad8084ca317a4658ba6ce99b7523571" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.12" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.LoopVectorization]] -deps = ["ArrayInterface", "CPUSummary", "ChainRulesCore", "CloseOpenIntervals", "DocStringExtensions", "ForwardDiff", "HostCPUFeatures", "IfElse", "LayoutPointers", "LinearAlgebra", "OffsetArrays", "PolyesterWeave", "SIMDDualNumbers", "SLEEFPirates", "SpecialFunctions", "Static", "ThreadingUtilities", "UnPack", "VectorizationBase"] -git-tree-sha1 = "4acc35e95bf18de5e9562d27735bef0950f2ed74" -uuid = "bdcacae8-1622-11e9-2a5c-532679323890" -version = "0.12.108" - -[[deps.MAT]] -deps = ["BufferedStreams", "CodecZlib", "HDF5", "SparseArrays"] -git-tree-sha1 = "971be550166fe3f604d28715302b58a3f7293160" -uuid = "23992714-dd62-5051-b70f-ba57cb901cac" -version = "0.10.3" - -[[deps.MKL_jll]] -deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "e595b205efd49508358f7dc670a940c790204629" -uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" -version = "2022.0.0+0" - -[[deps.MLDataPattern]] -deps = ["LearnBase", "MLLabelUtils", "Random", "SparseArrays", "StatsBase"] -git-tree-sha1 = "e99514e96e8b8129bb333c69e063a56ab6402b5b" -uuid = "9920b226-0b2a-5f5f-9153-9aa70a013f8b" -version = "0.5.4" - -[[deps.MLDataUtils]] -deps = ["DataFrames", "DelimitedFiles", "LearnBase", "MLDataPattern", "MLLabelUtils", "Statistics", "StatsBase"] -git-tree-sha1 = "ee54803aea12b9c8ee972e78ece11ac6023715e6" -uuid = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" -version = "0.5.4" - -[[deps.MLDatasets]] -deps = ["BinDeps", "CSV", "ColorTypes", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "JLD2", "JSON3", "MAT", "MLUtils", "Pickle", "Requires", "SparseArrays", "Tables"] -git-tree-sha1 = "862c3a31a5a6dfc68e78e2e1634dac1d3b0f654e" -uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" -version = "0.5.16" - -[[deps.MLLabelUtils]] -deps = ["LearnBase", "MappedArrays", "StatsBase"] -git-tree-sha1 = "fd75d4b0c4016e047bbb6263eecf7ae3891af522" -uuid = "66a33bbf-0c2b-5fc8-a008-9da813334f0a" -version = "0.5.7" - -[[deps.MLStyle]] -git-tree-sha1 = "594e189325f66e23a8818e5beb11c43bb0141bcd" -uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.10" - -[[deps.MLUtils]] -deps = ["ChainRulesCore", "DelimitedFiles", "FLoops", "FoldsThreads", "Random", "ShowCases", "Statistics", "StatsBase"] -git-tree-sha1 = "32eeb46fa393ae36a4127c9442ade478c8d01117" -uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.2.3" - -[[deps.MPI]] -deps = ["Distributed", "DocStringExtensions", "Libdl", "MPICH_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "Pkg", "Random", "Requires", "Serialization", "Sockets"] -git-tree-sha1 = "d56a80d8cf8b9dc3050116346b3d83432b1912c0" -uuid = "da04e1cc-30fd-572f-bb4f-1f8673147195" -version = "0.19.2" - -[[deps.MPICH_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "3dacfc006764fe498515a022c3976b7e133c4008" -uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" -version = "4.0.2+0" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.9" - -[[deps.ManualMemory]] -git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" -uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" -version = "0.1.8" - -[[deps.MappedArrays]] -git-tree-sha1 = "e8b359ef06ec72e8c030463fe02efe5527ee5142" -uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" -version = "0.4.1" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MatrixFactorizations]] -deps = ["ArrayLayouts", "LinearAlgebra", "Printf", "Random"] -git-tree-sha1 = "2212d36f97e01347adb1460a6914e20f2feee853" -uuid = "a3b82374-2e81-5b9e-98ce-41277c0e4c87" -version = "0.9.1" - -[[deps.MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "Random", "Sockets"] -git-tree-sha1 = "1c38e51c3d08ef2278062ebceade0e46cefc96fe" -uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.0.3" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" - -[[deps.MicroCollections]] -deps = ["BangBang", "InitialValues", "Setfield"] -git-tree-sha1 = "6bb7786e4f24d44b4e29df03c69add1b63d88f01" -uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.1.2" - -[[deps.MicrosoftMPI_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "a16aa086d335ed7e0170c5265247db29172af2f9" -uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" -version = "10.1.3+2" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "bf210ce90b6c9eed32d25dbcae1ebc565df2687f" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.0.2" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" - -[[deps.MuladdMacro]] -git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" -uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" -version = "0.2.2" - -[[deps.NLSolversBase]] -deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] -git-tree-sha1 = "50310f934e55e5ca3912fb941dec199b49ca9b68" -uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" -version = "7.8.2" - -[[deps.NLsolve]] -deps = ["Distances", "LineSearches", "LinearAlgebra", "NLSolversBase", "Printf", "Reexport"] -git-tree-sha1 = "019f12e9a1a7880459d0173c182e6a99365d7ac1" -uuid = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" -version = "4.5.1" - -[[deps.NNlib]] -deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "a59a614b8b4ea6dc1dcec8c6514e251f13ccbe10" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.8.4" - -[[deps.NNlibCUDA]] -deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] -git-tree-sha1 = "0d18b4c80a92a00d3d96e8f9677511a7422a946e" -uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" -version = "0.2.2" - -[[deps.NaNMath]] -git-tree-sha1 = "b086b7ea07f8e38cf122f5016af580881ac914fe" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "0.3.7" - -[[deps.NameResolution]] -deps = ["PrettyPrint"] -git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" -uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" -version = "0.1.5" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" - -[[deps.NonlinearSolve]] -deps = ["ArrayInterface", "FiniteDiff", "ForwardDiff", "IterativeSolvers", "LinearAlgebra", "RecursiveArrayTools", "RecursiveFactorization", "Reexport", "SciMLBase", "Setfield", "StaticArrays", "UnPack"] -git-tree-sha1 = "aeebff6a2a23506e5029fd2248a26aca98e477b3" -uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" -version = "0.3.16" - -[[deps.ObjectFile]] -deps = ["Reexport", "StructIO"] -git-tree-sha1 = "55ce61d43409b1fb0279d1781bf3b0f22c83ab3b" -uuid = "d8793406-e978-5875-9003-1fc021f44a92" -version = "0.3.7" - -[[deps.OffsetArrays]] -deps = ["Adapt"] -git-tree-sha1 = "043017e0bdeff61cfbb7afeb558ab29536bbb5ed" -uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.10.8" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" - -[[deps.OpenMPI_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "6340586e076b2abd41f5ba1a3b9c774ec6b30fde" -uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" -version = "4.1.2+0" - -[[deps.OpenSSL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "ab05aa4cc89736e95915b01e7279e61b1bfe33b8" -uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "1.1.14+0" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Optim]] -deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"] -git-tree-sha1 = "bc0a748740e8bc5eeb9ea6031e6f050de1fc0ba2" -uuid = "429524aa-4258-5aef-a3af-852621145aeb" -version = "1.6.2" - -[[deps.Optimisers]] -deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "cfedc2d6990d792e705ade4c458fea5fbe574520" -uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.2.2" - -[[deps.OrderedCollections]] -git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.4.1" - -[[deps.OrdinaryDiffEq]] -deps = ["Adapt", "ArrayInterface", "DataStructures", "DiffEqBase", "DocStringExtensions", "ExponentialUtilities", "FastClosures", "FiniteDiff", "ForwardDiff", "LinearAlgebra", "LinearSolve", "Logging", "LoopVectorization", "MacroTools", "MuladdMacro", "NLsolve", "NonlinearSolve", "Polyester", "PreallocationTools", "RecursiveArrayTools", "Reexport", "SciMLBase", "SparseArrays", "SparseDiffTools", "StaticArrays", "UnPack"] -git-tree-sha1 = "8031a288c9b418664a3dfbac36e464a3f61ace73" -uuid = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" -version = "6.10.0" - -[[deps.PDMats]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "e8185b83b9fc56eb6456200e873ce598ebc7f262" -uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.7" - -[[deps.ParameterSchedulers]] -deps = ["Flux", "InfiniteArrays"] -git-tree-sha1 = "68f63744d5d3e1714f989a9b4f38182275d3f348" -uuid = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" -version = "0.3.3" - -[[deps.Parameters]] -deps = ["OrderedCollections", "UnPack"] -git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" -uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" -version = "0.12.3" - -[[deps.Parsers]] -deps = ["Dates"] -git-tree-sha1 = "3b429f37de37f1fc603cc1de4a799dc7fbe4c0b6" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.3.0" - -[[deps.Pickle]] -deps = ["DataStructures", "InternedStrings", "Serialization", "SparseArrays", "Strided", "StringEncodings", "ZipFile"] -git-tree-sha1 = "8e4ba4cb57bedd0289865c65ffedeee910d6a8b6" -uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" -version = "0.3.1" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" - -[[deps.PoissonRandom]] -deps = ["Random", "Statistics", "Test"] -git-tree-sha1 = "44d018211a56626288b5d3f8c6497d28c26dc850" -uuid = "e409e4f3-bfea-5376-8464-e040bb5c01ab" -version = "0.4.0" - -[[deps.Polyester]] -deps = ["ArrayInterface", "BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "ManualMemory", "PolyesterWeave", "Requires", "Static", "StrideArraysCore", "ThreadingUtilities"] -git-tree-sha1 = "8d95a735921204f5d551ac300b20d802a150433a" -uuid = "f517fe37-dbe3-4b94-8317-1923a5111588" -version = "0.6.8" - -[[deps.PolyesterWeave]] -deps = ["BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "Static", "ThreadingUtilities"] -git-tree-sha1 = "7e597df97e46ffb1c8adbaddfa56908a7a20194b" -uuid = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad" -version = "0.1.5" - -[[deps.PooledArrays]] -deps = ["DataAPI", "Future"] -git-tree-sha1 = "28ef6c7ce353f0b35d0df0d5930e0d072c1f5b9b" -uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -version = "1.4.1" - -[[deps.PositiveFactorizations]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "17275485f373e6673f7e7f97051f703ed5b15b20" -uuid = "85a6dd25-e78a-55b7-8502-1745935b8125" -version = "0.2.4" - -[[deps.PreallocationTools]] -deps = ["Adapt", "ArrayInterface", "ForwardDiff", "LabelledArrays"] -git-tree-sha1 = "6c138c8510111fa47b5d2ed8ada482d97e279bee" -uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46" -version = "0.2.4" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.3.0" - -[[deps.PrettyPrint]] -git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" -uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" -version = "0.2.0" - -[[deps.PrettyTables]] -deps = ["Crayons", "Formatting", "Markdown", "Reexport", "Tables"] -git-tree-sha1 = "dfb54c4e414caa595a1f2ed759b160f5a3ddcba5" -uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "1.3.1" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.ProgressLogging]] -deps = ["Logging", "SHA", "UUIDs"] -git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" -uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" -version = "0.1.4" - -[[deps.QuadGK]] -deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "78aadffb3efd2155af139781b8a8df1ef279ea39" -uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.4.2" - -[[deps.QuasiMonteCarlo]] -deps = ["Distributions", "LatinHypercubeSampling", "LatticeRules", "Sobol"] -git-tree-sha1 = "bc69c718a83951dcb999404ff267a7b8c39c1c63" -uuid = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" -version = "0.2.4" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA", "Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.Random123]] -deps = ["Random", "RandomNumbers"] -git-tree-sha1 = "afeacaecf4ed1649555a19cb2cad3c141bbc9474" -uuid = "74087812-796a-5b5d-8853-05524746bad3" -version = "1.5.0" - -[[deps.RandomNumbers]] -deps = ["Random", "Requires"] -git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" -uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" -version = "1.5.3" - -[[deps.Ratios]] -deps = ["Requires"] -git-tree-sha1 = "dc84268fe0e3335a62e315a3a7cf2afa7178a734" -uuid = "c84ed2f1-dad5-54f0-aa8e-dbefe2724439" -version = "0.4.3" - -[[deps.RealDot]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" -uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" -version = "0.1.0" - -[[deps.RecipesBase]] -git-tree-sha1 = "6bf3f380ff52ce0832ddd3a2a7b9538ed1bcca7d" -uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -version = "1.2.1" - -[[deps.RecursiveArrayTools]] -deps = ["Adapt", "ArrayInterface", "ChainRulesCore", "DocStringExtensions", "FillArrays", "LinearAlgebra", "RecipesBase", "Requires", "StaticArrays", "Statistics", "ZygoteRules"] -git-tree-sha1 = "bfe14f127f3e7def02a6c2b1940b39d0dabaa3ef" -uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -version = "2.26.3" - -[[deps.RecursiveFactorization]] -deps = ["LinearAlgebra", "LoopVectorization", "Polyester", "StrideArraysCore", "TriangularSolve"] -git-tree-sha1 = "a9a852c7ebb08e2a40e8c0ab9830a744fa283690" -uuid = "f2c3362d-daeb-58d1-803e-2bc74f2840b4" -version = "0.2.10" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.ResettableStacks]] -deps = ["StaticArrays"] -git-tree-sha1 = "256eeeec186fa7f26f2801732774ccf277f05db9" -uuid = "ae5879a3-cd67-5da8-be7f-38c6eb64a37b" -version = "1.1.1" - -[[deps.ReverseDiff]] -deps = ["ChainRulesCore", "DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"] -git-tree-sha1 = "559db2c7a28262e9ff1af1ad4ec539aa972c8934" -uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -version = "1.13.0" - -[[deps.Richardson]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "e03ca566bec93f8a3aeb059c8ef102f268a38949" -uuid = "708f8203-808e-40c0-ba2d-98a6953ed40d" -version = "1.4.0" - -[[deps.Rmath]] -deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "bf3188feca147ce108c76ad82c2792c57abe7b1f" -uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.0" - -[[deps.Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "68db32dff12bb6127bac73c209881191bf0efbb7" -uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.3.0+0" - -[[deps.RuntimeGeneratedFunctions]] -deps = ["ExprTools", "SHA", "Serialization"] -git-tree-sha1 = "cdc1e4278e91a6ad530770ebb327f9ed83cf10c4" -uuid = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" -version = "0.5.3" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - -[[deps.SIMDDualNumbers]] -deps = ["ForwardDiff", "IfElse", "SLEEFPirates", "VectorizationBase"] -git-tree-sha1 = "62c2da6eb66de8bb88081d20528647140d4daa0e" -uuid = "3cdde19b-5bb0-4aaf-8931-af3e248e098b" -version = "0.1.0" - -[[deps.SIMDTypes]] -git-tree-sha1 = "330289636fb8107c5f32088d2741e9fd7a061a5c" -uuid = "94e857df-77ce-4151-89e5-788b33177be4" -version = "0.1.0" - -[[deps.SLEEFPirates]] -deps = ["IfElse", "Static", "VectorizationBase"] -git-tree-sha1 = "ac399b5b163b9140f9c310dfe9e9aaa225617ff6" -uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa" -version = "0.6.32" - -[[deps.SciMLBase]] -deps = ["ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "RecipesBase", "RecursiveArrayTools", "StaticArrays", "Statistics", "Tables", "TreeViews"] -git-tree-sha1 = "f03796a588eba66f6bcc63cfdeda89b4a339ce4e" -uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "1.30.0" - -[[deps.SentinelArrays]] -deps = ["Dates", "Random"] -git-tree-sha1 = "6a2f7d70512d205ca8c7ee31bfa9f142fe74310c" -uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.3.12" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "Requires"] -git-tree-sha1 = "38d88503f695eb0301479bc9b0d4320b378bafe5" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "0.8.2" - -[[deps.SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[deps.ShowCases]] -git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" -uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" -version = "0.1.0" - -[[deps.SimpleTraits]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" -uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" -version = "0.9.4" - -[[deps.Sobol]] -deps = ["DelimitedFiles", "Random"] -git-tree-sha1 = "5a74ac22a9daef23705f010f72c81d6925b19df8" -uuid = "ed01d8cd-4d21-5b2a-85b4-cc3bdc58bad4" -version = "1.5.0" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.0.1" - -[[deps.SparseArrays]] -deps = ["LinearAlgebra", "Random"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[deps.SparseDiffTools]] -deps = ["Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "Graphs", "LinearAlgebra", "Requires", "SparseArrays", "StaticArrays", "VertexSafeGraphs"] -git-tree-sha1 = "314a07e191ea4a5ea5a2f9d6b39f03833bde5e08" -uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" -version = "1.21.0" - -[[deps.SpecialFunctions]] -deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "5ba658aeecaaf96923dce0da9e703bd1fe7666f9" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.1.4" - -[[deps.SplittablesBase]] -deps = ["Setfield", "Test"] -git-tree-sha1 = "39c9f91521de844bad65049efd4f9223e7ed43f9" -uuid = "171d559e-b47b-412a-8079-5efa626c420e" -version = "0.1.14" - -[[deps.StableRNGs]] -deps = ["Random", "Test"] -git-tree-sha1 = "3be7d49667040add7ee151fefaf1f8c04c8c8276" -uuid = "860ef19b-820b-49d6-a774-d7a799459cd3" -version = "1.0.0" - -[[deps.Static]] -deps = ["IfElse"] -git-tree-sha1 = "87e9954dfa33fd145694e42337bdd3d5b07021a6" -uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "0.6.0" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "cd56bf18ed715e8b09f06ef8c6b781e6cdc49911" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.4.4" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "c82aaa13b44ea00134f8c9c89819477bd3986ecd" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.3.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "8977b17906b0a1cc74ab2e3a05faa16cf08a8291" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.16" - -[[deps.StatsFuns]] -deps = ["ChainRulesCore", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "5950925ff997ed6fb3e985dcce8eb1ba42a0bbe7" -uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "0.9.18" - -[[deps.SteadyStateDiffEq]] -deps = ["DiffEqBase", "DiffEqCallbacks", "LinearAlgebra", "NLsolve", "Reexport", "SciMLBase"] -git-tree-sha1 = "3e057e1f9f12d18cac32011aed9e61eef6c1c0ce" -uuid = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" -version = "1.6.6" - -[[deps.StochasticDiffEq]] -deps = ["Adapt", "ArrayInterface", "DataStructures", "DiffEqBase", "DiffEqJump", "DiffEqNoiseProcess", "DocStringExtensions", "FillArrays", "FiniteDiff", "ForwardDiff", "LevyArea", "LinearAlgebra", "Logging", "MuladdMacro", "NLsolve", "OrdinaryDiffEq", "Random", "RandomNumbers", "RecursiveArrayTools", "Reexport", "SparseArrays", "SparseDiffTools", "StaticArrays", "UnPack"] -git-tree-sha1 = "4d428684218ac7a3dc54aaeb3f76e03bf892c33c" -uuid = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" -version = "6.46.0" - -[[deps.StrideArraysCore]] -deps = ["ArrayInterface", "CloseOpenIntervals", "IfElse", "LayoutPointers", "ManualMemory", "Requires", "SIMDTypes", "Static", "ThreadingUtilities"] -git-tree-sha1 = "df8fc9d0407a77241c529cc2ef97ba2e3436ff51" -uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da" -version = "0.3.2" - -[[deps.Strided]] -deps = ["LinearAlgebra", "TupleTools"] -git-tree-sha1 = "7c4bcef07d559776a9e2a009c441547fb9eb5c92" -uuid = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" -version = "1.2.1" - -[[deps.StringEncodings]] -deps = ["Libiconv_jll"] -git-tree-sha1 = "50ccd5ddb00d19392577902f0079267a72c5ab04" -uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" -version = "0.3.5" - -[[deps.StructIO]] -deps = ["Test"] -git-tree-sha1 = "010dc73c7146869c042b49adcdb6bf528c12e859" -uuid = "53d494c1-5632-5724-8f4c-31dff12d585f" -version = "0.3.0" - -[[deps.StructTypes]] -deps = ["Dates", "UUIDs"] -git-tree-sha1 = "d24a825a95a6d98c385001212dc9020d609f2d4f" -uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" -version = "1.8.1" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] -git-tree-sha1 = "5ce79ce186cc678bbb5c5681ca3379d1ddae11a1" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.7.0" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.ThreadPools]] -deps = ["Printf", "RecipesBase", "Statistics"] -git-tree-sha1 = "705ccc29d575b87cceb359dfea19f4653d06df8f" -uuid = "b189fb0b-2eb5-4ed4-bc0c-d34c51242431" -version = "1.2.1" - -[[deps.ThreadingUtilities]] -deps = ["ManualMemory"] -git-tree-sha1 = "f8629df51cab659d70d2e5618a430b4d3f37f2c3" -uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" -version = "0.5.0" - -[[deps.TimerOutputs]] -deps = ["ExprTools", "Printf"] -git-tree-sha1 = "11db03dd5bbc0d2b57a570d228a0f34538c586b1" -uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.17" - -[[deps.Tracker]] -deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"] -git-tree-sha1 = "0874c1b5de1b5529b776cfeca3ec0acfada97b1b" -uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -version = "0.2.20" - -[[deps.TranscodingStreams]] -deps = ["Random", "Test"] -git-tree-sha1 = "216b95ea110b5972db65aa90f88d8d89dcb8851c" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.6" - -[[deps.Transducers]] -deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "c76399a3bbe6f5a88faa33c8f8a65aa631d95013" -uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.73" - -[[deps.Trapz]] -git-tree-sha1 = "79eb0ed763084a3e7de81fe1838379ac6a23b6a0" -uuid = "592b5752-818d-11e9-1e9a-2b8ca4a44cd1" -version = "2.0.3" - -[[deps.TreeViews]] -deps = ["Test"] -git-tree-sha1 = "8d0d7a3fe2f30d6a7f833a5f19f7c7a5b396eae6" -uuid = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7" -version = "0.3.0" - -[[deps.TriangularSolve]] -deps = ["CloseOpenIntervals", "IfElse", "LayoutPointers", "LinearAlgebra", "LoopVectorization", "Polyester", "Static", "VectorizationBase"] -git-tree-sha1 = "b8d08f55b02625770c09615d96927b3a8396925e" -uuid = "d5829a12-d9aa-46ab-831f-fb7c9ab06edf" -version = "0.1.11" - -[[deps.TupleTools]] -git-tree-sha1 = "3c712976c47707ff893cf6ba4354aa14db1d8938" -uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" -version = "1.3.0" - -[[deps.URIParser]] -deps = ["Unicode"] -git-tree-sha1 = "53a9f49546b8d2dd2e688d216421d050c9a31d0d" -uuid = "30578b45-9adc-5946-b283-645ec420af67" -version = "0.4.1" - -[[deps.URIs]] -git-tree-sha1 = "97bbe755a53fe859669cd907f2d96aee8d2c1355" -uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.3.0" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Umlaut]] -deps = ["LinearAlgebra", "Statistics", "Test"] -git-tree-sha1 = "1428bb6784d43298b29503b4a08b8a51b13e4c07" -uuid = "92992a2b-8ce5-4a9c-bb9d-58be9a7dc841" -version = "0.2.4" - -[[deps.UnPack]] -git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" -uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" -version = "1.0.2" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.VectorizationBase]] -deps = ["ArrayInterface", "CPUSummary", "HostCPUFeatures", "Hwloc", "IfElse", "LayoutPointers", "Libdl", "LinearAlgebra", "SIMDTypes", "Static"] -git-tree-sha1 = "9d1b533f597d87ce9b4abd36a2ce4664f08e08ed" -uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" -version = "0.21.29" - -[[deps.VertexSafeGraphs]] -deps = ["Graphs"] -git-tree-sha1 = "8351f8d73d7e880bfc042a8b6922684ebeafb35c" -uuid = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f" -version = "0.2.0" - -[[deps.WeakRefStrings]] -deps = ["DataAPI", "InlineStrings", "Parsers"] -git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" -uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" -version = "1.4.2" - -[[deps.WoodburyMatrices]] -deps = ["LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "de67fa59e33ad156a590055375a30b23c40299d3" -uuid = "efce3f68-66dc-5838-9240-27a6d6f5f9b6" -version = "0.5.5" - -[[deps.Yota]] -deps = ["ChainRules", "ChainRulesCore", "FiniteDifferences", "LinearAlgebra", "NNlib", "Random", "Statistics", "Test", "UUIDs", "Umlaut"] -git-tree-sha1 = "b4eef79929bab5503cbc6ca495aa205bdab98978" -uuid = "cd998857-8626-517d-b929-70ad188a48f0" -version = "0.7.3" - -[[deps.ZipFile]] -deps = ["Libdl", "Printf", "Zlib_jll"] -git-tree-sha1 = "3593e69e469d2111389a9bd06bac1f3d730ac6de" -uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" -version = "0.9.4" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" - -[[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "8c3e9ae8c2b520200df59d4f683a0dab65685ade" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.38" - -[[deps.ZygoteRules]] -deps = ["MacroTools"] -git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.2" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" diff --git a/examples/Project.toml b/examples/Project.toml index 5e213565..a512ba55 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -7,17 +7,17 @@ version = "0.1.0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" -ExplicitFluxLayers = "aafdbc67-5cd5-409a-82a0-ebc47ac8091e" FastDEQ = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" Format = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" -LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6" -MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/src/FastDEQExperiments.jl b/examples/src/FastDEQExperiments.jl index af303d6b..c6147a22 100644 --- a/examples/src/FastDEQExperiments.jl +++ b/examples/src/FastDEQExperiments.jl @@ -1,24 +1,23 @@ module FastDEQExperiments using FastDEQ, - ExplicitFluxLayers, + DataLoaders, Random, - Flux, OrdinaryDiffEq, FluxMPI, Format, + Lux, MLDatasets, - MLDataUtils, - DataLoaders, Optimisers, MPI, CUDA, Setfield, - ParameterSchedulers -import LearnBase: ObsDim -import MLDataUtils: nobs, getobs + ParameterSchedulers, + NNlib, + Zygote -const EFL = ExplicitFluxLayers +import Flux: OneHotArray, onecold, onehotbatch, onehot +import Flux.Losses: logitcrossentropy, mse # Memory Management relieve_gc_pressure(::Union{Nothing,<:AbstractArray}) = nothing diff --git a/examples/src/dataloaders.jl b/examples/src/dataloaders.jl index 09b15f19..2c9effe7 100644 --- a/examples/src/dataloaders.jl +++ b/examples/src/dataloaders.jl @@ -6,10 +6,8 @@ end MLDatasetsImageData(images::AbstractArray{T,4}, labels::AbstractArray{T,2}) where {T} = MLDatasetsImageData(collect(eachslice(images, dims=4)), collect(eachslice(labels, dims=2))) -nobs(d::MLDatasetsImageData) = length(d.images) - -getobs(d::MLDatasetsImageData, i::Int, ::ObsDim.Undefined) = (d.images[i], d.labels[i]) -getobs(d::MLDatasetsImageData, i::Int) = (d.images[i], d.labels[i]) +Base.length(d::MLDatasetsImageData) = length(d.images) +Base.getindex(d::MLDatasetsImageData, i::Int) = (d.images[i], d.labels[i]) function get_dataloaders( dataset::Symbol; μ=nothing, σ²=nothing, train_batchsize::Int64, eval_batchsize::Int64 @@ -23,9 +21,9 @@ function get_dataloaders( end x_train = (x_train .- μ) ./ σ² - y_train = Float32.(Flux.onehotbatch(y_train, 0:(nclasses - 1))) + y_train = Float32.(onehotbatch(y_train, 0:(nclasses - 1))) x_test = (x_test .- μ) ./ σ² - y_test = Float32.(Flux.onehotbatch(y_test, 0:(nclasses - 1))) + y_test = Float32.(onehotbatch(y_test, 0:(nclasses - 1))) train_dataset = shuffleobs(MLDatasetsImageData(x_train, y_train)) train_dataset = is_distributed() ? DistributedDataContainer(train_dataset) : train_dataset diff --git a/examples/src/models.jl b/examples/src/models.jl index b0745abd..426ee771 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -1,15 +1,15 @@ # Building Blocks ## Helpful Functional Wrappers function conv1x1(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) - return EFL.Conv((1, 1), mapping, activation; stride=stride, pad=0, bias=bias, initW=NormalInitializer(), kwargs...) + return Conv((1, 1), mapping, activation; stride=stride, pad=0, bias=bias, initW=NormalInitializer(), kwargs...) end function conv3x3(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) - return EFL.Conv((3, 3), mapping, activation; stride=stride, pad=1, bias=bias, initW=NormalInitializer(), kwargs...) + return Conv((3, 3), mapping, activation; stride=stride, pad=1, bias=bias, initW=NormalInitializer(), kwargs...) end function conv5x5(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) - return EFL.Conv((5, 5), mapping, activation; stride=stride, pad=2, bias=bias, initW=NormalInitializer(), kwargs...) + return Conv((5, 5), mapping, activation; stride=stride, pad=2, bias=bias, initW=NormalInitializer(), kwargs...) end reassociate(x::NTuple{2,<:AbstractArray}, y) = (x[1], (x[2], y)) @@ -26,14 +26,14 @@ function downsample_module(mapping, level_diff, activation; group_count=8) end end - layers = EFL.AbstractExplicitLayer[] + layers = Lux.AbstractExplicitLayer[] for i in 1:level_diff inchs, outchs = intermediate_mapping(i) push!(layers, conv3x3(inchs => outchs; stride=2)) - # push!(layers, EFL.GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) - push!(layers, EFL.BatchNorm(outchs, activation; affine=true, track_stats=false)) + # push!(layers, GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) + push!(layers, BatchNorm(outchs, activation; affine=true, track_stats=false)) end - return EFL.Chain(layers...) + return Chain(layers...) end ## Upsample Module @@ -48,15 +48,15 @@ function upsample_module(mapping, level_diff, activation; upsample_mode::Symbol= end end - layers = EFL.AbstractExplicitLayer[] + layers = Lux.AbstractExplicitLayer[] for i in 1:level_diff inchs, outchs = intermediate_mapping(i) push!(layers, conv1x1(inchs => outchs)) - # push!(layers, EFL.GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) - push!(layers, EFL.BatchNorm(outchs, activation; affine=true, track_stats=false)) - push!(layers, EFL.Upsample(upsample_mode; scale=2)) + # push!(layers, GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) + push!(layers, BatchNorm(outchs, activation; affine=true, track_stats=false)) + push!(layers, Upsample(upsample_mode; scale=2)) end - return EFL.Chain(layers...) + return Chain(layers...) end ## Residual Block @@ -64,7 +64,7 @@ function ResidualBlockV1( mapping; deq_expand::Int=5, num_gn_groups::Int=4, - downsample=EFL.NoOpLayer(), + downsample=NoOpLayer(), n_big_kernels::Int=0, dropout_rate::Real=0.0f0, gn_affine::Bool=true, @@ -77,36 +77,35 @@ function ResidualBlockV1( conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=false) conv1, conv2 = if weight_norm - EFL.WeightNorm(conv1, (:weight,), (4,)), EFL.WeightNorm(conv2, (:weight,), (4,)) + WeightNorm(conv1, (:weight,), (4,)), WeightNorm(conv2, (:weight,), (4,)) else conv1, conv2 end - # gn1 = EFL.GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - # gn2 = EFL.GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - # gn3 = EFL.GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) - gn1 = EFL.BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) - gn2 = EFL.BatchNorm(outplanes, relu; affine=gn_affine, track_stats=gn_track_stats) - gn3 = EFL.BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) + # gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # gn2 = GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + gn1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) + gn2 = BatchNorm(outplanes, relu; affine=gn_affine, track_stats=gn_track_stats) + gn3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) - dropout = iszero(dropout_rate) ? EFL.NoOpLayer() : EFL.VariationalHiddenDropout(dropout_rate) + dropout = iszero(dropout_rate) ? NoOpLayer() : VariationalHiddenDropout(dropout_rate) - return EFL.Chain( - EFL.Parallel( + return Chain( + Parallel( reassociate, # Reassociate and Merge - EFL.Chain(conv1, gn1, conv2, EFL.BranchLayer(downsample, dropout)), # For x - EFL.NoOpLayer(), # For injection + Chain(conv1, gn1, conv2, BranchLayer(downsample, dropout)), # For x + NoOpLayer(), # For injection ), - EFL.Parallel( + Parallel( +, - EFL.NoOpLayer(), # For y1 - EFL.Chain( - EFL.WrappedFunction(y2i -> y2i[1] .+ y2i[2]), # Since injection could be a scalar + NoOpLayer(), # For y1 + Chain( + WrappedFunction(y2i -> y2i[1] .+ y2i[2]), # Since injection could be a scalar gn2, ), # For (y2, injection) ), - # EFL.WrappedFunction(Base.Fix1(broadcast, relu)), - EFL.ActivationFunction(relu), + ActivationFunction(relu), gn3, ) end @@ -115,7 +114,7 @@ function ResidualBlockV2( mapping; deq_expand::Int=5, num_gn_groups::Int=4, - downsample=EFL.NoOpLayer(), + downsample=NoOpLayer(), n_big_kernels::Int=0, dropout_rate::Real=0.0f0, gn_affine::Bool=true, @@ -128,87 +127,85 @@ function ResidualBlockV2( conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=false) conv1, conv2 = if weight_norm - EFL.WeightNorm(conv1, (:weight,), (4,)), EFL.WeightNorm(conv2, (:weight,), (4,)) + WeightNorm(conv1, (:weight,), (4,)), WeightNorm(conv2, (:weight,), (4,)) else conv1, conv2 end - # gn1 = EFL.GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - # gn2 = EFL.GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - # gn3 = EFL.GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) - gn1 = EFL.BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) - gn2 = EFL.BatchNorm(outplanes, relu; affine=gn_affine, track_stats=gn_track_stats) - gn3 = EFL.BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) + # gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # gn2 = GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + gn1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) + gn2 = BatchNorm(outplanes, relu; affine=gn_affine, track_stats=gn_track_stats) + gn3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) - dropout = iszero(dropout_rate) ? EFL.NoOpLayer() : EFL.VariationalHiddenDropout(dropout_rate) + dropout = iszero(dropout_rate) ? NoOpLayer() : VariationalHiddenDropout(dropout_rate) - return EFL.Chain( + return Chain( conv1, gn1, conv2, - EFL.Parallel(+, downsample, EFL.Chain(dropout, gn2)), - # EFL.WrappedFunction(Base.Fix1(broadcast, relu)), - EFL.ActivationFunction(relu), + Parallel(+, downsample, Chain(dropout, gn2)), + # WrappedFunction(Base.Fix1(broadcast, relu)), + ActivationFunction(relu), gn3, ) end function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=false, bn_affine::Bool=true) rescale = if first(mapping) != last(mapping) * expansion - EFL.Chain( + Chain( conv1x1(first(mapping) => last(mapping) * expansion), - EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), + BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), ) else - EFL.NoOpLayer() + NoOpLayer() end - return EFL.Chain( - EFL.Parallel(reassociate, EFL.BranchLayer(rescale, conv1x1(mapping)), EFL.NoOpLayer()), - EFL.Parallel( + return Chain( + Parallel(reassociate, BranchLayer(rescale, conv1x1(mapping)), NoOpLayer()), + Parallel( +, - EFL.NoOpLayer(), - EFL.Chain( - EFL.WrappedFunction(y2i -> y2i[1] .+ y2i[2]), # Since injection could be a scalar - EFL.Chain( - EFL.BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats), + NoOpLayer(), + Chain( + WrappedFunction(y2i -> y2i[1] .+ y2i[2]), # Since injection could be a scalar + Chain( + BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats), conv3x3(last(mapping) => last(mapping) * expansion), - EFL.BatchNorm(last(mapping) * expansion, relu; track_stats=bn_track_stats, affine=bn_affine), + BatchNorm(last(mapping) * expansion, relu; track_stats=bn_track_stats, affine=bn_affine), conv1x1(last(mapping) * expansion => last(mapping) * expansion), - EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), + BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), ), ), ), - # EFL.WrappedFunction(Base.Fix1(broadcast, relu)), - EFL.ActivationFunction(relu), + ActivationFunction(relu), ) end function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=false, bn_affine::Bool=true) rescale = if first(mapping) != last(mapping) * expansion - EFL.Chain( + Chain( conv1x1(first(mapping) => last(mapping) * expansion), - EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), + BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), ) else - EFL.NoOpLayer() + NoOpLayer() end - return EFL.Chain( - EFL.Parallel( + return Chain( + Parallel( +, rescale, - EFL.Chain( + Chain( conv1x1(mapping), - EFL.BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats), + BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats), conv3x3(last(mapping) => last(mapping) * expansion), - EFL.BatchNorm(last(mapping) * expansion, relu; track_stats=bn_track_stats, affine=bn_affine), + BatchNorm(last(mapping) * expansion, relu; track_stats=bn_track_stats, affine=bn_affine), conv1x1(last(mapping) * expansion => last(mapping) * expansion), - EFL.BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), + BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), ), ), - # EFL.WrappedFunction(Base.Fix1(broadcast, relu)), - EFL.ActivationFunction(relu), + ActivationFunction(relu), ) end @@ -225,31 +222,31 @@ function get_model( downsample_layers = [ conv3x3(3 => init_channel_size; stride=config.downsample_times >= 1 ? 2 : 1), - EFL.BatchNorm(init_channel_size, relu; affine=true, track_stats=false), + BatchNorm(init_channel_size, relu; affine=true, track_stats=false), conv3x3(init_channel_size => init_channel_size; stride=config.downsample_times >= 2 ? 2 : 1), - EFL.BatchNorm(init_channel_size, relu; affine=true, track_stats=false), + BatchNorm(init_channel_size, relu; affine=true, track_stats=false), ] for _ in 3:(config.downsample_times) append!( downsample_layers, [ conv3x3(init_channel_size => init_channel_size; stride=2), - EFL.BatchNorm(init_channel_size, relu; affine=true, track_stats=false), + BatchNorm(init_channel_size, relu; affine=true, track_stats=false), ], ) end - downsample = EFL.Chain(downsample_layers...) + downsample = Chain(downsample_layers...) stage0 = if config.downsample_times == 0 && config.num_branches <= 2 - EFL.NoOpLayer() + NoOpLayer() else - EFL.Chain( + Chain( conv1x1(init_channel_size => init_channel_size; bias=false), - EFL.BatchNorm(init_channel_size, relu; affine=true, track_stats=false), + BatchNorm(init_channel_size, relu; affine=true, track_stats=false), ) end - initial_layers = EFL.Chain(downsample, stage0) + initial_layers = Chain(downsample, stage0) main_layers = Tuple( ResidualBlockV1( @@ -261,11 +258,11 @@ function get_model( ) for i in 1:(config.num_branches) ) - mapping_layers = Matrix{EFL.AbstractExplicitLayer}(undef, config.num_branches, config.num_branches) + mapping_layers = Matrix{Lux.AbstractExplicitLayer}(undef, config.num_branches, config.num_branches) for i in 1:(config.num_branches) for j in 1:(config.num_branches) if i == j - mapping_layers[i, j] = EFL.NoOpLayer() + mapping_layers[i, j] = NoOpLayer() elseif i < j mapping_layers[i, j] = downsample_module( config.num_channels[i] => config.num_channels[j], j - i, relu; group_count=config.group_count @@ -283,38 +280,37 @@ function get_model( end post_fuse_layers = Tuple( - EFL.Chain( - # EFL.WrappedFunction(Base.Fix1(broadcast, relu)), - EFL.ActivationFunction(relu), + Chain( + ActivationFunction(relu), conv1x1(config.num_channels[i] => config.num_channels[i]), - # EFL.GroupNorm(config.num_channels[i], config.group_count ÷ 2; affine=true, track_stats=false), - EFL.BatchNorm(config.num_channels[i]; affine=true, track_stats=false), + # GroupNorm(config.num_channels[i], config.group_count ÷ 2; affine=true, track_stats=false), + BatchNorm(config.num_channels[i]; affine=true, track_stats=false), ) for i in 1:(config.num_branches) ) - increment_modules = EFL.Parallel( + increment_modules = Parallel( nothing, [BottleneckBlockV2(config.num_channels[i] => config.head_channels[i]) for i in 1:(config.num_branches)]..., ) - downsample_modules = EFL.PairwiseFusion( + downsample_modules = PairwiseFusion( config.fuse_method == :sum ? (+) : error("Only `fuse_method` = `:sum` is supported"), [ - EFL.Chain( + Chain( conv3x3(config.head_channels[i] * 4 => config.head_channels[i + 1] * 4; stride=2, bias=true), - EFL.BatchNorm(config.head_channels[i + 1] * 4, relu; track_stats=false, affine=true), + BatchNorm(config.head_channels[i + 1] * 4, relu; track_stats=false, affine=true), ) for i in 1:(config.num_branches - 1) ]..., ) - final_layers = EFL.Chain( + final_layers = Chain( increment_modules, downsample_modules, conv1x1(config.head_channels[config.num_branches] * 4 => config.final_channelsize; bias=true), - EFL.BatchNorm(config.final_channelsize, relu; track_stats=false, affine=true), - EFL.GlobalMeanPool(), - EFL.FlattenLayer(), - EFL.Dense(config.final_channelsize, config.num_classes), + BatchNorm(config.final_channelsize, relu; track_stats=false, affine=true), + GlobalMeanPool(), + FlattenLayer(), + Dense(config.final_channelsize, config.num_classes), ) solver = if config.continuous @@ -334,7 +330,7 @@ function get_model( deq = if config.model_type ∈ (:SKIP, :SKIPV2) shortcut = if config.model_type == :SKIP - slayers = EFL.AbstractExplicitLayer[ResidualBlockV2(config.num_channels[1] => config.num_channels[1])] + slayers = Lux.AbstractExplicitLayer[ResidualBlockV2(config.num_channels[1] => config.num_channels[1])] for i in 1:(config.num_branches - 1) push!( slayers, @@ -379,28 +375,25 @@ function get_model( model = DEQChain(initial_layers, deq, final_layers) rng = Random.default_rng() Random.seed!(rng, seed) - ps, st = device.(EFL.setup(rng, model)) - - # Temporary Fix: CUDA.RNG giving errors on Julia 1.7 - st = EFL.update_state(st, :rng, rng) + ps, st = device.(Lux.setup(rng, model)) if warmup clean_println("Starting Model Warmup") - x__ = device(randn(Float32, config.image_size..., 3, 1)) - y__ = device(Float32.(Flux.onehotbatch([1], 0:(config.num_classes - 1)))) + x__ = device(randn(Float32, config.image_size..., 3, 2)) + y__ = device(Float32.(onehotbatch([1, 2], 0:(config.num_classes - 1)))) model(x__, ps, st) clean_println("Forward Pass Warmup Completed") - st_ = EFL.update_state(st, :fixed_depth, 2) + st_ = Lux.update_state(st, :fixed_depth, 2) model(x__, ps, st_) clean_println("Forward Pass (Pretraining) Warmup Completed") lfn = loss_function(config) - (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st), ps) + (l, _, _, _), back = pullback(p -> lfn(x__, y__, model, p, st), ps) back((one(l), nothing, nothing, nothing)) clean_println("Backward Pass Warmup Completed") - (l, _, _, _), back = Flux.pullback(p -> lfn(x__, y__, model, p, st_), ps) + (l, _, _, _), back = pullback(p -> lfn(x__, y__, model, p, st_), ps) back((one(l), nothing, nothing, nothing)) clean_println("Backward Pass (Pretraining) Warmup Completed") diff --git a/examples/src/train.jl b/examples/src/train.jl index 0602ae15..8c48c7bb 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -33,6 +33,10 @@ function construct_optimiser(config::ExperimentConfiguration) opt = Optimisers.OptimiserChain(opt, Optimisers.WeightDecay(config.weight_decay)) end + if is_distributed() + opt = DistributedOptimiser(opt) + end + sched = if config.lr_scheduler == :COSINE ParameterSchedulers.Stateful( ParameterSchedulers.Cos( @@ -78,8 +82,8 @@ function evaluate(model, ps, st, dataloader) total_time += time() - start_time total_nfe += soln.nfe * size(x, ndims(x)) - total_loss += Flux.Losses.logitcrossentropy(ŷ, y) * size(x, ndims(x)) - matches += sum(argmax.(eachcol(cpu(ŷ))) .== Flux.onecold(cpu(y))) + total_loss += logitcrossentropy(ŷ, y) * size(x, ndims(x)) + matches += sum(argmax.(eachcol(cpu(ŷ))) .== onecold(cpu(y))) total_datasize += size(x, ndims(x)) end return (loss=total_loss, accuracy=matches, mean_nfe=total_nfe, total_time=total_time, total_datasize=total_datasize) @@ -91,14 +95,14 @@ function loss_function(c::ImageClassificationModelConfiguration; λ_skip=1.0f-3) if c.model_type == :VANILLA function loss_function_closure_1(x, y, model, ps, st) (ŷ, soln), st_ = model(x, ps, st) - loss = Flux.Losses.logitcrossentropy(ŷ, y) + loss = logitcrossentropy(ŷ, y) return loss, ŷ, st_, soln.nfe end return loss_function_closure_1 else function loss_function_closure_2(x, y, model, ps, st) (ŷ, soln), st_ = model(x, ps, st) - loss = Flux.Losses.logitcrossentropy(ŷ, y) + λ_skip * Flux.Losses.mse(soln.u₀, soln.z_star) + loss = logitcrossentropy(ŷ, y) + λ_skip * mse(soln.u₀, soln.z_star) return loss, ŷ, st_, soln.nfe end return loss_function_closure_2 @@ -123,13 +127,13 @@ function train_one_epoch( # Compute Loss + Backprop + Update start_time = time() - (loss, ŷ, st, nfe), back = Flux.pullback(p -> loss_function(x, y, model, p, st), ps) + (loss, ŷ, st, nfe), back = pullback(p -> loss_function(x, y, model, p, st), ps) gs, = back((one(loss), nothing, nothing, nothing)) opt_state, ps = Optimisers.update!(opt_state, ps, gs) total_time += time() - start_time - acc = sum(argmax.(eachcol(cpu(ŷ))) .== Flux.onecold(cpu(y))) / size(x, 4) + acc = sum(argmax.(eachcol(cpu(ŷ))) .== onecold(cpu(y))) / size(x, 4) iteration_count += 1 st = econfig.pretrain_steps == iteration_count ? EFL.update_state(st, :fixed_depth, 0) : st From cad992e5a92697472c2a0d9729f931a4ebe51a3c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 3 May 2022 18:42:36 -0400 Subject: [PATCH 50/76] Compat entries --- examples/Project.toml | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/examples/Project.toml b/examples/Project.toml index a512ba55..840fe02e 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -8,7 +8,6 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" FastDEQ = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" Format = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -21,3 +20,19 @@ ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +CUDA = "3" +DataLoaders = "0.1" +FluxMPI = "0.4" +Format = "1.3" +Lux = "0.3" +MLDatasets = "0.5" +MPI = "0.19" +NNlib = "0.8" +Optimisers = "0.2" +OrdinaryDiffEq = "6" +ParameterSchedulers = "0.3" +Setfield = "0.8" +Zygote = "0.6" +julia = "1.6" From db985464c15fee9252da1bd5abdb1f334e69b024 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 May 2022 20:32:54 -0400 Subject: [PATCH 51/76] Dep fixes --- examples/Project.toml | 6 ++++++ examples/cifar10/script.jl | 2 +- examples/src/FastDEQExperiments.jl | 10 ++++++++++ examples/src/train.jl | 6 +++--- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/examples/Project.toml b/examples/Project.toml index 840fe02e..8221432a 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -8,10 +8,14 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" FastDEQ = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" Format = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MLDataPattern = "9920b226-0b2a-5f5f-9153-9aa70a013f8b" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" @@ -24,10 +28,12 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] CUDA = "3" DataLoaders = "0.1" +Flux = "0.13" FluxMPI = "0.4" Format = "1.3" Lux = "0.3" MLDatasets = "0.5" +MLUtils = "0.2" MPI = "0.19" NNlib = "0.8" Optimisers = "0.2" diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl index d8e0eda1..db8b05f0 100644 --- a/examples/cifar10/script.jl +++ b/examples/cifar10/script.jl @@ -1,4 +1,4 @@ -using FastDEQExperiments, Flux, CUDA, Optimisers, Dates, FluxMPI +using FastDEQExperiments, Lux, CUDA, Optimisers, Dates, FluxMPI # Distributed Training FluxMPI.Init(; verbose=true) diff --git a/examples/src/FastDEQExperiments.jl b/examples/src/FastDEQExperiments.jl index c6147a22..dbcbb68c 100644 --- a/examples/src/FastDEQExperiments.jl +++ b/examples/src/FastDEQExperiments.jl @@ -6,6 +6,7 @@ using FastDEQ, OrdinaryDiffEq, FluxMPI, Format, + Funtors, Lux, MLDatasets, Optimisers, @@ -18,6 +19,8 @@ using FastDEQ, import Flux: OneHotArray, onecold, onehotbatch, onehot import Flux.Losses: logitcrossentropy, mse +import MLUtils: shuffleobs +import MLDataPattern, MLUtils # Memory Management relieve_gc_pressure(::Union{Nothing,<:AbstractArray}) = nothing @@ -42,4 +45,11 @@ include("models.jl") # get_dataloaders include("dataloaders.jl") + +# Fallback since DataLoaders.jl still relies on MLDataPattern +MLDataPattern.nobs(x) = MLUtils.numobs(x) +MLDataPattern.getobs(d::Union{MLUtils.ObsView,MLDatasetsImageData,DistributedDataContainer}, i::Int64) = + MLUtils.getobs(d, i) + + end \ No newline at end of file diff --git a/examples/src/train.jl b/examples/src/train.jl index 8c48c7bb..31468fcf 100644 --- a/examples/src/train.jl +++ b/examples/src/train.jl @@ -74,7 +74,7 @@ end evaluate(model, ps, st, ::Nothing) = nothing function evaluate(model, ps, st, dataloader) - st_eval = EFL.testmode(st) + st_eval = Lux.testmode(st) matches, total_loss, total_datasize, total_nfe, total_time = 0, 0, 0, 0, 0 for (x, y) in CuIterator(dataloader) start_time = time() @@ -136,7 +136,7 @@ function train_one_epoch( acc = sum(argmax.(eachcol(cpu(ŷ))) .== onecold(cpu(y))) / size(x, 4) iteration_count += 1 - st = econfig.pretrain_steps == iteration_count ? EFL.update_state(st, :fixed_depth, 0) : st + st = econfig.pretrain_steps == iteration_count ? Lux.update_state(st, :fixed_depth, 0) : st # Run ParameterScheduler eta_new = ParameterSchedulers.next!(scheduler) @@ -169,7 +169,7 @@ function train( opt_state = is_distributed() ? FluxMPI.synchronize!(opt_state; root_rank=0) : opt_state iteration_count = 0 - st = econfig.pretrain_steps != 0 ? EFL.update_state(st, :fixed_depth, econfig.model_config.num_layers) : st + st = econfig.pretrain_steps != 0 ? Lux.update_state(st, :fixed_depth, econfig.model_config.num_layers) : st for epoch in 1:nepochs # Train 1 epoch From 19f2911ff90897a44a021cda1889925085121d71 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 May 2022 20:46:34 -0400 Subject: [PATCH 52/76] Dep fixes --- examples/src/FastDEQExperiments.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/FastDEQExperiments.jl b/examples/src/FastDEQExperiments.jl index dbcbb68c..218790ff 100644 --- a/examples/src/FastDEQExperiments.jl +++ b/examples/src/FastDEQExperiments.jl @@ -6,7 +6,7 @@ using FastDEQ, OrdinaryDiffEq, FluxMPI, Format, - Funtors, + Functors, Lux, MLDatasets, Optimisers, From 0eca98f7f1b3a17b9668c427769f00b0e486c694 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 May 2022 17:00:46 -0400 Subject: [PATCH 53/76] Update src --- .github/workflows/CI.yml | 4 ++- Project.toml | 2 +- README.md | 5 +++- src/FastDEQ.jl | 3 +- src/adjoint.jl | 63 +++++++++++++++++++++++++++++++++++----- src/layers/core.jl | 10 +++++-- src/layers/deq.jl | 16 +++++----- src/layers/mdeq.jl | 20 ++++++------- src/losses.jl | 43 --------------------------- src/utils.jl | 41 +++++++++++++++++++------- test/runtests.jl | 26 ++++++++--------- 11 files changed, 136 insertions(+), 97 deletions(-) delete mode 100644 src/losses.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 4ae2716e..747b614b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -33,5 +33,7 @@ jobs: run: julia --project=. -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/Lux.jl"); Pkg.instantiate()' - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v2 with: - coverage: false \ No newline at end of file + files: lcov.info \ No newline at end of file diff --git a/Project.toml b/Project.toml index 9f38cdf6..a0264923 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ DiffEqCallbacks = "2.20.1" DiffEqSensitivity = "6.64" Functors = "0.2" LinearSolve = "1" -Lux = "0.3" +Lux = "0.4" MLUtils = "0.2" OrdinaryDiffEq = "6" SciMLBase = "1.19" diff --git a/README.md b/README.md index 0e239780..11402f15 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,9 @@ # FastDEQ -![Dynamics Overview](assets/dynamics_overview.gif) [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://fastdeq.sciml.ai/dev/) [![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://fastdeq.sciml.ai/stable/) +[![codecov](https://codecov.io/gh/SciML/FastDEQ.jl/branch/main/graph/badge.svg?token=plksEh6pUG)](https://codecov.io/gh/SciML/FastDEQ.jl) +[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) + +Deep Equilibrium Networks using [Lux.jl](https://lux.csail.mit.edu/dev) and [DifferentialEquations.jl](https://diffeq.sciml.ai/stable/) diff --git a/src/FastDEQ.jl b/src/FastDEQ.jl index f6849bd2..6fca4152 100644 --- a/src/FastDEQ.jl +++ b/src/FastDEQ.jl @@ -19,6 +19,7 @@ using ChainRulesCore, UnPack, Zygote +import DiffEqSensitivity: AbstractAdjointSensitivityAlgorithm import Lux: AbstractExplicitContainerLayer, initialparameters, initialstates, parameterlength, statelength import Random: AbstractRNG @@ -47,7 +48,7 @@ include("adjoint.jl") export ContinuousDEQSolver, DiscreteDEQSolver, BroydenSolver, LimitedMemoryBroydenSolver # Utils -export NormalInitializer, SteadyStateAdjoint, compute_deq_jacobian_loss, DeepEquilibriumSolution +export NormalInitializer, DeepEquilibriumAdjoint, compute_deq_jacobian_loss, DeepEquilibriumSolution export DeepEquilibriumNetwork, SkipDeepEquilibriumNetwork, MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork, DEQChain diff --git a/src/adjoint.jl b/src/adjoint.jl index 5968ff17..75c0987b 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -2,7 +2,7 @@ neg(x::Any) = hasmethod(-, (typeof(x),)) ? -x : x neg(nt::NamedTuple) = fmap(neg, nt) @noinline function DiffEqSensitivity.SteadyStateAdjointProblem( - sol::EquilibriumSolution, sensealg::DiffEqSensitivity.SteadyStateAdjoint, g::Nothing, dg; save_idxs=nothing + sol::EquilibriumSolution, sensealg::DeepEquilibriumAdjoint, g::Nothing, dg; save_idxs=nothing ) @unpack f, p, u0 = sol.prob @@ -19,12 +19,19 @@ neg(nt::NamedTuple) = fmap(neg, nt) end end - # Solve the Linear Problem - _val, back = Zygote.pullback(x -> f(x, p, nothing), y) - s_val = size(_val) - op = ZygotePullbackMultiplyOperator{eltype(y),typeof(back),typeof(s_val)}(back, s_val) - linear_problem = LinearProblem(op, vec(diffcache.dg_val)) - λ = solve(linear_problem, sensealg.linsolve).u + if check_adjoint_mode(sensealg, Val(:vanilla)) + # Solve the Linear Problem + _val, back = Zygote.pullback(x -> f(x, p, nothing), y) + s_val = size(_val) + op = ZygotePullbackMultiplyOperator{eltype(y),typeof(back),typeof(s_val)}(back, s_val) + linear_problem = LinearProblem(op, vec(diffcache.dg_val)) + λ = solve(linear_problem, sensealg.linsolve).u + elseif check_adjoint_mode(sensealg, Val(:jfb)) + # Jacobian Free Backpropagation + λ = diffcache.dg_val + else + error("Unknown adjoint mode") + end # Compute the VJP _, back = Zygote.pullback(p -> vec(f(y, p, nothing)), p) @@ -32,3 +39,45 @@ neg(nt::NamedTuple) = fmap(neg, nt) return neg(dp) end + +function DiffEqBase._concrete_solve_adjoint( + prob::SteadyStateProblem, alg, sensealg::DeepEquilibriumAdjoint, u0, p, args...; save_idxs=nothing, kwargs... +) + _prob = remake(prob; u0=u0, p=p) + sol = solve(_prob, alg, args...; kwargs...) + _save_idxs = save_idxs === nothing ? Colon() : save_idxs + + if save_idxs === nothing + out = sol + else + out = DiffEqBase.sensitivity_solution(sol, sol[_save_idxs]) + end + + function steadystatebackpass(Δ) + # Δ = dg/dx or diffcache.dg_val + # del g/del p = 0 + dp = adjoint_sensitivities(sol, alg; sensealg=sensealg, g=nothing, dg=Δ, save_idxs=save_idxs) + return ( + NoTangent(), + NoTangent(), + NoTangent(), + NoTangent(), + dp, + NoTangent(), + ntuple(_ -> NoTangent(), length(args))..., + ) + end + return out, steadystatebackpass +end + +function DiffEqSensitivity._adjoint_sensitivities( + sol, sensealg::DeepEquilibriumAdjoint, alg, g, dg=nothing; abstol=1e-6, reltol=1e-3, kwargs... +) + return DiffEqSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...) +end + +function DiffEqSensitivity._adjoint_sensitivities( + sol, sensealg::DeepEquilibriumAdjoint, alg; g=nothing, dg=nothing, abstol=1e-6, reltol=1e-3, kwargs... +) + return DiffEqSensitivity.SteadyStateAdjointProblem(sol, sensealg, g, dg; kwargs...) +end diff --git a/src/layers/core.jl b/src/layers/core.jl index ac4a727b..fb94d9f5 100644 --- a/src/layers/core.jl +++ b/src/layers/core.jl @@ -1,17 +1,23 @@ abstract type AbstractDeepEquilibriumNetwork <: AbstractExplicitContainerLayer{(:model,)} end function initialstates(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork) - return (model=initialstates(rng, deq.model), fixed_depth=0) + return (model=initialstates(rng, deq.model), fixed_depth=Val(0)) end abstract type AbstractSkipDeepEquilibriumNetwork <: AbstractExplicitContainerLayer{(:model,:shortcut)} end function initialstates(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork) return ( - model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut), fixed_depth=0 + model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut), fixed_depth=Val(0) ) end +check_unrolled_mode(::Val{0}) = false +check_unrolled_mode(::Val{d}) where {d} = d >= 1 +check_unrolled_mode(st::NamedTuple) = check_unrolled_mode(st.fixed_depth) +get_unrolled_depth(::Val{d}) where {d} = d +get_unrolled_depth(st::NamedTuple) = get_unrolled_depth(st.fixed_depth) + """ DeepEquilibriumSolution(z_star, u₀, residual, jacobian_loss, nfe) diff --git a/src/layers/deq.jl b/src/layers/deq.jl index e4a83081..6c3ade3f 100644 --- a/src/layers/deq.jl +++ b/src/layers/deq.jl @@ -6,7 +6,7 @@ struct DeepEquilibriumNetwork{J,M,A,S,K} <: AbstractDeepEquilibriumNetwork end function DeepEquilibriumNetwork( - model, solver; jacobian_regularization::Bool=false, sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), kwargs... + model, solver; jacobian_regularization::Bool=false, sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs... ) return DeepEquilibriumNetwork{jacobian_regularization,typeof(model),typeof(solver),typeof(sensealg),typeof(kwargs)}( model, solver, sensealg, kwargs @@ -18,16 +18,16 @@ function (deq::DeepEquilibriumNetwork{J})( ) where {J,T} z = zero(x) - if !iszero(st.fixed_depth) + if check_unrolled_mode(st) # Pretraining without Fixed Point Solving st_ = st.model z_star = z - for _ in 1:(st.fixed_depth) + for _ in 1:get_unrolled_depth(st) z_star, st_ = deq.model((z_star, x), ps, st_) end @set! st.model = Lux.update_state(st_, :update_mask, true) - return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, st.fixed_depth)), st + return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, get_unrolled_depth(st))), st end st_ = st.model @@ -62,7 +62,7 @@ function SkipDeepEquilibriumNetwork( shortcut, solver; jacobian_regularization::Bool=false, - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs..., ) return SkipDeepEquilibriumNetwork{ @@ -82,16 +82,16 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( end @set! st.shortcut = st__ - if !iszero(st.fixed_depth) + if check_unrolled_mode(st) # Pretraining without Fixed Point Solving st_ = st.model z_star = z - for _ in 1:(st.fixed_depth) + for _ in 1:get_unrolled_depth(st) z_star, st_ = deq.model((z_star, x), ps.model, st_) end @set! st.model = Lux.update_state(st_, :update_mask, true) - return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, st.fixed_depth)), st + return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, get_unrolled_depth(st))), st end st_ = st.model diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index 1a5455e3..287abd07 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -7,7 +7,7 @@ struct MultiScaleDeepEquilibriumNetwork{N,L,M,A,S,K} <: AbstractDeepEquilibriumN end function initialstates(rng::AbstractRNG, deq::MultiScaleDeepEquilibriumNetwork) - return (model=initialstates(rng, deq.model), split_idxs=Tuple(vcat(0, cumsum(prod.(deq.scales))...)), fixed_depth=0) + return (model=initialstates(rng, deq.model), split_idxs=Tuple(vcat(0, cumsum(prod.(deq.scales))...)), fixed_depth=Val(0)) end function MultiScaleDeepEquilibriumNetwork( @@ -16,7 +16,7 @@ function MultiScaleDeepEquilibriumNetwork( post_fuse_layer::Union{Nothing,Tuple}, solver, scales; - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs..., ) l1 = Parallel(nothing, main_layers...) @@ -46,17 +46,17 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( ) where {N,T} z, st = get_initial_condition_mdeq(deq.scales, x, st) - if !iszero(st.fixed_depth) + if check_unrolled_mode(st) z_star = split_and_reshape(z, st.split_idxs, deq.scales) st_ = st.model - for _ in 1:(st.fixed_depth) + for _ in 1:get_unrolled_depth(st) z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps, st_) end @set! st.model = Lux.update_state(st_, :update_mask, true) - return (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, z, 0.0f0, st.fixed_depth)), st + return (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, z, 0.0f0, get_unrolled_depth(st))), st end st_ = st.model @@ -96,7 +96,7 @@ function initialstates(rng::AbstractRNG, deq::MultiScaleSkipDeepEquilibriumNetwo model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut), split_idxs=Tuple(vcat(0, cumsum(prod.(deq.scales))...)), - fixed_depth=0, + fixed_depth=Val(0), ) end @@ -107,7 +107,7 @@ function MultiScaleSkipDeepEquilibriumNetwork( shortcut_layers::Union{Nothing,Tuple}, solver, scales; - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs..., ) l1 = Parallel(nothing, main_layers...) @@ -141,17 +141,17 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( (vcat(flatten.(z0)...), st) end - if !iszero(st.fixed_depth) + if check_unrolled_mode(st) z_star = split_and_reshape(z, st.split_idxs, deq.scales) st_ = st.model - for _ in 1:(st.fixed_depth) + for _ in 1:get_unrolled_depth(st) z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps.model, st_) end @set! st.model = Lux.update_state(st_, :update_mask, true) - return (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, z, 0.0f0, st.fixed_depth)), st + return (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, z, 0.0f0, get_unrolled_depth(st))), st end st_ = st.model diff --git a/src/losses.jl b/src/losses.jl deleted file mode 100644 index a7b81277..00000000 --- a/src/losses.jl +++ /dev/null @@ -1,43 +0,0 @@ -""" - SupervisedLossContainer(loss_function) - SupervisedLossContainer(loss_function, λ, λⱼ) - -A container class for supervised loss functions. -""" -Base.@kwdef struct SupervisedLossContainer{L,T} - loss_function::L - λ::T = 0.0f0 - λⱼ::T = 0.0f0 -end - -function (lc::SupervisedLossContainer)(soln::DeepEquilibriumSolution) - return lc.λ * mean(abs, soln.u₀ .- soln.z_star) + lc.λⱼ * soln.jacobian_loss -end - -function (lc::SupervisedLossContainer)(soln::DeepEquilibriumSolution{T}) where {T<:Tuple} - return lc.λ * mapreduce((x, y) -> mean(abs, x .- y), +, soln.u₀, soln.z_star) + - lc.λⱼ * soln.jacobian_loss -end - -function (lc::SupervisedLossContainer)(model::Union{DeepEquilibriumNetwork,SkipDeepEquilibriumNetwork,DEQChain}, x, y; - kwargs...) - ŷ, soln = model(x; kwargs...) - return lc.loss_function(ŷ, y) + lc(soln) -end - -function (lc::SupervisedLossContainer)(model::Union{MultiScaleDeepEquilibriumNetwork, - MultiScaleSkipDeepEquilibriumNetwork}, x, ys::Tuple; kwargs...) - yŝ, soln = model(x; kwargs...) - return mapreduce(lc.loss_function, +, ys, yŝ) + lc(soln) -end - -function (lc::SupervisedLossContainer)(model::Union{MultiScaleDeepEquilibriumNetwork, - MultiScaleSkipDeepEquilibriumNetwork}, x, y; kwargs...) - yŝ, soln = model(x; kwargs...) - return sum(Base.Fix2(lc.loss_function, y), yŝ) + lc(soln) -end - -# Default fallback -function (lc::SupervisedLossContainer)(model, x, y; kwargs...) - return lc.loss_function(model(x; kwargs...), y) -end diff --git a/src/utils.jl b/src/utils.jl index 83336047..96405cda 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,9 +1,10 @@ # General DEQ Utils """ - SteadyStateAdjoint(reltol, abstol, maxiters; autojacvec=ZygoteVJP(), - linsolve=KrylovJL_GMRES(; rtol=reltol, atol=abstol, itmax=maxiters)) + DeepEquilibriumAdjoint(reltol, abstol, maxiters; autojacvec=ZygoteVJP(), + linsolve=KrylovJL_GMRES(; rtol=reltol, atol=abstol, itmax=maxiters), + mode=:vanilla) -Creates SteadyStateAdjoint ([johnson2012notes](@cite)) with sensible defaults. +Creates DeepEquilibriumAdjoint ([johnson2012notes](@cite)) with sensible defaults. ## Arguments @@ -12,10 +13,32 @@ Creates SteadyStateAdjoint ([johnson2012notes](@cite)) with sensible defaults. * `maxiters`: Maximum number of iterations. * `autojacvec`: Which backend to use for VJP. * `linsolve`: Linear Solver from [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl). +* `mode`: Adjoint mode. Currently only `:vanilla` & `:jfb` are supported. """ -function DiffEqSensitivity.SteadyStateAdjoint(reltol, abstol, maxiters; autojacvec=ZygoteVJP(), - linsolve=KrylovJL_GMRES(; rtol=reltol, atol=abstol, itmax=maxiters)) - return SteadyStateAdjoint(; autodiff=true, autojacvec=autojacvec, linsolve=linsolve) +struct DeepEquilibriumAdjoint{CS,AD,FDT,M,VJP,LS} <: AbstractAdjointSensitivityAlgorithm{CS,AD,FDT} + autojacvec::VJP + linsolve::LS +end + +@inline check_adjoint_mode(::DeepEquilibriumAdjoint{CS,AD,FDT,M}, ::Val{M}) where {CS,AD,FDT,M} = true +@inline check_adjoint_mode(::DeepEquilibriumAdjoint, ::Val) = false + +Base.@pure function DeepEquilibriumAdjoint( + reltol, + abstol, + maxiters; + autojacvec=ZygoteVJP(), + linsolve=KrylovJL_GMRES(; rtol=reltol, atol=abstol, itmax=maxiters), + autodiff=true, + chunk_size=0, + diff_type=Val{:central}, + mode::Symbol=:vanilla, +) + return DeepEquilibriumAdjoint{ + chunk_size,autodiff,diff_type,mode,typeof(autojacvec),typeof(linsolve) + }( + autojacvec, linsolve + ) end # Initialization @@ -25,15 +48,13 @@ end Initializes the weights of the network with a normal distribution. For DEQs the training is stable if we use this as the Initialization """ -function NormalInitializer(μ = 0.0f0, σ² = 0.01f0) +function NormalInitializer(μ=0.0f0, σ²=0.01f0) return (rng::AbstractRNG, dims...) -> randn(rng, Float32, dims...) .* σ² .+ μ end # For MultiScale DEQs function split_and_reshape(x::AbstractMatrix, idxs::Tuple, shapes::Tuple) - return Tuple( - @view(x[reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...), :]) for i in 1:(length(idxs) - 1) - ) + return Tuple(@view(x[reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...), :]) for i in 1:(length(idxs) - 1)) end # Zygote Fix diff --git a/test/runtests.jl b/test/runtests.jl index 8bfb7b8f..609bf56c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,7 +34,7 @@ end @test test_gradient_isfinite(gs) @info "Testing DEQ without Fixed Point Iterations" - st = Lux.update_state(st, :fixed_depth, 5) + st = Lux.update_state(st, :fixed_depth, Val(5)) gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1] @@ -48,7 +48,7 @@ end Parallel(+, Dense(2, 2), Dense(2, 2)), Dense(2, 2), ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0); - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), ), ) ps, st = gpu.(Lux.setup(rng, model)) @@ -63,7 +63,7 @@ end @test test_gradient_isfinite(gs) @info "Testing SkipDEQ without Fixed Point Iterations" - st = Lux.update_state(st, :fixed_depth, 5) + st = Lux.update_state(st, :fixed_depth, Val(5)) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -80,7 +80,7 @@ end Parallel(+, Dense(2, 2), Dense(2, 2)), nothing, ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0); - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), ), ) ps, st = gpu.(Lux.setup(rng, model)) @@ -95,7 +95,7 @@ end @test test_gradient_isfinite(gs) @info "Testing SkipDEQV2 without Fixed Point Iterations" - st = Lux.update_state(st, :fixed_depth, 5) + st = Lux.update_state(st, :fixed_depth, Val(5)) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -112,7 +112,7 @@ end Parallel(+, Dense(2, 2), Dense(2, 2)), Dense(2, 2), DiscreteDEQSolver(BroydenSolver(); abstol_termination=0.1f0, reltol_termination=0.1f0); - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), ), ) ps, st = gpu.(Lux.setup(rng, model)) @@ -134,7 +134,7 @@ end Parallel(+, Dense(2, 2), Dense(2, 2)), Dense(2, 2), DiscreteDEQSolver(LimitedMemoryBroydenSolver(); abstol_termination=0.1f0, reltol_termination=0.1f0); - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), ), ) ps, st = gpu.(Lux.setup(rng, model)) @@ -166,7 +166,7 @@ end nothing, ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), ((4,), (3,), (2,), (1,)); - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), ) ps, st = gpu.(Lux.setup(rng, model)) @@ -181,7 +181,7 @@ end @test test_gradient_isfinite(gs) @info "Testing MultiScaleDEQ without Fixed Point Iterations" - st = Lux.update_state(st, :fixed_depth, 5) + st = Lux.update_state(st, :fixed_depth, Val(5)) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -209,7 +209,7 @@ end (Dense(4, 4, tanh), Dense(4, 3, tanh), Dense(4, 2, tanh), Dense(4, 1, tanh)), ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), ((4,), (3,), (2,), (1,)); - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), ) ps, st = gpu.(Lux.setup(rng, model)) @@ -224,7 +224,7 @@ end @test test_gradient_isfinite(gs) @info "Testing MultiScaleSkipDEQ without Fixed Point Iterations" - st = Lux.update_state(st, :fixed_depth, 5) + st = Lux.update_state(st, :fixed_depth, Val(5)) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -252,7 +252,7 @@ end nothing, ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), ((4,), (3,), (2,), (1,)); - sensealg=SteadyStateAdjoint(0.1f0, 0.1f0, 10), + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), ) ps, st = gpu.(Lux.setup(rng, model)) @@ -267,7 +267,7 @@ end @test test_gradient_isfinite(gs) @info "Testing MultiScaleSkipDEQV2 without Fixed Point Iterations" - st = Lux.update_state(st, :fixed_depth, 5) + st = Lux.update_state(st, :fixed_depth, Val(5)) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) From 25cea0f3cf82ef9c44222507aeba2190ad388219 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 May 2022 17:31:01 -0400 Subject: [PATCH 54/76] Docs --- .github/workflows/Documentation.yml | 3 ++- docs/src/index.md | 17 +++-------------- src/layers/core.jl | 10 +++++----- 3 files changed, 10 insertions(+), 20 deletions(-) diff --git a/.github/workflows/Documentation.yml b/.github/workflows/Documentation.yml index fe19c974..215bb2ac 100644 --- a/.github/workflows/Documentation.yml +++ b/.github/workflows/Documentation.yml @@ -16,7 +16,8 @@ jobs: with: version: '1' - name: Install dependencies - run: julia --project=docs -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/FluxExperimental.jl"); Pkg.add(url="https://github.com/SciML/DiffEqSensitivity.jl", rev="ap/fastdeq"); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' + # FIXME: Remove once Lux.jl is registered + run: julia --project=. -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/Lux.jl"); Pkg.instantiate()' - name: Build and deploy env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token diff --git a/docs/src/index.md b/docs/src/index.md index b45cc7c2..150f7276 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,14 +1,13 @@ # FastDEQ: (Fast) Deep Equlibrium Networks -FastDEQ.jl is a framework built on top of [DifferentialEquations.jl](https://diffeq.sciml.ai/stable/) and [Flux.jl](https://fluxml.ai) enabling the efficient training and inference for Deep Equilibrium Networks (Infinitely Deep Neural Networks). +FastDEQ.jl is a framework built on top of [DifferentialEquations.jl](https://diffeq.sciml.ai/stable/) and [Lux.jl](https://lux.csail.mit.edu/dev/) enabling the efficient training and inference for Deep Equilibrium Networks (Infinitely Deep Neural Networks). ## Installation Currently the package is not registered and requires manually installing a few dependencies. We are working towards upstream fixes which will make installation easier ```julia -] add https://github.com/SciML/DiffEqSensitivity.jl.git#ap/fastdeq -] add https://github.com/avik-pal/FluxExperimental.jl.git#main +] add https://github.com/avik-pal/Lux.jl.git#main ] add https://github.com/SciML/FastDEQ.jl ``` @@ -27,14 +26,4 @@ If you are using this project for research or other academic purposes consider c } ``` -For specific algorithms, check the respective documentations and cite the corresponding papers. - -## FAQs - -#### How do I reproduce the experiments in the paper -- *Mixing Implicit and Explicit Deep Learning with Skip DEQs and Infinite Time Neural ODEs (Continuous DEQs)*? - -Check out the `ap/paper` branch for the code corresponding to that paper. - -#### Are there some tutorials? - -We are working on adding some in the near future. In the meantime, please checkout the `experiments` directory in the `ap/paper` branch. You can also check `test/runtests.jl` for some simple examples. \ No newline at end of file +For specific algorithms, check the respective documentations and cite the corresponding papers. \ No newline at end of file diff --git a/src/layers/core.jl b/src/layers/core.jl index fb94d9f5..47f5e927 100644 --- a/src/layers/core.jl +++ b/src/layers/core.jl @@ -12,11 +12,11 @@ function initialstates(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork ) end -check_unrolled_mode(::Val{0}) = false -check_unrolled_mode(::Val{d}) where {d} = d >= 1 -check_unrolled_mode(st::NamedTuple) = check_unrolled_mode(st.fixed_depth) -get_unrolled_depth(::Val{d}) where {d} = d -get_unrolled_depth(st::NamedTuple) = get_unrolled_depth(st.fixed_depth) +@inline check_unrolled_mode(::Val{0}) = false +@inline check_unrolled_mode(::Val{d}) where {d} = d >= 1 +@inline check_unrolled_mode(st::NamedTuple) = check_unrolled_mode(st.fixed_depth) +@inline get_unrolled_depth(::Val{d}) where {d} = d +@inline get_unrolled_depth(st::NamedTuple) = get_unrolled_depth(st.fixed_depth) """ DeepEquilibriumSolution(z_star, u₀, residual, jacobian_loss, nfe) From c0ed127341d4d538d52d8d886d233405c987aeb3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 May 2022 17:41:58 -0400 Subject: [PATCH 55/76] Docs --- .github/workflows/Documentation.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/Documentation.yml b/.github/workflows/Documentation.yml index 215bb2ac..7f68b361 100644 --- a/.github/workflows/Documentation.yml +++ b/.github/workflows/Documentation.yml @@ -17,7 +17,7 @@ jobs: version: '1' - name: Install dependencies # FIXME: Remove once Lux.jl is registered - run: julia --project=. -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/Lux.jl"); Pkg.instantiate()' + run: julia --project=docs -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/Lux.jl"); Pkg.instantiate()' - name: Build and deploy env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token From 7b716125eaf18ca535f92c846999753d33462831 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 May 2022 18:31:36 -0400 Subject: [PATCH 56/76] Docs --- .github/workflows/Documentation.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/Documentation.yml b/.github/workflows/Documentation.yml index 7f68b361..456ea6a9 100644 --- a/.github/workflows/Documentation.yml +++ b/.github/workflows/Documentation.yml @@ -4,7 +4,7 @@ on: push: branches: - main - tags: '*' + tags: "*" pull_request: jobs: @@ -14,10 +14,10 @@ jobs: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 with: - version: '1' + version: "1" - name: Install dependencies # FIXME: Remove once Lux.jl is registered - run: julia --project=docs -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/Lux.jl"); Pkg.instantiate()' + run: julia --project=docs -e 'using Pkg; Pkg.add(url="https://github.com/avik-pal/Lux.jl"); Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' - name: Build and deploy env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # For authentication with GitHub Actions token From ea24bfca17ecafc97ebff32479e42354f2df5be3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 May 2022 20:13:16 -0400 Subject: [PATCH 57/76] make it val --- src/layers/deq.jl | 8 ++++---- src/layers/mdeq.jl | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/layers/deq.jl b/src/layers/deq.jl index 6c3ade3f..f2309676 100644 --- a/src/layers/deq.jl +++ b/src/layers/deq.jl @@ -25,7 +25,7 @@ function (deq::DeepEquilibriumNetwork{J})( for _ in 1:get_unrolled_depth(st) z_star, st_ = deq.model((z_star, x), ps, st_) end - @set! st.model = Lux.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, get_unrolled_depth(st))), st end @@ -44,7 +44,7 @@ function (deq::DeepEquilibriumNetwork{J})( jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps, st.model, z_star, x) : T(0)) residual = z_star .- deq.model((z_star, x), ps, st.model)[1] - @set! st.model = Lux.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end @@ -89,7 +89,7 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( for _ in 1:get_unrolled_depth(st) z_star, st_ = deq.model((z_star, x), ps.model, st_) end - @set! st.model = Lux.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, get_unrolled_depth(st))), st end @@ -108,7 +108,7 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : T(0)) residual = z_star .- deq.model((z_star, x), ps.model, st.model)[1] - @set! st.model = Lux.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index 287abd07..225297d6 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -54,7 +54,7 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps, st_) end - @set! st.model = Lux.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) return (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, z, 0.0f0, get_unrolled_depth(st))), st end @@ -75,7 +75,7 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( residual = dudt(sol.u, ps, nothing) - @set! st.model = Lux.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) return ( (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st @@ -149,7 +149,7 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps.model, st_) end - @set! st.model = Lux.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) return (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, z, 0.0f0, get_unrolled_depth(st))), st end @@ -170,7 +170,7 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( residual = dudt(sol.u, ps.model, nothing) - @set! st.model = Lux.update_state(st_, :update_mask, true) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) return ( (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st From afc8b7999e24616cb97599c187fe395b4390f476 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 13 May 2022 16:34:56 -0400 Subject: [PATCH 58/76] Type Inference fixes --- Project.toml | 1 + docs/make.jl | 12 +- docs/src/api/solvers.md | 15 -- docs/src/{api => manual}/deqs.md | 0 docs/src/{api => manual}/layers.md | 1 - docs/src/{api => manual}/misc.md | 0 docs/src/{api => manual}/nlsolve.md | 2 +- docs/src/manual/solvers.md | 38 ++++ src/FastDEQ.jl | 1 + src/adjoint.jl | 7 +- src/layers/chain.jl | 11 ++ src/layers/core.jl | 13 +- src/layers/deq.jl | 106 ++++++++++- src/layers/mdeq.jl | 264 +++++++++++++++++++++------- src/solvers/continuous.jl | 28 +-- src/solvers/discrete.jl | 20 +-- src/utils.jl | 43 ++--- test/runtests.jl | 28 +++ 18 files changed, 426 insertions(+), 164 deletions(-) delete mode 100644 docs/src/api/solvers.md rename docs/src/{api => manual}/deqs.md (100%) rename docs/src/{api => manual}/layers.md (64%) rename docs/src/{api => manual}/misc.md (100%) rename docs/src/{api => manual}/nlsolve.md (98%) create mode 100644 docs/src/manual/solvers.md diff --git a/Project.toml b/Project.toml index a0264923..0c102336 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" diff --git a/docs/make.jl b/docs/make.jl index 475c377e..7a372005 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -14,12 +14,12 @@ makedocs( canonical="https://fastdeq.sciml.ai/stable/"), pages = [ "FastDEQ: Fast Deep Equilibrium Networks" => "index.md", - "API" => [ - "Dynamical Systems" => "api/solvers.md", - "Non Linear Solvers" => "api/nlsolve.md", - "General Purpose Layers" => "api/layers.md", - "DEQ Layers" => "api/deqs.md", - "Miscellaneous" => "api/misc.md", + "Manual" => [ + "Dynamical Systems" => "manual/solvers.md", + "Non Linear Solvers" => "manual/nlsolve.md", + "General Purpose Layers" => "manual/layers.md", + "DEQ Layers" => "manual/deqs.md", + "Miscellaneous" => "manual/misc.md", ], "References" => "references.md", ] diff --git a/docs/src/api/solvers.md b/docs/src/api/solvers.md deleted file mode 100644 index 13cd0d05..00000000 --- a/docs/src/api/solvers.md +++ /dev/null @@ -1,15 +0,0 @@ -# Dynamical System Variants - -[baideep2019](@cite) introduced Discrete Deep Equilibrium Models which drives a Discrete Dynamical System to its steady-state. [pal2022mixing](@cite) extends this framework to Continuous Dynamical Systems which converge to the steady-stable in a more stable fashion. For a detailed discussion refer to [pal2022mixing](@cite). - -## Continuous DEQs (Infinite Time Neural ODEs) - -```@docs -ContinuousDEQSolver -``` - -## Discrete DEQs - -```@docs -DiscreteDEQSolver -``` \ No newline at end of file diff --git a/docs/src/api/deqs.md b/docs/src/manual/deqs.md similarity index 100% rename from docs/src/api/deqs.md rename to docs/src/manual/deqs.md diff --git a/docs/src/api/layers.md b/docs/src/manual/layers.md similarity index 64% rename from docs/src/api/layers.md rename to docs/src/manual/layers.md index c9029ca9..cf15fbcf 100644 --- a/docs/src/api/layers.md +++ b/docs/src/manual/layers.md @@ -2,5 +2,4 @@ ```@docs DEQChain -MultiParallelNet ``` \ No newline at end of file diff --git a/docs/src/api/misc.md b/docs/src/manual/misc.md similarity index 100% rename from docs/src/api/misc.md rename to docs/src/manual/misc.md diff --git a/docs/src/api/nlsolve.md b/docs/src/manual/nlsolve.md similarity index 98% rename from docs/src/api/nlsolve.md rename to docs/src/manual/nlsolve.md index 4140c1cf..28a41d0c 100644 --- a/docs/src/api/nlsolve.md +++ b/docs/src/manual/nlsolve.md @@ -9,4 +9,4 @@ We provide the following NonLinear Solvers for DEQs. These are compatible with G ```@docs BroydenSolver LimitedMemoryBroydenSolver -``` \ No newline at end of file +``` diff --git a/docs/src/manual/solvers.md b/docs/src/manual/solvers.md new file mode 100644 index 00000000..ff0ee4c4 --- /dev/null +++ b/docs/src/manual/solvers.md @@ -0,0 +1,38 @@ +# Dynamical System Variants + +[baideep2019](@cite) introduced Discrete Deep Equilibrium Models which drives a Discrete Dynamical System to its steady-state. [pal2022mixing](@cite) extends this framework to Continuous Dynamical Systems which converge to the steady-stable in a more stable fashion. For a detailed discussion refer to [pal2022mixing](@cite). + +## Continuous DEQs + +```@docs +ContinuousDEQSolver +``` + +## Discrete DEQs + +```@docs +DiscreteDEQSolver +``` + +## Termination Conditions + +#### Termination on Absolute Tolerance + +* `:abs`: Terminates if ``all \left( | \frac{\partial u}{\partial t} | \leq abstol \right)`` +* `:abs_norm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq abstol`` +* `:abs_deq_default`: Essentially `abs_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) +* `:abs_deq_best`: Same as `:abs_deq_default` but uses the best solution found so far, i.e. deviates only if the solution has not converged + +#### Termination on Relative Tolerance + +* `:rel`: Terminates if ``all \left(| \frac{\partial u}{\partial t} | \leq reltol \times | u | \right)`` +* `:rel_norm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` +* `:rel_deq_default`: Essentially `rel_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) +* `:rel_deq_best`: Same as `:rel_deq_default` but uses the best solution found so far, i.e. deviates only if the solution has not converged + +#### Termination using both Absolute and Relative Tolerances + +* `:norm`: Terminates if ``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|`` & + ``\| \frac{\partial u}{\partial t} \| \leq abstol`` +* `fallback`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. This is usable for small problems + but doesn't scale well for neural networks, and should be avoided unless absolutely necessary \ No newline at end of file diff --git a/src/FastDEQ.jl b/src/FastDEQ.jl index 6fca4152..0b4a7a3e 100644 --- a/src/FastDEQ.jl +++ b/src/FastDEQ.jl @@ -14,6 +14,7 @@ using ChainRulesCore, OrdinaryDiffEq, SciMLBase, Setfield, + Static, Statistics, SteadyStateDiffEq, UnPack, diff --git a/src/adjoint.jl b/src/adjoint.jl index 75c0987b..2a1f9ac4 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -47,11 +47,7 @@ function DiffEqBase._concrete_solve_adjoint( sol = solve(_prob, alg, args...; kwargs...) _save_idxs = save_idxs === nothing ? Colon() : save_idxs - if save_idxs === nothing - out = sol - else - out = DiffEqBase.sensitivity_solution(sol, sol[_save_idxs]) - end + out = save_idxs === nothing ? sol : DiffEqBase.sensitivity_solution(sol, sol[_save_idxs]) function steadystatebackpass(Δ) # Δ = dg/dx or diffcache.dg_val @@ -67,6 +63,7 @@ function DiffEqBase._concrete_solve_adjoint( ntuple(_ -> NoTangent(), length(args))..., ) end + return out, steadystatebackpass end diff --git a/src/layers/chain.jl b/src/layers/chain.jl index 9706e972..d1c7d1ff 100644 --- a/src/layers/chain.jl +++ b/src/layers/chain.jl @@ -1,3 +1,14 @@ +""" + DEQChain(layers...) + +Sequence of layers divided into 3 chunks -- + +* `pre_deq` -- layers that are executed before DEQ is applied +* `deq` -- The Deep Equilibrium Layer +* `post_deq` -- layers that are executed after DEQ is applied + +Constraint: Must have one DEQ layer in `layers` +""" struct DEQChain{P1,D,P2} <: AbstractExplicitContainerLayer{(:pre_deq, :deq, :post_deq)} pre_deq::P1 deq::D diff --git a/src/layers/core.jl b/src/layers/core.jl index 47f5e927..13745da2 100644 --- a/src/layers/core.jl +++ b/src/layers/core.jl @@ -12,11 +12,14 @@ function initialstates(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork ) end -@inline check_unrolled_mode(::Val{0}) = false -@inline check_unrolled_mode(::Val{d}) where {d} = d >= 1 -@inline check_unrolled_mode(st::NamedTuple) = check_unrolled_mode(st.fixed_depth) -@inline get_unrolled_depth(::Val{d}) where {d} = d -@inline get_unrolled_depth(st::NamedTuple) = get_unrolled_depth(st.fixed_depth) +@inline check_unrolled_mode(::Val{0})::Bool = false +@inline check_unrolled_mode(::Val{d}) where {d} = (d >= 1)::Bool +@inline check_unrolled_mode(st::NamedTuple)::Bool = check_unrolled_mode(st.fixed_depth) +@inline get_unrolled_depth(::Val{d}) where {d} = d::Int +@inline get_unrolled_depth(st::NamedTuple)::Int = get_unrolled_depth(st.fixed_depth) + +ChainRulesCore.@non_differentiable check_unrolled_mode(::Any) +ChainRulesCore.@non_differentiable get_unrolled_depth(::Any) """ DeepEquilibriumSolution(z_star, u₀, residual, jacobian_loss, nfe) diff --git a/src/layers/deq.jl b/src/layers/deq.jl index f2309676..83c8449e 100644 --- a/src/layers/deq.jl +++ b/src/layers/deq.jl @@ -1,3 +1,36 @@ +""" + DeepEquilibriumNetwork(model, solver; jacobian_regularization::Bool=false, sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs...) + +Deep Equilibrium Network as proposed in [baideep2019](@cite) + +## Arguments + +* `model`: Neural Network +* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) +* `jacobian_regularization`: If true, Jacobian Loss is computed and stored in the [`DeepEquilibriumSolution`](@ref) +* `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) +* `kwargs`: Additional Parameters that are directly passed to `solve` + +## Example + +```julia +model = DeepEquilibriumNetwork( + Parallel( + +, + Dense(2, 2; bias=false), + Dense(2, 2; bias=false) + ), + ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0) +) + +rng = Random.default_rng() +ps, st = Lux.setup(rng, model) + +model(rand(Float32, 2, 1), ps, st) +``` + +See also: [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref) +""" struct DeepEquilibriumNetwork{J,M,A,S,K} <: AbstractDeepEquilibriumNetwork model::M solver::A @@ -25,9 +58,11 @@ function (deq::DeepEquilibriumNetwork{J})( for _ in 1:get_unrolled_depth(st) z_star, st_ = deq.model((z_star, x), ps, st_) end - @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) - return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, get_unrolled_depth(st))), st + residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps, st.model)[1]) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true))::typeof(st.model) + + return (z_star, DeepEquilibriumSolution(z_star, z, residual, 0.0f0, get_unrolled_depth(st))), st end st_ = st.model @@ -42,13 +77,66 @@ function (deq::DeepEquilibriumNetwork{J})( z_star, st_ = deq.model((sol.u, x), ps, st.model) jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps, st.model, z_star, x) : T(0)) - residual = z_star .- deq.model((z_star, x), ps, st.model)[1] + residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps, st.model)[1]) - @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true))::typeof(st.model) return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end + +""" + SkipDeepEquilibriumNetwork(model, shortcut, solver; jacobian_regularization::Bool=false, sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs...) + +Skip Deep Equilibrium Network as proposed in [pal2022mixing](@cite) + +## Arguments + +* `model`: Neural Network +* `shortcut`: Shortcut for the network (pass `nothing` for SkipDEQV2) +* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) +* `jacobian_regularization`: If true, Jacobian Loss is computed and stored in the [`DeepEquilibriumSolution`](@ref) +* `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) +* `kwargs`: Additional Parameters that are directly passed to `solve` + +## Example + +```julia +# SkipDEQ +model = SkipDeepEquilibriumNetwork( + Parallel( + +, + Dense(2, 2; bias=false), + Dense(2, 2; bias=false) + ), + Dense(2, 2), + ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0) +) + +rng = Random.default_rng() +ps, st = Lux.setup(rng, model) + +model(rand(Float32, 2, 1), ps, st) + +# SkipDEQV2 +model = SkipDeepEquilibriumNetwork( + Parallel( + +, + Dense(2, 2; bias=false), + Dense(2, 2; bias=false) + ), + nothing, + ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0) +) + +rng = Random.default_rng() +ps, st = Lux.setup(rng, model) + +model(rand(Float32, 2, 1), ps, st) +``` + +See also: [`DeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref) +""" struct SkipDeepEquilibriumNetwork{J,M,Sh,A,S,K} <: AbstractSkipDeepEquilibriumNetwork model::M shortcut::Sh @@ -89,9 +177,11 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( for _ in 1:get_unrolled_depth(st) z_star, st_ = deq.model((z_star, x), ps.model, st_) end - @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) - return (z_star, DeepEquilibriumSolution(z_star, z, z, 0.0f0, get_unrolled_depth(st))), st + residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps.model, st.model)[1]) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true))::typeof(st.model) + + return (z_star, DeepEquilibriumSolution(z_star, z, residual, 0.0f0, get_unrolled_depth(st))), st end st_ = st.model @@ -106,9 +196,9 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( z_star, st_ = deq.model((sol.u, x), ps.model, st.model) jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : T(0)) - residual = z_star .- deq.model((z_star, x), ps.model, st.model)[1] + residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps.model, st.model)[1]) - @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true))::typeof(st.model) return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index 225297d6..84dfff38 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -1,13 +1,72 @@ -struct MultiScaleDeepEquilibriumNetwork{N,L,M,A,S,K} <: AbstractDeepEquilibriumNetwork +@generated function evaluate_unrolled_deq(model, z_star::NTuple{N}, x, ps, st, ::Val{depth}) where {N,depth} + calls = [] + for _ in 1:depth + push!(calls, :((z_star, st) = model(((z_star[1], x), z_star[2:($N)]...), ps, st))) + end + push!(calls, :(return z_star, st)) + return Expr(:block, calls...) +end + +""" + MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing,Tuple}, solver, scales; sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs...) + +Multiscale Deep Equilibrium Network as proposed in [baimultiscale2020](@cite) + +## Arguments + +* `main_layers`: Tuple of Neural Networks. The first network needs to take a tuple of 2 arrays, the other ones only take 1 input +* `mapping_layers`: Matrix of Neural Networks. The ``(i, j)^{th}`` network takes the output of ``i^{th}`` `main_layer` and passes it to the ``j^{th}`` `main_layer` +* `post_fuse_layer`: Tuple of Neural Networks. Each of the scales are passed through this layer +* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) +* `scales`: Output scales +* `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) +* `kwargs`: Additional Parameters that are directly passed to `solve` + +## Example + +```julia +model = MultiScaleDeepEquilibriumNetwork( + ( + Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh) + ), + [ + NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh); + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh); + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh); + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() + ], + nothing, + ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0), + ((4,), (3,), (2,), (1,)), +) + +rng = Random.default_rng() +ps, st = Lux.setup(rng, model) +x = rand(rng, Float32, 4, 1) + +model(x, ps, st) +``` + +See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref) +""" +struct MultiScaleDeepEquilibriumNetwork{N,Sc,M,A,S,K} <: AbstractDeepEquilibriumNetwork model::M solver::A sensealg::S - scales::NTuple{N,NTuple{L,Int64}} + scales::Sc kwargs::K end function initialstates(rng::AbstractRNG, deq::MultiScaleDeepEquilibriumNetwork) - return (model=initialstates(rng, deq.model), split_idxs=Tuple(vcat(0, cumsum(prod.(deq.scales))...)), fixed_depth=Val(0)) + return ( + model=initialstates(rng, deq.model), + split_idxs=static(Tuple(vcat(0, cumsum(prod.(deq.scales))...))), + fixed_depth=Val(0), + initial_condition=zeros(Float32, 1, 1), + ) end function MultiScaleDeepEquilibriumNetwork( @@ -15,31 +74,36 @@ function MultiScaleDeepEquilibriumNetwork( mapping_layers::Matrix, post_fuse_layer::Union{Nothing,Tuple}, solver, - scales; + scales::NTuple{N,NTuple{L,Int64}}; sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs..., -) +) where {N,L} l1 = Parallel(nothing, main_layers...) l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) - model = if post_fuse_layer === nothing - Chain(l1, l2) - else - l3 = Parallel(nothing, post_fuse_layer...) - Chain(l1, l2, l3) - end - return MultiScaleDeepEquilibriumNetwork(model, solver, sensealg, scales, kwargs) + model = post_fuse_layer === nothing ? Chain(l1, l2) : Chain(l1, l2, Parallel(nothing, post_fuse_layer...)) + scales = static(scales) + return MultiScaleDeepEquilibriumNetwork{ + N,typeof(scales),typeof(model),typeof(solver),typeof(sensealg),typeof(kwargs) + }( + model, solver, sensealg, scales, kwargs + ) end -function get_initial_condition_mdeq(scales::NTuple, x::AbstractArray{T,N}, st::NamedTuple{fields}) where {T,N,fields} - if hasproperty(st, :initial_condition) && size(st.initial_condition, 2) == size(x, N) - return st.initial_condition, st +@generated function get_initial_condition_mdeq(::S, x::AbstractArray{T,N}, st::NamedTuple{fields}) where {S,T,N,fields} + scales = known(S) + sz = sum(prod.(scales)) + calls = [] + if :initial_condition ∈ fields + push!(calls, :(u0 = st[:initial_condition])) + push!(calls, :(($sz, size(x, $N)) == size(u0) && return u0, st)) end - u0 = vcat(map(scale -> fill!(similar(x, prod(scale), size(x, N)), T(0)), scales)...) - st = merge((initial_condition=u0,), st) - return u0, st + push!(calls, :(u0 = fill!(similar(x, $(sz), size(x, N)), $(T(0))))) + push!(calls, :(st = merge(st, (initial_condition=u0,))::typeof(st))) + push!(calls, :(return u0, st)) + return Expr(:block, calls...) end -Zygote.@nograd get_initial_condition_mdeq +ChainRulesCore.@non_differentiable get_initial_condition_mdeq(::Any...) function (deq::MultiScaleDeepEquilibriumNetwork{N})( x::AbstractArray{T}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple @@ -48,15 +112,18 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( if check_unrolled_mode(st) z_star = split_and_reshape(z, st.split_idxs, deq.scales) - st_ = st.model + z_star, st_ = evaluate_unrolled_deq(deq.model, z_star, x, ps, st.model, st.fixed_depth) - for _ in 1:get_unrolled_depth(st) - z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps, st_) - end + residual = ignore_derivatives( + vcat(flatten.(z_star)...) .- + vcat(flatten.(evaluate_unrolled_deq(deq.model, z_star, x, ps, st_, Val(1))[1])...), + ) + st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),))::typeof(st) - @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) - - return (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, z, 0.0f0, get_unrolled_depth(st))), st + return ( + (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, get_unrolled_depth(st))), + st__, + ) end st_ = st.model @@ -73,21 +140,94 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) z_star, st_ = dudt_(sol.u, ps, nothing) - residual = dudt(sol.u, ps, nothing) + residual = ignore_derivatives(dudt(sol.u, ps, nothing)) + st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),))::typeof(st) - @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) + return ((z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st__) +end - return ( - (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st +""" + MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing,Tuple}, shortcut_layers::Union{Nothing,Tuple}, solver, scales; sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), kwargs...) + +Multiscale Deep Equilibrium Network as proposed in [baimultiscale2020](@cite) combined with Skip Deep Equilibrium Network as proposed in [pal2022mixing](@cite) + +## Arguments + +* `main_layers`: Tuple of Neural Networks. The first network needs to take a tuple of 2 arrays, the other ones only take 1 input +* `mapping_layers`: Matrix of Neural Networks. The ``(i, j)^{th}`` network takes the output of ``i^{th}`` `main_layer` and passes it to the ``j^{th}`` `main_layer` +* `post_fuse_layer`: Tuple of Neural Networks. Each of the scales are passed through this layer +* `shortcut_layers`: Shortcut for the network (pass `nothing` for SkipDEQV2) +* `solver`: Solver for the optimization problem (See: [`ContinuousDEQSolver`](@ref) & [`DiscreteDEQSolver`](@ref)) +* `scales`: Output scales +* `sensealg`: See [`DeepEquilibriumAdjoint`](@ref) +* `kwargs`: Additional Parameters that are directly passed to `solve` + +## Example + +```julia +# MSkipDEQ +model = MultiScaleSkipDeepEquilibriumNetwork( + ( + Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh), + ), + [ + NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() + ], + nothing, + (Dense(4, 4, tanh), Dense(4, 3, tanh), Dense(4, 2, tanh), Dense(4, 1, tanh)), + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), +) + +rng = Random.default_rng() +ps, st = Lux.setup(rng, model) +x = rand(rng, Float32, 4, 2) + +model(x, ps, st) + +# MSkipDEQV2 +model = MultiScaleSkipDeepEquilibriumNetwork( + ( + Parallel(+, Dense(4, 4, tanh), Dense(4, 4, tanh)), + Dense(3, 3, tanh), + Dense(2, 2, tanh), + Dense(1, 1, tanh), + ), + [ + NoOpLayer() Dense(4, 3, tanh) Dense(4, 2, tanh) Dense(4, 1, tanh) + Dense(3, 4, tanh) NoOpLayer() Dense(3, 2, tanh) Dense(3, 1, tanh) + Dense(2, 4, tanh) Dense(2, 3, tanh) NoOpLayer() Dense(2, 1, tanh) + Dense(1, 4, tanh) Dense(1, 3, tanh) Dense(1, 2, tanh) NoOpLayer() + ], + nothing, + nothing, + ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0, reltol_termination=0.1f0), + ((4,), (3,), (2,), (1,)); + sensealg=DeepEquilibriumAdjoint(0.1f0, 0.1f0, 10), ) -end -struct MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh,A,S,K} <: AbstractSkipDeepEquilibriumNetwork +rng = Random.default_rng() +ps, st = Lux.setup(rng, model) +x = rand(rng, Float32, 4, 2) + +model(x, ps, st) +``` + +See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref) +""" +struct MultiScaleSkipDeepEquilibriumNetwork{N,Sc,M,Sh,A,S,K} <: AbstractSkipDeepEquilibriumNetwork model::M shortcut::Sh solver::A sensealg::S - scales::NTuple{N,NTuple{L,Int64}} + scales::Sc kwargs::K end @@ -95,8 +235,9 @@ function initialstates(rng::AbstractRNG, deq::MultiScaleSkipDeepEquilibriumNetwo return ( model=initialstates(rng, deq.model), shortcut=initialstates(rng, deq.shortcut), - split_idxs=Tuple(vcat(0, cumsum(prod.(deq.scales))...)), + split_idxs=static(Tuple(vcat(0, cumsum(prod.(deq.scales))...))), fixed_depth=Val(0), + initial_condition=zeros(Float32, 1, 1), ) end @@ -112,46 +253,43 @@ function MultiScaleSkipDeepEquilibriumNetwork( ) l1 = Parallel(nothing, main_layers...) l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) - model = if post_fuse_layer === nothing - Chain(l1, l2) - else - l3 = Parallel(nothing, post_fuse_layer...) - Chain(l1, l2, l3) - end - shortcut = if shortcut_layers === nothing - nothing - else - Parallel(nothing, shortcut_layers...) - end - return MultiScaleSkipDeepEquilibriumNetwork(model, shortcut, solver, sensealg, scales, kwargs) + model = post_fuse_layer === nothing ? Chain(l1, l2) : Chain(l1, l2, Parallel(nothing, post_fuse_layer...)) + shortcut = shortcut_layers === nothing ? nothing : Parallel(nothing, shortcut_layers...) + scales = static(scales) + return MultiScaleSkipDeepEquilibriumNetwork{ + length(scales),typeof(scales),typeof(model),typeof(shortcut),typeof(solver),typeof(sensealg),typeof(kwargs) + }( + model, shortcut, solver, sensealg, scales, kwargs + ) end -function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( +function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,Sc,M,Sh})( x::AbstractArray{T}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple -) where {N,L,M,Sh,T} +) where {N,Sc,M,Sh,T} z, st = if Sh == Nothing u0, st_ = get_initial_condition_mdeq(deq.scales, x, st) u0_ = split_and_reshape(u0, st.split_idxs, deq.scales) z0, st__ = deq.model(((u0_[1], x), u0_[2:N]...), ps.model, st_.model) - @set! st_.model = st__ - (vcat(flatten.(z0)...), st_) + (vcat(flatten.(z0)...), merge(st_, (model=st__,))::typeof(st)) else z0, st_ = deq.shortcut(x, ps.shortcut, st.shortcut) - @set! st.shortcut = st_ - (vcat(flatten.(z0)...), st) + (vcat(flatten.(z0)...), merge(st, (shortcut=st_,))::typeof(st)) end if check_unrolled_mode(st) z_star = split_and_reshape(z, st.split_idxs, deq.scales) - st_ = st.model - - for _ in 1:get_unrolled_depth(st) - z_star, st_ = deq.model(((z_star[1], x), z_star[2:N]...), ps.model, st_) - end + z_star, st_ = evaluate_unrolled_deq(deq.model, z_star, x, ps.model, st.model, st.fixed_depth) - @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) + residual = ignore_derivatives( + vcat(flatten.(z_star)...) .- + vcat(flatten.(evaluate_unrolled_deq(deq.model, z_star, x, ps.model, st_, Val(1))[1])...), + ) + st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),))::typeof(st) - return (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, z, 0.0f0, get_unrolled_depth(st))), st + return ( + (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, get_unrolled_depth(st))), + st__, + ) end st_ = st.model @@ -168,11 +306,9 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,L,M,Sh})( sol = solve(prob, deq.solver; sensealg=deq.sensealg, deq.kwargs...) z_star, st_ = dudt_(sol.u, ps.model, nothing) - residual = dudt(sol.u, ps.model, nothing) + residual = ignore_derivatives(dudt(sol.u, ps.model, nothing)) - @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) + st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),))::typeof(st) - return ( - (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st - ) + return ((z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st) end diff --git a/src/solvers/continuous.jl b/src/solvers/continuous.jl index 9119c971..1803b8e7 100644 --- a/src/solvers/continuous.jl +++ b/src/solvers/continuous.jl @@ -6,7 +6,7 @@ for solving DEQ problems. ## Arguments -* `alg`: Algorithm to solve the ODEProblem. (Default: `VCABM4()`) +* `alg`: Algorithm to solve the ODEProblem. (Default: `VCABM3()`) * `mode`: Termination Mode of the solver. See below for a description of the various termination conditions (Default: `:rel_deq_default`) * `abstol`: Absolute tolerance for time stepping. (Default: `1f-8`) * `reltol`: Relative tolerance for time stepping. (Default: `1f-8`) @@ -14,33 +14,7 @@ for solving DEQ problems. * `reltol_termination`: Relative tolerance for termination. (Default: `1f-8`) * `tspan`: Time span. Users should not change this value, instead control termination through `maxiters` in `solve` (Default: `Inf32`) -## Termination Modes - -#### Termination on Absolute Tolerance - -* `:abs`: Terminates if ``all \\left( | \\frac{\\partial u}{\\partial t} | \\leq abstol \\right)`` -* `:abs_norm`: Terminates if ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq abstol`` -* `:abs_deq_default`: Essentially `abs_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) -* `:abs_deq_best`: Same as `:abs_deq_default` but uses the best solution found so far, i.e. deviates only if the solution has not converged - -#### Termination on Relative Tolerance - -* `:rel`: Terminates if ``all \\left(| \\frac{\\partial u}{\\partial t} | \\leq reltol \\times | u | \\right)`` -* `:rel_norm`: Terminates if ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq reltol \\times \\| \\frac{\\partial u}{\\partial t} + u \\|`` -* `:rel_deq_default`: Essentially `rel_norm` + terminate if there has been no improvement for the last 30 steps + terminate if the solution blows up (diverges) -* `:rel_deq_best`: Same as `:rel_deq_default` but uses the best solution found so far, i.e. deviates only if the solution has not converged - -#### Termination using both Absolute and Relative Tolerances - -* `:norm`: Terminates if ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq reltol \\times \\| \\frac{\\partial u}{\\partial t} + u \\|`` & - ``\\| \\frac{\\partial u}{\\partial t} \\| \\leq abstol`` -* `fallback`: Check if all values of the derivative is close to zero wrt both relative and absolute tolerance. This is usable for small problems - but doesn't scale well for neural networks, and should be avoided unless absolutely necessary - See also: [`DiscreteDEQSolver`](@ref) - -!!! note - This will be upstreamed to DiffEqSensitivity in the later releases of the package """ struct ContinuousDEQSolver{M,A,T,TS} <: SteadyStateDiffEq.SteadyStateDiffEqAlgorithm alg::A diff --git a/src/solvers/discrete.jl b/src/solvers/discrete.jl index aecf447f..05cd9b50 100644 --- a/src/solvers/discrete.jl +++ b/src/solvers/discrete.jl @@ -1,20 +1,16 @@ # Wrapper for Discrete DEQs """ - DiscreteDEQSolver(solver=LimitedMemoryBroydenSolver; abstol=1e-8, reltol=1e-8, kwargs...) - -Solver for Discrete DEQ Problem ([baideep2019](@cite)). A wrapper around `SSrootfind` to mimic the [`ContinuousDEQSolver`](@ref) API. + DiscreteDEQSolver(alg=LimitedMemoryBroydenSolver(); mode::Symbol=:rel_deq_default, abstol_termination::T=1.0f-8, reltol_termination::T=1.0f-8) +Solver for Discrete DEQ Problem ([baideep2019](@cite)). Similar to `SSrootfind` but provides more flexibility needed + for solving DEQ problems. + ## Arguments -* `solver`: NonLinear Solver for the DEQ problem. (Default: [`LimitedMemoryBroydenSolver`](@ref)) -* `abstol`: Absolute tolerance for termination. (Default: `1e-8`) -* `reltol`: Relative tolerance for termination. (Default: `1e-8`) -* `kwargs`: Additional keyword arguments passed to the solver. - -!!! note - There is no `mode` kwarg for [`DiscreteDEQSolver`](@ref). Instead solvers directly define their own termination condition. - For [`BroydenSolver`](@ref) and [`LimitedMemoryBroydenSolver`](@ref), the termination conditions are `:abs_norm` & - `:rel_deq_default` respectively. +* `alg`: Algorithm to solve the Nonlinear Problem (Default: [`LimitedMemoryBroydenSolver`](@ref)) +* `mode`: Termination Mode of the solver. See below for a description of the various termination conditions (Default: `:rel_deq_default`) +* `abstol_termination`: Absolute tolerance for termination. (Default: `1f-8`) +* `reltol_termination`: Relative tolerance for termination. (Default: `1f-8`) See also: [`ContinuousDEQSolver`](@ref) """ diff --git a/src/utils.jl b/src/utils.jl index 96405cda..f0b1b23b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -53,22 +53,30 @@ function NormalInitializer(μ=0.0f0, σ²=0.01f0) end # For MultiScale DEQs -function split_and_reshape(x::AbstractMatrix, idxs::Tuple, shapes::Tuple) - return Tuple(@view(x[reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...), :]) for i in 1:(length(idxs) - 1)) +@generated function split_and_reshape(x::AbstractMatrix, ::T, ::S) where {T,S} + idxs, shapes = known(T), known(S) + dims = [reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...) for i in 1:(length(idxs) - 1)] + varnames = [gensym("x_view") for _ in dims] + calls = [] + for (i, dim) in enumerate(dims) + push!(calls, :($(varnames[i]) = view(x, $dim, :))) + end + push!(calls, :(return tuple($(Tuple(varnames)...)))) + return Expr(:block, calls...) end # Zygote Fix -function Zygote.accum(x::NTuple{N,T}, y::AbstractVector{T}) where {N,T<:AbstractArray} - return Zygote.accum.(x, y) -end +# function Zygote.accum(x::NTuple{N,T}, y::AbstractVector{T}) where {N,T<:AbstractArray} +# return Zygote.accum.(x, y) +# end -function Zygote.accum(x::AbstractVector{T}, y::NTuple{N,T}) where {N,T<:AbstractArray} - return Zygote.accum.(x, y) -end +# function Zygote.accum(x::AbstractVector{T}, y::NTuple{N,T}) where {N,T<:AbstractArray} +# return Zygote.accum.(x, y) +# end -function Zygote.accum(x::AbstractVector{T}, y::NTuple{N,Nothing}) where {N,T<:AbstractArray} - return Zygote.accum.(x, y) -end +# function Zygote.accum(x::AbstractVector{T}, y::NTuple{N,Nothing}) where {N,T<:AbstractArray} +# return Zygote.accum.(x, y) +# end # General Utils @inline function _init_identity_matrix(x::AbstractArray{T}, scale::T=T(1)) where {T} @@ -78,17 +86,12 @@ end @inline function _init_identity_matrix!(x::AbstractMatrix{T}, scale::T=T(1)) where {T} x .= zero(T) - idxs = diagind(x) - @. @view(x[idxs]) = scale * true + view(x, diagind(x)) .= scale .* true return x end -@inline function _norm(x; dims=Colon()) - return sqrt.(sum(abs2, x; dims=dims)) -end +@inline _norm(x; dims=Colon()) = sqrt.(sum(abs2, x; dims=dims)) # Compute norm over all dimensions except `except_dim` -@inline function _norm(x::AbstractArray{T,N}, except_dim) where {T,N} - dims = filter(i -> i != except_dim, 1:N) - return _norm(x; dims=dims) -end +@inline _norm(x::AbstractArray{T,N}, except_dim) where {T,N} = + _norm(x; dims=filter(i -> i != except_dim, 1:N)) diff --git a/test/runtests.jl b/test/runtests.jl index 609bf56c..9145b832 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,6 +29,8 @@ end x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) + @inferred model(x, ps, st) + gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1] @test test_gradient_isfinite(gs) @@ -36,6 +38,8 @@ end @info "Testing DEQ without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) + @inferred model(x, ps, st) + gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1] @test test_gradient_isfinite(gs) @@ -54,6 +58,8 @@ end ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) + + @inferred model(x, ps, st) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -64,6 +70,8 @@ end @info "Testing SkipDEQ without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) + + @inferred model(x, ps, st) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -86,6 +94,8 @@ end ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) + + @inferred model(x, ps, st) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -96,6 +106,8 @@ end @info "Testing SkipDEQV2 without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) + + @inferred model(x, ps, st) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -118,6 +130,8 @@ end ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) + + @inferred model(x, ps, st) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -140,6 +154,8 @@ end ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) + + @inferred model(x, ps, st) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -172,6 +188,8 @@ end ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 4, 2)) y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) + + @inferred model(x, ps, st) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -182,6 +200,8 @@ end @info "Testing MultiScaleDEQ without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) + + @inferred model(x, ps, st) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -215,6 +235,8 @@ end ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 4, 2)) y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) + + @inferred model(x, ps, st) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -225,6 +247,8 @@ end @info "Testing MultiScaleSkipDEQ without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) + + @inferred model(x, ps, st) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -258,6 +282,8 @@ end ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 4, 2)) y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) + + @inferred model(x, ps, st) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) @@ -268,6 +294,8 @@ end @info "Testing MultiScaleSkipDEQV2 without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) + + @inferred model(x, ps, st) gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) From 34517ea39e86b45ae63e4cd90aaf356067be44a5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 13 May 2022 22:57:53 -0400 Subject: [PATCH 59/76] Cant typeassert for state --- docs/src/manual/misc.md | 5 +---- src/layers/deq.jl | 8 ++++---- src/layers/mdeq.jl | 9 +++++---- src/utils.jl | 13 ------------- 4 files changed, 10 insertions(+), 25 deletions(-) diff --git a/docs/src/manual/misc.md b/docs/src/manual/misc.md index 00b93abc..8a3363ed 100644 --- a/docs/src/manual/misc.md +++ b/docs/src/manual/misc.md @@ -1,10 +1,7 @@ # Miscellaneous ```@docs -SteadyStateAdjoint +DeepEquilibriumAdjoint DeepEquilibriumSolution -get_and_clear_nfe! -compute_deq_jacobian_loss NormalInitializer -SupervisedLossContainer ``` \ No newline at end of file diff --git a/src/layers/deq.jl b/src/layers/deq.jl index 83c8449e..33192299 100644 --- a/src/layers/deq.jl +++ b/src/layers/deq.jl @@ -60,7 +60,7 @@ function (deq::DeepEquilibriumNetwork{J})( end residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps, st.model)[1]) - @set! st.model = Lux.update_state(st_, :update_mask, Val(true))::typeof(st.model) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) return (z_star, DeepEquilibriumSolution(z_star, z, residual, 0.0f0, get_unrolled_depth(st))), st end @@ -79,7 +79,7 @@ function (deq::DeepEquilibriumNetwork{J})( jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps, st.model, z_star, x) : T(0)) residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps, st.model)[1]) - @set! st.model = Lux.update_state(st_, :update_mask, Val(true))::typeof(st.model) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end @@ -179,7 +179,7 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( end residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps.model, st.model)[1]) - @set! st.model = Lux.update_state(st_, :update_mask, Val(true))::typeof(st.model) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) return (z_star, DeepEquilibriumSolution(z_star, z, residual, 0.0f0, get_unrolled_depth(st))), st end @@ -198,7 +198,7 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : T(0)) residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps.model, st.model)[1]) - @set! st.model = Lux.update_state(st_, :update_mask, Val(true))::typeof(st.model) + @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index 84dfff38..3db7d830 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -118,7 +118,7 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( vcat(flatten.(z_star)...) .- vcat(flatten.(evaluate_unrolled_deq(deq.model, z_star, x, ps, st_, Val(1))[1])...), ) - st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),))::typeof(st) + st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),)) return ( (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, get_unrolled_depth(st))), @@ -141,7 +141,8 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( z_star, st_ = dudt_(sol.u, ps, nothing) residual = ignore_derivatives(dudt(sol.u, ps, nothing)) - st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),))::typeof(st) + + st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),)) return ((z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st__) end @@ -284,7 +285,7 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,Sc,M,Sh})( vcat(flatten.(z_star)...) .- vcat(flatten.(evaluate_unrolled_deq(deq.model, z_star, x, ps.model, st_, Val(1))[1])...), ) - st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),))::typeof(st) + st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),)) return ( (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, get_unrolled_depth(st))), @@ -308,7 +309,7 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,Sc,M,Sh})( residual = ignore_derivatives(dudt(sol.u, ps.model, nothing)) - st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),))::typeof(st) + st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),)) return ((z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st) end diff --git a/src/utils.jl b/src/utils.jl index f0b1b23b..70bacf40 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -65,19 +65,6 @@ end return Expr(:block, calls...) end -# Zygote Fix -# function Zygote.accum(x::NTuple{N,T}, y::AbstractVector{T}) where {N,T<:AbstractArray} -# return Zygote.accum.(x, y) -# end - -# function Zygote.accum(x::AbstractVector{T}, y::NTuple{N,T}) where {N,T<:AbstractArray} -# return Zygote.accum.(x, y) -# end - -# function Zygote.accum(x::AbstractVector{T}, y::NTuple{N,Nothing}) where {N,T<:AbstractArray} -# return Zygote.accum.(x, y) -# end - # General Utils @inline function _init_identity_matrix(x::AbstractArray{T}, scale::T=T(1)) where {T} x_ = vec(x) From fdc460bc2ce50939eae1364a6bfb63bc0ca91fd8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 14 May 2022 09:31:10 -0400 Subject: [PATCH 60/76] Remove typeasserts --- src/layers/mdeq.jl | 14 +++++++------- src/utils.jl | 13 +++++++++++++ test/runtests.jl | 28 ---------------------------- 3 files changed, 20 insertions(+), 35 deletions(-) diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index 3db7d830..78fe34ec 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -1,4 +1,4 @@ -@generated function evaluate_unrolled_deq(model, z_star::NTuple{N}, x, ps, st, ::Val{depth}) where {N,depth} +@generated function evaluate_unrolled_mdeq(model, z_star::NTuple{N}, x, ps, st, ::Val{depth}) where {N,depth} calls = [] for _ in 1:depth push!(calls, :((z_star, st) = model(((z_star[1], x), z_star[2:($N)]...), ps, st))) @@ -112,11 +112,11 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( if check_unrolled_mode(st) z_star = split_and_reshape(z, st.split_idxs, deq.scales) - z_star, st_ = evaluate_unrolled_deq(deq.model, z_star, x, ps, st.model, st.fixed_depth) + z_star, st_ = evaluate_unrolled_mdeq(deq.model, z_star, x, ps, st.model, st.fixed_depth) residual = ignore_derivatives( vcat(flatten.(z_star)...) .- - vcat(flatten.(evaluate_unrolled_deq(deq.model, z_star, x, ps, st_, Val(1))[1])...), + vcat(flatten.(evaluate_unrolled_mdeq(deq.model, z_star, x, ps, st_, Val(1))[1])...), ) st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),)) @@ -271,19 +271,19 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,Sc,M,Sh})( u0, st_ = get_initial_condition_mdeq(deq.scales, x, st) u0_ = split_and_reshape(u0, st.split_idxs, deq.scales) z0, st__ = deq.model(((u0_[1], x), u0_[2:N]...), ps.model, st_.model) - (vcat(flatten.(z0)...), merge(st_, (model=st__,))::typeof(st)) + (vcat(flatten.(z0)...), merge(st_, (model=st__,))) else z0, st_ = deq.shortcut(x, ps.shortcut, st.shortcut) - (vcat(flatten.(z0)...), merge(st, (shortcut=st_,))::typeof(st)) + (vcat(flatten.(z0)...), merge(st, (shortcut=st_,))) end if check_unrolled_mode(st) z_star = split_and_reshape(z, st.split_idxs, deq.scales) - z_star, st_ = evaluate_unrolled_deq(deq.model, z_star, x, ps.model, st.model, st.fixed_depth) + z_star, st_ = evaluate_unrolled_mdeq(deq.model, z_star, x, ps.model, st.model, st.fixed_depth) residual = ignore_derivatives( vcat(flatten.(z_star)...) .- - vcat(flatten.(evaluate_unrolled_deq(deq.model, z_star, x, ps.model, st_, Val(1))[1])...), + vcat(flatten.(evaluate_unrolled_mdeq(deq.model, z_star, x, ps.model, st_, Val(1))[1])...), ) st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),)) diff --git a/src/utils.jl b/src/utils.jl index 70bacf40..09d19b23 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -52,6 +52,19 @@ function NormalInitializer(μ=0.0f0, σ²=0.01f0) return (rng::AbstractRNG, dims...) -> randn(rng, Float32, dims...) .* σ² .+ μ end +# Zygote Fix +function Zygote.accum(x::NTuple{N,T}, y::AbstractVector{T}) where {N,T<:AbstractArray} + return Zygote.accum.(x, y) +end + +function Zygote.accum(x::AbstractVector{T}, y::NTuple{N,T}) where {N,T<:AbstractArray} + return Zygote.accum.(x, y) +end + +function Zygote.accum(x::AbstractVector{T}, y::NTuple{N,Nothing}) where {N,T<:AbstractArray} + return Zygote.accum.(x, y) +end + # For MultiScale DEQs @generated function split_and_reshape(x::AbstractMatrix, ::T, ::S) where {T,S} idxs, shapes = known(T), known(S) diff --git a/test/runtests.jl b/test/runtests.jl index 9145b832..b5d4a6b8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,8 +29,6 @@ end x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) - @inferred model(x, ps, st) - gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1] @test test_gradient_isfinite(gs) @@ -38,8 +36,6 @@ end @info "Testing DEQ without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) - @inferred model(x, ps, st) - gs = gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1] @test test_gradient_isfinite(gs) @@ -59,8 +55,6 @@ end x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) - @inferred model(x, ps, st) - gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) @@ -71,8 +65,6 @@ end @info "Testing SkipDEQ without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) - @inferred model(x, ps, st) - gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) @@ -95,8 +87,6 @@ end x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) - @inferred model(x, ps, st) - gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) @@ -107,8 +97,6 @@ end @info "Testing SkipDEQV2 without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) - @inferred model(x, ps, st) - gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) @@ -131,8 +119,6 @@ end x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) - @inferred model(x, ps, st) - gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) @@ -155,8 +141,6 @@ end x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) - @inferred model(x, ps, st) - gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) @@ -189,8 +173,6 @@ end x = gpu(rand(rng, Float32, 4, 2)) y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) - @inferred model(x, ps, st) - gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) @@ -201,8 +183,6 @@ end @info "Testing MultiScaleDEQ without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) - @inferred model(x, ps, st) - gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) @@ -236,8 +216,6 @@ end x = gpu(rand(rng, Float32, 4, 2)) y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) - @inferred model(x, ps, st) - gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) @@ -248,8 +226,6 @@ end @info "Testing MultiScaleSkipDEQ without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) - @inferred model(x, ps, st) - gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) @@ -283,8 +259,6 @@ end x = gpu(rand(rng, Float32, 4, 2)) y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) - @inferred model(x, ps, st) - gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) @@ -295,8 +269,6 @@ end @info "Testing MultiScaleSkipDEQV2 without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) - @inferred model(x, ps, st) - gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) From be6f72d50c8da74b95637ff331a9b5ba61a37faa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 16 May 2022 14:46:43 -0400 Subject: [PATCH 61/76] Update script --- examples/Project.toml | 13 +- examples/cifar10/main.jl | 377 +++++++++++++++++++++++ examples/cifar10/options.jl | 73 +++++ examples/cifar10/script.jl | 79 ----- examples/imagenet/main.jl | 0 examples/src/FastDEQExperiments.jl | 76 ++--- examples/src/config.jl | 477 ++++++++++++----------------- examples/src/dataloaders.jl | 34 -- examples/src/logging.jl | 166 +++------- examples/src/models.jl | 117 +++++-- examples/src/train.jl | 191 ------------ examples/src/utils.jl | 58 ++++ src/FastDEQ.jl | 4 - src/layers/deq.jl | 17 +- src/layers/mdeq.jl | 8 +- src/utils.jl | 13 - test/runtests.jl | 10 +- 17 files changed, 899 insertions(+), 814 deletions(-) create mode 100644 examples/cifar10/main.jl create mode 100644 examples/cifar10/options.jl delete mode 100644 examples/cifar10/script.jl create mode 100644 examples/imagenet/main.jl delete mode 100644 examples/src/dataloaders.jl delete mode 100644 examples/src/train.jl create mode 100644 examples/src/utils.jl diff --git a/examples/Project.toml b/examples/Project.toml index 8221432a..6d5dcd50 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -4,25 +4,36 @@ authors = ["Avik Pal "] version = "0.1.0" [deps] +ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" +Augmentor = "02898b10-1f73-11ea-317c-6393d7073e15" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" DataLoaders = "2e981812-ef13-4a9c-bfa0-ab13047b12a9" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" FastDEQ = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" FluxMPI = "acf642fa-ee0e-513a-b9d5-bcd150f7ec3b" Format = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8" +Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLDataPattern = "9920b226-0b2a-5f5f-9153-9aa70a013f8b" +MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] @@ -31,7 +42,7 @@ DataLoaders = "0.1" Flux = "0.13" FluxMPI = "0.4" Format = "1.3" -Lux = "0.3" +Lux = "0.4" MLDatasets = "0.5" MLUtils = "0.2" MPI = "0.19" diff --git a/examples/cifar10/main.jl b/examples/cifar10/main.jl new file mode 100644 index 00000000..60307673 --- /dev/null +++ b/examples/cifar10/main.jl @@ -0,0 +1,377 @@ +# Adapted from https://github.com/avik-pal/Lux.jl/tree/main/examples/ImageNet/main.jl + +using ArgParse # Parse Arguments from Commandline +using DataAugmentation # Image Augmentation +using CUDA # GPUs <3 +using DataLoaders # Pytorch like DataLoaders +using Dates # Printing current time +using FastDEQ # Deep Equilibrium Model +using FastDEQExperiments # Models built using FastDEQ +using FluxMPI # Distibuted Training +using Formatting # Pretty Printing +using Functors # Parameter Manipulation +using Images # Image Processing +using LinearAlgebra # Linear Algebra +using Lux # Neural Network Framework +using MLDataPattern # Data Pattern +using MLDatasets # CIFAR10 +using MLDataUtils # Shuffling and Splitting Data +using MLUtils # Data Processing +using NNlib # Neural Network Backend +using OneHotArrays # One Hot Encoding +using Optimisers # Collection of Gradient Based Optimisers +using ParameterSchedulers # Collection of Schedulers for Parameter Updates +using Random # Make things less Random +using Serialization # Serialize Models +using Setfield # Easy Parameter Manipulation +using Statistics # Statistics +using ValueHistories # Storing Value Histories +using Zygote # Our AD Engine + +# Distributed Training +FluxMPI.Init(; verbose=true) +CUDA.allowscalar(false) + +# Training Options +include("options.jl") + +function get_experiment_config(args) + return get_experiment_configuration( + Val(:CIFAR10), + Val(Symbol(args["model-size"])); + model_type=Symbol(args["model-type"]), + continuous=!args["discrete"], + abstol=args["abstol"], + reltol=args["reltol"], + jfb=args["jfb"], + train_batchsize=args["train-batchsize"], + eval_batchsize=args["eval-batchsize"], + seed=args["seed"], + w_skip=args["w-skip"], + ) +end + +create_model(expt_config, args) = get_model(expt_config; device=gpu, warmup=true, loss_function=get_loss_function(args)) + +function get_loss_function(args) + if args["model-type"] == "VANILLA" + function loss_function_closure_vanilla(x, y, model, ps, st) + (ŷ, soln), st_ = model(x, ps, st) + celoss = logitcrossentropy(ŷ, y) + skiploss = FastDEQExperiments.mae(soln.u₀, soln.z_star) + loss = celoss + return loss, st_, (ŷ, soln.nfe, celoss, skiploss, soln.residual) + end + return loss_function_closure_vanilla + else + function loss_function_closure_skip(x, y, model, ps, st) + (ŷ, soln), st_ = model(x, ps, st) + celoss = logitcrossentropy(ŷ, y) + skiploss = FastDEQExperiments.mae(soln.u₀, soln.z_star) + loss = celoss + args["w-skip"] * skiploss + return loss, st_, (ŷ, soln.nfe, celoss, skiploss, soln.residual) + end + return loss_function_closure_skip + end +end + +# Checkpointing +function save_checkpoint(state, is_best, filename) + if should_log() + isdir(dirname(filename)) || mkpath(dirname(filename)) + serialize(filename, state) + is_best && cp(filename, joinpath(dirname(filename), "model_best.jls"); force=true) + end +end + +# DataLoading +struct CIFARDataContainer + images + labels + transform +end + +function get_dataloaders(expt_config::NamedTuple) + x_train, y_train = CIFAR10.traindata(Float32) + x_test, y_test = CIFAR10.testdata(Float32) + + x_train_images = map(x -> Image(colorview(RGB, permutedims(x, (3, 2, 1)))), eachslice(x_train; dims=4)) + y_train = collect(eachslice(Float32.(onehotbatch(y_train, 0:9)); dims=2)) + + x_test_images = map(x -> Image(colorview(RGB, permutedims(x, (3, 2, 1)))), eachslice(x_test; dims=4)) + y_test = collect(eachslice(Float32.(onehotbatch(y_test, 0:9)); dims=2)) + + base_transform = ImageToTensor() |> Normalize((0.4914f0, 0.4822f0, 0.4465f0), (0.2023f0, 0.1994f0, 0.2010f0)) + + if expt_config.augment + train_transform = ScaleKeepAspect((36, 36)) |> RandomResizeCrop((32, 32)) |> Maybe(FlipX()) |> base_transform + else + train_transform = base_transform + end + + train_dataset = MLUtils.shuffleobs(CIFARDataContainer(x_train_images, y_train, train_transform)) + train_dataset = is_distributed() ? DistributedDataContainer(train_dataset) : train_dataset + test_dataset = CIFARDataContainer(x_test_images, y_test, base_transform) + test_dataset = is_distributed() ? DistributedDataContainer(test_dataset) : test_dataset + + return ( + DataLoaders.DataLoader(train_dataset, expt_config.train_batchsize), + DataLoaders.DataLoader(test_dataset, expt_config.eval_batchsize), + ) +end + +Base.length(d::CIFARDataContainer) = length(d.images) +Base.getindex(d::CIFARDataContainer, i::Int) = (Array(itemdata(apply(d.transform, d.images[i]))), d.labels[i]) +MLDataPattern.getobs(d::CIFARDataContainer, i::Int64) = MLUtils.getobs(d, i) + +# Validation +function validate(val_loader, model, ps, st, loss_function, args) + batch_time = AverageMeter("Batch Time", "6.3f") + data_time = AverageMeter("Data Time", "6.3f") + losses = AverageMeter("Net Loss", "6.3f") + loss1 = AverageMeter("Cross Entropy Loss", "6.3e") + loss2 = AverageMeter("Skip Loss", "6.3e") + residual = AverageMeter("Residual", "6.3e") + top1 = AverageMeter("Accuracy", "3.2f") + nfe = AverageMeter("NFE", "3.2f") + + progress = ProgressMeter( + length(val_loader), (batch_time, data_time, losses, loss1, loss2, residual, top1, nfe), "Test:" + ) + + st_ = Lux.testmode(st) + t = time() + for (i, (x, y)) in enumerate(CUDA.functional() ? CuIterator(val_loader) : val_loader) + B = size(x, ndims(x)) + data_time(time() - t, B) + + # Compute Output + loss, st_, (ŷ, nfe_, celoss, skiploss, resi) = loss_function(x, y, model, ps, st_) + st_ = Lux.update_state(st_, :update_mask, Val(true)) + + # Measure Elapsed Time + batch_time(time() - t, B) + + # Metrics + acc1 = accuracy(cpu(ŷ), cpu(y)) + top1(acc1, B) + nfe(nfe_, B) + losses(loss, B) + loss1(celoss, B) + loss2(skiploss, B) + residual(norm(resi), B) + + # Print Progress + if i % args["print-freq"] == 0 || i == length(val_loader) + should_log() && print_meter(progress, i) + end + + t = time() + end + + return ( + batch_time.sum, + data_time.sum, + loss1.sum, + loss2.sum, + losses.sum, + nfe.sum, + top1.sum, + residual.sum, + top1.count, + ) +end + +# Training +function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, loss_function, args) + batch_time = AverageMeter("Batch Time", "6.3f") + data_time = AverageMeter("Data Time", "6.3f") + forward_pass_time = AverageMeter("Forward Pass Time", "6.3f") + backward_pass_time = AverageMeter("Backward Pass Time", "6.3f") + losses = AverageMeter("Net Loss", "6.3f") + loss1 = AverageMeter("Cross Entropy Loss", "6.3e") + loss2 = AverageMeter("Skip Loss", "6.3e") + residual = AverageMeter("Residual", "6.3e") + top1 = AverageMeter("Accuracy", "6.2f") + nfe = AverageMeter("NFE", "6.2f") + + progress = ProgressMeter( + length(train_loader), + (batch_time, data_time, forward_pass_time, backward_pass_time, losses, loss1, loss2, residual, top1, nfe), + "Epoch: [$epoch]", + ) + + st = Lux.trainmode(st) + t = time() + for (i, (x, y)) in enumerate(CuIterator(train_loader)) + B = size(x, ndims(x)) + data_time(time() - t, B) + + # Gradients and Update + _t = time() + (loss, st, (ŷ, nfe_, celoss, skiploss, resi)), back = Zygote.pullback( + p -> loss_function(x, y, model, p, st), ps + ) + forward_pass_time(time() - _t, B) + _t = time() + gs = back((one(loss), nothing, nothing))[1] + backward_pass_time(time() - _t, B) + st = Lux.update_state(st, :update_mask, Val(true)) + optimiser_state, ps = Optimisers.update(optimiser_state, ps, gs) + + # Measure Elapsed Time + batch_time(time() - t, B) + + # Metrics + acc1 = accuracy(cpu(ŷ), cpu(y)) + top1(acc1, B) + nfe(nfe_, B) + losses(loss, B) + loss1(celoss, B) + loss2(skiploss, B) + residual(norm(resi), B) + + # Print Progress + if i % args["print-freq"] == 0 || i == length(train_loader) + should_log() && print_meter(progress, i) + end + + t = time() + end + + return ( + ps, + st, + optimiser_state, + ( + batch_time.sum, + data_time.sum, + forward_pass_time.sum, + backward_pass_time.sum, + loss1.sum, + loss2.sum, + losses.sum, + nfe.sum, + top1.sum, + residual.sum, + top1.count, + ), + ) +end + +# Main Function +function get_base_experiment_name(args) + return "data-CIFAR10_type-$(args["model-type"])_size-$(args["model-size"])_discrete-$(args["discrete"])_jfb-$(args["jfb"])" +end + +function get_loggable_stats(stats) + v = [stats...] + is_distributed() && MPI.Reduce!(v, +, 0, MPI.COMM_WORLD) + return v[1:end-1] ./ v[end] +end + +function main(args) + best_acc1 = 0 + + # Seeding + rng = Random.default_rng() + Random.seed!(rng, args["seed"]) + + # Model Construction + expt_config = get_experiment_config(args) + should_log() && println("$(now()) => creating model") + model, ps, st = create_model(expt_config, args) + + should_log() && println("$(now()) => setting up dataloaders") + train_loader, test_loader = get_dataloaders(expt_config) + + # Optimizer and Scheduler + should_log() && println("$(now()) => creating optimiser") + optimiser, scheduler = construct_optimiser(expt_config) + optimiser_state = Optimisers.setup(optimiser, ps) + if is_distributed() + optimiser_state = FluxMPI.synchronize!(optimiser_state) + should_log() && println("$(now()) ==> synced optimiser state across all ranks") + end + + if args["resume"] != "" + if isfile(args["resume"]) + checkpoint = deserialize(args["resume"]) + args["start-epoch"] = checkpoint["epoch"] + optimiser_state = gpu(checkpoint["optimiser_state"]) + ps = gpu(checkpoint["model_parameters"]) + st = gpu(checkpoint["model_states"]) + should_log() && println("$(now()) => loaded checkpoint `$(args["resume"])` (epoch $(args["start-epoch"]))") + else + should_log() && println("$(now()) => no checkpoint found at `$(args["resume"])`. Starting from scratch.") + end + end + + loss_function = get_loss_function(args) + + if args["evaluate"] + validate(test_loader, model, ps, st, loss_function, args) + return nothing + end + + invoke_gc() + + expt_name = get_base_experiment_name(args) + store_in = string(now()) + + ckpt_dir = joinpath(args["checkpoint-dir"], expt_name, store_in) + log_path = joinpath(args["log-dir"], expt_name, store_in, "results.csv") + + should_log() && println("$(now()) => checkpoint directory `$(ckpt_dir)`") + + csv_logger = CSVLogger(log_path, ["Epoch", "Train/Batch Time", "Train/Data Time", "Train/Forward Pass Time", "Train/Backward Pass Time", "Train/Cross Entropy Loss", "Train/Skip Loss", "Train/Net Loss", "Train/NFE", "Train/Accuracy", "Train/Residual", "Test/Batch Time", "Test/Data Time", "Test/Cross Entropy Loss", "Test/Skip Loss", "Test/Net Loss", "Test/NFE", "Test/Accuracy", "Test/Residual"]) + + should_log() && println("$(now()) => logging results to `$(log_path)`") + + should_log() && serialize(joinpath(dirname(log_path), "setup.jls"), Dict("config" => expt_config, "args" => args)) + + for epoch in args["start-epoch"]:(expt_config.nepochs) + # Train for 1 epoch + ps, st, optimiser_state, train_stats = train_one_epoch( + train_loader, model, ps, st, optimiser_state, epoch, loss_function, args + ) + train_stats = get_loggable_stats(train_stats) + + should_log() && println() + + # Some Housekeeping + invoke_gc() + + # Evaluate on validation set + val_stats = validate(test_loader, model, ps, st, loss_function, args) + val_stats = get_loggable_stats(val_stats) + + should_log() && println() + + csv_logger(epoch, train_stats..., val_stats...) + should_log() && println("$(now()) => logged intermediated results to csv file\n") + + # ParameterSchedulers + eta_new = ParameterSchedulers.next!(scheduler) + optimiser_state = update_lr(optimiser_state, eta_new) + + # Some Housekeeping + invoke_gc() + + # Remember Best Accuracy and Save Checkpoint + is_best = val_stats[1] > best_acc1 + best_acc1 = max(val_stats[1], best_acc1) + + save_state = Dict( + "epoch" => epoch, + "config" => expt_config, + "accuracy" => accuracy, + "model_states" => cpu(st), + "model_parameters" => cpu(ps), + "optimiser_state" => cpu(optimiser_state), + ) + save_checkpoint(save_state, is_best, joinpath(ckpt_dir, "checkpoint.jls")) + end +end + +# main(parse_commandline_arguments()) diff --git a/examples/cifar10/options.jl b/examples/cifar10/options.jl new file mode 100644 index 00000000..e8b3471d --- /dev/null +++ b/examples/cifar10/options.jl @@ -0,0 +1,73 @@ +using ArgParse + + +# Parse Training Arguments +function parse_commandline_arguments() + parse_settings = ArgParseSettings("FastDEQ CIFAR-10 Training") + + @add_arg_table! parse_settings begin + "--model-size" + default = "TINY" + range_tester = x -> x ∈ ("TINY", "LARGE") + help = "model size: `TINY` or `LARGE`" + "--model-type" + default = "VANILLA" + range_tester = x -> x ∈ ("VANILLA", "SKIP", "SKIPV2") + help = "model type: `VANILLA`, `SKIP` or `SKIPV2`" + "--eval-batchsize" + help = "batch size for evaluation (per process)" + arg_type = Int + default = 32 + "--train-batchsize" + help = "batch size for training (per process)" + arg_type = Int + default = 32 + "--discrete" + help = "use discrete DEQ" + action = :store_true + "--jfb" + help = "enable jacobian-free-backpropagation" + action = :store_true + "--abstol" + default = 0.25f0 + arg_type = Float32 + help = "absolute tolerance for termination" + "--reltol" + default = 0.25f0 + arg_type = Float32 + help = "relative tolerance for termination" + "--w-skip" + default = 1.0f0 + arg_type = Float32 + help = "weight for skip DEQ loss" + "--start-epoch" + help = "manual epoch number (useful on restarts)" + arg_type = Int + default = 0 + "--print-freq" + help = "print frequency" + arg_type = Int + default = 100 + "--resume" + help = "resume from checkpoint" + arg_type = String + default = "" + "--evaluate" + help = "evaluate model on validation set" + action = :store_true + "--seed" + help = "seed for initializing training. " + arg_type = Int + default = 0 + "--checkpoint-dir" + help = "directory to save checkpoints" + arg_type = String + default = "checkpoints/" + "--log-dir" + help = "directory to save logs" + arg_type = String + default = "logs/" + end + + return parse_args(parse_settings) +end \ No newline at end of file diff --git a/examples/cifar10/script.jl b/examples/cifar10/script.jl deleted file mode 100644 index db8b05f0..00000000 --- a/examples/cifar10/script.jl +++ /dev/null @@ -1,79 +0,0 @@ -using FastDEQExperiments, Lux, CUDA, Optimisers, Dates, FluxMPI - -# Distributed Training -FluxMPI.Init(; verbose=true) - -# Setup -CUDA.allowscalar(false) - -# Training -function train_model(config, expt_name) - # Logger Setup - mkpath("logs/") - lg = FastDEQExperiments.PrettyTableLogger( - joinpath("logs/", expt_name * ".csv"), - ["Epoch Number", "Train/Time", "Test/NFE", "Test/Accuracy", "Test/Loss", "Test/Time"], - ["Train/Running/NFE", "Train/Running/Loss", "Train/Running/Accuracy"], - ) - - # Experiment Configuration - expt_config = FastDEQExperiments.get_experiment_config( - :CIFAR10, - config["model_size"]; - model_type=config["model_type"], - continuous=config["continuous"], - abstol=config["abstol"], - reltol=config["reltol"], - ) - - # Model Setup - model, ps, st = FastDEQExperiments.get_model(expt_config; seed=config["seed"], device=gpu, warmup=true) - - # Get Dataloaders - train_dataloader, test_dataloader = FastDEQExperiments.get_dataloaders( - :CIFAR10; train_batchsize=expt_config.train_batchsize, eval_batchsize=expt_config.eval_batchsize - ) - - # Train - ps, st, st_opt = FastDEQExperiments.train( - model, - ps, - st, - FastDEQExperiments.loss_function(expt_config), - FastDEQExperiments.construct_optimiser(expt_config)..., - train_dataloader, - nothing, - test_dataloader, - expt_config.nepochs, - lg, - expt_config, - ) - - # Close Logger and Flush Data to disk - Base.close(lg) - - return model, cpu(ps), cpu(st), st_opt -end - -# Experiment Configurations -configs = [] -for seed in [6171, 3859, 2961], model_type in [:VANILLA, :SKIP, :SKIPV2], model_size in [:TINY] #, :LARGE] - push!( - configs, - Dict( - "seed" => seed, - "abstol" => 5.0f-2, - "reltol" => 5.0f-2, - "model_type" => model_type, - "continuous" => true, - "model_size" => model_size, - ), - ) -end - -# Training -for config in configs - expt_name = "cifar-10_seed-$(config["seed"])_model-$(config["model_type"])_size-$(config["model_size"])_continuous-$(config["continuous"])_now-$(now())" - FastDEQExperiments._should_log() && println("Starting Experiment: " * expt_name) - model, ps, st, st_opt = train_model(config, expt_name) -end diff --git a/examples/imagenet/main.jl b/examples/imagenet/main.jl new file mode 100644 index 00000000..e69de29b diff --git a/examples/src/FastDEQExperiments.jl b/examples/src/FastDEQExperiments.jl index 218790ff..2230a668 100644 --- a/examples/src/FastDEQExperiments.jl +++ b/examples/src/FastDEQExperiments.jl @@ -1,55 +1,45 @@ module FastDEQExperiments -using FastDEQ, - DataLoaders, - Random, - OrdinaryDiffEq, - FluxMPI, - Format, - Functors, - Lux, - MLDatasets, - Optimisers, - MPI, - CUDA, - Setfield, - ParameterSchedulers, - NNlib, - Zygote - -import Flux: OneHotArray, onecold, onehotbatch, onehot -import Flux.Losses: logitcrossentropy, mse -import MLUtils: shuffleobs -import MLDataPattern, MLUtils - -# Memory Management -relieve_gc_pressure(::Union{Nothing,<:AbstractArray}) = nothing -relieve_gc_pressure(x::CuArray) = CUDA.unsafe_free!(x) -relieve_gc_pressure(t::Tuple) = relieve_gc_pressure.(t) -relieve_gc_pressure(x::NamedTuple) = fmap(relieve_gc_pressure, x) - -function invoke_gc() - GC.gc(true) - # CUDA.reclaim() - return nothing -end - -# PrettyTableLogger +using CUDA +using Dates +using FastBroadcast +using FastDEQ +using FluxMPI +using Format +using Formatting +using Functors +using Lux +using MPI +using NNlib +using OneHotArrays +using Optimisers +using OrdinaryDiffEq +using ParameterSchedulers +using Random +using Setfield +using Statistics +using Zygote + +import DataLoaders: LearnBase +import MLDataPattern +import MLUtils + +# logging utilities include("logging.jl") # get_model_config include("config.jl") -# train, loss_function -include("train.jl") # get_model include("models.jl") -# get_dataloaders -include("dataloaders.jl") +# random utilities +include("utils.jl") +# Exports +export AverageMeter, CSVLogger, ProgressMeter, print_meter -# Fallback since DataLoaders.jl still relies on MLDataPattern -MLDataPattern.nobs(x) = MLUtils.numobs(x) -MLDataPattern.getobs(d::Union{MLUtils.ObsView,MLDatasetsImageData,DistributedDataContainer}, i::Int64) = - MLUtils.getobs(d, i) +export get_experiment_configuration +export construct_optimiser, get_model + +export accuracy, invoke_gc, is_distributed, logitcrossentropy, mae, mse, relieve_gc_pressure, should_log, update_lr end \ No newline at end of file diff --git a/examples/src/config.jl b/examples/src/config.jl index e894ba04..2eb6be6c 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -1,175 +1,4 @@ -abstract type AbstractTaskModelConfiguration end - -# Predefined Image Classification Models -Base.@kwdef struct ImageClassificationModelConfiguration{N} <: AbstractTaskModelConfiguration - num_layers::Int - num_classes::Int - dropout_rate::Float32 - group_count::Int - weight_norm::Bool - downsample_times::Int - expansion_factor::Int - post_gn_affine::Bool - image_size::Tuple{Int,Int} - - num_modules::Int - num_branches::Int - block_type::Symbol - big_kernels::NTuple{N,Int} - head_channels::NTuple{N,Int} - num_blocks::NTuple{N,Int} - num_channels::NTuple{N,Int} - - fuse_method::Symbol - final_channelsize::Int - - fwd_maxiters::Int - bwd_maxiters::Int - model_type::Symbol - continuous::Bool - - # Specific for Continuous Models - abstol::Float32 = 5f-2 - reltol::Float32 = 5f-2 - stop_mode::Symbol = :rel_norm - ode_solver = VCAB3() -end - -function get_model_config(dataset::Symbol, model_size::Symbol; kwargs...) - if dataset == :CIFAR10 - if model_size == :TINY - return ImageClassificationModelConfiguration{2}(; - num_layers=10, - num_classes=10, - dropout_rate=0.25f0, - group_count=8, - weight_norm=true, - downsample_times=0, - expansion_factor=5, - post_gn_affine=false, - image_size=(32, 32), - num_modules=1, - num_branches=2, - block_type=:basic, - big_kernels=(0, 0), - head_channels=(8, 16), - num_blocks=(1, 1), - num_channels=(24, 24), - fuse_method=:sum, - final_channelsize=200, - fwd_maxiters=18, - bwd_maxiters=20, - kwargs... - ) - elseif model_size == :LARGE - return ImageClassificationModelConfiguration{4}(; - num_layers=10, - num_classes=10, - dropout_rate=0.3f0, - group_count=8, - weight_norm=true, - downsample_times=0, - expansion_factor=5, - post_gn_affine=false, - image_size=(32, 32), - num_modules=1, - num_branches=4, - block_type=:basic, - big_kernels=(0, 0, 0, 0), - head_channels=(14, 28, 56, 112), - num_blocks=(1, 1, 1, 1), - num_channels=(32, 64, 128, 256), - fuse_method=:sum, - final_channelsize=1680, - fwd_maxiters=18, - bwd_maxiters=20, - kwargs... - ) - else - throw(ArgumentError("`model_size` must be one of `[:TINY, :LARGE]`")) - end - elseif dataset == :IMAGENET - if model_size == :SMALL - return ImageClassificationModelConfiguration{4}(; - num_layers=4, - num_classes=1000, - dropout_rate=0.0f0, - group_count=8, - weight_norm=true, - downsample_times=2, - expansion_factor=5, - post_gn_affine=true, - image_size=(224, 224), - num_modules=1, - num_branches=4, - block_type=:basic, - big_kernels=(0, 0, 0, 0), - head_channels=(24, 48, 96, 192), - num_blocks=(1, 1, 1, 1), - num_channels=(32, 64, 128, 256), - fuse_method=:sum, - final_channelsize=2048, - fwd_maxiters=27, - bwd_maxiters=28, - kwargs... - ) - elseif model_size == :LARGE - return ImageClassificationModelConfiguration{4}(; - num_layers=4, - num_classes=1000, - dropout_rate=0.0f0, - group_count=8, - weight_norm=true, - downsample_times=2, - expansion_factor=5, - post_gn_affine=true, - image_size=(224, 224), - num_modules=1, - num_branches=4, - block_type=:basic, - big_kernels=(0, 0, 0, 0), - head_channels=(32, 64, 128, 256), - num_blocks=(1, 1, 1, 1), - num_channels=(80, 160, 320, 640), - fuse_method=:sum, - final_channelsize=2048, - fwd_maxiters=27, - bwd_maxiters=28, - kwargs... - ) - elseif model_size == :XL - return ImageClassificationModelConfiguration{4}(; - num_layers=4, - num_classes=1000, - dropout_rate=0.0f0, - group_count=8, - weight_norm=true, - downsample_times=2, - expansion_factor=5, - post_gn_affine=true, - image_size=(224, 224), - num_modules=1, - num_branches=4, - block_type=:basic, - big_kernels=(0, 0, 0, 0), - head_channels=(32, 64, 128, 256), - num_blocks=(1, 1, 1, 1), - num_channels=(88, 176, 352, 704), - fuse_method=:sum, - final_channelsize=2048, - fwd_maxiters=27, - bwd_maxiters=28, - kwargs... - ) - else - throw(ArgumentError("`model_size` must be one of `[:SMALL, :LARGE, :XL]`")) - end - else - throw(ArgumentError("`dataset` must be one of `[:CIFAR10]`")) - end -end - -function compute_feature_scales(config::ImageClassificationModelConfiguration) +function compute_feature_scales(config) image_size = config.image_size image_size_downsampled = image_size for _ in 1:(config.downsample_times) @@ -182,120 +11,204 @@ function compute_feature_scales(config::ImageClassificationModelConfiguration) return Tuple(scales) end -# Experiment Configuration -Base.@kwdef struct ExperimentConfiguration{M<:AbstractTaskModelConfiguration} - model_config::M +function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:TINY}) + return ( + num_layers=10, + num_classes=10, + dropout_rate=0.25f0, + group_count=8, + weight_norm=true, + downsample_times=0, + expansion_factor=5, + post_gn_affine=false, + image_size=(32, 32), + num_modules=1, + num_branches=2, + block_type=:basic, + big_kernels=(0, 0), + head_channels=(8, 16), + num_blocks=(1, 1), + num_channels=(24, 24), + fuse_method=:sum, + final_channelsize=200, + fwd_maxiters=18, + bwd_maxiters=20, + continuous=true, + stop_mode=:rel_norm, + nepochs=50, + jfb=false, + augment=false, + model_type=:VANILLA, + abstol=5.0f-2, + reltol=5.0f-2, + ode_solver=VCABM3(), + pretrain_steps=3000 ÷ scaling_factor(), + lr_scheduler=:COSINE, + optimiser=:ADAM, + eta=0.001f0 * scaling_factor(), + eval_datasize_per_process=10000 ÷ scaling_factor(), + train_datasize_per_process=50000 ÷ scaling_factor(), + ) +end - # Eval - eval_batchsize::Int - eval_datasize_per_process::Int +function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:LARGE}) + return ( + num_layers=10, + num_classes=10, + dropout_rate=0.3f0, + group_count=8, + weight_norm=true, + downsample_times=0, + expansion_factor=5, + post_gn_affine=false, + image_size=(32, 32), + num_modules=1, + num_branches=4, + block_type=:basic, + big_kernels=(0, 0, 0, 0), + head_channels=(14, 28, 56, 112), + num_blocks=(1, 1, 1, 1), + num_channels=(32, 64, 128, 256), + fuse_method=:sum, + final_channelsize=1680, + fwd_maxiters=18, + bwd_maxiters=20, + continuous=true, + stop_mode=:rel_norm, + nepochs=220, + jfb=false, + augment=true, + model_type=:VANILLA, + abstol=5.0f-2, + reltol=5.0f-2, + ode_solver=VCABM3(), + pretrain_steps=20000 ÷ scaling_factor(), + lr_scheduler=:COSINE, + optimiser=:ADAM, + eta=0.001f0 * scaling_factor(), + ) +end - # Train - train_batchsize::Int - train_datasize_per_process::Int - nepochs::Int - pretrain_steps::Int +function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:SMALL}) + return ( + num_layers=4, + num_classes=1000, + dropout_rate=0.0f0, + group_count=8, + weight_norm=true, + downsample_times=2, + expansion_factor=5, + post_gn_affine=true, + image_size=(224, 224), + num_modules=1, + num_branches=4, + block_type=:basic, + big_kernels=(0, 0, 0, 0), + head_channels=(24, 48, 96, 192), + num_blocks=(1, 1, 1, 1), + num_channels=(32, 64, 128, 256), + fuse_method=:sum, + final_channelsize=2048, + fwd_maxiters=27, + bwd_maxiters=28, + continuous=true, + stop_mode=:rel_norm, + nepochs=100, + jfb=false, + model_type=:VANILLA, + abstol=5.0f-2, + reltol=5.0f-2, + ode_solver=VCABM3(), + pretrain_steps=510000 ÷ scaling_factor(), + lr_scheduler=:COSINE, + optimiser=:SGD, + eta=0.05f0 * scaling_factor(), + weight_decay=0.00005f0, + momentum=0.9f0, + nesterov=true, + ) +end - # Optimiser - lr_scheduler::Symbol - optimiser::Symbol - eta::Float32 - momentum::Float32 - nesterov::Bool - weight_decay::Float32 +function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:LARGE}) + return ( + num_layers=4, + num_classes=1000, + dropout_rate=0.0f0, + group_count=8, + weight_norm=true, + downsample_times=2, + expansion_factor=5, + post_gn_affine=true, + image_size=(224, 224), + num_modules=1, + num_branches=4, + block_type=:basic, + big_kernels=(0, 0, 0, 0), + head_channels=(32, 64, 128, 256), + num_blocks=(1, 1, 1, 1), + num_channels=(80, 160, 320, 640), + fuse_method=:sum, + final_channelsize=2048, + fwd_maxiters=27, + bwd_maxiters=28, + continuous=true, + stop_mode=:rel_norm, + nepochs=100, + jfb=false, + model_type=:VANILLA, + abstol=5.0f-2, + reltol=5.0f-2, + ode_solver=VCABM3(), + pretrain_steps=510000 ÷ scaling_factor(), + lr_scheduler=:COSINE, + optimiser=:SGD, + eta=0.05f0 * scaling_factor(), + weight_decay=0.00005f0, + momentum=0.9f0, + nesterov=true, + ) end -function get_experiment_config(dataset::Symbol, model_size::Symbol; kwargs...) - if dataset == :CIFAR10 - if model_size == :TINY - return ExperimentConfiguration( - model_config=get_model_config(dataset, model_size; kwargs...), - eval_batchsize=128, - train_batchsize=128, - nepochs=50, - pretrain_steps=3000 ÷ (is_distributed() ? total_workers() : 1), - lr_scheduler=:COSINE, - optimiser=:ADAM, - eta=0.001f0 / 2 * (is_distributed() ? total_workers() : 1), - weight_decay=0.0f0, - momentum=0.9f0, - nesterov=true, - eval_datasize_per_process=10000 ÷ (is_distributed() ? total_workers() : 1), - train_datasize_per_process=50000 ÷ (is_distributed() ? total_workers() : 1), - ) - elseif model_size == :LARGE - return ExperimentConfiguration( - model_config=get_model_config(dataset, model_size; kwargs...), - eval_batchsize=32, - train_batchsize=32, - nepochs=220, - pretrain_steps=20000 ÷ (is_distributed() ? total_workers() : 1), - lr_scheduler=:COSINE, - optimiser=:ADAM, - eta=0.001f0 / 4 * (is_distributed() ? total_workers() : 1), - weight_decay=0.0f0, - momentum=0.9f0, - nesterov=true, - eval_datasize_per_process=10000 ÷ (is_distributed() ? total_workers() : 1), - train_datasize_per_process=50000 ÷ (is_distributed() ? total_workers() : 1), - ) - else - throw(ArgumentError("`model_size` must be one of `[:TINY, :LARGE]`")) - end - elseif dataset == :IMAGENET - if model_size == :SMALL - return ExperimentConfiguration( - model_config=get_model_config(dataset, model_size; kwargs...), - eval_batchsize=32, - train_batchsize=32, - nepochs=100, - pretrain_steps=510000 ÷ (is_distributed() ? total_workers() : 1), - lr_scheduler=:COSINE, - optimiser=:SGD, - eta=0.05f0 / 4 * (is_distributed() ? total_workers() : 1), - weight_decay=0.00005f0, - momentum=0.9f0, - nesterov=true, - eval_datasize_per_process=50000 ÷ (is_distributed() ? total_workers() : 1), - train_datasize_per_process=1281166 ÷ (is_distributed() ? total_workers() : 1), - ) - elseif model_size == :LARGE - return ExperimentConfiguration( - model_config=get_model_config(dataset, model_size; kwargs...), - eval_batchsize=32, - train_batchsize=32, - nepochs=100, - pretrain_steps=510000 ÷ (is_distributed() ? total_workers() : 1), - lr_scheduler=:COSINE, - optimiser=:SGD, - eta=0.05f0 / 4 * (is_distributed() ? total_workers() : 1), - weight_decay=0.00005f0, - momentum=0.9f0, - nesterov=true, - eval_datasize_per_process=50000 ÷ (is_distributed() ? total_workers() : 1), - train_datasize_per_process=1281166 ÷ (is_distributed() ? total_workers() : 1), - ) - elseif model_size == :XL - return ExperimentConfiguration( - model_config=get_model_config(dataset, model_size; kwargs...), - eval_batchsize=32, - train_batchsize=32, - nepochs=100, - pretrain_steps=510000 ÷ (is_distributed() ? total_workers() : 1), - lr_scheduler=:COSINE, - optimiser=:SGD, - eta=0.05f0 / 8 * (is_distributed() ? total_workers() : 1), - weight_decay=0.00005f0, - momentum=0.9f0, - nesterov=true, - eval_datasize_per_process=50000 ÷ (is_distributed() ? total_workers() : 1), - train_datasize_per_process=1281166 ÷ (is_distributed() ? total_workers() : 1), - ) - else - throw(ArgumentError("`model_size` must be one of `[:SMALL, :LARGE, :XL]`")) - end - else - throw(ArgumentError("`dataset` must be one of `[:CIFAR10]`")) - end +function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:XL}) + return ( + num_layers=4, + num_classes=1000, + dropout_rate=0.0f0, + group_count=8, + weight_norm=true, + downsample_times=2, + expansion_factor=5, + post_gn_affine=true, + image_size=(224, 224), + num_modules=1, + num_branches=4, + block_type=:basic, + big_kernels=(0, 0, 0, 0), + head_channels=(32, 64, 128, 256), + num_blocks=(1, 1, 1, 1), + num_channels=(88, 176, 352, 704), + fuse_method=:sum, + final_channelsize=2048, + fwd_maxiters=27, + bwd_maxiters=28, + continuous=true, + stop_mode=:rel_norm, + nepochs=100, + jfb=false, + model_type=:VANILLA, + abstol=5.0f-2, + reltol=5.0f-2, + ode_solver=VCABM3(), + pretrain_steps=510000 ÷ scaling_factor(), + lr_scheduler=:COSINE, + optimiser=:SGD, + eta=0.05f0 * scaling_factor(), + weight_decay=0.00005f0, + momentum=0.9f0, + nesterov=true, + ) end +function get_experiment_configuration(dataset::Val, model_size::Val; kwargs...) + return merge(get_default_experiment_configuration(dataset, model_size), kwargs) +end diff --git a/examples/src/dataloaders.jl b/examples/src/dataloaders.jl deleted file mode 100644 index 2c9effe7..00000000 --- a/examples/src/dataloaders.jl +++ /dev/null @@ -1,34 +0,0 @@ -struct MLDatasetsImageData - images - labels -end - -MLDatasetsImageData(images::AbstractArray{T,4}, labels::AbstractArray{T,2}) where {T} = - MLDatasetsImageData(collect(eachslice(images, dims=4)), collect(eachslice(labels, dims=2))) - -Base.length(d::MLDatasetsImageData) = length(d.images) -Base.getindex(d::MLDatasetsImageData, i::Int) = (d.images[i], d.labels[i]) - -function get_dataloaders( - dataset::Symbol; μ=nothing, σ²=nothing, train_batchsize::Int64, eval_batchsize::Int64 -) - (x_train, y_train), (x_test, y_test), μ, σ², nclasses = if dataset == :CIFAR10 - μ = μ === nothing ? reshape([0.4914f0, 0.4822f0, 0.4465f0], 1, 1, :, 1) : μ - σ² = σ² === nothing ? reshape([0.2023f0, 0.1994f0, 0.2010f0], 1, 1, :, 1) : σ² - CIFAR10.traindata(Float32), CIFAR10.testdata(Float32), μ, σ², 10 - else - throw(ArgumentError("Not yet implemented for $dataset")) - end - - x_train = (x_train .- μ) ./ σ² - y_train = Float32.(onehotbatch(y_train, 0:(nclasses - 1))) - x_test = (x_test .- μ) ./ σ² - y_test = Float32.(onehotbatch(y_test, 0:(nclasses - 1))) - - train_dataset = shuffleobs(MLDatasetsImageData(x_train, y_train)) - train_dataset = is_distributed() ? DistributedDataContainer(train_dataset) : train_dataset - test_dataset = MLDatasetsImageData(x_test, y_test) - test_dataset = is_distributed() ? DistributedDataContainer(test_dataset) : test_dataset - - return (DataLoader(train_dataset, train_batchsize), DataLoader(test_dataset, eval_batchsize)) -end diff --git a/examples/src/logging.jl b/examples/src/logging.jl index 6f6f9690..d18abf00 100644 --- a/examples/src/logging.jl +++ b/examples/src/logging.jl @@ -1,133 +1,67 @@ - -function _should_log(; logging_rank=0, comm=MPI.COMM_WORLD) - FluxMPI.Initialized() || return true # Not using MPI - return local_rank() == logging_rank -end - -# Running AverageMeter -mutable struct AverageMeter{T} - last_value::T - sum::T - count::Int - - AverageMeter(T=Float32) = new{T}(T(0), T(0), 0) +Base.@kwdef mutable struct AverageMeter + fmtstr + val::Float64 = 0.0 + sum::Float64 = 0.0 + count::Int = 0 + average::Float64 = 0 end -function reset!(am::AverageMeter{T}) where {T} - val = am() - am.last_value = T(0) - am.sum = T(0) - am.count = 0 - return val +function AverageMeter(name::String, fmt::String) + fmtstr = Formatting.FormatExpr("$name {1:$fmt} ({2:$fmt})") + return AverageMeter(; fmtstr=fmtstr) end -function update!(am::AverageMeter{T}, val::T) where {T} - am.last_value = val - am.sum += val - am.count += 1 - return am.sum / am.count +function (meter::AverageMeter)(val, n::Int) + meter.val = val + meter.sum += val * n + meter.count += n + meter.average = meter.sum / meter.count + return meter.average end -update!(am::AverageMeter{T}, val) where {T} = update!(am, T(val)) +print_meter(meter::AverageMeter) = Formatting.printfmt(meter.fmtstr, meter.val, meter.average) -(am::AverageMeter)() = am.sum / am.count - -# Simple Table Logger -struct PrettyTableLogger{N,AM,F,R,FIO} - header::NTuple{N,String} - average_meters::AM - span::Int - fmtrfuncs::F - records::R - fio::FIO - - function PrettyTableLogger(filename::String, header, record=[]) - fio = _should_log() ? open(filename, "w") : nothing - - N = length(header) + length(record) - headers = vcat(header, record) - headers_og = headers - _c = 0 - count() = (_c += 1; _c) - rsplits = first.(map(x -> length(x) >= 2 ? x : ("__" * string(count()), x), rsplit.(headers, "/"; limit=2)),) - headers = string.(last.(rsplit.(headers, "/"; limit=2))) - headers = map(x -> length(x) <= 6 ? x * (" "^length(x)) : x, headers) - ind_lens = length.(headers) - span = sum(ind_lens .+ 3) + 1 - rsplit_lens = Dict() - if fio !== nothing - for (i, r) in enumerate(rsplits) - _r = string(r) - _r ∉ keys(rsplit_lens) && (rsplit_lens[_r] = -3 - length(_r) + 1) - rsplit_lens[_r] = rsplit_lens[_r] + ind_lens[i] + 3 - end - rsplits_unique = unique(rsplits) - if !(length(rsplits_unique) == 1 && rsplits_unique[0] == "") - println("="^span) - for r in rsplits_unique - if startswith(r, "__") - print("| " * (" "^length(r)) * (" "^rsplit_lens[string(r)])) - else - print("| $r" * (" "^rsplit_lens[string(r)])) - end - end - println("|") - end - println("="^span) - for h in headers - print("| $h ") - end - println("|") - println("="^span) - for h in headers_og[1:(end - 1)] - print(fio, "$h,") - end - println(fio, "$(headers_og[end])") - end +struct ProgressMeter{N} + batch_fmtstr + meters::NTuple{N,AverageMeter} +end - avg_meters = Dict{String,AverageMeter}(rec => AverageMeter() for rec in record) +function ProgressMeter(num_batches::Int, meters::NTuple{N}, prefix::String="") where {N} + fmt = "%" * string(length(string(num_batches))) * "d" + prefix = prefix != "" ? endswith(prefix, " ") ? prefix : prefix * " " : "" + batch_fmtstr = Formatting.generate_formatter("$prefix[$fmt/" * sprintf1(fmt, num_batches) * "]") + return ProgressMeter{N}(batch_fmtstr, meters) +end - patterns = ["%$l.4f" for l in ind_lens] - fmtrfuncs = generate_formatter.(patterns) +function print_meter(meter::ProgressMeter, batch::Int) + base_str = meter.batch_fmtstr(batch) + print(base_str) + foreach(x -> (print("\t"); print_meter(x)), meter.meters[1:end]) + return println() +end - record = tuple(record...) +struct CSVLogger{N} + filename + fio +end - return new{N,typeof(avg_meters),typeof(fmtrfuncs),typeof(record),typeof(fio)}( - tuple(headers...), avg_meters, span, fmtrfuncs, record, fio - ) - end +function CSVLogger(filename, header) + should_log() && !isdir(dirname(filename)) && mkpath(dirname(filename)) + fio = should_log() ? open(filename, "w") : nothing + N = length(header) + should_log() && println(fio, join(header, ",")) + return CSVLogger{N}(filename, fio) end -function (pl::PrettyTableLogger)(args...; last::Bool=false, records::Dict=Dict()) - _should_log() || return nothing - if length(records) > 0 - for (rec, val) in records - update!(pl.average_meters[rec], val) - end - return nothing +function (csv::CSVLogger)(args...) + if should_log() + println(csv.fio, join(args, ",")) + flush(csv.fio) end - if last - str = "="^pl.span - println(str) - return nothing - end - for (i, (fmtrfunc, arg)) in - enumerate(zip(pl.fmtrfuncs, vcat([args...], [reset!(pl.average_meters[rec]) for rec in pl.records]))) - h = fmtrfunc(arg) - print("| $h ") - if i < length(pl.fmtrfuncs) - print(pl.fio, "$arg,") - else - println(pl.fio, "$arg") - end - end - println("|") - flush(pl.fio) - return nothing end -function Base.close(pl::PrettyTableLogger) - pl(; last=true) - pl.fio === nothing || close(pl.fio) - return nothing +function Base.close(csv::CSVLogger) + if should_log() + close(csv.fio) + end end diff --git a/examples/src/models.jl b/examples/src/models.jl index 426ee771..06d3e3c2 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -1,19 +1,29 @@ # Building Blocks ## Helpful Functional Wrappers function conv1x1(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) - return Conv((1, 1), mapping, activation; stride=stride, pad=0, bias=bias, initW=NormalInitializer(), kwargs...) + return Conv( + (1, 1), mapping, activation; stride=stride, pad=0, bias=bias, init_weight=NormalInitializer(), kwargs... + ) end function conv3x3(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) - return Conv((3, 3), mapping, activation; stride=stride, pad=1, bias=bias, initW=NormalInitializer(), kwargs...) + return Conv( + (3, 3), mapping, activation; stride=stride, pad=1, bias=bias, init_weight=NormalInitializer(), kwargs... + ) end function conv5x5(mapping, activation=identity; stride::Int=1, bias=false, kwargs...) - return Conv((5, 5), mapping, activation; stride=stride, pad=2, bias=bias, initW=NormalInitializer(), kwargs...) + return Conv( + (5, 5), mapping, activation; stride=stride, pad=2, bias=bias, init_weight=NormalInitializer(), kwargs... + ) end +addrelu(x, y) = @. relu(x + y) + reassociate(x::NTuple{2,<:AbstractArray}, y) = (x[1], (x[2], y)) +addtuple(y) = y[1] .+ y[2] + ## Downsample Module function downsample_module(mapping, level_diff, activation; group_count=8) in_channels, out_channels = mapping @@ -98,21 +108,20 @@ function ResidualBlockV1( NoOpLayer(), # For injection ), Parallel( - +, + addrelu, NoOpLayer(), # For y1 Chain( - WrappedFunction(y2i -> y2i[1] .+ y2i[2]), # Since injection could be a scalar + WrappedFunction(addtuple), # Since injection could be a scalar gn2, ), # For (y2, injection) ), - ActivationFunction(relu), gn3, ) end function ResidualBlockV2( mapping; - deq_expand::Int=5, + deq_expand::Int=1, num_gn_groups::Int=4, downsample=NoOpLayer(), n_big_kernels::Int=0, @@ -145,9 +154,7 @@ function ResidualBlockV2( conv1, gn1, conv2, - Parallel(+, downsample, Chain(dropout, gn2)), - # WrappedFunction(Base.Fix1(broadcast, relu)), - ActivationFunction(relu), + Parallel(addrelu, downsample, Chain(dropout, gn2)), gn3, ) end @@ -165,10 +172,10 @@ function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool return Chain( Parallel(reassociate, BranchLayer(rescale, conv1x1(mapping)), NoOpLayer()), Parallel( - +, + addrelu, NoOpLayer(), Chain( - WrappedFunction(y2i -> y2i[1] .+ y2i[2]), # Since injection could be a scalar + WrappedFunction(addtuple), # Since injection could be a scalar Chain( BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats), conv3x3(last(mapping) => last(mapping) * expansion), @@ -178,7 +185,6 @@ function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool ), ), ), - ActivationFunction(relu), ) end @@ -194,7 +200,7 @@ function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool return Chain( Parallel( - +, + addrelu, rescale, Chain( conv1x1(mapping), @@ -205,19 +211,18 @@ function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), ), ), - ActivationFunction(relu), ) end # Dataset Specific Models -get_model(econfig::ExperimentConfiguration, args...; kwargs...) = get_model(econfig.model_config, args...; kwargs...) - function get_model( - config::ImageClassificationModelConfiguration; - seed::Int, + config::NamedTuple; device=gpu, warmup::Bool=true, # Helps reduce Zygote compile times + loss_function=nothing ) + @assert !warmup || loss_function !== nothing + init_channel_size = config.num_channels[1] downsample_layers = [ @@ -323,14 +328,21 @@ function get_model( reltol_termination=config.reltol, ) else - error("Discrete Solvers have not been updated yet") + DiscreteDEQSolver( + LimitedMemoryBroydenSolver(); + mode=config.stop_mode, + abstol_termination=config.abstol, + reltol_termination=config.reltol, + ) end - sensealg = SteadyStateAdjoint(config.abstol, config.reltol, config.bwd_maxiters) + sensealg = DeepEquilibriumAdjoint( + config.abstol, config.reltol, config.bwd_maxiters; mode=config.jfb ? :jfb : :vanilla + ) deq = if config.model_type ∈ (:SKIP, :SKIPV2) shortcut = if config.model_type == :SKIP - slayers = Lux.AbstractExplicitLayer[ResidualBlockV2(config.num_channels[1] => config.num_channels[1])] + slayers = Lux.AbstractExplicitLayer[ResidualBlockV2(config.num_channels[1] => config.num_channels[1], weight_norm=false)] for i in 1:(config.num_branches - 1) push!( slayers, @@ -374,35 +386,36 @@ function get_model( model = DEQChain(initial_layers, deq, final_layers) rng = Random.default_rng() - Random.seed!(rng, seed) + Random.seed!(rng, config.seed) ps, st = device.(Lux.setup(rng, model)) if warmup - clean_println("Starting Model Warmup") + should_log() && println("$(now()) ==> starting model warmup") x__ = device(randn(Float32, config.image_size..., 3, 2)) y__ = device(Float32.(onehotbatch([1, 2], 0:(config.num_classes - 1)))) model(x__, ps, st) - clean_println("Forward Pass Warmup Completed") + should_log() && println("$(now()) ==> forward pass warmup completed") - st_ = Lux.update_state(st, :fixed_depth, 2) + st_ = Lux.update_state(st, :fixed_depth, Val(2)) model(x__, ps, st_) - clean_println("Forward Pass (Pretraining) Warmup Completed") + should_log() && println("$(now()) ==> forward pass (pretraining) warmup completed") - lfn = loss_function(config) - (l, _, _, _), back = pullback(p -> lfn(x__, y__, model, p, st), ps) - back((one(l), nothing, nothing, nothing)) - clean_println("Backward Pass Warmup Completed") + (l, _, _), back = pullback(p -> loss_function(x__, y__, model, p, st), ps) + back((one(l), nothing, nothing)) + should_log() && println("$(now()) ==> backward pass warmup completed") - (l, _, _, _), back = pullback(p -> lfn(x__, y__, model, p, st_), ps) - back((one(l), nothing, nothing, nothing)) - clean_println("Backward Pass (Pretraining) Warmup Completed") + (l, _, _), back = pullback(p -> loss_function(x__, y__, model, p, st_), ps) + back((one(l), nothing, nothing)) + should_log() && println("$(now()) ==> backward pass (pretraining) warmup completed") invoke_gc() end ps, st = if is_distributed() ps_ = FluxMPI.synchronize!(ps; root_rank=0) + should_log() && println("$(now()) ===> synchronized model parameters across all processes") st_ = FluxMPI.synchronize!(st; root_rank=0) + should_log() && println("$(now()) ===> synchronized model state across all processes") ps_, st_ else ps, st @@ -410,3 +423,39 @@ function get_model( return model, ps, st end + +# Optimisers +function construct_optimiser(config::NamedTuple) + opt = if config.optimiser == :ADAM + Optimisers.ADAM(config.eta) + elseif config.optimiser == :SGD + if config.nesterov + Optimisers.Nesterov(config.eta, config.momentum) + else + if iszero(config.momentum) + Optimisers.Descent(config.eta) + else + Optimisers.Momentum(config.eta, config.momentum) + end + end + else + throw(ArgumentError("`config.optimiser` must be either `:ADAM` or `:SGD`")) + end + if hasproperty(config, :weight_decay) && !iszero(config.weight_decay) + opt = Optimisers.OptimiserChain(opt, Optimisers.WeightDecay(config.weight_decay)) + end + + if is_distributed() + opt = DistributedOptimiser(opt) + end + + sched = if config.lr_scheduler == :COSINE + ParameterSchedulers.Stateful(ParameterSchedulers.Cos(config.eta, 1.0f-6, config.nepochs)) + elseif config.lr_scheduler == :CONSTANT + ParameterSchedulers.Stateful(ParameterSchedulers.Constant(config.eta)) + else + throw(ArgumentError("`config.lr_scheduler` must be either `:COSINE` or `:CONSTANT`")) + end + + return opt, sched +end diff --git a/examples/src/train.jl b/examples/src/train.jl deleted file mode 100644 index 31468fcf..00000000 --- a/examples/src/train.jl +++ /dev/null @@ -1,191 +0,0 @@ -function update_lr(st::ST, eta) where {ST} - if hasfield(ST, :eta) - @set! st.eta = eta - end - return st -end - -update_lr(st::Optimisers.OptimiserChain, eta) = update_lr.(st.opts, eta) - -function update_lr(st::Optimisers.Leaf, eta) - @set! st.rule = update_lr(st.rule, eta) -end - -update_lr(st_opt::NamedTuple, eta) = fmap(l -> update_lr(l, eta), st_opt) - -function construct_optimiser(config::ExperimentConfiguration) - opt = if config.optimiser == :ADAM - Optimisers.ADAM(config.eta) - elseif config.optimiser == :SGD - if config.nesterov - Optimisers.Nesterov(config.eta, config.momentum) - else - if iszero(config.momentum) - Optimisers.Descent(config.eta) - else - Optimisers.Momentum(config.eta, config.momentum) - end - end - else - throw(ArgumentError("`config.optimiser` must be either `:ADAM` or `:SGD`")) - end - if !iszero(config.weight_decay) - opt = Optimisers.OptimiserChain(opt, Optimisers.WeightDecay(config.weight_decay)) - end - - if is_distributed() - opt = DistributedOptimiser(opt) - end - - sched = if config.lr_scheduler == :COSINE - ParameterSchedulers.Stateful( - ParameterSchedulers.Cos( - config.eta, 1.0f-6, config.nepochs * (config.train_datasize_per_process ÷ config.train_batchsize) - ), - ) - elseif config.lr_scheduler == :CONSTANT - ParameterSchedulers.Stateful(ParameterSchedulers.Constant(config.eta)) - else - throw(ArgumentError("`config.lr_scheduler` must be either `:COSINE` or `:CONSTANT`")) - end - - return opt, sched -end - -is_distributed() = FluxMPI.Initialized() && total_workers() > 1 - -_get_loggable_stats(::Nothing) = () - -function _get_loggable_stats(stats::NamedTuple) - if is_distributed() - arr = [stats.mean_nfe, stats.accuracy, stats.loss, stats.total_datasize] - MPI.Reduce!(arr, +, 0, MPI.COMM_WORLD) - return ((arr[1:3] ./ arr[4])..., stats.total_time) - else - return ( - stats.mean_nfe / stats.total_datasize, - stats.accuracy / stats.total_datasize, - stats.loss / stats.total_datasize, - stats.total_time, - ) - end -end - -evaluate(model, ps, st, ::Nothing) = nothing - -function evaluate(model, ps, st, dataloader) - st_eval = Lux.testmode(st) - matches, total_loss, total_datasize, total_nfe, total_time = 0, 0, 0, 0, 0 - for (x, y) in CuIterator(dataloader) - start_time = time() - (ŷ, soln), _ = model(x, ps, st_eval) - total_time += time() - start_time - - total_nfe += soln.nfe * size(x, ndims(x)) - total_loss += logitcrossentropy(ŷ, y) * size(x, ndims(x)) - matches += sum(argmax.(eachcol(cpu(ŷ))) .== onecold(cpu(y))) - total_datasize += size(x, ndims(x)) - end - return (loss=total_loss, accuracy=matches, mean_nfe=total_nfe, total_time=total_time, total_datasize=total_datasize) -end - -loss_function(e::ExperimentConfiguration, args...; kwargs...) = loss_function(e.model_config, args...; kwargs...) - -function loss_function(c::ImageClassificationModelConfiguration; λ_skip=1.0f-3) - if c.model_type == :VANILLA - function loss_function_closure_1(x, y, model, ps, st) - (ŷ, soln), st_ = model(x, ps, st) - loss = logitcrossentropy(ŷ, y) - return loss, ŷ, st_, soln.nfe - end - return loss_function_closure_1 - else - function loss_function_closure_2(x, y, model, ps, st) - (ŷ, soln), st_ = model(x, ps, st) - loss = logitcrossentropy(ŷ, y) + λ_skip * mse(soln.u₀, soln.z_star) - return loss, ŷ, st_, soln.nfe - end - return loss_function_closure_2 - end -end - -function train_one_epoch( - model, - ps, - st, - loss_function, - opt_state, - scheduler, - dataloader, - lg::PrettyTableLogger, - econfig::ExperimentConfiguration, - iteration_count::Int, -) - total_time = 0 - - for (x, y) in CuIterator(dataloader) - # Compute Loss + Backprop + Update - start_time = time() - - (loss, ŷ, st, nfe), back = pullback(p -> loss_function(x, y, model, p, st), ps) - gs, = back((one(loss), nothing, nothing, nothing)) - opt_state, ps = Optimisers.update!(opt_state, ps, gs) - - total_time += time() - start_time - - acc = sum(argmax.(eachcol(cpu(ŷ))) .== onecold(cpu(y))) / size(x, 4) - - iteration_count += 1 - st = econfig.pretrain_steps == iteration_count ? Lux.update_state(st, :fixed_depth, 0) : st - - # Run ParameterScheduler - eta_new = ParameterSchedulers.next!(scheduler) - opt_state = update_lr(opt_state, eta_new) - - # Logging - lg(; records=Dict("Train/Running/NFE" => nfe, "Train/Running/Loss" => loss, "Train/Running/Accuracy" => acc)) - end - - return ps, st, opt_state, scheduler, iteration_count, (total_time=total_time,) -end - -function train( - model, - ps, - st, - loss_function, - opt, - scheduler, - train_dataloader, - val_dataloader, - test_dataloader, - nepochs, - lg::PrettyTableLogger, - econfig::ExperimentConfiguration; -) - invoke_gc() - # TODO: Saving model weights - opt_state = Optimisers.setup(opt, ps) - opt_state = is_distributed() ? FluxMPI.synchronize!(opt_state; root_rank=0) : opt_state - iteration_count = 0 - - st = econfig.pretrain_steps != 0 ? Lux.update_state(st, :fixed_depth, econfig.model_config.num_layers) : st - - for epoch in 1:nepochs - # Train 1 epoch - ps, st, opt_state, scheduler, iteration_count, training_stats = train_one_epoch( - model, ps, st, loss_function, opt_state, scheduler, train_dataloader, lg, econfig, iteration_count - ) - invoke_gc() - - # Evaluate - val_stats = _get_loggable_stats(evaluate(model, ps, st, val_dataloader)) - invoke_gc() - test_stats = _get_loggable_stats(evaluate(model, ps, st, test_dataloader)) - invoke_gc() - - lg(epoch, training_stats.total_time, val_stats..., test_stats...) - end - - return ps, st, opt_state -end diff --git a/examples/src/utils.jl b/examples/src/utils.jl new file mode 100644 index 00000000..f88f3cfc --- /dev/null +++ b/examples/src/utils.jl @@ -0,0 +1,58 @@ +# unsafe_free OneHotArrays +CUDA.unsafe_free!(x::OneHotArray) = CUDA.unsafe_free!(x.indices) + +# Memory Management +relieve_gc_pressure(::Union{Nothing,<:AbstractArray}) = nothing +relieve_gc_pressure(x::CuArray) = CUDA.unsafe_free!(x) +relieve_gc_pressure(t::Tuple) = relieve_gc_pressure.(t) +relieve_gc_pressure(x::NamedTuple) = fmap(relieve_gc_pressure, x) + +function invoke_gc() + GC.gc(true) + CUDA.reclaim() + return nothing +end + +# Optimisers / Parameter Schedulers +function update_lr(st::ST, eta) where {ST} + if hasfield(ST, :eta) + @set! st.eta = eta + end + return st +end +update_lr(st::Optimisers.OptimiserChain, eta) = update_lr.(st.opts, eta) +function update_lr(st::Optimisers.Leaf, eta) + @set! st.rule = update_lr(st.rule, eta) +end +update_lr(st_opt::NamedTuple, eta) = fmap(l -> update_lr(l, eta), st_opt) + +# Metrics +accuracy(ŷ, y) = sum(argmax.(eachcol(ŷ)) .== onecold(y)) * 100 / size(y, ndims(y)) + +function accuracy(ŷ, y, topk::NTuple{N,<:Int}) where {N} + maxk = maximum(topk) + + pred_labels = partialsortperm.(eachcol(ŷ), (1:maxk,), rev=true) + true_labels = onecold(y) + + accuracies = Tuple(sum(map((a, b) -> sum(view(a, 1:k) .== b), pred_labels, true_labels)) for k in topk) + + return accuracies .* 100 ./ size(y, ndims(y)) +end + +# Distributed Utils +@inline is_distributed() = FluxMPI.Initialized() && total_workers() > 1 +@inline should_log() = !FluxMPI.Initialized() || local_rank() == 0 +@inline scaling_factor() = (is_distributed() ? total_workers() : 1) + +# Loss Function +@inline logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1)) +@inline mae(ŷ, y) = mean(abs, ŷ .- y) +@inline mse(ŷ, y) = mean(abs2, ŷ .- y) + +# DataLoaders doesn't yet work with MLUtils +LearnBase.nobs(data::DistributedDataContainer) = MLUtils.numobs(data) +LearnBase.getobs(data::DistributedDataContainer, i::Int) = MLUtils.getobs(data, i) +MLDataPattern.nobs(x) = MLUtils.numobs(x) +MLDataPattern.getobs(d::Union{MLUtils.ObsView,DistributedDataContainer}, i::Int64) = + MLUtils.getobs(d, i) diff --git a/src/FastDEQ.jl b/src/FastDEQ.jl index 0b4a7a3e..7384ecb7 100644 --- a/src/FastDEQ.jl +++ b/src/FastDEQ.jl @@ -24,10 +24,6 @@ import DiffEqSensitivity: AbstractAdjointSensitivityAlgorithm import Lux: AbstractExplicitContainerLayer, initialparameters, initialstates, parameterlength, statelength import Random: AbstractRNG -# This shouldn't be put in Lux since it is not true in the general case -# However for our usecase gradients dont propagate through the state -ChainRulesCore.@non_differentiable Lux.update_state(::Any...) - include("operator.jl") include("solvers/continuous.jl") diff --git a/src/layers/deq.jl b/src/layers/deq.jl index 33192299..dfaceec0 100644 --- a/src/layers/deq.jl +++ b/src/layers/deq.jl @@ -60,7 +60,7 @@ function (deq::DeepEquilibriumNetwork{J})( end residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps, st.model)[1]) - @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) + st = merge(st, (model=st_,)) return (z_star, DeepEquilibriumSolution(z_star, z, residual, 0.0f0, get_unrolled_depth(st))), st end @@ -79,7 +79,7 @@ function (deq::DeepEquilibriumNetwork{J})( jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps, st.model, z_star, x) : T(0)) residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps, st.model)[1]) - @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) + st = merge(st, (model=st_,)) return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end @@ -163,12 +163,13 @@ end function (deq::SkipDeepEquilibriumNetwork{J,M,S})( x::AbstractArray{T}, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple ) where {J,M,S,T} - z, st__ = if S == Nothing - deq.model((zero(x), x), ps.model, st.model) + z, st = if S == Nothing + z__, st__ = deq.model((zero(x), x), ps.model, st.model) + z__, merge(st, (model=st__,)) else - deq.shortcut(x, ps.shortcut, st.shortcut) + z__, st__ = deq.shortcut(x, ps.shortcut, st.shortcut) + z__, merge(st, (shortcut=st__,)) end - @set! st.shortcut = st__ if check_unrolled_mode(st) # Pretraining without Fixed Point Solving @@ -179,7 +180,7 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( end residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps.model, st.model)[1]) - @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) + st = merge(st, (model=st_,)) return (z_star, DeepEquilibriumSolution(z_star, z, residual, 0.0f0, get_unrolled_depth(st))), st end @@ -198,7 +199,7 @@ function (deq::SkipDeepEquilibriumNetwork{J,M,S})( jac_loss = (J ? compute_deq_jacobian_loss(deq.model, ps.model, st.model, z_star, x) : T(0)) residual = ignore_derivatives(z_star .- deq.model((z_star, x), ps.model, st.model)[1]) - @set! st.model = Lux.update_state(st_, :update_mask, Val(true)) + st = merge(st, (model=st_,)) return (z_star, DeepEquilibriumSolution(z_star, z, residual, jac_loss, sol.destats.nf + 1 + J)), st end diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index 78fe34ec..f6f9081a 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -118,7 +118,7 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( vcat(flatten.(z_star)...) .- vcat(flatten.(evaluate_unrolled_mdeq(deq.model, z_star, x, ps, st_, Val(1))[1])...), ) - st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),)) + st__ = merge(st, (model=st_,)) return ( (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, get_unrolled_depth(st))), @@ -142,7 +142,7 @@ function (deq::MultiScaleDeepEquilibriumNetwork{N})( residual = ignore_derivatives(dudt(sol.u, ps, nothing)) - st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),)) + st__ = merge(st, (model=st_,)) return ((z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st__) end @@ -285,7 +285,7 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,Sc,M,Sh})( vcat(flatten.(z_star)...) .- vcat(flatten.(evaluate_unrolled_mdeq(deq.model, z_star, x, ps.model, st_, Val(1))[1])...), ) - st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),)) + st__ = merge(st, (model=st_,)) return ( (z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, get_unrolled_depth(st))), @@ -309,7 +309,7 @@ function (deq::MultiScaleSkipDeepEquilibriumNetwork{N,Sc,M,Sh})( residual = ignore_derivatives(dudt(sol.u, ps.model, nothing)) - st__ = merge(st, (model=Lux.update_state(st_, :update_mask, Val(true)),)) + st__ = merge(st, (model=st_,)) return ((z_star, DeepEquilibriumSolution(vcat(flatten.(z_star)...), z, residual, 0.0f0, sol.destats.nf + 1)), st) end diff --git a/src/utils.jl b/src/utils.jl index 09d19b23..70bacf40 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -52,19 +52,6 @@ function NormalInitializer(μ=0.0f0, σ²=0.01f0) return (rng::AbstractRNG, dims...) -> randn(rng, Float32, dims...) .* σ² .+ μ end -# Zygote Fix -function Zygote.accum(x::NTuple{N,T}, y::AbstractVector{T}) where {N,T<:AbstractArray} - return Zygote.accum.(x, y) -end - -function Zygote.accum(x::AbstractVector{T}, y::NTuple{N,T}) where {N,T<:AbstractArray} - return Zygote.accum.(x, y) -end - -function Zygote.accum(x::AbstractVector{T}, y::NTuple{N,Nothing}) where {N,T<:AbstractArray} - return Zygote.accum.(x, y) -end - # For MultiScale DEQs @generated function split_and_reshape(x::AbstractMatrix, ::T, ::S) where {T,S} idxs, shapes = known(T), known(S) diff --git a/test/runtests.jl b/test/runtests.jl index b5d4a6b8..1d0f4189 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -54,7 +54,7 @@ end ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 2, 1)) y = gpu(rand(rng, Float32, 2, 1)) - + gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(abs2, ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) @@ -172,7 +172,7 @@ end ps, st = gpu.(Lux.setup(rng, model)) x = gpu(rand(rng, Float32, 4, 2)) y = tuple([gpu(rand(rng, Float32, i, 2)) for i in 4:-1:1]...) - + gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) @@ -182,7 +182,7 @@ end @info "Testing MultiScaleDEQ without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) - + gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) @@ -225,7 +225,7 @@ end @info "Testing MultiScaleSkipDEQ without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) - + gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) @@ -268,7 +268,7 @@ end @info "Testing MultiScaleSkipDEQV2 without Fixed Point Iterations" st = Lux.update_state(st, :fixed_depth, Val(5)) - + gs = gradient(ps) do p (ŷ, soln), _ = model(x, p, st) sum(Base.Fix1(sum, abs2), ŷ .- y) + sum(abs2, soln.u₀ .- soln.z_star) From ff9462c89c7b0257f5c0dbee967db9e5f7996ea8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 16 May 2022 20:45:44 -0400 Subject: [PATCH 62/76] Pretraining --- examples/cifar10/main.jl | 10 ++++++++-- examples/cifar10/options.jl | 4 ++++ examples/src/config.jl | 22 ++++++++++------------ examples/src/models.jl | 4 ++-- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/examples/cifar10/main.jl b/examples/cifar10/main.jl index 60307673..83141009 100644 --- a/examples/cifar10/main.jl +++ b/examples/cifar10/main.jl @@ -317,7 +317,7 @@ function main(args) invoke_gc() expt_name = get_base_experiment_name(args) - store_in = string(now()) + store_in = args["expt-subdir"] == "" ? string(now()) : args["expt-subdir"] ckpt_dir = joinpath(args["checkpoint-dir"], expt_name, store_in) log_path = joinpath(args["log-dir"], expt_name, store_in, "results.csv") @@ -330,6 +330,8 @@ function main(args) should_log() && serialize(joinpath(dirname(log_path), "setup.jls"), Dict("config" => expt_config, "args" => args)) + st = hasproperty(expt_config, :pretrain_epochs) && getproperty(expt_config, :pretrain_epochs) > 0 ? Lux.update_state(st, :fixed_depth, Val(getproperty(expt_config, :num_layers))) : st + for epoch in args["start-epoch"]:(expt_config.nepochs) # Train for 1 epoch ps, st, optimiser_state, train_stats = train_one_epoch( @@ -354,6 +356,10 @@ function main(args) # ParameterSchedulers eta_new = ParameterSchedulers.next!(scheduler) optimiser_state = update_lr(optimiser_state, eta_new) + if hasproperty(expt_config, :pretrain_epochs) && getproperty(expt_config, :pretrain_epochs) == epoch + should_log() && println("$(now()) => pretraining completed\n") + st = Lux.update_state(st, :fixed_depth, Val(0)) + end # Some Housekeeping invoke_gc() @@ -374,4 +380,4 @@ function main(args) end end -# main(parse_commandline_arguments()) +main(parse_commandline_arguments()) diff --git a/examples/cifar10/options.jl b/examples/cifar10/options.jl index e8b3471d..31b4c81b 100644 --- a/examples/cifar10/options.jl +++ b/examples/cifar10/options.jl @@ -67,6 +67,10 @@ function parse_commandline_arguments() help = "directory to save logs" arg_type = String default = "logs/" + "--expt-subdir" + help = "subdirectory name" + arg_type = String + default = "" end return parse_args(parse_settings) diff --git a/examples/src/config.jl b/examples/src/config.jl index 2eb6be6c..18c3a7f8 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -34,7 +34,7 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:TINY}) fwd_maxiters=18, bwd_maxiters=20, continuous=true, - stop_mode=:rel_norm, + stop_mode=:rel_deq_best, nepochs=50, jfb=false, augment=false, @@ -42,12 +42,10 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:TINY}) abstol=5.0f-2, reltol=5.0f-2, ode_solver=VCABM3(), - pretrain_steps=3000 ÷ scaling_factor(), + pretrain_epochs=8, lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 * scaling_factor(), - eval_datasize_per_process=10000 ÷ scaling_factor(), - train_datasize_per_process=50000 ÷ scaling_factor(), ) end @@ -74,7 +72,7 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:LARGE}) fwd_maxiters=18, bwd_maxiters=20, continuous=true, - stop_mode=:rel_norm, + stop_mode=:rel_deq_best, nepochs=220, jfb=false, augment=true, @@ -82,7 +80,7 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:LARGE}) abstol=5.0f-2, reltol=5.0f-2, ode_solver=VCABM3(), - pretrain_steps=20000 ÷ scaling_factor(), + pretrain_epochs=13, lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 * scaling_factor(), @@ -112,14 +110,14 @@ function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:SMALL}) fwd_maxiters=27, bwd_maxiters=28, continuous=true, - stop_mode=:rel_norm, + stop_mode=:rel_deq_best, nepochs=100, jfb=false, model_type=:VANILLA, abstol=5.0f-2, reltol=5.0f-2, ode_solver=VCABM3(), - pretrain_steps=510000 ÷ scaling_factor(), + pretrain_epochs=18, lr_scheduler=:COSINE, optimiser=:SGD, eta=0.05f0 * scaling_factor(), @@ -152,14 +150,14 @@ function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:LARGE}) fwd_maxiters=27, bwd_maxiters=28, continuous=true, - stop_mode=:rel_norm, + stop_mode=:rel_deq_best, nepochs=100, jfb=false, model_type=:VANILLA, abstol=5.0f-2, reltol=5.0f-2, ode_solver=VCABM3(), - pretrain_steps=510000 ÷ scaling_factor(), + pretrain_epochs=18, lr_scheduler=:COSINE, optimiser=:SGD, eta=0.05f0 * scaling_factor(), @@ -192,14 +190,14 @@ function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:XL}) fwd_maxiters=27, bwd_maxiters=28, continuous=true, - stop_mode=:rel_norm, + stop_mode=:rel_deq_best, nepochs=100, jfb=false, model_type=:VANILLA, abstol=5.0f-2, reltol=5.0f-2, ode_solver=VCABM3(), - pretrain_steps=510000 ÷ scaling_factor(), + pretrain_epochs=18, lr_scheduler=:COSINE, optimiser=:SGD, eta=0.05f0 * scaling_factor(), diff --git a/examples/src/models.jl b/examples/src/models.jl index 06d3e3c2..447e938b 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -322,8 +322,8 @@ function get_model( ContinuousDEQSolver( config.ode_solver; mode=config.stop_mode, - abstol=1.0f-5, - reltol=1.0f-5, + abstol=1.0f-3, + reltol=1.0f-3, abstol_termination=config.abstol, reltol_termination=config.reltol, ) From 57a7abb305644ca534ed51e3f26064443fe1cca6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 17 May 2022 10:50:01 -0400 Subject: [PATCH 63/76] Config --- examples/cifar10/options.jl | 2 +- examples/src/config.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/cifar10/options.jl b/examples/cifar10/options.jl index 31b4c81b..d5b7fe2a 100644 --- a/examples/cifar10/options.jl +++ b/examples/cifar10/options.jl @@ -43,7 +43,7 @@ function parse_commandline_arguments() "--start-epoch" help = "manual epoch number (useful on restarts)" arg_type = Int - default = 0 + default = 1 "--print-freq" help = "print frequency" arg_type = Int diff --git a/examples/src/config.jl b/examples/src/config.jl index 18c3a7f8..3396b4d1 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -42,7 +42,7 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:TINY}) abstol=5.0f-2, reltol=5.0f-2, ode_solver=VCABM3(), - pretrain_epochs=8, + pretrain_epochs=3, lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 * scaling_factor(), @@ -80,7 +80,7 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:LARGE}) abstol=5.0f-2, reltol=5.0f-2, ode_solver=VCABM3(), - pretrain_epochs=13, + pretrain_epochs=3, lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 * scaling_factor(), From 513172abfe4dcac287515b6d85a4d4984ab0de64 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 18 May 2022 17:37:37 -0400 Subject: [PATCH 64/76] Fix data augmentation --- examples/cifar10/main.jl | 6 ++++-- src/solve.jl | 8 +++++++- src/solvers/discrete/broyden.jl | 7 ++++++- src/solvers/discrete/limited_memory_broyden.jl | 7 ++++++- 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/examples/cifar10/main.jl b/examples/cifar10/main.jl index 83141009..5fdd5f23 100644 --- a/examples/cifar10/main.jl +++ b/examples/cifar10/main.jl @@ -104,7 +104,7 @@ function get_dataloaders(expt_config::NamedTuple) base_transform = ImageToTensor() |> Normalize((0.4914f0, 0.4822f0, 0.4465f0), (0.2023f0, 0.1994f0, 0.2010f0)) if expt_config.augment - train_transform = ScaleKeepAspect((36, 36)) |> RandomResizeCrop((32, 32)) |> Maybe(FlipX()) |> base_transform + train_transform = Maybe(FlipX()) |> ScaleKeepAspect((36, 36)) |> RandomResizeCrop((32, 32)) |> base_transform else train_transform = base_transform end @@ -165,6 +165,7 @@ function validate(val_loader, model, ps, st, loss_function, args) if i % args["print-freq"] == 0 || i == length(val_loader) should_log() && print_meter(progress, i) end + i % 10 == 0 && invoke_gc() t = time() end @@ -235,6 +236,7 @@ function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, lo if i % args["print-freq"] == 0 || i == length(train_loader) should_log() && print_meter(progress, i) end + i % 10 == 0 && invoke_gc() t = time() end @@ -380,4 +382,4 @@ function main(args) end end -main(parse_commandline_arguments()) +# main(parse_commandline_arguments()) diff --git a/src/solve.jl b/src/solve.jl index eae2c0e6..056ac41b 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -58,7 +58,7 @@ function DiffEqBase.__solve( :best_objective_value => real(eltype(prob.u0))(Inf), :best_objective_value_iteration => nothing ) - u, stats = nlsolve( + us, stats = nlsolve( alg.alg, u -> prob.f(u, prob.p, nothing), prob.u0; @@ -66,6 +66,12 @@ function DiffEqBase.__solve( terminate_condition=get_terminate_condition(alg, terminate_stats) ) + u = if terminate_stats[:best_objective_value_iteration] === nothing + us[end] + else + us[terminate_stats[:best_objective_value_iteration] + 1] + end + # Dont count towards NFE since this is mostly a check for convergence du = prob.f(u, prob.p, nothing) diff --git a/src/solvers/discrete/broyden.jl b/src/solvers/discrete/broyden.jl index b1d32416..41e5f204 100644 --- a/src/solvers/discrete/broyden.jl +++ b/src/solvers/discrete/broyden.jl @@ -45,6 +45,9 @@ function nlsolve( p = similar(fx_old, (size(Jinv, 1),)) ρ, σ₂ = T(0.9), T(0.001) + # Store the trajectory + xs = [x] + maybe_stuck, max_resets, resets, nsteps, nf = false, 3, 0, 1, 1 while nsteps <= maxiters @@ -82,11 +85,13 @@ function nlsolve( copyto!(fx_old, fx) copyto!(x_old, x) + push!(xs, x) + # Convergence Check terminate_condition(fx, x) && break end - return x, (nf=nf,) + return xs, (nf=nf,) end function _approximate_norm_descent(f::Function, x::AbstractArray{T,N}, p; λ₀=T(1), β=T(0.5), σ₁=T(0.001), diff --git a/src/solvers/discrete/limited_memory_broyden.jl b/src/solvers/discrete/limited_memory_broyden.jl index 63c896cc..3c84402d 100644 --- a/src/solvers/discrete/limited_memory_broyden.jl +++ b/src/solvers/discrete/limited_memory_broyden.jl @@ -39,6 +39,9 @@ struct LimitedMemoryBroydenSolver end Us = fill!(similar(y, (LBFGS_threshold, total_hsize, batch_size)), T(0)) VTs = fill!(similar(y, (total_hsize, LBFGS_threshold, batch_size)), T(0)) + # Store the trajectory + xs = [x₀] + # Counters nstep = 1 @@ -52,6 +55,8 @@ struct LimitedMemoryBroydenSolver end @. Δx = x₁ - x₀ @. Δfx = fx₁ - fx₀ + push!(xs, x₁) + # Convergence Check terminate_condition(fx₁, x₁) && break @@ -78,7 +83,7 @@ struct LimitedMemoryBroydenSolver end nstep += 1 end - return x₁, (nf=nstep + 1,) + return xs, (nf=nstep + 1,) end @inbounds @views function matvec( From d7098461e7b1c2711229ddc1108bd2b0b4bade4c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 23 May 2022 16:48:01 -0400 Subject: [PATCH 65/76] Wandb logging --- examples/Project.toml | 4 +++- examples/cifar10/main.jl | 32 ++++++++++++++++++++++++++------ examples/src/models.jl | 24 ++++++++++-------------- 3 files changed, 39 insertions(+), 21 deletions(-) diff --git a/examples/Project.toml b/examples/Project.toml index 6d5dcd50..d44b8ad6 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -34,13 +34,14 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7" +Wandb = "ad70616a-06c9-5745-b1f1-6a5f42545108" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] CUDA = "3" DataLoaders = "0.1" Flux = "0.13" -FluxMPI = "0.4" +FluxMPI = "0.5.3" Format = "1.3" Lux = "0.4" MLDatasets = "0.5" @@ -51,5 +52,6 @@ Optimisers = "0.2" OrdinaryDiffEq = "6" ParameterSchedulers = "0.3" Setfield = "0.8" +Wandb = "0.4.3" Zygote = "0.6" julia = "1.6" diff --git a/examples/cifar10/main.jl b/examples/cifar10/main.jl index 5fdd5f23..8dd175fc 100644 --- a/examples/cifar10/main.jl +++ b/examples/cifar10/main.jl @@ -26,6 +26,7 @@ using Serialization # Serialize Models using Setfield # Easy Parameter Manipulation using Statistics # Statistics using ValueHistories # Storing Value Histories +using Wandb # Logging to Weights and Biases using Zygote # Our AD Engine # Distributed Training @@ -165,7 +166,7 @@ function validate(val_loader, model, ps, st, loss_function, args) if i % args["print-freq"] == 0 || i == length(val_loader) should_log() && print_meter(progress, i) end - i % 10 == 0 && invoke_gc() + i == length(val_loader) - 1 && invoke_gc() # Needed since the last batch size is different t = time() end @@ -215,9 +216,12 @@ function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, lo ) forward_pass_time(time() - _t, B) _t = time() - gs = back((one(loss), nothing, nothing))[1] + gs = back((one(loss) / total_workers(), nothing, nothing))[1] backward_pass_time(time() - _t, B) st = Lux.update_state(st, :update_mask, Val(true)) + if is_distributed() + gs = allreduce_gradients(gs) + end optimiser_state, ps = Optimisers.update(optimiser_state, ps, gs) # Measure Elapsed Time @@ -236,7 +240,7 @@ function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, lo if i % args["print-freq"] == 0 || i == length(train_loader) should_log() && print_meter(progress, i) end - i % 10 == 0 && invoke_gc() + i == length(train_loader) - 1 && invoke_gc() # Needed since the last batch size is different t = time() end @@ -272,6 +276,14 @@ function get_loggable_stats(stats) return v[1:end-1] ./ v[end] end +function convert_config_to_loggable(expt_config::NamedTuple) + config = Dict() + for (k, v) in pairs(expt_config) + config[k] = isprimitivetype(typeof(v)) ? v : string(v) + end + return config +end + function main(args) best_acc1 = 0 @@ -281,6 +293,7 @@ function main(args) # Model Construction expt_config = get_experiment_config(args) + loggable_config = convert_config_to_loggable(expt_config) should_log() && println("$(now()) => creating model") model, ps, st = create_model(expt_config, args) @@ -326,7 +339,13 @@ function main(args) should_log() && println("$(now()) => checkpoint directory `$(ckpt_dir)`") - csv_logger = CSVLogger(log_path, ["Epoch", "Train/Batch Time", "Train/Data Time", "Train/Forward Pass Time", "Train/Backward Pass Time", "Train/Cross Entropy Loss", "Train/Skip Loss", "Train/Net Loss", "Train/NFE", "Train/Accuracy", "Train/Residual", "Test/Batch Time", "Test/Data Time", "Test/Cross Entropy Loss", "Test/Skip Loss", "Test/Net Loss", "Test/NFE", "Test/Accuracy", "Test/Residual"]) + logging_header = ["Epoch", "Train/Batch Time", "Train/Data Time", "Train/Forward Pass Time", "Train/Backward Pass Time", "Train/Cross Entropy Loss", "Train/Skip Loss", "Train/Net Loss", "Train/NFE", "Train/Accuracy", "Train/Residual", "Test/Batch Time", "Test/Data Time", "Test/Cross Entropy Loss", "Test/Skip Loss", "Test/Net Loss", "Test/NFE", "Test/Accuracy", "Test/Residual"] + csv_logger = CSVLogger(log_path, logging_header) + wandb_logger = WandbLoggerMPI(project="deep_equilibrium_models", + name=store_in, + config=loggable_config) + + values_to_loggable_dict(args...) = Dict(zip(logging_header, args)) should_log() && println("$(now()) => logging results to `$(log_path)`") @@ -353,6 +372,7 @@ function main(args) should_log() && println() csv_logger(epoch, train_stats..., val_stats...) + Wandb.log(wandb_logger, values_to_loggable_dict(epoch, train_stats..., val_stats...)) should_log() && println("$(now()) => logged intermediated results to csv file\n") # ParameterSchedulers @@ -372,7 +392,7 @@ function main(args) save_state = Dict( "epoch" => epoch, - "config" => expt_config, + "config" => loggable_config, "accuracy" => accuracy, "model_states" => cpu(st), "model_parameters" => cpu(ps), @@ -382,4 +402,4 @@ function main(args) end end -# main(parse_commandline_arguments()) +main(parse_commandline_arguments()) diff --git a/examples/src/models.jl b/examples/src/models.jl index 447e938b..d5249ecc 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -389,6 +389,16 @@ function get_model( Random.seed!(rng, config.seed) ps, st = device.(Lux.setup(rng, model)) + ps, st = if is_distributed() + ps_ = FluxMPI.synchronize!(ps; root_rank=0) + should_log() && println("$(now()) ===> synchronized model parameters across all processes") + st_ = FluxMPI.synchronize!(st; root_rank=0) + should_log() && println("$(now()) ===> synchronized model state across all processes") + ps_, st_ + else + ps, st + end + if warmup should_log() && println("$(now()) ==> starting model warmup") x__ = device(randn(Float32, config.image_size..., 3, 2)) @@ -411,16 +421,6 @@ function get_model( invoke_gc() end - ps, st = if is_distributed() - ps_ = FluxMPI.synchronize!(ps; root_rank=0) - should_log() && println("$(now()) ===> synchronized model parameters across all processes") - st_ = FluxMPI.synchronize!(st; root_rank=0) - should_log() && println("$(now()) ===> synchronized model state across all processes") - ps_, st_ - else - ps, st - end - return model, ps, st end @@ -445,10 +445,6 @@ function construct_optimiser(config::NamedTuple) opt = Optimisers.OptimiserChain(opt, Optimisers.WeightDecay(config.weight_decay)) end - if is_distributed() - opt = DistributedOptimiser(opt) - end - sched = if config.lr_scheduler == :COSINE ParameterSchedulers.Stateful(ParameterSchedulers.Cos(config.eta, 1.0f-6, config.nepochs)) elseif config.lr_scheduler == :CONSTANT From 50b879d29cbbaa4a5ce5019b046f7852be559a57 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 23 May 2022 19:55:19 -0400 Subject: [PATCH 66/76] modify normalization layers --- examples/src/config.jl | 4 ++-- examples/src/models.jl | 37 ++++++++++++++----------------------- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/examples/src/config.jl b/examples/src/config.jl index 3396b4d1..b1b739e9 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -42,7 +42,7 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:TINY}) abstol=5.0f-2, reltol=5.0f-2, ode_solver=VCABM3(), - pretrain_epochs=3, + pretrain_epochs=8, lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 * scaling_factor(), @@ -80,7 +80,7 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:LARGE}) abstol=5.0f-2, reltol=5.0f-2, ode_solver=VCABM3(), - pretrain_epochs=3, + pretrain_epochs=8, lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 * scaling_factor(), diff --git a/examples/src/models.jl b/examples/src/models.jl index d5249ecc..ed4110e4 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -40,8 +40,7 @@ function downsample_module(mapping, level_diff, activation; group_count=8) for i in 1:level_diff inchs, outchs = intermediate_mapping(i) push!(layers, conv3x3(inchs => outchs; stride=2)) - # push!(layers, GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) - push!(layers, BatchNorm(outchs, activation; affine=true, track_stats=false)) + push!(layers, GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) end return Chain(layers...) end @@ -62,8 +61,7 @@ function upsample_module(mapping, level_diff, activation; upsample_mode::Symbol= for i in 1:level_diff inchs, outchs = intermediate_mapping(i) push!(layers, conv1x1(inchs => outchs)) - # push!(layers, GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) - push!(layers, BatchNorm(outchs, activation; affine=true, track_stats=false)) + push!(layers, GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) push!(layers, Upsample(upsample_mode; scale=2)) end return Chain(layers...) @@ -92,12 +90,9 @@ function ResidualBlockV1( conv1, conv2 end - # gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - # gn2 = GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - # gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) - gn1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) - gn2 = BatchNorm(outplanes, relu; affine=gn_affine, track_stats=gn_track_stats) - gn3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) + gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + gn2 = GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) dropout = iszero(dropout_rate) ? NoOpLayer() : VariationalHiddenDropout(dropout_rate) @@ -141,12 +136,9 @@ function ResidualBlockV2( conv1, conv2 end - # gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - # gn2 = GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - # gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) - gn1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) - gn2 = BatchNorm(outplanes, relu; affine=gn_affine, track_stats=gn_track_stats) - gn3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) + gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + gn2 = GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) dropout = iszero(dropout_rate) ? NoOpLayer() : VariationalHiddenDropout(dropout_rate) @@ -227,16 +219,16 @@ function get_model( downsample_layers = [ conv3x3(3 => init_channel_size; stride=config.downsample_times >= 1 ? 2 : 1), - BatchNorm(init_channel_size, relu; affine=true, track_stats=false), + BatchNorm(init_channel_size, relu; affine=true, track_stats=true), conv3x3(init_channel_size => init_channel_size; stride=config.downsample_times >= 2 ? 2 : 1), - BatchNorm(init_channel_size, relu; affine=true, track_stats=false), + BatchNorm(init_channel_size, relu; affine=true, track_stats=true), ] for _ in 3:(config.downsample_times) append!( downsample_layers, [ conv3x3(init_channel_size => init_channel_size; stride=2), - BatchNorm(init_channel_size, relu; affine=true, track_stats=false), + BatchNorm(init_channel_size, relu; affine=true, track_stats=true), ], ) end @@ -247,7 +239,7 @@ function get_model( else Chain( conv1x1(init_channel_size => init_channel_size; bias=false), - BatchNorm(init_channel_size, relu; affine=true, track_stats=false), + BatchNorm(init_channel_size, relu; affine=true, track_stats=true), ) end @@ -288,8 +280,7 @@ function get_model( Chain( ActivationFunction(relu), conv1x1(config.num_channels[i] => config.num_channels[i]), - # GroupNorm(config.num_channels[i], config.group_count ÷ 2; affine=true, track_stats=false), - BatchNorm(config.num_channels[i]; affine=true, track_stats=false), + GroupNorm(config.num_channels[i], config.group_count ÷ 2; affine=true, track_stats=false), ) for i in 1:(config.num_branches) ) @@ -312,7 +303,7 @@ function get_model( increment_modules, downsample_modules, conv1x1(config.head_channels[config.num_branches] * 4 => config.final_channelsize; bias=true), - BatchNorm(config.final_channelsize, relu; track_stats=false, affine=true), + BatchNorm(config.final_channelsize, relu; track_stats=true, affine=true), GlobalMeanPool(), FlattenLayer(), Dense(config.final_channelsize, config.num_classes), From d709d7f14cbb47bf36a40b1373d6c6661c35b388 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 23 May 2022 20:51:41 -0400 Subject: [PATCH 67/76] modify normalization layers --- examples/src/config.jl | 2 +- examples/src/models.jl | 27 ++++++++++++++++++--------- examples/src/utils.jl | 5 +---- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/examples/src/config.jl b/examples/src/config.jl index b1b739e9..5cff04dd 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -35,7 +35,7 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:TINY}) bwd_maxiters=20, continuous=true, stop_mode=:rel_deq_best, - nepochs=50, + nepochs=20, jfb=false, augment=false, model_type=:VANILLA, diff --git a/examples/src/models.jl b/examples/src/models.jl index ed4110e4..3f89897d 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -40,7 +40,8 @@ function downsample_module(mapping, level_diff, activation; group_count=8) for i in 1:level_diff inchs, outchs = intermediate_mapping(i) push!(layers, conv3x3(inchs => outchs; stride=2)) - push!(layers, GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) + # push!(layers, GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) + push!(layers, BatchNorm(outchs, activation; affine=true, track_stats=false)) end return Chain(layers...) end @@ -61,7 +62,8 @@ function upsample_module(mapping, level_diff, activation; upsample_mode::Symbol= for i in 1:level_diff inchs, outchs = intermediate_mapping(i) push!(layers, conv1x1(inchs => outchs)) - push!(layers, GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) + # push!(layers, GroupNorm(outchs, group_count, activation; affine=true, track_stats=false)) + push!(layers, BatchNorm(outchs, activation; affine=true, track_stats=false)) push!(layers, Upsample(upsample_mode; scale=2)) end return Chain(layers...) @@ -90,9 +92,12 @@ function ResidualBlockV1( conv1, conv2 end - gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - gn2 = GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + # gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # gn2 = GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + gn1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) + gn2 = BatchNorm(outplanes, relu; affine=gn_affine, track_stats=gn_track_stats) + gn3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) dropout = iszero(dropout_rate) ? NoOpLayer() : VariationalHiddenDropout(dropout_rate) @@ -136,9 +141,12 @@ function ResidualBlockV2( conv1, conv2 end - gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - gn2 = GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + # gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # gn2 = GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + gn1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) + gn2 = BatchNorm(outplanes, relu; affine=gn_affine, track_stats=gn_track_stats) + gn3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) dropout = iszero(dropout_rate) ? NoOpLayer() : VariationalHiddenDropout(dropout_rate) @@ -280,7 +288,8 @@ function get_model( Chain( ActivationFunction(relu), conv1x1(config.num_channels[i] => config.num_channels[i]), - GroupNorm(config.num_channels[i], config.group_count ÷ 2; affine=true, track_stats=false), + # GroupNorm(config.num_channels[i], config.group_count ÷ 2; affine=true, track_stats=false), + BatchNorm(config.num_channels[i]; affine=true, track_stats=false), ) for i in 1:(config.num_branches) ) diff --git a/examples/src/utils.jl b/examples/src/utils.jl index f88f3cfc..e68ac402 100644 --- a/examples/src/utils.jl +++ b/examples/src/utils.jl @@ -51,8 +51,5 @@ end @inline mse(ŷ, y) = mean(abs2, ŷ .- y) # DataLoaders doesn't yet work with MLUtils -LearnBase.nobs(data::DistributedDataContainer) = MLUtils.numobs(data) -LearnBase.getobs(data::DistributedDataContainer, i::Int) = MLUtils.getobs(data, i) MLDataPattern.nobs(x) = MLUtils.numobs(x) -MLDataPattern.getobs(d::Union{MLUtils.ObsView,DistributedDataContainer}, i::Int64) = - MLUtils.getobs(d, i) +MLDataPattern.getobs(d::MLUtils.ObsView, i::Int64) = MLUtils.getobs(d, i) From 6897ec38cf3e7814e00862fd0cbee6a2f2c48093 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 25 May 2022 23:40:10 -0400 Subject: [PATCH 68/76] Modify model architecture --- Project.toml | 2 +- examples/Project.toml | 2 +- examples/src/models.jl | 32 ++++++++++++++++---------------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/Project.toml b/Project.toml index 0c102336..d1f70bb1 100644 --- a/Project.toml +++ b/Project.toml @@ -40,7 +40,7 @@ Lux = "0.4" MLUtils = "0.2" OrdinaryDiffEq = "6" SciMLBase = "1.19" -Setfield = "0.8.2" +Setfield = "0.8, 1" SteadyStateDiffEq = "1.6" UnPack = "1" Zygote = "0.6.34" diff --git a/examples/Project.toml b/examples/Project.toml index d44b8ad6..a7d4144a 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -51,7 +51,7 @@ NNlib = "0.8" Optimisers = "0.2" OrdinaryDiffEq = "6" ParameterSchedulers = "0.3" -Setfield = "0.8" +Setfield = "0.8, 1" Wandb = "0.4.3" Zygote = "0.6" julia = "1.6" diff --git a/examples/src/models.jl b/examples/src/models.jl index 3f89897d..4fe61863 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -78,7 +78,7 @@ function ResidualBlockV1( n_big_kernels::Int=0, dropout_rate::Real=0.0f0, gn_affine::Bool=true, - weight_norm::Bool=true, + weight_norm::Bool=false, gn_track_stats::Bool=false, ) inplanes, outplanes = mapping @@ -93,10 +93,10 @@ function ResidualBlockV1( end # gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - # gn2 = GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # gn2 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) # gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) gn1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) - gn2 = BatchNorm(outplanes, relu; affine=gn_affine, track_stats=gn_track_stats) + gn2 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) gn3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) dropout = iszero(dropout_rate) ? NoOpLayer() : VariationalHiddenDropout(dropout_rate) @@ -127,7 +127,7 @@ function ResidualBlockV2( n_big_kernels::Int=0, dropout_rate::Real=0.0f0, gn_affine::Bool=true, - weight_norm::Bool=true, + weight_norm::Bool=false, gn_track_stats::Bool=false, ) inplanes, outplanes = mapping @@ -142,10 +142,10 @@ function ResidualBlockV2( end # gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - # gn2 = GroupNorm(outplanes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # gn2 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) # gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) gn1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) - gn2 = BatchNorm(outplanes, relu; affine=gn_affine, track_stats=gn_track_stats) + gn2 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) gn3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) dropout = iszero(dropout_rate) ? NoOpLayer() : VariationalHiddenDropout(dropout_rate) @@ -178,9 +178,9 @@ function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool WrappedFunction(addtuple), # Since injection could be a scalar Chain( BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats), - conv3x3(last(mapping) => last(mapping) * expansion), - BatchNorm(last(mapping) * expansion, relu; track_stats=bn_track_stats, affine=bn_affine), - conv1x1(last(mapping) * expansion => last(mapping) * expansion), + conv3x3(last(mapping) => last(mapping)), + BatchNorm(last(mapping), relu; track_stats=bn_track_stats, affine=bn_affine), + conv1x1(last(mapping) => last(mapping) * expansion), BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), ), ), @@ -188,7 +188,7 @@ function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool ) end -function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=false, bn_affine::Bool=true) +function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=true, bn_affine::Bool=true) rescale = if first(mapping) != last(mapping) * expansion Chain( conv1x1(first(mapping) => last(mapping) * expansion), @@ -205,9 +205,9 @@ function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool Chain( conv1x1(mapping), BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats), - conv3x3(last(mapping) => last(mapping) * expansion), - BatchNorm(last(mapping) * expansion, relu; track_stats=bn_track_stats, affine=bn_affine), - conv1x1(last(mapping) * expansion => last(mapping) * expansion), + conv3x3(last(mapping) => last(mapping)), + BatchNorm(last(mapping), relu; track_stats=bn_track_stats, affine=bn_affine), + conv1x1(last(mapping) => last(mapping) * expansion), BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), ), ), @@ -288,8 +288,8 @@ function get_model( Chain( ActivationFunction(relu), conv1x1(config.num_channels[i] => config.num_channels[i]), - # GroupNorm(config.num_channels[i], config.group_count ÷ 2; affine=true, track_stats=false), - BatchNorm(config.num_channels[i]; affine=true, track_stats=false), + # GroupNorm(config.num_channels[i], config.group_count ÷ 2; affine=false, track_stats=false), + BatchNorm(config.num_channels[i]; affine=false, track_stats=false), ) for i in 1:(config.num_branches) ) @@ -303,7 +303,7 @@ function get_model( [ Chain( conv3x3(config.head_channels[i] * 4 => config.head_channels[i + 1] * 4; stride=2, bias=true), - BatchNorm(config.head_channels[i + 1] * 4, relu; track_stats=false, affine=true), + BatchNorm(config.head_channels[i + 1] * 4, relu; track_stats=true, affine=true), ) for i in 1:(config.num_branches - 1) ]..., ) From 4555e10a1f1e1a0dc3d013c64f8b659bb3294dec Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 25 May 2022 23:40:43 -0400 Subject: [PATCH 69/76] Modify model architecture --- examples/src/config.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/config.jl b/examples/src/config.jl index 5cff04dd..b1b739e9 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -35,7 +35,7 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:TINY}) bwd_maxiters=20, continuous=true, stop_mode=:rel_deq_best, - nepochs=20, + nepochs=50, jfb=false, augment=false, model_type=:VANILLA, From 9fa1c32a8879314627c8d345a3ec68449b275c13 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 May 2022 00:44:16 -0400 Subject: [PATCH 70/76] Use weight norm --- examples/src/models.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/src/models.jl b/examples/src/models.jl index 4fe61863..c3ec5aaa 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -78,7 +78,7 @@ function ResidualBlockV1( n_big_kernels::Int=0, dropout_rate::Real=0.0f0, gn_affine::Bool=true, - weight_norm::Bool=false, + weight_norm::Bool=true, gn_track_stats::Bool=false, ) inplanes, outplanes = mapping @@ -127,7 +127,7 @@ function ResidualBlockV2( n_big_kernels::Int=0, dropout_rate::Real=0.0f0, gn_affine::Bool=true, - weight_norm::Bool=false, + weight_norm::Bool=true, gn_track_stats::Bool=false, ) inplanes, outplanes = mapping @@ -342,7 +342,7 @@ function get_model( deq = if config.model_type ∈ (:SKIP, :SKIPV2) shortcut = if config.model_type == :SKIP - slayers = Lux.AbstractExplicitLayer[ResidualBlockV2(config.num_channels[1] => config.num_channels[1], weight_norm=false)] + slayers = Lux.AbstractExplicitLayer[ResidualBlockV2(config.num_channels[1] => config.num_channels[1], weight_norm=true)] for i in 1:(config.num_branches - 1) push!( slayers, From 0ef3ffa14bea1e215b6c4f5f91a096291ccbc9c4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 May 2022 13:24:32 -0400 Subject: [PATCH 71/76] Update model --- examples/src/config.jl | 12 +-- examples/src/models.jl | 226 ++++++++++++++++++++++------------------- 2 files changed, 129 insertions(+), 109 deletions(-) diff --git a/examples/src/config.jl b/examples/src/config.jl index b1b739e9..c8a4097f 100644 --- a/examples/src/config.jl +++ b/examples/src/config.jl @@ -34,7 +34,7 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:TINY}) fwd_maxiters=18, bwd_maxiters=20, continuous=true, - stop_mode=:rel_deq_best, + stop_mode=:rel_norm, nepochs=50, jfb=false, augment=false, @@ -42,7 +42,7 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:TINY}) abstol=5.0f-2, reltol=5.0f-2, ode_solver=VCABM3(), - pretrain_epochs=8, + pretrain_epochs=5, lr_scheduler=:COSINE, optimiser=:ADAM, eta=0.001f0 * scaling_factor(), @@ -72,7 +72,7 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:LARGE}) fwd_maxiters=18, bwd_maxiters=20, continuous=true, - stop_mode=:rel_deq_best, + stop_mode=:rel_norm, nepochs=220, jfb=false, augment=true, @@ -110,7 +110,7 @@ function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:SMALL}) fwd_maxiters=27, bwd_maxiters=28, continuous=true, - stop_mode=:rel_deq_best, + stop_mode=:rel_norm, nepochs=100, jfb=false, model_type=:VANILLA, @@ -150,7 +150,7 @@ function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:LARGE}) fwd_maxiters=27, bwd_maxiters=28, continuous=true, - stop_mode=:rel_deq_best, + stop_mode=:rel_norm, nepochs=100, jfb=false, model_type=:VANILLA, @@ -190,7 +190,7 @@ function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:XL}) fwd_maxiters=27, bwd_maxiters=28, continuous=true, - stop_mode=:rel_deq_best, + stop_mode=:rel_norm, nepochs=100, jfb=false, model_type=:VANILLA, diff --git a/examples/src/models.jl b/examples/src/models.jl index c3ec5aaa..a0a6da74 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -70,7 +70,18 @@ function upsample_module(mapping, level_diff, activation; upsample_mode::Symbol= end ## Residual Block -function ResidualBlockV1( +struct ResidualBlock{C1,C2,Dr,Do,N1,N2,N3} <: + Lux.AbstractExplicitContainerLayer{(:conv1, :conv2, :dropout, :downsample, :norm1, :norm2, :norm3)} + conv1::C1 + conv2::C2 + dropout::Dr + downsample::Do + norm1::N1 + norm2::N2 + norm3::N3 +end + +function ResidualBlock( mapping; deq_expand::Int=5, num_gn_groups::Int=4, @@ -92,74 +103,81 @@ function ResidualBlockV1( conv1, conv2 end - # gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - # gn2 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) - # gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) - gn1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) - gn2 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) - gn3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) + # norm1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) + # norm2 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + # norm3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) + norm1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) + norm2 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) + norm3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) - dropout = iszero(dropout_rate) ? NoOpLayer() : VariationalHiddenDropout(dropout_rate) + dropout = VariationalHiddenDropout(dropout_rate) - return Chain( - Parallel( - reassociate, # Reassociate and Merge - Chain(conv1, gn1, conv2, BranchLayer(downsample, dropout)), # For x - NoOpLayer(), # For injection - ), - Parallel( - addrelu, - NoOpLayer(), # For y1 - Chain( - WrappedFunction(addtuple), # Since injection could be a scalar - gn2, - ), # For (y2, injection) + return ResidualBlock(conv1, conv2, dropout, downsample, norm1, norm2, norm3) +end + +function (rb::ResidualBlock)((x, y)::NTuple{2,<:AbstractArray}, ps, st) + x, st_conv1 = rb.conv1(x, ps.conv1, st.conv1) + x, st_norm1 = rb.norm1(x, ps.norm1, st.norm1) + x, st_conv2 = rb.conv2(x, ps.conv2, st.conv2) + + x_do, st_downsample = rb.downsample(x, ps.downsample, st.downsample) + x_dr, st_dropout = rb.dropout(x, ps.dropout, st.dropout) + + y_ = x_dr .+ y + y_, st_norm2 = rb.norm2(y_, ps.norm2, st.norm2) + + y__ = relu.(y_ .+ x_do) + y__, st_norm3 = rb.norm3(y__, ps.norm3, st.norm3) + + return ( + y__, + ( + conv1=st_conv1, + conv2=st_conv2, + dropout=st_dropout, + downsample=st_downsample, + norm1=st_norm1, + norm2=st_norm2, + norm3=st_norm3, ), - gn3, ) end -function ResidualBlockV2( - mapping; - deq_expand::Int=1, - num_gn_groups::Int=4, - downsample=NoOpLayer(), - n_big_kernels::Int=0, - dropout_rate::Real=0.0f0, - gn_affine::Bool=true, - weight_norm::Bool=true, - gn_track_stats::Bool=false, -) - inplanes, outplanes = mapping - inner_planes = outplanes * deq_expand - conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; bias=false) - conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=false) - - conv1, conv2 = if weight_norm - WeightNorm(conv1, (:weight,), (4,)), WeightNorm(conv2, (:weight,), (4,)) - else - conv1, conv2 - end - - # gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats) - # gn2 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) - # gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats) - gn1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats) - gn2 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) - gn3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats) - - dropout = iszero(dropout_rate) ? NoOpLayer() : VariationalHiddenDropout(dropout_rate) - - return Chain( - conv1, - gn1, - conv2, - Parallel(addrelu, downsample, Chain(dropout, gn2)), - gn3, +function (rb::ResidualBlock)(x::AbstractArray, ps, st) + x, st_conv1 = rb.conv1(x, ps.conv1, st.conv1) + x, st_norm1 = rb.norm1(x, ps.norm1, st.norm1) + x, st_conv2 = rb.conv2(x, ps.conv2, st.conv2) + + x_do, st_downsample = rb.downsample(x, ps.downsample, st.downsample) + + x_dr, st_dropout = rb.dropout(x, ps.dropout, st.dropout) + x_dr, st_norm2 = rb.norm2(x_dr, ps.norm2, st.norm2) + + y__ = relu.(x_dr .+ x_do) + y__, st_norm3 = rb.norm3(y__, ps.norm3, st.norm3) + + return ( + y__, + ( + conv1=st_conv1, + conv2=st_conv2, + dropout=st_dropout, + downsample=st_downsample, + norm1=st_norm1, + norm2=st_norm2, + norm3=st_norm3, + ), ) end -function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=false, bn_affine::Bool=true) +# Bottleneck Block +struct BottleneckBlock{R,C,M} <: Lux.AbstractExplicitContainerLayer{(:rescale, :conv, :mapping)} + rescale::R + conv::C + mapping::M +end + +function BottleneckBlock(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=true, bn_affine::Bool=true) rescale = if first(mapping) != last(mapping) * expansion Chain( conv1x1(first(mapping) => last(mapping) * expansion), @@ -169,48 +187,48 @@ function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool NoOpLayer() end - return Chain( - Parallel(reassociate, BranchLayer(rescale, conv1x1(mapping)), NoOpLayer()), - Parallel( - addrelu, - NoOpLayer(), - Chain( - WrappedFunction(addtuple), # Since injection could be a scalar - Chain( - BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats), - conv3x3(last(mapping) => last(mapping)), - BatchNorm(last(mapping), relu; track_stats=bn_track_stats, affine=bn_affine), - conv1x1(last(mapping) => last(mapping) * expansion), - BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), - ), - ), - ), + return BottleneckBlock( + rescale, + conv1x1(mapping), + Chain( + BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats), + conv3x3(last(mapping) => last(mapping)), + BatchNorm(last(mapping), relu; track_stats=bn_track_stats, affine=bn_affine), + conv1x1(last(mapping) => last(mapping) * expansion), + BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine) + ) ) end -function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=true, bn_affine::Bool=true) - rescale = if first(mapping) != last(mapping) * expansion - Chain( - conv1x1(first(mapping) => last(mapping) * expansion), - BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), +function (bn::BottleneckBlock)((x, y)::NTuple{2,<:AbstractArray}, ps, st) + x_r, st_rescale = bn.rescale(x, ps.rescale, st.rescale) + x_m, st_conv1 = bn.conv(x_r, ps.conv, st.conv) + + x_m = y .+ x_m + x_m, st_mapping = bn.mapping(x_m, ps.mapping, st.mapping) + + return ( + relu.(x_m .+ x_r), + ( + rescale=st_rescale, + conv=st_conv1, + mapping=st_mapping, ) - else - NoOpLayer() - end + ) +end - return Chain( - Parallel( - addrelu, - rescale, - Chain( - conv1x1(mapping), - BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats), - conv3x3(last(mapping) => last(mapping)), - BatchNorm(last(mapping), relu; track_stats=bn_track_stats, affine=bn_affine), - conv1x1(last(mapping) => last(mapping) * expansion), - BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine), - ), - ), +function (bn::BottleneckBlock)(x::AbstractArray, ps, st) + x_r, st_rescale = bn.rescale(x, ps.rescale, st.rescale) + x_m, st_conv1 = bn.conv(x_r, ps.conv, st.conv) + x_m, st_mapping = bn.mapping(x_m, ps.mapping, st.mapping) + + return ( + relu.(x_m .+ x_r), + ( + rescale=st_rescale, + conv=st_conv1, + mapping=st_mapping, + ) ) end @@ -219,7 +237,7 @@ function get_model( config::NamedTuple; device=gpu, warmup::Bool=true, # Helps reduce Zygote compile times - loss_function=nothing + loss_function=nothing, ) @assert !warmup || loss_function !== nothing @@ -254,7 +272,7 @@ function get_model( initial_layers = Chain(downsample, stage0) main_layers = Tuple( - ResidualBlockV1( + ResidualBlock( config.num_channels[i] => config.num_channels[i]; deq_expand=config.expansion_factor, dropout_rate=config.dropout_rate, @@ -295,7 +313,7 @@ function get_model( increment_modules = Parallel( nothing, - [BottleneckBlockV2(config.num_channels[i] => config.head_channels[i]) for i in 1:(config.num_branches)]..., + [BottleneckBlock(config.num_channels[i] => config.head_channels[i]) for i in 1:(config.num_branches)]..., ) downsample_modules = PairwiseFusion( @@ -322,8 +340,8 @@ function get_model( ContinuousDEQSolver( config.ode_solver; mode=config.stop_mode, - abstol=1.0f-3, - reltol=1.0f-3, + abstol=config.abstol, + reltol=config.reltol, abstol_termination=config.abstol, reltol_termination=config.reltol, ) @@ -342,7 +360,9 @@ function get_model( deq = if config.model_type ∈ (:SKIP, :SKIPV2) shortcut = if config.model_type == :SKIP - slayers = Lux.AbstractExplicitLayer[ResidualBlockV2(config.num_channels[1] => config.num_channels[1], weight_norm=true)] + slayers = Lux.AbstractExplicitLayer[ResidualBlock( + config.num_channels[1] => config.num_channels[1]; weight_norm=true + )] for i in 1:(config.num_branches - 1) push!( slayers, From 3df322d876d7fa494075addcb7bcc161e5eba811 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 May 2022 13:37:55 -0400 Subject: [PATCH 72/76] Relax types --- examples/src/models.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/models.jl b/examples/src/models.jl index a0a6da74..6cd70da7 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -115,7 +115,7 @@ function ResidualBlock( return ResidualBlock(conv1, conv2, dropout, downsample, norm1, norm2, norm3) end -function (rb::ResidualBlock)((x, y)::NTuple{2,<:AbstractArray}, ps, st) +function (rb::ResidualBlock)((x, y)::Tuple{<:AbstractArray,<:AbstractArray}, ps, st) x, st_conv1 = rb.conv1(x, ps.conv1, st.conv1) x, st_norm1 = rb.norm1(x, ps.norm1, st.norm1) x, st_conv2 = rb.conv2(x, ps.conv2, st.conv2) @@ -200,7 +200,7 @@ function BottleneckBlock(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=t ) end -function (bn::BottleneckBlock)((x, y)::NTuple{2,<:AbstractArray}, ps, st) +function (bn::BottleneckBlock)((x, y)::Tuple{<:AbstractArray,<:AbstractArray}, ps, st) x_r, st_rescale = bn.rescale(x, ps.rescale, st.rescale) x_m, st_conv1 = bn.conv(x_r, ps.conv, st.conv) From 1cd9a6e97439c04221c399aef6d8bafb1269662c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 May 2022 13:47:34 -0400 Subject: [PATCH 73/76] Minor fix --- examples/src/models.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/models.jl b/examples/src/models.jl index 6cd70da7..356c3d9d 100644 --- a/examples/src/models.jl +++ b/examples/src/models.jl @@ -202,7 +202,7 @@ end function (bn::BottleneckBlock)((x, y)::Tuple{<:AbstractArray,<:AbstractArray}, ps, st) x_r, st_rescale = bn.rescale(x, ps.rescale, st.rescale) - x_m, st_conv1 = bn.conv(x_r, ps.conv, st.conv) + x_m, st_conv1 = bn.conv(x, ps.conv, st.conv) x_m = y .+ x_m x_m, st_mapping = bn.mapping(x_m, ps.mapping, st.mapping) @@ -219,7 +219,7 @@ end function (bn::BottleneckBlock)(x::AbstractArray, ps, st) x_r, st_rescale = bn.rescale(x, ps.rescale, st.rescale) - x_m, st_conv1 = bn.conv(x_r, ps.conv, st.conv) + x_m, st_conv1 = bn.conv(x, ps.conv, st.conv) x_m, st_mapping = bn.mapping(x_m, ps.mapping, st.mapping) return ( From edbb7a33e99f0bfa660cb3238f845859249b6412 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 May 2022 15:21:55 -0400 Subject: [PATCH 74/76] Make post_deq type-stable --- src/layers/chain.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/layers/chain.jl b/src/layers/chain.jl index d1c7d1ff..581d3a78 100644 --- a/src/layers/chain.jl +++ b/src/layers/chain.jl @@ -27,22 +27,22 @@ function DEQChain(layers...) push!(encounter_deq ? post_deq : pre_deq, l) end @assert encounter_deq "No DEQ Layer in the Chain!!! Maybe you wanted to use Chain" - pre_deq = length(pre_deq) == 0 ? nothing : Chain(pre_deq...) - post_deq = length(post_deq) == 0 ? nothing : Chain(post_deq...) + pre_deq = length(pre_deq) == 0 ? NoOpLayer() : Chain(pre_deq...) + post_deq = length(post_deq) == 0 ? NoOpLayer() : Chain(post_deq...) return DEQChain(pre_deq, deq, post_deq) end -function (deq::DEQChain{P1,D,P2})(x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) where {P1,D,P2} - x1, st1 = if P1 == Nothing - x, st.pre_deq - else - deq.pre_deq(x, ps.pre_deq, st.pre_deq) - end - (x2, deq_soln), st2 = deq.deq(x1, ps.deq, st.deq) - x3, st3 = if P2 == Nothing - x2, st.post_deq - else - deq.post_deq(x2, ps.post_deq, st.post_deq) - end +function get_deq_return_type( + deq::DEQChain{P1,<:Union{MultiScaleDeepEquilibriumNetwork,MultiScaleSkipDeepEquilibriumNetwork}}, ::T +) where {P1,T} + return NTuple{length(deq.deq.scales),T} +end +get_deq_return_type(::DEQChain, ::T) where {T} = T + +function (deq::DEQChain)(x, ps::Union{ComponentArray,NamedTuple}, st::NamedTuple) + T = get_deq_return_type(deq, x) + x1, st1 = deq.pre_deq(x, ps.pre_deq, st.pre_deq) + (x2::T, deq_soln), st2 = deq.deq(x1, ps.deq, st.deq) + x3, st3 = deq.post_deq(x2, ps.post_deq, st.post_deq) return (x3, deq_soln), (pre_deq=st1, deq=st2, post_deq=st3) end From a9af2e3790e90327d288fa3be023f859df7683d9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 May 2022 16:28:26 -0400 Subject: [PATCH 75/76] No MPI for cifar10 --- examples/cifar10/main.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/cifar10/main.jl b/examples/cifar10/main.jl index 8dd175fc..a443e494 100644 --- a/examples/cifar10/main.jl +++ b/examples/cifar10/main.jl @@ -30,7 +30,7 @@ using Wandb # Logging to Weights and using Zygote # Our AD Engine # Distributed Training -FluxMPI.Init(; verbose=true) +# FluxMPI.Init(; verbose=true) CUDA.allowscalar(false) # Training Options @@ -216,7 +216,7 @@ function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, lo ) forward_pass_time(time() - _t, B) _t = time() - gs = back((one(loss) / total_workers(), nothing, nothing))[1] + gs = back((one(loss), nothing, nothing))[1] backward_pass_time(time() - _t, B) st = Lux.update_state(st, :update_mask, Val(true)) if is_distributed() @@ -341,9 +341,9 @@ function main(args) logging_header = ["Epoch", "Train/Batch Time", "Train/Data Time", "Train/Forward Pass Time", "Train/Backward Pass Time", "Train/Cross Entropy Loss", "Train/Skip Loss", "Train/Net Loss", "Train/NFE", "Train/Accuracy", "Train/Residual", "Test/Batch Time", "Test/Data Time", "Test/Cross Entropy Loss", "Test/Skip Loss", "Test/Net Loss", "Test/NFE", "Test/Accuracy", "Test/Residual"] csv_logger = CSVLogger(log_path, logging_header) - wandb_logger = WandbLoggerMPI(project="deep_equilibrium_models", - name=store_in, - config=loggable_config) + wandb_logger = WandbLogger(project="deep_equilibrium_models", + name=store_in, + config=loggable_config) values_to_loggable_dict(args...) = Dict(zip(logging_header, args)) From 7e42c90b9540488902dbeaecbea1f89a644e0d17 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 May 2022 20:49:12 -0400 Subject: [PATCH 76/76] Decay the weight of skip --- examples/cifar10/main.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/cifar10/main.jl b/examples/cifar10/main.jl index a443e494..1a8f8aea 100644 --- a/examples/cifar10/main.jl +++ b/examples/cifar10/main.jl @@ -56,7 +56,7 @@ create_model(expt_config, args) = get_model(expt_config; device=gpu, warmup=true function get_loss_function(args) if args["model-type"] == "VANILLA" - function loss_function_closure_vanilla(x, y, model, ps, st) + function loss_function_closure_vanilla(x, y, model, ps, st, w_skip=args["w-skip"]) (ŷ, soln), st_ = model(x, ps, st) celoss = logitcrossentropy(ŷ, y) skiploss = FastDEQExperiments.mae(soln.u₀, soln.z_star) @@ -65,11 +65,11 @@ function get_loss_function(args) end return loss_function_closure_vanilla else - function loss_function_closure_skip(x, y, model, ps, st) + function loss_function_closure_skip(x, y, model, ps, st, w_skip=args["w-skip"]) (ŷ, soln), st_ = model(x, ps, st) celoss = logitcrossentropy(ŷ, y) skiploss = FastDEQExperiments.mae(soln.u₀, soln.z_star) - loss = celoss + args["w-skip"] * skiploss + loss = celoss + w_skip * skiploss return loss, st_, (ŷ, soln.nfe, celoss, skiploss, soln.residual) end return loss_function_closure_skip @@ -185,7 +185,7 @@ function validate(val_loader, model, ps, st, loss_function, args) end # Training -function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, loss_function, args) +function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, loss_function, w_skip, args) batch_time = AverageMeter("Batch Time", "6.3f") data_time = AverageMeter("Data Time", "6.3f") forward_pass_time = AverageMeter("Forward Pass Time", "6.3f") @@ -212,7 +212,7 @@ function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, lo # Gradients and Update _t = time() (loss, st, (ŷ, nfe_, celoss, skiploss, resi)), back = Zygote.pullback( - p -> loss_function(x, y, model, p, st), ps + p -> loss_function(x, y, model, p, st, w_skip), ps ) forward_pass_time(time() - _t, B) _t = time() @@ -353,10 +353,12 @@ function main(args) st = hasproperty(expt_config, :pretrain_epochs) && getproperty(expt_config, :pretrain_epochs) > 0 ? Lux.update_state(st, :fixed_depth, Val(getproperty(expt_config, :num_layers))) : st + wskip_sched = ParameterSchedulers.Exp(args["w-skip"], 0.92f0) + for epoch in args["start-epoch"]:(expt_config.nepochs) # Train for 1 epoch ps, st, optimiser_state, train_stats = train_one_epoch( - train_loader, model, ps, st, optimiser_state, epoch, loss_function, args + train_loader, model, ps, st, optimiser_state, epoch, loss_function, wskip_sched(epoch), args ) train_stats = get_loggable_stats(train_stats)