diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index 2a62528de..ebfdf0c73 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -752,20 +752,19 @@ Base.length(iter::TimeChoiceIterator) = length(iter.ts) idxs = vars end - syms = getsyms(integrator) - int_vars = interpret_vars(idxs, integrator.sol, syms) - strs = cleansyms(syms) + int_vars = interpret_vars(idxs, integrator.sol) if denseplot # Generate the points from the plot from dense function - plott = collect(range(integrator.tprev; step = integrator.t, length = plotdensity)) - plot_timeseries = integrator(plott) + plott = collect(range(integrator.tprev, integrator.t; length = plotdensity)) if plot_analytic plot_analytic_timeseries = [integrator.sol.prob.f.analytic(integrator.sol.prob.u0, integrator.sol.prob.p, t) for t in plott] end - end # if not denseplot, we'll just get the values right from the integrator. + else + plott = nothing + end dims = length(int_vars[1]) for var in int_vars @@ -779,11 +778,18 @@ Base.length(iter::TimeChoiceIterator) = length(iter.ts) end labels = String[]# Array{String, 2}(1, length(int_vars)*(1+plot_analytic)) + strs = String[] + varsyms = variable_symbols(integrator) + @show plott + for x in int_vars for j in 2:dims if denseplot - push!(plot_vecs[j - 1], - u_n(plot_timeseries, x[j], integrator.sol, plott, plot_timeseries)) + if (x[j] isa Integer && x[j] == 0) || isequal(x[j],getindepsym_defaultt(integrator)) + push!(plot_vecs[j - 1], plott) + else + push!(plot_vecs[j - 1], Vector(integrator(plott; idxs = x[j]))) + end else # just get values if x[j] == 0 push!(plot_vecs[j - 1], integrator.t) @@ -793,6 +799,14 @@ Base.length(iter::TimeChoiceIterator) = length(iter.ts) push!(plot_vecs[j - 1], integrator.u[x[j]]) end end + + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end end add_labels!(labels, x, dims, integrator.sol, strs) end diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index b8b054ff8..56486edb4 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -147,7 +147,7 @@ function Base.show(io::IO, m::MIME"text/plain", A::AbstractPDESolution) show(io, m, A.u) end -DEFAULT_PLOT_FUNC(x...) = (x...,) +DEFAULT_PLOT_FUNC(x,y) = (x,y) DEFAULT_PLOT_FUNC(x, y, z) = (x, y, z) # For v0.5.2 bug @recipe function f(sol::AbstractTimeseriesSolution; @@ -163,7 +163,7 @@ DEFAULT_PLOT_FUNC(x, y, z) = (x, y, z) # For v0.5.2 bug max(1000, 100 * length(sol)) : max(1000, 10 * length(sol))) : 1000 * sol.tslocation), - tspan = nothing, axis_safety = 0.1, + tspan = nothing, vars = nothing, idxs = nothing) if vars !== nothing Base.depwarn("To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", @@ -173,60 +173,79 @@ DEFAULT_PLOT_FUNC(x, y, z) = (x, y, z) # For v0.5.2 bug idxs = vars end - syms = getsyms(sol) - if idxs isa Symbol - int_vars = interpret_vars([idxs], sol, syms) + idxs = idxs === nothing ? (1:length(sol.u[1])) : idxs + + if !(idxs isa Union{Tuple, AbstractArray}) + vars = interpret_vars([idxs], sol) else - int_vars = interpret_vars(idxs, sol, syms) + vars = interpret_vars(idxs, sol) end - strs = cleansyms(syms) tscale = get(plotattributes, :xscale, :identity) plot_vecs, labels = diffeq_to_arrays(sol, plot_analytic, denseplot, - plotdensity, tspan, axis_safety, - idxs, int_vars, tscale, strs) + plotdensity, tspan, vars, tscale) tdir = sign(sol.t[end] - sol.t[1]) xflip --> tdir < 0 seriestype --> :path + @show labels + # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ... - if idxs isa Tuple && (symbolic_type(idxs[1]) != NotSymbolic() && symbolic_type(idxs[2]) != NotSymbolic()) - val = symbolic_type(int_vars[1][2]) != NotSymbolic() ? String(Symbol(int_vars[1][2])) : - strs[int_vars[1][2]] + if idxs isa Tuple && vars[1][1] === DEFAULT_PLOT_FUNC + val = hasname(vars[1][2]) ? String(getname(vars[1][2])) : vars[1][2] + if val isa Integer + if val == 0 + val = "t" + else + val = "u[$val]" + end + end xguide --> val - val = symbolic_type(int_vars[1][3]) != NotSymbolic() ? String(Symbol(int_vars[1][3])) : - strs[int_vars[1][3]] + val = hasname(vars[1][3]) ? String(getname(vars[1][3])) : vars[1][3] + if val isa Integer + if val == 0 + val = "t" + else + val = "u[$val]" + end + end yguide --> val if length(idxs) > 2 - val = symbolic_type(int_vars[1][4]) != NotSymbolic() ? String(Symbol(int_vars[1][4])) : - strs[int_vars[1][4]] + val = hasname(vars[1][4]) ? String(getname(vars[1][4])) : vars[1][4] + if val isa Integer + if val == 0 + val = "t" + else + val = "u[$val]" + end + end zguide --> val end end - if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(int_vars, 1))) && - getindex.(int_vars, 1) == zeros(length(int_vars))) || - (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(int_vars, 2))) && - getindex.(int_vars, 2) == zeros(length(int_vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(int_vars, 1)) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(int_vars, 2)) + if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 1))) && + getindex.(vars, 1) == zeros(length(vars))) || + (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) && + getindex.(vars, 2) == zeros(length(vars))) || + all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 1)) || + all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2)) xguide --> "$(getindepsym_defaultt(sol))" end - if length(int_vars[1]) >= 3 && ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(int_vars, 3))) && - getindex.(int_vars, 3) == zeros(length(int_vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(int_vars, 3))) + if length(vars[1]) >= 3 && ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 3))) && + getindex.(vars, 3) == zeros(length(vars))) || + all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 3))) yguide --> "$(getindepsym_defaultt(sol))" end - if length(int_vars[1]) >= 4 && ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(int_vars, 4))) && - getindex.(int_vars, 4) == zeros(length(int_vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(int_vars, 4))) + if length(vars[1]) >= 4 && ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 4))) && + getindex.(vars, 4) == zeros(length(vars))) || + all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 4))) zguide --> "$(getindepsym_defaultt(sol))" end - if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(int_vars, 2))) && - getindex.(int_vars, 2) == zeros(length(int_vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(int_vars, 2)) + if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) && + getindex.(vars, 2) == zeros(length(vars))) || + all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2)) if tspan === nothing if tdir > 0 xlims --> (sol.t[1], sol.t[end]) @@ -236,50 +255,14 @@ DEFAULT_PLOT_FUNC(x, y, z) = (x, y, z) # For v0.5.2 bug else xlims --> (tspan[1], tspan[end]) end - else - mins = minimum(sol[int_vars[1][2], :]) - maxs = maximum(sol[int_vars[1][2], :]) - for iv in int_vars - mins = min(mins, minimum(sol[iv[2], :])) - maxs = max(maxs, maximum(sol[iv[2], :])) - end - xlims --> - ((1 - sign(mins) * axis_safety) * mins, (1 + sign(maxs) * axis_safety) * maxs) - end - - # Analytical solutions do not save enough information to have a good idea - # of the axis ahead of time - # Only set axis for animations - if sol.tslocation != 0 && !(sol isa AbstractAnalyticalSolution) - if all(getindex.(int_vars, 1) .== DEFAULT_PLOT_FUNC) - mins = minimum(sol[int_vars[1][3], :]) - maxs = maximum(sol[int_vars[1][3], :]) - for iv in int_vars - mins = min(mins, minimum(sol[iv[3], :])) - maxs = max(maxs, maximum(sol[iv[3], :])) - end - ylims --> - ((1 - sign(mins) * axis_safety) * mins, (1 + sign(maxs) * axis_safety) * maxs) - - if length(int_vars[1]) >= 4 - mins = minimum(sol[int_vars[1][4], :]) - maxs = maximum(sol[int_vars[1][4], :]) - for iv in int_vars - mins = min(mins, minimum(sol[iv[4], :])) - maxs = max(mins, maximum(sol[iv[4], :])) - end - zlims --> ((1 - sign(mins) * axis_safety) * mins, - (1 + sign(maxs) * axis_safety) * maxs) - end - end end label --> reshape(labels, 1, length(labels)) (plot_vecs...,) end -function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, axis_safety, - vars, int_vars, tscale, strs) +function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, + vars, tscale) if tspan === nothing if sol.tslocation == 0 end_idx = length(sol) @@ -309,7 +292,6 @@ function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, axi else plott = collect(densetspacer(tspan[1], tspan[end], plotdensity)) end - plot_timeseries = sol(plott) if plot_analytic if sol.prob.f isa Tuple plot_analytic_timeseries = [sol.prob.f[1].analytic(sol.prob.u0, sol.prob.p, @@ -338,7 +320,6 @@ function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, axi plott = collect(densetspacer(tspan[1], tspan[2], plotdensity)) end - plot_timeseries = sol.u[start_idx:end_idx] if plot_analytic plot_analytic_timeseries = sol.u_analytic[start_idx:end_idx] else @@ -347,70 +328,16 @@ function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, axi end end - dims = length(int_vars[1]) - for var in int_vars - @assert length(var) == dims + dims = length(vars[1]) - 1 + for var in vars + @assert length(var)-1 == dims end # Should check that all have the same dims! - plot_vecs, labels = solplot_vecs_and_labels(dims, int_vars, plot_timeseries, plott, sol, - plot_analytic, plot_analytic_timeseries, - strs) + plot_vecs, labels = solplot_vecs_and_labels(dims, vars, plott, sol, + plot_analytic, plot_analytic_timeseries) end -function interpret_vars(vars, sol, syms) - if vars !== nothing && syms !== nothing - # Do syms conversion - tmp_vars = [] - for var in vars - if var isa Union{Tuple, AbstractArray} #eltype(var) <: Symbol # Some kind of iterable - tmp = [] - for x in var - if symbolic_type(x) != NotSymbolic() - found = sym_to_index(x, syms) - push!(tmp, - found == nothing && getindepsym_defaultt(sol) == x ? 0 : - something(found, x)) - else - push!(tmp, x) - end - end - if var isa Tuple - var_int = tuple(tmp...) - else - var_int = tmp - end - elseif symbolic_type(var) != NotSymbolic() - found = sym_to_index(var, syms) - if (var isa Symbol) && has_sys(sol.prob.f) - var_int = if found === nothing && getindepsym_defaultt(sol) == var - 0 - elseif found !== nothing - found - elseif is_variable(sol, var) - variable_symbols(sol)[variable_index(sol, var)] - elseif is_parameter(sol, var) - parameter_symbols(sol)[parameter_index(sol, var)] - elseif is_independent_variable(sol, var) - independent_variable_symbols(sol)[1] - else - error("Tried to index solution with a Symbol that was not found in the system using `getproperty`.") - end - else - var_int = found == nothing && getindepsym_defaultt(sol) == var ? 0 : - something(found, var) - end - else - var_int = var - end - push!(tmp_vars, var_int) - end - if vars isa Tuple - vars = tuple(tmp_vars...) - else - vars = tmp_vars - end - end - +function interpret_vars(vars, sol) if vars === nothing # Default: plot all timeseries if sol[:, 1] isa Union{Tuple, AbstractArray} @@ -474,119 +401,54 @@ function interpret_vars(vars, sol, syms) end function add_labels!(labels, x, dims, sol, strs) - lys = [] - for j in 3:dims - if symbolic_type(x[j]) == NotSymbolic() && x[j] == 0 - push!(lys, "$(getindepsym_defaultt(sol)),") - elseif symbolic_type(x[j]) != NotSymbolic() - push!(lys, "$(x[j]),") - else - if strs !== nothing - push!(lys, "$(strs[x[j]]),") - else - push!(lys, "u$(x[j]),") - end - end - end - lys[end] = chop(lys[end]) # Take off the last comma - if symbolic_type(x[2]) == NotSymbolic() && x[2] == 0 && dims == 3 - # if there are no dependence in syms, then we add "(t)" - if strs !== nothing && (x[3] isa Int && endswith(strs[x[3]], r"(.*)")) || - (symbolic_type(x[3]) != NotSymbolic() && endswith(string(x[3]), r"(.*)")) - tmp_lab = "$(lys...)" - else - tmp_lab = "$(lys...)($(getindepsym_defaultt(sol)))" - end - else - if strs !== nothing && symbolic_type(x[2]) == NotSymbolic() && x[2] != 0 - tmp = strs[x[2]] - tmp_lab = "($tmp,$(lys...))" - else - if symbolic_type(x[2]) == NotSymbolic() && x[2] == 0 - tmp_lab = "($(getindepsym_defaultt(sol)),$(lys...))" - elseif symbolic_type(x[2]) != NotSymbolic() - tmp_lab = "($(x[2]),$(lys...))" - else - tmp_lab = "(u$(x[2]),$(lys...))" - end - end - end - if x[1] != DEFAULT_PLOT_FUNC - push!(labels, "$(x[1])$(tmp_lab)") + if ((x[2] isa Integer && x[2] == 0) || isequal(x[2],getindepsym_defaultt(sol))) && dims == 2 + push!(labels, strs[end]) + elseif x[1] !== DEFAULT_PLOT_FUNC + push!(labels, "f($(join(strs, ',')))") else - push!(labels, tmp_lab) + push!(labels, "($(join(strs, ',')))") end labels end function add_analytic_labels!(labels, x, dims, sol, strs) - lys = [] - for j in 3:dims - if x[j] == 0 && dims == 3 - push!(lys, "$(getindepsym_defaultt(sol)),") - else - if strs !== nothing - push!(lys, string("True ", strs[x[j]], ",")) - else - push!(lys, "True u$(x[j]),") - end - end - end - lys[end] = lys[end][1:(end - 1)] # Take off the last comma - if x[2] == 0 - tmp_lab = "$(lys...)($(getindepsym_defaultt(sol)))" - else - if strs !== nothing - tmp = string("True ", strs[x[2]]) - tmp_lab = "($tmp,$(lys...))" - else - tmp_lab = "(True u$(x[2]),$(lys...))" - end - end - if x[1] != DEFAULT_PLOT_FUNC - push!(labels, "$(x[1])$(tmp_lab)") - else - push!(labels, tmp_lab) - end -end - -function u_n(timeseries::Union{AbstractArray,RecursiveArrayTools.AbstractVectorOfArray}, n::Int, sol, plott, plot_timeseries) - # Returns the nth variable from the timeseries, t if n == 0 - if n == 0 - return plott - elseif n == 1 && !(sol[:, 1] isa Union{AbstractArray, ArrayPartition}) - return timeseries - else - tmp = Vector{eltype(sol[1])}(undef, length(plot_timeseries)) - for j in 1:length(plot_timeseries) - tmp[j] = plot_timeseries[j][n] - end - return tmp - end -end - -function u_n(timeseries::Union{AbstractArray,RecursiveArrayTools.AbstractVectorOfArray}, sym, sol, plott, plot_timeseries) - @assert symbolic_type(sym) != NotSymbolic() - if getindepsym_defaultt(sol) == Symbol(sym) - return plott + if ((x[2] isa Integer && x[2] == 0) || isequal(x[2],getindepsym_defaultt(sol))) && dims == 2 + push!(labels, "True $(strs[end])") + elseif x[1] !== DEFAULT_PLOT_FUNC + push!(labels, "True f($(join(strs, ',')))") else - getobserved(sol).((sym,), eachcol(timeseries), (sol.prob.p,), plott) + push!(labels, "True ($(join(strs, ',')))") end + labels end -function solplot_vecs_and_labels(dims, vars, plot_timeseries, plott, sol, plot_analytic, - plot_analytic_timeseries, strs) +function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic, + plot_analytic_timeseries) plot_vecs = [] labels = String[] + varsyms = variable_symbols(sol) for x in vars tmp = [] - for j in 2:dims - push!(tmp, u_n(plot_timeseries, x[j], sol, plott, plot_timeseries)) + strs = String[] + for j in 2:length(x) + if (x[j] isa Integer && x[j] == 0) || isequal(x[j],getindepsym_defaultt(sol)) + push!(tmp, plott) + push!(strs, "t") + else + push!(tmp, sol(plott; idxs = x[j])) + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + end end f = x[1] - tmp = f.(tmp...) + tmp = map(f,tmp...) tmp = tuple((getindex.(tmp, i) for i in eachindex(tmp[1]))...) for i in eachindex(tmp) @@ -603,13 +465,40 @@ function solplot_vecs_and_labels(dims, vars, plot_timeseries, plott, sol, plot_a analytic_plot_vecs = [] for x in vars tmp = [] - for j in 2:dims - push!(tmp, - u_n(plot_analytic_timeseries, x[j], sol, plott, - plot_analytic_timeseries)) + strs = String[] + for j in 2:length(x) + if (x[j] isa Integer && x[j] == 0) + push!(tmp, plott) + push!(strs, "t") + elseif isequal(x[j],getindepsym_defaultt(sol)) + push!(tmp, plott) + push!(strs, String(getname(x[j]))) + elseif x[j] == 1 && !(sol[:, 1] isa Union{AbstractArray, ArrayPartition}) + push!(tmp,plot_analytic_timeseries) + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + else + _tmp = Vector{eltype(sol[1])}(undef, length(plot_timeseries)) + for j in 1:length(plot_timeseries) + _tmp[j] = plot_timeseries[j][n] + end + push!(tmp,_tmp) + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + end end f = x[1] - tmp = f.(tmp...) + tmp = map(f,tmp...) tmp = tuple((getindex.(tmp, i) for i in eachindex(tmp[1]))...) for i in eachindex(tmp) push!(plot_vecs[i], tmp[i]) diff --git a/src/symbolic_utils.jl b/src/symbolic_utils.jl index 9057717e1..dac171f55 100644 --- a/src/symbolic_utils.jl +++ b/src/symbolic_utils.jl @@ -69,33 +69,3 @@ function getobserved(sol::AbstractOptimizationSolution) return DEFAULT_OBSERVED end end - -cleansyms(syms::Nothing) = nothing -cleansyms(syms::Tuple) = collect(cleansym(sym) for sym in syms) -cleansyms(syms::Vector) = cleansyms(Symbol.(syms)) -cleansyms(syms::Vector{Symbol}) = cleansym.(syms) -cleansyms(syms::LinearIndices) = nothing -cleansyms(syms::CartesianIndices) = nothing -cleansyms(syms::Base.OneTo) = nothing - -function cleansym(sym::Symbol) - str = String(sym) - # MTK generated names - rules = ("₊" => ".", "⦗" => "(", "⦘" => ")") - for r in rules - str = replace(str, r) - end - return str -end - -function sym_to_index(sym, prob::AbstractSciMLProblem) - return variable_index(prob.f, sym) -end -function sym_to_index(sym, sol::AbstractSciMLSolution) - idx = variable_index(sol.prob.f, sym) - if idx === nothing - idx = findfirst(isequal(sym), keys(sol.u[1])) - end - return idx -end -sym_to_index(sym, syms) = findfirst(isequal(Symbol(sym)), syms) diff --git a/test/solution_interface.jl b/test/solution_interface.jl index 798aa31d7..ed8edb31b 100644 --- a/test/solution_interface.jl +++ b/test/solution_interface.jl @@ -14,18 +14,14 @@ end push!(sol.u, ode.u0) end - syms = SciMLBase.interpret_vars(nothing, sol, SciMLBase.getsyms(sol)) - int_vars = SciMLBase.interpret_vars(nothing, sol, syms) # nothing = idxs + int_vars = SciMLBase.interpret_vars(nothing, sol) # nothing = idxs plot_vecs, labels = SciMLBase.diffeq_to_arrays(sol, true, # plot_analytic true, # denseplot 10, # plotdensity ode.tspan, - 0.1, # axis_safety - nothing, # idxs int_vars, - :identity, # tscale - nothing) # strs + :identity) # tscale @test plot_vecs[2][:, 2] ≈ @. exp(-plot_vecs[1][:, 2]) end