Skip to content

Commit

Permalink
actual fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pogudingleb committed Dec 21, 2024
1 parent e167762 commit 46e9289
Showing 1 changed file with 42 additions and 22 deletions.
64 changes: 42 additions & 22 deletions src/ODEexport.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
function __add_t(s)
if endswith(s, "(t)")
return s
end
return s * "(t)"
end

function __remove_t(s)
if endswith(s, "(t)")
return s[1:(end - 3)]
end
return s
end

"""
print_for_maple(ode, package)
Expand All @@ -7,16 +21,17 @@ Prints the ODE in the format accepted by maple packages
- DifferentialThomas if package=:DifferentialThomas
"""
function print_for_maple(ode::ODE, package = :SIAN)
varstr =
Dict(x => var_to_str(x) * "(t)" for x in vcat(ode.x_vars, ode.u_vars, ode.y_vars))
varstr = Dict(
x => (__add_t var_to_str)(x) for x in vcat(ode.x_vars, ode.u_vars, ode.y_vars)
)
merge!(varstr, Dict(p => var_to_str(p) for p in ode.parameters))
R_print, vars_print = Nemo.polynomial_ring(
base_ring(ode.poly_ring),
[varstr[v] for v in gens(ode.poly_ring)],
)
x_names = join(map(var_to_str, ode.x_vars), ", ")
y_names = join(map(var_to_str, ode.y_vars), ", ")
u_names = join(map(var_to_str, ode.u_vars), ", ")
x_names = join(map(__remove_t var_to_str, ode.x_vars), ", ")
y_names = join(map(__remove_t var_to_str, ode.y_vars), ", ")
u_names = join(map(__remove_t var_to_str, ode.u_vars), ", ")
u_string = length(u_names) > 0 ? ", [$u_names]" : ""
ranking = "[$x_names], [$y_names] $u_string"

Expand All @@ -39,7 +54,7 @@ function print_for_maple(ode::ODE, package = :SIAN)
end

# Form the equations
function _rhs_to_str(rhs)
function _rhs_to_str_t(rhs)
num, den = unpack_fraction(rhs)
res = string(evaluate(num, vars_print))
if den != 1
Expand All @@ -50,10 +65,10 @@ function print_for_maple(ode::ODE, package = :SIAN)

eqs = []
for (x, f) in ode.x_equations
push!(eqs, ("diff(" * var_to_str(x) * "(t), t)", _rhs_to_str(f)))
push!(eqs, ("diff(" * (__add_t var_to_str)(x) * ", t)", _rhs_to_str_t(f)))
end
for (y, g) in ode.y_equations
push!(eqs, (var_to_str(y) * "(t)", _rhs_to_str(g)))
push!(eqs, ((__add_t var_to_str)(y), _rhs_to_str_t(g)))
end
if package == :SIAN
result *= join(map(a -> a[1] * " = " * a[2], eqs), ",\n") * "\n];\n"
Expand All @@ -78,13 +93,13 @@ end

#------------------------------------------------------------------------------

function _rhs_to_str(lhs)
function _rhs_to_str_not(lhs)
num, den = unpack_fraction(lhs)
rslt = string(num)
if den != 1
rslt = "($rslt) / ($den)"
end
return rslt
return replace(rslt, "(t)" => "")
end

"""
Expand All @@ -99,7 +114,7 @@ function print_for_DAISY(ode::ODE)
result =
result *
"B_:={" *
join(map(var_to_str, vcat(ode.u_vars, ode.y_vars, ode.x_vars)), ", ") *
join(map(__remove_t var_to_str, vcat(ode.u_vars, ode.y_vars, ode.x_vars)), ", ") *
"}\$\n"

result = result * "FOR EACH EL_ IN B_ DO DEPEND EL_,T\$\n\n"
Expand All @@ -114,10 +129,10 @@ function print_for_DAISY(ode::ODE)
eqs = []

for (x, f) in ode.x_equations
push!(eqs, "df($(var_to_str(x)), t) = " * _rhs_to_str(f))
push!(eqs, "df($((__remove_t var_to_str)(x)), t) = " * _rhs_to_str_not(f))
end
for (y, g) in ode.y_equations
push!(eqs, "$(var_to_str(y)) = " * _rhs_to_str(g))
push!(eqs, "$((__remove_t var_to_str)(y)) = " * _rhs_to_str_not(g))
end

result = result * "C_:={" * join(eqs, ",\n") * "}\$\n"
Expand All @@ -138,36 +153,41 @@ Prints the ODE in the format accepted by GenSSI 2.0 (https://github.com/genssi-d
function print_for_GenSSI(ode::ODE)
result = "function model = SMTH()\n"

result *= "syms " * join(var_to_str.(ode.x_vars), " ") * "\n"
result *= "syms " * join((__remove_t var_to_str).(ode.x_vars), " ") * "\n"
result *= "syms " * join(var_to_str.(ode.parameters), " ") * "\n"
result *= "syms " * join(map(v -> var_to_str(v) * "0", ode.x_vars), " ") * "\n"
result *=
"syms " * join(map(v -> (__remove_t var_to_str)(v) * "0", ode.x_vars), " ") * "\n"
if length(ode.u_vars) > 0
result *= "syms " * join(var_to_str.(ode.u_vars), " ") * "\n"
result *= "syms " * join((__remove_t var_to_str).(ode.u_vars), " ") * "\n"
end

result *= "model.sym.p = [" * join(var_to_str.(ode.parameters), "; ") * "; "
result *= join(map(v -> var_to_str(v) * "0", ode.x_vars), "; ") * "];\n"
result *= join(map(v -> (__remove_t var_to_str)(v) * "0", ode.x_vars), "; ") * "];\n"

result *= "model.sym.x = [" * join(var_to_str.(ode.x_vars), "; ") * "];\n"
result *= "model.sym.g = [" * join(var_to_str.(ode.u_vars), "; ") * "];\n"
result *=
"model.sym.x = [" * join((__remove_t var_to_str).(ode.x_vars), "; ") * "];\n"
result *=
"model.sym.g = [" * join((__remove_t var_to_str).(ode.u_vars), "; ") * "];\n"

result *=
"model.sym.x0 = [" * join(map(v -> var_to_str(v) * "0", ode.x_vars), "; ") * "];\n"
"model.sym.x0 = [" *
join(map(v -> (__remove_t var_to_str)(v) * "0", ode.x_vars), "; ") *
"];\n"

result *= "model.sym.xdot = ["

eqs = []
for x in ode.x_vars
f = ode.x_equations[x]
push!(eqs, _rhs_to_str(f))
push!(eqs, _rhs_to_str_not(f))
end
result *= join(eqs, "\n") * "];\n"

result *= "model.sym.y = ["
eqs = []
for y in ode.y_vars
g = ode.y_equations[y]
push!(eqs, _rhs_to_str(g))
push!(eqs, _rhs_to_str_not(g))
end
result *= join(eqs, "\n") * "];\n"

Expand Down

0 comments on commit 46e9289

Please sign in to comment.