diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index e5c2c2bcc..a8a28749a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -24,6 +24,7 @@ jobs: - Core5 - Core6 - Core7 + - DiffEq - SDE1 - SDE2 - SDE3 diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index c1eff0e60..7c356e58d 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -104,7 +104,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f if !needs_jac && !issemiexplicitdae && !(autojacvec isa Bool) J = nothing else - if SciMLBase.forwarddiffs_model_time(alg) + if alg === nothing || SciMLBase.forwarddiffs_model_time(alg) # 1 chunk is fine because it's only t J = dualcache(similar(u0, numindvar, numindvar), ForwardDiff.pickchunksize(length(u0))) diff --git a/test/diffeq/Project.toml b/test/diffeq/Project.toml new file mode 100644 index 000000000..6dde5fc89 --- /dev/null +++ b/test/diffeq/Project.toml @@ -0,0 +1,5 @@ +[deps] +DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" + +[compat] +DifferentialEquations = "7" diff --git a/test/diffeq/default_alg_diff.jl b/test/diffeq/default_alg_diff.jl new file mode 100644 index 000000000..7387e3b77 --- /dev/null +++ b/test/diffeq/default_alg_diff.jl @@ -0,0 +1,21 @@ +using ComponentArrays, DifferentialEquations, Lux, Random, SciMLSensitivity, Zygote + +function f(du, u, p, t) + du .= first(nn(u, p, st)) + nothing +end + +nn = Dense(8, 8, tanh) +ps, st = Lux.setup(Random.default_rng(), nn) +ps = ComponentArray(ps) + +r = rand(Float32, 8, 64) + +function f2(x) + prob = ODEProblem(f, r, (0.0f0, 1.0f0), x) + sol = solve(prob; sensealg = InterpolatingAdjoint(; autodiff = true, autojacvec = true)) + sum(last(sol.u)) +end + +f2(ps) +Zygote.gradient(f2, ps) diff --git a/test/runtests.jl b/test/runtests.jl index 12c63705d..8228d6ffd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,12 @@ function activate_gpu_env() Pkg.instantiate() end +function activate_diffeq_env() + Pkg.activate("diffeq") + Pkg.develop(PackageSpec(path = dirname(@__DIR__))) + Pkg.instantiate() +end + @time @testset "SciMLSensitivity" begin if GROUP == "All" || GROUP == "Core1" || GROUP == "Downstream" @testset "Core1" begin @@ -144,6 +150,13 @@ end end end + if GROUP == "DiffEq" + @testset "DiffEq" begin + activate_gpu_env() + @time @safetestset "Default DiffEq Alg" include("diffeq/default_alg_diff.jl") + end + end + if GROUP == "GPU" @testset "GPU" begin activate_gpu_env()