Skip to content

Commit

Permalink
remove unnecessary changes
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Nov 6, 2023
1 parent 5ab14bf commit b720c86
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 22 deletions.
20 changes: 10 additions & 10 deletions src/caches/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))

grad_config = build_grad_config(alg, f, tf, dT, t)
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2, Val(false))
algebraic_vars = f.mass_matrix === I ? nothing :
[all(iszero, x) for x in eachcol(f.mass_matrix)]
Expand Down Expand Up @@ -142,7 +142,7 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, dT, t)
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2, Val(false))
algebraic_vars = f.mass_matrix === I ? nothing :
[all(iszero, x) for x in eachcol(f.mass_matrix)]
Expand Down Expand Up @@ -287,7 +287,7 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, dT, t)
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rosenbrock33Cache(u, uprev, du, du1, du2, k1, k2, k3, k4,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
Expand Down Expand Up @@ -369,7 +369,7 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, dT, t)
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rosenbrock34Cache(u, uprev, du, du1, du2, k1, k2, k3, k4,
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
Expand Down Expand Up @@ -498,7 +498,7 @@ function alg_cache(alg::Rodas4, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, dT, t)
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4,
k5, k6,
Expand Down Expand Up @@ -558,7 +558,7 @@ function alg_cache(alg::Rodas42, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, dT, t)
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4,
k5, k6,
Expand Down Expand Up @@ -616,7 +616,7 @@ function alg_cache(alg::Rodas4P, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, dT, t)
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4,
k5, k6,
Expand Down Expand Up @@ -674,7 +674,7 @@ function alg_cache(alg::Rodas4P2, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, dT, t)
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4,
k5, k6,
Expand Down Expand Up @@ -788,7 +788,7 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, dT, t)
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rosenbrock5Cache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4,
k5, k6, k7, k8,
Expand Down Expand Up @@ -849,7 +849,7 @@ function alg_cache(alg::Rodas5P, u, rate_prototype, ::Type{uEltypeNoUnits},
linsolve = init(linprob, alg.linsolve, alias_A = true, alias_b = true,
Pl = Pl, Pr = Pr,
assumptions = LinearSolve.OperatorAssumptions(true))
grad_config = build_grad_config(alg, f, tf, dT, t)
grad_config = build_grad_config(alg, f, tf, du1, t)
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
Rosenbrock5Cache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4,
k5, k6, k7, k8,
Expand Down
23 changes: 12 additions & 11 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ function derivative!(df::AbstractArray{<:Number}, f,
try
f(grad_config, xdual)
catch e
rethrow(e)
throw(FirstAutodiffTgradError(e))
end
else
Expand Down Expand Up @@ -356,28 +357,28 @@ function resize_grad_config!(grad_config::FiniteDiff.GradientCache, i)
grad_config
end

function build_grad_config(alg, f::F1, tf::F2, dT, t) where {F1, F2}
function build_grad_config(alg, f::F1, tf::F2, du1, t) where {F1, F2}
if !DiffEqBase.has_tgrad(f)
if alg_autodiff(alg) isa AutoForwardDiff
T = if standardtag(alg)
typeof(ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(dT)))
typeof(ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(du1)))
else
typeof(ForwardDiff.Tag(f, eltype(dT)))
typeof(ForwardDiff.Tag(f, eltype(du1)))
end

if dT isa Array
dualt = Dual{T, eltype(dT), 1}(first(dT),
ForwardDiff.Partials((one(eltype(dT)),)))
grad_config = similar(dT, typeof(dualt))
if du1 isa Array
dualt = Dual{T, eltype(du1), 1}(first(du1),
ForwardDiff.Partials((one(eltype(du1)),)))
grad_config = similar(du1, typeof(dualt))
fill!(grad_config, false)
else
grad_config = ArrayInterface.restructure(dT,
Dual{T, eltype(dT), 1}.(dT,
(ForwardDiff.Partials((one(eltype(dT)),)),)) .*
grad_config = ArrayInterface.restructure(du1,
Dual{T, eltype(du1), 1}.(du1,
(ForwardDiff.Partials((one(eltype(du1)),)),)) .*
false)
end
elseif alg_autodiff(alg) isa AutoFiniteDiff
grad_config = FiniteDiff.GradientCache(dT, t, alg_difftype(alg))
grad_config = FiniteDiff.GradientCache(du1, t, alg_difftype(alg))
else
error("$alg_autodiff not yet supported in build_grad_config function")
end
Expand Down
2 changes: 1 addition & 1 deletion src/generic_rosenbrock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ function gen_algcache(cacheexpr::Expr,constcachename::Symbol,algname::Symbol,tab
linsolve = init(linprob,alg.linsolve,alias_A=true,alias_b=true,
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
Pr = Diagonal(_vec(weight)))
grad_config = build_grad_config(alg,f,tf,dT,t)
grad_config = build_grad_config(alg,f,tf,du1,t)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,du2)
$cachename($(valsyms...))
end
Expand Down

0 comments on commit b720c86

Please sign in to comment.