Skip to content

Commit

Permalink
fixup! feat: support callable parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Sep 16, 2024
1 parent 0136759 commit 6d2aed8
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ using Symbolics: _parse_vars, value, @derivatives, get_variables,
NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval,
initial_state, transition, activeState, entry, hasnode,
ticksInState, timeInState, fixpoint_sub, fast_substitute,
CallWithMetadata
CallWithMetadata, CallWithParent
const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR)
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
jacobian_sparsity, isaffine, islinear, _iszero, _isone,
Expand Down
10 changes: 10 additions & 0 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ function isparameter(x)
end
end

function iscalledparameter(x)
x = unwrap(x)
return isparameter(getmetadata(x, CallWithParent, nothing))
end

function getcalledparameter(x)
x = unwrap(x)
return getmetadata(x, CallWithParent)
end

"""
toparam(s)
Expand Down
12 changes: 10 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,11 @@ end
vars(exprs::Num; op = Differential) = vars(unwrap(exprs); op)
vars(exprs::Symbolics.Arr; op = Differential) = vars(unwrap(exprs); op)
function vars(exprs; op = Differential)
foldl((x, y) -> vars!(x, unwrap(y); op = op), exprs; init = Set())
if hasmethod(iterate, Tuple{typeof(exprs)})
foldl((x, y) -> vars!(x, unwrap(y); op = op), exprs; init = Set())
else
vars!(Set(), unwrap(exprs); op)
end
end
vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op)
function vars!(vars, eq::Equation; op = Differential)
Expand Down Expand Up @@ -479,7 +483,11 @@ end

function collect_var!(unknowns, parameters, var, iv)
isequal(var, iv) && return nothing
if isparameter(var) || (iscall(var) && isparameter(operation(var)))
if iscalledparameter(var)
callable = getcalledparameter(var)
push!(parameters, callable)
collect_vars!(unknowns, parameters, arguments(var), iv)
elseif isparameter(var) || (iscall(var) && isparameter(operation(var)))
push!(parameters, var)
elseif !isconstant(var)
push!(unknowns, var)
Expand Down

0 comments on commit 6d2aed8

Please sign in to comment.