diff --git a/Project.toml b/Project.toml index 5403c13..4a221ec 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Derive" uuid = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a" authors = ["ITensor developers and contributors"] -version = "0.3.3" +version = "0.3.4" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/abstractarrayinterface.jl b/src/abstractarrayinterface.jl index c10b649..c2a853a 100644 --- a/src/abstractarrayinterface.jl +++ b/src/abstractarrayinterface.jl @@ -67,9 +67,20 @@ using BroadcastMapConversion: map_function, map_args # TODO: Look into `SparseArrays.capturescalars`: # https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102 @interface interface::AbstractArrayInterface function Base.copyto!( - dest::AbstractArray, bc::Broadcast.Broadcasted + a_dest::AbstractArray, bc::Broadcast.Broadcasted ) - return @interface interface map!(map_function(bc), dest, map_args(bc)...) + return @interface interface map!(map_function(bc), a_dest, map_args(bc)...) +end + +# This captures broadcast expressions such as `a .= 2`. +# Ideally this would be handled by `map!(f, a_dest)` but that isn't defined yet: +# https://github.com/JuliaLang/julia/issues/31677 +# https://github.com/JuliaLang/julia/pull/40632 +@interface interface::AbstractArrayInterface function Base.copyto!( + a_dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}} +) + isempty(map_args(bc)) || error("Bad broadcast expression.") + return @interface interface map!(map_function(bc), a_dest, a_dest) end # This is defined in this way so we can rely on the Broadcast logic @@ -86,11 +97,41 @@ end # `invoke(Base.map!, Tuple{Any,AbstractArray,Vararg{AbstractArray}}, f, dest, as...)`. # TODO: Use `MethodError`? @interface ::AbstractArrayInterface function Base.map!( - f, dest::AbstractArray, as::AbstractArray... + f, a_dest::AbstractArray, a_srcs::AbstractArray... ) return error("Not implemented.") end +@interface interface::AbstractArrayInterface function Base.fill!(a::AbstractArray, value) + @interface interface map!(Returns(value), a, a) +end + +using ArrayLayouts: zero! + +# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts` +# and is useful for sparse array logic, since it can be used to empty +# the sparse array storage. +# We use a single function definition to minimize method ambiguities. +@interface interface::AbstractArrayInterface function ArrayLayouts.zero!(a::AbstractArray) + # More generally, the first codepath could be taking if `zero(eltype(a))` + # is defined and the elements are immutable. + f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero! + return @interface interface map!(f, a, a) +end + +# Specialized version of `Base.zero` written in terms of `ArrayLayouts.zero!`. +# This is friendlier for sparse arrays since `ArrayLayouts.zero!` makes it easier +# to handle the logic of dropping all elements of the sparse array when possible. +# We use a single function definition to minimize method ambiguities. +@interface interface::AbstractArrayInterface function Base.zero(a::AbstractArray) + # More generally, the first codepath could be taking if `zero(eltype(a))` + # is defined and the elements are immutable. + if eltype(a) <: Number + return @interface interface zero!(similar(a)) + end + return @interface interface map(interface(zero), a) +end + @interface ::AbstractArrayInterface function Base.mapreduce( f, op, as::AbstractArray...; kwargs... ) diff --git a/src/traits.jl b/src/traits.jl index d74d7a0..f833ee7 100644 --- a/src/traits.jl +++ b/src/traits.jl @@ -27,8 +27,12 @@ function derive(::Val{:AbstractArrayOps}, type) Base.all(::$type) Base.iszero(::$type) Base.real(::$type) + Base.fill!(::$type, ::Any) + ArrayLayouts.zero!(::$type) + Base.zero(::$type) Base.permutedims!(::Any, ::$type, ::Any) Broadcast.BroadcastStyle(::Type{<:$type}) + Base.copyto!(::$type, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}) ArrayLayouts.MemoryLayout(::Type{<:$type}) LinearAlgebra.mul!(::AbstractMatrix, ::$type, ::$type, ::Number, ::Number) end diff --git a/test/Project.toml b/test/Project.toml index be2e203..2affdfa 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" +BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2" Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" diff --git a/test/basics/SparseArrayDOKs.jl b/test/basics/SparseArrayDOKs.jl index 44a3e32..224f470 100644 --- a/test/basics/SparseArrayDOKs.jl +++ b/test/basics/SparseArrayDOKs.jl @@ -80,6 +80,24 @@ end @interface ::SparseArrayInterface function Base.map!( f, a_dest::AbstractArray, as::AbstractArray... ) + # TODO: Define a function `preserves_unstored(a_dest, f, as...)` + # to determine if a function preserves the stored values + # of the destination sparse array. + # The current code may be inefficient since it actually + # accesses an unstored element, which in the case of a + # sparse array of arrays can allocate an array. + # Sparse arrays could be expected to define a cheap + # unstored element allocator, for example + # `get_prototypical_unstored(a::AbstractArray)`. + I = first(eachindex(as...)) + preserves_unstored = iszero(f(map(a -> getunstoredindex(a, I), as)...)) + if !preserves_unstored + # Doesn't preserve unstored values, loop over all elements. + for I in eachindex(as...) + a_dest[I] = map(f, map(a -> a[I], as)...) + end + end + # TODO: Define `eachstoredindex(as...)`. for I in union(eachstoredindex.(as)...) a_dest[I] = map(f, map(a -> a[I], as)...) end @@ -230,6 +248,11 @@ end eachstoredindex(a::SparseArrayDOK) = keys(storage(a)) storedlength(a::SparseArrayDOK) = length(eachstoredindex(a)) +function ArrayLayouts.zero!(a::SparseArrayDOK) + empty!(storage(a)) + return a +end + # Specify the interface the type adheres to. Derive.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface() diff --git a/test/basics/test_basics.jl b/test/basics/test_basics.jl index f408320..5ce444a 100644 --- a/test/basics/test_basics.jl +++ b/test/basics/test_basics.jl @@ -1,6 +1,7 @@ -using Test: @test, @testset +using ArrayLayouts: zero! include("SparseArrayDOKs.jl") using .SparseArrayDOKs: SparseArrayDOK, storedlength +using Test: @test, @testset elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "Derive" for elt in elts @@ -89,4 +90,30 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test b == a @test b[1, 2] == 12 @test storedlength(b) == 1 + + a = SparseArrayDOK{elt}(2, 2) + a .= 2 + @test storedlength(a) == length(a) + for I in eachindex(a) + @test a[I] == 2 + end + + a = SparseArrayDOK{elt}(2, 2) + fill!(a, 2) + @test storedlength(a) == length(a) + for I in eachindex(a) + @test a[I] == 2 + end + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + zero!(a) + @test iszero(a) + @test iszero(storedlength(a)) + + a = SparseArrayDOK{elt}(2, 2) + a[1, 2] = 12 + b = zero(a) + @test iszero(b) + @test iszero(storedlength(b)) end