diff --git a/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl b/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl index cffbf869cb..4fdbd9f54b 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl @@ -14,13 +14,15 @@ import FunctionWrappersWrappers using DiffEqBase import LinearAlgebra -import LinearAlgebra: Diagonal, I, UniformScaling, diagind, mul!, lmul!, axpby!, opnorm +import LinearAlgebra: Diagonal, I, UniformScaling, diagind, mul!, lmul!, axpby!, opnorm, lu +import LinearAlgebra: LowerTriangular, UpperTriangular import SparseArrays: SparseMatrixCSC, AbstractSparseMatrix, nonzeros import InteractiveUtils import ArrayInterface import StaticArrayInterface +import StaticArrays import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA, StaticMatrix diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl index ee2a9d4fd4..a0f56c835b 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl @@ -1,10 +1,19 @@ const ROSENBROCK_INV_CUTOFF = 7 # https://github.com/SciML/OrdinaryDiffEq.jl/pull/1539 -struct StaticWOperator{isinv, T} <: AbstractSciMLOperator{T} +struct StaticWOperator{isinv, T, F} <: AbstractSciMLOperator{T} W::T + F::F function StaticWOperator(W::T, callinv = true) where {T} - isinv = size(W, 1) <= ROSENBROCK_INV_CUTOFF + n = size(W, 1) + isinv = n <= ROSENBROCK_INV_CUTOFF + F = if isinv && callinv + # this should be in ArrayInterface but can't be for silly reasons + # doing to how StaticArrays and StaticArraysCore are split up + StaticArrays.LU(LowerTriangular(W), UpperTriangular(W), SVector{n}(1:n)) + else + lu(W, check=false) + end # when constructing W for the first time for the type # inv(W) can be singular _W = if isinv && callinv @@ -12,11 +21,11 @@ struct StaticWOperator{isinv, T} <: AbstractSciMLOperator{T} else W end - new{isinv, T}(_W) + new{isinv, T, typeof(F)}(_W, F) end end isinv(W::StaticWOperator{S}) where {S} = S -Base.:\(W::StaticWOperator, v::AbstractArray) = isinv(W) ? W.W * v : W.W \ v +Base.:\(W::StaticWOperator, v::AbstractArray) = isinv(W) ? W.W * v : W.F \ v function calc_tderivative!(integrator, cache, dtd1, repeat_step) @inbounds begin