-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Differentiate push!
with implicit Params
#992
Conversation
Do we have Flux's integration tests in this repo? If so, it would be good to copy and adapt the tests from FluxML/Flux.jl#1614 here. Also needs some |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like various places relied on the old constructor which iterated any iterable xs
to build the Params
object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good pending some extraneous deleted lines
Something is wrong with this PR. I tried the tests in FluxML/Flux.jl#1614 julia> using Flux, Test
julia> @testset "use params in gradient context" begin
m = Chain(Dense(3,2), Dense(2,2))
ps = Flux.params(m)
gs = gradient(() -> sum(sum(p) for p in Flux.params(m)), ps)
for p in ps
@test gs[p] ≈ ones(size(p))
end
w1, w2 = rand(2), rand(2)
ps = Flux.params(w1, w2)
gs = gradient(() -> sum(sum(p) for p in Flux.params(w1, w2)), ps)
for p in ps
@test gs[p] ≈ ones(size(p))
end
end
use params in gradient context: Test Failed at REPL[4]:6
Expression: gs[p] ≈ ones(size(p))
Evaluated: 2×3 Fill{Float32}: entries equal to 2.0 ≈ [1.0 1.0 1.0; 1.0 1.0 1.0]
Stacktrace:
[1] macro expansion
@ REPL[4]:6 [inlined]
[2] macro expansion
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[3] top-level scope
@ REPL[4]:2
use params in gradient context: Test Failed at REPL[4]:6
Expression: gs[p] ≈ ones(size(p))
Evaluated: 2-element Fill{Float32}: entries equal to 2.0 ≈ [1.0, 1.0]
Stacktrace:
[1] macro expansion
@ REPL[4]:6 [inlined]
[2] macro expansion
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[3] top-level scope
@ REPL[4]:2
use params in gradient context: Test Failed at REPL[4]:6
Expression: gs[p] ≈ ones(size(p))
Evaluated: 2×2 Fill{Float32}: entries equal to 2.0 ≈ [1.0 1.0; 1.0 1.0]
Stacktrace:
[1] macro expansion
@ REPL[4]:6 [inlined]
[2] macro expansion
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[3] top-level scope
@ REPL[4]:2
use params in gradient context: Test Failed at REPL[4]:6
Expression: gs[p] ≈ ones(size(p))
Evaluated: 2-element Fill{Float32}: entries equal to 2.0 ≈ [1.0, 1.0]
Stacktrace:
[1] macro expansion
@ REPL[4]:6 [inlined]
[2] macro expansion
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[3] top-level scope
@ REPL[4]:2
use params in gradient context: Test Failed at REPL[4]:13
Expression: gs[p] ≈ ones(size(p))
Evaluated: 2-element Fill{Float64}: entries equal to 2.0 ≈ [1.0, 1.0]
Stacktrace:
[1] macro expansion
@ REPL[4]:13 [inlined]
[2] macro expansion
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[3] top-level scope
@ REPL[4]:2
use params in gradient context: Test Failed at REPL[4]:13
Expression: gs[p] ≈ ones(size(p))
Evaluated: 2-element Fill{Float64}: entries equal to 2.0 ≈ [1.0, 1.0]
Stacktrace:
[1] macro expansion
@ REPL[4]:13 [inlined]
[2] macro expansion
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[3] top-level scope
@ REPL[4]:2
Test Summary: | Fail Total
use params in gradient context | 6 6
ERROR: Some tests did not pass: 0 passed, 6 failed, 0 errored, 0 broken. |
@DhairyaLGandhi this PR introduced a regression, can you spot where the issue is? |
My guess is there is a discrepancy between |
I'll look into it over the weekend |
The adjoint added at https://github.com/FluxML/Zygote.jl/blob/v0.6.23/src/compiler/interface.jl#L193-L198 seems to be causing some unintended spooky action at a distance, ref. https://discourse.julialang.org/t/calling-flux-params-inside-gradient-changes-output/68829/2. MWE: julia> a = rand(10);
julia> gradient(Params([a])) do
# What Flux.params! does under the hood
ps = Params()
push!(ps, a)
1
end.grads
IdDict{Any, Any} with 2 entries:
:(Main.a) => Ones(10)
[0.462122, 0.162373, 0.579594, 0.403881, 0.762125, 0.7762, 0.961503, 0.812436, 0.747902, 0.38289] => Ones(10)
julia> gradient(Params([a])) do
buf = Zygote.Buffer([], false)
ps = Zygote.IdSet()
# What push!(::Params) does under the hood
if !(a in ps)
push!(buf, a)
push!(ps, a)
end
1
end.grads
IdDict{Any, Any} with 2 entries:
[0.462122, 0.162373, 0.579594, 0.403881, 0.762125, 0.7762, 0.961503, 0.812436, 0.747902, 0.38289] => nothing
Buffer{Any, Vector{Any}}(Any[[0.462122, 0.162373, 0.579594, 0.403881, 0.762125, 0.7762, 0.961503, 0.812436, 0.747902, 0.38289]], false) => Any[] |
This is because we defined push, etc on Params, so this is definitely expected. With #1025 we would no longer be able to do higher order AD with implicit params. |
We should try with @adjoint! function push!(xs, x...)
l = length(x)
push!(xs, x...), Δ -> begin
(nothing, ntuple(_ -> nothing, l)...)
end
end which is the basic definition, same as we do for Buffer. |
AFAICT neither the higher order example in #823 nor the nested params examples in FluxML/Flux.jl#1614 (comment) were impacted by #1025? Testing the former: using Zygote: Zygote, @ignore, @showgrad
a, b = rand(10), rand(10)
ps = Params([a, b])
julia> gradient(ps) do
gs = gradient(ps) do
sum(a) + sum(b)
end
@ignore @show (gs.grads[a], gs.grads[b])
sum(ps) do x
@ignore @show x
sum(gs[x])
end
end
(gs.grads[a], gs.grads[b]) = (Fill(1.0, 10), Fill(1.0, 10))
x = [0.577428448803366, 0.4326501245243599, 0.3512628071864099, 0.2544220253047734, 0.22847370951600943, 0.40533823390761825, 0.39367034770674414, 0.4059778114213479, 0.0409136160980863, 0.43590306185598293]
x = [0.6360355211704295, 0.5288173379484092, 0.9710812452894872, 0.6658959378156822, 0.21704879179732628, 0.4477013857799639, 0.9278207943412708, 0.9960756858428664, 0.03741588026431253, 0.5159463044375548]
ERROR: Can't differentiate foreigncall expression
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] Pullback
@ ./iddict.jl:102 [inlined]
[3] (::typeof(∂(get)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
[4] Pullback
@ ~/.julia/packages/Zygote/SDLsG/src/lib/lib.jl:68 [inlined]
[5] (::typeof(∂(accum_global)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
[6] Pullback
@ ~/.julia/packages/Zygote/SDLsG/src/lib/lib.jl:79 [inlined]
[7] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[9] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
[10] getindex
@ ./tuple.jl:29 [inlined]
[11] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
[12] Pullback
@ ~/.julia/packages/Zygote/SDLsG/src/compiler/interface.jl:343 [inlined]
[13] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
[14] Pullback
@ ~/.julia/packages/Zygote/SDLsG/src/compiler/interface.jl:76 [inlined]
[15] (::typeof(∂(gradient)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
[16] Pullback
@ ./REPL[33]:2 [inlined]
[17] (::typeof(∂(#109)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
[18] (::Zygote.var"#90#91"{Params, typeof(∂(#109)), Zygote.Context})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface.jl:343
[19] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface.jl:76
[20] top-level scope
@ REPL[33]:1 The failure seems to stem from a missing |
No description provided.