From 4a8f3b10441997038b35a98482b8f434ced8c05d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 25 Dec 2023 19:43:06 -0500 Subject: [PATCH] Fix ODEInterface tests --- .github/workflows/CI.yml | 1 - .github/workflows/CompatHelper.yml | 2 +- Manifest.toml | 121 +++++++++++++--------- Project.toml | 6 +- ext/BoundaryValueDiffEqODEInterfaceExt.jl | 4 +- src/solve/multiple_shooting.jl | 4 +- src/utils.jl | 3 +- test/misc/odeinterface_wrapper.jl | 30 +++--- 8 files changed, 99 insertions(+), 72 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 16eb9ca4..dbf01864 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -23,7 +23,6 @@ jobs: - Others version: - '1' - - '~1.10.0-0' steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 46670d75..0fe6c374 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -11,7 +11,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - julia-version: [1.5.0] + julia-version: [1] julia-arch: [x86] os: [ubuntu-latest] steps: diff --git a/Manifest.toml b/Manifest.toml index 2464fce9..cfb98313 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.0-rc2" +julia_version = "1.10.0" manifest_format = "2.0" -project_hash = "0ac16cb78a3540d2a0e82de32da74fda24f340a4" +project_hash = "d7c3c48e84b3bd6db2d27ce714362096378baf8a" [[deps.ADTypes]] git-tree-sha1 = "41c37aa88889c171f1300ceac1313c06e891d245" @@ -11,9 +11,9 @@ version = "0.2.6" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "cde29ddf7e5726c9fb511f340244ea3481267608" +git-tree-sha1 = "0fb305e0253fd4e833d486914367a2ee2c2e78d0" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.7.2" +version = "4.0.1" weakdeps = ["StaticArrays"] [deps.Adapt.extensions] @@ -53,9 +53,9 @@ version = "7.7.0" [[deps.ArrayLayouts]] deps = ["FillArrays", "LinearAlgebra"] -git-tree-sha1 = "b08a4043e1c14096ef8efe4dd97e07de5cacf240" +git-tree-sha1 = "a45ec4acc9d905f94b47243cff666820bb107789" uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" -version = "1.4.5" +version = "1.5.2" weakdeps = ["SparseArrays"] [deps.ArrayLayouts.extensions] @@ -107,10 +107,10 @@ uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" version = "0.3.0" [[deps.Compat]] -deps = ["UUIDs"] -git-tree-sha1 = "886826d76ea9e72b35fcd000e535588f7b60f21d" +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "75bd5b6fc5089df449b5d35fa501c846c9b6549b" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.10.1" +version = "4.12.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -147,15 +147,15 @@ uuid = "adafc99b-e345-5852-983c-f28acb93d879" version = "0.3.1" [[deps.DataAPI]] -git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.15.0" +version = "1.16.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d" +git-tree-sha1 = "ac67408d9ddf207de5cfa9a97e114352430f01ed" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.15" +version = "0.18.16" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -168,9 +168,9 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" [[deps.DiffEqBase]] deps = ["ArrayInterface", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces"] -git-tree-sha1 = "8775b80752e9656000ab3800adad8ee22c9cb8f6" +git-tree-sha1 = "6af33c2eb7478db06bcf5c810e6f3dda53aac2ac" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" -version = "6.145.0" +version = "6.146.0" [deps.DiffEqBase.extensions] DiffEqBaseChainRulesCoreExt = "ChainRulesCore" @@ -229,10 +229,13 @@ uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" version = "1.0.4" [[deps.EnzymeCore]] -deps = ["Adapt"] -git-tree-sha1 = "2efe862de93cd87f620ad6ac9c9e3f83f1b2841b" +git-tree-sha1 = "59c44d8fbc651c0395d8a6eda64b05ce316f58b4" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" -version = "0.6.4" +version = "0.6.5" +weakdeps = ["Adapt"] + + [deps.EnzymeCore.extensions] + AdaptExt = "Adapt" [[deps.ExprTools]] git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" @@ -258,9 +261,9 @@ version = "0.3.2" [[deps.FastLapackInterface]] deps = ["LinearAlgebra"] -git-tree-sha1 = "b12f05108e405dadcc2aff0008db7f831374e051" +git-tree-sha1 = "d576a29bf8bcabf4b1deb9abe88a3d7f78306ab5" uuid = "29a986be-02c6-4525-aec4-84b980013641" -version = "2.0.0" +version = "2.0.1" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" @@ -283,9 +286,9 @@ version = "1.9.3" [[deps.FiniteDiff]] deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays"] -git-tree-sha1 = "c6e4a1fbe73b31a3dea94b1da449503b8830c306" +git-tree-sha1 = "73d1214fec245096717847c62d389a5d2ac86504" uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" -version = "2.21.1" +version = "2.22.0" [deps.FiniteDiff.extensions] FiniteDiffBandedMatricesExt = "BandedMatrices" @@ -324,9 +327,9 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.GPUArraysCore]] deps = ["Adapt"] -git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0" +git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.5" +version = "0.1.6" [[deps.Graphs]] deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] @@ -452,10 +455,10 @@ deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LinearSolve]] -deps = ["ArrayInterface", "ConcreteStructs", "DocStringExtensions", "EnumX", "FastLapackInterface", "GPUArraysCore", "InteractiveUtils", "KLU", "Krylov", "Libdl", "LinearAlgebra", "MKL_jll", "PrecompileTools", "Preferences", "RecursiveFactorization", "Reexport", "Requires", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Sparspak", "StaticArraysCore", "UnPack"] -git-tree-sha1 = "ebdc72aa2f1ccbb9f9dd1e85698145024b762ac3" +deps = ["ArrayInterface", "ConcreteStructs", "DocStringExtensions", "EnumX", "FastLapackInterface", "GPUArraysCore", "InteractiveUtils", "KLU", "Krylov", "Libdl", "LinearAlgebra", "MKL_jll", "PrecompileTools", "Preferences", "RecursiveFactorization", "Reexport", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Sparspak", "StaticArraysCore", "UnPack"] +git-tree-sha1 = "6f8e084deabe3189416c4e505b1c53e1b590cae8" uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" -version = "2.21.2" +version = "2.22.1" [deps.LinearSolve.extensions] LinearSolveBandedMatricesExt = "BandedMatrices" @@ -528,9 +531,9 @@ version = "2024.0.0+0" [[deps.MacroTools]] deps = ["Markdown", "Random"] -git-tree-sha1 = "b211c553c199c111d998ecdaf7623d1b89b69f93" +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.12" +version = "0.5.13" [[deps.ManualMemory]] git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" @@ -587,10 +590,10 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" [[deps.NonlinearSolve]] -deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "EnumX", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SciMLOperators", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArrays", "UnPack"] -git-tree-sha1 = "72b036b728461272ae1b1c3f7096cb4c319d8793" +deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FastBroadcast", "FastClosures", "FiniteDiff", "ForwardDiff", "LazyArrays", "LineSearches", "LinearAlgebra", "LinearSolve", "MaybeInplace", "PrecompileTools", "Preferences", "Printf", "RecursiveArrayTools", "Reexport", "SciMLBase", "SimpleNonlinearSolve", "SparseArrays", "SparseDiffTools", "StaticArraysCore", "TimerOutputs"] +git-tree-sha1 = "323d2a61f4adc4dfe404bf332b59690253b4f4f2" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" -version = "3.4.0" +version = "3.5.3" [deps.NonlinearSolve.extensions] NonlinearSolveBandedMatricesExt = "BandedMatrices" @@ -676,10 +679,10 @@ uuid = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad" version = "0.2.1" [[deps.PreallocationTools]] -deps = ["Adapt", "ArrayInterface", "ForwardDiff", "Requires"] -git-tree-sha1 = "01ac95fca7daabe77a9cb705862bd87016af9ddb" +deps = ["Adapt", "ArrayInterface", "ForwardDiff"] +git-tree-sha1 = "64bb68f76f789f5fe5930a80af310f19cdafeaed" uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46" -version = "0.4.13" +version = "0.4.17" [deps.PreallocationTools.extensions] PreallocationToolsReverseDiffExt = "ReverseDiff" @@ -718,15 +721,16 @@ uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" version = "1.3.4" [[deps.RecursiveArrayTools]] -deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "Requires", "SparseArrays", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "96fdc4a33fa4282e6f3ed54de6be569b1aa43972" +deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "SparseArrays", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] +git-tree-sha1 = "5a904ad526cc9a2c5b464f6642ce9dd230fd69b6" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -version = "3.2.6" +version = "3.7.0" [deps.RecursiveArrayTools.extensions] RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" RecursiveArrayToolsMeasurementsExt = "Measurements" RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" + RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"] RecursiveArrayToolsTrackerExt = "Tracker" RecursiveArrayToolsZygoteExt = "Zygote" @@ -734,6 +738,7 @@ version = "3.2.6" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -777,11 +782,11 @@ version = "0.6.42" [[deps.SciMLBase]] deps = ["ADTypes", "ArrayInterface", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FillArrays", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "Printf", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces"] -git-tree-sha1 = "b8f7a0807314cce87bf846ba5fd12c1b0ef512b7" +git-tree-sha1 = "40642998c5edee0d229a18c29084f656b690e464" repo-rev = "ap/nlls_bvp" repo-url = "https://github.com/SciML/SciMLBase.jl.git" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -version = "2.12.0" +version = "2.22.0" [deps.SciMLBase.extensions] SciMLBaseChainRulesCoreExt = "ChainRulesCore" @@ -820,10 +825,18 @@ deps = ["Distributed", "Mmap", "Random", "Serialization"] uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" [[deps.SimpleNonlinearSolve]] -deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FiniteDiff", "ForwardDiff", "LinearAlgebra", "MaybeInplace", "PrecompileTools", "Reexport", "SciMLBase", "StaticArraysCore"] -git-tree-sha1 = "1a467a5767d712863e2108e86f7ab103f6d54b13" +deps = ["ADTypes", "ArrayInterface", "ConcreteStructs", "DiffEqBase", "FastClosures", "FiniteDiff", "ForwardDiff", "LinearAlgebra", "MaybeInplace", "PrecompileTools", "Reexport", "SciMLBase", "StaticArraysCore"] +git-tree-sha1 = "470c5f97af31fa35926b45eb01e53a46c8d7d35f" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" -version = "1.0.4" +version = "1.3.1" + + [deps.SimpleNonlinearSolve.extensions] + SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff" + SimpleNonlinearSolveStaticArraysExt = "StaticArrays" + + [deps.SimpleNonlinearSolve.weakdeps] + PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.SimpleTraits]] deps = ["InteractiveUtils", "MacroTools"] @@ -841,17 +854,19 @@ version = "1.10.0" [[deps.SparseDiffTools]] deps = ["ADTypes", "Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "Graphs", "LinearAlgebra", "PackageExtensionCompat", "Random", "Reexport", "SciMLOperators", "Setfield", "SparseArrays", "StaticArrayInterface", "StaticArrays", "Tricks", "UnPack", "VertexSafeGraphs"] -git-tree-sha1 = "c281e11db4eacb36a292a054bac83c5a0aca2a26" +git-tree-sha1 = "3b38ae7a1cbe9b8b1344359599753957644b03d4" uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" -version = "2.15.0" +version = "2.16.0" [deps.SparseDiffTools.extensions] SparseDiffToolsEnzymeExt = "Enzyme" + SparseDiffToolsPolyesterForwardDiffExt = "PolyesterForwardDiff" SparseDiffToolsSymbolicsExt = "Symbolics" SparseDiffToolsZygoteExt = "Zygote" [deps.SparseDiffTools.weakdeps] Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -892,9 +907,9 @@ weakdeps = ["OffsetArrays", "StaticArrays"] [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "fba11dbe2562eecdfcac49a05246af09ee64d055" +git-tree-sha1 = "7b0e9c14c624e435076d19aea1e5cbdec2b9ca37" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.8.1" +version = "1.9.2" [deps.StaticArrays.extensions] StaticArraysChainRulesCoreExt = "ChainRulesCore" @@ -925,14 +940,14 @@ deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" version = "7.2.1+1" [[deps.SymbolicIndexingInterface]] -git-tree-sha1 = "65f4ed0f9e3125e0836df12c231cea3dd98bb165" +git-tree-sha1 = "b3103f4f50a3843e66297a2456921377c78f5e31" uuid = "2efcf032-c050-4f8e-a9bb-153293bab1f5" -version = "0.3.0" +version = "0.3.5" [[deps.TOML]] deps = ["Dates"] @@ -966,6 +981,12 @@ git-tree-sha1 = "eda08f7e9818eb53661b3deb74e3159460dfbc27" uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" version = "0.5.2" +[[deps.TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.23" + [[deps.TriangularSolve]] deps = ["CloseOpenIntervals", "IfElse", "LayoutPointers", "LinearAlgebra", "LoopVectorization", "Polyester", "Static", "VectorizationBase"] git-tree-sha1 = "fadebab77bf3ae041f77346dd1c290173da5a443" diff --git a/Project.toml b/Project.toml index a3132353..739a961f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "BoundaryValueDiffEq" uuid = "764a87c0-6b3e-53db-9096-fe964310641d" -version = "5.6.3" +version = "5.7.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -51,7 +51,7 @@ ForwardDiff = "0.10" JET = "0.8" LinearAlgebra = "1.9" LinearSolve = "2.20" -NonlinearSolve = "2.6.1, 3" +NonlinearSolve = "3.5" ODEInterface = "0.5" OrdinaryDiffEq = "6" PreallocationTools = "0.4" @@ -70,7 +70,7 @@ Test = "1" Tricks = "0.1" TruncatedStacktraces = "1" UnPack = "1" -julia = "1.9" +julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" diff --git a/ext/BoundaryValueDiffEqODEInterfaceExt.jl b/ext/BoundaryValueDiffEqODEInterfaceExt.jl index e478f298..2e8a2dea 100644 --- a/ext/BoundaryValueDiffEqODEInterfaceExt.jl +++ b/ext/BoundaryValueDiffEqODEInterfaceExt.jl @@ -29,7 +29,7 @@ function __solve(prob::BVProblem, alg::BVPM2; dt = 0.0, reltol = 1e-3, kwargs... n == -1 && dt ≤ 0 && throw(ArgumentError("`dt` must be positive.")) - mesh = __extract_mesh(prob.u0, t₀, t₁, ifelse(n == -1, dt, n)) + mesh = __extract_mesh(prob.u0, t₀, t₁, ifelse(n == -1, dt, n - 1)) n = length(mesh) - 1 no_odes = length(u0_) @@ -111,7 +111,7 @@ function __solve(prob::BVProblem, alg::BVPSOL; maxiters = 1000, reltol = 1e-3, d n == -1 && dt ≤ 0 && throw(ArgumentError("`dt` must be positive.")) u0 = __flatten_initial_guess(prob.u0) - mesh = __extract_mesh(prob.u0, t₀, t₁, ifelse(n == -1, dt, n)) + mesh = __extract_mesh(prob.u0, t₀, t₁, ifelse(n == -1, dt, n - 1)) if u0 === nothing # initial_guess function was provided u0 = mapreduce(@closure(t->vec(__initial_guess(prob.u0, prob.p, t))), hcat, mesh) diff --git a/src/solve/multiple_shooting.jl b/src/solve/multiple_shooting.jl index 5a841cbe..ce141e3c 100644 --- a/src/solve/multiple_shooting.jl +++ b/src/solve/multiple_shooting.jl @@ -34,7 +34,7 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;), internal_ode_kwargs = (; verbose, kwargs..., odesolve_kwargs..., save_end = true) solve_internal_odes! = @closure (resid_nodes, us, p, cur_nshoot, nodes, - odecache) -> __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoot, + odecache) -> __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoot, odecache, nodes, u0_size, N, ensemblealg, tspan) # This gets all the nshoots except the final SingleShooting case @@ -476,4 +476,4 @@ end end @assert !(1 in nshoots_vec) return nshoots_vec -end \ No newline at end of file +end diff --git a/src/utils.jl b/src/utils.jl index 3e68caad..63c6387b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -312,7 +312,8 @@ Takes the input initial guess and returns the mesh. """ @inline __extract_mesh(u₀, t₀, t₁, n::Int) = collect(range(t₀; stop = t₁, length = n + 1)) @inline __extract_mesh(u₀, t₀, t₁, dt::Number) = collect(t₀:dt:t₁) -@inline __extract_mesh(u₀::DiffEqArray, t₀, t₁, n) = u₀.t +@inline __extract_mesh(u₀::DiffEqArray, t₀, t₁, ::Int) = u₀.t +@inline __extract_mesh(u₀::DiffEqArray, t₀, t₁, ::Number) = u₀.t """ __has_initial_guess(u₀) -> Bool diff --git a/test/misc/odeinterface_wrapper.jl b/test/misc/odeinterface_wrapper.jl index 42e61f86..69e3a560 100644 --- a/test/misc/odeinterface_wrapper.jl +++ b/test/misc/odeinterface_wrapper.jl @@ -1,4 +1,5 @@ -using Test, BoundaryValueDiffEq, LinearAlgebra, ODEInterface, Random, RecursiveArrayTools +using Test, BoundaryValueDiffEq, LinearAlgebra, ODEInterface, Random, OrdinaryDiffEq, + RecursiveArrayTools # Adaptation of https://github.com/luchr/ODEInterface.jl/blob/958b6023d1dabf775033d0b89c5401b33100bca3/examples/BasicExamples/ex7.jl function ex7_f!(du, u, p, t) @@ -25,7 +26,6 @@ tspan = (-π / 2, π / 2) tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), u0, tspan, p; bcresid_prototype = (zeros(1), zeros(1))) -sol_bvpm2 = solve(tpprob, BVPM2(); dt = π / 20) @testset "BVPM2" begin @info "Testing BVPM2" @@ -38,14 +38,19 @@ sol_bvpm2 = solve(tpprob, BVPM2(); dt = π / 20) @test norm(resid_f, Inf) < 1e-6 end +# Just generate a solution for bvpsol +sol_ms = solve(tpprob, MultipleShooting(10, DP5(), NewtonRaphson()); + dt = π / 20, abstol = 1e-5, maxiters = 1000, + odesolve_kwargs = (; adaptive = false, dt = 0.01, abstol = 1e-6, maxiters = 1000)) + # Just test that it runs. BVPSOL only works with linearly separable BCs. @testset "BVPSOL" begin @info "Testing BVPSOL" @info "BVPSOL with Vector{<:AbstractArray}" - initial_u0 = [sol_bvpm2(t) .+ rand() for t in tspan[1]:(π / 20):tspan[2]] - tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan; + initial_u0 = [sol_ms(t) .+ rand() for t in tspan[1]:(π / 20):tspan[2]] + tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan, p; bcresid_prototype = (zeros(1), zeros(1))) # Just test that it runs. BVPSOL only works with linearly separable BCs. @@ -53,8 +58,8 @@ end @info "BVPSOL with VectorOfArray" - initial_u0 = VectorOfArray([sol_bvpm2(t) .+ rand() for t in tspan[1]:(π / 20):tspan[2]]) - tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan; + initial_u0 = VectorOfArray([sol_ms(t) .+ rand() for t in tspan[1]:(π / 20):tspan[2]]) + tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan, p; bcresid_prototype = (zeros(1), zeros(1))) # Just test that it runs. BVPSOL only works with linearly separable BCs. @@ -63,18 +68,19 @@ end @info "BVPSOL with DiffEqArray" ts = collect(tspan[1]:(π / 20):tspan[2]) - initial_u0 = DiffEqArray([sol_bvpm2(t) .+ rand() for t in ts], ts) - tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan; + initial_u0 = DiffEqArray([sol_ms(t) .+ rand() for t in ts], ts) + tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan, p; bcresid_prototype = (zeros(1), zeros(1))) sol_bvpsol = solve(tpprob, BVPSOL(); dt = π / 20) @info "BVPSOL with initial guess function" - initial_u0 = (p, t) -> sol_bvpm2(t) .+ rand() - tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan, p; - bcresid_prototype = (zeros(1), zeros(1))) - sol_bvpsol = solve(tpprob, BVPSOL(); dt = π / 20) + initial_u0 = (p, t) -> sol_ms(t) .+ rand() + # FIXME: Upstream fix + # tpprob = TwoPointBVProblem(ex7_f!, (ex7_2pbc1!, ex7_2pbc2!), initial_u0, tspan, p; + # bcresid_prototype = (zeros(1), zeros(1))) + # sol_bvpsol = solve(tpprob, BVPSOL(); dt = π / 20) end #=