Skip to content

Commit

Permalink
Merge pull request #10 from SciML/master
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
Shreyas-Ekanathan authored Aug 15, 2024
2 parents 1125bc1 + f5f1cc4 commit d836402
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ import FunctionWrappersWrappers
using DiffEqBase

import LinearAlgebra
import LinearAlgebra: Diagonal, I, UniformScaling, diagind, mul!, lmul!, axpby!, opnorm
import LinearAlgebra: Diagonal, I, UniformScaling, diagind, mul!, lmul!, axpby!, opnorm, lu
import LinearAlgebra: LowerTriangular, UpperTriangular
import SparseArrays: SparseMatrixCSC, AbstractSparseMatrix, nonzeros

import InteractiveUtils
import ArrayInterface

import StaticArrayInterface
import StaticArrays
import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA,
StaticMatrix

Expand Down
17 changes: 13 additions & 4 deletions lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
const ROSENBROCK_INV_CUTOFF = 7 # https://github.com/SciML/OrdinaryDiffEq.jl/pull/1539

struct StaticWOperator{isinv, T} <: AbstractSciMLOperator{T}
struct StaticWOperator{isinv, T, F} <: AbstractSciMLOperator{T}
W::T
F::F
function StaticWOperator(W::T, callinv = true) where {T}
isinv = size(W, 1) <= ROSENBROCK_INV_CUTOFF
n = size(W, 1)
isinv = n <= ROSENBROCK_INV_CUTOFF

F = if isinv && callinv
# this should be in ArrayInterface but can't be for silly reasons
# doing to how StaticArrays and StaticArraysCore are split up
StaticArrays.LU(LowerTriangular(W), UpperTriangular(W), SVector{n}(1:n))
else
lu(W, check=false)
end
# when constructing W for the first time for the type
# inv(W) can be singular
_W = if isinv && callinv
inv(W)
else
W
end
new{isinv, T}(_W)
new{isinv, T, typeof(F)}(_W, F)
end
end
isinv(W::StaticWOperator{S}) where {S} = S
Base.:\(W::StaticWOperator, v::AbstractArray) = isinv(W) ? W.W * v : W.W \ v
Base.:\(W::StaticWOperator, v::AbstractArray) = isinv(W) ? W.W * v : W.F \ v

function calc_tderivative!(integrator, cache, dtd1, repeat_step)
@inbounds begin
Expand Down

0 comments on commit d836402

Please sign in to comment.