From 2dc6dbe841467d1502aa865bcd957fa78a6e848e Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 28 Jun 2024 13:27:05 +0530 Subject: [PATCH 01/19] chore: add fwd sensitivity --- src/concrete_solve.jl | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 62e9d2de0..64d01cd06 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -773,9 +773,17 @@ function DiffEqBase._concrete_solve_adjoint( u0, p, originator::SciMLBase.ADOriginator, args...; saveat = eltype(prob.tspan)[], kwargs...) where {CS, CTS} - if !(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || - (p isa AbstractArray && !Base.isconcretetype(eltype(p))) - throw(ForwardDiffSensitivityParameterCompatibilityError()) + # if !(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || + # (p isa AbstractArray && !Base.isconcretetype(eltype(p))) + # throw(ForwardDiffSensitivityParameterCompatibilityError()) + # end + + if p === nothing || p isa SciMLBase.NullParameters + tunables, repack = p, identity + elseif isscimlstructure(p) + tunables, repack, _ = canonicalize(Tunable(), p) + else + throw(SciMLStructuresCompatibilityError()) end if saveat isa Number @@ -806,12 +814,12 @@ function DiffEqBase._concrete_solve_adjoint( num_chunks = length(p) ÷ chunk_size num_chunks * chunk_size != length(p) && (num_chunks += 1) - pparts = typeof(p[1:1])[] + pparts = typeof(tunables[1:1])[] for j in 0:(num_chunks - 1) local chunk if ((j + 1) * chunk_size) <= length(p) chunk = ((j * chunk_size + 1):((j + 1) * chunk_size)) - pchunk = vec(p)[chunk] + pchunk = vec(tunables)[chunk] pdualpart = seed_duals(pchunk, prob.f, ForwardDiff.Chunk{chunk_size}()) else @@ -822,7 +830,7 @@ function DiffEqBase._concrete_solve_adjoint( end pdualvec = if j == 0 - vcat(pdualpart, p[((j + 1) * chunk_size + 1):end]) + vcat(pdualpart, tunables[((j + 1) * chunk_size + 1):end]) elseif j == num_chunks - 1 vcat(p[1:(j * chunk_size)], pdualpart) else @@ -830,7 +838,7 @@ function DiffEqBase._concrete_solve_adjoint( p[(((j + 1) * chunk_size) + 1):end]) end - pdual = ArrayInterface.restructure(p, pdualvec) + pdual = ArrayInterface.restructure(tunables, pdualvec) u0dual = convert.(eltype(pdualvec), u0) if (convert_tspan(sensealg) === nothing && @@ -927,7 +935,7 @@ function DiffEqBase._concrete_solve_adjoint( end push!(pparts, vec(_dp)) end - ArrayInterface.restructure(p, reduce(vcat, pparts)) + ArrayInterface.restructure(tunables, reduce(vcat, pparts)) end else dp = nothing @@ -982,9 +990,9 @@ function DiffEqBase._concrete_solve_adjoint( end if p === nothing || p === DiffEqBase.NullParameters() - pdual = p + pdual = tunables else - pdual = convert.(eltype(u0dual), p) + pdual = convert.(eltype(u0dual), tunables) end if (convert_tspan(sensealg) === nothing && From f52dcaf3ddcdcfe9ac98a7613837d39ef3baea89 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 25 Jul 2024 18:39:43 +0530 Subject: [PATCH 02/19] chore: replace p with tunables --- src/concrete_solve.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 64d01cd06..c0a9636d3 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -812,7 +812,7 @@ function DiffEqBase._concrete_solve_adjoint( end num_chunks = length(p) ÷ chunk_size - num_chunks * chunk_size != length(p) && (num_chunks += 1) + num_chunks * chunk_size != length(tunables) && (num_chunks += 1) pparts = typeof(tunables[1:1])[] for j in 0:(num_chunks - 1) @@ -823,8 +823,8 @@ function DiffEqBase._concrete_solve_adjoint( pdualpart = seed_duals(pchunk, prob.f, ForwardDiff.Chunk{chunk_size}()) else - chunk = ((j * chunk_size + 1):length(p)) - pchunk = vec(p)[chunk] + chunk = ((j * chunk_size + 1):length(tunables)) + pchunk = vec(tunables)[chunk] pdualpart = seed_duals(pchunk, prob.f, ForwardDiff.Chunk{length(chunk)}()) end @@ -832,10 +832,10 @@ function DiffEqBase._concrete_solve_adjoint( pdualvec = if j == 0 vcat(pdualpart, tunables[((j + 1) * chunk_size + 1):end]) elseif j == num_chunks - 1 - vcat(p[1:(j * chunk_size)], pdualpart) + vcat(tunables[1:(j * chunk_size)], pdualpart) else - vcat(p[1:(j * chunk_size)], pdualpart, - p[(((j + 1) * chunk_size) + 1):end]) + vcat(tunables[1:(j * chunk_size)], pdualpart, + tunables[(((j + 1) * chunk_size) + 1):end]) end pdual = ArrayInterface.restructure(tunables, pdualvec) From 9544b64a2c1cf6fa05cc91a56e17a263bc17f77d Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 25 Jul 2024 18:46:14 +0530 Subject: [PATCH 03/19] chore: update p -> tunables --- src/concrete_solve.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index c0a9636d3..a17d42c7d 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -803,15 +803,15 @@ function DiffEqBase._concrete_solve_adjoint( function forward_sensitivity_backpass(Δ) if !(p === nothing || p === DiffEqBase.NullParameters()) dp = @thunk begin - chunk_size = if CS === 0 && length(p) < 12 - length(p) + chunk_size = if CS === 0 && length(tunables) < 12 + length(tunables) elseif CS !== 0 CS else 12 end - num_chunks = length(p) ÷ chunk_size + num_chunks = length(tunables) ÷ chunk_size num_chunks * chunk_size != length(tunables) && (num_chunks += 1) pparts = typeof(tunables[1:1])[] From 0bbc1332b7d24c46df93b2df800c7f50cd04329b Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 26 Jul 2024 07:08:43 +0530 Subject: [PATCH 04/19] chore: rm dead code --- src/concrete_solve.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index a17d42c7d..b09e4362a 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -773,10 +773,6 @@ function DiffEqBase._concrete_solve_adjoint( u0, p, originator::SciMLBase.ADOriginator, args...; saveat = eltype(prob.tspan)[], kwargs...) where {CS, CTS} - # if !(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || - # (p isa AbstractArray && !Base.isconcretetype(eltype(p))) - # throw(ForwardDiffSensitivityParameterCompatibilityError()) - # end if p === nothing || p isa SciMLBase.NullParameters tunables, repack = p, identity From d53b48fb2bb44e2bca0a90ab71aea07067e76e57 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 26 Jul 2024 17:52:37 +0530 Subject: [PATCH 05/19] chore: replace restructure with replace --- src/concrete_solve.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index b09e4362a..7624b0237 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -834,7 +834,8 @@ function DiffEqBase._concrete_solve_adjoint( tunables[(((j + 1) * chunk_size) + 1):end]) end - pdual = ArrayInterface.restructure(tunables, pdualvec) + # pdual = ArrayInterface.restructure(tunables, pdualvec) + pdual = SciMLStructures.replace(Tunable(), p, pdualvec) u0dual = convert.(eltype(pdualvec), u0) if (convert_tspan(sensealg) === nothing && @@ -931,7 +932,8 @@ function DiffEqBase._concrete_solve_adjoint( end push!(pparts, vec(_dp)) end - ArrayInterface.restructure(tunables, reduce(vcat, pparts)) + # ArrayInterface.restructure(tunables, reduce(vcat, pparts)) + SciMLStructures.replace(Tunable(), p, reduce(vcat, pparts)) end else dp = nothing From 87d3deb5c9385b261f2c7866c17ffee8d5c123ef Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 26 Jul 2024 17:56:47 +0530 Subject: [PATCH 06/19] chore: rm comments --- src/concrete_solve.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 7624b0237..b87cd329d 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -834,7 +834,6 @@ function DiffEqBase._concrete_solve_adjoint( tunables[(((j + 1) * chunk_size) + 1):end]) end - # pdual = ArrayInterface.restructure(tunables, pdualvec) pdual = SciMLStructures.replace(Tunable(), p, pdualvec) u0dual = convert.(eltype(pdualvec), u0) @@ -932,7 +931,6 @@ function DiffEqBase._concrete_solve_adjoint( end push!(pparts, vec(_dp)) end - # ArrayInterface.restructure(tunables, reduce(vcat, pparts)) SciMLStructures.replace(Tunable(), p, reduce(vcat, pparts)) end else From e65be4f678ab28b769beea5e2f212497593e7079 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 29 Jul 2024 12:42:06 +0530 Subject: [PATCH 07/19] chore: remake adjoint problem with replace'd p --- src/concrete_solve.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index efaf181b1..a51b75202 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -868,7 +868,8 @@ function DiffEqBase._concrete_solve_adjoint( else _f = prob.f end - _prob = remake(prob, f = _f, u0 = u0dual, p = pdual, tspan = tspandual) + pdual_structure = SciMLStructures.replace(Tunable(), p, pdual) + _prob = remake(prob, f = _f, u0 = u0dual, p = pdual_structure, tspan = tspandual) if _prob isa SDEProblem _prob.noise_rate_prototype !== nothing && (_prob = remake(_prob, @@ -1023,7 +1024,8 @@ function DiffEqBase._concrete_solve_adjoint( _f = prob.f end - _prob = remake(prob, f = _f, u0 = u0dual, p = pdual, tspan = tspandual) + pdual_structure = SciMLStructures.replace(Tunable(), p, pdual) + _prob = remake(prob, f = _f, u0 = u0dual, p = pdual_structure, tspan = tspandual) if _prob isa SDEProblem _prob.noise_rate_prototype !== nothing && (_prob = remake(_prob, From e1bbd25a957fc621e07ebf8fe85abb066a3fffba Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 29 Jul 2024 15:30:12 +0530 Subject: [PATCH 08/19] chore: repack instead of replace --- src/concrete_solve.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index a51b75202..70977b932 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -869,7 +869,7 @@ function DiffEqBase._concrete_solve_adjoint( _f = prob.f end pdual_structure = SciMLStructures.replace(Tunable(), p, pdual) - _prob = remake(prob, f = _f, u0 = u0dual, p = pdual_structure, tspan = tspandual) + _prob = remake(prob, f = _f, u0 = u0dual, p = repack(pdual), tspan = tspandual) if _prob isa SDEProblem _prob.noise_rate_prototype !== nothing && (_prob = remake(_prob, @@ -1025,7 +1025,7 @@ function DiffEqBase._concrete_solve_adjoint( end pdual_structure = SciMLStructures.replace(Tunable(), p, pdual) - _prob = remake(prob, f = _f, u0 = u0dual, p = pdual_structure, tspan = tspandual) + _prob = remake(prob, f = _f, u0 = u0dual, p = repack(pdual), tspan = tspandual) if _prob isa SDEProblem _prob.noise_rate_prototype !== nothing && (_prob = remake(_prob, From 3f2ea4e40ee8ecec395988b1ce55d7c5d94abdcc Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 29 Jul 2024 19:12:38 +0530 Subject: [PATCH 09/19] chore: update callbacks forward mode path to SciMLSensitivity --- src/concrete_solve.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 70977b932..14a915562 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -661,12 +661,15 @@ function DiffEqBase._concrete_solve_adjoint(prob::SciMLBase.AbstractODEProblem, args...; save_idxs = nothing, kwargs...) - if !(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || - (p isa AbstractArray && !Base.isconcretetype(eltype(p))) - throw(ForwardSensitivityParameterCompatibilityError()) + if p === nothing || p isa SciMLBase.NullParameters + tunables, repack = p, identity + elseif isscimlstructure(p) + tunables, repack, _ = canonicalize(Tunable(), p) + else + throw(SciMLStructuresCompatibilityError()) end - if p isa AbstractArray && eltype(p) <: ForwardDiff.Dual && + if tunables isa AbstractArray && eltype(tunables) <: ForwardDiff.Dual && !(eltype(u0) <: ForwardDiff.Dual) # Handle double differentiation case u0 = eltype(p).(u0) From cfdacbf9ce5fcb040169dfbc796c310896c93960 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 1 Aug 2024 20:26:09 +0530 Subject: [PATCH 10/19] test: check correct error type --- test/parameter_compatibility_errors.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/parameter_compatibility_errors.jl b/test/parameter_compatibility_errors.jl index a3c48db55..acdad9f13 100644 --- a/test/parameter_compatibility_errors.jl +++ b/test/parameter_compatibility_errors.jl @@ -32,7 +32,7 @@ end grad(p) = Zygote.gradient(loss, p) p2 = [4; 5; 6] -@test_throws SciMLSensitivity.ForwardDiffSensitivityParameterCompatibilityError grad(p2) +@test_throws SciMLSensitivity.SciMLStructuresCompatibilityError grad(p2) function loss(p1) sol = solve(prob, Tsit5(), p = [p1, mystruct(-1, -2), control], @@ -48,7 +48,7 @@ function loss(p1) return sum(abs2, sol) end -@test_throws SciMLSensitivity.ForwardSensitivityParameterCompatibilityError grad(p2) +@test_throws SciMLSensitivity.SciMLStructuresCompatibilityError grad(p2) @test_throws SciMLSensitivity.ForwardSensitivityParameterCompatibilityError ODEForwardSensitivityProblem( f!, u0, From 458d738a9a6b588f4c6781eaa1f36cca5e970b4e Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 2 Aug 2024 13:05:21 +0530 Subject: [PATCH 11/19] chore: remove pdual_structure until UDEs --- src/concrete_solve.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 14a915562..81e40b7b8 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -871,7 +871,6 @@ function DiffEqBase._concrete_solve_adjoint( else _f = prob.f end - pdual_structure = SciMLStructures.replace(Tunable(), p, pdual) _prob = remake(prob, f = _f, u0 = u0dual, p = repack(pdual), tspan = tspandual) if _prob isa SDEProblem @@ -1027,7 +1026,6 @@ function DiffEqBase._concrete_solve_adjoint( _f = prob.f end - pdual_structure = SciMLStructures.replace(Tunable(), p, pdual) _prob = remake(prob, f = _f, u0 = u0dual, p = repack(pdual), tspan = tspandual) if _prob isa SDEProblem From 0fe619d8e25a7d90677567c3819643fa05f8de3b Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 2 Aug 2024 19:45:30 +0530 Subject: [PATCH 12/19] chore: format --- src/concrete_solve.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 81e40b7b8..5ecad1287 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -778,7 +778,6 @@ function DiffEqBase._concrete_solve_adjoint( u0, p, originator::SciMLBase.ADOriginator, args...; saveat = eltype(prob.tspan)[], kwargs...) where {CS, CTS} - if p === nothing || p isa SciMLBase.NullParameters tunables, repack = p, identity elseif isscimlstructure(p) @@ -871,7 +870,8 @@ function DiffEqBase._concrete_solve_adjoint( else _f = prob.f end - _prob = remake(prob, f = _f, u0 = u0dual, p = repack(pdual), tspan = tspandual) + _prob = remake( + prob, f = _f, u0 = u0dual, p = repack(pdual), tspan = tspandual) if _prob isa SDEProblem _prob.noise_rate_prototype !== nothing && (_prob = remake(_prob, @@ -1026,7 +1026,8 @@ function DiffEqBase._concrete_solve_adjoint( _f = prob.f end - _prob = remake(prob, f = _f, u0 = u0dual, p = repack(pdual), tspan = tspandual) + _prob = remake( + prob, f = _f, u0 = u0dual, p = repack(pdual), tspan = tspandual) if _prob isa SDEProblem _prob.noise_rate_prototype !== nothing && (_prob = remake(_prob, From c16fc12a0204cd73dd199268b20c7510075bdc81 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Fri, 9 Aug 2024 02:57:11 +0530 Subject: [PATCH 13/19] chore: try with replace again --- src/concrete_solve.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 5ecad1287..aa69d9373 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -871,7 +871,7 @@ function DiffEqBase._concrete_solve_adjoint( _f = prob.f end _prob = remake( - prob, f = _f, u0 = u0dual, p = repack(pdual), tspan = tspandual) + prob, f = _f, u0 = u0dual, p = pdual, tspan = tspandual) if _prob isa SDEProblem _prob.noise_rate_prototype !== nothing && (_prob = remake(_prob, @@ -1027,7 +1027,7 @@ function DiffEqBase._concrete_solve_adjoint( end _prob = remake( - prob, f = _f, u0 = u0dual, p = repack(pdual), tspan = tspandual) + prob, f = _f, u0 = u0dual, p = SciMLStructures.replace(Tunable(), p, pdual), tspan = tspandual) if _prob isa SDEProblem _prob.noise_rate_prototype !== nothing && (_prob = remake(_prob, From 4669ecc329cbf3b4987cfad9dc0eb56c54ae79c9 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 12 Aug 2024 18:38:09 +0530 Subject: [PATCH 14/19] test: add MTK test --- Project.toml | 3 ++- test/mtk.jl | 37 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 test/mtk.jl diff --git a/Project.toml b/Project.toml index caab92930..5cf98d255 100644 --- a/Project.toml +++ b/Project.toml @@ -112,6 +112,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" @@ -124,4 +125,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"] +test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"] diff --git a/test/mtk.jl b/test/mtk.jl new file mode 100644 index 000000000..1e9c9f5ee --- /dev/null +++ b/test/mtk.jl @@ -0,0 +1,37 @@ +using ModelingToolkit, OrdinaryDiffEq +using OrdinaryDiffEq +using SciMLSensitivity +using ForwardDiff + +@parameters σ ρ β A[1:3] +@variables x(t) y(t) z(t) w(t) w2(t) + +eqs = [D(D(x)) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z, + w ~ x + y + z + 2 * β, + ] + +@mtkbuild sys = ODESystem(eqs, t) + +u0 = [D(x) => 2.0, + x => 1.0, + y => 0.0, + z => 0.0,] + +p = [σ => 28.0, + ρ => 10.0, + β => 8 / 3,] + # A => ones(3),] + +tspan = (0.0, 100.0) +prob = ODEProblem(sys, u0, tspan, p, jac = true) +sol = solve(prob, Tsit5()) + + +gt = rand(5501) +dmtk, = gradient(mtkparams) do p + new_sol = solve(prob, Rosenbrock23(), p = p) + mean(abs.(new_sol[sys.x] .- gt)) +end + diff --git a/test/runtests.jl b/test/runtests.jl index d26d6789a..d42a4f6ae 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ end if GROUP == "All" || GROUP == "Core1" || GROUP == "Downstream" @testset "Core1" begin @time @safetestset "Forward Sensitivity" include("forward.jl") + @time @safetestset "MTK Forward Mode" include("mtk.jl") @time @safetestset "Sparse Adjoint Sensitivity" include("sparse_adjoint.jl") @time @safetestset "Adjoint Shapes" include("adjoint_shapes.jl") @time @safetestset "Second Order Sensitivity" include("second_order.jl") From da1f2ddac81fc2a60094648fc00277a58bef49f8 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 12 Aug 2024 19:15:45 +0530 Subject: [PATCH 15/19] test: add necesary ipmorts --- test/mtk.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/mtk.jl b/test/mtk.jl index 1e9c9f5ee..f832cd8b8 100644 --- a/test/mtk.jl +++ b/test/mtk.jl @@ -1,4 +1,5 @@ using ModelingToolkit, OrdinaryDiffEq +using ModelingToolkit: t_nounits as t, D_nounits as D using OrdinaryDiffEq using SciMLSensitivity using ForwardDiff From b71fac227f2b77610cb51f145ad78595bcc35927 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 12 Aug 2024 19:40:00 +0530 Subject: [PATCH 16/19] chore: formatting --- src/concrete_solve.jl | 3 ++- test/mtk.jl | 12 +++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 030ac7639..e044c8220 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -1033,7 +1033,8 @@ function DiffEqBase._concrete_solve_adjoint( end # use the callback from kwargs, not prob - _prob = remake(prob, f = _f, u0 = u0dual, p = SciMLStructures.replace(Tunable(), p, pdual), + _prob = remake(prob, f = _f, u0 = u0dual, + p = SciMLStructures.replace(Tunable(), p, pdual), tspan = tspandual, callback = nothing) if _prob isa SDEProblem diff --git a/test/mtk.jl b/test/mtk.jl index f832cd8b8..06f37b417 100644 --- a/test/mtk.jl +++ b/test/mtk.jl @@ -10,29 +10,27 @@ using ForwardDiff eqs = [D(D(x)) ~ σ * (y - x), D(y) ~ x * (ρ - z) - y, D(z) ~ x * y - β * z, - w ~ x + y + z + 2 * β, - ] + w ~ x + y + z + 2 * β +] @mtkbuild sys = ODESystem(eqs, t) u0 = [D(x) => 2.0, x => 1.0, y => 0.0, - z => 0.0,] + z => 0.0] p = [σ => 28.0, ρ => 10.0, - β => 8 / 3,] - # A => ones(3),] + β => 8 / 3] +# A => ones(3),] tspan = (0.0, 100.0) prob = ODEProblem(sys, u0, tspan, p, jac = true) sol = solve(prob, Tsit5()) - gt = rand(5501) dmtk, = gradient(mtkparams) do p new_sol = solve(prob, Rosenbrock23(), p = p) mean(abs.(new_sol[sys.x] .- gt)) end - From 517c496b448c34f35742f4de174998a1f0887cc6 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 12 Aug 2024 20:26:45 +0530 Subject: [PATCH 17/19] test: qualify gradient --- test/mtk.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/mtk.jl b/test/mtk.jl index 06f37b417..0b6f213b4 100644 --- a/test/mtk.jl +++ b/test/mtk.jl @@ -3,6 +3,7 @@ using ModelingToolkit: t_nounits as t, D_nounits as D using OrdinaryDiffEq using SciMLSensitivity using ForwardDiff +using Zygote @parameters σ ρ β A[1:3] @variables x(t) y(t) z(t) w(t) w2(t) @@ -30,7 +31,7 @@ prob = ODEProblem(sys, u0, tspan, p, jac = true) sol = solve(prob, Tsit5()) gt = rand(5501) -dmtk, = gradient(mtkparams) do p +dmtk, = Zygote.gradient(mtkparams) do p new_sol = solve(prob, Rosenbrock23(), p = p) mean(abs.(new_sol[sys.x] .- gt)) end From 3f8c62f1d8fdba0be5d0b609494e0a3877dc47dc Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 12 Aug 2024 20:58:23 +0530 Subject: [PATCH 18/19] test: qualify mtkparams --- test/mtk.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/mtk.jl b/test/mtk.jl index 0b6f213b4..90fb2283a 100644 --- a/test/mtk.jl +++ b/test/mtk.jl @@ -29,6 +29,7 @@ p = [σ => 28.0, tspan = (0.0, 100.0) prob = ODEProblem(sys, u0, tspan, p, jac = true) sol = solve(prob, Tsit5()) +mtkparams = SciMLSensitivity.parameter_values(sol) gt = rand(5501) dmtk, = Zygote.gradient(mtkparams) do p From 0e212eaa37eecc97bb228b8646cccff07ce46dcf Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Mon, 12 Aug 2024 21:43:13 +0530 Subject: [PATCH 19/19] test: add Statistics import --- test/mtk.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/mtk.jl b/test/mtk.jl index 90fb2283a..b42c14883 100644 --- a/test/mtk.jl +++ b/test/mtk.jl @@ -4,6 +4,7 @@ using OrdinaryDiffEq using SciMLSensitivity using ForwardDiff using Zygote +using Statistics @parameters σ ρ β A[1:3] @variables x(t) y(t) z(t) w(t) w2(t)