Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_project error #1078

Closed
willtebbutt opened this issue Sep 24, 2021 · 6 comments · Fixed by #1079
Closed

_project error #1078

willtebbutt opened this issue Sep 24, 2021 · 6 comments · Fixed by #1079
Labels
bug Something isn't working

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Sep 24, 2021

The problem occurs when one tries to take the Zygote.gradient of a function an AbstractArray with a structural tangent.

ProjectTo(my_array) produces a ProjectTo{AbstractArray}, and Zygote winds up producing a Tangent{Any} for it. This means that this method doesn't apply.

MWE:

julia> using ChainRules: Tangent, ProjectTo, NoTangent

julia> using KernelFunctions: MOInputIsotopicByFeatures

julia> x = MOInputIsotopicByFeatures(randn(3), 2);

julia> dx = Tangent{Any}(x=randn(3), out_dim=NoTangent());

julia> ProjectTo(x)(dx)
ERROR: MethodError: no method matching (::ProjectTo{AbstractArray, NamedTuple{(:elements, :axes), Tuple{Vector{typeof(identity)}, Tuple{Base.OneTo{Int64}}}}})(::Tangent{Any, NamedTuple{(:x, :out_dim), Tuple{Vector{Float64}, NoTangent}}})
Closest candidates are:
  (::ProjectTo{T, D} where D<:NamedTuple)(::Tangent{var"#s15", T} where {var"#s15"<:T, T}) where T at /Users/willtebbutt/.julia/packages/ChainRulesCore/ChM7X/src/projection.jl:139
  (::ProjectTo{AbstractArray, D} where D<:NamedTuple)(::Union{LinearAlgebra.Adjoint{T, var"#s832"}, LinearAlgebra.Transpose{T, var"#s832"}} where {T, var"#s832"<:(AbstractVector{T} where T)}) at /Users/willtebbutt/.julia/packages/ChainRulesCore/ChM7X/src/projection.jl:239
  (::ProjectTo{AbstractArray, D} where D<:NamedTuple)(::Number) at /Users/willtebbutt/.julia/packages/ChainRulesCore/ChM7X/src/projection.jl:245
  ...
Stacktrace:
 [1] top-level scope
   @ REPL[6]:1

What is the appropriate fix @mcabbott @oxinabox @mzgubic ? I can work around this by using pullback or _pullback directly, but it's obviously sub-optimal not to be able to use gradient.

Where's the discussion about adding the _project call here? I feel like I've seen it somewhere, but I'm not sure where...

@willtebbutt willtebbutt added the bug Something isn't working label Sep 24, 2021
@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Sep 24, 2021

See #1075

And #1044 for ProjectTo, definitely agree that that should go. I feel projecting gradients to some space may be limiting in some ways the compiler can optimise code.

@willtebbutt
Copy link
Member Author

I feel projecting gradients to some space may be limiting in some ways the compiler can optimise code.

To be clear, I have no problem whatsoever will projecting things (e.g. dropping the imaginary bit of a complex number representing the tangent of a real number) -- it's both a correctness and performance issue (especially where linear algebra is concerned). I just don't think the implementation is correct here 🤷 .

I would be interested to know where you think projections could get in the way of optimisations because my experience has consistently been that failing to project is the thing that causes problems, but this probably isn't the place for that discussion.

@mcabbott
Copy link
Member

mcabbott commented Sep 24, 2021

Should this be a CRC issue? You will get the same error regardless of what rule or function applies the projection. The assumption of projection is that it may be applied widely.

It's likely to ProjectTo{AbstractArray} should pass all Tangents through, which is I think what you're arguing should happen.

Or are you arguing that _project should be more permissive than ProjectTo? Are there cases which it should handle, which will never be seen by an rrule applying projection?

@willtebbutt
Copy link
Member Author

I was hoping you would suggest the appropriate fix :) Sounds like you're in favour of modification to CRC? Happy to open a PR if that's so.

@oxinabox
Copy link
Member

Possibly Zygote should type-pirate:

(::ProjectTo)(dx::Tangent}) = dx

to work-around the fact that it loses track of primal types
and so it can't manage to hit:

(::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx

This is what it currently does for Complex numbers for this reason, because it can't hit the Tangent{<:Complex} method

(project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im))

@willtebbutt
Copy link
Member Author

Seems reasonable to me. I'll open a PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants