From 0ee433f115f1f543102de3ad5fd35deb8e7d364e Mon Sep 17 00:00:00 2001 From: lxvm Date: Sat, 23 Sep 2023 14:46:57 -0400 Subject: [PATCH 01/14] simplify batch integrand_prototype --- src/problems/basic_problems.jl | 22 ++++++++++++++-------- src/scimlfunctions.jl | 16 ++++++++-------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/src/problems/basic_problems.jl b/src/problems/basic_problems.jl index 5bfafe888..57a736ef3 100644 --- a/src/problems/basic_problems.jl +++ b/src/problems/basic_problems.jl @@ -414,8 +414,10 @@ which are `Number`s or `AbstractVector`s with the same geometry as `u`. ### Constructors ``` -IntegralProblem(f,domain,p=NullParameters(); kwargs...) -IntegralProblem(f,lb,ub,p=NullParameters(); kwargs...) +IntegralProblem(f::AbstractIntegralFunction,domain,p=NullParameters(); kwargs...) +IntegralProblem(f::AbstractIntegralFunction,lb,ub,p=NullParameters(); kwargs...) +IntegralProblem(f,domain,p=NullParameters(); nout=nothing, batch=nothing, kwargs...) +IntegralProblem(f,lb,ub,p=NullParameters(); nout=nothing, batch=nothing, kwargs...) ``` - f: the integrand, callable function `y = f(u,p)` for out-of-place (default) or an @@ -424,6 +426,10 @@ IntegralProblem(f,lb,ub,p=NullParameters(); kwargs...) - lb: Either a number or vector of lower bounds. - ub: Either a number or vector of upper bounds. - p: The parameters associated with the problem. +- nout: DEPRECATED (see `IntegralFunction`): length of the vector output of the integrand + (by default the integrand is assumed to be scalar) +- batch: DEPRECATED (see `BatchIntegralFunction`): number of points the integrand can + evaluate simultaneously (by default there is no batching) - kwargs: Keyword arguments copied to the solvers. Additionally, we can supply iip like IntegralProblem{iip}(...) as true or false to declare at @@ -469,19 +475,19 @@ function IntegralProblem(f, args...; nout = nothing, batch = nothing, kwargs...) @warn "`nout` and `batch` keywords are deprecated in favor of inplace `IntegralFunction`s or `BatchIntegralFunction`s. See the updated Integrals.jl documentation for details." end - max_batch = batch === nothing ? 0 : batch g = if isinplace(f, 3) - output_prototype = Vector{Float64}(undef, nout === nothing ? 1 : nout) - if max_batch == 0 + if batch === nothing + output_prototype = nout === nothing ? Array{Float64, 0}(undef) : Vector{Float64}(undef, nout) IntegralFunction(f, output_prototype) else - BatchIntegralFunction(f, output_prototype, max_batch=max_batch) + output_prototype = nout === nothing ? Float64[] : Matrix{Float64}(undef, nout, 0) + BatchIntegralFunction(f, output_prototype, max_batch=batch) end else - if max_batch == 0 + if batch === nothing IntegralFunction(f) else - BatchIntegralFunction(f, max_batch=max_batch) + BatchIntegralFunction(f, max_batch=batch) end end IntegralProblem(g, args...; kwargs...) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 2be746949..151c8c71d 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2361,8 +2361,8 @@ BatchIntegralFunction{iip,specialize}(f, [integrand_prototype]; max_batch=typemax(Int)) ``` Note that only `f` is required, and in the case of inplace integrands a mutable container -`integrand_prototype` to store the result of the integrand of one integrand, without a last -"batching" dimension. +`integrand_prototype` to store a batch of integrand evaluations, with a last "batching" +dimension. The keyword `max_batch` is used to set a soft limit on the number of points to batch at the same time so that memory usage is controlled. @@ -2375,12 +2375,12 @@ assumed to be out-of-place. Out-of-place functions must be of the form ``y = f(u,p)`` and in-place functions of the form ``f(y, u, p)``. Since `f` is allowed to return any type (e.g. real or complex numbers or arrays), in-place functions must provide a container `integrand_prototype` of the right type -for a single integrand evaluation. The integration algorithm will then allocate a ``y`` -array with the same element type as `integrand_prototype` and an additional last "batching" -dimension to store multiple integrand evaluations. In the out-of-place case, the algorithm -may infer the type of ``y`` by passing `f` an empty array of input points. This means ``y`` -is a vector in the out-of-place case, or a matrix/array in the in-place case. The number of -batched points may vary between subsequent calls to `f`. When in-place forms are used, +for ``y``. The only assumption that is enforced is that the last axes of `the `y`` and ``u`` +arrays are the same length and correspond to distinct batched points. The algorithm will +then allocate arrays `similar` to ``y`` to pass to the integrand. Since the algorithm may +vary the number of points to batch, the length of the batching dimension of ``y`` may vary +between subsequent calls to `f`. In the out-of-place case, the algorithm may infer the type +of ``y`` by passing `f` an empty array of input points. When in-place forms are used, in-place array operations may be used by algorithms to reduce allocations. If `integrand_prototype` is not provided, `f` is assumed to be out-of-place. From cbf77969151124f2939237b2c11e282489323b67 Mon Sep 17 00:00:00 2001 From: lxvm Date: Sat, 23 Sep 2023 22:01:22 -0400 Subject: [PATCH 02/14] add note about views --- src/scimlfunctions.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 151c8c71d..6afcbf108 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2379,7 +2379,8 @@ for ``y``. The only assumption that is enforced is that the last axes of `the `y arrays are the same length and correspond to distinct batched points. The algorithm will then allocate arrays `similar` to ``y`` to pass to the integrand. Since the algorithm may vary the number of points to batch, the length of the batching dimension of ``y`` may vary -between subsequent calls to `f`. In the out-of-place case, the algorithm may infer the type +between subsequent calls to `f`. To reduce allocations, views of ``y`` may also be passed to +the integrand. In the out-of-place case, the algorithm may infer the type of ``y`` by passing `f` an empty array of input points. When in-place forms are used, in-place array operations may be used by algorithms to reduce allocations. If `integrand_prototype` is not provided, `f` is assumed to be out-of-place. From 007bc432f28c2462e4db2b7379d2e091f000a5ca Mon Sep 17 00:00:00 2001 From: lxvm Date: Sat, 23 Sep 2023 22:01:35 -0400 Subject: [PATCH 03/14] fix dispatch in IntegralProblem constructor --- src/problems/basic_problems.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/problems/basic_problems.jl b/src/problems/basic_problems.jl index 57a736ef3..0faa97713 100644 --- a/src/problems/basic_problems.jl +++ b/src/problems/basic_problems.jl @@ -467,7 +467,7 @@ function IntegralProblem(f::AbstractIntegralFunction, ub::B, p = NullParameters(); kwargs...) where {B} - IntegralProblem(f, (lb, ub), p; kwargs...) + IntegralProblem{isinplace(f)}(f, (lb, ub), p; kwargs...) end function IntegralProblem(f, args...; nout = nothing, batch = nothing, kwargs...) From 10307a5e98c7bdfd1405be3ab5e1162e402ce1f1 Mon Sep 17 00:00:00 2001 From: lxvm Date: Sat, 28 Oct 2023 10:31:07 -0400 Subject: [PATCH 04/14] support deprecated lb and ub properties of IntegralProblem --- src/problems/basic_problems.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/problems/basic_problems.jl b/src/problems/basic_problems.jl index 0faa97713..b3061cdf2 100644 --- a/src/problems/basic_problems.jl +++ b/src/problems/basic_problems.jl @@ -493,6 +493,19 @@ function IntegralProblem(f, args...; nout = nothing, batch = nothing, kwargs...) IntegralProblem(g, args...; kwargs...) end +function Base.getproperty(prob::IntegralProblem, name::Symbol) + if name === :lb + domain = getfield(prob, :domain) + lb, ub = domain + return lb + elseif name === :ub + domain = getfield(prob, :domain) + lb, ub = domain + return ub + end + return Base.getfield(prob, name) +end + struct QuadratureProblem end @deprecate QuadratureProblem(args...; kwargs...) IntegralProblem(args...; kwargs...) From 3127f57a7dc666931671fc1a4f6d075d13538fe5 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 31 Oct 2023 23:58:40 -0400 Subject: [PATCH 05/14] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3019705b3..626b7ff25 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "2.5.0" +version = "2.6.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 585a1aabcb21054bdbceb5a5c48197305d0d582a Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 1 Nov 2023 00:23:52 -0400 Subject: [PATCH 06/14] Update Project.toml --- Project.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Project.toml b/Project.toml index 626b7ff25..3c0453fe2 100644 --- a/Project.toml +++ b/Project.toml @@ -53,14 +53,19 @@ ArrayInterface = "6, 7" ChainRulesCore = "1.16" CommonSolve = "0.2.4" ConstructionBase = "1" +Distributed = "1.6" DocStringExtensions = "0.8, 0.9" EnumX = "1" FillArrays = "1.6" FunctionWrappersWrappers = "0.1.3" IteratorInterfaceExtensions = "^0.1, ^1" +LinearAlgebra = "1.6" +Logging = "1.6" +Markdown = "1.6" PartialFunctions = "1.1" PrecompileTools = "1" Preferences = "1.3" +Printf = "1.6" RCall = "0.13.18" RecipesBase = "0.7.0, 0.8, 1.0" RecursiveArrayTools = "2.33" From e037ef6eebe1dc75168bc21e63869796a47156a5 Mon Sep 17 00:00:00 2001 From: Pepijn de Vos Date: Thu, 2 Nov 2023 17:54:08 +0100 Subject: [PATCH 07/14] Change typeof(x) <: y to x isa y --- src/ensemble/basic_ensemble_solve.jl | 6 ++-- src/ensemble/ensemble_analysis.jl | 32 +++++++++---------- src/ensemble/ensemble_solutions.jl | 12 +++---- src/integrator_interface.jl | 8 ++--- src/interpolation.jl | 38 +++++++++++----------- src/operators/diffeq_operator.jl | 4 +-- src/problems/discrete_problems.jl | 2 +- src/problems/problem_utils.jl | 2 +- src/remake.jl | 2 +- src/scimlfunctions.jl | 16 +++++----- src/solutions/ode_solutions.jl | 6 ++-- src/solutions/rode_solutions.jl | 8 ++--- src/solutions/solution_interface.jl | 48 ++++++++++++++-------------- 13 files changed, 92 insertions(+), 92 deletions(-) diff --git a/src/ensemble/basic_ensemble_solve.jl b/src/ensemble/basic_ensemble_solve.jl index 92f7a86d0..8e62ab278 100644 --- a/src/ensemble/basic_ensemble_solve.jl +++ b/src/ensemble/basic_ensemble_solve.jl @@ -98,7 +98,7 @@ function batch_func(i, prob, alg; kwargs...) new_prob = prob.prob_func(_prob, i, iter) rerun = true x = prob.output_func(solve(new_prob, alg; kwargs...), i) - if !(typeof(x) <: Tuple) + if !(x isa Tuple) rerun_warn() _x = (x, false) else @@ -110,7 +110,7 @@ function batch_func(i, prob, alg; kwargs...) _prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob new_prob = prob.prob_func(_prob, i, iter) x = prob.output_func(solve(new_prob, alg; kwargs...), i) - if !(typeof(x) <: Tuple) + if !(x isa Tuple) rerun_warn() _x = (x, false) else @@ -163,7 +163,7 @@ function solve_batch(prob, alg, ensemblealg::EnsembleThreads, II, pmap_batch_siz return solve_batch(prob, alg, EnsembleSerial(), II, pmap_batch_size; kwargs...) end - if typeof(prob.prob) <: AbstractJumpProblem && length(II) != 1 + if prob.prob isa AbstractJumpProblem && length(II) != 1 probs = [deepcopy(prob.prob) for i in 1:nthreads] else probs = prob.prob diff --git a/src/ensemble/ensemble_analysis.jl b/src/ensemble/ensemble_analysis.jl index 509869848..c68fd5dad 100644 --- a/src/ensemble/ensemble_analysis.jl +++ b/src/ensemble/ensemble_analysis.jl @@ -7,7 +7,7 @@ get_timestep(sim, i) = (getindex(sol, i) for sol in sim) get_timepoint(sim, t) = (sol(t) for sol in sim) function componentwise_vectors_timestep(sim, i) arr = [get_timestep(sim, i)...] - if typeof(arr[1]) <: AbstractArray + if arr[1] isa AbstractArray return vecarr_to_vectors(VectorOfArray(arr)) else return arr @@ -15,7 +15,7 @@ function componentwise_vectors_timestep(sim, i) end function componentwise_vectors_timepoint(sim, t) arr = [get_timepoint(sim, t)...] - if typeof(arr[1]) <: AbstractArray + if arr[1] isa AbstractArray return vecarr_to_vectors(VectorOfArray(arr)) else return arr @@ -123,7 +123,7 @@ end function SciMLBase.EnsembleSummary(sim::SciMLBase.AbstractEnsembleSolution{T, N}, t = sim[1].t; quantiles = [0.05, 0.95]) where {T, N} - if typeof(sim[1]) <: SciMLSolution + if sim[1] isa SciMLSolution m, v = timeseries_point_meanvar(sim, t) med = timeseries_point_median(sim, t) qlow = timeseries_point_quantile(sim, quantiles[1], t) @@ -190,13 +190,13 @@ function componentwise_mean(A) mean = zero(x0) ./ 1 for x in A n += 1 - if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray) + if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray) mean .+= x else mean += x end end - if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray) + if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray) mean ./= n else mean /= n @@ -215,7 +215,7 @@ function componentwise_meanvar(A; bessel = true) delta2 = zero(x0) ./ 1 for x in A n += 1 - if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray) + if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray) delta .= x .- mean mean .+= delta ./ n delta2 .= x .- mean @@ -231,13 +231,13 @@ function componentwise_meanvar(A; bessel = true) return NaN else if bessel - if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray) + if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray) M2 .= M2 ./ (n .- 1) else M2 = M2 ./ (n .- 1) end else - if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray) + if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray) M2 .= M2 ./ n else M2 = M2 ./ n @@ -257,7 +257,7 @@ function componentwise_meancov(A, B; bessel = true) dx = zero(x0) ./ 1 for (x, y) in zip(A, B) n += 1 - if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray) + if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray) dx .= x .- meanx meanx .+= dx ./ n meany .+= (y .- meany) ./ n @@ -273,13 +273,13 @@ function componentwise_meancov(A, B; bessel = true) return NaN else if bessel - if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray) + if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray) C .= C ./ (n .- 1) else C = C ./ (n .- 1) end else - if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray) + if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray) C .= C ./ n else C = C ./ n @@ -293,7 +293,7 @@ function componentwise_meancor(A, B; bessel = true) mx, my, cov = componentwise_meancov(A, B; bessel = bessel) mx, vx = componentwise_meanvar(A; bessel = bessel) my, vy = componentwise_meanvar(B; bessel = bessel) - if typeof(vx) <: AbstractArray + if vx isa AbstractArray vx .= sqrt.(vx) vy .= sqrt.(vy) else @@ -316,7 +316,7 @@ function componentwise_weighted_meancov(A, B, W; weight_type = :reliability) dx = zero(x0) ./ 1 for (x, y, w) in zip(A, B, W) n += 1 - if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray) + if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray) wsum .+= w wsum2 .+= w .* w dx .= x .- meanx @@ -336,19 +336,19 @@ function componentwise_weighted_meancov(A, B, W; weight_type = :reliability) return NaN else if weight_type == :population - if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray) + if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray) C .= C ./ wsum else C = C ./ wsum end elseif weight_type == :reliability - if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray) + if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray) C .= C ./ (wsum .- wsum2 ./ wsum) else C = C ./ (wsum .- wsum2 ./ wsum) end elseif weight_type == :frequency - if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray) + if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray) C .= C ./ (wsum .- 1) else C = C ./ (wsum .- 1) diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index 265918f3e..83e9a6083 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -137,7 +137,7 @@ end ### Plot Recipes @recipe function f(sim::AbstractEnsembleSolution; - zcolors = typeof(sim.u) <: AbstractArray ? fill(nothing, length(sim.u)) : + zcolors = sim.u isa AbstractArray ? fill(nothing, length(sim.u)) : nothing, trajectories = eachindex(sim)) for i in trajectories @@ -154,16 +154,16 @@ end end @recipe function f(sim::EnsembleSummary; - trajectories = typeof(sim.u[1]) <: AbstractArray ? eachindex(sim.u[1]) : + trajectories = sim.u[1] isa AbstractArray ? eachindex(sim.u[1]) : 1, error_style = :ribbon, ci_type = :quantile) if ci_type == :SEM - if typeof(sim.u[1]) <: AbstractArray + if sim.u[1] isa AbstractArray u = vecarr_to_vectors(sim.u) else u = [sim.u.u] end - if typeof(sim.u[1]) <: AbstractArray + if sim.u[1] isa AbstractArray ci_low = vecarr_to_vectors(VectorOfArray([sqrt.(sim.v[i] / sim.num_monte) .* 1.96 for i in 1:length(sim.v)])) ci_high = ci_low @@ -173,12 +173,12 @@ end ci_high = ci_low end elseif ci_type == :quantile - if typeof(sim.med[1]) <: AbstractArray + if sim.med[1] isa AbstractArray u = vecarr_to_vectors(sim.med) else u = [sim.med.u] end - if typeof(sim.u[1]) <: AbstractArray + if sim.u[1] isa AbstractArray ci_low = u - vecarr_to_vectors(sim.qlow) ci_high = vecarr_to_vectors(sim.qhigh) - u else diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index 52e6edb28..7046398a7 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -750,7 +750,7 @@ Base.length(iter::TimeChoiceIterator) = length(iter.ts) @recipe function f(integrator::DEIntegrator; denseplot = (integrator.opts.calck || - typeof(integrator) <: AbstractSDEIntegrator) && + integrator isa AbstractSDEIntegrator) && integrator.iter > 0, plotdensity = 10, plot_analytic = false, vars = nothing, idxs = nothing) @@ -797,7 +797,7 @@ Base.length(iter::TimeChoiceIterator) = length(iter.ts) else # just get values if x[j] == 0 push!(plot_vecs[j - 1], integrator.t) - elseif x[j] == 1 && !(typeof(integrator.u) <: AbstractArray) + elseif x[j] == 1 && !(integrator.u isa AbstractArray) push!(plot_vecs[j - 1], integrator.u) else push!(plot_vecs[j - 1], integrator.u[x[j]]) @@ -816,7 +816,7 @@ Base.length(iter::TimeChoiceIterator) = length(iter.ts) else # Just get values if x[j] == 0 push!(plot_vecs[j], integrator.t) - elseif x[j] == 1 && !(typeof(integrator.u) <: AbstractArray) + elseif x[j] == 1 && !(integrator.u isa AbstractArray) push!(plot_vecs[j], integrator.sol.prob.f(Val{:analytic}, integrator.t, integrator.sol[1])) @@ -840,7 +840,7 @@ Base.length(iter::TimeChoiceIterator) = length(iter.ts) end # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ... - if typeof(idxs) <: Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol) + if idxs isa Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol) xlabel --> idxs[1] ylabel --> idxs[2] if length(idxs) > 2 diff --git a/src/interpolation.jl b/src/interpolation.jl index e64d91c35..d212ac82f 100644 --- a/src/interpolation.jl +++ b/src/interpolation.jl @@ -81,7 +81,7 @@ end continuity::Symbol = :left) where {I, D} t = id.t u = id.u - typeof(id) <: HermiteInterpolation && (du = id.du) + id isa HermiteInterpolation && (du = id.du) tdir = sign(t[end] - t[1]) idx = sortperm(tvals, rev = tdir < 0) i = 2 # Start the search thinking it's between t[1] and t[2] @@ -91,9 +91,9 @@ end error("Solution interpolation cannot extrapolate past the final timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.") tdir * tvals[idx[1]] < tdir * t[1] && error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.") - if typeof(idxs) <: Number + if idxs isa Number vals = Vector{eltype(first(u))}(undef, length(tvals)) - elseif typeof(idxs) <: AbstractVector + elseif idxs isa AbstractVector vals = Vector{Vector{eltype(first(u))}}(undef, length(tvals)) else vals = Vector{eltype(u)}(undef, length(tvals)) @@ -101,7 +101,7 @@ end for j in idx tval = tvals[j] i = searchsortedfirst(@view(t[i:end]), tval, rev = tdir < 0) + i - 1 # It's in the interval t[i-1] to t[i] - avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual + avoid_constant_ends = deriv != Val{0} #|| tval isa ForwardDiff.Dual avoid_constant_ends && i == 1 && (i += 1) if !avoid_constant_ends && t[i - 1] == tval # Can happen if it's the first value! if idxs === nothing @@ -118,11 +118,11 @@ end vals[j] = u[k][idxs] end else - typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) + id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) dt = t[i] - t[i - 1] Θ = (tval - t[i - 1]) / dt idxs_internal = idxs - if typeof(id) <: HermiteInterpolation + if id isa HermiteInterpolation vals[j] = interpolant(Θ, id, dt, u[i - 1], u[i], du[i - 1], du[i], idxs_internal, deriv) else @@ -143,7 +143,7 @@ times t (sorted), with values u and derivatives ks continuity::Symbol = :left) where {I, D} t = id.t u = id.u - typeof(id) <: HermiteInterpolation && (du = id.du) + id isa HermiteInterpolation && (du = id.du) tdir = sign(t[end] - t[1]) idx = sortperm(tvals, rev = tdir < 0) i = 2 # Start the search thinking it's between t[1] and t[2] @@ -156,7 +156,7 @@ times t (sorted), with values u and derivatives ks for j in idx tval = tvals[j] i = searchsortedfirst(@view(t[i:end]), tval, rev = tdir < 0) + i - 1 # It's in the interval t[i-1] to t[i] - avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual + avoid_constant_ends = deriv != Val{0} #|| tval isa ForwardDiff.Dual avoid_constant_ends && i == 1 && (i += 1) if !avoid_constant_ends && t[i - 1] == tval # Can happen if it's the first value! if idxs === nothing @@ -173,19 +173,19 @@ times t (sorted), with values u and derivatives ks vals[j] = u[k][idxs] end else - typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) + id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) dt = t[i] - t[i - 1] Θ = (tval - t[i - 1]) / dt idxs_internal = idxs if eltype(u) <: Union{AbstractArray, ArrayPartition} - if typeof(id) <: HermiteInterpolation + if id isa HermiteInterpolation interpolant!(vals[j], Θ, id, dt, u[i - 1], u[i], du[i - 1], du[i], idxs_internal, deriv) else interpolant!(vals[j], Θ, id, dt, u[i - 1], u[i], idxs_internal, deriv) end else - if typeof(id) <: HermiteInterpolation + if id isa HermiteInterpolation vals[j] = interpolant(Θ, id, dt, u[i - 1], u[i], du[i - 1], du[i], idxs_internal, deriv) else @@ -206,7 +206,7 @@ times t (sorted), with values u and derivatives ks continuity::Symbol = :left) where {I, D} t = id.t u = id.u - typeof(id) <: HermiteInterpolation && (du = id.du) + id isa HermiteInterpolation && (du = id.du) tdir = sign(t[end] - t[1]) t[end] == t[1] && tval != t[end] && error("Solution interpolation cannot extrapolate from a single timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.") @@ -215,7 +215,7 @@ times t (sorted), with values u and derivatives ks tdir * tval < tdir * t[1] && error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.") @inbounds i = searchsortedfirst(t, tval, rev = tdir < 0) # It's in the interval t[i-1] to t[i] - avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual + avoid_constant_ends = deriv != Val{0} #|| tval isa ForwardDiff.Dual avoid_constant_ends && i == 1 && (i += 1) if !avoid_constant_ends && t[i] == tval lasti = lastindex(t) @@ -232,11 +232,11 @@ times t (sorted), with values u and derivatives ks val = u[i - 1][idxs] end else - typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) + id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) dt = t[i] - t[i - 1] Θ = (tval - t[i - 1]) / dt idxs_internal = idxs - if typeof(id) <: HermiteInterpolation + if id isa HermiteInterpolation val = interpolant(Θ, id, dt, u[i - 1], u[i], du[i - 1], du[i], idxs_internal, deriv) else @@ -256,7 +256,7 @@ times t (sorted), with values u and derivatives ks continuity::Symbol = :left) where {I, D} t = id.t u = id.u - typeof(id) <: HermiteInterpolation && (du = id.du) + id isa HermiteInterpolation && (du = id.du) tdir = sign(t[end] - t[1]) t[end] == t[1] && tval != t[end] && error("Solution interpolation cannot extrapolate from a single timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.") @@ -265,7 +265,7 @@ times t (sorted), with values u and derivatives ks tdir * tval < tdir * t[1] && error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.") @inbounds i = searchsortedfirst(t, tval, rev = tdir < 0) # It's in the interval t[i-1] to t[i] - avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual + avoid_constant_ends = deriv != Val{0} #|| tval isa ForwardDiff.Dual avoid_constant_ends && i == 1 && (i += 1) if !avoid_constant_ends && t[i] == tval lasti = lastindex(t) @@ -282,11 +282,11 @@ times t (sorted), with values u and derivatives ks copy!(out, u[i - 1][idxs]) end else - typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) + id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) dt = t[i] - t[i - 1] Θ = (tval - t[i - 1]) / dt idxs_internal = idxs - if typeof(id) <: HermiteInterpolation + if id isa HermiteInterpolation interpolant!(out, Θ, id, dt, u[i - 1], u[i], du[i - 1], du[i], idxs_internal, deriv) else diff --git a/src/operators/diffeq_operator.jl b/src/operators/diffeq_operator.jl index e508c7022..a97063b83 100644 --- a/src/operators/diffeq_operator.jl +++ b/src/operators/diffeq_operator.jl @@ -37,7 +37,7 @@ function (L::AffineDiffEqOperator)(u, p, t::Number) update_coefficients!(L, u, p, t) du = sum(A * u for A in L.As) for B in L.Bs - if typeof(B) <: Union{Number, AbstractArray} + if B isa Union{Number, AbstractArray} du .+= B else du .+= B(t) @@ -58,7 +58,7 @@ function (L::AffineDiffEqOperator)(du, u, p, t::Number) du .+= du_cache end for B in L.Bs - if typeof(B) <: Union{Number, AbstractArray} + if B isa Union{Number, AbstractArray} du .+= B else B(du_cache, t) diff --git a/src/problems/discrete_problems.jl b/src/problems/discrete_problems.jl index 9d6f643e9..a2ae622fb 100644 --- a/src/problems/discrete_problems.jl +++ b/src/problems/discrete_problems.jl @@ -146,7 +146,7 @@ Define a discrete problem with the identity map. """ function DiscreteProblem(u0::Union{AbstractArray, Number}, tspan::Tuple, p = NullParameters(); kwargs...) - iip = typeof(u0) <: AbstractArray + iip = u0 isa AbstractArray if iip f = DISCRETE_INPLACE_DEFAULT else diff --git a/src/problems/problem_utils.jl b/src/problems/problem_utils.jl index 936eab5b7..c72a451f3 100644 --- a/src/problems/problem_utils.jl +++ b/src/problems/problem_utils.jl @@ -23,7 +23,7 @@ function Base.summary(io::IO, prob::AbstractDEProblem) type_color, typeof(prob.u0), no_color, " and tType ", type_color, - typeof(prob.tspan) <: Function ? + prob.tspan isa Function ? "Unknown" : (prob.tspan === nothing ? "Nothing" : typeof(prob.tspan[1])), no_color, ". In-place: ", diff --git a/src/remake.jl b/src/remake.jl index 7bbc6de6a..99208f334 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -310,7 +310,7 @@ end # overloaded in MTK to intercept symbolic remake function process_p_u0_symbolic(prob, p, u0) - if typeof(prob) <: Union{AbstractDEProblem, OptimizationProblem, NonlinearProblem} + if prob isa Union{AbstractDEProblem, OptimizationProblem, NonlinearProblem} throw(ArgumentError("Please load `ModelingToolkit.jl` in order to support symbolic remake.")) else throw(ArgumentError("Symbolic remake for $(typeof(prob)) is currently not supported, consider opening an issue.")) diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index c74808df2..ce7d12627 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -2475,7 +2475,7 @@ function ODEFunction{iip, specialize}(f; sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize, } - if mass_matrix === I && typeof(f) <: Tuple + if mass_matrix === I && f isa Tuple mass_matrix = ((I for i in 1:length(f))...,) end @@ -2690,7 +2690,7 @@ end f1 = ODEFunction(f1) f2 = ODEFunction(f2) - if !(typeof(f1) <: AbstractSciMLOperator || typeof(f1.f) <: AbstractSciMLOperator) && + if !(f1 isa AbstractSciMLOperator || f1.f isa AbstractSciMLOperator) && isinplace(f1) != isinplace(f2) throw(NonconformingFunctionsError(["f2"])) end @@ -2772,7 +2772,7 @@ SplitFunction(f::SplitFunction; kwargs...) = f jvp, vjp, jac_prototype, sparsity, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed, colorvec, sys) where {iip} - f1 = typeof(f1) <: AbstractSciMLOperator ? f1 : ODEFunction(f1) + f1 = f1 isa AbstractSciMLOperator ? f1 : ODEFunction(f1) f2 = ODEFunction(f2) if isinplace(f1) != isinplace(f2) @@ -3132,7 +3132,7 @@ SDEFunction(f::SDEFunction; kwargs...) = f jvp, vjp, jac_prototype, Wfact, Wfact_t, paramjac, observed, syms, indepsym, paramsyms, colorvec, sys) - f1 = typeof(f1) <: AbstractSciMLOperator ? f1 : SDEFunction(f1) + f1 = f1 isa AbstractSciMLOperator ? f1 : SDEFunction(f1) f2 = SDEFunction(f2) SplitFunction{isinplace(f2), typeof(f1), typeof(f2), typeof(g), typeof(mass_matrix), typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp), @@ -3218,7 +3218,7 @@ SplitSDEFunction(f::SplitSDEFunction; kwargs...) = f jac_prototype, Wfact, Wfact_t, paramjac, syms, indepsym, paramsyms, observed, colorvec, sys) - f1 = typeof(f1) <: AbstractSciMLOperator ? f1 : SDEFunction(f1) + f1 = f1 isa AbstractSciMLOperator ? f1 : SDEFunction(f1) f2 = SDEFunction(f2) DynamicalSDEFunction{isinplace(f2), FullSpecialize, typeof(f1), typeof(f2), typeof(g), typeof(mass_matrix), @@ -3578,7 +3578,7 @@ DDEFunction(f::DDEFunction; kwargs...) = f paramjac, syms, indepsym, paramsyms, observed, colorvec) where {iip} - f1 = typeof(f1) <: AbstractSciMLOperator ? f1 : DDEFunction(f1) + f1 = f1 isa AbstractSciMLOperator ? f1 : DDEFunction(f1) f2 = DDEFunction(f2) DynamicalDDEFunction{isinplace(f2), FullSpecialize, typeof(f1), typeof(f2), typeof(mass_matrix), @@ -3780,7 +3780,7 @@ function NonlinearFunction{iip, specialize}(f; resid_prototype = __has_resid_prototype(f) ? f.resid_prototype : nothing) where { iip, specialize} - if mass_matrix === I && typeof(f) <: Tuple + if mass_matrix === I && f isa Tuple mass_matrix = ((I for i in 1:length(f))...,) end @@ -3965,7 +3965,7 @@ function BVPFunction{iip, specialize, twopoint}(f, bc; colorvec = __has_colorvec(f) ? f.colorvec : nothing, bccolorvec = __has_colorvec(bc) ? bc.colorvec : nothing, sys = __has_sys(f) ? f.sys : nothing) where {iip, specialize, twopoint} - if mass_matrix === I && typeof(f) <: Tuple + if mass_matrix === I && f isa Tuple mass_matrix = ((I for i in 1:length(f))...,) end diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 1b65b8e3a..d5d7aa005 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -188,7 +188,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, N = length((size(prob.u0)..., length(u))) end - if typeof(prob.f) <: Tuple + if prob.f isa Tuple f = prob.f[1] else f = prob.f @@ -346,9 +346,9 @@ end function sensitivity_solution(sol::ODESolution, u, t) T = eltype(eltype(u)) N = length((size(sol.prob.u0)..., length(u))) - interp = if typeof(sol.interp) <: LinearInterpolation + interp = if sol.interp isa LinearInterpolation LinearInterpolation(t, u) - elseif typeof(sol.interp) <: ConstantInterpolation + elseif sol.interp isa ConstantInterpolation ConstantInterpolation(t, u) else SensitivityInterpolation(t, u) diff --git a/src/solutions/rode_solutions.jl b/src/solutions/rode_solutions.jl index 6597ac185..350c3c823 100644 --- a/src/solutions/rode_solutions.jl +++ b/src/solutions/rode_solutions.jl @@ -80,7 +80,7 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, T = eltype(eltype(u)) N = length((size(prob.u0)..., length(u))) - if typeof(prob.f) <: Tuple + if prob.f isa Tuple f = prob.f[1] else f = prob.f @@ -134,7 +134,7 @@ end function calculate_solution_errors!(sol::AbstractRODESolution; fill_uanalytic = true, timeseries_errors = true, dense_errors = true) - if typeof(sol.prob.f) <: Tuple + if sol.prob.f isa Tuple f = sol.prob.f[1] else f = sol.prob.f @@ -232,9 +232,9 @@ end function sensitivity_solution(sol::AbstractRODESolution, u, t) T = eltype(eltype(u)) N = length((size(sol.prob.u0)..., length(u))) - interp = if typeof(sol.interp) <: LinearInterpolation + interp = if sol.interp isa LinearInterpolation LinearInterpolation(t, u) - elseif typeof(sol.interp) <: ConstantInterpolation + elseif sol.interp isa ConstantInterpolation ConstantInterpolation(t, u) else SensitivityInterpolation(t, u) diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 900fecfea..ed0f5f88c 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -233,13 +233,13 @@ DEFAULT_PLOT_FUNC(x, y, z) = (x, y, z) # For v0.5.2 bug @recipe function f(sol::AbstractTimeseriesSolution; plot_analytic = false, denseplot = (sol.dense || - typeof(sol.prob) <: AbstractDiscreteProblem) && - !(typeof(sol) <: AbstractRODESolution) && + sol.prob isa AbstractDiscreteProblem) && + !(sol isa AbstractRODESolution) && !(hasfield(typeof(sol), :interp) && - typeof(sol.interp) <: SensitivityInterpolation), + sol.interp isa SensitivityInterpolation), plotdensity = min(Int(1e5), sol.tslocation == 0 ? - (typeof(sol.prob) <: AbstractDiscreteProblem ? + (sol.prob isa AbstractDiscreteProblem ? max(1000, 100 * length(sol)) : max(1000, 10 * length(sol))) : 1000 * sol.tslocation), @@ -271,7 +271,7 @@ DEFAULT_PLOT_FUNC(x, y, z) = (x, y, z) # For v0.5.2 bug seriestype --> :path # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ... - if typeof(idxs) <: Tuple && (issymbollike(idxs[1]) && issymbollike(idxs[2])) + if idxs isa Tuple && (issymbollike(idxs[1]) && issymbollike(idxs[2])) val = issymbollike(int_vars[1][2]) ? String(Symbol(int_vars[1][2])) : strs[int_vars[1][2]] xguide --> val @@ -330,7 +330,7 @@ DEFAULT_PLOT_FUNC(x, y, z) = (x, y, z) # For v0.5.2 bug # Analytical solutions do not save enough information to have a good idea # of the axis ahead of time # Only set axis for animations - if sol.tslocation != 0 && !(typeof(sol) <: AbstractAnalyticalSolution) + if sol.tslocation != 0 && !(sol isa AbstractAnalyticalSolution) if all(getindex.(int_vars, 1) .== DEFAULT_PLOT_FUNC) mins = minimum(sol[int_vars[1][3], :]) maxs = maximum(sol[int_vars[1][3], :]) @@ -381,9 +381,9 @@ function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, axi if denseplot # Generate the points from the plot from dense function - if tspan === nothing && !(typeof(sol) <: AbstractAnalyticalSolution) + if tspan === nothing && !(sol isa AbstractAnalyticalSolution) plott = collect(densetspacer(sol.t[start_idx], sol.t[end_idx], plotdensity)) - elseif typeof(sol) <: AbstractAnalyticalSolution + elseif sol isa AbstractAnalyticalSolution tspan = sol.prob.tspan plott = collect(densetspacer(tspan[1], tspan[end], plotdensity)) else @@ -391,7 +391,7 @@ function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, axi end plot_timeseries = sol(plott) if plot_analytic - if typeof(sol.prob.f) <: Tuple + if sol.prob.f isa Tuple plot_analytic_timeseries = [sol.prob.f[1].analytic(sol.prob.u0, sol.prob.p, t) for t in plott] else @@ -442,7 +442,7 @@ function interpret_vars(vars, sol, syms) # Do syms conversion tmp_vars = [] for var in vars - if typeof(var) <: Union{Tuple, AbstractArray} #eltype(var) <: Symbol # Some kind of iterable + if var isa Union{Tuple, AbstractArray} #eltype(var) <: Symbol # Some kind of iterable tmp = [] for x in var if issymbollike(x) @@ -454,7 +454,7 @@ function interpret_vars(vars, sol, syms) push!(tmp, x) end end - if typeof(var) <: Tuple + if var isa Tuple var_int = tuple(tmp...) else var_int = tmp @@ -477,7 +477,7 @@ function interpret_vars(vars, sol, syms) end push!(tmp_vars, var_int) end - if typeof(vars) <: Tuple + if vars isa Tuple vars = tuple(tmp_vars...) else vars = tmp_vars @@ -486,23 +486,23 @@ function interpret_vars(vars, sol, syms) if vars === nothing # Default: plot all timeseries - if typeof(sol[1]) <: Union{Tuple, AbstractArray} + if sol[1] isa Union{Tuple, AbstractArray} vars = collect((DEFAULT_PLOT_FUNC, 0, i) for i in plot_indices(sol[1])) else vars = [(DEFAULT_PLOT_FUNC, 0, 1)] end end - if typeof(vars) <: Base.Integer + if vars isa Base.Integer vars = [(DEFAULT_PLOT_FUNC, 0, vars)] end - if typeof(vars) <: AbstractArray + if vars isa AbstractArray # If list given, its elements should be tuples, or we assume x = time tmp = Tuple[] for x in vars - if typeof(x) <: Tuple - if typeof(x[1]) <: Int + if x isa Tuple + if x[1] isa Int push!(tmp, tuple(DEFAULT_PLOT_FUNC, x...)) else push!(tmp, x) @@ -514,10 +514,10 @@ function interpret_vars(vars, sol, syms) vars = tmp end - if typeof(vars) <: Tuple + if vars isa Tuple # If tuple given... - if typeof(vars[end - 1]) <: AbstractArray - if typeof(vars[end]) <: AbstractArray + if vars[end - 1] isa AbstractArray + if vars[end] isa AbstractArray # If both axes are lists we zip (will fail if different lengths) vars = collect(zip([DEFAULT_PLOT_FUNC for i in eachindex(vars[end - 1])], vars[end - 1], vars[end])) @@ -526,12 +526,12 @@ function interpret_vars(vars, sol, syms) vars = [(DEFAULT_PLOT_FUNC, x, vars[end]) for x in vars[end - 1]] end else - if typeof(vars[2]) <: AbstractArray + if vars[2] isa AbstractArray # Just the y axis is a list vars = [(DEFAULT_PLOT_FUNC, vars[end - 1], y) for y in vars[end]] else # Both axes are numbers - if typeof(vars[1]) <: Int || issymbollike(vars[1]) + if vars[1] isa Int || issymbollike(vars[1]) vars = [tuple(DEFAULT_PLOT_FUNC, vars...)] else vars = [vars] @@ -564,7 +564,7 @@ function add_labels!(labels, x, dims, sol, strs) lys[end] = chop(lys[end]) # Take off the last comma if !issymbollike(x[2]) && x[2] == 0 && dims == 3 # if there are no dependence in syms, then we add "(t)" - if strs !== nothing && (typeof(x[3]) <: Int && endswith(strs[x[3]], r"(.*)")) || + if strs !== nothing && (x[3] isa Int && endswith(strs[x[3]], r"(.*)")) || (issymbollike(x[3]) && endswith(string(x[3]), r"(.*)")) tmp_lab = "$(lys...)" else @@ -627,7 +627,7 @@ function u_n(timeseries::AbstractArray, n::Int, sol, plott, plot_timeseries) # Returns the nth variable from the timeseries, t if n == 0 if n == 0 return plott - elseif n == 1 && !(typeof(sol[1]) <: Union{AbstractArray, ArrayPartition}) + elseif n == 1 && !(sol[1] isa Union{AbstractArray, ArrayPartition}) return timeseries else tmp = Vector{eltype(sol[1])}(undef, length(plot_timeseries)) From 6da7acff3dc86785d01f9ab789533730f004b365 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 3 Nov 2023 05:04:24 -0400 Subject: [PATCH 08/14] Move over the rest of pirating rules Will require a DiffEqBase update --- Project.toml | 7 +- .../SciMLBaseChainRulesCoreExt.jl | 62 +++++++ ext/SciMLBaseZygoteExt.jl | 160 ++++++++++++++++-- src/SciMLBase.jl | 2 - src/solutions/zygote.jl | 22 --- 5 files changed, 216 insertions(+), 37 deletions(-) rename src/solutions/chainrules.jl => ext/SciMLBaseChainRulesCoreExt.jl (57%) delete mode 100644 src/solutions/zygote.jl diff --git a/Project.toml b/Project.toml index 3c0453fe2..f3f63e5ab 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "2.6.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -31,9 +30,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" @@ -41,6 +40,7 @@ RCall = "6f49c342-dc21-5d91-9882-a32aef131414" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] +SciMLBaseChainRulesCoreExt = "ChainRulesCore" SciMLBasePartialFunctionsExt = "PartialFunctions" SciMLBasePyCallExt = "PyCall" SciMLBasePythonCallExt = "PythonCall" @@ -78,11 +78,12 @@ Statistics = "1" SymbolicIndexingInterface = "0.2" Tables = "1" TruncatedStacktraces = "1" -ZygoteRules = "0.2" +Zygote = "0.6" julia = "1.6" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" diff --git a/src/solutions/chainrules.jl b/ext/SciMLBaseChainRulesCoreExt.jl similarity index 57% rename from src/solutions/chainrules.jl rename to ext/SciMLBaseChainRulesCoreExt.jl index 899f58f19..e59bc55d6 100644 --- a/src/solutions/chainrules.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -1,3 +1,8 @@ +module SciMLBaseChainRulesCoreExt + +import ChainRulesCore +import ChainRulesCore: NoTangent, @non_differentiable + function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{ >:ChainRulesCore.HasReverseMode, }, @@ -70,3 +75,60 @@ function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym) end VA[sym], ODESolution_getindex_pullback end + +function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...) + function ODEProblemAdjoint(ȳ) + (NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type) + end + + ODEProblem(args...; kwargs...), ODEProblemAdjoint +end + +function ChainRulesCore.rrule(::Type{SDEProblem}, args...; kwargs...) + function SDEProblemAdjoint(ȳ) + (NoTangent(), ȳ.f, ȳ.g, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type) + end + + SDEProblem(args...; kwargs...), SDEProblemAdjoint +end + +function ChainRulesCore.rrule(::Type{ + <:ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, + T11, T12, + }}, u, + args...) where {T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, + T12} + function ODESolutionAdjoint(ȳ) + (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...) + end + + ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...), + ODESolutionAdjoint +end + +function ChainRulesCore.rrule(::Type{ + <:ODESolution{uType, tType, isinplace, P, NP, F, G, K, + ND, + }}, u, + args...) where {uType, tType, isinplace, P, NP, F, G, K, ND} + function SDESolutionAdjoint(ȳ) + (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...) + end + + SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint +end + +function ChainRulesCore.rrule(::DiffEqBase.EnsembleSolution, sim, time, converged) + out = EnsembleSolution(sim, time, converged) + function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} + arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i] + for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]] + (NoTangent(), EnsembleSolution(arrarr, 0.0, true), NoTangent(), NoTangent()) + end + function EnsembleSolution_adjoint(p̄::EnsembleSolution) + (NoTangent(), p̄, NoTangent(), NoTangent()) + end + out, EnsembleSolution_adjoint +end + +end \ No newline at end of file diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index a401e34f9..2070345dd 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -1,9 +1,10 @@ module SciMLBaseZygoteExt -using Zygote: pullback +using Zygote +using Zygote: pullback, ZygoteRules using ZygoteRules: @adjoint -import ZygoteRules -using SciMLBase: EnsembleSolution, ODESolution, issymbollike, sym_to_index, remake, getobserved +using SciMLBase +using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved # This method resolves the ambiguity with the pullback defined in # RecursiveArrayToolsZygoteExt @@ -30,7 +31,7 @@ end @adjoint function getindex(VA::ODESolution, sym, j::Int) function ODESolution_getindex_pullback(Δ) i = issymbollike(sym) ? sym_to_index(sym, VA) : sym - du, dprob = if i === nothing + du, dprob = if i ==== nothing getter = getobserved(VA) grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ) du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)] @@ -56,15 +57,15 @@ end VA[sym, j], ODESolution_getindex_pullback end -ZygoteRules.@adjoint function EnsembleSolution(sim, time, converged, stats) - out = EnsembleSolution(sim, time, converged) +@adjoint function EnsembleSolution(sim, time, converged, stats) + out = EnsembleSolution(sim, time, converged, stats) function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i] for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]] - (EnsembleSolution(arrarr, 0.0, true), nothing, nothing, nothing) + (EnsembleSolution(arrarr, 0.0, true, stats), nothing, nothing, nothing) end function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1}) - (EnsembleSolution(p̄, 0.0, true), nothing, nothing, nothing) + (EnsembleSolution(p̄, 0.0, true, stats), nothing, nothing, nothing) end function EnsembleSolution_adjoint(p̄::EnsembleSolution) (p̄, nothing, nothing, nothing) @@ -72,9 +73,148 @@ ZygoteRules.@adjoint function EnsembleSolution(sim, time, converged, stats) out, EnsembleSolution_adjoint end -ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution, +@adjoint function getindex(VA::ODESolution, i::Int) + function ODESolution_getindex_pullback(Δ) + Δ′ = [(i == j ? Δ : FillArrays.Fill(zero(eltype(x)), size(x))) + for (x, j) in zip(VA.u, 1:length(VA))] + (Δ′, nothing) + end + VA[i], ODESolution_getindex_pullback +end + +@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution, ::Val{:u}) - sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),) + sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),) +end + +@adjoint function getindex(VA::ODESolution, sym) + function ODESolution_getindex_pullback(Δ) + i = issymbollike(sym) ? sym_to_index(sym, VA) : sym + if i ==== nothing + throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) + else + Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] + for (x, j) in zip(VA.u, 1:length(VA))] + (Δ′, nothing) + end + end + VA[sym], ODESolution_getindex_pullback +end + +@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12 + }(u, + args...) where {T1, T2, T3, T4, T5, T6, T7, T8, + T9, T10, T11, T12} + function ODESolutionAdjoint(ȳ) + (ȳ, ntuple(_ -> nothing, length(args))...) + end + + ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12}(u, args...), + ODESolutionAdjoint +end + +@adjoint function SDEProblem{uType, tType, isinplace, P, NP, F, G, K, ND}(u, + args...) where + {uType, tType, isinplace, P, NP, F, G, K, ND} + function SDESolutionAdjoint(ȳ) + (ȳ, ntuple(_ -> nothing, length(args))...) + end + + SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint +end + +@adjoint function NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, + args...) where { + T, + N, + uType, + R, + P, + A, + O, + uType2, +} + function NonlinearSolutionAdjoint(ȳ) + (ȳ, ntuple(_ -> nothing, length(args))...) + end + NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint +end + +@adjoint function ZygoteRules.literal_getproperty(sol::AbstractTimeseriesSolution, + ::Val{:u}) + function solu_adjoint(Δ) + zerou = zero(sol.prob.u0) + _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) + (DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) + end + sol.u, solu_adjoint +end + +@adjoint function ZygoteRules.literal_getproperty(sol::AbstractNoTimeSolution, + ::Val{:u}) + function solu_adjoint(Δ) + zerou = zero(sol.prob.u0) + _Δ = @. ifelse(Δ === nothing, zerou, Δ) + (DiffEqBase.build_solution(sol.prob, sol.alg, _Δ, sol.resid),) + end + sol.u, solu_adjoint +end + +@adjoint function ZygoteRules.literal_getproperty(sol::SciMLBase.OptimizationSolution, + ::Val{:u}) + function solu_adjoint(Δ) + zerou = zero(sol.u) + _Δ = @. ifelse(Δ === nothing, zerou, Δ) + (DiffEqBase.build_solution(sol.cache, sol.alg, _Δ, sol.objective),) + end + sol.u, solu_adjoint +end + +function ∇tmap(cx, f, args...) + ys_and_backs = SciMLBase.tmap((args...) -> Zygote._pullback(cx, f, args...), args...) + if isempty(ys_and_backs) + ys_and_backs, _ -> (NoTangent(), NoTangent()) + else + ys, backs = Zygote.unzip(ys_and_backs) + function ∇tmap_internal(Δ) + Δf_and_args_zipped = SciMLBase.tmap((f, δ) -> f(δ), backs, Δ) + Δf_and_args = Zygote.unzip(Δf_and_args_zipped) + Δf = reduce(Zygote.accum, Δf_and_args[1]) + (Δf, Δf_and_args[2:end]...) + end + ys, ∇tmap_internal + end +end + +function ∇responsible_map(cx, f, args...) + ys_and_backs = SciMLBase.responsible_map((args...) -> Zygote._pullback(cx, f, args...), + args...) + if isempty(ys_and_backs) + ys_and_backs, _ -> (NoTangent(), NoTangent()) + else + ys, backs = Zygote.unzip(ys_and_backs) + ys, + function ∇responsible_map_internal(Δ) + # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful. + Δf_and_args_zipped = SciMLBase.responsible_map((f, δ) -> f(δ), + Zygote._tryreverse(SciMLBase.responsible_map, + backs, Δ)...) + Δf_and_args = Zygote.unzip(Zygote._tryreverse(SciMLBase.responsible_map, + Δf_and_args_zipped)) + Δf = reduce(Zygote.accum, Δf_and_args[1]) + (Δf, Δf_and_args[2:end]...) + end + end +end + +@adjoint function SciMLBase.tmap(f, args::Union{AbstractArray, Tuple}...) + ∇tmap(__context__, f, args...) +end + +@adjoint function SciMLBase.responsible_map(f, + args::Union{AbstractArray, Tuple + }...) + ∇responsible_map(__context__, f, args...) end end diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 72c887d77..93b809bd1 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -22,8 +22,6 @@ import RuntimeGeneratedFunctions import EnumX import TruncatedStacktraces import ADTypes: AbstractADType -import ChainRulesCore -import ZygoteRules: @adjoint import FillArrays using Reexport diff --git a/src/solutions/zygote.jl b/src/solutions/zygote.jl deleted file mode 100644 index d41d07e0f..000000000 --- a/src/solutions/zygote.jl +++ /dev/null @@ -1,22 +0,0 @@ -@adjoint function getindex(VA::ODESolution, i::Int) - function ODESolution_getindex_pullback(Δ) - Δ′ = [(i == j ? Δ : FillArrays.Fill(zero(eltype(x)), size(x))) - for (x, j) in zip(VA.u, 1:length(VA))] - (Δ′, nothing) - end - VA[i], ODESolution_getindex_pullback -end - -@adjoint function getindex(VA::ODESolution, sym) - function ODESolution_getindex_pullback(Δ) - i = issymbollike(sym) ? sym_to_index(sym, VA) : sym - if i === nothing - throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) - else - Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] - for (x, j) in zip(VA.u, 1:length(VA))] - (Δ′, nothing) - end - end - VA[sym], ODESolution_getindex_pullback -end From 68bf0d06ad70c6ad79ce4e4a809817dfa90541f5 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 3 Nov 2023 05:09:33 -0400 Subject: [PATCH 09/14] remove zygote file --- src/SciMLBase.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 93b809bd1..b7401404b 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -714,7 +714,6 @@ include("solutions/optimization_solutions.jl") include("solutions/dae_solutions.jl") include("solutions/pde_solutions.jl") include("solutions/solution_interface.jl") -include("solutions/zygote.jl") include("ensemble/ensemble_solutions.jl") include("ensemble/ensemble_problems.jl") From b960eaed6058a6ed57b97546796fc0911ed65771 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 3 Nov 2023 05:44:29 -0400 Subject: [PATCH 10/14] fix parsing --- ext/SciMLBaseZygoteExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 2070345dd..09ce9bdb5 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -31,7 +31,7 @@ end @adjoint function getindex(VA::ODESolution, sym, j::Int) function ODESolution_getindex_pullback(Δ) i = issymbollike(sym) ? sym_to_index(sym, VA) : sym - du, dprob = if i ==== nothing + du, dprob = if i === nothing getter = getobserved(VA) grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ) du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)] @@ -90,7 +90,7 @@ end @adjoint function getindex(VA::ODESolution, sym) function ODESolution_getindex_pullback(Δ) i = issymbollike(sym) ? sym_to_index(sym, VA) : sym - if i ==== nothing + if i === nothing throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) else Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] From db00ab887cd34c4c918dd95069d0a32d8c141b09 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 3 Nov 2023 05:46:46 -0400 Subject: [PATCH 11/14] version bump --- .github/workflows/Downstream.yml | 2 +- Project.toml | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 4afc74b32..7c2733d03 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -14,7 +14,7 @@ jobs: strategy: fail-fast: false matrix: - julia-version: [1,1.6] + julia-version: [1] os: [ubuntu-latest] package: - {user: SciML, repo: DelayDiffEq.jl, group: Interface} diff --git a/Project.toml b/Project.toml index f3f63e5ab..c94357eb4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "2.6.0" +version = "2.7.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -53,19 +53,19 @@ ArrayInterface = "6, 7" ChainRulesCore = "1.16" CommonSolve = "0.2.4" ConstructionBase = "1" -Distributed = "1.6" +Distributed = "1.9" DocStringExtensions = "0.8, 0.9" EnumX = "1" FillArrays = "1.6" FunctionWrappersWrappers = "0.1.3" IteratorInterfaceExtensions = "^0.1, ^1" -LinearAlgebra = "1.6" -Logging = "1.6" -Markdown = "1.6" +LinearAlgebra = "1.9" +Logging = "1.9" +Markdown = "1.9" PartialFunctions = "1.1" PrecompileTools = "1" Preferences = "1.3" -Printf = "1.6" +Printf = "1.9" RCall = "0.13.18" RecipesBase = "0.7.0, 0.8, 1.0" RecursiveArrayTools = "2.33" @@ -79,7 +79,7 @@ SymbolicIndexingInterface = "0.2" Tables = "1" TruncatedStacktraces = "1" Zygote = "0.6" -julia = "1.6" +julia = "1.9" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" From 8b5e58da77ac484a5d65e9d4706251f2916c9429 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 3 Nov 2023 06:35:46 -0400 Subject: [PATCH 12/14] hotfix new extensions --- Project.toml | 2 +- ext/SciMLBaseChainRulesCoreExt.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index c94357eb4..34ef2a560 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "2.7.0" +version = "2.7.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/SciMLBaseChainRulesCoreExt.jl b/ext/SciMLBaseChainRulesCoreExt.jl index e59bc55d6..51334f937 100644 --- a/ext/SciMLBaseChainRulesCoreExt.jl +++ b/ext/SciMLBaseChainRulesCoreExt.jl @@ -1,5 +1,6 @@ module SciMLBaseChainRulesCoreExt +using SciMLBase import ChainRulesCore import ChainRulesCore: NoTangent, @non_differentiable @@ -118,7 +119,7 @@ function ChainRulesCore.rrule(::Type{ SDESolution{uType, tType, isinplace, P, NP, F, G, K, ND}(u, args...), SDESolutionAdjoint end -function ChainRulesCore.rrule(::DiffEqBase.EnsembleSolution, sim, time, converged) +function ChainRulesCore.rrule(::SciMLBase.EnsembleSolution, sim, time, converged) out = EnsembleSolution(sim, time, converged) function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N} arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i] From 17b664cdc9f4b3a6fe33fd4a9e1992da9214013d Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 3 Nov 2023 08:48:52 -0400 Subject: [PATCH 13/14] hotfix missing imports --- Project.toml | 2 +- ext/SciMLBaseZygoteExt.jl | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index 34ef2a560..78a09fd27 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "2.7.1" +version = "2.7.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 09ce9bdb5..557a3844d 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -1,10 +1,12 @@ module SciMLBaseZygoteExt using Zygote -using Zygote: pullback, ZygoteRules -using ZygoteRules: @adjoint +using Zygote: @adjoint, pullback +import Zygote: literal_getproperty using SciMLBase -using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved +using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, + getobserved, build_solution, EnsembleSolution, + NonlinearSolution, AbstractTimeseriesSolution # This method resolves the ambiguity with the pullback defined in # RecursiveArrayToolsZygoteExt @@ -82,7 +84,7 @@ end VA[i], ODESolution_getindex_pullback end -@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution, +@adjoint function Zygote.literal_getproperty(sim::EnsembleSolution, ::Val{:u}) sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true, sim.stats),) end @@ -140,32 +142,32 @@ end NonlinearSolution{T, N, uType, R, P, A, O, uType2}(u, args...), NonlinearSolutionAdjoint end -@adjoint function ZygoteRules.literal_getproperty(sol::AbstractTimeseriesSolution, +@adjoint function literal_getproperty(sol::AbstractTimeseriesSolution, ::Val{:u}) function solu_adjoint(Δ) zerou = zero(sol.prob.u0) _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) - (DiffEqBase.build_solution(sol.prob, sol.alg, sol.t, _Δ),) + (build_solution(sol.prob, sol.alg, sol.t, _Δ),) end sol.u, solu_adjoint end -@adjoint function ZygoteRules.literal_getproperty(sol::AbstractNoTimeSolution, +@adjoint function literal_getproperty(sol::SciMLBase.AbstractNoTimeSolution, ::Val{:u}) function solu_adjoint(Δ) zerou = zero(sol.prob.u0) _Δ = @. ifelse(Δ === nothing, zerou, Δ) - (DiffEqBase.build_solution(sol.prob, sol.alg, _Δ, sol.resid),) + (build_solution(sol.prob, sol.alg, _Δ, sol.resid),) end sol.u, solu_adjoint end -@adjoint function ZygoteRules.literal_getproperty(sol::SciMLBase.OptimizationSolution, +@adjoint function literal_getproperty(sol::SciMLBase.OptimizationSolution, ::Val{:u}) function solu_adjoint(Δ) zerou = zero(sol.u) _Δ = @. ifelse(Δ === nothing, zerou, Δ) - (DiffEqBase.build_solution(sol.cache, sol.alg, _Δ, sol.objective),) + (build_solution(sol.cache, sol.alg, _Δ, sol.objective),) end sol.u, solu_adjoint end From ffe68aebedee5915190623cb08160d7ef1fbcce0 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 3 Nov 2023 15:57:28 -0400 Subject: [PATCH 14/14] hotfix fillarrays --- Project.toml | 2 +- ext/SciMLBaseZygoteExt.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 78a09fd27..df38b6f3d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "2.7.2" +version = "2.7.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 557a3844d..c0171f4d4 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -77,7 +77,7 @@ end @adjoint function getindex(VA::ODESolution, i::Int) function ODESolution_getindex_pullback(Δ) - Δ′ = [(i == j ? Δ : FillArrays.Fill(zero(eltype(x)), size(x))) + Δ′ = [(i == j ? Δ : Zygote.FillArrays.Fill(zero(eltype(x)), size(x))) for (x, j) in zip(VA.u, 1:length(VA))] (Δ′, nothing) end