-
-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient w.r.t parameters not working with MTKParameters
#1130
Comments
I get a different error complaining about julia> oop_setter(prob, p)
ERROR: TypeError: in validate_parameter_type, in Parameter ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Int64}(SciMLStructures.Tunable(), 3, false), expected ModelingToolkit.ParameterIndex, got a value of type Float64
Stacktrace:
[1] validate_parameter_type(ic::ModelingToolkit.IndexCache, p::ModelingToolkit.ParameterIndex{…}, index::ModelingToolkit.ParameterIndex{…}, val::Float64)
@ ModelingToolkit ~/.julia/packages/ModelingToolkit/hZRaA/src/systems/parameter_buffer.jl:489
[2] remake_buffer(indp::ODESystem, oldbuf::MTKParameters{…}, vals::Dict{…})
@ ModelingToolkit ~/.julia/packages/ModelingToolkit/hZRaA/src/systems/parameter_buffer.jl:529
[3] remake_buffer(sys::ODESystem, oldbuffer::MTKParameters{…}, idxs::Vector{…}, vals::Vector{…})
@ SymbolicIndexingInterface ~/.julia/packages/SymbolicIndexingInterface/cwAFH/src/remake.jl:59
[4] (::SymbolicIndexingInterface.OOPSetter{…})(valp::ODEProblem{…}, val::Vector{…})
@ SymbolicIndexingInterface ~/.julia/packages/SymbolicIndexingInterface/cwAFH/src/parameter_indexing.jl:740
[5] top-level scope
@ REPL[5]:1
Some type information was truncated. Use `show(err)` to see complete types.
|
With julia> function symbolic_indexing(u0, p)
_p = oop_setter(prob, p)
_prob = remake(prob, u0 = u0, p = _p)
soln = solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1,
sensealg = GaussAdjoint(autojacvec = ZygoteVJP()))
sum(soln[x])
end
symbolic_indexing (generic function with 1 method)
julia> du01, dp1 = Zygote.gradient(symbolic_indexing, u0, p)
ERROR: No matching function wrapper was found!
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0 [inlined]
[2] _pullback(ctx::Zygote.Context{false}, f::typeof(throw), args::FunctionWrappersWrappers.NoFunctionWrapperFoundError)
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:87
[3] _call
@ ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:23 [inlined]
[4] _pullback(::Zygote.Context{…}, ::typeof(FunctionWrappersWrappers._call), ::Tuple{}, ::Tuple{…}, ::FunctionWrappersWrappers.FunctionWrappersWrapper{…})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[5] _call
@ ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:13 [inlined]
--- the last 2 lines are repeated 3 more times ---
[12] _pullback(::Zygote.Context{…}, ::typeof(FunctionWrappersWrappers._call), ::Tuple{…}, ::Tuple{…}, ::FunctionWrappersWrappers.FunctionWrappersWrapper{…})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[13] FunctionWrappersWrapper
@ ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:10 [inlined]
[14] _pullback(::Zygote.Context{…}, ::FunctionWrappersWrappers.FunctionWrappersWrapper{…}, ::Vector{…}, ::MTKParameters{…}, ::Float64)
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[15] _apply
@ ./boot.jl:838 [inlined]
[16] adjoint
@ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:203 [inlined]
[17] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[18] ODEFunction
@ ~/Julia/SciML/SciMLBase.jl/src/scimlfunctions.jl:2330 [inlined]
[19] _pullback(::Zygote.Context{…}, ::ODEFunction{…}, ::Vector{…}, ::MTKParameters{…}, ::Float64)
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[20] #262
@ ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:488 [inlined]
[21] _pullback(ctx::Zygote.Context{…}, f::SciMLSensitivity.var"#262#263"{…}, args::MTKParameters{…})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[22] pullback
@ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:90 [inlined]
[23] pullback
@ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:88 [inlined]
[24] vec_pjac!(out::Vector{…}, λ::Vector{…}, y::Vector{…}, t::Float64, S::SciMLSensitivity.GaussIntegrand{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:487
[25] GaussIntegrand
@ ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:519 [inlined]
[26] (::SciMLSensitivity.var"#265#266"{…})(out::Vector{…}, u::Vector{…}, t::Float64, integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:560
[27] (::DiffEqCallbacks.SavingIntegrandSumAffect{…})(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ DiffEqCallbacks ~/.julia/packages/DiffEqCallbacks/n5zrr/src/integrating_sum.jl:50
[28] apply_discrete_callback!
@ ~/.julia/packages/DiffEqBase/DdIeW/src/callbacks.jl:618 [inlined]
[29] apply_discrete_callback!
@ ~/.julia/packages/DiffEqBase/DdIeW/src/callbacks.jl:637 [inlined]
[30] handle_callbacks!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/integrators/integrator_utils.jl:355
[31] _loopfooter!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/integrators/integrator_utils.jl:243
[32] loopfooter!
@ ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/integrators/integrator_utils.jl:207 [inlined]
[33] solve!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
@ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/solve.jl:552
[34] #__solve#75
@ ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/solve.jl:7 [inlined]
[35] __solve
@ ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/solve.jl:1 [inlined]
[36] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:612
[37] solve_call
@ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:569 [inlined]
[38] #solve_up#53
@ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1080 [inlined]
[39] solve_up
@ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1066 [inlined]
[40] #solve#51
@ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1003 [inlined]
[41] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::GaussAdjoint{…}, alg::Tsit5{…}; t::StepRangeLen{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Bool, callback::Nothing, kwargs::@Kwargs{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:580
[42] _adjoint_sensitivities
@ ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:533 [inlined]
[43] #adjoint_sensitivities#63
@ ~/.julia/packages/SciMLSensitivity/HRhwU/src/sensitivity_interface.jl:401 [inlined]
[44] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#313"{…})(Δ::ODESolution{…})
@ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/concrete_solve.jl:627
[45] ZBack
@ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/chainrules.jl:212 [inlined]
[46] (::Zygote.var"#kw_zpullback#56"{…})(dy::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/chainrules.jl:238
[47] #294
@ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:206 [inlined]
[48] (::Zygote.var"#2169#back#296"{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[49] #solve#51
@ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1003 [inlined]
[50] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[51] #294
@ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:206 [inlined]
[52] (::Zygote.var"#2169#back#296"{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[53] solve
@ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:993 [inlined]
[54] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[55] symbolic_indexing
@ ./REPL[16]:4 [inlined]
[56] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
[57] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:91
[58] gradient(::Function, ::Vector{Float64}, ::Vararg{Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:148 |
With SciML/ModelingToolkit.jl#3100 and #1131 the following works function symbolic_indexing(u0, p)
_p = SciMLStructures.replace(SciMLStructures.Tunable(), prob.p, p)
_prob = remake(prob, u0 = u0, p = _p)
soln = solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1,
sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP()))
sum(soln[x])
end julia> du01, dp1 = Zygote.gradient(symbolic_indexing, u0, p)
([-6.439844108628638, -0.699257997562027], [90.44500786842111, -0.6992683768440167, -23.971135312882016, -159.4738415565799]) https://github.com/SciML/ModelingToolkit.jl/blob/d7fa2b9a03fa964c214c7fa5fd23574de1fd0db5/ext/MTKChainRulesCoreExt.jl#L87 doesn't handle the case where the tangents are arrays. |
I've updated the |
Describe the bug 🐞
Taking the gradient with respect to a vector of parameter values (which are replaced into the parameter object) is not working
with
MTKParameters
Expected behavior
The gradient works
Minimal Reproducible Example 👇
Error & Stacktrace⚠️
Environment (please complete the following information):
using Pkg; Pkg.status()
(SciMLBase
test/downstream
environment)The text was updated successfully, but these errors were encountered: