Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove uses of NLsolve #2081

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Expand All @@ -35,7 +34,6 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLNLSolve = "e9a6253c-8580-4d32-9898-8661bb511710"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
Expand Down Expand Up @@ -68,7 +66,7 @@ Logging = "1.9"
LoopVectorization = "0.12"
MacroTools = "0.5"
MuladdMacro = "0.2.1"
NLsolve = "4.3"
NLsolve = "4"
NonlinearSolve = "3"
Polyester = "0.7"
PreallocationTools = "0.4"
Expand All @@ -77,7 +75,6 @@ Preferences = "1.3"
RecursiveArrayTools = "2.36, 3"
Reexport = "1.0"
SciMLBase = "2"
SciMLNLSolve = "0.1"
SciMLOperators = "0.3"
SimpleNonlinearSolve = "1"
SimpleUnPack = "1"
Expand All @@ -98,6 +95,7 @@ ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab"
Expand All @@ -112,4 +110,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "Test", "Unitful", "ModelingToolkit", "Pkg"]
test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "Test", "Unitful", "ModelingToolkit", "Pkg",
"NLsolve"]
1 change: 0 additions & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ using ExponentialUtilities

using NonlinearSolve

using NLsolve
# Required by temporary fix in not in-place methods with 12+ broadcasts
# `MVector` is used by Nordsieck forms
import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA
Expand Down
7 changes: 3 additions & 4 deletions src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
end
BrownFullBasicInit(abstol) = BrownFullBasicInit(; abstol = abstol, nlsolve = nothing)

using SciMLNLSolve
default_nlsolve(alg, isinplace, u, autodiff = false) = alg
function default_nlsolve(::Nothing, isinplace, u, autodiff = false)
NLSolveJL(autodiff = autodiff ? :forward : :central)
TrustRegion(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
end
function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray, autodiff = false)
SimpleNewtonRaphson(autodiff = autodiff)
SimpleTrustRegion(autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())

Check warning on line 28 in src/initialize_dae.jl

View check run for this annotation

Codecov / codecov/patch

src/initialize_dae.jl#L28

Added line #L28 was not covered by tests
end

## Notes
Expand Down Expand Up @@ -564,7 +563,7 @@
if alg.nlsolve !== nothing
nlsolve = alg.nlsolve
else
nlsolve = NewtonRaphson(autodiff = isAD)
nlsolve = NewtonRaphson(autodiff = alg_autodiff(integrator.alg))

Check warning on line 566 in src/initialize_dae.jl

View check run for this annotation

Codecov / codecov/patch

src/initialize_dae.jl#L566

Added line #L566 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep it as type information now that we can use it effectively.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does alg_autodiff give? Can it give bool?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it changes the type level stuff to a bool as backwards compat support IIRC

Copy link
Member Author

@avik-pal avik-pal Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function alg_autodiff(alg)
    autodiff = _alg_autodiff(alg)
    if autodiff == Val(false)
        return AutoFiniteDiff()
    elseif autodiff == Val(true)
        return AutoForwardDiff()
    else
        return _unwrap_val(autodiff)
    end
end

This should constant propagate and give the correct AD type right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't rely on constant propagation. It can fail. But let's pull this in and make this better in a follow up.

end

nlfunc = NonlinearFunction(nlequation!; jac_prototype = f.jac_prototype)
Expand Down
Loading