diff --git a/Project.toml b/Project.toml index caab92930..5cf98d255 100644 --- a/Project.toml +++ b/Project.toml @@ -112,6 +112,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" @@ -124,4 +125,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"] +test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"] diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 9145cb3bb..e044c8220 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -661,12 +661,15 @@ function DiffEqBase._concrete_solve_adjoint(prob::SciMLBase.AbstractODEProblem, args...; save_idxs = nothing, kwargs...) - if !(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || - (p isa AbstractArray && !Base.isconcretetype(eltype(p))) - throw(ForwardSensitivityParameterCompatibilityError()) + if p === nothing || p isa SciMLBase.NullParameters + tunables, repack = p, identity + elseif isscimlstructure(p) + tunables, repack, _ = canonicalize(Tunable(), p) + else + throw(SciMLStructuresCompatibilityError()) end - if p isa AbstractArray && eltype(p) <: ForwardDiff.Dual && + if tunables isa AbstractArray && eltype(tunables) <: ForwardDiff.Dual && !(eltype(u0) <: ForwardDiff.Dual) # Handle double differentiation case u0 = eltype(p).(u0) @@ -778,9 +781,12 @@ function DiffEqBase._concrete_solve_adjoint( u0, p, originator::SciMLBase.ADOriginator, args...; saveat = eltype(prob.tspan)[], kwargs...) where {CS, CTS} - if !(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || - (p isa AbstractArray && !Base.isconcretetype(eltype(p))) - throw(ForwardDiffSensitivityParameterCompatibilityError()) + if p === nothing || p isa SciMLBase.NullParameters + tunables, repack = p, identity + elseif isscimlstructure(p) + tunables, repack, _ = canonicalize(Tunable(), p) + else + throw(SciMLStructuresCompatibilityError()) end if saveat isa Number @@ -802,42 +808,42 @@ function DiffEqBase._concrete_solve_adjoint( function forward_sensitivity_backpass(Δ) if !(p === nothing || p === DiffEqBase.NullParameters()) dp = @thunk begin - chunk_size = if CS === 0 && length(p) < 12 - length(p) + chunk_size = if CS === 0 && length(tunables) < 12 + length(tunables) elseif CS !== 0 CS else 12 end - num_chunks = length(p) ÷ chunk_size - num_chunks * chunk_size != length(p) && (num_chunks += 1) + num_chunks = length(tunables) ÷ chunk_size + num_chunks * chunk_size != length(tunables) && (num_chunks += 1) - pparts = typeof(p[1:1])[] + pparts = typeof(tunables[1:1])[] for j in 0:(num_chunks - 1) local chunk if ((j + 1) * chunk_size) <= length(p) chunk = ((j * chunk_size + 1):((j + 1) * chunk_size)) - pchunk = vec(p)[chunk] + pchunk = vec(tunables)[chunk] pdualpart = seed_duals(pchunk, prob.f, ForwardDiff.Chunk{chunk_size}()) else - chunk = ((j * chunk_size + 1):length(p)) - pchunk = vec(p)[chunk] + chunk = ((j * chunk_size + 1):length(tunables)) + pchunk = vec(tunables)[chunk] pdualpart = seed_duals(pchunk, prob.f, ForwardDiff.Chunk{length(chunk)}()) end pdualvec = if j == 0 - vcat(pdualpart, p[((j + 1) * chunk_size + 1):end]) + vcat(pdualpart, tunables[((j + 1) * chunk_size + 1):end]) elseif j == num_chunks - 1 - vcat(p[1:(j * chunk_size)], pdualpart) + vcat(tunables[1:(j * chunk_size)], pdualpart) else - vcat(p[1:(j * chunk_size)], pdualpart, - p[(((j + 1) * chunk_size) + 1):end]) + vcat(tunables[1:(j * chunk_size)], pdualpart, + tunables[(((j + 1) * chunk_size) + 1):end]) end - pdual = ArrayInterface.restructure(p, pdualvec) + pdual = SciMLStructures.replace(Tunable(), p, pdualvec) u0dual = convert.(eltype(pdualvec), u0) if (convert_tspan(sensealg) === nothing && @@ -869,7 +875,6 @@ function DiffEqBase._concrete_solve_adjoint( else _f = prob.f end - # use the callback from kwargs, not prob _prob = remake(prob, f = _f, u0 = u0dual, p = pdual, tspan = tspandual, callback = nothing) @@ -937,7 +942,7 @@ function DiffEqBase._concrete_solve_adjoint( end push!(pparts, vec(_dp)) end - ArrayInterface.restructure(p, reduce(vcat, pparts)) + SciMLStructures.replace(Tunable(), p, reduce(vcat, pparts)) end else dp = nothing @@ -992,9 +997,9 @@ function DiffEqBase._concrete_solve_adjoint( end if p === nothing || p === DiffEqBase.NullParameters() - pdual = p + pdual = tunables else - pdual = convert.(eltype(u0dual), p) + pdual = convert.(eltype(u0dual), tunables) end if (convert_tspan(sensealg) === nothing && @@ -1028,7 +1033,8 @@ function DiffEqBase._concrete_solve_adjoint( end # use the callback from kwargs, not prob - _prob = remake(prob, f = _f, u0 = u0dual, p = pdual, + _prob = remake(prob, f = _f, u0 = u0dual, + p = SciMLStructures.replace(Tunable(), p, pdual), tspan = tspandual, callback = nothing) if _prob isa SDEProblem diff --git a/test/mtk.jl b/test/mtk.jl new file mode 100644 index 000000000..b42c14883 --- /dev/null +++ b/test/mtk.jl @@ -0,0 +1,39 @@ +using ModelingToolkit, OrdinaryDiffEq +using ModelingToolkit: t_nounits as t, D_nounits as D +using OrdinaryDiffEq +using SciMLSensitivity +using ForwardDiff +using Zygote +using Statistics + +@parameters σ ρ β A[1:3] +@variables x(t) y(t) z(t) w(t) w2(t) + +eqs = [D(D(x)) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z, + w ~ x + y + z + 2 * β +] + +@mtkbuild sys = ODESystem(eqs, t) + +u0 = [D(x) => 2.0, + x => 1.0, + y => 0.0, + z => 0.0] + +p = [σ => 28.0, + ρ => 10.0, + β => 8 / 3] +# A => ones(3),] + +tspan = (0.0, 100.0) +prob = ODEProblem(sys, u0, tspan, p, jac = true) +sol = solve(prob, Tsit5()) +mtkparams = SciMLSensitivity.parameter_values(sol) + +gt = rand(5501) +dmtk, = Zygote.gradient(mtkparams) do p + new_sol = solve(prob, Rosenbrock23(), p = p) + mean(abs.(new_sol[sys.x] .- gt)) +end diff --git a/test/parameter_compatibility_errors.jl b/test/parameter_compatibility_errors.jl index a3c48db55..acdad9f13 100644 --- a/test/parameter_compatibility_errors.jl +++ b/test/parameter_compatibility_errors.jl @@ -32,7 +32,7 @@ end grad(p) = Zygote.gradient(loss, p) p2 = [4; 5; 6] -@test_throws SciMLSensitivity.ForwardDiffSensitivityParameterCompatibilityError grad(p2) +@test_throws SciMLSensitivity.SciMLStructuresCompatibilityError grad(p2) function loss(p1) sol = solve(prob, Tsit5(), p = [p1, mystruct(-1, -2), control], @@ -48,7 +48,7 @@ function loss(p1) return sum(abs2, sol) end -@test_throws SciMLSensitivity.ForwardSensitivityParameterCompatibilityError grad(p2) +@test_throws SciMLSensitivity.SciMLStructuresCompatibilityError grad(p2) @test_throws SciMLSensitivity.ForwardSensitivityParameterCompatibilityError ODEForwardSensitivityProblem( f!, u0, diff --git a/test/runtests.jl b/test/runtests.jl index d26d6789a..d42a4f6ae 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ end if GROUP == "All" || GROUP == "Core1" || GROUP == "Downstream" @testset "Core1" begin @time @safetestset "Forward Sensitivity" include("forward.jl") + @time @safetestset "MTK Forward Mode" include("mtk.jl") @time @safetestset "Sparse Adjoint Sensitivity" include("sparse_adjoint.jl") @time @safetestset "Adjoint Shapes" include("adjoint_shapes.jl") @time @safetestset "Second Order Sensitivity" include("second_order.jl")