Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix several adjoints, copy and zero methods for VoA #336

Merged
merged 1 commit into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@
Colon, BitArray, AbstractArray{Bool}}...)
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))])
Δ′[i, j...] = Δ
if isempty(j)
Δ′.u[i] = Δ

Check warning on line 54 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L54

Added line #L54 was not covered by tests
else
Δ′[i, j...] = Δ
end
(Δ′, nothing, map(_ -> nothing, j)...)
end
VA[i, j...], AbstractVectorOfArray_getindex_adjoint
Expand Down Expand Up @@ -104,13 +108,25 @@
end

@adjoint function Base.Array(VA::AbstractVectorOfArray)
Array(VA),
y -> (Array(y),)
adj = let VA=VA
function Array_adjoint(y)
VA = copy(VA)
copyto!(VA, y)
return (VA,)
end
end
Array(VA), adj
end

@adjoint function Base.view(A::AbstractVectorOfArray, I...)
view(A, I...),
y -> (view(y, I...), ntuple(_ -> nothing, length(I))...)
adj = let A = A, I = I
function view_adjoint(y)
A = zero(A)
view(A, I...) .= y
return (A, map(_ -> nothing, I)...)
end
end
view(A, I...), adj
end

ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))
Expand Down
61 changes: 48 additions & 13 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,24 @@
p,
sys)
end

# ambiguity resolution
function DiffEqArray(vec::AbstractVector{VT},

Check warning on line 165 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L165

Added line #L165 was not covered by tests
ts::AbstractVector,
::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing}(vec,

Check warning on line 168 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L168

Added line #L168 was not covered by tests
ts,
nothing,
nothing)
end
function DiffEqArray(vec::AbstractVector{VT},

Check warning on line 173 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L173

Added line #L173 was not covered by tests
ts::AbstractVector,
::NTuple{N, Int}, p) where {T, N, VT <: AbstractArray{T, N}}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing}(vec,

Check warning on line 176 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L176

Added line #L176 was not covered by tests
ts,
p,
nothing)
end
# Assume that the first element is representative of all other elements

function DiffEqArray(vec::AbstractVector,
Expand All @@ -174,9 +192,10 @@
something(parameters, []),
something(independent_variables, [])))
_size = size(vec[1])
T = eltype(vec[1])
return DiffEqArray{
eltype(eltype(vec)),
length(_size),
T,
length(_size) + 1,
typeof(vec),
typeof(ts),
typeof(p),
Expand Down Expand Up @@ -466,19 +485,25 @@
tuples(VA::DiffEqArray) = tuple.(VA.t, VA.u)

# Growing the array simply adds to the container vector
function Base.copy(VA::AbstractDiffEqArray)
typeof(VA)(copy(VA.u),
copy(VA.t),
(VA.p === nothing) ? nothing : copy(VA.p),
(VA.sys === nothing) ? nothing : copy(VA.sys))
function _copyfield(VA, fname)
if fname == :u
copy(VA.u)
elseif fname == :t
copy(VA.t)
else
getfield(VA, fname)
end
end
function Base.copy(VA::AbstractVectorOfArray)
typeof(VA)((_copyfield(VA, fname) for fname in fieldnames(typeof(VA)))...)
end
Base.copy(VA::AbstractVectorOfArray) = typeof(VA)(copy(VA.u))

Base.zero(VA::AbstractVectorOfArray) = VectorOfArray(Base.zero.(VA.u))

function Base.zero(VA::AbstractDiffEqArray)
u = Base.zero.(VA.u)
DiffEqArray(u, VA.t, parameter_values(VA), symbolic_container(VA))
function Base.zero(VA::AbstractVectorOfArray)
val = copy(VA)
for i in eachindex(VA.u)
val.u[i] = zero(VA.u[i])
end
return val
end

Base.sizehint!(VA::AbstractVectorOfArray{T, N}, i) where {T, N} = sizehint!(VA.u, i)
Expand Down Expand Up @@ -563,6 +588,16 @@
function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T,N}) where {T,N}
copyto!.(dest.u, src.u)
end
function Base.copyto!(dest::AbstractVectorOfArray{T, N}, src::AbstractArray{T, N}) where {T, N}
for (i, slice) in enumerate(eachslice(src, dims = ndims(src)))
copyto!(dest.u[i], slice)
end
dest
end
function Base.copyto!(dest::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, src::AbstractVector{T}) where {T, N}
copyto!(dest.u, src)
dest
end
# Required for broadcasted setindex! when slicing across subarrays
# E.g. if `va = VectorOfArray([rand(3, 3) for i in 1:5])`
# Need this method for `va[2, :, :] .= 3.0`
Expand Down
37 changes: 37 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using RecursiveArrayTools, StaticArrays, Test
using FastBroadcast
using SymbolicIndexingInterface: SymbolCache

t = 1:3
testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
Expand Down Expand Up @@ -149,6 +150,42 @@ testda = DiffEqArray(recursivecopy(testva.u), testts)
fill!(testda, testval)
@test all(x -> (x == testval), testda)

# copyto!
testva = VectorOfArray(collect(0.1:0.1:1.0))
arr = 0.2:0.2:2.0
copyto!(testva, arr)
@test Array(testva) == arr
testva = VectorOfArray([i * ones(3, 2) for i in 1:4])
arr = rand(3, 2, 4)
copyto!(testva, arr)
@test Array(testva) == arr
testva = VectorOfArray([
ones(3, 2, 2),
VectorOfArray([
2ones(3, 2),
VectorOfArray([3ones(3), 4ones(3)])
]),
DiffEqArray([
5ones(3, 2),
VectorOfArray([
6ones(3),
7ones(3),
]),
], [0.1, 0.2], [100.0, 200.0], SymbolCache([:x, :y], [:a, :b], :t))
])
arr = rand(3, 2, 2, 3)
copyto!(testva, arr)
@test Array(testva) == arr
# ensure structure and fields are maintained
@test testva.u[1] isa Array
@test testva.u[2] isa VectorOfArray
@test testva.u[2].u[2] isa VectorOfArray
@test testva.u[3] isa DiffEqArray
@test testva.u[3].u[2] isa VectorOfArray
@test testva.u[3].t == [0.1, 0.2]
@test testva.u[3].p == [100.0, 200.0]
@test testva.u[3].sys isa SymbolCache

# check any
recs = [collect(1:5), collect(6:10), collect(11:15)]
testts = rand(5)
Expand Down
Loading