Skip to content

Commit

Permalink
fix: fix several adjoints, copy, copyto! and zero methods for VoA
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 18, 2024
1 parent 030923c commit f1a40fc
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 18 deletions.
26 changes: 21 additions & 5 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ end
Colon, BitArray, AbstractArray{Bool}}...)
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))])
Δ′[i, j...] = Δ
if isempty(j)
Δ′.u[i] = Δ

Check warning on line 54 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L54

Added line #L54 was not covered by tests
else
Δ′[i, j...] = Δ
end
(Δ′, nothing, map(_ -> nothing, j)...)
end
VA[i, j...], AbstractVectorOfArray_getindex_adjoint
Expand Down Expand Up @@ -104,13 +108,25 @@ end
end

@adjoint function Base.Array(VA::AbstractVectorOfArray)
Array(VA),
y -> (Array(y),)
adj = let VA=VA
function Array_adjoint(y)
VA = copy(VA)
copyto!(VA, y)
return (VA,)
end
end
Array(VA), adj
end

@adjoint function Base.view(A::AbstractVectorOfArray, I...)
view(A, I...),
y -> (view(y, I...), ntuple(_ -> nothing, length(I))...)
adj = let A = A, I = I
function view_adjoint(y)
A = zero(A)
view(A, I...) .= y
return (A, map(_ -> nothing, I)...)
end
end
view(A, I...), adj
end

ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))
Expand Down
61 changes: 48 additions & 13 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,24 @@ function DiffEqArray(vec::AbstractVector{T},
p,
sys)
end

# ambiguity resolution
function DiffEqArray(vec::AbstractVector{VT},

Check warning on line 165 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L165

Added line #L165 was not covered by tests
ts::AbstractVector,
::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing}(vec,

Check warning on line 168 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L168

Added line #L168 was not covered by tests
ts,
nothing,
nothing)
end
function DiffEqArray(vec::AbstractVector{VT},

Check warning on line 173 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L173

Added line #L173 was not covered by tests
ts::AbstractVector,
::NTuple{N, Int}, p) where {T, N, VT <: AbstractArray{T, N}}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing}(vec,

Check warning on line 176 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L176

Added line #L176 was not covered by tests
ts,
p,
nothing)
end
# Assume that the first element is representative of all other elements

function DiffEqArray(vec::AbstractVector,
Expand All @@ -174,9 +192,10 @@ function DiffEqArray(vec::AbstractVector,
something(parameters, []),
something(independent_variables, [])))
_size = size(vec[1])
T = eltype(vec[1])
return DiffEqArray{
eltype(eltype(vec)),
length(_size),
T,
length(_size) + 1,
typeof(vec),
typeof(ts),
typeof(p),
Expand Down Expand Up @@ -466,19 +485,25 @@ end
tuples(VA::DiffEqArray) = tuple.(VA.t, VA.u)

# Growing the array simply adds to the container vector
function Base.copy(VA::AbstractDiffEqArray)
typeof(VA)(copy(VA.u),
copy(VA.t),
(VA.p === nothing) ? nothing : copy(VA.p),
(VA.sys === nothing) ? nothing : copy(VA.sys))
function _copyfield(VA, fname)
if fname == :u
copy(VA.u)
elseif fname == :t
copy(VA.t)
else
getfield(VA, fname)
end
end
function Base.copy(VA::AbstractVectorOfArray)
typeof(VA)((_copyfield(VA, fname) for fname in fieldnames(typeof(VA)))...)
end
Base.copy(VA::AbstractVectorOfArray) = typeof(VA)(copy(VA.u))

Base.zero(VA::AbstractVectorOfArray) = VectorOfArray(Base.zero.(VA.u))

function Base.zero(VA::AbstractDiffEqArray)
u = Base.zero.(VA.u)
DiffEqArray(u, VA.t, parameter_values(VA), symbolic_container(VA))
function Base.zero(VA::AbstractVectorOfArray)
val = copy(VA)
for i in eachindex(VA.u)
val.u[i] = zero(VA.u[i])
end
return val
end

Base.sizehint!(VA::AbstractVectorOfArray{T, N}, i) where {T, N} = sizehint!(VA.u, i)
Expand Down Expand Up @@ -563,6 +588,16 @@ end
function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T,N}) where {T,N}
copyto!.(dest.u, src.u)
end
function Base.copyto!(dest::AbstractVectorOfArray{T, N}, src::AbstractArray{T, N}) where {T, N}
for (i, slice) in enumerate(eachslice(src, dims = ndims(src)))
copyto!(dest.u[i], slice)
end
dest
end
function Base.copyto!(dest::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, src::AbstractVector{T}) where {T, N}
copyto!(dest.u, src)
dest
end
# Required for broadcasted setindex! when slicing across subarrays
# E.g. if `va = VectorOfArray([rand(3, 3) for i in 1:5])`
# Need this method for `va[2, :, :] .= 3.0`
Expand Down
37 changes: 37 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using RecursiveArrayTools, StaticArrays, Test
using FastBroadcast
using SymbolicIndexingInterface: SymbolCache

t = 1:3
testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
Expand Down Expand Up @@ -149,6 +150,42 @@ testda = DiffEqArray(recursivecopy(testva.u), testts)
fill!(testda, testval)
@test all(x -> (x == testval), testda)

# copyto!
testva = VectorOfArray(collect(0.1:0.1:1.0))
arr = 0.2:0.2:2.0
copyto!(testva, arr)
@test Array(testva) == arr
testva = VectorOfArray([i * ones(3, 2) for i in 1:4])
arr = rand(3, 2, 4)
copyto!(testva, arr)
@test Array(testva) == arr
testva = VectorOfArray([
ones(3, 2, 2),
VectorOfArray([
2ones(3, 2),
VectorOfArray([3ones(3), 4ones(3)])
]),
DiffEqArray([
5ones(3, 2),
VectorOfArray([
6ones(3),
7ones(3),
]),
], [0.1, 0.2], [100.0, 200.0], SymbolCache([:x, :y], [:a, :b], :t))
])
arr = rand(3, 2, 2, 3)
copyto!(testva, arr)
@test Array(testva) == arr
# ensure structure and fields are maintained
@test testva.u[1] isa Array
@test testva.u[2] isa VectorOfArray
@test testva.u[2].u[2] isa VectorOfArray
@test testva.u[3] isa DiffEqArray
@test testva.u[3].u[2] isa VectorOfArray
@test testva.u[3].t == [0.1, 0.2]
@test testva.u[3].p == [100.0, 200.0]
@test testva.u[3].sys isa SymbolCache

# check any
recs = [collect(1:5), collect(6:10), collect(11:15)]
testts = rand(5)
Expand Down

0 comments on commit f1a40fc

Please sign in to comment.