Skip to content

Commit

Permalink
fix prepare_alg
Browse files Browse the repository at this point in the history
  • Loading branch information
jClugstor committed Nov 7, 2024
1 parent d8a7649 commit 7ecc05c
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,21 @@ function DiffEqBase.prepare_alg(

# If not using autodiff or norecompile mode or very large bitsize (like a dual number u0 already)
# don't use a large chunksize as it will either error or not be beneficial
if nameof(alg_autodiff(alg)) == :AutoForwardDiff
if !(isbitstype(T) && sizeof(T) > 24) ||
(prob.f isa ODEFunction &&
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
return remake(alg, autodiff = constructorof(alg_autodiff(alg))(chunksize = 1, tag = _get_fwd_tag(alg_autodiff(alg))))
# If prob.f.f is a FunctionWrappersWrappers from ODEFunction, need to set chunksize to 1

if nameof(alg_autodiff(alg)) == :AutoForwardDiff && ((prob.f isa ODEFunction &&
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper) || (isbitstype(T) && sizeof(T) > 24))
return remake(alg, autodiff = constructorof(alg_autodiff(alg))(chunksize = 1, tag = _get_fwd_tag(alg_autodiff(alg))))
end

# If the autodiff alg is AutoFiniteDiff, prob.f.f isa FunctionWrappersWrapper,
# and fdtype is complex, fdtype needs to change to something not complex
if nameof(alg_autodiff(alg)) == :AutoFiniteDiff
if alg_difftype(alg) == Val{:complex} && (prob.f isa ODEFunction && prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
@warn "AutoFiniteDiff fdtype complex is not compatible with this function"
return remake(alg, autodiff = constructorof(alg_autodiff(alg))(fdtype = Val{:forward}()))
end
return alg
else
return alg
end

L = StaticArrayInterface.known_length(typeof(u0))
Expand All @@ -77,6 +83,7 @@ function DiffEqBase.prepare_alg(
chunksize = cs, tag = _get_fwd_tag(alg_autodiff(alg))))
else # statically sized
cs = pick_static_chunksize(Val{L}())
cs = SciMLBase._unwrap_val(cs)
return remake(
alg, autodiff = constructorof(alg_autodiff(alg))(
chunksize = cs, tag = _get_fwd_tag(alg_autodiff(alg))))
Expand Down

0 comments on commit 7ecc05c

Please sign in to comment.