Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiate push! with implicit Params #992

Merged
merged 7 commits into from
Jun 24, 2021

Conversation

DhairyaLGandhi
Copy link
Member

No description provided.

@CarloLucibello
Copy link
Member

Do we have Flux's integration tests in this repo? If so, it would be good to copy and adapt the tests from FluxML/Flux.jl#1614 here. Also needs some push! specific tests

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like various places relied on the old constructor which iterated any iterable xs to build the Params object.

src/compiler/interface.jl Outdated Show resolved Hide resolved
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good pending some extraneous deleted lines

src/compiler/interface.jl Outdated Show resolved Hide resolved
@CarloLucibello CarloLucibello merged commit 87e2f12 into FluxML:master Jun 24, 2021
@CarloLucibello
Copy link
Member

Something is wrong with this PR. I tried the tests in FluxML/Flux.jl#1614
and they fail due to some double-counting

julia> using Flux, Test

julia> @testset "use params in gradient context" begin
           m = Chain(Dense(3,2), Dense(2,2))
           ps = Flux.params(m)
           gs = gradient(() -> sum(sum(p) for p in Flux.params(m)), ps)
           for p in ps
               @test gs[p]  ones(size(p))
           end    

           w1, w2 =  rand(2), rand(2)
           ps = Flux.params(w1, w2)
           gs = gradient(() -> sum(sum(p) for p in Flux.params(w1, w2)), ps)
           for p in ps
               @test gs[p]  ones(size(p))
           end
       end
use params in gradient context: Test Failed at REPL[4]:6
  Expression: gs[p]  ones(size(p))
   Evaluated: 2×3 Fill{Float32}: entries equal to 2.0  [1.0 1.0 1.0; 1.0 1.0 1.0]
Stacktrace:
 [1] macro expansion
   @ REPL[4]:6 [inlined]
 [2] macro expansion
   @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [3] top-level scope
   @ REPL[4]:2
use params in gradient context: Test Failed at REPL[4]:6
  Expression: gs[p]  ones(size(p))
   Evaluated: 2-element Fill{Float32}: entries equal to 2.0  [1.0, 1.0]
Stacktrace:
 [1] macro expansion
   @ REPL[4]:6 [inlined]
 [2] macro expansion
   @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [3] top-level scope
   @ REPL[4]:2
use params in gradient context: Test Failed at REPL[4]:6
  Expression: gs[p]  ones(size(p))
   Evaluated: 2×2 Fill{Float32}: entries equal to 2.0  [1.0 1.0; 1.0 1.0]
Stacktrace:
 [1] macro expansion
   @ REPL[4]:6 [inlined]
 [2] macro expansion
   @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [3] top-level scope
   @ REPL[4]:2
use params in gradient context: Test Failed at REPL[4]:6
  Expression: gs[p]  ones(size(p))
   Evaluated: 2-element Fill{Float32}: entries equal to 2.0  [1.0, 1.0]
Stacktrace:
 [1] macro expansion
   @ REPL[4]:6 [inlined]
 [2] macro expansion
   @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [3] top-level scope
   @ REPL[4]:2
use params in gradient context: Test Failed at REPL[4]:13
  Expression: gs[p]  ones(size(p))
   Evaluated: 2-element Fill{Float64}: entries equal to 2.0  [1.0, 1.0]
Stacktrace:
 [1] macro expansion
   @ REPL[4]:13 [inlined]
 [2] macro expansion
   @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [3] top-level scope
   @ REPL[4]:2
use params in gradient context: Test Failed at REPL[4]:13
  Expression: gs[p]  ones(size(p))
   Evaluated: 2-element Fill{Float64}: entries equal to 2.0  [1.0, 1.0]
Stacktrace:
 [1] macro expansion
   @ REPL[4]:13 [inlined]
 [2] macro expansion
   @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [3] top-level scope
   @ REPL[4]:2
Test Summary:                  | Fail  Total
use params in gradient context |    6      6
ERROR: Some tests did not pass: 0 passed, 6 failed, 0 errored, 0 broken.

@DhairyaLGandhi DhairyaLGandhi deleted the dg/941 branch June 24, 2021 05:55
@CarloLucibello
Copy link
Member

@DhairyaLGandhi this PR introduced a regression, can you spot where the issue is?

@darsnack
Copy link
Member

My guess is there is a discrepancy between Zygote.Params and Flux.params, because there is a nearly identical test in this PR that was passing.

@DhairyaLGandhi
Copy link
Member Author

I'll look into it over the weekend

@ToucheSir
Copy link
Member

The adjoint added at https://github.com/FluxML/Zygote.jl/blob/v0.6.23/src/compiler/interface.jl#L193-L198 seems to be causing some unintended spooky action at a distance, ref. https://discourse.julialang.org/t/calling-flux-params-inside-gradient-changes-output/68829/2.

MWE:

julia> a = rand(10);

julia> gradient(Params([a])) do
         # What Flux.params! does under the hood
         ps = Params()
         push!(ps, a)
         1
       end.grads
IdDict{Any, Any} with 2 entries:
  :(Main.a)                                                                                         => Ones(10)
  [0.462122, 0.162373, 0.579594, 0.403881, 0.762125, 0.7762, 0.961503, 0.812436, 0.747902, 0.38289] => Ones(10)

julia> gradient(Params([a])) do
         buf = Zygote.Buffer([], false)
         ps = Zygote.IdSet()
         # What push!(::Params) does under the hood
         if !(a in ps)
           push!(buf, a)
           push!(ps, a)
         end
         1
       end.grads
IdDict{Any, Any} with 2 entries:
  [0.462122, 0.162373, 0.579594, 0.403881, 0.762125, 0.7762, 0.961503, 0.812436, 0.747902, 0.38289]                                       => nothing
  Buffer{Any, Vector{Any}}(Any[[0.462122, 0.162373, 0.579594, 0.403881, 0.762125, 0.7762, 0.961503, 0.812436, 0.747902, 0.38289]], false) => Any[]

@CarloLucibello
Copy link
Member

I guess this deserves an issue, otherwise it will get lost. Maybe #1025 or #823 will help at least in erroring out?

@DhairyaLGandhi
Copy link
Member Author

DhairyaLGandhi commented Sep 29, 2021

This is because we defined push, etc on Params, so this is definitely expected. With #1025 we would no longer be able to do higher order AD with implicit params.

@DhairyaLGandhi
Copy link
Member Author

We should try with

@adjoint! function push!(xs, x...)
  l = length(x)
  push!(xs, x...), Δ -> begin
    (nothing, ntuple(_ -> nothing, l)...)
  end
end

which is the basic definition, same as we do for Buffer.

@ToucheSir
Copy link
Member

AFAICT neither the higher order example in #823 nor the nested params examples in FluxML/Flux.jl#1614 (comment) were impacted by #1025? Testing the former:

using Zygote: Zygote, @ignore, @showgrad

a, b = rand(10), rand(10)
ps = Params([a, b])

julia> gradient(ps) do
           gs = gradient(ps) do
               sum(a) + sum(b)
           end
           @ignore @show (gs.grads[a], gs.grads[b])
           sum(ps) do x
               @ignore @show x
               sum(gs[x])
           end
       end
(gs.grads[a], gs.grads[b]) = (Fill(1.0, 10), Fill(1.0, 10))
x = [0.577428448803366, 0.4326501245243599, 0.3512628071864099, 0.2544220253047734, 0.22847370951600943, 0.40533823390761825, 0.39367034770674414, 0.4059778114213479, 0.0409136160980863, 0.43590306185598293]
x = [0.6360355211704295, 0.5288173379484092, 0.9710812452894872, 0.6658959378156822, 0.21704879179732628, 0.4477013857799639, 0.9278207943412708, 0.9960756858428664, 0.03741588026431253, 0.5159463044375548]
ERROR: Can't differentiate foreigncall expression
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] Pullback
    @ ./iddict.jl:102 [inlined]
  [3] (::typeof((get)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
  [4] Pullback
    @ ~/.julia/packages/Zygote/SDLsG/src/lib/lib.jl:68 [inlined]
  [5] (::typeof((accum_global)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
  [6] Pullback
    @ ~/.julia/packages/Zygote/SDLsG/src/lib/lib.jl:79 [inlined]
  [7] (::typeof((λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
  [9] (::typeof((λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
 [10] getindex
    @ ./tuple.jl:29 [inlined]
 [11] (::typeof((λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/Zygote/SDLsG/src/compiler/interface.jl:343 [inlined]
 [13] (::typeof((λ)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/Zygote/SDLsG/src/compiler/interface.jl:76 [inlined]
 [15] (::typeof((gradient)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
 [16] Pullback
    @ ./REPL[33]:2 [inlined]
 [17] (::typeof((#109)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface2.jl:0
 [18] (::Zygote.var"#90#91"{Params, typeof((#109)), Zygote.Context})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface.jl:343
 [19] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/SDLsG/src/compiler/interface.jl:76
 [20] top-level scope
    @ REPL[33]:1

The failure seems to stem from a missing get adjoint (addressed in #823) and not anything push!-related.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants