From a706ae396048b4ca692d694fd5044686a8ac20bd Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 22 May 2023 13:52:20 +0800 Subject: [PATCH] optimize + --- src/rulesets/Base/indexing.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 156db42b9..7e1befd14 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -111,6 +111,17 @@ Base.size(A::OneElement) = map(length, A.axes) Base.axes(A::OneElement) = A.axes Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T)) +function ChainRulesCore.add!!(xs::AbstractArray{<:Any,N}, oe::OneElement{<:Any,N}) where {N} + if !ChainRulesCore.is_inplaceable_destination(xs) + xs = collect(xs) + end + xs[oe.ind...] += oe.val + return xs +end + +Base.:(+)(xs::AbstractArray, oe::OneElement) = add!!(copy(xs), oe) +Base.:(+)(oe::OneElement, xs::AbstractArray) = +(xs, oe) +Base.:(+)(oe1::OneElement, oe2::OneElement) = +(collect(oe1), oe2) """ _setindex_zero(x, dy, inds...)