Skip to content

Commit

Permalink
replace deprecated types
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg authored and ChrisRackauckas committed Oct 8, 2023
1 parent e2183d9 commit 70f5d22
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 37 deletions.
12 changes: 6 additions & 6 deletions ext/DiffEqBaseReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,23 @@ end
end

# `ReverseDiff.TrackedArray`
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::ReverseDiff.TrackedArray,
p::ReverseDiff.TrackedArray, args...; kwargs...)
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0, p::ReverseDiff.TrackedArray,
args...; kwargs...)
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::ReverseDiff.TrackedArray, p,
Expand All @@ -101,7 +101,7 @@ function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
end

# `AbstractArray{<:ReverseDiff.TrackedReal}`
function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing},
Expand All @@ -112,7 +112,7 @@ function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0,
Expand All @@ -121,7 +121,7 @@ function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
DiffEqBase.solve_up(prob, sensealg, u0, reduce(vcat, p), args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing},
Expand Down
6 changes: 3 additions & 3 deletions ext/DiffEqBaseTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,23 @@ end
end
@inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t::Tracker.TrackedReal) = abs(u)

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::Tracker.TrackedArray,
p::Tracker.TrackedArray, args...; kwargs...)
Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::Tracker.TrackedArray, p, args...;
kwargs...)
Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0, p::Tracker.TrackedArray, args...;
Expand Down
6 changes: 3 additions & 3 deletions src/DiffEqBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ PrecompileTools.@recompile_invalidations begin
import PreallocationTools

import FunctionWrappersWrappers

using SciMLBase

using SciMLOperators: AbstractSciMLOperator, AbstractSciMLScalarOperator

using SciMLBase: @def, DEIntegrator, DEProblem,
using SciMLBase: @def, DEIntegrator, AbstractDEProblem,
AbstractDiffEqInterpolation,
DECallback, AbstractDEOptions, DECache, AbstractContinuousCallback,
AbstractDiscreteCallback, AbstractLinearProblem, AbstractNonlinearProblem,
AbstractOptimizationProblem, AbstractSteadyStateProblem,
AbstractJumpProblem,
AbstractNoiseProblem, AbstractEnsembleProblem, AbstractDynamicalODEProblem,
DEAlgorithm, StandardODEProblem, AbstractIntegralProblem,
AbstractDEAlgorithm, StandardODEProblem, AbstractIntegralProblem,
AbstractSensitivityAlgorithm, AbstractODEAlgorithm,
AbstractSDEAlgorithm, AbstractDDEAlgorithm, AbstractDAEAlgorithm,
AbstractSDDEAlgorithm, AbstractRODEAlgorithm, DAEInitializationAlgorithm,
Expand Down
50 changes: 25 additions & 25 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ NO_TSPAN_PROBS = Union{AbstractLinearProblem, AbstractNonlinearProblem,
AbstractIntegralProblem, AbstractSteadyStateProblem,
AbstractJumpProblem}

has_kwargs(_prob::DEProblem) = has_kwargs(typeof(_prob))
has_kwargs(_prob::AbstractDEProblem) = has_kwargs(typeof(_prob))
Base.@pure __has_kwargs(::Type{T}) where {T} = :kwargs fieldnames(T)
has_kwargs(::Type{T}) where {T} = __has_kwargs(T)

Expand Down Expand Up @@ -196,7 +196,7 @@ end
const NON_SOLVER_MESSAGE = """
The arguments to solve are incorrect.
The second argument must be a solver choice, `solve(prob,alg)`
where `alg` is a `<: DEAlgorithm`, e.g. `Tsit5()`.
where `alg` is a `<: AbstractDEAlgorithm`, e.g. `Tsit5()`.
Please double check the arguments being sent to the solver.
Expand Down Expand Up @@ -484,7 +484,7 @@ function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothin
end
end

function init(prob::Union{DEProblem, NonlinearProblem}, args...; sensealg = nothing,
function init(prob::Union{AbstractDEProblem, NonlinearProblem}, args...; sensealg = nothing,
u0 = nothing, p = nothing, kwargs...)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
Expand All @@ -500,7 +500,7 @@ function init(prob::AbstractJumpProblem, args...; kwargs...)
init_call(prob, args...; kwargs...)
end

function init_up(prob::DEProblem, sensealg, u0, p, args...; kwargs...)
function init_up(prob::AbstractDEProblem, sensealg, u0, p, args...; kwargs...)
alg = extract_alg(args, kwargs, prob.kwargs)
if isnothing(alg) # Default algorithm handling
_prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0,
Expand Down Expand Up @@ -566,7 +566,7 @@ mutable struct NullODEIntegrator{IIP, ProbType, T, SolType, F, P} <:
f::F
p::P
end
function build_null_integrator(prob::DEProblem, args...;
function build_null_integrator(prob::AbstractDEProblem, args...;
kwargs...)
sol = solve(prob, args...; kwargs...)
return NullODEIntegrator{isinplace(prob), typeof(prob), eltype(prob.tspan), typeof(sol),
Expand All @@ -592,7 +592,7 @@ function step!(integ::NullODEIntegrator, dt = nothing, stop_at_tdt = false)
return nothing
end

function build_null_solution(prob::DEProblem, args...;
function build_null_solution(prob::AbstractDEProblem, args...;
saveat = (),
save_everystep = true,
save_on = true,
Expand Down Expand Up @@ -635,7 +635,7 @@ end

"""
```julia
solve(prob::DEProblem, alg::Union{DEAlgorithm,Nothing}; kwargs...)
solve(prob::AbstractDEProblem, alg::Union{AbstractDEAlgorithm,Nothing}; kwargs...)
```
## Arguments
Expand Down Expand Up @@ -914,7 +914,7 @@ the extension to other types is straightforward.
to save size or because the user does not care about the others. Finally, with
`progress = true` you are enabling the progress bar.
"""
function solve(prob::DEProblem, args...; sensealg = nothing,
function solve(prob::AbstractDEProblem, args...; sensealg = nothing,
u0 = nothing, p = nothing, wrap = Val(true), kwargs...)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
Expand Down Expand Up @@ -987,8 +987,8 @@ function solve(prob::NonlinearProblem, args...; sensealg = nothing,
end
end

function solve_up(prob::Union{DEProblem, NonlinearProblem}, sensealg, u0, p, args...;
kwargs...)
function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0, p,
args...; kwargs...)
alg = extract_alg(args, kwargs, prob.kwargs)
if isnothing(alg) # Default algorithm handling
_prob = get_concrete_problem(prob, !(typeof(prob) <: DiscreteProblem); u0 = u0,
Expand Down Expand Up @@ -1070,12 +1070,12 @@ function get_concrete_problem(prob::AbstractEnsembleProblem, isadapt; kwargs...)
prob
end

function solve(prob::PDEProblem, alg::DiffEqBase.DEAlgorithm, args...;
function solve(prob::PDEProblem, alg::AbstractDEAlgorithm, args...;
kwargs...)
solve(prob.prob, alg, args...; kwargs...)
end

function init(prob::PDEProblem, alg::DiffEqBase.DEAlgorithm, args...;
function init(prob::PDEProblem, alg::AbstractDEAlgorithm, args...;
kwargs...)
init(prob.prob, alg, args...; kwargs...)
end
Expand Down Expand Up @@ -1266,27 +1266,27 @@ handle_distribution_u0(_u0) = _u0
eval_u0(u0::Function) = true
eval_u0(u0) = false

function __solve(prob::DEProblem, args...; default_set = false, second_time = false,
function __solve(prob::AbstractDEProblem, args...; default_set = false, second_time = false,
kwargs...)
if second_time
throw(NoDefaultAlgorithmError())
elseif length(args) > 0 && !(typeof(args[1]) <: Union{Nothing, DEAlgorithm})
elseif length(args) > 0 && !(typeof(args[1]) <: Union{Nothing, AbstractDEAlgorithm})
throw(NonSolverError())
else
__solve(prob::DEProblem, nothing, args...; default_set = false, second_time = true,
kwargs...)
__solve(prob::AbstractDEProblem, nothing, args...; default_set = false,
second_time = true, kwargs...)
end
end

function __init(prob::DEProblem, args...; default_set = false, second_time = false,
function __init(prob::AbstractDEProblem, args...; default_set = false, second_time = false,
kwargs...)
if second_time
throw(NoDefaultAlgorithmError())
elseif length(args) > 0 && !(typeof(args[1]) <: Union{Nothing, DEAlgorithm})
elseif length(args) > 0 && !(typeof(args[1]) <: Union{Nothing, AbstractDEAlgorithm})
throw(NonSolverError())
else
__init(prob::DEProblem, nothing, args...; default_set = false, second_time = true,
kwargs...)
__init(prob::AbstractDEProblem, nothing, args...; default_set = false,
second_time = true, kwargs...)
end
end

Expand Down Expand Up @@ -1360,7 +1360,7 @@ Ignores all adjoint definitions (i.e. `sensealg`) and proceeds to do standard
AD through the `solve` functions. Generally only used internally for implementing
discrete sensitivity algorithms.
"""
struct SensitivityADPassThrough <: SciMLBase.DEAlgorithm end
struct SensitivityADPassThrough <: AbstractDEAlgorithm end

function ChainRulesCore.frule(::typeof(solve_up), prob,
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
Expand All @@ -1370,7 +1370,7 @@ function ChainRulesCore.frule(::typeof(solve_up), prob,
kwargs...)
end

function ChainRulesCore.rrule(::typeof(solve_up), prob::SciMLBase.DEProblem,
function ChainRulesCore.rrule(::typeof(solve_up), prob::AbstractDEProblem,
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
u0, p, args...;
kwargs...)
Expand All @@ -1382,8 +1382,8 @@ end
### Legacy Dispatches to be Non-Breaking
###

@deprecate concrete_solve(prob::SciMLBase.DEProblem,
alg::Union{SciMLBase.DEAlgorithm, Nothing},
@deprecate concrete_solve(prob::AbstractDEProblem,
alg::Union{AbstractDEAlgorithm, Nothing},
u0 = prob.u0, p = prob.p, args...; kwargs...) solve(prob, alg,
args...;
u0 = u0,
Expand Down Expand Up @@ -1459,7 +1459,7 @@ end
else
nothing
end
elseif solve_args[1] isa DEAlgorithm
elseif solve_args[1] isa AbstractDEAlgorithm
solve_args[1]
else
nothing
Expand Down

0 comments on commit 70f5d22

Please sign in to comment.