Skip to content

Commit

Permalink
fixup! feat: support callable parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Sep 16, 2024
1 parent 6d2aed8 commit ab2788f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 21 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ ChainRulesCore = "1"
Combinatorics = "1"
Compat = "3.42, 4"
ConstructionBase = "1"
DataInterpolations = "6.4"
DataStructures = "0.17, 0.18"
DeepDiffs = "1"
DiffEqBase = "6.103.0"
Expand Down Expand Up @@ -131,6 +132,7 @@ julia = "1.9"
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -156,4 +158,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
61 changes: 41 additions & 20 deletions test/split_parameters.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using ModelingToolkit, Test
using ModelingToolkitStandardLibrary.Blocks
using OrdinaryDiffEq
using DataInterpolations
using BlockArrays: BlockedArray
using ModelingToolkit: t_nounits as t, D_nounits as D
using ModelingToolkit: MTKParameters, ParameterIndex, NONNUMERIC_PORTION
Expand Down Expand Up @@ -222,24 +223,44 @@ S = get_sensitivity(closed_loop, :u)
end

@testset "Callable parameters" begin
_f1(x) = 2x
struct Foo end
(::Foo)(x) = 3x
@variables x(t)
@parameters fn(..) = _f1
@mtkbuild sys = ODESystem(D(x) ~ fn(x), t, [x], [fn])
@test is_parameter(sys, fn)
@test ModelingToolkit.defaults(sys)[fn] == _f1

prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0))
@test_broken @inferred prob.ps[fn]
@test_broken @inferred prob.f(prob.u0, prob.p, prob.tspan[1])
sol = solve(prob; abstol = 1e-10, reltol = 1e-10)
@test sol.u[end][] exp(2.0)

prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => Foo()])
@test_broken @inferred prob.ps[fn]
@test_broken @inferred prob.f(prob.u0, prob.p, prob.tspan[1])
sol = solve(prob; abstol = 1e-10, reltol = 1e-10)
@test sol.u[end][] exp(3.0)
@testset "As FunctionWrapper" begin
_f1(x) = 2x
struct Foo end
(::Foo)(x) = 3x
@variables x(t)
@parameters fn(::Real) = _f1
@mtkbuild sys = ODESystem(D(x) ~ fn(t), t)
@test is_parameter(sys, fn)
@test ModelingToolkit.defaults(sys)[fn] == _f1

getter = getp(sys, fn)
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0))
@inferred getter(prob)
# cannot be inferred better since `FunctionWrapper` is only known to return `Real`
@inferred Vector{<:Real} prob.f(prob.u0, prob.p, prob.tspan[1])
sol = solve(prob, Tsit5(); abstol = 1e-10, reltol = 1e-10)
@test sol.u[end][] 2.0

prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => Foo()])
@inferred getter(prob)
@inferred Vector{<:Real} prob.f(prob.u0, prob.p, prob.tspan[1])
sol = solve(prob; abstol = 1e-10, reltol = 1e-10)
@test sol.u[end][] 2.5
end

@testset "Concrete function type" begin
ts = 0.0:0.1:1.0
interp = LinearInterpolation(ts .^ 2, ts; extrapolate = true)
@variables x(t)
@parameters (fn::typeof(interp))(..)
@mtkbuild sys = ODESystem(D(x) ~ fn(x), t)
@test is_parameter(sys, fn)
getter = getp(sys, fn)
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => interp])
@inferred getter(prob)
@inferred prob.f(prob.u0, prob.p, prob.tspan[1])
@test_nowarn sol = solve(prob, Tsit5())
@test_nowarn prob.ps[fn] = LinearInterpolation(ts .^ 3, ts; extrapolate = true)
@test_nowarn sol = solve(prob)
end
end

0 comments on commit ab2788f

Please sign in to comment.