Skip to content

Commit

Permalink
Fix defining single line interface functions, add support for cat a…
Browse files Browse the repository at this point in the history
…nd slicing (#18)

* Fix defining single line interface functions, add support for `cat` and slicing

* Bump to v0.3.5
  • Loading branch information
mtfishman authored Dec 16, 2024
1 parent f30e57d commit 5c52fbf
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Derive"
uuid = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.4"
version = "0.3.5"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
121 changes: 121 additions & 0 deletions src/abstractarrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ using ArrayLayouts: ArrayLayouts
return ArrayLayouts.layout_getindex(a, I...)
end

@interface interface::AbstractArrayInterface function Base.setindex!(
a::AbstractArray, value, I...
)
# TODO: Change to this once broadcasting in `@interface` is supported:
# @interface interface a[I...] .= value
@interface interface map!(identity, @view(a[I...]), value)
return a
end

# TODO: Maybe define as `ArrayLayouts.layout_getindex(a, I...)` or
# `invoke(getindex, Tuple{AbstractArray,Vararg{Any}}, a, I...)`.
# TODO: Use `MethodError`?
Expand All @@ -28,6 +37,27 @@ end
return error("Not implemented.")
end

# TODO: Make this more general, use `Base.to_index`.
@interface interface::AbstractArrayInterface function Base.getindex(
a::AbstractArray{<:Any,N}, I::CartesianIndex{N}
) where {N}
return @interface interface getindex(a, Tuple(I)...)
end

# TODO: Use `MethodError`?
@interface ::AbstractArrayInterface function Base.setindex!(
a::AbstractArray{<:Any,N}, value, I::Vararg{Int,N}
) where {N}
return error("Not implemented.")
end

# TODO: Make this more general, use `Base.to_index`.
@interface interface::AbstractArrayInterface function Base.setindex!(
a::AbstractArray{<:Any,N}, value, I::CartesianIndex{N}
) where {N}
return @interface interface setindex!(a, value, Tuple(I)...)
end

@interface ::AbstractArrayInterface function Broadcast.BroadcastStyle(type::Type)
return Broadcast.DefaultArrayStyle{ndims(type)}()
end
Expand Down Expand Up @@ -203,3 +233,94 @@ end
## @interface ::AbstractMatrixInterface function Base.*(a1, a2)
## return ArrayLayouts.mul(a1, a2)
## end

# Concatenation

axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2))
function axis_cat(
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
)
return axis_cat(axis_cat(a1, a2), a_rest...)
end

unval(x) = x
unval(::Val{x}) where {x} = x

function cat_axes(as::AbstractArray...; dims)
return ntuple(length(first(axes.(as)))) do dim
return if dim in unval(dims)
axis_cat(map(axes -> axes[dim], axes.(as))...)
else
axes(first(as))[dim]
end
end
end

function cat! end

# Represents concatenating `args` over `dims`.
struct Cat{Args<:Tuple{Vararg{AbstractArray}},dims}
args::Args
end
to_cat_dims(dim::Integer) = Int(dim)
to_cat_dims(dim::Int) = (dim,)
to_cat_dims(dims::Val) = to_cat_dims(unval(dims))
to_cat_dims(dims::Tuple) = dims
Cat(args::AbstractArray...; dims) = Cat{typeof(args),to_cat_dims(dims)}(args)
cat_dims(::Cat{<:Any,dims}) where {dims} = dims

function Base.axes(a::Cat)
return cat_axes(a.args...; dims=cat_dims(a))
end
Base.eltype(a::Cat) = promote_type(eltype.(a.args)...)
function Base.similar(a::Cat)
ax = axes(a)
elt = eltype(a)
# TODO: This drops GPU information, maybe use MemoryLayout?
return similar(arraytype(interface(a.args...), elt), ax)
end

# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857
# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation
# This is very similar to the `Base.cat` implementation but handles zero values better.
function cat_offset!(
a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims
)
inds = ntuple(ndims(a_dest)) do dim
dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim)
end
a_dest[inds...] = a1
new_offsets = ntuple(ndims(a_dest)) do dim
dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim]
end
cat_offset!(a_dest, new_offsets, a_rest...; dims)
return a_dest
end
function cat_offset!(a_dest::AbstractArray, offsets; dims)
return a_dest
end

@interface ::AbstractArrayInterface function cat!(
a_dest::AbstractArray, as::AbstractArray...; dims
)
offsets = ntuple(zero, ndims(a_dest))
# TODO: Fill `a_dest` with zeros if needed using `zero!`.
cat_offset!(a_dest, offsets, as...; dims)
return a_dest
end

@interface interface::AbstractArrayInterface function Base.cat(as::AbstractArray...; dims)
a_dest = similar(Cat(as...; dims))
@interface interface cat!(a_dest, as...; dims)
return a_dest
end

# TODO: Use `@derive` instead:
# ```julia
# @derive (T=AbstractArray,) begin
# cat!(a_dest::AbstractArray, as::T...; dims)
# end
# ```
function cat!(a_dest::AbstractArray, as::AbstractArray...; dims)
return @interface interface(as...) cat!(a_dest, as...; dims)
end
1 change: 1 addition & 0 deletions src/abstractinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
interface(x) = interface(typeof(x))
# TODO: Define as `DefaultInterface()`.
interface(::Type) = error("Interface unknown.")
interface(x1, x_rest...) = combine_interfaces(x1, x_rest...)

# Adapted from `Base.Broadcast.combine_styles`.
# Get the combined interfaces of the input objects.
Expand Down
15 changes: 12 additions & 3 deletions src/interface_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,22 @@ macro interface(expr...)
return esc(interface_expr(expr...))
end

# TODO: Use `MLStyle.@match`/`Moshi.@match`.
# f(args...)
iscallexpr(expr) = Meta.isexpr(expr, :call)
# a[I...]
isrefexpr(expr) = Meta.isexpr(expr, :ref)
# a[I...] = value
issetrefexpr(expr) = Meta.isexpr(expr, :(=)) && isrefexpr(expr.args[1])

function interface_expr(interface::Union{Symbol,Expr}, func::Expr)
# TODO: Use `MLStyle.@match`/`Moshi.@match`.
# f(args...)
Meta.isexpr(func, :call) && return interface_call(interface, func)
iscallexpr(func) && return interface_call(interface, func)
# a[I...]
Meta.isexpr(func, :ref) && return interface_ref(interface, func)
isrefexpr(func) && return interface_ref(interface, func)
# a[I...] = value
Meta.isexpr(func, :(=)) && return interface_setref(interface, func)
issetrefexpr(func) && return interface_setref(interface, func)
# Assume it is a function definition.
return interface_definition(interface, func)
end
Expand Down
13 changes: 13 additions & 0 deletions src/traits.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
using ArrayLayouts: ArrayLayouts
using LinearAlgebra: LinearAlgebra

# TODO: Create a macro:
#=
```
@derive_def AbstractArrayOps T begin
Base.getindex(::T, ::Any...)
Base.getindex(::T, ::Int...)
Base.setindex!(::T, ::Any, ::Int...)
Base.similar(::T, ::Type, ::Tuple{Vararg{Int}})
end
```
=#
# TODO: Define an `AbstractMatrixOps` trait, which is where
# matrix multiplication should be defined (both `mul!` and `*`).
#=
Expand All @@ -13,6 +24,7 @@ function derive(::Val{:AbstractArrayOps}, type)
return quote
Base.getindex(::$type, ::Any...)
Base.getindex(::$type, ::Int...)
Base.setindex!(::$type, ::Any, ::Any...)
Base.setindex!(::$type, ::Any, ::Int...)
Base.similar(::$type, ::Type, ::Tuple{Vararg{Int}})
Base.similar(::$type, ::Type, ::Tuple{Base.OneTo,Vararg{Base.OneTo}})
Expand All @@ -33,6 +45,7 @@ function derive(::Val{:AbstractArrayOps}, type)
Base.permutedims!(::Any, ::$type, ::Any)
Broadcast.BroadcastStyle(::Type{<:$type})
Base.copyto!(::$type, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}})
Base.cat(::$type...; kwargs...)
ArrayLayouts.MemoryLayout(::Type{<:$type})
LinearAlgebra.mul!(::AbstractMatrix, ::$type, ::$type, ::Number, ::Number)
end
Expand Down
34 changes: 34 additions & 0 deletions test/basics/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a[1, 2] = 12
b = similar(a)
copyto!(b, a)
@test b isa SparseArrayDOK{elt,2}
@test b == a
@test b[1, 2] == 12
@test storedlength(b) == 1
Expand Down Expand Up @@ -114,6 +115,39 @@ elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = zero(a)
@test b isa SparseArrayDOK{elt,2}
@test iszero(b)
@test iszero(storedlength(b))

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = SparseArrayDOK{elt}(4, 4)
b[2:3, 2:3] .= a
@test isone(storedlength(b))
@test b[2, 3] == 12

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = SparseArrayDOK{elt}(4, 4)
b[2:3, 2:3] = a
@test isone(storedlength(b))
@test b[2, 3] == 12

a = SparseArrayDOK{elt}(2, 2)
a[1, 2] = 12
b = SparseArrayDOK{elt}(4, 4)
c = @view b[2:3, 2:3]
c .= a
@test isone(storedlength(b))
@test b[2, 3] == 12

a1 = SparseArrayDOK{elt}(2, 2)
a1[1, 2] = 12
a2 = SparseArrayDOK{elt}(2, 2)
a2[2, 1] = 21
b = cat(a1, a2; dims=(1, 2))
@test b isa SparseArrayDOK{elt,2}
@test storedlength(b) == 2
@test b[1, 2] == 12
@test b[4, 3] == 21
end

0 comments on commit 5c52fbf

Please sign in to comment.