From 7ecc05cfce39af6808c43115d2186fb7bbbbc68a Mon Sep 17 00:00:00 2001 From: jClugstor Date: Thu, 7 Nov 2024 11:34:30 -0500 Subject: [PATCH] fix prepare_alg --- .../src/alg_utils.jl | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl b/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl index 76def40efd..391018191c 100644 --- a/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl +++ b/lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl @@ -50,15 +50,21 @@ function DiffEqBase.prepare_alg( # If not using autodiff or norecompile mode or very large bitsize (like a dual number u0 already) # don't use a large chunksize as it will either error or not be beneficial - if nameof(alg_autodiff(alg)) == :AutoForwardDiff - if !(isbitstype(T) && sizeof(T) > 24) || - (prob.f isa ODEFunction && - prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper) - return remake(alg, autodiff = constructorof(alg_autodiff(alg))(chunksize = 1, tag = _get_fwd_tag(alg_autodiff(alg)))) + # If prob.f.f is a FunctionWrappersWrappers from ODEFunction, need to set chunksize to 1 + + if nameof(alg_autodiff(alg)) == :AutoForwardDiff && ((prob.f isa ODEFunction && + prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper) || (isbitstype(T) && sizeof(T) > 24)) + return remake(alg, autodiff = constructorof(alg_autodiff(alg))(chunksize = 1, tag = _get_fwd_tag(alg_autodiff(alg)))) + end + + # If the autodiff alg is AutoFiniteDiff, prob.f.f isa FunctionWrappersWrapper, + # and fdtype is complex, fdtype needs to change to something not complex + if nameof(alg_autodiff(alg)) == :AutoFiniteDiff + if alg_difftype(alg) == Val{:complex} && (prob.f isa ODEFunction && prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper) + @warn "AutoFiniteDiff fdtype complex is not compatible with this function" + return remake(alg, autodiff = constructorof(alg_autodiff(alg))(fdtype = Val{:forward}())) end return alg - else - return alg end L = StaticArrayInterface.known_length(typeof(u0)) @@ -77,6 +83,7 @@ function DiffEqBase.prepare_alg( chunksize = cs, tag = _get_fwd_tag(alg_autodiff(alg)))) else # statically sized cs = pick_static_chunksize(Val{L}()) + cs = SciMLBase._unwrap_val(cs) return remake( alg, autodiff = constructorof(alg_autodiff(alg))( chunksize = cs, tag = _get_fwd_tag(alg_autodiff(alg))))