Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot compute gradient with ArrayPartition that holds containers of different types #706

Open
bvdmitri opened this issue Jul 25, 2024 · 3 comments

Comments

@bvdmitri
Copy link

ArrayPartition is a useful structure to concatenate arrays of different types. The type is defined in SciML/RecursiveArrayTools.jl

ArrayPartitions are also used in many places in SciML ecosystem, but also in other places like Manopt.jl.
It appears, though, that if ArrayPartition references two containers, one of eltype is Float64 and another one is Int64, the gradient from ForwardDiff fails.

MWE is:

julia> using ForwardDiff, RecursiveArrayTools

julia> v = [ 0.0, 1 ]
2-element Vector{Float64}:
 0.0
 1.0

julia> f(v) = sum(v)
f (generic function with 1 method)

julia> ForwardDiff.gradient(f, [ 0.0, 1 ])
2-element Vector{Float64}:
 1.0
 1.0

julia> ForwardDiff.gradient(f, ArrayPartition([ 0.0 ], [ 1 ]))
ERROR: MethodError: no method matching ForwardDiff.Dual{ForwardDiff.Tag{typeof(f), Float64}, Float64, 2}(::Int64, ::ForwardDiff.Partials{2, Float64})

Closest candidates are:
  ForwardDiff.Dual{T, V, N}(::Number) where {T, V, N}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:78
  ForwardDiff.Dual{T, V, N}(::Any) where {T, V, N}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:77
  ForwardDiff.Dual{T, V, N}(::V, ::ForwardDiff.Partials{N, V}) where {T, V, N}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:17

Stacktrace:
  [1] _broadcast_getindex_evalf
    @ ./broadcast.jl:709 [inlined]
  [2] _broadcast_getindex
    @ ./broadcast.jl:682 [inlined]
  [3] getindex
    @ ./broadcast.jl:636 [inlined]
  [4] macro expansion
    @ ./broadcast.jl:1004 [inlined]
  [5] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [6] copyto!
    @ ./broadcast.jl:1003 [inlined]
  [7] copyto!
    @ ./broadcast.jl:956 [inlined]
  [8] materialize!
    @ ./broadcast.jl:914 [inlined]
  [9] materialize!
    @ ./broadcast.jl:911 [inlined]
 [10] seed!(duals::ArrayPartition{…}, x::ArrayPartition{…}, seeds::Tuple{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:52
 [11] vector_mode_dual_eval!
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/apiutils.jl:23 [inlined]
@mcabbott
Copy link
Member

I think the reason ForwardDiff is confused because this ArrayPartition declares itself to have Float64 elements, but in fact returns an Int sometimes:

julia> ap = ArrayPartition([ 0.0 ], [ 1 ])
([0.0], [1])

julia> size(ap), axes(ap), eltype(ap), supertype(typeof(ap))
((2,), (Base.OneTo(2),), Float64, AbstractVector{Float64})

julia> ap[1]  # no surprise
0.0

julia> ap[2]  # very surprising
1

julia> ap[1:2]  # here you get the expected eltype
2-element Vector{Float64}:
 0.0
 1.0

julia> ap[2,1:1]
1-element Vector{Float64}:
 1.0

The usual way to encode that elements of a vector have different types is to have an abstract eltype, which it seems ForwardDiff is able to handle.

(Note that the other example above constructs a Vector{Float64}, promoting to 1.0 when making the array.)

julia> x64 = [ 0.0, 1 ]  # this promotes on construction
2-element Vector{Float64}:
 0.0
 1.0

julia> ForwardDiff.gradient(f, x64)  # as in question
2-element Vector{Float64}:
 1.0
 1.0

julia> xabs = Real[ 0.0, 1 ]  # abstract eltype, could also use  xabs = Union{Float64, Int}[ 0.0, 1 ] 
2-element Vector{Real}:
 0.0
 1

julia> ForwardDiff.gradient(f, xabs)  # also OK, ForwardDiff not confused
2-element Vector{Float64}:
 1.0
 1.0

Fixing ArrayPartition to declare its eltype accurately would be the obvious fix here, and would probably avoid many other weird edge cases. (Or else fixing its getindex definition to convert to the declared eltype.) Although I'm sure there's going to be some reason that consistency is inconvenient for something.

It's possible that ForwardDiff could be made more robust to misleading signals. For instance making the ForwardDiff.Dual constructor called above promote its first argument might work here?

@bvdmitri
Copy link
Author

Opened an issue in RecursiveArrayTools as well, though, I have a feeling that this behaviour might be by design.

For instance making the ForwardDiff.Dual constructor called above promote its first argument might work here?

For me that would be an obvious fix, that shouldn't break anything, right?

@KristofferC
Copy link
Collaborator

I have a feeling that this behaviour might be by design.

That just seems broken though. If fixing RecursiveArrayTools also fixes this then I don't think anything should be done here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants