Skip to content

Commit

Permalink
refactor: format
Browse files Browse the repository at this point in the history
refactor: format
  • Loading branch information
AayushSabharwal committed Jul 31, 2024
1 parent fe9f207 commit 01b13d6
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 35 deletions.
7 changes: 4 additions & 3 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,16 @@ end
VA[sym], ODESolution_getindex_pullback
end

@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13
, T14, T15}(u,
@adjoint function ODESolution{
T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
T9, T10, T11, T12, T13, T14, T15}
function ODESolutionAdjoint(ȳ)
(ȳ, ntuple(_ -> nothing, length(args))...)
end

ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(u, args...),
ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(
u, args...),
ODESolutionAdjoint
end

Expand Down
2 changes: 1 addition & 1 deletion src/clock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Base.getindex(c::TimeDomain, idx) = IndexedClock(c, idx)

function canonicalize_indexed_clock(ic::IndexedClock, sol::AbstractTimeseriesSolution)
c = ic.clock

return @match c begin
PeriodicClock(dt, _...) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt
&SolverStepClock => begin
Expand Down
6 changes: 4 additions & 2 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,13 @@ end
end
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s::Integer, i::Integer)
Base.@propagate_inbounds function Base.getindex(
x::AbstractEnsembleSolution, s::Integer, i::Integer)
return x.u[s].u[i]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s::Integer, i2::Integer, i3::Integer, idxs::Integer...)
Base.@propagate_inbounds function Base.getindex(
x::AbstractEnsembleSolution, s::Integer, i2::Integer, i3::Integer, idxs::Integer...)
return x.u[s][i2, i3, idxs...]
end

Expand Down
2 changes: 1 addition & 1 deletion src/problems/linear_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ function LinearProblem(A, b, args...; kwargs...)
else
LinearProblem{isinplace(A, 4)}(A, b, args...; kwargs...)
end
end
end
2 changes: 1 addition & 1 deletion src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ anydict(d) = Dict{Any, Any}(d)
anydict() = Dict{Any, Any}()

function _updated_u0_p_internal(
prob, ::Missing, ::Missing; interpret_symbolicmap = true, use_defaults = false)
prob, ::Missing, ::Missing; interpret_symbolicmap = true, use_defaults = false)
return state_values(prob), parameter_values(prob)
end
function _updated_u0_p_internal(
Expand Down
5 changes: 3 additions & 2 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2615,7 +2615,8 @@ end
typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
typeof(vjp), typeof(jac_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(f1, f2, mass_matrix,
typeof(sys), typeof(initializeprob), typeof(initializeprobmap)}(
f1, f2, mass_matrix,
cache, analytic, tgrad, jac, jvp, vjp,
jac_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
initializeprob, initializeprobmap)
Expand Down Expand Up @@ -2649,7 +2650,7 @@ function SplitFunction{iip, specialize}(f1, f2;
sys = __has_sys(f1) ? f1.sys : nothing,
initializeprob = __has_initializeprob(f1) ? f1.initializeprob : nothing,
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing
) where {iip,
) where {iip,
specialize
}
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
Expand Down
13 changes: 9 additions & 4 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
ps = parameter_values(discs)
for ts_idx in eachindex(discs)
partition = discs[ts_idx]
interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity)
interp_val = ConstantInterpolation(partition.t, partition.u)(
t, nothing, deriv, nothing, continuity)
ps = with_updated_parameter_timeseries_values(sol, ps, ts_idx => interp_val)
end
end
Expand All @@ -296,7 +297,8 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
ps = parameter_values(discs)
for ts_idx in eachindex(discs)
partition = discs[ts_idx]
interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity)
interp_val = ConstantInterpolation(partition.t, partition.u)(
t, nothing, deriv, nothing, continuity)
ps = with_updated_parameter_timeseries_values(sol, ps, ts_idx => interp_val)
end
end
Expand Down Expand Up @@ -374,7 +376,9 @@ function get_saveable_values(sys, ps, timeseries_idx)
end

function save_discretes!(integ::DEIntegrator, timeseries_idx)
save_discretes!(integ.sol, current_time(integ), get_saveable_values(integ, parameter_values(integ), timeseries_idx), timeseries_idx)
save_discretes!(integ.sol, current_time(integ),
get_saveable_values(integ, parameter_values(integ), timeseries_idx),
timeseries_idx)
end

save_discretes!(args...) = nothing
Expand Down Expand Up @@ -555,7 +559,8 @@ end

mask_discretes(::Nothing, _, _...) = nothing

function mask_discretes(discretes::ParameterTimeseriesCollection, new_t, ::Union{Int, CartesianIndex})
function mask_discretes(
discretes::ParameterTimeseriesCollection, new_t, ::Union{Int, CartesianIndex})
masked_discretes = map(discretes) do disc
i = searchsortedlast(disc.t, new_t)
disc[i:i]
Expand Down
18 changes: 9 additions & 9 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ DEFAULT_PLOT_FUNC(x, y, z) = (x, y, z) # For v0.5.2 bug

function isdenseplot(sol)
(sol.dense || sol.prob isa AbstractDiscreteProblem) &&
!(sol isa AbstractRODESolution) &&
!(hasfield(typeof(sol), :interp) &&
sol.interp isa SensitivityInterpolation)
!(sol isa AbstractRODESolution) &&
!(hasfield(typeof(sol), :interp) &&
sol.interp isa SensitivityInterpolation)
end

@recipe function f(sol::AbstractTimeseriesSolution;
Expand Down Expand Up @@ -187,7 +187,8 @@ end
disc_vars = Tuple[]
cont_vars = Tuple[]
for var in vars
tsidxs = union(get_all_timeseries_indexes(sol, var[2]), get_all_timeseries_indexes(sol, var[3]))
tsidxs = union(get_all_timeseries_indexes(sol, var[2]),
get_all_timeseries_indexes(sol, var[3]))
if ContinuousTimeseries() in tsidxs
push!(cont_vars, var)
else
Expand All @@ -209,7 +210,6 @@ end
plot_vecs, labels = diffeq_to_arrays(sol, plot_analytic, denseplot,
plotdensity, tspan, vars, tscale, plotat)


# Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ...
if idxs isa Tuple && vars[1][1] === DEFAULT_PLOT_FUNC
val = hasname(vars[1][2]) ? String(getname(vars[1][2])) : vars[1][2]
Expand Down Expand Up @@ -311,12 +311,12 @@ end
seriestype := :line
linestyle --> :dash
markershape --> :o
markersize --> repeat([2, 0], length(ts)-1)
markeralpha --> repeat([1, 0], length(ts)-1)
markersize --> repeat([2, 0], length(ts) - 1)
markeralpha --> repeat([1, 0], length(ts) - 1)
label --> string(hasname(yvar) ? getname(yvar) : yvar)

x = vec([xvals[1:end-1]'; xvals[2:end]'])
y = repeat(yvals, inner=2)[1:end-1]
x = vec([xvals[1:(end - 1)]'; xvals[2:end]'])
y = repeat(yvals, inner = 2)[1:(end - 1)]
x, y
end
end
Expand Down
12 changes: 8 additions & 4 deletions test/downstream/comprehensive_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,8 @@ end
newx = []
for i in eachindex(x)
if x[i] isa Symbol
push!(newx, allsyms[findfirst(y -> hasname(y) && x[i] == getname(y), allsyms)])
push!(newx,
allsyms[findfirst(y -> hasname(y) && x[i] == getname(y), allsyms)])
else
push!(newx, x[i])
end
Expand Down Expand Up @@ -590,7 +591,8 @@ end
newx = []
for i in eachindex(x)
if x[i] isa Symbol
push!(newx, allsyms[findfirst(y -> hasname(y) && x[i] == getname(y), allsyms)])
push!(newx,
allsyms[findfirst(y -> hasname(y) && x[i] == getname(y), allsyms)])
else
push!(newx, x[i])
end
Expand Down Expand Up @@ -896,9 +898,11 @@ end
@test_nowarn plot(sol; idxs = idx)
@test_nowarn plot(sol; idxs = [idx])
end
for idx in Iterators.flatten((Iterators.product(all_idxs, all_idxs), Iterators.product(sym_idxs, sym_idxs)))
for idx in Iterators.flatten((
Iterators.product(all_idxs, all_idxs), Iterators.product(sym_idxs, sym_idxs)))
@test_nowarn plot(sol; idxs = collect(idx))
if !(idx[1] isa Tuple || idx[2] isa Tuple || length(get_all_timeseries_indexes(sol, collect(idx))) > 1)
if !(idx[1] isa Tuple || idx[2] isa Tuple ||
length(get_all_timeseries_indexes(sol, collect(idx))) > 1)
@test_nowarn plot(sol; idxs = idx)
end
end
Expand Down
7 changes: 4 additions & 3 deletions test/downstream/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,15 @@ sol10 = sol(0.1, idxs = 2)

plotfn(t, u) = (t, 2u)
all_idxs = [x, x + p * y, t, (plotfn, 0, 1), (plotfn, t, 1), (plotfn, 0, x),
(plotfn, t, x), (plotfn, t, p * y)]
(plotfn, t, x), (plotfn, t, p * y)]
sym_idxs = [:x, :t, (plotfn, :t, 1), (plotfn, 0, :x),
(plotfn, :t, :x)]
(plotfn, :t, :x)]
for idx in Iterators.flatten((all_idxs, sym_idxs))
@test_nowarn plot(sol; idxs = idx)
@test_nowarn plot(sol; idxs = [idx])
end
for idx in Iterators.flatten((Iterators.product(all_idxs, all_idxs), Iterators.product(sym_idxs, sym_idxs)))
for idx in Iterators.flatten((
Iterators.product(all_idxs, all_idxs), Iterators.product(sym_idxs, sym_idxs)))
@test_nowarn plot(sol; idxs = collect(idx))
if !(idx[1] isa Tuple || idx[2] isa Tuple)
@test_nowarn plot(sol; idxs = idx)
Expand Down
10 changes: 5 additions & 5 deletions test/serialization_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ using Serialization
using Test

for clock in [
SciMLBase.Clock(0.5),
SciMLBase.Clock(0.5; phase = 0.1),
SciMLBase.SolverStepClock,
SciMLBase.Continuous,
]
SciMLBase.Clock(0.5),
SciMLBase.Clock(0.5; phase = 0.1),
SciMLBase.SolverStepClock,
SciMLBase.Continuous
]
serialize("_tmp.jls", clock)
newclock = deserialize("_tmp.jls")
@test newclock == clock
Expand Down

0 comments on commit 01b13d6

Please sign in to comment.