Skip to content

Commit

Permalink
fix analytical plots
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Dec 27, 2023
1 parent d6bbac5 commit db52700
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 38 deletions.
56 changes: 24 additions & 32 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,34 +412,14 @@ function add_labels!(labels, x, dims, sol, strs)
end

function add_analytic_labels!(labels, x, dims, sol, strs)
lys = []
for j in 3:dims
if x[j] == 0 && dims == 2
push!(lys, "$(getindepsym_defaultt(sol)),")
else
if strs !== nothing
push!(lys, string("True ", strs[x[j]], ","))
else
push!(lys, "True u$(x[j]),")
end
end
end
lys[end] = lys[end][1:(end - 1)] # Take off the last comma
if x[2] == 0
tmp_lab = "$(lys...)($(getindepsym_defaultt(sol)))"
else
if strs !== nothing
tmp = string("True ", strs[x[2]])
tmp_lab = "($tmp,$(lys...))"
else
tmp_lab = "(True u$(x[2]),$(lys...))"
end
end
if x[1] != DEFAULT_PLOT_FUNC
push!(labels, "$(x[1])$(tmp_lab)")
if ((x[2] isa Integer && x[2] == 0) || isequal(x[2],getindepsym_defaultt(sol))) && dims == 2
push!(labels, "True $(strs[end])")
elseif x[1] !== DEFAULT_PLOT_FUNC
push!(labels, "True f($(join(strs, ',')))")

Check warning on line 418 in src/solutions/solution_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/solution_interface.jl#L417-L418

Added lines #L417 - L418 were not covered by tests
else
push!(labels, tmp_lab)
push!(labels, "True ($(join(strs, ',')))")

Check warning on line 420 in src/solutions/solution_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/solution_interface.jl#L420

Added line #L420 was not covered by tests
end
labels
end

function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic,
Expand Down Expand Up @@ -486,27 +466,39 @@ function solplot_vecs_and_labels(dims, vars, plott, sol, plot_analytic,
for x in vars
tmp = []
strs = String[]
for j in 2:dims
for j in 2:length(x)
if (x[j] isa Integer && x[j] == 0)
push!(tmp, plott)
push!(strs, "t")
elseif isequal(x[j],getindepsym_defaultt(sol))
push!(tmp, plott)
push!(strs, String(getname(x[j])))

Check warning on line 475 in src/solutions/solution_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/solution_interface.jl#L474-L475

Added lines #L474 - L475 were not covered by tests
elseif n == 1 && !(sol[:, 1] isa Union{AbstractArray, ArrayPartition})
push!(tmp,timeseries)
push!(strs, String(getname(x[j])))
elseif x[j] == 1 && !(sol[:, 1] isa Union{AbstractArray, ArrayPartition})
push!(tmp,plot_analytic_timeseries)
if !isempty(varsyms) && x[j] isa Integer
push!(strs, String(getname(varsyms[x[j]])))

Check warning on line 479 in src/solutions/solution_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/solution_interface.jl#L479

Added line #L479 was not covered by tests
elseif hasname(x[j])
push!(strs, String(getname(x[j])))

Check warning on line 481 in src/solutions/solution_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/solution_interface.jl#L481

Added line #L481 was not covered by tests
else
push!(strs, "u[$(x[j])]")
end
else
_tmp = Vector{eltype(sol[1])}(undef, length(plot_timeseries))
for j in 1:length(plot_timeseries)
_tmp[j] = plot_timeseries[j][n]
end
push!(tmp,_tmp)
push!(strs, String(getname(x[j])))
if !isempty(varsyms) && x[j] isa Integer
push!(strs, String(getname(varsyms[x[j]])))
elseif hasname(x[j])
push!(strs, String(getname(x[j])))

Check warning on line 494 in src/solutions/solution_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/solution_interface.jl#L486-L494

Added lines #L486 - L494 were not covered by tests
else
push!(strs, "u[$(x[j])]")

Check warning on line 496 in src/solutions/solution_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/solution_interface.jl#L496

Added line #L496 was not covered by tests
end
end
end
f = x[1]
tmp = f.(tmp...)
tmp = map(f,tmp...)
tmp = tuple((getindex.(tmp, i) for i in eachindex(tmp[1]))...)
for i in eachindex(tmp)
push!(plot_vecs[i], tmp[i])
Expand Down
8 changes: 2 additions & 6 deletions test/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,14 @@ end
push!(sol.u, ode.u0)
end

syms = SciMLBase.interpret_vars(nothing, sol, SciMLBase.getsyms(sol))
int_vars = SciMLBase.interpret_vars(nothing, sol, syms) # nothing = idxs
int_vars = SciMLBase.interpret_vars(nothing, sol) # nothing = idxs
plot_vecs, labels = SciMLBase.diffeq_to_arrays(sol,
true, # plot_analytic
true, # denseplot
10, # plotdensity
ode.tspan,
0.1, # axis_safety
nothing, # idxs
int_vars,
:identity, # tscale
nothing) # strs
:identity) # tscale
@test plot_vecs[2][:, 2] @. exp(-plot_vecs[1][:, 2])
end

Expand Down

0 comments on commit db52700

Please sign in to comment.