-
-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: mapreduce type stability, VoA broadcast adjoints #325
Conversation
src/vector_of_array.jl
Outdated
@inline Base.sum(VA::AbstractVectorOfArray; kwargs...) = sum(identity, VA; kwargs...) | ||
@inline function Base.sum(f, VA::AbstractVectorOfArray; kwargs...) | ||
if hasproperty(kwargs, :dims) | ||
sum(Array(VA); kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be a potentially big allocation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was the default implementation, but I just realized two things:
- that method is probably not necessary
- mapreduce is implemented wrong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in bdb5830
src/vector_of_array.jl
Outdated
@@ -638,8 +653,9 @@ end | |||
end | |||
|
|||
Base.map(f, A::RecursiveArrayTools.AbstractVectorOfArray) = map(f, A.u) | |||
|
|||
function Base.mapreduce(f, op, A::AbstractVectorOfArray) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably need to have an init
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in bdb5830
2a55225
to
1f9b577
Compare
_minus(Δ) = .-Δ | ||
_minus(::Nothing) = nothing | ||
|
||
@adjoint function Broadcast.broadcasted(::typeof(-), x::AbstractVectorOfArray, y::Union{AbstractVectorOfArray, Zygote.Numeric}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're never supposed to define specific broadcasts like this, and I don't think digging into Zygote's broadcast system is the answer. It just relies on the Julia-level broadcast
This will need a real fix, but let's at least get downstream tests passing. |
Close #323
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
Add any other context about the problem here.