Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Make package concretely typed #324

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions ext/ModelingToolkitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using HarmonicBalance:
is_rearranged,
rearrange_standard,
get_variables,
ParameterList,
DifferentialEquation,
get_independent_variables
using HarmonicBalance.KrylovBogoliubov:
Expand Down Expand Up @@ -90,7 +89,7 @@ function ModelingToolkit.ODEProblem(
eom::Union{HarmonicEquation,DifferentialEquation},
u0,
tspan::Tuple,
p::ParameterList;
p::AbstractDict;
in_place=true,
kwargs...,
)
Expand All @@ -105,14 +104,14 @@ function ModelingToolkit.ODEProblem(
end

function ModelingToolkit.NonlinearProblem(
eom::HarmonicEquation, u0, p::ParameterList; in_place=true, kwargs...
eom::HarmonicEquation, u0, p::AbstractDict; in_place=true, kwargs...
)
ss_prob = SteadyStateProblem(eom, u0, p::ParameterList; in_place=in_place, kwargs...)
ss_prob = SteadyStateProblem(eom, u0, p::AbstractDict; in_place=in_place, kwargs...)
return NonlinearProblem(ss_prob)
end

function ModelingToolkit.SteadyStateProblem(
eom::HarmonicEquation, u0, p::ParameterList; in_place=true, kwargs...
eom::HarmonicEquation, u0, p::AbstractDict; in_place=true, kwargs...
)
sys = ODESystem(eom)
param = varmap_to_vars(p, parameters(sys))
Expand Down
18 changes: 9 additions & 9 deletions src/HarmonicBalance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ using .FFTWExt

using PrecompileTools: @setup_workload, @compile_workload

@setup_workload begin
# Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
# precompile file and potentially make loading faster.
@compile_workload begin
# all calls in this block will be precompiled, regardless of whether
# they belong to your package or not (on Julia 1.8 and higher)
include("precompilation.jl")
end
end
# @setup_workload begin
# # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the
# # precompile file and potentially make loading faster.
# @compile_workload begin
# # all calls in this block will be precompiled, regardless of whether
# # they belong to your package or not (on Julia 1.8 and higher)
# include("precompilation.jl")
# end
# end

end # module
2 changes: 1 addition & 1 deletion src/Problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ end
""" Compile the Jacobian from `prob`, inserting `fixed_parameters`.
Returns a function that takes a dictionary of variables and `swept_parameters` to give the Jacobian."""
function _compile_Jacobian(
prob::Problem, swept_parameters::ParameterRange, fixed_parameters::ParameterList
prob::Problem, swept_parameters::OrderedDict, fixed_parameters::OrderedDict
)
if prob.jacobian isa Matrix
compiled_J = compile_matrix(
Expand Down
45 changes: 35 additions & 10 deletions src/Result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,55 @@ Stores the steady states of a HarmonicEquation.
$(TYPEDFIELDS)

"""
mutable struct Result
struct Result{SolutionType<:Number,ParameterType<:Number,Dimension}
"The variable values of steady-state solutions."
solutions::Array{Vector{SteadyState}}
solutions::Array{Vector{Vector{SolutionType}},Dimension}
"Values of all parameters for all solutions."
swept_parameters::ParameterRange
swept_parameters::OrderedDict{Num,Vector{ParameterType}}
"The parameters fixed throughout the solutions."
fixed_parameters::ParameterList
fixed_parameters::OrderedDict{Num,ParameterType}
"The `Problem` used to generate this."
problem::Problem
"Maps strings such as \"stable\", \"physical\" etc to arrays of values, classifying the solutions (see method `classify_solutions!`)."
classes::Dict{String,Array}
classes::Dict{String,Array{BitVector,Dimension}}
"Create binary classification of the solutions, such that each solution point receives an identifier
based on its permutation of stable branches (allows to distinguish between different phases,
which may have the same number of stable solutions). It works by converting each bitstring
`[is_stable(solution_1), is_stable(solution_2), ...,]` into unique labels."
binary_labels::Array{Int64,Dimension}
"The Jacobian with `fixed_parameters` already substituted. Accepts a dictionary specifying the solution.
If problem.jacobian is a symbolic matrix, this holds a compiled function.
If problem.jacobian was `false`, this holds a function that rearranges the equations to find J
only after numerical values are inserted (preferable in cases where the symbolic J would be very large)."
jacobian::Function
"Seed used for the solver"
seed::UInt32
end

function Result(
solutions,
swept_parameters,
fixed_parameters,
problem,
classes,
binary_labels,
jacobian,
seed,
)
soltype = eltype(eltype(eltype(solutions)))
partype = eltype(eltype(swept_parameters).parameters[2])
dim = ndims(solutions)

function Result(sol, swept, fixed, problem, classes, J, seed)
return new(sol, swept, fixed, problem, classes, J, seed)
end
Result(sol, swept, fixed, problem, classes) = new(sol, swept, fixed, problem, classes)
Result(sol, swept, fixed, problem) = new(sol, swept, fixed, problem, Dict([]))
return Result{soltype,partype,dim}(
solutions,
swept_parameters,
fixed_parameters,
problem,
classes,
binary_labels,
jacobian,
seed,
)
end

Symbolics.get_variables(res::Result)::Vector{Num} = get_variables(res.problem)
Expand Down
2 changes: 1 addition & 1 deletion src/classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function classify_binaries!(res::Result)
for (idx, el) in enumerate(unique(bin_label))
bin_label[findall(x -> x == el, bin_label)] .= idx
end
return res.classes["binary_labels"] = bin_label
return res.binary_labels .= bin_label
end

function clean_bitstrings(res::Result)
Expand Down
28 changes: 14 additions & 14 deletions src/plotting_Plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ the amplitude of the first quadratures multiplied by 2.
Default behaviour is to plot stable solutions as full lines, unstable as dashed.

If a sweep in two parameters were done, i.e., `dim(res)==2`, a one dimensional cut can be
plotted by using the keyword `cut` were it takes a `Pair{Num, Float64}` type entry. For example,
plotted by using the keyword `cut` were it takes a `Pair{Num, Float}` type entry. For example,
`plot(res, y="sqrt(u1^2+v1^2), cut=(λ => 0.2))` plots a cut at `λ = 0.2`.
###

Expand All @@ -44,18 +44,18 @@ To make the 2d plot less chaotic it is required to specify the specific `branch`
The x and y axes are taken automatically from `res`
"""
function Plots.plot(
res::Result, varargs...; cut=Pair(missing, missing), kwargs...
)::Plots.Plot
if dim(res) == 1
res::Result{S,P,D}, varargs...; cut=Pair(missing, missing), kwargs...
)::Plots.Plot where {S,P,D}
if D == 1
plot1D(res, varargs...; _set_Plots_default..., kwargs...)
elseif dim(res) == 2
elseif D == 2
if ismissing(cut.first)
plot2D(res, varargs...; _set_Plots_default..., kwargs...)
else
plot2D_cut(res, varargs...; cut=cut, _set_Plots_default..., kwargs...)
end
else
error("Data dimension ", dim(res), " not supported")
error("Data dimension ", D, " not supported")
end
end

Expand All @@ -69,9 +69,9 @@ function Plots.plot!(res::Result, varargs...; kwargs...)::Plots.Plot
end

""" Project the array `a` into the real axis, warning if its contents are complex. """
function _realify(a::Array{T} where {T<:Number}; warning="")
function _realify(a::Array{T}; warning="") where {T<:Number}
warned = false
a_real = similar(a, Float64)
a_real = similar(a, typeof(real(a[1])))
for i in eachindex(a)
if !isnan(a[i]) && !warned && !is_real(a[i])
@warn "Values with non-negligible complex parts have
Expand Down Expand Up @@ -178,7 +178,7 @@ function plot2D(

ylab, xlab = latexify.(string.(keys(res.swept_parameters)))
return p = plot!(
map(_realify, [Float64.(Y), Float64.(X), Z])...;
map(_realify, [real.(Y), real.(X), Z])...;
st=:surface,
color=:blue,
opacity=0.5,
Expand Down Expand Up @@ -252,7 +252,7 @@ function plot2D_cut(
for k in findall(branch -> !all(isnan.(branch)), branch_data) # skip NaN branches but keep indices
l = _is_labeled(p, k) ? nothing : k
Plots.plot!(
Float64.(X),
real.(X),
_realify(getindex.(branches, k));
color=k,
label=l,
Expand Down Expand Up @@ -308,8 +308,8 @@ function plot_phase_diagram_2D(

# cannot set heatmap ticks (Plots issue #3560)
return heatmap(
Float64.(X),
Float64.(Y),
real.(X),
real.(Y),
transpose(Z);
xlabel=xlab,
ylabel=ylab,
Expand All @@ -324,7 +324,7 @@ function plot_phase_diagram_1D(
X = first(values(res.swept_parameters))
Y = sum.(_get_mask(res, class, not_class))
return plot(
Float64.(X),
real.(X),
Y;
xlabel=latexify(string(keys(res.swept_parameters)...)),
ylabel="#",
Expand Down Expand Up @@ -425,7 +425,7 @@ function plot_spaghetti(
Plots.plot!(
_realify(getindex.(X, k)),
_realify(getindex.(Y, k)),
Float64.(Z);
real.(Z);
_set_Plots_default...,
color=k,
label=l,
Expand Down
Loading
Loading