Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support callable parameters #2995

Merged
merged 3 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636"
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand Down Expand Up @@ -76,6 +77,7 @@ ChainRulesCore = "1"
Combinatorics = "1"
Compat = "3.42, 4"
ConstructionBase = "1"
DataInterpolations = "6.4"
DataStructures = "0.17, 0.18"
DeepDiffs = "1"
DiffEqBase = "6.103.0"
Expand All @@ -91,6 +93,7 @@ ExprTools = "0.1.10"
Expronicon = "0.8"
FindFirstFunctions = "1"
ForwardDiff = "0.10.3"
FunctionWrappers = "1.1"
FunctionWrappersWrappers = "0.1"
Graphs = "1.5.2"
InteractiveUtils = "1"
Expand Down Expand Up @@ -118,8 +121,8 @@ SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.3.29"
SymbolicUtils = "3.2"
Symbolics = "6.3"
SymbolicUtils = "3.7"
Symbolics = "6.12"
URIs = "1"
UnPack = "0.1, 1.0"
Unitful = "1.1"
Expand All @@ -129,6 +132,7 @@ julia = "1.9"
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -154,4 +158,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
4 changes: 3 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ using Base: RefValue
using Combinatorics
import Distributions
import FunctionWrappersWrappers
import FunctionWrappers: FunctionWrapper
using URIs: URI
using SciMLStructures
using Compat
Expand All @@ -63,7 +64,8 @@ using Symbolics: _parse_vars, value, @derivatives, get_variables,
VariableSource, getname, variable, Connection, connect,
NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval,
initial_state, transition, activeState, entry, hasnode,
ticksInState, timeInState, fixpoint_sub, fast_substitute
ticksInState, timeInState, fixpoint_sub, fast_substitute,
CallWithMetadata, CallWithParent
const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR)
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
jacobian_sparsity, isaffine, islinear, _iszero, _isone,
Expand Down
10 changes: 10 additions & 0 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ function isparameter(x)
end
end

function iscalledparameter(x)
x = unwrap(x)
return isparameter(getmetadata(x, CallWithParent, nothing))
end

function getcalledparameter(x)
x = unwrap(x)
return getmetadata(x, CallWithParent)
end

"""
toparam(s)

Expand Down
43 changes: 29 additions & 14 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ struct DiscreteIndex
end

const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}}
const NonnumericMap = Dict{
Union{BasicSymbolic, Symbolics.CallWithMetadata}, Tuple{Int, Int}}
const UnknownIndexMap = Dict{
BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
const TunableIndexMap = Dict{BasicSymbolic,
Expand All @@ -45,20 +47,20 @@ struct IndexCache
callback_to_clocks::Dict{Any, Vector{Int}}
tunable_idx::TunableIndexMap
constant_idx::ParamIndexMap
nonnumeric_idx::ParamIndexMap
nonnumeric_idx::NonnumericMap
observed_syms::Set{BasicSymbolic}
dependent_pars::Set{BasicSymbolic}
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
tunable_buffer_size::BufferTemplate
constant_buffer_sizes::Vector{BufferTemplate}
nonnumeric_buffer_sizes::Vector{BufferTemplate}
symbol_to_variable::Dict{Symbol, BasicSymbolic}
symbol_to_variable::Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}
end

function IndexCache(sys::AbstractSystem)
unks = solved_unknowns(sys)
unk_idxs = UnknownIndexMap()
symbol_to_variable = Dict{Symbol, BasicSymbolic}()
symbol_to_variable = Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}()

let idx = 1
for sym in unks
Expand Down Expand Up @@ -105,12 +107,11 @@ function IndexCache(sys::AbstractSystem)

tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
nonnumeric_buffers = Dict{Any, Set{BasicSymbolic}}()
nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}()

function insert_by_type!(buffers::Dict{Any, Set{BasicSymbolic}}, sym)
function insert_by_type!(buffers::Dict{Any, S}, sym, ctype) where {S}
sym = unwrap(sym)
ctype = symtype(sym)
buf = get!(buffers, ctype, Set{BasicSymbolic}())
buf = get!(buffers, ctype, S())
push!(buf, sym)
end

Expand Down Expand Up @@ -142,7 +143,7 @@ function IndexCache(sys::AbstractSystem)
clocks = get!(() -> Set{Int}(), disc_param_callbacks, sym)
push!(clocks, i)
else
insert_by_type!(constant_buffers, sym)
insert_by_type!(constant_buffers, sym, symtype(sym))
end
end
end
Expand Down Expand Up @@ -197,6 +198,9 @@ function IndexCache(sys::AbstractSystem)
for p in parameters(sys)
p = unwrap(p)
ctype = symtype(p)
if ctype <: FnType
ctype = fntype_to_function_type(ctype)
end
haskey(disc_idxs, p) && continue
haskey(constant_buffers, ctype) && p in constant_buffers[ctype] && continue
insert_by_type!(
Expand All @@ -212,12 +216,13 @@ function IndexCache(sys::AbstractSystem)
else
nonnumeric_buffers
end,
p
p,
ctype
)
end

function get_buffer_sizes_and_idxs(buffers::Dict{Any, Set{BasicSymbolic}})
idxs = ParamIndexMap()
function get_buffer_sizes_and_idxs(T, buffers::Dict)
idxs = T()
buffer_sizes = BufferTemplate[]
for (i, (T, buf)) in enumerate(buffers)
for (j, p) in enumerate(buf)
Expand All @@ -229,13 +234,18 @@ function IndexCache(sys::AbstractSystem)
idxs[rp] = (i, j)
idxs[rttp] = (i, j)
end
if T <: Symbolics.FnType
T = Any
end
push!(buffer_sizes, BufferTemplate(T, length(buf)))
end
return idxs, buffer_sizes
end

const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers)
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers)
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(
ParamIndexMap, constant_buffers)
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(
NonnumericMap, nonnumeric_buffers)

tunable_idxs = TunableIndexMap()
tunable_buffer_size = 0
Expand Down Expand Up @@ -401,7 +411,8 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
for temp in ic.discrete_buffer_sizes)
const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
for temp in ic.constant_buffer_sizes)
nonnumeric_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
nonnumeric_buf = Tuple(Union{BasicSymbolic, CallWithMetadata}[unwrap(variable(:DEF))
for _ in 1:(temp.length)]
for temp in ic.nonnumeric_buffer_sizes)
for p in ps
p = unwrap(p)
Expand Down Expand Up @@ -481,3 +492,7 @@ function get_buffer_template(ic::IndexCache, pidx::ParameterIndex)
error("Unhandled portion $portion")
end
end

fntype_to_function_type(::Type{FnType{A, R, T}}) where {A, R, T} = T
fntype_to_function_type(::Type{FnType{A, R, Nothing}}) where {A, R} = FunctionWrapper{R, A}
fntype_to_function_type(::Type{FnType{A, R}}) where {A, R} = FunctionWrapper{R, A}
3 changes: 3 additions & 0 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ function MTKParameters(
if symbolic_type(val) !== NotSymbolic()
error("Could not evaluate value of parameter $sym. Missing values for variables in expression $val.")
end
if ctype <: FnType
ctype = fntype_to_function_type(ctype)
end
val = symconvert(ctype, val)
done = set_value(sym, val)
if !done && Symbolics.isarraysymbolic(sym)
Expand Down
3 changes: 2 additions & 1 deletion src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,8 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
@set! sys.defaults = merge(ModelingToolkit.defaults(sys),
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
end
ps = [setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous))
ps = [sym isa CallWithMetadata ? sym :
setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous))
for sym in get_ps(sys)]
@set! sys.ps = ps
else
Expand Down
12 changes: 10 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,11 @@ end
vars(exprs::Num; op = Differential) = vars(unwrap(exprs); op)
vars(exprs::Symbolics.Arr; op = Differential) = vars(unwrap(exprs); op)
function vars(exprs; op = Differential)
foldl((x, y) -> vars!(x, unwrap(y); op = op), exprs; init = Set())
if hasmethod(iterate, Tuple{typeof(exprs)})
foldl((x, y) -> vars!(x, unwrap(y); op = op), exprs; init = Set())
else
vars!(Set(), unwrap(exprs); op)
end
end
vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op)
function vars!(vars, eq::Equation; op = Differential)
Expand Down Expand Up @@ -479,7 +483,11 @@ end

function collect_var!(unknowns, parameters, var, iv)
isequal(var, iv) && return nothing
if isparameter(var) || (iscall(var) && isparameter(operation(var)))
if iscalledparameter(var)
callable = getcalledparameter(var)
push!(parameters, callable)
collect_vars!(unknowns, parameters, arguments(var), iv)
elseif isparameter(var) || (iscall(var) && isparameter(operation(var)))
push!(parameters, var)
elseif !isconstant(var)
push!(unknowns, var)
Expand Down
4 changes: 2 additions & 2 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,15 @@ eqs = [D(x) ~ σ(t - 1) * (y - x),
D(y) ~ x * (ρ - z) - y,
D(z) ~ x * y - β * z * κ]
@named de = ODESystem(eqs, t)
test_diffeq_inference("single internal iv-varying", de, t, (x, y, z), (σ(t - 1), ρ, β))
test_diffeq_inference("single internal iv-varying", de, t, (x, y, z), (σ, ρ, β))
f = eval(generate_function(de, [x, y, z], [σ, ρ, β])[2])
du = [0.0, 0.0, 0.0]
f(du, [1.0, 2.0, 3.0], [x -> x + 7, 2, 3], 5.0)
@test du ≈ [11, -3, -7]

eqs = [D(x) ~ x + 10σ(t - 1) + 100σ(t - 2) + 1000σ(t^2)]
@named de = ODESystem(eqs, t)
test_diffeq_inference("many internal iv-varying", de, t, (x,), (σ(t - 2), σ(t^2), σ(t - 1)))
test_diffeq_inference("many internal iv-varying", de, t, (x,), (σ,))
f = eval(generate_function(de, [x], [σ])[2])
du = [0.0]
f(du, [1.0], [t -> t + 2], 5.0)
Expand Down
45 changes: 45 additions & 0 deletions test/split_parameters.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
using ModelingToolkit, Test
using ModelingToolkitStandardLibrary.Blocks
using OrdinaryDiffEq
using DataInterpolations
using BlockArrays: BlockedArray
using ModelingToolkit: t_nounits as t, D_nounits as D
using ModelingToolkit: MTKParameters, ParameterIndex, NONNUMERIC_PORTION
using SciMLStructures: Tunable, Discrete, Constants
using StaticArrays: SizedVector
using SymbolicIndexingInterface: is_parameter, getp

x = [1, 2.0, false, [1, 2, 3], Parameter(1.0)]

Expand Down Expand Up @@ -219,3 +221,46 @@ S = get_sensitivity(closed_loop, :u)
@test ps[ParameterIndex(Tunable(), 1:8)] == collect(1.0:8.0) .+ 0.5
@test ps[ParameterIndex(Discrete(), (2, 1, 2, 2))] == 5
end

@testset "Callable parameters" begin
@testset "As FunctionWrapper" begin
_f1(x) = 2x
struct Foo end
(::Foo)(x) = 3x
@variables x(t)
@parameters fn(::Real) = _f1
@mtkbuild sys = ODESystem(D(x) ~ fn(t), t)
@test is_parameter(sys, fn)
@test ModelingToolkit.defaults(sys)[fn] == _f1

getter = getp(sys, fn)
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0))
@inferred getter(prob)
# cannot be inferred better since `FunctionWrapper` is only known to return `Real`
@inferred Vector{<:Real} prob.f(prob.u0, prob.p, prob.tspan[1])
sol = solve(prob, Tsit5(); abstol = 1e-10, reltol = 1e-10)
@test sol.u[end][] ≈ 2.0

prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => Foo()])
@inferred getter(prob)
@inferred Vector{<:Real} prob.f(prob.u0, prob.p, prob.tspan[1])
sol = solve(prob; abstol = 1e-10, reltol = 1e-10)
@test sol.u[end][] ≈ 2.5
end

@testset "Concrete function type" begin
ts = 0.0:0.1:1.0
interp = LinearInterpolation(ts .^ 2, ts; extrapolate = true)
@variables x(t)
@parameters (fn::typeof(interp))(..)
@mtkbuild sys = ODESystem(D(x) ~ fn(x), t)
@test is_parameter(sys, fn)
getter = getp(sys, fn)
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => interp])
@inferred getter(prob)
@inferred prob.f(prob.u0, prob.p, prob.tspan[1])
@test_nowarn sol = solve(prob, Tsit5())
@test_nowarn prob.ps[fn] = LinearInterpolation(ts .^ 3, ts; extrapolate = true)
@test_nowarn sol = solve(prob)
end
end
Loading