Skip to content

Commit

Permalink
Merge pull request #536 from pepijndevos/new-branch
Browse files Browse the repository at this point in the history
Change typeof(x) <: y to x isa y
  • Loading branch information
ChrisRackauckas authored Nov 2, 2023
2 parents 585a1aa + e037ef6 commit 5a7771d
Show file tree
Hide file tree
Showing 13 changed files with 92 additions and 92 deletions.
6 changes: 3 additions & 3 deletions src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ function batch_func(i, prob, alg; kwargs...)
new_prob = prob.prob_func(_prob, i, iter)
rerun = true
x = prob.output_func(solve(new_prob, alg; kwargs...), i)
if !(typeof(x) <: Tuple)
if !(x isa Tuple)
rerun_warn()
_x = (x, false)
else
Expand All @@ -117,7 +117,7 @@ function batch_func(i, prob, alg; kwargs...)
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
new_prob = prob.prob_func(_prob, i, iter)
x = prob.output_func(solve(new_prob, alg; kwargs...), i)
if !(typeof(x) <: Tuple)
if !(x isa Tuple)
rerun_warn()
_x = (x, false)
else
Expand Down Expand Up @@ -170,7 +170,7 @@ function solve_batch(prob, alg, ensemblealg::EnsembleThreads, II, pmap_batch_siz
return solve_batch(prob, alg, EnsembleSerial(), II, pmap_batch_size; kwargs...)
end

if typeof(prob.prob) <: AbstractJumpProblem && length(II) != 1
if prob.prob isa AbstractJumpProblem && length(II) != 1
probs = [deepcopy(prob.prob) for i in 1:nthreads]
else
probs = prob.prob
Expand Down
32 changes: 16 additions & 16 deletions src/ensemble/ensemble_analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ get_timestep(sim, i) = (getindex(sol, i) for sol in sim)
get_timepoint(sim, t) = (sol(t) for sol in sim)
function componentwise_vectors_timestep(sim, i)
arr = [get_timestep(sim, i)...]
if typeof(arr[1]) <: AbstractArray
if arr[1] isa AbstractArray
return vecarr_to_vectors(VectorOfArray(arr))
else
return arr
end
end
function componentwise_vectors_timepoint(sim, t)
arr = [get_timepoint(sim, t)...]
if typeof(arr[1]) <: AbstractArray
if arr[1] isa AbstractArray
return vecarr_to_vectors(VectorOfArray(arr))
else
return arr
Expand Down Expand Up @@ -123,7 +123,7 @@ end

function SciMLBase.EnsembleSummary(sim::SciMLBase.AbstractEnsembleSolution{T, N},
t = sim[1].t; quantiles = [0.05, 0.95]) where {T, N}
if typeof(sim[1]) <: SciMLSolution
if sim[1] isa SciMLSolution
m, v = timeseries_point_meanvar(sim, t)
med = timeseries_point_median(sim, t)
qlow = timeseries_point_quantile(sim, quantiles[1], t)
Expand Down Expand Up @@ -190,13 +190,13 @@ function componentwise_mean(A)
mean = zero(x0) ./ 1
for x in A
n += 1
if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray)
if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray)
mean .+= x
else
mean += x
end
end
if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray)
if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray)
mean ./= n
else
mean /= n
Expand All @@ -215,7 +215,7 @@ function componentwise_meanvar(A; bessel = true)
delta2 = zero(x0) ./ 1
for x in A
n += 1
if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray)
if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray)
delta .= x .- mean
mean .+= delta ./ n
delta2 .= x .- mean
Expand All @@ -231,13 +231,13 @@ function componentwise_meanvar(A; bessel = true)
return NaN
else
if bessel
if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray)
if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray)
M2 .= M2 ./ (n .- 1)
else
M2 = M2 ./ (n .- 1)
end
else
if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray)
if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray)
M2 .= M2 ./ n
else
M2 = M2 ./ n
Expand All @@ -257,7 +257,7 @@ function componentwise_meancov(A, B; bessel = true)
dx = zero(x0) ./ 1
for (x, y) in zip(A, B)
n += 1
if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray)
if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray)
dx .= x .- meanx
meanx .+= dx ./ n
meany .+= (y .- meany) ./ n
Expand All @@ -273,13 +273,13 @@ function componentwise_meancov(A, B; bessel = true)
return NaN
else
if bessel
if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray)
if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray)
C .= C ./ (n .- 1)
else
C = C ./ (n .- 1)
end
else
if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray)
if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray)
C .= C ./ n
else
C = C ./ n
Expand All @@ -293,7 +293,7 @@ function componentwise_meancor(A, B; bessel = true)
mx, my, cov = componentwise_meancov(A, B; bessel = bessel)
mx, vx = componentwise_meanvar(A; bessel = bessel)
my, vy = componentwise_meanvar(B; bessel = bessel)
if typeof(vx) <: AbstractArray
if vx isa AbstractArray
vx .= sqrt.(vx)
vy .= sqrt.(vy)
else
Expand All @@ -316,7 +316,7 @@ function componentwise_weighted_meancov(A, B, W; weight_type = :reliability)
dx = zero(x0) ./ 1
for (x, y, w) in zip(A, B, W)
n += 1
if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray)
if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray)
wsum .+= w
wsum2 .+= w .* w
dx .= x .- meanx
Expand All @@ -336,19 +336,19 @@ function componentwise_weighted_meancov(A, B, W; weight_type = :reliability)
return NaN
else
if weight_type == :population
if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray)
if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray)
C .= C ./ wsum
else
C = C ./ wsum
end
elseif weight_type == :reliability
if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray)
if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray)
C .= C ./ (wsum .- wsum2 ./ wsum)
else
C = C ./ (wsum .- wsum2 ./ wsum)
end
elseif weight_type == :frequency
if typeof(x0) <: AbstractArray && !(typeof(x0) <: StaticArraysCore.SArray)
if x0 isa AbstractArray && !(x0 isa StaticArraysCore.SArray)
C .= C ./ (wsum .- 1)
else
C = C ./ (wsum .- 1)
Expand Down
12 changes: 6 additions & 6 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ end
### Plot Recipes

@recipe function f(sim::AbstractEnsembleSolution;
zcolors = typeof(sim.u) <: AbstractArray ? fill(nothing, length(sim.u)) :
zcolors = sim.u isa AbstractArray ? fill(nothing, length(sim.u)) :
nothing,
trajectories = eachindex(sim))
for i in trajectories
Expand All @@ -156,16 +156,16 @@ end
end

@recipe function f(sim::EnsembleSummary;
trajectories = typeof(sim.u[1]) <: AbstractArray ? eachindex(sim.u[1]) :
trajectories = sim.u[1] isa AbstractArray ? eachindex(sim.u[1]) :
1,
error_style = :ribbon, ci_type = :quantile)
if ci_type == :SEM
if typeof(sim.u[1]) <: AbstractArray
if sim.u[1] isa AbstractArray
u = vecarr_to_vectors(sim.u)
else
u = [sim.u.u]
end
if typeof(sim.u[1]) <: AbstractArray
if sim.u[1] isa AbstractArray
ci_low = vecarr_to_vectors(VectorOfArray([sqrt.(sim.v[i] / sim.num_monte) .*
1.96 for i in 1:length(sim.v)]))
ci_high = ci_low
Expand All @@ -175,12 +175,12 @@ end
ci_high = ci_low
end
elseif ci_type == :quantile
if typeof(sim.med[1]) <: AbstractArray
if sim.med[1] isa AbstractArray
u = vecarr_to_vectors(sim.med)
else
u = [sim.med.u]
end
if typeof(sim.u[1]) <: AbstractArray
if sim.u[1] isa AbstractArray
ci_low = u - vecarr_to_vectors(sim.qlow)
ci_high = vecarr_to_vectors(sim.qhigh) - u
else
Expand Down
8 changes: 4 additions & 4 deletions src/integrator_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ Base.length(iter::TimeChoiceIterator) = length(iter.ts)

@recipe function f(integrator::DEIntegrator;
denseplot = (integrator.opts.calck ||
typeof(integrator) <: AbstractSDEIntegrator) &&
integrator isa AbstractSDEIntegrator) &&
integrator.iter > 0,
plotdensity = 10,
plot_analytic = false, vars = nothing, idxs = nothing)
Expand Down Expand Up @@ -797,7 +797,7 @@ Base.length(iter::TimeChoiceIterator) = length(iter.ts)
else # just get values
if x[j] == 0
push!(plot_vecs[j - 1], integrator.t)
elseif x[j] == 1 && !(typeof(integrator.u) <: AbstractArray)
elseif x[j] == 1 && !(integrator.u isa AbstractArray)
push!(plot_vecs[j - 1], integrator.u)
else
push!(plot_vecs[j - 1], integrator.u[x[j]])
Expand All @@ -816,7 +816,7 @@ Base.length(iter::TimeChoiceIterator) = length(iter.ts)
else # Just get values
if x[j] == 0
push!(plot_vecs[j], integrator.t)
elseif x[j] == 1 && !(typeof(integrator.u) <: AbstractArray)
elseif x[j] == 1 && !(integrator.u isa AbstractArray)
push!(plot_vecs[j],
integrator.sol.prob.f(Val{:analytic}, integrator.t,
integrator.sol[1]))
Expand All @@ -840,7 +840,7 @@ Base.length(iter::TimeChoiceIterator) = length(iter.ts)
end

# Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ...
if typeof(idxs) <: Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol)
if idxs isa Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol)
xlabel --> idxs[1]
ylabel --> idxs[2]
if length(idxs) > 2
Expand Down
38 changes: 19 additions & 19 deletions src/interpolation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ end
continuity::Symbol = :left) where {I, D}
t = id.t
u = id.u
typeof(id) <: HermiteInterpolation && (du = id.du)
id isa HermiteInterpolation && (du = id.du)
tdir = sign(t[end] - t[1])
idx = sortperm(tvals, rev = tdir < 0)
i = 2 # Start the search thinking it's between t[1] and t[2]
Expand All @@ -91,17 +91,17 @@ end
error("Solution interpolation cannot extrapolate past the final timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.")
tdir * tvals[idx[1]] < tdir * t[1] &&
error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.")
if typeof(idxs) <: Number
if idxs isa Number
vals = Vector{eltype(first(u))}(undef, length(tvals))
elseif typeof(idxs) <: AbstractVector
elseif idxs isa AbstractVector
vals = Vector{Vector{eltype(first(u))}}(undef, length(tvals))
else
vals = Vector{eltype(u)}(undef, length(tvals))
end
for j in idx
tval = tvals[j]
i = searchsortedfirst(@view(t[i:end]), tval, rev = tdir < 0) + i - 1 # It's in the interval t[i-1] to t[i]
avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual
avoid_constant_ends = deriv != Val{0} #|| tval isa ForwardDiff.Dual
avoid_constant_ends && i == 1 && (i += 1)
if !avoid_constant_ends && t[i - 1] == tval # Can happen if it's the first value!
if idxs === nothing
Expand All @@ -118,11 +118,11 @@ end
vals[j] = u[k][idxs]
end
else
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
dt = t[i] - t[i - 1]
Θ = (tval - t[i - 1]) / dt
idxs_internal = idxs
if typeof(id) <: HermiteInterpolation
if id isa HermiteInterpolation
vals[j] = interpolant(Θ, id, dt, u[i - 1], u[i], du[i - 1], du[i],
idxs_internal, deriv)
else
Expand All @@ -143,7 +143,7 @@ times t (sorted), with values u and derivatives ks
continuity::Symbol = :left) where {I, D}
t = id.t
u = id.u
typeof(id) <: HermiteInterpolation && (du = id.du)
id isa HermiteInterpolation && (du = id.du)
tdir = sign(t[end] - t[1])
idx = sortperm(tvals, rev = tdir < 0)
i = 2 # Start the search thinking it's between t[1] and t[2]
Expand All @@ -156,7 +156,7 @@ times t (sorted), with values u and derivatives ks
for j in idx
tval = tvals[j]
i = searchsortedfirst(@view(t[i:end]), tval, rev = tdir < 0) + i - 1 # It's in the interval t[i-1] to t[i]
avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual
avoid_constant_ends = deriv != Val{0} #|| tval isa ForwardDiff.Dual
avoid_constant_ends && i == 1 && (i += 1)
if !avoid_constant_ends && t[i - 1] == tval # Can happen if it's the first value!
if idxs === nothing
Expand All @@ -173,19 +173,19 @@ times t (sorted), with values u and derivatives ks
vals[j] = u[k][idxs]
end
else
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
dt = t[i] - t[i - 1]
Θ = (tval - t[i - 1]) / dt
idxs_internal = idxs
if eltype(u) <: Union{AbstractArray, ArrayPartition}
if typeof(id) <: HermiteInterpolation
if id isa HermiteInterpolation
interpolant!(vals[j], Θ, id, dt, u[i - 1], u[i], du[i - 1], du[i],
idxs_internal, deriv)
else
interpolant!(vals[j], Θ, id, dt, u[i - 1], u[i], idxs_internal, deriv)
end
else
if typeof(id) <: HermiteInterpolation
if id isa HermiteInterpolation
vals[j] = interpolant(Θ, id, dt, u[i - 1], u[i], du[i - 1], du[i],
idxs_internal, deriv)
else
Expand All @@ -206,7 +206,7 @@ times t (sorted), with values u and derivatives ks
continuity::Symbol = :left) where {I, D}
t = id.t
u = id.u
typeof(id) <: HermiteInterpolation && (du = id.du)
id isa HermiteInterpolation && (du = id.du)
tdir = sign(t[end] - t[1])
t[end] == t[1] && tval != t[end] &&
error("Solution interpolation cannot extrapolate from a single timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.")
Expand All @@ -215,7 +215,7 @@ times t (sorted), with values u and derivatives ks
tdir * tval < tdir * t[1] &&
error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.")
@inbounds i = searchsortedfirst(t, tval, rev = tdir < 0) # It's in the interval t[i-1] to t[i]
avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual
avoid_constant_ends = deriv != Val{0} #|| tval isa ForwardDiff.Dual
avoid_constant_ends && i == 1 && (i += 1)
if !avoid_constant_ends && t[i] == tval
lasti = lastindex(t)
Expand All @@ -232,11 +232,11 @@ times t (sorted), with values u and derivatives ks
val = u[i - 1][idxs]
end
else
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
dt = t[i] - t[i - 1]
Θ = (tval - t[i - 1]) / dt
idxs_internal = idxs
if typeof(id) <: HermiteInterpolation
if id isa HermiteInterpolation
val = interpolant(Θ, id, dt, u[i - 1], u[i], du[i - 1], du[i], idxs_internal,
deriv)
else
Expand All @@ -256,7 +256,7 @@ times t (sorted), with values u and derivatives ks
continuity::Symbol = :left) where {I, D}
t = id.t
u = id.u
typeof(id) <: HermiteInterpolation && (du = id.du)
id isa HermiteInterpolation && (du = id.du)
tdir = sign(t[end] - t[1])
t[end] == t[1] && tval != t[end] &&
error("Solution interpolation cannot extrapolate from a single timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.")
Expand All @@ -265,7 +265,7 @@ times t (sorted), with values u and derivatives ks
tdir * tval < tdir * t[1] &&
error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.")
@inbounds i = searchsortedfirst(t, tval, rev = tdir < 0) # It's in the interval t[i-1] to t[i]
avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual
avoid_constant_ends = deriv != Val{0} #|| tval isa ForwardDiff.Dual
avoid_constant_ends && i == 1 && (i += 1)
if !avoid_constant_ends && t[i] == tval
lasti = lastindex(t)
Expand All @@ -282,11 +282,11 @@ times t (sorted), with values u and derivatives ks
copy!(out, u[i - 1][idxs])
end
else
typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
id isa SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE)
dt = t[i] - t[i - 1]
Θ = (tval - t[i - 1]) / dt
idxs_internal = idxs
if typeof(id) <: HermiteInterpolation
if id isa HermiteInterpolation
interpolant!(out, Θ, id, dt, u[i - 1], u[i], du[i - 1], du[i], idxs_internal,
deriv)
else
Expand Down
Loading

0 comments on commit 5a7771d

Please sign in to comment.