Skip to content

Commit

Permalink
Fill, zero, etc. (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Dec 12, 2024
1 parent 61786d0 commit f30e57d
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Derive"
uuid = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.3"
version = "0.3.4"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
47 changes: 44 additions & 3 deletions src/abstractarrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,20 @@ using BroadcastMapConversion: map_function, map_args
# TODO: Look into `SparseArrays.capturescalars`:
# https://github.com/JuliaSparse/SparseArrays.jl/blob/1beb0e4a4618b0399907b0000c43d9f66d34accc/src/higherorderfns.jl#L1092-L1102
@interface interface::AbstractArrayInterface function Base.copyto!(
dest::AbstractArray, bc::Broadcast.Broadcasted
a_dest::AbstractArray, bc::Broadcast.Broadcasted
)
return @interface interface map!(map_function(bc), dest, map_args(bc)...)
return @interface interface map!(map_function(bc), a_dest, map_args(bc)...)
end

# This captures broadcast expressions such as `a .= 2`.
# Ideally this would be handled by `map!(f, a_dest)` but that isn't defined yet:
# https://github.com/JuliaLang/julia/issues/31677
# https://github.com/JuliaLang/julia/pull/40632
@interface interface::AbstractArrayInterface function Base.copyto!(
a_dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}
)
isempty(map_args(bc)) || error("Bad broadcast expression.")
return @interface interface map!(map_function(bc), a_dest, a_dest)
end

# This is defined in this way so we can rely on the Broadcast logic
Expand All @@ -86,11 +97,41 @@ end
# `invoke(Base.map!, Tuple{Any,AbstractArray,Vararg{AbstractArray}}, f, dest, as...)`.
# TODO: Use `MethodError`?
@interface ::AbstractArrayInterface function Base.map!(
f, dest::AbstractArray, as::AbstractArray...
f, a_dest::AbstractArray, a_srcs::AbstractArray...
)
return error("Not implemented.")
end

@interface interface::AbstractArrayInterface function Base.fill!(a::AbstractArray, value)
@interface interface map!(Returns(value), a, a)
end

using ArrayLayouts: zero!

# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts`
# and is useful for sparse array logic, since it can be used to empty
# the sparse array storage.
# We use a single function definition to minimize method ambiguities.
@interface interface::AbstractArrayInterface function ArrayLayouts.zero!(a::AbstractArray)
# More generally, the first codepath could be taking if `zero(eltype(a))`
# is defined and the elements are immutable.
f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero!
return @interface interface map!(f, a, a)
end

# Specialized version of `Base.zero` written in terms of `ArrayLayouts.zero!`.
# This is friendlier for sparse arrays since `ArrayLayouts.zero!` makes it easier
# to handle the logic of dropping all elements of the sparse array when possible.
# We use a single function definition to minimize method ambiguities.
@interface interface::AbstractArrayInterface function Base.zero(a::AbstractArray)
# More generally, the first codepath could be taking if `zero(eltype(a))`
# is defined and the elements are immutable.
if eltype(a) <: Number
return @interface interface zero!(similar(a))
end
return @interface interface map(interface(zero), a)
end

@interface ::AbstractArrayInterface function Base.mapreduce(
f, op, as::AbstractArray...; kwargs...
)
Expand Down
4 changes: 4 additions & 0 deletions src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ function derive(::Val{:AbstractArrayOps}, type)
Base.all(::$type)
Base.iszero(::$type)
Base.real(::$type)
Base.fill!(::$type, ::Any)
ArrayLayouts.zero!(::$type)
Base.zero(::$type)
Base.permutedims!(::Any, ::$type, ::Any)
Broadcast.BroadcastStyle(::Type{<:$type})
Base.copyto!(::$type, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}})
ArrayLayouts.MemoryLayout(::Type{<:$type})
LinearAlgebra.mul!(::AbstractMatrix, ::$type, ::$type, ::Number, ::Number)
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand Down
23 changes: 23 additions & 0 deletions test/basics/SparseArrayDOKs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,24 @@ end
@interface ::SparseArrayInterface function Base.map!(
f, a_dest::AbstractArray, as::AbstractArray...
)
# TODO: Define a function `preserves_unstored(a_dest, f, as...)`
# to determine if a function preserves the stored values
# of the destination sparse array.
# The current code may be inefficient since it actually
# accesses an unstored element, which in the case of a
# sparse array of arrays can allocate an array.
# Sparse arrays could be expected to define a cheap
# unstored element allocator, for example
# `get_prototypical_unstored(a::AbstractArray)`.
I = first(eachindex(as...))
preserves_unstored = iszero(f(map(a -> getunstoredindex(a, I), as)...))
if !preserves_unstored
# Doesn't preserve unstored values, loop over all elements.
for I in eachindex(as...)
a_dest[I] = map(f, map(a -> a[I], as)...)
end
end
# TODO: Define `eachstoredindex(as...)`.
for I in union(eachstoredindex.(as)...)
a_dest[I] = map(f, map(a -> a[I], as)...)
end
Expand Down Expand Up @@ -230,6 +248,11 @@ end
eachstoredindex(a::SparseArrayDOK) = keys(storage(a))
storedlength(a::SparseArrayDOK) = length(eachstoredindex(a))

function ArrayLayouts.zero!(a::SparseArrayDOK)
empty!(storage(a))
return a
end

# Specify the interface the type adheres to.
Derive.interface(::Type{<:SparseArrayDOK}) = SparseArrayInterface()

Expand Down
29 changes: 28 additions & 1 deletion test/basics/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test: @test, @testset
using ArrayLayouts: zero!
include("SparseArrayDOKs.jl")
using .SparseArrayDOKs: SparseArrayDOK, storedlength
using Test: @test, @testset

elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "Derive" for elt in elts
Expand Down Expand Up @@ -89,4 +90,30 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test b == a
@test b[1, 2] == 12
@test storedlength(b) == 1

a = SparseArrayDOK{elt}(2, 2)
a .= 2
@test storedlength(a) == length(a)
for I in eachindex(a)
@test a[I] == 2
end

a = SparseArrayDOK{elt}(2, 2)
fill!(a, 2)
@test storedlength(a) == length(a)
for I in eachindex(a)
@test a[I] == 2
end

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
zero!(a)
@test iszero(a)
@test iszero(storedlength(a))

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = zero(a)
@test iszero(b)
@test iszero(storedlength(b))
end

0 comments on commit f30e57d

Please sign in to comment.