Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up for classification by refactoring solution interface #133

Merged
merged 13 commits into from
Jan 2, 2024
78 changes: 38 additions & 40 deletions src/classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,39 +20,19 @@ classify_solutions!(res, "sqrt(u1^2 + v1^2) > 1.0" , "large_amplitude")
```

"""
function classify_solutions!(res::Result, condition::String, name::String; physical=true)
values = classify_solutions(res, condition; physical=physical)
function classify_solutions!(res::Result, func::Union{String, Function}, name::String; physical=true)
values = classify_solutions(res, func; physical=physical)
res.classes[name] = values
end


function classify_solutions(res::Result, condition::String; physical=true)
expr = Num(eval(Meta.parse(condition)))
function cond_func(s::OrderedDict, res)
physical && !is_physical(s, res) && return false
s = [key => real(s[key]) for key in keys(s)] # make values real
Bool(substitute_all(expr, s).val)
function classify_solutions(res::Result, func; physical=true)
if physical
transform_solutions(res, func)
else
f_comp(soln) = func(soln) * _is_physical(soln)
transform_solutions(res, f_comp)
end
classify_solutions(res, cond_func)
end


# Classify solutions where for `f` which is a function accepting a solution dictionary
# specifies all params and variables).
function classify_solutions!(res::Result, f::Function, name::String)
values = classify_solutions(res, f)
res.classes[name] = values
end


"Return an array of booleans classifying the solution in `res`
according to `f` (`f` takes a solution dictionary, return a boolean)"
function classify_solutions(res::Result, f::Function)
values = similar(res.solutions, BitVector)
@maybethread for (idx, soln) in collect(enumerate(res.solutions))
values[idx] = [f(get_single_solution(res, index=idx, branch=b), res) for b in 1:length(soln)]
end
values
end


Expand All @@ -73,11 +53,14 @@ $(TYPEDSIGNATURES)
Returns true if the solution `soln` of the Result `res` is physical (= real number).
`im_tol` : an absolute threshold to distinguish real/complex numbers.
"""
function is_physical(soln::StateDict, res::Result; im_tol=IM_TOL)
function is_physical(soln::StateDict, res::Result)
var_values = [getindex(soln, v) for v in res.problem.variables]
return all(abs.(imag.(var_values)).<im_tol) && any(isnan.(var_values).==false)
return _is_physical(var_values)
end

_is_physical(soln::SteadyState; im_tol=IM_TOL) = all(abs.(imag.(soln)).<im_tol) && any(isnan.(soln).==false)
_is_physical(res::Result) = classify_solutions(res, _is_physical)


"""
$(TYPEDSIGNATURES)
Expand All @@ -86,12 +69,19 @@ Stable solutions are real and have all Jacobian eigenvalues Re[λ] <= 0.
`im_tol` : an absolute threshold to distinguish real/complex numbers.
`rel_tol`: Re(λ) considered <=0 if real.(λ) < rel_tol*abs(λmax)
"""
function is_stable(soln::StateDict, res::Result; im_tol=IM_TOL, rel_tol=1E-10)
is_physical(soln, res ,im_tol=im_tol) || return false # the solution is unphysical anyway
λs = eigvals(real.(res.jacobian(soln)))
return all([real.(λs) .< rel_tol*maximum(abs.(λs))]...)
function is_stable(soln::StateDict, res::Result; kwargs...)
_is_stable(values(soln) |> collect, res.jacobian; kwargs...)
end

function _is_stable(res::Result; kwargs...)
_isit(soln) = _is_stable(soln, res.jacobian; kwargs...)
end

function _is_stable(soln, J; rel_tol=1E-10)
λs = eigvals(real.(J(soln)))
oameye marked this conversation as resolved.
Show resolved Hide resolved
scale = maximum(Iterators.map(abs, λs))
_is_physical(soln) && all(x -> real(x) < rel_tol*scale, λs)
end

"""
$(TYPEDSIGNATURES)
Expand All @@ -100,12 +90,20 @@ Hopf-unstable solutions are real and have exactly two Jacobian eigenvalues with
are complex conjugates of each other.
`im_tol` : an absolute threshold to distinguish real/complex numbers.
"""
function is_Hopf_unstable(soln::StateDict, res::Result; im_tol=IM_TOL)
is_physical(soln, res, im_tol=im_tol) || return false # the solution is unphysical anyway
J = res.jacobian(soln)
unstable = filter(x -> real(x) > 0, eigvals(J))
function is_Hopf_unstable(soln::StateDict, res::Result)
_is_Hopf_unstable(values(soln) |> collect, res.jacobian)
end

function _is_Hopf_unstable(res::Result)
_isit(soln) = _is_Hopf_unstable(soln, res.jacobian)
end

function _is_Hopf_unstable(soln, J)
_is_physical(soln) || return false # the solution is unphysical anyway
λs = eigvals(J(soln))
unstable = filter(x -> real(x) > 0, λs)
(length(unstable) == 2 && abs(conj(unstable[1]) - unstable[2]) < im_tol && return true) || return false
return all([ real.(eigvals(J)) .< 0]...)
return all(x -> real(x) < 0, λs)
end


Expand Down Expand Up @@ -143,4 +141,4 @@ function filter_result!(res::Result, class::String)
for c in filter(x -> x != "binary_labels", keys(res.classes)) # binary_labels stores one Int per parameter set, ignore here
res.classes[c] = [s[bools] for s in res.classes[c]]
end
end
end
16 changes: 7 additions & 9 deletions src/solve_homotopy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ end


function _classify_default!(result)
classify_solutions!(result, is_physical, "physical")
classify_solutions!(result, is_stable, "stable")
classify_solutions!(result, is_Hopf_unstable, "Hopf")
classify_solutions!(result, _is_physical, "physical")
classify_solutions!(result, _is_stable(result), "stable")
classify_solutions!(result, _is_Hopf_unstable(result), "Hopf")
order_branches!(result, ["physical", "stable"]) # shuffle the branches to have relevant ones first
classify_binaries!(result) # assign binaries to solutions depending on which branches are stable
end
Expand Down Expand Up @@ -152,12 +152,10 @@ Substitute the values according to `fixed_parameters` and compile into a functio
"""
function compile_matrix(matrix, variables, fixed_parameters)
J = substitute_all(matrix, fixed_parameters)
matrix = [build_function(el, variables) for el in J]
matrix = eval.(matrix)
function m(s::OrderedDict)
vals = [s[var] for var in variables]
return [ComplexF64(Base.invokelatest(el, vals)) for el in matrix]
end
matrix = build_function(J, variables)
matrix = eval(matrix[1]) # compiled allocating function, see Symbolics manual
m(vals::Vector) = matrix(vals)
m(s::OrderedDict) = m([s[var] for var in variables]) # for the UI
return m
end

Expand Down
28 changes: 19 additions & 9 deletions src/transform_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@ Takes a `Result` object and a string `f` representing a Symbolics.jl expression.
Returns an array with the values of `f` evaluated for the respective solutions.
Additional substitution rules can be specified in `rules` in the format `("a" => val)` or `(a => val)`
"""
function transform_solutions(res::Result, f::String; branches = 1:branch_count(res), rules=Dict(), target_type=ComplexF64)
# a string is used as input - a macro would not "see" the user's namespace while the user's namespace does not "see" the variables
transformed = [Vector{target_type}(undef, length(branches)) for k in res.solutions] # preallocate

func = _build_substituted(res, f; rules=rules)
function transform_solutions(res::Result, func; branches = 1:branch_count(res))

# preallocate an array for the numerical values, rewrite parts of it
# when looping through the solutions
n_vars = length(get_variables(res))
n_pars = length(res.swept_parameters)
vals = Vector{ComplexF64}(undef, n_vars + n_pars)

for idx in CartesianIndices(res.solutions)
params_values = res.swept_parameters[Tuple(idx)...]
vals[end-n_pars+1:end] .= params_values # param values are common to all branches
vtype = isa(Base.invokelatest(func, zeros(ComplexF64, n_vars+n_pars)), Bool) ? BitVector : Vector{ComplexF64}
transformed = _similar(vtype, res; branches=branches)

@maybethread for idx in CartesianIndices(res.solutions)
for i in 1:length(idx) # param values are common to all branches
vals[end-n_pars+i] = res.swept_parameters[idx[i]][i]
end
for (k, branch) in enumerate(branches)
vals[1:n_vars] .= res.solutions[idx][branch]
transformed[idx][k] = Base.invokelatest(func, vals)
Expand All @@ -34,6 +34,13 @@ function transform_solutions(res::Result, f::String; branches = 1:branch_count(r
return transformed
end

function transform_solutions(res::Result, f::String; rules=Dict(), kwargs...)
# a string is used as input
# a macro would not "see" the user's namespace while the user's namespace does not "see" the variables
func = _build_substituted(f, res; rules=rules)
transform_solutions(res, func; kwargs...)
end

transform_solutions(res::Result, fs::Vector{String}; kwargs...) = [transform_solutions(res, f; kwargs...) for f in fs]

# a simplified version meant to work with arrays of solutions
Expand Down Expand Up @@ -64,7 +71,7 @@ end

""" Parse `expr` into a Symbolics.jl expression, substituting the fixed parameters of `res`
The resulting function takes in the values of the variables and swept parameters. """
function _build_substituted(res::Result, expr; rules=Dict())
function _build_substituted(expr, res::Result; rules=Dict())

# define variables in rules in this namespace
new_keys = declare_variable.(string.(keys(Dict(rules))))
Expand All @@ -75,6 +82,9 @@ function _build_substituted(res::Result, expr; rules=Dict())

end

function _similar(type, res::Result; branches=1:branch_count(res))
[type(undef, length(branches)) for k in res.solutions]
end

## move masks here

Expand Down
Loading