Skip to content

Commit

Permalink
Merge pull request #350 from FluxML/cl/scatter
Browse files Browse the repository at this point in the history
simplify scatter mean
  • Loading branch information
CarloLucibello authored Sep 7, 2021
2 parents 5256d0b + 8962dc7 commit 20c110b
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "NNlib"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.28"
version = "0.7.29"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
1 change: 0 additions & 1 deletion src/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ function gather(src::AbstractArray{Tsrc, Nsrc},
return gather!(dst, src, idx)
end


∇gather_src(Δ, src_size, idx) = scatter!(+, fill!(similar(Δ, eltype(Δ), src_size), 0), Δ, idx)

function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
Expand Down
13 changes: 11 additions & 2 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,17 @@ end
function ∇scatter_src(::typeof(mean), Δ, dst,
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
divide_by_counts!(gather(Δ, idx), idx, dims)

M = typelength(Tidx)
num = gather(Δ, idx)
counts = fill!(similar(Δ, Int, size(Δ)[end-M+1:end]), 0)
scatter!(+, counts, fill!(similar(idx, Int), 1), idx)
den = gather(counts, idx)
# make num and den broadcast compatible
for i in 1:ndims(num)-ndims(den)
den = unsqueeze(den)
end
return safe_div.(num, den)
end

∇scatter_src(op, Δ, dst, src, idx) =
Expand Down
17 changes: 1 addition & 16 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,4 @@ function reverse_indices(idx::AbstractArray{<:Any,N}) where N
return reverse_indices!(rev, idx)
end

function count_indices(idx::AbstractArray)
counts = zero.(idx)
for i in unique(idx)
counts += sum(idx .== i) * (idx .== i)
end
return counts
end

function divide_by_counts!(xs, idx::AbstractArray, dims)
colons = Base.ntuple(_->Colon(), dims)
counts = count_indices(idx)
for i in CartesianIndices(counts)
view(xs, colons..., i) ./= counts[i]
end
return xs
end
unsqueeze(x) = reshape(x, 1, size(x)...)
3 changes: 1 addition & 2 deletions test/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ res = Dict(
4. 4. 6. 5. 5.],
)

types = [UInt8, UInt32, UInt128,
Int16, Int64, BigInt,
types = [UInt8, UInt32, Int64,
Float16, Float32, Float64, BigFloat, Rational]

@testset "scatter" begin
Expand Down

2 comments on commit 20c110b

@CarloLucibello
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/44368

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.29 -m "<description of version>" 20c110b8c22f66b84e43ce57022822d07c618d3f
git push origin v0.7.29

Please sign in to comment.