Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference check failure in test_rrule, @inferred rrule works #246

Open
tpapp opened this issue May 3, 2022 · 7 comments
Open

inference check failure in test_rrule, @inferred rrule works #246

tpapp opened this issue May 3, 2022 · 7 comments
Assignees

Comments

@tpapp
Copy link

tpapp commented May 3, 2022

I wrapped up an MWE in ImplicitAD.jl for an rrule I defined (see the single test). My issue is that

test_rrule(one_one, 1.0;
                   check_inferred = true,
                   fdm = forward_fdm(5, 1),
                   atol = ϵ, rtol = ϵ)

gives an inference failure, while

@inferred rrule(ChainRulesTestUtils.TestConfig(), one_one, 1.0)

is fine.

@mzgubic
Copy link
Member

mzgubic commented May 3, 2022

test_rrule also checks the pullback inferrability, perhaps that's what's giving the error?

@tpapp
Copy link
Author

tpapp commented May 3, 2022

Thanks. Indeed it does, but it seems to be coming from the rrule_via_ad call:

julia> g, pbg = @inferred rrule_via_ad(ChainRulesTestUtils.TestConfig(), one_one_core, 1.0, 2.0)
(162761.18047510285, ChainRulesTestUtils.var"#f_pb#43"{ChainRulesTestUtils.TestConfig, Tuple{Bool, Bool, Bool}, Tuple{typeof(one_one_core), Float64, Float64}, ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}(ChainRulesTestUtils.TestConfig(FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}(SVector{5,Int64}(-2, -1, 0, 1, 2), SVector{5,Float64}(0.08333333333333333, -0.6666666666666666, 0.0, 0.6666666666666666, -0.08333333333333333), (SVector{5,Float64}(-0.08333333333333333, 0.5, -1.5, 0.8333333333333334, 0.25), SVector{5,Float64}(0.08333333333333333, -0.6666666666666666, 0.0, 0.6666666666666666, -0.08333333333333333), SVector{5,Float64}(-0.25, -0.8333333333333334, 1.5, -0.5, 0.08333333333333333)), 10.0, 1.0, Inf, 0.05555555555555555, 1.4999999999999998, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}(SVector{7,Int64}(-3, -2, -1, 0, 1, 2, 3), SVector{7,Float64}(-0.5, 2.0, -2.5, 0.0, 2.5, -2.0, 0.5), (SVector{7,Float64}(0.5, -4.0, 12.5, -20.0, 17.5, -8.0, 1.5), SVector{7,Float64}(-0.5, 2.0, -2.5, 0.0, 2.5, -2.0, 0.5), SVector{7,Float64}(-1.5, 8.0, -17.5, 20.0, -12.5, 4.0, -0.5)), 10.0, 1.0, Inf, 0.5365079365079365, 10.0))), (true, false, false), (ImplicitAD.one_one_core, 1.0, 2.0), ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}(Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}())))

julia> @code_warntype pbg(1.0)
MethodInstance for (::ChainRulesTestUtils.var"#f_pb#43"{ChainRulesTestUtils.TestConfig, Tuple{Bool, Bool, Bool}, Tuple{typeof(one_one_core), Float64, Float64}, ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}})(::Float64)
  from (::ChainRulesTestUtils.var"#f_pb#43")(ȳ) in ChainRulesTestUtils at /home/tamas/.julia/packages/ChainRulesTestUtils/vWKSm/src/rule_config.jl:39
Arguments
  #self#::ChainRulesTestUtils.var"#f_pb#43"{ChainRulesTestUtils.TestConfig, Tuple{Bool, Bool, Bool}, Tuple{typeof(one_one_core), Float64, Float64}, ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}}
  ȳ::Float64
Body::Tuple
1%1 = Core.getfield(#self#, :config)::ChainRulesTestUtils.TestConfig%2 = Base.getproperty(%1, :fdm)::Any%3 = Core.getfield(#self#, :call)::Core.Const(ChainRulesTestUtils.var"#call#42"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}(Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}()))%4 = Core.getfield(#self#, :primals)::Tuple{typeof(one_one_core), Float64, Float64}%5 = Core.getfield(#self#, :is_ignored)::Tuple{Bool, Bool, Bool}%6 = ChainRulesTestUtils._make_j′vp_call(%2, %3, ȳ, %4, %5)::Tuple
└──      return %6

Given the MWE, I wonder if anyone could please dig into this. I am not familiar with the internals of this package.

@tpapp
Copy link
Author

tpapp commented May 3, 2022

Wait, I think that just typing the fdm field of TestConfig would do the trick. Testing, then making a PR.

@tpapp
Copy link
Author

tpapp commented May 3, 2022

Nope, that fixes the Any for %2 above, but _make_j′vp_call is still not inferred. Which is not surprising because of the Any[...] inside that function.

@mzgubic
Copy link
Member

mzgubic commented May 3, 2022

Thanks for checking, I'll have time next week to dig into this. For now maybe just set check_inferred=false?

@tpapp
Copy link
Author

tpapp commented May 3, 2022

Thanks, I appreciate it.

@mzgubic
Copy link
Member

mzgubic commented May 10, 2022

My current understanding of this issue is:

The inference comes from the pullback, in particular the rrule_via_ad call, more precisely the _make_j′vp_call as you correctly point out. We can't make that call infer easily because the output (tangents to xs) will depend on whether we are ignoring a particular x, the information for which is passed in as a boolean array.

I see two ways around the issue:

  1. We can define an rrule for the one_one_core function, which will be called by the rrule_via_ad with the current TestConfig
  2. We can pass in a different RuleConfig, which unlike the current one (TestConfig) will not call the finite differences methods under the hood to work out the rrule.

I imagine this package wants to be AD-independent? In that case we could still use a particular AD system as a test dependency and use its RuleConfig if we want to test the inference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants