Skip to content

Commit

Permalink
fix: fix array variable handling in get_analysis_variable
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 6, 2025
1 parent 469eda8 commit ccfed45
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/systems/analysis_points.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,11 @@ function get_analysis_variable(var, name, iv; perturb = true)
if perturb
name = Symbol(:d_, name)
end
if Symbolics.isarraysymbolic(var)
if symbolic_type(var) == ArraySymbolic()
T = Array{eltype(symtype(var)), ndims(var)}
pvar = unwrap(only(@variables $name(iv)::T))
pvar = setmetadata(pvar, Symbolics.ArrayShapeCtx, Symbolics.shape(var))
default = zeros(symtype(var), size(var))
default = zeros(eltype(symtype(var)), size(var))
else
T = symtype(var)
pvar = unwrap(only(@variables $name(iv)::T))
Expand Down Expand Up @@ -429,7 +429,7 @@ function apply_transformation(tf::PerturbOutput, sys::AbstractSystem)
ap_ivar = ap_var(ap.input)
new_var, new_def = get_analysis_variable(ap_ivar, nameof(ap), get_iv(sys))
for outsys in ap.outputs
push!(ap_sys_eqs, ap_var(outsys) ~ ap_ivar + new_var)
push!(ap_sys_eqs, ap_var(outsys) ~ ap_ivar + wrap(new_var))
end
# add variable
unks = copy(get_unknowns(ap_sys))
Expand All @@ -446,7 +446,7 @@ function apply_transformation(tf::PerturbOutput, sys::AbstractSystem)
out_var, out_def = get_analysis_variable(
ap_ivar, nameof(ap), get_iv(sys); perturb = false)
defs[out_var] = out_def
push!(ap_sys_eqs, out_var ~ ap_ivar + new_var)
push!(ap_sys_eqs, out_var ~ ap_ivar + wrap(new_var))
push!(unks, out_var)

return ap_sys, (new_var, out_var)
Expand Down Expand Up @@ -526,11 +526,11 @@ function apply_transformation(cst::ComplementarySensitivityTransform, sys::Abstr
# but comp sensitivity wants `output + du ~ input`. Thus, `du ~ -_du`.
eqs = copy(get_eqs(sys))
@set! sys.eqs = eqs
push!(eqs, du ~ -_du)
push!(eqs, du ~ -wrap(_du))

defs = copy(get_defaults(sys))
@set! sys.defaults = defs
defs[du] = -_du
defs[du] = -wrap(_du)
return sys, (du, u)
end

Expand Down

0 comments on commit ccfed45

Please sign in to comment.