From ab2788f03242d0d4e19c861b6405636062df1eed Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 16 Sep 2024 11:57:32 +0530 Subject: [PATCH] fixup! feat: support callable parameters --- Project.toml | 4 ++- test/split_parameters.jl | 61 +++++++++++++++++++++++++++------------- 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/Project.toml b/Project.toml index 9d22e4c70c..8bcfe57083 100644 --- a/Project.toml +++ b/Project.toml @@ -77,6 +77,7 @@ ChainRulesCore = "1" Combinatorics = "1" Compat = "3.42, 4" ConstructionBase = "1" +DataInterpolations = "6.4" DataStructures = "0.17, 0.18" DeepDiffs = "1" DiffEqBase = "6.103.0" @@ -131,6 +132,7 @@ julia = "1.9" AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e" +DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6" DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -156,4 +158,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"] +test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"] diff --git a/test/split_parameters.jl b/test/split_parameters.jl index 5dcea5cd11..06ec3bb292 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -1,6 +1,7 @@ using ModelingToolkit, Test using ModelingToolkitStandardLibrary.Blocks using OrdinaryDiffEq +using DataInterpolations using BlockArrays: BlockedArray using ModelingToolkit: t_nounits as t, D_nounits as D using ModelingToolkit: MTKParameters, ParameterIndex, NONNUMERIC_PORTION @@ -222,24 +223,44 @@ S = get_sensitivity(closed_loop, :u) end @testset "Callable parameters" begin - _f1(x) = 2x - struct Foo end - (::Foo)(x) = 3x - @variables x(t) - @parameters fn(..) = _f1 - @mtkbuild sys = ODESystem(D(x) ~ fn(x), t, [x], [fn]) - @test is_parameter(sys, fn) - @test ModelingToolkit.defaults(sys)[fn] == _f1 - - prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0)) - @test_broken @inferred prob.ps[fn] - @test_broken @inferred prob.f(prob.u0, prob.p, prob.tspan[1]) - sol = solve(prob; abstol = 1e-10, reltol = 1e-10) - @test sol.u[end][] ≈ exp(2.0) - - prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => Foo()]) - @test_broken @inferred prob.ps[fn] - @test_broken @inferred prob.f(prob.u0, prob.p, prob.tspan[1]) - sol = solve(prob; abstol = 1e-10, reltol = 1e-10) - @test sol.u[end][] ≈ exp(3.0) + @testset "As FunctionWrapper" begin + _f1(x) = 2x + struct Foo end + (::Foo)(x) = 3x + @variables x(t) + @parameters fn(::Real) = _f1 + @mtkbuild sys = ODESystem(D(x) ~ fn(t), t) + @test is_parameter(sys, fn) + @test ModelingToolkit.defaults(sys)[fn] == _f1 + + getter = getp(sys, fn) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0)) + @inferred getter(prob) + # cannot be inferred better since `FunctionWrapper` is only known to return `Real` + @inferred Vector{<:Real} prob.f(prob.u0, prob.p, prob.tspan[1]) + sol = solve(prob, Tsit5(); abstol = 1e-10, reltol = 1e-10) + @test sol.u[end][] ≈ 2.0 + + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => Foo()]) + @inferred getter(prob) + @inferred Vector{<:Real} prob.f(prob.u0, prob.p, prob.tspan[1]) + sol = solve(prob; abstol = 1e-10, reltol = 1e-10) + @test sol.u[end][] ≈ 2.5 + end + + @testset "Concrete function type" begin + ts = 0.0:0.1:1.0 + interp = LinearInterpolation(ts .^ 2, ts; extrapolate = true) + @variables x(t) + @parameters (fn::typeof(interp))(..) + @mtkbuild sys = ODESystem(D(x) ~ fn(x), t) + @test is_parameter(sys, fn) + getter = getp(sys, fn) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => interp]) + @inferred getter(prob) + @inferred prob.f(prob.u0, prob.p, prob.tspan[1]) + @test_nowarn sol = solve(prob, Tsit5()) + @test_nowarn prob.ps[fn] = LinearInterpolation(ts .^ 3, ts; extrapolate = true) + @test_nowarn sol = solve(prob) + end end