Skip to content

Commit

Permalink
fix: update compat entires, reduce method ambiguities
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 19, 2023
1 parent 094c2bb commit f8cfd1c
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 112 deletions.
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Expand Down Expand Up @@ -39,15 +40,17 @@ GPUArraysCore = "0.1"
IteratorInterfaceExtensions = "1"
LabelledArrays = "1"
LinearAlgebra = "1"
Measurements = "2"
MonteCarloMeasurements = "1"
Measurements = "2.3"
MonteCarloMeasurements = "1.1"
NLsolve = "4"
OrdinaryDiffEq = "6"
Pkg = "1"
Random = "1"
RecipesBase = "0.7, 0.8, 1.0"
Requires = "1.0"
SafeTestsets = "0.1"
StaticArrays = "0.12"
SparseArrays = "1"
StaticArrays = "1.6"
StaticArraysCore = "1.1"
Statistics = "1"
StructArrays = "0.6"
Expand Down
1 change: 1 addition & 0 deletions src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using DocStringExtensions
using RecipesBase, StaticArraysCore, Statistics,
ArrayInterface, LinearAlgebra
using SymbolicIndexingInterface
using SparseArrays

import Adapt

Expand Down
147 changes: 55 additions & 92 deletions src/array_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,16 @@ Base.all(f, A::ArrayPartition) = all(f, (all(f, x) for x in A.x))
Base.all(f::Function, A::ArrayPartition) = all((all(f, x) for x in A.x))
Base.all(A::ArrayPartition) = all(identity, A)

function Base.copyto!(dest::AbstractArray, A::ArrayPartition)
@assert length(dest) == length(A)
cur = 1
@inbounds for i in 1:length(A.x)
dest[cur:(cur + length(A.x[i]) - 1)] .= vec(A.x[i])
cur += length(A.x[i])
for type in [AbstractArray, SparseArrays.AbstractCompressedVector, PermutedDimsArray]
@eval function Base.copyto!(dest::$(type), A::ArrayPartition)
@assert length(dest) == length(A)
cur = 1
@inbounds for i in 1:length(A.x)
dest[cur:(cur + length(A.x[i]) - 1)] .= vec(A.x[i])
cur += length(A.x[i])
end
dest

Check warning on line 187 in src/array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/array_partition.jl#L180-L187

Added lines #L180 - L187 were not covered by tests
end
dest
end

function Base.copyto!(A::ArrayPartition, src::ArrayPartition)
Expand Down Expand Up @@ -419,30 +421,38 @@ end

ArrayInterface.zeromatrix(A::ArrayPartition) = ArrayInterface.zeromatrix(Vector(A))

function LinearAlgebra.ldiv!(A::Factorization, b::ArrayPartition)
(x = ldiv!(A, Array(b)); copyto!(b, x))
function __get_subtypes_in_module(mod, supertype; include_supertype = true, all=false, except=[])
return filter([getproperty(mod, name) for name in names(mod; all) if !in(name, except)]) do value
return value isa Type && (value <: supertype) && (include_supertype || value != supertype) && !in(value, except)
end
end

@static if VERSION >= v"1.9"
function LinearAlgebra.ldiv!(A::LinearAlgebra.SVD{T, Tr, M},
b::ArrayPartition) where {Tr, T, M <: AbstractArray{T}}
for factorization in vcat(__get_subtypes_in_module(LinearAlgebra, Factorization; include_supertype = false, all=true, except=[:LU, :LAPACKFactorizations]), LDLt{T,<:SymTridiagonal{T,V} where {V<:AbstractVector{T}}} where {T})
@eval function LinearAlgebra.ldiv!(A::T, b::ArrayPartition) where {T<:$factorization}

Check warning on line 431 in src/array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/array_partition.jl#L431

Added line #L431 was not covered by tests
(x = ldiv!(A, Array(b)); copyto!(b, x))
end
end

function LinearAlgebra.ldiv!(A::LinearAlgebra.QRCompactWY{T, M, C},
b::ArrayPartition) where {
T <: Union{Float32, Float64, ComplexF64, ComplexF32},
M <: AbstractMatrix{T},
C <: AbstractMatrix{T},
}
(x = ldiv!(A, Array(b)); copyto!(b, x))
end
function LinearAlgebra.ldiv!(A::LinearAlgebra.SVD{T, Tr, M},

Check warning on line 436 in src/array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/array_partition.jl#L436

Added line #L436 was not covered by tests
b::ArrayPartition) where {Tr, T, M <: AbstractArray{T}}
(x = ldiv!(A, Array(b)); copyto!(b, x))

Check warning on line 438 in src/array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/array_partition.jl#L438

Added line #L438 was not covered by tests
end

function LinearAlgebra.ldiv!(A::LU, b::ArrayPartition)
LinearAlgebra._ipiv_rows!(A, 1:length(A.ipiv), b)
ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), b))
return b
function LinearAlgebra.ldiv!(A::LinearAlgebra.QRCompactWY{T, M, C},

Check warning on line 441 in src/array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/array_partition.jl#L441

Added line #L441 was not covered by tests
b::ArrayPartition) where {
T <: Union{Float32, Float64, ComplexF64, ComplexF32},
M <: AbstractMatrix{T},
C <: AbstractMatrix{T},
}
(x = ldiv!(A, Array(b)); copyto!(b, x))

Check warning on line 447 in src/array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/array_partition.jl#L447

Added line #L447 was not covered by tests
end

for type in [LU, LU{T,Tridiagonal{T,V}} where {T,V}]
@eval function LinearAlgebra.ldiv!(A::$type, b::ArrayPartition)
LinearAlgebra._ipiv_rows!(A, 1:length(A.ipiv), b)
ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), b))
return b

Check warning on line 454 in src/array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/array_partition.jl#L451-L454

Added lines #L451 - L454 were not covered by tests
end
end

# block matrix indexing
Expand All @@ -458,78 +468,31 @@ end
# [U11 U12 U13] [ b1 ]
# [ 0 U22 U23] \ [ b2 ]
# [ 0 0 U33] [ b3 ]
function LinearAlgebra.ldiv!(A::UnitUpperTriangular, bb::ArrayPartition)
A = A.data
n = npartitions(bb)
b = bb.x
lens = map(length, b)
@inbounds for j in n:-1:1
Ajj = UnitUpperTriangular(getblock(A, lens, j, j))
xj = ldiv!(Ajj, vec(b[j]))
for i in (j - 1):-1:1
Aij = getblock(A, lens, i, j)
# bi = -Aij * xj + bi
mul!(vec(b[i]), Aij, xj, -1, true)
end
end
return bb
end

function LinearAlgebra.ldiv!(A::UpperTriangular, bb::ArrayPartition)
A = A.data
n = npartitions(bb)
b = bb.x
lens = map(length, b)
@inbounds for j in n:-1:1
Ajj = UpperTriangular(getblock(A, lens, j, j))
xj = ldiv!(Ajj, vec(b[j]))
for i in (j - 1):-1:1
Aij = getblock(A, lens, i, j)
# bi = -Aij * xj + bi
mul!(vec(b[i]), Aij, xj, -1, true)
end
end
return bb
end

function LinearAlgebra.ldiv!(A::UnitLowerTriangular, bb::ArrayPartition)
A = A.data
n = npartitions(bb)
b = bb.x
lens = map(length, b)
@inbounds for j in 1:n
Ajj = UnitLowerTriangular(getblock(A, lens, j, j))
xj = ldiv!(Ajj, vec(b[j]))
for i in (j + 1):n
Aij = getblock(A, lens, i, j)
# bi = -Aij * xj + b[i]
mul!(vec(b[i]), Aij, xj, -1, true)
for basetype in [UnitUpperTriangular, UpperTriangular, UnitLowerTriangular, LowerTriangular]
for type in [basetype, basetype{T, <:Adjoint{T}} where {T}, basetype{T, <:Transpose{T}} where {T}]
j_iter, i_iter = if basetype <: UnitUpperTriangular || basetype <: UpperTriangular
(:(n:-1:1), :(j-1:-1:1))
else
(:(1:n), :((j+1):n))
end
end
return bb
end
function _ldiv!(A::LowerTriangular, bb::ArrayPartition)
A = A.data
n = npartitions(bb)
b = bb.x
lens = map(length, b)
@inbounds for j in 1:n
Ajj = LowerTriangular(getblock(A, lens, j, j))
xj = ldiv!(Ajj, vec(b[j]))
for i in (j + 1):n
Aij = getblock(A, lens, i, j)
# bi = -Aij * xj + b[i]
mul!(vec(b[i]), Aij, xj, -1, true)
@eval function LinearAlgebra.ldiv!(A::$type, bb::ArrayPartition)
A = A.data
n = npartitions(bb)
b = bb.x
lens = map(length, b)
@inbounds for j in $j_iter
Ajj = $basetype(getblock(A, lens, j, j))
xj = ldiv!(Ajj, vec(b[j]))
for i in $i_iter
Aij = getblock(A, lens, i, j)

Check warning on line 487 in src/array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/array_partition.jl#L478-L487

Added lines #L478 - L487 were not covered by tests
# bi = -Aij * xj + bi
mul!(vec(b[i]), Aij, xj, -1, true)
end
end
return bb

Check warning on line 492 in src/array_partition.jl

View check run for this annotation

Codecov / codecov/patch

src/array_partition.jl#L489-L492

Added lines #L489 - L492 were not covered by tests
end
end
return bb
end

function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:LinearAlgebra.Adjoint{T}},
bb::ArrayPartition) where {T}
_ldiv!(A, bb)
end
LinearAlgebra.ldiv!(A::LowerTriangular, bb::ArrayPartition) = _ldiv!(A, bb)

# TODO: optimize
function LinearAlgebra._ipiv_rows!(A::LU, order::OrdinalRange, B::ArrayPartition)
Expand Down
4 changes: 2 additions & 2 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray
end

function DiffEqArray(vec::AbstractVector{T},
ts,
ts::AbstractVector,
::NTuple{N, Int},
p = nothing,
sys = nothing) where {T, N}
Expand Down Expand Up @@ -532,7 +532,7 @@ function Base.CartesianIndices(VA::AbstractVectorOfArray)
end

# Tools for creating similar objects
Base.eltype(::VectorOfArray{T}) where {T} = T
Base.eltype(::Type{<:AbstractVectorOfArray{T}}) where {T} = T

Check warning on line 535 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L535

Added line #L535 was not covered by tests
# TODO: Is there a better way to do this?
@inline function Base.similar(VA::AbstractVectorOfArray, args...)
if args[end] isa Type
Expand Down
2 changes: 1 addition & 1 deletion test/interface_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using RecursiveArrayTools, Test
using RecursiveArrayTools, StaticArrays, Test

t = 1:3
testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
Expand Down
4 changes: 3 additions & 1 deletion test/qa.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using RecursiveArrayTools, Aqua
@testset "Aqua" begin
Aqua.find_persistent_tasks_deps(RecursiveArrayTools)
Aqua.test_ambiguities(RecursiveArrayTools, recursive = false, broken = true)
ambs = Aqua.detect_ambiguities(RecursiveArrayTools; recursive = true)
@warn "Number of method ambiguities: $(length(ambs))"
@test length(ambs) <= 2
Aqua.test_deps_compat(RecursiveArrayTools)
Aqua.test_piracies(RecursiveArrayTools)
Aqua.test_project_extras(RecursiveArrayTools)
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ end
@time @safetestset "Upstream Tests" begin
include("upstream.jl")
end
# @time @safetestset "Adjoint Tests" begin include("adjoints.jl") end
@time @safetestset "Adjoint Tests" begin include("adjoints.jl") end
@time @safetestset "Measurement Tests" begin
include("measurements.jl")
end
Expand All @@ -65,7 +65,7 @@ end
@time @safetestset "Event Tests with ArrayPartition" begin
include("downstream/downstream_events.jl")
end
VERSION >= v"1.9" && @time @safetestset "Measurements and Units" begin
@time @safetestset "Measurements and Units" begin
include("downstream/measurements_and_units.jl")
end
@time @safetestset "TrackerExt" begin
Expand Down
15 changes: 4 additions & 11 deletions test/upstream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,7 @@ end
ArrayPartition(zeros(1), [0.75])),
(0.0, 1.0)), AutoTsit5(Rodas5())).retcode == ReturnCode.Success

if VERSION < v"1.7"
@test solve(ODEProblem(dyn,
ArrayPartition(ArrayPartition(zeros(1), [-1.0]),
ArrayPartition(zeros(1), [0.75])),
(0.0, 1.0)), Rodas5()).retcode == ReturnCode.Success
else
@test_broken solve(ODEProblem(dyn,
ArrayPartition(ArrayPartition(zeros(1), [-1.0]),
ArrayPartition(zeros(1), [0.75])),
(0.0, 1.0)), Rodas5()).retcode == ReturnCode.Success
end
@test_broken solve(ODEProblem(dyn,
ArrayPartition(ArrayPartition(zeros(1), [-1.0]),
ArrayPartition(zeros(1), [0.75])),
(0.0, 1.0)), Rodas5()).retcode == ReturnCode.Success

0 comments on commit f8cfd1c

Please sign in to comment.