From 65c10e6c50e9fe1a2fc3ebd5f38eb36f5d3931a8 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 1 Aug 2024 14:32:43 -0400 Subject: [PATCH 1/3] optimize StaticWOperator by using lu to allow saving the factorization --- .../src/derivative_utils.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl index ee2a9d4fd4..32b002c1f9 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl @@ -1,22 +1,24 @@ const ROSENBROCK_INV_CUTOFF = 7 # https://github.com/SciML/OrdinaryDiffEq.jl/pull/1539 -struct StaticWOperator{isinv, T} <: AbstractSciMLOperator{T} +struct StaticWOperator{isinv, T, F} <: AbstractSciMLOperator{T} W::T + F::F function StaticWOperator(W::T, callinv = true) where {T} isinv = size(W, 1) <= ROSENBROCK_INV_CUTOFF + F = lu(W, check=false) # when constructing W for the first time for the type # inv(W) can be singular _W = if isinv && callinv - inv(W) + F\typeof(W)(I) else W end - new{isinv, T}(_W) + new{isinv, T, typeof(F)}(_W, F) end end isinv(W::StaticWOperator{S}) where {S} = S -Base.:\(W::StaticWOperator, v::AbstractArray) = isinv(W) ? W.W * v : W.W \ v +Base.:\(W::StaticWOperator, v::AbstractArray) = isinv(W) ? W.W * v : W.F \ v function calc_tderivative!(integrator, cache, dtd1, repeat_step) @inbounds begin From dc0161eecb3c8efd72644fd4ba569a86f0ed0b3e Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 14 Aug 2024 11:55:49 -0400 Subject: [PATCH 2/3] import lu --- .../src/OrdinaryDiffEqDifferentiation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl b/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl index cffbf869cb..31bb6ea3e3 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl @@ -14,7 +14,7 @@ import FunctionWrappersWrappers using DiffEqBase import LinearAlgebra -import LinearAlgebra: Diagonal, I, UniformScaling, diagind, mul!, lmul!, axpby!, opnorm +import LinearAlgebra: Diagonal, I, UniformScaling, diagind, mul!, lmul!, axpby!, opnorm, lu import SparseArrays: SparseMatrixCSC, AbstractSparseMatrix, nonzeros import InteractiveUtils From 9474b35e4a5e5bde37c4d50bfa2f4a47900ec942 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 14 Aug 2024 12:48:22 -0400 Subject: [PATCH 3/3] improve performance --- .../src/OrdinaryDiffEqDifferentiation.jl | 2 ++ .../src/derivative_utils.jl | 13 ++++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl b/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl index 31bb6ea3e3..4fdbd9f54b 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/OrdinaryDiffEqDifferentiation.jl @@ -15,12 +15,14 @@ using DiffEqBase import LinearAlgebra import LinearAlgebra: Diagonal, I, UniformScaling, diagind, mul!, lmul!, axpby!, opnorm, lu +import LinearAlgebra: LowerTriangular, UpperTriangular import SparseArrays: SparseMatrixCSC, AbstractSparseMatrix, nonzeros import InteractiveUtils import ArrayInterface import StaticArrayInterface +import StaticArrays import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA, StaticMatrix diff --git a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl index 32b002c1f9..a0f56c835b 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl @@ -4,13 +4,20 @@ struct StaticWOperator{isinv, T, F} <: AbstractSciMLOperator{T} W::T F::F function StaticWOperator(W::T, callinv = true) where {T} - isinv = size(W, 1) <= ROSENBROCK_INV_CUTOFF + n = size(W, 1) + isinv = n <= ROSENBROCK_INV_CUTOFF - F = lu(W, check=false) + F = if isinv && callinv + # this should be in ArrayInterface but can't be for silly reasons + # doing to how StaticArrays and StaticArraysCore are split up + StaticArrays.LU(LowerTriangular(W), UpperTriangular(W), SVector{n}(1:n)) + else + lu(W, check=false) + end # when constructing W for the first time for the type # inv(W) can be singular _W = if isinv && callinv - F\typeof(W)(I) + inv(W) else W end