Skip to content

Commit

Permalink
add StackView{T} constructor with better default eltype inference
Browse files Browse the repository at this point in the history
This supports heterogeneous array inputs
  • Loading branch information
johnnychen94 committed Apr 9, 2021
1 parent 157294c commit 526de2f
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 12 deletions.
20 changes: 18 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
[![Build Status](https://github.com/JuliaArrays/StackViews.jl/workflows/CI/badge.svg)](https://github.com/JuliaArrays/StackViews.jl/actions)
[![Coverage](https://codecov.io/gh/JuliaArrays/StackViews.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaArrays/StackViews.jl)

There are three ways to understand `StackView`:
`StackViews` provides only one array type: `StackView`. There are multiple ways to understand `StackView`:

- inverse of `eachslice`
- `cat` variant
- view object
- lazy version of `repeat` special case

## `StackView` as the inverse of `eachslice`

Expand Down Expand Up @@ -128,11 +129,26 @@ As == Ac # true
@btime arrsum_cart($Ac); # 128.888 ms (0 allocations: 0 bytes)
```

## `StackView` as a lazy version of `repeat` special case

`StackView` allows you to stack the same array object multiple times, which makes a special
version of `repeat` when there's only one none-1 repeat count:

```julia
A = rand(1000, 1000);
n = 100;
StackView([A for _ in 1:n]) == repeat(A, ntuple(_->1, ndims(A))..., n) # true
@btime StackView([$A for _ in 1:$n]); # 403.156 ns (2 allocations: 1.75 KiB)
@btime repeat($A, ntuple(_->1, ndims($A))..., $n) # 590.043 ms (4 allocations: 762.94 MiB)
```

## More examples

When arrays are of different types and sizes, `StackViews` just kills `cat`s:
When arrays are of different types and sizes, `StackView` just kills `cat`s:

```julia
julia> using StackViews, PaddedViews

julia> A = collect(reshape(1:8, 2, 4));

julia> B = collect(reshape(9:16, 4, 2));
Expand Down
36 changes: 27 additions & 9 deletions src/StackViews.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ const SlicesType = Union{Tuple, AbstractArray{<:AbstractArray}}

"""
StackView(slices...; [dims])
StackView(slices, [dim])
StackView(slices, [dims])
StackView{T}(args...; kwargs...)
Stack/concatenate a list of arrays `slices` along dimension `dim` without copying the data.
Stack/concatenate a list of arrays `slices` along dimension `dims` without copying the data.
If not specified, `dim` is defined as `ndims(first(slices))+1`, i.e., a new dimension in the tail.
If not specified, `dims` is defined as `ndims(first(slices))+1`, i.e., a new dimension in the tail.
This works better for normal Julia arrays when their memory layout is column-major.
# Example
Expand Down Expand Up @@ -80,20 +81,32 @@ struct StackView{T, N, D, A} <: AbstractArray{T, N}
StackView{T, N, 0}(::A) where {T, N, A} = throw(ArgumentError("`dims=Val(0)` is not supported, do you mean `dims=Val(1)`?"))
end
StackView(slices::AbstractArray...; dims=Val(_default_dims(slices))) = StackView(slices, dims)
StackView{T}(slices::AbstractArray...; dims=Val(_default_dims(slices))) where T = StackView{T}(slices, dims)
StackView(slices, dims::Int) = StackView(slices, Val(dims)) # type-unstable

function StackView(slices::SlicesType, dims::Val = Val(_default_dims(slices)))
return StackView{_default_eltype(slices)}(slices, dims)
end
function StackView{T}(slices::SlicesType, dims::Val = Val(_default_dims(slices))) where T
N = _max(dims, Val(_default_dims(slices)))
slices = map(OffsetArrays.no_offset_view, slices) # unify all the axes
StackView{_default_eltype(slices), _value(N), _value(dims)}(slices)
# unify all the axes to 1-based ranges
slices = map(OffsetArrays.no_offset_view, slices)
return StackView{T, _value(N), _value(dims)}(slices)
end

@inline _default_dims(slices) = ndims(first(slices)) + 1
@inline function _default_eltype(slices)
T = mapreduce(eltype, promote_type, slices)
isconcretetype(T) || throw(ArgumentError("Input arrays should be homogenous."))
_isconcretetype(T) || throw(ArgumentError("Input arrays should be homogenous."))
return T
end
function _isconcretetype(T)
# We relax the restriction and allow `Union`
# This is particularily useful for arrays with `missing` and `nothing`
isconcretetype(T) && return true
isa(T, Union) && return isconcretetype(T.a) && _isconcretetype(T.b)
return false
end

function Base.size(A::StackView{T,N,D}) where {T,N,D}
frame_size = size(first(A.slices))
Expand All @@ -102,9 +115,14 @@ function Base.size(A::StackView{T,N,D}) where {T,N,D}
end

function Base.axes(A::StackView{T,N,D}) where {T,N,D}
frame_size = axes(first(A.slices))
prev, post = Base.IteratorsMD.split(frame_size, Val(D-1))
return (_append_tuple(prev, Val(D-1), Base.OneTo(1))..., Base.OneTo(length(A.slices)), post...)
frame_axes = axes(first(A.slices))
prev, post = Base.IteratorsMD.split(frame_axes, Val(D-1))

# use homogenous range to make _append_tuple happy
fill_range = convert(typeof(first(frame_axes)), Base.OneTo(1))
return (_append_tuple(prev, Val(D-1), fill_range)...,
Base.OneTo(length(A.slices)),
post...)
end

@inline function Base.getindex(A::StackView{T,N,D}, i::Vararg{Int,N}) where {T,N,D}
Expand Down
32 changes: 31 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,41 @@ end
@test_throws ArgumentError StackView([A, B], 0)
@test_throws ArgumentError StackView([A, B], Val(0))

# repeat the array multiple times is allowed
A = [1, 2, 3, 4]
@test StackView([A for i in 1:3]) == repeat(A, 1, 3)

@testset "heterogeneous arrays" begin
A = Int[1, 2, 3, 4]
B = Union{Missing, Float64}[missing, 6, 7, 8]
sv = @inferred StackView([A, B])
@test eltype(sv) == Union{Missing, Float64}
T = Union{Missing, Float32}
sv = @inferred StackView{T}([A, B])
@test eltype(sv) == T

A = Int[1, 2, 3, 4]
B = [[1, ], [2, ], [3, ], [4, ]]
@test_throws ArgumentError StackView(A, B)
# Perhaps suprisingly, with explicit type T it bypasses the type check
# It's an undefined behavior, just put it here for the record
sv = @test_nowarn StackView{Int}(A, B)
@test_throws MethodError collect(sv)
end

@testset "axes" begin
# axes are unified to 1-based
A = OffsetArray(collect(reshape(1:8, 2, 4)), -1, 1)
B = OffsetArray(collect(reshape(9:16, 2, 4)), 1, -1)
@test axes(StackView([A, B], 1)) == (Base.OneTo(2), Base.OneTo(2), Base.OneTo(4))
sv = StackView([A, B], 1)
@inferred axes(sv)
@test axes(sv) === (Base.OneTo(2), Base.OneTo(2), Base.OneTo(4))

A = collect(reshape(1:8, 2, 4))
B = collect(reshape(9:16, 2, 4))
sv = StackView([A, B], 1)
@inferred axes(sv)
@test axes(sv) === (Base.OneTo(2), Base.OneTo(2), Base.OneTo(4))
end

@testset "setindex!" begin
Expand Down

0 comments on commit 526de2f

Please sign in to comment.