-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Don't try and convert to FloatX except if Integer or AbstractFloat (option 1) #417
Conversation
Codecov Report
@@ Coverage Diff @@
## master #417 +/- ##
==========================================
+ Coverage 92.25% 92.32% +0.06%
==========================================
Files 14 14
Lines 775 782 +7
==========================================
+ Hits 715 722 +7
Misses 60 60
Continue to review full report at Codecov.
|
src/projection.jl
Outdated
# For on-AbstractFloat other types pass though to project each component | ||
function (::ProjectTo{<:Complex{T}})(dx::Complex) where T | ||
project = ProjectTo(zero(T)) | ||
return Complex(project(real(dx)), project(imag(dx))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this aimed at? The real types we preserve, we also already preserve their complex versions.
Maybe recursing into Complex can wait until someone has a stranger real type which needs it. Then we will also know whether ProjectTo(zero(T))
is sufficient for this strange type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we needed it for ProjectTo(1.0 + 1im)(Complex(Dual(1.0, 2.0), Dual(1.0, 2.0)))
But that gets caught by (::ProjectTo{<:Number})(dx::Number) = dx
We would need it if we wanted ProjectTo(1.0 + 1im)(Complex(Dual(1, 2), Dual(1, 2)))
to return a Complex{Dual{Float64}}
rather than a Complex{Dual{Int}}
which it currently does.
but that is not needed right now, so I guess we can add it back later
# The (::ProjectTo{T})(::T) method doesn't work because we are allowing a different | ||
# Number type that might not be a subtype of the `project_type`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this true? project_type
should be Real, Complex, or Number. I get super-confused by dispatch with type parameters on LHS like {T}(...
. But not important really.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for e.g. Float32
or ComplexF64
etc it is not the case that it is Real
, Complex
, or Number
.
I could add those in as alternative special cases instead of this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not be shocked if this come backs to bite us, but we can remove it if and when that happens
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh right, I wasn't thinking clearly, that's exactly the use case.
@mcabbott do you approve this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems fine for now, to make Zygote work. We knew these details may need tweaking.
They may still need more. I'm not so sure whether or not we should do JuliaDiff/ForwardDiff.jl#538 (and possibly things like it elsewhere); if we do then we may also want to make this more strict again, so that x::Float32
forces the numbers inside the Dual
to be Float32. Or maybe not.
One option to fix the issues with tying to use ForwardDiff on a
rrule
pullback that makes use ofProjectTo
See discussion on JuliaDiff/ForwardDiff.jl#538.
Which fixes the following test that fails in FluxML/Zygote.jl#1035
This option basically says:
All subtypes of
Real
represent the same subspace: all of Real.So if you give me one I am not familar with (e.g.
Dual
) I am just going to leave it alone.But if you give a
AbstractFloat
subtype, I will fix the precision to match what i saw in the primal pass.And if you give me an Integer I will make it a Float.
So for this particular case if you give a
ProjectTo{Float64}(::Dual{Int})
it will return aDual{Int}
because it has no idea what a Dual is or if it is Integery or Floaty, but it trust that it is a good Real.This works fine.
It is a bit more code than I expected as need to make sure complex works.