Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to implement vect #492

Closed
willtebbutt opened this issue Aug 1, 2021 · 19 comments
Closed

How to implement vect #492

willtebbutt opened this issue Aug 1, 2021 · 19 comments

Comments

@willtebbutt
Copy link
Member

#491 implements vect for two special cases, but neglects the general case because I couldn't figure out how to implement it in a type-stable manner for Numbers, or at all in general.

My attempt at an implementation for Numbers was

# Numbers need to be projected because they don't pass straight through the function.
# More generally, we would ideally project everything.
function rrule(::typeof(Base.vect), X::Vararg{Number, N}) where {N}
    l = length(X)
    projects = map(ProjectTo, X)
    function vect_pullback(ȳ)
        X̄ = ntuple(n -> projects[n](ȳ[n]), l)
        return (NoTangent(), X̄...)
    end
    return Base.vect(X...), vect_pullback
end

but it's type-unstable because the type of X\bar couldn't be inferred (Julia 1.6.2). I'm not entirely sure why it can't figure it out.

The Number specialisation of this function is interesting because it highlights that when an array is constructed in vect, the type of the elements changes from their original types some of the time but not others. For example, this isn't the case with subtypes of AbstractArray, which essentially pass through unchanged. I'm not sure how to do with this in general, because

  1. there's not a function to explicitly hook in to which implements the conversion
  2. ProjectTo isn't defined for all types (only Numbers and AbstractArrays AFAICT), so we can't rely on it in general.

Any thoughts on either of these issues?

@simeonschaub
Copy link
Member

In this case I think you want to move the l = length(X) into the pullback, that way the length should be able to be constant-propagated.

@mcabbott
Copy link
Member

mcabbott commented Aug 1, 2021

Just using N instead of l seems to solve it, on 1.7 at least. You could also close over valN = Val(N) which sometimes behaves better.

@willtebbutt
Copy link
Member Author

Good catch, thanks @simeonschaub . I've actually just gone down the Vararg{T, N} where {T, N} route, which seems to solve the problem.

@willtebbutt
Copy link
Member Author

Just using N instead of l seems to solve it, on 1.7 at least. You could also close over valN = Val(N) which sometimes behaves better.

Haha you just beat me to it!

@simeonschaub
Copy link
Member

Just using N instead of l seems to solve it, on 1.7 at least.

Oh, nice! I actually wasn't aware that we special-cased closures over static parameters like this.

@mcabbott
Copy link
Member

mcabbott commented Aug 1, 2021

Heh. Arrays also sometimes get promoted, e.g.:

julia> [[1,2], [3,4.0+im]]
2-element Vector{Vector{ComplexF64}}:
 [1.0 + 0.0im, 2.0 + 0.0im]
 [3.0 + 0.0im, 4.0 + 1.0im]

julia> [[1,2]', [3 4]]
2-element Vector{AbstractMatrix{Int64}}:
 [1 2]
 [3 4]
  1. ProjectTo isn't defined for all types

Yes, the final version thought it safest to leave that undefined, so that any changes would be possible later. The competing idea was to make it ignore anything it doesn't (yet) understand how to project, in which case it could be applied blindly.

That may ultimately be the right thing. For instance if you have a type (which is not an array) and define your own ProjectTo methods, under the present design they will never be called except in rules you write. In FluxML/Zygote.jl#1044 Zygote I made a Zygote._project function which is safe to apply everywhere, but it seems weird to tell anyone to overload that instead.

But for vect for now, numbers & arrays probably cover the vast majority of uses.

@mcabbott
Copy link
Member

mcabbott commented Aug 1, 2021

Maybe while thinking about vect, we should also think about its typed friends:

julia> Meta.@lower Int[1,2,3]
:($(Expr(:thunk, CodeInfo(
    @ none within `top-level scope`
1%1 = Base.getindex(Int, 1, 2, 3)
└──      return %1
))))

julia> Meta.@lower Int[1 2; 3 4]
:($(Expr(:thunk, CodeInfo(
    @ none within `top-level scope`
1%1 = Core.tuple(2, 2)
│   %2 = Base.typed_hvcat(Int, %1, 1, 2, 3, 4)
└──      return %2
))))

@willtebbutt
Copy link
Member Author

Heh. Arrays also sometimes get promoted, e.g.:

I learned something new here.

Yes, the final version thought it safest to leave that undefined, so that any changes would be possible later. The competing idea was to make it ignore anything it doesn't (yet) understand how to project, in which case it could be applied blindly.

Ah, great. I'm going through the process of getting back up to speed after all of the progress that's been made the last couple of months (which has been fantastic), so this context is helpeful.

But for vect for now, numbers & arrays probably cover the vast majority of uses.

Good point. In any case, I could actually see any tests for in in Zygote, so my feeling is that it's probably safe to restrict the arguments to Union{Number, AbstractArray} and be done with it.

Maybe while thinking about vect, we should also think about its typed friends:

I had literally no idea that was how this is implemented. I don't think I'll address it in #491 if that's okay, but I agree that it would be good to cover.

@nickrobinson251
Copy link
Contributor

nickrobinson251 commented Aug 1, 2021

I had literally no idea that was how this is implemented

if you're taking about vect, then it is surprisingly simple
https://github.com/JuliaLang/julia/blob/bdacfa21ef7d9657be0ff3220947ceca92a340a1/base/array.jl#L125-L148

but when you start looking at all the *cat(🐱) / typed_*cat (⌨️🐱) stuff in base/abstractarray.jl, there are a lot of different methods (and i suppose hvncat has added to that in v1.7)

Anyway, probably a deep but maybe worthwhile rabbit hole 🐰

@willtebbutt
Copy link
Member Author

willtebbutt commented Aug 1, 2021

if you're taking about vect, then it is surprisingly simple

Ah, no, I was just referring to the fact that getindex(Int, 5) is a thing 🤯 -- makes sense, but I wouldn't have imagined that was how it worked.

@mcabbott
Copy link
Member

mcabbott commented Aug 1, 2021

the fact that getindex(Int, 5) is a thing 🤯

I was oddly pleased when I discovered this. Nobody felt the need to make an AbstractAgnosticSquareBracketsHander in the name of purity.

hvncat is a rabbit-hole I've stayed a safe distance from, for now.

@nickrobinson251
Copy link
Contributor

nickrobinson251 commented Aug 1, 2021

getindex(Int, 5) is a thing

oh, yeah, lol -- that is wild -- sometimes i feel better not to know these things (in case i'm tempted to use them)

on the otherhand i agree it's oddly pleasing

@mcabbott
Copy link
Member

mcabbott commented Aug 2, 2021

Next question is how to make collect infer. This is the same as vect without the splat, but ChainRules wants it to return a Tangent, and this doesn't seem to infer:

NumberOrArray{T<:Number} = Union{T,AbstractArray{T}}

function rrule(::typeof(collect), xs::Tuple{Vararg{NumberOrArray,N}}) where {N}
    TX = Val(typeof(xs))
    projectors = map(ProjectTo, xs)
    function collect_pullback_3(dy_raw)
        @info "3"
        dy = unthunk(dy_raw)
        dxs = ntuple(n -> projectors[n](dy[n]), N)
        return (NoTangent(), Tangent{_val(TX)}(dxs...))
    end
    return collect(xs), collect_pullback_3
end

test_rrule(collect, (1.2, 3.4 + 5.6im))

# return type Tuple{ChainRulesCore.NoTangent, ChainRulesCore.Tangent{Tuple{Float64, ComplexF64}, Tuple{Float64, ComplexF64}}} does not match inferred 
# return type Tuple{ChainRulesCore.NoTangent, ChainRulesCore.Tangent{Tuple{Float64, ComplexF64}}}

@willtebbutt
Copy link
Member Author

Alas, not making vect sufficiently general as already come back to bite me. This KernelFunctions PR to upgrade to CRC 1 appears to be failing because we're hitting a method of vect which ChainRules doesn't currently support.

The question becomes what to do about it. Should we generalise ProjectTo to handle aribtrary data types? The code is definitely hitting this method, but I'm genuinely at a loss as to how to implement the ChainRule. I suspect we're going to need to call back in to AD to make this work, but I'm not sure what function to call back in to :(

@willtebbutt
Copy link
Member Author

One potential edge case that we could add is when T is inferred to be Any -- in that case we know that data will pass straight through without modification.

@mcabbott
Copy link
Member

mcabbott commented Aug 9, 2021

The short-term solution is I think not to project at all, when you aren't sure that it's safe. That seems OK, we don't guarantee 100% application of this thing, still have many rules which don't apply it yet even though they could.

@willtebbutt
Copy link
Member Author

Okay, cool. I'll add another method which accepts anything and doesn't project.

Do you think it would be worth starting to assemble a list of rules where we think the implementation is dubious due to a lack of projection?

@mcabbott
Copy link
Member

mcabbott commented Aug 9, 2021

Might be premature to make a list, there are lots which just nobody has got around to, I have a branch which fixes a dozen somewhere...

@pabloferz was complaining that some of them don't play well with GPUs, the solution to which may also involve weakening slightly the "guarantee" about how universally projection will happen. (And turning more branches into dispatch.)

@willtebbutt
Copy link
Member Author

Might be premature to make a list, there are lots which just nobody has got around to, I have a branch which fixes a dozen somewhere...

Sorry, I'm not trying to suggest that we need to act on the list immediately, I just don't want us to risk forgetting that this is a thing that will probably have to be addressed at some point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants