From ca2e5acd10388f87a7882aa8197fc6952b2dc38a Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 1 Dec 2022 07:08:58 +0100 Subject: [PATCH 1/2] Dict constructor rrule --- src/lib/base.jl | 16 ++++++++++++++++ test/features.jl | 29 +++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/src/lib/base.jl b/src/lib/base.jl index e259a999d..0f2d663ef 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -231,3 +231,19 @@ end fallback_Fix2(y) = f(y, x) return _pullback(__context__, fallback_Fix2, y) end + +function ChainRulesCore.rrule(::typeof(Dict), xs::Pair...) + function Dict_pullback(Δ) + return (NoTangent(), ((first=ZeroTangent(), second=get(Δ, x[1], ZeroTangent())) for x in xs)...) + end + return Dict(xs...), Dict_pullback +end + +# xs iterable of pairs +function ChainRulesCore.rrule(::typeof(Dict), xs) + function Dict_pullback(Δ) + x̄s = [(first=ZeroTangent(), second=get(Δ, x[1], ZeroTangent())) for x in xs] + return (NoTangent(), x̄s) + end + return Dict(xs), Dict_pullback +end diff --git a/test/features.jl b/test/features.jl index e4fe61140..98e26de64 100644 --- a/test/features.jl +++ b/test/features.jl @@ -835,3 +835,32 @@ end end @test gradient(f760, 3)[1] ≈ 123.93054835019153 end + +@test "Dict constructors" begin + # pair + g = gradient(1 => 2) do x + d = Dict(x) + d[1] + end[1] + @test g == (first = nothing, second = 1) + + # pairs + g = gradient(1 => 2, 2 => 3, 4=>10) do x1, x2, x3 + d = Dict(x1, x2, x3) + d[1] + 2*d[4] + end + @test g == ((first = nothing, second = 1), nothing, (first = nothing, second = 2.0)) + + # array of pairs + g = gradient(2) do c + d = Dict([i => i*c for i in 1:3]) + d[1] + 2*d[2] + end[1] + @test g == 5 + + # generator of pairs + @test_broken gradient(2) do c + d = Dict(i => i*c for i in 1:3) + d[1] + 2*d[2] + end[1] +end \ No newline at end of file From 939c8e4898cfbdc2688402e3a011bc0f4996f6ed Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 1 Dec 2022 08:08:26 +0100 Subject: [PATCH 2/2] Dict constructor pullback --- src/lib/base.jl | 40 ++++++++++++++++++++++++++++++++++------ test/features.jl | 20 ++++++++++++++++++-- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/src/lib/base.jl b/src/lib/base.jl index 0f2d663ef..88738afb1 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -232,18 +232,46 @@ end return _pullback(__context__, fallback_Fix2, y) end -function ChainRulesCore.rrule(::typeof(Dict), xs::Pair...) +# function ChainRulesCore.rrule(::typeof(Dict), xs::Pair...) +# function Dict_pullback(Δ) +# return (NoTangent(), ((first=ZeroTangent(), second=get(Δ, x[1], ZeroTangent())) for x in xs)...) +# end +# return Dict(xs...), Dict_pullback +# end + +# function ChainRulesCore.rrule(::typeof(Dict), xs::AbstractVector{<:Pair}) +# function Dict_pullback(Δ) +# x̄s = [(first=ZeroTangent(), second=get(Δ, x[1], ZeroTangent())) for x in xs] +# return (NoTangent(), x̄s) +# end +# return Dict(xs), Dict_pullback +# end + + +function Zygote._pullback(::AContext, ::typeof(Dict), xs::Pair...) function Dict_pullback(Δ) - return (NoTangent(), ((first=ZeroTangent(), second=get(Δ, x[1], ZeroTangent())) for x in xs)...) + return (nothing, ((first=nothing, second=get(Δ, x[1], nothing)) for x in xs)...) end return Dict(xs...), Dict_pullback end -# xs iterable of pairs -function ChainRulesCore.rrule(::typeof(Dict), xs) +function Zygote._pullback(::AContext, ::typeof(Dict), xs::AbstractVector{<:Pair}) function Dict_pullback(Δ) - x̄s = [(first=ZeroTangent(), second=get(Δ, x[1], ZeroTangent())) for x in xs] - return (NoTangent(), x̄s) + x̄s = [(first=nothing, second=get(Δ, x[1], nothing)) for x in xs] + return (nothing, x̄s) end return Dict(xs), Dict_pullback end + +# iterable of pairs / generator +function _pullback(cx::AContext, ::typeof(Dict), xs) + a, pba = _pullback(cx, collect, xs) + y, pby = _pullback(cx, Dict, a) + function Dict_pullback(Δ) + Δa = pby(Δ)[2] + @show a Δa Δ + Δxs = pba(Δa) + return (nothing, Δxs) + end + return y, Dict_pullback +end diff --git a/test/features.jl b/test/features.jl index 98e26de64..85317103d 100644 --- a/test/features.jl +++ b/test/features.jl @@ -836,7 +836,7 @@ end @test gradient(f760, 3)[1] ≈ 123.93054835019153 end -@test "Dict constructors" begin +@testset "Dict constructors" begin # pair g = gradient(1 => 2) do x d = Dict(x) @@ -863,4 +863,20 @@ end d = Dict(i => i*c for i in 1:3) d[1] + 2*d[2] end[1] -end \ No newline at end of file +end + +# pullback(Dict, 1 => 2) + +# Zygote.refresh() +# y, pb = Zygote._pullback(Zygote.Context(), Dict, 1 => 2) +# pb(Dict(1 => 5)) + +# gradient(2) do c +# d = Dict(i => i*c for i in 1:3) +# d[1] + 2*d[2] +# end[1] + +# gradient(2) do c +# d = collect(i => i*c for i in 1:3) +# d[1][2] + 2*d[2][2] +# end[1] \ No newline at end of file