Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling for nesting Params #823

Closed
wants to merge 9 commits into from

Conversation

DhairyaLGandhi
Copy link
Member

This allows for differentiating through the Params code to be able to nest gradient calls with Params instead of the usual arguments.

nn = Chain(Dense(3,3), Dense(3,1))
ip = rand(Float32, 3,2)
gradient(ps) do
  y,b = pullback(ps) do
    sum(nn(ip))
  end
  _gs = b(y)
  sum(x -> sum(_gs[x]), ps)
end

d[k]
else
hk[] = false
d[k] = default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get should not mutate, maybe you want to define the adjoiint for get!?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We aren't mutating anything in the user defined objects, so it should be fine. get! would still mutate the gradient dictionary so its not any better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are defining an adjoint for get(d::AbstractDict, k, default) that mutates d, this is not fine at all

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see what you mean

@DhairyaLGandhi
Copy link
Member Author

This resolves the conflicts and moves iteration over to the dictionary which more honestly captures gradients from globals. This can break the current assumption of algebra on Grads but that's alright since most cases access the params fields directly, and those that don't can be fixed up.

@@ -171,15 +172,15 @@ const ADictOrGrads = Union{AbstractDict, Grads}
# Dictionary interface.
# Don't use the IdDict directly since it may contain some spurious pairs.
Base.haskey(gs::Grads, x) = x ∈ gs.params
Base.keys(gs::Grads) = gs.params
# Base.keys(gs::Grads) = gs.params
Base.values(gs::Grads) = (gs.grads[p] for p in gs.params)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forwarding keys to gs.grads while relying on gs.params for the values is not good. Either we base the dictionary interface entirely on gs.params or we base it on gs.grads, can't have mixed stuff.

The comments above

# Don't use the IdDict directly since it may contain some spurious pairs

suggests that we should think changes thoroughly or leave things as they are (which seems the best option to me)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the spurious stuff you refer to? It contains references to the objects that had gradients along the way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Carlo is referring to the existing comment at https://github.com/FluxML/Zygote.jl/pull/823/files#diff-7511b224d7f3ebb56465690de8e307422e3c9798a22bdd4e960d5c86ba6528aaR173. My understanding of that is that Base.keys(Grads.grads) may contain items not in Base.keys(Grads.params). This seems like it should never happen though, so is the comment out of date or am I missing some scenario where it could?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment is not outdated

julia> using Flux

julia> m = Chain(Dense(2,2), x->relu.(x), BatchNorm(2)) 
Chain(Dense(2, 2), #1, BatchNorm(2))

julia> gs = gradient(() -> sum(m(rand(2,2))), Flux.params(m))
Grads(...)

julia> gs.grads
IdDict{Any, Any} with 8 entries:
  Float32[0.0, 0.0]                              => [0.0, 0.0]
  BatchNorm(2)                                   => RefValue{Any}((λ = nothing, β = nothing, γ = nothing, μ = nothing, σ² = nothing, ϵ = 0.0, momentum = nothing, affi
  Float32[-0.63824 0.222623; -0.785237 0.536415] => [0.0 0.0; 0.0 0.0]
  :(Main.m)                                      => (layers = (nothing, nothing, RefValue{Any}((λ = nothing, β = nothing, γ = nothing, μ = nothing, σ² = nothing, ϵ = 
  Box([0.0; 0.0])                                => RefValue{Any}((contents = nothing,))
  Float32[0.0, 0.0]                              => [2.0, 2.0]
  Box([0.0; 0.0])                                => RefValue{Any}((contents = nothing,))
  Float32[1.0, 1.0]                              => [0.0, 0.0]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have to keep the current dict interface based on gs.params

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of this is spurious though, there is no prior knowledge of what needs to be tracked at the beginning of the differentiation. The grads dictionary returns all the stuff it needed to track even if those entities weren't present in the params. They may have been indirectly needed to get the grads of the params. What we can guarantee is that the grads dictionary will always have the params as keys. So the defensive thing is to return the entire dict, so these values for the intermediaries are available to multiple levels of differentiation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then per Carlo's point, Base.values(gs::Grads) should also forward to .grads as well. Having the 2 be differently sized is unexpected (i.e. potentially subtly breaking), and arguably breaking the contract of keys + values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll add that

@CarloLucibello
Copy link
Member

This needs some tests of the features it adds and for #941 if it actually fixes it

@DhairyaLGandhi
Copy link
Member Author

#941 may well require extra handling. I'd still prefer to solve the underlying problems first.

# the adjoint jacobian of an FFT with respect to its input is the reverse FFT of the
# gradient of its inputs, but with different normalization factor
@adjoint function fft(xs)
return AbstractFFTs.fft(xs), function(Δ)
return (AbstractFFTs.bfft(Δ),)
end
end

Copy link
Member

@CarloLucibello CarloLucibello Jun 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? these empty lines shouldn't be removed

@@ -44,6 +44,27 @@ end
end
end

@adjoint function Base._oidd_nextind(a, i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to define an adjoint for an internal function of Base?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I couldn't see a different way at the time. I'm with you on internal functions. We ended up dropping some grads without it, which hopefully shouldn't be.

@DhairyaLGandhi
Copy link
Member Author

Fixes #941

@CarloLucibello
Copy link
Member

This is not fixing #941

@ToucheSir
Copy link
Member

Yeah, the most direct fix for this is to define @adjointBase.push!(ps::Params, x...) because that's what params calls and what #876 (implicitly/inadvertently?) eliminated.

@ToucheSir
Copy link
Member

Since this was mentioned in #1035 (comment), here's what I'd like to see before a merge:

@DhairyaLGandhi
Copy link
Member Author

Well, some of this is orthogonal. Handling Params in nested differentiation is more than just handling push! so I wouldn't block it on push!. push! was needed for construction, this is necessary for nesting and tracking parameters in higher order differentiation.

@ToucheSir
Copy link
Member

Coming back to this with more clarity about the whole push! story (i.e. the understanding that this is completely unrelated), I think the only changes left are a non-mutating get adjoint and having Grads forward Base.values to its inner IdDict? Most of the CI failures seem unrelated, so I presume they'll go away after a rebase and we can safely land this.

@CarloLucibello
Copy link
Member

closing as stale

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants