From f1a40fcd32786b29a20e0d73e077020ce01c3ea8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 17 Jan 2024 19:32:00 +0530 Subject: [PATCH] fix: fix several adjoints, copy, copyto! and zero methods for VoA --- ext/RecursiveArrayToolsZygoteExt.jl | 26 +++++++++--- src/vector_of_array.jl | 61 +++++++++++++++++++++++------ test/interface_tests.jl | 37 +++++++++++++++++ 3 files changed, 106 insertions(+), 18 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index c4ddc1a7..c4611137 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -50,7 +50,11 @@ end Colon, BitArray, AbstractArray{Bool}}...) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))]) - Δ′[i, j...] = Δ + if isempty(j) + Δ′.u[i] = Δ + else + Δ′[i, j...] = Δ + end (Δ′, nothing, map(_ -> nothing, j)...) end VA[i, j...], AbstractVectorOfArray_getindex_adjoint @@ -104,13 +108,25 @@ end end @adjoint function Base.Array(VA::AbstractVectorOfArray) - Array(VA), - y -> (Array(y),) + adj = let VA=VA + function Array_adjoint(y) + VA = copy(VA) + copyto!(VA, y) + return (VA,) + end + end + Array(VA), adj end @adjoint function Base.view(A::AbstractVectorOfArray, I...) - view(A, I...), - y -> (view(y, I...), ntuple(_ -> nothing, length(I))...) + adj = let A = A, I = I + function view_adjoint(y) + A = zero(A) + view(A, I...) .= y + return (A, map(_ -> nothing, I)...) + end + end + view(A, I...), adj end ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a))) diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 16b58279..d348cac4 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -160,6 +160,24 @@ function DiffEqArray(vec::AbstractVector{T}, p, sys) end + +# ambiguity resolution +function DiffEqArray(vec::AbstractVector{VT}, + ts::AbstractVector, + ::NTuple{N, Int}) where {T, N, VT <: AbstractArray{T, N}} + DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), Nothing, Nothing}(vec, + ts, + nothing, + nothing) +end +function DiffEqArray(vec::AbstractVector{VT}, + ts::AbstractVector, + ::NTuple{N, Int}, p) where {T, N, VT <: AbstractArray{T, N}} + DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing}(vec, + ts, + p, + nothing) +end # Assume that the first element is representative of all other elements function DiffEqArray(vec::AbstractVector, @@ -174,9 +192,10 @@ function DiffEqArray(vec::AbstractVector, something(parameters, []), something(independent_variables, []))) _size = size(vec[1]) + T = eltype(vec[1]) return DiffEqArray{ - eltype(eltype(vec)), - length(_size), + T, + length(_size) + 1, typeof(vec), typeof(ts), typeof(p), @@ -466,19 +485,25 @@ end tuples(VA::DiffEqArray) = tuple.(VA.t, VA.u) # Growing the array simply adds to the container vector -function Base.copy(VA::AbstractDiffEqArray) - typeof(VA)(copy(VA.u), - copy(VA.t), - (VA.p === nothing) ? nothing : copy(VA.p), - (VA.sys === nothing) ? nothing : copy(VA.sys)) +function _copyfield(VA, fname) + if fname == :u + copy(VA.u) + elseif fname == :t + copy(VA.t) + else + getfield(VA, fname) + end +end +function Base.copy(VA::AbstractVectorOfArray) + typeof(VA)((_copyfield(VA, fname) for fname in fieldnames(typeof(VA)))...) end -Base.copy(VA::AbstractVectorOfArray) = typeof(VA)(copy(VA.u)) - -Base.zero(VA::AbstractVectorOfArray) = VectorOfArray(Base.zero.(VA.u)) -function Base.zero(VA::AbstractDiffEqArray) - u = Base.zero.(VA.u) - DiffEqArray(u, VA.t, parameter_values(VA), symbolic_container(VA)) +function Base.zero(VA::AbstractVectorOfArray) + val = copy(VA) + for i in eachindex(VA.u) + val.u[i] = zero(VA.u[i]) + end + return val end Base.sizehint!(VA::AbstractVectorOfArray{T, N}, i) where {T, N} = sizehint!(VA.u, i) @@ -563,6 +588,16 @@ end function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T,N}) where {T,N} copyto!.(dest.u, src.u) end +function Base.copyto!(dest::AbstractVectorOfArray{T, N}, src::AbstractArray{T, N}) where {T, N} + for (i, slice) in enumerate(eachslice(src, dims = ndims(src))) + copyto!(dest.u[i], slice) + end + dest +end +function Base.copyto!(dest::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, src::AbstractVector{T}) where {T, N} + copyto!(dest.u, src) + dest +end # Required for broadcasted setindex! when slicing across subarrays # E.g. if `va = VectorOfArray([rand(3, 3) for i in 1:5])` # Need this method for `va[2, :, :] .= 3.0` diff --git a/test/interface_tests.jl b/test/interface_tests.jl index 6dd7e04e..61abf227 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -1,5 +1,6 @@ using RecursiveArrayTools, StaticArrays, Test using FastBroadcast +using SymbolicIndexingInterface: SymbolCache t = 1:3 testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) @@ -149,6 +150,42 @@ testda = DiffEqArray(recursivecopy(testva.u), testts) fill!(testda, testval) @test all(x -> (x == testval), testda) +# copyto! +testva = VectorOfArray(collect(0.1:0.1:1.0)) +arr = 0.2:0.2:2.0 +copyto!(testva, arr) +@test Array(testva) == arr +testva = VectorOfArray([i * ones(3, 2) for i in 1:4]) +arr = rand(3, 2, 4) +copyto!(testva, arr) +@test Array(testva) == arr +testva = VectorOfArray([ + ones(3, 2, 2), + VectorOfArray([ + 2ones(3, 2), + VectorOfArray([3ones(3), 4ones(3)]) + ]), + DiffEqArray([ + 5ones(3, 2), + VectorOfArray([ + 6ones(3), + 7ones(3), + ]), + ], [0.1, 0.2], [100.0, 200.0], SymbolCache([:x, :y], [:a, :b], :t)) +]) +arr = rand(3, 2, 2, 3) +copyto!(testva, arr) +@test Array(testva) == arr +# ensure structure and fields are maintained +@test testva.u[1] isa Array +@test testva.u[2] isa VectorOfArray +@test testva.u[2].u[2] isa VectorOfArray +@test testva.u[3] isa DiffEqArray +@test testva.u[3].u[2] isa VectorOfArray +@test testva.u[3].t == [0.1, 0.2] +@test testva.u[3].p == [100.0, 200.0] +@test testva.u[3].sys isa SymbolCache + # check any recs = [collect(1:5), collect(6:10), collect(11:15)] testts = rand(5)