Skip to content

Commit

Permalink
change prepare_alg
Browse files Browse the repository at this point in the history
  • Loading branch information
jClugstor committed Nov 5, 2024
1 parent df1340a commit 198d795
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
22 changes: 15 additions & 7 deletions lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,15 @@ function DiffEqBase.prepare_alg(

# If not using autodiff or norecompile mode or very large bitsize (like a dual number u0 already)
# don't use a large chunksize as it will either error or not be beneficial
if !(nameof(alg_autodiff(alg)) == :AutoForwardDiff) ||
(isbitstype(T) && sizeof(T) > 24) ||
(prob.f isa ODEFunction &&
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
return remake(alg, chunk_size = Val{1}())
if nameof(alg_autodiff(alg)) == :AutoForwardDiff
if !(isbitstype(T) && sizeof(T) > 24) ||
(prob.f isa ODEFunction &&
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
return remake(alg, autodiff = constructorof(alg_autodiff(alg))(chunksize = 1, tag = _get_fwd_tag(alg_autodiff(alg))))
end
return alg
else
return alg
end

L = StaticArrayInterface.known_length(typeof(u0))
Expand All @@ -68,10 +72,14 @@ function DiffEqBase.prepare_alg(
end

cs = ForwardDiff.pickchunksize(x)
return remake(alg, chunk_size = Val{cs}())
return remake(alg,
autodiff = constructorof(alg_autodiff(alg))(
chunksize = cs, tag = _get_fwd_tag(alg_autodiff(alg))))
else # statically sized
cs = pick_static_chunksize(Val{L}())
return remake(alg, chunk_size = cs)
return remake(
alg, autodiff = constructorof(alg_autodiff(alg))(
chunksize = cs, tag = _get_fwd_tag(alg_autodiff(alg))))
end
end

Expand Down
1 change: 0 additions & 1 deletion lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ function calc_J!(J, integrator, cache, next_step::Bool = false)
f.jac(J, uprev, p, t)
else
@unpack du1, uf, jac_config = cache
println(typeof(cache))
uf.f = nlsolve_f(f, alg)
uf.t = t
if !(p isa DiffEqBase.NullParameters)
Expand Down
1 change: 0 additions & 1 deletion lib/OrdinaryDiffEqRosenbrock/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ for Alg in [
stage_limiter! = trivial_limiter!)

AD_choice = _process_AD_choice(autodiff, chunk_size, diff_type)

$Alg{_unwrap_val(chunk_size), AD_choice, typeof(linsolve),
typeof(precs), diff_type, _unwrap_val(standardtag),
_unwrap_val(concrete_jac), typeof(step_limiter!),
Expand Down

0 comments on commit 198d795

Please sign in to comment.