Skip to content

Commit

Permalink
Merge pull request #290 from AayushSabharwal/as/indexing-rework
Browse files Browse the repository at this point in the history
feat!: rework AbstractVectorOfArray, use new SymbolicIndexingInterface
  • Loading branch information
ChrisRackauckas authored Dec 11, 2023
2 parents d869b10 + 0936086 commit 93125bc
Show file tree
Hide file tree
Showing 22 changed files with 696 additions and 449 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ RecipesBase = "0.7, 0.8, 1.0"
Requires = "1.0"
StaticArraysCore = "1.1"
Statistics = "1"
SymbolicIndexingInterface = "0.1, 0.2"
SymbolicIndexingInterface = "0.3"
Tables = "1"
Zygote = "0.6.56"
julia = "1.6"
Expand Down
16 changes: 8 additions & 8 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true)
include("pages.jl")

makedocs(sitename = "RecursiveArrayTools.jl",
authors = "Chris Rackauckas",
modules = [RecursiveArrayTools],
clean = true, doctest = false, linkcheck = true,
warnonly = [:missing_docs],
format = Documenter.HTML(assets = ["assets/favicon.ico"],
canonical = "https://docs.sciml.ai/RecursiveArrayTools/stable/"),
pages = pages)
authors = "Chris Rackauckas",
modules = [RecursiveArrayTools],
clean = true, doctest = false, linkcheck = true,
warnonly = [:missing_docs],
format = Documenter.HTML(assets = ["assets/favicon.ico"],
canonical = "https://docs.sciml.ai/RecursiveArrayTools/stable/"),
pages = pages)

deploydocs(repo = "github.com/SciML/RecursiveArrayTools.jl.git";
push_preview = true)
push_preview = true)
4 changes: 2 additions & 2 deletions ext/RecursiveArrayToolsMeasurementsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import RecursiveArrayTools
isdefined(Base, :get_extension) ? (import Measurements) : (import ..Measurements)

function RecursiveArrayTools.recursive_unitless_bottom_eltype(a::Type{
<:Measurements.Measurement
})
<:Measurements.Measurement,
})
typeof(oneunit(a))
end

Expand Down
11 changes: 7 additions & 4 deletions ext/RecursiveArrayToolsMonteCarloMeasurementsExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
module RecursiveArrayToolsMonteCarloMeasurementsExt

import RecursiveArrayTools
isdefined(Base, :get_extension) ? (import MonteCarloMeasurements) : (import ..MonteCarloMeasurements)
isdefined(Base, :get_extension) ? (import MonteCarloMeasurements) :
(import ..MonteCarloMeasurements)

function RecursiveArrayTools.recursive_unitless_bottom_eltype(a::Type{
<:MonteCarloMeasurements.Particles
})
<:MonteCarloMeasurements.Particles,
})
typeof(one(a))
end

function RecursiveArrayTools.recursive_unitless_eltype(a::Type{<:MonteCarloMeasurements.Particles})
function RecursiveArrayTools.recursive_unitless_eltype(a::Type{
<:MonteCarloMeasurements.Particles,
})
typeof(one(a))
end

Expand Down
12 changes: 6 additions & 6 deletions ext/RecursiveArrayToolsTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import RecursiveArrayTools
isdefined(Base, :get_extension) ? (import Tracker) : (import ..Tracker)

function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T, N},
a::AbstractArray{T2, N}) where {
T <:
Tracker.TrackedArray,
T2 <:
Tracker.TrackedArray,
N}
a::AbstractArray{T2, N}) where {
T <:
Tracker.TrackedArray,
T2 <:
Tracker.TrackedArray,
N}
@inbounds for i in eachindex(a)
b[i] = copy(a[i])
end
Expand Down
35 changes: 21 additions & 14 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end
ChainRulesCore.ProjectTo(x::VectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}()

function ChainRulesCore.rrule(T::Type{<:RecursiveArrayTools.GPUArraysCore.AbstractGPUArray},
xs::AbstractVectorOfArray)
xs::AbstractVectorOfArray)
T(xs), ȳ -> (ChainRulesCore.NoTangent(), ȳ)
end

Expand All @@ -28,7 +28,7 @@ end
end

@adjoint function getindex(VA::AbstractVectorOfArray,
i::Union{BitArray, AbstractArray{Bool}})
i::Union{BitArray, AbstractArray{Bool}})
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = [(i[j] ? Δ[j] : FillArrays.Fill(zero(eltype(x)), size(x)))
for (x, j) in zip(VA.u, 1:length(VA))]
Expand All @@ -48,7 +48,7 @@ end
end

@adjoint function getindex(VA::AbstractVectorOfArray,
i::Union{Int, AbstractArray{Int}})
i::Union{Int, AbstractArray{Int}})
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = [(i[j] ? Δ[j] : FillArrays.Fill(zero(eltype(x)), size(x)))
for (x, j) in zip(VA.u, 1:length(VA))]
Expand All @@ -65,8 +65,8 @@ end
end

@adjoint function getindex(VA::AbstractVectorOfArray, i::Int,
j::Union{Int, AbstractArray{Int}, CartesianIndex,
Colon, BitArray, AbstractArray{Bool}}...)
j::Union{Int, AbstractArray{Int}, CartesianIndex,
Colon, BitArray, AbstractArray{Bool}}...)
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))])
Δ′[i, j...] = Δ
Expand All @@ -76,11 +76,11 @@ end
end

@adjoint function ArrayPartition(x::S,
::Type{Val{copy_x}} = Val{false}) where {
S <:
Tuple,
copy_x
}
::Type{Val{copy_x}} = Val{false}) where {
S <:
Tuple,
copy_x,
}
function ArrayPartition_adjoint(_y)
y = Array(_y)
starts = vcat(0, cumsum(reduce(vcat, length.(x))))
Expand All @@ -93,14 +93,21 @@ end

@adjoint function VectorOfArray(u)
VectorOfArray(u),
y -> (VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]
for i in 1:size(y)[end]]),)
y -> begin
y isa Ref && (y = VectorOfArray(y[].u))
(VectorOfArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
for i in 1:size(y.u)[end]]),)
end
end

@adjoint function DiffEqArray(u, t)
DiffEqArray(u, t),
y -> (DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i] for i in 1:size(y)[end]],
t), nothing)
y -> begin
y isa Ref && (y = VectorOfArray(y[].u))
(DiffEqArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
for i in 1:size(y.u)[end]],
t), nothing)
end
end

@adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x})
Expand Down
22 changes: 14 additions & 8 deletions src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ module RecursiveArrayTools

using DocStringExtensions
using RecipesBase, StaticArraysCore, Statistics,
ArrayInterface, LinearAlgebra
ArrayInterface, LinearAlgebra
using SymbolicIndexingInterface

import Adapt

import Tables, IteratorInterfaceExtensions

abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
abstract type AbstractVectorOfArray{T, N, A} end
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end

include("utils.jl")
Expand All @@ -31,18 +31,24 @@ Base.convert(T::Type{<:GPUArraysCore.AbstractGPUArray}, VA::AbstractVectorOfArra
import Requires
@static if !isdefined(Base, :get_extension)
function __init__()
Requires.@require Measurements="eff96d63-e80a-5855-80a2-b1b0885c5ab7" begin include("../ext/RecursiveArrayToolsMeasurementsExt.jl") end
Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/RecursiveArrayToolsTrackerExt.jl") end
Requires.@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/RecursiveArrayToolsZygoteExt.jl") end
Requires.@require Measurements="eff96d63-e80a-5855-80a2-b1b0885c5ab7" begin
include("../ext/RecursiveArrayToolsMeasurementsExt.jl")
end
Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
include("../ext/RecursiveArrayToolsTrackerExt.jl")
end
Requires.@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("../ext/RecursiveArrayToolsZygoteExt.jl")
end
end
end

export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
AllObserved, vecarr_to_vectors, tuples
AllObserved, vecarr_to_vectors, tuples

export recursivecopy, recursivecopy!, recursivefill!, vecvecapply, copyat_or_push!,
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
recursive_unitless_bottom_eltype, recursive_unitless_eltype
vecvec_to_mat, recursive_one, recursive_mean, recursive_bottom_eltype,
recursive_unitless_bottom_eltype, recursive_unitless_eltype

export ArrayPartition

Expand Down
Loading

0 comments on commit 93125bc

Please sign in to comment.