Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: rework AbstractVectorOfArray, use new SymbolicIndexingInterface #290

Merged
merged 45 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
11dd2b1
feat!: remove `AbstractArray` subtype
AayushSabharwal Oct 20, 2023
c6eafa1
feat: remove issymbollike
AayushSabharwal Oct 23, 2023
4e3f80d
feat: rework DiffEqArray, use new SymbolicIndexingInterface
AayushSabharwal Oct 23, 2023
0966545
fix: add Base.convert method, fix Base.summary
AayushSabharwal Oct 23, 2023
8ea2d63
feat: make VectorOfArrays broadcastable
AayushSabharwal Oct 25, 2023
3c499f5
feat: VectorOfArray equality testing
AayushSabharwal Oct 25, 2023
c638aa4
fix: fix copy method for DiffEqArray
AayushSabharwal Oct 25, 2023
5f722d7
feat: equality testing between VectorOfArray and AbstractArray
AayushSabharwal Oct 25, 2023
5e1a3aa
feat: Base.axes for VectorOfArray
AayushSabharwal Oct 25, 2023
a452f62
feat: fallback getindex for AbstractVectorOfArray
AayushSabharwal Oct 25, 2023
6f83cd4
feat: add Base.eltype for AbstractVectorOfArray
AayushSabharwal Oct 27, 2023
6c41391
feat: add AbstractArray methods, fix reverse and convert
AayushSabharwal Nov 2, 2023
8419443
fixup! feat: add AbstractArray methods, fix reverse and convert
AayushSabharwal Nov 2, 2023
67e6429
fixup! feat: add AbstractArray methods, fix reverse and convert
AayushSabharwal Nov 2, 2023
01c6af2
fixup! feat: add AbstractArray methods, fix reverse and convert
AayushSabharwal Nov 3, 2023
e0eef08
feat: update DiffEqArray constructors, support explicit symbols
AayushSabharwal Nov 6, 2023
249c2b2
refactor: improve symbolic indexing to support LabelledArrays
AayushSabharwal Nov 6, 2023
9501cf7
fix: remove references to DiffEqArray.sc
AayushSabharwal Nov 6, 2023
2727653
refactor: refactor tests with new constructors
AayushSabharwal Nov 6, 2023
d05db29
refactor: update DiffEqArray constructors
AayushSabharwal Nov 7, 2023
757cfa0
refactor: update to use new SII
AayushSabharwal Nov 7, 2023
d6d8536
fixup! refactor: update DiffEqArray constructors
AayushSabharwal Nov 7, 2023
a629cc7
fix: use new SII for tabletraits and plots
AayushSabharwal Nov 7, 2023
ca31986
refactor: format
AayushSabharwal Nov 7, 2023
b8ebfab
feat: deprecate linear indexing, fix tests
AayushSabharwal Nov 9, 2023
3cb88ed
refactor: drastically simplify getindex methods
AayushSabharwal Nov 9, 2023
975ebfc
fix: fix constructor
AayushSabharwal Nov 10, 2023
06ba929
feat: revert length method, add Base.first and Base.last methods
AayushSabharwal Nov 10, 2023
6e15122
feat: use trait dispatch for getindex method
AayushSabharwal Nov 10, 2023
c875b65
fix: add firstindex and lastindex methods with depwarn
AayushSabharwal Nov 14, 2023
364a905
test: fix DiffEqConstructor
AayushSabharwal Nov 14, 2023
53e9faf
fix: add IndexStyle method, fix reshape for VectorOfArray
AayushSabharwal Nov 16, 2023
919df69
test: mark measurements and units test as broken
AayushSabharwal Nov 16, 2023
2e607b8
fix: fix Base.similar for VectorOfArray
AayushSabharwal Nov 20, 2023
6a5e303
fix!: Julia 1.10 fix for ldiv
AayushSabharwal Nov 20, 2023
1220488
refactor: hacky fix for autodiff
AayushSabharwal Nov 20, 2023
d843222
refactor: use symbolic_container from SII for fallback implementation
AayushSabharwal Nov 24, 2023
7996890
refactor: deprecate direct parameter indexing instead of erroring
AayushSabharwal Nov 27, 2023
92fc778
refactor: fix VectorOfArray arithmetic
AayushSabharwal Nov 27, 2023
cbf61c6
test: more comprehensive tests
AayushSabharwal Nov 27, 2023
26c5873
refactor: use new getp and setp for parameter indexing
AayushSabharwal Nov 28, 2023
22a81f2
refactor: revert eachindex to iterating over indices of VA.u
AayushSabharwal Nov 29, 2023
7b2613c
refactor: add Base.view for AbstractVectorOfArray
AayushSabharwal Dec 1, 2023
59f4b86
Update src/vector_of_array.jl
ChrisRackauckas Dec 11, 2023
0936086
Update src/vector_of_array.jl
ChrisRackauckas Dec 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ RecipesBase = "0.7, 0.8, 1.0"
Requires = "1.0"
StaticArraysCore = "1.1"
Statistics = "1"
SymbolicIndexingInterface = "0.1, 0.2"
SymbolicIndexingInterface = "0.3"
Tables = "1"
Zygote = "0.6.56"
julia = "1.6"
Expand Down
16 changes: 8 additions & 8 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true)
include("pages.jl")

makedocs(sitename = "RecursiveArrayTools.jl",
authors = "Chris Rackauckas",
modules = [RecursiveArrayTools],
clean = true, doctest = false, linkcheck = true,
warnonly = [:missing_docs],
format = Documenter.HTML(assets = ["assets/favicon.ico"],
canonical = "https://docs.sciml.ai/RecursiveArrayTools/stable/"),
pages = pages)
authors = "Chris Rackauckas",
modules = [RecursiveArrayTools],
clean = true, doctest = false, linkcheck = true,
warnonly = [:missing_docs],
format = Documenter.HTML(assets = ["assets/favicon.ico"],
canonical = "https://docs.sciml.ai/RecursiveArrayTools/stable/"),
pages = pages)

deploydocs(repo = "github.com/SciML/RecursiveArrayTools.jl.git";
push_preview = true)
push_preview = true)
4 changes: 2 additions & 2 deletions ext/RecursiveArrayToolsMeasurementsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import RecursiveArrayTools
isdefined(Base, :get_extension) ? (import Measurements) : (import ..Measurements)

function RecursiveArrayTools.recursive_unitless_bottom_eltype(a::Type{
<:Measurements.Measurement
})
<:Measurements.Measurement,
})
typeof(oneunit(a))
end

Expand Down
11 changes: 7 additions & 4 deletions ext/RecursiveArrayToolsMonteCarloMeasurementsExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
module RecursiveArrayToolsMonteCarloMeasurementsExt

import RecursiveArrayTools
isdefined(Base, :get_extension) ? (import MonteCarloMeasurements) : (import ..MonteCarloMeasurements)
isdefined(Base, :get_extension) ? (import MonteCarloMeasurements) :
(import ..MonteCarloMeasurements)

function RecursiveArrayTools.recursive_unitless_bottom_eltype(a::Type{
<:MonteCarloMeasurements.Particles
})
<:MonteCarloMeasurements.Particles,
})
typeof(one(a))
end

function RecursiveArrayTools.recursive_unitless_eltype(a::Type{<:MonteCarloMeasurements.Particles})
function RecursiveArrayTools.recursive_unitless_eltype(a::Type{
<:MonteCarloMeasurements.Particles,
})
typeof(one(a))
end

Expand Down
12 changes: 6 additions & 6 deletions ext/RecursiveArrayToolsTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import RecursiveArrayTools
isdefined(Base, :get_extension) ? (import Tracker) : (import ..Tracker)

function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T, N},
a::AbstractArray{T2, N}) where {
T <:
Tracker.TrackedArray,
T2 <:
Tracker.TrackedArray,
N}
a::AbstractArray{T2, N}) where {
T <:
Tracker.TrackedArray,
T2 <:
Tracker.TrackedArray,
N}
@inbounds for i in eachindex(a)
b[i] = copy(a[i])
end
Expand Down
35 changes: 21 additions & 14 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end
ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()

function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray},
xs::AbstractVectorOfArray)
xs::AbstractVectorOfArray)
T(xs), ȳ -> (ChainRulesCore.NoTangent(), ȳ)
end

Expand All @@ -28,7 +28,7 @@ end
end

@adjoint function getindex(VA::AbstractVectorOfArray,
i::Union{BitArray, AbstractArray{Bool}})
i::Union{BitArray, AbstractArray{Bool}})
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = [(i[j] ? Δ[j] : FillArrays.Fill(zero(eltype(x)), size(x)))
for (x, j) in zip(VA.u, 1:length(VA))]
Expand All @@ -48,7 +48,7 @@ end
end

@adjoint function getindex(VA::AbstractVectorOfArray,
i::Union{Int, AbstractArray{Int}})
i::Union{Int, AbstractArray{Int}})
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = [(i[j] ? Δ[j] : FillArrays.Fill(zero(eltype(x)), size(x)))
for (x, j) in zip(VA.u, 1:length(VA))]
Expand All @@ -65,8 +65,8 @@ end
end

@adjoint function getindex(VA::AbstractVectorOfArray, i::Int,
j::Union{Int, AbstractArray{Int}, CartesianIndex,
Colon, BitArray, AbstractArray{Bool}}...)
j::Union{Int, AbstractArray{Int}, CartesianIndex,
Colon, BitArray, AbstractArray{Bool}}...)
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))])
Δ′[i, j...] = Δ
Expand All @@ -76,11 +76,11 @@ end
end

@adjoint function ArrayPartition(x::S,
::Type{Val{copy_x}} = Val{false}) where {
S <:
Tuple,
copy_x
}
::Type{Val{copy_x}} = Val{false}) where {
S <:
Tuple,
copy_x,
}
function ArrayPartition_adjoint(_y)
y = Array(_y)
starts = vcat(0, cumsum(reduce(vcat, length.(x))))
Expand All @@ -93,14 +93,21 @@ end

@adjoint function VectorOfArray(u)
VectorOfArray(u),
y -> (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
for i in 1:size(y)[end]]),)
y -> begin
y isa Ref && (y = VectorOfArray(y[].u))
(VectorOfArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
for i in 1:size(y.u)[end]]),)
end
end

@adjoint function DiffEqArray(u, t)
DiffEqArray(u, t),
y -> (DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]],
t), nothing)
y -> begin
y isa Ref && (y = VectorOfArray(y[].u))
(DiffEqArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
for i in 1:size(y.u)[end]],
t), nothing)
end
end

@adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x})
Expand Down
22 changes: 14 additions & 8 deletions src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ module RecursiveArrayTools

using DocStringExtensions
using RecipesBase, StaticArraysCore, Statistics,
ArrayInterface, LinearAlgebra
ArrayInterface, LinearAlgebra
using SymbolicIndexingInterface

import Adapt

import Tables, IteratorInterfaceExtensions

abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
abstract type AbstractVectorOfArray{T, N, A} end
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end

include("utils.jl")
Expand All @@ -31,18 +31,24 @@ Base.convert(T::Type{<:GPUArraysCore.AbstractGPUArray}, VA::AbstractVectorOfArra
import Requires
@static if !isdefined(Base, :get_extension)
function __init__()
Requires.@require Measurements="eff96d63-e80a-5855-80a2-b1b0885c5ab7" begin include("../ext/RecursiveArrayToolsMeasurementsExt.jl") end
Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/RecursiveArrayToolsTrackerExt.jl") end
Requires.@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/RecursiveArrayToolsZygoteExt.jl") end
Requires.@require Measurements="eff96d63-e80a-5855-80a2-b1b0885c5ab7" begin
include("../ext/RecursiveArrayToolsMeasurementsExt.jl")
end
Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
include("../ext/RecursiveArrayToolsTrackerExt.jl")
end
Requires.@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("../ext/RecursiveArrayToolsZygoteExt.jl")
end
end
end

export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
AllObserved, vecarr_to_vectors, tuples
AllObserved, vecarr_to_vectors, tuples

export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_push!,
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
recursive_unitless_bottom_eltype, recursive_unitless_eltype
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
recursive_unitless_bottom_eltype, recursive_unitless_eltype

export ArrayPartition

Expand Down
Loading
Loading