From 85baf2c32488d98d683b6757b3697aa199ff7294 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 15 Nov 2023 20:59:53 -0500 Subject: [PATCH 1/2] Drop passing p and t to the JacVec operator --- src/derivative_utils.jl | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/derivative_utils.jl b/src/derivative_utils.jl index 5fe107e2f8..2c2e21786f 100644 --- a/src/derivative_utils.jl +++ b/src/derivative_utils.jl @@ -302,9 +302,22 @@ function SciMLOperators.update_coefficients!(W::WOperator, dtgamma = nothing, transform = nothing) if (u !== nothing) && (p !== nothing) && (t !== nothing) - update_coefficients!(W.J, u, p, t) + if W.J isa FunctionOperator + update_coefficients!(W.J, u, ifelse(W.J.p === nothing, nothing, p), + ifelse(W.J.t === nothing, nothing, t)) + else + update_coefficients!(W.J, u, p, t) + end update_coefficients!(W.mass_matrix, u, p, t) - !isnothing(W.jacvec) && update_coefficients!(W.jacvec, u, p, t) + if !isnothing(W.jacvec) + if W.jacvec isa FunctionOperator + update_coefficients!(W.jacvec, u, + ifelse(W.jacvec.p === nothing, nothing, p), + ifelse(W.jacvec.t === nothing, nothing, t)) + else + update_coefficients!(W.jacvec, u, p, t) + end + end end dtgamma !== nothing && (W.gamma = dtgamma) transform !== nothing && (W.transform = transform) @@ -874,7 +887,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits}, # be overridden with concrete_jac. _f = islin ? (isode ? f.f : f.f1.f) : f - jacvec = JacVec(UJacobianWrapper(_f, t, p), copy(u), p, t; + jacvec = JacVec(UJacobianWrapper(_f, t, p), copy(u); autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag()) J = jacvec W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec) @@ -890,7 +903,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits}, else deepcopy(f.jac_prototype) end - jacvec = JacVec(UJacobianWrapper(_f, t, p), copy(u), p, t; + jacvec = JacVec(UJacobianWrapper(_f, t, p), copy(u); autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag()) W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec) From b1a6545729676cfaa9605c39632bb1f38c0293dd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Nov 2023 10:01:36 -0500 Subject: [PATCH 2/2] Do it properly --- src/derivative_utils.jl | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/src/derivative_utils.jl b/src/derivative_utils.jl index 2c2e21786f..c671634a02 100644 --- a/src/derivative_utils.jl +++ b/src/derivative_utils.jl @@ -302,22 +302,9 @@ function SciMLOperators.update_coefficients!(W::WOperator, dtgamma = nothing, transform = nothing) if (u !== nothing) && (p !== nothing) && (t !== nothing) - if W.J isa FunctionOperator - update_coefficients!(W.J, u, ifelse(W.J.p === nothing, nothing, p), - ifelse(W.J.t === nothing, nothing, t)) - else - update_coefficients!(W.J, u, p, t) - end + update_coefficients!(W.J, u, p, t) update_coefficients!(W.mass_matrix, u, p, t) - if !isnothing(W.jacvec) - if W.jacvec isa FunctionOperator - update_coefficients!(W.jacvec, u, - ifelse(W.jacvec.p === nothing, nothing, p), - ifelse(W.jacvec.t === nothing, nothing, t)) - else - update_coefficients!(W.jacvec, u, p, t) - end - end + !isnothing(W.jacvec) && update_coefficients!(W.jacvec, u, p, t) end dtgamma !== nothing && (W.gamma = dtgamma) transform !== nothing && (W.transform = transform) @@ -887,7 +874,7 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits}, # be overridden with concrete_jac. _f = islin ? (isode ? f.f : f.f1.f) : f - jacvec = JacVec(UJacobianWrapper(_f, t, p), copy(u); + jacvec = JacVec((du, u, p, t) -> _f(du, u, p, t), copy(u), p, t; autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag()) J = jacvec W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec) @@ -903,7 +890,12 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits}, else deepcopy(f.jac_prototype) end - jacvec = JacVec(UJacobianWrapper(_f, t, p), copy(u); + __f = if IIP + (du, u, p, t) -> _f(du, u, p, t) + else + (u, p, t) -> _f(u, p, t) + end + jacvec = JacVec(__f, copy(u), p, t; autodiff = alg_autodiff(alg), tag = OrdinaryDiffEqTag()) W = WOperator{IIP}(f.mass_matrix, dt, J, u, jacvec)