-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better handling for nesting Params #823
Conversation
d[k] | ||
else | ||
hk[] = false | ||
d[k] = default |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get
should not mutate, maybe you want to define the adjoiint for get!
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We aren't mutating anything in the user defined objects, so it should be fine. get!
would still mutate the gradient dictionary so its not any better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are defining an adjoint for get(d::AbstractDict, k, default)
that mutates d
, this is not fine at all
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I see what you mean
This resolves the conflicts and moves iteration over to the dictionary which more honestly captures gradients from globals. This can break the current assumption of algebra on |
@@ -171,15 +172,15 @@ const ADictOrGrads = Union{AbstractDict, Grads} | |||
# Dictionary interface. | |||
# Don't use the IdDict directly since it may contain some spurious pairs. | |||
Base.haskey(gs::Grads, x) = x ∈ gs.params | |||
Base.keys(gs::Grads) = gs.params | |||
# Base.keys(gs::Grads) = gs.params | |||
Base.values(gs::Grads) = (gs.grads[p] for p in gs.params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forwarding keys
to gs.grads while relying on gs.params for the values is not good. Either we base the dictionary interface entirely on gs.params or we base it on gs.grads, can't have mixed stuff.
The comments above
# Don't use the IdDict directly since it may contain some spurious pairs
suggests that we should think changes thoroughly or leave things as they are (which seems the best option to me)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the spurious stuff you refer to? It contains references to the objects that had gradients along the way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Carlo is referring to the existing comment at https://github.com/FluxML/Zygote.jl/pull/823/files#diff-7511b224d7f3ebb56465690de8e307422e3c9798a22bdd4e960d5c86ba6528aaR173. My understanding of that is that Base.keys(Grads.grads)
may contain items not in Base.keys(Grads.params)
. This seems like it should never happen though, so is the comment out of date or am I missing some scenario where it could?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the comment is not outdated
julia> using Flux
julia> m = Chain(Dense(2,2), x->relu.(x), BatchNorm(2))
Chain(Dense(2, 2), #1, BatchNorm(2))
julia> gs = gradient(() -> sum(m(rand(2,2))), Flux.params(m))
Grads(...)
julia> gs.grads
IdDict{Any, Any} with 8 entries:
Float32[0.0, 0.0] => [0.0, 0.0]
BatchNorm(2) => RefValue{Any}((λ = nothing, β = nothing, γ = nothing, μ = nothing, σ² = nothing, ϵ = 0.0, momentum = nothing, affi…
Float32[-0.63824 0.222623; -0.785237 0.536415] => [0.0 0.0; 0.0 0.0]
:(Main.m) => (layers = (nothing, nothing, RefValue{Any}((λ = nothing, β = nothing, γ = nothing, μ = nothing, σ² = nothing, ϵ = …
Box([0.0; 0.0]) => RefValue{Any}((contents = nothing,))
Float32[0.0, 0.0] => [2.0, 2.0]
Box([0.0; 0.0]) => RefValue{Any}((contents = nothing,))
Float32[1.0, 1.0] => [0.0, 0.0]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have to keep the current dict interface based on gs.params
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None of this is spurious though, there is no prior knowledge of what needs to be tracked at the beginning of the differentiation. The grads dictionary returns all the stuff it needed to track even if those entities weren't present in the params. They may have been indirectly needed to get the grads of the params. What we can guarantee is that the grads dictionary will always have the params as keys. So the defensive thing is to return the entire dict, so these values for the intermediaries are available to multiple levels of differentiation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then per Carlo's point, Base.values(gs::Grads)
should also forward to .grads
as well. Having the 2 be differently sized is unexpected (i.e. potentially subtly breaking), and arguably breaking the contract of keys
+ values
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll add that
This needs some tests of the features it adds and for #941 if it actually fixes it |
#941 may well require extra handling. I'd still prefer to solve the underlying problems first. |
# the adjoint jacobian of an FFT with respect to its input is the reverse FFT of the | ||
# gradient of its inputs, but with different normalization factor | ||
@adjoint function fft(xs) | ||
return AbstractFFTs.fft(xs), function(Δ) | ||
return (AbstractFFTs.bfft(Δ),) | ||
end | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
? these empty lines shouldn't be removed
@@ -44,6 +44,27 @@ end | |||
end | |||
end | |||
|
|||
@adjoint function Base._oidd_nextind(a, i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to define an adjoint for an internal function of Base?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately I couldn't see a different way at the time. I'm with you on internal functions. We ended up dropping some grads without it, which hopefully shouldn't be.
Fixes #941 |
This is not fixing #941 |
Yeah, the most direct fix for this is to define |
Since this was mentioned in #1035 (comment), here's what I'd like to see before a merge:
|
Well, some of this is orthogonal. Handling |
Coming back to this with more clarity about the whole push! story (i.e. the understanding that this is completely unrelated), I think the only changes left are a non-mutating |
closing as stale |
This allows for differentiating through the Params code to be able to nest gradient calls with
Params
instead of the usual arguments.