Skip to content

Commit

Permalink
Fix ambiguity in __solve with polyalgorithm due to dispatch on Va (
Browse files Browse the repository at this point in the history
…#498)

* Fix ambiguity in `__solve` with polyalgorithm due to dispatch on `Val{N}`

Noticed downstream

```
  Candidates:
    kwcall(::NamedTuple, ::typeof(SciMLBase.__solve), prob::SciMLBase.AbstractNonlinearProblem, alg::NonlinearSolveBase.NonlinearSolvePolyAlgorithm{Val{N}}, args...) where N
      @ NonlinearSolveBase ~/.julia/packages/NonlinearSolveBase/54f3T/src/solve.jl:106
    kwcall(::NamedTuple, ::typeof(SciMLBase.__solve), prob::Union{SciMLBase.NonlinearLeastSquaresProblem{<:Union{Number, var"#s15"} where var"#s15"<:AbstractArray, iip, <:Union{var"#s14", var"#s13"} where {var"#s14"<:ForwardDiff.Dual{T, V, P}, var"#s13"<:(AbstractArray{<:ForwardDiff.Dual{T, V, P}})}} where {iip, T, V, P}, SciMLBase.NonlinearProblem{<:Union{Number, var"#s15"} where var"#s15"<:AbstractArray, iip, <:Union{var"#s14", var"#s13"} where {var"#s14"<:ForwardDiff.Dual{T, V, P}, var"#s13"<:(AbstractArray{<:ForwardDiff.Dual{T, V, P}})}} where {iip, T, V, P}}, alg::NonlinearSolveBase.AbstractNonlinearSolveAlgorithm, args...)
      @ NonlinearSolve ~/.julia/packages/NonlinearSolve/LuVnf/src/forward_diff.jl:14
    kwcall(::NamedTuple, ::typeof(SciMLBase.__solve), prob::Union{SciMLBase.NonlinearLeastSquaresProblem{<:Union{Number, var"#s15"} where var"#s15"<:AbstractArray, iip, <:Union{var"#s14", var"#s13"} where {var"#s14"<:ForwardDiff.Dual{T, V, P}, var"#s13"<:(AbstractArray{<:ForwardDiff.Dual{T, V, P}})}} where {iip, T, V, P}, SciMLBase.NonlinearProblem{<:Union{Number, var"#s15"} where var"#s15"<:AbstractArray, iip, <:Union{var"#s14", var"#s13"} where {var"#s14"<:ForwardDiff.Dual{T, V, P}, var"#s13"<:(AbstractArray{<:ForwardDiff.Dual{T, V, P}})}} where {iip, T, V, P}}, alg::NonlinearSolveBase.NonlinearSolvePolyAlgorithm, args...)
      @ NonlinearSolve ~/.julia/packages/NonlinearSolve/LuVnf/src/forward_diff.jl:14
  
  Possible fix, define
    kwcall(::NamedTuple, ::typeof(SciMLBase.__solve), ::Union{SciMLBase.NonlinearLeastSquaresProblem{<:Union{Number, var"#s15"} where var"#s15"<:AbstractArray, iip, <:Union{var"#s14", var"#s13"} where {var"#s14"<:ForwardDiff.Dual{T, V, P}, var"#s13"<:(AbstractArray{<:ForwardDiff.Dual{T, V, P}})}} where {iip, T, V, P}, SciMLBase.NonlinearProblem{<:Union{Number, var"#s15"} where var"#s15"<:AbstractArray, iip, <:Union{var"#s14", var"#s13"} where {var"#s14"<:ForwardDiff.Dual{T, V, P}, var"#s13"<:(AbstractArray{<:ForwardDiff.Dual{T, V, P}})}} where {iip, T, V, P}}, ::NonlinearSolveBase.NonlinearSolvePolyAlgorithm{Val{N}}, ::Vararg{Any}) where N
```

https://github.com/SciML/OrdinaryDiffEq.jl/actions/runs/11786040241/job/32828542508?pr=2522

The issue is that the polyalgorithm code is dispatching on more details of the polyalgorithm than https://github.com/SciML/NonlinearSolve.jl/blob/master/src/forward_diff.jl#L13-L25 which leads to an ambiguity. But it only needs that information since it's a generated function. The easy fix is to just make the generated function be one step lower.

* Update forward_ad_tests.jl

* Update lib/NonlinearSolveBase/src/solve.jl
  • Loading branch information
ChrisRackauckas authored Nov 12, 2024
1 parent cbe4354 commit fe4950d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
7 changes: 6 additions & 1 deletion lib/NonlinearSolveBase/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,12 @@ end
return Expr(:block, calls...)
end

@generated function SciMLBase.__solve(
function SciMLBase.__solve(prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm,
args...; kwargs...)
__generated_polysolve(prob, alg, args...; kwargs...)
end

@generated function __generated_polysolve(
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...;
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, kwargs...
) where {N}
Expand Down
1 change: 1 addition & 0 deletions test/forward_ad_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ end
Broyden(),
Klement(),
DFSane(),
FastShortcutNonlinearPolyalg(),
nothing,
NLsolveJL(),
CMINPACK(),
Expand Down

0 comments on commit fe4950d

Please sign in to comment.