Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't try and convert to FloatX except if Integer or AbstractFloat (option 1) #417

Merged
merged 4 commits into from
Jul 27, 2021

Conversation

oxinabox
Copy link
Member

One option to fix the issues with tying to use ForwardDiff on a rrule pullback that makes use of ProjectTo
See discussion on JuliaDiff/ForwardDiff.jl#538.

Which fixes the following test that fails in FluxML/Zygote.jl#1035

  using Zygote
  dx, dy = diaghessian(f34, xs, y)
  @test size(dx) == size(xs)
  @test vec(dx)  diag(hessian(x -> f34(x,y), xs))
  @test dy  hessian(y -> f34(xs,y), y)

This option basically says:
All subtypes of Real represent the same subspace: all of Real.
So if you give me one I am not familar with (e.g. Dual) I am just going to leave it alone.
But if you give a AbstractFloat subtype, I will fix the precision to match what i saw in the primal pass.
And if you give me an Integer I will make it a Float.

So for this particular case if you give a ProjectTo{Float64}(::Dual{Int}) it will return a Dual{Int} because it has no idea what a Dual is or if it is Integery or Floaty, but it trust that it is a good Real.
This works fine.

It is a bit more code than I expected as need to make sure complex works.

@codecov-commenter
Copy link

codecov-commenter commented Jul 27, 2021

Codecov Report

Merging #417 (127bae5) into master (424a0b7) will increase coverage by 0.06%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #417      +/-   ##
==========================================
+ Coverage   92.25%   92.32%   +0.06%     
==========================================
  Files          14       14              
  Lines         775      782       +7     
==========================================
+ Hits          715      722       +7     
  Misses         60       60              
Impacted Files Coverage Δ
src/projection.jl 95.71% <100.00%> (+0.14%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 424a0b7...127bae5. Read the comment docs.

src/projection.jl Outdated Show resolved Hide resolved
Comment on lines 167 to 170
# For on-AbstractFloat other types pass though to project each component
function (::ProjectTo{<:Complex{T}})(dx::Complex) where T
project = ProjectTo(zero(T))
return Complex(project(real(dx)), project(imag(dx)))
Copy link
Member

@mcabbott mcabbott Jul 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this aimed at? The real types we preserve, we also already preserve their complex versions.

Maybe recursing into Complex can wait until someone has a stranger real type which needs it. Then we will also know whether ProjectTo(zero(T)) is sufficient for this strange type.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we needed it for ProjectTo(1.0 + 1im)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0)))
But that gets caught by (::ProjectTo{<:Number})(dx::Number) = dx

We would need it if we wanted ProjectTo(1.0 + 1im)(Complex(Dual(1, 2), Dual(1, 2))) to return a Complex{Dual{Float64}}
rather than a Complex{Dual{Int}} which it currently does.
but that is not needed right now, so I guess we can add it back later

src/projection.jl Outdated Show resolved Hide resolved
src/projection.jl Outdated Show resolved Hide resolved
src/projection.jl Outdated Show resolved Hide resolved
@oxinabox oxinabox requested a review from mcabbott July 27, 2021 13:50
Comment on lines +157 to +158
# The (::ProjectTo{T})(::T) method doesn't work because we are allowing a different
# Number type that might not be a subtype of the `project_type`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true? project_type should be Real, Complex, or Number. I get super-confused by dispatch with type parameters on LHS like {T}(.... But not important really.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for e.g. Float32 or ComplexF64 etc it is not the case that it is Real, Complex, or Number.
I could add those in as alternative special cases instead of this.

Copy link
Member Author

@oxinabox oxinabox Jul 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not be shocked if this come backs to bite us, but we can remove it if and when that happens

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, I wasn't thinking clearly, that's exactly the use case.

@oxinabox
Copy link
Member Author

@mcabbott do you approve this PR?
By ColPrac i need someone to approve so i can merge it, so I can finish the Zygote PR

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine for now, to make Zygote work. We knew these details may need tweaking.

They may still need more. I'm not so sure whether or not we should do JuliaDiff/ForwardDiff.jl#538 (and possibly things like it elsewhere); if we do then we may also want to make this more strict again, so that x::Float32 forces the numbers inside the Dual to be Float32. Or maybe not.

@oxinabox oxinabox merged commit f6ed7de into master Jul 27, 2021
@oxinabox oxinabox deleted the ox/lkf1 branch July 27, 2021 14:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants