Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: format #755

Merged
merged 1 commit into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,16 @@ end
VA[sym], ODESolution_getindex_pullback
end

@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13
, T14, T15}(u,
@adjoint function ODESolution{
T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
T9, T10, T11, T12, T13, T14, T15}
function ODESolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end

ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(u, args...),
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(
u, args...),
ODESolutionAdjoint
end

Expand Down
2 changes: 1 addition & 1 deletion src/clock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Base.getindex(c::TimeDomain, idx) = IndexedClock(c, idx)

function canonicalize_indexed_clock(ic::IndexedClock, sol::AbstractTimeseriesSolution)
c = ic.clock

return @match c begin
PeriodicClock(dt, _...) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt
&SolverStepClock => begin
Expand Down
6 changes: 4 additions & 2 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,13 @@ end
end
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s::Integer, i::Integer)
Base.@propagate_inbounds function Base.getindex(
x::AbstractEnsembleSolution, s::Integer, i::Integer)
return x.u[s].u[i]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s::Integer, i2::Integer, i3::Integer, idxs::Integer...)
Base.@propagate_inbounds function Base.getindex(
x::AbstractEnsembleSolution, s::Integer, i2::Integer, i3::Integer, idxs::Integer...)
return x.u[s][i2, i3, idxs...]
end

Expand Down
2 changes: 1 addition & 1 deletion src/problems/linear_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ function LinearProblem(A, b, args...; kwargs...)
else
LinearProblem{isinplace(A, 4)}(A, b, args...; kwargs...)
end
end
end
2 changes: 1 addition & 1 deletion src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ anydict(d) = Dict{Any, Any}(d)
anydict() = Dict{Any, Any}()

function _updated_u0_p_internal(
prob, ::Missing, ::Missing; interpret_symbolicmap = true, use_defaults = false)
prob, ::Missing, ::Missing; interpret_symbolicmap = true, use_defaults = false)
return state_values(prob), parameter_values(prob)
end
function _updated_u0_p_internal(
Expand Down
5 changes: 3 additions & 2 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2615,7 +2615,8 @@ end
typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
typeof(vjp), typeof(jac_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(f1, f2, mass_matrix,
typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(
f1, f2, mass_matrix,
cache, analytic, tgrad, jac, jvp, vjp,
jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
initializeprob, initializeprobmap)
Expand Down Expand Up @@ -2649,7 +2650,7 @@ function SplitFunction{iip, specialize}(f1, f2;
sys = __has_sys(f1) ? f1.sys : nothing,
initializeprob = __has_initializeprob(f1) ? f1.initializeprob : nothing,
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing
) where {iip,
) where {iip,
specialize
}
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
Expand Down
13 changes: 9 additions & 4 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
ps = parameter_values(discs)
for ts_idx in eachindex(discs)
partition = discs[ts_idx]
interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity)
interp_val = ConstantInterpolation(partition.t, partition.u)(
t, nothing, deriv, nothing, continuity)
ps = with_updated_parameter_timeseries_values(sol, ps, ts_idx => interp_val)
end
end
Expand All @@ -296,7 +297,8 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
ps = parameter_values(discs)
for ts_idx in eachindex(discs)
partition = discs[ts_idx]
interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity)
interp_val = ConstantInterpolation(partition.t, partition.u)(
t, nothing, deriv, nothing, continuity)
ps = with_updated_parameter_timeseries_values(sol, ps, ts_idx => interp_val)
end
end
Expand Down Expand Up @@ -374,7 +376,9 @@ function get_saveable_values(sys, ps, timeseries_idx)
end

function save_discretes!(integ::DEIntegrator, timeseries_idx)
save_discretes!(integ.sol, current_time(integ), get_saveable_values(integ, parameter_values(integ), timeseries_idx), timeseries_idx)
save_discretes!(integ.sol, current_time(integ),
get_saveable_values(integ, parameter_values(integ), timeseries_idx),
timeseries_idx)
end

save_discretes!(args...) = nothing
Expand Down Expand Up @@ -555,7 +559,8 @@ end

mask_discretes(::Nothing, _, _...) = nothing

function mask_discretes(discretes::ParameterTimeseriesCollection, new_t, ::Union{Int, CartesianIndex})
function mask_discretes(
discretes::ParameterTimeseriesCollection, new_t, ::Union{Int, CartesianIndex})
masked_discretes = map(discretes) do disc
i = searchsortedlast(disc.t, new_t)
disc[i:i]
Expand Down
18 changes: 9 additions & 9 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ DEFAULT_PLOT_FUNC(x, y, z) = (x, y, z) # For v0.5.2 bug

function isdenseplot(sol)
(sol.dense || sol.prob isa AbstractDiscreteProblem) &&
!(sol isa AbstractRODESolution) &&
!(hasfield(typeof(sol), :interp) &&
sol.interp isa SensitivityInterpolation)
!(sol isa AbstractRODESolution) &&
!(hasfield(typeof(sol), :interp) &&
sol.interp isa SensitivityInterpolation)
end

@recipe function f(sol::AbstractTimeseriesSolution;
Expand Down Expand Up @@ -187,7 +187,8 @@ end
disc_vars = Tuple[]
cont_vars = Tuple[]
for var in vars
tsidxs = union(get_all_timeseries_indexes(sol, var[2]), get_all_timeseries_indexes(sol, var[3]))
tsidxs = union(get_all_timeseries_indexes(sol, var[2]),
get_all_timeseries_indexes(sol, var[3]))
if ContinuousTimeseries() in tsidxs
push!(cont_vars, var)
else
Expand All @@ -209,7 +210,6 @@ end
plot_vecs, labels = diffeq_to_arrays(sol, plot_analytic, denseplot,
plotdensity, tspan, vars, tscale, plotat)


# Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ...
if idxs isa Tuple && vars[1][1] === DEFAULT_PLOT_FUNC
val = hasname(vars[1][2]) ? String(getname(vars[1][2])) : vars[1][2]
Expand Down Expand Up @@ -311,12 +311,12 @@ end
seriestype := :line
linestyle --> :dash
markershape --> :o
markersize --> repeat([2, 0], length(ts)-1)
markeralpha --> repeat([1, 0], length(ts)-1)
markersize --> repeat([2, 0], length(ts) - 1)
markeralpha --> repeat([1, 0], length(ts) - 1)
label --> string(hasname(yvar) ? getname(yvar) : yvar)

x = vec([xvals[1:end-1]'; xvals[2:end]'])
y = repeat(yvals, inner=2)[1:end-1]
x = vec([xvals[1:(end - 1)]'; xvals[2:end]'])
y = repeat(yvals, inner = 2)[1:(end - 1)]
x, y
end
end
Expand Down
12 changes: 8 additions & 4 deletions test/downstream/comprehensive_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,8 @@ end
newx = []
for i in eachindex(x)
if x[i] isa Symbol
push!(newx, allsyms[findfirst(y -> hasname(y) && x[i] == getname(y), allsyms)])
push!(newx,
allsyms[findfirst(y -> hasname(y) && x[i] == getname(y), allsyms)])
else
push!(newx, x[i])
end
Expand Down Expand Up @@ -590,7 +591,8 @@ end
newx = []
for i in eachindex(x)
if x[i] isa Symbol
push!(newx, allsyms[findfirst(y -> hasname(y) && x[i] == getname(y), allsyms)])
push!(newx,
allsyms[findfirst(y -> hasname(y) && x[i] == getname(y), allsyms)])
else
push!(newx, x[i])
end
Expand Down Expand Up @@ -896,9 +898,11 @@ end
@test_nowarn plot(sol; idxs = idx)
@test_nowarn plot(sol; idxs = [idx])
end
for idx in Iterators.flatten((Iterators.product(all_idxs, all_idxs), Iterators.product(sym_idxs, sym_idxs)))
for idx in Iterators.flatten((
Iterators.product(all_idxs, all_idxs), Iterators.product(sym_idxs, sym_idxs)))
@test_nowarn plot(sol; idxs = collect(idx))
if !(idx[1] isa Tuple || idx[2] isa Tuple || length(get_all_timeseries_indexes(sol, collect(idx))) > 1)
if !(idx[1] isa Tuple || idx[2] isa Tuple ||
length(get_all_timeseries_indexes(sol, collect(idx))) > 1)
@test_nowarn plot(sol; idxs = idx)
end
end
Expand Down
7 changes: 4 additions & 3 deletions test/downstream/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,15 @@ sol10 = sol(0.1, idxs = 2)

plotfn(t, u) = (t, 2u)
all_idxs = [x, x + p * y, t, (plotfn, 0, 1), (plotfn, t, 1), (plotfn, 0, x),
(plotfn, t, x), (plotfn, t, p * y)]
(plotfn, t, x), (plotfn, t, p * y)]
sym_idxs = [:x, :t, (plotfn, :t, 1), (plotfn, 0, :x),
(plotfn, :t, :x)]
(plotfn, :t, :x)]
for idx in Iterators.flatten((all_idxs, sym_idxs))
@test_nowarn plot(sol; idxs = idx)
@test_nowarn plot(sol; idxs = [idx])
end
for idx in Iterators.flatten((Iterators.product(all_idxs, all_idxs), Iterators.product(sym_idxs, sym_idxs)))
for idx in Iterators.flatten((
Iterators.product(all_idxs, all_idxs), Iterators.product(sym_idxs, sym_idxs)))
@test_nowarn plot(sol; idxs = collect(idx))
if !(idx[1] isa Tuple || idx[2] isa Tuple)
@test_nowarn plot(sol; idxs = idx)
Expand Down
10 changes: 5 additions & 5 deletions test/serialization_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ using Serialization
using Test

for clock in [
SciMLBase.Clock(0.5),
SciMLBase.Clock(0.5; phase = 0.1),
SciMLBase.SolverStepClock,
SciMLBase.Continuous,
]
SciMLBase.Clock(0.5),
SciMLBase.Clock(0.5; phase = 0.1),
SciMLBase.SolverStepClock,
SciMLBase.Continuous
]
serialize("_tmp.jls", clock)
newclock = deserialize("_tmp.jls")
@test newclock == clock
Expand Down
Loading