Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ambiguity in _setindex_zero #669

Merged
merged 3 commits into from
Aug 23, 2022
Merged

Fix ambiguity in _setindex_zero #669

merged 3 commits into from
Aug 23, 2022

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Aug 22, 2022

This should fix the bug from https://discourse.julialang.org/t/forwarddiff-jl-error-loaderror-methoderror-convert-is-ambiguous/86132 , which must be cause by #655

julia> using Zygote

julia> function f3(x)
           A = ones(5,5)*x
           maximum(A)
       end
f3 (generic function with 1 method)

julia> gradient(f3, 0.5)
(1.0,)

julia> hessian(f3, 0.5)
ERROR: MethodError: convert(::Type{ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#104#105"{typeof(f3)}, Float64}, Float64, 1}}, ::ChainRulesCore.ZeroTangent) is ambiguous.

Candidates:
  convert(::Type{T}, x::ChainRulesCore.AbstractZero) where T<:Number
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/abstract_zero.jl:31
  convert(::Type{ForwardDiff.Dual{T, V, N}}, x) where {T, V, N}
    @ ForwardDiff ~/.julia/packages/ForwardDiff/pDtsf/src/dual.jl:432

Possible fix, define
  convert(::Type{ForwardDiff.Dual{T, V, N}}, ::ChainRulesCore.AbstractZero) where {T, V, N}

Stacktrace:
  [1] fill!(dest::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#104#105"{typeof(f3)}, Float64}, Float64, 1}}, x::ChainRulesCore.ZeroTangent)
    @ Base ./array.jl:347
  [2] _setindex_zero(::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#104#105"{typeof(f3)}, Float64}, Float64, 1}}, ::ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#104#105"{typeof(f3)}, Float64}, Float64, 1}, ::Int64, ::Vararg{Int64})
    @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/indexing.jl:104
  [3] ∇getindex(x::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#104#105"{typeof(f3)}, Float64}, Float64, 1}}, dy::ForwardDiff.Dual{ForwardDiff.Tag{Zygote.var"#104#105"{typeof(f3)}, Float64}, Float64, 1}, inds::CartesianIndex{2})
    @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/indexing.jl:89

Xref JuliaDiff/ChainRulesCore.jl#448


I think the reason Zygote is calling this rule at all is that its rrule_via_ad function takes a shortcut which doesn't check its own rules:

https://github.com/FluxML/Zygote.jl/blob/99d5a38b14dc842643acfa624b6f0f89061efbbf/src/compiler/chainrules.jl#L243-L246

Edit: maybe not, sorry. The rule for maximum calls ∇getindex directly.


Needs a test. Do we add ForwardDiff just for this?

@github-actions github-actions bot added the needs version bump Version needs to be incremented or set to -DEV in Project.toml label Aug 22, 2022
@oxinabox
Copy link
Member

Needs a test. Do we add ForwardDiff just for this?

I would rather not.
I am not sure how specific we want to be just now with what we promise for this function.
So I think it is fine to just leave this as is -- it passes current tests whicch i assume hit the changed line.

So I am kind of ok having this without additional tests

@mcabbott mcabbott merged commit 39c2d17 into main Aug 23, 2022
@mcabbott mcabbott deleted the mcabbott-patch-3 branch August 23, 2022 11:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs version bump Version needs to be incremented or set to -DEV in Project.toml
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants